This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bfb0dd6a16 [Relax][PyTorch] Add support for decomposed operators and
fix IR of ops tests(8) (#18428)
bfb0dd6a16 is described below
commit bfb0dd6a161d33c58f67469988244b139366e063
Author: Shushi Hong <[email protected]>
AuthorDate: Mon Nov 10 00:35:43 2025 -0500
[Relax][PyTorch] Add support for decomposed operators and fix IR of ops
tests(8) (#18428)
---
.../frontend/torch/base_fx_graph_translator.py | 8 +
.../frontend/torch/exported_program_translator.py | 2 +
.../relax/test_frontend_from_exported_program.py | 222 ++++++++++++++++-----
3 files changed, 179 insertions(+), 53 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 03e3b8d557..177e3d91f9 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1379,6 +1379,14 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
keepdim = args[2] if len(node.args) > 2 else
node.kwargs.get("keepdim", False)
return self.block_builder.emit(relax.op.variance(x, dim,
keepdims=keepdim))
+ def _any(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+ keepdim = args[2] if len(node.args) > 2 else
node.kwargs.get("keepdim", False)
+ # For boolean tensors, any is equivalent to max (checking if any
element is True)
+ return self.block_builder.emit(relax.op.max(x, dim, keepdims=keepdim))
+
########## Search ##########
def _argmax_argmin(self, op: Callable) -> Callable:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 4f3132b8d8..ddd19f2b58 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -930,6 +930,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"remainder.Tensor": self._binary_op(relax.op.floor_mod,
operator.mod),
"remainder.Scalar": self._binary_op(relax.op.floor_mod,
operator.mod),
"mul.Tensor": self._binary_op(relax.op.multiply, operator.mul),
+ "mul.Scalar": self._binary_op(relax.op.multiply, operator.mul),
"mul_.Tensor": self._binary_op(relax.op.multiply, operator.mul),
"ne.Tensor": self._binary_op(relax.op.not_equal, operator.ne),
"ne.Scalar": self._binary_op(relax.op.not_equal, operator.ne),
@@ -988,6 +989,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"upsample_nearest2d.vec": self._upsample_nearest2d,
"upsample_bicubic2d.vec": self._upsample_bicubic2d,
# statistical
+ "any.dim": self._any,
"mean.dim": self._mean,
"prod.default": self._prod,
"std.correction": self._std,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index c2ec57ee28..fb4f77567e 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -2778,16 +2778,26 @@ def test_pixel_shuffle():
x: R.Tensor((1, 8, 10, 15), dtype="float32")
) -> R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((1, 2, 20, 30), dtype="float32") =
R.nn.pixel_shuffle(
- x, upscale_factor=2
+ lv: R.Tensor((1, 2, 2, 2, 10, 15), dtype="float32") =
R.reshape(
+ x, R.shape([1, 2, 2, 2, 10, 15])
)
- gv: R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")) = (lv,)
+ lv1: R.Tensor((1, 2, 10, 2, 15, 2), dtype="float32") =
R.permute_dims(
+ lv, axes=[0, 1, 4, 2, 5, 3]
+ )
+ lv2: R.Tensor((1, 2, 20, 30), dtype="float32") = R.reshape(
+ lv1, R.shape([1, 2, 20, 30])
+ )
+ gv: R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")) = (lv2,)
R.output(gv)
return gv
example_args = (torch.randn(1, 8, 10, 15, dtype=torch.float32),)
- verify_model(PixelShuffle1(upscale_factor=2), example_args, {}, expected)
- verify_model(PixelShuffle2(upscale_factor=2), example_args, {}, expected)
+ verify_model(
+ PixelShuffle1(upscale_factor=2), example_args, {}, expected,
run_ep_decomposition=True
+ )
+ verify_model(
+ PixelShuffle2(upscale_factor=2), example_args, {}, expected,
run_ep_decomposition=True
+ )
def test_einsum():
@@ -2832,10 +2842,10 @@ def test_einsum():
return gv
example_args = (torch.randn(4, 4, dtype=torch.float32),)
- verify_model(Einsum1(), example_args, {}, Expected1)
+ verify_model(Einsum1(), example_args, {}, Expected1,
run_ep_decomposition=False)
example_args = (torch.randn(5, dtype=torch.float32), torch.randn(4,
dtype=torch.float32))
- verify_model(Einsum2(), example_args, {}, Expected2)
+ verify_model(Einsum2(), example_args, {}, Expected2,
run_ep_decomposition=False)
def test_outer():
@@ -2847,11 +2857,12 @@ def test_outer():
class expected:
@R.function
def main(
- a: R.Tensor((3,), dtype="float32"), b: R.Tensor((4,),
dtype="float32")
+ x: R.Tensor((3,), dtype="float32"), y: R.Tensor((4,),
dtype="float32")
) -> R.Tuple(R.Tensor((3, 4), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((3, 4), dtype="float32") = R.outer(a, b)
- gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,)
+ lv: R.Tensor((3, 1), dtype="float32") = R.reshape(x,
R.shape([3, 1]))
+ lv1: R.Tensor((3, 4), dtype="float32") = R.multiply(lv, y)
+ gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv1,)
R.output(gv)
return gv
@@ -2859,7 +2870,7 @@ def test_outer():
torch.randn(3, dtype=torch.float32),
torch.randn(4, dtype=torch.float32),
)
- verify_model(Outer(), example_args, {}, expected)
+ verify_model(Outer(), example_args, {}, expected,
run_ep_decomposition=True)
def test_embedding():
@@ -2889,7 +2900,7 @@ def test_embedding():
model = Embedding()
binding = {"w1": model.embedding.weight.detach().numpy()}
- verify_model(model, example_args, binding, expected1)
+ verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
def test_groupnorm():
@@ -3056,12 +3067,14 @@ def test_linear():
) -> R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")):
# block 0
with R.dataflow():
- lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1,
axes=None)
- lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul(
- input_1, lv, out_dtype="float32"
+ lv: R.Tensor((30, 10), dtype="float32") = R.reshape(input_1,
R.shape([30, 10]))
+ lv1: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1,
axes=[1, 0])
+ lv2: R.Tensor((30, 7), dtype="float32") = R.matmul(lv, lv1,
out_dtype="float32")
+ lv3: R.Tensor((30, 7), dtype="float32") = R.add(w2, lv2)
+ lv4: R.Tensor((1, 3, 10, 7), dtype="float32") = R.reshape(
+ lv3, R.shape([1, 3, 10, 7])
)
- lv2: R.Tensor((1, 3, 10, 7), dtype="float32") = R.add(lv1, w2)
- gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv2,)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv4,)
R.output(gv)
return gv
@@ -3082,11 +3095,13 @@ def test_linear():
) -> R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")):
# block 0
with R.dataflow():
- lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1,
axes=None)
- lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul(
- input_1, lv, out_dtype="float32"
+ lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1,
axes=[1, 0])
+ lv1: R.Tensor((30, 10), dtype="float32") = R.reshape(input_1,
R.shape([30, 10]))
+ lv2: R.Tensor((30, 7), dtype="float32") = R.matmul(lv1, lv,
out_dtype="float32")
+ lv3: R.Tensor((1, 3, 10, 7), dtype="float32") = R.reshape(
+ lv2, R.shape([1, 3, 10, 7])
)
- gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv1,)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv3,)
R.output(gv)
return gv
@@ -3094,15 +3109,15 @@ def test_linear():
model = Dense1()
binding = {"w1": model.linear.weight.detach().numpy(), "w2":
model.linear.bias.detach().numpy()}
- verify_model(model, example_args, binding, expected1)
+ verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
model = Dense1Func()
binding = {"w1": model.weight.detach().numpy(), "w2":
model.bias.detach().numpy()}
- verify_model(model, example_args, binding, expected1)
+ verify_model(model, example_args, binding, expected1,
run_ep_decomposition=True)
model = Dense2()
binding = {"w1": model.linear.weight.detach().numpy()}
- verify_model(model, example_args, binding, expected2)
+ verify_model(model, example_args, binding, expected2,
run_ep_decomposition=True)
def test_maxpool1d():
@@ -3415,27 +3430,76 @@ def test_scaled_dot_product_attention():
class Expected1:
@R.function
def main(
- inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"),
- inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"),
- inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"),
+ q: R.Tensor((32, 8, 128, 64), dtype="float32"),
+ k: R.Tensor((32, 8, 128, 64), dtype="float32"),
+ v: R.Tensor((32, 8, 128, 64), dtype="float32"),
) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
- inp_0, axes=[0, 2, 1, 3]
+ lv: R.Tensor((32, 8, 128, 64), dtype="float32") = R.multiply(
+ q, R.const(0.35355338454246521, "float32")
+ )
+ lv1: R.Tensor((32, 8, 64, 128), dtype="float32") =
R.permute_dims(
+ k, axes=[0, 1, 3, 2]
+ )
+ lv2: R.Tensor((32, 8, 64, 128), dtype="float32") = R.multiply(
+ lv1, R.const(0.35355338454246521, "float32")
+ )
+ lv3: R.Tensor((32, 8, 128, 64), dtype="float32") =
R.broadcast_to(
+ lv, R.shape([32, 8, 128, 64])
+ )
+ lv4: R.Tensor((256, 128, 64), dtype="float32") = R.reshape(
+ lv3, R.shape([256, 128, 64])
+ )
+ lv5: R.Tensor((32, 8, 64, 128), dtype="float32") =
R.broadcast_to(
+ lv2, R.shape([32, 8, 64, 128])
+ )
+ lv6: R.Tensor((256, 64, 128), dtype="float32") = R.reshape(
+ lv5, R.shape([256, 64, 128])
+ )
+ lv7: R.Tensor((256, 128, 128), dtype="float32") = R.matmul(
+ lv4, lv6, out_dtype="float32"
+ )
+ lv8: R.Tensor((32, 8, 128, 128), dtype="float32") = R.reshape(
+ lv7, R.shape([32, 8, 128, 128])
+ )
+ lv9: R.Tensor((32, 8, 128, 128), dtype="float32") =
R.nn.softmax(lv8, axis=-1)
+ lv10: R.Tensor((32, 8, 128, 128), dtype="bool") = R.equal(
+ lv8, R.const(float("-inf"), "float32")
)
- lv1: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
- inp_1, axes=[0, 2, 1, 3]
+ lv11: R.Tensor((32, 8, 128, 128), dtype="bool") =
R.logical_not(lv10)
+ lv12: R.Tensor((32, 8, 128, 1), dtype="bool") = R.max(
+ lv11, axis=[-1], keepdims=True
)
- lv2: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
- inp_2, axes=[0, 2, 1, 3]
+ lv13: R.Tensor((32, 8, 128, 1), dtype="bool") =
R.logical_not(lv12)
+ lv14: R.Tensor((32, 8, 128, 128), dtype="float32") =
R.full_like(
+ lv9, R.const(0, "int32"), dtype="void"
)
- lv3: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.nn.attention(
- lv, lv1, lv2, scale=None
+ lv15: R.Tensor((32, 8, 128, 128), dtype="float32") =
R.where(lv13, lv14, lv9)
+ lv16: R.Tensor((32, 8, 128, 128), dtype="float32") =
R.broadcast_to(
+ lv15, R.shape([32, 8, 128, 128])
)
- lv4: R.Tensor((32, 8, 128, 64), dtype="float32") =
R.permute_dims(
- lv3, axes=[0, 2, 1, 3]
+ lv17: R.Tensor((256, 128, 128), dtype="float32") = R.reshape(
+ lv16, R.shape([256, 128, 128])
)
- gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) =
(lv4,)
+ lv18: R.Tensor((32, 8, 128, 64), dtype="float32") =
R.broadcast_to(
+ v, R.shape([32, 8, 128, 64])
+ )
+ lv19: R.Tensor((256, 128, 64), dtype="float32") = R.reshape(
+ lv18, R.shape([256, 128, 64])
+ )
+ lv20: R.Tensor((256, 128, 64), dtype="float32") = R.matmul(
+ lv17, lv19, out_dtype="float32"
+ )
+ lv21: R.Tensor((32, 8, 128, 64), dtype="float32") = R.reshape(
+ lv20, R.shape([32, 8, 128, 64])
+ )
+ lv22: R.Tensor((128, 32, 8, 64), dtype="float32") =
R.permute_dims(
+ lv21, axes=[2, 0, 1, 3]
+ )
+ lv23: R.Tensor((32, 8, 128, 64), dtype="float32") =
R.permute_dims(
+ lv22, axes=[1, 2, 0, 3]
+ )
+ gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) =
(lv23,)
R.output(gv)
return gv
@@ -3447,28 +3511,78 @@ def test_scaled_dot_product_attention():
class Expected2:
@R.function
def main(
- inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"),
- inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"),
- inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"),
- inp_3: R.Tensor((32, 8, 128, 128), dtype="float32"),
+ q: R.Tensor((32, 8, 128, 64), dtype="float32"),
+ k: R.Tensor((32, 8, 128, 64), dtype="float32"),
+ v: R.Tensor((32, 8, 128, 64), dtype="float32"),
+ mask: R.Tensor((32, 8, 128, 128), dtype="float32"),
) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
- inp_0, axes=[0, 2, 1, 3]
+ lv: R.Tensor((32, 8, 128, 64), dtype="float32") = R.multiply(
+ q, R.const(0.35355338454246521, "float32")
+ )
+ lv1: R.Tensor((32, 8, 64, 128), dtype="float32") =
R.permute_dims(
+ k, axes=[0, 1, 3, 2]
+ )
+ lv2: R.Tensor((32, 8, 64, 128), dtype="float32") = R.multiply(
+ lv1, R.const(0.35355338454246521, "float32")
+ )
+ lv3: R.Tensor((32, 8, 128, 64), dtype="float32") =
R.broadcast_to(
+ lv, R.shape([32, 8, 128, 64])
+ )
+ lv4: R.Tensor((256, 128, 64), dtype="float32") = R.reshape(
+ lv3, R.shape([256, 128, 64])
+ )
+ lv5: R.Tensor((32, 8, 64, 128), dtype="float32") =
R.broadcast_to(
+ lv2, R.shape([32, 8, 64, 128])
+ )
+ lv6: R.Tensor((256, 64, 128), dtype="float32") = R.reshape(
+ lv5, R.shape([256, 64, 128])
)
- lv1: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
- inp_1, axes=[0, 2, 1, 3]
+ lv7: R.Tensor((256, 128, 128), dtype="float32") = R.matmul(
+ lv4, lv6, out_dtype="float32"
)
- lv2: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.permute_dims(
- inp_2, axes=[0, 2, 1, 3]
+ lv8: R.Tensor((32, 8, 128, 128), dtype="float32") = R.reshape(
+ lv7, R.shape([32, 8, 128, 128])
)
- lv3: R.Tensor((32, 128, 8, 64), dtype="float32") =
R.nn.attention(
- lv, lv1, lv2, inp_3, scale=None
+ lv9: R.Tensor((32, 8, 128, 128), dtype="float32") = R.add(lv8,
mask)
+ lv10: R.Tensor((32, 8, 128, 128), dtype="float32") =
R.nn.softmax(lv9, axis=-1)
+ lv11: R.Tensor((32, 8, 128, 128), dtype="bool") = R.equal(
+ lv9, R.const(float("-inf"), "float32")
)
- lv4: R.Tensor((32, 8, 128, 64), dtype="float32") =
R.permute_dims(
- lv3, axes=[0, 2, 1, 3]
+ lv12: R.Tensor((32, 8, 128, 128), dtype="bool") =
R.logical_not(lv11)
+ lv13: R.Tensor((32, 8, 128, 1), dtype="bool") = R.max(
+ lv12, axis=[-1], keepdims=True
)
- gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) =
(lv4,)
+ lv14: R.Tensor((32, 8, 128, 1), dtype="bool") =
R.logical_not(lv13)
+ lv15: R.Tensor((32, 8, 128, 128), dtype="float32") =
R.full_like(
+ lv10, R.const(0, "int32"), dtype="void"
+ )
+ lv16: R.Tensor((32, 8, 128, 128), dtype="float32") =
R.where(lv14, lv15, lv10)
+ lv17: R.Tensor((32, 8, 128, 128), dtype="float32") =
R.broadcast_to(
+ lv16, R.shape([32, 8, 128, 128])
+ )
+ lv18: R.Tensor((256, 128, 128), dtype="float32") = R.reshape(
+ lv17, R.shape([256, 128, 128])
+ )
+ lv19: R.Tensor((32, 8, 128, 64), dtype="float32") =
R.broadcast_to(
+ v, R.shape([32, 8, 128, 64])
+ )
+ lv20: R.Tensor((256, 128, 64), dtype="float32") = R.reshape(
+ lv19, R.shape([256, 128, 64])
+ )
+ lv21: R.Tensor((256, 128, 64), dtype="float32") = R.matmul(
+ lv18, lv20, out_dtype="float32"
+ )
+ lv22: R.Tensor((32, 8, 128, 64), dtype="float32") = R.reshape(
+ lv21, R.shape([32, 8, 128, 64])
+ )
+ lv23: R.Tensor((128, 32, 8, 64), dtype="float32") =
R.permute_dims(
+ lv22, axes=[2, 0, 1, 3]
+ )
+ lv24: R.Tensor((32, 8, 128, 64), dtype="float32") =
R.permute_dims(
+ lv23, axes=[1, 2, 0, 3]
+ )
+ gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) =
(lv24,)
R.output(gv)
return gv
@@ -3481,6 +3595,7 @@ def test_scaled_dot_product_attention():
),
{},
Expected1,
+ run_ep_decomposition=True,
)
verify_model(
@@ -3493,6 +3608,7 @@ def test_scaled_dot_product_attention():
),
{},
Expected2,
+ run_ep_decomposition=True,
)