This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6ed3ab3e33 [TFLite] Support quantized EQUAL op in TFLite frontend 
(#11520)
6ed3ab3e33 is described below

commit 6ed3ab3e33f8eafa4acaf53b7a671831de7587e9
Author: Dhruv Chauhan <[email protected]>
AuthorDate: Wed Jun 22 10:42:00 2022 +0100

    [TFLite] Support quantized EQUAL op in TFLite frontend (#11520)
    
    * [TFLite] Support quantized EQUAL op in TFLite frontend
    
    Support EQUAL quantization operation conversion as part of issue #9187
    
    * [TFLite] Support quantized EQUAL op in TFLite frontend
    
    Update elementwise quantized test for EQUAL op
    Change-Id: I3897d1ac07051ebfc10356ad45397117b592f878
---
 python/tvm/relay/frontend/tflite.py          |  6 +--
 tests/python/frontend/tflite/test_forward.py | 81 +++++++++++++++++++---------
 2 files changed, 57 insertions(+), 30 deletions(-)

diff --git a/python/tvm/relay/frontend/tflite.py 
b/python/tvm/relay/frontend/tflite.py
index 981074b6ad..2a9d66acff 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -1448,11 +1448,7 @@ class OperatorConverter(object):
 
     def convert_equal(self, op):
         """Convert TFLite EQUAL"""
-        if self.is_quantized(op):
-            raise tvm.error.OpNotImplemented(
-                "TFlite quantized EQUAL operator is not supported yet."
-            )
-        return self._convert_elemwise(_op.equal, op)
+        return self._convert_elemwise(_op.equal, op, self.is_quantized(op))
 
     def convert_not_equal(self, op):
         """Convert TFLite NOT_EQUAL"""
diff --git a/tests/python/frontend/tflite/test_forward.py 
b/tests/python/frontend/tflite/test_forward.py
index 76b0766dae..23b5a03ffb 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -2214,22 +2214,33 @@ def _test_elemwise(
                 if None != x[0]
             }
 
-            out = math_op(inq_data[0], inq_data[1])
-            out = with_fused_activation_function(out, 
fused_activation_function)
-            out = tf.quantization.fake_quant_with_min_max_args(
-                out, min=out_min, max=out_max, name="out"
-            )
+            if math_op is math_ops.equal:
+                out = math_op(inq_data[0], inq_data[1])
+                out = with_fused_activation_function(out, 
fused_activation_function)
 
-            # Note same_qnn_params uses experimental_new_converter as toco 
failed
-            compare_tflite_with_tvm(
-                [x[1] for x in zip(in_data, data) if None != x[0]],
-                [x + ":0" for x in input_range.keys()],
-                [x[1] for x in zip(in_data, inq_data) if None != x[0]],
-                [out],
-                quantized=True,
-                input_range=input_range,
-                experimental_new_converter=same_qnn_params,
-            )
+                compare_tflite_with_tvm(
+                    [x[1] for x in zip(in_data, data) if None != x[0]],
+                    [x + ":0" for x in input_range.keys()],
+                    [x[1] for x in zip(in_data, inq_data) if None != x[0]],
+                    [out],
+                )
+            else:
+                out = math_op(inq_data[0], inq_data[1])
+                out = with_fused_activation_function(out, 
fused_activation_function)
+                out = tf.quantization.fake_quant_with_min_max_args(
+                    out, min=out_min, max=out_max, name="out"
+                )
+
+                # Note same_qnn_params uses experimental_new_converter as toco 
failed
+                compare_tflite_with_tvm(
+                    [x[1] for x in zip(in_data, data) if None != x[0]],
+                    [x + ":0" for x in input_range.keys()],
+                    [x[1] for x in zip(in_data, inq_data) if None != x[0]],
+                    [out],
+                    quantized=True,
+                    input_range=input_range,
+                    experimental_new_converter=same_qnn_params,
+                )
         else:
             out = math_op(
                 in_data[0]
@@ -2386,9 +2397,16 @@ def _test_less_equal(data):
 # -----
 
 
-def _test_equal(data):
+def _test_equal(data, fused_activation_function=None, quantized=False, 
qnn_op=None):
     """One iteration of equal"""
-    return _test_elemwise(math_ops.equal, data)
+    return _test_elemwise(
+        math_ops.equal,
+        data,
+        fused_activation_function,
+        quantized,
+        qnn_op,
+        same_qnn_params=True,
+    )
 
 
 #######################################################################
@@ -2454,14 +2472,25 @@ def _test_forward_elemwise(testop):
 
 
 def _test_forward_elemwise_quantized(testop):
-    testop(
-        [
-            np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
-            np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
-        ],
-        quantized=True,
-        qnn_op=testop,
-    )
+    if testop is not _test_equal:
+        testop(
+            [
+                np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
+                np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.uint8),
+            ],
+            quantized=True,
+            qnn_op=testop,
+        )
+    else:
+        # no need for fake_quant to hold tensors in float32 until conversion
+        testop(
+            [
+                np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.float32),
+                np.array(np.random.uniform(0, 255, (3, 6)), dtype=np.float32),
+            ],
+            quantized=True,
+            qnn_op=testop,
+        )
 
 
 def _test_elemwise_qnn_out_range(qnn_op):
@@ -2472,6 +2501,7 @@ def _test_elemwise_qnn_out_range(qnn_op):
         _test_mul: (-5e3, 5e3),
         _test_maximum: (-112, 111),
         _test_minimum: (-128, 127),
+        _test_equal: (-150, 150),
     }
 
     return qnn_out_range[qnn_op]
@@ -2506,6 +2536,7 @@ def test_all_elemwise():
     _test_forward_elemwise(_test_less)
     _test_forward_elemwise(_test_less_equal)
     _test_forward_elemwise(_test_equal)
+    _test_forward_elemwise_quantized(_test_equal)
     _test_forward_elemwise(_test_not_equal)
     if package_version.parse(tf.VERSION) >= package_version.parse("1.14.0"):
         _test_forward_elemwise(_test_floor_divide)

Reply via email to