zhiics commented on a change in pull request #7369:
URL: https://github.com/apache/tvm/pull/7369#discussion_r566548070



##########
File path: src/relay/transforms/memory_alloc.cc
##########
@@ -0,0 +1,508 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/memory_alloc.cc
+ * \brief A pass for manifesting explicit memory allocations.
+ */
+
+#include <tvm/node/structural_equal.h>
+#include <tvm/node/structural_hash.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/transform.h>
+#include <tvm/support/logging.h>
+#include <tvm/target/target.h>
+
+#include <cstdint>
+#include <cstdio>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "../backend/compile_engine.h"
+#include "let_list.h"
+
+using namespace tvm::runtime;
+
+namespace tvm {
+namespace relay {
+
+extern bool IsDynamic(const Type& ty);
+extern Expr ToTupleType(const Type& ty, const std::vector<Expr>& exprs);
+extern std::vector<Expr> FromTupleType(const Type& type, const Expr& expr);
+extern std::vector<TensorType> FlattenTupleType(const Type& type);
+
+using AnalysisResultMap =
+    std::unordered_map<Expr, TVMContext, runtime::ObjectPtrHash, 
runtime::ObjectPtrEqual>;
+
+inline Constant MakeConstant(int64_t value) {
+  auto tensor = NDArray::Empty({}, {kDLInt, 64, 1}, {kDLCPU, 0});
+  reinterpret_cast<int64_t*>(tensor->data)[0] = value;
+  return std::move(Constant(tensor));
+}
+
+inline Constant MakeConstant(const std::vector<int64_t>& value) {
+  auto tensor = NDArray::Empty({static_cast<int>(value.size())}, {kDLInt, 64, 
1}, {kDLCPU, 0});
+  for (size_t i = 0; i < value.size(); i++) {
+    reinterpret_cast<int64_t*>(tensor->data)[i] = value[i];
+  }
+  return std::move(Constant(tensor));
+}
+
+inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType 
dtype,
+                        Array<IndexExpr> assert_shape) {
+  auto f = runtime::Registry::Get("relay.op.memory._make.alloc_tensor");
+  CHECK(f != nullptr) << "unable to find alloc_tensor op";
+  auto offset = MakeConstant(0);
+  return (*f)(storage, offset, shape, dtype, assert_shape);
+}
+
+// A pass to check if the fused op contains only reshape ops.
+class CheckReshapeOnly : public ExprVisitor {
+ public:
+  CheckReshapeOnly()
+      : reshape_(Op::Get("reshape")),
+        contr_reshape_(Op::Get("contrib_reverse_reshape")),
+        dyn_reshape_(Op::Get("dyn.reshape")) {}
+
+  void VisitExpr_(const CallNode* cn) final {
+    if (!reshape_only) return;
+    if (cn->op != reshape_ && cn->op != contr_reshape_ && cn->op != 
dyn_reshape_) {
+      reshape_only = false;
+    }
+    for (auto arg : cn->args) ExprVisitor::VisitExpr(arg);
+  }
+
+  void VisitExpr_(const VarNode* vn) final {
+    if (!vn->checked_type_->IsInstance<TensorTypeNode>()) {
+      reshape_only = false;
+    }
+  }
+
+  const Op& reshape_;
+  const Op& contr_reshape_;
+  const Op& dyn_reshape_;
+  bool reshape_only{true};
+};
+
+// Check if the primitive function contains only reshape ops.
+bool IsReshapeOnly(const Expr& expr) {
+  auto check = CheckReshapeOnly();
+  check.VisitExpr(expr);
+  return check.reshape_only;
+}
+
+class DialectRewriter : public ExprMutator {
+ public:
+  DialectRewriter(const Target& target_host, const AnalysisResultMap& 
context_analysis_map)
+      : target_host_(target_host),
+        context_analysis_map_(context_analysis_map),
+        scopes_{LetList()},
+        device_copy_(runtime::Registry::Get("relay.op._make.device_copy")),
+        invoke_tvm_(runtime::Registry::Get("relay.op.vm.invoke_tvm_op")),
+        
alloc_storage_(runtime::Registry::Get("relay.op.memory._make.alloc_storage")),
+        shape_func_(runtime::Registry::Get("relay.op.vm.shape_func")),
+        shape_of_(runtime::Registry::Get("relay.op.vm.shape_of")),
+        reshape_tensor_(runtime::Registry::Get("relay.op.vm.reshape_tensor")) 
{}
+
+  // Get the context of an expression.
+  TVMContext GetContext(const Expr& expr) const {
+    auto it = context_analysis_map_.find(expr);
+    CHECK(it != context_analysis_map_.end()) << "Cannot find expr in the 
context analysis map:\n"
+                                             << AsText(expr, false);
+    return it->second;
+  }
+
+  Function Rewrite(const Function& expr) {
+    auto ret = ExprMutator::Mutate(expr);
+    return Downcast<Function>(ret);
+  }
+
+  Expr VisitExpr_(const TupleNode* tn) final {
+    LetList& scope = scopes_.back();
+    Array<Expr> new_fields;
+    for (auto field : tn->fields) {
+      auto new_field = ExprMutator::Mutate(field);
+      if (new_field->IsInstance<ConstantNode>()) {
+        Var const_var("const", Type(nullptr));
+        new_field = scope.Push(const_var, new_field);
+      }
+      new_fields.push_back(new_field);
+    }
+    return Tuple(new_fields);
+  }
+
+  Expr VisitExpr_(const LetNode* ln) final {
+    scopes_.emplace_back();
+
+    const LetNode* let = ln;
+    Expr body;
+    while (let) {
+      auto new_value = ExprMutator::Mutate(let->value);
+      scopes_.back().Push(let->var, new_value);
+      body = let->body;
+      let = body.as<LetNode>();
+    }
+
+    CHECK(body.defined());
+    auto new_body = ExprMutator::Mutate(body);
+    auto ret = scopes_.back().Get(new_body);
+    scopes_.pop_back();
+    return ret;
+  }
+
+  Expr VisitExpr_(const CallNode* cn) final {
+    if (IsPrimitive(cn)) {
+      // Because we are in ANF we do not need to visit the arguments.
+      LetList& scope = scopes_.back();
+      std::vector<Expr> new_args;
+      for (const auto& it : cn->args) {
+        new_args.push_back(ExprMutator::Mutate(it));
+      }
+
+      Tuple ins(new_args);
+      Type ret_type = cn->checked_type_;
+      std::vector<TensorType> out_types = FlattenTupleType(ret_type);
+
+      // Handle fused op that only contains reshape op
+      if (IsReshapeOnly(cn->op)) {
+        Function func = Downcast<Function>(cn->op);
+        return EmitReshapeTensor(&scope, func, new_args, ret_type);
+      }
+
+      // Handle device copy op
+      if (IsDeviceCopy(cn->op)) {
+        Attrs attr;
+        if (const auto* fn = cn->op.as<FunctionNode>()) {
+          const auto* copy_call = fn->body.as<CallNode>();
+          CHECK(copy_call);
+          attr = copy_call->attrs;
+        } else {
+          attr = cn->attrs;
+        }
+        const DeviceCopyAttrs* copy_attr = attr.as<DeviceCopyAttrs>();
+        CHECK(copy_attr);
+        return DeviceCopy(new_args[0], copy_attr->src_dev_type, 
copy_attr->dst_dev_type);
+      } else if (IsDynamic(ret_type)) {
+        Function func = Downcast<Function>(cn->op);
+        return DynamicInvoke(&scope, func, ins, new_args, out_types, ret_type);
+      } else {
+        // Handle the static case
+        Array<Expr> outs;
+        for (size_t i = 0; i < out_types.size(); ++i) {
+          TVMContext ctx = GetContext(GetRef<Call>(cn));
+          auto out = MakeStaticAllocation(&scope, out_types[i], ctx, 
std::to_string(i));
+          outs.push_back(out);
+        }
+        Tuple output(outs);
+        Expr invoke = (*invoke_tvm_)(cn->op, ins, output);
+        scope.Push(invoke);
+        return ToTupleType(ret_type,
+                           std::vector<Expr>(output->fields.begin(), 
output->fields.end()));
+      }
+    } else {
+      return ExprMutator::VisitExpr_(cn);
+    }
+  }
+
+ private:
+  // Insert a device copy node.
+  Expr DeviceCopy(const Expr& inp, int src_ctx, int dst_ctx) {
+    return ExprMutator::Mutate((*device_copy_)(inp, src_ctx, dst_ctx));
+  }
+
+  // Check if a call invokes a primitive function.
+  bool IsPrimitive(const CallNode* call) const {
+    if (const auto* fn = call->op.as<FunctionNode>()) {
+      return fn->HasNonzeroAttr(attr::kPrimitive);
+    }
+    return false;
+  }
+
+  // Check if the current relay expression is a device copy call. We can simply
+  // check the body of it if it is a function because the device_copy op is 
opaque.
+  bool IsDeviceCopy(const Expr& expr) const {
+    if (const auto* fn = expr.as<FunctionNode>()) {
+      auto body = fn->body;
+      const CallNode* call = body.as<CallNode>();
+      return call && call->op == Op::Get("device_copy");
+    } else if (const CallNode* cn = expr.as<CallNode>()) {
+      return cn->op == Op::Get("device_copy");
+    } else {
+      return false;
+    }
+  }
+
+  Expr ComputeAlignment(const DataType& dtype) const {
+    int64_t align = dtype.bits() / 8 * dtype.lanes();
+    if (align < 64) {
+      align = 64;
+    }
+    return MakeConstant(align);
+  }
+
+  Expr ComputeStorageInRelay(const Expr& shape, const TensorType& type) const {
+    auto dtype = DataType(type->dtype);
+    auto fp = runtime::Registry::Get("relay.op._make.prod");
+    CHECK(fp) << "cannot find operator prod from the registry";
+    auto fa = runtime::Registry::Get("relay.op._make.add");
+    CHECK(fa) << "cannot find operator add from the registry";
+    auto fd = runtime::Registry::Get("relay.op._make.divide");
+    CHECK(fd) << "cannot find operator devide from the registry";
+    auto fm = runtime::Registry::Get("relay.op._make.multiply");
+    CHECK(fm) << "cannot find operator multiply from the registry";

Review comment:
       I dont have a strong preference. Maybe I want to still keep the C++ one. 
The Python one save a few lines but I don't really want to keep the original 
file or let the C++ pass call into Python




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to