This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dbe95c43b2 [MSC][BugFix] Bugfix for strided_slice op (#17315)
dbe95c43b2 is described below

commit dbe95c43b2afde26eab428181d47cfc939d153c1
Author: Archermmt <[email protected]>
AuthorDate: Fri Sep 6 20:45:36 2024 +0800

    [MSC][BugFix] Bugfix for strided_slice op (#17315)
    
    support strided_slice
---
 src/contrib/msc/core/codegen/base_codegen.h        |  6 +-
 src/contrib/msc/core/ir/graph_builder.cc           | 13 ++++-
 .../msc/core/transform/bind_named_params.cc        |  2 +-
 src/contrib/msc/core/utils.cc                      | 67 +++++++++++++++++++++-
 src/contrib/msc/core/utils.h                       | 54 ++++++++++++++---
 tests/python/contrib/test_msc/test_graph_build.py  |  3 -
 .../contrib/test_msc/test_translate_relax.py       |  4 --
 .../contrib/test_msc/test_translate_tensorflow.py  |  4 --
 .../contrib/test_msc/test_translate_torch.py       |  3 -
 9 files changed, 128 insertions(+), 28 deletions(-)

diff --git a/src/contrib/msc/core/codegen/base_codegen.h 
b/src/contrib/msc/core/codegen/base_codegen.h
index 19d8b524b9..acaac896a1 100644
--- a/src/contrib/msc/core/codegen/base_codegen.h
+++ b/src/contrib/msc/core/codegen/base_codegen.h
@@ -179,17 +179,17 @@ class BaseCodeGen {
       return 1;
     }
     if (node->scope.size() == scopes_.top().size()) {
-      ICHECK(StringUtils::CompareArrays(node->scope, scopes_.top()))
+      ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top()))
           << "Scope mismatch, node " << node->scope << " compare to current " 
<< scopes_.top();
       return 0;
     } else if (node->scope.size() == scopes_.top().size() + 1) {
-      ICHECK(StringUtils::CompareArrays(node->scope, scopes_.top(), 
scopes_.top().size()))
+      ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top(), 
scopes_.top().size()))
           << "Scope increase mismatch, node " << node->scope << " compare to 
current "
           << scopes_.top();
       scopes_.push(node->scope);
       return 1;
     } else if (node->scope.size() == scopes_.top().size() - 1) {
-      ICHECK(StringUtils::CompareArrays(node->scope, scopes_.top(), 
node->scope.size()))
+      ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top(), 
node->scope.size()))
           << "Scope decrease mismatch, node " << node->scope << " compare to 
current "
           << scopes_.top();
       scopes_.pop();
diff --git a/src/contrib/msc/core/ir/graph_builder.cc 
b/src/contrib/msc/core/ir/graph_builder.cc
index d35a462579..a968df4204 100644
--- a/src/contrib/msc/core/ir/graph_builder.cc
+++ b/src/contrib/msc/core/ir/graph_builder.cc
@@ -23,6 +23,7 @@
 
 #include "graph_builder.h"
 
