masahi commented on a change in pull request #8777:
URL: https://github.com/apache/tvm/pull/8777#discussion_r735245631



##########
File path: cmake/modules/contrib/PT_TVMDSOOP.cmake
##########
@@ -0,0 +1,64 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+if(NOT USE_PT_TVMDSOOP STREQUAL "OFF")
+  find_package(Python3 COMPONENTS Interpreter Development)
+  include_directories(${Python3_INCLUDE_DIRS})
+
+  message(STATUS "Python3_INCLUDE_DIRS: ${Python3_INCLUDE_DIRS}")
+
+  execute_process(COMMAND ${Python3_EXECUTABLE} -c "import torch; 
print(torch.__path__[0].strip())"
+    OUTPUT_VARIABLE PT_PATH
+    RESULT_VARIABLE PT_STATUS)
+  if (NOT ${PT_STATUS} EQUAL 0)
+    message(FATAL_ERROR "Fail to get pytorch path")
+  endif()
+
+  string(REGEX REPLACE "\n" "" PT_PATH "${PT_PATH}")
+
+  set(PT_COMPILE_FLAGS_STR "-I${PT_PATH}/include -D_GLIBCXX_USE_CXX11_ABI=0")
+  set(PT_LINK_FLAGS_STR "-L${PT_PATH}/lib -l:libtorch.so 
-l:libtorch_python.so")
+
+  if(NOT USE_CUDA STREQUAL "OFF")
+    add_definitions(-DPT_TVMDSOOP_ENABLE_GPU)
+  endif()
+
+
+  string(REGEX REPLACE "\n" " " PT_FLAGS "${PT_COMPILE_FLAGS} 
${PT_LINK_FLAGS}")
+  separate_arguments(PT_COMPILE_FLAGS UNIX_COMMAND ${PT_COMPILE_FLAGS_STR})
+  separate_arguments(PT_LINK_FLAGS UNIX_COMMAND ${PT_LINK_FLAGS_STR})
+
+
+  set(LIBRARY_NAME pt_tvmdsoop)
+  file(GLOB_RECURSE PTTVM_SRCS 
${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/**/*.cc)
+  add_library(${LIBRARY_NAME} SHARED ${PTTVM_SRCS})
+  # add_library(${STATIC_NAME} STATIC ${PTTVM_SRCS})
+  # set(PTTVM_LINK_FLAGS -ltvm -ltvm_runtime -L${CMAKE_CURRENT_BINARY_DIR})
+  set(PTTVM_LINK_FLAGS -ltvm -L${CMAKE_CURRENT_BINARY_DIR})
+
+  if (NOT BUILD_PT_TVMDSOOP_ONLY STREQUAL "ON")
+    add_dependencies(${LIBRARY_NAME} tvm) 
+  endif()
+  # add_dependencies(${LIBRARY_NAME} tvm)
+
+  target_compile_options(${LIBRARY_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} 
${PT_COMPILE_FLAGS})
+  target_link_libraries(${LIBRARY_NAME} PUBLIC ${PTTVM_LINK_FLAGS} 
${PT_LINK_FLAGS})
+  # target_compile_options(${STATIC_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} 
${PT_COMPILE_FLAGS})
+  # target_link_libraries(${STATIC_NAME} PUBLIC ${PTTVM_LINK_FLAGS} 
${PT_LINK_FLAGS})

Review comment:
       Please remove commented lines.

##########
File path: python/tvm/contrib/torch/module.py
##########
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Module container of PyTorch custom class"""
+from typing import List
+import torch
+
+
+class GraphModule(torch.nn.Module):
+    r"""Module container of Pytorch class which wraps exported
+    TVM op implementation library to be called on Pytorch side"""
+
+    @classmethod
+    def shape_repr(cls, input_shapes):
+        return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes)
+
+    def __init__(self, num_inputs, num_outputs, device=None):
+        super().__init__()
+        self.dummy_param = torch.nn.Parameter(torch.empty(0))
+        self.engine = None
+
+        if device is not None:
+            self.to(device)
+        self.engine = torch.classes.tvm_dsoop.TvmGraphModule(num_inputs, 
num_outputs, self.device)
+
+    def init(self, input_shapes, lib_path, graph_path, params_path):
+        r"""Load tvm module"""
+        self.engine.load_tvm_module(input_shapes, lib_path, graph_path, 
params_path)
+
+    def forward(self, inputs: List[torch.Tensor]):
+        r"""Call tvm module to forward"""
+        return self.engine.forward(inputs)
+
+    @property
+    def device(self):
+        r"""Get the device string"""
+        return str(self.dummy_param.device)
+
+    def _apply(self, func):
+        r"""Override to device function, manually move tvm module to desired 
device"""
+        super()._apply(func)
+        if self.engine is not None:
+            self.engine.to(self.device)
+        return self
+
+
+class VMModule(torch.nn.Module):

Review comment:
       It seems this class is not used anywhere. Please add a test for it or 
remove it.

##########
File path: python/tvm/contrib/torch/pytorch_tvm.py
##########
@@ -0,0 +1,226 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""`compile` api that convert torch module to torch tvm module"""
+import os
+import tvm
+import tvm.testing
+from tvm import relay, autotvm
+from tvm.runtime import load_module
+from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
+from tvm.contrib import graph_executor
+from tvm.contrib.debugger import debug_executor
+from . import GraphModule
+
+
+def tune_tasks(
+    tasks,
+    measure_option,
+    tuner="xgb",
+    n_trial=1000,
+    early_stopping=None,
+    log_filename="tuning.log",
+    use_transfer_learning=True,
+):
+    """Tune tasks and generate tuning log to file"""
+    # create tmp log file
+    tmp_log_file = log_filename + ".tmp"
+    if os.path.exists(tmp_log_file):
+        os.remove(tmp_log_file)
+
+    for i, tsk in enumerate(reversed(tasks)):
+        prefix = f"[Task {i + 1:2d}/{len(tasks):2d}] "
+
+        # create tuner
+        if tuner in ("xgb", "sgb-rank"):
+            tuner_obj = XGBTuner(tsk, loss_type="rank")
+        elif tuner == "ga":
+            tuner_obj = GATuner(tsk, pop_size=100)
+        elif tuner == "random":
+            tuner_obj = RandomTuner(tsk)
+        elif tuner == "gridsearch":
+            tuner_obj = GridSearchTuner(tsk)
+        else:
+            raise ValueError("Invalid tuner: " + tuner)
+
+        if use_transfer_learning:
+            if os.path.isfile(tmp_log_file):
+                
tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))
+
+        # do tuning
+        tsk_trial = min(n_trial, len(tsk.config_space))
+        tuner_obj.tune(
+            n_trial=tsk_trial,
+            early_stopping=early_stopping,
+            measure_option=measure_option,
+            callbacks=[
+                autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
+                autotvm.callback.log_to_file(tmp_log_file),
+            ],
+        )
+
+    # pick best records to a cache file
+    autotvm.record.pick_best(tmp_log_file, log_filename)
+    os.remove(tmp_log_file)
+
+
+def get_tuning_opt(log_file="tuning.log", n_trial=200):
+    """Returns tuning options"""
+    tuning_opt = {
+        "log_filename": log_file,
+        "tuner": "random",
+        "n_trial": n_trial,
+        "early_stopping": 60,
+        "measure_option": autotvm.measure_option(
+            builder=autotvm.LocalBuilder(timeout=10),
+            runner=autotvm.LocalRunner(number=20, repeat=3, timeout=4, 
min_repeat_ms=150),
+        ),
+    }
+    return tuning_opt
+
+
+TVM_ASSETS = ["mod.so", "graph.json", "params"]
+
+
+class PyTorchTVMModule:
+    """Helper class for compiling pytorch module to tvm module"""
+
+    def __init__(self) -> None:
+        self.script_module = None
+        self.input_infos = None
+        self.default_dtype = "float32"
+        self.mod = None
+        self.params = None
+        self.tasks = None
+        self.target = "cuda"
+        self.dev = tvm.cuda(0)

