lhutton1 commented on a change in pull request #10345:
URL: https://github.com/apache/tvm/pull/10345#discussion_r822557744



##########
File path: tests/python/contrib/test_ethosu/test_legalize.py
##########
@@ -2421,5 +2421,108 @@ def verify(ext_func):
     verify(mod["tvmgen_default_ethos_u_main_0"])
 
 
[email protected]("ifm_shape", [(1, 14), (1, 151)])
[email protected]("ofm_channels", [32, 64])
[email protected]("use_bias", [True, False])
[email protected]("activation_function", ["RELU", "NONE"])
+def test_tflite_fully_connected(
+    ifm_shape,
+    ofm_channels,
+    use_bias,
+    activation_function,
+):
+    dtype = "int8"
+
+    def create_tflite_graph():
+        class Model(tf.Module):
+            @tf.function
+            def fully_connected(self, x):
+                bias_shape = ofm_channels
+                bias = tf.constant(np.random.uniform(size=bias_shape), 
dtype=tf.float32)
+                w = tf.constant(
+                    np.random.uniform(size=[ifm_shape[1], ofm_channels]),
+                    dtype=tf.float32,
+                )
+                x = tf.matmul(x, w)
+                if use_bias:
+                    x = tf.nn.bias_add(x, bias)
+                if activation_function:
+                    x = tf.nn.relu(x)
+                return x
+
+        model = Model()
+        concrete_func = model.fully_connected.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32)
+        )
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
+
+        converter = 
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = 
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    def verify(ext_func):
+        op = ext_func.body.args[0]
+        ofm_channels = op.attrs.ofm_channels
+
+        # check IFM
+        ifm = op.args[0].checked_type
+        assert [ifm.shape[2], ifm.shape[3]] == list(ifm_shape)
+        assert str(ifm.dtype) == dtype
+
+        # check OFM
+        ofm = op.checked_type
+        assert [ofm.shape[2], ofm.shape[3]] == [1, ofm_channels]
+        assert str(ofm.dtype) == dtype
+
+        # check weights
+        weights_ohwi = op.args[1].data.asnumpy()
+        assert str(weights_ohwi.dtype) == dtype
+        assert list(weights_ohwi) == [ofm_channels, 1, 1, ifm_shape[1]]

Review comment:
       This ones just missing `.shape` i.e. `assert list(weights_ohwi.shape) == 
[ofm_channels, 1, 1, ifm_shape[1]]`

##########
File path: tests/python/contrib/test_ethosu/test_legalize.py
##########
@@ -2346,5 +2348,121 @@ def verify(ext_func):
     verify(mod["tvmgen_default_ethos_u_main_0"])
 
 
[email protected]("ifm_shape", [(1, 14), (1, 151)])
[email protected]("ofm_channels", [32, 64])
[email protected]("use_bias", [True, False])
[email protected]("activation_function", ["RELU", "NONE"])
+def test_tflite_fully_connected(
+    ifm_shape,
+    ofm_channels,
+    use_bias,
+    activation_function,
+):
+    dtype = "int8"
+
+    def create_tflite_graph():
+        class Model(tf.Module):
+            @tf.function
+            def fully_connected(self, x):
+                bias_shape = ofm_channels
+                bias = tf.constant(np.random.uniform(size=bias_shape), 
dtype=tf.float32)
+                w = tf.constant(
+                    np.random.uniform(size=[ifm_shape[1], ofm_channels]),
+                    dtype=tf.float32,
+                )
+                x = tf.matmul(x, w)
+                if use_bias:
+                    x = tf.nn.bias_add(x, bias)
+                if activation_function:
+                    x = tf.nn.relu(x)
+                return x
+
+        model = Model()
+        concrete_func = model.fully_connected.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32)
+        )
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
+
+        converter = 
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = 
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    def verify(ext_func):
+        op = ext_func.body.args[0]
+        ofm_channels = op.attrs.ofm_channels
+
+        # check IFM
+        ifm = op.args[0].checked_type
+        assert [ifm.shape[2], ifm.shape[3]] == list(ifm_shape)

Review comment:
       To get this change to work we would need `assert list(ifm.shape) == [1, 
1] + list(ifm_shape)`, since ifm.shape is not a `list`

##########
File path: tests/python/contrib/test_ethosu/test_legalize.py
##########
@@ -2346,5 +2348,121 @@ def verify(ext_func):
     verify(mod["tvmgen_default_ethos_u_main_0"])
 
 
[email protected]("ifm_shape", [(1, 14), (1, 151)])
[email protected]("ofm_channels", [32, 64])
[email protected]("use_bias", [True, False])
[email protected]("activation_function", ["RELU", "NONE"])
+def test_tflite_fully_connected(
+    ifm_shape,
+    ofm_channels,
+    use_bias,
+    activation_function,
+):
+    dtype = "int8"
+
+    def create_tflite_graph():
+        class Model(tf.Module):
+            @tf.function
+            def fully_connected(self, x):
+                bias_shape = ofm_channels
+                bias = tf.constant(np.random.uniform(size=bias_shape), 
dtype=tf.float32)
+                w = tf.constant(
+                    np.random.uniform(size=[ifm_shape[1], ofm_channels]),
+                    dtype=tf.float32,
+                )
+                x = tf.matmul(x, w)
+                if use_bias:
+                    x = tf.nn.bias_add(x, bias)
+                if activation_function:
+                    x = tf.nn.relu(x)
+                return x
+
+        model = Model()
+        concrete_func = model.fully_connected.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32)
+        )
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
+
+        converter = 
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = 
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    def verify(ext_func):
+        op = ext_func.body.args[0]
+        ofm_channels = op.attrs.ofm_channels
+
+        # check IFM
+        ifm = op.args[0].checked_type
+        assert [ifm.shape[2], ifm.shape[3]] == list(ifm_shape)
+        assert str(ifm.dtype) == dtype
+
+        # check OFM
+        ofm = op.checked_type
+        assert [ofm.shape[2], ofm.shape[3]] == [1, ofm_channels]

Review comment:
       This would need to be `assert list(ofm.shape) == [1, 1, 1, ofm_channels]`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to