This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 37fe645f94 [Relax] Ingest Tensor.clamp from torch export (#17725)
37fe645f94 is described below
commit 37fe645f945d56e0d4bfac1d9f3bcf355f950a1b
Author: Hugo Latendresse <[email protected]>
AuthorDate: Tue Mar 11 20:27:53 2025 -0400
[Relax] Ingest Tensor.clamp from torch export (#17725)
Allow handling of Torch.clamp when only min is passed, only max is
passed, or tensors are passed as arguments.
---
.../frontend/torch/base_fx_graph_translator.py | 96 +++++++++++++++++++---
.../frontend/torch/exported_program_translator.py | 3 +
python/tvm/relax/frontend/torch/fx_translator.py | 3 +-
tests/python/relax/test_from_exported_to_cuda.py | 81 ++++++++++++++++++
.../relax/test_frontend_from_exported_program.py | 62 +++++++++++++-
tests/python/relax/test_frontend_from_fx.py | 64 ++++++++++-----
6 files changed, 276 insertions(+), 33 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index a0f00e1f4b..6bbc9d5de6 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -19,6 +19,7 @@
# pylint: disable=import-outside-toplevel
"""Base class for PyTorch FX Graph importer."""
import abc
+import math
from typing import Callable, Dict, Optional, Tuple, Union
from tvm import relax
@@ -141,19 +142,94 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
def _clamp(self, node: fx.Node) -> relax.Expr:
args = self.retrieve_args(node)
- a_min = args[1] if len(args) > 1 else node.kwargs["min"]
- a_max = args[2] if len(args) > 2 else node.kwargs["max"]
+ x = args[0]
+ a_min = args[1] if len(args) > 1 else node.kwargs.get("min", -math.inf)
+ a_max = args[2] if len(args) > 2 else node.kwargs.get("max", math.inf)
+
+ a_min = -math.inf if a_min is None else a_min
+ a_max = math.inf if a_max is None else a_max
+
+ # Handle the case where a_min is a tensor
if not isinstance(a_min, (int, float)):
- raise ValueError(
- f"TVM only supports constant min value for torch.clamp/clip, "
- f"but got {a_min} with type {type(a_min)}"
+ from torch import fx
+
+ if isinstance(a_min, fx.Node):
+ # Extract relax Expr (needed for fx.tracer)
+ a_min = self.env[a_min]
+ assert isinstance(a_min, relax.Expr), (
+ f"Unexpected argument type "
+ f"passed to torch.clamp/clip: {a_min} with type {type(a_min)}"
)
+ a_min = self.block_builder.emit(relax.op.broadcast_to(a_min,
self.shape_of(x)))
+ x = self.block_builder.emit(relax.op.maximum(x, a_min))
+ a_min = -math.inf
+
+ # Handle the case where a_max is a tensor
if not isinstance(a_max, (int, float)):
- raise ValueError(
- f"TVM only supports constant max value for torch.clamp/clip, "
- f"but got {a_max} with type {type(a_max)}"
+ from torch import fx
+
+ if isinstance(a_max, fx.Node):
+ # Extract relax Expr (needed for fx.tracer)
+ a_max = self.env[a_max]
+ assert isinstance(a_max, relax.Expr), (
+ f"Unexpected argument type "
+ f"passed to torch.clamp/clip: {a_max} with type {type(a_max)}"
+ )
+ a_max = self.block_builder.emit(relax.op.broadcast_to(a_max,
self.shape_of(x)))
+ x = self.block_builder.emit(relax.op.minimum(x, a_max))
+ a_max = math.inf
+
+ return self.block_builder.emit(relax.op.clip(x, a_min, a_max))
+
+ def _clamp_min(self, node: fx.Node) -> relax.Expr:
+ args = self.retrieve_args(node)
+ x = args[0]
+ a_min = args[1] if len(args) > 1 else node.kwargs.get("min", -math.inf)
+ a_max = math.inf
+
+ a_min = -math.inf if a_min is None else a_min
+
+ # Handle the case where a_min is a tensor
+ if not isinstance(a_min, (int, float)):
+ from torch import fx
+
+ if isinstance(a_min, fx.Node):
+ # Extract relax Expr (needed for fx.tracer)
+ a_min = self.env[a_min]
+ assert isinstance(a_min, relax.Expr), (
+ f"Unexpected argument type "
+ f"passed to torch.clamp/clip: {a_min} with type {type(a_min)}"
)
- return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))
+ a_min = self.block_builder.emit(relax.op.broadcast_to(a_min,
self.shape_of(x)))
+ x = self.block_builder.emit(relax.op.maximum(x, a_min))
+ a_min = -math.inf
+
+ return self.block_builder.emit(relax.op.clip(x, a_min, a_max))
+
+ def _clamp_max(self, node: fx.Node) -> relax.Expr:
+ args = self.retrieve_args(node)
+ x = args[0]
+ a_min = -math.inf
+ a_max = args[2] if len(args) > 2 else node.kwargs.get("max", math.inf)
+
+ a_max = math.inf if a_max is None else a_max
+
+ # Handle the case where a_max is a tensor
+ if not isinstance(a_max, (int, float)):
+ from torch import fx
+
+ if isinstance(a_max, fx.Node):
+ # Extract relax Expr (needed for fx.tracer)
+ a_max = self.env[a_max]
+ assert isinstance(a_max, relax.Expr), (
+ f"Unexpected argument type "
+ f"passed to torch.clamp/clip: {a_max} with type {type(a_max)}"
+ )
+ a_max = self.block_builder.emit(relax.op.broadcast_to(a_max,
self.shape_of(x)))
+ x = self.block_builder.emit(relax.op.minimum(x, a_max))
+ a_max = math.inf
+
+ return self.block_builder.emit(relax.op.clip(x, a_min, a_max))
def _elu(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
@@ -696,8 +772,8 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return self.block_builder.emit(relax.op.reshape(embedding,
[*x_shape, emb_size]))
def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) ->
relax.Var:
- from torch.fx.immutable_collections import immutable_list
import numpy as np # type: ignore
+ from torch.fx.immutable_collections import immutable_list
if isinstance(normalized_shape, (immutable_list, tuple)):
normalized_shape = tuple(normalized_shape)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 2103365c6c..71a3d13aa1 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -193,6 +193,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"bitwise_not.default": self._unary_op(relax.op.bitwise_not),
"ceil.default": self._unary_op(relax.op.ceil),
"clamp.default": self._clamp,
+ "clamp_min.default": self._clamp_min,
+ "clamp_max.default": self._clamp_max,
"cos.default": self._unary_op(relax.op.cos),
"cosh.default": self._unary_op(relax.op.cosh),
"dropout.default": lambda node: self.env[node.args[0]],
@@ -294,6 +296,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"argmin.default": self._argmax_argmin(relax.op.argmin),
# tensor manipulation
"cat.default": self._cat,
+ "clamp.Tensor": self._clamp,
"concat.default": self._cat,
"copy_.default": self._copy_,
"cumsum.default": self._cumsum,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index abda5088db..952fb6f971 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -18,8 +18,8 @@
# pylint: disable=invalid-name, inconsistent-return-statements,
unidiomatic-typecheck
# pylint: disable=import-outside-toplevel
"""PyTorch FX frontend of Relax."""
-from typing import Callable, Dict, List, Tuple, Union
from functools import partial, reduce
+from typing import Callable, Dict, List, Tuple, Union
import tvm
from tvm import relax
@@ -598,6 +598,7 @@ class TorchFXImporter(BaseFXGraphImporter):
self,
) -> Dict[Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]]:
import operator
+
from torch import nn
return {
diff --git a/tests/python/relax/test_from_exported_to_cuda.py
b/tests/python/relax/test_from_exported_to_cuda.py
index e8b5da0dc2..6cc12370d6 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -56,6 +56,87 @@ def
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar
np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5,
atol=1e-5)
[email protected]_targets("cuda")
+def test_tensor_clamp(target, dev):
+ class ClampBothTensor(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer("min_val", torch.tensor(-1.0))
+ self.register_buffer("max_val", torch.tensor(1.0))
+
+ def forward(self, x):
+ return x.clamp(min=self.min_val, max=self.max_val)
+
+ class ClampBothInt(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.min_val = -1
+ self.max_val = 1
+
+ def forward(self, x):
+ return x.clamp(min=self.min_val, max=self.max_val)
+
+ class ClampMinOnlyTensor(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer("min_val", torch.tensor(0.0))
+
+ def forward(self, x):
+ return x.clamp(min=self.min_val)
+
+ class ClampMinOnlyInt(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.min_val = 0
+
+ def forward(self, x):
+ return x.clamp(min=self.min_val)
+
+ class ClampMaxOnlyTensor(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.register_buffer("max_val", torch.tensor(0.5))
+
+ def forward(self, x):
+ return x.clamp(max=self.max_val)
+
+ class ClampMaxOnlyInt(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.max_val = 0.5
+
+ def forward(self, x):
+ return x.clamp(max=self.max_val)
+
+ class ClampDifferentValues(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.min_val = -2
+ self.max_val = 2
+
+ def forward(self, x):
+ return x.clamp(min=self.min_val, max=self.max_val)
+
+ # Create random data with values outside our clamp ranges
+ raw_data = np.random.uniform(-3.0, 3.0, (2, 3, 4, 5)).astype(np.float32)
+
+ torch_module0 = ClampBothTensor().eval()
+ torch_module1 = ClampBothInt().eval()
+ torch_module2 = ClampMinOnlyTensor().eval()
+ torch_module3 = ClampMinOnlyInt().eval()
+ torch_module4 = ClampMaxOnlyTensor().eval()
+ torch_module5 = ClampMaxOnlyInt().eval()
+ torch_module6 = ClampDifferentValues().eval()
+
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0,
target, dev)
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1,
target, dev)
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2,
target, dev)
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3,
target, dev)
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module4,
target, dev)
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module5,
target, dev)
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module6,
target, dev)
+
+
@tvm.testing.parametrize_targets("cuda")
def test_tensor_expand_as(target, dev):
class ExpandAs0(torch.nn.Module):
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 6406610bf5..8b0a711a52 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -135,18 +135,70 @@ def test_extended_unary_ops():
class expected_clamp:
@R.function
def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ input: R.Tensor((1, 3, 10, 10), dtype="float32"),
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
- # block 0
with R.dataflow():
- lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.clip(input_1, 0.1, 0.5)
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+ input,
+ R.prim_value(T.float64(0.10000000000000001)),
+ R.prim_value(T.float64(0.5)),
+ )
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
R.output(gv)
return gv
verify_model(Clamp(), example_args, {}, expected_clamp)
+ class ClampMinOnly(Module):
+ def forward(self, input):
+ return torch.clamp(input, min=0.5, max=None)
+
+ @tvm.script.ir_module
+ class expected_clamp_min_only:
+ @R.function
+ def main(
+ input: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+ input, R.prim_value(T.float64(0.5)),
R.prim_value(T.float64("inf"))
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(ClampMinOnly(), example_args, {}, expected_clamp_min_only)
+
+ class ClampTensors(Module):
+ def forward(self, input):
+ return torch.clamp(input, min=input, max=input)
+
+ @tvm.script.ir_module
+ class expected_clamp_tensors:
+ @R.function
+ def main(
+ input: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.broadcast_to(
+ input, R.shape([1, 3, 10, 10])
+ )
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.maximum(input, lv)
+ lv2: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.broadcast_to(
+ input, R.shape([1, 3, 10, 10])
+ )
+ lv3: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.minimum(lv1, lv2)
+ lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+ lv3, R.prim_value(T.float64("-inf")),
R.prim_value(T.float64("inf"))
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv4,)
+ R.output(gv)
+ return gv
+
+ verify_model(ClampTensors(), example_args, {}, expected_clamp_tensors)
+
# dropout
+
class Dropout1(Module):
def __init__(self):
super().__init__()
@@ -3248,3 +3300,7 @@ def test_no_bind_return_tuple():
exported_program = export(Identity(), args=example_args)
mod = from_exported_program(exported_program, no_bind_return_tuple=True)
tvm.ir.assert_structural_equal(mod, Expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 020fc8f5b3..fbea8b7388 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
import operator
import pytest
import torch
@@ -21,6 +22,7 @@ import torch.nn.functional as F
from torch import fx
from torch.nn import Module
import torchvision
+import math
import tvm
from tvm import relax
@@ -1970,7 +1972,7 @@ def test_extended_unary_ops():
class expected_clamp:
@R.function
def main(
- input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
# block 0
with R.dataflow():
@@ -1981,29 +1983,53 @@ def test_extended_unary_ops():
verify_model(Clamp(), input_info, {}, expected_clamp)
- from tvm.relax.frontend.torch import from_fx
-
- with pytest.raises(
- ValueError, match="TVM only supports constant max value for
torch.clamp/clip"
- ):
+ class ClampMinOnly(Module):
+ def forward(self, input):
+ return torch.clamp(input, min=0.5, max=None)
- class Clamp_Error(Module):
- def forward(self, input):
- return torch.clamp(input, min=0.5, max=None)
+ @tvm.script.ir_module
+ class expected_clamp_min_only:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.clip(input_1, 0.5, math.inf)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
- gm = fx.symbolic_trace(Clamp_Error())
- from_fx(gm, input_info)
+ verify_model(ClampMinOnly(), input_info, {}, expected_clamp_min_only)
- with pytest.raises(
- ValueError, match="TVM only supports constant min value for
torch.clamp/clip"
- ):
+ class ClampTensors(Module):
+ def forward(self, input):
+ return torch.clamp(input, min=input, max=input)
- class Clamp_Error(Module):
- def forward(self, input):
- return torch.clamp(input, min=input, max=input)
+ @tvm.script.ir_module
+ class expected_clamp_tensors:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.broadcast_to(
+ inp_0, R.shape([1, 3, 10, 10])
+ )
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.maximum(inp_0, lv)
+ lv2: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.broadcast_to(
+ inp_0, R.shape([1, 3, 10, 10])
+ )
+ lv3: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.minimum(lv1, lv2)
+ lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+ lv3, R.prim_value(T.float64("-inf")),
R.prim_value(T.float64("inf"))
+ )
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv4
+ R.output(gv)
+ return gv
- gm = fx.symbolic_trace(Clamp_Error())
- from_fx(gm, input_info)
+ verify_model(ClampTensors(), input_info, {}, expected_clamp_tensors)
# dropout
class Dropout1(Module):