This is an automated email from the ASF dual-hosted git repository.
manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 94d01d3565 [microNPU] Add support for hard swish (#12120)
94d01d3565 is described below
commit 94d01d35650ff4cb73dfd5360f6f613f5dadc41c
Author: Luke Hutton <[email protected]>
AuthorDate: Tue Jul 19 08:28:45 2022 +0100
[microNPU] Add support for hard swish (#12120)
Adds support for hard swish by populating a LUT similar to Vela's
implementation.
Change-Id: I7ca15a3e21bc91c1b41cdd4547fabaa00de96e90
---
.../tvm/relay/backend/contrib/ethosu/legalize.py | 78 ++++++++++++++++++++++
python/tvm/relay/op/contrib/ethosu.py | 53 +++++++++++++++
tests/python/contrib/test_ethosu/test_codegen.py | 15 +++++
tests/python/contrib/test_ethosu/test_legalize.py | 55 +++++++++++++++
4 files changed, 201 insertions(+)
diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py
b/python/tvm/relay/backend/contrib/ethosu/legalize.py
index c940abdeab..77ef51ef9c 100644
--- a/python/tvm/relay/backend/contrib/ethosu/legalize.py
+++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py
@@ -298,6 +298,83 @@ class LeakyReLURewriter(DFPatternCallback):
return identity
+class HardSwishRewriter(DFPatternCallback):
+ """Convert ethosu.hard_swish composite function to add operation with
LUT."""
+
+ def __init__(self):
+ super().__init__(require_type=True, rewrite_once=True)
+ self.params_class = ethosu_patterns.HardSwishParams
+ self.pattern = wildcard().has_attr({"Composite":
self.params_class.composite_name})(
+ wildcard()
+ )
+
+ def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map:
tvm.ir.container.Map):
+ params = self.params_class(post.op.body)
+ params.ifm.tensor = post.args[0]
+
+ # The calculation of the LUT values is similar to that in Vela
+ # convert_hardswish_to_lut(op, arch, nng)
+ #
(https://review.mlplatform.org/plugins/gitiles/ml/ethos-u/ethos-u-vela/+/refs/tags/3.2.0/ethosu/vela/tflite_graph_optimiser.py#719)
# pylint: disable=line-too-long
+ input_scale = np.double(params.ifm.q_params.scale_f32)
+ input_zp = int(params.ifm.q_params.zero_point)
+ hires_input_scale = (1 / 128) * input_scale
+
+ output_scale = np.double(params.ofm.q_params.scale_f32)
+ output_zp = int(params.ofm.q_params.zero_point)
+ output_scale, output_shift = scaling.quantise_scale(hires_input_scale
/ output_scale)
+ output_scale_16 =
fp_math.downscale_multiplier_int32_to_int16(output_scale)
+ output_shift = 31 - output_shift
+ output_shift = -output_shift if output_shift < 0 else 0
+
+ dtype = params.ifm.dtype
+ qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max
+
+ def calculate_relu_multiplier(inp, input_scale):
+ rmultiplier = np.double(3 / 32768)
+ rscale, rshift = scaling.quantise_scale(input_scale / rmultiplier)
+ rscale_16 = fp_math.downscale_multiplier_int32_to_int16(rscale)
+
+ rvalue = np.int16(inp)
+ if rshift < 31:
+ rvalue = fp_math.shift_left16(rvalue, 30 - rshift)
+ rvalue = fp_math.saturating_rounding_mul16(rvalue, rscale_16)
+ rvalue = fp_math.shift_left16(rvalue, 1)
+ elif rshift > 31:
+ rvalue = fp_math.saturating_rounding_mul16(rvalue, rscale_16)
+ rvalue = fp_math.rounding_divide_by_pot(rvalue, rshift - 31)
+ else:
+ rvalue = fp_math.saturating_rounding_mul16(rvalue, rscale_16)
+
+ rvalue = (rvalue + (1 << 15)) >> 1
+ return rvalue
+
+ def calculate_lut_values(i):
+ hires_input_value = (i - input_zp) * 128
+ preshift_input_value = fp_math.saturating_rounding_mul16(
+ hires_input_value, output_scale_16
+ )
+ relu_value = calculate_relu_multiplier(hires_input_value,
hires_input_scale)
+ lut_result = fp_math.saturating_mul16(relu_value,
preshift_input_value)
+ lut_result = fp_math.rounding_divide_by_pot(lut_result,
output_shift) + output_zp
+ return min(qmax, max(qmin, lut_result))
+
+ values = list(map(calculate_lut_values, range(-128, 128)))
+ lut = relay.const(values, dtype=dtype)
+
+ # We baked the requantization into the LUT, so we don't requantize the
identity operator
+ identity = ethosu_ops.ethosu_identity(
+ ifm=params.ifm.tensor,
+ lut=lut,
+ ifm_scale=input_scale,
+ ifm_zero_point=input_zp,
+ ofm_scale=input_scale,
+ ofm_zero_point=input_zp,
+ activation="LUT",
+ )
+
+ return identity
+
+
class Conv2DRewriter(DFPatternCallback):
"""Convert conv2d related composite functions into ethosu_conv2d
operators"""
@@ -1306,6 +1383,7 @@ class LegalizeEthosU:
ShlRewriter(),
AbsRewriter(),
TanhRewriter(),
+ HardSwishRewriter(),
LeakyReLURewriter(),
MeanRewriter(),
ConcatRewriter(),
diff --git a/python/tvm/relay/op/contrib/ethosu.py
b/python/tvm/relay/op/contrib/ethosu.py
index 4c3dcc2fc4..c0f8e5e970 100644
--- a/python/tvm/relay/op/contrib/ethosu.py
+++ b/python/tvm/relay/op/contrib/ethosu.py
@@ -1724,6 +1724,54 @@ def qnn_fc_pattern():
return optional_clip
+class HardSwishParams:
+ """
+ This class will parse a call to a ethos-u.hard_swish composite function
+ and extract the parameter information.
+ """
+
+ composite_name = "ethos-u.hard_swish"
+
+ def __init__(self, func_body):
+ from tvm.relay.backend.contrib.ethosu.util import QuantizeArgs
+ from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs
+
+ quantize = func_body
+ divide = quantize.args[0]
+ multiply = divide.args[0]
+ clip = multiply.args[1]
+ add = clip.args[0]
+ dequantize = add.args[0]
+
+ self.ifm = TensorParams(
+ dequantize.args[0],
+ scale=dequantize.args[DequantizeArgs.IFM_SCALE.value],
+ zero_point=dequantize.args[DequantizeArgs.IFM_ZERO_POINT.value],
+ )
+ self.ofm = TensorParams(
+ quantize,
+ scale=quantize.args[QuantizeArgs.OFM_SCALE.value],
+ zero_point=quantize.args[QuantizeArgs.OFM_ZERO_POINT.value],
+ )
+
+ def is_valid(self):
+ tensor_params = [self.ifm, self.ofm]
+ if not check_valid_dtypes(tensor_params, supported_dtypes=[np.int8]):
+ return False
+ return True
+
+
+def hard_swish_pattern():
+ """Create the pattern for hard swish."""
+ dequantize = is_op("qnn.dequantize")(wildcard(), is_constant(),
is_constant())
+ add = is_op("add")(dequantize, is_constant())
+ clip = is_op("clip")(add)
+ multiply = is_op("multiply")(dequantize, clip)
+ divide = is_op("divide")(multiply, is_constant())
+ quantize = is_op("qnn.quantize")(divide, is_constant(), is_constant())
+ return quantize
+
+
@register_pattern_table("ethos-u")
def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern,
Callable]]:
return [
@@ -1844,6 +1892,11 @@ def pattern_table() -> List[Tuple[str,
tvm.relay.dataflow_pattern.DFPattern, Cal
squeeze_pattern(),
lambda pat: SqueezeParams(pat).is_valid(),
),
+ (
+ HardSwishParams.composite_name,
+ hard_swish_pattern(),
+ lambda pat: HardSwishParams(pat).is_valid(),
+ ),
]
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py
b/tests/python/contrib/test_ethosu/test_codegen.py
index 2d3489889e..920cfff178 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -819,6 +819,21 @@ def test_tflite_tanh(accel_type):
)
[email protected]("accel_type", ACCEL_TYPES)
[email protected]("ifm_shape", [(1, 5, 5, 3), (1, 12, 9, 1)])
+def test_tflite_hard_swish(accel_type, ifm_shape):
+ np.random.seed(0)
+
+ @tf.function
+ def hard_swish_func(x):
+ op = tf.keras.layers.Lambda(
+ lambda x: x * tf.keras.activations.relu(x + 3.0, max_value=6.0) /
6.0
+ )(x)
+ return op
+
+ infra.compare_tvm_with_tflite(hard_swish_func, [ifm_shape], accel_type,
ranges=[(-1, 1)])
+
+
@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
"shapes, axis",
diff --git a/tests/python/contrib/test_ethosu/test_legalize.py
b/tests/python/contrib/test_ethosu/test_legalize.py
index 3f8b5f7d5b..0f8fa4d84b 100644
--- a/tests/python/contrib/test_ethosu/test_legalize.py
+++ b/tests/python/contrib/test_ethosu/test_legalize.py
@@ -2751,5 +2751,60 @@ def test_tflite_fully_connected(
verify(mod["tvmgen_default_ethos_u_main_0"])
[email protected]("ifm_shape", [(1, 5, 5, 3), (1, 12, 9, 1)])
+def test_tflite_hard_swish(ifm_shape):
+ dtype = "int8"
+
+ def create_tflite_graph():
+ class Model(tf.Module):
+ @tf.function
+ def tf_function(self, x):
+ op = tf.keras.layers.Lambda(
+ lambda x: x * tf.keras.activations.relu(x + 3.0,
max_value=6.0) / 6.0
+ )(x)
+ return op
+
+ model = Model()
+ concrete_func = model.tf_function.get_concrete_function(
+ tf.TensorSpec(ifm_shape, tf.float32)
+ )
+
+ def representative_dataset():
+ for _ in range(100):
+ data = np.random.rand(*tuple(ifm_shape))
+ yield [data.astype(np.float32)]
+
+ converter =
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
+ converter.representative_dataset = representative_dataset
+ converter.target_spec.supported_ops =
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+ converter.inference_input_type = tf.int8
+ converter.inference_output_type = tf.int8
+ tflite_model = converter.convert()
+
+ return tflite_model
+
+ tflite_graph = create_tflite_graph()
+ tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+ mod, params = relay.frontend.from_tflite(
+ tflite_model,
+ shape_dict={"input": ifm_shape},
+ dtype_dict={"input": dtype},
+ )
+
+ mod = ethosu.partition_for_ethosu(mod, params)
+ mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
+ legalize.HardSwishRewriter(), mod["tvmgen_default_ethos_u_main_0"]
+ )
+ mod = relay.transform.InferType()(mod)
+
+ func_body = mod["tvmgen_default_ethos_u_main_0"].body
+ assert func_body.op.name == "contrib.ethosu.identity"
+ assert func_body.attrs.activation == "LUT"
+ assert tuple(func_body.args[0].checked_type.shape) == (ifm_shape)
+ assert tuple(func_body.args[1].checked_type.shape) == (256,)
+
+
if __name__ == "__main__":
pytest.main([__file__])