This is an automated email from the ASF dual-hosted git repository.

mousius pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 034dc67d03 [TFLite] Enable int64 biases for int16 quantized operators 
(#12042)
034dc67d03 is described below

commit 034dc67d032aac3b848e15a87a7fbb5b72a0b909
Author: Leandro Nunes <[email protected]>
AuthorDate: Tue Nov 15 10:30:50 2022 +0000

    [TFLite] Enable int64 biases for int16 quantized operators (#12042)
    
    This enables int64 biases for quantized fully connected, requantize
    and transpose convolution in TFLite networks. It goes on top of existing
    int16 support for TFLite frontend.
    
    Add a test case using DS_CNN int16 quantized.
---
 python/tvm/relay/frontend/tflite.py                |   6 +-
 src/relay/qnn/op/convolution_transpose.cc          |  10 +-
 src/relay/qnn/op/dense.cc                          |  10 +-
 src/relay/qnn/op/requantize.cc                     |   5 +-
 .../test_ethosn/test_convert_equivalents.py        |   4 +-
 tests/python/frontend/tflite/test_forward.py       |  23 +
 tests/python/relay/test_op_qnn_requantize.py       | 495 ++++++++++++---------
 7 files changed, 329 insertions(+), 224 deletions(-)

diff --git a/python/tvm/relay/frontend/tflite.py 
b/python/tvm/relay/frontend/tflite.py
index 1915eb9322..3d2f4a2f25 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -1966,7 +1966,7 @@ class OperatorConverter(object):
                 input_scale=input_tensor.qnn_params["scale"],
                 kernel_scale=weight_tensor.qnn_params["scale"],
                 units=weight_shape[0],
-                out_dtype="int32",
+                out_dtype="int64" if output_tensor_type_str == "int16" else 
"int32",
             )
         else:
             out = _op.nn.dense(in_expr, weight_expr, units=weight_shape[0])
@@ -1977,7 +1977,7 @@ class OperatorConverter(object):
             if bias_tensor.tensor_idx != -1:
                 bias_tensor_type = bias_tensor.tensor.Type()
                 # bias tensor type should be INT32 (quantization) or FLOAT32
-                assert bias_tensor_type in (TensorType.INT32, 
TensorType.FLOAT32)
+                assert bias_tensor_type in (TensorType.INT32, 
TensorType.INT64, TensorType.FLOAT32)
                 bias_tensor_type_str = 
self.get_tensor_type_str(bias_tensor_type)
                 if self.has_expr(bias_tensor.tensor_idx):
                     bias_expr = self.get_expr(bias_tensor.tensor_idx)
@@ -3175,7 +3175,7 @@ class OperatorConverter(object):
             bias_tensor = input_tensors[3]
             bias_tensor_type = bias_tensor.tensor.Type()
             # bias tensor type should be INT32 (quantization) or FLOAT32
-            assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
+            assert bias_tensor_type in (TensorType.INT32, TensorType.INT64, 
TensorType.FLOAT32)
             bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
             if self.has_expr(bias_tensor.tensor_idx):
                 bias_expr = self.get_expr(bias_tensor.tensor_idx)
diff --git a/src/relay/qnn/op/convolution_transpose.cc 
b/src/relay/qnn/op/convolution_transpose.cc
index 6163e1c204..951c1bdfb0 100644
--- a/src/relay/qnn/op/convolution_transpose.cc
+++ b/src/relay/qnn/op/convolution_transpose.cc
@@ -93,12 +93,14 @@ bool QnnConv2DTransposeRel(const Array<Type>& types, int 
num_inputs, const Attrs
   if (data == nullptr || weight == nullptr) return false;
   const auto* param = attrs.as<Conv2DTransposeAttrs>();
   ICHECK(param != nullptr) << "Conv2DTransposeAttrs cannot be nullptr.";
-  ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
-      << "Expected qnn conv2d type(int8, uint8) for input but was " << 
data->dtype;
+  ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8) ||
+         data->dtype == DataType::Int(16) || data->dtype == DataType::UInt(16))
+      << "Expected qnn conv2d type(int8, uint8, int16) for input but was " << 
data->dtype;
   ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == 
