comaniac commented on a change in pull request #6302:
URL: https://github.com/apache/incubator-tvm/pull/6302#discussion_r479420731
##########
File path: python/tvm/driver/tvmc/common.py
##########
@@ -17,6 +17,74 @@
"""
Common utility functions shared by TVMC modules.
"""
+import argparse
+import re
+
+from tvm import relay
+from tvm import transform
class TVMCException(Exception):
"""TVMC Exception"""
+
+
+def convert_graph_layout(mod, desired_layout):
+ """Alter the layout of the input graph.
+
+ Parameters
+ ----------
+ mod : tvm.relay.Module
+ The relay module to convert.
+ desired_layout : str
+ The layout to convert to.
+
+ Returns
+ -------
+ mod : tvm.relay.Module
+ The converted module.
+ """
+
+ # Assume for the time being that graphs only have
+ # conv2d as heavily-sensitive operators.
+ desired_layouts = {
+ "nn.conv2d": [desired_layout, "default"],
+ "qnn.conv2d": [desired_layout, "default"],
+ }
+
+ # Convert the layout of the graph where possible.
+ seq = transform.Sequential(
+ [
+ relay.transform.RemoveUnusedFunctions(),
+ relay.transform.ConvertLayout(desired_layouts),
+ ]
+ )
+ with transform.PassContext(opt_level=3):
+ return seq(mod)
+
+
+def parse_input_shapes(shapes_str):
+ """ Parsing function for tensor shape syntax. """
+ shapes = []
+ # Split up string into comma seperated sections ignoring commas in ()s
+ match = re.findall(r"(\(.*?\)|.+?),?", shapes_str)
+ if match:
+ for inp in match:
+ # Test for and remove brackets
+ shape = re.match(r"\((.*)\)", inp)
+ if shape and shape.lastindex == 1:
+ # Remove white space and extract numbers
+ strshape = shape[1].replace(" ", "").split(",")
+ try:
+ shapes.append([int(i) for i in strshape])
+ except ValueError:
+ raise argparse.ArgumentTypeError(
+ f"expected numbers in shape '{shape[1]}'"
Review comment:
Per https://github.com/apache/incubator-tvm/pull/4250, we don't use
f-string for now.
##########
File path: python/tvm/driver/tvmc/compiler.py
##########
@@ -0,0 +1,305 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Provides support to compile networks both AOT and JIT.
+"""
+import logging
+import os.path
+import tarfile
+from pathlib import Path
+
+import tvm
+from tvm import autotvm
+from tvm import relay
+from tvm.contrib import cc
+from tvm.contrib import util
+
+from . import common, frontends
+from .main import register_parser
+
+
+@register_parser
+def add_compile_parser(subparsers):
+ """ Include parser for 'compile' subcommand """
+
+ parser = subparsers.add_parser("compile", help="compile a model")
+ parser.set_defaults(func=drive_compile)
+ parser.add_argument(
+ "--cross-compiler",
+ default="",
+ help="the cross compiler to generate target libraries, e.g.
'aarch64-linux-gnu-gcc'",
+ )
+ parser.add_argument(
+ "--dump-code",
+ metavar="FORMAT",
+ default="",
+ help="comma separarated list of formats to export, e.g. 'asm,ll,relay'
"
+ )
+ parser.add_argument(
+ "--model-format",
+ choices=frontends.get_frontends(),
+ help="specify input model format",
+ )
+ parser.add_argument(
+ "--input-shape",
+ type=common.parse_input_shapes,
+ metavar="INPUT_SHAPE,[INPUT_SHAPE]...",
+ help="for pytorch, e.g. '(1,3,224,224)'",
+ )
+ parser.add_argument(
+ "-o",
+ "--output",
+ default="module.tar",
+ help="output the compiled module to an archive",
+ )
+ parser.add_argument(
+ "--target",
+ help="compilation target as plain string, inline JSON or path to a
JSON file",
+ required=True
+ )
+ parser.add_argument(
+ "--tuning-records",
+ metavar="PATH",
+ default="",
+ help="path to an auto-tuning log file from AutoTVM"
+ )
+ parser.add_argument(
+ "--desired-layout",
+ choices=["NCHW", "NHWC"],
+ default=None,
+ help="change the data layout of the whole graph",
+ )
+ parser.add_argument(
+ "-v", "--verbose", action="count", default=0, help="increase verbosity"
+ )
+ parser.add_argument("FILE")
Review comment:
Add description to this argument.
Out of scope: Maybe we could consider adding a feature to pull a model from
Gluon CV modelzoo or torchvision in the future.
##########
File path: python/tvm/driver/tvmc/compiler.py
##########
@@ -0,0 +1,305 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Provides support to compile networks both AOT and JIT.
+"""
+import logging
+import os.path
+import tarfile
+from pathlib import Path
+
+import tvm
+from tvm import autotvm
+from tvm import relay
+from tvm.contrib import cc
+from tvm.contrib import util
+
+from . import common, frontends
+from .main import register_parser
+
+
+@register_parser
+def add_compile_parser(subparsers):
+ """ Include parser for 'compile' subcommand """
+
+ parser = subparsers.add_parser("compile", help="compile a model")
+ parser.set_defaults(func=drive_compile)
+ parser.add_argument(
+ "--cross-compiler",
+ default="",
+ help="the cross compiler to generate target libraries, e.g.
'aarch64-linux-gnu-gcc'",
+ )
+ parser.add_argument(
+ "--dump-code",
+ metavar="FORMAT",
+ default="",
+ help="comma separarated list of formats to export, e.g. 'asm,ll,relay'
"
+ )
+ parser.add_argument(
+ "--model-format",
+ choices=frontends.get_frontends(),
+ help="specify input model format",
+ )
+ parser.add_argument(
+ "--input-shape",
+ type=common.parse_input_shapes,
+ metavar="INPUT_SHAPE,[INPUT_SHAPE]...",
+ help="for pytorch, e.g. '(1,3,224,224)'",
+ )
+ parser.add_argument(
+ "-o",
+ "--output",
+ default="module.tar",
+ help="output the compiled module to an archive",
+ )
+ parser.add_argument(
+ "--target",
+ help="compilation target as plain string, inline JSON or path to a
JSON file",
+ required=True
+ )
+ parser.add_argument(
+ "--tuning-records",
+ metavar="PATH",
+ default="",
+ help="path to an auto-tuning log file from AutoTVM"
+ )
+ parser.add_argument(
+ "--desired-layout",
+ choices=["NCHW", "NHWC"],
+ default=None,
+ help="change the data layout of the whole graph",
+ )
+ parser.add_argument(
+ "-v", "--verbose", action="count", default=0, help="increase verbosity"
+ )
+ parser.add_argument("FILE")
+
+
+def drive_compile(args):
+ """ Invoke tvmc.compiler module with command line arguments """
+
+ graph, lib, params, dumps = compile_model(
+ args.FILE,
+ args.target,
+ args.dump_code,
+ "",
+ args.model_format,
+ args.input_shape,
+ args.tuning_records,
+ args.tensor_layout,
+ )
+
+ if dumps:
+ save_dumps(args.output, dumps)
+
+ save_module(args.output, graph, lib, params, args.cross_compiler)
+ return 0
+
+
+def compile_model(
+ path,
+ target,
+ dump_sources=None,
+ target_host=None,
+ model_format=None,
+ shapes=None,
+ tuning_records=None,
+ alter_layout=None,
+):
+ """Compile a model from a supported framework into a TVM module.
+
+ This function takes a union of the arguments of both frontends.load_model
+ and compiler.compile_relay. The resulting TVM module can be executed using
+ the graph runtime.
+
+ Returns
+ -------
+ graph : str
+ A JSON-serialized TVM execution graph.
+ lib : tvm.module.Module
+ A TVM module containing the compiled functions.
+ params : dict
+ The parameters (weights) for the TVM module.
+ dumps : dict
+ Dictionary containing the dumps specified.
+
+ """
+ dump_sources = [x.strip() for x in dump_sources.split(',')] if
dump_sources else None
+ mod, params = frontends.load_model(path, model_format, shapes)
+
+ return compile_relay(
+ mod,
+ params,
+ target,
+ dump_sources=dump_sources,
+ target_host=target_host,
+ tuning_records=tuning_records,
+ alter_layout=alter_layout,
+ )
+
+
+def compile_relay(
+ mod,
+ params,
+ target,
+ dump_sources=None,
+ target_host=None,
+ tuning_records=None,
+ alter_layout=None,
+):
+ """Compile a relay module to a TVM module for the graph runtime.
+
+ Parameters
+ ----------
+ mod : tvm.relay.Module
+ The relay module to compile.
+ params : dict
+ The parameters (weights) for the relay module.
+ target : str
+ The target for which to compile. Can be a plain string or
+ a path.
+ dump_sources : list, optional
+ Dump the generated code for the specified source types, on
+ the requested target.
+ target_host : Union[str, tvm.target.Target], optional
+ The target of the host machine if host-side code
+ needs to be generated.
+ tuning_records: str, optional
+ Name of the file produced by the tuning to be used during
+ compilation.
+ alter_layout: str, optional
+ The layout to convert the graph to. Note, the convert layout
+ pass doesn't currently guarantee the whole of the graph will
+ be converted to the chosen layout.
+
+ Returns
+ -------
+ graph : str
+ A JSON-serialized TVM execution graph.
+ lib : tvm.module.Module
+ A TVM module containing the compiled functions.
+ params : dict
+ The parameters (weights) for the TVM module.
+ dumps : dict
+ Dictionary containing the dumps specified.
+
+ """
+
+ if alter_layout:
+ mod = common.convert_graph_layout(mod, alter_layout)
+
+ if os.path.exists(str(target)):
+ with open(target) as target_file:
+ logging.info("using target input from file: %s", target)
+ target = "".join(target_file.readlines())
+
+ # TODO: We don't have an API to collect a list of supported
+ # targets yet. (@leandron)
+ logging.debug("creating target from input: %s", target)
+ tvm_target = tvm.target.create(target)
+ target_host = target_host or ""
+
+ if tuning_records:
+ logging.debug("tuning records file provided: %s", tuning_records)
+ with autotvm.apply_history_best(tuning_records):
+ with tvm.transform.PassContext(opt_level=3):
+ logging.debug("building relay graph with tuning records")
+ graph_module = relay.build(mod, tvm_target, params=params,
target_host=tvm_target)
Review comment:
We may need to provide an option to enable graph tuner. If it is not a
plan in this PR, add a TODO.
##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -0,0 +1,389 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Provides support to parse models from different frameworks into Relay networks.
+
+Frontend classes do lazy-loading of modules on purpose, to reduce time spent on
+loading the tool.
+"""
+import logging
+import os
+import sys
+from abc import ABC
+from abc import abstractmethod
+from pathlib import Path
+
+from tvm.driver.tvmc.common import TVMCException
+
+
+class Frontend(ABC):
+ """Abstract class for frontend"""
+
+ @staticmethod
+ @abstractmethod
+ def name():
+ """Frontend name"""
+
+ @staticmethod
+ @abstractmethod
+ def suffixes():
+ """File suffixes (extensions) used by this frontend"""
+
+ @abstractmethod
+ def load(self, path, shapes):
+ """Load network"""
+
+
+def import_keras():
+ """ Lazy import function for Keras"""
+ # Keras writes the message "Using TensorFlow backend." to stderr
+ # Redirect stderr during the import to disable this
+ stderr = sys.stderr
+ sys.stderr = open(os.devnull, "w")
+ try:
+ # pylint: disable=C0415
+ import tensorflow as tf
+ from tensorflow import keras
+
+ return tf, keras
+ finally:
+ sys.stderr = stderr
+
+
+class KerasFrontend(Frontend):
+ """ Keras frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "keras"
+
+ @staticmethod
+ def suffixes():
+ return ["h5"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import numpy as np
+ from tvm import relay
+
+ # pylint: disable=C0103
+ tf, keras = import_keras()
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ # tvm build currently imports keras directly instead of
tensorflow.keras
+ try:
+ model = keras.models.load_model(path)
+ except ValueError as err:
+ raise TVMCException(str(err))
+
+ # There are two flavours of keras model, sequential and
+ # functional, TVM expects a functional model, so convert
+ # if required:
+ if self.is_sequential_p(model):
+ model = self.sequential_to_functional(model)
+
+ in_shapes = []
+ for layer in model._input_layers:
+ if tf.executing_eagerly():
+ in_shapes.append(
+ tuple(dim if dim is not None else 1 for dim in
layer.input.shape)
+ )
+ else:
+ in_shapes.append(
+ tuple(
+ dim.value if dim.value is not None else 1
+ for dim in layer.input.shape
+ )
+ )
+
+ inputs = [
+ np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in
in_shapes
+ ]
+ shape_dict = {name: x.shape for (name, x) in zip(model.input_names,
inputs)}
+ return relay.frontend.from_keras(model, shape_dict, layout="NHWC")
+
+ def is_sequential_p(self, model):
+ _, keras = import_keras()
+ return isinstance(model, keras.models.Sequential)
+
+ def sequential_to_functional(self, model):
+ _, keras = import_keras()
+ assert self.is_sequential_p(model)
+ input_layer =
keras.layers.Input(batch_shape=model.layers[0].input_shape)
+ prev_layer = input_layer
+ for layer in model.layers:
+ prev_layer = layer(prev_layer)
+ model = keras.models.Model([input_layer], [prev_layer])
+ return model
+
+
+class OnnxFrontend(Frontend):
+ """ ONNX frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "onnx"
+
+ @staticmethod
+ def suffixes():
+ return ["onnx"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import onnx
+ from tvm import relay
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ model = onnx.load(path)
+
+ # Find the name and shape of the first input in the graph
+
+ # pylint: disable=E1101
+ name = model.graph.input[0].name
+
+ # pylint: disable=E1101
+ proto_shape = model.graph.input[0].type.tensor_type.shape.dim
+ shape = [d.dim_value for d in proto_shape]
+
+ shape_dict = {name: shape}
+
+ return relay.frontend.from_onnx(model, shape_dict)
+
+
+class TensorflowFrontend(Frontend):
+ """ TensorFlow frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "pb"
+
+ @staticmethod
+ def suffixes():
+ return ["pb"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ from tvm import relay
+ import tensorflow as tf
+ import tvm.relay.testing.tf as tf_testing
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ with tf.io.gfile.GFile(path, "rb") as tf_graph:
+ content = tf_graph.read()
+
+ graph_def = tf.compat.v1.GraphDef()
+ graph_def.ParseFromString(content)
+ graph_def = tf_testing.ProcessGraphDefParam(graph_def)
+
+ logging.debug("relay.frontend.from_tensorflow")
+ return relay.frontend.from_tensorflow(graph_def)
+
+
+class TFLiteFrontend(Frontend):
+ """ TFLite frontend for TVMC """
+
+ _tflite_m = {
+ 0: "float32",
+ 1: "float16",
+ 2: "int32",
+ 3: "uint8",
+ 4: "int64",
+ 5: "string",
+ 6: "bool",
+ 7: "int16",
+ 8: "complex64",
+ 9: "int8",
+ }
+
+ @staticmethod
+ def name():
+ return "tflite"
+
+ @staticmethod
+ def suffixes():
+ return ["tflite"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import tflite.Model as model
+ from tvm import relay
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ with open(path, "rb") as tf_graph:
+ content = tf_graph.read()
+
+ # tflite.Model.Model is tflite.Model in 1.14 and 2.1.0
+ try:
+ tflite_model = model.Model.GetRootAsModel(content, 0)
+ except AttributeError:
+ tflite_model = model.GetRootAsModel(content, 0)
+
+ try:
+ version = tflite_model.Version()
+ logging.debug("tflite version %s", version)
+ except Exception:
+ raise TVMCException("input file not tflite")
+
+ if version != 3:
+ raise TVMCException("input file not tflite version 3")
+
+ logging.debug("tflite_input_type")
+ shape_dict, dtype_dict = TFLiteFrontend._input_type(tflite_model)
+
+ # parse TFLite model and convert into Relay computation graph
+ logging.debug("relay.frontend.from_tflite")
+ mod, params = relay.frontend.from_tflite(
+ tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict
+ )
+ return mod, params
+
+ @staticmethod
+ def _decode_type(n):
+ return TFLiteFrontend._tflite_m[n]
+
+ @staticmethod
+ def _input_type(model):
+ subgraph_count = model.SubgraphsLength()
+ assert subgraph_count > 0
+ shape_dict = {}
+ dtype_dict = {}
+ for subgraph_index in range(subgraph_count):
+ subgraph = model.Subgraphs(subgraph_index)
+ inputs_count = subgraph.InputsLength()
+ assert inputs_count >= 1
+ for input_index in range(inputs_count):
+ input_ = subgraph.Inputs(input_index)
+ assert subgraph.TensorsLength() > input_
+ tensor = subgraph.Tensors(input_)
+ input_shape = tuple(tensor.ShapeAsNumpy())
+ tensor_type = tensor.Type()
+ input_name = tensor.Name().decode("utf8")
+ shape_dict[input_name] = input_shape
+ dtype_dict[input_name] =
TFLiteFrontend._decode_type(tensor_type)
+
+ return shape_dict, dtype_dict
+
+
+class PyTorchFrontend(Frontend):
+ """ PyTorch frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "pytorch"
+
+ @staticmethod
+ def suffixes():
+ # Torch Script is a zip file, but can be named pth
+ return ["pth", "zip"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import torch
+ from tvm import relay
+
+ if not shapes:
+ raise TVMCException(
+ "--input-shape must be specified for {}".format(self.name())
+ )
+
+ traced_model = torch.jit.load(path)
+ traced_model.eval() # Switch to inference mode
+ input_shapes = [
+ ("input{}".format(idx), shape) for idx, shape in enumerate(shapes)
+ ]
+ logging.debug("relay.frontend.from_pytorch")
+ return relay.frontend.from_pytorch(traced_model, input_shapes)
+
+
+ALL_FRONTENDS = [
+ KerasFrontend,
+ OnnxFrontend,
+ TensorflowFrontend,
+ TFLiteFrontend,
+ PyTorchFrontend,
+]
+
+
+def get_frontends():
Review comment:
Since you have `lookup_frontend`, this name might be confusing.
`get_frontend_names` might be more accurate
##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -0,0 +1,389 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Provides support to parse models from different frameworks into Relay networks.
+
+Frontend classes do lazy-loading of modules on purpose, to reduce time spent on
+loading the tool.
+"""
+import logging
+import os
+import sys
+from abc import ABC
+from abc import abstractmethod
+from pathlib import Path
+
+from tvm.driver.tvmc.common import TVMCException
+
+
+class Frontend(ABC):
+ """Abstract class for frontend"""
+
+ @staticmethod
+ @abstractmethod
+ def name():
+ """Frontend name"""
+
+ @staticmethod
+ @abstractmethod
+ def suffixes():
+ """File suffixes (extensions) used by this frontend"""
+
+ @abstractmethod
+ def load(self, path, shapes):
+ """Load network"""
+
+
+def import_keras():
+ """ Lazy import function for Keras"""
+ # Keras writes the message "Using TensorFlow backend." to stderr
+ # Redirect stderr during the import to disable this
+ stderr = sys.stderr
+ sys.stderr = open(os.devnull, "w")
+ try:
+ # pylint: disable=C0415
+ import tensorflow as tf
+ from tensorflow import keras
+
+ return tf, keras
+ finally:
+ sys.stderr = stderr
+
+
+class KerasFrontend(Frontend):
+ """ Keras frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "keras"
+
+ @staticmethod
+ def suffixes():
+ return ["h5"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import numpy as np
+ from tvm import relay
+
+ # pylint: disable=C0103
+ tf, keras = import_keras()
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ # tvm build currently imports keras directly instead of
tensorflow.keras
+ try:
+ model = keras.models.load_model(path)
+ except ValueError as err:
+ raise TVMCException(str(err))
+
+ # There are two flavours of keras model, sequential and
+ # functional, TVM expects a functional model, so convert
+ # if required:
+ if self.is_sequential_p(model):
+ model = self.sequential_to_functional(model)
+
+ in_shapes = []
+ for layer in model._input_layers:
+ if tf.executing_eagerly():
+ in_shapes.append(
+ tuple(dim if dim is not None else 1 for dim in
layer.input.shape)
+ )
+ else:
+ in_shapes.append(
+ tuple(
+ dim.value if dim.value is not None else 1
+ for dim in layer.input.shape
+ )
+ )
+
+ inputs = [
+ np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in
in_shapes
+ ]
+ shape_dict = {name: x.shape for (name, x) in zip(model.input_names,
inputs)}
+ return relay.frontend.from_keras(model, shape_dict, layout="NHWC")
+
+ def is_sequential_p(self, model):
+ _, keras = import_keras()
+ return isinstance(model, keras.models.Sequential)
+
+ def sequential_to_functional(self, model):
+ _, keras = import_keras()
+ assert self.is_sequential_p(model)
+ input_layer =
keras.layers.Input(batch_shape=model.layers[0].input_shape)
+ prev_layer = input_layer
+ for layer in model.layers:
+ prev_layer = layer(prev_layer)
+ model = keras.models.Model([input_layer], [prev_layer])
+ return model
+
+
+class OnnxFrontend(Frontend):
+ """ ONNX frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "onnx"
+
+ @staticmethod
+ def suffixes():
+ return ["onnx"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import onnx
+ from tvm import relay
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ model = onnx.load(path)
+
+ # Find the name and shape of the first input in the graph
+
+ # pylint: disable=E1101
+ name = model.graph.input[0].name
+
+ # pylint: disable=E1101
+ proto_shape = model.graph.input[0].type.tensor_type.shape.dim
+ shape = [d.dim_value for d in proto_shape]
+
+ shape_dict = {name: shape}
+
+ return relay.frontend.from_onnx(model, shape_dict)
+
+
+class TensorflowFrontend(Frontend):
+ """ TensorFlow frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "pb"
+
+ @staticmethod
+ def suffixes():
+ return ["pb"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ from tvm import relay
+ import tensorflow as tf
+ import tvm.relay.testing.tf as tf_testing
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ with tf.io.gfile.GFile(path, "rb") as tf_graph:
+ content = tf_graph.read()
+
+ graph_def = tf.compat.v1.GraphDef()
+ graph_def.ParseFromString(content)
+ graph_def = tf_testing.ProcessGraphDefParam(graph_def)
+
+ logging.debug("relay.frontend.from_tensorflow")
+ return relay.frontend.from_tensorflow(graph_def)
+
+
+class TFLiteFrontend(Frontend):
+ """ TFLite frontend for TVMC """
+
+ _tflite_m = {
+ 0: "float32",
+ 1: "float16",
+ 2: "int32",
+ 3: "uint8",
+ 4: "int64",
+ 5: "string",
+ 6: "bool",
+ 7: "int16",
+ 8: "complex64",
+ 9: "int8",
+ }
+
+ @staticmethod
+ def name():
+ return "tflite"
+
+ @staticmethod
+ def suffixes():
+ return ["tflite"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import tflite.Model as model
+ from tvm import relay
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ with open(path, "rb") as tf_graph:
+ content = tf_graph.read()
+
+ # tflite.Model.Model is tflite.Model in 1.14 and 2.1.0
+ try:
+ tflite_model = model.Model.GetRootAsModel(content, 0)
+ except AttributeError:
+ tflite_model = model.GetRootAsModel(content, 0)
+
+ try:
+ version = tflite_model.Version()
+ logging.debug("tflite version %s", version)
+ except Exception:
+ raise TVMCException("input file not tflite")
+
+ if version != 3:
+ raise TVMCException("input file not tflite version 3")
+
+ logging.debug("tflite_input_type")
+ shape_dict, dtype_dict = TFLiteFrontend._input_type(tflite_model)
+
+ # parse TFLite model and convert into Relay computation graph
+ logging.debug("relay.frontend.from_tflite")
+ mod, params = relay.frontend.from_tflite(
+ tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict
+ )
+ return mod, params
+
+ @staticmethod
+ def _decode_type(n):
+ return TFLiteFrontend._tflite_m[n]
+
+ @staticmethod
+ def _input_type(model):
+ subgraph_count = model.SubgraphsLength()
+ assert subgraph_count > 0
+ shape_dict = {}
+ dtype_dict = {}
+ for subgraph_index in range(subgraph_count):
+ subgraph = model.Subgraphs(subgraph_index)
+ inputs_count = subgraph.InputsLength()
+ assert inputs_count >= 1
+ for input_index in range(inputs_count):
+ input_ = subgraph.Inputs(input_index)
+ assert subgraph.TensorsLength() > input_
+ tensor = subgraph.Tensors(input_)
+ input_shape = tuple(tensor.ShapeAsNumpy())
+ tensor_type = tensor.Type()
+ input_name = tensor.Name().decode("utf8")
+ shape_dict[input_name] = input_shape
+ dtype_dict[input_name] =
TFLiteFrontend._decode_type(tensor_type)
+
+ return shape_dict, dtype_dict
+
+
+class PyTorchFrontend(Frontend):
+ """ PyTorch frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "pytorch"
+
+ @staticmethod
+ def suffixes():
+ # Torch Script is a zip file, but can be named pth
+ return ["pth", "zip"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import torch
+ from tvm import relay
+
+ if not shapes:
+ raise TVMCException(
+ "--input-shape must be specified for {}".format(self.name())
+ )
+
+ traced_model = torch.jit.load(path)
+ traced_model.eval() # Switch to inference mode
+ input_shapes = [
+ ("input{}".format(idx), shape) for idx, shape in enumerate(shapes)
+ ]
+ logging.debug("relay.frontend.from_pytorch")
+ return relay.frontend.from_pytorch(traced_model, input_shapes)
+
+
+ALL_FRONTENDS = [
+ KerasFrontend,
+ OnnxFrontend,
+ TensorflowFrontend,
+ TFLiteFrontend,
+ PyTorchFrontend,
+]
+
+
+def get_frontends():
+ """Return the names of all supported frontends"""
+ return [frontend.name() for frontend in ALL_FRONTENDS]
+
+
+def lookup_frontend(name):
+ for frontend in ALL_FRONTENDS:
+ if name == frontend.name():
+ return frontend()
+ raise TVMCException("unrecognized frontend")
+
+
+def guess_input_language(path):
+ suffix = Path(path).suffix.lower()
+ if suffix.startswith("."):
+ suffix = suffix[1:]
+
+ for frontend in ALL_FRONTENDS:
+ if suffix in frontend.suffixes():
+ return frontend()
+
+ raise TVMCException("cannot guess input language")
+
+
+def load_model(path, language=None, shapes=None):
+ """Load a model from a supported framework and convert it
+ into an equivalent relay representation.
+
+ Parameters
+ ----------
Review comment:
shapes?
##########
File path: python/tvm/driver/tvmc/common.py
##########
@@ -17,6 +17,74 @@
"""
Common utility functions shared by TVMC modules.
"""
+import argparse
+import re
+
+from tvm import relay
+from tvm import transform
class TVMCException(Exception):
"""TVMC Exception"""
+
+
+def convert_graph_layout(mod, desired_layout):
Review comment:
ConvertLayout pass doesn't check if the input and desired layouts are
the same, so you need to make sure it won't be a problem. Please cover it in
the unit test.
##########
File path: tests/python/driver/tvmc/test_frontends.py
##########
@@ -0,0 +1,216 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+import tarfile
+
+import pytest
+
+from tvm.ir.module import IRModule
+
+from tvm.driver import tvmc
+from tvm.driver.tvmc.common import TVMCException
+
+
+def test_get_frontends_is_list():
+ sut = tvmc.frontends.get_frontends()
+ assert type(sut) is list
+
+
+def test_get_frontends_contains_only_strings():
+ sut = tvmc.frontends.get_frontends()
+ assert all([type(x) is str for x in sut]) is True
+
+
+def test_lookup_frontend_valid():
+ sut = tvmc.frontends.lookup_frontend("keras")
+ assert type(sut) is tvmc.frontends.KerasFrontend
+
+
+def test_lookup_frontend_invalid():
+ with pytest.raises(TVMCException) as e:
+ def f():
+ tvmc.frontends.lookup_frontend("unsupported_thingy")
Review comment:
You meant "thing"?
##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -0,0 +1,389 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Provides support to parse models from different frameworks into Relay networks.
+
+Frontend classes do lazy-loading of modules on purpose, to reduce time spent on
+loading the tool.
+"""
+import logging
+import os
+import sys
+from abc import ABC
+from abc import abstractmethod
+from pathlib import Path
+
+from tvm.driver.tvmc.common import TVMCException
+
+
+class Frontend(ABC):
+ """Abstract class for frontend"""
+
+ @staticmethod
+ @abstractmethod
+ def name():
+ """Frontend name"""
+
+ @staticmethod
+ @abstractmethod
+ def suffixes():
+ """File suffixes (extensions) used by this frontend"""
+
+ @abstractmethod
+ def load(self, path, shapes):
+ """Load network"""
+
+
+def import_keras():
+ """ Lazy import function for Keras"""
+ # Keras writes the message "Using TensorFlow backend." to stderr
+ # Redirect stderr during the import to disable this
+ stderr = sys.stderr
+ sys.stderr = open(os.devnull, "w")
+ try:
+ # pylint: disable=C0415
+ import tensorflow as tf
+ from tensorflow import keras
+
+ return tf, keras
+ finally:
+ sys.stderr = stderr
+
+
+class KerasFrontend(Frontend):
+ """ Keras frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "keras"
+
+ @staticmethod
+ def suffixes():
+ return ["h5"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import numpy as np
+ from tvm import relay
Review comment:
- This could be put on the top of this file since every frontend will
need it.
- You are now in the same package as TVM.
##########
File path: python/tvm/driver/tvmc/compiler.py
##########
@@ -0,0 +1,305 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Provides support to compile networks both AOT and JIT.
+"""
+import logging
+import os.path
+import tarfile
+from pathlib import Path
+
+import tvm
+from tvm import autotvm
+from tvm import relay
+from tvm.contrib import cc
+from tvm.contrib import util
+
+from . import common, frontends
+from .main import register_parser
+
+
+@register_parser
+def add_compile_parser(subparsers):
+ """ Include parser for 'compile' subcommand """
+
+ parser = subparsers.add_parser("compile", help="compile a model")
+ parser.set_defaults(func=drive_compile)
+ parser.add_argument(
+ "--cross-compiler",
+ default="",
+ help="the cross compiler to generate target libraries, e.g.
'aarch64-linux-gnu-gcc'",
+ )
+ parser.add_argument(
+ "--dump-code",
+ metavar="FORMAT",
+ default="",
+ help="comma separarated list of formats to export, e.g. 'asm,ll,relay'
"
+ )
+ parser.add_argument(
+ "--model-format",
+ choices=frontends.get_frontends(),
+ help="specify input model format",
+ )
+ parser.add_argument(
+ "--input-shape",
+ type=common.parse_input_shapes,
+ metavar="INPUT_SHAPE,[INPUT_SHAPE]...",
+ help="for pytorch, e.g. '(1,3,224,224)'",
+ )
+ parser.add_argument(
+ "-o",
+ "--output",
+ default="module.tar",
+ help="output the compiled module to an archive",
+ )
+ parser.add_argument(
+ "--target",
+ help="compilation target as plain string, inline JSON or path to a
JSON file",
+ required=True
+ )
+ parser.add_argument(
+ "--tuning-records",
+ metavar="PATH",
+ default="",
+ help="path to an auto-tuning log file from AutoTVM"
+ )
+ parser.add_argument(
+ "--desired-layout",
+ choices=["NCHW", "NHWC"],
+ default=None,
+ help="change the data layout of the whole graph",
+ )
+ parser.add_argument(
+ "-v", "--verbose", action="count", default=0, help="increase verbosity"
+ )
+ parser.add_argument("FILE")
+
+
+def drive_compile(args):
+ """ Invoke tvmc.compiler module with command line arguments """
+
+ graph, lib, params, dumps = compile_model(
+ args.FILE,
+ args.target,
+ args.dump_code,
+ "",
+ args.model_format,
+ args.input_shape,
+ args.tuning_records,
+ args.tensor_layout,
+ )
+
+ if dumps:
+ save_dumps(args.output, dumps)
+
+ save_module(args.output, graph, lib, params, args.cross_compiler)
+ return 0
+
+
+def compile_model(
+ path,
+ target,
+ dump_sources=None,
+ target_host=None,
+ model_format=None,
+ shapes=None,
+ tuning_records=None,
+ alter_layout=None,
+):
Review comment:
Looks like you can just merge `compile_relay` to `compile_model`.
##########
File path: python/tvm/driver/tvmc/compiler.py
##########
@@ -0,0 +1,305 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Provides support to compile networks both AOT and JIT.
+"""
+import logging
+import os.path
+import tarfile
+from pathlib import Path
+
+import tvm
+from tvm import autotvm
+from tvm import relay
+from tvm.contrib import cc
+from tvm.contrib import util
+
+from . import common, frontends
+from .main import register_parser
+
+
+@register_parser
+def add_compile_parser(subparsers):
+ """ Include parser for 'compile' subcommand """
+
+ parser = subparsers.add_parser("compile", help="compile a model")
+ parser.set_defaults(func=drive_compile)
+ parser.add_argument(
+ "--cross-compiler",
+ default="",
+ help="the cross compiler to generate target libraries, e.g.
'aarch64-linux-gnu-gcc'",
+ )
+ parser.add_argument(
+ "--dump-code",
+ metavar="FORMAT",
+ default="",
+ help="comma separarated list of formats to export, e.g. 'asm,ll,relay'
"
+ )
+ parser.add_argument(
+ "--model-format",
+ choices=frontends.get_frontends(),
+ help="specify input model format",
+ )
+ parser.add_argument(
+ "--input-shape",
+ type=common.parse_input_shapes,
+ metavar="INPUT_SHAPE,[INPUT_SHAPE]...",
+ help="for pytorch, e.g. '(1,3,224,224)'",
+ )
+ parser.add_argument(
+ "-o",
+ "--output",
+ default="module.tar",
+ help="output the compiled module to an archive",
+ )
+ parser.add_argument(
+ "--target",
+ help="compilation target as plain string, inline JSON or path to a
JSON file",
+ required=True
+ )
+ parser.add_argument(
+ "--tuning-records",
+ metavar="PATH",
+ default="",
+ help="path to an auto-tuning log file from AutoTVM"
+ )
+ parser.add_argument(
+ "--desired-layout",
+ choices=["NCHW", "NHWC"],
+ default=None,
+ help="change the data layout of the whole graph",
+ )
+ parser.add_argument(
+ "-v", "--verbose", action="count", default=0, help="increase verbosity"
+ )
+ parser.add_argument("FILE")
+
+
+def drive_compile(args):
+ """ Invoke tvmc.compiler module with command line arguments """
+
+ graph, lib, params, dumps = compile_model(
+ args.FILE,
+ args.target,
+ args.dump_code,
+ "",
+ args.model_format,
+ args.input_shape,
+ args.tuning_records,
+ args.tensor_layout,
+ )
+
+ if dumps:
+ save_dumps(args.output, dumps)
+
+ save_module(args.output, graph, lib, params, args.cross_compiler)
+ return 0
+
+
+def compile_model(
+ path,
+ target,
+ dump_sources=None,
+ target_host=None,
+ model_format=None,
+ shapes=None,
+ tuning_records=None,
+ alter_layout=None,
+):
+ """Compile a model from a supported framework into a TVM module.
+
+ This function takes a union of the arguments of both frontends.load_model
+ and compiler.compile_relay. The resulting TVM module can be executed using
+ the graph runtime.
+
+ Returns
+ -------
+ graph : str
+ A JSON-serialized TVM execution graph.
+ lib : tvm.module.Module
+ A TVM module containing the compiled functions.
+ params : dict
+ The parameters (weights) for the TVM module.
+ dumps : dict
+ Dictionary containing the dumps specified.
+
+ """
+ dump_sources = [x.strip() for x in dump_sources.split(',')] if
dump_sources else None
+ mod, params = frontends.load_model(path, model_format, shapes)
+
+ return compile_relay(
+ mod,
+ params,
+ target,
+ dump_sources=dump_sources,
+ target_host=target_host,
+ tuning_records=tuning_records,
+ alter_layout=alter_layout,
+ )
+
+
+def compile_relay(
+ mod,
+ params,
+ target,
+ dump_sources=None,
+ target_host=None,
+ tuning_records=None,
+ alter_layout=None,
+):
+ """Compile a relay module to a TVM module for the graph runtime.
+
+ Parameters
+ ----------
+ mod : tvm.relay.Module
+ The relay module to compile.
+ params : dict
+ The parameters (weights) for the relay module.
+ target : str
+ The target for which to compile. Can be a plain string or
+ a path.
+ dump_sources : list, optional
+ Dump the generated code for the specified source types, on
+ the requested target.
+ target_host : Union[str, tvm.target.Target], optional
+ The target of the host machine if host-side code
+ needs to be generated.
+ tuning_records: str, optional
+ Name of the file produced by the tuning to be used during
+ compilation.
+ alter_layout: str, optional
+ The layout to convert the graph to. Note, the convert layout
+ pass doesn't currently guarantee the whole of the graph will
+ be converted to the chosen layout.
+
+ Returns
+ -------
+ graph : str
+ A JSON-serialized TVM execution graph.
+ lib : tvm.module.Module
+ A TVM module containing the compiled functions.
+ params : dict
+ The parameters (weights) for the TVM module.
+ dumps : dict
+ Dictionary containing the dumps specified.
+
+ """
+
+ if alter_layout:
+ mod = common.convert_graph_layout(mod, alter_layout)
+
+ if os.path.exists(str(target)):
+ with open(target) as target_file:
+ logging.info("using target input from file: %s", target)
+ target = "".join(target_file.readlines())
+
+ # TODO: We don't have an API to collect a list of supported
+ # targets yet. (@leandron)
+ logging.debug("creating target from input: %s", target)
+ tvm_target = tvm.target.create(target)
+ target_host = target_host or ""
+
+ if tuning_records:
+ logging.debug("tuning records file provided: %s", tuning_records)
+ with autotvm.apply_history_best(tuning_records):
+ with tvm.transform.PassContext(opt_level=3):
+ logging.debug("building relay graph with tuning records")
+ graph_module = relay.build(mod, tvm_target, params=params,
target_host=tvm_target)
+ else:
+ with tvm.transform.PassContext(opt_level=3):
+ logging.debug("building relay graph (no tuning records provided)")
+ graph_module = relay.build(mod, tvm_target, params=params,
target_host=tvm_target)
+
+ # Generate output dump files with sources
+ dump_sources = dump_sources or []
+ dumps = {}
+ for source_type in dump_sources:
+ lib = graph_module.get_lib()
+ # TODO lib.get_source call have inconsistent behavior for unsupported
+ # formats (@leandron).
+ source = str(mod) if source_type == "relay" else
lib.get_source(source_type)
+ dumps[source_type] = source
+
+ return graph_module.get_json(), graph_module.get_lib(),
graph_module.get_params(), dumps
+
+
+def save_module(module_path, graph, lib, params, cross=None):
+ """
+ Create a tarball containing the generated TVM graph,
+ exported library and parameters
+
+ Parameters
+ ----------
+ module_path : str
+ path to the target tar.gz file to be created,
+ including the file name
+ graph : str
+ A JSON-serialized TVM execution graph.
+ lib : tvm.module.Module
+ A TVM module containing the compiled functions.
+ params : dict
+ The parameters (weights) for the TVM module.
+ cross : Union[str, Callable[[str, str, Optional[str]], None]]
Review comment:
Please be consistent in the type annotation style. For example, you used
numpy-style (e.g., `dict, optional`) in other functions like `compile_relay`.
##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -0,0 +1,389 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Provides support to parse models from different frameworks into Relay networks.
+
+Frontend classes do lazy-loading of modules on purpose, to reduce time spent on
+loading the tool.
+"""
+import logging
+import os
+import sys
+from abc import ABC
+from abc import abstractmethod
+from pathlib import Path
+
+from tvm.driver.tvmc.common import TVMCException
+
+
+class Frontend(ABC):
+ """Abstract class for frontend"""
+
+ @staticmethod
+ @abstractmethod
+ def name():
+ """Frontend name"""
+
+ @staticmethod
+ @abstractmethod
+ def suffixes():
+ """File suffixes (extensions) used by this frontend"""
+
+ @abstractmethod
+ def load(self, path, shapes):
+ """Load network"""
Review comment:
docstring the types.
##########
File path: python/tvm/driver/tvmc/compiler.py
##########
@@ -0,0 +1,305 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Provides support to compile networks both AOT and JIT.
+"""
+import logging
+import os.path
+import tarfile
+from pathlib import Path
+
+import tvm
+from tvm import autotvm
+from tvm import relay
+from tvm.contrib import cc
+from tvm.contrib import util
+
+from . import common, frontends
+from .main import register_parser
+
+
+@register_parser
+def add_compile_parser(subparsers):
+ """ Include parser for 'compile' subcommand """
+
+ parser = subparsers.add_parser("compile", help="compile a model")
+ parser.set_defaults(func=drive_compile)
+ parser.add_argument(
+ "--cross-compiler",
+ default="",
+ help="the cross compiler to generate target libraries, e.g.
'aarch64-linux-gnu-gcc'",
+ )
+ parser.add_argument(
+ "--dump-code",
+ metavar="FORMAT",
+ default="",
+ help="comma separarated list of formats to export, e.g. 'asm,ll,relay'
"
+ )
+ parser.add_argument(
+ "--model-format",
+ choices=frontends.get_frontends(),
+ help="specify input model format",
+ )
+ parser.add_argument(
+ "--input-shape",
+ type=common.parse_input_shapes,
+ metavar="INPUT_SHAPE,[INPUT_SHAPE]...",
+ help="for pytorch, e.g. '(1,3,224,224)'",
+ )
+ parser.add_argument(
+ "-o",
+ "--output",
+ default="module.tar",
+ help="output the compiled module to an archive",
+ )
+ parser.add_argument(
+ "--target",
+ help="compilation target as plain string, inline JSON or path to a
JSON file",
+ required=True
+ )
+ parser.add_argument(
+ "--tuning-records",
+ metavar="PATH",
+ default="",
+ help="path to an auto-tuning log file from AutoTVM"
+ )
+ parser.add_argument(
+ "--desired-layout",
+ choices=["NCHW", "NHWC"],
+ default=None,
+ help="change the data layout of the whole graph",
+ )
+ parser.add_argument(
+ "-v", "--verbose", action="count", default=0, help="increase verbosity"
+ )
+ parser.add_argument("FILE")
+
+
+def drive_compile(args):
+ """ Invoke tvmc.compiler module with command line arguments """
+
+ graph, lib, params, dumps = compile_model(
+ args.FILE,
+ args.target,
+ args.dump_code,
+ "",
+ args.model_format,
+ args.input_shape,
+ args.tuning_records,
+ args.tensor_layout,
+ )
+
+ if dumps:
+ save_dumps(args.output, dumps)
+
+ save_module(args.output, graph, lib, params, args.cross_compiler)
+ return 0
+
+
+def compile_model(
+ path,
+ target,
+ dump_sources=None,
+ target_host=None,
+ model_format=None,
+ shapes=None,
+ tuning_records=None,
+ alter_layout=None,
+):
+ """Compile a model from a supported framework into a TVM module.
+
+ This function takes a union of the arguments of both frontends.load_model
+ and compiler.compile_relay. The resulting TVM module can be executed using
+ the graph runtime.
+
+ Returns
+ -------
+ graph : str
+ A JSON-serialized TVM execution graph.
+ lib : tvm.module.Module
+ A TVM module containing the compiled functions.
+ params : dict
+ The parameters (weights) for the TVM module.
+ dumps : dict
+ Dictionary containing the dumps specified.
+
+ """
+ dump_sources = [x.strip() for x in dump_sources.split(',')] if
dump_sources else None
+ mod, params = frontends.load_model(path, model_format, shapes)
+
+ return compile_relay(
+ mod,
+ params,
+ target,
+ dump_sources=dump_sources,
+ target_host=target_host,
+ tuning_records=tuning_records,
+ alter_layout=alter_layout,
+ )
+
+
+def compile_relay(
+ mod,
+ params,
+ target,
+ dump_sources=None,
+ target_host=None,
+ tuning_records=None,
+ alter_layout=None,
+):
+ """Compile a relay module to a TVM module for the graph runtime.
+
+ Parameters
+ ----------
+ mod : tvm.relay.Module
+ The relay module to compile.
+ params : dict
+ The parameters (weights) for the relay module.
+ target : str
+ The target for which to compile. Can be a plain string or
+ a path.
+ dump_sources : list, optional
+ Dump the generated code for the specified source types, on
+ the requested target.
+ target_host : Union[str, tvm.target.Target], optional
+ The target of the host machine if host-side code
+ needs to be generated.
+ tuning_records: str, optional
+ Name of the file produced by the tuning to be used during
+ compilation.
+ alter_layout: str, optional
+ The layout to convert the graph to. Note, the convert layout
+ pass doesn't currently guarantee the whole of the graph will
+ be converted to the chosen layout.
+
+ Returns
+ -------
+ graph : str
+ A JSON-serialized TVM execution graph.
+ lib : tvm.module.Module
+ A TVM module containing the compiled functions.
+ params : dict
+ The parameters (weights) for the TVM module.
+ dumps : dict
+ Dictionary containing the dumps specified.
+
+ """
+
+ if alter_layout:
+ mod = common.convert_graph_layout(mod, alter_layout)
+
+ if os.path.exists(str(target)):
+ with open(target) as target_file:
+ logging.info("using target input from file: %s", target)
+ target = "".join(target_file.readlines())
+
+ # TODO: We don't have an API to collect a list of supported
+ # targets yet. (@leandron)
+ logging.debug("creating target from input: %s", target)
+ tvm_target = tvm.target.create(target)
+ target_host = target_host or ""
+
+ if tuning_records:
+ logging.debug("tuning records file provided: %s", tuning_records)
+ with autotvm.apply_history_best(tuning_records):
+ with tvm.transform.PassContext(opt_level=3):
+ logging.debug("building relay graph with tuning records")
+ graph_module = relay.build(mod, tvm_target, params=params,
target_host=tvm_target)
+ else:
+ with tvm.transform.PassContext(opt_level=3):
+ logging.debug("building relay graph (no tuning records provided)")
+ graph_module = relay.build(mod, tvm_target, params=params,
target_host=tvm_target)
+
+ # Generate output dump files with sources
+ dump_sources = dump_sources or []
+ dumps = {}
+ for source_type in dump_sources:
+ lib = graph_module.get_lib()
+ # TODO lib.get_source call have inconsistent behavior for unsupported
+ # formats (@leandron).
+ source = str(mod) if source_type == "relay" else
lib.get_source(source_type)
+ dumps[source_type] = source
+
+ return graph_module.get_json(), graph_module.get_lib(),
graph_module.get_params(), dumps
+
+
+def save_module(module_path, graph, lib, params, cross=None):
+ """
+ Create a tarball containing the generated TVM graph,
+ exported library and parameters
+
+ Parameters
+ ----------
+ module_path : str
+ path to the target tar.gz file to be created,
+ including the file name
+ graph : str
+ A JSON-serialized TVM execution graph.
+ lib : tvm.module.Module
+ A TVM module containing the compiled functions.
+ params : dict
+ The parameters (weights) for the TVM module.
+ cross : Union[str, Callable[[str, str, Optional[str]], None]]
+ Function that performs the actual compilation
+
+ """
+ lib_name = "mod.so"
+ graph_name = "mod.json"
+ param_name = "mod.params"
+ temp = util.tempdir()
+ path_lib = temp.relpath(lib_name)
+ if not cross:
+ logging.debug("exporting library to %s", path_lib)
+ lib.export_library(path_lib)
+ else:
+ logging.debug("exporting library to %s , using cross compiler %s",
path_lib, cross)
+ lib.export_library(path_lib, cc.cross_compiler(cross))
+
+ with open(temp.relpath(graph_name), "w") as graph_file:
+ logging.debug("writing graph to file to %s", graph_file.name)
+ graph_file.write(graph)
+
+ with open(temp.relpath(param_name), "wb") as params_file:
+ logging.debug("writing params to file to %s", params_file.name)
+ params_file.write(relay.save_param_dict(params))
+
+ logging.debug("saving module as tar file to %s", module_path)
+ with tarfile.open(module_path, "w") as tar:
+ tar.add(path_lib, lib_name)
+ tar.add(temp.relpath(graph_name), graph_name)
+ tar.add(temp.relpath(param_name), param_name)
+
+
+def save_dumps(module_name, dumps, dump_root="."):
+ """
+ Serialize dump files to the disk.
+
+ Parameters
+ ----------
+ module_name : list(Union[str, tvm.target.Target])
+ file name, referring to the module that generated
+ the dump contents
+ dumps : dict
Review comment:
```suggestion
dumps : Dict[?]
```
##########
File path: python/tvm/driver/tvmc/compiler.py
##########
@@ -0,0 +1,305 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Provides support to compile networks both AOT and JIT.
+"""
+import logging
+import os.path
+import tarfile
+from pathlib import Path
+
+import tvm
+from tvm import autotvm
+from tvm import relay
+from tvm.contrib import cc
+from tvm.contrib import util
+
+from . import common, frontends
+from .main import register_parser
+
+
+@register_parser
+def add_compile_parser(subparsers):
+ """ Include parser for 'compile' subcommand """
+
+ parser = subparsers.add_parser("compile", help="compile a model")
+ parser.set_defaults(func=drive_compile)
+ parser.add_argument(
+ "--cross-compiler",
+ default="",
+ help="the cross compiler to generate target libraries, e.g.
'aarch64-linux-gnu-gcc'",
+ )
+ parser.add_argument(
+ "--dump-code",
+ metavar="FORMAT",
+ default="",
+ help="comma separarated list of formats to export, e.g. 'asm,ll,relay'
"
+ )
+ parser.add_argument(
+ "--model-format",
+ choices=frontends.get_frontends(),
+ help="specify input model format",
+ )
+ parser.add_argument(
+ "--input-shape",
+ type=common.parse_input_shapes,
+ metavar="INPUT_SHAPE,[INPUT_SHAPE]...",
+ help="for pytorch, e.g. '(1,3,224,224)'",
+ )
+ parser.add_argument(
+ "-o",
+ "--output",
+ default="module.tar",
+ help="output the compiled module to an archive",
+ )
+ parser.add_argument(
+ "--target",
+ help="compilation target as plain string, inline JSON or path to a
JSON file",
+ required=True
+ )
+ parser.add_argument(
+ "--tuning-records",
+ metavar="PATH",
+ default="",
+ help="path to an auto-tuning log file from AutoTVM"
+ )
+ parser.add_argument(
+ "--desired-layout",
+ choices=["NCHW", "NHWC"],
+ default=None,
+ help="change the data layout of the whole graph",
+ )
+ parser.add_argument(
+ "-v", "--verbose", action="count", default=0, help="increase verbosity"
+ )
+ parser.add_argument("FILE")
+
+
+def drive_compile(args):
+ """ Invoke tvmc.compiler module with command line arguments """
+
+ graph, lib, params, dumps = compile_model(
+ args.FILE,
+ args.target,
+ args.dump_code,
+ "",
+ args.model_format,
+ args.input_shape,
+ args.tuning_records,
+ args.tensor_layout,
+ )
+
+ if dumps:
+ save_dumps(args.output, dumps)
+
+ save_module(args.output, graph, lib, params, args.cross_compiler)
+ return 0
+
+
+def compile_model(
+ path,
+ target,
+ dump_sources=None,
+ target_host=None,
+ model_format=None,
+ shapes=None,
+ tuning_records=None,
+ alter_layout=None,
+):
+ """Compile a model from a supported framework into a TVM module.
+
+ This function takes a union of the arguments of both frontends.load_model
+ and compiler.compile_relay. The resulting TVM module can be executed using
+ the graph runtime.
+
+ Returns
+ -------
+ graph : str
+ A JSON-serialized TVM execution graph.
+ lib : tvm.module.Module
+ A TVM module containing the compiled functions.
+ params : dict
+ The parameters (weights) for the TVM module.
+ dumps : dict
+ Dictionary containing the dumps specified.
+
+ """
+ dump_sources = [x.strip() for x in dump_sources.split(',')] if
dump_sources else None
+ mod, params = frontends.load_model(path, model_format, shapes)
+
+ return compile_relay(
+ mod,
+ params,
+ target,
+ dump_sources=dump_sources,
+ target_host=target_host,
+ tuning_records=tuning_records,
+ alter_layout=alter_layout,
+ )
+
+
+def compile_relay(
+ mod,
+ params,
+ target,
+ dump_sources=None,
+ target_host=None,
+ tuning_records=None,
+ alter_layout=None,
+):
+ """Compile a relay module to a TVM module for the graph runtime.
+
+ Parameters
+ ----------
+ mod : tvm.relay.Module
+ The relay module to compile.
+ params : dict
+ The parameters (weights) for the relay module.
+ target : str
+ The target for which to compile. Can be a plain string or
+ a path.
+ dump_sources : list, optional
+ Dump the generated code for the specified source types, on
+ the requested target.
+ target_host : Union[str, tvm.target.Target], optional
+ The target of the host machine if host-side code
+ needs to be generated.
+ tuning_records: str, optional
+ Name of the file produced by the tuning to be used during
+ compilation.
+ alter_layout: str, optional
+ The layout to convert the graph to. Note, the convert layout
+ pass doesn't currently guarantee the whole of the graph will
+ be converted to the chosen layout.
+
+ Returns
+ -------
+ graph : str
+ A JSON-serialized TVM execution graph.
+ lib : tvm.module.Module
+ A TVM module containing the compiled functions.
+ params : dict
+ The parameters (weights) for the TVM module.
+ dumps : dict
+ Dictionary containing the dumps specified.
+
+ """
+
+ if alter_layout:
+ mod = common.convert_graph_layout(mod, alter_layout)
+
+ if os.path.exists(str(target)):
+ with open(target) as target_file:
+ logging.info("using target input from file: %s", target)
+ target = "".join(target_file.readlines())
+
+ # TODO: We don't have an API to collect a list of supported
+ # targets yet. (@leandron)
+ logging.debug("creating target from input: %s", target)
+ tvm_target = tvm.target.create(target)
+ target_host = target_host or ""
+
+ if tuning_records:
+ logging.debug("tuning records file provided: %s", tuning_records)
+ with autotvm.apply_history_best(tuning_records):
+ with tvm.transform.PassContext(opt_level=3):
+ logging.debug("building relay graph with tuning records")
+ graph_module = relay.build(mod, tvm_target, params=params,
target_host=tvm_target)
+ else:
+ with tvm.transform.PassContext(opt_level=3):
+ logging.debug("building relay graph (no tuning records provided)")
+ graph_module = relay.build(mod, tvm_target, params=params,
target_host=tvm_target)
+
+ # Generate output dump files with sources
+ dump_sources = dump_sources or []
+ dumps = {}
+ for source_type in dump_sources:
+ lib = graph_module.get_lib()
+ # TODO lib.get_source call have inconsistent behavior for unsupported
+ # formats (@leandron).
+ source = str(mod) if source_type == "relay" else
lib.get_source(source_type)
+ dumps[source_type] = source
+
+ return graph_module.get_json(), graph_module.get_lib(),
graph_module.get_params(), dumps
+
+
+def save_module(module_path, graph, lib, params, cross=None):
+ """
+ Create a tarball containing the generated TVM graph,
+ exported library and parameters
+
+ Parameters
+ ----------
+ module_path : str
+ path to the target tar.gz file to be created,
+ including the file name
+ graph : str
+ A JSON-serialized TVM execution graph.
+ lib : tvm.module.Module
+ A TVM module containing the compiled functions.
+ params : dict
+ The parameters (weights) for the TVM module.
+ cross : Union[str, Callable[[str, str, Optional[str]], None]]
+ Function that performs the actual compilation
+
+ """
+ lib_name = "mod.so"
+ graph_name = "mod.json"
+ param_name = "mod.params"
+ temp = util.tempdir()
+ path_lib = temp.relpath(lib_name)
+ if not cross:
+ logging.debug("exporting library to %s", path_lib)
+ lib.export_library(path_lib)
+ else:
+ logging.debug("exporting library to %s , using cross compiler %s",
path_lib, cross)
+ lib.export_library(path_lib, cc.cross_compiler(cross))
+
+ with open(temp.relpath(graph_name), "w") as graph_file:
+ logging.debug("writing graph to file to %s", graph_file.name)
+ graph_file.write(graph)
+
+ with open(temp.relpath(param_name), "wb") as params_file:
+ logging.debug("writing params to file to %s", params_file.name)
+ params_file.write(relay.save_param_dict(params))
+
+ logging.debug("saving module as tar file to %s", module_path)
+ with tarfile.open(module_path, "w") as tar:
+ tar.add(path_lib, lib_name)
+ tar.add(temp.relpath(graph_name), graph_name)
+ tar.add(temp.relpath(param_name), param_name)
+
+
+def save_dumps(module_name, dumps, dump_root="."):
+ """
+ Serialize dump files to the disk.
+
+ Parameters
+ ----------
+ module_name : list(Union[str, tvm.target.Target])
Review comment:
```suggestion
module_name : List[Union[str, tvm.target.Target]]
```
##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -0,0 +1,389 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Provides support to parse models from different frameworks into Relay networks.
+
+Frontend classes do lazy-loading of modules on purpose, to reduce time spent on
+loading the tool.
+"""
+import logging
+import os
+import sys
+from abc import ABC
+from abc import abstractmethod
+from pathlib import Path
+
+from tvm.driver.tvmc.common import TVMCException
+
+
+class Frontend(ABC):
+ """Abstract class for frontend"""
+
+ @staticmethod
+ @abstractmethod
+ def name():
+ """Frontend name"""
+
+ @staticmethod
+ @abstractmethod
+ def suffixes():
+ """File suffixes (extensions) used by this frontend"""
+
+ @abstractmethod
+ def load(self, path, shapes):
+ """Load network"""
+
+
+def import_keras():
+ """ Lazy import function for Keras"""
+ # Keras writes the message "Using TensorFlow backend." to stderr
+ # Redirect stderr during the import to disable this
+ stderr = sys.stderr
+ sys.stderr = open(os.devnull, "w")
+ try:
+ # pylint: disable=C0415
+ import tensorflow as tf
+ from tensorflow import keras
+
+ return tf, keras
+ finally:
+ sys.stderr = stderr
+
+
+class KerasFrontend(Frontend):
+ """ Keras frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "keras"
+
+ @staticmethod
+ def suffixes():
+ return ["h5"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import numpy as np
+ from tvm import relay
+
+ # pylint: disable=C0103
+ tf, keras = import_keras()
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ # tvm build currently imports keras directly instead of
tensorflow.keras
+ try:
+ model = keras.models.load_model(path)
+ except ValueError as err:
+ raise TVMCException(str(err))
+
+ # There are two flavours of keras model, sequential and
+ # functional, TVM expects a functional model, so convert
+ # if required:
+ if self.is_sequential_p(model):
+ model = self.sequential_to_functional(model)
+
+ in_shapes = []
+ for layer in model._input_layers:
+ if tf.executing_eagerly():
+ in_shapes.append(
+ tuple(dim if dim is not None else 1 for dim in
layer.input.shape)
+ )
+ else:
+ in_shapes.append(
+ tuple(
+ dim.value if dim.value is not None else 1
+ for dim in layer.input.shape
+ )
+ )
+
+ inputs = [
+ np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in
in_shapes
+ ]
+ shape_dict = {name: x.shape for (name, x) in zip(model.input_names,
inputs)}
+ return relay.frontend.from_keras(model, shape_dict, layout="NHWC")
+
+ def is_sequential_p(self, model):
+ _, keras = import_keras()
+ return isinstance(model, keras.models.Sequential)
+
+ def sequential_to_functional(self, model):
+ _, keras = import_keras()
+ assert self.is_sequential_p(model)
+ input_layer =
keras.layers.Input(batch_shape=model.layers[0].input_shape)
+ prev_layer = input_layer
+ for layer in model.layers:
+ prev_layer = layer(prev_layer)
+ model = keras.models.Model([input_layer], [prev_layer])
+ return model
+
+
+class OnnxFrontend(Frontend):
+ """ ONNX frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "onnx"
+
+ @staticmethod
+ def suffixes():
+ return ["onnx"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import onnx
+ from tvm import relay
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ model = onnx.load(path)
+
+ # Find the name and shape of the first input in the graph
+
Review comment:
remove this line.
##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -0,0 +1,389 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Provides support to parse models from different frameworks into Relay networks.
+
+Frontend classes do lazy-loading of modules on purpose, to reduce time spent on
+loading the tool.
+"""
+import logging
+import os
+import sys
+from abc import ABC
+from abc import abstractmethod
+from pathlib import Path
+
+from tvm.driver.tvmc.common import TVMCException
+
+
+class Frontend(ABC):
+ """Abstract class for frontend"""
+
+ @staticmethod
+ @abstractmethod
+ def name():
+ """Frontend name"""
+
+ @staticmethod
+ @abstractmethod
+ def suffixes():
+ """File suffixes (extensions) used by this frontend"""
+
+ @abstractmethod
+ def load(self, path, shapes):
+ """Load network"""
+
+
+def import_keras():
+ """ Lazy import function for Keras"""
+ # Keras writes the message "Using TensorFlow backend." to stderr
+ # Redirect stderr during the import to disable this
+ stderr = sys.stderr
+ sys.stderr = open(os.devnull, "w")
+ try:
+ # pylint: disable=C0415
+ import tensorflow as tf
+ from tensorflow import keras
+
+ return tf, keras
+ finally:
+ sys.stderr = stderr
+
+
+class KerasFrontend(Frontend):
+ """ Keras frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "keras"
+
+ @staticmethod
+ def suffixes():
+ return ["h5"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import numpy as np
+ from tvm import relay
+
+ # pylint: disable=C0103
+ tf, keras = import_keras()
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ # tvm build currently imports keras directly instead of
tensorflow.keras
+ try:
+ model = keras.models.load_model(path)
+ except ValueError as err:
+ raise TVMCException(str(err))
+
+ # There are two flavours of keras model, sequential and
+ # functional, TVM expects a functional model, so convert
+ # if required:
+ if self.is_sequential_p(model):
+ model = self.sequential_to_functional(model)
+
+ in_shapes = []
+ for layer in model._input_layers:
+ if tf.executing_eagerly():
+ in_shapes.append(
+ tuple(dim if dim is not None else 1 for dim in
layer.input.shape)
+ )
+ else:
+ in_shapes.append(
+ tuple(
+ dim.value if dim.value is not None else 1
+ for dim in layer.input.shape
+ )
+ )
+
+ inputs = [
+ np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in
in_shapes
+ ]
+ shape_dict = {name: x.shape for (name, x) in zip(model.input_names,
inputs)}
+ return relay.frontend.from_keras(model, shape_dict, layout="NHWC")
+
+ def is_sequential_p(self, model):
+ _, keras = import_keras()
+ return isinstance(model, keras.models.Sequential)
+
+ def sequential_to_functional(self, model):
+ _, keras = import_keras()
+ assert self.is_sequential_p(model)
+ input_layer =
keras.layers.Input(batch_shape=model.layers[0].input_shape)
+ prev_layer = input_layer
+ for layer in model.layers:
+ prev_layer = layer(prev_layer)
+ model = keras.models.Model([input_layer], [prev_layer])
+ return model
+
+
+class OnnxFrontend(Frontend):
+ """ ONNX frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "onnx"
+
+ @staticmethod
+ def suffixes():
+ return ["onnx"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import onnx
+ from tvm import relay
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ model = onnx.load(path)
+
+ # Find the name and shape of the first input in the graph
+
+ # pylint: disable=E1101
+ name = model.graph.input[0].name
+
+ # pylint: disable=E1101
+ proto_shape = model.graph.input[0].type.tensor_type.shape.dim
+ shape = [d.dim_value for d in proto_shape]
+
+ shape_dict = {name: shape}
+
+ return relay.frontend.from_onnx(model, shape_dict)
+
+
+class TensorflowFrontend(Frontend):
+ """ TensorFlow frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "pb"
+
+ @staticmethod
+ def suffixes():
+ return ["pb"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ from tvm import relay
+ import tensorflow as tf
+ import tvm.relay.testing.tf as tf_testing
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ with tf.io.gfile.GFile(path, "rb") as tf_graph:
+ content = tf_graph.read()
+
+ graph_def = tf.compat.v1.GraphDef()
+ graph_def.ParseFromString(content)
+ graph_def = tf_testing.ProcessGraphDefParam(graph_def)
+
+ logging.debug("relay.frontend.from_tensorflow")
+ return relay.frontend.from_tensorflow(graph_def)
+
+
+class TFLiteFrontend(Frontend):
+ """ TFLite frontend for TVMC """
+
+ _tflite_m = {
+ 0: "float32",
+ 1: "float16",
+ 2: "int32",
+ 3: "uint8",
+ 4: "int64",
+ 5: "string",
+ 6: "bool",
+ 7: "int16",
+ 8: "complex64",
+ 9: "int8",
+ }
+
+ @staticmethod
+ def name():
+ return "tflite"
+
+ @staticmethod
+ def suffixes():
+ return ["tflite"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import tflite.Model as model
+ from tvm import relay
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ with open(path, "rb") as tf_graph:
+ content = tf_graph.read()
+
+ # tflite.Model.Model is tflite.Model in 1.14 and 2.1.0
+ try:
+ tflite_model = model.Model.GetRootAsModel(content, 0)
+ except AttributeError:
+ tflite_model = model.GetRootAsModel(content, 0)
+
+ try:
+ version = tflite_model.Version()
+ logging.debug("tflite version %s", version)
+ except Exception:
+ raise TVMCException("input file not tflite")
+
+ if version != 3:
+ raise TVMCException("input file not tflite version 3")
+
+ logging.debug("tflite_input_type")
+ shape_dict, dtype_dict = TFLiteFrontend._input_type(tflite_model)
+
+ # parse TFLite model and convert into Relay computation graph
+ logging.debug("relay.frontend.from_tflite")
+ mod, params = relay.frontend.from_tflite(
+ tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict
+ )
+ return mod, params
+
+ @staticmethod
+ def _decode_type(n):
+ return TFLiteFrontend._tflite_m[n]
+
+ @staticmethod
+ def _input_type(model):
+ subgraph_count = model.SubgraphsLength()
+ assert subgraph_count > 0
+ shape_dict = {}
+ dtype_dict = {}
+ for subgraph_index in range(subgraph_count):
+ subgraph = model.Subgraphs(subgraph_index)
+ inputs_count = subgraph.InputsLength()
+ assert inputs_count >= 1
+ for input_index in range(inputs_count):
+ input_ = subgraph.Inputs(input_index)
+ assert subgraph.TensorsLength() > input_
+ tensor = subgraph.Tensors(input_)
+ input_shape = tuple(tensor.ShapeAsNumpy())
+ tensor_type = tensor.Type()
+ input_name = tensor.Name().decode("utf8")
+ shape_dict[input_name] = input_shape
+ dtype_dict[input_name] =
TFLiteFrontend._decode_type(tensor_type)
+
+ return shape_dict, dtype_dict
+
+
+class PyTorchFrontend(Frontend):
+ """ PyTorch frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "pytorch"
+
+ @staticmethod
+ def suffixes():
+ # Torch Script is a zip file, but can be named pth
+ return ["pth", "zip"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import torch
+ from tvm import relay
+
+ if not shapes:
+ raise TVMCException(
+ "--input-shape must be specified for {}".format(self.name())
+ )
+
+ traced_model = torch.jit.load(path)
+ traced_model.eval() # Switch to inference mode
+ input_shapes = [
+ ("input{}".format(idx), shape) for idx, shape in enumerate(shapes)
+ ]
+ logging.debug("relay.frontend.from_pytorch")
+ return relay.frontend.from_pytorch(traced_model, input_shapes)
+
+
+ALL_FRONTENDS = [
+ KerasFrontend,
+ OnnxFrontend,
+ TensorflowFrontend,
+ TFLiteFrontend,
+ PyTorchFrontend,
+]
+
+
+def get_frontends():
+ """Return the names of all supported frontends"""
+ return [frontend.name() for frontend in ALL_FRONTENDS]
+
+
+def lookup_frontend(name):
+ for frontend in ALL_FRONTENDS:
+ if name == frontend.name():
+ return frontend()
+ raise TVMCException("unrecognized frontend")
+
+
+def guess_input_language(path):
Review comment:
- We are not using "language" anymore.
- By its name, I don't expect this function to return a frontend but should
just a frontend name; otherwise you may need to rename it to something like
`guess_frontend`.
##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -0,0 +1,389 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Provides support to parse models from different frameworks into Relay networks.
+
+Frontend classes do lazy-loading of modules on purpose, to reduce time spent on
+loading the tool.
+"""
+import logging
+import os
+import sys
+from abc import ABC
+from abc import abstractmethod
+from pathlib import Path
+
+from tvm.driver.tvmc.common import TVMCException
+
+
+class Frontend(ABC):
+ """Abstract class for frontend"""
+
+ @staticmethod
+ @abstractmethod
+ def name():
+ """Frontend name"""
+
+ @staticmethod
+ @abstractmethod
+ def suffixes():
+ """File suffixes (extensions) used by this frontend"""
+
+ @abstractmethod
+ def load(self, path, shapes):
+ """Load network"""
+
+
+def import_keras():
+ """ Lazy import function for Keras"""
+ # Keras writes the message "Using TensorFlow backend." to stderr
+ # Redirect stderr during the import to disable this
+ stderr = sys.stderr
+ sys.stderr = open(os.devnull, "w")
+ try:
+ # pylint: disable=C0415
+ import tensorflow as tf
+ from tensorflow import keras
+
+ return tf, keras
+ finally:
+ sys.stderr = stderr
+
+
+class KerasFrontend(Frontend):
+ """ Keras frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "keras"
+
+ @staticmethod
+ def suffixes():
+ return ["h5"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import numpy as np
+ from tvm import relay
+
+ # pylint: disable=C0103
+ tf, keras = import_keras()
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ # tvm build currently imports keras directly instead of
tensorflow.keras
+ try:
+ model = keras.models.load_model(path)
+ except ValueError as err:
+ raise TVMCException(str(err))
+
+ # There are two flavours of keras model, sequential and
+ # functional, TVM expects a functional model, so convert
+ # if required:
+ if self.is_sequential_p(model):
+ model = self.sequential_to_functional(model)
+
+ in_shapes = []
+ for layer in model._input_layers:
+ if tf.executing_eagerly():
+ in_shapes.append(
+ tuple(dim if dim is not None else 1 for dim in
layer.input.shape)
+ )
+ else:
+ in_shapes.append(
+ tuple(
+ dim.value if dim.value is not None else 1
+ for dim in layer.input.shape
+ )
+ )
+
+ inputs = [
+ np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in
in_shapes
+ ]
+ shape_dict = {name: x.shape for (name, x) in zip(model.input_names,
inputs)}
+ return relay.frontend.from_keras(model, shape_dict, layout="NHWC")
+
+ def is_sequential_p(self, model):
+ _, keras = import_keras()
+ return isinstance(model, keras.models.Sequential)
+
+ def sequential_to_functional(self, model):
+ _, keras = import_keras()
+ assert self.is_sequential_p(model)
+ input_layer =
keras.layers.Input(batch_shape=model.layers[0].input_shape)
+ prev_layer = input_layer
+ for layer in model.layers:
+ prev_layer = layer(prev_layer)
+ model = keras.models.Model([input_layer], [prev_layer])
+ return model
+
+
+class OnnxFrontend(Frontend):
+ """ ONNX frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "onnx"
+
+ @staticmethod
+ def suffixes():
+ return ["onnx"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import onnx
+ from tvm import relay
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ model = onnx.load(path)
+
+ # Find the name and shape of the first input in the graph
+
+ # pylint: disable=E1101
+ name = model.graph.input[0].name
+
+ # pylint: disable=E1101
+ proto_shape = model.graph.input[0].type.tensor_type.shape.dim
+ shape = [d.dim_value for d in proto_shape]
+
+ shape_dict = {name: shape}
+
+ return relay.frontend.from_onnx(model, shape_dict)
+
+
+class TensorflowFrontend(Frontend):
+ """ TensorFlow frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "pb"
+
+ @staticmethod
+ def suffixes():
+ return ["pb"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ from tvm import relay
+ import tensorflow as tf
+ import tvm.relay.testing.tf as tf_testing
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ with tf.io.gfile.GFile(path, "rb") as tf_graph:
+ content = tf_graph.read()
+
+ graph_def = tf.compat.v1.GraphDef()
+ graph_def.ParseFromString(content)
+ graph_def = tf_testing.ProcessGraphDefParam(graph_def)
+
+ logging.debug("relay.frontend.from_tensorflow")
+ return relay.frontend.from_tensorflow(graph_def)
+
+
+class TFLiteFrontend(Frontend):
+ """ TFLite frontend for TVMC """
+
+ _tflite_m = {
+ 0: "float32",
+ 1: "float16",
+ 2: "int32",
+ 3: "uint8",
+ 4: "int64",
+ 5: "string",
+ 6: "bool",
+ 7: "int16",
+ 8: "complex64",
+ 9: "int8",
+ }
+
+ @staticmethod
+ def name():
+ return "tflite"
+
+ @staticmethod
+ def suffixes():
+ return ["tflite"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import tflite.Model as model
+ from tvm import relay
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ with open(path, "rb") as tf_graph:
+ content = tf_graph.read()
+
+ # tflite.Model.Model is tflite.Model in 1.14 and 2.1.0
+ try:
+ tflite_model = model.Model.GetRootAsModel(content, 0)
+ except AttributeError:
+ tflite_model = model.GetRootAsModel(content, 0)
+
+ try:
+ version = tflite_model.Version()
+ logging.debug("tflite version %s", version)
+ except Exception:
+ raise TVMCException("input file not tflite")
+
+ if version != 3:
+ raise TVMCException("input file not tflite version 3")
+
+ logging.debug("tflite_input_type")
+ shape_dict, dtype_dict = TFLiteFrontend._input_type(tflite_model)
+
+ # parse TFLite model and convert into Relay computation graph
+ logging.debug("relay.frontend.from_tflite")
+ mod, params = relay.frontend.from_tflite(
+ tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict
+ )
+ return mod, params
+
+ @staticmethod
+ def _decode_type(n):
+ return TFLiteFrontend._tflite_m[n]
+
+ @staticmethod
+ def _input_type(model):
+ subgraph_count = model.SubgraphsLength()
+ assert subgraph_count > 0
+ shape_dict = {}
+ dtype_dict = {}
+ for subgraph_index in range(subgraph_count):
+ subgraph = model.Subgraphs(subgraph_index)
+ inputs_count = subgraph.InputsLength()
+ assert inputs_count >= 1
+ for input_index in range(inputs_count):
+ input_ = subgraph.Inputs(input_index)
+ assert subgraph.TensorsLength() > input_
+ tensor = subgraph.Tensors(input_)
+ input_shape = tuple(tensor.ShapeAsNumpy())
+ tensor_type = tensor.Type()
+ input_name = tensor.Name().decode("utf8")
+ shape_dict[input_name] = input_shape
+ dtype_dict[input_name] =
TFLiteFrontend._decode_type(tensor_type)
+
+ return shape_dict, dtype_dict
+
+
+class PyTorchFrontend(Frontend):
+ """ PyTorch frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "pytorch"
+
+ @staticmethod
+ def suffixes():
+ # Torch Script is a zip file, but can be named pth
+ return ["pth", "zip"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import torch
+ from tvm import relay
+
+ if not shapes:
+ raise TVMCException(
+ "--input-shape must be specified for {}".format(self.name())
+ )
+
+ traced_model = torch.jit.load(path)
+ traced_model.eval() # Switch to inference mode
+ input_shapes = [
+ ("input{}".format(idx), shape) for idx, shape in enumerate(shapes)
+ ]
+ logging.debug("relay.frontend.from_pytorch")
+ return relay.frontend.from_pytorch(traced_model, input_shapes)
+
+
+ALL_FRONTENDS = [
+ KerasFrontend,
+ OnnxFrontend,
+ TensorflowFrontend,
+ TFLiteFrontend,
+ PyTorchFrontend,
+]
+
+
+def get_frontends():
+ """Return the names of all supported frontends"""
+ return [frontend.name() for frontend in ALL_FRONTENDS]
+
+
+def lookup_frontend(name):
Review comment:
`get_frontend_by_name`?
##########
File path: python/tvm/driver/tvmc/frontends.py
##########
@@ -0,0 +1,389 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Provides support to parse models from different frameworks into Relay networks.
+
+Frontend classes do lazy-loading of modules on purpose, to reduce time spent on
+loading the tool.
+"""
+import logging
+import os
+import sys
+from abc import ABC
+from abc import abstractmethod
+from pathlib import Path
+
+from tvm.driver.tvmc.common import TVMCException
+
+
+class Frontend(ABC):
+ """Abstract class for frontend"""
+
+ @staticmethod
+ @abstractmethod
+ def name():
+ """Frontend name"""
+
+ @staticmethod
+ @abstractmethod
+ def suffixes():
+ """File suffixes (extensions) used by this frontend"""
+
+ @abstractmethod
+ def load(self, path, shapes):
+ """Load network"""
+
+
+def import_keras():
+ """ Lazy import function for Keras"""
+ # Keras writes the message "Using TensorFlow backend." to stderr
+ # Redirect stderr during the import to disable this
+ stderr = sys.stderr
+ sys.stderr = open(os.devnull, "w")
+ try:
+ # pylint: disable=C0415
+ import tensorflow as tf
+ from tensorflow import keras
+
+ return tf, keras
+ finally:
+ sys.stderr = stderr
+
+
+class KerasFrontend(Frontend):
+ """ Keras frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "keras"
+
+ @staticmethod
+ def suffixes():
+ return ["h5"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import numpy as np
+ from tvm import relay
+
+ # pylint: disable=C0103
+ tf, keras = import_keras()
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ # tvm build currently imports keras directly instead of
tensorflow.keras
+ try:
+ model = keras.models.load_model(path)
+ except ValueError as err:
+ raise TVMCException(str(err))
+
+ # There are two flavours of keras model, sequential and
+ # functional, TVM expects a functional model, so convert
+ # if required:
+ if self.is_sequential_p(model):
+ model = self.sequential_to_functional(model)
+
+ in_shapes = []
+ for layer in model._input_layers:
+ if tf.executing_eagerly():
+ in_shapes.append(
+ tuple(dim if dim is not None else 1 for dim in
layer.input.shape)
+ )
+ else:
+ in_shapes.append(
+ tuple(
+ dim.value if dim.value is not None else 1
+ for dim in layer.input.shape
+ )
+ )
+
+ inputs = [
+ np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in
in_shapes
+ ]
+ shape_dict = {name: x.shape for (name, x) in zip(model.input_names,
inputs)}
+ return relay.frontend.from_keras(model, shape_dict, layout="NHWC")
+
+ def is_sequential_p(self, model):
+ _, keras = import_keras()
+ return isinstance(model, keras.models.Sequential)
+
+ def sequential_to_functional(self, model):
+ _, keras = import_keras()
+ assert self.is_sequential_p(model)
+ input_layer =
keras.layers.Input(batch_shape=model.layers[0].input_shape)
+ prev_layer = input_layer
+ for layer in model.layers:
+ prev_layer = layer(prev_layer)
+ model = keras.models.Model([input_layer], [prev_layer])
+ return model
+
+
+class OnnxFrontend(Frontend):
+ """ ONNX frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "onnx"
+
+ @staticmethod
+ def suffixes():
+ return ["onnx"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import onnx
+ from tvm import relay
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ model = onnx.load(path)
+
+ # Find the name and shape of the first input in the graph
+
+ # pylint: disable=E1101
+ name = model.graph.input[0].name
+
+ # pylint: disable=E1101
+ proto_shape = model.graph.input[0].type.tensor_type.shape.dim
+ shape = [d.dim_value for d in proto_shape]
+
+ shape_dict = {name: shape}
+
+ return relay.frontend.from_onnx(model, shape_dict)
+
+
+class TensorflowFrontend(Frontend):
+ """ TensorFlow frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "pb"
+
+ @staticmethod
+ def suffixes():
+ return ["pb"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ from tvm import relay
+ import tensorflow as tf
+ import tvm.relay.testing.tf as tf_testing
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ with tf.io.gfile.GFile(path, "rb") as tf_graph:
+ content = tf_graph.read()
+
+ graph_def = tf.compat.v1.GraphDef()
+ graph_def.ParseFromString(content)
+ graph_def = tf_testing.ProcessGraphDefParam(graph_def)
+
+ logging.debug("relay.frontend.from_tensorflow")
+ return relay.frontend.from_tensorflow(graph_def)
+
+
+class TFLiteFrontend(Frontend):
+ """ TFLite frontend for TVMC """
+
+ _tflite_m = {
+ 0: "float32",
+ 1: "float16",
+ 2: "int32",
+ 3: "uint8",
+ 4: "int64",
+ 5: "string",
+ 6: "bool",
+ 7: "int16",
+ 8: "complex64",
+ 9: "int8",
+ }
+
+ @staticmethod
+ def name():
+ return "tflite"
+
+ @staticmethod
+ def suffixes():
+ return ["tflite"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import tflite.Model as model
+ from tvm import relay
+
+ if shapes:
+ raise TVMCException(
+ "--input-shape is not supported for {}".format(self.name())
+ )
+
+ with open(path, "rb") as tf_graph:
+ content = tf_graph.read()
+
+ # tflite.Model.Model is tflite.Model in 1.14 and 2.1.0
+ try:
+ tflite_model = model.Model.GetRootAsModel(content, 0)
+ except AttributeError:
+ tflite_model = model.GetRootAsModel(content, 0)
+
+ try:
+ version = tflite_model.Version()
+ logging.debug("tflite version %s", version)
+ except Exception:
+ raise TVMCException("input file not tflite")
+
+ if version != 3:
+ raise TVMCException("input file not tflite version 3")
+
+ logging.debug("tflite_input_type")
+ shape_dict, dtype_dict = TFLiteFrontend._input_type(tflite_model)
+
+ # parse TFLite model and convert into Relay computation graph
+ logging.debug("relay.frontend.from_tflite")
+ mod, params = relay.frontend.from_tflite(
+ tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict
+ )
+ return mod, params
+
+ @staticmethod
+ def _decode_type(n):
+ return TFLiteFrontend._tflite_m[n]
+
+ @staticmethod
+ def _input_type(model):
+ subgraph_count = model.SubgraphsLength()
+ assert subgraph_count > 0
+ shape_dict = {}
+ dtype_dict = {}
+ for subgraph_index in range(subgraph_count):
+ subgraph = model.Subgraphs(subgraph_index)
+ inputs_count = subgraph.InputsLength()
+ assert inputs_count >= 1
+ for input_index in range(inputs_count):
+ input_ = subgraph.Inputs(input_index)
+ assert subgraph.TensorsLength() > input_
+ tensor = subgraph.Tensors(input_)
+ input_shape = tuple(tensor.ShapeAsNumpy())
+ tensor_type = tensor.Type()
+ input_name = tensor.Name().decode("utf8")
+ shape_dict[input_name] = input_shape
+ dtype_dict[input_name] =
TFLiteFrontend._decode_type(tensor_type)
+
+ return shape_dict, dtype_dict
+
+
+class PyTorchFrontend(Frontend):
+ """ PyTorch frontend for TVMC """
+
+ @staticmethod
+ def name():
+ return "pytorch"
+
+ @staticmethod
+ def suffixes():
+ # Torch Script is a zip file, but can be named pth
+ return ["pth", "zip"]
+
+ def load(self, path, shapes):
+ # pylint: disable=C0415
+ import torch
+ from tvm import relay
+
+ if not shapes:
+ raise TVMCException(
+ "--input-shape must be specified for {}".format(self.name())
+ )
+
+ traced_model = torch.jit.load(path)
+ traced_model.eval() # Switch to inference mode
+ input_shapes = [
+ ("input{}".format(idx), shape) for idx, shape in enumerate(shapes)
+ ]
+ logging.debug("relay.frontend.from_pytorch")
+ return relay.frontend.from_pytorch(traced_model, input_shapes)
+
+
+ALL_FRONTENDS = [
+ KerasFrontend,
+ OnnxFrontend,
+ TensorflowFrontend,
+ TFLiteFrontend,
+ PyTorchFrontend,
+]
+
+
+def get_frontends():
+ """Return the names of all supported frontends"""
+ return [frontend.name() for frontend in ALL_FRONTENDS]
+
+
+def lookup_frontend(name):
+ for frontend in ALL_FRONTENDS:
+ if name == frontend.name():
+ return frontend()
+ raise TVMCException("unrecognized frontend")
+
+
+def guess_input_language(path):
+ suffix = Path(path).suffix.lower()
+ if suffix.startswith("."):
+ suffix = suffix[1:]
+
+ for frontend in ALL_FRONTENDS:
+ if suffix in frontend.suffixes():
+ return frontend()
+
+ raise TVMCException("cannot guess input language")
+
+
+def load_model(path, language=None, shapes=None):
Review comment:
Ditto to "language"
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]