manupa-arm commented on a change in pull request #9418:
URL: https://github.com/apache/tvm/pull/9418#discussion_r763193087



##########
File path: src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
##########
@@ -0,0 +1,350 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/usmp/transform/convert_pool_allocations_to_offsets.cc
+ * \brief This pass would convert the pool allocations to offsets from pools
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+#include <tvm/tir/usmp/utils.h>
+
+#include <stack>
+
+namespace tvm {
+namespace tir {
+namespace usmp {
+
+class PoolAllocationToOffsetConverter : public StmtExprMutator {
+ public:
+  explicit PoolAllocationToOffsetConverter(const IRModule& module,
+                                           const Map<tir::Stmt, 
PoolAllocation>& pool_allocations,
+                                           bool emit_tvmscript_printable = 
false)
+      : pool_allocations_(pool_allocations), 
emit_tvmscript_printable_(emit_tvmscript_printable) {
+    module_ = module->ShallowCopy();
+    for (const auto& gv_func : module_->functions) {
+      function_global_vars_.Set(gv_func.first->name_hint, gv_func.first);
+    }
+    for (const auto& kv : pool_allocations) {
+      // TODO(@manupa-arm): add AllocateConstNode when it is available
+      ICHECK(kv.first->IsInstance<AllocateNode>());
+      Allocate allocate_node = Downcast<Allocate>(kv.first);
+      PoolAllocation pool_allocation = kv.second;
+      PoolInfo pool_info = pool_allocation->pool_info;
+      int byte_pool_offset = pool_allocation->byte_offset->value;
+      int required_pool_size_for_allocation =
+          byte_pool_offset + CalculateExtentsSize(allocate_node.operator->());
+      if (all_pools_sizes_.find(pool_info) == all_pools_sizes_.end()) {
+        all_pools_sizes_[pool_info] = required_pool_size_for_allocation;
+      } else {
+        int prev_required_pool_size = all_pools_sizes_[pool_info];
+        if (prev_required_pool_size < required_pool_size_for_allocation) {
+          all_pools_sizes_[pool_info] = required_pool_size_for_allocation;
+        }
+      }
+    }
+
+    for (const auto& kv : all_pools_sizes_) {
+      PoolInfo pi = kv.first;
+      int allocated_size = kv.second;
+      allocated_pool_ordering_.push_back(AllocatedPoolInfo(pi, 
allocated_size));
+    }
+    std::sort(allocated_pool_ordering_.begin(), allocated_pool_ordering_.end(),
+              [](const AllocatedPoolInfo& lhs, const AllocatedPoolInfo& rhs) {
+                if (lhs->pool_info->pool_name < rhs->pool_info->pool_name) {
+                  return true;
+                }
+                return false;
+              });
+  }
+  IRModule operator()();
+
+ private:
+  PrimExpr VisitExpr_(const CallNode* op) override;
+  Stmt VisitStmt_(const AllocateNode* op) override;
+  //  PrimExpr VisitExpr_(const VarNode* op) override;
+  PrimExpr VisitExpr_(const LoadNode* op) override;
+  Stmt VisitStmt_(const StoreNode* op) override;
+
+  /*! \brief This is a structure where the modified function
+   * signature is kept while body of the function is mutated
+   */
+  struct ScopeInfo {
+    Array<tir::Var> params;
+    Map<PoolInfo, tir::Var> pools_to_params;
+    Array<AllocatedPoolInfo> allocated_pool_params;
+    Map<tir::Var, Buffer> buffer_map;
+  };
+
+  /*! \brief The function scope information that are needed
+   * in the mutation of the function need to be stacked and
+   * popped when each function is entered/exited in the
+   * mutation process.
+   */
+  std::stack<ScopeInfo> scope_stack;
+  /*! \brief Each PrimFunc signature needs to be updated
+   * with pool variables. This is a helper function to
+   * capture the updated information to ScopeInfo object.
+   */
+  ScopeInfo UpdateFunctionScopeInfo(const PrimFunc& original_func);
+  /*! \brief This is a helper to create the PrimFunc with
+   * pool variables that calls the UpdateFunctionScopeInfo
+   * inside of it.
+   */
+  PrimFunc CreatePrimFuncWithPoolParams(const PrimFunc& original_primfunc);
+  /*! \brief This is a helper to append the pool args to
+   * the callsite of the function.
+   */
+  Array<PrimExpr> AppendPoolParamsToArgs(const Array<PrimExpr>& args);
+  /*! \brief Some arguments that used to be Allocate nodes
+   * should be replaced by Let nodes in the pass that loads
+   * the space from a pool variable.
+   */
+  Array<PrimExpr> ReplaceAllocateArgsWithLetArgs(const Array<PrimExpr>& args);
+
+  /*! \brief The tir::Var map to PoolInfo objects */
+  Map<tir::Var, PoolInfo> primfunc_args_to_pool_info_map_;
+  /*! \brief The buffer var map to their allocate nodes */
+  Map<tir::Var, tir::Stmt> allocate_var_to_stmt_map_;
+  /*! \brief The IRModule being constructed/mutated */
+  IRModule module_;
+  /*! \brief The input allocate node to PoolAllocation map */
+  Map<tir::Stmt, PoolAllocation> pool_allocations_;
+  /*! \brief The set of ordered pools to ensure an unique order of args for 
functions */
+  std::vector<AllocatedPoolInfo> allocated_pool_ordering_;
+  /*! \brief The storage of calculated pool size at init */
+  std::unordered_map<PoolInfo, int, ObjectPtrHash, ObjectPtrEqual> 
all_pools_sizes_;
+  /*! \brief The AoT codegen uses extern_calls due to some functions not being 
exposed in the TIR
+   * IRModule This maps maintains the map of which to each function
+   */
+  Map<String, GlobalVar> function_global_vars_;

Review comment:
       Yes, you are right! removed it altogether.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to