manupa-arm commented on a change in pull request #10753:
URL: https://github.com/apache/tvm/pull/10753#discussion_r840519145



##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -260,9 +260,168 @@ class AOTOnDemandAllocator : public 
transform::DeviceAwareExprVisitor {
   std::vector<TensorType> return_ttypes_;
 };
 
-/*! \brief Code generator for AOT executor */
-class AOTExecutorCodegen : public MixedModeVisitor {
- protected:
+namespace {
+
+/*!
+ * \brief Utility function to convert a concrete integer to a PrimExpr.
+ * \param num the number to convert
+ * \return PrimExpr representing num
+ */
+inline PrimExpr ConstInt32(int32_t num) {
+  ICHECK_LE(num, std::numeric_limits<int>::max());
+  return tir::make_const(DataType::Int(32), static_cast<int>(num));
+}
+
+/*!
+ * \brief Emit a call to the C Device API.
+ * \param device_name Name of the device, used to prefix the function name.
+ * \param hook Name of the Device API function.
+ * \param context void* context arg passed to this API function.
+ */
+tir::Stmt MakeDeviceHookCall(const std::string& device_name, const 
std::string& hook,
+                             PrimExpr context) {
+  Array<String> sections = {"Device", device_name, hook};
+  String device_hook = ToCFunctionStyle(PrefixName(sections));
+
+  return tir::Evaluate(tir::Call(DataType::Int(32), 
tvm::tir::builtin::call_extern(),
+                                 {tvm::tir::StringImm(device_hook), context}));
+}
+}  // namespace
+
+class AOTCallGenerator {
+ public:
+  explicit AOTCallGenerator(std::string func_name)
+      : func_name_{func_name}, args_{tvm::tir::StringImm(func_name)} {}
+
+  tir::Var PushArg(PrimExpr arg) {
+    if (!arg->IsInstance<tir::VarNode>()) {
+      arg = MakeLetBind(arg);
+    }
+    args_.push_back(arg);
+    return Downcast<tir::Var>(arg);
+  }
+
+  void PushStackDLTensor(const TensorType& ttype, PrimExpr data) {
+    auto dltensor_var = MakeLetBind(StackAlloca("array", 1));
+    auto shape_var = MakeLetBind(StackAlloca("shape", ttype->shape.size()));
+
+    // Populate DLTensor.data
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), 
tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, tir::builtin::kArrData, 
data})));
+
+    // Populate DLTensor.device
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), 
tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, 
tir::builtin::kArrDeviceType, kDLCPU})));
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), 
tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, 
tir::builtin::kArrDeviceId, 0})));
+
+    // Populate DLTensor.ndim
+    prep_stmts_.push_back(tir::Evaluate(tvm::tir::Call(
+        DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+        {dltensor_var, 0, tir::builtin::kArrNDim, 
static_cast<int32_t>(ttype->shape.size())})));
+
+    // Populate DLTensor.dtype
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), 
tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, 
tir::builtin::kArrTypeCode,
+                                      IntImm(DataType(kDLUInt, 8, 1), 
ttype->dtype.code())})));
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), 
tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, 
tir::builtin::kArrTypeBits,
+                                      IntImm(DataType(kDLUInt, 8, 1), 
ttype->dtype.bits())})));
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), 
tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, 
tir::builtin::kArrTypeLanes,
+                                      IntImm(DataType(kDLUInt, 16, 1), 
ttype->dtype.lanes())})));
+
+    // Populate DLTensor.shape
+    for (size_t i = 0; i < ttype->shape.size(); ++i) {
+      prep_stmts_.push_back(tvm::tir::Store(
+          shape_var, IntImm(DataType(kDLInt, 64, 1), 
Downcast<IntImm>(ttype->shape[i])->value),
+          IntImm(DataType(kDLUInt, 64, 1), i), tir::const_true()));
+    }
+
+    prep_stmts_.push_back(
+        tir::Evaluate(tvm::tir::Call(DataType::Handle(), 
tvm::tir::builtin::tvm_struct_set(),
+                                     {dltensor_var, 0, 
tir::builtin::kArrShape, shape_var})));
+
+    // Populate DLTensor.strides. DNS -- TODO actually pull correct byte_offset
+    prep_stmts_.push_back(tir::Evaluate(tvm::tir::Call(
+        DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+        {dltensor_var, 0, tir::builtin::kArrStrides, IntImm(DataType(kDLUInt, 
64, 1), 0)})));
+
+    // Populate DLTensor.byte_offset. DNS -- TODO actually pull correct 
byte_offset
+    prep_stmts_.push_back(tir::Evaluate(tvm::tir::Call(
+        DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+        {dltensor_var, 0, tir::builtin::kArrByteOffset, 
IntImm(DataType(kDLUInt, 64, 1), 0)})));
+
+    args_.push_back(dltensor_var);
+  }
+
+  void PushStackDLTensors(const Expr& expr, std::vector<tir::Var> sids) {
+    const TupleNode* t = expr.as<TupleNode>();
+    if (t != nullptr) {
+      CHECK_EQ(sids.size(), t->fields.size()) << "Relay tuple does not map 1:1 
into TIR; AOT can't "
+                                                 "handle this type of Relay 
Expr in a CallNode.";
+      for (size_t i = 0; i < sids.size(); i++) {
+        PushStackDLTensor(Downcast<TensorType>(t->fields[i]->checked_type()), 
sids[i]);
+      }
+    } else {
+      PushStackDLTensor(Downcast<TensorType>(expr->checked_type()), sids[0]);
+    }
+  }
+
+  tir::Stmt GenerateUnpacked(std::string device_name, PrimExpr device_context) 
{
+    auto make_call = [this] {
+      return tir::Evaluate(
+          tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), 
args_));
+    };
+    if (device_context.defined()) {
+      tir::Var context_var = PushArg(device_context);
+      return Generate(tir::SeqStmt({
+          MakeDeviceHookCall(device_name, "Open", context_var),
+          make_call(),
+          MakeDeviceHookCall(device_name, "Close", context_var),
+      }));
+    } else {
+      return Generate(make_call());
+    }
+  }
+
+  tir::Stmt GeneratePacked() {
+    return Generate(
+        tir::Evaluate(tvm::tir::Call(DataType::Int(32), 
tir::builtin::tvm_call_packed(), args_)));
+  }
+
+  tir::Stmt GenerateCPacked() {
+    // call_cpacked calling convention does not use a context
+    PushArg(tir::make_zero(DataType::Handle()));
+    return Generate(
+        tir::Evaluate(tvm::tir::Call(DataType::Int(32), 
tir::builtin::tvm_call_cpacked(), args_)));
+  }
+
+ private:
+  tir::Stmt Generate(tir::Stmt call_stmts) {
+    tir::Stmt body = tir::SeqStmt::Flatten(prep_stmts_, call_stmts);
+
+    for (auto bind : let_binds_) {
+      body = tir::LetStmt(bind.first, bind.second, body);

Review comment:
       Why do we need a Let binding here? Cant we just arg bind w/o introducing 
a let node here ?
   
   
https://github.com/apache/tvm/blob/95df0eb1461718d9d1453d2ba4beb9441c5cab3c/src/tir/transforms/arg_binder.h#L74-L75

##########
File path: src/target/llvm/llvm_module.cc
##########
@@ -527,6 +527,46 @@ TVM_REGISTER_GLOBAL("codegen.codegen_blob")
       return runtime::Module(n);
     });
 