DataType::UInt(8))
       << "Expected qnn conv2d type(int8, uint8) for weight but was " << 
weight->dtype;
-  ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == 
DataType::Int(32))
-      << "Expected qnn conv2d type(int32, int16) for output but was " << 
param->out_dtype;
+  ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype == 
DataType::Int(32) ||
+         data->dtype == DataType::Int(64))
+      << "Expected qnn conv2d type(int16, int32, int64) for output but was " 
<< param->out_dtype;
   ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater 
than 0.";
 
   // Check the types of scale and zero points.
diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc
index adaf509e7d..09d51e3c9c 100644
--- a/src/relay/qnn/op/dense.cc
+++ b/src/relay/qnn/op/dense.cc
@@ -47,12 +47,14 @@ bool QnnDenseRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
   if (data == nullptr || weight == nullptr) return false;
   const auto* param = attrs.as<DenseAttrs>();
   ICHECK(param != nullptr) << "DenseAttrs cannot be nullptr.";
-  ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
-      << "Expected quantized dense type(int8, uint8) for input but was " << 
data->dtype;
+  ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8) ||
+         data->dtype == DataType::Int(16) || data->dtype == DataType::UInt(16))
+      << "Expected quantized dense type(int8, uint8, int16, uint16) for input 
but was "
+      << data->dtype;
   ICHECK(weight->dtype == DataType::Int(8) || weight->dtype == 
DataType::UInt(8))
       << "Expected quantized dense type(int8, uint8) for weight but was " << 
weight->dtype;
-  ICHECK(param->out_dtype == DataType::Int(32))
-      << "Expected quantized dense type(int32) for output but was " << 
param->out_dtype;
+  ICHECK(param->out_dtype == DataType::Int(32) || param->out_dtype == 
DataType::Int(64))
+      << "Expected quantized dense type(int32, int64) for output but was " << 
param->out_dtype;
 
   // Check the types of scale and zero points.
   for (size_t i = 2; i < 5; ++i) {
diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc
index 1614652719..e199ea27f1 100644
--- a/src/relay/qnn/op/requantize.cc
+++ b/src/relay/qnn/op/requantize.cc
@@ -480,8 +480,9 @@ bool RequantizeRel(const Array<Type>& types, int 
num_inputs, const Attrs& attrs,
   }
   const auto in_dtype = data->dtype;
   ICHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) ||
-         in_dtype == DataType::Int(32) || in_dtype == DataType::Int(64))
-      << "Input type should be one of [int8, uint8, int32, int64] but was " << 
in_dtype;
+         in_dtype == DataType::Int(16) || in_dtype == DataType::Int(32) ||
+         in_dtype == DataType::Int(64))
+      << "Input type should be one of [int8, uint8, int16, int32, int64] but 
was " << in_dtype;
 
   const RequantizeAttrs* requantize_attrs = attrs.as<RequantizeAttrs>();
   int axis = requantize_attrs->axis;
