This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 225d067fb8 [Unity] Support Padding Reversal in Alter-Op pass (#15679)
225d067fb8 is described below

commit 225d067fb85e37723bbb3a5ca51f872c0e0f6abf
Author: rutkoor <120498024+rutk...@users.noreply.github.com>
AuthorDate: Thu Sep 28 23:19:01 2023 +0530

    [Unity] Support Padding Reversal in Alter-Op pass (#15679)
    
    * Support for padding Reversal in Alter-op pass
    
    * Removing lambda from remove_pad te.compute
    
    * Applying clang-format on .cc file
    
    * Removing extra line from manipulate.h file
---
 .../tvm/relax/transform/legalize_ops/manipulate.py | 10 ++-
 python/tvm/relax/transform/transform.py            |  2 +-
 src/relax/transform/alter_op_impl.cc               | 81 +++++++++++++++++++---
 src/te/operation/compute_op.cc                     |  6 +-
 tests/python/relax/test_transform_alter_op_impl.py | 80 +++++++++++++++++----
 5 files changed, 153 insertions(+), 26 deletions(-)

diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py 
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 4e06a0df39..e56240dc0d 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -182,7 +182,15 @@ def _layout_transform(bb: BlockBuilder, call: Call) -> 
Expr:
         )
 
     index_map: tvm.tir.IndexMap = call.attrs.index_map
-    pad_value = call.attrs.pad_value.value
+    pad_value = call.attrs.pad_value
+    if pad_value is not None:
+        pad_value = pad_value.value
+    else:
+        if "int" in call.args[0].struct_info.dtype:
+            pad_value = int(0)
+        else:
+            pad_value = float(0.0)
+
     axis_separators: tvm.tir.IndexMap.AXIS_SEPARATOR = 
call.attrs.axis_separators
     # Convert to list from array
     axis_separators = list(map(lambda x: x.value, axis_separators))
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 2a06d5098e..72a9966a4b 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -1034,7 +1034,7 @@ def AlterOpImpl(
         l = []
         for transform in transform_list:
             if isinstance(transform, Callable):
-                transform = IndexMap.from_func(transform)
+                transform = IndexMap.from_func_with_separators(transform)[0]
             l.append(transform)
         op_buffer_transforms[operator_name] = l
 
diff --git a/src/relax/transform/alter_op_impl.cc 
b/src/relax/transform/alter_op_impl.cc
index 9813c4ed24..98d64dd7a8 100644
--- a/src/relax/transform/alter_op_impl.cc
+++ b/src/relax/transform/alter_op_impl.cc
@@ -30,7 +30,11 @@
 #include <tvm/relax/attrs/manipulate.h>
 #include <tvm/relax/expr_functor.h>
 #include <tvm/relax/transform.h>
+#include <tvm/te/operation.h>
 #include <tvm/tir/transform.h>
+#include <tvm/topi/tags.h>
+
+#include "../../te/operation/create_primfunc.h"
 namespace tvm {
 namespace relax {
 
@@ -162,8 +166,18 @@ class AlterOpImplMutator : public ExprMutator {
     return arr_tensor_sinfo;
   }
 
+  bool IsScalarConstant(const Expr& expr) {
+    if (expr->IsInstance<ConstantNode>() && 
expr.as<ConstantNode>()->is_scalar()) {
+      return true;
+    }
+    return false;
+  }
+
   Expr TransformLayout(const Expr& expr, const IndexMap& index_map,
-                       const Array<IntImm> axis_separators) {
+                       const Array<IntImm>& axis_separators) {
+    if (IsScalarConstant(expr) || index_map.get() == nullptr) {
+      return expr;
+    }
     ObjectPtr<LayoutTransformAttrs> attrs = 
make_object<LayoutTransformAttrs>();
     // We want to avoid two layout_transform ops to share the same index map 
even if they are
     // identical. The scope of vars used in index map initial indices is local 
to the op. Not doing
@@ -173,19 +187,70 @@ class AlterOpImplMutator : public ExprMutator {
     return Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {});
   }
 
+  /*!
+   * \brief Adds the \p remove_pad op to the module if it has not already been 
added before.
+   * \returns The global var associated with the remove_pad PrimFunc.
+   */
+  GlobalVar GetOrCreateRemovePadOp(const Array<PrimExpr>& old_shape, const 
DataType& dtype) {
+    int t_shape = old_shape.size();
+    if (remove_pad_map_.count(t_shape) != 0) {
+      return remove_pad_map_[t_shape];
+    }
+    // Create dynamic shapes for input and output tensors
+    Array<PrimExpr> dyn_padded_shape, dyn_old_shape;
+    for (int i = 0; i < t_shape; i++) {
+      tir::Var var1("p" + std::to_string(i), old_shape[i].dtype());
+      tir::Var var2("i" + std::to_string(i), old_shape[i].dtype());
+      dyn_padded_shape.push_back(var1);
+      dyn_old_shape.push_back(var2);
+    }
+
+    // Input tensor of remove_pad op
+    te::Tensor placeholder_tensor = te::placeholder(dyn_padded_shape, dtype, 
"input");
+    // Output tensor of remove_pad op
+    te::Tensor output_tensor = te::compute(
+        dyn_old_shape,
+        [&placeholder_tensor](const Array<tir::Var>& indices) {
+          return placeholder_tensor(indices);
+        },
+        "output", topi::kElementWise);
+
+    String op_name = "remove_pad";
+    // Create PrimFunc and add op_name to func.attrs
+    PrimFunc remove_pad_with_frozen_layout =
+        WithAttr(CreatePrimFunc({placeholder_tensor, output_tensor}), 
kOperatorName, op_name);
+    // Add PrimFunc to module
+    GlobalVar gv_remove_pad = 
builder_->AddFunction(remove_pad_with_frozen_layout, op_name);
+    // Mark the remove_pad PrimFunc as private by removing it from global scope
+    builder_->UpdateFunction(gv_remove_pad,
+                             WithoutAttr(remove_pad_with_frozen_layout, 
"global_symbol"));
+
+    remove_pad_map_[t_shape] = gv_remove_pad;
+    return gv_remove_pad;
+  }
+
   Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map,
                               const TensorStructInfo& old_tensor_sinfo,
                               const Array<IntImm>& axis_separator) {
+    if (IsScalarConstant(expr) || index_map.get() == nullptr) {
+      return expr;
+    }
     Array<PrimExpr> old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo);
     Array<Range> initial_ranges = ConstructRangeFromShape(old_shape);
     arith::Analyzer analyzer;
     auto [inverse_index_map, padding_predicate] =
         index_map.NonSurjectiveInverse(initial_ranges, &analyzer);
-    ICHECK(tir::is_zero(padding_predicate))
-        << "Only bijective transformations on input/output buffers are 
supported, but found "
-           "padding predicate "
-        << padding_predicate << " on initial range " << initial_ranges;
-    return TransformLayout(expr, inverse_index_map, axis_separator);
+
+    if (tir::is_zero(padding_predicate)) {
+      return TransformLayout(expr, inverse_index_map, axis_separator);
+    } else {
+      auto padded_expr =
+          builder_->Normalize(TransformLayout(expr, inverse_index_map, 
axis_separator));
+      const auto& tensor_sinfo = 
Downcast<TensorStructInfo>(padded_expr->struct_info_);
+
+      GlobalVar gv_remove_pad = GetOrCreateRemovePadOp(old_shape, 
tensor_sinfo->dtype);
+      return Call(call_tir_op_, {gv_remove_pad, Tuple({padded_expr})}, {}, 
{old_tensor_sinfo});
+    }
   }
 
   /*!
@@ -223,8 +288,6 @@ class AlterOpImplMutator : public ExprMutator {
         axis_separator = axis_separators_value[index];
       }
       auto transform = transforms[index++];
-      ICHECK(IsTransformBijective(input, transform))
-          << "Non bijective transforms on input and output buffers are not 
supported.";
       updated_inputs.push_back(TransformLayout(input, transform, 
axis_separator));
     }
     return Tuple(updated_inputs);
@@ -314,6 +377,8 @@ class AlterOpImplMutator : public ExprMutator {
   Map<PrimFunc, GlobalVar> cache_;
   /*! \brief Input IRModule */
   const IRModule& mod_;
+  /*! \brief Map from shape_dim.size to the remove_pad GlobalVar */
+  std::unordered_map<int, GlobalVar> remove_pad_map_;
   /*! \brief Map from kOperatorName attribute to the replacement PrimFunc */
   const Map<String, PrimFunc>& op_impl_map_;
   /*! \brief Map from kOperatorName attribute to the layout transforms on i/o 
buffers */
diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc
index 3ca40c9a6b..5797d2295b 100644
--- a/src/te/operation/compute_op.cc
+++ b/src/te/operation/compute_op.cc
@@ -98,7 +98,8 @@ Tensor compute(Array<PrimExpr> shape, FCompute fcompute, 
std::string name, std::
   for (size_t i = 0; i < ndim; ++i) {
     std::ostringstream os;
     os << "ax" << i;
-    axis.emplace_back(IterVar(Range(0, shape[i]), Var(os.str(), 
shape[i].dtype()), kDataPar));
+    axis.emplace_back(IterVar(Range(IntImm(shape[i]->dtype, 0), shape[i]),
+                              Var(os.str(), shape[i].dtype()), kDataPar));
     args.push_back(axis.back()->var);
   }
 
@@ -114,7 +115,8 @@ Array<Tensor> compute(Array<PrimExpr> shape, FBatchCompute 
fcompute, std::string
   for (size_t i = 0; i < ndim; ++i) {
     std::ostringstream os;
     os << "ax" << i;
-    axis.emplace_back(IterVar(Range(0, shape[i]), Var(os.str(), 
shape[i].dtype()), kDataPar));
+    axis.emplace_back(IterVar(Range(IntImm(shape[i]->dtype, 0), shape[i]),
+                              Var(os.str(), shape[i].dtype()), kDataPar));
     args.push_back(axis.back()->var);
   }
 
diff --git a/tests/python/relax/test_transform_alter_op_impl.py 
b/tests/python/relax/test_transform_alter_op_impl.py
index 81bc480785..3cbba9a031 100644
--- a/tests/python/relax/test_transform_alter_op_impl.py
+++ b/tests/python/relax/test_transform_alter_op_impl.py
@@ -303,13 +303,13 @@ def test_multiple_outputs_with_axis_sep():
     )
 
 
-def test_unsupported_implicit_padding():
+def test_supported_implicit_padding():
     @I.ir_module
-    class InputModule:
+    class Before:
         @R.function
         def foo(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), 
dtype="float32"):
             with R.dataflow():
-                lv = R.call_tir(InputModule.relu, (x,), 
out_sinfo=R.Tensor((14,), dtype="float32"))
+                lv = R.call_tir(Before.relu, (x,), out_sinfo=R.Tensor((14,), 
dtype="float32"))
                 gv: R.Tensor((14,), dtype="float32") = lv
                 R.output(gv)
             return gv
@@ -324,7 +324,62 @@ def test_unsupported_implicit_padding():
                     T.writes(output[v_ax0])
                     output[v_ax0] = T.max(arg0[v_ax0], T.float32(0))
 
-    before = InputModule
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), 
dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((16,), dtype="float32") = R.layout_transform(
+                    x,
+                    index_map=T.index_map(lambda i: (i % 16,)),
+                    pad_value=None,
+                    axis_separators=[],
+                )
+                lv1 = R.call_tir(
+                    Expected.relax_relu_replacement,
+                    (lv,),
+                    out_sinfo=R.Tensor((16,), dtype="float32"),
+                )
+                lv2: R.Tensor((16,), dtype="float32") = R.layout_transform(
+                    lv1,
+                    index_map=T.index_map(lambda axis0: (axis0,)),
+                    pad_value=None,
+                    axis_separators=[],
+                )
+                lv_1 = R.call_tir(
+                    Expected.remove_pad, (lv2,), out_sinfo=R.Tensor((14,), 
dtype="float32")
+                )
+                gv: R.Tensor((14,), dtype="float32") = lv_1
+                R.output(gv)
+            return gv
+
+        @T.prim_func(private=True)
+        def relax_relu_replacement(
+            arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), 
"float32")
+        ):
+            T.func_attr({"operator_name": "relax.relu"})
+            # with T.block("root"):
+            for ax0 in range(16):
+                with T.block("T_add"):
+                    v_ax0 = T.axis.spatial(16, ax0)
+                    T.reads(arg0[v_ax0])
+                    T.writes(output[v_ax0])
+                    output[v_ax0] = T.max(arg0[v_ax0], T.float32(0))
+
+        @T.prim_func(private=True)
+        def remove_pad(var_input: T.handle, var_output: T.handle):
+            T.func_attr({"operator_name": "remove_pad", "tir.noalias": 
T.bool(True)})
+            p0 = T.int64()
+            input = T.match_buffer(var_input, (p0,))
+            i0 = T.int64()
+            output = T.match_buffer(var_output, (i0,))
+            # with T.block("root"):
+            for ax0 in range(i0):
+                with T.block("output"):
+                    v_ax0 = T.axis.spatial(i0, ax0)
+                    T.reads(input[v_ax0])
+                    T.writes(output[v_ax0])
+                    output[v_ax0] = input[v_ax0]
 
     @T.prim_func(private=True)
     def relu_pad(arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), 
"float32")):
@@ -338,16 +393,13 @@ def test_unsupported_implicit_padding():
     # introduces implicit padding for shape (14,)
     index_map = lambda i: (i % 16)
     operator_name = "relax.relu"
-    with pytest.raises(
-        tvm.TVMError, match="Non bijective transforms on input and output 
buffers are not supported"
-    ):
-        _ = relax.transform.AlterOpImpl(
-            {operator_name: relu_pad},
-            {
-                operator_name: [index_map, index_map],
-            },
-            {operator_name: None},
-        )(before)
+    _check(
+        Before,
+        Expected,
+        operator_name="relax.relu",
+        replacement_primfunc=relu_pad,
+        layout_changes=[index_map, index_map],
+    )
 
 
 def test_multiple_call_sites():

Reply via email to