This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new b54f9ff194 [Unity] Add support for AXIS_SEPARATOR in AlterOpImpl Pass 
(#15315)
b54f9ff194 is described below

commit b54f9ff194e65988b57be1d46fd833824b7c548f
Author: Abhikrant Sharma <[email protected]>
AuthorDate: Tue Jul 25 01:30:17 2023 +0530

    [Unity] Add support for AXIS_SEPARATOR in AlterOpImpl Pass (#15315)
    
    * [Unity] Add support for AXIS_SEPARATOR in AlterOpImpl Pass
    
    Enable support of AXIS_SEPARATOR to handle non-flat buffers
    Modified  pass to handle AXIS_SEPARATOR
    
    * Fix LINT errors.
---
 include/tvm/relax/attrs/manipulate.h               |  9 +++
 include/tvm/relax/transform.h                      |  4 +-
 python/tvm/relax/op/manipulate.py                  | 10 ++-
 python/tvm/relax/transform/transform.py            |  8 +-
 src/relax/op/tensor/manipulate.cc                  |  4 +-
 src/relax/op/tensor/manipulate.h                   |  5 +-
 src/relax/transform/alter_op_impl.cc               | 55 ++++++++++----
 tests/python/relax/test_transform_alter_op_impl.py | 88 +++++++++++++++++++++-
 8 files changed, 161 insertions(+), 22 deletions(-)

diff --git a/include/tvm/relax/attrs/manipulate.h 
b/include/tvm/relax/attrs/manipulate.h
index 550515c032..b9d0b9f53b 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -59,12 +59,21 @@ struct LayoutTransformAttrs : public 
tvm::AttrsNode<LayoutTransformAttrs> {
   // pad_value is chosen to be of PrimValue type, as it represents constant 
TIR POD expression. This
   // needs to be revisited in case PrimValue is evolved to represent symbolic 
expression in future.
   Optional<PrimValue> pad_value;
+  /*!
+   * axis_separators between input axes when generating flattened output axes. 
For buffers
+   * representing flat 1-d memory (e.g. any buffer in RAM), this should be an 
empty array.
+   * For buffers representing non-flat memory, each entry in axis_separators 
should be the
+   * first input axis that is part of a new flattened axis.
+   */
+  Optional<Array<IntImm>> axis_separators;
 
   TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") {
     TVM_ATTR_FIELD(index_map).describe("The layout transformation to apply.");
     TVM_ATTR_FIELD(pad_value).describe(
         "The specific value to be used to pad if the layout transform would 
result in implicit "
         "padding. If not specified, the compiler is free to choose any 
value.");
+    TVM_ATTR_FIELD(axis_separators)
+        .describe("The separators between input axes when generating flat 
output axes");
   }
 };  // struct LayoutTransformAttrs
 
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index b618672292..dc2476f383 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -464,10 +464,12 @@ TVM_DLL Pass DecomposeOpsForTraining(Optional<String> 
func_name);
  * \param op_impl_map Map from from kOperatorName attr (e.g., relax.conv2d) to 
replacement PrimFunc
  * \param op_buffer_transforms Map from kOperatorName attr to layout 
transformations on each of the
  * PrimFunc i/o buffers.
+ * \param axis_separators Map from kOperatorName attr to axis_separators of 
each buffer_transforms
  * \return The Pass.
  */
 TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
-                         const Map<String, Array<tir::IndexMap>>& 
op_buffer_transforms);
+                         const Map<String, Array<tir::IndexMap>>& 
op_buffer_transforms,
+                         const Map<String, Array<Array<IntImm>>>& 
axis_separators);
 
 /*!
  * \brief Layout conversion pass.
diff --git a/python/tvm/relax/op/manipulate.py 
b/python/tvm/relax/op/manipulate.py
index a51b516b2e..a6f1c580a4 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -114,6 +114,7 @@ def layout_transform(
     x: Expr,
     index_map: Union[Callable, IndexMap],
     pad_value: Optional[Union[int, float, PrimValue]] = None,
+    axis_separators: Optional[Union[int, IndexMap.AXIS_SEPARATOR]] = None,
 ):
     """Modifies the layout of a tensor.
 
@@ -129,6 +130,9 @@ def layout_transform(
         The value used for padding if the transformation results in implicit 
padding.
         If not specified, any value can be used.
 
+    axis_separators : Optional[Union[int, IndexMap.AXIS_SEPARATOR]]
+        The axis_separators for index_map to create non flat buffers.
+
     Returns
     -------
     result : relax.Expr
@@ -150,7 +154,11 @@ def layout_transform(
         elif "float" in x_dtype and (isinstance(pad_value, (int, float))):
             pad_value = FloatImm(x_dtype, float(pad_value))
         pad_value = PrimValue(pad_value)
-    return _ffi_api.layout_transform(x, index_map, pad_value)  # type: ignore
+
+    if axis_separators is None:
+        axis_separators = []
+
+    return _ffi_api.layout_transform(x, index_map, pad_value, axis_separators) 
 # type: ignore
 
 
 def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr:
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index d1d337a8c0..73267c43ae 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -903,6 +903,7 @@ def DecomposeOpsForTraining(func_name: Optional[str] = 
None) -> tvm.ir.transform
 def AlterOpImpl(
     op_impl_map: Dict[str, PrimFunc],
     op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]],
+    op_buffer_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, 
Callable]]],
 ):
     """Replace all PrimFunc's which have matching 'operator_name' attribute, 
with replacement
     PrimFunc that could possibly have different layouts on i/o buffers. The 
layout
@@ -916,6 +917,9 @@ def AlterOpImpl(
         op_kind to PrimFunc map
     op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]
         op_kind to layout transformation map for each of the buffers
+    op_buffer_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, 
Callable]]]
+        op_kind to axis_separator for each index_map
+
     Returns
     -------
     ret: tvm.ir.transform.Pass
@@ -928,7 +932,9 @@ def AlterOpImpl(
             l.append(transform)
         op_buffer_transforms[operator_name] = l
 
-    return _ffi_api.AlterOpImpl(op_impl_map, op_buffer_transforms)  # type: 
ignore
+    return _ffi_api.AlterOpImpl(
+        op_impl_map, op_buffer_transforms, op_buffer_axis_separators
+    )  # type: ignore
 
 
 def ConvertLayout(desired_layouts: Dict[str, List[str]]) -> 
tvm.ir.transform.Pass:
diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index a55d199822..2d7e60c4f0 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -425,10 +425,12 @@ TVM_REGISTER_OP("relax.flatten")
 /* relax.layout_transform */
 TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);
 
-Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> 
pad_value) {
+Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> 
pad_value,
+                      Optional<Array<IntImm>> axis_separators) {
   ObjectPtr<LayoutTransformAttrs> attrs = make_object<LayoutTransformAttrs>();
   attrs->index_map = std::move(index_map);
   attrs->pad_value = std::move(pad_value);
+  attrs->axis_separators = std::move(axis_separators);
 
   static const Op& op = Op::Get("relax.layout_transform");
   return Call(op, {std::move(x)}, Attrs{attrs}, {});
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 1593b19c17..b19e3b8507 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -65,9 +65,12 @@ Expr flatten(Expr x);
  * \param index_map The transformation to apply.
  * \param pad_value The value used for padding if the transformation results 
in implicit padding. If
  * not specified, any value can be used.
+ * \param axis_separators Array of values to differentiate between input axes
+ * when generating flattened output axes.
  * \return The transformed result.
  */
-Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> 
pad_value);
+Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> 
pad_value,
+                      Optional<Array<IntImm>> axis_separators);
 
 /*!
  * \brief Permutes the dimensions of an array.
diff --git a/src/relax/transform/alter_op_impl.cc 
b/src/relax/transform/alter_op_impl.cc
index f40ee3b3bf..c303a2c8f0 100644
--- a/src/relax/transform/alter_op_impl.cc
+++ b/src/relax/transform/alter_op_impl.cc
@@ -76,11 +76,13 @@ bool IsTransformBijective(const Expr& expr, const IndexMap& 
transform) {
 class AlterOpImplMutator : public ExprMutator {
  public:
   AlterOpImplMutator(const IRModule& mod, const Map<String, tir::PrimFunc>& 
op_impl_map,
-                     const Map<String, Array<IndexMap>>& op_buffer_transforms_)
+                     const Map<String, Array<IndexMap>>& op_buffer_transforms_,
+                     const Map<String, Array<Array<IntImm>>>& axis_separators_)
       : ExprMutator(mod),
         mod_(mod),
         op_impl_map_(op_impl_map),
-        op_buffer_transforms__(op_buffer_transforms_) {}
+        op_buffer_transforms__(op_buffer_transforms_),
+        op_buffer_axis_separators__(axis_separators_) {}
 
   IRModule Run() {
     for (const auto& [gv, func] : mod_->functions) {
@@ -119,7 +121,10 @@ class AlterOpImplMutator : public ExprMutator {
     const auto& replacement_func = op_impl_map_[op_kind];
 
     Array<IndexMap> buffer_transforms;
+    Optional<Array<Array<IntImm>>> axis_separators;
     if (op_buffer_transforms__.count(op_kind)) buffer_transforms = 
op_buffer_transforms__[op_kind];
+    if (op_buffer_axis_separators__.count(op_kind))
+      axis_separators = op_buffer_axis_separators__[op_kind];
 
     ICHECK(buffer_transforms.empty() || buffer_transforms.size() == 
replacement_func->params.size())
         << "Either the i/o buffers do not require any transformations or 
transformations for each "
@@ -130,7 +135,7 @@ class AlterOpImplMutator : public ExprMutator {
     GlobalVar replacement_gv = GetOrCreateGlobalVarForFunc(replacement_func, 
op_kind);
 
     auto call_tir_inputs_tuple = GetRef<Tuple>(call->args[1].as<TupleNode>());
-    Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, 
buffer_transforms);
+    Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, 
buffer_transforms, axis_separators);
 
     ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir sinfo_args.size() is 
expected to be 1";
     StructInfo updated_ret_sinfo = UpdateStructInfo(call->sinfo_args[0], 
buffer_transforms);
@@ -138,7 +143,7 @@ class AlterOpImplMutator : public ExprMutator {
         Call(call_tir_op_, {replacement_gv, updated_inputs}, call->attrs, 
{updated_ret_sinfo}));
 
     // Now transform each of the outputs to previous layout.
-    return TransformOutputs(updated_call, buffer_transforms, 
call->sinfo_args[0]);
+    return TransformOutputs(updated_call, buffer_transforms, 
call->sinfo_args[0], axis_separators);
   }
 
   Array<TensorStructInfo> GetTensorStructInfoPerOutput(const StructInfo& 
output_sinfo) {
@@ -157,17 +162,20 @@ class AlterOpImplMutator : public ExprMutator {
     return arr_tensor_sinfo;
   }
 
-  Expr TransformLayout(const Expr& expr, const IndexMap& index_map) {
+  Expr TransformLayout(const Expr& expr, const IndexMap& index_map,
+                       const Array<IntImm> axis_separators) {
     ObjectPtr<LayoutTransformAttrs> attrs = 
make_object<LayoutTransformAttrs>();
     // We want to avoid two layout_transform ops to share the same index map 
even if they are
     // identical. The scope of vars used in index map initial indices is local 
to the op. Not doing
     // so would confuse the structural equality check.
     attrs->index_map = std::move(DeepCopyIndexMap(index_map));
+    attrs->axis_separators = std::move(axis_separators);
     return Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {});
   }
 
   Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map,
-                              const TensorStructInfo& old_tensor_sinfo) {
+                              const TensorStructInfo& old_tensor_sinfo,
+                              const Array<IntImm>& axis_separator) {
     Array<PrimExpr> old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo);
     Array<Range> initial_ranges = ConstructRangeFromShape(old_shape);
     arith::Analyzer analyzer;
@@ -177,7 +185,7 @@ class AlterOpImplMutator : public ExprMutator {
         << "Only bijective transformations on input/output buffers are 
supported, but found "
            "padding predicate "
         << padding_predicate << " on initial range " << initial_ranges;
-    return TransformLayout(expr, inverse_index_map);
+    return TransformLayout(expr, inverse_index_map, axis_separator);
   }
 
   /*!
@@ -202,16 +210,22 @@ class AlterOpImplMutator : public ExprMutator {
   /*!
    * \brief Updates call inputs with layout transformed inputs
    */
