This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 22e592b744 [FRONTEND][TFLITE][BugFix] Fix int16 transpose conv loading 
(#15173)
22e592b744 is described below

commit 22e592b744744e9e7fef0ae74034dffb5d3b31dc
Author: Wooseok Lee <[email protected]>
AuthorDate: Thu Jun 29 01:23:05 2023 -0500

    [FRONTEND][TFLITE][BugFix] Fix int16 transpose conv loading (#15173)
    
    Loading int16 conv transpose op in tflite model currently
    fails because output type is not int64.
    
    This patch adjusts output type to int64 for int16 quantized
    transpose convolution operation. In addition, one typo in
    QnnConv2DTransposeRel is fixed.
    
    Test script is also included to evaluate the loading
    of int16 quantized transpose convolution op.
    
    Co-authored-by: Wooseok <[email protected]>
---
 python/tvm/relay/frontend/tflite.py          |  3 +-
 src/relay/qnn/op/convolution_transpose.cc    |  2 +-
 tests/python/frontend/tflite/test_forward.py | 80 ++++++++++++++++++++++++++++
 3 files changed, 83 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/frontend/tflite.py 
b/python/tvm/relay/frontend/tflite.py
index 9e2e244cb1..9e88a85e03 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -3300,6 +3300,7 @@ class OperatorConverter(object):
             kernel_zero_point = weights_tensor.qnn_params["zero_point"]
             input_scale = input_tensor.qnn_params["scale"]
             kernel_scale = weights_tensor.qnn_params["scale"]
+            out_dtype = "int64" if output_tensor_type_str == "int16" else 
"int32"
             out = _qnn.op.conv2d_transpose(
                 in_expr,
                 weight_expr_iohw,
@@ -3313,7 +3314,7 @@ class OperatorConverter(object):
                 kernel_size=(int(kernel_h), int(kernel_w)),
                 data_layout="NHWC",
                 kernel_layout="IOHW",
-                out_dtype="int32",
+                out_dtype=out_dtype,
             )
         else:
             out = _op.nn.conv2d_transpose(
diff --git a/src/relay/qnn/op/convolution_transpose.cc 
b/src/relay/qnn/op/convolution_transpose.cc
index 951c1bdfb0..0b24ae71ca 100644
--- a/src/relay/qnn/op/convolution_transpose.cc
+++ b/src/relay/qnn/op/convolution_transpose.cc
@@ -99,7 +99,7 @@ bool QnnConv2DTransposeRel(const Array<Type>& types, int 
num_inputs, const Attrs
   ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == 
DataType::UInt(8))
       << "Expected qnn conv2d type(int8, uint8) for weight but was " << 
weight->dtype;
   ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == 
DataType::Int(32) ||
-         data->dtype == DataType::Int(64))
+         param->out_dtype == DataType::Int(64))
       << "Expected qnn conv2d type(int16, int32, int64) for output but was " 
<< param->out_dtype;
   ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater 
than 0.";
 
diff --git a/tests/python/frontend/tflite/test_forward.py 
b/tests/python/frontend/tflite/test_forward.py
index 3b3dcc59f0..c65e48b402 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -1648,6 +1648,86 @@ def test_forward_transpose_conv():
             )
 
 
+def _test_tflite2_quantized_transpose_conv(
+    input_shape,
+    kernel_shape,
+    filters,
+    padding="valid",
+    strides=(1, 1),
+    data_format=None,
+    int_quant_dtype=tf.int8,
+):
+    """One iteration of TFLite2 quantized tranpose conv with given shapes and 
attributes"""
+    data_format = "channels_last" if data_format == "NHWC" else 
"channels_first"
+    data = np.random.uniform(0, 1, input_shape).astype("float32")
+    _ = np.random.uniform(0, 1, kernel_shape).astype("float32")
+
+    data_in = tf.keras.layers.Input(shape=data.shape[1:], batch_size=1)
+    transpose_conv = tf.keras.layers.Conv2DTranspose(
+        filters=filters,
+        kernel_size=(kernel_shape[0], kernel_shape[1]),
+        padding=padding,
+        strides=strides,
+        use_bias=True,
+    )(data_in)
+    keras_model = tf.keras.models.Model(data_in, transpose_conv)
+
+    # To create quantized values with dynamic range of activations, needs 
representative dataset
+    def representative_data_gen():
+        for _ in range(1):
+            yield [data]
+
+    tflite_model_quant = _quantize_keras_model(
+        keras_model,
+        representative_data_gen,
+        is_float_input=True,
+        is_float_output=True,
+        int_quant_dtype=int_quant_dtype,
+    )
+
+    # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
+    try:
+        import tflite.Model
+
+        tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_quant, 0)
+    except AttributeError:
+        import tflite
+
+        tflite_model = tflite.Model.GetRootAsModel(tflite_model_quant, 0)
+    except ImportError as exc:
+        raise ImportError("The tflite package must be installed") from exc
+
+    subgraph = tflite_model.Subgraphs(0)
+    model_input = subgraph.InputsAsNumpy()
+    input_node = subgraph.Tensors(model_input).Name().decode("utf-8")
+
+    tflite_output = run_tflite_graph(tflite_model_quant, data)
+
+    if tf.__version__ < LooseVersion("2.9"):
+        input_node = data_in.name.replace(":0", "")
+    else:
+        input_node = "serving_default_" + data_in.name + ":0"
+
+    tvm_output = run_tvm_graph(tflite_model_quant, data, input_node)
+    tvm.testing.assert_allclose(
+        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-2, 
atol=1e-2
+    )
+
+
+def test_forward_quantized_transpose_conv():
+    """Quantized convolution"""
+    for int_quant_dtype in [tf.int8, tf.int16]:
+        _test_tflite2_quantized_transpose_conv(
+            (1, 1, 5, 64),
+            (3, 3),
+            64,
+            padding="same",
+            strides=(1, 2),
+            data_format="NHWC",
+            int_quant_dtype=int_quant_dtype,
+        )
+
+
 #######################################################################
 # Reshape
 # -------

Reply via email to