This is an automated email from the ASF dual-hosted git repository.
leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b54f57aa72 [TFLite] Add support for GELU conversion (#16936)
b54f57aa72 is described below
commit b54f57aa721a4e619dbe187bb4ac0cfd37988c71
Author: Luke Hutton <[email protected]>
AuthorDate: Sun Apr 28 10:33:37 2024 +0100
[TFLite] Add support for GELU conversion (#16936)
This commit adds support for converting a TFLite fp32 GELU operation
to Relay.
Also includes some neighbouring cleanup of version checks to silence
warnings.
Change-Id: Ic43b1525c4b80cf7f47281c52bb9a8f2643c4073
---
python/tvm/relay/frontend/tflite.py | 21 +++++++++++++++++++++
tests/python/frontend/tflite/test_forward.py | 19 ++++++++++++++++---
2 files changed, 37 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relay/frontend/tflite.py
b/python/tvm/relay/frontend/tflite.py
index 3648864239..e939895ade 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -109,6 +109,7 @@ class OperatorConverter(object):
"GATHER_ND": self.convert_gather_nd,
"GREATER_EQUAL": self.convert_greater_equal,
"GREATER": self.convert_greater,
+ "GELU": self.convert_gelu,
"HARD_SWISH": self.convert_hard_swish,
"L2_NORMALIZATION": self.convert_l2_normalization,
"L2_POOL_2D": self.convert_l2_pool2d,
@@ -1287,6 +1288,26 @@ class OperatorConverter(object):
return out
+ def convert_gelu(self, op):
+ """Convert TFLite GELU"""
+ if self.is_quantized(op):
+ raise tvm.error.OpNotImplemented(
+ "The TFLite to Relay converter does not support quantized GELU
operator yet."
+ )
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 1, "input tensors length should be 1"
+
+ input_tensor = input_tensors[0]
+ in_expr = self.get_expr(input_tensor.tensor_idx)
+ in_type = self.get_tensor_type_str(input_tensor.tensor.Type())
+
+ return in_expr * (
+ _expr.const(0.5, dtype=in_type)
+ + _op.erf(in_expr * _expr.const(0.5**0.5, dtype=in_type))
+ * _expr.const(0.5, dtype=in_type)
+ )
+
def convert_square(self, op):
"""Convert TFLite SQUARE"""
input_tensors = self.get_input_tensors(op)
diff --git a/tests/python/frontend/tflite/test_forward.py
b/tests/python/frontend/tflite/test_forward.py
index 7f65cfbc85..ebf7bce250 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -2150,7 +2150,9 @@ def _test_unary_elemwise(math_op, data, quantized,
quant_range=(-6, 6), int_quan
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape,
dtype=data.dtype, name="in")
out = math_op(in_data)
- compare_tflite_with_tvm(data, ["in:0"], [in_data], [out])
+ compare_tflite_with_tvm(
+ data, ["in:0"], [in_data], [out],
experimental_new_converter=True
+ )
def _unary_elewise_create_model(math_op, data, offset=0,
int_quant_dtype=tf.int8):
@@ -2400,6 +2402,16 @@ def _test_elu(data, quantized, int_quant_dtype=tf.int8):
return _test_unary_elemwise(nn_ops.elu, data, quantized,
int_quant_dtype=int_quant_dtype)
+#######################################################################
+# Gelu
+# ---
+
+
+def _test_gelu(data, quantized, int_quant_dtype=tf.int8):
+ """One iteration of elu"""
+ return _test_unary_elemwise(nn_ops.gelu, data, quantized,
int_quant_dtype=int_quant_dtype)
+
+
def _test_forward_unary_elemwise(test_op, int_quant_dtype=None,
quantized=True, negative=True):
# input data
in_data, inq_data = [], []
@@ -2439,15 +2451,16 @@ def test_all_unary_elemwise():
_test_forward_unary_elemwise(_test_sin)
_test_forward_unary_elemwise(_test_neg)
_test_forward_unary_elemwise(_test_sqrt, negative=False)
+ _test_forward_unary_elemwise(_test_gelu, quantized=False)
# tensorflow version upgrade support
- if tf.__version__ < LooseVersion("2.6.1"):
+ if package_version.parse(tf.VERSION) < package_version.parse("2.6.1"):
_test_forward_unary_elemwise(_test_rsqrt, negative=False,
int_quant_dtype=tf.uint8)
else:
_test_forward_unary_elemwise(_test_rsqrt, negative=False,
int_quant_dtype=tf.int8)
# ceil and cos come with TFLite 1.14.0.post1 fbs schema
if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
_test_forward_unary_elemwise(_test_ceil)
- if tf.__version__ < LooseVersion("2.6.1"):
+ if package_version.parse(tf.VERSION) < package_version.parse("2.6.1"):
_test_forward_unary_elemwise(_test_cos, quantized=False)
else:
_test_forward_unary_elemwise(_test_cos, int_quant_dtype=tf.int8)