-  Tuple UpdateInputs(const Tuple& inputs, const Array<IndexMap>& transforms) {
+  Tuple UpdateInputs(const Tuple& inputs, const Array<IndexMap>& transforms,
+                     const Optional<Array<Array<IntImm>>>& axis_separators) {
     if (transforms.empty()) return inputs;
 
     Array<Expr> updated_inputs;
     int index = 0;
     for (const auto& input : inputs->fields) {
+      Array<IntImm> axis_separator;
+      if (axis_separators.defined()) {
+        Array<Array<IntImm>> axis_separators_value = axis_separators.value();
+        axis_separator = axis_separators_value[index];
+      }
       auto transform = transforms[index++];
       ICHECK(IsTransformBijective(input, transform))
           << "Non bijective transforms on input and output buffers are not 
supported.";
-      updated_inputs.push_back(TransformLayout(input, transform));
+      updated_inputs.push_back(TransformLayout(input, transform, 
axis_separator));
     }
     return Tuple(updated_inputs);
   }
@@ -254,11 +268,13 @@ class AlterOpImplMutator : public ExprMutator {
   }
 
   Expr TransformOutputs(const Expr& expr, const Array<IndexMap>& 
buffer_transforms,
-                        const StructInfo& old_struct_info) {
+                        const StructInfo& old_struct_info,
+                        const Optional<Array<Array<IntImm>>>& axis_separators) 
{
     if (buffer_transforms.empty()) return expr;
 
     Array<TensorStructInfo> old_output_sinfo = 
GetTensorStructInfoPerOutput(old_struct_info);
 
+    Array<IntImm> axis_sep;
     size_t num_outputs = old_output_sinfo.size();
     if (num_outputs == 0) return expr;
 
@@ -266,7 +282,11 @@ class AlterOpImplMutator : public ExprMutator {
     // If there is a single output, return the transformed output.
     if (num_outputs == 1) {
       IndexMap output_map = buffer_transforms[first_output_index];
-      return TransformLayoutInverse(expr, output_map, old_output_sinfo[0]);
+      if (axis_separators.defined()) {
+        Array<Array<IntImm>> axis_separators_value = axis_separators.value();
+        axis_sep = axis_separators_value[first_output_index];
+      }
+      return TransformLayoutInverse(expr, output_map, old_output_sinfo[0], 
axis_sep);
     }
 
     // In case of more than one output, we would have to get each item of the 
output tuple,
@@ -274,9 +294,13 @@ class AlterOpImplMutator : public ExprMutator {
     Array<Expr> transformed_outputs;
     for (size_t i = 0; i + first_output_index < buffer_transforms.size(); ++i) 
{
       const auto& output_map = buffer_transforms[i + first_output_index];
+      if (axis_separators.defined()) {
+        Array<Array<IntImm>> axis_separators_value = axis_separators.value();
+        axis_sep = axis_separators_value[i + first_output_index];
+      }
       auto output = builder_->Normalize(TupleGetItem(expr, 
static_cast<int>(i)));
       transformed_outputs.push_back(
-          TransformLayoutInverse(output, output_map, old_output_sinfo[i]));
+          TransformLayoutInverse(output, output_map, old_output_sinfo[i], 
axis_sep));
     }
     return Tuple(transformed_outputs);
   }
@@ -290,6 +314,8 @@ class AlterOpImplMutator : public ExprMutator {
   const Map<String, PrimFunc>& op_impl_map_;
   /*! \brief Map from kOperatorName attribute to the layout transforms on i/o 
buffers */
   const Map<String, Array<IndexMap>>& op_buffer_transforms__;
+  /*! \brief Map from kOperatorName attribute to the axis separatos on i/o 
buffers */
+  const Map<String, Array<Array<IntImm>>>& op_buffer_axis_separators__;
 
   const Op& call_tir_op_ = Op::Get("relax.call_tir");
   const Op& layout_transform_op_ = Op::Get("relax.layout_transform");
@@ -298,10 +324,11 @@ class AlterOpImplMutator : public ExprMutator {
 namespace transform {
 
 Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
-                 const Map<String, Array<IndexMap>>& op_buffer_transforms_) {
+                 const Map<String, Array<IndexMap>>& op_buffer_transforms_,
+                 const Map<String, Array<Array<IntImm>>>& axis_separators_) {
   runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = 
[=](IRModule mod,
                                                                             
PassContext pc) {
-    return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_).Run();
+    return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_, 
axis_separators_).Run();
   };
   return CreateModulePass(/*pass_function=*/pass_func,  //
                           /*opt_level=*/0,              //
diff --git a/tests/python/relax/test_transform_alter_op_impl.py 
b/tests/python/relax/test_transform_alter_op_impl.py
index 77e2d4e359..aa35067d58 100644
--- a/tests/python/relax/test_transform_alter_op_impl.py
+++ b/tests/python/relax/test_transform_alter_op_impl.py
@@ -20,13 +20,18 @@ import tvm.testing
 
 from tvm import relax
 from tvm.script import tir as T, ir as I, relax as R
+from tvm.tir import IndexMap
 
 kOperatorName = "operator_name"
 
 
-def _check(before, expected, operator_name, replacement_primfunc, 
layout_changes):
+def _check(
+    before, expected, operator_name, replacement_primfunc, layout_changes, 
axis_separator=None
+):
     after = relax.transform.AlterOpImpl(
-        {operator_name: replacement_primfunc}, {operator_name: layout_changes}
+        {operator_name: replacement_primfunc},
+        {operator_name: layout_changes},
+        {operator_name: axis_separator},
     )(before)
     after = relax.transform.DeadCodeElimination()(after)
     tvm.ir.assert_structural_equal(after, expected)
@@ -225,6 +230,79 @@ def test_multiple_outputs():
     )
 
 
+def test_multiple_outputs_with_axis_sep():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def some_op(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,), 
"float32"), output0: T.Buffer((16,), "float32"), output1: T.Buffer((16,), 
"float32")):
+            T.func_attr({"operator_name": "relax.some_op"})
+            for ax0 in range(16):
+                with T.block("T_add"):
+                    v_ax0 = T.axis.spatial(16, ax0)
+                    T.reads(arg0[v_ax0], arg1[v_ax0])
+                    T.writes(output0[v_ax0], output1[v_ax0])
+                    output0[v_ax0] = arg0[v_ax0] + arg1[v_ax0]
+                    output1[v_ax0] = arg0[v_ax0] - arg1[v_ax0]
+
+        @R.function
+        def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), 
dtype="float32")) -> R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,), 
dtype="float32")):
+            with R.dataflow():
+                gv = R.call_tir(Before.some_op, (x, y), 
out_sinfo=[R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")])
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def relax_some_op_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: 
T.Buffer((4, 4), "float32"), output0: T.Buffer((4, 4), "float32"), output1: 
T.Buffer((4, 4), "float32")):
+            T.func_attr({"operator_name": "relax.some_op"})
+            for ax0, ax1 in T.grid(4, 4):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+                    T.writes(output0[v_ax0, v_ax1], output1[v_ax0, v_ax1])
+                    output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, 
v_ax1]
+                    output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, 
v_ax1]
+
+        @R.function
+        def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), 
dtype="float32")) -> R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,), 
dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, 
index_map=lambda i: (i // 4, i % 4), pad_value=None, axis_separators=[1])
+                lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, 
index_map=lambda i: (i // 4, i % 4), pad_value=None, axis_separators=[1])
+                lv2 = R.call_tir(Expected.relax_some_op_replacement, (lv, 
lv1), out_sinfo=[R.Tensor((4, 4), dtype="float32"), R.Tensor((4, 4), 
dtype="float32")])
+                lv3: R.Tensor((4, 4), dtype="float32") = lv2[0]
+                lv4: R.Tensor((16,), dtype="float32") = 
R.layout_transform(lv3, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), 
pad_value=None, axis_separators=[1])
+                lv5: R.Tensor((4, 4), dtype="float32") = lv2[1]
+                lv6: R.Tensor((16,), dtype="float32") = 
R.layout_transform(lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), 
pad_value=None, axis_separators=[1])
+                gv: R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,), 
dtype="float32")) = (lv4, lv6)
+                R.output(gv)
+            return gv
+
+    @T.prim_func
+    def some_op_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), 
"float32"), output0: T.Buffer((4, 4), "float32"), output1: T.Buffer((4, 4), 
"float32")):
+        for ax0, ax1 in T.grid(4, 4):
+            with T.block("T_add"):
+                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+                T.writes(output0[v_ax0, v_ax1], output1[v_ax0, v_ax1])
+                output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
+                output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, v_ax1]
+    # fmt: on
+
+    index_map, axis_sep = IndexMap.from_func_with_separators(
+        lambda i: (i // 4, IndexMap.AXIS_SEPARATOR, i % 4)
+    )
+    _check(
+        Before,
+        Expected,
+        operator_name="relax.some_op",
+        replacement_primfunc=some_op_2d,
+        layout_changes=[index_map, index_map, index_map, index_map],
+        axis_separator=[axis_sep, axis_sep, axis_sep, axis_sep],
+    )
+
+
 def test_unsupported_implicit_padding():
     @I.ir_module
     class InputModule:
@@ -264,7 +342,11 @@ def test_unsupported_implicit_padding():
         tvm.TVMError, match="Non bijective transforms on input and output 
buffers are not supported"
     ):
         _ = relax.transform.AlterOpImpl(
-            {operator_name: relu_pad}, {operator_name: [index_map, index_map]}
+            {operator_name: relu_pad},
+            {
+                operator_name: [index_map, index_map],
+            },
+            {operator_name: None},
         )(before)
 
 

Reply via email to