This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 8a2ca3405b [Unity] Use PrimValue as offset in R.tril and R.triu
(#15783)
8a2ca3405b is described below
commit 8a2ca3405baef5ea2d33b3f9da1c6344a8bdae15
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Oct 5 14:56:27 2023 -0500
[Unity] Use PrimValue as offset in R.tril and R.triu (#15783)
* [Unity] Use PrimValue as offset in R.tril and R.triu
This mirrors the support in `topi`, which supports a `PrimExpr` as the
offset of the diagonal.
* Update implementation to avoid
I believe the `-Wsequence-point` raised by gcc is spurious, as the
`index++` occurs within a braced-initialization list, which has a
defined left-to-right execution order. However, better to avoid the
warning altogether.
* Updated attr usage to args
* Correct relax op names in msc
* Parametrize failing MSC unit tests, mark with xfail
* Lint fix
* Marked relay to relax tests as known failures
---
python/tvm/relax/op/_op_gradient.py | 2 +-
python/tvm/relax/op/create.py | 10 +++-
python/tvm/relax/transform/legalize_ops/create.py | 5 +-
src/contrib/msc/core/utils.cc | 5 +-
src/relax/op/op_common.h | 60 ++++++++++++++++++++++
src/relax/op/tensor/create.cc | 31 ++++++-----
src/relax/op/tensor/create.h | 12 +++++
.../contrib/test_msc/test_translate_relay.py | 8 +++
.../contrib/test_msc/test_translate_torch.py | 39 +++++++++++---
tests/python/relax/test_op_create.py | 8 +++
10 files changed, 150 insertions(+), 30 deletions(-)
diff --git a/python/tvm/relax/op/_op_gradient.py
b/python/tvm/relax/op/_op_gradient.py
index 17bafe0a37..2873c70ba7 100644
--- a/python/tvm/relax/op/_op_gradient.py
+++ b/python/tvm/relax/op/_op_gradient.py
@@ -433,7 +433,7 @@ def triu_grad(
Backward:
Returns `[triu(y_grad, k)]`.
"""
- k = orig_call.attrs.k
+ k = orig_call.args[1]
return [triu(output_grad, k)]
diff --git a/python/tvm/relax/op/create.py b/python/tvm/relax/op/create.py
index eb981efdb5..8fd3b2cde1 100644
--- a/python/tvm/relax/op/create.py
+++ b/python/tvm/relax/op/create.py
@@ -215,7 +215,7 @@ def arange(
return _ffi_api.arange(start, end, step, dtype) # type: ignore
-def tril(x: Expr, k: int = 0) -> Expr:
+def tril(x: Expr, k: Union[int, PrimExpr, Expr] = 0) -> Expr:
"""Return the lower triangular part of a matrix or a batch of matrices.
Parameters
@@ -235,10 +235,13 @@ def tril(x: Expr, k: int = 0) -> Expr:
ret : relax.Expr
The result tensor.
"""
+ if not isinstance(k, Expr):
+ k = PrimValue(k)
+
return _ffi_api.tril(x, k) # type: ignore
-def triu(x: Expr, k: int = 0) -> Expr:
+def triu(x: Expr, k: [int, PrimExpr, Expr] = 0) -> Expr:
"""Return the upper triangular part of a matrix or a batch of matrices.
Parameters
@@ -258,4 +261,7 @@ def triu(x: Expr, k: int = 0) -> Expr:
ret : relax.Expr
The result tensor.
"""
+ if not isinstance(k, Expr):
+ k = PrimValue(k)
+
return _ffi_api.triu(x, k) # type: ignore
diff --git a/python/tvm/relax/transform/legalize_ops/create.py
b/python/tvm/relax/transform/legalize_ops/create.py
index 972dac5501..1b022672d0 100644
--- a/python/tvm/relax/transform/legalize_ops/create.py
+++ b/python/tvm/relax/transform/legalize_ops/create.py
@@ -48,10 +48,11 @@ def _full(is_like: bool, fill_value: Optional[float],
primfunc_name: str) -> Leg
def _tril_triu(is_upper: bool, primfunc_name: str) -> LegalizeFunc:
def tril_triu_call_te(bb: BlockBuilder, call: Call) -> Expr:
+ data, k = call.args
return bb.call_te(
topi.trilu,
- call.args[0],
- tir.const(call.attrs.k, "int32"),
+ data,
+ k,
upper=is_upper,
primfunc_name_hint=primfunc_name,
)
diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc
index 908c4d9507..f7ec1f1cbf 100644
--- a/src/contrib/msc/core/utils.cc
+++ b/src/contrib/msc/core/utils.cc
@@ -236,7 +236,10 @@ const Array<String> ExprUtils::GetInputTypes(const String&
optype, size_t inputs
} else if (optype == "full" && as_relax) {
input_types.push_back("shape");
input_types.push_back("input");
- } else if (optype == "trilu") {
+ } else if (optype == "triu") {
+ input_types.push_back("input");
+ input_types.push_back("k");
+ } else if (optype == "tril") {
input_types.push_back("input");
input_types.push_back("k");
} else if (optype == "image.resize2d" && as_relax) {
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index dd4c3ac173..290cdef0d5 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -31,6 +31,7 @@
#include <tvm/relay/op.h>
#include <tvm/tir/data_layout.h>
+#include <tuple>
#include <utility>
#include <vector>
@@ -75,6 +76,65 @@ inline TensorStructInfo GetUnaryInputTensorStructInfo(const
Call& call, const Bl
Array<TensorStructInfo> GetTensorStructInfoFromTuple(const Call& call, const
BlockBuilder& ctx,
const Expr& tup);
+namespace detail {
+/*! \brief Implementation helper for GetArgStructInfo */
+template <typename ArgType>
+ArgType GetArgStructInfoByIndex(const Call& call, const Op& op, const
BlockBuilder& ctx,
+ size_t index) {
+ if (!call->args[index]->struct_info_.defined()) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << op << " op should have arguments with defined
StructInfo. "
+ << "However, args[" << index << "] has undefined struct
info.");
+ }
+
+ auto sinfo = GetStructInfo(call->args[index]);
+ auto typed_sinfo = sinfo.as<ArgType>();
+
+ if (!typed_sinfo.defined()) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << op << " requires that args[" << index << "] be a "
+ << ArgType::ContainerType::_type_key << ", but was
instead " << sinfo
+ << " of type " << sinfo->GetTypeKey());
+ }
+
+ return typed_sinfo.value();
+}
+
+/*! \brief Implementation helper for GetArgStructInfo */
+template <typename... ArgTypes, size_t... Indices>
+std::tuple<ArgTypes...> GetArgStructInfoHelper(const Call& call, const Op& op,
+ const BlockBuilder& ctx,
+
std::index_sequence<Indices...>) {
+ return std::tuple<ArgTypes...>{GetArgStructInfoByIndex<ArgTypes>(call, op,
ctx, Indices)...};
+}
+} // namespace detail
+
+/*!
+ * \brief Get all arg struct infos as expected types
+ *
+ * \tparam ArgTypes The expected types of arguments, in the order they appear.
+ * \param call The context Call to the operator.
+ * \param ctx The error reporting context.
+ * \return The tensor struct infos of tuple input.
+ * \throw Throw exception if input expression is not a tuple.
+ */
+template <typename... ArgTypes>
+std::tuple<ArgTypes...> GetArgStructInfo(const Call& call, const BlockBuilder&
ctx) {
+ Op op = Downcast<Op>(call->op);
+ size_t n_input = op->arguments.size();
+
+ // Unfortunately, because the `.add_argument()` calls in
+ // TVM_REGISTER_OP occur during initialization of globals and are
+ // not available at compile-time, this cannot be a static_assert.
+ ICHECK_EQ(n_input, sizeof...(ArgTypes))
+ << "Internal error: " << op << " op defines " << n_input
+ << " arguments in its TVM_REGISTER_OP() call, "
+ << "but GetArgStructInfo was given " << sizeof...(ArgTypes) << "
template arguments.";
+
+ return detail::GetArgStructInfoHelper<ArgTypes...>(
+ call, op, ctx, std::make_index_sequence<sizeof...(ArgTypes)>());
+}
+
/************ Op registration macro ************/
/*!
diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc
index 3a4de79b11..f5893d64b1 100644
--- a/src/relax/op/tensor/create.cc
+++ b/src/relax/op/tensor/create.cc
@@ -286,27 +286,26 @@ TVM_REGISTER_OP("relax.arange")
/* relax.tril & relax.triu */
TVM_REGISTER_NODE_TYPE(TriluAttrs);
-Expr tril(Expr x, int k) {
- ObjectPtr<TriluAttrs> attrs = make_object<TriluAttrs>();
- attrs->k = k;
-
+Expr tril(Expr x, Expr k) {
static const Op& op = Op::Get("relax.tril");
- return Call(op, {std::move(x)}, Attrs(attrs), {});
+ return Call(op, {x, k});
}
-Expr triu(Expr x, int k) {
- ObjectPtr<TriluAttrs> attrs = make_object<TriluAttrs>();
- attrs->k = k;
+Expr tril(Expr x, int k) { return tril(x, relax::PrimValue::Int64(k)); }
+Expr triu(Expr x, Expr k) {
static const Op& op = Op::Get("relax.triu");
- return Call(op, {std::move(x)}, Attrs(attrs), {});
+ return Call(op, {x, k});
}
-TVM_REGISTER_GLOBAL("relax.op.tril").set_body_typed(tril);
-TVM_REGISTER_GLOBAL("relax.op.triu").set_body_typed(triu);
+Expr triu(Expr x, int k) { return triu(x, relax::PrimValue::Int64(k)); }
+
+TVM_REGISTER_GLOBAL("relax.op.tril").set_body_typed(static_cast<Expr (*)(Expr,
Expr)>(tril));
+TVM_REGISTER_GLOBAL("relax.op.triu").set_body_typed(static_cast<Expr (*)(Expr,
Expr)>(triu));
StructInfo InferStructInfoTrilTriu(const Call& call, const BlockBuilder& ctx) {
- TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+ auto [data_sinfo, offset] = GetArgStructInfo<TensorStructInfo,
PrimStructInfo>(call, ctx);
+
if (!data_sinfo->IsUnknownNdim() && data_sinfo->ndim < 2) {
ctx->ReportFatal(Diagnostic::Error(call) << call->op
<< " requires the input tensor to
have at least two "
@@ -317,16 +316,16 @@ StructInfo InferStructInfoTrilTriu(const Call& call,
const BlockBuilder& ctx) {
}
TVM_REGISTER_OP("relax.tril")
- .set_attrs_type<TriluAttrs>()
- .set_num_inputs(1)
+ .set_num_inputs(2)
.add_argument("x", "Tensor", "The input tensor.")
+ .add_argument("k", "PrimValue", "The offset of the diagonal.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTrilTriu)
.set_attr<Bool>("FPurity", Bool(true));
TVM_REGISTER_OP("relax.triu")
- .set_attrs_type<TriluAttrs>()
- .set_num_inputs(1)
+ .set_num_inputs(2)
.add_argument("x", "Tensor", "The input tensor.")
+ .add_argument("k", "PrimValue", "The offset of the diagonal.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTrilTriu)
.set_attr<Bool>("FPurity", Bool(true));
diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h
index 40a46abe35..989eaa12fd 100644
--- a/src/relax/op/tensor/create.h
+++ b/src/relax/op/tensor/create.h
@@ -82,9 +82,21 @@ Expr zeros_like(Expr x, DataType dtype);
Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype);
/*! \brief Return the lower triangular part of a matrix or a batch of
matrices. */
+Expr tril(Expr x, Expr k);
+
+/*! \brief Return the lower triangular part of a matrix or a batch of matrices.
+ *
+ * Overload provided for backwards compatibility.
+ */
Expr tril(Expr x, int k);
/*! \brief Return the upper triangular part of a matrix or a batch of
matrices. */
+Expr triu(Expr x, Expr k);
+
+/*! \brief Return the upper triangular part of a matrix or a batch of matrices.
+ *
+ * Overload provided for backwards compatibility.
+ */
Expr triu(Expr x, int k);
} // namespace relax
diff --git a/tests/python/contrib/test_msc/test_translate_relay.py
b/tests/python/contrib/test_msc/test_translate_relay.py
index 70f5e8cda5..f543c8c292 100644
--- a/tests/python/contrib/test_msc/test_translate_relay.py
+++ b/tests/python/contrib/test_msc/test_translate_relay.py
@@ -19,6 +19,8 @@
""" Test translate from relay. """
import numpy as np
+import pytest
+
import torch
from torch import fx
from torch.nn import Module
@@ -805,6 +807,9 @@ def test_tensor():
verify_model(Empty2(), [([10, 10], "float32")], build_target="llvm")
[email protected](
+ reason="Failure to convert from R.PrimValue argument in
msc/framework/tvm/codegen.cc"
+)
def test_tril():
"""test relay to relax for tril"""
@@ -822,6 +827,9 @@ def test_tril():
verify_model(InplaceTril(), input_info)
[email protected](
+ reason="Failure to convert from R.PrimValue argument in
msc/framework/tvm/codegen.cc"
+)
def test_triu():
"""test relay to relax for triu"""
diff --git a/tests/python/contrib/test_msc/test_translate_torch.py
b/tests/python/contrib/test_msc/test_translate_torch.py
index 6b9a7c9332..ae377e52cb 100644
--- a/tests/python/contrib/test_msc/test_translate_torch.py
+++ b/tests/python/contrib/test_msc/test_translate_torch.py
@@ -18,6 +18,7 @@
""" Test translate from torch. """
import numpy as np
+import pytest
import torch
from torch.nn import Module
@@ -781,40 +782,62 @@ def test_arange():
verify_model(Arange(), [([10, 10], "float32")])
-def test_tril():
+# pylint: disable=redefined-outer-name
+via_relax_param = tvm.testing.parameter(
+ True,
+ pytest.param(
+ False,
+ marks=pytest.mark.xfail(
+ reason="Failure to convert from R.PrimValue argument in
msc/framework/tvm/codegen.cc"
+ ),
+ ),
+)
+
+
+def test_tril(via_relax_param):
"""test torch translator for tril"""
class Tril(Module):
def forward(self, data):
return torch.tril(data, 1)
+ input_info = [([10, 10], "float32")]
+ verify_model(Tril(), input_info, via_relax_param)
+
+
+def test_tril_inplace(via_relax_param):
+ """test torch translator for tril"""
+
class InplaceTril(Module):
def forward(self, data):
data.tril_(1)
return data
input_info = [([10, 10], "float32")]
- for via_relax in [True, False]:
- verify_model(Tril(), input_info, via_relax)
- verify_model(InplaceTril(), input_info, via_relax)
+ verify_model(InplaceTril(), input_info, via_relax_param)
-def test_triu():
+def test_triu(via_relax_param):
"""test torch translator for triu"""
class Triu(Module):
def forward(self, data):
return torch.triu(data, 1)
+ input_info = [([10, 10], "float32")]
+ verify_model(Triu(), input_info, via_relax_param)
+
+
+def test_triu_inplace(via_relax_param):
+ """test torch translator for triu"""
+
class InplaceTriu(Module):
def forward(self, data):
data.triu_(1)
return data
input_info = [([10, 10], "float32")]
- for via_relax in [True, False]:
- verify_model(Triu(), input_info, via_relax)
- verify_model(InplaceTriu(), input_info, via_relax)
+ verify_model(InplaceTriu(), input_info, via_relax_param)
def test_new_ones():
diff --git a/tests/python/relax/test_op_create.py
b/tests/python/relax/test_op_create.py
index c0b4308529..1e895169f6 100644
--- a/tests/python/relax/test_op_create.py
+++ b/tests/python/relax/test_op_create.py
@@ -643,11 +643,19 @@ def test_tril_triu_infer_struct_info_shape_symbolic():
x0 = relax.Var("x", R.Tensor((a, b, c), "float32"))
x1 = relax.Var("x", R.Tensor((a, b, c)))
x2 = relax.Var("x", R.Tensor((a, b, c), "float32", vdev0))
+ x3 = relax.Var("x", R.Tensor((16, 32, 64)))
+ # Dynamic tensor, static offset
_check_inference(bb, relax.op.tril(x0), relax.TensorStructInfo((a, b, c),
"float32"))
_check_inference(bb, relax.op.triu(x1), relax.TensorStructInfo((a, b, c),
dtype=""))
_check_inference(bb, relax.op.tril(x2), relax.TensorStructInfo((a, b, c),
"float32", vdev0))
+ # Static tensor, dynamic offset
+ _check_inference(bb, relax.op.tril(x3, a), relax.TensorStructInfo((16, 32,
64), dtype=""))
+
+ # Dynamic tensor, dynamic offset
+ _check_inference(bb, relax.op.tril(x0, a), relax.TensorStructInfo((a, b,
c), "float32"))
+
def test_tril_triu_infer_struct_info_shape_var():
bb = relax.BlockBuilder()