Review comment:
       target and device are hard coded. Is this ok? What happens if a user 
does `model.to("cpu")`.

##########
File path: python/tvm/contrib/torch/__init__.py
##########
@@ -0,0 +1,50 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Module container of Pytorch custom class"""
+import os
+import platform
+import torch
+from tvm._ffi import libinfo
+from tvm.relay.frontend import pytorch
+
+
+def _load_platform_specific_library(lib_name="libpt_tvmdsoop"):
+    system = platform.system()
+    if system == "Darwin":
+        lib_file_name = lib_name + ".dylib"
+    elif system == "Windows":
+        lib_file_name = lib_name + ".dll"
+    else:
+        lib_file_name = lib_name + ".so"
+    lib_path = libinfo.find_lib_path()[0]
+    lib_dir = os.path.dirname(lib_path)
+    lib_file_path = os.path.join(lib_dir, lib_file_name)
+    torch.classes.load_library(lib_file_path)
+
+
+_load_platform_specific_library()
+
+from . import module  # nopep8, pylint: disable=wrong-import-position
+
+GraphModule = module.GraphModule
+VMModule = module.VMModule
+TraceTvmModule = module.TraceTvmModule
+
+from . import pytorch_tvm  # nopep8, pylint: disable=wrong-import-position
+
+PyTorchTVMModule = pytorch_tvm.PyTorchTVMModule
+compile = pytorch_tvm.compile  # pylint: disable=redefined-builtin,invalid-name

Review comment:
       Better to put all pylint disable directives below the license (see other 
files) 

##########
File path: python/tvm/contrib/torch/pytorch_tvm.py
##########
@@ -0,0 +1,226 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""`compile` api that convert torch module to torch tvm module"""
+import os
+import tvm
+import tvm.testing
+from tvm import relay, autotvm
+from tvm.runtime import load_module
+from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
+from tvm.contrib import graph_executor
+from tvm.contrib.debugger import debug_executor
+from . import GraphModule
+
+
+def tune_tasks(
+    tasks,
+    measure_option,
+    tuner="xgb",
+    n_trial=1000,
+    early_stopping=None,
+    log_filename="tuning.log",
+    use_transfer_learning=True,
+):
+    """Tune tasks and generate tuning log to file"""
+    # create tmp log file
+    tmp_log_file = log_filename + ".tmp"
+    if os.path.exists(tmp_log_file):
+        os.remove(tmp_log_file)
+
+    for i, tsk in enumerate(reversed(tasks)):
+        prefix = f"[Task {i + 1:2d}/{len(tasks):2d}] "
+
+        # create tuner
+        if tuner in ("xgb", "sgb-rank"):
+            tuner_obj = XGBTuner(tsk, loss_type="rank")
+        elif tuner == "ga":
+            tuner_obj = GATuner(tsk, pop_size=100)
+        elif tuner == "random":
+            tuner_obj = RandomTuner(tsk)
+        elif tuner == "gridsearch":
+            tuner_obj = GridSearchTuner(tsk)
+        else:
+            raise ValueError("Invalid tuner: " + tuner)
+
+        if use_transfer_learning:
+            if os.path.isfile(tmp_log_file):
+                
tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))
+
+        # do tuning
+        tsk_trial = min(n_trial, len(tsk.config_space))
+        tuner_obj.tune(
+            n_trial=tsk_trial,
+            early_stopping=early_stopping,
+            measure_option=measure_option,
+            callbacks=[
+                autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
+                autotvm.callback.log_to_file(tmp_log_file),
+            ],
+        )
+
+    # pick best records to a cache file
+    autotvm.record.pick_best(tmp_log_file, log_filename)
+    os.remove(tmp_log_file)
+
+
+def get_tuning_opt(log_file="tuning.log", n_trial=200):
+    """Returns tuning options"""
+    tuning_opt = {
+        "log_filename": log_file,
+        "tuner": "random",
+        "n_trial": n_trial,
+        "early_stopping": 60,
+        "measure_option": autotvm.measure_option(
+            builder=autotvm.LocalBuilder(timeout=10),
+            runner=autotvm.LocalRunner(number=20, repeat=3, timeout=4, 
min_repeat_ms=150),
+        ),
+    }
+    return tuning_opt
+
+
+TVM_ASSETS = ["mod.so", "graph.json", "params"]
+
+
+class PyTorchTVMModule:
+    """Helper class for compiling pytorch module to tvm module"""
+
+    def __init__(self) -> None:
+        self.script_module = None
+        self.input_infos = None
+        self.default_dtype = "float32"
+        self.mod = None
+        self.params = None
+        self.tasks = None
+        self.target = "cuda"
+        self.dev = tvm.cuda(0)
+        self.log_file = None
+        self.tvm_module = None
+        self.tvm_graph = None
+        self.tvm_lib = None
+        self.tvm_params = None
+
+    def from_pytorch(self, script_module, input_infos, 
default_dtype="float32"):
+        self.script_module = script_module
+        self.input_infos = input_infos
+        self.default_dtype = default_dtype
+        self.mod, self.params = relay.frontend.from_pytorch(
+            script_module, input_infos, default_dtype=default_dtype
+        )
+
+    def tune_tvm(self, log_file="tuning.log", n_trial=200):
+        self.tasks = autotvm.task.extract_from_program(
+            self.mod["main"],
+            target=self.target,
+            params=self.params,
+        )
+        self.log_file = log_file
+        tuning_opt = get_tuning_opt(log_file, n_trial)
+        tune_tasks(self.tasks, **tuning_opt)
+
+    def build_tvm(self, export_dir, debug_runtime=False):
+        tvm_mod = self._build_tvm(debug_runtime)
+        self._export_tvm(export_dir)
+        return tvm_mod
+
+    def _build_tvm(self, debug_runtime=False):
+        # compile kernels with history best records
+        with autotvm.apply_history_best(self.log_file):
+            with tvm.transform.PassContext(opt_level=3):
+                self.tvm_graph, self.tvm_lib, self.tvm_params = relay.build(
+                    self.mod, target=self.target, params=self.params
+                )
+
+        if not debug_runtime:
+            self.tvm_module = graph_executor.create(self.tvm_graph, 
self.tvm_lib, device=self.dev)
+        else:
+            self.tvm_module = debug_executor.create(self.tvm_graph, 
self.tvm_lib, device=self.dev)
+        self.tvm_module.set_input(**self.tvm_params)
+        return self.tvm_module
+
+    def _export_tvm(self, export_dir):
+        if not os.path.isdir(export_dir):
+            os.makedirs(export_dir)
+        self.export_dir = export_dir
+        self.tvm_lib.export_library(os.path.join(export_dir, TVM_ASSETS[0]))
+        with open(os.path.join(export_dir, TVM_ASSETS[1]), "w", 
encoding="utf8") as fout:
+            fout.write(self.tvm_graph)
+        with open(os.path.join(export_dir, TVM_ASSETS[2]), "wb") as fout:
+            fout.write(relay.save_param_dict(self.tvm_params))
+
+    def load_tvm(self, export_dir):
+        """Load tvm module from export directory"""
+        self.export_dir = export_dir
+        self.tvm_lib = load_module(os.path.join(export_dir, TVM_ASSETS[0]))
+        with open(os.path.join(export_dir, TVM_ASSETS[1]), "r", 
encoding="utf8") as f:
+            self.tvm_graph = f.read()
+        with open(os.path.join(export_dir, TVM_ASSETS[2]), "rb") as f:
+            self.tvm_params = relay.load_param_dict(f.read())
+
+        self.tvm_module = graph_executor.create(self.tvm_graph, self.tvm_lib, 
device=self.dev)
+        self.tvm_module.set_input(**self.tvm_params)
+        return self.tvm_module
+
+    def build_pytorch_op(self, num_inputs, num_outputs, input_infos=None):
+        assert self.export_dir, "you must build_tvm or load_tvm before"
+        input_infos = input_infos or self.input_infos
+        assert input_infos
+        assert len(input_infos) == num_inputs
+        assets = [os.path.join(self.export_dir, i) for i in TVM_ASSETS]
+        input_shapes = [i[1] for i in input_infos]
+        mod = GraphModule(num_inputs=num_inputs, 
num_outputs=num_outputs).to(self.target)
+        mod.init(input_shapes, *assets)
+        return mod
+
+
+def compile(script_module, option):  # pylint: disable=redefined-builtin
+    """
+    option = {
+        "input_infos": [
+            ("x", (1, 3, 244, 244)),
+        ],
+        "default_dtype": "float16",
+        "export_dir": "pytorch_compiled",
+        "num_outputs": 1,
+        "tuning_n_trials": 20,  # set zero to skip tuning
+        "tuning_log_file": "tuning.log",
+    }
+    script_module = torch.jit.script(model)
+    pytorch_tvm_module = compile(script_module, option)
+    pytorch_tvm_module("model_tvm.pt")
+    """
+    mod = PyTorchTVMModule()
+    print("Converting...")
+    input_infos = option["input_infos"]
+    default_dtype = option.get("default_dtype", "float32")
+    export_dir = option.get("export_dir", "pytorch_compiled")
+    tuning_log_file = option.get("tuning_log_file", "tuning.log")
+    tuning_n_trials = option.get("tuning_n_trials", 20)
+    num_outputs = option.get("num_outputs", 1)
+