+runtime::Module CreateLLVMCppMetadataModule(runtime::metadata::Metadata 
metadata, Target target,
+                                            tvm::relay::Runtime runtime) {
+  InitializeLLVM();
+  auto tm = GetLLVMTargetMachine(target);
+  bool system_lib = runtime->GetAttr<Bool>("system-lib").value_or(Bool(false));
+  auto ctx = std::make_shared<llvm::LLVMContext>();
+  std::unique_ptr<CodeGenCPU> cg{new CodeGenCPU()};
+
+  cg->Init("TVMMetadataMod", tm.get(), ctx.get(), system_lib, system_lib,
+           false /* target_c_runtime */);
+
+  cg->DefineMetadata(metadata);
+  auto mod = cg->Finish();
+  mod->addModuleFlag(llvm::Module::Warning, "tvm_target",
+                     llvm::MDString::get(*ctx, LLVMTargetToString(target)));
+  mod->addModuleFlag(llvm::Module::Override, "Debug Info Version", 
llvm::DEBUG_METADATA_VERSION);
+
+  if (tm->getTargetTriple().isOSDarwin()) {
+    mod->addModuleFlag(llvm::Module::Override, "Dwarf Version", 2);
+  }
+
+  std::string verify_errors_storage;
+  llvm::raw_string_ostream verify_errors(verify_errors_storage);
+  LOG_IF(FATAL, llvm::verifyModule(*mod, &verify_errors))
+      << "LLVM module verification failed with the following errors: \n"
+      << verify_errors.str();
+
+  // std::string tmp;
+  // llvm::raw_string_ostream stream(tmp);
+  // mod->print(stream, nullptr);
+  // LOG(INFO) << "LLVM metadata IR: " << stream.str();

Review comment:
       remove

##########
File path: src/target/llvm/codegen_cpu.cc
##########
@@ -802,10 +803,14 @@ llvm::Value* CodeGenCPU::GetPackedFuncHandle(const 
std::string& fname) {
 
 CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const 
Array<PrimExpr>& args,
                                                          const DataType& 
r_type,
-                                                         const int64_t begin, 
const int64_t end) {
+                                                         const int64_t begin, 
const int64_t end,
+                                                         bool 
use_string_lookup) {
   PackedCall pc;
   std::string func_name = args[0].as<StringImmNode>()->value;
-  llvm::Value* handle = GetPackedFuncHandle(func_name);
+  llvm::Value* handle = nullptr;
+  if (use_string_lookup) {

Review comment:
       I think we dont need to introduce handle just yet.
   
   Shall we just merged to a single if/else down ? --  so its clear what 
happens when string-based function lookup is not used.

##########
File path: src/target/llvm/codegen_llvm.cc
##########
@@ -1399,9 +1411,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const 
BufferLoadNode* op) {
 }
 
 llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
+  //  LOG(INFO) << "Visit Call:" << GetRef<Call>(op);

Review comment:
       remove

##########
File path: src/target/metadata_module.cc
##########
@@ -144,6 +144,12 @@ static runtime::Module CreateCppMetadataModule(
         auto metadata_module = 
CreateCSourceCppMetadataModule(runtime_metadata);
         metadata_module->Import(target_module);
         target_module = metadata_module;
+#ifdef TVM_LLVM_VERSION

Review comment:
       Why this Ifdef ?

##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -260,9 +260,168 @@ class AOTOnDemandAllocator : public 
transform::DeviceAwareExprVisitor {
   std::vector<TensorType> return_ttypes_;
 };
 
-/*! \brief Code generator for AOT executor */
-class AOTExecutorCodegen : public MixedModeVisitor {
- protected:
+namespace {
+
+/*!
+ * \brief Utility function to convert a concrete integer to a PrimExpr.
+ * \param num the number to convert
+ * \return PrimExpr representing num
+ */
+inline PrimExpr ConstInt32(int32_t num) {
+  ICHECK_LE(num, std::numeric_limits<int>::max());
+  return tir::make_const(DataType::Int(32), static_cast<int>(num));
+}
+
+/*!
+ * \brief Emit a call to the C Device API.
+ * \param device_name Name of the device, used to prefix the function name.
+ * \param hook Name of the Device API function.
+ * \param context void* context arg passed to this API function.
+ */
+tir::Stmt MakeDeviceHookCall(const std::string& device_name, const 
std::string& hook,
+                             PrimExpr context) {
+  Array<String> sections = {"Device", device_name, hook};
+  String device_hook = ToCFunctionStyle(PrefixName(sections));
+
+  return tir::Evaluate(tir::Call(DataType::Int(32), 
tvm::tir::builtin::call_extern(),
+                                 {tvm::tir::StringImm(device_hook), context}));
+}
+}  // namespace

Review comment:
       Why this empty namespace ?

##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -260,9 +260,168 @@ class AOTOnDemandAllocator : public 
transform::DeviceAwareExprVisitor {
   std::vector<TensorType> return_ttypes_;
 };
 
-/*! \brief Code generator for AOT executor */
-class AOTExecutorCodegen : public MixedModeVisitor {
- protected:
+namespace {
+
+/*!
+ * \brief Utility function to convert a concrete integer to a PrimExpr.
+ * \param num the number to convert
+ * \return PrimExpr representing num
+ */
+inline PrimExpr ConstInt32(int32_t num) {

Review comment:
       Maybe not needed for this PR, but should we move this as 
tir::make_const_int32 ?

##########
File path: src/target/llvm/codegen_llvm.h
##########
@@ -389,6 +406,16 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const 
PrimExpr&)>,
                                              unsigned int 
shared_address_space, int alignment,
                                              llvm::GlobalValue::LinkageTypes 
linkage);
 
+  llvm::Argument* GetArg(const llvm::Function* function, int i) const {

Review comment:
       Should we add docs for this ?

##########
File path: src/target/llvm/codegen_cpu.cc
##########
@@ -822,14 +827,46 @@ CodeGenCPU::PackedCall 
CodeGenCPU::MakeCallPackedLowered(const Array<PrimExpr>&
   TypedPointer ret_tcode =
       CreateBufferPtr(stack_tcode, DataType::Int(32), {ConstInt32(end)}, 
DataType::Int(32));
 
+  llvm::FunctionType* callee_ftype = nullptr;
+  llvm::Value* callee_value = nullptr;
+  std::vector<llvm::Value*> call_args;
+
+  if (use_string_lookup) {
+    callee_ftype = ftype_tvm_func_call_;
+    callee_value = RuntimeTVMFuncCall();
+    call_args.push_back(handle);
+  } else {
+    callee_ftype = ftype_tvm_backend_packed_c_func_;
+    callee_value = module_->getFunction(func_name);
+    if (callee_value == nullptr) {
+      callee_value =
+          llvm::Function::Create(ftype_tvm_backend_packed_c_func_, 
llvm::Function::ExternalLinkage,
+                                 func_name, module_.get());
+    }
+  }
+
+  if (use_string_lookup) {

Review comment:
       Here also lets merge this if/else blocks.

##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -260,9 +260,168 @@ class AOTOnDemandAllocator : public 
transform::DeviceAwareExprVisitor {
   std::vector<TensorType> return_ttypes_;
 };
 
-/*! \brief Code generator for AOT executor */
-class AOTExecutorCodegen : public MixedModeVisitor {
- protected:
+namespace {
+
+/*!
+ * \brief Utility function to convert a concrete integer to a PrimExpr.
+ * \param num the number to convert
+ * \return PrimExpr representing num
+ */
+inline PrimExpr ConstInt32(int32_t num) {
+  ICHECK_LE(num, std::numeric_limits<int>::max());
+  return tir::make_const(DataType::Int(32), static_cast<int>(num));
+}
+
+/*!
+ * \brief Emit a call to the C Device API.
+ * \param device_name Name of the device, used to prefix the function name.
+ * \param hook Name of the Device API function.
+ * \param context void* context arg passed to this API function.
+ */
+tir::Stmt MakeDeviceHookCall(const std::string& device_name, const 
std::string& hook,
+                             PrimExpr context) {
+  Array<String> sections = {"Device", device_name, hook};
+  String device_hook = ToCFunctionStyle(PrefixName(sections));
+
+  return tir::Evaluate(tir::Call(DataType::Int(32), 
tvm::tir::builtin::call_extern(),
+                                 {tvm::tir::StringImm(device_hook), context}));
+}
+}  // namespace
+
+class AOTCallGenerator {

Review comment:
       docs : we need docs for this but I think this part is still WIP as I did 
not see who uses this

##########
File path: src/target/llvm/codegen_cpu.cc
##########
@@ -914,6 +952,321 @@ llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() {
   return GetContextPtr(gv_tvm_parallel_barrier_);
 }
 
+/*! \brief Defines LLVM Types for each Metadata member type. */
+struct MetadataLlvmTypes {
+  llvm::Type* t_float64;
+  llvm::Type* t_uint8;
+  llvm::Type* t_int64;
+  llvm::Type* t_bool;
+  llvm::Type* t_cstring;
+  llvm::Type* t_void_p;
+  llvm::StructType* t_data_type;
+
+  /*! \brief Maps a MetadataBase subclass' type_key to its corresponding LLVM 
StructType. */
+  ::std::unordered_map<std::string, llvm::StructType*> structs_by_type_key;
+};
+
+class MetadataTypeDefiner : public AttrVisitor {
+ public:
+  MetadataTypeDefiner(llvm::LLVMContext* ctx, struct MetadataLlvmTypes* 
llvm_types)
+      : ctx_{ctx}, llvm_types_{llvm_types} {}
+
+  void Visit(const char* key, double* value) final {
+    elements_.emplace_back(llvm_types_->t_float64);
+  }
+  void Visit(const char* key, int64_t* value) final {
+    elements_.emplace_back(llvm_types_->t_int64);
+  }
+  void Visit(const char* key, uint64_t* value) final {
+    elements_.emplace_back(llvm_types_->t_int64);
+  }
+  void Visit(const char* key, int* value) final { 
elements_.emplace_back(llvm_types_->t_int64); }
+  void Visit(const char* key, bool* value) final { 
elements_.emplace_back(llvm_types_->t_bool); }
+  void Visit(const char* key, std::string* value) final {
+    elements_.emplace_back(llvm_types_->t_cstring);
+  }
+  void Visit(const char* key, void** value) final { 
elements_.emplace_back(llvm_types_->t_void_p); }
+  void Visit(const char* key, DataType* value) final {
+    elements_.emplace_back(llvm_types_->t_data_type);
+  }
+  void Visit(const char* key, runtime::NDArray* value) final {
+    CHECK(false) << "Do not support serializing NDArray";
+  }
+
+ private:
+  void VisitMetadataBase(runtime::metadata::MetadataBase metadata) {
+    elements_.emplace_back(llvm::PointerType::getUnqual(
+        llvm::StructType::create(*ctx_, metadata->get_c_struct_name())));
+    if (visited_.find(metadata->get_c_struct_name()) != visited_.end()) {
+      return;
+    }
+
+    if (to_visit_.find(metadata->get_c_struct_name()) != to_visit_.end()) {
+      return;
+    }
+    to_visit_[metadata->get_c_struct_name()] = metadata;
+  }
+
+ public:
+  using MetadataKind = runtime::metadata::MetadataKind;
+
+  void VisitArray(const runtime::metadata::MetadataArrayNode* arr) {
+    switch (arr->kind) {
+      case MetadataKind::kUint64:  // LLVM encodes signed and unsigned with 
same types.
+      case MetadataKind::kInt64:
+        
elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_int64));
+        break;
+      case MetadataKind::kBool:
+        
elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_bool));
+        break;
+      case MetadataKind::kString:
+        
elements_.emplace_back(llvm::PointerType::getUnqual(llvm_types_->t_cstring));
+        break;
+      case MetadataKind::kHandle:
+        CHECK(false) << "Do not support handle";
+        break;
+      case MetadataKind::kMetadata:
+        elements_.emplace_back(
+            
llvm::PointerType::getUnqual(llvm_types_->structs_by_type_key[arr->type_key]));
+        break;
+      default:
+        CHECK(false) << "Unsupported metadata kind " << arr->kind;
+        break;
+    }
+  }
+
+  void Visit(const char* key, ObjectRef* value) final {
+    const runtime::metadata::MetadataArrayNode* arr =
+        value->as<runtime::metadata::MetadataArrayNode>();
+    if (arr != nullptr) {
+      VisitArray(arr);
+    } else {
+      elements_.emplace_back(
+          
llvm::PointerType::getUnqual(llvm_types_->structs_by_type_key[(*value)->GetTypeKey()]));
+    }
+  }
+
+  void DefineType(runtime::metadata::MetadataBase metadata) {
+    ReflectionVTable::Global()->VisitAttrs(metadata.operator->(), this);
+    LOG(INFO) << "Created type for " << metadata->GetTypeKey() << ":";
+    for (auto e : elements_) {
+      std::string value;
+      llvm::raw_string_ostream os(value);
+      e->print(os, true);
+      //      LOG(INFO) << " - " << e << ", tyid=" << e->getTypeID() << " == " 
<< value;

Review comment:
       remove

##########
File path: src/target/llvm/codegen_llvm.cc
##########
@@ -1399,9 +1411,13 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const 
BufferLoadNode* op) {
 }
 
 llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
+  //  LOG(INFO) << "Visit Call:" << GetRef<Call>(op);
   if (auto* ptr_op = op->op.as<OpNode>()) {
     auto call_op = GetRef<Op>(ptr_op);
-    if (op->op.same_as(builtin_call_extern_) || 
op->op.same_as(builtin_call_pure_extern_)) {
+    if (op->op.same_as(builtin_lookup_param_)) {
+      //      return llvm::ConstantInt::get(t_void_p_, 0);

Review comment:
       remove

##########
File path: src/target/metadata_utils.cc
##########
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/target/metadata_utils.cc
+ * \brief Defines utility functions and classes for emitting metadata.
+ */
+#include "metadata_utils.h"

Review comment:
       Please add docs for all the functions introduced here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to