areusch commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r603671640



##########
File path: src/relay/backend/graph_runtime_codegen.cc
##########
@@ -181,23 +186,56 @@ class GraphOpNode : public GraphNode {
   const std::string op_type_name_{"tvm_op"};
 };
 
-/*! \brief Code generator for graph runtime */
+/*! \brief Code generator for the graph runtime, produces a module containing 
the graph JSON,
+ * module, and parameters.
+ */
 class GraphRuntimeCodegen : public 
backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
  public:
   GraphRuntimeCodegen(runtime::Module* mod, const TargetsMap& targets) : 
mod_(mod) {
-    compile_engine_ = CompileEngine::Global();
     targets_ = targets;
   }
 
   LoweredOutput Codegen(relay::Function func) {
-    auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
-    storage_device_map_ = (*pf)(func);
+    // Jared: why do we do this? just call C++ API.

Review comment:
       at some point we should support overriding this. it seems like you'd 
need to build a C++ impl right now to do that, so I don't know that we need 
this hook now. but, we should plan to put some customization back in the future?

##########
File path: src/relay/backend/graph_plan_memory.cc
##########
@@ -252,9 +265,22 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
   }
   // The call map
   void VisitExpr_(const CallNode* op) final {
+    // temporary hack for change of style.

Review comment:
       same

##########
File path: src/driver/driver_api.cc
##########
@@ -244,14 +244,17 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule 
mod_mixed, const Target
   }
 
   if (target->kind->device_type == kDLCPU && target_host == target) {
-    ICHECK(mdevice->functions.empty()) << "No device code should be generated 
when target "
-                                       << "and host_target are both llvm 
target."
-                                       << "\n";
+    // We need to relax this check for just TIR functions.

Review comment:
       +1

##########
File path: src/relay/backend/te_compiler_cache.cc
##########
@@ -0,0 +1,681 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "./te_compiler_cache.h"
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/type_functor.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/topi/tags.h>
+
+#include <functional>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../transforms/pass_utils.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relay {
+namespace tec {
+
+TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
+TVM_REGISTER_NODE_TYPE(CachedFuncNode);
+TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
+TVM_REGISTER_NODE_TYPE(CCacheValueNode);
+
+LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation 
impl) {
+  auto n = make_object<LoweredOutputNode>();
+  n->outputs = std::move(outputs);
+  n->implementation = std::move(impl);
+  data_ = std::move(n);
+}
+
+CCacheKey::CCacheKey(Function source_func, Target target) {
+  auto n = make_object<CCacheKeyNode>();
+  n->source_func = std::move(source_func);
+  n->target = std::move(target);
+  data_ = std::move(n);
+}
+
+CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var, 
tvm::Array<te::Tensor> inputs,
+                       tvm::Array<te::Tensor> outputs, te::Schedule schedule,
+                       tvm::Array<Integer> shape_func_param_states, IRModule 
funcs) {
+  auto n = make_object<CachedFuncNode>();
+  n->target = target;
+  n->prim_fn_var = prim_fn_var;
+  n->inputs = inputs;
+  n->outputs = outputs;
+  n->schedule = schedule;
+  n->shape_func_param_states = shape_func_param_states;
+  n->funcs = funcs;
+  data_ = std::move(n);
+}
+
+Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
+  // for now, we always use int32 shape when possible
+  // even if the result of shape inference becomes int64.
+  Array<IndexExpr> res;
+  for (IndexExpr val : shape) {
+    const int64_t* pval = tir::as_const_int(val);
+    if (pval != nullptr) {
+#ifndef TVM_INDEX_DEFAULT_I64
+      ICHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
+      ICHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
+      res.push_back(IntImm(DataType::Int(32), *pval));
+#else
+      res.push_back(val);
+#endif  // TVM_INDEX_DEFAULT_I64
+    } else if (val->IsInstance<tir::AnyNode>()) {
+      res.push_back(val.as<tir::AnyNode>()->ToVar());
+    } else {
+      res.push_back(val);
+    }
+  }
+  return res;
+}
+
+// The getter to get schedule from compile engine.

Review comment:
       this isn't really a Getter is it? What does it translate from and to?

##########
File path: tests/python/relay/test_backend_graph_runtime.py
##########
@@ -142,7 +143,7 @@ def test_plan_memory():
     # Current rule requires vars have unique storage id
     # because we don't do inplace, we will need another
     # two alternating temporary space.
-    assert len(storage_ids) == 4
+    assert len(storage_ids) == 4, f"found storaged_ids: {storage_ids}"

Review comment:
       nit: storage_ids

##########
File path: tests/python/relay/test_backend_graph_runtime.py
##########
@@ -126,6 +126,7 @@ def test_plan_memory():
     func = relay.Function([x, y], z)
     mod = tvm.IRModule.from_expr(func)
     mod = relay.transform.InferType()(mod)
+    print(mod)

Review comment:
       remove?

##########
File path: tests/python/relay/test_backend_graph_runtime.py
##########
@@ -231,10 +232,12 @@ def test_graph_executor_nested_tuples():
 
 
 if __name__ == "__main__":
-    test_plan_memory()
-    test_with_params()
-    test_add_op_scalar()
     test_add_op_tensor()
-    test_add_op_broadcast()
-    test_gru_like()
-    test_compile_nested_tuples()
+    # test_plan_memory()
+    # test_with_params()
+    # test_add_op_scalar()
+    # test_add_op_tensor()
+    # test_add_op_broadcast()
+    # test_gru_like()
+    # test_compile_nested_tuples()
+    # test_add_op_tensor()

Review comment:
       sys.exit(pytest.main([__file__] + sys.argv[1:]))

##########
File path: src/relay/backend/graph_plan_memory.cc
##########
@@ -166,10 +167,22 @@ class StorageAllocaInit : protected 
StorageAllocaBaseVisitor {
   }
 
   void VisitExpr_(const CallNode* op) final {
+    // temporary hack for change of style.

Review comment:
       can you state which arg is being dropped here?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to