This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 12e72c29e8 [Relax][PyTorch] Add support for elu, hardtanh ops (#17694)
12e72c29e8 is described below
commit 12e72c29e877bced07f4d0cb5a0861e92cddb38c
Author: Shushi Hong <[email protected]>
AuthorDate: Sun Mar 2 11:47:52 2025 +0800
[Relax][PyTorch] Add support for elu, hardtanh ops (#17694)
* Update fx_translator.py
* Update test_frontend_from_fx.py
* Update base_fx_graph_translator.py
* Update fx_translator.py
* Update test_frontend_from_fx.py
* Update fx_translator.py
* Update fx_translator.py
* lint
* lint
* lint
---
.../frontend/torch/base_fx_graph_translator.py | 29 +++++++++
python/tvm/relax/frontend/torch/fx_translator.py | 4 ++
tests/python/relax/test_frontend_from_fx.py | 70 ++++++++++++++++++++++
3 files changed, 103 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index d84993c68d..e601f18181 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -127,6 +127,28 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
)
return self.block_builder.emit(relax.op.clip(args[0], a_min, a_max))
+ def _elu(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ alpha = node.args[1] if len(node.args) > 1 else
node.kwargs.get("alpha", 1.0)
+ dtype = x.struct_info.dtype
+
+ if isinstance(alpha, (int, float)):
+ alpha = relax.const(alpha, dtype)
+ else:
+ if not isinstance(alpha, relax.Var):
+ alpha = self.block_builder.emit(relax.const(alpha, dtype))
+
+ # α⋅ReLU(1−exp(x))+ReLU(x)
+ return self.block_builder.emit(
+ relax.op.add(
+ relax.op.multiply(
+ alpha,
+ relax.op.nn.relu(relax.op.subtract(relax.const(1, dtype),
relax.op.exp(x))),
+ ),
+ relax.op.nn.relu(x),
+ )
+ )
+
def _gelu(self, node: fx.Node) -> relax.Expr:
approximate = node.kwargs.get("approximate", "none")
if approximate == "none":
@@ -153,6 +175,13 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
x2 = relax.op.divide(x1, relax.const(6, dtype))
return self.block_builder.emit(relax.op.multiply(x, x2))
+ def _hardtanh(self, node: fx.Node) -> relax.Expr:
+ args = self.retrieve_args(node)
+ x = args[0]
+ min_val = node.kwargs.get("min_val", -1.0)
+ max_val = node.kwargs.get("max_val", 1.0)
+ return self.block_builder.emit(relax.op.clip(x, min_val, max_val))
+
def _leakyrelu(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
alpha = node.args[1] if len(node.args) > 1 else
node.kwargs.get("negative_slope", 0.01)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index dffe2b60eb..bbad7c0c70 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -581,9 +581,11 @@ class TorchFXImporter(BaseFXGraphImporter):
## call_module
# unary
nn.Dropout: lambda node: self.env[node.args[0]],
+ nn.ELU: self._elu,
nn.GELU: self._gelu,
nn.Hardsigmoid: self._hardsigmoid,
nn.Hardswish: self._hardswish,
+ nn.Hardtanh: self._hardtanh,
nn.Identity: lambda node: self.env[node.args[0]],
nn.LeakyReLU: self._leakyrelu_module,
nn.LogSoftmax: self._log_softmax_module,
@@ -627,12 +629,14 @@ class TorchFXImporter(BaseFXGraphImporter):
"cos": self._unary_op(relax.op.cos),
"cosh": self._unary_op(relax.op.cosh),
"dropout": lambda node: self.env[node.args[0]],
+ "elu": self._elu,
"erf": self._unary_op(relax.op.erf),
"exp": self._unary_op(relax.op.exp),
"floor": self._unary_op(relax.op.floor),
"gelu": self._gelu,
"hardsigmoid": self._hardsigmoid,
"hardswish": self._hardswish,
+ "hardtanh": self._hardtanh,
"isfinite": self._unary_op(relax.op.isfinite),
"isinf": self._unary_op(relax.op.isinf),
"isnan": self._unary_op(relax.op.isnan),
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 9e7e1ff2ea..797ce05a3f 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -1989,6 +1989,46 @@ def test_extended_unary_ops():
verify_model(Dropout1(), input_info, {}, expected_dropout)
verify_model(Dropout2(), input_info, {}, expected_dropout)
+ # elu
+ class Elu(Module):
+ def __init__(self):
+ super().__init__()
+ self.elu = torch.nn.ELU()
+
+ def forward(self, input):
+ return self.elu(input)
+
+ class Elu2(Module):
+ def forward(self, input):
+ return torch.nn.functional.elu(input)
+
+ @tvm.script.ir_module
+ class expected_elu:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ # block 0
+ with R.dataflow():
+ lv_exp: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.exp(input_1)
+ lv_one_minus_exp: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.subtract(
+ R.const(1.0, dtype="float32"), lv_exp
+ )
+ lv_relu_one_minus_exp: R.Tensor((1, 3, 10, 10),
dtype="float32") = R.nn.relu(
+ lv_one_minus_exp
+ )
+ lv_scaled: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(
+ R.const(1.0, dtype="float32"), lv_relu_one_minus_exp
+ )
+ lv_relu_x: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.relu(input_1)
+ lv_elu: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.add(lv_scaled, lv_relu_x)
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv_elu
+ R.output(gv)
+ return gv
+
+ verify_model(Elu(), input_info, {}, expected_elu)
+ verify_model(Elu2(), input_info, {}, expected_elu)
+
# gelu
class Gelu(Module):
def __init__(self):
@@ -2086,6 +2126,36 @@ def test_extended_unary_ops():
verify_model(Hardswish(), input_info, {}, expected_hardswish)
verify_model(Hardswish2(), input_info, {}, expected_hardswish)
+ # hardtanh
+ class Hardtanh(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.ht = torch.nn.Hardtanh()
+
+ def forward(self, input):
+ return self.ht(input)
+
+ class Hardtanh2(torch.nn.Module):
+ def forward(self, input):
+ return torch.nn.functional.hardtanh(input)
+
+ @tvm.script.ir_module
+ class expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tensor((1, 3, 10, 10), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+ inp_0, R.prim_value(T.float64(-1.0)),
R.prim_value(T.float64(1.0))
+ )
+ gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Hardtanh(), input_info, {}, expected1)
+ verify_model(Hardtanh2(), input_info, {}, expected1)
+
# logical_not
class LogicalNot(Module):
def forward(self, input):