manupa-arm commented on a change in pull request #8468:
URL: https://github.com/apache/tvm/pull/8468#discussion_r724328703



##########
File path: include/tvm/tir/usmp/utils.h
##########
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/usmp/utils.h
+ * \brief Utilities for Unified Static Memory Planner
+ */
+
+#ifndef TVM_TIR_USMP_UTILS_H_
+#define TVM_TIR_USMP_UTILS_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/tir/stmt.h>
+
+namespace tvm {
+namespace tir {
+namespace usmp {
+
+/*!
+ * \brief The buffer information to be used by USMP
+ */
+struct BufferInfoNode : public Object {

Review comment:
       Done

##########
File path: src/tir/usmp/analysis/extract_buffer_info.cc
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/usmp/convert_for_loops_serial.cc
+ * \brief Convert all for loops to serial for lesser memory consumption
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/usmp/utils.h>
+
+#include <stack>
+
+namespace tvm {
+namespace tir {
+namespace usmp {
+
+class BufferInfoExtractor : public StmtExprVisitor {
+ public:
+  explicit BufferInfoExtractor(const IRModule& module) : module_(module) {
+    for (const auto& gv_func : module_->functions) {
+      functions.Set(gv_func.first->name_hint, 
Downcast<PrimFunc>(gv_func.second));
+    }
+    // Pushing a scope info for the initial body of the main function
+    scope_stack.push(ScopeInfo());
+  }
+  Map<tir::Stmt, BufferInfo> operator()(const PrimFunc& func);
+
+ private:
+  void VisitStmt(const Stmt& n) override;
+  void VisitStmt_(const AllocateNode* op) override;
+  void VisitExpr_(const CallNode* op) override;
+  void VisitExpr_(const VarNode* op) override;
+  void VisitExpr_(const LoadNode* op) override;
+  void VisitStmt_(const StoreNode* op) override;
+  void VisitStmt_(const ForNode* op) override;
+
+  void UpdateAliases(const Array<PrimExpr>& args, const PrimFunc& func);
+
+  Map<tir::Stmt, BufferInfo> buffer_info_map;
+  Map<tir::Stmt, Integer> buffer_info_start_stmt_idx;
+  Map<tir::Stmt, Integer> buffer_info_end_stmt_idx;
+  Map<tir::Var, tir::Stmt> allocate_var_to_stmt_map;
+
+  std::unordered_set<Stmt, ObjectPtrHash, ObjectPtrEqual> 
currently_live_allocates;
+  int current_stmt_idx = 0;
+  struct ScopeInfo {
+    For for_loop;
+  };
+  std::stack<ScopeInfo> scope_stack;
+
+  Map<String, PrimFunc> functions;
+  IRModule module_;
+};
+
+void BufferInfoExtractor::VisitStmt(const Stmt& n) {
+  current_stmt_idx += 1;
+  StmtExprVisitor::VisitStmt(n);
+}
+
+size_t static CalculateExtentsSize(const AllocateNode* op) {
+  size_t element_size_bytes = op->dtype.bytes();
+  size_t num_elements = 1;
+  for (const auto& ext : op->extents) {
+    if (ext->IsInstance<IntImmNode>()) {
+      num_elements *= Downcast<IntImm>(ext)->value;
+    } else {
+      // We cant statically calculate workspace for dynamic shapes
+      num_elements = 0;
+    }
+  }
+  return (num_elements * element_size_bytes);
+}
+
+Array<String> static ParseCommaSeperatedString(const String& cs_string) {
+  std::stringstream ss(cs_string->data);
+  Array<String> storage_pools;
+  while (ss.good()) {
+    std::string storage_pool_name;
+    std::getline(ss, storage_pool_name, ',');
+    storage_pools.push_back(storage_pool_name);
+  }
+  return storage_pools;
+}
+
+void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) {
+  const auto& currect_scope_info = scope_stack.top();
+  const auto& type = Downcast<PointerType>(op->buffer_var->type_annotation);
+  const auto& storage_scope = type->storage_scope;
+
+  // If the allocate is in a for loop,
+  // USMP currently only looks at serial for loops.
+  if ((!currect_scope_info.for_loop.defined()) ||
+      (currect_scope_info.for_loop.defined() &&
+       currect_scope_info.for_loop->kind == ForKind::kSerial && storage_scope 
== "global")) {
+    // USMP can only work with buffers that have global storage_scope
+    auto size_bytes = CalculateExtentsSize(op);
+    if (size_bytes) {
+      // We only statically memory plan only allocates with known
+      // compile time sizes.
+      auto buffer_info = BufferInfo(op->buffer_var->name_hint, size_bytes);
+      
buffer_info->SetPoolCandidates(ParseCommaSeperatedString(op->pinned_memory));

Review comment:
       Explained more and inserted checks

##########
File path: include/tvm/tir/usmp/utils.h
##########
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/usmp/utils.h
+ * \brief Utilities for Unified Static Memory Planner
+ */
+
+#ifndef TVM_TIR_USMP_UTILS_H_
+#define TVM_TIR_USMP_UTILS_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/tir/stmt.h>
+
+namespace tvm {
+namespace tir {
+namespace usmp {
+
+/*!
+ * \brief The buffer information to be used by USMP
+ */
+struct BufferInfoNode : public Object {
+  /*! \brief The name of the buffer var */
+  String name_hint;
+  /*! \brief The size in terms of bytes */
+  Integer size_bytes;
+  /*! \brief The byte alignment required within the pool */
+  Integer alignment;
+  /*! \brief The liveness conflicting other buffer info objects */
+  Array<ObjectRef> conflicts;
+  /*! \brief The names of the pool candidates that this buffer can get pooled 
to*/

Review comment:
       It will never be empty -- the default will always be a size unrestricted 
pool.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to