This is an automated email from the ASF dual-hosted git repository.
lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 28aead905b [microNPU][ETHOSU] Fix SoftMax legalization parameters
(#15069)
28aead905b is described below
commit 28aead905b2b8ea3984f75d16819a4079b11b23e
Author: Aleksei-grovety <[email protected]>
AuthorDate: Mon Jun 26 20:57:54 2023 +0400
[microNPU][ETHOSU] Fix SoftMax legalization parameters (#15069)
* [microNPU][ETHOSU] Fix Softmax activation parameters
Fix activation parameters for operations according to the values in Vela.
* fix legalization parameters
* Update test_legalize.py
* Update test_legalize.py
---
.../backend/contrib/ethosu/softmax_rewriter.py | 75 +++++++++-----
tests/python/contrib/test_ethosu/test_legalize.py | 114 +++++++++++++--------
2 files changed, 122 insertions(+), 67 deletions(-)
diff --git a/python/tvm/relay/backend/contrib/ethosu/softmax_rewriter.py
b/python/tvm/relay/backend/contrib/ethosu/softmax_rewriter.py
index 6c0a1dffc3..23d4f4b45b 100644
--- a/python/tvm/relay/backend/contrib/ethosu/softmax_rewriter.py
+++ b/python/tvm/relay/backend/contrib/ethosu/softmax_rewriter.py
@@ -82,6 +82,8 @@ class SoftmaxRewriter(DFPatternCallback):
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map:
tvm.ir.container.Map
) -> tvm.relay.Expr:
params = self.params_class(post.op.body)
+ quant_min = -128
+ quant_max = 127
ifm = post.args[0]
ifm_dtype = ifm.checked_type.dtype
@@ -121,12 +123,14 @@ class SoftmaxRewriter(DFPatternCallback):
ifm2_scale=0.0,
ifm2_zero_point=int(params.ifm.q_params.zero_point),
ofm_scale=1.0,
- ofm_zero_point=127,
+ ofm_zero_point=quant_max,
ifm_channels=depth,
ifm2_channels=1,
reversed_operands=False,
ofm_dtype="int32",
activation="LUT",
+ clip_min=-255,
+ clip_max=0,
)
# PASS 2 - SHR
@@ -147,8 +151,8 @@ class SoftmaxRewriter(DFPatternCallback):
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
- clip_min=-128,
- clip_max=127,
+ clip_min=quant_min,
+ clip_max=quant_max,
rounding_mode="NATURAL",
)
@@ -165,6 +169,9 @@ class SoftmaxRewriter(DFPatternCallback):
ofm_channels=1,
upscale="NONE",
ofm_dtype="int32",
+ activation="CLIP",
+ clip_min=quant_min,
+ clip_max=quant_max,
)
# PASS 4 - CLZ
@@ -177,6 +184,9 @@ class SoftmaxRewriter(DFPatternCallback):
ofm_scale=0.0,
ofm_zero_point=int(params.ifm.q_params.zero_point),
ofm_channels=1,
+ activation="CLIP",
+ clip_min=quant_min,
+ clip_max=quant_max,
)
# PASS 5 - Sub
@@ -196,6 +206,9 @@ class SoftmaxRewriter(DFPatternCallback):
ifm2_channels=1,
reversed_operands=False,
ofm_dtype="int32",
+ activation="CLIP",
+ clip_min=quant_min,
+ clip_max=quant_max,
)
# PASS 6 - Sub
@@ -215,6 +228,9 @@ class SoftmaxRewriter(DFPatternCallback):
ifm2_channels=1,
reversed_operands=False,
ofm_dtype="int32",
+ activation="CLIP",
+ clip_min=quant_min,
+ clip_max=quant_max,
)
# PASS 7 - SHL
@@ -229,13 +245,13 @@ class SoftmaxRewriter(DFPatternCallback):
ifm2_zero_point=0,
ofm_scale=0.0,
ofm_zero_point=int(params.ifm.q_params.zero_point),
- ifm_channels=depth,
+ ifm_channels=1,
ifm2_channels=1,
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
- clip_min=-128,
- clip_max=127,
+ clip_min=quant_min,
+ clip_max=quant_max,
)
# PASS 8 - Sub
@@ -255,6 +271,9 @@ class SoftmaxRewriter(DFPatternCallback):
ifm2_channels=1,
reversed_operands=False,
ofm_dtype="int32",
+ activation="CLIP",
+ clip_min=quant_min,
+ clip_max=quant_max,
)
# PASS 9 - SHL
@@ -274,8 +293,8 @@ class SoftmaxRewriter(DFPatternCallback):
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
- clip_min=-128,
- clip_max=127,
+ clip_min=quant_min,
+ clip_max=quant_max,
)
# PASS 10 - Add
@@ -296,8 +315,8 @@ class SoftmaxRewriter(DFPatternCallback):
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
- clip_min=-128,
- clip_max=127,
+ clip_min=quant_min,
+ clip_max=quant_max,
use_rescale=True,
rescale_scale=1,
rescale_shift=1,
@@ -316,13 +335,13 @@ class SoftmaxRewriter(DFPatternCallback):
ifm2_zero_point=0,
ofm_scale=2.0,
ofm_zero_point=0,
- ifm_channels=depth,
+ ifm_channels=1,
ifm2_channels=1,
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
- clip_min=-128 * 2,
- clip_max=127 * 2,
+ clip_min=quant_min,
+ clip_max=quant_max,
)
# PASS 12 - Add
@@ -343,8 +362,8 @@ class SoftmaxRewriter(DFPatternCallback):
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
- clip_min=-128,
- clip_max=127,
+ clip_min=quant_min,
+ clip_max=quant_max,
)
nr_x = rescale_w_offset
@@ -368,8 +387,8 @@ class SoftmaxRewriter(DFPatternCallback):
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
- clip_min=-128 * 2,
- clip_max=127 * 2,
+ clip_min=quant_min,
+ clip_max=quant_max,
)
# PASS 14, 19, 24 - Sub
@@ -388,6 +407,9 @@ class SoftmaxRewriter(DFPatternCallback):
ifm2_channels=1,
reversed_operands=False,
ofm_dtype="int32",
+ activation="CLIP",
+ clip_min=quant_min,
+ clip_max=quant_max,
)
# PASS 15, 20, 25 - Mul
@@ -407,8 +429,8 @@ class SoftmaxRewriter(DFPatternCallback):
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
- clip_min=-128 * 2,
- clip_max=127 * 2,
+ clip_min=quant_min,
+ clip_max=quant_max,
)
# PASS 16, 21, 26 - Mul
@@ -428,8 +450,8 @@ class SoftmaxRewriter(DFPatternCallback):
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
- clip_min=-128,
- clip_max=127,
+ clip_min=quant_min,
+ clip_max=quant_max,
)
# PASS 17, 22, 27 - Add
@@ -448,6 +470,9 @@ class SoftmaxRewriter(DFPatternCallback):
ifm2_channels=1,
reversed_operands=False,
ofm_dtype="int32",
+ activation="CLIP",
+ clip_min=quant_min,
+ clip_max=quant_max,
)
# PASS 28 - Mul
@@ -468,8 +493,8 @@ class SoftmaxRewriter(DFPatternCallback):
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
- clip_min=-128,
- clip_max=127,
+ clip_min=quant_min,
+ clip_max=quant_max,
)
# PASS 29 - Mul
@@ -489,8 +514,8 @@ class SoftmaxRewriter(DFPatternCallback):
reversed_operands=False,
ofm_dtype="int32",
activation="CLIP",
- clip_min=-128 * 2,
- clip_max=127 * 2,
+ clip_min=quant_min,
+ clip_max=quant_max,
)
# PASS 30 - SHR
diff --git a/tests/python/contrib/test_ethosu/test_legalize.py
b/tests/python/contrib/test_ethosu/test_legalize.py
index 1b643f8157..c952a13c52 100644
--- a/tests/python/contrib/test_ethosu/test_legalize.py
+++ b/tests/python/contrib/test_ethosu/test_legalize.py
@@ -3526,9 +3526,10 @@ def test_tflite_hard_swish(ifm_shape):
assert tuple(func_body.args[1].checked_type.shape) == (256,)
[email protected]("ifm_shape", [(1, 12), (1, 12, 32)])
-def test_tflite_softmax(ifm_shape):
+def test_tflite_softmax():
+ np.random.seed(0)
dtype = "int8"
+ ifm_shape = (1, 12)
def create_tflite_graph():
@tf.function
@@ -3539,7 +3540,7 @@ def test_tflite_softmax(ifm_shape):
# Convert the model
def representative_dataset():
for _ in range(100):
- data = np.random.rand(*tuple(ifm_shape))
+ data = np.random.uniform(low=-1, high=2, size=tuple(ifm_shape))
yield [data.astype(np.float32)]
converter =
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
@@ -3554,44 +3555,54 @@ def test_tflite_softmax(ifm_shape):
def verify(ext_func):
out_op = ext_func.body
ops = []
- # List of expected operations and their type if it exists
- expected_ops = [
- ("reshape", None),
- ("reshape", None),
- ("contrib.ethosu.pooling", "MAX"),
- ("contrib.ethosu.binary_elementwise", "SUB"),
- ("contrib.ethosu.binary_elementwise", "SHR"),
- ("contrib.ethosu.pooling", "SUM"),
- ("contrib.ethosu.unary_elementwise", "CLZ"),
- ("contrib.ethosu.binary_elementwise", "SUB"),
- ("contrib.ethosu.binary_elementwise", "SHL"),
- ("contrib.ethosu.binary_elementwise", "SUB"),
- ("contrib.ethosu.binary_elementwise", "SHL"),
- ("contrib.ethosu.binary_elementwise", "ADD"),
- ("contrib.ethosu.binary_elementwise", "MUL"),
- ("contrib.ethosu.binary_elementwise", "ADD"),
- ("contrib.ethosu.binary_elementwise", "MUL"),
- ("contrib.ethosu.binary_elementwise", "SUB"),
- ("contrib.ethosu.binary_elementwise", "MUL"),
- ("contrib.ethosu.binary_elementwise", "MUL"),
- ("contrib.ethosu.binary_elementwise", "ADD"),
- ("contrib.ethosu.binary_elementwise", "MUL"),
- ("contrib.ethosu.binary_elementwise", "SUB"),
- ("contrib.ethosu.binary_elementwise", "MUL"),
- ("contrib.ethosu.binary_elementwise", "MUL"),
- ("contrib.ethosu.binary_elementwise", "ADD"),
- ("contrib.ethosu.binary_elementwise", "MUL"),
- ("contrib.ethosu.binary_elementwise", "SUB"),
- ("contrib.ethosu.binary_elementwise", "MUL"),
- ("contrib.ethosu.binary_elementwise", "MUL"),
- ("contrib.ethosu.binary_elementwise", "ADD"),
- ("contrib.ethosu.binary_elementwise", "MUL"),
- ("contrib.ethosu.binary_elementwise", "MUL"),
- ("contrib.ethosu.binary_elementwise", "SUB"),
- ("contrib.ethosu.binary_elementwise", "SHR"),
- ("reshape", None),
+ # List of expected operations, their type and activation parameters if
it exists
+ expected_ops_params = [
+ ("reshape", None, [None, None, None, None, None, None]),
+ ("reshape", None, [None, None, None, None, None, None]),
+ ("contrib.ethosu.pooling", "MAX", [0.011756093241274357, -43,
None, None, 0.0, -43]),
+ (
+ "contrib.ethosu.binary_elementwise",
+ "SUB",
+ [0.011756093241274357, -43, 0.0, -43, 1.0, 127],
+ ),
+ ("contrib.ethosu.binary_elementwise", "SHR", [1.0, 0, 0.0, 0, 0.0,
-43]),
+ ("contrib.ethosu.pooling", "SUM", [0.0, 0, None, None, 0.0, -43]),
+ ("contrib.ethosu.unary_elementwise", "CLZ", [0.0, 0, None, None,
0.0, -43]),
+ ("contrib.ethosu.binary_elementwise", "SUB", [0.0, 0, 0.0, 0, 0.0,
-43]),
+ ("contrib.ethosu.binary_elementwise", "SHL", [0.0, 0, 0.0, 0, 0.0,
-43]),
+ ("contrib.ethosu.binary_elementwise", "SUB", [0.0, 0, 0.0, 0, 0.0,
-43]),
+ ("contrib.ethosu.binary_elementwise", "SHL", [0.0, 0, 0.0, 0, 0.0,
-43]),
+ ("contrib.ethosu.binary_elementwise", "ADD", [0.0, 0, 0.0, 0, 1.0,
0]),
+ ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0,
0]),
+ ("contrib.ethosu.binary_elementwise", "ADD", [2.0, 0, 0.0, 0, 1.0,
0]),
+ ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0,
0]),
+ ("contrib.ethosu.binary_elementwise", "SUB", [2.0, 0, 0.0, 0, 1.0,
0]),
+ ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0,
0]),
+ ("contrib.ethosu.binary_elementwise", "MUL", [2.0, 0, 0.0, 0, 0.0,
-43]),
+ ("contrib.ethosu.binary_elementwise", "ADD", [1.0, 0, 0.0, 0, 1.0,
0]),
+ ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0,
0]),
+ ("contrib.ethosu.binary_elementwise", "SUB", [2.0, 0, 0.0, 0, 1.0,
0]),
+ ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0,
0]),
+ ("contrib.ethosu.binary_elementwise", "MUL", [2.0, 0, 0.0, 0, 0.0,
-43]),
+ ("contrib.ethosu.binary_elementwise", "ADD", [1.0, 0, 0.0, 0, 1.0,
0]),
+ ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0,
0]),
+ ("contrib.ethosu.binary_elementwise", "SUB", [2.0, 0, 0.0, 0, 1.0,
0]),
+ ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0,
0]),
+ ("contrib.ethosu.binary_elementwise", "MUL", [2.0, 0, 0.0, 0, 0.0,
-43]),
+ ("contrib.ethosu.binary_elementwise", "ADD", [1.0, 0, 0.0, 0, 1.0,
0]),
+ ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 0.0, 0, 1.0,
0]),
+ ("contrib.ethosu.binary_elementwise", "MUL", [1.0, 0, 1.0, 0, 2.0,
0]),
+ ("contrib.ethosu.binary_elementwise", "SUB", [0.0, 0, 0.0, 0, 0.0,
-43]),
+ ("contrib.ethosu.binary_elementwise", "SHR", [2.0, 0, 0.0, 0,
0.00390625, -128]),
+ ("reshape", None, [None, None, None, None, None, None]),
]
+ def get_attr_value(op, attr_name):
+ if hasattr(op.attrs, attr_name):
+ return op.attrs[attr_name]
+ else:
+ return None
+
def get_op_type(op):
if hasattr(op.attrs, "pooling_type"):
return op.attrs.pooling_type
@@ -3599,6 +3610,16 @@ def test_tflite_softmax(ifm_shape):
return op.attrs.operator_type
return None
+ def get_activation_params(op):
+ activation_params = []
+ activation_params.append(get_attr_value(op, "ifm_scale"))
+ activation_params.append(get_attr_value(op, "ifm_zero_point"))
+ activation_params.append(get_attr_value(op, "ifm2_scale"))
+ activation_params.append(get_attr_value(op, "ifm2_zero_point"))
+ activation_params.append(get_attr_value(op, "ofm_scale"))
+ activation_params.append(get_attr_value(op, "ofm_zero_point"))
+ return activation_params
+
def _visit(stmt):
if isinstance(stmt, relay.expr.Call):
ops.append(stmt)
@@ -3616,9 +3637,18 @@ def test_tflite_softmax(ifm_shape):
assert ofm.dtype == dtype
# check operations
-
- ops = [(op.op.name, get_op_type(op)) for op in ops]
- assert expected_ops == ops
+ for op, expected_op_params in zip(ops, expected_ops_params):
+ activation_params = get_activation_params(op)
+ expected_op_name, expected_op_type, expected_activation_params =
expected_op_params
+ assert op.op.name == expected_op_name
+ assert expected_op_type == get_op_type(op)
+ for activation_param, expected_activation_param in zip(
+ activation_params, expected_activation_params
+ ):
+ if isinstance(activation_param, float):
+ assert math.isclose(expected_activation_param,
activation_param, abs_tol=1e-7)
+ else:
+ assert expected_activation_param == activation_param
softmax_pattern_table = [
(