Mousius commented on a change in pull request #9331: URL: https://github.com/apache/tvm/pull/9331#discussion_r736327262
########## File path: src/relay/backend/contrib/cmsisnn/extract_constants.cc ########## @@ -0,0 +1,158 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <tvm/relay/attrs/nn.h> +#include <tvm/relay/expr_functor.h> +#include <tvm/relay/transform.h> +#include <tvm/runtime/ndarray.h> + +#include "../../../qnn/utils.h" +#include "../../../transforms/pattern_utils.h" + +namespace tvm { +namespace relay { +namespace contrib { +namespace cmsisnn { + +class ExtractConstantsMutator : public MixedModeMutator { + public: + explicit ExtractConstantsMutator(IRModule& mod) : mod_(mod) {} + + private: + String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); } + + Expr VisitExpr_(const FunctionNode* func) final { + Function final_func = GetRef<Function>(func); + ++func_nesting_level_; + auto new_body = VisitExpr(func->body); + --func_nesting_level_; + if (!new_body.same_as(func->body)) { + final_func = Function(FreeVars(new_body), new_body, func->ret_type, + FreeTypeVars(new_body, mod_), func->attrs); + function_to_constants_.Set(GetRef<Function>(func), constants_within_function_); + constants_within_function_.clear(); + } + return final_func; + } + + Expr Rewrite_(const CallNode* call, const Expr& post) final { + Expr final_call = post; + auto* post_call = post.as<CallNode>(); + if (post_call == nullptr) { + return final_call; + } + + // Replace Constant arguments with Vars for ML Operators + // Perform this for non-main Call Nodes only + if (func_nesting_level_ && call->op.as<OpNode>()) { + Array<Expr> new_args; + for (auto& arg : post_call->args) { + auto* const_arg = arg.as<ConstantNode>(); + if (const_arg && !const_arg->is_scalar()) { + Var var_arg = Var(gen_var_name(), const_arg->tensor_type()); + new_args.push_back(var_arg); + constants_within_function_.push_back(GetRef<Constant>(const_arg)); + } else { + new_args.push_back(arg); + } + } + final_call = Call(call->op, new_args, call->attrs, {}); + } + + // Since the constants are kicked out of partitioned functions + // a new call to global function is needed + if (auto* glob_var_node = post_call->op.as<GlobalVarNode>()) { + auto glob_var = GetRef<GlobalVar>(glob_var_node); + auto glob_func = Downcast<Function>(mod_->Lookup(glob_var)); + auto new_glob_func = VisitExpr(glob_func); + if (!new_glob_func.same_as(glob_func)) { + mod_->Update(glob_var, Downcast<Function>(new_glob_func)); + Array<Expr> new_args = post_call->args; + ICHECK(function_to_constants_.find(glob_func) != function_to_constants_.end()); Review comment: @manupa-arm my bad actually, I rechecked and `unordered_set` is hash based not tree based - sorry about that! In general though, for smaller sets, iterating over the entire set can be either similar time or even faster depending on the complexity of the hashing or tree algorithm. As for TVM's `Map` vs `unordered_map`, both are hash maps so I don't believe there'll be any performance difference between them. Internally it looks like `Map` uses `unordered_map`: https://github.com/apache/tvm/blob/75a8fa1fb7011636066c839c6abd5f1c2e9056ba/include/tvm/runtime/container/map.h#L51 And there's pushes elsewhere in the codebase to use `Map` to be consistent: https://github.com/apache/tvm/blob/75a8fa1fb7011636066c839c6abd5f1c2e9056ba/src/relay/backend/te_compiler.h#L61-L63 Given this pass is only meant for CMSIS-NN it's bounded to a lower count of functions which will fit on an embedded device, that is the scope in which this contribution was made. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
