This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ae839848b2 [Relax][PyTorch] Add support for decomposed operators and
fix IR of ops tests(6) (#18420)
ae839848b2 is described below
commit ae839848b22f16aa92adb2a83ab050ecde8ee3cc
Author: Shushi Hong <[email protected]>
AuthorDate: Wed Nov 5 21:46:48 2025 -0500
[Relax][PyTorch] Add support for decomposed operators and fix IR of ops
tests(6) (#18420)
* finish1
* finish2
---
.../frontend/torch/base_fx_graph_translator.py | 14 ++
.../frontend/torch/exported_program_translator.py | 20 +-
.../relax/test_frontend_from_exported_program.py | 229 +++++++++------------
3 files changed, 130 insertions(+), 133 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index aedef8acf8..03e3b8d557 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1725,6 +1725,20 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
# Support both "dim" and "dims" parameters
if dim is None:
dim = node.kwargs.get("dims", None)
+
+ # If dims is a list, filter out axes where dimension is not 1
+ # This is needed because PyTorch decomposition may pass all axes
+ if isinstance(dim, (list, tuple)) and len(dim) > 0:
+ shape = self.shape_of(x)
+ # Filter to only include axes where the dimension is 1
+ valid_dims = []
+ for d in dim:
+ axis = d if d >= 0 else len(shape) + d
+ if axis < len(shape) and shape[axis] == 1:
+ valid_dims.append(d)
+ # If no valid dims, use None to squeeze all size-1 dimensions
+ dim = valid_dims if valid_dims else None
+
return self.block_builder.emit(relax.op.squeeze(x, dim))
def _stack(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 3be255a29a..4f3132b8d8 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -701,11 +701,23 @@ class ExportedProgramImporter(BaseFXGraphImporter):
return self.block_builder.emit(relax.op.take(x, index, dim))
def _slice(self, node: fx.Node) -> relax.Var:
+ import sys
+
x = self.env[node.args[0]]
- axes = [node.args[1]]
- begin = [node.args[2]]
- end = [node.args[3]]
- stride = [node.args[4] if len(node.args) > 4 else 1]
+ dim = node.args[1] if len(node.args) > 1 else 0
+ start = node.args[2] if len(node.args) > 2 else None
+ end_val = node.args[3] if len(node.args) > 3 else None
+ step = node.args[4] if len(node.args) > 4 else 1
+
+ if start is None:
+ start = 0
+ if end_val is None:
+ end_val = sys.maxsize
+
+ axes = [dim]
+ begin = [start]
+ end = [end_val]
+ stride = [step]
return self.block_builder.emit(relax.op.strided_slice(x, axes, begin,
end, stride))
def _unflatten(self, node: fx.Node) -> relax.Var:
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 8a9fe66a0f..44248c1c59 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -4111,7 +4111,7 @@ def test_reshape():
return gv
example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
- verify_model(Reshape(), example_args, {}, expected1)
+ verify_model(Reshape(), example_args, {}, expected1,
run_ep_decomposition=True)
def test_reshape_as():
@@ -4137,7 +4137,7 @@ def test_reshape_as():
torch.randn(1, 2, 3, 4, dtype=torch.float32),
torch.randn(2, 12, dtype=torch.float32),
)
- verify_model(ReshapeAs(), example_args, {}, expected1)
+ verify_model(ReshapeAs(), example_args, {}, expected1,
run_ep_decomposition=True)
def test_roll():
@@ -4160,25 +4160,14 @@ def test_roll():
def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4,
2), dtype="int64")):
with R.dataflow():
lv: R.Tensor((8,), dtype="int64") = R.reshape(x, R.shape([8]))
- lv1: R.Tensor((7,), dtype="int64") = R.strided_slice(
- lv,
- axes=[0],
- begin=[R.prim_value(0)],
- end=[R.prim_value(7)],
- strides=[R.prim_value(1)],
- assume_inbound=False,
- )
- lv2: R.Tensor((1,), dtype="int64") = R.strided_slice(
- lv,
- axes=[0],
- begin=[R.prim_value(7)],
- end=[R.prim_value(8)],
- strides=[R.prim_value(1)],
- assume_inbound=False,
+ lv1: R.Tensor((8,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(8), R.prim_value(1),
dtype="int64"
)
- lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1),
axis=0)
- lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3,
R.shape([4, 2]))
- gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv4,)
+ lv2: R.Tensor((8,), dtype="int64") = R.add(lv1, R.const(7,
"int64"))
+ lv3: R.Tensor((8,), dtype="int64") = R.mod(lv2, R.const(8,
"int64"))
+ lv4: R.Tensor((8,), dtype="int64") = R.take(lv, lv3, axis=0,
mode="fast")
+ lv5: R.Tensor((4, 2), dtype="int64") = R.reshape(lv4,
R.shape([4, 2]))
+ gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv5,)
R.output(gv)
return gv
@@ -4188,24 +4177,13 @@ def test_roll():
@R.function
def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4,
2), dtype="int64")):
with R.dataflow():
- lv: R.Tensor((1, 2), dtype="int64") = R.strided_slice(
- x,
- axes=[0],
- begin=[R.prim_value(0)],
- end=[R.prim_value(1)],
- strides=[R.prim_value(1)],
- assume_inbound=False,
+ lv: R.Tensor((4,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(4), R.prim_value(1),
dtype="int64"
)
- lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice(
- x,
- axes=[0],
- begin=[R.prim_value(1)],
- end=[R.prim_value(4)],
- strides=[R.prim_value(1)],
- assume_inbound=False,
- )
- lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv),
axis=0)
- gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv2,)
+ lv1: R.Tensor((4,), dtype="int64") = R.add(lv, R.const(1,
"int64"))
+ lv2: R.Tensor((4,), dtype="int64") = R.mod(lv1, R.const(4,
"int64"))
+ lv3: R.Tensor((4, 2), dtype="int64") = R.take(x, lv2, axis=0,
mode="fast")
+ gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv3,)
R.output(gv)
return gv
@@ -4216,43 +4194,20 @@ def test_roll():
def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4,
2), dtype="int64")):
with R.dataflow():
# First roll along dim=0 with shift=2
- lv: R.Tensor((2, 2), dtype="int64") = R.strided_slice(
- x,
- axes=[0],
- begin=[R.prim_value(0)],
- end=[R.prim_value(2)],
- strides=[R.prim_value(1)],
- assume_inbound=False,
+ lv: R.Tensor((4,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(4), R.prim_value(1),
dtype="int64"
)
- lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice(
- x,
- axes=[0],
- begin=[R.prim_value(2)],
- end=[R.prim_value(4)],
- strides=[R.prim_value(1)],
- assume_inbound=False,
- )
- lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv),
axis=0)
-
+ lv1: R.Tensor((4,), dtype="int64") = R.add(lv, R.const(2,
"int64"))
+ lv2: R.Tensor((4,), dtype="int64") = R.mod(lv1, R.const(4,
"int64"))
+ lv3: R.Tensor((4, 2), dtype="int64") = R.take(x, lv2, axis=0,
mode="fast")
# Second roll along dim=1 with shift=1
- lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice(
- lv2,
- axes=[1],
- begin=[R.prim_value(0)],
- end=[R.prim_value(1)],
- strides=[R.prim_value(1)],
- assume_inbound=False,
- )
- lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice(
- lv2,
- axes=[1],
- begin=[R.prim_value(1)],
- end=[R.prim_value(2)],
- strides=[R.prim_value(1)],
- assume_inbound=False,
+ lv4: R.Tensor((2,), dtype="int64") = R.arange(
+ R.prim_value(0), R.prim_value(2), R.prim_value(1),
dtype="int64"
)
- lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3),
axis=1)
- gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv5,)
+ lv5: R.Tensor((2,), dtype="int64") = R.add(lv4, R.const(1,
"int64"))
+ lv6: R.Tensor((2,), dtype="int64") = R.mod(lv5, R.const(2,
"int64"))
+ lv7: R.Tensor((4, 2), dtype="int64") = R.take(lv3, lv6,
axis=1, mode="fast")
+ gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv7,)
R.output(gv)
return gv
@@ -4260,9 +4215,9 @@ def test_roll():
example_input = torch.randint(0, 10, (4, 2), dtype=torch.int64)
# Run verification for each case
- verify_model(Roll1(), (example_input,), {}, Expected1)
- verify_model(Roll2(), (example_input,), {}, Expected2)
- verify_model(Roll3(), (example_input,), {}, Expected3)
+ verify_model(Roll1(), (example_input,), {}, Expected1,
run_ep_decomposition=True)
+ verify_model(Roll2(), (example_input,), {}, Expected2,
run_ep_decomposition=True)
+ verify_model(Roll3(), (example_input,), {}, Expected3,
run_ep_decomposition=True)
def test_select_slice():
@@ -4342,10 +4297,10 @@ def test_select_slice():
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Slice1(), example_args, {}, expected1)
+ verify_model(Slice1(), example_args, {}, expected1,
run_ep_decomposition=True)
example_args = (torch.randn(8, 16, dtype=torch.float32),)
- verify_model(Slice2(), example_args, {}, expected2)
+ verify_model(Slice2(), example_args, {}, expected2,
run_ep_decomposition=True)
def test_slice_scatter():
@@ -4387,10 +4342,10 @@ def test_slice_scatter():
return gv
example_args = (torch.randn(8, 8, 10, 10, dtype=torch.float32),
torch.randn(8, 3, 10, 10))
- verify_model(SliceScatter1(), example_args, {}, expected1)
+ verify_model(SliceScatter1(), example_args, {}, expected1,
run_ep_decomposition=True)
example_args = (torch.randn(8, 16, dtype=torch.float32), torch.randn(6,
16))
- verify_model(SliceScatter2(), example_args, {}, expected2)
+ verify_model(SliceScatter2(), example_args, {}, expected2,
run_ep_decomposition=True)
def test_split():
@@ -4402,7 +4357,7 @@ def test_split():
class Expected:
@R.function
def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ input: R.Tensor((1, 3, 10, 10), dtype="float32")
) -> R.Tuple(
R.Tensor((1, 1, 10, 10), dtype="float32"),
R.Tensor((1, 1, 10, 10), dtype="float32"),
@@ -4414,7 +4369,7 @@ def test_split():
R.Tensor((1, 1, 10, 10), dtype="float32"),
R.Tensor((1, 1, 10, 10), dtype="float32"),
R.Tensor((1, 1, 10, 10), dtype="float32"),
- ) = R.split(input_1, indices_or_sections=3, axis=1)
+ ) = R.split(input, indices_or_sections=[1, 2], axis=1)
lv1: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[0]
lv2: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[1]
lv3: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[2]
@@ -4434,7 +4389,7 @@ def test_split():
class expected1:
@R.function
def main(
- input_1: R.Tensor((3, 3, 10, 10), dtype="float32")
+ data: R.Tensor((3, 3, 10, 10), dtype="float32")
) -> R.Tuple(
R.Tensor((3, 10, 10), dtype="float32"),
R.Tensor((3, 10, 10), dtype="float32"),
@@ -4442,30 +4397,38 @@ def test_split():
):
# block 0
with R.dataflow():
- lv: R.Tuple(
- R.Tensor((1, 3, 10, 10), dtype="float32"),
- R.Tensor((1, 3, 10, 10), dtype="float32"),
- R.Tensor((1, 3, 10, 10), dtype="float32"),
- ) = R.split(input_1, indices_or_sections=3, axis=0)
- lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
- lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1,
axis=[0])
- lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1]
- lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3,
axis=[0])
- lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2]
- lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5,
axis=[0])
- lv7: R.Tuple(
- R.Tensor((3, 10, 10), dtype="float32"),
- R.Tensor((3, 10, 10), dtype="float32"),
- R.Tensor((3, 10, 10), dtype="float32"),
- ) = (lv2, lv4, lv6)
- lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0]
- lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1]
- lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2]
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.strided_slice(
+ data,
+ (R.prim_value(0),),
+ (R.prim_value(0),),
+ (R.prim_value(1),),
+ (R.prim_value(1),),
+ assume_inbound=False,
+ )
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.strided_slice(
+ data,
+ (R.prim_value(0),),
+ (R.prim_value(1),),
+ (R.prim_value(2),),
+ (R.prim_value(1),),
+ assume_inbound=False,
+ )
+ lv2: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.strided_slice(
+ data,
+ (R.prim_value(0),),
+ (R.prim_value(2),),
+ (R.prim_value(3),),
+ (R.prim_value(1),),
+ assume_inbound=False,
+ )
+ lv3: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv,
axis=[0])
+ lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1,
axis=[0])
+ lv5: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv2,
axis=[0])
gv: R.Tuple(
R.Tensor((3, 10, 10), dtype="float32"),
R.Tensor((3, 10, 10), dtype="float32"),
R.Tensor((3, 10, 10), dtype="float32"),
- ) = (lv8, lv9, lv10)
+ ) = (lv3, lv4, lv5)
R.output(gv)
return gv
@@ -4477,7 +4440,7 @@ def test_split():
class expected2:
@R.function
def main(
- input_1: R.Tensor((3, 3, 10, 10), dtype="float32")
+ data: R.Tensor((3, 3, 10, 10), dtype="float32")
) -> R.Tuple(
R.Tensor((3, 10, 10), dtype="float32"),
R.Tensor((3, 10, 10), dtype="float32"),
@@ -4485,39 +4448,47 @@ def test_split():
):
# block 0
with R.dataflow():
- lv: R.Tuple(
- R.Tensor((3, 1, 10, 10), dtype="float32"),
- R.Tensor((3, 1, 10, 10), dtype="float32"),
- R.Tensor((3, 1, 10, 10), dtype="float32"),
- ) = R.split(input_1, indices_or_sections=3, axis=1)
- lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0]
- lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1,
axis=[1])
- lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1]
- lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3,
axis=[1])
- lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2]
- lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5,
axis=[1])
- lv7: R.Tuple(
- R.Tensor((3, 10, 10), dtype="float32"),
- R.Tensor((3, 10, 10), dtype="float32"),
- R.Tensor((3, 10, 10), dtype="float32"),
- ) = (lv2, lv4, lv6)
- lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0]
- lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1]
- lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2]
+ lv: R.Tensor((3, 1, 10, 10), dtype="float32") =
R.strided_slice(
+ data,
+ (R.prim_value(1),),
+ (R.prim_value(0),),
+ (R.prim_value(1),),
+ (R.prim_value(1),),
+ assume_inbound=False,
+ )
+ lv1: R.Tensor((3, 1, 10, 10), dtype="float32") =
R.strided_slice(
+ data,
+ (R.prim_value(1),),
+ (R.prim_value(1),),
+ (R.prim_value(2),),
+ (R.prim_value(1),),
+ assume_inbound=False,
+ )
+ lv2: R.Tensor((3, 1, 10, 10), dtype="float32") =
R.strided_slice(
+ data,
+ (R.prim_value(1),),
+ (R.prim_value(2),),
+ (R.prim_value(3),),
+ (R.prim_value(1),),
+ assume_inbound=False,
+ )
+ lv3: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv,
axis=[1])
+ lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1,
axis=[1])
+ lv5: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv2,
axis=[1])
gv: R.Tuple(
R.Tensor((3, 10, 10), dtype="float32"),
R.Tensor((3, 10, 10), dtype="float32"),
R.Tensor((3, 10, 10), dtype="float32"),
- ) = (lv8, lv9, lv10)
+ ) = (lv3, lv4, lv5)
R.output(gv)
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Chunk(), example_args, {}, Expected)
+ verify_model(Chunk(), example_args, {}, Expected,
run_ep_decomposition=True)
example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),)
- verify_model(Unbind1(), example_args, {}, expected1)
- verify_model(Unbind2(), example_args, {}, expected2)
+ verify_model(Unbind1(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(Unbind2(), example_args, {}, expected2,
run_ep_decomposition=True)
def test_squeeze():
@@ -4545,18 +4516,18 @@ def test_squeeze():
class Expected2:
@R.function
def main(
- inp_0: R.Tensor((3, 1, 4, 1), dtype="float32")
+ input: R.Tensor((3, 1, 4, 1), dtype="float32")
) -> R.Tuple(R.Tensor((3, 4), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(inp_0,
axis=None)
+ lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(input,
axis=[1, 3])
gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,)
R.output(gv)
return gv
example_args = (torch.randn(3, 1, 4, 1, dtype=torch.float32),)
- verify_model(Squeeze1(), example_args, {}, Expected1)
- verify_model(Squeeze2(), example_args, {}, Expected2)
+ verify_model(Squeeze1(), example_args, {}, Expected1,
run_ep_decomposition=True)
+ verify_model(Squeeze2(), example_args, {}, Expected2,
run_ep_decomposition=True)
def test_stack():