This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 18a2ee1 [Frontend][TFLite] Implement fake quant (#8780)
18a2ee1 is described below
commit 18a2ee16036a4d70d7cfc572fe9807e6e2a70eda
Author: Euntaik <[email protected]>
AuthorDate: Sat Aug 21 07:29:43 2021 +0900
[Frontend][TFLite] Implement fake quant (#8780)
* [Frontend][TFLite] Implement fake quant
* remove unused variable
* fix linting errors
* add more tests
* use pytest parametrize instead of a separate function
---
python/tvm/relay/frontend/tflite.py | 51 ++++++++++++++++++++++++++++
tests/python/frontend/tflite/test_forward.py | 17 +++++++++-
2 files changed, 67 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relay/frontend/tflite.py
b/python/tvm/relay/frontend/tflite.py
index db6e053..4d607e4 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -91,6 +91,7 @@ class OperatorConverter(object):
"EQUAL": self.convert_equal,
"EXP": self.convert_exp,
"EXPAND_DIMS": self.convert_expand_dims,
+ "FAKE_QUANT": self.convert_fake_quant,
"FILL": self.convert_fill,
"FLOOR_DIV": self.convert_floor_div,
"FLOOR_MOD": self.convert_floor_mod,
@@ -3336,6 +3337,56 @@ class OperatorConverter(object):
self.set_prefetched_node(output_tensor.tensor_idx, dense_weight)
+ def convert_fake_quant(self, op):
+ """Convert TFLite FAKE_QUANT"""
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 1, "input tensors length should be 1"
+
+ input_tensor = input_tensors[0]
+ in_expr = self.get_expr(input_tensor.tensor_idx)
+
+ from tflite.BuiltinOptions import BuiltinOptions
+ from tflite.FakeQuantOptions import FakeQuantOptions
+
+ assert op.BuiltinOptionsType() == BuiltinOptions.FakeQuantOptions
+
+ op_options = op.BuiltinOptions()
+ fake_quant_options = FakeQuantOptions()
+ fake_quant_options.Init(op_options.Bytes, op_options.Pos)
+
+ opt_min = fake_quant_options.Min()
+ opt_max = fake_quant_options.Max()
+ narrow_range = fake_quant_options.NarrowRange()
+ num_bits = fake_quant_options.NumBits()
+
+ assert 2 <= num_bits <= 16
+
+ quant_min = 1 if narrow_range else 0
+ quant_max = (1 << num_bits) - 1
+ scale = (opt_max - opt_min) / (quant_max - quant_min)
+
+ zero_point_from_min = quant_min - opt_min / scale
+ if zero_point_from_min <= quant_min:
+ nudged_zero_point = quant_min
+ elif zero_point_from_min >= quant_max:
+ nudged_zero_point = quant_max
+ else:
+ nudged_zero_point = round(zero_point_from_min)
+
+ nudged_min = (quant_min - nudged_zero_point) * scale
+ nudged_max = (quant_max - nudged_zero_point) * scale
+
+ nudged_min_expr = _op.const(nudged_min)
+ clamped = _op.clip(in_expr, nudged_min, nudged_max)
+ clamped_shifted = _op.subtract(clamped, nudged_min_expr)
+
+ half = _op.const(0.5)
+ one = _op.const(1.0)
+ scale_expr = _op.const(scale)
+ inv_scale = _op.divide(one, scale_expr)
+ rounded = _op.floor(_op.add(_op.multiply(clamped_shifted, inv_scale),
half))
+ return _op.add(_op.multiply(rounded, scale_expr), nudged_min_expr)
+
def get_expr(self, input_tensor_idx):
return self.exp_tab.get_expr(get_tensor_name(self.subgraph,
input_tensor_idx))
diff --git a/tests/python/frontend/tflite/test_forward.py
b/tests/python/frontend/tflite/test_forward.py
index 7b7f1b1..f294103 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -322,7 +322,6 @@ def compare_tflite_with_tvm(
out_names=out_names,
mode=mode,
)
-
# WARNING: the results could well be random values clipped to 0 or
255 because of badly tuned output
# range for the specific operator. While adding test ensure that
we aren't getting only clipped values
# in output tensors that still pass the assertion. For reference
see _test_elemwise_qnn_out_range()
@@ -2618,6 +2617,22 @@ def test_forward_select():
)
[email protected]("quant_bits", [2, 4, 8, 16])
[email protected](
+ "value, min, max", [[-10.11, -6, 6], [-3.55, -6, 6], [0, -6, 6], [3.55,
-6, 6], [10.11, -6, 6]]
+)
+def test_forward_fake_quant(value, min, max, quant_bits):
+ with tf.Graph().as_default():
+ with tf.Session() as sess:
+ input = tf.placeholder(tf.float32, shape=[1], name="input")
+ out = tf.quantization.fake_quant_with_min_max_args(
+ input, min=min, max=max, num_bits=quant_bits, name=None
+ )
+
+ in_data = np.float32(value)
+ compare_tflite_with_tvm([in_data], ["input:0"], [input], [out])
+
+
# Squeeze
# -------