manupa-arm commented on a change in pull request #8468:
URL: https://github.com/apache/tvm/pull/8468#discussion_r724330521



##########
File path: src/tir/usmp/analysis/extract_buffer_info.cc
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/usmp/convert_for_loops_serial.cc
+ * \brief Convert all for loops to serial for lesser memory consumption
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/usmp/utils.h>
+
+#include <stack>
+
+namespace tvm {
+namespace tir {
+namespace usmp {
+
+class BufferInfoExtractor : public StmtExprVisitor {
+ public:
+  explicit BufferInfoExtractor(const IRModule& module) : module_(module) {
+    for (const auto& gv_func : module_->functions) {
+      functions.Set(gv_func.first->name_hint, 
Downcast<PrimFunc>(gv_func.second));
+    }
+    // Pushing a scope info for the initial body of the main function
+    scope_stack.push(ScopeInfo());
+  }
+  Map<tir::Stmt, BufferInfo> operator()(const PrimFunc& func);
+
+ private:
+  void VisitStmt(const Stmt& n) override;
+  void VisitStmt_(const AllocateNode* op) override;
+  void VisitExpr_(const CallNode* op) override;
+  void VisitExpr_(const VarNode* op) override;
+  void VisitExpr_(const LoadNode* op) override;
+  void VisitStmt_(const StoreNode* op) override;
+  void VisitStmt_(const ForNode* op) override;
+
+  void UpdateAliases(const Array<PrimExpr>& args, const PrimFunc& func);
+
+  Map<tir::Stmt, BufferInfo> buffer_info_map;
+  Map<tir::Stmt, Integer> buffer_info_start_stmt_idx;
+  Map<tir::Stmt, Integer> buffer_info_end_stmt_idx;
+  Map<tir::Var, tir::Stmt> allocate_var_to_stmt_map;
+
+  std::unordered_set<Stmt, ObjectPtrHash, ObjectPtrEqual> 
currently_live_allocates;
+  int current_stmt_idx = 0;
+  struct ScopeInfo {
+    For for_loop;
+  };
+  std::stack<ScopeInfo> scope_stack;
+
+  Map<String, PrimFunc> functions;
+  IRModule module_;
+};
+
+void BufferInfoExtractor::VisitStmt(const Stmt& n) {
+  current_stmt_idx += 1;
+  StmtExprVisitor::VisitStmt(n);
+}
+
+size_t static CalculateExtentsSize(const AllocateNode* op) {
+  size_t element_size_bytes = op->dtype.bytes();
+  size_t num_elements = 1;
+  for (const auto& ext : op->extents) {
+    if (ext->IsInstance<IntImmNode>()) {
+      num_elements *= Downcast<IntImm>(ext)->value;
+    } else {
+      // We cant statically calculate workspace for dynamic shapes
+      num_elements = 0;
+    }
+  }
+  return (num_elements * element_size_bytes);
+}
+
+Array<String> static ParseCommaSeperatedString(const String& cs_string) {
+  std::stringstream ss(cs_string->data);
+  Array<String> storage_pools;
+  while (ss.good()) {
+    std::string storage_pool_name;
+    std::getline(ss, storage_pool_name, ',');
+    storage_pools.push_back(storage_pool_name);
+  }
+  return storage_pools;
+}
+
+void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) {
+  const auto& currect_scope_info = scope_stack.top();
+  const auto& type = Downcast<PointerType>(op->buffer_var->type_annotation);
+  const auto& storage_scope = type->storage_scope;
+
+  // If the allocate is in a for loop,
+  // USMP currently only looks at serial for loops.
+  if ((!currect_scope_info.for_loop.defined()) ||
+      (currect_scope_info.for_loop.defined() &&
+       currect_scope_info.for_loop->kind == ForKind::kSerial && storage_scope 
== "global")) {
+    // USMP can only work with buffers that have global storage_scope
+    auto size_bytes = CalculateExtentsSize(op);
+    if (size_bytes) {
+      // We only statically memory plan only allocates with known
+      // compile time sizes.
+      auto buffer_info = BufferInfo(op->buffer_var->name_hint, size_bytes);
+      
buffer_info->SetPoolCandidates(ParseCommaSeperatedString(op->pinned_memory));

Review comment:
       Explained more and inserted checks




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to