This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 8a2ca3405b [Unity] Use PrimValue as offset in R.tril and R.triu 
(#15783)
8a2ca3405b is described below

commit 8a2ca3405baef5ea2d33b3f9da1c6344a8bdae15
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Oct 5 14:56:27 2023 -0500

    [Unity] Use PrimValue as offset in R.tril and R.triu (#15783)
    
    * [Unity] Use PrimValue as offset in R.tril and R.triu
    
    This mirrors the support in `topi`, which supports a `PrimExpr` as the
    offset of the diagonal.
    
    * Update implementation to avoid
    
    I believe the `-Wsequence-point` raised by gcc is spurious, as the
    `index++` occurs within a braced-initialization list, which has a
    defined left-to-right execution order.  However, better to avoid the
    warning altogether.
    
    * Updated attr usage to args
    
    * Correct relax op names in msc
    
    * Parametrize failing MSC unit tests, mark with xfail
    
    * Lint fix
    
    * Marked relay to relax tests as known failures
---
 python/tvm/relax/op/_op_gradient.py                |  2 +-
 python/tvm/relax/op/create.py                      | 10 +++-
 python/tvm/relax/transform/legalize_ops/create.py  |  5 +-
 src/contrib/msc/core/utils.cc                      |  5 +-
 src/relax/op/op_common.h                           | 60 ++++++++++++++++++++++
 src/relax/op/tensor/create.cc                      | 31 ++++++-----
 src/relax/op/tensor/create.h                       | 12 +++++
 .../contrib/test_msc/test_translate_relay.py       |  8 +++
 .../contrib/test_msc/test_translate_torch.py       | 39 +++++++++++---
 tests/python/relax/test_op_create.py               |  8 +++
 10 files changed, 150 insertions(+), 30 deletions(-)

diff --git a/python/tvm/relax/op/_op_gradient.py 
b/python/tvm/relax/op/_op_gradient.py
index 17bafe0a37..2873c70ba7 100644
--- a/python/tvm/relax/op/_op_gradient.py
+++ b/python/tvm/relax/op/_op_gradient.py
@@ -433,7 +433,7 @@ def triu_grad(
     Backward:
         Returns `[triu(y_grad, k)]`.
     """
-    k = orig_call.attrs.k
+    k = orig_call.args[1]
     return [triu(output_grad, k)]
 
 
diff --git a/python/tvm/relax/op/create.py b/python/tvm/relax/op/create.py
index eb981efdb5..8fd3b2cde1 100644
--- a/python/tvm/relax/op/create.py
+++ b/python/tvm/relax/op/create.py
@@ -215,7 +215,7 @@ def arange(
     return _ffi_api.arange(start, end, step, dtype)  # type: ignore
 
 
-def tril(x: Expr, k: int = 0) -> Expr:
+def tril(x: Expr, k: Union[int, PrimExpr, Expr] = 0) -> Expr:
     """Return the lower triangular part of a matrix or a batch of matrices.
 
     Parameters
@@ -235,10 +235,13 @@ def tril(x: Expr, k: int = 0) -> Expr:
     ret : relax.Expr
         The result tensor.
     """
+    if not isinstance(k, Expr):
+        k = PrimValue(k)
+
     return _ffi_api.tril(x, k)  # type: ignore
 
 
-def triu(x: Expr, k: int = 0) -> Expr:
+def triu(x: Expr, k: [int, PrimExpr, Expr] = 0) -> Expr:
     """Return the upper triangular part of a matrix or a batch of matrices.
 
     Parameters
@@ -258,4 +261,7 @@ def triu(x: Expr, k: int = 0) -> Expr:
     ret : relax.Expr
         The result tensor.
     """
+    if not isinstance(k, Expr):
+        k = PrimValue(k)
+
     return _ffi_api.triu(x, k)  # type: ignore
diff --git a/python/tvm/relax/transform/legalize_ops/create.py 
b/python/tvm/relax/transform/legalize_ops/create.py
index 972dac5501..1b022672d0 100644
--- a/python/tvm/relax/transform/legalize_ops/create.py
+++ b/python/tvm/relax/transform/legalize_ops/create.py
@@ -48,10 +48,11 @@ def _full(is_like: bool, fill_value: Optional[float], 
primfunc_name: str) -> Leg
 
 def _tril_triu(is_upper: bool, primfunc_name: str) -> LegalizeFunc:
     def tril_triu_call_te(bb: BlockBuilder, call: Call) -> Expr:
+        data, k = call.args
         return bb.call_te(
             topi.trilu,
-            call.args[0],
-            tir.const(call.attrs.k, "int32"),
+            data,
+            k,
             upper=is_upper,
             primfunc_name_hint=primfunc_name,
         )
diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc
index 908c4d9507..f7ec1f1cbf 100644
--- a/src/contrib/msc/core/utils.cc
+++ b/src/contrib/msc/core/utils.cc
@@ -236,7 +236,10 @@ const Array<String> ExprUtils::GetInputTypes(const String& 
optype, size_t inputs
   } else if (optype == "full" && as_relax) {
     input_types.push_back("shape");
     input_types.push_back("input");
-  } else if (optype == "trilu") {
+  } else if (optype == "triu") {
+    input_types.push_back("input");
+    input_types.push_back("k");
+  } else if (optype == "tril") {
     input_types.push_back("input");
     input_types.push_back("k");
   } else if (optype == "image.resize2d" && as_relax) {
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index dd4c3ac173..290cdef0d5 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -31,6 +31,7 @@
 #include <tvm/relay/op.h>
 #include <tvm/tir/data_layout.h>
 
+#include <tuple>
 #include <utility>
 #include <vector>
 
@@ -75,6 +76,65 @@ inline TensorStructInfo GetUnaryInputTensorStructInfo(const 
Call& call, const Bl
 Array<TensorStructInfo> GetTensorStructInfoFromTuple(const Call& call, const 
BlockBuilder& ctx,
                                                      const Expr& tup);
 
+namespace detail {
+/*! \brief Implementation helper for GetArgStructInfo */
+template <typename ArgType>
+ArgType GetArgStructInfoByIndex(const Call& call, const Op& op, const 
BlockBuilder& ctx,
+                                size_t index) {
+  if (!call->args[index]->struct_info_.defined()) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << op << " op should have arguments with defined 
StructInfo.  "
+                     << "However, args[" << index << "] has undefined struct 
info.");
+  }
+
+  auto sinfo = GetStructInfo(call->args[index]);
+  auto typed_sinfo = sinfo.as<ArgType>();
+
+  if (!typed_sinfo.defined()) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << op << " requires that args[" << index << "] be a "
+                     << ArgType::ContainerType::_type_key << ", but was 
instead " << sinfo
+                     << " of type " << sinfo->GetTypeKey());
+  }
+
+  return typed_sinfo.value();
+}
+
+/*! \brief Implementation helper for GetArgStructInfo */
+template <typename... ArgTypes, size_t... Indices>
+std::tuple<ArgTypes...> GetArgStructInfoHelper(const Call& call, const Op& op,
+                                               const BlockBuilder& ctx,
+                                               
std::index_sequence<Indices...>) {
+  return std::tuple<ArgTypes...>{GetArgStructInfoByIndex<ArgTypes>(call, op, 
ctx, Indices)...};
+}
+}  // namespace detail
+
+/*!
+ * \brief Get all arg struct infos as expected types
+ *
+ * \tparam ArgTypes The expected types of arguments, in the order they appear.
+ * \param call The context Call to the operator.
+ * \param ctx The error reporting context.
+ * \return The tensor struct infos of tuple input.
+ * \throw Throw exception if input expression is not a tuple.
+ */
+template <typename... ArgTypes>
+std::tuple<ArgTypes...> GetArgStructInfo(const Call& call, const BlockBuilder& 
ctx) {
+  Op op = Downcast<Op>(call->op);
+  size_t n_input = op->arguments.size();
+
+  // Unfortunately, because the `.add_argument()` calls in
+  // TVM_REGISTER_OP occur during initialization of globals and are
+  // not available at compile-time, this cannot be a static_assert.
+  ICHECK_EQ(n_input, sizeof...(ArgTypes))
+      << "Internal error: " << op << " op defines " << n_input
+      << " arguments in its TVM_REGISTER_OP() call, "
+      << "but GetArgStructInfo was given " << sizeof...(ArgTypes) << " 
template arguments.";
+
+  return detail::GetArgStructInfoHelper<ArgTypes...>(
+      call, op, ctx, std::make_index_sequence<sizeof...(ArgTypes)>());
+}
+
 /************ Op registration macro ************/
 
 /*!
diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc
index 3a4de79b11..f5893d64b1 100644
--- a/src/relax/op/tensor/create.cc
+++ b/src/relax/op/tensor/create.cc
@@ -286,27 +286,26 @@ TVM_REGISTER_OP("relax.arange")
 /* relax.tril & relax.triu */
 TVM_REGISTER_NODE_TYPE(TriluAttrs);
 
-Expr tril(Expr x, int k) {
-  ObjectPtr<TriluAttrs> attrs = make_object<TriluAttrs>();
-  attrs->k = k;
-
+Expr tril(Expr x, Expr k) {
   static const Op& op = Op::Get("relax.tril");
-  return Call(op, {std::move(x)}, Attrs(attrs), {});
+  return Call(op, {x, k});
 }
 
-Expr triu(Expr x, int k) {
-  ObjectPtr<TriluAttrs> attrs = make_object<TriluAttrs>();
-  attrs->k = k;
+Expr tril(Expr x, int k) { return tril(x, relax::PrimValue::Int64(k)); }
 
+Expr triu(Expr x, Expr k) {
   static const Op& op = Op::Get("relax.triu");
-  return Call(op, {std::move(x)}, Attrs(attrs), {});
+  return Call(op, {x, k});
 }
 
-TVM_REGISTER_GLOBAL("relax.op.tril").set_body_typed(tril);
-TVM_REGISTER_GLOBAL("relax.op.triu").set_body_typed(triu);
+Expr triu(Expr x, int k) { return triu(x, relax::PrimValue::Int64(k)); }
+
+TVM_REGISTER_GLOBAL("relax.op.tril").set_body_typed(static_cast<Expr (*)(Expr, 
Expr)>(tril));
+TVM_REGISTER_GLOBAL("relax.op.triu").set_body_typed(static_cast<Expr (*)(Expr, 
Expr)>(triu));
 
 StructInfo InferStructInfoTrilTriu(const Call& call, const BlockBuilder& ctx) {
-  TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  auto [data_sinfo, offset] = GetArgStructInfo<TensorStructInfo, 
PrimStructInfo>(call, ctx);
+
   if (!data_sinfo->IsUnknownNdim() && data_sinfo->ndim < 2) {
     ctx->ReportFatal(Diagnostic::Error(call) << call->op
                                              << " requires the input tensor to 
have at least two "
@@ -317,16 +316,16 @@ StructInfo InferStructInfoTrilTriu(const Call& call, 
const BlockBuilder& ctx) {
 }
 
 TVM_REGISTER_OP("relax.tril")
-    .set_attrs_type<TriluAttrs>()
-    .set_num_inputs(1)
+    .set_num_inputs(2)
     .add_argument("x", "Tensor", "The input tensor.")
+    .add_argument("k", "PrimValue", "The offset of the diagonal.")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTrilTriu)
     .set_attr<Bool>("FPurity", Bool(true));
 
 TVM_REGISTER_OP("relax.triu")
-    .set_attrs_type<TriluAttrs>()
-    .set_num_inputs(1)
+    .set_num_inputs(2)
     .add_argument("x", "Tensor", "The input tensor.")
+    .add_argument("k", "PrimValue", "The offset of the diagonal.")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTrilTriu)
     .set_attr<Bool>("FPurity", Bool(true));
 
diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h
index 40a46abe35..989eaa12fd 100644
--- a/src/relax/op/tensor/create.h
+++ b/src/relax/op/tensor/create.h
@@ -82,9 +82,21 @@ Expr zeros_like(Expr x, DataType dtype);
 Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype);
 
 /*! \brief Return the lower triangular part of a matrix or a batch of 
matrices. */
+Expr tril(Expr x, Expr k);
+
+/*! \brief Return the lower triangular part of a matrix or a batch of matrices.
+ *
+ * Overload provided for backwards compatibility.
+ */
 Expr tril(Expr x, int k);
 
 /*! \brief Return the upper triangular part of a matrix or a batch of 
matrices. */
+Expr triu(Expr x, Expr k);
+
+/*! \brief Return the upper triangular part of a matrix or a batch of matrices.
+ *
+ * Overload provided for backwards compatibility.
+ */
 Expr triu(Expr x, int k);
 
 }  // namespace relax
diff --git a/tests/python/contrib/test_msc/test_translate_relay.py 
b/tests/python/contrib/test_msc/test_translate_relay.py
index 70f5e8cda5..f543c8c292 100644
--- a/tests/python/contrib/test_msc/test_translate_relay.py
+++ b/tests/python/contrib/test_msc/test_translate_relay.py
@@ -19,6 +19,8 @@
 """ Test translate from relay. """
 
 import numpy as np
+import pytest
+
 import torch
 from torch import fx
 from torch.nn import Module
@@ -805,6 +807,9 @@ def test_tensor():
     verify_model(Empty2(), [([10, 10], "float32")], build_target="llvm")
 
 
[email protected](
+    reason="Failure to convert from R.PrimValue argument in 
msc/framework/tvm/codegen.cc"
+)
 def test_tril():
     """test relay to relax for tril"""
 
@@ -822,6 +827,9 @@ def test_tril():
     verify_model(InplaceTril(), input_info)
 
 
[email protected](
+    reason="Failure to convert from R.PrimValue argument in 
msc/framework/tvm/codegen.cc"
+)
 def test_triu():
     """test relay to relax for triu"""
 
diff --git a/tests/python/contrib/test_msc/test_translate_torch.py 
b/tests/python/contrib/test_msc/test_translate_torch.py
index 6b9a7c9332..ae377e52cb 100644
--- a/tests/python/contrib/test_msc/test_translate_torch.py
+++ b/tests/python/contrib/test_msc/test_translate_torch.py
@@ -18,6 +18,7 @@
 """ Test translate from torch. """
 
 import numpy as np
+import pytest
 
 import torch
 from torch.nn import Module
@@ -781,40 +782,62 @@ def test_arange():
     verify_model(Arange(), [([10, 10], "float32")])
 
 
-def test_tril():
+# pylint: disable=redefined-outer-name
+via_relax_param = tvm.testing.parameter(
+    True,
+    pytest.param(
+        False,
+        marks=pytest.mark.xfail(
+            reason="Failure to convert from R.PrimValue argument in 
msc/framework/tvm/codegen.cc"
+        ),
+    ),
+)
+
+
+def test_tril(via_relax_param):
     """test torch translator for tril"""
 
     class Tril(Module):
         def forward(self, data):
             return torch.tril(data, 1)
 
+    input_info = [([10, 10], "float32")]
+    verify_model(Tril(), input_info, via_relax_param)
+
+
+def test_tril_inplace(via_relax_param):
+    """test torch translator for tril"""
+
     class InplaceTril(Module):
         def forward(self, data):
             data.tril_(1)
             return data
 
     input_info = [([10, 10], "float32")]
-    for via_relax in [True, False]:
-        verify_model(Tril(), input_info, via_relax)
-        verify_model(InplaceTril(), input_info, via_relax)
+    verify_model(InplaceTril(), input_info, via_relax_param)
 
 
-def test_triu():
+def test_triu(via_relax_param):
     """test torch translator for triu"""
 
     class Triu(Module):
         def forward(self, data):
             return torch.triu(data, 1)
 
+    input_info = [([10, 10], "float32")]
+    verify_model(Triu(), input_info, via_relax_param)
+
+
+def test_triu_inplace(via_relax_param):
+    """test torch translator for triu"""
+
     class InplaceTriu(Module):
         def forward(self, data):
             data.triu_(1)
             return data
 
     input_info = [([10, 10], "float32")]
-    for via_relax in [True, False]:
-        verify_model(Triu(), input_info, via_relax)
-        verify_model(InplaceTriu(), input_info, via_relax)
+    verify_model(InplaceTriu(), input_info, via_relax_param)
 
 
 def test_new_ones():
diff --git a/tests/python/relax/test_op_create.py 
b/tests/python/relax/test_op_create.py
index c0b4308529..1e895169f6 100644
--- a/tests/python/relax/test_op_create.py
+++ b/tests/python/relax/test_op_create.py
@@ -643,11 +643,19 @@ def test_tril_triu_infer_struct_info_shape_symbolic():
     x0 = relax.Var("x", R.Tensor((a, b, c), "float32"))
     x1 = relax.Var("x", R.Tensor((a, b, c)))
     x2 = relax.Var("x", R.Tensor((a, b, c), "float32", vdev0))
+    x3 = relax.Var("x", R.Tensor((16, 32, 64)))
 
+    # Dynamic tensor, static offset
     _check_inference(bb, relax.op.tril(x0), relax.TensorStructInfo((a, b, c), 
"float32"))
     _check_inference(bb, relax.op.triu(x1), relax.TensorStructInfo((a, b, c), 
dtype=""))
     _check_inference(bb, relax.op.tril(x2), relax.TensorStructInfo((a, b, c), 
"float32", vdev0))
 
+    # Static tensor, dynamic offset
+    _check_inference(bb, relax.op.tril(x3, a), relax.TensorStructInfo((16, 32, 
64), dtype=""))
+
+    # Dynamic tensor, dynamic offset
+    _check_inference(bb, relax.op.tril(x0, a), relax.TensorStructInfo((a, b, 
c), "float32"))
+
 
 def test_tril_triu_infer_struct_info_shape_var():
     bb = relax.BlockBuilder()

Reply via email to