masahi commented on code in PR #11911: URL: https://github.com/apache/tvm/pull/11911#discussion_r917466929
########## apps/pt_tvmdsoop/tests/test_optimize_torch.py: ########## @@ -0,0 +1,161 @@ +# pylint: disable=missing-class-docstring +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test script for tvm torch module""" +import tempfile + +import torch +from torch.utils import benchmark +from torchvision.models import resnet18 + +import tvm +import tvm.testing +from tvm.contrib.torch import optimize_torch +from tvm.meta_schedule import TuneConfig + +# default config for testing +config = TuneConfig( Review Comment: not used? ########## src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc: ########## @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <ATen/DLConvertor.h> +#include <dlpack/dlpack.h> +#include <dmlc/memory_io.h> +#include <torch/custom_class.h> +#include <torch/script.h> +#include <tvm/runtime/module.h> +#include <tvm/runtime/registry.h> +#include <tvm/target/codegen.h> +#include <tvm/target/target.h> + +#include <cstdio> +#include <map> +#include <string> +#include <vector> + +#include "../../../runtime/graph_executor/graph_executor_factory.h" +#include "../base64.h" + +namespace tvm { +namespace contrib { + +/** + * We pass the TVM module by TVM's FFI because Torch's FFI cannot recognize such TVM objects + */ +struct ThreadLocalStore { + tvm::runtime::Module mod; + static ThreadLocalStore* ThreadLocal() { + thread_local ThreadLocalStore tls; + return &tls; + } +}; + +using SerializationType = std::string; // base64 stream + +SerializationType serialize(tvm::runtime::Module module) { + static const runtime::PackedFunc* f_to_str = + runtime::Registry::Get("script_torch.save_to_base64"); + ICHECK(f_to_str) << "IndexError: Cannot find the packed function " + "`script_torch.save_to_tar` in the global registry"; + return (*f_to_str)(module); +} + +struct Deleter { // deleter + explicit Deleter(std::string file_name) { this->file_name = file_name; } + void operator()(FILE* p) const { + fclose(p); + ICHECK(remove(file_name.c_str()) == 0) + << "remove temporary file (" << file_name << ") unsuccessfully"; + } + std::string file_name; +}; + +tvm::runtime::Module deserialize(SerializationType state) { + auto length = tvm::support::b64strlen(state); + + std::vector<u_char> bytes(length); + tvm::support::b64decode(state, bytes.data()); + + const std::string name = tmpnam(NULL); + auto file_name = name + ".so"; + std::unique_ptr<FILE, Deleter> pFile(fopen(file_name.c_str(), "wb"), Deleter(file_name)); + fwrite(bytes.data(), sizeof(u_char), length, pFile.get()); + fflush(pFile.get()); + + std::string load_f_name = "runtime.module.loadfile_so"; + const PackedFunc* f = runtime::Registry::Get(load_f_name); + ICHECK(f != nullptr) << "Loader for `.so` files is not registered," + << " resolved to (" << load_f_name << ") in the global registry." + << "Ensure that you have loaded the correct runtime code, and" + << "that you are on the correct hardware architecture."; + + tvm::runtime::Module ret = (*f)(file_name, ""); + + return ret; +} + +/** + * @brief A Torch's module which wraps TVM's OperatorModule Class. + * The basic forward function calling TVM's runtime is provided. + * The TVM module can be serialized/deserialized as a Torch module. + */ +class OperatorModuleWrapper : public torch::jit::CustomClassHolder { + public: + OperatorModuleWrapper() { runtime_module = ThreadLocalStore::ThreadLocal()->mod; } + + void forward(const c10::List<at::Tensor>& inputs) { + int input_length = inputs.size(); + + std::vector<DLManagedTensor*> tensors; + + for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i])); + + tvm::runtime::PackedFunc run = runtime_module.GetFunction("__tvm_main__"); + + std::vector<TVMValue> tvm_values(input_length); + std::vector<int> tvm_type_codes(input_length); + tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); + for (int k = 0; k < input_length; ++k) { + setter(k, &tensors[k]->dl_tensor); + } + + run.CallPacked(tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), input_length), + nullptr); + + for (int k = 0; k < input_length; ++k) { + tensors[k]->deleter(tensors[k]); + } + } + + SerializationType Serialize() { return serialize(runtime_module); } + + explicit OperatorModuleWrapper(SerializationType state) { runtime_module = deserialize(state); } + + private: + tvm::runtime::Module runtime_module; +}; + +tvm::Device getDevice(const at::Tensor& tensor) { + tvm::Device dev; + dev.device_id = tensor.get_device(); + switch (tensor.device().type()) { + case at::DeviceType::CPU: + dev.device_type = DLDeviceType::kDLCPU; + if (dev.device_id == -1) { + /* + * In PyTorch the device ID for cpu is -1, sometimes causing error during tuning + * Thus we manually set the device ID as 0 for avoiding potentially error of index out of + * bounds + */ + dev.device_id = 0; + } + break; + case at::DeviceType::CUDA: + dev.device_type = DLDeviceType::kDLCUDA; + break; + default: + TORCH_CHECK(false, "PyTorch TVM integration doesn't support device " + tensor.device().str()); + } + return dev; +} + +/** + * @brief A Torch's module which wraps TVM's GraphExecutorFactory Class. + * The basic forward function calling TVM's runtime is provided. + * The TVM module can be serialized/deserialized as a Torch module. + */ +class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { Review Comment: This looks similar to `TvmGraphModulePack`: https://github.com/apache/tvm/blob/e7024fb39ea27494fa5618102dae42e7e5551986/src/contrib/torch/pt_call_tvm/tvm_class.cc#L40 Why do we need this? ########## src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc: ########## @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <ATen/DLConvertor.h> +#include <dlpack/dlpack.h> +#include <dmlc/memory_io.h> +#include <torch/custom_class.h> +#include <torch/script.h> +#include <tvm/runtime/module.h> +#include <tvm/runtime/registry.h> +#include <tvm/target/codegen.h> +#include <tvm/target/target.h> + +#include <cstdio> +#include <map> +#include <string> +#include <vector> + +#include "../../../runtime/graph_executor/graph_executor_factory.h" +#include "../base64.h" + +namespace tvm { +namespace contrib { + +/** + * We pass the TVM module by TVM's FFI because Torch's FFI cannot recognize such TVM objects + */ +struct ThreadLocalStore { + tvm::runtime::Module mod; + static ThreadLocalStore* ThreadLocal() { + thread_local ThreadLocalStore tls; + return &tls; + } +}; + +using SerializationType = std::string; // base64 stream + +SerializationType serialize(tvm::runtime::Module module) { + static const runtime::PackedFunc* f_to_str = + runtime::Registry::Get("script_torch.save_to_base64"); + ICHECK(f_to_str) << "IndexError: Cannot find the packed function " + "`script_torch.save_to_tar` in the global registry"; + return (*f_to_str)(module); +} + +struct Deleter { // deleter + explicit Deleter(std::string file_name) { this->file_name = file_name; } + void operator()(FILE* p) const { + fclose(p); + ICHECK(remove(file_name.c_str()) == 0) + << "remove temporary file (" << file_name << ") unsuccessfully"; + } + std::string file_name; +}; + +tvm::runtime::Module deserialize(SerializationType state) { Review Comment: Are `serialize` / `deserialize` tested in this PR? ########## python/tvm/contrib/torch/optimize_torch.py: ########## @@ -0,0 +1,143 @@ +# pylint: disable=inconsistent-return-statements +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +""" +optimize_torch: aa function similar to `torch.jit.trace`, +which is used to optimize the `torch.nn.module` by TVM metaSchedule, +and returns a custom TorchScript operator +""" +import base64 +import contextlib +import tempfile +from typing import Tuple + +import torch +import torch.utils.dlpack + +import tvm +from tvm import relay +from tvm._ffi import get_global_func, register_func +from tvm.meta_schedule import TuneConfig +from tvm.meta_schedule.tune import tune_relay + + +# The python wrapper for GraphExecutorFactory +class GraphExecutorFactoryWrapper(torch.nn.Module): + def __init__(self, module: tvm.runtime.Module): + super().__init__() + self.inner_module = module + + def forward(self, *torch_inputs: Tuple[torch.Tensor]): + ret = self.inner_module.forward(torch_inputs) + if len(ret) == 1: + return ret[0] + return ret + + +def llvm_target(): + return "llvm -num-cores" + + +@register_func("script_torch.save_to_base64") +def save_to_base64(obj) -> bytes: + with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile: + obj.export_library(tmpfile.name) + with open(tmpfile.name, "rb") as tfile: + return base64.b64encode(tfile.read()) + + +def optimize_torch( + func, + example_inputs, + tuning_config=None, + target=None, + work_dir=None, +): + """Load PyTorch model that could be traced by TorchScript, then optimize it via MetaSchedule. + + Parameters + ---------- + func : callable or torch.nn.Module + A Python function or nn.Module that could run by TorchScript's trace. + (ie: torch.jit.trace(model, input)) + + example_inputs : tuple or torch.Tensor + A tuple of example inputs that + will run together with `func` by providing the shape information. + + tuning_config : tvm.meta_schedule.TuneConfig + The configuration of tuning by MetaSchedule. + We suggest users to provide their own setting, + otherwise by default setting a tuning process could be very slow, + sometimes costs a few hours. Review Comment: Need to improve this doc ########## python/tvm/contrib/torch/optimize_torch.py: ########## @@ -0,0 +1,143 @@ +# pylint: disable=inconsistent-return-statements +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +""" +optimize_torch: aa function similar to `torch.jit.trace`, +which is used to optimize the `torch.nn.module` by TVM metaSchedule, +and returns a custom TorchScript operator +""" +import base64 +import contextlib +import tempfile +from typing import Tuple + +import torch +import torch.utils.dlpack + +import tvm +from tvm import relay +from tvm._ffi import get_global_func, register_func +from tvm.meta_schedule import TuneConfig +from tvm.meta_schedule.tune import tune_relay + + +# The python wrapper for GraphExecutorFactory +class GraphExecutorFactoryWrapper(torch.nn.Module): + def __init__(self, module: tvm.runtime.Module): + super().__init__() + self.inner_module = module + + def forward(self, *torch_inputs: Tuple[torch.Tensor]): + ret = self.inner_module.forward(torch_inputs) + if len(ret) == 1: + return ret[0] + return ret + + +def llvm_target(): + return "llvm -num-cores" + + +@register_func("script_torch.save_to_base64") +def save_to_base64(obj) -> bytes: + with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile: + obj.export_library(tmpfile.name) + with open(tmpfile.name, "rb") as tfile: + return base64.b64encode(tfile.read()) + + +def optimize_torch( + func, + example_inputs, + tuning_config=None, + target=None, + work_dir=None, +): + """Load PyTorch model that could be traced by TorchScript, then optimize it via MetaSchedule. + + Parameters + ---------- + func : callable or torch.nn.Module + A Python function or nn.Module that could run by TorchScript's trace. + (ie: torch.jit.trace(model, input)) + + example_inputs : tuple or torch.Tensor + A tuple of example inputs that + will run together with `func` by providing the shape information. + + tuning_config : tvm.meta_schedule.TuneConfig + The configuration of tuning by MetaSchedule. + We suggest users to provide their own setting, + otherwise by default setting a tuning process could be very slow, + sometimes costs a few hours. + + target : Optional[Union[str, Target]] + The target of the compilation. + If user doesn't set the target, the module is built upon the LLVM. Review Comment: will be built for the CPU target ########## apps/pt_tvmdsoop/tests/test_as_torch.py: ########## @@ -0,0 +1,148 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test script for tvm torch module""" +import numpy as np + +import torch +import torch.nn + +import tvm +import tvm.testing +from tvm.contrib.torch import as_torch +from tvm.script import tir as T + + +@as_torch +def matmul(M: int, N: int, K: int, dtype: str): + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [M, K], dtype=dtype) + B = T.match_buffer(b, [N, K], dtype=dtype) + C = T.match_buffer(c, [M, N], dtype=dtype) + for i, j, k in T.grid(M, N, K): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + return main + + +@as_torch [email protected]_module +class MyModule: + @T.prim_func + def main(a: T.handle, b: T.handle): + # We exchange data between function by handles, which are similar to pointer. + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # Create buffer from handles. + A = T.match_buffer(a, (8,), dtype="float32") + B = T.match_buffer(b, (8,), dtype="float32") + for i in range(8): + # A block is an abstraction for computation. + with T.block("B"): + # Define a spatial block iterator and bind it to value i. + vi = T.axis.spatial(8, i) + B[vi] = A[vi] + 1.0 + + +@as_torch [email protected]_module +class ModuleGPU: + @T.prim_func + def main(A: T.Buffer[8, "float32"], B: T.Buffer[8, "float32"]) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i_0 in T.thread_binding(2, thread="blockIdx.x"): + for i_2 in T.thread_binding(2, thread="threadIdx.x"): + for i_1 in T.serial(2): + with T.block("B"): + vi = T.axis.spatial(8, i_0 * 4 + i_1 * 2 + i_2) + T.reads(A[vi]) + T.writes(B[vi]) + B[vi] = A[vi] + T.float32(1) + + +class MinuesOnes(torch.nn.Module): + def __init__(self): + super(MinuesOnes, self).__init__() + self.engine = MyModule + + def forward(self, *input): + self.engine.forward(*input) + return input[-1] - 1 + + +def test_tvmscript_torch_matmul(): + s1 = np.ones((128, 128)).astype("float32") + s2 = np.ones((128, 128)).astype("float32") Review Comment: use random matrices for inputs. `matmul` is computing rhs-transposed matmul, which is not the same as `np.matmul`. ########## python/tvm/contrib/torch/optimize_torch.py: ########## @@ -0,0 +1,143 @@ +# pylint: disable=inconsistent-return-statements +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +""" +optimize_torch: aa function similar to `torch.jit.trace`, +which is used to optimize the `torch.nn.module` by TVM metaSchedule, +and returns a custom TorchScript operator +""" +import base64 +import contextlib +import tempfile +from typing import Tuple + +import torch +import torch.utils.dlpack + +import tvm +from tvm import relay +from tvm._ffi import get_global_func, register_func +from tvm.meta_schedule import TuneConfig +from tvm.meta_schedule.tune import tune_relay + + +# The python wrapper for GraphExecutorFactory +class GraphExecutorFactoryWrapper(torch.nn.Module): + def __init__(self, module: tvm.runtime.Module): + super().__init__() + self.inner_module = module + + def forward(self, *torch_inputs: Tuple[torch.Tensor]): + ret = self.inner_module.forward(torch_inputs) + if len(ret) == 1: + return ret[0] + return ret + + +def llvm_target(): + return "llvm -num-cores" + + +@register_func("script_torch.save_to_base64") +def save_to_base64(obj) -> bytes: + with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile: + obj.export_library(tmpfile.name) + with open(tmpfile.name, "rb") as tfile: + return base64.b64encode(tfile.read()) + + +def optimize_torch( + func, + example_inputs, + tuning_config=None, + target=None, + work_dir=None, +): + """Load PyTorch model that could be traced by TorchScript, then optimize it via MetaSchedule. + + Parameters + ---------- + func : callable or torch.nn.Module + A Python function or nn.Module that could run by TorchScript's trace. + (ie: torch.jit.trace(model, input)) + + example_inputs : tuple or torch.Tensor + A tuple of example inputs that + will run together with `func` by providing the shape information. Review Comment: Just say "inputs to `torch.jit.trace`" ########## src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc: ########## @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <ATen/DLConvertor.h> +#include <dlpack/dlpack.h> +#include <dmlc/memory_io.h> +#include <torch/custom_class.h> +#include <torch/script.h> +#include <tvm/runtime/module.h> +#include <tvm/runtime/registry.h> +#include <tvm/target/codegen.h> +#include <tvm/target/target.h> + +#include <cstdio> +#include <map> +#include <string> +#include <vector> + +#include "../../../runtime/graph_executor/graph_executor_factory.h" +#include "../base64.h" + +namespace tvm { +namespace contrib { + +/** + * We pass the TVM module by TVM's FFI because Torch's FFI cannot recognize such TVM objects + */ +struct ThreadLocalStore { + tvm::runtime::Module mod; + static ThreadLocalStore* ThreadLocal() { + thread_local ThreadLocalStore tls; + return &tls; + } +}; + +using SerializationType = std::string; // base64 stream + +SerializationType serialize(tvm::runtime::Module module) { + static const runtime::PackedFunc* f_to_str = + runtime::Registry::Get("script_torch.save_to_base64"); + ICHECK(f_to_str) << "IndexError: Cannot find the packed function " + "`script_torch.save_to_tar` in the global registry"; + return (*f_to_str)(module); +} + +struct Deleter { // deleter + explicit Deleter(std::string file_name) { this->file_name = file_name; } + void operator()(FILE* p) const { + fclose(p); + ICHECK(remove(file_name.c_str()) == 0) + << "remove temporary file (" << file_name << ") unsuccessfully"; Review Comment: "Failed to remove temporary file (" << file_name << ")"; ########## python/tvm/contrib/torch/as_torch.py: ########## @@ -0,0 +1,89 @@ +# pylint: disable=inconsistent-return-statements +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +""" +as_torch: a decorator, which is used to wrap the TVMscript code to `torch.nn.module`. +""" +from typing import Callable, List, Union + +import torch +import torch.utils.dlpack + +import tvm + + +# python wrapper for OperatorModule +class OperatorModuleWrapper(torch.nn.Module): + def __init__( + self, + module: Union[ + tvm.ir.module.IRModule, + tvm.tir.function.PrimFunc, + tvm.contrib.graph_executor.GraphModule, Review Comment: I think it's better to remove `tvm.contrib.graph_executor.GraphModule` from this list. Otherwise it's not clear what `OperatorModuleWrapper` is supposed to do. ########## python/tvm/contrib/torch/optimize_torch.py: ########## @@ -0,0 +1,143 @@ +# pylint: disable=inconsistent-return-statements +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +""" +optimize_torch: aa function similar to `torch.jit.trace`, +which is used to optimize the `torch.nn.module` by TVM metaSchedule, +and returns a custom TorchScript operator +""" +import base64 +import contextlib +import tempfile +from typing import Tuple + +import torch +import torch.utils.dlpack + +import tvm +from tvm import relay +from tvm._ffi import get_global_func, register_func +from tvm.meta_schedule import TuneConfig +from tvm.meta_schedule.tune import tune_relay + + +# The python wrapper for GraphExecutorFactory +class GraphExecutorFactoryWrapper(torch.nn.Module): + def __init__(self, module: tvm.runtime.Module): + super().__init__() + self.inner_module = module + + def forward(self, *torch_inputs: Tuple[torch.Tensor]): + ret = self.inner_module.forward(torch_inputs) + if len(ret) == 1: + return ret[0] + return ret + + +def llvm_target(): + return "llvm -num-cores" + + +@register_func("script_torch.save_to_base64") +def save_to_base64(obj) -> bytes: + with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile: + obj.export_library(tmpfile.name) + with open(tmpfile.name, "rb") as tfile: + return base64.b64encode(tfile.read()) + + +def optimize_torch( + func, + example_inputs, + tuning_config=None, + target=None, + work_dir=None, +): + """Load PyTorch model that could be traced by TorchScript, then optimize it via MetaSchedule. + + Parameters + ---------- + func : callable or torch.nn.Module + A Python function or nn.Module that could run by TorchScript's trace. + (ie: torch.jit.trace(model, input)) + + example_inputs : tuple or torch.Tensor + A tuple of example inputs that + will run together with `func` by providing the shape information. + + tuning_config : tvm.meta_schedule.TuneConfig + The configuration of tuning by MetaSchedule. + We suggest users to provide their own setting, + otherwise by default setting a tuning process could be very slow, + sometimes costs a few hours. + + target : Optional[Union[str, Target]] + The target of the compilation. + If user doesn't set the target, the module is built upon the LLVM. + + work_dir : Optional[str] + The working directory to save intermediate results. + + Returns + ------- + mod : GraphExecutorFactoryWrapper + It will return an object of GraphExecutorFactoryWrapper, + which is the subclass of the original nn.Module. + """ + + if target: + pass + else: + target = llvm_target() Review Comment: ``` if target is None: target = llvm_target() ``` ########## src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc: ########## @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <ATen/DLConvertor.h> +#include <dlpack/dlpack.h> +#include <dmlc/memory_io.h> +#include <torch/custom_class.h> +#include <torch/script.h> +#include <tvm/runtime/module.h> +#include <tvm/runtime/registry.h> +#include <tvm/target/codegen.h> +#include <tvm/target/target.h> + +#include <cstdio> +#include <map> +#include <string> +#include <vector> + +#include "../../../runtime/graph_executor/graph_executor_factory.h" +#include "../base64.h" + +namespace tvm { +namespace contrib { + +/** + * We pass the TVM module by TVM's FFI because Torch's FFI cannot recognize such TVM objects + */ +struct ThreadLocalStore { + tvm::runtime::Module mod; + static ThreadLocalStore* ThreadLocal() { + thread_local ThreadLocalStore tls; + return &tls; + } +}; + +using SerializationType = std::string; // base64 stream + +SerializationType serialize(tvm::runtime::Module module) { + static const runtime::PackedFunc* f_to_str = + runtime::Registry::Get("script_torch.save_to_base64"); + ICHECK(f_to_str) << "IndexError: Cannot find the packed function " + "`script_torch.save_to_tar` in the global registry"; + return (*f_to_str)(module); +} + +struct Deleter { // deleter + explicit Deleter(std::string file_name) { this->file_name = file_name; } + void operator()(FILE* p) const { + fclose(p); + ICHECK(remove(file_name.c_str()) == 0) Review Comment: where does this `remove` function come from? ########## python/tvm/contrib/torch/optimize_torch.py: ########## @@ -0,0 +1,143 @@ +# pylint: disable=inconsistent-return-statements +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +""" +optimize_torch: aa function similar to `torch.jit.trace`, +which is used to optimize the `torch.nn.module` by TVM metaSchedule, +and returns a custom TorchScript operator +""" +import base64 +import contextlib +import tempfile +from typing import Tuple + +import torch +import torch.utils.dlpack + +import tvm +from tvm import relay +from tvm._ffi import get_global_func, register_func +from tvm.meta_schedule import TuneConfig +from tvm.meta_schedule.tune import tune_relay + + +# The python wrapper for GraphExecutorFactory +class GraphExecutorFactoryWrapper(torch.nn.Module): + def __init__(self, module: tvm.runtime.Module): + super().__init__() + self.inner_module = module + + def forward(self, *torch_inputs: Tuple[torch.Tensor]): + ret = self.inner_module.forward(torch_inputs) + if len(ret) == 1: + return ret[0] + return ret + + +def llvm_target(): + return "llvm -num-cores" + + +@register_func("script_torch.save_to_base64") +def save_to_base64(obj) -> bytes: + with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile: + obj.export_library(tmpfile.name) + with open(tmpfile.name, "rb") as tfile: + return base64.b64encode(tfile.read()) + + +def optimize_torch( + func, + example_inputs, + tuning_config=None, + target=None, + work_dir=None, +): + """Load PyTorch model that could be traced by TorchScript, then optimize it via MetaSchedule. + + Parameters + ---------- + func : callable or torch.nn.Module + A Python function or nn.Module that could run by TorchScript's trace. + (ie: torch.jit.trace(model, input)) + + example_inputs : tuple or torch.Tensor + A tuple of example inputs that + will run together with `func` by providing the shape information. + + tuning_config : tvm.meta_schedule.TuneConfig + The configuration of tuning by MetaSchedule. + We suggest users to provide their own setting, + otherwise by default setting a tuning process could be very slow, + sometimes costs a few hours. + + target : Optional[Union[str, Target]] + The target of the compilation. + If user doesn't set the target, the module is built upon the LLVM. + + work_dir : Optional[str] + The working directory to save intermediate results. + + Returns + ------- + mod : GraphExecutorFactoryWrapper + It will return an object of GraphExecutorFactoryWrapper, + which is the subclass of the original nn.Module. + """ + + if target: + pass + else: + target = llvm_target() + + if tuning_config: + pass + else: Review Comment: ``` if tuning_config is None: tuning_config = ... ``` ########## src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc: ########## @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <ATen/DLConvertor.h> +#include <dlpack/dlpack.h> +#include <dmlc/memory_io.h> +#include <torch/custom_class.h> +#include <torch/script.h> +#include <tvm/runtime/module.h> +#include <tvm/runtime/registry.h> +#include <tvm/target/codegen.h> +#include <tvm/target/target.h> + +#include <cstdio> +#include <map> +#include <string> +#include <vector> + +#include "../../../runtime/graph_executor/graph_executor_factory.h" +#include "../base64.h" + +namespace tvm { +namespace contrib { + +/** + * We pass the TVM module by TVM's FFI because Torch's FFI cannot recognize such TVM objects + */ +struct ThreadLocalStore { + tvm::runtime::Module mod; + static ThreadLocalStore* ThreadLocal() { + thread_local ThreadLocalStore tls; + return &tls; + } +}; + +using SerializationType = std::string; // base64 stream + +SerializationType serialize(tvm::runtime::Module module) { + static const runtime::PackedFunc* f_to_str = + runtime::Registry::Get("script_torch.save_to_base64"); + ICHECK(f_to_str) << "IndexError: Cannot find the packed function " + "`script_torch.save_to_tar` in the global registry"; Review Comment: typo `save_to_tar` ########## python/tvm/contrib/torch/optimize_torch.py: ########## @@ -0,0 +1,143 @@ +# pylint: disable=inconsistent-return-statements +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring +# pylint: disable=missing-class-docstring +# pylint: disable=missing-function-docstring +""" +optimize_torch: aa function similar to `torch.jit.trace`, +which is used to optimize the `torch.nn.module` by TVM metaSchedule, +and returns a custom TorchScript operator +""" +import base64 +import contextlib +import tempfile +from typing import Tuple + +import torch +import torch.utils.dlpack + +import tvm +from tvm import relay +from tvm._ffi import get_global_func, register_func +from tvm.meta_schedule import TuneConfig +from tvm.meta_schedule.tune import tune_relay + + +# The python wrapper for GraphExecutorFactory +class GraphExecutorFactoryWrapper(torch.nn.Module): + def __init__(self, module: tvm.runtime.Module): + super().__init__() + self.inner_module = module + + def forward(self, *torch_inputs: Tuple[torch.Tensor]): + ret = self.inner_module.forward(torch_inputs) + if len(ret) == 1: + return ret[0] + return ret + + +def llvm_target(): + return "llvm -num-cores" + + +@register_func("script_torch.save_to_base64") +def save_to_base64(obj) -> bytes: + with tempfile.NamedTemporaryFile(suffix=".so") as tmpfile: + obj.export_library(tmpfile.name) + with open(tmpfile.name, "rb") as tfile: + return base64.b64encode(tfile.read()) + + +def optimize_torch( + func, + example_inputs, + tuning_config=None, + target=None, + work_dir=None, +): + """Load PyTorch model that could be traced by TorchScript, then optimize it via MetaSchedule. + + Parameters + ---------- + func : callable or torch.nn.Module + A Python function or nn.Module that could run by TorchScript's trace. + (ie: torch.jit.trace(model, input)) + + example_inputs : tuple or torch.Tensor + A tuple of example inputs that + will run together with `func` by providing the shape information. + + tuning_config : tvm.meta_schedule.TuneConfig + The configuration of tuning by MetaSchedule. + We suggest users to provide their own setting, + otherwise by default setting a tuning process could be very slow, + sometimes costs a few hours. + + target : Optional[Union[str, Target]] + The target of the compilation. + If user doesn't set the target, the module is built upon the LLVM. + + work_dir : Optional[str] + The working directory to save intermediate results. + + Returns + ------- + mod : GraphExecutorFactoryWrapper + It will return an object of GraphExecutorFactoryWrapper, + which is the subclass of the original nn.Module. + """ + + if target: + pass + else: + target = llvm_target() + + if tuning_config: + pass + else: + # Default setting. For a better tuning result the number could be set large. + tuning_config = TuneConfig( + strategy="evolutionary", + num_trials_per_iter=64, + max_trials_per_task=2000, + max_trials_global=2000, + ) Review Comment: I highly doubt this default config would be useful. For e2e model this is probably not enough, and for a single op this can be too much + unnecessarily long tuning. Also depends on the target to tune for. I think we need to come up with a more intelligent default config, or make the config a required argument. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
