This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dbe95c43b2 [MSC][BugFix] Bugfix for strided_slice op (#17315)
dbe95c43b2 is described below
commit dbe95c43b2afde26eab428181d47cfc939d153c1
Author: Archermmt <[email protected]>
AuthorDate: Fri Sep 6 20:45:36 2024 +0800
[MSC][BugFix] Bugfix for strided_slice op (#17315)
support strided_slice
---
src/contrib/msc/core/codegen/base_codegen.h | 6 +-
src/contrib/msc/core/ir/graph_builder.cc | 13 ++++-
.../msc/core/transform/bind_named_params.cc | 2 +-
src/contrib/msc/core/utils.cc | 67 +++++++++++++++++++++-
src/contrib/msc/core/utils.h | 54 ++++++++++++++---
tests/python/contrib/test_msc/test_graph_build.py | 3 -
.../contrib/test_msc/test_translate_relax.py | 4 --
.../contrib/test_msc/test_translate_tensorflow.py | 4 --
.../contrib/test_msc/test_translate_torch.py | 3 -
9 files changed, 128 insertions(+), 28 deletions(-)
diff --git a/src/contrib/msc/core/codegen/base_codegen.h
b/src/contrib/msc/core/codegen/base_codegen.h
index 19d8b524b9..acaac896a1 100644
--- a/src/contrib/msc/core/codegen/base_codegen.h
+++ b/src/contrib/msc/core/codegen/base_codegen.h
@@ -179,17 +179,17 @@ class BaseCodeGen {
return 1;
}
if (node->scope.size() == scopes_.top().size()) {
- ICHECK(StringUtils::CompareArrays(node->scope, scopes_.top()))
+ ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top()))
<< "Scope mismatch, node " << node->scope << " compare to current "
<< scopes_.top();
return 0;
} else if (node->scope.size() == scopes_.top().size() + 1) {
- ICHECK(StringUtils::CompareArrays(node->scope, scopes_.top(),
scopes_.top().size()))
+ ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top(),
scopes_.top().size()))
<< "Scope increase mismatch, node " << node->scope << " compare to
current "
<< scopes_.top();
scopes_.push(node->scope);
return 1;
} else if (node->scope.size() == scopes_.top().size() - 1) {
- ICHECK(StringUtils::CompareArrays(node->scope, scopes_.top(),
node->scope.size()))
+ ICHECK(ArrayUtils::CompareArrays(node->scope, scopes_.top(),
node->scope.size()))
<< "Scope decrease mismatch, node " << node->scope << " compare to
current "
<< scopes_.top();
scopes_.pop();
diff --git a/src/contrib/msc/core/ir/graph_builder.cc
b/src/contrib/msc/core/ir/graph_builder.cc
index d35a462579..a968df4204 100644
--- a/src/contrib/msc/core/ir/graph_builder.cc
+++ b/src/contrib/msc/core/ir/graph_builder.cc
@@ -23,6 +23,7 @@
#include "graph_builder.h"
+#include <algorithm>
#include <set>
namespace tvm {
@@ -71,6 +72,13 @@ void RelaxFuncValueGetter::VisitExpr_(const relax::CallNode*
op) {
for (const auto& arg : op->args) {
if (const auto* s_node = arg.as<relax::PrimValueNode>()) {
values_.push_back(StringUtils::ToString(s_node->value));
+ } else if (const auto* s_node = arg.as<relax::TupleNode>()) {
+ bool all_values =
+ std::all_of(s_node->fields.begin(), s_node->fields.end(),
+ [](const relax::Expr& e) { return
e->IsInstance<relax::PrimValueNode>(); });
+ if (all_values) {
+ values_.push_back(StringUtils::ToString(s_node->fields));
+ }
}
}
}
@@ -337,6 +345,8 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr,
const Optional<Expr>
ICHECK(input_types[i] != "input") << i << " th PrimValue of " << optype
<< " should has special type, get "
<< input_types;
attrs.Set(input_types[i], StringUtils::ToString(s_node->value));
+ } else if (input_types[i] != "input" &&
arg->IsInstance<relax::TupleNode>()) {
+ attrs.Set(input_types[i], StringUtils::ToString(arg));
}
}
for (size_t i = call->args.size(); i < input_types.size(); i++) {
@@ -371,7 +381,8 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr,
const Optional<Expr>
Array<String> arg_names;
if (expr_tensor_map_.count(arg)) {
arg_names = expr_tensor_map_[arg];
- } else if (const auto* tuple_node = arg.as<relax::TupleNode>()) {
+ } else if (input_types[i] == "input" &&
arg->IsInstance<relax::TupleNode>()) {
+ const auto* tuple_node = arg.as<relax::TupleNode>();
for (const auto& f : tuple_node->fields) {
ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " <<
f;
for (const auto& in_name : expr_tensor_map_[f]) {
diff --git a/src/contrib/msc/core/transform/bind_named_params.cc
b/src/contrib/msc/core/transform/bind_named_params.cc
index 5ba1ca30eb..6256fae05f 100644
--- a/src/contrib/msc/core/transform/bind_named_params.cc
+++ b/src/contrib/msc/core/transform/bind_named_params.cc
@@ -84,7 +84,7 @@ std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>>
NormalizeNamedBindings(
if (auto opt = obj.as<relax::Expr>()) {
return opt.value();
} else if (auto opt = obj.as<runtime::NDArray>()) {
- const auto& span = SpanUtils::SetAttr(Span(), msc_attr::kName,
key->name_hint());
+ const auto& span = SpanUtils::CreateWithAttr(msc_attr::kName,
key->name_hint());
return Constant(opt.value(), StructInfo(), span);
} else {
LOG(FATAL) << "Cannot coerce object of type " << obj->GetTypeKey()
diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc
index 5fcbe924ae..c6e74d4284 100644
--- a/src/contrib/msc/core/utils.cc
+++ b/src/contrib/msc/core/utils.cc
@@ -280,6 +280,8 @@ const String StringUtils::ToString(const
runtime::ObjectRef& obj) {
}
} else if (const auto* n = obj.as<relax::PrimValueNode>()) {
obj_string = ToString(n->value);
+ } else if (const auto* n = obj.as<relax::TupleNode>()) {
+ obj_string = ToString(n->fields);
} else {
std::ostringstream obj_des;
obj_des << obj;
@@ -288,7 +290,7 @@ const String StringUtils::ToString(const
runtime::ObjectRef& obj) {
return obj_string;
}
-bool StringUtils::CompareArrays(const Array<String>& left, const
Array<String>& right, int size) {
+bool ArrayUtils::CompareArrays(const Array<String>& left, const Array<String>&
right, int size) {
if (left.size() == right.size() && left.size() == 0) {
return true;
}
@@ -311,6 +313,37 @@ bool StringUtils::CompareArrays(const Array<String>& left,
const Array<String>&
return true;
}
+PrimExpr ArrayUtils::Accumulate(const Array<PrimExpr>& array, int pos) {
+ size_t t_pos = pos < 0 ? array.size() + pos + 1 : pos;
+ PrimExpr accumulate = Integer(1);
+ for (size_t i = 0; i < t_pos; i++) {
+ accumulate = accumulate * array[i];
+ }
+ return accumulate;
+}
+
+bool ArrayUtils::Broadcastable(const Array<PrimExpr>& lhs, const
Array<PrimExpr>& rhs) {
+ if (lhs.size() != rhs.size()) {
+ return false;
+ }
+ for (size_t i = 0; i < lhs.size(); i++) {
+ const auto& lp = lhs[i];
+ const auto& rp = rhs[i];
+ if (lp->IsInstance<tvm::tir::VarNode>() &&
rp->IsInstance<tvm::tir::VarNode>()) {
+ continue;
+ }
+ if (lp->IsInstance<IntImmNode>() && rp->IsInstance<IntImmNode>() &&
+ Downcast<Integer>(lp)->value == Downcast<Integer>(rp)->value) {
+ continue;
+ }
+ if (lp->IsInstance<IntImmNode>() && Downcast<Integer>(lp)->value == 1) {
+ continue;
+ }
+ return false;
+ }
+ return true;
+}
+
const Span SpanUtils::SetAttr(const Span& span, const String& key, const
String& value) {
if (value.size() == 0) {
return span;
@@ -353,6 +386,10 @@ const Map<String, String> SpanUtils::GetAttrs(const Span&
span) {
return attrs;
}
+const Span SpanUtils::CreateWithAttr(const String& key, const String& value) {
+ return SetAttr(Span(), key, value);
+}
+
const Array<String> ExprUtils::GetInputTypes(const String& optype, size_t
inputs_num,
bool as_relax) {
Array<String> input_types;
@@ -370,6 +407,14 @@ const Array<String> ExprUtils::GetInputTypes(const String&
optype, size_t inputs
} else if (optype == "full" && as_relax) {
input_types.push_back("shape");
input_types.push_back("input");
+ } else if (optype == "strided_slice") {
+ input_types.push_back("input");
+ if (inputs_num > 1) {
+ input_types.push_back("axes");
+ input_types.push_back("begin");
+ input_types.push_back("end");
+ input_types.push_back("strides");
+ }
} else if (optype == "triu") {
input_types.push_back("input");
input_types.push_back("k");
@@ -454,13 +499,31 @@ const Array<String> ExprUtils::GetInputTypes(const
RelayCall& call) {
return GetInputTypes(optype, call->args.size(), false);
}
+const String ExprUtils::GetSpanName(const Expr& expr, const String& suffix) {
+ const auto& name = SpanUtils::GetAttr(expr->span, msc_attr::kName);
+ if (suffix.size() > 0) {
+ return name + "_" + suffix;
+ }
+ return name;
+}
+
+const Array<PrimExpr> ExprUtils::GetShape(const Expr& expr) {
+ const auto& shape_opt =
Downcast<relax::TensorStructInfo>(relax::GetStructInfo(expr))->GetShape();
+ ICHECK(shape_opt.defined()) << "Shape is not defined for " << expr;
+ return shape_opt.value();
+}
+
+const DataType ExprUtils::GetDataType(const Expr& expr) {
+ return Downcast<relax::TensorStructInfo>(relax::GetStructInfo(expr))->dtype;
+}
+
TVM_REGISTER_GLOBAL("msc.core.SpanGetAttr").set_body_typed(SpanUtils::GetAttr);
TVM_REGISTER_GLOBAL("msc.core.SpanGetAttrs").set_body_typed(SpanUtils::GetAttrs);
TVM_REGISTER_GLOBAL("msc.core.SpanCreateWithAttr")
.set_body_typed([](const String& key, const String& value) -> Span {
- return SpanUtils::SetAttr(Span(), key, value);
+ return SpanUtils::CreateWithAttr(key, value);
});
TVM_REGISTER_GLOBAL("msc.core.SpanSetAttr")
diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h
index 6c39a8d0a1..d7758cc23d 100644
--- a/src/contrib/msc/core/utils.h
+++ b/src/contrib/msc/core/utils.h
@@ -26,6 +26,7 @@
#include <tvm/ir/source_map.h>
#include <tvm/relax/expr.h>
+#include <tvm/relax/struct_info.h>
#include <tvm/relay/expr.h>
#include <string>
@@ -175,13 +176,6 @@ class StringUtils {
* \return The String.
*/
TVM_DLL static const String ToString(const runtime::ObjectRef& obj);
-
- /*!
- * \brief Compare String arrays.
- * \return Whether two array are same.
- */
- TVM_DLL static bool CompareArrays(const Array<String>& left, const
Array<String>& right,
- int size = -1);
};
/*!
@@ -238,6 +232,10 @@ class ArrayUtils {
return new_array;
}
+ /*!
+ * \brief Product elements in the arrays.
+ * \return The producted array
+ */
template <typename T>
TVM_DLL static const Array<Array<T>> Product(const Array<Array<T>>& arrays) {
Array<Array<T>> p_arrays;
@@ -260,6 +258,24 @@ class ArrayUtils {
}
return p_arrays;
}
+
+ /*!
+ * \brief Compare String arrays.
+ * \return Whether two array are same.
+ */
+ TVM_DLL static bool CompareArrays(const Array<String>& left, const
Array<String>& right,
+ int size = -1);
+ /*!
+ * \brief Accumulate array.
+ * \return The accumulate result
+ */
+ TVM_DLL static PrimExpr Accumulate(const Array<PrimExpr>& array, int pos =
-1);
+
+ /*!
+ * \brief Check if lhs array is broadcastable to rhs.
+ * \return broadcastable
+ */
+ TVM_DLL static bool Broadcastable(const Array<PrimExpr>& lhs, const
Array<PrimExpr>& rhs);
};
/*!
@@ -284,6 +300,12 @@ class SpanUtils {
* \return The Attrs Map.
*/
TVM_DLL static const Map<String, String> GetAttrs(const Span& span);
+
+ /*!
+ * \brief Create a span with <key>value</key>.
+ * \return The created Span.
+ */
+ TVM_DLL static const Span CreateWithAttr(const String& key, const String&
value);
};
/*!
@@ -365,6 +387,24 @@ class ExprUtils {
TVM_DLL static const T GetScalar(const relay::Constant& constant, size_t i =
0) {
return GetScalar<T>(constant->data, i);
}
+
+ /*!
+ * \brief Get name in span.
+ * \return The name.
+ */
+ TVM_DLL static const String GetSpanName(const Expr& expr, const String&
suffix = "");
+
+ /*!
+ * \brief Get shape of expr.
+ * \return The shape.
+ */
+ TVM_DLL static const Array<PrimExpr> GetShape(const Expr& expr);
+
+ /*!
+ * \brief Get dtype of expr.
+ * \return The shape.
+ */
+ TVM_DLL static const DataType GetDataType(const Expr& expr);
};
} // namespace msc
diff --git a/tests/python/contrib/test_msc/test_graph_build.py
b/tests/python/contrib/test_msc/test_graph_build.py
index 069ffff53b..d027672082 100644
--- a/tests/python/contrib/test_msc/test_graph_build.py
+++ b/tests/python/contrib/test_msc/test_graph_build.py
@@ -17,8 +17,6 @@
""" Test graph builder && graph. """
-import pytest
-
import torch
from torch import fx
from torch.nn import Module
@@ -1101,7 +1099,6 @@ def test_getattr():
verify_model(GetAttr1(), input_info, expected)
[email protected](reason="MSC does not support Tuple of PrimValue")
def test_getitem():
"""test graph builder for getitem"""
diff --git a/tests/python/contrib/test_msc/test_translate_relax.py
b/tests/python/contrib/test_msc/test_translate_relax.py
index e8b7149a68..66aa90a625 100644
--- a/tests/python/contrib/test_msc/test_translate_relax.py
+++ b/tests/python/contrib/test_msc/test_translate_relax.py
@@ -17,8 +17,6 @@
""" Test translate from relax. """
-import pytest
-
import torch
from torch import fx
from torch.nn import Module
@@ -57,7 +55,6 @@ def _verify_model(torch_model, input_info, opt_config=None):
relax_exec = tvm.relax.build(relax_mod, target)
vm_runner = tvm.relax.VirtualMachine(relax_exec, dev)
res = vm_runner["main"](*args)
-
return _tvm_runtime_to_np(res)
rt_mod = tvm_codegen.to_relax(
@@ -629,7 +626,6 @@ def test_getattr():
_verify_model(GetAttr1(), input_info)
[email protected](reason="MSC does not support Tuple of PrimValue")
def test_getitem():
"""test relax translator for getitem"""
diff --git a/tests/python/contrib/test_msc/test_translate_tensorflow.py
b/tests/python/contrib/test_msc/test_translate_tensorflow.py
index 61f8ce1a97..cb4ea3c02e 100644
--- a/tests/python/contrib/test_msc/test_translate_tensorflow.py
+++ b/tests/python/contrib/test_msc/test_translate_tensorflow.py
@@ -18,8 +18,6 @@
""" Test translate from tensorflow. """
-import pytest
-
from packaging import version as package_version
import numpy as np
@@ -504,7 +502,6 @@ def _test_stridedslice(
verify_model(graph_def, golden, **io_info)
[email protected](reason="MSC does not support Tuple of PrimValue")
def test_stridedslice():
"""test tensorflow translator for stridedslice"""
@@ -1065,7 +1062,6 @@ def _test_slice_operation_input(input_value, begin_value,
size_value):
verify_model(graph_def, golden, **io_info)
[email protected](reason="MSC does not support Tuple of PrimValue")
def test_slice():
"""test tensorflow translator for slice"""
diff --git a/tests/python/contrib/test_msc/test_translate_torch.py
b/tests/python/contrib/test_msc/test_translate_torch.py
index 60dcbb293a..f3e01493d9 100644
--- a/tests/python/contrib/test_msc/test_translate_torch.py
+++ b/tests/python/contrib/test_msc/test_translate_torch.py
@@ -17,8 +17,6 @@
""" Test translate from torch. """
-import pytest
-
import numpy as np
import torch
@@ -589,7 +587,6 @@ def test_getattr():
verify_model(GetAttr1(), input_info)
[email protected](reason="MSC does not support Tuple of PrimValue")
def test_getitem():
"""test torch translator for getitem"""