diff --git a/tests/python/contrib/test_ethosn/test_convert_equivalents.py 
b/tests/python/contrib/test_ethosn/test_convert_equivalents.py
index 7777729372..a3e48f4424 100644
--- a/tests/python/contrib/test_ethosn/test_convert_equivalents.py
+++ b/tests/python/contrib/test_ethosn/test_convert_equivalents.py
@@ -227,7 +227,7 @@ def test_multiply_to_reinterpret_quantize(shape, 
constant_shape, reverse_inputs)
 @requires_ethosn
 @pytest.mark.parametrize(
     "dtype,shape,constant_shape",
-    [("int16", (1, 16, 12, 4), None)],
+    [("float32", (1, 16, 12, 4), None)],
 )
 def test_unsupported_multiply_to_reinterpret_quantize(dtype, shape, 
constant_shape):
     """
@@ -445,7 +445,7 @@ def test_add_to_reinterpret_quantize(shape, constant_shape, 
reverse_inputs):
 @pytest.mark.parametrize(
     "dtype,shape,constant_shape",
     [
-        ("int16", (1, 16, 12, 4), None),
+        ("float32", (1, 16, 12, 4), None),
     ],
 )
 def test_unsupported_add_to_reinterpret_quantize(dtype, shape, constant_shape):
diff --git a/tests/python/frontend/tflite/test_forward.py 
b/tests/python/frontend/tflite/test_forward.py
index 7b2bd60d8a..877406ae2a 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -4878,6 +4878,28 @@ def test_forward_mobilenet_int16():
     tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
 
 
+def test_forward_ds_cnn_int16():
+    """Test DS_CNN int16 quantized model"""
+    tflite_model_file = download_testdata(
+        
"https://github.com/ARM-software/ML-zoo/blob/48f458af1e9065d9aad2ad94d24b58d6e7c00817/";
+        
"models/keyword_spotting/ds_cnn_small/tflite_int16/ds_cnn_quantized.tflite?raw=true",
+        "ds_cnn_quantized_int16.tflite",
+    )
+
+    with open(tflite_model_file, "rb") as f:
+        tflite_model_buf = f.read()
+
+    data = np.random.uniform(size=(1, 490)).astype("int16")
+
+    tflite_output = run_tflite_graph(tflite_model_buf, data)
+    tflite_predictions = np.squeeze(tflite_output)
+    tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
+    tvm_output = run_tvm_graph(tflite_model_buf, data, 
"serving_default_input:0")
+    tvm_predictions = np.squeeze(tvm_output)
+    tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
+    tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
+
+
 #######################################################################
 # Unidirectional Sequence LSTM
 # ---------------------
@@ -5250,3 +5272,4 @@ if __name__ == "__main__":
     test_forward_tflite_float16()
 
     test_forward_tflite_int16()
+    test_forward_ds_cnn_int16()
diff --git a/tests/python/relay/test_op_qnn_requantize.py 
b/tests/python/relay/test_op_qnn_requantize.py
index 64306476df..1dee1f5b61 100644
--- a/tests/python/relay/test_op_qnn_requantize.py
+++ b/tests/python/relay/test_op_qnn_requantize.py
@@ -23,6 +23,7 @@ from tvm.contrib import graph_executor
 
 roundings = ["UPWARD", "TONEAREST"]
 compute_dtypes = ["float32", "float64", "int64"]
+out_dtypes = ["int8", "int16"]
 
 
 def verify(mod, goldens, target="llvm"):
@@ -83,17 +84,18 @@ def test_same_scale():
     golden_output = golden_data
     for compute_dtype in compute_dtypes:
         for rounding in roundings:
-            mod = get_mod(
-                data_shape=(200,),
-                data_dtype="int32",
-                out_dtype="int8",
-                input_scale=0.5,
-                output_scale=0.5,
-                rounding=rounding,
-                compute_dtype=compute_dtype,
-            )
-            assert "right_shift" not in mod.astext()
-            verify(mod, (golden_data, golden_output))
+            for qnn_out_dtype in out_dtypes:
+                mod = get_mod(
+                    data_shape=(200,),
+                    data_dtype="int32",
+                    out_dtype=qnn_out_dtype,
+                    input_scale=0.5,
+                    output_scale=0.5,
+                    rounding=rounding,
+                    compute_dtype=compute_dtype,
+                )
+                assert "right_shift" not in mod.astext()
+                verify(mod, (golden_data, golden_output))
 
 
 def test_scalar_same_scale():
@@ -102,75 +104,77 @@ def test_scalar_same_scale():
     golden_output = golden_data
     for compute_dtype in compute_dtypes:
         for rounding in roundings:
-            mod = get_mod(
-                data_shape=(),
-                data_dtype="int32",
-                out_dtype="int8",
-                input_scale=0.5,
-                output_scale=0.5,
-                rounding=rounding,
-                compute_dtype=compute_dtype,
-            )
-            assert "right_shift" not in mod.astext()
-            verify(mod, (golden_data, golden_output))
+            for qnn_out_dtype in out_dtypes:
+                mod = get_mod(
+                    data_shape=(),
+                    data_dtype="int32",
+                    out_dtype=qnn_out_dtype,
+                    input_scale=0.5,
+                    output_scale=0.5,
+                    rounding=rounding,
+                    compute_dtype=compute_dtype,
+                )
+                assert "right_shift" not in mod.astext()
+                verify(mod, (golden_data, golden_output))
 
 
 def test_downscale():
     for compute_dtype in compute_dtypes:
         for rounding in roundings:
-            mod = get_mod(
-                data_shape=(32,),
-                data_dtype="int32",
-                out_dtype="int8",
-                input_scale=1,
-                output_scale=16,
-                rounding=rounding,
-                compute_dtype=compute_dtype,
-            )
+            for qnn_out_dtype in out_dtypes:
+                mod = get_mod(
+                    data_shape=(32,),
+                    data_dtype="int32",
+                    out_dtype=qnn_out_dtype,
+                    input_scale=1,
+                    output_scale=16,
+                    rounding=rounding,
+                    compute_dtype=compute_dtype,
+                )
 
-            # Try positive values
-            # 8 corresponds to 0.5, resulting in 1
-            golden_data = np.arange(0, 32, 1).astype("int32")
-            golden_output = np.repeat([0, 1, 2], [8, 16, 8])
-            verify(mod, (golden_data, golden_output))
+                # Try positive values
+                # 8 corresponds to 0.5, resulting in 1
+                golden_data = np.arange(0, 32, 1).astype("int32")
+                golden_output = np.repeat([0, 1, 2], [8, 16, 8])
+                verify(mod, (golden_data, golden_output))
 
-            # Try negative values
-            # -8 corresponds to -0.5. For UPWARD, this is 0
-            golden_data = np.arange(0, -32, -1).astype("int32")
-            if rounding == "UPWARD":
-                golden_output = np.repeat([0, -1, -2], [9, 16, 7])
-            else:
-                golden_output = np.repeat([0, -1, -2], [8, 16, 8])
-            verify(mod, (golden_data, golden_output))
+                # Try negative values
+                # -8 corresponds to -0.5. For UPWARD, this is 0
+                golden_data = np.arange(0, -32, -1).astype("int32")
+                if rounding == "UPWARD":
+                    golden_output = np.repeat([0, -1, -2], [9, 16, 7])
+                else:
+                    golden_output = np.repeat([0, -1, -2], [8, 16, 8])
+                verify(mod, (golden_data, golden_output))
 
-            # Try a different scale
-            mod = get_mod(
-                data_shape=(32,),
-                data_dtype="int32",
-                out_dtype="int8",
-                input_scale=1,
-                output_scale=4,
-                rounding=rounding,
-            )
+                # Try a different scale
+                mod = get_mod(
+                    data_shape=(32,),
+                    data_dtype="int32",
+                    out_dtype=qnn_out_dtype,
+                    input_scale=1,
+                    output_scale=4,
+                    rounding=rounding,
+                )
 
-            # Try positive values
-            # 2I corresponds to 0.5, resulting in 1
-            golden_data = np.arange(0, 32, 1).astype("int32")
-            golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], [2, 4, 4, 
4, 4, 4, 4, 4, 2])
-            verify(mod, (golden_data, golden_output))
+                # Try positive values
+                # 2I corresponds to 0.5, resulting in 1
+                golden_data = np.arange(0, 32, 1).astype("int32")
+                golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], [2, 4, 
4, 4, 4, 4, 4, 4, 2])
+                verify(mod, (golden_data, golden_output))
 
-            # Try negative values
-            # -8 corresponds to -0.5. For UPWARD, this is 0
-            golden_data = np.arange(0, -32, -1).astype("int32")
-            if rounding == "UPWARD":
-                golden_output = np.repeat(
-                    [0, -1, -2, -3, -4, -5, -6, -7, -8], [3, 4, 4, 4, 4, 4, 4, 
4, 1]
-                )
-            else:
-                golden_output = np.repeat(
-                    [0, -1, -2, -3, -4, -5, -6, -7, -8], [2, 4, 4, 4, 4, 4, 4, 
4, 2]
-                )
-            verify(mod, (golden_data, golden_output))
+                # Try negative values
+                # -8 corresponds to -0.5. For UPWARD, this is 0
+                golden_data = np.arange(0, -32, -1).astype("int32")
+                if rounding == "UPWARD":
+                    golden_output = np.repeat(
+                        [0, -1, -2, -3, -4, -5, -6, -7, -8], [3, 4, 4, 4, 4, 
4, 4, 4, 1]
+                    )
+                else:
+                    golden_output = np.repeat(
+                        [0, -1, -2, -3, -4, -5, -6, -7, -8], [2, 4, 4, 4, 4, 
4, 4, 4, 2]
+                    )
+                verify(mod, (golden_data, golden_output))
 
             # Try uint8 out_dtype
             mod = get_mod(
@@ -208,74 +212,76 @@ def test_downscale():
 def test_upscale():
     for compute_dtype in compute_dtypes:
         for rounding in roundings:
-            mod = get_mod(
-                data_shape=(32,),
-                data_dtype="int32",
-                out_dtype="int8",
-                input_scale=2,
-                output_scale=1,
-                rounding=rounding,
-                compute_dtype=compute_dtype,
-            )
+            for qnn_out_dtype in out_dtypes:
+                mod = get_mod(
+                    data_shape=(32,),
+                    data_dtype="int32",
+                    out_dtype=qnn_out_dtype,
+                    input_scale=2,
+                    output_scale=1,
+                    rounding=rounding,
+                    compute_dtype=compute_dtype,
+                )
 
-            # Try positive values
-            # 8 corresponds to 0.5, resulting in 1
-            golden_data = np.arange(0, 32, 1).astype("int32")
-            golden_output = np.multiply(2, golden_data)
-            verify(mod, (golden_data, golden_output))
+                # Try positive values
+                # 8 corresponds to 0.5, resulting in 1
+                golden_data = np.arange(0, 32, 1).astype("int32")
+                golden_output = np.multiply(2, golden_data)
+                verify(mod, (golden_data, golden_output))
 
-            # Try negative values
-            # -8 corresponds to -0.5. For UPWARD, this is 0
-            golden_data = np.arange(0, -32, -1).astype("int32")
-            golden_output = np.multiply(2, golden_data)
-            verify(mod, (golden_data, golden_output))
+                # Try negative values
+                # -8 corresponds to -0.5. For UPWARD, this is 0
+                golden_data = np.arange(0, -32, -1).astype("int32")
+                golden_output = np.multiply(2, golden_data)
+                verify(mod, (golden_data, golden_output))
 
 
 def test_non_power_of_two():
     for compute_dtype in compute_dtypes:
         for rounding in roundings:
-            mod = get_mod(
-                data_shape=(32,),
-                data_dtype="int32",
-                out_dtype="int8",
-                input_scale=1,
-                output_scale=3,
-                rounding=rounding,
-                compute_dtype=compute_dtype,
-            )
+            for qnn_out_dtype in out_dtypes:
+                mod = get_mod(
+                    data_shape=(32,),
+                    data_dtype="int32",
+                    out_dtype=qnn_out_dtype,
+                    input_scale=1,
+                    output_scale=3,
+                    rounding=rounding,
+                    compute_dtype=compute_dtype,
+                )
 
-            # Try positive values
-            golden_data = np.multiply(np.arange(0, 32, 1).astype("int32"), 3)
-            golden_output = np.arange(0, 32, 1)
-            verify(mod, (golden_data, golden_output))
+                # Try positive values
+                golden_data = np.multiply(np.arange(0, 32, 1).astype("int32"), 
3)
+                golden_output = np.arange(0, 32, 1)
+                verify(mod, (golden_data, golden_output))
 
-            # Try negative values
-            golden_data = np.multiply(np.arange(0, -32, -1).astype("int32"), 3)
-            golden_output = np.arange(0, -32, -1)
-            verify(mod, (golden_data, golden_output))
+                # Try negative values
+                golden_data = np.multiply(np.arange(0, -32, 
-1).astype("int32"), 3)
+                golden_output = np.arange(0, -32, -1)
+                verify(mod, (golden_data, golden_output))
 
-            # Try a different scale
-            mod = get_mod(
-                data_shape=(32,),
-                data_dtype="int32",
-                out_dtype="int8",
-                input_scale=3,
-                output_scale=1,
-                rounding=rounding,
-            )
+                # Try a different scale
+                mod = get_mod(
+                    data_shape=(32,),
+                    data_dtype="int32",
+                    out_dtype=qnn_out_dtype,
+                    input_scale=3,
+                    output_scale=1,
+                    rounding=rounding,
+                )
 
-            # Try positive values
-            golden_data = np.arange(0, 32, 1).astype("int32")
-            golden_output = np.multiply(golden_data, 3)
-            verify(mod, (golden_data, golden_output))
+                # Try positive values
+                golden_data = np.arange(0, 32, 1).astype("int32")
+                golden_output = np.multiply(golden_data, 3)
+                verify(mod, (golden_data, golden_output))
 
-            # Try negative values
-            golden_data = np.arange(0, -32, -1).astype("int32")
-            golden_output = np.multiply(golden_data, 3)
-            verify(mod, (golden_data, golden_output))
+                # Try negative values
+                golden_data = np.arange(0, -32, -1).astype("int32")
+                golden_output = np.multiply(golden_data, 3)
+                verify(mod, (golden_data, golden_output))
 
 
-def test_saturation():
+def test_saturation_int8():
     for compute_dtype in compute_dtypes:
         for rounding in roundings:
             mod = get_mod(
@@ -322,6 +328,70 @@ def test_saturation():
             verify(mod, (golden_data, golden_output))
 
 
+def test_saturation_int16():
+    for compute_dtype in compute_dtypes:
+        for rounding in roundings:
+            mod = get_mod(
+                data_shape=(16,),
+                data_dtype="int32",
+                out_dtype="int16",
+                input_scale=0.5,
+                output_scale=0.5,
+                rounding=rounding,
+                compute_dtype=compute_dtype,
+            )
+            golden_data = np.arange(0, 16, 1).astype("int32")
+            golden_data = np.add(32760, golden_data)
+            output = np.array(
+                [
+                    32760,
+                    32761,
+                    32762,
+                    32763,
+                    32764,
+                    32765,
+                    32766,
+                    32767,
+                    32767,
+                    32767,
+                    32767,
+                    32767,
+                    32767,
+                    32767,
+                    32767,
+                    32767,
+                ]
+            )
+            golden_output = output
+            verify(mod, (golden_data, golden_output))
+
+            # Try negative numbers
+            golden_data = np.arange(0, -16, -1).astype("int32")
+            golden_data = np.add(-32760, golden_data)
+            output = np.array(
+                [
+                    -32760,
+                    -32761,
+                    -32762,
+                    -32763,
+                    -32764,
+                    -32765,
+                    -32766,
+                    -32767,
+                    -32768,
+                    -32768,
+                    -32768,
+                    -32768,
+                    -32768,
+                    -32768,
+                    -32768,
+                    -32768,
+                ]
+            )
+            golden_output = output
+            verify(mod, (golden_data, golden_output))
+
+
 def test_zero_point():
     # Output zero point
     for compute_dtype in compute_dtypes:
@@ -357,31 +427,32 @@ def test_zero_point():
     # Input zero point
     for compute_dtype in compute_dtypes:
         for rounding in roundings:
-            mod = get_mod(
-                data_shape=(32,),
-                data_dtype="int32",
-                out_dtype="int8",
-                input_scale=1,
-                output_scale=16,
-                input_zero_point=16,
-                rounding=rounding,
-                compute_dtype=compute_dtype,
-            )
+            for qnn_out_dtype in out_dtypes:
+                mod = get_mod(
+                    data_shape=(32,),
+                    data_dtype="int32",
+                    out_dtype=qnn_out_dtype,
+                    input_scale=1,
+                    output_scale=16,
+                    input_zero_point=16,
+                    rounding=rounding,
+                    compute_dtype=compute_dtype,
+                )
 
-            # Try positive values
-            golden_data = np.arange(32, 64, 1).astype("int32")
-            golden_output = np.repeat([2, 3, 4], [8, 16, 8])
-            golden_output = np.subtract(golden_output, 1)
-            verify(mod, (golden_data, golden_output))
+                # Try positive values
+                golden_data = np.arange(32, 64, 1).astype("int32")
+                golden_output = np.repeat([2, 3, 4], [8, 16, 8])
+                golden_output = np.subtract(golden_output, 1)
+                verify(mod, (golden_data, golden_output))
 
-            # Try negative values
-            golden_data = np.arange(-32, -64, -1).astype("int32")
-            if rounding == "UPWARD":
-                golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
-            else:
-                golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
-            golden_output = np.subtract(golden_output, 1)
-            verify(mod, (golden_data, golden_output))
+                # Try negative values
+                golden_data = np.arange(-32, -64, -1).astype("int32")
+                if rounding == "UPWARD":
+                    golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
+                else:
+                    golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
+                golden_output = np.subtract(golden_output, 1)
+                verify(mod, (golden_data, golden_output))
 
 
 def test_per_channel_same_scale():
@@ -390,17 +461,18 @@ def test_per_channel_same_scale():
     golden_output = golden_data
     for compute_dtype in compute_dtypes:
         for rounding in roundings:
-            mod = get_mod(
-                data_shape=(5, 2),
-                data_dtype="int32",
-                out_dtype="int8",
-                input_scale=[0.5, 0.5],
-                output_scale=0.5,
-                axis=1,
-                rounding=rounding,
-                compute_dtype=compute_dtype,
-            )
-            verify(mod, (golden_data, golden_output))
+            for qnn_out_dtype in out_dtypes:
+                mod = get_mod(
+                    data_shape=(5, 2),
+                    data_dtype="int32",
+                    out_dtype=qnn_out_dtype,
+                    input_scale=[0.5, 0.5],
+                    output_scale=0.5,
+                    axis=1,
+                    rounding=rounding,
+                    compute_dtype=compute_dtype,
+                )
+                verify(mod, (golden_data, golden_output))
 
     # Change axis
     golden_data = np.arange(-10, 10, 1).astype("int32").reshape((2, 2, 5))
@@ -480,88 +552,93 @@ def test_per_channel_different_scale():
 
 
 def test_default_cfg_and_no_args():
-    mod = get_mod(
-        data_shape=(32,),
-        data_dtype="int32",
-        out_dtype="int8",
-        input_scale=1,
-        output_scale=16,
-    )
-    golden_data = np.arange(0, -32, -1).astype("int32")
-    golden_output = np.repeat([0, -1, -2], [9, 16, 7])
-    verify(mod, (golden_data, golden_output))
+    for qnn_out_dtype in out_dtypes:
+        mod = get_mod(
+            data_shape=(32,),
+            data_dtype="int32",
+            out_dtype=qnn_out_dtype,
+            input_scale=1,
+            output_scale=16,
+        )
+        golden_data = np.arange(0, -32, -1).astype("int32")
+        golden_output = np.repeat([0, -1, -2], [9, 16, 7])
+        verify(mod, (golden_data, golden_output))
 
 
 def test_non_default_cfg_and_no_args():
     for rounding_cfg in roundings:
-        with relay.qnn.op.requantize_config(rounding=rounding_cfg):
-            mod = get_mod(
-                data_shape=(32,),
-                data_dtype="int32",
-                out_dtype="int8",
-                input_scale=1,
-                output_scale=16,
-            )
+        for qnn_out_dtype in out_dtypes:
+            with relay.qnn.op.requantize_config(rounding=rounding_cfg):
+                mod = get_mod(
+                    data_shape=(32,),
+                    data_dtype="int32",
+                    out_dtype=qnn_out_dtype,
+                    input_scale=1,
+                    output_scale=16,
+                )
 
-            golden_data = np.arange(0, -32, -1).astype("int32")
+                golden_data = np.arange(0, -32, -1).astype("int32")
 
-            if rounding_cfg == "UPWARD":
-                golden_output = np.repeat([0, -1, -2], [9, 16, 7])
-            else:
-                golden_output = np.repeat([0, -1, -2], [8, 16, 8])
-            verify(mod, (golden_data, golden_output))
+                if rounding_cfg == "UPWARD":
+                    golden_output = np.repeat([0, -1, -2], [9, 16, 7])
+                else:
+                    golden_output = np.repeat([0, -1, -2], [8, 16, 8])
+                verify(mod, (golden_data, golden_output))
 
 
 def test_default_cfg_and_args():
     for rounding in roundings:
-        with relay.qnn.op.requantize_config(rounding="UPWARD"):
-            mod = get_mod(
-                data_shape=(32,),
-                data_dtype="int32",
-                out_dtype="int8",
-                input_scale=1,
-                output_scale=16,
-                rounding=rounding,
-            )
-
-            golden_data = np.arange(0, -32, -1).astype("int32")
-
-            if rounding == "UPWARD":
-                golden_output = np.repeat([0, -1, -2], [9, 16, 7])
-            else:
-                golden_output = np.repeat([0, -1, -2], [8, 16, 8])
-            verify(mod, (golden_data, golden_output))
-
-
-def test_non_default_cfg_and_args():
-    for rounding_arg in roundings:
-        for rounding_cfg in roundings:
-            with relay.qnn.op.requantize_config(rounding=rounding_cfg):
+        for qnn_out_dtype in out_dtypes:
+            with relay.qnn.op.requantize_config(rounding="UPWARD"):
                 mod = get_mod(
                     data_shape=(32,),
                     data_dtype="int32",
-                    out_dtype="int8",
+                    out_dtype=qnn_out_dtype,
                     input_scale=1,
                     output_scale=16,
-                    rounding=rounding_arg,
+                    rounding=rounding,
                 )
 
                 golden_data = np.arange(0, -32, -1).astype("int32")
 
-                if rounding_arg == "UPWARD":
+                if rounding == "UPWARD":
                     golden_output = np.repeat([0, -1, -2], [9, 16, 7])
                 else:
                     golden_output = np.repeat([0, -1, -2], [8, 16, 8])
                 verify(mod, (golden_data, golden_output))
 
 
+def test_non_default_cfg_and_args():
+    for rounding_arg in roundings:
+        for rounding_cfg in roundings:
+            for qnn_out_dtype in out_dtypes:
+                with relay.qnn.op.requantize_config(rounding=rounding_cfg):
+                    mod = get_mod(
+                        data_shape=(32,),
+                        data_dtype="int32",
+                        out_dtype=qnn_out_dtype,
+                        input_scale=1,
+                        output_scale=16,
+                        rounding=rounding_arg,
+                    )
+
+                    golden_data = np.arange(0, -32, -1).astype("int32")
+
+                    if rounding_arg == "UPWARD":
+                        golden_output = np.repeat([0, -1, -2], [9, 16, 7])
+                    else:
+                        golden_output = np.repeat([0, -1, -2], [8, 16, 8])
+                    verify(mod, (golden_data, golden_output))
+
+
 if __name__ == "__main__":
     test_same_scale()
     test_scalar_same_scale()
     test_downscale()
     test_upscale()
     test_non_power_of_two()
-    test_saturation()
+    test_saturation_int8()
+    test_saturation_int16()
     test_zero_point()
     test_per_channel_same_scale()
     test_per_channel_different_scale()

Reply via email to