This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 4c65ceee91 [Unity] Lowering of axis separator in Layout Transform
(#15390)
4c65ceee91 is described below
commit 4c65ceee91bec61cb036c4e4022f58da0442863d
Author: Abhikrant Sharma <[email protected]>
AuthorDate: Thu Jul 27 06:43:47 2023 +0530
[Unity] Lowering of axis separator in Layout Transform (#15390)
* [Unity] Lowering of axis separator in Layout Transform
* Fix LINT errors
* Add comments to code
---
.../tvm/relax/transform/legalize_ops/manipulate.py | 61 +++++++++-------
.../test_transform_legalize_ops_manipulate.py | 81 ++++++++++++++++------
2 files changed, 96 insertions(+), 46 deletions(-)
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 2e5f44b406..4e06a0df39 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -21,6 +21,9 @@ from typing import Optional
import tvm
from tvm import topi, tir, relax, te
+from tvm.relax.op.base import call_tir
+from tvm.relax.struct_info import TensorStructInfo
+from tvm.relax.utils import gen_call_tir_inputs
from tvm.tir.expr import IntImm
from ...block_builder import BlockBuilder
from ...expr import Call, Expr, Var, Tuple, TupleGetItem, ShapeExpr
@@ -167,30 +170,38 @@ def _scatter_elements(bb: BlockBuilder, call: Call) ->
Expr:
@register_legalize("relax.layout_transform")
def _layout_transform(bb: BlockBuilder, call: Call) -> Expr:
+ def te_layout_transform(data, name):
+ """
+ Returns a passthrough TE compute with appropriate name. This is needed
to generate
+ TIR function, output shape info, TIR vars from gen_call_tir_inputs
function.
+ """
+ return te.compute(
+ data.shape,
+ data,
+ name=name,
+ )
+
index_map: tvm.tir.IndexMap = call.attrs.index_map
pad_value = call.attrs.pad_value.value
-
- def te_layout_transform(data):
- inverse, padding_predicate =
index_map.non_surjective_inverse(data.shape)
- output_shape = index_map.map_shape(data.shape)
- if isinstance(padding_predicate, tvm.tir.expr.IntImm) and
bool(padding_predicate) is False:
- return te.compute(
- output_shape,
- lambda *idx: data(*inverse.map_indices(idx)),
- name="te_layout_transform",
- )
- else:
- return te.compute(
- output_shape,
- lambda *idx: tvm.te.if_then_else(
- tir.stmt_functor.substitute(
- padding_predicate,
- {old_idx: idx[i] for i, old_idx in
enumerate(inverse.initial_indices)},
- ),
- pad_value,
- data(*inverse.map_indices(idx)),
- ),
- name="te_layout_transform_with_pad",
- )
-
- return bb.call_te(te_layout_transform, call.args[0])
+ axis_separators: tvm.tir.IndexMap.AXIS_SEPARATOR =
call.attrs.axis_separators
+ # Convert to list from array
+ axis_separators = list(map(lambda x: x.value, axis_separators))
+ primfunc_name = "te_layout_transform"
+ _, padding_predicate =
index_map.non_surjective_inverse(call.args[0].struct_info.shape)
+ if not isinstance(padding_predicate, tvm.tir.expr.IntImm):
+ primfunc_name += "_with_pad"
+ if len(axis_separators) != 0:
+ primfunc_name += "_axis_separator"
+ tir_func, call_args, _, tir_vars = gen_call_tir_inputs(
+ te_layout_transform, call.args[0], primfunc_name
+ )
+ # Create TIR schedule to apply layout changes with axis separators
+ sch = tir.Schedule(tir_func)
+ sch.transform_layout(primfunc_name, ("write", 0), index_map, pad_value)
+ if len(axis_separators) != 0:
+ sch.set_axis_separator(primfunc_name, ("write", 0),
axis_separators=axis_separators)
+ gvar = bb.add_func(sch.mod["main"], primfunc_name)
+ output_shape = index_map.map_shape(list(call_args[0].struct_info.shape))
+ output_dtype = call_args[0].struct_info.dtype
+ output_sinfo = [TensorStructInfo(output_shape, output_dtype)]
+ return call_tir(gvar, call_args, output_sinfo, tir_vars)
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index ff01c508ec..5f556730d9 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -1541,12 +1541,12 @@ def test_layout_transform():
def te_layout_transform(A: T.Buffer((T.int64(10), T.int64(21),
T.int64(30)), "float32"), te_layout_transform_1: T.Buffer((T.int64(10),
T.int64(30), T.int64(7), T.int64(3)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
- for i0, i1, i2, i3 in T.grid(T.int64(10), T.int64(30), T.int64(7),
T.int64(3)):
+ for i0, i1, i2 in T.grid(T.int64(10), T.int64(21), T.int64(30)):
with T.block("te_layout_transform"):
- v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
- T.reads(A[v_i0, v_i2 * T.int64(3) + v_i3, v_i1])
- T.writes(te_layout_transform_1[v_i0, v_i1, v_i2, v_i3])
- te_layout_transform_1[v_i0, v_i1, v_i2, v_i3] = A[v_i0,
v_i2 * T.int64(3) + v_i3, v_i1]
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(A[v_i0, v_i1, v_i2])
+ T.writes(te_layout_transform_1[v_i0, v_i2, v_i1 //
T.int64(3), v_i1 % T.int64(3)])
+ te_layout_transform_1[v_i0, v_i2, v_i1 // T.int64(3), v_i1
% T.int64(3)] = A[v_i0, v_i1, v_i2]
@R.function
def main(x: R.Tensor((10, 21, 30), dtype="float32")) -> R.Tensor((10,
30, 7, 3), dtype="float32"):
@@ -1575,20 +1575,20 @@ def test_layout_transform_with_pad():
@I.ir_module
class Expected:
@T.prim_func
- def te_layout_transform(A: T.Buffer((T.int64(10), T.int64(20),
T.int64(30)), "float32"), te_layout_transform_1: T.Buffer((T.int64(10),
T.int64(30), T.int64(7), T.int64(3)), "float32")):
+ def te_layout_transform_with_pad(A: T.Buffer((T.int64(10),
T.int64(20), T.int64(30)), "float32"), te_layout_transform_with_pad_1:
T.Buffer((T.int64(10), T.int64(30), T.int64(7), T.int64(3)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
- for i0, i1, i2, i3 in T.grid(T.int64(10), T.int64(30), T.int64(7),
T.int64(3)):
+ for axis0, axis1, axis2, axis3 in T.grid(T.int64(10), T.int64(30),
T.int64(7), T.int64(3)):
with T.block("te_layout_transform_with_pad"):
- v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
- T.reads(A[v_i0, v_i2 * T.int64(3) + v_i3, v_i1])
- T.writes(te_layout_transform_1[v_i0, v_i1, v_i2, v_i3])
- te_layout_transform_1[v_i0, v_i1, v_i2, v_i3] =
T.if_then_else(v_i2 == T.int64(6) and v_i3 == T.int64(2), T.float32(2), A[v_i0,
v_i2 * T.int64(3) + v_i3, v_i1])
+ v_axis0, v_axis1, v_axis2, v_axis3 = T.axis.remap("SSSS",
[axis0, axis1, axis2, axis3])
+ T.reads(A[v_axis0, v_axis2 * T.int64(3) + v_axis3,
v_axis1])
+ T.writes(te_layout_transform_with_pad_1[v_axis0, v_axis1,
v_axis2, v_axis3])
+ te_layout_transform_with_pad_1[v_axis0, v_axis1, v_axis2,
v_axis3] = T.if_then_else(v_axis2 == T.int64(6) and v_axis3 == T.int64(2),
T.float32(2), A[v_axis0, v_axis2 * T.int64(3) + v_axis3, v_axis1])
@R.function
def main(x: R.Tensor((10, 20, 30), dtype="float32")) -> R.Tensor((10,
30, 7, 3), dtype="float32"):
cls = Expected
- gv = R.call_tir(cls.te_layout_transform, (x,),
out_sinfo=R.Tensor((10, 30, 7, 3), dtype="float32"))
+ gv = R.call_tir(cls.te_layout_transform_with_pad, (x,),
out_sinfo=R.Tensor((10, 30, 7, 3), dtype="float32"))
return gv
# fmt: on
@@ -1612,18 +1612,18 @@ def test_layout_transform_symbolic():
@I.ir_module
class Expected:
@T.prim_func
- def te_layout_transform(var_A: T.handle, var_te_layout_transform:
T.handle):
+ def te_layout_transform_with_pad(var_A: T.handle,
var_te_layout_transform_with_pad: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
a, b, c = T.int64(), T.int64(), T.int64()
A = T.match_buffer(var_A, (a, b, c))
- te_layout_transform_1 = T.match_buffer(var_te_layout_transform,
(a, c, (b - b % T.int64(-3)) // T.int64(3), T.int64(3)))
+ te_layout_transform_with_pad_1 =
T.match_buffer(var_te_layout_transform_with_pad, (a, c, (b - b % T.int64(-3))
// T.int64(3), T.int64(3)))
# with T.block("root"):
- for i0, i1, i2, i3 in T.grid(a, c, (b - b % T.int64(-3)) //
T.int64(3), T.int64(3)):
- with T.block("te_layout_transform_with_pad"):
- v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
- T.reads(A[v_i0, v_i2 * T.int64(3) + v_i3, v_i1])
- T.writes(te_layout_transform_1[v_i0, v_i1, v_i2, v_i3])
- te_layout_transform_1[v_i0, v_i1, v_i2, v_i3] =
T.if_then_else(b % T.int64(-3) < T.int64(0) and v_i2 == b // T.int64(3) and b %
T.int64(3) <= v_i3, T.float32(2), A[v_i0, v_i2 * T.int64(3) + v_i3, v_i1])
+ for axis0, axis1, axis2, axis3 in T.grid(a, c, (b - b %
T.int64(-3)) // T.int64(3), T.int64(3)):
+ with T.block("te_layout_transform_with_pad_with_pad"):
+ v_axis0, v_axis1, v_axis2, v_axis3 = T.axis.remap("SSSS",
[axis0, axis1, axis2, axis3])
+ T.reads(A[v_axis0, v_axis2 * T.int64(3) + v_axis3,
v_axis1])
+ T.writes(te_layout_transform_with_pad_1[v_axis0, v_axis1,
v_axis2, v_axis3])
+ te_layout_transform_with_pad_1[v_axis0, v_axis1, v_axis2,
v_axis3] = T.if_then_else(b % T.int64(-3) < T.int64(0) and v_axis2 == b //
T.int64(3) and b % T.int64(3) <= v_axis3, T.float32(2), A[v_axis0, v_axis2 *
T.int64(3) + v_axis3, v_axis1])
@R.function
def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) ->
R.Tensor(("a", "c", "(b - b % -3) // 3", 3), dtype="float32"):
@@ -1631,7 +1631,46 @@ def test_layout_transform_symbolic():
c = T.int64()
b = T.int64()
cls = Expected
- gv = R.call_tir(cls.te_layout_transform, (x,),
out_sinfo=R.Tensor((a, c, (b - b % -3) // 3, 3), dtype="float32"))
+ gv = R.call_tir(cls.te_layout_transform_with_pad, (x,),
out_sinfo=R.Tensor((a, c, (b - b % -3) // 3, 3), dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(LayoutTransform)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_layout_transform_with_pad_axis_sep():
+ transformation = lambda a, b, c: (a, c, b // 3, b % 3)
+ pad_value = 2
+ axis_separator = [3]
+ # fmt: off
+ @I.ir_module
+ class LayoutTransform:
+ @R.function
+ def main(x: R.Tensor((10, 20, 30), "float32")):
+ gv = R.layout_transform(
+ x, index_map=transformation, pad_value=pad_value,
axis_separators=axis_separator,
+ )
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def te_layout_transform_with_pad_axis_separator(A:
T.Buffer((T.int64(10), T.int64(20), T.int64(30)), "float32"),
var_te_layout_transform_with_pad_axis_separator: T.handle):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ te_layout_transform_with_pad_axis_separator_1 =
T.match_buffer(var_te_layout_transform_with_pad_axis_separator, (T.int64(10),
T.int64(30), T.int64(7), T.int64(3)), axis_separators=[3])
+ # with T.block("root"):
+ for axis0, axis1, axis2, axis3 in T.grid(T.int64(10), T.int64(30),
T.int64(7), T.int64(3)):
+ with T.block("te_layout_transform_with_pad_axis_separator"):
+ v_axis0, v_axis1, v_axis2, v_axis3 = T.axis.remap("SSSS",
[axis0, axis1, axis2, axis3])
+ T.reads(A[v_axis0, v_axis2 * T.int64(3) + v_axis3,
v_axis1])
+
T.writes(te_layout_transform_with_pad_axis_separator_1[v_axis0, v_axis1,
v_axis2, v_axis3])
+ te_layout_transform_with_pad_axis_separator_1[v_axis0,
v_axis1, v_axis2, v_axis3] = T.if_then_else(v_axis2 == T.int64(6) and v_axis3
== T.int64(2), T.float32(2), A[v_axis0, v_axis2 * T.int64(3) + v_axis3,
v_axis1])
+
+ @R.function
+ def main(x: R.Tensor((10, 20, 30), dtype="float32")) -> R.Tensor((10,
30, 7, 3), dtype="float32"):
+ cls = Expected
+ gv = R.call_tir(cls.te_layout_transform_with_pad_axis_separator,
(x,), out_sinfo=R.Tensor((10, 30, 7, 3), dtype="float32"))
return gv
# fmt: on