areusch commented on a change in pull request #8468:
URL: https://github.com/apache/tvm/pull/8468#discussion_r692438124



##########
File path: include/tvm/tir/stmt.h
##########
@@ -521,6 +521,11 @@ class AllocateNode : public StmtNode {
   PrimExpr condition;
   /*! \brief The body to be executed. */
   Stmt body;
+  /*! \brief If the allocate is scoped global, this field indicates
+   *  which external memories it could be pinned to as a comma seperated
+   *  string.
+   */
+  String pinned_memory;

Review comment:
       seems this should be a vector. why is it a comma-separated string?

##########
File path: src/tir/usmp/analysis/extract_buffer_info.cc
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/usmp/convert_for_loops_serial.cc
+ * \brief Convert all for loops to serial for lesser memory consumption
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/usmp/utils.h>
+
+#include <stack>
+
+namespace tvm {
+namespace tir {
+namespace usmp {
+
+class BufferInfoExtractor : public StmtExprVisitor {
+ public:
+  explicit BufferInfoExtractor(const IRModule& module) : module_(module) {
+    for (const auto& gv_func : module_->functions) {
+      functions.Set(gv_func.first->name_hint, 
Downcast<PrimFunc>(gv_func.second));
+    }
+    // Pushing a scope info for the initial body of the main function
+    scope_stack.push(ScopeInfo());
+  }
+  Map<tir::Stmt, BufferInfo> operator()(const PrimFunc& func);
+
+ private:
+  void VisitStmt(const Stmt& n) override;
+  void VisitStmt_(const AllocateNode* op) override;
+  void VisitExpr_(const CallNode* op) override;
+  void VisitExpr_(const VarNode* op) override;
+  void VisitExpr_(const LoadNode* op) override;
+  void VisitStmt_(const StoreNode* op) override;
+  void VisitStmt_(const ForNode* op) override;
+
+  void UpdateAliases(const Array<PrimExpr>& args, const PrimFunc& func);
+
+  Map<tir::Stmt, BufferInfo> buffer_info_map;

Review comment:
       per https://google.github.io/styleguide/cppguide.html#Variable_Names, 
add an `_` at the end

##########
File path: include/tvm/tir/usmp/utils.h
##########
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/usmp/utils.h
+ * \brief Utilities for Unified Static Memory Planner
+ */
+
+#ifndef TVM_TIR_USMP_UTILS_H_
+#define TVM_TIR_USMP_UTILS_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/tir/stmt.h>
+
+namespace tvm {
+namespace tir {
+namespace usmp {
+
+/*!
+ * \brief The buffer information to be used by USMP
+ */
+struct BufferInfoNode : public Object {

Review comment:
       per my RFC comment, can you split this into planner input struct and 
annotation struct?

##########
File path: src/tir/usmp/analysis/extract_buffer_info.cc
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/usmp/convert_for_loops_serial.cc
+ * \brief Convert all for loops to serial for lesser memory consumption
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/usmp/utils.h>
+
+#include <stack>
+
+namespace tvm {
+namespace tir {
+namespace usmp {
+
+class BufferInfoExtractor : public StmtExprVisitor {
+ public:
+  explicit BufferInfoExtractor(const IRModule& module) : module_(module) {
+    for (const auto& gv_func : module_->functions) {
+      functions.Set(gv_func.first->name_hint, 
Downcast<PrimFunc>(gv_func.second));
+    }
+    // Pushing a scope info for the initial body of the main function
+    scope_stack.push(ScopeInfo());
+  }
+  Map<tir::Stmt, BufferInfo> operator()(const PrimFunc& func);
+
+ private:
+  void VisitStmt(const Stmt& n) override;
+  void VisitStmt_(const AllocateNode* op) override;
+  void VisitExpr_(const CallNode* op) override;
+  void VisitExpr_(const VarNode* op) override;
+  void VisitExpr_(const LoadNode* op) override;
+  void VisitStmt_(const StoreNode* op) override;
+  void VisitStmt_(const ForNode* op) override;
+
+  void UpdateAliases(const Array<PrimExpr>& args, const PrimFunc& func);
+
+  Map<tir::Stmt, BufferInfo> buffer_info_map;
+  Map<tir::Stmt, Integer> buffer_info_start_stmt_idx;
+  Map<tir::Stmt, Integer> buffer_info_end_stmt_idx;
+  Map<tir::Var, tir::Stmt> allocate_var_to_stmt_map;
+
+  std::unordered_set<Stmt, ObjectPtrHash, ObjectPtrEqual> 
currently_live_allocates;
+  int current_stmt_idx = 0;
+  struct ScopeInfo {
+    For for_loop;
+  };
+  std::stack<ScopeInfo> scope_stack;
+
+  Map<String, PrimFunc> functions;
+  IRModule module_;
+};
+
+void BufferInfoExtractor::VisitStmt(const Stmt& n) {
+  current_stmt_idx += 1;
+  StmtExprVisitor::VisitStmt(n);
+}
+
+size_t static CalculateExtentsSize(const AllocateNode* op) {
+  size_t element_size_bytes = op->dtype.bytes();
+  size_t num_elements = 1;
+  for (const auto& ext : op->extents) {
+    if (ext->IsInstance<IntImmNode>()) {
+      num_elements *= Downcast<IntImm>(ext)->value;
+    } else {
+      // We cant statically calculate workspace for dynamic shapes
+      num_elements = 0;
+    }
+  }
+  return (num_elements * element_size_bytes);
+}
+
+Array<String> static ParseCommaSeperatedString(const String& cs_string) {
+  std::stringstream ss(cs_string->data);
+  Array<String> storage_pools;
+  while (ss.good()) {
+    std::string storage_pool_name;
+    std::getline(ss, storage_pool_name, ',');
+    storage_pools.push_back(storage_pool_name);
+  }
+  return storage_pools;
+}
+
+void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) {
+  const auto& currect_scope_info = scope_stack.top();
+  const auto& type = Downcast<PointerType>(op->buffer_var->type_annotation);
+  const auto& storage_scope = type->storage_scope;
+
+  // If the allocate is in a for loop,
+  // USMP currently only looks at serial for loops.
+  if ((!currect_scope_info.for_loop.defined()) ||
+      (currect_scope_info.for_loop.defined() &&
+       currect_scope_info.for_loop->kind == ForKind::kSerial && storage_scope 
== "global")) {
+    // USMP can only work with buffers that have global storage_scope
+    auto size_bytes = CalculateExtentsSize(op);
+    if (size_bytes) {
+      // We only statically memory plan only allocates with known
+      // compile time sizes.
+      auto buffer_info = BufferInfo(op->buffer_var->name_hint, size_bytes);
+      
buffer_info->SetPoolCandidates(ParseCommaSeperatedString(op->pinned_memory));

Review comment:
       what if pinned_memory is empty?

##########
File path: include/tvm/tir/usmp/analysis.h
##########
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/tir/analysis.h
+ * \brief Analysis utilities and passes for TIR Unified Static Memory Planner.
+ */
+#ifndef TVM_TIR_USMP_ANALYSIS_H_

Review comment:
       why are you adding this file now?

##########
File path: include/tvm/tir/usmp/utils.h
##########
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/usmp/utils.h
+ * \brief Utilities for Unified Static Memory Planner
+ */
+
+#ifndef TVM_TIR_USMP_UTILS_H_
+#define TVM_TIR_USMP_UTILS_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/tir/stmt.h>
+
+namespace tvm {
+namespace tir {
+namespace usmp {
+
+/*!
+ * \brief The buffer information to be used by USMP
+ */
+struct BufferInfoNode : public Object {
+  /*! \brief The name of the buffer var */
+  String name_hint;
+  /*! \brief The size in terms of bytes */
+  Integer size_bytes;
+  /*! \brief The byte alignment required within the pool */
+  Integer alignment;
+  /*! \brief The liveness conflicting other buffer info objects */
+  Array<ObjectRef> conflicts;
+  /*! \brief The names of the pool candidates that this buffer can get pooled 
to*/

Review comment:
       it seems like this could be empty; what should a planner do then? does 
it get a list of global memories it could pick from? is there policy defined 
around each memory e.g. "can_fallback" or "readable_from_targets"

##########
File path: src/tir/usmp/analysis/extract_buffer_info.cc
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/usmp/convert_for_loops_serial.cc
+ * \brief Convert all for loops to serial for lesser memory consumption
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/usmp/utils.h>
+
+#include <stack>
+
+namespace tvm {
+namespace tir {
+namespace usmp {
+
+class BufferInfoExtractor : public StmtExprVisitor {
+ public:
+  explicit BufferInfoExtractor(const IRModule& module) : module_(module) {
+    for (const auto& gv_func : module_->functions) {
+      functions.Set(gv_func.first->name_hint, 
Downcast<PrimFunc>(gv_func.second));
+    }
+    // Pushing a scope info for the initial body of the main function
+    scope_stack.push(ScopeInfo());
+  }
+  Map<tir::Stmt, BufferInfo> operator()(const PrimFunc& func);
+
+ private:
+  void VisitStmt(const Stmt& n) override;
+  void VisitStmt_(const AllocateNode* op) override;
+  void VisitExpr_(const CallNode* op) override;
+  void VisitExpr_(const VarNode* op) override;
+  void VisitExpr_(const LoadNode* op) override;
+  void VisitStmt_(const StoreNode* op) override;
+  void VisitStmt_(const ForNode* op) override;
+
+  void UpdateAliases(const Array<PrimExpr>& args, const PrimFunc& func);
+
+  Map<tir::Stmt, BufferInfo> buffer_info_map;
+  Map<tir::Stmt, Integer> buffer_info_start_stmt_idx;
+  Map<tir::Stmt, Integer> buffer_info_end_stmt_idx;
+  Map<tir::Var, tir::Stmt> allocate_var_to_stmt_map;
+
+  std::unordered_set<Stmt, ObjectPtrHash, ObjectPtrEqual> 
currently_live_allocates;
+  int current_stmt_idx = 0;
+  struct ScopeInfo {
+    For for_loop;
+  };
+  std::stack<ScopeInfo> scope_stack;
+
+  Map<String, PrimFunc> functions;
+  IRModule module_;
+};
+
+void BufferInfoExtractor::VisitStmt(const Stmt& n) {
+  current_stmt_idx += 1;
+  StmtExprVisitor::VisitStmt(n);
+}
+
+size_t static CalculateExtentsSize(const AllocateNode* op) {

Review comment:
       static size_t?

##########
File path: src/tir/usmp/analysis/extract_buffer_info.cc
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/usmp/convert_for_loops_serial.cc
+ * \brief Convert all for loops to serial for lesser memory consumption
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/usmp/utils.h>
+
+#include <stack>
+
+namespace tvm {
+namespace tir {
+namespace usmp {
+
+class BufferInfoExtractor : public StmtExprVisitor {
+ public:
+  explicit BufferInfoExtractor(const IRModule& module) : module_(module) {
+    for (const auto& gv_func : module_->functions) {
+      functions.Set(gv_func.first->name_hint, 
Downcast<PrimFunc>(gv_func.second));
+    }
+    // Pushing a scope info for the initial body of the main function
+    scope_stack.push(ScopeInfo());
+  }
+  Map<tir::Stmt, BufferInfo> operator()(const PrimFunc& func);
+
+ private:
+  void VisitStmt(const Stmt& n) override;
+  void VisitStmt_(const AllocateNode* op) override;
+  void VisitExpr_(const CallNode* op) override;
+  void VisitExpr_(const VarNode* op) override;
+  void VisitExpr_(const LoadNode* op) override;
+  void VisitStmt_(const StoreNode* op) override;
+  void VisitStmt_(const ForNode* op) override;
+
+  void UpdateAliases(const Array<PrimExpr>& args, const PrimFunc& func);
+
+  Map<tir::Stmt, BufferInfo> buffer_info_map;
+  Map<tir::Stmt, Integer> buffer_info_start_stmt_idx;
+  Map<tir::Stmt, Integer> buffer_info_end_stmt_idx;
+  Map<tir::Var, tir::Stmt> allocate_var_to_stmt_map;
+
+  std::unordered_set<Stmt, ObjectPtrHash, ObjectPtrEqual> 
currently_live_allocates;
+  int current_stmt_idx = 0;
+  struct ScopeInfo {
+    For for_loop;
+  };
+  std::stack<ScopeInfo> scope_stack;
+
+  Map<String, PrimFunc> functions;
+  IRModule module_;
+};
+
+void BufferInfoExtractor::VisitStmt(const Stmt& n) {
+  current_stmt_idx += 1;
+  StmtExprVisitor::VisitStmt(n);
+}
+
+size_t static CalculateExtentsSize(const AllocateNode* op) {
+  size_t element_size_bytes = op->dtype.bytes();
+  size_t num_elements = 1;
+  for (const auto& ext : op->extents) {
+    if (ext->IsInstance<IntImmNode>()) {
+      num_elements *= Downcast<IntImm>(ext)->value;
+    } else {
+      // We cant statically calculate workspace for dynamic shapes

Review comment:
       nit: can't




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to