dchauhan-arm commented on a change in pull request #10345:
URL: https://github.com/apache/tvm/pull/10345#discussion_r813708528



##########
File path: tests/python/contrib/test_ethosu/test_codegen.py
##########
@@ -1167,5 +1167,24 @@ def leaky_relu_func(x):
     _compare_tvm_with_tflite(leaky_relu_func, [ifm_shape], accel_type)
 
 
[email protected]("accel_type", ACCEL_TYPES)
[email protected]("units", [32, 64])
[email protected]("use_bias", [True, False])
[email protected]("activation_function", ["RELU", "NONE"])
+def test_tflite_fully_connected(
+    accel_type,
+    units,
+    use_bias,
+    activation_function,
+):
+    @tf.function
+    def fully_connected():
+        return tf.keras.layers.Dense(

Review comment:
       this is a very welcome change, I'l try and make this work!

##########
File path: python/tvm/relay/op/contrib/ethosu.py
##########
@@ -1537,6 +1537,105 @@ def squeeze_pattern():
     return is_op("squeeze")(wildcard())
 
 
+class FullyConnectedParams:
+    """
+    This class will parse a call to an ethos-u.fully_connected composite
+    function and extract the parameter information.
+    """
+
+    composite_name = "ethosu.fully_connected"
+    activation_map = {"clip": "CLIP"}
+
+    @requires_vela
+    def __init__(self, func_body):
+        from tvm.relay.backend.contrib.ethosu.util import QDenseArgs  # type: 
ignore
+        from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs
+        from tvm.relay.backend.contrib.ethosu.util import RequantArgs
+
+        activation = None
+        if str(func_body.op) in self.activation_map.keys():
+            activation = func_body
+            requantize_op = activation.args[0]
+        else:
+            requantize_op = func_body
+
+        bias_add = requantize_op.args[0]
+        qnn_dense = bias_add.args[0]
+        
+        # We consider the weights & biases as params as they should be constant
+        self.weights = TensorParams(
+            qnn_dense.args[QDenseArgs.weights.value],
+            "OI",
+            qnn_dense.args[QDenseArgs.weights_scale.value],
+            qnn_dense.args[QDenseArgs.weights_zero_point.value],
+        )
+        self.biases = TensorParams(
+            bias_add.args[BiasAddArgs.BIASES.value],
+            None,
+            requantize_op.args[RequantArgs.IFM_SCALE.value],
+            requantize_op.args[RequantArgs.IFM_ZERO_POINT.value],
+        )
+        self.ifm = TensorParams(
+            qnn_dense.args[QDenseArgs.ifm.value],

Review comment:
       ack!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to