areusch commented on code in PR #10753:
URL: https://github.com/apache/tvm/pull/10753#discussion_r843339941


##########
src/relay/backend/aot_executor_codegen.cc:
##########
@@ -260,9 +260,168 @@ class AOTOnDemandAllocator : public 
transform::DeviceAwareExprVisitor {
   std::vector<TensorType> return_ttypes_;
 };
 
-/*! \brief Code generator for AOT executor */
-class AOTExecutorCodegen : public MixedModeVisitor {
- protected:
+namespace {
+
+/*!
+ * \brief Utility function to convert a concrete integer to a PrimExpr.
+ * \param num the number to convert
+ * \return PrimExpr representing num
+ */
+inline PrimExpr ConstInt32(int32_t num) {
+  ICHECK_LE(num, std::numeric_limits<int>::max());
+  return tir::make_const(DataType::Int(32), static_cast<int>(num));
+}
+
+/*!
+ * \brief Emit a call to the C Device API.
+ * \param device_name Name of the device, used to prefix the function name.
+ * \param hook Name of the Device API function.
+ * \param context void* context arg passed to this API function.
+ */
+tir::Stmt MakeDeviceHookCall(const std::string& device_name, const 
std::string& hook,
+                             PrimExpr context) {
+  Array<String> sections = {"Device", device_name, hook};
+  String device_hook = ToCFunctionStyle(PrefixName(sections));
+
+  return tir::Evaluate(tir::Call(DataType::Int(32), 
tvm::tir::builtin::call_extern(),
+                                 {tvm::tir::StringImm(device_hook), context}));
+}
+}  // namespace

Review Comment:
   to make this a "static" function aka it cannot be linked against from 
elsewhere



##########
src/relay/backend/aot_executor_codegen.cc:
##########
@@ -260,9 +260,168 @@ class AOTOnDemandAllocator : public 
transform::DeviceAwareExprVisitor {
   std::vector<TensorType> return_ttypes_;
 };
 
-/*! \brief Code generator for AOT executor */
-class AOTExecutorCodegen : public MixedModeVisitor {
- protected:
+namespace {
+
+/*!
+ * \brief Utility function to convert a concrete integer to a PrimExpr.
+ * \param num the number to convert
+ * \return PrimExpr representing num
+ */
+inline PrimExpr ConstInt32(int32_t num) {

Review Comment:
   changed to use `tir::make_const`, which i think handles it



##########
src/target/llvm/codegen_cpu.cc:
##########
@@ -822,14 +827,46 @@ CodeGenCPU::PackedCall 
CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>&
   TypedPointer ret_tcode =
       CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(end)}, 
DataType::Int(32));
 
+  llvm::FunctionType* callee_ftype = nullptr;
+  llvm::Value* callee_value = nullptr;
+  std::vector<llvm::Value*> call_args;
+
+  if (use_string_lookup) {
+    callee_ftype = ftype_tvm_func_call_;
+    callee_value = RuntimeTVMFuncCall();
+    call_args.push_back(handle);
+  } else {
+    callee_ftype = ftype_tvm_backend_packed_c_func_;
+    callee_value = module_->getFunction(func_name);
+    if (callee_value == nullptr) {
+      callee_value =
+          llvm::Function::Create(ftype_tvm_backend_packed_c_func_, 
llvm::Function::ExternalLinkage,
+                                 func_name, module_.get());
+    }
+  }
+
+  if (use_string_lookup) {

Review Comment:
   done



##########
src/target/metadata_utils.cc:
##########
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/target/metadata_utils.cc
+ * \brief Defines utility functions and classes for emitting metadata.
+ */
+#include "metadata_utils.h"

Review Comment:
   done (let me know if more are needed)



##########
src/target/metadata_module.cc:
##########
@@ -144,6 +144,12 @@ static runtime::Module CreateCppMetadataModule(
         auto metadata_module = 
CreateCSourceCppMetadataModule(runtime_metadata);
         metadata_module->Import(target_module);
         target_module = metadata_module;
+#ifdef TVM_LLVM_VERSION

Review Comment:
   indicates USE_LLVM is ON. added a comment.



##########
src/target/llvm/llvm_module.cc:
##########
@@ -527,6 +527,46 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob")
       return runtime::Module(n);
     });
 
+runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata 
metadata, Target target,
+                                            tvm::relay::Runtime runtime) {
+  InitializeLLVM();
+  auto tm = GetLLVMTargetMachine(target);
+  bool system_lib = runtime->GetAttr<Bool>("system-lib").value_or(Bool(false));
+  auto ctx = std::make_shared<llvm::LLVMContext>();
+  std::unique_ptr<CodeGenCPU> cg{new CodeGenCPU()};
+
+  cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib,
+           false /* target_c_runtime */);
+
+  cg->DefineMetadata(metadata);
+  auto mod = cg->Finish();
+  mod->addModuleFlag(llvm::Module::Warning, "tvm_target",
+                     llvm::MDString::get(*ctx, LLVMTargetToString(target)));
+  mod->addModuleFlag(llvm::Module::Override, "Debug Info Version", 
llvm::DEBUG_METADATA_VERSION);
+
+  if (tm->getTargetTriple().isOSDarwin()) {
+    mod->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2);
+  }
+
+  std::string verify_errors_storage;
+  llvm::raw_string_ostream verify_errors(verify_errors_storage);
+  LOG_IF(FATAL, llvm::verifyModule(*mod, &verify_errors))
+      << "LLVM module verification failed with the following errors: \n"
+      << verify_errors.str();
+
+  // std::string tmp;
+  // llvm::raw_string_ostream stream(tmp);
+  // mod->print(stream, nullptr);
+  // LOG(INFO) << "LLVM metadata IR: " << stream.str();

