This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 16ec246df9 [Relax][PyTorch] Add support for celu, selu,
is_floating_point ops (#17702)
16ec246df9 is described below
commit 16ec246df91cf664cf88c12c068f140124603ead
Author: Shushi Hong <[email protected]>
AuthorDate: Tue Mar 4 12:35:52 2025 +0800
[Relax][PyTorch] Add support for celu, selu, is_floating_point ops (#17702)
* Update fx_translator.py
* Update base_fx_graph_translator.py
* Update test_frontend_from_fx.py
* lint
* lint
* Update fx_translator.py
* Update base_fx_graph_translator.py
* Update test_frontend_from_fx.py
---
.../frontend/torch/base_fx_graph_translator.py | 65 ++++++++++++-
python/tvm/relax/frontend/torch/fx_translator.py | 11 +++
tests/python/relax/test_frontend_from_fx.py | 103 ++++++++++++++++++++-
3 files changed, 175 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index e601f18181..4ce899685a 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -111,6 +111,34 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return convert
+ def _celu(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ alpha = node.args[1] if len(node.args) > 1 else
node.kwargs.get("alpha", 1.0)
+ dtype = x.struct_info.dtype
+
+ if isinstance(alpha, (int, float)):
+ alpha = relax.const(alpha, dtype)
+ else:
+ if not isinstance(alpha, relax.Var):
+ alpha = self.block_builder.emit(relax.const(alpha, dtype))
+
+ zero = relax.const(0, dtype)
+ # alpha * min(0, exp(x / alpha) - 1) + max(0, x)
+ return self.block_builder.emit(
+ relax.op.add(
+ relax.op.multiply(
+ alpha,
+ relax.op.minimum(
+ zero,
+ relax.op.subtract(
+ relax.op.divide(relax.op.exp(x), alpha),
relax.const(1, dtype)
+ ),
+ ),
+ ),
+ relax.op.nn.relu(x),
+ )
+ )
+
def _clamp(self, node: fx.Node) -> relax.Expr:
args = self.retrieve_args(node)
a_min = args[1] if len(args) > 1 else node.kwargs["min"]
@@ -133,12 +161,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
dtype = x.struct_info.dtype
if isinstance(alpha, (int, float)):
- alpha = relax.const(alpha, dtype)
+ alpha = relax.const(-alpha, dtype)
else:
if not isinstance(alpha, relax.Var):
- alpha = self.block_builder.emit(relax.const(alpha, dtype))
+ alpha = self.block_builder.emit(relax.const(-alpha, dtype))
- # α⋅ReLU(1−exp(x))+ReLU(x)
+ # alpha * ReLU(1 − exp(x)) + ReLU(x)
return self.block_builder.emit(
relax.op.add(
relax.op.multiply(
@@ -203,6 +231,37 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
-1)
return self.block_builder.emit(relax.op.nn.softmax(x, dim))
+ def _selu(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ alpha = node.args[1] if len(node.args) > 1 else
node.kwargs.get("alpha", 1.6732631921768188)
+ gamma = node.args[2] if len(node.args) > 2 else
node.kwargs.get("gamma", 1.0507009873554805)
+ dtype = x.struct_info.dtype
+
+ if isinstance(alpha, (int, float)):
+ alpha = relax.const(alpha, dtype)
+ else:
+ if not isinstance(alpha, relax.Var):
+ alpha = self.block_builder.emit(relax.const(alpha, dtype))
+
+ if isinstance(gamma, (int, float)):
+ gamma = relax.const(gamma, dtype)
+ else:
+ if not isinstance(gamma, relax.Var):
+ gamma = self.block_builder.emit(relax.const(gamma, dtype))
+
+ # gamma * (ReLU(x) + alpha * (exp(x) - 1))
+ return self.block_builder.emit(
+ relax.op.multiply(
+ gamma,
+ relax.op.add(
+ relax.op.nn.relu(x),
+ relax.op.multiply(
+ alpha, relax.op.subtract(relax.op.exp(x),
relax.const(1, dtype))
+ ),
+ ),
+ )
+ )
+
def _tril_triu(self, op: Callable) -> Callable:
from torch import fx
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index bbad7c0c70..af84f71bbf 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -527,6 +527,12 @@ class TorchFXImporter(BaseFXGraphImporter):
def _half(self, node: fx.Node) -> relax.Var:
return self.block_builder.emit(relax.op.astype(self.env[node.args[0]],
"float16"))
+ def _is_floating_point(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ return relax.const(
+ x.struct_info.dtype in ["float16", "float32", "float64",
"bfloat16"], "bool"
+ )
+
def _to(self, node: fx.Node) -> relax.Var:
import torch
@@ -580,6 +586,7 @@ class TorchFXImporter(BaseFXGraphImporter):
return {
## call_module
# unary
+ nn.CELU: self._celu,
nn.Dropout: lambda node: self.env[node.args[0]],
nn.ELU: self._elu,
nn.GELU: self._gelu,
@@ -594,6 +601,7 @@ class TorchFXImporter(BaseFXGraphImporter):
relax.op.clip(self.env[node.args[0]], 0, 6)
),
nn.Sigmoid: self._unary_op(relax.op.sigmoid),
+ nn.SELU: self._selu,
nn.SiLU: self._unary_op(relax.op.nn.silu),
nn.Softmax: self._softmax_module,
nn.Tanh: self._unary_op(relax.op.tanh),
@@ -625,6 +633,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"atanh": self._unary_op(relax.op.atanh),
"bitwise_not": self._unary_op(relax.op.bitwise_not),
"ceil": self._unary_op(relax.op.ceil),
+ "celu": self._celu,
"clamp": self._clamp,
"cos": self._unary_op(relax.op.cos),
"cosh": self._unary_op(relax.op.cosh),
@@ -648,6 +657,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"relu": self._unary_op(relax.op.nn.relu),
"round": self._round,
"rsqrt": self._unary_op(relax.op.rsqrt),
+ "selu": self._selu,
"sigmoid": self._unary_op(relax.op.sigmoid),
"sign": self._unary_op(relax.op.sign),
"silu": self._unary_op(relax.op.nn.silu),
@@ -753,6 +763,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"astype": self._type,
"float": self._float,
"half": self._half,
+ "is_floating_point": self._is_floating_point,
"to": self._to,
"type": self._type,
# other
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 797ce05a3f..e9fa796531 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1917,6 +1917,50 @@ def test_bool_unary_ops(pytorch_op, relax_op):
def test_extended_unary_ops():
input_info = [([1, 3, 10, 10], "float32")]
+ # celu
+ class Celu1(Module):
+ def __init__(self):
+ super().__init__()
+ self.celu = torch.nn.CELU()
+
+ def forward(self, input):
+ return self.celu(input)
+
+ class Celu2(Module):
+ def forward(self, input):
+ return torch.nn.functional.celu(input)
+
+ # alpha * min(0, exp(x / alpha) - 1) + max(0, x)
+ @tvm.script.ir_module
+ class expected_celu:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(input_1)
+ lv_div: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
+ lv, R.const(1.0, "float32")
+ )
+ lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
+ lv_div, R.const(1.0, "float32")
+ )
+ lv_min: R.Tensor((1, 3, 10, 10), dtype="float32") = R.minimum(
+ R.const(0.0, "float32"), lv_sub
+ )
+ lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(
+ R.const(1.0, "float32"), lv_min
+ )
+ lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.relu(input_1)
+ lv_celu: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.add(lv_scaled, lv_relu_x)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv_celu
+ R.output(gv)
+ return gv
+
+ verify_model(Celu1(), input_info, {}, expected_celu)
+ verify_model(Celu2(), input_info, {}, expected_celu)
+
# clamp
class Clamp(Module):
def forward(self, input):
@@ -2018,7 +2062,7 @@ def test_extended_unary_ops():
lv_one_minus_exp
)
lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(
- R.const(1.0, dtype="float32"), lv_relu_one_minus_exp
+ R.const(-1.0, dtype="float32"), lv_relu_one_minus_exp
)
lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.relu(input_1)
lv_elu: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.add(lv_scaled, lv_relu_x)
@@ -2256,6 +2300,46 @@ def test_extended_unary_ops():
verify_model(ReLU6(), input_info, {}, expected_relu6)
+ # selu
+ class Selu1(Module):
+ def __init__(self):
+ super().__init__()
+ self.selu = torch.nn.SELU()
+
+ def forward(self, input):
+ return self.selu(input)
+
+ class Selu2(Module):
+ def forward(self, input):
+ return torch.nn.functional.selu(input)
+
+ @tvm.script.ir_module
+ class expected_selu:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv_relu: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.relu(input_1)
+ lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.exp(input_1)
+ lv_sub: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
+ lv_exp, R.const(1.0, "float32")
+ )
+ lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(
+ R.const(1.6732631921768188, "float32"), lv_sub
+ )
+ lv_add: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.add(lv_relu, lv_scaled)
+ lv_selu: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(
+ R.const(1.0507009873554805, "float32"), lv_add
+ )
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv_selu
+ R.output(gv)
+ return gv
+
+ verify_model(Selu1(), input_info, {}, expected_selu)
+ verify_model(Selu2(), input_info, {}, expected_selu)
+
# sigmoid
class Sigmoid(Module):
def __init__(self):
@@ -3802,5 +3886,22 @@ def test_masked_scatter():
)
+def test_is_floating_point():
+ class IsFloatingPoint(Module):
+ def forward(self, x):
+ return torch.is_floating_point(x)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(inp_0: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((),
dtype="bool"):
+ with R.dataflow():
+ gv: R.Tensor((), dtype="bool") = R.const(True, "bool")
+ R.output(gv)
+ return gv
+
+ verify_model(IsFloatingPoint(), [([2, 3], "float32")], {}, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()