Review comment:
       I think it is worth adding an option to enable fp16 quantization and 
NHWC layout conversion on TVM side, to enable using Tensor cores.

##########
File path: python/tvm/contrib/torch/__init__.py
##########
@@ -0,0 +1,50 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Module container of Pytorch custom class"""
+import os
+import platform
+import torch
+from tvm._ffi import libinfo
+from tvm.relay.frontend import pytorch
+
+
+def _load_platform_specific_library(lib_name="libpt_tvmdsoop"):
+    system = platform.system()
+    if system == "Darwin":
+        lib_file_name = lib_name + ".dylib"
+    elif system == "Windows":
+        lib_file_name = lib_name + ".dll"

Review comment:
       From the cmake config file I'm assuming that you only support linux. 

##########
File path: python/tvm/contrib/torch/module.py
##########
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Module container of PyTorch custom class"""
+from typing import List
+import torch
+
+
+class GraphModule(torch.nn.Module):
+    r"""Module container of Pytorch class which wraps exported
+    TVM op implementation library to be called on Pytorch side"""
+
+    @classmethod
+    def shape_repr(cls, input_shapes):
+        return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes)
+
+    def __init__(self, num_inputs, num_outputs, device=None):
+        super().__init__()
+        self.dummy_param = torch.nn.Parameter(torch.empty(0))
+        self.engine = None
+
+        if device is not None:
+            self.to(device)
+        self.engine = torch.classes.tvm_dsoop.TvmGraphModule(num_inputs, 
num_outputs, self.device)
+
+    def init(self, input_shapes, lib_path, graph_path, params_path):
+        r"""Load tvm module"""
+        self.engine.load_tvm_module(input_shapes, lib_path, graph_path, 
params_path)
+
+    def forward(self, inputs: List[torch.Tensor]):
+        r"""Call tvm module to forward"""
+        return self.engine.forward(inputs)
+
+    @property
+    def device(self):
+        r"""Get the device string"""
+        return str(self.dummy_param.device)
+
+    def _apply(self, func):
+        r"""Override to device function, manually move tvm module to desired 
device"""
+        super()._apply(func)
+        if self.engine is not None:
+            self.engine.to(self.device)
+        return self
+
+
+class VMModule(torch.nn.Module):
+    r"""Module container of Pytorch class which wraps exported
+    TVM op implementation library to be called on Pytorch side"""
+
+    @classmethod
+    def shape_repr(cls, input_shapes):
+        return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes)
+
+    def __init__(self, num_inputs, num_outputs, device=None):
+        super().__init__()
+        self.dummy_param = torch.nn.Parameter(torch.empty(0))
+        self.engine = None
+
+        if device is not None:
+            self.to(device)
+        self.engine = torch.classes.tvm_dsoop.TvmVMModule(num_inputs, 
num_outputs, self.device)
+
+    def init(self, input_shapes, lib_path, code_path):
+        r"""Load tvm module"""
+        self.engine.load_tvm_module(input_shapes, lib_path, code_path)
+
+    def forward(self, inputs: List[torch.Tensor]):
+        r"""Call tvm module to forward"""
+        return self.engine.forward(inputs)
+
+    @property
+    def device(self):
+        r"""Get the device string"""
+        return str(self.dummy_param.device)
+
+    def _apply(self, func):
+        r"""Override to device function, manually move tvm module to desired 
device"""
+        super()._apply(func)
+        if self.engine is not None:
+            self.engine.to(self.device)
+        return self
+
+
+class TraceTvmModule(torch.nn.Module):

Review comment:
       This is also not used anywhere

##########
File path: python/tvm/contrib/torch/module.py
##########
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Module container of PyTorch custom class"""
+from typing import List
+import torch
+
+
+class GraphModule(torch.nn.Module):
+    r"""Module container of Pytorch class which wraps exported
+    TVM op implementation library to be called on Pytorch side"""
+
+    @classmethod
+    def shape_repr(cls, input_shapes):
+        return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes)
+
+    def __init__(self, num_inputs, num_outputs, device=None):
+        super().__init__()
+        self.dummy_param = torch.nn.Parameter(torch.empty(0))
+        self.engine = None
+
+        if device is not None:
+            self.to(device)

Review comment:
       Should we make sure that `device` param and the target the `.so` file is 
compiled for are consistent? 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to