areusch commented on code in PR #11464:
URL: https://github.com/apache/tvm/pull/11464#discussion_r894786749


##########
python/tvm/micro/contrib/stm32/emitter.py:
##########
@@ -482,8 +482,18 @@ def parse_library_format(self, model_library_format_path, 
quantization=None):
         with tarfile.TarFile(model_library_format_path) as f:
             f.extractall(extract_path)
 
+        with open(os.path.join(extract_path, "metadata.json")) as metadata_f:
+            metadata = json.load(metadata_f)
+
+        all_module_names = []
+        for name in metadata["modules"].keys():
+            all_module_names.append(name)
+        assert len(all_module_names) == 1, "Multiple modules is not supported."

Review Comment:
   i think you could simplify to just len(metadata["modules"])



##########
python/tvm/micro/model_library_format.py:
##########
@@ -67,56 +69,75 @@ def generate_c_interface_header(
 EPHEMERAL_MODULE_TYPE_KEYS = ("metadata_module",)
 
 
-def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None):
+def _populate_codegen_dir(
+    mods: Union[
+        typing.List[executor_factory.ExecutorFactoryModule], 
typing.List[tvm.runtime.Module]

Review Comment:
   how come tvm.runtime.Module is allowed? also, do we want 
`List[Union[ExecutorFactoryModule, tvm.runtime.Module]]`?



##########
tests/python/unittest/test_micro_model_library_format.py:
##########
@@ -95,8 +101,35 @@ def test_export_operator_model_library_format():
             assert tir_f.read() == str(ir_mod)
 
 
[email protected]_micro
+def test_export_multiple_operator_model_library_format():
+    target = tvm.target.target.micro("host")
+    with tvm.transform.PassContext(opt_level=3, 
config={"tir.disable_vectorize": True}):
+        A = tvm.te.placeholder((2,), dtype="int8")
+        B = tvm.te.placeholder((1,), dtype="int8")
+        C = tvm.te.compute(A.shape, lambda i: A[i] + B[0], name="C")
+        sched = tvm.te.create_schedule(C.op)
+        mod = tvm.build(
+            sched,
+            [A, B, C],
+            tvm.target.Target(target, target),
+            runtime=Runtime("crt", {"system-lib": True}),
+            name="add",
+        )
+
+    temp_dir = utils.tempdir()
+    mlf_tar_path = temp_dir.relpath("lib.tar")
+
+    with pytest.raises(RuntimeError) as exc:
+        micro.export_model_library_format([mod, mod], mlf_tar_path)
+
+        assert str(exc.exception) == ("Multiple operator is not supported.")

Review Comment:
   it seems like in this case the error is that you've passed duplicate `mod`, 
right? maybe we should be checking for that or de-duping?



##########
tests/python/unittest/test_micro_model_library_format.py:
##########
@@ -439,5 +474,140 @@ def test_export_byoc_c_module():
             ]
 
 
[email protected]_micro
+def test_multiple_relay_modules_same_module_name():
+    mod = get_conv2d_relay_module()
+
+    executor = Executor("graph")
+    runtime = Runtime("crt")
+    target = tvm.target.target.micro("host")
+
+    with tvm.transform.PassContext(opt_level=3, 
config={"tir.disable_vectorize": True}):
+        factory1 = tvm.relay.build(mod, target, runtime=runtime, 
executor=executor, mod_name="mod")
+        factory2 = tvm.relay.build(mod, target, runtime=runtime, 
executor=executor, mod_name="mod")
+
+    temp_dir = utils.tempdir()
+    mlf_tar_path = temp_dir.relpath("lib.tar")
+
+    with pytest.raises(AssertionError, match="Multiple modules should have 
unique names"):
+        micro.export_model_library_format([factory1, factory2], mlf_tar_path)
+
+
[email protected]_micro
+def test_multiple_relay_modules_graph():
+    mod = get_conv2d_relay_module()
+
+    executor = Executor("graph")
+    runtime = Runtime("crt")
+    target = tvm.target.target.micro("host")
+
+    with tvm.transform.PassContext(opt_level=3, 
config={"tir.disable_vectorize": True}):
+        factory1 = tvm.relay.build(mod, target, runtime=runtime, 
executor=executor, mod_name="mod1")
+        factory2 = tvm.relay.build(mod, target, runtime=runtime, 
executor=executor, mod_name="mod2")
+
+    temp_dir = utils.tempdir()
+    mlf_tar_path = temp_dir.relpath("lib.tar")
+    micro.export_model_library_format([factory1, factory2], mlf_tar_path)
+
+    with tarfile.open(mlf_tar_path, "r:*") as tf:
+        tar_members = [ti.name for ti in tf.getmembers()]
+        print("tar members", tar_members)
+        assert "./metadata.json" in tar_members
+        assert "./codegen/host/src/mod1_lib0.c" in tar_members
+        assert "./codegen/host/src/mod2_lib0.c" in tar_members
+
+        with tf.extractfile("./metadata.json") as f:
+            metadata = json.load(f)
+        mod2_main_md = 
metadata["modules"]["mod2"]["memory"]["functions"]["main"]
+        assert mod2_main_md == [
+            {
+                "constants_size_bytes": 0,
+                "device": 1,
+                "io_size_bytes": 143960,
+                "workspace_size_bytes": 158088,
+            }
+        ]
+        assert metadata["modules"]["mod1"]["model_name"] == "mod1"
+        assert metadata["modules"]["mod2"]["model_name"] == "mod2"
+
+
[email protected]_micro
+def test_multiple_relay_modules_c():
+    mod = get_conv2d_relay_module()
+
+    executor = Executor("aot", {"unpacked-api": True, "interface-api": "c"})
+    runtime = Runtime("crt")
+    target = tvm.target.target.micro("host")
+
+    with tvm.transform.PassContext(opt_level=3, 
config={"tir.disable_vectorize": True}):
+        factory1 = tvm.relay.build(mod, target, runtime=runtime, 
executor=executor, mod_name="mod1")
+        factory2 = tvm.relay.build(mod, target, runtime=runtime, 
executor=executor, mod_name="mod2")
+
+    temp_dir = utils.tempdir()
+    mlf_tar_path = temp_dir.relpath("lib.tar")
+
+    micro.export_model_library_format([factory1, factory2], mlf_tar_path)
+
+    tf = tarfile.open(mlf_tar_path)
+
+    extract_dir = temp_dir.relpath("extract")
+    os.mkdir(extract_dir)
+    tf.extractall(extract_dir)
+
+    assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", 
"mod1_lib0.c"))
+    assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", 
"mod1_lib1.c"))
+    assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", 
"mod2_lib0.c"))
+    assert os.path.exists(os.path.join(extract_dir, "codegen", "host", "src", 
"mod2_lib1.c"))
+
+    assert os.path.exists(os.path.join(extract_dir, "codegen", "host", 
"include", "tvmgen_mod1.h"))
+    assert os.path.exists(os.path.join(extract_dir, "codegen", "host", 
"include", "tvmgen_mod2.h"))
+
+    # check CRT runtime directory
+    assert os.path.exists(os.path.join(extract_dir, "runtime"))
+
+
[email protected]_micro
+def test_multiple_relay_modules_aot_graph():

