comaniac commented on a change in pull request #7304:
URL: https://github.com/apache/tvm/pull/7304#discussion_r576483634
##########
File path: python/tvm/driver/tvmc/compiler.py
##########
@@ -185,13 +185,21 @@ def compile_model(
"""
dump_code = [x.strip() for x in dump_code.split(",")] if dump_code else
None
mod, params = frontends.load_model(path, model_format, shape_dict)
+ config = {}
if alter_layout:
mod = common.convert_graph_layout(mod, alter_layout)
- tvm_target = common.target_from_cli(target)
+ tvm_target, extra_targets = common.target_from_cli(target)
target_host = tvm_target if not target_host else target_host
+ for codegen_from_cli in extra_targets:
+ codegen = composite_target.get_target_by_name(codegen_from_cli["name"])
Review comment:
It's weird to see "get_target_by_name" returns a "codegen". Maybe
`get_codegen_by_target`?
##########
File path: python/tvm/driver/tvmc/composite_target.py
##########
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Provides support to composite target on TVMC.
+"""
+import logging
+
+from tvm.relay.op.contrib.arm_compute_lib import partition_for_arm_compute_lib
+from tvm.relay.op.contrib.ethosn import partition_for_ethosn
+
+from .common import TVMCException
+
+
+# pylint: disable=invalid-name
+logger = logging.getLogger("TVMC")
+
+# Global dictionary to map targets with the configuration key
+# to be used in the PassContext (if any), and a function
+# responsible for partitioning to that target.
+REGISTERED_TARGET = {
+ "acl": {
+ "config_key": None,
+ "pass_pipeline": partition_for_arm_compute_lib,
+ },
+ "ethos-n77": {
+ "config_key": "relay.ext.ethos-n.options",
+ "pass_pipeline": partition_for_ethosn,
+ },
+}
+
+
+def get_target_names():
+ """Return a list of all registered codegens.
+
+ Returns
+ --------
+ list of str
+ all registered targets
+ """
+ return list(REGISTERED_TARGET.keys())
+
+
+def get_target_by_name(name):
+ """Return a target entry by name.
+
+ Returns
+ --------
Review comment:
```suggestion
-------
```
##########
File path: python/tvm/driver/tvmc/composite_target.py
##########
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Provides support to composite target on TVMC.
+"""
+import logging
+
+from tvm.relay.op.contrib.arm_compute_lib import partition_for_arm_compute_lib
+from tvm.relay.op.contrib.ethosn import partition_for_ethosn
+
+from .common import TVMCException
+
+
+# pylint: disable=invalid-name
+logger = logging.getLogger("TVMC")
+
+# Global dictionary to map targets with the configuration key
+# to be used in the PassContext (if any), and a function
+# responsible for partitioning to that target.
+REGISTERED_TARGET = {
+ "acl": {
+ "config_key": None,
+ "pass_pipeline": partition_for_arm_compute_lib,
+ },
+ "ethos-n77": {
+ "config_key": "relay.ext.ethos-n.options",
+ "pass_pipeline": partition_for_ethosn,
+ },
+}
+
+
+def get_target_names():
+ """Return a list of all registered codegens.
+
+ Returns
+ --------
Review comment:
```suggestion
-------
```
##########
File path: python/tvm/driver/tvmc/byoc.py
##########
@@ -0,0 +1,191 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Provides support to Bring Your Own Codegen (BYOC) on TVMC.
+"""
+import logging
+
+from abc import ABC
+from abc import abstractmethod
+
+import tvm
+
+from tvm import relay
+
+from tvm.relay.op.contrib import get_pattern_table
+
+from .common import TVMCException
+
+
+# pylint: disable=invalid-name
+logger = logging.getLogger("TVMC")
+
+# Global dictionary to map existing custom codegens
+# with the names used to use them
+REGISTERED_CODEGEN = {}
+
+
+def register_codegen(kind):
+ """
+ Utility function to register a BYOC class for TVMC.
+
+ Classes decorated with `tvm.driver.tvmc.target.register_codegen` will
+ be added to the codegens dictionary.
+
+ Example
+ -------
+
+ @register_codegen(kind="samplebyoc")
+ class MyCustomTarget(TVMCCodegen):
+ ...
+ """
+
+ def codegen_decorator(cls):
+ cls.kind = kind
+
+ assert kind not in REGISTERED_CODEGEN, "there is already a codegen
'%s': %s" % (
+ kind,
+ REGISTERED_CODEGEN[kind],
+ )
+ assert issubclass(cls, TVMCCodegen), "%s is expected to be a subclass
of TVMCCodegen" % cls
+
+ REGISTERED_CODEGEN[kind] = cls
+ return cls
+
+ return codegen_decorator
+
+
+def get_codegen_kinds():
+ """Return a list of all registered codegens.
+
+ Returns
+ --------
+ list of str
+ all registered codegens
+ """
+ return REGISTERED_CODEGEN.keys()
+
+
+def get_codegen_by_kind(kind):
+ """Return a custom codegen by kind.
+
+ Returns
+ --------
+ TVMCCodegen
+ The requested codegen or None in case it is
+ not valid
+ """
+ try:
+ return REGISTERED_CODEGEN[kind]
+ except KeyError:
+ raise TVMCException("Target %s is not defined." % kind)
+
+
+class TVMCCodegen(ABC):
+ """Abstract class for command line driver BYOC definition.
+
+ Provide a unified way to create a codegen adapter with a set of
+ callback functions to be used within TVMC.
+
+ """
+
+ @staticmethod
+ @abstractmethod
+ def get_config_key_name():
+ """Return the name of the dictionary key to be used
+ at compile time (i.e. relay.build) as part of the
+ "config" argument.
+
+ Returns
+ -------
+ str
+ Name of the Codegen config dictionary key
+ to match with the expected name in the relay
+ implementation of this Codegen.
+ """
+
+ @staticmethod
+ @abstractmethod
+ def run_custom_passes(mod, params):
+ """Apply a set of transformations to the module
+ before compilation happens.
+
+ Parameters
+ ----------
+ mod : tvm.relay.Module
+ The relay module to convert.
+ params : dict
+ The parameters (weights) for the TVM module.
+
+ Returns
+ -------
+ mod: tvm.relay.Module
+ The converted module.
+ """
+
+
+@register_codegen(kind="ethos-n77")
+class TVMCEthosNCodegen(TVMCCodegen):
Review comment:
In fact, this is exactly the reason I would like to add others -- make
sure this registration mechanism works for not only one or certain codegens. If
adding TensorRT would change the way we register the codegen, the registry
should be improved. Of course, it's also possible to keep the current registry
map and change the TensorRT partition function. It would be better to at least
figure out what to change in this PR.
cc @trevor-m
##########
File path: python/tvm/driver/tvmc/composite_target.py
##########
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Provides support to composite target on TVMC.
+"""
+import logging
+
+from tvm.relay.op.contrib.arm_compute_lib import partition_for_arm_compute_lib
+from tvm.relay.op.contrib.ethosn import partition_for_ethosn
+
+from .common import TVMCException
+
+
+# pylint: disable=invalid-name
+logger = logging.getLogger("TVMC")
+
+# Global dictionary to map targets with the configuration key
+# to be used in the PassContext (if any), and a function
+# responsible for partitioning to that target.
+REGISTERED_TARGET = {
+ "acl": {
+ "config_key": None,
+ "pass_pipeline": partition_for_arm_compute_lib,
+ },
+ "ethos-n77": {
+ "config_key": "relay.ext.ethos-n.options",
+ "pass_pipeline": partition_for_ethosn,
+ },
+}
+
+
+def get_target_names():
+ """Return a list of all registered codegens.
+
+ Returns
+ --------
+ list of str
+ all registered targets
+ """
+ return list(REGISTERED_TARGET.keys())
+
+
+def get_target_by_name(name):
+ """Return a target entry by name.
+
+ Returns
+ --------
+ dict
+ requested target information
+ """
+ try:
+ return REGISTERED_TARGET[name]
+ except KeyError:
+ raise TVMCException("Target %s is not defined." % name)
Review comment:
```suggestion
raise TVMCException("Compsite target %s is not defined in TVMC." %
name)
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]