manupa-arm commented on a change in pull request #9331:
URL: https://github.com/apache/tvm/pull/9331#discussion_r735463335



##########
File path: src/relay/backend/contrib/cmsisnn/extract_constants.cc
##########
@@ -0,0 +1,158 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+class ExtractConstantsMutator : public MixedModeMutator {
+ public:
+  explicit ExtractConstantsMutator(IRModule& mod) : mod_(mod) {}
+
+ private:
+  String gen_var_name() { return "tvm_var_extract_const_" + 
std::to_string(var_count_++); }
+
+  Expr VisitExpr_(const FunctionNode* func) final {
+    Function final_func = GetRef<Function>(func);
+    ++func_nesting_level_;
+    auto new_body = VisitExpr(func->body);
+    --func_nesting_level_;
+    if (!new_body.same_as(func->body)) {
+      final_func = Function(FreeVars(new_body), new_body, func->ret_type,
+                            FreeTypeVars(new_body, mod_), func->attrs);
+      function_to_constants_.Set(GetRef<Function>(func), 
constants_within_function_);
+      constants_within_function_.clear();
+    }
+    return final_func;
+  }
+
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
+    Expr final_call = post;
+    auto* post_call = post.as<CallNode>();
+    if (post_call == nullptr) {
+      return final_call;
+    }
+
+    // Replace Constant arguments with Vars for ML Operators
+    // Perform this for non-main Call Nodes only
+    if (func_nesting_level_ && call->op.as<OpNode>()) {
+      Array<Expr> new_args;
+      for (auto& arg : post_call->args) {
+        auto* const_arg = arg.as<ConstantNode>();
+        if (const_arg && !const_arg->is_scalar()) {
+          Var var_arg = Var(gen_var_name(), const_arg->tensor_type());
+          new_args.push_back(var_arg);
+          constants_within_function_.push_back(GetRef<Constant>(const_arg));
+        } else {
+          new_args.push_back(arg);
+        }
+      }
+      final_call = Call(call->op, new_args, call->attrs, {});
+    }
+
+    // Since the constants are kicked out of partitioned functions
+    // a new call to global function is needed
+    if (auto* glob_var_node = post_call->op.as<GlobalVarNode>()) {
+      auto glob_var = GetRef<GlobalVar>(glob_var_node);
+      auto glob_func = Downcast<Function>(mod_->Lookup(glob_var));
+      auto new_glob_func = VisitExpr(glob_func);
+      if (!new_glob_func.same_as(glob_func)) {
+        mod_->Update(glob_var, Downcast<Function>(new_glob_func));
+        Array<Expr> new_args = post_call->args;
+        ICHECK(function_to_constants_.find(glob_func) != 
function_to_constants_.end());

Review comment:
       Oops, it seems like we are discussing about a Map rather than the Array 
of it.
   Can't we use unordered_map then ?
   
   So my reasoning is using the .find(..) with an Array/Map -- Im not sure the 
Array/Map is commonly used like that especially when it is an internal data 
structure that is not exposed to python.
   
   
https://github.com/apache/tvm/blob/df2bcfbd239c486fbad1b2ddf335389cb6c35fb7/src/tir/transforms/loop_partition.cc#L331
   
   
https://github.com/apache/tvm/blob/df2bcfbd239c486fbad1b2ddf335389cb6c35fb7/src/relay/transforms/fuse_ops.cc#L854
   
   
https://github.com/apache/tvm/blob/df2bcfbd239c486fbad1b2ddf335389cb6c35fb7/src/relay/transforms/fuse_ops.cc#L1008
   
   So two questions here from code quality perspective,
   
    @ashutosh-arm , if the number of partitioned functions are bounded 
(prefferably to a constant) then it might not make an effect. Is it the case ? 
-- say if we think of a large network with many operators. If that is the case, 
it may be fine to overlook this.
   
   @Mousius , Hmmm, The tree is used to make the .find(..) faster. right ?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to