This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e7c04f554b [Refactor] Introduce base Executable class and
`tvm.compile` interface (#17710)
e7c04f554b is described below
commit e7c04f554b81d9f7059269b78b6c645cf43b5540
Author: Siyuan Feng <[email protected]>
AuthorDate: Fri Mar 7 21:00:48 2025 +0800
[Refactor] Introduce base Executable class and `tvm.compile` interface
(#17710)
This refactor introduces a base Executable class and a `tvm.compile`
interface that can be used to compile both TIR and Relax programs.
`tvm.compile` will return an Executable object that can be used to call
either TIR or Relax functions.
---
docs/how_to/tutorials/optimize_llm.py | 2 +-
include/tvm/relax/exec_builder.h | 6 +-
include/tvm/runtime/relax_vm/executable.h | 30 +--
include/tvm/runtime/relax_vm/vm.h | 2 +-
python/tvm/__init__.py | 2 +-
python/tvm/contrib/hexagon/session.py | 38 +--
python/tvm/driver/__init__.py | 4 +-
python/tvm/driver/build_module.py | 85 ++++++-
python/tvm/meta_schedule/relax_integration.py | 6 +-
.../meta_schedule/testing/custom_builder_runner.py | 17 +-
python/tvm/meta_schedule/testing/tune_utils.py | 4 +-
python/tvm/relax/__init__.py | 2 +-
python/tvm/relax/exec_builder.py | 6 +-
python/tvm/relax/frontend/nn/core.py | 2 +-
python/tvm/relax/pipeline.py | 1 +
.../transform/tuning_api/default_functions.py | 2 +-
python/tvm/relax/vm_build.py | 129 ++--------
python/tvm/runtime/__init__.py | 5 +-
python/tvm/runtime/executable.py | 166 +++++++++++++
python/tvm/runtime/relax_vm.py | 17 +-
python/tvm/tir/build.py | 2 +-
python/tvm/tir/pipeline.py | 4 +-
src/relax/backend/vm/codegen_vm.cc | 6 +-
src/relax/backend/vm/exec_builder.cc | 10 +-
src/runtime/relax_vm/executable.cc | 56 ++---
src/runtime/relax_vm/vm.cc | 6 +-
tests/python/driver/test_compile.py | 121 ++++++++++
tests/python/runtime/test_executable.py | 263 +++++++++++++++++++++
tests/scripts/task_python_unittest.sh | 7 +-
29 files changed, 764 insertions(+), 237 deletions(-)
diff --git a/docs/how_to/tutorials/optimize_llm.py
b/docs/how_to/tutorials/optimize_llm.py
index 49855910fc..8cc674920d 100644
--- a/docs/how_to/tutorials/optimize_llm.py
+++ b/docs/how_to/tutorials/optimize_llm.py
@@ -426,7 +426,7 @@ def _pipeline( # pylint: disable=too-many-arguments
with target:
- ex = relax.build(mod, target, pipeline=relax.get_pipeline("opt_llm"))
+ ex = tvm.compile(mod, target, relax_pipeline=relax.get_pipeline("opt_llm"))
vm = relax.VirtualMachine(ex, dev)
diff --git a/include/tvm/relax/exec_builder.h b/include/tvm/relax/exec_builder.h
index 03e58392c2..8940408f80 100644
--- a/include/tvm/relax/exec_builder.h
+++ b/include/tvm/relax/exec_builder.h
@@ -125,12 +125,12 @@ class ExecBuilderNode : public Object {
/*!
* \brief Raw access to underlying executable build in progress.
*/
- vm::Executable* exec() const;
+ vm::VMExecutable* exec() const;
/*!
* \brief Finalize the build, run formalize and get the final result.
* \note This function should not be called during construction.
*/
- ObjectPtr<vm::Executable> Get();
+ ObjectPtr<vm::VMExecutable> Get();
/*!
* \brief Create an ExecBuilder.
* \return The ExecBuilder.
@@ -165,7 +165,7 @@ class ExecBuilderNode : public Object {
void Formalize();
/*! \brief The mutable internal executable. */
- ObjectPtr<vm::Executable> exec_; // mutable
+ ObjectPtr<vm::VMExecutable> exec_; // mutable
/*! \brief internal dedup map when creating index for a new constant */
std::unordered_map<ObjectRef, vm::Index, StructuralHash, StructuralEqual>
const_dedup_map_;
};
diff --git a/include/tvm/runtime/relax_vm/executable.h
b/include/tvm/runtime/relax_vm/executable.h
index e953d94eb9..7028a62329 100644
--- a/include/tvm/runtime/relax_vm/executable.h
+++ b/include/tvm/runtime/relax_vm/executable.h
@@ -80,12 +80,12 @@ struct VMFuncInfo {
};
/*!
- * \brief The executable emitted by the VM compiler.
+ * \brief The virtual machine executable emitted by the VM compiler.
*
* The executable contains information (e.g. data in different memory regions)
* to run in a virtual machine.
*/
-class Executable : public runtime::ModuleNode {
+class VMExecutable : public runtime::ModuleNode {
public:
/*! \brief Get the property of the runtime module .*/
int GetPropertyMask() const final { return
ModulePropertyMask::kBinarySerializable; };
@@ -120,18 +120,18 @@ class Executable : public runtime::ModuleNode {
*/
String AsPython() const;
/*!
- * \brief Write the Executable to the binary stream in serialized form.
+ * \brief Write the VMExecutable to the binary stream in serialized form.
* \param stream The binary stream to save the executable to.
*/
void SaveToBinary(dmlc::Stream* stream) final;
/*!
- * \brief Load Executable from the binary stream in serialized form.
+ * \brief Load VMExecutable from the binary stream in serialized form.
* \param stream The binary stream that load the executable from.
* \return The loaded executable, in the form of a `runtime::Module`.
*/
static Module LoadFromBinary(void* stream);
/*!
- * \brief Write the Executable to the provided path as a file containing its
serialized content.
+ * \brief Write the VMExecutable to the provided path as a file containing
its serialized content.
* \param file_name The name of the file to write the serialized data to.
* \param format The target format of the saved file.
*/
@@ -140,10 +140,10 @@ class Executable : public runtime::ModuleNode {
Module VMLoadExecutable() const;
/*! \brief Create a Relax virtual machine with profiler and load `this` as
the executable. */
Module VMProfilerLoadExecutable() const;
- /*! \brief Check if the Executable contains a specific function. */
+ /*! \brief Check if the VMExecutable contains a specific function. */
bool HasFunction(const String& name) const;
/*!
- * \brief Load Executable from the file.
+ * \brief Load VMExecutable from the file.
* \param file_name The path of the file that load the executable from.
* \return The loaded executable, in the form of a `runtime::Module`.
*/
@@ -160,15 +160,15 @@ class Executable : public runtime::ModuleNode {
/*! \brief The byte data of instruction. */
std::vector<ExecWord> instr_data;
- virtual ~Executable() {}
+ virtual ~VMExecutable() {}
- TVM_MODULE_VTABLE_BEGIN("relax.Executable");
- TVM_MODULE_VTABLE_ENTRY("stats", &Executable::Stats);
- TVM_MODULE_VTABLE_ENTRY("as_text", &Executable::AsText);
- TVM_MODULE_VTABLE_ENTRY("as_python", &Executable::AsPython);
- TVM_MODULE_VTABLE_ENTRY("vm_load_executable", &Executable::VMLoadExecutable);
- TVM_MODULE_VTABLE_ENTRY("vm_profiler_load_executable",
&Executable::VMProfilerLoadExecutable);
- TVM_MODULE_VTABLE_ENTRY("has_function", &Executable::HasFunction);
+ TVM_MODULE_VTABLE_BEGIN("relax.VMExecutable");
+ TVM_MODULE_VTABLE_ENTRY("stats", &VMExecutable::Stats);
+ TVM_MODULE_VTABLE_ENTRY("as_text", &VMExecutable::AsText);
+ TVM_MODULE_VTABLE_ENTRY("as_python", &VMExecutable::AsPython);
+ TVM_MODULE_VTABLE_ENTRY("vm_load_executable",
&VMExecutable::VMLoadExecutable);
+ TVM_MODULE_VTABLE_ENTRY("vm_profiler_load_executable",
&VMExecutable::VMProfilerLoadExecutable);
+ TVM_MODULE_VTABLE_ENTRY("has_function", &VMExecutable::HasFunction);
TVM_MODULE_VTABLE_END();
private:
diff --git a/include/tvm/runtime/relax_vm/vm.h
b/include/tvm/runtime/relax_vm/vm.h
index 607269d812..7bf716ae50 100644
--- a/include/tvm/runtime/relax_vm/vm.h
+++ b/include/tvm/runtime/relax_vm/vm.h
@@ -143,7 +143,7 @@ class VirtualMachine : public runtime::ModuleNode {
* \brief Load the executable for the virtual machine.
* \param exec The executable.
*/
- virtual void LoadExecutable(ObjectPtr<Executable> exec) = 0;
+ virtual void LoadExecutable(ObjectPtr<VMExecutable> exec) = 0;
/*!
* \brief Get global function in the VM.
* \param func_name The name of the function.
diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py
index f4519f834d..b853c4fa61 100644
--- a/python/tvm/__init__.py
+++ b/python/tvm/__init__.py
@@ -55,7 +55,7 @@ from . import target
from . import te
# tvm.driver
-from .driver import build
+from .driver import build, compile
# others
from . import arith
diff --git a/python/tvm/contrib/hexagon/session.py
b/python/tvm/contrib/hexagon/session.py
index 2456aa244e..4f017abcf6 100644
--- a/python/tvm/contrib/hexagon/session.py
+++ b/python/tvm/contrib/hexagon/session.py
@@ -24,11 +24,12 @@ import tempfile
from typing import Union
import tvm
-from tvm import relax
+import tvm.contrib.hexagon as hexagon
from tvm import rpc as _rpc
+from tvm import runtime
from tvm.contrib import utils
-import tvm.contrib.hexagon as hexagon
-from .tools import export_module, HEXAGON_SIMULATOR_NAME
+
+from .tools import HEXAGON_SIMULATOR_NAME, export_module
class Session:
@@ -202,26 +203,26 @@ class Session:
return
self._rpc.get_function("tvm.hexagon.load_module")(str(remote_file_path))
def get_executor_from_factory(
- self, module: Union[ExecutorFactoryModule, relax.Executable, str],
hexagon_arch: str = "v68"
+ self, module: Union[runtime.executable, str], hexagon_arch: str = "v68"
):
"""Create a local GraphModule which consumes a remote libmod.
Parameters
----------
- module : Union[relax.Executable]
+ module : Union[runtime.Executable, str]
The module to upload to the remote
session and load.
hexagon_arch : str
The hexagon arch to be used
"""
- if isinstance(module, (relax.Executable, str)):
+ if isinstance(module, (runtime.Executable, str)):
return self._relax_vm_executable_executor(module,
hexagon_arch=hexagon_arch)
raise TypeError(f"Unsupported executor type: {type(module)}")
- def _set_device_type(self, module: Union[str, pathlib.Path,
GraphExecutorFactoryModule]):
+ def _set_device_type(self, module: Union[str, pathlib.Path]):
"""Set session device type(hexagon, cpu) based on target in module.
Parameters
@@ -244,18 +245,19 @@ class Session:
self._requires_cpu_device = False
def _relax_vm_executable_executor(
- self, vm_exec: Union[relax.Executable, str], hexagon_arch: str
+ self, executable: Union[runtime.Executable, str], hexagon_arch: str
):
"""Create a local TVM module which consumes a remote vm executable.
- Paramters
- ---------
+ Parameters
+ ----------
- vm_exec : relax.Executable
- The Relax VM Executable to upload to the remote and load. This
will typically be the
- output of `relax.build` or the path to an already built and
exported shared library
+ executable : runtime.Executable
+ The Executable to upload to the remote and load. This will
typically be the
+ output of `tvm.compile` or the path to an already built and
exported shared library
hexagon_arch : str
The hexagon arch to be used
+
Returns
-------
TVMModule :
@@ -263,21 +265,21 @@ class Session:
"""
assert self._rpc is not None, "Hexagon session must be started using
__enter__ prior to use"
- if isinstance(vm_exec, relax.Executable):
+ if isinstance(executable, runtime.Executable):
temp_dir = utils.tempdir()
path_exec = temp_dir.relpath("exec.so")
- vm_exec.mod.export_library(
+ executable.export_library(
path_exec,
fcompile=hexagon.create_aot_shared,
hexagon_arch=hexagon_arch,
)
path = self.upload(path_exec, "exec.so")
- elif isinstance(vm_exec, str):
- path_exec = vm_exec
+ elif isinstance(executable, str):
+ path_exec = executable
else:
- raise TypeError(f"Unsupported executor type: {type(vm_exec)}")
+ raise TypeError(f"Unsupported executor type: {type(executable)}")
path = self.upload(path_exec, "exec.so")
return self._rpc.get_function("tvm.hexagon.load_module")(str(path))
diff --git a/python/tvm/driver/__init__.py b/python/tvm/driver/__init__.py
index b97375c3a3..6a4a2ba9f9 100644
--- a/python/tvm/driver/__init__.py
+++ b/python/tvm/driver/__init__.py
@@ -14,5 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+# pylint: disable=redefined-builtin
+
"""Namespace for driver APIs"""
-from .build_module import build
+from .build_module import build, compile
diff --git a/python/tvm/driver/build_module.py
b/python/tvm/driver/build_module.py
index 8d6a2a5343..ea923aae9a 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -17,16 +17,95 @@
# pylint: disable=invalid-name
"""The build utils in python."""
-from typing import Union, Optional
+import warnings
+from typing import Callable, Optional, Union
+
import tvm
-from tvm.tir import PrimFunc
from tvm.ir.module import IRModule
+from tvm.runtime import Executable
from tvm.target import Target
+from tvm.tir import PrimFunc
def build(
mod: Union[PrimFunc, IRModule],
target: Optional[Union[str, Target]] = None,
- pipeline: Optional[Union[str, tvm.transform.Pass]] = "default_tir",
+ pipeline: Optional[Union[str, tvm.transform.Pass]] = "default",
):
+ """
+ Build a function with a signature, generating code for devices
+ coupled with target information.
+
+ This function is deprecated. Use `tvm.compile` or `tvm.tir.build` instead.
+
+ Parameters
+ ----------
+ mod : Union[PrimFunc, IRModule]
+ The input to be built.
+ target : Optional[Union[str, Target]]
+ The target for compilation.
+ pipeline : Optional[Union[str, tvm.transform.Pass]]
+ The pipeline to use for compilation.
+
+ Returns
+ -------
+ tvm.runtime.Module
+ A module combining both host and device code.
+ """
+ warnings.warn(
+ "build is deprecated. Use `tvm.compile` or `tvm.tir.build` instead.",
+ DeprecationWarning,
+ )
return tvm.tir.build(mod, target, pipeline)
+
+
+def _contains_relax(mod: Union[PrimFunc, IRModule]) -> bool:
+ if isinstance(mod, PrimFunc):
+ return False
+ if isinstance(mod, IRModule):
+ return any(isinstance(func, tvm.relax.Function) for _, func in
mod.functions_items())
+
+ raise ValueError(f"Function input must be a PrimFunc or IRModule, but got
{type(mod)}")
+
+
+def compile( # pylint: disable=redefined-builtin
+ mod: Union[PrimFunc, IRModule],
+ target: Optional[Target] = None,
+ *,
+ relax_pipeline: Optional[Union[tvm.transform.Pass, Callable, str]] =
"default",
+ tir_pipeline: Optional[Union[tvm.transform.Pass, Callable, str]] =
"default",
+) -> Executable:
+ """
+ Compile an IRModule to a runtime executable.
+
+ This function serves as a unified entry point for compiling both TIR and
Relax modules.
+ It automatically detects the module type and routes to the appropriate
build function.
+
+ Parameters
+ ----------
+ mod : Union[PrimFunc, IRModule]
+ The input module to be compiled. Can be a PrimFunc or an IRModule
containing
+ TIR or Relax functions.
+ target : Optional[Target]
+ The target platform to compile for.
+ relax_pipeline : Optional[Union[tvm.transform.Pass, Callable, str]]
+ The compilation pipeline to use for Relax functions.
+ Only used if the module contains Relax functions.
+ tir_pipeline : Optional[Union[tvm.transform.Pass, Callable, str]]
+ The compilation pipeline to use for TIR functions.
+
+ Returns
+ -------
+ Executable
+ A runtime executable that can be loaded and executed.
+ """
+ # TODO(tvm-team): combine two path into unified one
+ if _contains_relax(mod):
+ return tvm.relax.build(
+ mod,
+ target,
+ relax_pipeline=relax_pipeline,
+ tir_pipeline=tir_pipeline,
+ )
+ lib = tvm.tir.build(mod, target, pipeline=tir_pipeline)
+ return Executable(lib)
diff --git a/python/tvm/meta_schedule/relax_integration.py
b/python/tvm/meta_schedule/relax_integration.py
index c3c24aa631..9c293a1654 100644
--- a/python/tvm/meta_schedule/relax_integration.py
+++ b/python/tvm/meta_schedule/relax_integration.py
@@ -382,7 +382,7 @@ def compile_relax(
target: Union[Target, str],
params: Optional[Dict[str, NDArray]],
enable_warning: bool = False,
-) -> "relax.Executable":
+) -> "relax.VMExecutable":
"""Compile a relax program with a MetaSchedule database.
Parameters
@@ -401,8 +401,8 @@ def compile_relax(
Returns
-------
- lib : relax.Executable
- The built runtime module or vm Executable for the given relax workload.
+ lib : relax.VMExecutable
+ The built runtime module or vm VMExecutable for the given relax
workload.
"""
# pylint: disable=import-outside-toplevel
from tvm.relax.transform import BindParams, MetaScheduleApplyDatabase
diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py
b/python/tvm/meta_schedule/testing/custom_builder_runner.py
index 7e7a3a1d9d..2da672b405 100644
--- a/python/tvm/meta_schedule/testing/custom_builder_runner.py
+++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py
@@ -17,21 +17,18 @@
"""Customized builder and runner methods"""
# pylint: disable=import-outside-toplevel
-from typing import TYPE_CHECKING, Dict, Union, Callable
+from typing import Dict, Union, Callable
-if TYPE_CHECKING:
- import numpy as np # type: ignore
- from tvm.ir import IRModule
- from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig
- from tvm.runtime import Device, Module, NDArray
- from tvm.target import Target
+import numpy as np # type: ignore
+from tvm.meta_schedule.runner import RPCConfig
+from tvm.runtime import Module, Executable
def run_module_via_rpc(
- rpc_config: "RPCConfig",
- lib: Union["Module", "Executable"],
+ rpc_config: RPCConfig,
+ lib: Union[Module, Executable],
dev_type: str,
- args: Union[Dict[int, "np.ndarray"], Dict[str, "np.ndarray"]],
+ args: Union[Dict[int, np.ndarray], Dict[str, np.ndarray]],
continuation: Callable,
):
"""Execute a tvm.runtime.Module on RPC remote"""
diff --git a/python/tvm/meta_schedule/testing/tune_utils.py
b/python/tvm/meta_schedule/testing/tune_utils.py
index cb97b221b2..08618a289d 100644
--- a/python/tvm/meta_schedule/testing/tune_utils.py
+++ b/python/tvm/meta_schedule/testing/tune_utils.py
@@ -87,8 +87,8 @@ def create_calculator(backend: str) -> Callable:
Parameters
----------
- rt_mod : Union[tvm.runtime.Module, tvm.runtime.vm.Executable]
- The runtime module or vm executable.
+ rt_mod : tvm.runtime.Module
+ The runtime module.
dev : tvm.device
The device type to run workload.
input_data : Dict[str, np.ndarray]
diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py
index 8494bd8e58..da288942ed 100644
--- a/python/tvm/relax/__init__.py
+++ b/python/tvm/relax/__init__.py
@@ -114,6 +114,6 @@ from . import frontend
from . import utils
# VM
-from .vm_build import build, Executable
+from .vm_build import build, VMExecutable
from .binding_rewrite import DataflowBlockRewrite
diff --git a/python/tvm/relax/exec_builder.py b/python/tvm/relax/exec_builder.py
index 140c497eb9..6998607860 100644
--- a/python/tvm/relax/exec_builder.py
+++ b/python/tvm/relax/exec_builder.py
@@ -21,7 +21,7 @@ from typing import Optional, Union, List
import tvm
from tvm.runtime import Object
from tvm.runtime.container import ShapeTuple
-from .vm_build import Executable
+from .vm_build import VMExecutable
from . import _ffi_api
@@ -142,6 +142,6 @@ class ExecBuilder(Object):
self._check_scope()
_ffi_api.ExecBuilderEmitIf(self, cond, false_offset) # type: ignore
- def get(self) -> Executable:
+ def get(self) -> VMExecutable:
"""return the executable"""
- return Executable(_ffi_api.ExecBuilderGet(self)) # type: ignore
+ return VMExecutable(_ffi_api.ExecBuilderGet(self)) # type: ignore
diff --git a/python/tvm/relax/frontend/nn/core.py
b/python/tvm/relax/frontend/nn/core.py
index 21118b1cb8..c25b683844 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -522,7 +522,7 @@ class Module(SubroutineMixin):
relax_build(
mod,
target=Target.from_device(device),
- pipeline=pipeline,
+ relax_pipeline=pipeline,
),
device,
)
diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py
index ffb38cdd93..ddf88e8bd0 100644
--- a/python/tvm/relax/pipeline.py
+++ b/python/tvm/relax/pipeline.py
@@ -194,6 +194,7 @@ def static_shape_tuning_pipeline(
# global map of pre-built pipelines
PIPELINE_MAP = {
"zero": zero_pipeline,
+ "default": default_build_pipeline,
"default_build": default_build_pipeline,
"static_shape_tuning": static_shape_tuning_pipeline,
}
diff --git a/python/tvm/relax/transform/tuning_api/default_functions.py
b/python/tvm/relax/transform/tuning_api/default_functions.py
index 7cdb211bd3..efdc9b13e3 100644
--- a/python/tvm/relax/transform/tuning_api/default_functions.py
+++ b/python/tvm/relax/transform/tuning_api/default_functions.py
@@ -185,7 +185,7 @@ def default_evaluate(
if runner is None:
def relax_eval_func(rt_mod, device, evaluator_config, repeated_args):
- relax_exec = tvm.relax.Executable(rt_mod)
+ relax_exec = tvm.relax.VMExecutable(rt_mod)
relax_vm = tvm.relax.VirtualMachine(relax_exec, device=device)
evaluator = relax_vm.module.time_evaluator(
diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py
index ac4d9698a0..f44fcb9c22 100644
--- a/python/tvm/relax/vm_build.py
+++ b/python/tvm/relax/vm_build.py
@@ -16,22 +16,22 @@
# under the License.
# pylint: disable=invalid-name, no-member
"""VM build logics"""
-from typing import Any, Dict, List, Optional, Union
+from typing import Dict, List, Optional, Union
import tvm
from tvm import relax
-from tvm.contrib import utils as _utils
from tvm.ir.module import IRModule
from tvm.tir.function import PrimFunc
+from tvm.runtime import Executable
from . import _ffi_api
-class Executable:
- """The executable object emitted by the VM compiler or the ExecBuilder."""
+class VMExecutable(Executable):
+ """The virtual machine executable object emitted by the VM compiler or the
ExecBuilder."""
def __init__(self, mod: tvm.runtime.Module):
- self.mod = mod
+ super().__init__(mod)
self._stats = self.mod["stats"]
self._as_text = self.mod["as_text"]
self._as_python = self.mod["as_python"]
@@ -48,105 +48,6 @@ class Executable:
"""print the instructions as python program."""
return self._as_python()
- def jit(self, fcompile=None, addons=None, **kwargs) -> tvm.runtime.Module:
- """Just-in-time compile and link the modules.
-
- The Executable returned by relax.build may not be directly
- runnable as they may contain cuda source files and objects that
- are yet to be compiled and linked.
- This function helps to create a runtime.Module for these cases.
-
- Parameters
- ----------
- fcompile : function(target, file_list, kwargs), optional
- The compilation function to use create the final library object
during
-
- kwargs : dict, optional
- Additional arguments passed to fcompile
-
- Returns
- -------
- rt_mod: tvm.runtime.Module
- A runnable runtime module that can be passed to VirtualMachine.
-
- Examples
- --------
- .. code:: python
-
- ex = relax.build(mod, target)
- # build a runnable module using nvcc to link everything
- rt_mod = ex.jit()
- vm = tvm.relax.VirtualMachine(rt_mod, tvm.cuda())
- """
-
- # TODO(tvm-team): Update runtime.Module interface
- # to query these properties as bitmask.
- def _not_runnable(x):
- return x.type_key in ("c", "static_library")
-
- # pylint:disable = protected-access
- not_runnable_list = self.mod._collect_from_import_tree(_not_runnable)
-
- # everything is runnable, directly return mod.
- if len(not_runnable_list) == 0:
- return self.mod
-
- # found source module, or other not runnable modules
- # need to be export and load
- # TODO(tvm-team): Support runnable but not exportable module.
- # by collecting the link and allow export_library skip those modules.
- workspace_dir = _utils.tempdir()
- dso_path = workspace_dir.relpath("exported.so")
- self.mod.export_library(dso_path, fcompile=fcompile, addons=addons,
**kwargs)
- return tvm.runtime.load_module(dso_path)
-
- def export_library(
- self,
- file_name: str,
- fcompile: Optional[Union[str, callable]] = None,
- workspace_dir: Optional[str] = None,
- **kwargs,
- ) -> Any:
- """Export the executable to a library which can then be loaded back.
-
- Parameters
- ----------
- file_name : str
- The name of the shared library.
-
- fcompile : function(target, file_list, kwargs), optional
- The compilation function to use create the final library object
during
-
- workspace_dir : str, optional
- The path of the directory used to create the intermediate
- artifacts when exporting the module.
- If this is not provided a temporary dir will be created.
-
- kwargs : dict, optional
- Additional arguments passed to fcompile
-
- Returns
- -------
- result of fcompile() : unknown, optional
- If the compilation function returns an artifact it would be
returned via
- export_library, if any.
-
- Examples
- --------
- .. code:: python
-
- ex = relax.build(mod, target)
- # export the library
- ex.export_library("exported.so")
-
- # load it back for future uses.
- rt_mod = tvm.runtime.load_module("exported.so")
- vm = tvm.relax.VirtualMachine(rt_mod, tvm.cuda())
- """
- return self.mod.export_library(
- file_name=file_name, fcompile=fcompile,
workspace_dir=workspace_dir, **kwargs
- )
-
def _vmcodegen(
builder: "relax.ExecBuilder",
@@ -202,6 +103,7 @@ def _vmlink(
builder: "relax.ExecBuilder",
target: Optional[Union[str, tvm.target.Target]],
tir_mod: Optional[tvm.IRModule] = None,
+ tir_pipeline: Optional[Union[str, tvm.transform.Pass]] = "default",
ext_libs: List[tvm.runtime.Module] = None,
params: Optional[Dict[str, list]] = None,
*,
@@ -249,7 +151,7 @@ def _vmlink(
tir_ext_libs = []
if tir_mod is not None and len(tir_mod.get_global_vars()) > 0:
tir_mod = _auto_attach_system_lib_prefix(tir_mod, target, system_lib)
- lib = tvm.build(tir_mod, target=target)
+ lib = tvm.tir.build(tir_mod, target=target, pipeline=tir_pipeline)
for ext_mod in ext_libs:
if ext_mod.is_device_module:
tir_ext_libs.append(ext_mod)
@@ -260,14 +162,16 @@ def _vmlink(
lib.import_module(mod)
elif len(tir_ext_libs) > 0:
print("Warning: No TIR module is found, but external modules for TIR
are provided.")
- return Executable(_ffi_api.VMLink(builder, target, lib, relax_ext_libs,
params)) # type: ignore
+ lib = _ffi_api.VMLink(builder, target, lib, relax_ext_libs, params) #
type: ignore
+ return VMExecutable(lib)
def build(
mod: tvm.IRModule,
target: Optional[Union[str, tvm.target.Target]] = None,
params: Optional[Dict[str, list]] = None,
- pipeline: Union[None, str, tvm.transform.Pass] = "default_build",
+ relax_pipeline: Union[None, str, tvm.transform.Pass] = "default",
+ tir_pipeline: Union[None, str, tvm.transform.Pass] = "default",
exec_mode: str = "bytecode",
*,
system_lib: Optional[bool] = None,
@@ -336,14 +240,14 @@ def build(
if not params:
params = {}
- if pipeline is not None:
- if isinstance(pipeline, str):
- pipeline = relax.get_pipeline(pipeline)
+ if relax_pipeline is not None:
+ if isinstance(relax_pipeline, str):
+ relax_pipeline = relax.get_pipeline(relax_pipeline)
if target is None:
- mod = pipeline(mod)
+ mod = relax_pipeline(mod)
else:
with target:
- mod = pipeline(mod)
+ mod = relax_pipeline(mod)
ext_libs, constants = _extract_attrs(mod)
params.update(dict(constants))
@@ -353,6 +257,7 @@ def build(
builder=builder,
target=target,
tir_mod=_filter_tir(mod),
+ tir_pipeline=tir_pipeline,
ext_libs=ext_libs,
params=params,
system_lib=system_lib,
diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py
index b748f84bec..c7e407f028 100644
--- a/python/tvm/runtime/__init__.py
+++ b/python/tvm/runtime/__init__.py
@@ -23,13 +23,14 @@ from .object_path import ObjectPath, ObjectPathPair
from .script_printer import Scriptable
from .object_generic import ObjectGeneric, ObjectTypes
from .ndarray import NDArray, DataType, DataTypeCode, Device
-from .module import Module, num_threads
+from .module import Module
from .profiling import Report
+from .executable import Executable
# function exposures
from .ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, ext_dev
-from .module import load_module, enabled, system_lib, load_static_library
+from .module import load_module, enabled, system_lib, load_static_library,
num_threads
from .container import String, ShapeTuple # , BoxBool
from .object_generic import convert_to_object, convert, const
from .params import (
diff --git a/python/tvm/runtime/executable.py b/python/tvm/runtime/executable.py
new file mode 100644
index 0000000000..cf4e5b0587
--- /dev/null
+++ b/python/tvm/runtime/executable.py
@@ -0,0 +1,166 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, no-member
+
+"""Executable object for TVM Runtime"""
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import tvm
+from tvm.contrib import utils as _utils
+from . import PackedFunc, Module
+
+
+class Executable:
+ """The executable object generated by `tvm.compile`."""
+
+ def __init__(self, mod: Module):
+ """Initialize the Executable object."""
+ self.mod: Module = mod
+ self._jitted_mod: Optional[Module] = None
+ self.entry_name = mod.entry_name
+
+ def __getitem__(self, name: str) -> PackedFunc:
+ """Get the PackedFunc from the jitted module."""
+ return self.jit().get_function(name, query_imports=True)
+
+ def __call__(self, *args, **kwargs) -> Any:
+ """Call the executable."""
+ return self.jit().get_function(self.entry_name,
query_imports=True)(*args, **kwargs)
+
+ def jit(
+ self,
+ *,
+ fcompile: Optional[Callable[[str, List[str], Dict[str, Any]], None]] =
None,
+ addons: Optional[List[str]] = None,
+ force_recompile: bool = False,
+ **kwargs,
+ ) -> Module:
+ """Just-in-time compile and link the modules.
+
+ The Executable returned by tvm.compile may not be directly
+ runnable as they may contain cuda source files and objects that
+ are yet to be compiled and linked.
+ This function helps to create a runtime.Module for these cases.
+
+ Parameters
+ ----------
+ fcompile : function(target, file_list, kwargs), optional
+ The compilation function to use create the final library object
during
+
+ addons : list of str, optional
+ Additional object files to link against.
+
+ force_recompile : bool, optional
+ If True, force a recompile of the module.
+
+ kwargs : dict, optional
+ Additional arguments passed to fcompile
+
+ Returns
+ -------
+ rt_mod: tvm.runtime.Module
+ A runnable runtime module that can be passed to VirtualMachine.
+
+ Examples
+ --------
+ .. code:: python
+
+ ex = tvm.compile(mod, target)
+ rt_mod = ex.jit()
+
+ """
+
+ # If the module is already jitted and we don't want to force a
recompile,
+ # return the cached module
+ if self._jitted_mod is not None and not force_recompile:
+ return self._jitted_mod
+
+ # TODO(tvm-team): Update runtime.Module interface
+ # to query these properties as bitmask.
+ def _not_runnable(x):
+ return x.type_key in ("c", "static_library")
+
+ # pylint:disable = protected-access
+ not_runnable_list = self.mod._collect_from_import_tree(_not_runnable)
+
+ # everything is runnable, directly return mod.
+ if len(not_runnable_list) == 0:
+ return self.mod
+
+ # found source module, or other not runnable modules need to be export
and load
+ # TODO(tvm-team): Support runnable but not exportable module.
+ # by collecting the link and allow export_library skip those modules.
+ workspace_dir = _utils.tempdir()
+ dso_path = workspace_dir.relpath("exported.so")
+ self.mod.export_library(dso_path, fcompile=fcompile, addons=addons,
**kwargs)
+ self._jitted_mod = tvm.runtime.load_module(dso_path)
+ return self._jitted_mod
+
+ def export_library(
+ self,
+ file_name: str,
+ *,
+ fcompile: Optional[Union[str, Callable[[str, List[str], Dict[str,
Any]], None]]] = None,
+ addons: Optional[List[str]] = None,
+ workspace_dir: Optional[str] = None,
+ **kwargs,
+ ) -> Any:
+ """Export the executable to a library which can then be loaded back.
+
+ Parameters
+ ----------
+ file_name : str
+ The name of the shared library.
+
+ fcompile : function(target, file_list, kwargs), optional
+ The compilation function to use create the final library object
during
+
+ addons : list of str, optional
+ Additional object files to link against.
+
+ workspace_dir : str, optional
+ The path of the directory used to create the intermediate
+ artifacts when exporting the module.
+ If this is not provided a temporary dir will be created.
+
+ kwargs : dict, optional
+ Additional arguments passed to fcompile
+
+ Returns
+ -------
+ result of fcompile() : unknown, optional
+ If the compilation function returns an artifact it would be
returned via
+ export_library, if any.
+
+ Examples
+ --------
+ .. code:: python
+
+ ex = tvm.compile(mod, target)
+ # export the library
+ ex.export_library("exported.so")
+
+ # load it back for future uses.
+ rt_mod = tvm.runtime.load_module("exported.so")
+ """
+ return self.mod.export_library(
+ file_name=file_name,
+ fcompile=fcompile,
+ addons=addons,
+ workspace_dir=workspace_dir,
+ **kwargs,
+ )
diff --git a/python/tvm/runtime/relax_vm.py b/python/tvm/runtime/relax_vm.py
index 5b8bbe6d33..1dc4aef6f4 100644
--- a/python/tvm/runtime/relax_vm.py
+++ b/python/tvm/runtime/relax_vm.py
@@ -44,7 +44,7 @@ class VirtualMachine(object):
def __init__(
self,
- rt_mod: Union[tvm.runtime.Module, "tvm.relax.Executable"],
+ rt_mod: Union[tvm.runtime.Module, tvm.runtime.Executable],
device: Union[Device, List[Device]],
memory_cfg: Optional[Union[str, Dict[Device, str]]] = None,
profile: bool = False,
@@ -54,7 +54,7 @@ class VirtualMachine(object):
Parameters
----------
- rt_mod: Union[tvm.runtime.Module, tvm.relax.Executable]
+ rt_mod: Union[tvm.runtime.Module, tvm.runtime.Executable]
Runtime module exported by the result of build.
device : Union[Device, List[Device]]
@@ -72,13 +72,7 @@ class VirtualMachine(object):
Whether or not to enable profiling.
"""
if not isinstance(rt_mod, tvm.runtime.Module):
- # important to keep this import local
- # as the relax_vm needs to be isolated from compiler
- # if we do not use the jit feature
- # pylint:disable=import-outside-toplevel
- from tvm import relax
-
- if isinstance(rt_mod, relax.Executable):
+ if isinstance(rt_mod, tvm.runtime.Executable):
rt_mod = rt_mod.jit()
else:
raise ValueError("Expect the rt_mod to be an runtime.Module")
@@ -101,10 +95,7 @@ class VirtualMachine(object):
devs = dev
if not isinstance(dev, (list, tuple)):
if not isinstance(dev, tvm.runtime.Device):
- raise TypeError(
- "dev is expected to be Device or \
- List[Device]"
- )
+ raise TypeError("dev is expected to be Device or List[Device]")
devs = [dev]
# CPU is required for executing shape functions
diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py
index ee6280b740..14bc189b9f 100644
--- a/python/tvm/tir/build.py
+++ b/python/tvm/tir/build.py
@@ -105,7 +105,7 @@ def tir_to_runtime(
def build(
mod: Union[PrimFunc, IRModule],
target: Optional[Union[str, Target]] = None,
- pipeline: Union[None, str, tvm.transform.Pass] = "default_tir",
+ pipeline: Union[None, str, tvm.transform.Pass] = "default",
):
"""Build a function with a signature, generating code for devices
coupled with target information.
diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tir/pipeline.py
index c8019c9229..ae78b05738 100644
--- a/python/tvm/tir/pipeline.py
+++ b/python/tvm/tir/pipeline.py
@@ -151,11 +151,11 @@ def finalize_device_passes(): # pylint:
disable=unused-argument
# global map of pre-built pipelines
PIPELINE_MAP = {
- "default_tir": default_tir_pipeline,
+ "default": default_tir_pipeline,
}
-def get_tir_pipeline(name: str = "default_tir", **kwargs) ->
tvm.transform.Pass:
+def get_tir_pipeline(name: str = "default", **kwargs) -> tvm.transform.Pass:
"""Get pre-build pipeline by name
Parameters
diff --git a/src/relax/backend/vm/codegen_vm.cc
b/src/relax/backend/vm/codegen_vm.cc
index 18da88be80..bd56b0fd7b 100644
--- a/src/relax/backend/vm/codegen_vm.cc
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -170,7 +170,7 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const
Expr&)> {
builder_->EmitCall("vm.builtin.read_if_cond", {cond_value}, cond_reg);
// obtain the temp exec in progress.
- vm::Executable* exec = builder_->exec();
+ vm::VMExecutable* exec = builder_->exec();
// Record the offset of If instruction
size_t if_offset = exec->instr_offset.size();
@@ -436,7 +436,7 @@
TVM_REGISTER_GLOBAL("relax.VMCodeGen").set_body_typed(VMCodeGen);
* module(s).
* \return The created module.
*/
-void LinkModules(ObjectPtr<Executable> exec, const Map<String,
runtime::NDArray>& params,
+void LinkModules(ObjectPtr<VMExecutable> exec, const Map<String,
runtime::NDArray>& params,
const tvm::runtime::Module& lib, const
Array<runtime::Module>& ext_libs) {
// query if we need const loader for ext_modules
// Wrap all submodules in the initialization wrapper.
@@ -482,7 +482,7 @@ void LinkModules(ObjectPtr<Executable> exec, const
Map<String, runtime::NDArray>
*/
Module VMLink(ExecBuilder builder, Target target, Optional<Module> lib,
Array<Module> ext_libs,
Map<String, runtime::NDArray> params) {
- ObjectPtr<Executable> executable = builder->Get();
+ ObjectPtr<VMExecutable> executable = builder->Get();
if (!lib.defined()) {
lib = codegen::CSourceModuleCreate(";", "", Array<String>{});
}
diff --git a/src/relax/backend/vm/exec_builder.cc
b/src/relax/backend/vm/exec_builder.cc
index 0e6f59b460..36bfa7e242 100644
--- a/src/relax/backend/vm/exec_builder.cc
+++ b/src/relax/backend/vm/exec_builder.cc
@@ -33,13 +33,13 @@ TVM_REGISTER_NODE_TYPE(ExecBuilderNode);
ExecBuilder ExecBuilderNode::Create() {
ExecBuilder ret(make_object<ExecBuilderNode>());
- ret->exec_ = make_object<Executable>();
+ ret->exec_ = make_object<VMExecutable>();
return ret;
}
-Executable* ExecBuilderNode::exec() const { return exec_.get(); }
+VMExecutable* ExecBuilderNode::exec() const { return exec_.get(); }
-ObjectPtr<Executable> ExecBuilderNode::Get() {
+ObjectPtr<VMExecutable> ExecBuilderNode::Get() {
this->Formalize();
this->CheckExecutable();
return exec_;
@@ -270,7 +270,7 @@ void ExecBuilderNode::CheckExecutable() {
void ExecBuilderNode::Formalize() {
// a pass to formalize user-specified register indexes in the order of use
- // and decide the number of registers to allocate for each VMFunction in the
Executable
+ // and decide the number of registers to allocate for each VMFunction in the
VMExecutable
for (auto it = this->exec_->func_table.begin(); it !=
this->exec_->func_table.end(); ++it) {
if (it->kind == VMFuncInfo::FuncKind::kPackedFunc) continue;
if (it->kind == VMFuncInfo::FuncKind::kVMTIRFunc) continue;
@@ -395,7 +395,7 @@
TVM_REGISTER_GLOBAL("relax.ExecBuilderF").set_body_typed([](ExecBuilder builder,
});
TVM_REGISTER_GLOBAL("relax.ExecBuilderGet").set_body_typed([](ExecBuilder
builder) {
- ObjectPtr<Executable> p_exec = builder->Get();
+ ObjectPtr<VMExecutable> p_exec = builder->Get();
return runtime::Module(p_exec);
});
diff --git a/src/runtime/relax_vm/executable.cc
b/src/runtime/relax_vm/executable.cc
index f45786c3da..bf122cc04b 100644
--- a/src/runtime/relax_vm/executable.cc
+++ b/src/runtime/relax_vm/executable.cc
@@ -52,7 +52,7 @@ enum ConstantType : int {
ICHECK(val) << "Invalid VM file format in the " << section << " section." \
<< "\n";
-std::string Executable::Stats() const {
+std::string VMExecutable::Stats() const {
std::ostringstream oss;
oss << "Relax VM executable statistics:" << std::endl;
@@ -116,14 +116,14 @@ std::string Executable::Stats() const {
return oss.str();
}
-void Executable::SetInstructionData(Index i, Index j, ExecWord val) {
+void VMExecutable::SetInstructionData(Index i, Index j, ExecWord val) {
ICHECK_LT(i, instr_offset.size());
Index instr_idx = instr_offset[i];
ICHECK_LT(instr_idx + j, instr_data.size());
instr_data[instr_idx + j] = val;
}
-Instruction Executable::GetInstruction(Index i) const {
+Instruction VMExecutable::GetInstruction(Index i) const {
Index offset = instr_offset[i];
Opcode op = static_cast<Opcode>(instr_data[offset]);
switch (op) {
@@ -173,7 +173,7 @@ void LoadHeader(dmlc::Stream* strm) {
STREAM_CHECK(version == RELAX_VM_VERSION, "version");
}
-void Executable::SaveToBinary(dmlc::Stream* stream) {
+void VMExecutable::SaveToBinary(dmlc::Stream* stream) {
std::string code;
// Initialize the stream object.
dmlc::MemoryStringStream strm(&code);
@@ -193,20 +193,20 @@ void Executable::SaveToBinary(dmlc::Stream* stream) {
stream->Write(code);
}
-void Executable::SaveToFile(const String& file_name, const String& format) {
+void VMExecutable::SaveToFile(const String& file_name, const String& format) {
std::string data;
dmlc::MemoryStringStream writer(&data);
dmlc::SeekStream* strm = &writer;
- Executable::SaveToBinary(strm);
+ VMExecutable::SaveToBinary(strm);
runtime::SaveBinaryToFile(file_name, data);
}
-Module Executable::LoadFromBinary(void* stream) {
+Module VMExecutable::LoadFromBinary(void* stream) {
std::string code;
static_cast<dmlc::Stream*>(stream)->Read(&code);
dmlc::MemoryStringStream strm(&code);
- ObjectPtr<Executable> exec = make_object<Executable>();
+ ObjectPtr<VMExecutable> exec = make_object<VMExecutable>();
// Load header.
LoadHeader(&strm);
@@ -223,19 +223,19 @@ Module Executable::LoadFromBinary(void* stream) {
return Module(exec);
}
-TVM_REGISTER_GLOBAL("runtime.module.loadbinary_relax.Executable")
- .set_body_typed(Executable::LoadFromBinary);
+TVM_REGISTER_GLOBAL("runtime.module.loadbinary_relax.VMExecutable")
+ .set_body_typed(VMExecutable::LoadFromBinary);
-Module Executable::LoadFromFile(const String& file_name) {
+Module VMExecutable::LoadFromFile(const String& file_name) {
std::string data;
runtime::LoadBinaryFromFile(file_name, &data);
dmlc::MemoryStringStream reader(&data);
dmlc::Stream* strm = &reader;
- return Executable::LoadFromBinary(reinterpret_cast<void*>(strm));
+ return VMExecutable::LoadFromBinary(reinterpret_cast<void*>(strm));
}
-TVM_REGISTER_GLOBAL("runtime.module.loadfile_relax.Executable")
- .set_body_typed(Executable::LoadFromFile);
+TVM_REGISTER_GLOBAL("runtime.module.loadfile_relax.VMExecutable")
+ .set_body_typed(VMExecutable::LoadFromFile);
void VMFuncInfo::Save(dmlc::Stream* strm) const {
int32_t temp_kind = static_cast<int32_t>(kind);
@@ -261,9 +261,9 @@ bool VMFuncInfo::Load(dmlc::Stream* strm) {
return true;
}
-void Executable::SaveGlobalSection(dmlc::Stream* strm) {
strm->Write(func_table); }
+void VMExecutable::SaveGlobalSection(dmlc::Stream* strm) {
strm->Write(func_table); }
-void Executable::SaveConstantSection(dmlc::Stream* strm) {
+void VMExecutable::SaveConstantSection(dmlc::Stream* strm) {
strm->Write(static_cast<uint64_t>(this->constants.size()));
for (const auto& it : this->constants) {
if (it.IsObjectRef<runtime::NDArray>()) {
@@ -301,12 +301,12 @@ void Executable::SaveConstantSection(dmlc::Stream* strm) {
}
}
-void Executable::SaveCodeSection(dmlc::Stream* strm) {
+void VMExecutable::SaveCodeSection(dmlc::Stream* strm) {
strm->Write(instr_offset);
strm->Write(instr_data);
}
-void Executable::LoadGlobalSection(dmlc::Stream* strm) {
+void VMExecutable::LoadGlobalSection(dmlc::Stream* strm) {
STREAM_CHECK(strm->Read(&func_table), "Global Section");
// setup func map
for (size_t i = 0; i < func_table.size(); ++i) {
@@ -314,7 +314,7 @@ void Executable::LoadGlobalSection(dmlc::Stream* strm) {
}
}
-void Executable::LoadConstantSection(dmlc::Stream* strm) {
+void VMExecutable::LoadConstantSection(dmlc::Stream* strm) {
uint64_t sz;
// Load the number of constants.
STREAM_CHECK(strm->Read(&sz, sizeof(sz)), "constant");
@@ -375,7 +375,7 @@ void Executable::LoadConstantSection(dmlc::Stream* strm) {
}
}
-void Executable::LoadCodeSection(dmlc::Stream* strm) {
+void VMExecutable::LoadCodeSection(dmlc::Stream* strm) {
STREAM_CHECK(strm->Read(&(this->instr_offset)), "instr offset");
STREAM_CHECK(strm->Read(&(this->instr_data)), "instr data");
}
@@ -404,21 +404,21 @@ std::string RegNameToStr(RegName reg) {
return "%" + std::to_string(reg);
}
-Module Executable::VMLoadExecutable() const {
+Module VMExecutable::VMLoadExecutable() const {
ObjectPtr<VirtualMachine> vm = VirtualMachine::Create();
- vm->LoadExecutable(GetObjectPtr<Executable>(const_cast<Executable*>(this)));
+
vm->LoadExecutable(GetObjectPtr<VMExecutable>(const_cast<VMExecutable*>(this)));
return Module(vm);
}
-Module Executable::VMProfilerLoadExecutable() const {
+Module VMExecutable::VMProfilerLoadExecutable() const {
ObjectPtr<VirtualMachine> vm = VirtualMachine::CreateProfiler();
- vm->LoadExecutable(GetObjectPtr<Executable>(const_cast<Executable*>(this)));
+
vm->LoadExecutable(GetObjectPtr<VMExecutable>(const_cast<VMExecutable*>(this)));
return Module(vm);
}
-bool Executable::HasFunction(const String& name) const { return
func_map.count(name); }
+bool VMExecutable::HasFunction(const String& name) const { return
func_map.count(name); }
-String Executable::AsText() const {
+String VMExecutable::AsText() const {
auto get_func_name = [&](Index index) -> std::string {
if (static_cast<size_t>(index) < func_table.size()) {
return func_table[index].name;
@@ -495,7 +495,7 @@ String Executable::AsText() const {
return String(os.str());
}
-String Executable::AsPython() const {
+String VMExecutable::AsPython() const {
auto get_func_name = [&](Index index) -> std::string {
if (static_cast<size_t>(index) < func_table.size()) {
return "\"" + func_table[index].name + "\"";
@@ -573,7 +573,7 @@ String Executable::AsPython() const {
return String(os.str());
}
-TVM_REGISTER_GLOBAL("relax.ExecutableLoadFromFile").set_body_typed(Executable::LoadFromFile);
+TVM_REGISTER_GLOBAL("relax.ExecutableLoadFromFile").set_body_typed(VMExecutable::LoadFromFile);
} // namespace relax_vm
} // namespace runtime
diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc
index 1442f6cd06..5c45024fc8 100644
--- a/src/runtime/relax_vm/vm.cc
+++ b/src/runtime/relax_vm/vm.cc
@@ -198,7 +198,7 @@ class VirtualMachineImpl : public VirtualMachine {
//---------------------------------------------------
// Public facing functions overloading
//---------------------------------------------------
- void LoadExecutable(ObjectPtr<Executable> exec) final;
+ void LoadExecutable(ObjectPtr<VMExecutable> exec) final;
void Init(const std::vector<Device>& devices,
const std::vector<AllocatorType>& alloc_types) final;
VMClosure GetClosure(const String& func_name) final {
@@ -425,7 +425,7 @@ class VirtualMachineImpl : public VirtualMachine {
// Internal states for execution.
//--------------------------------------------------------
/*! \brief The loaded executable. */
- ObjectPtr<Executable> exec_;
+ ObjectPtr<VMExecutable> exec_;
/*! \brief The global constant pool */
std::vector<TVMRetValue> const_pool_;
/*!
@@ -462,7 +462,7 @@ class VirtualMachineImpl : public VirtualMachine {
PackedFunc instrument_ = nullptr;
};
-void VirtualMachineImpl::LoadExecutable(ObjectPtr<Executable> exec) {
+void VirtualMachineImpl::LoadExecutable(ObjectPtr<VMExecutable> exec) {
this->exec_ = exec;
this->imports_ = exec_->imports();
}
diff --git a/tests/python/driver/test_compile.py
b/tests/python/driver/test_compile.py
new file mode 100644
index 0000000000..e66bd7c290
--- /dev/null
+++ b/tests/python/driver/test_compile.py
@@ -0,0 +1,121 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+
+import tvm
+import tvm.testing
+from tvm import relax, te
+from tvm.runtime import Executable
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+
+def test_compile_tir():
+ """Test tvm.compile with TIR input."""
+ n = te.var("n")
+ A = te.placeholder((n,), name="A")
+ B = te.placeholder((n,), name="B")
+ C = te.compute(A.shape, lambda i: A[i] + B[i], name="C")
+ func = te.create_prim_func([A, B, C])
+
+ # Test compile with PrimFunc
+ exec_prim = tvm.compile(func)
+ assert isinstance(exec_prim, Executable)
+
+ # Test compile with IRModule containing PrimFunc
+ mod = tvm.IRModule.from_expr(func)
+ exec_mod = tvm.compile(mod)
+ assert isinstance(exec_mod, Executable)
+
+ # Verify the compiled module works
+ dev = tvm.cpu(0)
+ a_np = np.random.uniform(size=10).astype(np.float32)
+ b_np = np.random.uniform(size=10).astype(np.float32)
+ a = tvm.nd.array(a_np, dev)
+ b = tvm.nd.array(b_np, dev)
+ c = tvm.nd.array(np.zeros(10, dtype=np.float32), dev)
+
+ exec_prim(a, b, c)
+ np.testing.assert_allclose(c.numpy(), a_np + b_np)
+ exec_mod(a, b, c)
+ np.testing.assert_allclose(c.numpy(), a_np + b_np)
+
+
+def test_compile_relax():
+ """Test tvm.compile with Relax input."""
+ # Define a simple Relax program
+ @I.ir_module
+ class MyModule:
+ @R.function
+ def main(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4),
"float32")) -> R.Tensor:
+ z = R.add(x, y)
+ return z
+
+ # Test compile with Relax module
+ target = tvm.target.Target("llvm")
+ exec_relax = tvm.compile(MyModule, target)
+ assert isinstance(exec_relax, Executable)
+
+ # Verify the compiled module works
+ dev = tvm.cpu(0)
+ x_np = np.random.uniform(size=(3, 4)).astype(np.float32)
+ y_np = np.random.uniform(size=(3, 4)).astype(np.float32)
+ x = tvm.nd.array(x_np, dev)
+ y = tvm.nd.array(y_np, dev)
+
+ vm = relax.VirtualMachine(exec_relax, dev)
+ z = vm["main"](x, y)
+ np.testing.assert_allclose(z.numpy(), x_np + y_np)
+
+
[email protected]_if_32bit(reason="skipping test for i386.")
+def test_compile_mixed_module():
+ @tvm.script.ir_module
+ class MyModule:
+ @T.prim_func
+ def add_one(X: T.Buffer((4,), "float32"), Y: T.Buffer((4,),
"float32")):
+ for i in range(4):
+ Y[i] = X[i] + 1
+
+ @R.function
+ def main(x: R.Tensor((4,), "float32")):
+ cls = MyModule
+ with R.dataflow():
+ y = R.call_tir(cls.add_one, [x], R.Tensor((4,), "float32"))
+ return y
+
+ # Test with custom pipeline
+ target = tvm.target.Target("c")
+ ex = tvm.compile(MyModule, target)
+ assert isinstance(ex, Executable)
+
+ dev = tvm.cpu(0)
+ x = tvm.nd.array(np.array([1, 2, 3, 4], dtype=np.float32), dev)
+ y = tvm.nd.array(np.zeros(4, dtype=np.float32), dev)
+ # For tir function, we can directly call the function
+ ex["add_one"](x, y)
+ np.testing.assert_allclose(y.numpy(), x.numpy() + 1)
+ # For relax function, we need to use the vm to call the function
+ vm = relax.VirtualMachine(ex, dev)
+ z = vm["main"](x)
+ np.testing.assert_allclose(z.numpy(), x.numpy() + 1)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/runtime/test_executable.py
b/tests/python/runtime/test_executable.py
new file mode 100644
index 0000000000..571ce7adb2
--- /dev/null
+++ b/tests/python/runtime/test_executable.py
@@ -0,0 +1,263 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for the Executable class."""
+
+import os
+import tempfile
+
+import numpy as np
+
+import tvm
+import tvm.testing
+from tvm.runtime import Executable
+from tvm.script import tir as T
+
+
[email protected]_module
+class MyModule:
+ @T.prim_func
+ def add(
+ A: T.Buffer((10,), "float32"),
+ B: T.Buffer((10,), "float32"),
+ C: T.Buffer((10,), "float32"),
+ ):
+ for i in range(10):
+ C[i] = A[i] + B[i]
+
+
+def test_executable_init():
+ """Test initialization of Executable class."""
+ lib = tvm.tir.build(MyModule, target="llvm")
+ executable = Executable(lib)
+
+ assert executable.mod is lib
+ assert executable._jitted_mod is None
+
+
+def test_executable_getitem():
+ """Test __getitem__ method of Executable class."""
+ lib = tvm.tir.build(MyModule, target="llvm")
+ executable = Executable(lib)
+
+ # Jit the module first
+ executable.jit()
+
+ # Test __getitem__
+ add_func = executable["add"]
+
+ # Verify the function works
+ a = tvm.nd.array(np.array([1.0] * 10, dtype="float32"))
+ b = tvm.nd.array(np.array([2.0] * 10, dtype="float32"))
+ c = tvm.nd.array(np.array([0.0] * 10, dtype="float32"))
+
+ add_func(a, b, c)
+
+ # Check results
+ tvm.testing.assert_allclose(c.numpy(), np.array([3.0] * 10,
dtype="float32"))
+
+
+def test_executable_jit_already_jitted():
+ """Test jit method when module is already jitted."""
+ lib = tvm.tir.build(MyModule, target="llvm")
+ executable = Executable(lib)
+
+ # First jit call
+ jitted_mod1 = executable.jit()
+
+ # Second jit call should return the cached jitted module
+ jitted_mod2 = executable.jit()
+ assert jitted_mod2 is jitted_mod1
+
+ # Test with force_recompile
+ jitted_mod3 = executable.jit(force_recompile=True)
+ # The module might be different after force recompilation
+
+ # Verify both modules work correctly
+ a = tvm.nd.array(np.array([1.0] * 10, dtype="float32"))
+ b = tvm.nd.array(np.array([2.0] * 10, dtype="float32"))
+ c1 = tvm.nd.array(np.array([0.0] * 10, dtype="float32"))
+ c2 = tvm.nd.array(np.array([0.0] * 10, dtype="float32"))
+
+ jitted_mod1["add"](a, b, c1)
+ jitted_mod3["add"](a, b, c2)
+
+ tvm.testing.assert_allclose(c1.numpy(), np.array([3.0] * 10,
dtype="float32"))
+ tvm.testing.assert_allclose(c2.numpy(), np.array([3.0] * 10,
dtype="float32"))
+
+
+def test_executable_export_library():
+ """Test export_library method."""
+ lib = tvm.tir.build(MyModule, target="llvm")
+ executable = Executable(lib)
+
+ # Create a temporary directory for the library
+ temp_dir = tempfile.mkdtemp()
+ try:
+ lib_path = os.path.join(temp_dir, "test_lib.so")
+ executable.export_library(lib_path)
+
+ # Verify the library was created
+ assert os.path.exists(lib_path)
+
+ # Load the library back
+ loaded_mod = tvm.runtime.load_module(lib_path)
+ assert loaded_mod is not None
+
+ # Test the loaded module
+ a = tvm.nd.array(np.array([1.0] * 10, dtype="float32"))
+ b = tvm.nd.array(np.array([2.0] * 10, dtype="float32"))
+ c = tvm.nd.array(np.array([0.0] * 10, dtype="float32"))
+
+ loaded_mod["add"](a, b, c)
+
+ # Check results
+ tvm.testing.assert_allclose(c.numpy(), np.array([3.0] * 10,
dtype="float32"))
+ finally:
+ # Clean up
+ if os.path.exists(temp_dir):
+ import shutil
+
+ shutil.rmtree(temp_dir)
+
+
+def test_executable_export_library_with_workspace():
+ """Test export_library method with workspace_dir."""
+ lib = tvm.tir.build(MyModule, target="llvm")
+ executable = Executable(lib)
+
+ # Create temporary directories
+ temp_dir = tempfile.mkdtemp()
+ workspace_dir = tempfile.mkdtemp()
+
+ try:
+ lib_path = os.path.join(temp_dir, "test_lib.so")
+ executable.export_library(lib_path, workspace_dir=workspace_dir)
+
+ # Verify the library was created
+ assert os.path.exists(lib_path)
+
+ # Load the library back
+ loaded_mod = tvm.runtime.load_module(lib_path)
+ assert loaded_mod is not None
+
+ # Test the loaded module
+ a = tvm.nd.array(np.array([1.0] * 10, dtype="float32"))
+ b = tvm.nd.array(np.array([2.0] * 10, dtype="float32"))
+ c = tvm.nd.array(np.array([0.0] * 10, dtype="float32"))
+
+ loaded_mod["add"](a, b, c)
+
+ # Check results
+ tvm.testing.assert_allclose(c.numpy(), np.array([3.0] * 10,
dtype="float32"))
+ finally:
+ # Clean up
+ for directory in [temp_dir, workspace_dir]:
+ if os.path.exists(directory):
+ import shutil
+
+ shutil.rmtree(directory)
+
+
+def test_executable_integration():
+ """Integration test for Executable with a simple TVM module."""
+ # Create target and build
+ target = tvm.target.Target("llvm")
+ lib = tvm.tir.build(MyModule, target=target)
+
+ # Create an executable
+ executable = Executable(lib)
+
+ # Test jit
+ jitted_mod = executable.jit()
+ assert jitted_mod is not None
+
+ # Test __getitem__
+ add_func = executable["add"]
+ assert add_func is not None
+
+ # Test the function works
+ a = tvm.nd.array(np.array([1.0] * 10, dtype="float32"))
+ b = tvm.nd.array(np.array([2.0] * 10, dtype="float32"))
+ c = tvm.nd.array(np.array([0.0] * 10, dtype="float32"))
+
+ add_func(a, b, c)
+
+ # Check results
+ tvm.testing.assert_allclose(c.numpy(), np.array([3.0] * 10,
dtype="float32"))
+
+ # Test export_library
+ temp_dir = tempfile.mkdtemp()
+ try:
+ lib_path = os.path.join(temp_dir, "test_lib.so")
+ executable.export_library(lib_path)
+
+ # Verify the library was created
+ assert os.path.exists(lib_path)
+
+ # Load the library back
+ loaded_mod = tvm.runtime.load_module(lib_path)
+ assert loaded_mod is not None
+
+ # Test the loaded module
+ loaded_add = loaded_mod["add"]
+ c_loaded = tvm.nd.array(np.array([0.0] * 10, dtype="float32"))
+ loaded_add(a, b, c_loaded)
+
+ # Check results
+ tvm.testing.assert_allclose(c_loaded.numpy(), np.array([3.0] * 10,
dtype="float32"))
+
+ finally:
+ # Clean up
+ if os.path.exists(temp_dir):
+ import shutil
+
+ shutil.rmtree(temp_dir)
+
+
+def test_executable_jit_force_recompile():
+ """Test jit method with force_recompile=True."""
+ # Create target and build
+ target = tvm.target.Target("c")
+ lib = tvm.tir.build(MyModule, target=target)
+
+ # Create an executable
+ executable = Executable(lib)
+
+ # First jit call
+ jitted_mod1 = executable.jit()
+
+ # Second jit call without force_recompile should return the same module
+ jitted_mod2 = executable.jit()
+ assert jitted_mod1 is jitted_mod2
+
+ # Third jit call with force_recompile should return a new module
+ jitted_mod3 = executable.jit(force_recompile=True)
+ assert jitted_mod3 is not jitted_mod1
+
+ # Test the function works
+ a = tvm.nd.array(np.array([1.0] * 10, dtype="float32"))
+ b = tvm.nd.array(np.array([2.0] * 10, dtype="float32"))
+ c = tvm.nd.array(np.array([0.0] * 10, dtype="float32"))
+
+ jitted_mod3["add"](a, b, c)
+
+ # Check results
+ tvm.testing.assert_allclose(c.numpy(), np.array([3.0] * 10,
dtype="float32"))
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/scripts/task_python_unittest.sh
b/tests/scripts/task_python_unittest.sh
index b074af3147..ddc775933c 100755
--- a/tests/scripts/task_python_unittest.sh
+++ b/tests/scripts/task_python_unittest.sh
@@ -37,21 +37,20 @@ run_pytest
${TVM_UNITTEST_TESTSUITE_NAME}-platform-minimal-test tests/python/all
# Then run all unittests on both ctypes and cython.
TEST_FILES=(
"arith"
+ "ci"
"codegen"
+ "driver"
"ir"
"meta_schedule"
"runtime"
+ "target"
"te"
"testing"
"tir-analysis"
"tir-base"
"tir-schedule"
"tir-transform"
- "tir-usmp"
"tvmscript"
- "usmp"
- "ci"
- "target"
)
for TEST_FILE in ${TEST_FILES[@]}; do