+#include <algorithm>
 #include <set>
 
 namespace tvm {
@@ -71,6 +72,13 @@ void RelaxFuncValueGetter::VisitExpr_(const relax::CallNode* 
op) {
   for (const auto& arg : op->args) {
     if (const auto* s_node = arg.as<relax::PrimValueNode>()) {
       values_.push_back(StringUtils::ToString(s_node->value));
+    } else if (const auto* s_node = arg.as<relax::TupleNode>()) {
+      bool all_values =
+          std::all_of(s_node->fields.begin(), s_node->fields.end(),
+                      [](const relax::Expr& e) { return 
e->IsInstance<relax::PrimValueNode>(); });
+      if (all_values) {
+        values_.push_back(StringUtils::ToString(s_node->fields));
+      }
     }
   }
 }
@@ -337,6 +345,8 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, 
const Optional<Expr>
         ICHECK(input_types[i] != "input") << i << " th PrimValue of " << optype
                                           << " should has special type, get " 
<< input_types;
         attrs.Set(input_types[i], StringUtils::ToString(s_node->value));
+      } else if (input_types[i] != "input" && 
arg->IsInstance<relax::TupleNode>()) {
+        attrs.Set(input_types[i], StringUtils::ToString(arg));
       }
     }
     for (size_t i = call->args.size(); i < input_types.size(); i++) {
@@ -371,7 +381,8 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, 
const Optional<Expr>
       Array<String> arg_names;
       if (expr_tensor_map_.count(arg)) {
         arg_names = expr_tensor_map_[arg];
-      } else if (const auto* tuple_node = arg.as<relax::TupleNode>()) {
+      } else if (input_types[i] == "input" && 
arg->IsInstance<relax::TupleNode>()) {
+        const auto* tuple_node = arg.as<relax::TupleNode>();
         for (const auto& f : tuple_node->fields) {
           ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << 
f;
           for (const auto& in_name : expr_tensor_map_[f]) {
diff --git a/src/contrib/msc/core/transform/bind_named_params.cc 
b/src/contrib/msc/core/transform/bind_named_params.cc
index 5ba1ca30eb..6256fae05f 100644
--- a/src/contrib/msc/core/transform/bind_named_params.cc
+++ b/src/contrib/msc/core/transform/bind_named_params.cc
@@ -84,7 +84,7 @@ std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>> 
NormalizeNamedBindings(
     if (auto opt = obj.as<relax::Expr>()) {
       return opt.value();
     } else if (auto opt = obj.as<runtime::NDArray>()) {
-      const auto& span = SpanUtils::SetAttr(Span(), msc_attr::kName, 
key->name_hint());
+      const auto& span = SpanUtils::CreateWithAttr(msc_attr::kName, 
key->name_hint());
       return Constant(opt.value(), StructInfo(), span);
     } else {
       LOG(FATAL) << "Cannot coerce object of type " << obj->GetTypeKey()
diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc
index 5fcbe924ae..c6e74d4284 100644
--- a/src/contrib/msc/core/utils.cc
+++ b/src/contrib/msc/core/utils.cc
@@ -280,6 +280,8 @@ const String StringUtils::ToString(const 
runtime::ObjectRef& obj) {
     }
   } else if (const auto* n = obj.as<relax::PrimValueNode>()) {
     obj_string = ToString(n->value);
+  } else if (const auto* n = obj.as<relax::TupleNode>()) {
+    obj_string = ToString(n->fields);
   } else {
     std::ostringstream obj_des;
     obj_des << obj;
@@ -288,7 +290,7 @@ const String StringUtils::ToString(const 
runtime::ObjectRef& obj) {
   return obj_string;
 }
 
-bool StringUtils::CompareArrays(const Array<String>& left, const 
Array<String>& right, int size) {
+bool ArrayUtils::CompareArrays(const Array<String>& left, const Array<String>& 
right, int size) {
   if (left.size() == right.size() && left.size() == 0) {
     return true;
   }
@@ -311,6 +313,37 @@ bool StringUtils::CompareArrays(const Array<String>& left, 
const Array<String>&
   return true;
 }
 
+PrimExpr ArrayUtils::Accumulate(const Array<PrimExpr>& array, int pos) {
+  size_t t_pos = pos < 0 ? array.size() + pos + 1 : pos;
+  PrimExpr accumulate = Integer(1);
+  for (size_t i = 0; i < t_pos; i++) {
+    accumulate = accumulate * array[i];
+  }
+  return accumulate;
+}
+
+bool ArrayUtils::Broadcastable(const Array<PrimExpr>& lhs, const 
Array<PrimExpr>& rhs) {
+  if (lhs.size() != rhs.size()) {
+    return false;
+  }
+  for (size_t i = 0; i < lhs.size(); i++) {
+    const auto& lp = lhs[i];
+    const auto& rp = rhs[i];
+    if (lp->IsInstance<tvm::tir::VarNode>() && 
rp->IsInstance<tvm::tir::VarNode>()) {
+      continue;
+    }
+    if (lp->IsInstance<IntImmNode>() && rp->IsInstance<IntImmNode>() &&
+        Downcast<Integer>(lp)->value == Downcast<Integer>(rp)->value) {
+      continue;
+    }
+    if (lp->IsInstance<IntImmNode>() && Downcast<Integer>(lp)->value == 1) {
+      continue;
+    }
+    return false;
+  }
+  return true;
+}
+
 const Span SpanUtils::SetAttr(const Span& span, const String& key, const 
String& value) {
   if (value.size() == 0) {
     return span;
@@ -353,6 +386,10 @@ const Map<String, String> SpanUtils::GetAttrs(const Span& 
span) {
   return attrs;
 }
 
+const Span SpanUtils::CreateWithAttr(const String& key, const String& value) {
+  return SetAttr(Span(), key, value);
+}
+
 const Array<String> ExprUtils::GetInputTypes(const String& optype, size_t 
inputs_num,
                                              bool as_relax) {
   Array<String> input_types;
@@ -370,6 +407,14 @@ const Array<String> ExprUtils::GetInputTypes(const String& 
optype, size_t inputs
   } else if (optype == "full" && as_relax) {
     input_types.push_back("shape");
     input_types.push_back("input");
+  } else if (optype == "strided_slice") {
+    input_types.push_back("input");
+    if (inputs_num > 1) {
+      input_types.push_back("axes");
+      input_types.push_back("begin");
+      input_types.push_back("end");
+      input_types.push_back("strides");
+    }
   } else if (optype == "triu") {
     input_types.push_back("input");
     input_types.push_back("k");
@@ -454,13 +499,31 @@ const Array<String> ExprUtils::GetInputTypes(const 
RelayCall& call) {
   return GetInputTypes(optype, call->args.size(), false);
 }
 
+const String ExprUtils::GetSpanName(const Expr& expr, const String& suffix) {
+  const auto& name = SpanUtils::GetAttr(expr->span, msc_attr::kName);
+  if (suffix.size() > 0) {
+    return name + "_" + suffix;
+  }
+  return name;
+}
+
+const Array<PrimExpr> ExprUtils::GetShape(const Expr& expr) {
+  const auto& shape_opt = 
Downcast<relax::TensorStructInfo>(relax::GetStructInfo(expr))->GetShape();
+  ICHECK(shape_opt.defined()) << "Shape is not defined for " << expr;
+  return shape_opt.value();
+}
+
+const DataType ExprUtils::GetDataType(const Expr& expr) {
+  return Downcast<relax::TensorStructInfo>(relax::GetStructInfo(expr))->dtype;
+}
+
 TVM_REGISTER_GLOBAL("msc.core.SpanGetAttr").set_body_typed(SpanUtils::GetAttr);
 
 
TVM_REGISTER_GLOBAL("msc.core.SpanGetAttrs").set_body_typed(SpanUtils::GetAttrs);
 
 TVM_REGISTER_GLOBAL("msc.core.SpanCreateWithAttr")
     .set_body_typed([](const String& key, const String& value) -> Span {
-      return SpanUtils::SetAttr(Span(), key, value);
+      return SpanUtils::CreateWithAttr(key, value);
     });
 
 TVM_REGISTER_GLOBAL("msc.core.SpanSetAttr")
diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h
index 6c39a8d0a1..d7758cc23d 100644
--- a/src/contrib/msc/core/utils.h
+++ b/src/contrib/msc/core/utils.h
@@ -26,6 +26,7 @@
 
 #include <tvm/ir/source_map.h>
 #include <tvm/relax/expr.h>
+#include <tvm/relax/struct_info.h>
 #include <tvm/relay/expr.h>
 
 #include <string>
@@ -175,13 +176,6 @@ class StringUtils {
    * \return The String.
    */
   TVM_DLL static const String ToString(const runtime::ObjectRef& obj);
-
-  /*!
-   * \brief Compare String arrays.
-   * \return Whether two array are same.
-   */
-  TVM_DLL static bool CompareArrays(const Array<String>& left, const 
Array<String>& right,
-                                    int size = -1);
 };
 
 /*!
@@ -238,6 +232,10 @@ class ArrayUtils {
     return new_array;
   }
 
+  /*!
+   * \brief Product elements in the arrays.
+   * \return The producted array
+   */
   template <typename T>
   TVM_DLL static const Array<Array<T>> Product(const Array<Array<T>>& arrays) {
     Array<Array<T>> p_arrays;
@@ -260,6 +258,24 @@ class ArrayUtils {
     }
     return p_arrays;
   }
+
+  /*!
+   * \brief Compare String arrays.
+   * \return Whether two array are same.
+   */
+  TVM_DLL static bool CompareArrays(const Array<String>& left, const 
Array<String>& right,
+                                    int size = -1);
+  /*!
+   * \brief Accumulate array.
+   * \return The accumulate result
+   */
+  TVM_DLL static PrimExpr Accumulate(const Array<PrimExpr>& array, int pos = 
-1);
+
+  /*!
+   * \brief Check if lhs array is broadcastable to rhs.
+   * \return broadcastable
+   */
+  TVM_DLL static bool Broadcastable(const Array<PrimExpr>& lhs, const 
Array<PrimExpr>& rhs);
 };
 
 /*!
@@ -284,6 +300,12 @@ class SpanUtils {
    * \return The Attrs Map.
    */
   TVM_DLL static const Map<String, String> GetAttrs(const Span& span);
+
+  /*!
+   * \brief Create a span with <key>value</key>.
+   * \return The created Span.
+   */
+  TVM_DLL static const Span CreateWithAttr(const String& key, const String& 
value);
 };
 
 /*!
@@ -365,6 +387,24 @@ class ExprUtils {
   TVM_DLL static const T GetScalar(const relay::Constant& constant, size_t i = 
0) {
     return GetScalar<T>(constant->data, i);
   }
+
+  /*!
+   * \brief Get name in span.
+   * \return The name.
+   */
+  TVM_DLL static const String GetSpanName(const Expr& expr, const String& 
suffix = "");
+
+  /*!
+   * \brief Get shape of expr.
+   * \return The shape.
+   */
+  TVM_DLL static const Array<PrimExpr> GetShape(const Expr& expr);
+
+  /*!
+   * \brief Get dtype of expr.
+   * \return The shape.
+   */
+  TVM_DLL static const DataType GetDataType(const Expr& expr);
 };
 
 }  // namespace msc
diff --git a/tests/python/contrib/test_msc/test_graph_build.py 
b/tests/python/contrib/test_msc/test_graph_build.py
index 069ffff53b..d027672082 100644
--- a/tests/python/contrib/test_msc/test_graph_build.py
+++ b/tests/python/contrib/test_msc/test_graph_build.py
@@ -17,8 +17,6 @@
 
 """ Test graph builder && graph. """
 
-import pytest
-
 import torch
 from torch import fx
 from torch.nn import Module
@@ -1101,7 +1099,6 @@ def test_getattr():
     verify_model(GetAttr1(), input_info, expected)
 
 
[email protected](reason="MSC does not support Tuple of PrimValue")
 def test_getitem():
     """test graph builder for getitem"""
 
diff --git a/tests/python/contrib/test_msc/test_translate_relax.py 
b/tests/python/contrib/test_msc/test_translate_relax.py
index e8b7149a68..66aa90a625 100644
--- a/tests/python/contrib/test_msc/test_translate_relax.py
+++ b/tests/python/contrib/test_msc/test_translate_relax.py
@@ -17,8 +17,6 @@
 
 """ Test translate from relax. """
 
-import pytest
-
 import torch
 from torch import fx
 from torch.nn import Module
@@ -57,7 +55,6 @@ def _verify_model(torch_model, input_info, opt_config=None):
         relax_exec = tvm.relax.build(relax_mod, target)
         vm_runner = tvm.relax.VirtualMachine(relax_exec, dev)
         res = vm_runner["main"](*args)
-
         return _tvm_runtime_to_np(res)
 
     rt_mod = tvm_codegen.to_relax(
@@ -629,7 +626,6 @@ def test_getattr():
     _verify_model(GetAttr1(), input_info)
 
 
[email protected](reason="MSC does not support Tuple of PrimValue")
 def test_getitem():
     """test relax translator for getitem"""
 
diff --git a/tests/python/contrib/test_msc/test_translate_tensorflow.py 
b/tests/python/contrib/test_msc/test_translate_tensorflow.py
index 61f8ce1a97..cb4ea3c02e 100644
--- a/tests/python/contrib/test_msc/test_translate_tensorflow.py
+++ b/tests/python/contrib/test_msc/test_translate_tensorflow.py
@@ -18,8 +18,6 @@
 
 """ Test translate from tensorflow. """
 
-import pytest
-
 from packaging import version as package_version
 import numpy as np
 
@@ -504,7 +502,6 @@ def _test_stridedslice(
     verify_model(graph_def, golden, **io_info)
 
 
[email protected](reason="MSC does not support Tuple of PrimValue")
 def test_stridedslice():
     """test tensorflow translator for stridedslice"""
 
@@ -1065,7 +1062,6 @@ def _test_slice_operation_input(input_value, begin_value, 
size_value):
     verify_model(graph_def, golden, **io_info)
 
 
[email protected](reason="MSC does not support Tuple of PrimValue")
 def test_slice():
     """test tensorflow translator for slice"""
 
diff --git a/tests/python/contrib/test_msc/test_translate_torch.py 
b/tests/python/contrib/test_msc/test_translate_torch.py
index 60dcbb293a..f3e01493d9 100644
--- a/tests/python/contrib/test_msc/test_translate_torch.py
+++ b/tests/python/contrib/test_msc/test_translate_torch.py
@@ -17,8 +17,6 @@
 
 """ Test translate from torch. """
 
-import pytest
-
 import numpy as np
 
 import torch
@@ -589,7 +587,6 @@ def test_getattr():
     verify_model(GetAttr1(), input_info)
 
 
[email protected](reason="MSC does not support Tuple of PrimValue")
 def test_getitem():
     """test torch translator for getitem"""
 

Reply via email to