This is an automated email from the ASF dual-hosted git repository.
mousius pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 034dc67d03 [TFLite] Enable int64 biases for int16 quantized operators
(#12042)
034dc67d03 is described below
commit 034dc67d032aac3b848e15a87a7fbb5b72a0b909
Author: Leandro Nunes <[email protected]>
AuthorDate: Tue Nov 15 10:30:50 2022 +0000
[TFLite] Enable int64 biases for int16 quantized operators (#12042)
This enables int64 biases for quantized fully connected, requantize
and transpose convolution in TFLite networks. It goes on top of existing
int16 support for TFLite frontend.
Add a test case using DS_CNN int16 quantized.
---
python/tvm/relay/frontend/tflite.py | 6 +-
src/relay/qnn/op/convolution_transpose.cc | 10 +-
src/relay/qnn/op/dense.cc | 10 +-
src/relay/qnn/op/requantize.cc | 5 +-
.../test_ethosn/test_convert_equivalents.py | 4 +-
tests/python/frontend/tflite/test_forward.py | 23 +
tests/python/relay/test_op_qnn_requantize.py | 495 ++++++++++++---------
7 files changed, 329 insertions(+), 224 deletions(-)
diff --git a/python/tvm/relay/frontend/tflite.py
b/python/tvm/relay/frontend/tflite.py
index 1915eb9322..3d2f4a2f25 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -1966,7 +1966,7 @@ class OperatorConverter(object):
input_scale=input_tensor.qnn_params["scale"],
kernel_scale=weight_tensor.qnn_params["scale"],
units=weight_shape[0],
- out_dtype="int32",
+ out_dtype="int64" if output_tensor_type_str == "int16" else
"int32",
)
else:
out = _op.nn.dense(in_expr, weight_expr, units=weight_shape[0])
@@ -1977,7 +1977,7 @@ class OperatorConverter(object):
if bias_tensor.tensor_idx != -1:
bias_tensor_type = bias_tensor.tensor.Type()
# bias tensor type should be INT32 (quantization) or FLOAT32
- assert bias_tensor_type in (TensorType.INT32,
TensorType.FLOAT32)
+ assert bias_tensor_type in (TensorType.INT32,
TensorType.INT64, TensorType.FLOAT32)
bias_tensor_type_str =
self.get_tensor_type_str(bias_tensor_type)
if self.has_expr(bias_tensor.tensor_idx):
bias_expr = self.get_expr(bias_tensor.tensor_idx)
@@ -3175,7 +3175,7 @@ class OperatorConverter(object):
bias_tensor = input_tensors[3]
bias_tensor_type = bias_tensor.tensor.Type()
# bias tensor type should be INT32 (quantization) or FLOAT32
- assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
+ assert bias_tensor_type in (TensorType.INT32, TensorType.INT64,
TensorType.FLOAT32)
bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
if self.has_expr(bias_tensor.tensor_idx):
bias_expr = self.get_expr(bias_tensor.tensor_idx)
diff --git a/src/relay/qnn/op/convolution_transpose.cc
b/src/relay/qnn/op/convolution_transpose.cc
index 6163e1c204..951c1bdfb0 100644
--- a/src/relay/qnn/op/convolution_transpose.cc
+++ b/src/relay/qnn/op/convolution_transpose.cc
@@ -93,12 +93,14 @@ bool QnnConv2DTransposeRel(const Array<Type>& types, int
num_inputs, const Attrs
if (data == nullptr || weight == nullptr) return false;
const auto* param = attrs.as<Conv2DTransposeAttrs>();
ICHECK(param != nullptr) << "Conv2DTransposeAttrs cannot be nullptr.";
- ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
- << "Expected qnn conv2d type(int8, uint8) for input but was " <<
data->dtype;
+ ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8) ||
+ data->dtype == DataType::Int(16) || data->dtype == DataType::UInt(16))
+ << "Expected qnn conv2d type(int8, uint8, int16) for input but was " <<
data->dtype;
ICHECK(weight->dtype == DataType::Int(8) || weight->dtype ==
DataType::UInt(8))
<< "Expected qnn conv2d type(int8, uint8) for weight but was " <<
weight->dtype;
- ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype ==
DataType::Int(32))
- << "Expected qnn conv2d type(int32, int16) for output but was " <<
param->out_dtype;
+ ICHECK(param->out_dtype == DataType::Int(16) || param->out_dtype ==
DataType::Int(32) ||
+ data->dtype == DataType::Int(64))
+ << "Expected qnn conv2d type(int16, int32, int64) for output but was "
<< param->out_dtype;
ICHECK(param->out_dtype.bits() > 0) << "Output dtype bits should be greater
than 0.";
// Check the types of scale and zero points.
diff --git a/src/relay/qnn/op/dense.cc b/src/relay/qnn/op/dense.cc
index adaf509e7d..09d51e3c9c 100644
--- a/src/relay/qnn/op/dense.cc
+++ b/src/relay/qnn/op/dense.cc
@@ -47,12 +47,14 @@ bool QnnDenseRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs,
if (data == nullptr || weight == nullptr) return false;
const auto* param = attrs.as<DenseAttrs>();
ICHECK(param != nullptr) << "DenseAttrs cannot be nullptr.";
- ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8))
- << "Expected quantized dense type(int8, uint8) for input but was " <<
data->dtype;
+ ICHECK(data->dtype == DataType::Int(8) || data->dtype == DataType::UInt(8) ||
+ data->dtype == DataType::Int(16) || data->dtype == DataType::UInt(16))
+ << "Expected quantized dense type(int8, uint8, int16, uint16) for input
but was "
+ << data->dtype;
ICHECK(weight->dtype == DataType::Int(8) || weight->dtype ==
DataType::UInt(8))
<< "Expected quantized dense type(int8, uint8) for weight but was " <<
weight->dtype;
- ICHECK(param->out_dtype == DataType::Int(32))
- << "Expected quantized dense type(int32) for output but was " <<
param->out_dtype;
+ ICHECK(param->out_dtype == DataType::Int(32) || param->out_dtype ==
DataType::Int(64))
+ << "Expected quantized dense type(int32, int64) for output but was " <<
param->out_dtype;
// Check the types of scale and zero points.
for (size_t i = 2; i < 5; ++i) {
diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc
index 1614652719..e199ea27f1 100644
--- a/src/relay/qnn/op/requantize.cc
+++ b/src/relay/qnn/op/requantize.cc
@@ -480,8 +480,9 @@ bool RequantizeRel(const Array<Type>& types, int
num_inputs, const Attrs& attrs,
}
const auto in_dtype = data->dtype;
ICHECK(in_dtype == DataType::Int(8) || in_dtype == DataType::UInt(8) ||
- in_dtype == DataType::Int(32) || in_dtype == DataType::Int(64))
- << "Input type should be one of [int8, uint8, int32, int64] but was " <<
in_dtype;
+ in_dtype == DataType::Int(16) || in_dtype == DataType::Int(32) ||
+ in_dtype == DataType::Int(64))
+ << "Input type should be one of [int8, uint8, int16, int32, int64] but
was " << in_dtype;
const RequantizeAttrs* requantize_attrs = attrs.as<RequantizeAttrs>();
int axis = requantize_attrs->axis;
diff --git a/tests/python/contrib/test_ethosn/test_convert_equivalents.py
b/tests/python/contrib/test_ethosn/test_convert_equivalents.py
index 7777729372..a3e48f4424 100644
--- a/tests/python/contrib/test_ethosn/test_convert_equivalents.py
+++ b/tests/python/contrib/test_ethosn/test_convert_equivalents.py
@@ -227,7 +227,7 @@ def test_multiply_to_reinterpret_quantize(shape,
constant_shape, reverse_inputs)
@requires_ethosn
@pytest.mark.parametrize(
"dtype,shape,constant_shape",
- [("int16", (1, 16, 12, 4), None)],
+ [("float32", (1, 16, 12, 4), None)],
)
def test_unsupported_multiply_to_reinterpret_quantize(dtype, shape,
constant_shape):
"""
@@ -445,7 +445,7 @@ def test_add_to_reinterpret_quantize(shape, constant_shape,
reverse_inputs):
@pytest.mark.parametrize(
"dtype,shape,constant_shape",
[
- ("int16", (1, 16, 12, 4), None),
+ ("float32", (1, 16, 12, 4), None),
],
)
def test_unsupported_add_to_reinterpret_quantize(dtype, shape, constant_shape):
diff --git a/tests/python/frontend/tflite/test_forward.py
b/tests/python/frontend/tflite/test_forward.py
index 7b2bd60d8a..877406ae2a 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -4878,6 +4878,28 @@ def test_forward_mobilenet_int16():
tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
+def test_forward_ds_cnn_int16():
+ """Test DS_CNN int16 quantized model"""
+ tflite_model_file = download_testdata(
+
"https://github.com/ARM-software/ML-zoo/blob/48f458af1e9065d9aad2ad94d24b58d6e7c00817/"
+
"models/keyword_spotting/ds_cnn_small/tflite_int16/ds_cnn_quantized.tflite?raw=true",
+ "ds_cnn_quantized_int16.tflite",
+ )
+
+ with open(tflite_model_file, "rb") as f:
+ tflite_model_buf = f.read()
+
+ data = np.random.uniform(size=(1, 490)).astype("int16")
+
+ tflite_output = run_tflite_graph(tflite_model_buf, data)
+ tflite_predictions = np.squeeze(tflite_output)
+ tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
+ tvm_output = run_tvm_graph(tflite_model_buf, data,
"serving_default_input:0")
+ tvm_predictions = np.squeeze(tvm_output)
+ tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
+ tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
+
+
#######################################################################
# Unidirectional Sequence LSTM
# ---------------------
@@ -5250,3 +5272,4 @@ if __name__ == "__main__":
test_forward_tflite_float16()
test_forward_tflite_int16()
+ test_forward_ds_cnn_int16()
diff --git a/tests/python/relay/test_op_qnn_requantize.py
b/tests/python/relay/test_op_qnn_requantize.py
index 64306476df..1dee1f5b61 100644
--- a/tests/python/relay/test_op_qnn_requantize.py
+++ b/tests/python/relay/test_op_qnn_requantize.py
@@ -23,6 +23,7 @@ from tvm.contrib import graph_executor
roundings = ["UPWARD", "TONEAREST"]
compute_dtypes = ["float32", "float64", "int64"]
+out_dtypes = ["int8", "int16"]
def verify(mod, goldens, target="llvm"):
@@ -83,17 +84,18 @@ def test_same_scale():
golden_output = golden_data
for compute_dtype in compute_dtypes:
for rounding in roundings:
- mod = get_mod(
- data_shape=(200,),
- data_dtype="int32",
- out_dtype="int8",
- input_scale=0.5,
- output_scale=0.5,
- rounding=rounding,
- compute_dtype=compute_dtype,
- )
- assert "right_shift" not in mod.astext()
- verify(mod, (golden_data, golden_output))
+ for qnn_out_dtype in out_dtypes:
+ mod = get_mod(
+ data_shape=(200,),
+ data_dtype="int32",
+ out_dtype=qnn_out_dtype,
+ input_scale=0.5,
+ output_scale=0.5,
+ rounding=rounding,
+ compute_dtype=compute_dtype,
+ )
+ assert "right_shift" not in mod.astext()
+ verify(mod, (golden_data, golden_output))
def test_scalar_same_scale():
@@ -102,75 +104,77 @@ def test_scalar_same_scale():
golden_output = golden_data
for compute_dtype in compute_dtypes:
for rounding in roundings:
- mod = get_mod(
- data_shape=(),
- data_dtype="int32",
- out_dtype="int8",
- input_scale=0.5,
- output_scale=0.5,
- rounding=rounding,
- compute_dtype=compute_dtype,
- )
- assert "right_shift" not in mod.astext()
- verify(mod, (golden_data, golden_output))
+ for qnn_out_dtype in out_dtypes:
+ mod = get_mod(
+ data_shape=(),
+ data_dtype="int32",
+ out_dtype=qnn_out_dtype,
+ input_scale=0.5,
+ output_scale=0.5,
+ rounding=rounding,
+ compute_dtype=compute_dtype,
+ )
+ assert "right_shift" not in mod.astext()
+ verify(mod, (golden_data, golden_output))
def test_downscale():
for compute_dtype in compute_dtypes:
for rounding in roundings:
- mod = get_mod(
- data_shape=(32,),
- data_dtype="int32",
- out_dtype="int8",
- input_scale=1,
- output_scale=16,
- rounding=rounding,
- compute_dtype=compute_dtype,
- )
+ for qnn_out_dtype in out_dtypes:
+ mod = get_mod(
+ data_shape=(32,),
+ data_dtype="int32",
+ out_dtype=qnn_out_dtype,
+ input_scale=1,
+ output_scale=16,
+ rounding=rounding,
+ compute_dtype=compute_dtype,
+ )
- # Try positive values
- # 8 corresponds to 0.5, resulting in 1
- golden_data = np.arange(0, 32, 1).astype("int32")
- golden_output = np.repeat([0, 1, 2], [8, 16, 8])
- verify(mod, (golden_data, golden_output))
+ # Try positive values
+ # 8 corresponds to 0.5, resulting in 1
+ golden_data = np.arange(0, 32, 1).astype("int32")
+ golden_output = np.repeat([0, 1, 2], [8, 16, 8])
+ verify(mod, (golden_data, golden_output))
- # Try negative values
- # -8 corresponds to -0.5. For UPWARD, this is 0
- golden_data = np.arange(0, -32, -1).astype("int32")
- if rounding == "UPWARD":
- golden_output = np.repeat([0, -1, -2], [9, 16, 7])
- else:
- golden_output = np.repeat([0, -1, -2], [8, 16, 8])
- verify(mod, (golden_data, golden_output))
+ # Try negative values
+ # -8 corresponds to -0.5. For UPWARD, this is 0
+ golden_data = np.arange(0, -32, -1).astype("int32")
+ if rounding == "UPWARD":
+ golden_output = np.repeat([0, -1, -2], [9, 16, 7])
+ else:
+ golden_output = np.repeat([0, -1, -2], [8, 16, 8])
+ verify(mod, (golden_data, golden_output))
- # Try a different scale
- mod = get_mod(
- data_shape=(32,),
- data_dtype="int32",
- out_dtype="int8",
- input_scale=1,
- output_scale=4,
- rounding=rounding,
- )
+ # Try a different scale
+ mod = get_mod(
+ data_shape=(32,),
+ data_dtype="int32",
+ out_dtype=qnn_out_dtype,
+ input_scale=1,
+ output_scale=4,
+ rounding=rounding,
+ )
- # Try positive values
- # 2I corresponds to 0.5, resulting in 1
- golden_data = np.arange(0, 32, 1).astype("int32")
- golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], [2, 4, 4,
4, 4, 4, 4, 4, 2])
- verify(mod, (golden_data, golden_output))
+ # Try positive values
+ # 2I corresponds to 0.5, resulting in 1
+ golden_data = np.arange(0, 32, 1).astype("int32")
+ golden_output = np.repeat([0, 1, 2, 3, 4, 5, 6, 7, 8], [2, 4,
4, 4, 4, 4, 4, 4, 2])
+ verify(mod, (golden_data, golden_output))
- # Try negative values
- # -8 corresponds to -0.5. For UPWARD, this is 0
- golden_data = np.arange(0, -32, -1).astype("int32")
- if rounding == "UPWARD":
- golden_output = np.repeat(
- [0, -1, -2, -3, -4, -5, -6, -7, -8], [3, 4, 4, 4, 4, 4, 4,
4, 1]
- )
- else:
- golden_output = np.repeat(
- [0, -1, -2, -3, -4, -5, -6, -7, -8], [2, 4, 4, 4, 4, 4, 4,
4, 2]
- )
- verify(mod, (golden_data, golden_output))
+ # Try negative values
+ # -8 corresponds to -0.5. For UPWARD, this is 0
+ golden_data = np.arange(0, -32, -1).astype("int32")
+ if rounding == "UPWARD":
+ golden_output = np.repeat(
+ [0, -1, -2, -3, -4, -5, -6, -7, -8], [3, 4, 4, 4, 4,
4, 4, 4, 1]
+ )
+ else:
+ golden_output = np.repeat(
+ [0, -1, -2, -3, -4, -5, -6, -7, -8], [2, 4, 4, 4, 4,
4, 4, 4, 2]
+ )
+ verify(mod, (golden_data, golden_output))
# Try uint8 out_dtype
mod = get_mod(
@@ -208,74 +212,76 @@ def test_downscale():
def test_upscale():
for compute_dtype in compute_dtypes:
for rounding in roundings:
- mod = get_mod(
- data_shape=(32,),
- data_dtype="int32",
- out_dtype="int8",
- input_scale=2,
- output_scale=1,
- rounding=rounding,
- compute_dtype=compute_dtype,
- )
+ for qnn_out_dtype in out_dtypes:
+ mod = get_mod(
+ data_shape=(32,),
+ data_dtype="int32",
+ out_dtype=qnn_out_dtype,
+ input_scale=2,
+ output_scale=1,
+ rounding=rounding,
+ compute_dtype=compute_dtype,
+ )
- # Try positive values
- # 8 corresponds to 0.5, resulting in 1
- golden_data = np.arange(0, 32, 1).astype("int32")
- golden_output = np.multiply(2, golden_data)
- verify(mod, (golden_data, golden_output))
+ # Try positive values
+ # 8 corresponds to 0.5, resulting in 1
+ golden_data = np.arange(0, 32, 1).astype("int32")
+ golden_output = np.multiply(2, golden_data)
+ verify(mod, (golden_data, golden_output))
- # Try negative values
- # -8 corresponds to -0.5. For UPWARD, this is 0
- golden_data = np.arange(0, -32, -1).astype("int32")
- golden_output = np.multiply(2, golden_data)
- verify(mod, (golden_data, golden_output))
+ # Try negative values
+ # -8 corresponds to -0.5. For UPWARD, this is 0
+ golden_data = np.arange(0, -32, -1).astype("int32")
+ golden_output = np.multiply(2, golden_data)
+ verify(mod, (golden_data, golden_output))
def test_non_power_of_two():
for compute_dtype in compute_dtypes:
for rounding in roundings:
- mod = get_mod(
- data_shape=(32,),
- data_dtype="int32",
- out_dtype="int8",
- input_scale=1,
- output_scale=3,
- rounding=rounding,
- compute_dtype=compute_dtype,
- )
+ for qnn_out_dtype in out_dtypes:
+ mod = get_mod(
+ data_shape=(32,),
+ data_dtype="int32",
+ out_dtype=qnn_out_dtype,
+ input_scale=1,
+ output_scale=3,
+ rounding=rounding,
+ compute_dtype=compute_dtype,
+ )
- # Try positive values
- golden_data = np.multiply(np.arange(0, 32, 1).astype("int32"), 3)
- golden_output = np.arange(0, 32, 1)
- verify(mod, (golden_data, golden_output))
+ # Try positive values
+ golden_data = np.multiply(np.arange(0, 32, 1).astype("int32"),
3)
+ golden_output = np.arange(0, 32, 1)
+ verify(mod, (golden_data, golden_output))
- # Try negative values
- golden_data = np.multiply(np.arange(0, -32, -1).astype("int32"), 3)
- golden_output = np.arange(0, -32, -1)
- verify(mod, (golden_data, golden_output))
+ # Try negative values
+ golden_data = np.multiply(np.arange(0, -32,
-1).astype("int32"), 3)
+ golden_output = np.arange(0, -32, -1)
+ verify(mod, (golden_data, golden_output))
- # Try a different scale
- mod = get_mod(
- data_shape=(32,),
- data_dtype="int32",
- out_dtype="int8",
- input_scale=3,
- output_scale=1,
- rounding=rounding,
- )
+ # Try a different scale
+ mod = get_mod(
+ data_shape=(32,),
+ data_dtype="int32",
+ out_dtype=qnn_out_dtype,
+ input_scale=3,
+ output_scale=1,
+ rounding=rounding,
+ )
- # Try positive values
- golden_data = np.arange(0, 32, 1).astype("int32")
- golden_output = np.multiply(golden_data, 3)
- verify(mod, (golden_data, golden_output))
+ # Try positive values
+ golden_data = np.arange(0, 32, 1).astype("int32")
+ golden_output = np.multiply(golden_data, 3)
+ verify(mod, (golden_data, golden_output))
- # Try negative values
- golden_data = np.arange(0, -32, -1).astype("int32")
- golden_output = np.multiply(golden_data, 3)
- verify(mod, (golden_data, golden_output))
+ # Try negative values
+ golden_data = np.arange(0, -32, -1).astype("int32")
+ golden_output = np.multiply(golden_data, 3)
+ verify(mod, (golden_data, golden_output))
-def test_saturation():
+def test_saturation_int8():
for compute_dtype in compute_dtypes:
for rounding in roundings:
mod = get_mod(
@@ -322,6 +328,70 @@ def test_saturation():
verify(mod, (golden_data, golden_output))
+def test_saturation_int16():
+ for compute_dtype in compute_dtypes:
+ for rounding in roundings:
+ mod = get_mod(
+ data_shape=(16,),
+ data_dtype="int32",
+ out_dtype="int16",
+ input_scale=0.5,
+ output_scale=0.5,
+ rounding=rounding,
+ compute_dtype=compute_dtype,
+ )
+ golden_data = np.arange(0, 16, 1).astype("int32")
+ golden_data = np.add(32760, golden_data)
+ output = np.array(
+ [
+ 32760,
+ 32761,
+ 32762,
+ 32763,
+ 32764,
+ 32765,
+ 32766,
+ 32767,
+ 32767,
+ 32767,
+ 32767,
+ 32767,
+ 32767,
+ 32767,
+ 32767,
+ 32767,
+ ]
+ )
+ golden_output = output
+ verify(mod, (golden_data, golden_output))
+
+ # Try negative numbers
+ golden_data = np.arange(0, -16, -1).astype("int32")
+ golden_data = np.add(-32760, golden_data)
+ output = np.array(
+ [
+ -32760,
+ -32761,
+ -32762,
+ -32763,
+ -32764,
+ -32765,
+ -32766,
+ -32767,
+ -32768,
+ -32768,
+ -32768,
+ -32768,
+ -32768,
+ -32768,
+ -32768,
+ -32768,
+ ]
+ )
+ golden_output = output
+ verify(mod, (golden_data, golden_output))
+
+
def test_zero_point():
# Output zero point
for compute_dtype in compute_dtypes:
@@ -357,31 +427,32 @@ def test_zero_point():
# Input zero point
for compute_dtype in compute_dtypes:
for rounding in roundings:
- mod = get_mod(
- data_shape=(32,),
- data_dtype="int32",
- out_dtype="int8",
- input_scale=1,
- output_scale=16,
- input_zero_point=16,
- rounding=rounding,
- compute_dtype=compute_dtype,
- )
+ for qnn_out_dtype in out_dtypes:
+ mod = get_mod(
+ data_shape=(32,),
+ data_dtype="int32",
+ out_dtype=qnn_out_dtype,
+ input_scale=1,
+ output_scale=16,
+ input_zero_point=16,
+ rounding=rounding,
+ compute_dtype=compute_dtype,
+ )
- # Try positive values
- golden_data = np.arange(32, 64, 1).astype("int32")
- golden_output = np.repeat([2, 3, 4], [8, 16, 8])
- golden_output = np.subtract(golden_output, 1)
- verify(mod, (golden_data, golden_output))
+ # Try positive values
+ golden_data = np.arange(32, 64, 1).astype("int32")
+ golden_output = np.repeat([2, 3, 4], [8, 16, 8])
+ golden_output = np.subtract(golden_output, 1)
+ verify(mod, (golden_data, golden_output))
- # Try negative values
- golden_data = np.arange(-32, -64, -1).astype("int32")
- if rounding == "UPWARD":
- golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
- else:
- golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
- golden_output = np.subtract(golden_output, 1)
- verify(mod, (golden_data, golden_output))
+ # Try negative values
+ golden_data = np.arange(-32, -64, -1).astype("int32")
+ if rounding == "UPWARD":
+ golden_output = np.repeat([-2, -3, -4], [9, 16, 7])
+ else:
+ golden_output = np.repeat([-2, -3, -4], [8, 16, 8])
+ golden_output = np.subtract(golden_output, 1)
+ verify(mod, (golden_data, golden_output))
def test_per_channel_same_scale():
@@ -390,17 +461,18 @@ def test_per_channel_same_scale():
golden_output = golden_data
for compute_dtype in compute_dtypes:
for rounding in roundings:
- mod = get_mod(
- data_shape=(5, 2),
- data_dtype="int32",
- out_dtype="int8",
- input_scale=[0.5, 0.5],
- output_scale=0.5,
- axis=1,
- rounding=rounding,
- compute_dtype=compute_dtype,
- )
- verify(mod, (golden_data, golden_output))
+ for qnn_out_dtype in out_dtypes:
+ mod = get_mod(
+ data_shape=(5, 2),
+ data_dtype="int32",
+ out_dtype=qnn_out_dtype,
+ input_scale=[0.5, 0.5],
+ output_scale=0.5,
+ axis=1,
+ rounding=rounding,
+ compute_dtype=compute_dtype,
+ )
+ verify(mod, (golden_data, golden_output))
# Change axis
golden_data = np.arange(-10, 10, 1).astype("int32").reshape((2, 2, 5))
@@ -480,88 +552,93 @@ def test_per_channel_different_scale():
def test_default_cfg_and_no_args():
- mod = get_mod(
- data_shape=(32,),
- data_dtype="int32",
- out_dtype="int8",
- input_scale=1,
- output_scale=16,
- )
- golden_data = np.arange(0, -32, -1).astype("int32")
- golden_output = np.repeat([0, -1, -2], [9, 16, 7])
- verify(mod, (golden_data, golden_output))
+ for qnn_out_dtype in out_dtypes:
+ mod = get_mod(
+ data_shape=(32,),
+ data_dtype="int32",
+ out_dtype=qnn_out_dtype,
+ input_scale=1,
+ output_scale=16,
+ )
+ golden_data = np.arange(0, -32, -1).astype("int32")
+ golden_output = np.repeat([0, -1, -2], [9, 16, 7])
+ verify(mod, (golden_data, golden_output))
def test_non_default_cfg_and_no_args():
for rounding_cfg in roundings:
- with relay.qnn.op.requantize_config(rounding=rounding_cfg):
- mod = get_mod(
- data_shape=(32,),
- data_dtype="int32",
- out_dtype="int8",
- input_scale=1,
- output_scale=16,
- )
+ for qnn_out_dtype in out_dtypes:
+ with relay.qnn.op.requantize_config(rounding=rounding_cfg):
+ mod = get_mod(
+ data_shape=(32,),
+ data_dtype="int32",
+ out_dtype=qnn_out_dtype,
+ input_scale=1,
+ output_scale=16,
+ )
- golden_data = np.arange(0, -32, -1).astype("int32")
+ golden_data = np.arange(0, -32, -1).astype("int32")
- if rounding_cfg == "UPWARD":
- golden_output = np.repeat([0, -1, -2], [9, 16, 7])
- else:
- golden_output = np.repeat([0, -1, -2], [8, 16, 8])
- verify(mod, (golden_data, golden_output))
+ if rounding_cfg == "UPWARD":
+ golden_output = np.repeat([0, -1, -2], [9, 16, 7])
+ else:
+ golden_output = np.repeat([0, -1, -2], [8, 16, 8])
+ verify(mod, (golden_data, golden_output))
def test_default_cfg_and_args():
for rounding in roundings:
- with relay.qnn.op.requantize_config(rounding="UPWARD"):
- mod = get_mod(
- data_shape=(32,),
- data_dtype="int32",
- out_dtype="int8",
- input_scale=1,
- output_scale=16,
- rounding=rounding,
- )
-
- golden_data = np.arange(0, -32, -1).astype("int32")
-
- if rounding == "UPWARD":
- golden_output = np.repeat([0, -1, -2], [9, 16, 7])
- else:
- golden_output = np.repeat([0, -1, -2], [8, 16, 8])
- verify(mod, (golden_data, golden_output))
-
-
-def test_non_default_cfg_and_args():
- for rounding_arg in roundings:
- for rounding_cfg in roundings:
- with relay.qnn.op.requantize_config(rounding=rounding_cfg):
+ for qnn_out_dtype in out_dtypes:
+ with relay.qnn.op.requantize_config(rounding="UPWARD"):
mod = get_mod(
data_shape=(32,),
data_dtype="int32",
- out_dtype="int8",
+ out_dtype=qnn_out_dtype,
input_scale=1,
output_scale=16,
- rounding=rounding_arg,
+ rounding=rounding,
)
golden_data = np.arange(0, -32, -1).astype("int32")
- if rounding_arg == "UPWARD":
+ if rounding == "UPWARD":
golden_output = np.repeat([0, -1, -2], [9, 16, 7])
else:
golden_output = np.repeat([0, -1, -2], [8, 16, 8])
verify(mod, (golden_data, golden_output))
+def test_non_default_cfg_and_args():
+ for rounding_arg in roundings:
+ for rounding_cfg in roundings:
+ for qnn_out_dtype in out_dtypes:
+ with relay.qnn.op.requantize_config(rounding=rounding_cfg):
+ mod = get_mod(
+ data_shape=(32,),
+ data_dtype="int32",
+ out_dtype=qnn_out_dtype,
+ input_scale=1,
+ output_scale=16,
+ rounding=rounding_arg,
+ )
+
+ golden_data = np.arange(0, -32, -1).astype("int32")
+
+ if rounding_arg == "UPWARD":
+ golden_output = np.repeat([0, -1, -2], [9, 16, 7])
+ else:
+ golden_output = np.repeat([0, -1, -2], [8, 16, 8])
+ verify(mod, (golden_data, golden_output))
+
+
if __name__ == "__main__":
test_same_scale()
test_scalar_same_scale()
test_downscale()
test_upscale()
test_non_power_of_two()
- test_saturation()
+ test_saturation_int8()
+ test_saturation_int16()
test_zero_point()
test_per_channel_same_scale()
test_per_channel_different_scale()