lhutton1 commented on a change in pull request #10599:
URL: https://github.com/apache/tvm/pull/10599#discussion_r828315695



##########
File path: python/tvm/relay/backend/contrib/ethosu/codegen.py
##########
@@ -322,43 +354,36 @@ def constant_updater(expr, symbol):  # pylint: 
disable=unused-argument
     return dict()
 
 
-@tvm._ffi.register_func("relay.ext.ethos-u.relay_to_tir_func")
-def relay_to_tir_func(ext_func: relay.Function) -> tvm.tir.PrimFunc:
+@tvm._ffi.register_func("relay.ext.ethos-u.relay_to_tir")
+def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
     """
-    This is the hook for python-based lowering of relay function
-    that gets offloaded to the microNPU.
+    This is the hook for python-based lowering of a Relay module which lowers 
NPU
+    external functions to TIR.
 
     Parameters
     ----------
-    ext_func : relay.Function
-        This is the partitioned relay function
+    mod : tvm.ir.IRModule
+        This is the Relay module.
 
     Returns
     -------
-    primfunc : tir.PrimFunc
-        This returns the scheduled PrimFunc
+    mod : tvm.ir.TRModule
+        The Relay module with scheduled NPU external functions.
     """
-    assert len(ext_func.params) == 1
-    mod = tvm.IRModule()
-    mod["main"] = ext_func
+    mod = OutlineCompilerFunctions("ethos-u")(mod)
     mod = LegalizeEthosU()(mod)
     mod = LUTsOptimizer()(mod)
     mod = IdentityOptimizer()(mod)
     mod = LayoutOptimizer()(mod)
     mod = relay.transform.InferType()(mod)
-    # We are currently using copy_constants scheduler In the long run,
-    # this should be a single intelligent and a composite scheduler
-    # that can perform scheduling based on user inputs such as
-    # scratch memory size.
-    tir_mod, const_dict = lower_to_tir(mod["main"], copy_constants())
-
-    for param in const_dict.keys():
-        const_dict[param] = tvm.nd.array(const_dict[param])
-
-    primfunc = tir_mod["main"]
-    primfunc = primfunc.with_attr("global_symbol", 
ext_func.attrs["global_symbol"])
-    primfunc = primfunc.with_attr("ethos-u.constants", const_dict)
-    return primfunc
+
+    device_contexts = {
+        gv: "ethos-u" for gv, _ in filter(lambda x: util.is_npu_func(x[1]), 
mod.functions.items())
+    }
+    mod = mod.with_attr("device_contexts", device_contexts)
+    mod = RelayToTIR()(mod)

Review comment:
       Agreed, I wasn't too happy with this either. I'll revise this so that 
more of this functionality is rewritten in C++, although its a bit tricky to 
define a clear boundary between Python and C++.. see what you think to the 
update




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to