Review Comment:
   do we need to test mixing the executor styles? i think it's prob legit, but 
just wanted to understand where this is providing test coverage.



##########
python/tvm/micro/model_library_format.py:
##########
@@ -24,6 +24,7 @@
 import re
 import tarfile
 import typing
+from typing import Union

Review Comment:
   i think we should unify on style here (either also import List, etc) or just 
typing.



##########
python/tvm/micro/model_library_format.py:
##########
@@ -67,56 +69,75 @@ def generate_c_interface_header(
 EPHEMERAL_MODULE_TYPE_KEYS = ("metadata_module",)
 
 
-def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None):
+def _populate_codegen_dir(
+    mods: Union[
+        typing.List[executor_factory.ExecutorFactoryModule], 
typing.List[tvm.runtime.Module]
+    ],
+    codegen_dir: str,
+):
     """Populate the codegen sub-directory as part of a Model Library Format 
export.
 
     Parameters
     ----------
-    mod : tvm.runtime.Module
-        Module which should be written to codegen_dir.
+    mods : List[tvm.relay.backend.executor_factory.ExecutorFactoryModule], 
List[tvm.runtime.Module]
+        A list of the return value of tvm.relay.build, which
+        will be exported into Model Library Format.
     codegen_dir : str
         Path to the codegen directory on disk.
     module_name: Optional[str]
         Name used to prefix the generated source files
 
     """
-    dso_modules = mod._collect_dso_modules()
-    non_dso_modules = mod._collect_from_import_tree(lambda m: m not in 
dso_modules)
-
-    # Filter ephemeral modules which cannot be exported.
-    dso_modules = [m for m in dso_modules if m.type_key not in 
EPHEMERAL_MODULE_TYPE_KEYS]
-    non_dso_modules = [m for m in non_dso_modules if m.type_key not in 
EPHEMERAL_MODULE_TYPE_KEYS]
+    dso_modules = []
+    for mod in mods:
+        if isinstance(mod, executor_factory.ExecutorFactoryModule):
+            lib = mod.lib
+        elif isinstance(mod, tvm.runtime.Module):

Review Comment:
   add `else` block and raise error



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to