This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 4c65ceee91 [Unity] Lowering of axis separator in Layout Transform 
(#15390)
4c65ceee91 is described below

commit 4c65ceee91bec61cb036c4e4022f58da0442863d
Author: Abhikrant Sharma <[email protected]>
AuthorDate: Thu Jul 27 06:43:47 2023 +0530

    [Unity] Lowering of axis separator in Layout Transform (#15390)
    
    * [Unity] Lowering of axis separator in Layout Transform
    
    * Fix LINT errors
    
    * Add comments to code
---
 .../tvm/relax/transform/legalize_ops/manipulate.py | 61 +++++++++-------
 .../test_transform_legalize_ops_manipulate.py      | 81 ++++++++++++++++------
 2 files changed, 96 insertions(+), 46 deletions(-)

diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py 
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 2e5f44b406..4e06a0df39 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -21,6 +21,9 @@ from typing import Optional
 
 import tvm
 from tvm import topi, tir, relax, te
+from tvm.relax.op.base import call_tir
+from tvm.relax.struct_info import TensorStructInfo
+from tvm.relax.utils import gen_call_tir_inputs
 from tvm.tir.expr import IntImm
 from ...block_builder import BlockBuilder
 from ...expr import Call, Expr, Var, Tuple, TupleGetItem, ShapeExpr
@@ -167,30 +170,38 @@ def _scatter_elements(bb: BlockBuilder, call: Call) -> 
Expr:
 
 @register_legalize("relax.layout_transform")
 def _layout_transform(bb: BlockBuilder, call: Call) -> Expr:
+    def te_layout_transform(data, name):
+        """
+        Returns a passthrough TE compute with appropriate name. This is needed 
to generate
+        TIR function, output shape info, TIR vars from gen_call_tir_inputs 
function.
+        """
+        return te.compute(
+            data.shape,
+            data,
+            name=name,
+        )
+
     index_map: tvm.tir.IndexMap = call.attrs.index_map
     pad_value = call.attrs.pad_value.value
-
-    def te_layout_transform(data):
-        inverse, padding_predicate = 
index_map.non_surjective_inverse(data.shape)
-        output_shape = index_map.map_shape(data.shape)
-        if isinstance(padding_predicate, tvm.tir.expr.IntImm) and 
bool(padding_predicate) is False:
-            return te.compute(
-                output_shape,
-                lambda *idx: data(*inverse.map_indices(idx)),
-                name="te_layout_transform",
-            )
-        else:
-            return te.compute(
-                output_shape,
-                lambda *idx: tvm.te.if_then_else(
-                    tir.stmt_functor.substitute(
-                        padding_predicate,
-                        {old_idx: idx[i] for i, old_idx in 
enumerate(inverse.initial_indices)},
-                    ),
-                    pad_value,
-                    data(*inverse.map_indices(idx)),
-                ),
-                name="te_layout_transform_with_pad",
-            )
-
-    return bb.call_te(te_layout_transform, call.args[0])
+    axis_separators: tvm.tir.IndexMap.AXIS_SEPARATOR = 
call.attrs.axis_separators
+    # Convert to list from array
+    axis_separators = list(map(lambda x: x.value, axis_separators))
+    primfunc_name = "te_layout_transform"
+    _, padding_predicate = 
index_map.non_surjective_inverse(call.args[0].struct_info.shape)
+    if not isinstance(padding_predicate, tvm.tir.expr.IntImm):
+        primfunc_name += "_with_pad"
+    if len(axis_separators) != 0:
+        primfunc_name += "_axis_separator"
+    tir_func, call_args, _, tir_vars = gen_call_tir_inputs(
+        te_layout_transform, call.args[0], primfunc_name
+    )
+    # Create TIR schedule to apply layout changes with axis separators
+    sch = tir.Schedule(tir_func)
+    sch.transform_layout(primfunc_name, ("write", 0), index_map, pad_value)
+    if len(axis_separators) != 0:
+        sch.set_axis_separator(primfunc_name, ("write", 0), 
axis_separators=axis_separators)
+    gvar = bb.add_func(sch.mod["main"], primfunc_name)
+    output_shape = index_map.map_shape(list(call_args[0].struct_info.shape))
+    output_dtype = call_args[0].struct_info.dtype
+    output_sinfo = [TensorStructInfo(output_shape, output_dtype)]
+    return call_tir(gvar, call_args, output_sinfo, tir_vars)
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py 
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index ff01c508ec..5f556730d9 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -1541,12 +1541,12 @@ def test_layout_transform():
         def te_layout_transform(A: T.Buffer((T.int64(10), T.int64(21), 
T.int64(30)), "float32"), te_layout_transform_1: T.Buffer((T.int64(10), 
T.int64(30), T.int64(7), T.int64(3)), "float32")):
             T.func_attr({"tir.noalias": T.bool(True)})
             # with T.block("root"):
-            for i0, i1, i2, i3 in T.grid(T.int64(10), T.int64(30), T.int64(7), 
T.int64(3)):
+            for i0, i1, i2 in T.grid(T.int64(10), T.int64(21), T.int64(30)):
                 with T.block("te_layout_transform"):
-                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
-                    T.reads(A[v_i0, v_i2 * T.int64(3) + v_i3, v_i1])
-                    T.writes(te_layout_transform_1[v_i0, v_i1, v_i2, v_i3])
-                    te_layout_transform_1[v_i0, v_i1, v_i2, v_i3] = A[v_i0, 
v_i2 * T.int64(3) + v_i3, v_i1]
+                    v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+                    T.reads(A[v_i0, v_i1, v_i2])
+                    T.writes(te_layout_transform_1[v_i0, v_i2, v_i1 // 
T.int64(3), v_i1 % T.int64(3)])
+                    te_layout_transform_1[v_i0, v_i2, v_i1 // T.int64(3), v_i1 
% T.int64(3)] = A[v_i0, v_i1, v_i2]
 
         @R.function
         def main(x: R.Tensor((10, 21, 30), dtype="float32")) -> R.Tensor((10, 
30, 7, 3), dtype="float32"):
@@ -1575,20 +1575,20 @@ def test_layout_transform_with_pad():
     @I.ir_module
     class Expected:
         @T.prim_func
-        def te_layout_transform(A: T.Buffer((T.int64(10), T.int64(20), 
T.int64(30)), "float32"), te_layout_transform_1: T.Buffer((T.int64(10), 
T.int64(30), T.int64(7), T.int64(3)), "float32")):
+        def te_layout_transform_with_pad(A: T.Buffer((T.int64(10), 
T.int64(20), T.int64(30)), "float32"), te_layout_transform_with_pad_1: 
T.Buffer((T.int64(10), T.int64(30), T.int64(7), T.int64(3)), "float32")):
             T.func_attr({"tir.noalias": T.bool(True)})
             # with T.block("root"):
-            for i0, i1, i2, i3 in T.grid(T.int64(10), T.int64(30), T.int64(7), 
T.int64(3)):
+            for axis0, axis1, axis2, axis3 in T.grid(T.int64(10), T.int64(30), 
T.int64(7), T.int64(3)):
                 with T.block("te_layout_transform_with_pad"):
-                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
-                    T.reads(A[v_i0, v_i2 * T.int64(3) + v_i3, v_i1])
-                    T.writes(te_layout_transform_1[v_i0, v_i1, v_i2, v_i3])
-                    te_layout_transform_1[v_i0, v_i1, v_i2, v_i3] = 
T.if_then_else(v_i2 == T.int64(6) and v_i3 == T.int64(2), T.float32(2), A[v_i0, 
v_i2 * T.int64(3) + v_i3, v_i1])
+                    v_axis0, v_axis1, v_axis2, v_axis3 = T.axis.remap("SSSS", 
[axis0, axis1, axis2, axis3])
+                    T.reads(A[v_axis0, v_axis2 * T.int64(3) + v_axis3, 
v_axis1])
+                    T.writes(te_layout_transform_with_pad_1[v_axis0, v_axis1, 
v_axis2, v_axis3])
+                    te_layout_transform_with_pad_1[v_axis0, v_axis1, v_axis2, 
v_axis3] = T.if_then_else(v_axis2 == T.int64(6) and v_axis3 == T.int64(2), 
T.float32(2), A[v_axis0, v_axis2 * T.int64(3) + v_axis3, v_axis1])
 
         @R.function
         def main(x: R.Tensor((10, 20, 30), dtype="float32")) -> R.Tensor((10, 
30, 7, 3), dtype="float32"):
             cls = Expected
-            gv = R.call_tir(cls.te_layout_transform, (x,), 
out_sinfo=R.Tensor((10, 30, 7, 3), dtype="float32"))
+            gv = R.call_tir(cls.te_layout_transform_with_pad, (x,), 
out_sinfo=R.Tensor((10, 30, 7, 3), dtype="float32"))
             return gv
     # fmt: on
 
@@ -1612,18 +1612,18 @@ def test_layout_transform_symbolic():
     @I.ir_module
     class Expected:
         @T.prim_func
-        def te_layout_transform(var_A: T.handle, var_te_layout_transform: 
T.handle):
+        def te_layout_transform_with_pad(var_A: T.handle, 
var_te_layout_transform_with_pad: T.handle):
             T.func_attr({"tir.noalias": T.bool(True)})
             a, b, c = T.int64(), T.int64(), T.int64()
             A = T.match_buffer(var_A, (a, b, c))
-            te_layout_transform_1 = T.match_buffer(var_te_layout_transform, 
(a, c, (b - b % T.int64(-3)) // T.int64(3), T.int64(3)))
+            te_layout_transform_with_pad_1 = 
T.match_buffer(var_te_layout_transform_with_pad, (a, c, (b - b % T.int64(-3)) 
// T.int64(3), T.int64(3)))
             # with T.block("root"):
-            for i0, i1, i2, i3 in T.grid(a, c, (b - b % T.int64(-3)) // 
T.int64(3), T.int64(3)):
-                with T.block("te_layout_transform_with_pad"):
-                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
-                    T.reads(A[v_i0, v_i2 * T.int64(3) + v_i3, v_i1])
-                    T.writes(te_layout_transform_1[v_i0, v_i1, v_i2, v_i3])
-                    te_layout_transform_1[v_i0, v_i1, v_i2, v_i3] = 
T.if_then_else(b % T.int64(-3) < T.int64(0) and v_i2 == b // T.int64(3) and b % 
T.int64(3) <= v_i3, T.float32(2), A[v_i0, v_i2 * T.int64(3) + v_i3, v_i1])
+            for axis0, axis1, axis2, axis3 in T.grid(a, c, (b - b % 
T.int64(-3)) // T.int64(3), T.int64(3)):
+                with T.block("te_layout_transform_with_pad_with_pad"):
+                    v_axis0, v_axis1, v_axis2, v_axis3 = T.axis.remap("SSSS", 
[axis0, axis1, axis2, axis3])
+                    T.reads(A[v_axis0, v_axis2 * T.int64(3) + v_axis3, 
v_axis1])
+                    T.writes(te_layout_transform_with_pad_1[v_axis0, v_axis1, 
v_axis2, v_axis3])
+                    te_layout_transform_with_pad_1[v_axis0, v_axis1, v_axis2, 
v_axis3] = T.if_then_else(b % T.int64(-3) < T.int64(0) and v_axis2 == b // 
T.int64(3) and b % T.int64(3) <= v_axis3, T.float32(2), A[v_axis0, v_axis2 * 
T.int64(3) + v_axis3, v_axis1])
 
         @R.function
         def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> 
R.Tensor(("a", "c", "(b - b % -3) // 3", 3), dtype="float32"):
@@ -1631,7 +1631,46 @@ def test_layout_transform_symbolic():
             c = T.int64()
             b = T.int64()
             cls = Expected
-            gv = R.call_tir(cls.te_layout_transform, (x,), 
out_sinfo=R.Tensor((a, c, (b - b % -3) // 3, 3), dtype="float32"))
+            gv = R.call_tir(cls.te_layout_transform_with_pad, (x,), 
out_sinfo=R.Tensor((a, c, (b - b % -3) // 3, 3), dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(LayoutTransform)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_layout_transform_with_pad_axis_sep():
+    transformation = lambda a, b, c: (a, c, b // 3, b % 3)
+    pad_value = 2
+    axis_separator = [3]
+    # fmt: off
+    @I.ir_module
+    class LayoutTransform:
+        @R.function
+        def main(x: R.Tensor((10, 20, 30), "float32")):
+            gv = R.layout_transform(
+                x, index_map=transformation, pad_value=pad_value, 
axis_separators=axis_separator,
+            )
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def te_layout_transform_with_pad_axis_separator(A: 
T.Buffer((T.int64(10), T.int64(20), T.int64(30)), "float32"), 
var_te_layout_transform_with_pad_axis_separator: T.handle):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            te_layout_transform_with_pad_axis_separator_1 = 
T.match_buffer(var_te_layout_transform_with_pad_axis_separator, (T.int64(10), 
T.int64(30), T.int64(7), T.int64(3)), axis_separators=[3])
+            # with T.block("root"):
+            for axis0, axis1, axis2, axis3 in T.grid(T.int64(10), T.int64(30), 
T.int64(7), T.int64(3)):
+                with T.block("te_layout_transform_with_pad_axis_separator"):
+                    v_axis0, v_axis1, v_axis2, v_axis3 = T.axis.remap("SSSS", 
[axis0, axis1, axis2, axis3])
+                    T.reads(A[v_axis0, v_axis2 * T.int64(3) + v_axis3, 
v_axis1])
+                    
T.writes(te_layout_transform_with_pad_axis_separator_1[v_axis0, v_axis1, 
v_axis2, v_axis3])
+                    te_layout_transform_with_pad_axis_separator_1[v_axis0, 
v_axis1, v_axis2, v_axis3] = T.if_then_else(v_axis2 == T.int64(6) and v_axis3 
== T.int64(2), T.float32(2), A[v_axis0, v_axis2 * T.int64(3) + v_axis3, 
v_axis1])
+
+        @R.function
+        def main(x: R.Tensor((10, 20, 30), dtype="float32")) -> R.Tensor((10, 
30, 7, 3), dtype="float32"):
+            cls = Expected
+            gv = R.call_tir(cls.te_layout_transform_with_pad_axis_separator, 
(x,), out_sinfo=R.Tensor((10, 30, 7, 3), dtype="float32"))
             return gv
     # fmt: on
 

Reply via email to