This is an automated email from the ASF dual-hosted git repository.

manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 94d01d3565 [microNPU] Add support for hard swish (#12120)
94d01d3565 is described below

commit 94d01d35650ff4cb73dfd5360f6f613f5dadc41c
Author: Luke Hutton <[email protected]>
AuthorDate: Tue Jul 19 08:28:45 2022 +0100

    [microNPU] Add support for hard swish (#12120)
    
    Adds support for hard swish by populating a LUT similar to Vela's
    implementation.
    
    Change-Id: I7ca15a3e21bc91c1b41cdd4547fabaa00de96e90
---
 .../tvm/relay/backend/contrib/ethosu/legalize.py   | 78 ++++++++++++++++++++++
 python/tvm/relay/op/contrib/ethosu.py              | 53 +++++++++++++++
 tests/python/contrib/test_ethosu/test_codegen.py   | 15 +++++
 tests/python/contrib/test_ethosu/test_legalize.py  | 55 +++++++++++++++
 4 files changed, 201 insertions(+)

diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py 
b/python/tvm/relay/backend/contrib/ethosu/legalize.py
index c940abdeab..77ef51ef9c 100644
--- a/python/tvm/relay/backend/contrib/ethosu/legalize.py
+++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py
@@ -298,6 +298,83 @@ class LeakyReLURewriter(DFPatternCallback):
         return identity
 
 
+class HardSwishRewriter(DFPatternCallback):
+    """Convert ethosu.hard_swish composite function to add operation with 
LUT."""
+
+    def __init__(self):
+        super().__init__(require_type=True, rewrite_once=True)
+        self.params_class = ethosu_patterns.HardSwishParams
+        self.pattern = wildcard().has_attr({"Composite": 
self.params_class.composite_name})(
+            wildcard()
+        )
+
+    def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: 
tvm.ir.container.Map):
+        params = self.params_class(post.op.body)
+        params.ifm.tensor = post.args[0]
+
+        # The calculation of the LUT values is similar to that in Vela
+        # convert_hardswish_to_lut(op, arch, nng)
+        # 
(https://review.mlplatform.org/plugins/gitiles/ml/ethos-u/ethos-u-vela/+/refs/tags/3.2.0/ethosu/vela/tflite_graph_optimiser.py#719)
  # pylint: disable=line-too-long
+        input_scale = np.double(params.ifm.q_params.scale_f32)
+        input_zp = int(params.ifm.q_params.zero_point)
+        hires_input_scale = (1 / 128) * input_scale
+
+        output_scale = np.double(params.ofm.q_params.scale_f32)
+        output_zp = int(params.ofm.q_params.zero_point)
+        output_scale, output_shift = scaling.quantise_scale(hires_input_scale 
/ output_scale)
+        output_scale_16 = 
fp_math.downscale_multiplier_int32_to_int16(output_scale)
+        output_shift = 31 - output_shift
+        output_shift = -output_shift if output_shift < 0 else 0
+
+        dtype = params.ifm.dtype
+        qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max
+
+        def calculate_relu_multiplier(inp, input_scale):
+            rmultiplier = np.double(3 / 32768)
+            rscale, rshift = scaling.quantise_scale(input_scale / rmultiplier)
+            rscale_16 = fp_math.downscale_multiplier_int32_to_int16(rscale)
+
+            rvalue = np.int16(inp)
+            if rshift < 31:
+                rvalue = fp_math.shift_left16(rvalue, 30 - rshift)
+                rvalue = fp_math.saturating_rounding_mul16(rvalue, rscale_16)
+                rvalue = fp_math.shift_left16(rvalue, 1)
+            elif rshift > 31:
+                rvalue = fp_math.saturating_rounding_mul16(rvalue, rscale_16)
+                rvalue = fp_math.rounding_divide_by_pot(rvalue, rshift - 31)
+            else:
+                rvalue = fp_math.saturating_rounding_mul16(rvalue, rscale_16)
+
+            rvalue = (rvalue + (1 << 15)) >> 1
+            return rvalue
+
+        def calculate_lut_values(i):
+            hires_input_value = (i - input_zp) * 128
+            preshift_input_value = fp_math.saturating_rounding_mul16(
+                hires_input_value, output_scale_16
+            )
+            relu_value = calculate_relu_multiplier(hires_input_value, 
hires_input_scale)
+            lut_result = fp_math.saturating_mul16(relu_value, 
preshift_input_value)
+            lut_result = fp_math.rounding_divide_by_pot(lut_result, 
output_shift) + output_zp
+            return min(qmax, max(qmin, lut_result))
+
+        values = list(map(calculate_lut_values, range(-128, 128)))
+        lut = relay.const(values, dtype=dtype)
+
+        # We baked the requantization into the LUT, so we don't requantize the 
identity operator
+        identity = ethosu_ops.ethosu_identity(
+            ifm=params.ifm.tensor,
+            lut=lut,
+            ifm_scale=input_scale,
+            ifm_zero_point=input_zp,
+            ofm_scale=input_scale,
+            ofm_zero_point=input_zp,
+            activation="LUT",
+        )
+
+        return identity
+
+
 class Conv2DRewriter(DFPatternCallback):
     """Convert conv2d related composite functions into ethosu_conv2d 
operators"""
 
@@ -1306,6 +1383,7 @@ class LegalizeEthosU:
             ShlRewriter(),
             AbsRewriter(),
             TanhRewriter(),
+            HardSwishRewriter(),
             LeakyReLURewriter(),
             MeanRewriter(),
             ConcatRewriter(),
diff --git a/python/tvm/relay/op/contrib/ethosu.py 
b/python/tvm/relay/op/contrib/ethosu.py
index 4c3dcc2fc4..c0f8e5e970 100644
--- a/python/tvm/relay/op/contrib/ethosu.py
+++ b/python/tvm/relay/op/contrib/ethosu.py
@@ -1724,6 +1724,54 @@ def qnn_fc_pattern():
     return optional_clip
 
 
+class HardSwishParams:
+    """
+    This class will parse a call to a ethos-u.hard_swish composite function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethos-u.hard_swish"
+
+    def __init__(self, func_body):
+        from tvm.relay.backend.contrib.ethosu.util import QuantizeArgs
+        from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs
+
+        quantize = func_body
+        divide = quantize.args[0]
+        multiply = divide.args[0]
+        clip = multiply.args[1]
+        add = clip.args[0]
+        dequantize = add.args[0]
+
+        self.ifm = TensorParams(
+            dequantize.args[0],
+            scale=dequantize.args[DequantizeArgs.IFM_SCALE.value],
+            zero_point=dequantize.args[DequantizeArgs.IFM_ZERO_POINT.value],
+        )
+        self.ofm = TensorParams(
+            quantize,
+            scale=quantize.args[QuantizeArgs.OFM_SCALE.value],
+            zero_point=quantize.args[QuantizeArgs.OFM_ZERO_POINT.value],
+        )
+
+    def is_valid(self):
+        tensor_params = [self.ifm, self.ofm]
+        if not check_valid_dtypes(tensor_params, supported_dtypes=[np.int8]):
+            return False
+        return True
+
+
+def hard_swish_pattern():
+    """Create the pattern for hard swish."""
+    dequantize = is_op("qnn.dequantize")(wildcard(), is_constant(), 
is_constant())
+    add = is_op("add")(dequantize, is_constant())
+    clip = is_op("clip")(add)
+    multiply = is_op("multiply")(dequantize, clip)
+    divide = is_op("divide")(multiply, is_constant())
+    quantize = is_op("qnn.quantize")(divide, is_constant(), is_constant())
+    return quantize
+
+
 @register_pattern_table("ethos-u")
 def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, 
Callable]]:
     return [
@@ -1844,6 +1892,11 @@ def pattern_table() -> List[Tuple[str, 
tvm.relay.dataflow_pattern.DFPattern, Cal
             squeeze_pattern(),
             lambda pat: SqueezeParams(pat).is_valid(),
         ),
+        (
+            HardSwishParams.composite_name,
+            hard_swish_pattern(),
+            lambda pat: HardSwishParams(pat).is_valid(),
+        ),
     ]
 
 
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py 
b/tests/python/contrib/test_ethosu/test_codegen.py
index 2d3489889e..920cfff178 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -819,6 +819,21 @@ def test_tflite_tanh(accel_type):
     )
 
 
[email protected]("accel_type", ACCEL_TYPES)
[email protected]("ifm_shape", [(1, 5, 5, 3), (1, 12, 9, 1)])
+def test_tflite_hard_swish(accel_type, ifm_shape):
+    np.random.seed(0)
+
+    @tf.function
+    def hard_swish_func(x):
+        op = tf.keras.layers.Lambda(
+            lambda x: x * tf.keras.activations.relu(x + 3.0, max_value=6.0) / 
6.0
+        )(x)
+        return op
+
+    infra.compare_tvm_with_tflite(hard_swish_func, [ifm_shape], accel_type, 
ranges=[(-1, 1)])
+
+
 @pytest.mark.parametrize("accel_type", ACCEL_TYPES)
 @pytest.mark.parametrize(
     "shapes, axis",
diff --git a/tests/python/contrib/test_ethosu/test_legalize.py 
b/tests/python/contrib/test_ethosu/test_legalize.py
index 3f8b5f7d5b..0f8fa4d84b 100644
--- a/tests/python/contrib/test_ethosu/test_legalize.py
+++ b/tests/python/contrib/test_ethosu/test_legalize.py
@@ -2751,5 +2751,60 @@ def test_tflite_fully_connected(
     verify(mod["tvmgen_default_ethos_u_main_0"])
 
 
[email protected]("ifm_shape", [(1, 5, 5, 3), (1, 12, 9, 1)])
+def test_tflite_hard_swish(ifm_shape):
+    dtype = "int8"
+
+    def create_tflite_graph():
+        class Model(tf.Module):
+            @tf.function
+            def tf_function(self, x):
+                op = tf.keras.layers.Lambda(
+                    lambda x: x * tf.keras.activations.relu(x + 3.0, 
max_value=6.0) / 6.0
+                )(x)
+                return op
+
+        model = Model()
+        concrete_func = model.tf_function.get_concrete_function(
+            tf.TensorSpec(ifm_shape, tf.float32)
+        )
+
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
+
+        converter = 
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = 
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+
+        return tflite_model
+
+    tflite_graph = create_tflite_graph()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    mod, params = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"input": ifm_shape},
+        dtype_dict={"input": dtype},
+    )
+
+    mod = ethosu.partition_for_ethosu(mod, params)
+    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
+        legalize.HardSwishRewriter(), mod["tvmgen_default_ethos_u_main_0"]
+    )
+    mod = relay.transform.InferType()(mod)
+
+    func_body = mod["tvmgen_default_ethos_u_main_0"].body
+    assert func_body.op.name == "contrib.ethosu.identity"
+    assert func_body.attrs.activation == "LUT"
+    assert tuple(func_body.args[0].checked_type.shape) == (ifm_shape)
+    assert tuple(func_body.args[1].checked_type.shape) == (256,)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to