areusch commented on a change in pull request #6950:
URL: https://github.com/apache/tvm/pull/6950#discussion_r556736523



##########
File path: tests/micro/qemu/test_zephyr.py
##########
@@ -198,5 +200,143 @@ def test_relay(platform):
         tvm.testing.assert_allclose(result, x_in * x_in + 1)
 
 
+class CcompilerAnnotator(ExprMutator):
+    """
+    This is used to create external functions for ccompiler.
+    A simple annotator that creates the following program:
+           |
+      -- begin --
+           |
+          add
+           |
+        subtract
+           |
+        multiply
+           |
+       -- end --
+           |
+    """
+
+    def __init__(self):
+        super(CcompilerAnnotator, self).__init__()
+        self.in_compiler = 0
+
+    def visit_call(self, call):
+        if call.op.name == "add":  # Annotate begin at args
+            if self.in_compiler == 1:
+                lhs = compiler_begin(super().visit(call.args[0]), "ccompiler")
+                rhs = compiler_begin(super().visit(call.args[1]), "ccompiler")
+                op = relay.add(lhs, rhs)
+                self.in_compiler = 2
+                return op
+        elif call.op.name == "subtract":
+            if self.in_compiler == 1:
+                lhs = super().visit(call.args[0])
+                rhs = super().visit(call.args[1])
+                if isinstance(lhs, relay.expr.Var):
+                    lhs = compiler_begin(lhs, "ccompiler")
+                if isinstance(rhs, relay.expr.Var):
+                    rhs = compiler_begin(rhs, "ccompiler")
+                return relay.subtract(lhs, rhs)
+        elif call.op.name == "multiply":  # Annotate end at output
+            self.in_compiler = 1
+            lhs = super().visit(call.args[0])
+            rhs = super().visit(call.args[1])
+            if isinstance(lhs, relay.expr.Var):
+                lhs = compiler_begin(lhs, "ccompiler")
+            if isinstance(rhs, relay.expr.Var):
+                rhs = compiler_begin(rhs, "ccompiler")
+            op = relay.multiply(lhs, rhs)
+            if self.in_compiler == 2:
+                op = compiler_end(op, "ccompiler")
+            self.in_compiler = 0
+            return op
+        return super().visit_call(call)
+
+
+def check_result(relay_mod, model, zephyr_board, map_inputs, out_shape, 
result):
+    """Helper function to verify results"""
+    tol = 1e-5
+    target = tvm.target.target.micro(model)
+    with tvm.transform.PassContext(opt_level=3, 
config={"tir.disable_vectorize": True}):
+        graph, mod, params = tvm.relay.build(relay_mod, target=target)
+
+    with _make_session(model, target, zephyr_board, mod) as session:

Review comment:
       ah that's correct, thanks for jogging my memory. great, so this test 
should do everything you need for now.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to