This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b54f57aa72 [TFLite] Add support for GELU conversion (#16936)
b54f57aa72 is described below

commit b54f57aa721a4e619dbe187bb4ac0cfd37988c71
Author: Luke Hutton <[email protected]>
AuthorDate: Sun Apr 28 10:33:37 2024 +0100

    [TFLite] Add support for GELU conversion (#16936)
    
    This commit adds support for converting a TFLite fp32 GELU operation
    to Relay.
    
    Also includes some neighbouring cleanup of version checks to silence
    warnings.
    
    Change-Id: Ic43b1525c4b80cf7f47281c52bb9a8f2643c4073
---
 python/tvm/relay/frontend/tflite.py          | 21 +++++++++++++++++++++
 tests/python/frontend/tflite/test_forward.py | 19 ++++++++++++++++---
 2 files changed, 37 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/frontend/tflite.py 
b/python/tvm/relay/frontend/tflite.py
index 3648864239..e939895ade 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -109,6 +109,7 @@ class OperatorConverter(object):
             "GATHER_ND": self.convert_gather_nd,
             "GREATER_EQUAL": self.convert_greater_equal,
             "GREATER": self.convert_greater,
+            "GELU": self.convert_gelu,
             "HARD_SWISH": self.convert_hard_swish,
             "L2_NORMALIZATION": self.convert_l2_normalization,
             "L2_POOL_2D": self.convert_l2_pool2d,
@@ -1287,6 +1288,26 @@ class OperatorConverter(object):
 
         return out
 
+    def convert_gelu(self, op):
+        """Convert TFLite GELU"""
+        if self.is_quantized(op):
+            raise tvm.error.OpNotImplemented(
+                "The TFLite to Relay converter does not support quantized GELU 
operator yet."
+            )
+
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 1, "input tensors length should be 1"
+
+        input_tensor = input_tensors[0]
+        in_expr = self.get_expr(input_tensor.tensor_idx)
+        in_type = self.get_tensor_type_str(input_tensor.tensor.Type())
+
+        return in_expr * (
+            _expr.const(0.5, dtype=in_type)
+            + _op.erf(in_expr * _expr.const(0.5**0.5, dtype=in_type))
+            * _expr.const(0.5, dtype=in_type)
+        )
+
     def convert_square(self, op):
         """Convert TFLite SQUARE"""
         input_tensors = self.get_input_tensors(op)
diff --git a/tests/python/frontend/tflite/test_forward.py 
b/tests/python/frontend/tflite/test_forward.py
index 7f65cfbc85..ebf7bce250 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -2150,7 +2150,9 @@ def _test_unary_elemwise(math_op, data, quantized, 
quant_range=(-6, 6), int_quan
         with tf.Graph().as_default():
             in_data = array_ops.placeholder(shape=data.shape, 
dtype=data.dtype, name="in")
             out = math_op(in_data)
-            compare_tflite_with_tvm(data, ["in:0"], [in_data], [out])
+            compare_tflite_with_tvm(
+                data, ["in:0"], [in_data], [out], 
experimental_new_converter=True
+            )
 
 
 def _unary_elewise_create_model(math_op, data, offset=0, 
int_quant_dtype=tf.int8):
@@ -2400,6 +2402,16 @@ def _test_elu(data, quantized, int_quant_dtype=tf.int8):
     return _test_unary_elemwise(nn_ops.elu, data, quantized, 
int_quant_dtype=int_quant_dtype)
 
 
+#######################################################################
+# Gelu
+# ---
+
+
+def _test_gelu(data, quantized, int_quant_dtype=tf.int8):
+    """One iteration of elu"""
+    return _test_unary_elemwise(nn_ops.gelu, data, quantized, 
int_quant_dtype=int_quant_dtype)
+
+
 def _test_forward_unary_elemwise(test_op, int_quant_dtype=None, 
quantized=True, negative=True):
     # input data
     in_data, inq_data = [], []
@@ -2439,15 +2451,16 @@ def test_all_unary_elemwise():
     _test_forward_unary_elemwise(_test_sin)
     _test_forward_unary_elemwise(_test_neg)
     _test_forward_unary_elemwise(_test_sqrt, negative=False)
+    _test_forward_unary_elemwise(_test_gelu, quantized=False)
     # tensorflow version upgrade support
-    if tf.__version__ < LooseVersion("2.6.1"):
+    if package_version.parse(tf.VERSION) < package_version.parse("2.6.1"):
         _test_forward_unary_elemwise(_test_rsqrt, negative=False, 
int_quant_dtype=tf.uint8)
     else:
         _test_forward_unary_elemwise(_test_rsqrt, negative=False, 
int_quant_dtype=tf.int8)
     # ceil and cos come with TFLite 1.14.0.post1 fbs schema
     if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
         _test_forward_unary_elemwise(_test_ceil)
-        if tf.__version__ < LooseVersion("2.6.1"):
+        if package_version.parse(tf.VERSION) < package_version.parse("2.6.1"):
             _test_forward_unary_elemwise(_test_cos, quantized=False)
         else:
             _test_forward_unary_elemwise(_test_cos, int_quant_dtype=tf.int8)

Reply via email to