ekalda commented on a change in pull request #10345:
URL: https://github.com/apache/tvm/pull/10345#discussion_r813800424



##########
File path: python/tvm/relay/backend/contrib/ethosu/legalize.py
##########
@@ -1577,6 +1577,88 @@ def __call__(self, *args, **kwargs):
         pass
 
 
+class FullyConnectedRewriter(DFPatternCallback):
+    """Legalize Fully Connected (with bias and clip) to an EthosU operator"""
+
+    def __init__(self):
+        super().__init__(require_type=True)
+        self.pattern = (
+            wildcard().has_attr({"Composite": 
ethosu_patterns.FullyConnectedParams.composite_name})
+        )(wildcard())
+
+    def callback(self, pre, post, node_map):
+        params = ethosu_patterns.FullyConnectedParams(post.op.body)
+        params.ifm.tensor = post.args[0]
+        activation_map = {"clip": "CLIP"}

Review comment:
       nit: we don't expect that dict to expand, so we can just do `if 
activation == "clip":` etc

##########
File path: python/tvm/relay/backend/contrib/ethosu/legalize.py
##########
@@ -1577,6 +1577,88 @@ def __call__(self, *args, **kwargs):
         pass
 
 
+class FullyConnectedRewriter(DFPatternCallback):
+    """Legalize Fully Connected (with bias and clip) to an EthosU operator"""
+
+    def __init__(self):
+        super().__init__(require_type=True)
+        self.pattern = (
+            wildcard().has_attr({"Composite": 
ethosu_patterns.FullyConnectedParams.composite_name})
+        )(wildcard())
+
+    def callback(self, pre, post, node_map):
+        params = ethosu_patterns.FullyConnectedParams(post.op.body)
+        params.ifm.tensor = post.args[0]
+        activation_map = {"clip": "CLIP"}
+
+        # IFM reshapes
+        ifm = post.args[0]
+        if len(params.ifm.shape) != 4 or not params.ifm.shape[1] == 
params.ifm.shape[2] == 1:
+            ifm = relay.reshape(ifm, (-1, 1, 1, params.ifm.shape[-1]))
+
+        # Weight transformations
+        weights_values = params.weights.values
+        weights_values_ohwi = np.expand_dims(weights_values, axis=(1, 2))
+        if params.activation:
+            activation = activation_map[params.activation.op.name]
+            clip_min = int(params.activation.attrs.a_min)
+            clip_max = int(params.activation.attrs.a_max)
+        else:
+            activation = "NONE"
+            clip_min = 0
+            clip_max = 0
+        scale_bias = vela_api.pack_biases(
+            biases=params.biases.tensor.data.asnumpy(),
+            ifm_scale=params.ifm.q_params.scale_f32,
+            ifm_dtype=np.dtype(params.ifm.dtype),
+            weight_scales=params.weights.q_params.scale_f32,
+            ofm_scale=params.ofm.q_params.scale_f32,
+            is_activation_tanh_or_sigmoid=False,
+        )
+        ethosu_fc = ethosu_ops.ethosu_conv2d(
+            ifm=ifm,
+            weight=relay.const(weights_values_ohwi, 
params.weights.values.dtype),
+            scale_bias=relay.const(scale_bias, "uint8"),
+            lut=relay.const([], dtype="int8"),
+            ifm_scale=float(params.ifm.q_params.scale_f32),
+            ifm_zero_point=int(params.ifm.q_params.zero_point),
+            weight_zero_point=int(params.weights.q_params.zero_point),
+            ofm_scale=float(params.ofm.q_params.scale_f32),
+            ofm_zero_point=int(params.ofm.q_params.zero_point),
+            kernel_shape=[1, 1],
+            ofm_channels=params.weights.shape[0],
+            strides=(1, 1),
+            padding=(0, 0, 0, 0),
+            dilation=(1, 1),
+            activation=activation,
+            clip_min=clip_min,
+            clip_max=clip_max,
+            upscale="NONE",
+            ifm_layout="NHWC",
+            ofm_layout="NHWC",
+        )
+
+        if len(params.ofm.shape) != 4 or not params.ofm.shape[1] == 
params.ofm.shape[2] == 1:
+            ethosu_fc = relay.reshape(ethosu_fc, params.ofm.shape)

Review comment:
       I suspect there isn't a test case that exercises this case since on line 
1700 this pass runs after the no op legalizer, so the last reshape won't have a 
following identity op and will fall over in TE

##########
File path: python/tvm/relay/backend/contrib/ethosu/legalize.py
##########
@@ -1577,6 +1577,88 @@ def __call__(self, *args, **kwargs):
         pass
 
 
+class FullyConnectedRewriter(DFPatternCallback):
+    """Legalize Fully Connected (with bias and clip) to an EthosU operator"""
+
+    def __init__(self):
+        super().__init__(require_type=True)
+        self.pattern = (
+            wildcard().has_attr({"Composite": 
ethosu_patterns.FullyConnectedParams.composite_name})
+        )(wildcard())
+
+    def callback(self, pre, post, node_map):
+        params = ethosu_patterns.FullyConnectedParams(post.op.body)
+        params.ifm.tensor = post.args[0]
+        activation_map = {"clip": "CLIP"}
+
+        # IFM reshapes
+        ifm = post.args[0]
+        if len(params.ifm.shape) != 4 or not params.ifm.shape[1] == 
params.ifm.shape[2] == 1:
+            ifm = relay.reshape(ifm, (-1, 1, 1, params.ifm.shape[-1]))

Review comment:
       ```suggestion
               ifm = relay.reshape(ifm, (1, 1, 1, params.ifm.shape[-1]))
   ```
   should be safer since the NPU doesn't support IFMs with a batch size 
anything other than 1 and this kind of fully connected wouldn't be offloaded 
anyway

##########
File path: python/tvm/relay/backend/contrib/ethosu/legalize.py
##########
@@ -1577,6 +1577,88 @@ def __call__(self, *args, **kwargs):
         pass
 
 
+class FullyConnectedRewriter(DFPatternCallback):
+    """Legalize Fully Connected (with bias and clip) to an EthosU operator"""
+
+    def __init__(self):
+        super().__init__(require_type=True)
+        self.pattern = (
+            wildcard().has_attr({"Composite": 
ethosu_patterns.FullyConnectedParams.composite_name})
+        )(wildcard())
+
+    def callback(self, pre, post, node_map):
+        params = ethosu_patterns.FullyConnectedParams(post.op.body)
+        params.ifm.tensor = post.args[0]
+        activation_map = {"clip": "CLIP"}
+
+        # IFM reshapes
+        ifm = post.args[0]
+        if len(params.ifm.shape) != 4 or not params.ifm.shape[1] == 
params.ifm.shape[2] == 1:
+            ifm = relay.reshape(ifm, (-1, 1, 1, params.ifm.shape[-1]))
+
+        # Weight transformations
+        weights_values = params.weights.values
+        weights_values_ohwi = np.expand_dims(weights_values, axis=(1, 2))
+        if params.activation:
+            activation = activation_map[params.activation.op.name]
+            clip_min = int(params.activation.attrs.a_min)
+            clip_max = int(params.activation.attrs.a_max)
+        else:
+            activation = "NONE"
+            clip_min = 0
+            clip_max = 0
+        scale_bias = vela_api.pack_biases(
+            biases=params.biases.tensor.data.asnumpy(),
+            ifm_scale=params.ifm.q_params.scale_f32,
+            ifm_dtype=np.dtype(params.ifm.dtype),
+            weight_scales=params.weights.q_params.scale_f32,
+            ofm_scale=params.ofm.q_params.scale_f32,
+            is_activation_tanh_or_sigmoid=False,
+        )
+        ethosu_fc = ethosu_ops.ethosu_conv2d(
+            ifm=ifm,
+            weight=relay.const(weights_values_ohwi, 
params.weights.values.dtype),
+            scale_bias=relay.const(scale_bias, "uint8"),
+            lut=relay.const([], dtype="int8"),
+            ifm_scale=float(params.ifm.q_params.scale_f32),
+            ifm_zero_point=int(params.ifm.q_params.zero_point),
+            weight_zero_point=int(params.weights.q_params.zero_point),
+            ofm_scale=float(params.ofm.q_params.scale_f32),
+            ofm_zero_point=int(params.ofm.q_params.zero_point),
+            kernel_shape=[1, 1],
+            ofm_channels=params.weights.shape[0],
+            strides=(1, 1),
+            padding=(0, 0, 0, 0),
+            dilation=(1, 1),
+            activation=activation,
+            clip_min=clip_min,
+            clip_max=clip_max,
+            upscale="NONE",
+            ifm_layout="NHWC",
+            ofm_layout="NHWC",
+        )
+
+        if len(params.ofm.shape) != 4 or not params.ofm.shape[1] == 
params.ofm.shape[2] == 1:
+            ethosu_fc = relay.reshape(ethosu_fc, params.ofm.shape)
+        return ethosu_fc
+
+
[email protected]_pass(opt_level=1)
+class LegalizeFullyConnected:
+    """This is the pass that wraps the AddRewriter"""

Review comment:
       ```suggestion
       """This is the pass that wraps the FullyConnectedRewriter"""
   ```

##########
File path: tests/python/contrib/test_ethosu/test_legalize.py
##########
@@ -2346,5 +2346,87 @@ def verify(ext_func):
     verify(mod["tvmgen_default_ethos_u_main_0"])
 
 
[email protected]("units", [32, 64])
[email protected]("use_bias", [True, False])
[email protected]("activation_function", ["RELU", "NONE"])
+def test_tflite_fully_connected(
+    units,
+    use_bias,
+    activation_function,
+):
+    dtype = "int8"
+
+    def create_tflite_graph():
+        class Model(tf.Module):
+            @tf.function
+            def fully_connected(self, x):
+                return tf.keras.layers.Dense(
+                    units=units,
+                    activation=activation_function,
+                    use_bias=use_bias,
+                )(x)
+
+        model = Model()
+        concrete_func = model.fully_connected.get_concrete_function(
+            tf.TensorSpec([1, 3, units, 1], dtype=tf.float32)
+        )
+
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple([1, 3, units, 1]))
+                yield [data.astype(np.float32)]
+
+        converter = 
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = 
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    def verify(ext_func):
+        op = ext_func.body
+        ofm_channels = op.attrs.ofm_channels
+
+        # check IFM
+        ifm = op.args[0].checked_type
+        assert list([1, 3, units, 1]) == list([1, 3, units, 1])
+        assert str(ifm.dtype) == dtype
+        assert ifm.shape[3] == ofm_channels
+
+        # Check that scale_bias matches weight tensor
+        assert list(op.args[2].checked_type.shape)[0] == ofm_channels
+
+        if activation_function == "RELU":
+            assert str(op.attrs.activation) == "CLIP"
+
+    dense_pattern_table = [

Review comment:
       nit: it would be better to keep the naming consistent, so maybe rename 
this to `fc_pattern_table` or `fully_connected_pattern_table`

##########
File path: tests/python/contrib/test_ethosu/test_legalize.py
##########
@@ -2346,5 +2346,87 @@ def verify(ext_func):
     verify(mod["tvmgen_default_ethos_u_main_0"])
 
 
[email protected]("units", [32, 64])
[email protected]("use_bias", [True, False])
[email protected]("activation_function", ["RELU", "NONE"])
+def test_tflite_fully_connected(
+    units,
+    use_bias,
+    activation_function,
+):
+    dtype = "int8"
+
+    def create_tflite_graph():
+        class Model(tf.Module):
+            @tf.function
+            def fully_connected(self, x):
+                return tf.keras.layers.Dense(
+                    units=units,
+                    activation=activation_function,
+                    use_bias=use_bias,
+                )(x)
+
+        model = Model()
+        concrete_func = model.fully_connected.get_concrete_function(
+            tf.TensorSpec([1, 3, units, 1], dtype=tf.float32)
+        )
+
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple([1, 3, units, 1]))
+                yield [data.astype(np.float32)]
+
+        converter = 
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = 
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    def verify(ext_func):
+        op = ext_func.body
+        ofm_channels = op.attrs.ofm_channels
+
+        # check IFM
+        ifm = op.args[0].checked_type
+        assert list([1, 3, units, 1]) == list([1, 3, units, 1])

Review comment:
       This assert doesn't check anything... Some things to potentially check:
   * That we have ended up with a `ethosu_conv2d` op (taking into account that 
there might be reshape ops before and after the conv2d)
   * That the IFM is in a shape of (1, 1, 1, c)
   * That the weights are in a shape (o, 1, 1, c) with o being the output 
channels of the weights
   * That the kernel and dilation are (1, 1)

##########
File path: python/tvm/relay/backend/contrib/ethosu/legalize.py
##########
@@ -1615,6 +1697,7 @@ def transform_module(
         mod = LegalizeReshape()(mod)
         mod = LegalizeStridedSlice()(mod)
         mod = LegalizeNoOps()(mod)
+        mod = LegalizeFullyConnected()(mod)

Review comment:
       This should run before `LegalizeNoOps`

##########
File path: python/tvm/relay/op/contrib/ethosu.py
##########
@@ -1537,6 +1537,106 @@ def squeeze_pattern():
     return is_op("squeeze")(wildcard())
 
 
+class FullyConnectedParams:
+    """
+    This class will parse a call to an ethos-u.fully_connected composite
+    function and extract the parameter information.
+    """
+
+    composite_name = "ethosu.fully_connected"
+    activation_map = {"clip": "CLIP"}

Review comment:
       Same nit about the clip dict as before :) 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to