Review Comment:
   done



##########
src/target/llvm/codegen_cpu.cc:
##########
@@ -914,6 +952,321 @@ llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() {
   return GetContextPtr(gv_tvm_parallel_barrier_);
 }
 
+/*! \brief Defines LLVM Types for each Metadata member type. */
+struct MetadataLlvmTypes {
+  llvm::Type* t_float64;
+  llvm::Type* t_uint8;
+  llvm::Type* t_int64;
+  llvm::Type* t_bool;
+  llvm::Type* t_cstring;
+  llvm::Type* t_void_p;
+  llvm::StructType* t_data_type;
+
+  /*! \brief Maps a MetadataBase subclass' type_key to its corresponding LLVM 
StructType. */
+  ::std::unordered_map<std::string, llvm::StructType*> structs_by_type_key;
+};
+
+class MetadataTypeDefiner : public AttrVisitor {
+ public:
+  MetadataTypeDefiner(llvm::LLVMContext* ctx, struct MetadataLlvmTypes* 
llvm_types)
+      : ctx_{ctx}, llvm_types_{llvm_types} {}
+
+  void Visit(const char* key, double* value) final {
+    elements_.emplace_back(llvm_types_->t_float64);
+  }
+  void Visit(const char* key, int64_t* value) final {
+    elements_.emplace_back(llvm_types_->t_int64);
+  }
+  void Visit(const char* key, uint64_t* value) final {
+    elements_.emplace_back(llvm_types_->t_int64);
+  }
+  void Visit(const char* key, int* value) final { 
elements_.emplace_back(llvm_types_->t_int64); }
+  void Visit(const char* key, bool* value) final { 
elements_.emplace_back(llvm_types_->t_bool); }
+  void Visit(const char* key, std::string* value) final {
+    elements_.emplace_back(llvm_types_->t_cstring);
+  }
+  void Visit(const char* key, void** value) final { 
elements_.emplace_back(llvm_types_->t_void_p); }
+  void Visit(const char* key, DataType* value) final {
+    elements_.emplace_back(llvm_types_->t_data_type);
+  }
+  void Visit(const char* key, runtime::NDArray* value) final {
+    CHECK(false) << "Do not support serializing NDArray";
+  }
+
+ private:
+  void VisitMetadataBase(runtime::metadata::MetadataBase metadata) {
+    elements_.emplace_back(llvm::PointerType::getUnqual(
+        llvm::StructType::create(*ctx_, metadata->get_c_struct_name())));
+    if (visited_.find(metadata->get_c_struct_name()) != visited_.end()) {
+      return;
+    }
+
+    if (to_visit_.find(metadata->get_c_struct_name()) != to_visit_.end()) {
+      return;
+    }
+    to_visit_[metadata->get_c_struct_name()] = metadata;
+  }
+
+ public:
+  using MetadataKind = runtime::metadata::MetadataKind;
+
+  void VisitArray(const runtime::metadata::MetadataArrayNode* arr) {
+    switch (arr->kind) {
+      case MetadataKind::kUint64:  // LLVM encodes signed and unsigned with 
same types.
+      case MetadataKind::kInt64:
+        
elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_int64));
+        break;
+      case MetadataKind::kBool:
+        
elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_bool));
+        break;
+      case MetadataKind::kString:
+        
elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_cstring));
+        break;
+      case MetadataKind::kHandle:
+        CHECK(false) << "Do not support handle";
+        break;
+      case MetadataKind::kMetadata:
+        elements_.emplace_back(
+            
llvm::PointerType::getUnqual(llvm_types_->structs_by_type_key[arr->type_key]));
+        break;
+      default:
+        CHECK(false) << "Unsupported metadata kind " << arr->kind;
+        break;
+    }
+  }
+
+  void Visit(const char* key, ObjectRef* value) final {
+    const runtime::metadata::MetadataArrayNode* arr =
+        value->as<runtime::metadata::MetadataArrayNode>();
+    if (arr != nullptr) {
+      VisitArray(arr);
+    } else {
+      elements_.emplace_back(
+          
llvm::PointerType::getUnqual(llvm_types_->structs_by_type_key[(*value)->GetTypeKey()]));
+    }
+  }
+
+  void DefineType(runtime::metadata::MetadataBase metadata) {
+    ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this);
+    LOG(INFO) << "Created type for " << metadata->GetTypeKey() << ":";
+    for (auto e : elements_) {
+      std::string value;
+      llvm::raw_string_ostream os(value);
+      e->print(os, true);
+      //      LOG(INFO) << " - " << e << ", tyid=" << e->getTypeID() << " == " 
<< value;

Review Comment:
   done



##########
src/relay/backend/aot_executor_codegen.cc:
##########
@@ -260,9 +260,168 @@ class AOTOnDemandAllocator : public 
transform::DeviceAwareExprVisitor {
   std::vector<TensorType> return_ttypes_;
 };
 
-/*! \brief Code generator for AOT executor */
-class AOTExecutorCodegen : public MixedModeVisitor {
- protected:
+namespace {
+
+/*!
+ * \brief Utility function to convert a concrete integer to a PrimExpr.
+ * \param num the number to convert
+ * \return PrimExpr representing num
+ */
+inline PrimExpr ConstInt32(int32_t num) {
+  ICHECK_LE(num, std::numeric_limits<int>::max());
+  return tir::make_const(DataType::Int(32), static_cast<int>(num));
+}
+
+/*!
+ * \brief Emit a call to the C Device API.
+ * \param device_name Name of the device, used to prefix the function name.
+ * \param hook Name of the Device API function.
+ * \param context void* context arg passed to this API function.
+ */
+tir::Stmt MakeDeviceHookCall(const std::string& device_name, const 
std::string& hook,
+                             PrimExpr context) {
+  Array<String> sections = {"Device", device_name, hook};
+  String device_hook = ToCFunctionStyle(PrefixName(sections));
+
+  return tir::Evaluate(tir::Call(DataType::Int(32), 
tvm::tir::builtin::call_extern(),
+                                 {tvm::tir::StringImm(device_hook), context}));
+}
+}  // namespace
+
+class AOTCallGenerator {

Review Comment:
   good point--this is cruft from the old approach, before I started relying on 
LegalizePackedCalls/LowerTVMBuiltin, so can remove it now :)



##########
src/target/llvm/codegen_cpu.cc:
##########
@@ -802,10 +803,14 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const 
std::string& fname) {
 
 CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const 
Array<PrimExpr>& args,
                                                          const DataType& 
r_type,
-                                                         const int64_t begin, 
const int64_t end) {
+                                                         const int64_t begin, 
const int64_t end,
+                                                         bool 
use_string_lookup) {
   PackedCall pc;
   std::string func_name = args[0].as<StringImmNode>()->value;
-  llvm::Value* handle = GetPackedFuncHandle(func_name);
+  llvm::Value* handle = nullptr;
+  if (use_string_lookup) {

Review Comment:
   done!



##########
src/target/llvm/codegen_llvm.cc:
##########
@@ -1399,9 +1411,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const 
BufferLoadNode* op) {
 }
 
 llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
+  //  LOG(INFO) << "Visit Call:" << GetRef<Call>(op);
   if (auto* ptr_op = op->op.as<OpNode>()) {
     auto call_op = GetRef<Op>(ptr_op);
-    if (op->op.same_as(builtin_call_extern_) || 
op->op.same_as(builtin_call_pure_extern_)) {
+    if (op->op.same_as(builtin_lookup_param_)) {
+      //      return llvm::ConstantInt::get(t_void_p_, 0);

Review Comment:
   done



##########
src/target/llvm/codegen_llvm.cc:
##########
@@ -1399,9 +1411,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const 
BufferLoadNode* op) {
 }
 
 llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
+  //  LOG(INFO) << "Visit Call:" << GetRef<Call>(op);

Review Comment:
   done



##########
src/target/llvm/codegen_llvm.h:
##########
@@ -389,6 +406,16 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const 
PrimExpr&)>,
                                              unsigned int 
shared_address_space, int alignment,
                                              llvm::GlobalValue::LinkageTypes 
linkage);
 
+  llvm::Argument* GetArg(const llvm::Function* function, int i) const {

Review Comment:
   done



##########
src/relay/backend/aot_executor_codegen.cc:
##########
@@ -260,9 +260,168 @@ class AOTOnDemandAllocator : public 
transform::DeviceAwareExprVisitor {
   std::vector<TensorType> return_ttypes_;
 };
 
-/*! \brief Code generator for AOT executor */
-class AOTExecutorCodegen : public MixedModeVisitor {
- protected:
+namespace {
+
+/*!
+ * \brief Utility function to convert a concrete integer to a PrimExpr.
+ * \param num the number to convert
+ * \return PrimExpr representing num
+ */
+inline PrimExpr ConstInt32(int32_t num) {
+  ICHECK_LE(num, std::numeric_limits<int>::max());
+  return tir::make_const(DataType::Int(32), static_cast<int>(num));
+}
+
+/*!
+ * \brief Emit a call to the C Device API.
+ * \param device_name Name of the device, used to prefix the function name.
+ * \param hook Name of the Device API function.
+ * \param context void* context arg passed to this API function.
+ */
+tir::Stmt MakeDeviceHookCall(const std::string& device_name, const 
std::string& hook,
+                             PrimExpr context) {
+  Array<String> sections = {"Device", device_name, hook};
+  String device_hook = ToCFunctionStyle(PrefixName(sections));
+
+  return tir::Evaluate(tir::Call(DataType::Int(32), 
tvm::tir::builtin::call_extern(),
+                                 {tvm::tir::StringImm(device_hook), context}));
+}
+}  // namespace
+
+class AOTCallGenerator {
+ public:
+  explicit AOTCallGenerator(std::string func_name)
+      : func_name_{func_name}, args_{tvm::tir::StringImm(func_name)} {}
+
+  tir::Var PushArg(PrimExpr arg) {
+    if (!arg->IsInstance<tir::VarNode>()) {
+      arg = MakeLetBind(arg);
+    }
+    args_.push_back(arg);
+    return Downcast<tir::Var>(arg);
+  }
+
+  void PushStackDLTensor(const TensorType& ttype, PrimExpr data) {
+    auto dltensor_var = MakeLetBind(StackAlloca("array", 1));
+    auto shape_var = MakeLetBind(StackAlloca("shape", ttype->shape.size()));
+
+    // Populate DLTensor.data
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), 
tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, tir::builtin::kArrData, 
data})));
+
+    // Populate DLTensor.device
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), 
tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, 
tir::builtin::kArrDeviceType, kDLCPU})));
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), 
tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, 
tir::builtin::kArrDeviceId, 0})));
+
+    // Populate DLTensor.ndim
+    prep_stmts_.push_back(tir::Evaluate(tvm::tir::Call(
+        DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+        {dltensor_var, 0, tir::builtin::kArrNDim, 
static_cast<int32_t>(ttype->shape.size())})));
+
+    // Populate DLTensor.dtype
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), 
tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, 
tir::builtin::kArrTypeCode,
+                                      IntImm(DataType(kDLUInt, 8, 1), 
ttype->dtype.code())})));
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), 
tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, 
tir::builtin::kArrTypeBits,
+                                      IntImm(DataType(kDLUInt, 8, 1), 
ttype->dtype.bits())})));
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), 
tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, 
tir::builtin::kArrTypeLanes,
+                                      IntImm(DataType(kDLUInt, 16, 1), 
ttype->dtype.lanes())})));
+
+    // Populate DLTensor.shape
+    for (size_t i = 0; i < ttype->shape.size(); ++i) {
+      prep_stmts_.push_back(tvm::tir::Store(
+          shape_var, IntImm(DataType(kDLInt, 64, 1), 
Downcast<IntImm>(ttype->shape[i])->value),
+          IntImm(DataType(kDLUInt, 64, 1), i), tir::const_true()));
+    }
+
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), 
tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, 
tir::builtin::kArrShape, shape_var})));
+
+    // Populate DLTensor.strides. DNS -- TODO actually pull correct byte_offset
+    prep_stmts_.push_back(tir::Evaluate(tvm::tir::Call(
+        DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+        {dltensor_var, 0, tir::builtin::kArrStrides, IntImm(DataType(kDLUInt, 
64, 1), 0)})));
+
+    // Populate DLTensor.byte_offset. DNS -- TODO actually pull correct 
byte_offset
+    prep_stmts_.push_back(tir::Evaluate(tvm::tir::Call(
+        DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+        {dltensor_var, 0, tir::builtin::kArrByteOffset, 
IntImm(DataType(kDLUInt, 64, 1), 0)})));
+
+    args_.push_back(dltensor_var);
+  }
+
+  void PushStackDLTensors(const Expr& expr, std::vector<tir::Var> sids) {
+    const TupleNode* t = expr.as<TupleNode>();
+    if (t != nullptr) {
+      CHECK_EQ(sids.size(), t->fields.size()) << "Relay tuple does not map 1:1 
into TIR; AOT can't "
+                                                 "handle this type of Relay 
Expr in a CallNode.";
+      for (size_t i = 0; i < sids.size(); i++) {
+        PushStackDLTensor(Downcast<TensorType>(t->fields[i]->checked_type()), 
sids[i]);
+      }
+    } else {
+      PushStackDLTensor(Downcast<TensorType>(expr->checked_type()), sids[0]);
+    }
+  }
+
+  tir::Stmt GenerateUnpacked(std::string device_name, PrimExpr device_context) 
{
+    auto make_call = [this] {
+      return tir::Evaluate(
+          tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), 
args_));
+    };
+    if (device_context.defined()) {
+      tir::Var context_var = PushArg(device_context);
+      return Generate(tir::SeqStmt({
+          MakeDeviceHookCall(device_name, "Open", context_var),
+          make_call(),
+          MakeDeviceHookCall(device_name, "Close", context_var),
+      }));
+    } else {
+      return Generate(make_call());
+    }
+  }
+
+  tir::Stmt GeneratePacked() {
+    return Generate(
+        tir::Evaluate(tvm::tir::Call(DataType::Int(32), 
tir::builtin::tvm_call_packed(), args_)));
+  }
+
+  tir::Stmt GenerateCPacked() {
+    // call_cpacked calling convention does not use a context
+    PushArg(tir::make_zero(DataType::Handle()));
+    return Generate(
+        tir::Evaluate(tvm::tir::Call(DataType::Int(32), 
tir::builtin::tvm_call_cpacked(), args_)));
+  }
+
+ private:
+  tir::Stmt Generate(tir::Stmt call_stmts) {
+    tir::Stmt body = tir::SeqStmt::Flatten(prep_stmts_, call_stmts);
+
+    for (auto bind : let_binds_) {
+      body = tir::LetStmt(bind.first, bind.second, body);

Review Comment:
   yeah sorry, this part is cruft and should get deleted.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to