This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c94623a TFLite failures resulted from TF latest version upgrade
resolved (#6774)
c94623a is described below
commit c94623ad0dac5fa5ce7a9a3c4ecb794351ebc610
Author: ANSHUMAN TRIPATHY <[email protected]>
AuthorDate: Thu Oct 29 18:40:06 2020 +0530
TFLite failures resulted from TF latest version upgrade resolved (#6774)
* TFLite failures resulted from TF latest version upgrade resolved
* [1] Review comments handled
---
docker/install/ubuntu_install_tflite.sh | 6 +-
python/tvm/relay/frontend/tflite.py | 15 +++-
tests/python/frontend/tflite/test_forward.py | 115 ++++++++++++++-------------
3 files changed, 77 insertions(+), 59 deletions(-)
diff --git a/docker/install/ubuntu_install_tflite.sh
b/docker/install/ubuntu_install_tflite.sh
index 123ff52..2dfbb06 100755
--- a/docker/install/ubuntu_install_tflite.sh
+++ b/docker/install/ubuntu_install_tflite.sh
@@ -33,14 +33,14 @@ pip3 install flatbuffers
# Build the TFLite static library, necessary for building with TFLite ON.
# The library is built at:
# tensorflow/tensorflow/lite/tools/make/gen/*/lib/libtensorflow-lite.a.
-git clone https://github.com/tensorflow/tensorflow --branch=r2.1
+git clone https://github.com/tensorflow/tensorflow --branch=r2.3
./tensorflow/tensorflow/lite/tools/make/download_dependencies.sh
./tensorflow/tensorflow/lite/tools/make/build_lib.sh
# Setup tflite from schema
mkdir tflite
cd tflite
-wget -q
https://raw.githubusercontent.com/tensorflow/tensorflow/r2.1/tensorflow/lite/schema/schema.fbs
+wget -q
https://raw.githubusercontent.com/tensorflow/tensorflow/r2.3/tensorflow/lite/schema/schema.fbs
flatc --python schema.fbs
cat <<EOM >setup.py
@@ -48,7 +48,7 @@ import setuptools
setuptools.setup(
name="tflite",
- version="2.1.0",
+ version="2.3.1",
author="google",
author_email="[email protected]",
description="TFLite",
diff --git a/python/tvm/relay/frontend/tflite.py
b/python/tvm/relay/frontend/tflite.py
index f52c318..6da06ac 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -2770,7 +2770,7 @@ class OperatorConverter(object):
raise ImportError("The tflite package must be installed")
input_tensors = self.get_input_tensors(op)
- assert len(input_tensors) == 3, "input tensors length should be 3"
+ assert len(input_tensors) >= 3, "input tensors length should be >= 3"
# Input (data) Tensor. NHWC layout
input_tensor = input_tensors[2]
@@ -2843,6 +2843,19 @@ class OperatorConverter(object):
out_dtype=output_tensor_type_str,
)
+ # if we have bias
+ if len(input_tensors) == 4:
+ bias_tensor = input_tensors[3]
+ bias_tensor_type = bias_tensor.tensor.Type()
+ # bias tensor type should be INT32 (quantization) or FLOAT32
+ assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
+ bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
+ bias_expr = self.exp_tab.new_const(
+ self.get_tensor_value(bias_tensor), dtype=bias_tensor_type_str
+ )
+ channel_axis = 3
+ out = _op.nn.bias_add(out, bias_expr, axis=channel_axis)
+
return out
def convert_quantize(self, op):
diff --git a/tests/python/frontend/tflite/test_forward.py
b/tests/python/frontend/tflite/test_forward.py
index caa4180..3f860a3 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -136,14 +136,20 @@ def vmobj_to_list(o):
raise RuntimeError("Unknown object type: %s" % type(o))
-def _quantize_keras_model(keras_model, representative_data_gen):
+def _quantize_keras_model(
+ keras_model, representative_data_gen, is_float_input=False,
is_float_output=False
+):
"""Utility function to quantize a Keras model using TFLite converter."""
converter =
interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model)
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
converter.representative_dataset = representative_data_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
- converter.inference_input_type = tf.uint8
- converter.inference_output_type = tf.uint8
+ # NOTE: If representative dataset is provided, and inference input type is
not set,
+ # then converter will self add quant & dequant Op accordingly.
+ if not is_float_input:
+ converter.inference_input_type = tf.uint8
+ if not is_float_output:
+ converter.inference_output_type = tf.uint8
return converter.convert()
@@ -973,6 +979,7 @@ def _test_convolution(
[out],
quantized=quantized,
input_range=input_range,
+ experimental_new_converter=True,
)
else:
# Quantized the inputs and feed them to the convolution
@@ -1000,6 +1007,7 @@ def _test_convolution(
[out],
quantized=quantized,
input_range=input_range,
+ experimental_new_converter=True,
)
else:
data_array = np.reshape(data_array,
tensor_in_sizes).astype("float32")
@@ -1078,18 +1086,18 @@ def test_forward_convolution():
)
# TFLite2 quantized convolution testing
- if package_version.parse(tf.VERSION) >= package_version.parse("2.1.0"):
- _test_tflite2_quantized_convolution(
- [1, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], "SAME", "NHWC"
+ if package_version.parse(tf.VERSION) >= package_version.parse("2.3.0"):
+ _test_convolution(
+ [1, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], "SAME", "NHWC",
quantized=True
)
- _test_tflite2_quantized_convolution(
- [1, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], "VALID", "NHWC"
+ _test_convolution(
+ [1, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], "VALID", "NHWC",
quantized=True
)
- _test_tflite2_quantized_convolution(
- [1, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], "VALID", "NHWC"
+ _test_convolution(
+ [1, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], "VALID", "NHWC",
quantized=True
)
- _test_tflite2_quantized_convolution(
- [1, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], "SAME", "NHWC"
+ _test_convolution(
+ [1, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], "SAME", "NHWC",
quantized=True
)
# Disable as tests are flaky -
https://github.com/apache/incubator-tvm/issues/6064
@@ -2280,7 +2288,7 @@ def _test_quantize_dequantize(data):
for i in range(1):
yield [data]
- tflite_model_quant = _quantize_keras_model(keras_model,
representative_data_gen)
+ tflite_model_quant = _quantize_keras_model(keras_model,
representative_data_gen, True, True)
tflite_output = run_tflite_graph(tflite_model_quant, data)
tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
@@ -2307,7 +2315,7 @@ def _test_quantize_dequantize_const(data):
for i in range(1):
yield [data]
- tflite_model_quant = _quantize_keras_model(keras_model,
representative_data_gen)
+ tflite_model_quant = _quantize_keras_model(keras_model,
representative_data_gen, True, True)
tflite_output = run_tflite_graph(tflite_model_quant, data)
tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
@@ -2548,14 +2556,17 @@ def test_forward_padv2():
np.array([2], dtype=np.float32),
]
)
- _test_padv2(
- [
- np.arange(0, 256, dtype=np.uint8).reshape((1, 256)),
- np.array([[1, 1], [2, 2]], dtype=np.int32),
- np.array([2], dtype=np.uint8),
- ],
- quantized=True,
- )
+ # NOTE: In versions > 2.1.0, there is a bug in Tensorflow package for this
scenario.
+ # Hence, it is disabled temporarily for TF version > 2.1.0 .
+ if package_version.parse(tf.VERSION) <= package_version.parse("2.1.0"):
+ _test_padv2(
+ [
+ np.arange(0, 256, dtype=np.uint8).reshape((1, 256)),
+ np.array([[1, 1], [2, 2]], dtype=np.int32),
+ np.array([2], dtype=np.float32),
+ ],
+ quantized=True,
+ )
# Constant Values input can be scalar
_test_padv2(
@@ -2565,14 +2576,17 @@ def test_forward_padv2():
np.float32(2),
]
)
- _test_padv2(
- [
- np.arange(0, 256, dtype=np.uint8).reshape((1, 256)),
- np.array([[1, 1], [2, 2]], dtype=np.int32),
- np.uint8(10),
- ],
- quantized=True,
- )
+ # NOTE: In versions > 2.1.0, there is a bug in Tensorflow package for this
scenario.
+ # Hence, it is disabled temporarily for TF versions > 2.1.0.
+ if package_version.parse(tf.VERSION) <= package_version.parse("2.1.0"):
+ _test_padv2(
+ [
+ np.arange(0, 256, dtype=np.uint8).reshape((1, 256)),
+ np.array([[1, 1], [2, 2]], dtype=np.int32),
+ np.uint8(10),
+ ],
+ quantized=True,
+ )
#######################################################################
@@ -2870,37 +2884,28 @@ def test_forward_tanh():
def _test_relu(data, quantized=False):
""" One iteration of ReLU """
- if quantized:
- if package_version.parse(tf.VERSION) < package_version.parse("2.1.0"):
- pytest.skip("Testcase requires tflite version >= 2.1.0")
- data_in = tf.keras.layers.Input(shape=data.shape[1:])
- relu = tf.keras.layers.ReLU()(data_in)
- keras_model = tf.keras.models.Model(inputs=data_in, outputs=relu)
- input_name = data_in.name.split(":")[0]
-
- # To create quantized values with dynamic range of activations, needs
representative dataset
- def representative_data_gen():
- for i in range(1):
- yield [data]
-
- tflite_model_quant = _quantize_keras_model(keras_model,
representative_data_gen)
-
- tflite_output = run_tflite_graph(tflite_model_quant, data)
- tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
- tvm.testing.assert_allclose(
- np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5
- )
- else:
- with tf.Graph().as_default():
- in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+ with tf.Graph().as_default():
+ in_data = array_ops.placeholder(shape=data.shape, dtype="float32",
name="in_0")
+
+ if quantized:
+ inq_data = tf.quantization.fake_quant_with_min_max_args(
+ in_data, min=-10, max=10, name="inq_0"
+ )
+ input_range = {"inq_0": (-10, 10)}
+ out = nn_ops.relu(inq_data)
+ out = tf.quantization.fake_quant_with_min_max_args(out, min=0,
max=6, name="out")
+ compare_tflite_with_tvm(
+ data, "inq_0:0", [inq_data], [out], quantized=True,
input_range=input_range
+ )
+ else:
out = nn_ops.relu(in_data)
- compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out])
+ compare_tflite_with_tvm(data, "in_0:0", [in_data], [out])
def test_forward_relu():
""" ReLU """
_test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
- _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)),
quantized=True)
+ _test_relu(np.random.uniform(0, 255, (3, 6)).astype(np.uint8),
quantized=True)
#######################################################################