This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fc6775e770 [REFACTOR] move build flow from C++ to Python (#17665)
fc6775e770 is described below
commit fc6775e770f22f03c7e0d396d3bc51af6807735f
Author: Bohan Hou <[email protected]>
AuthorDate: Thu Feb 20 14:47:06 2025 -0500
[REFACTOR] move build flow from C++ to Python (#17665)
This PR moves build flow from C++ to python, enables more developer
productivity and readabilities
---
docs/how_to/tutorials/cross_compilation_and_rpc.py | 2 +-
docs/reference/api/python/driver.rst | 2 -
include/tvm/driver/driver_api.h | 124 -----
python/tvm/__init__.py | 2 +-
python/tvm/driver/__init__.py | 2 +-
python/tvm/driver/build_module.py | 123 +----
python/tvm/tir/__init__.py | 2 +
python/tvm/tir/build.py | 179 +++++++
python/tvm/tir/pipeline.py | 175 ++++++
python/tvm/tir/transform/transform.py | 35 +-
src/driver/driver_api.cc | 595 ---------------------
src/driver/internal_driver_api.h | 48 --
src/relax/backend/vm/codegen_vm.cc | 1 -
src/relax/backend/vm/codegen_vm_tir.cc | 1 -
src/relax/transform/bind_params.cc | 1 -
src/relax/transform/fold_constant.cc | 7 +-
src/tir/ir/transform.cc | 20 +
src/tir/transforms/primfunc_utils.cc | 1 -
tests/python/codegen/test_target_codegen_cuda.py | 1 +
tests/python/codegen/test_target_codegen_llvm.py | 47 +-
.../test_hexagon/test_2d_physical_buffers.py | 12 -
.../test_hexagon/test_benchmark_elemwise_add.py | 6 -
.../contrib/test_hexagon/test_meta_schedule.py | 2 +-
tests/python/contrib/test_hexagon/test_sigmoid.py | 2 +-
.../test_hexagon/test_software_pipeline_async.py | 1 -
tests/python/ir/test_pass_instrument.py | 2 +-
tests/python/tir-base/test_lower_build.py | 133 -----
.../python/tir-base/test_tir_te_extern_primfunc.py | 1 -
..._tir_schedule_tensorize_ldmatrix_mma_numeric.py | 2 +-
.../test_tir_schedule_tensorize_mfma_numeric.py | 2 +-
.../test_tir_transform_convert_ssa.py | 35 --
.../test_tir_transform_extract_constants.py | 2 -
.../test_tir_transform_flatten_buffer.py | 31 --
.../test_tir_transform_lower_tvm_builtin.py | 2 +-
.../test_tir_transform_narrow_datatype.py | 27 -
.../test_tir_transform_storage_rewrite.py | 82 ---
36 files changed, 431 insertions(+), 1279 deletions(-)
diff --git a/docs/how_to/tutorials/cross_compilation_and_rpc.py
b/docs/how_to/tutorials/cross_compilation_and_rpc.py
index 81c73fd051..94a6f48b4b 100644
--- a/docs/how_to/tutorials/cross_compilation_and_rpc.py
+++ b/docs/how_to/tutorials/cross_compilation_and_rpc.py
@@ -119,7 +119,7 @@ if local_demo:
else:
target = "llvm -mtriple=armv7l-linux-gnueabihf"
-func = tvm.build(mod, target=target, name="add_one")
+func = tvm.build(mod, target=target)
# save the lib at a local temp folder
temp = utils.tempdir()
path = temp.relpath("lib.tar")
diff --git a/docs/reference/api/python/driver.rst
b/docs/reference/api/python/driver.rst
index 1f1bc8c7cf..97c30ec2d2 100644
--- a/docs/reference/api/python/driver.rst
+++ b/docs/reference/api/python/driver.rst
@@ -19,6 +19,4 @@ tvm.driver
----------
.. automodule:: tvm.driver
-.. autofunction:: tvm.lower
-
.. autofunction:: tvm.build
diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h
deleted file mode 100644
index 39444d1629..0000000000
--- a/include/tvm/driver/driver_api.h
+++ /dev/null
@@ -1,124 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/driver/driver_api.h
- * \brief Compiler driver APIs to drive the compilation.
- *
- * This module provides end-to-end utils to drive the compilation process.
- * We adopt the term "compiler driver" in common compiler infrastructures.
- * Note that a compiler driver is different from "runtime drivers".
- * Most of runtime related code are defined in the runtime folder instead.
- */
-#ifndef TVM_DRIVER_DRIVER_API_H_
-#define TVM_DRIVER_DRIVER_API_H_
-
-#include <tvm/ir/global_var_supply.h>
-#include <tvm/ir/module.h>
-#include <tvm/ir/transform.h>
-#include <tvm/runtime/packed_func.h>
-#include <tvm/support/with.h>
-#include <tvm/target/target.h>
-#include <tvm/tir/function.h>
-
-#include <string>
-
-namespace tvm {
-using tvm::transform::Pass;
-
-/*!
- * \brief Configures and returns the composite Pass for the fused module (pre
split) that contains
- * device and host code.
- * \param mixed_mod The original mixed module.
- * \param target The device Target.
- * \return The composite Pass for the fused module.
-// */
-TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod,
Target target);
-
-/*!
- * \brief Configures and returns the composite Pass for the device Target
after device/host from
- * mixed module.
- * \param mixed_mod The optimized mixed module.
- * \param target The device Target.
- * \return The composite Pass for the device module.
- */
-TVM_DLL transform::Sequential DeviceModulePassManager(IRModule mixed_mod,
Target target);
-
-/*!
- * \brief Configures and returns the composite Pass for the host Target after
device/host from mixed
- * module.
- * \param mixed_mod The optimized mixed module.
- * \param target_host The host Target.
- * \return The composite Pass for the host module.
- */
-TVM_DLL transform::Sequential HostModulePassManager(IRModule mixed_mod, Target
target_host);
-
-/*!
- * \brief Lower an IRModule (optimize with it with the pass list defined in
CreatePassList)
- * \param mod The IRmodule to lower
- * \param simple_mode Disables the loop partition pass. Defaults to false.
- * \return The result module.
- */
-TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false);
-
-/*!
- * \brief Lower a primfunc and name (convert to IRModule, and optimize it with
the pass list
- * defined in CreatePassList)
- * \param func The PrimFunc to lower
- * \param name The name of the lowered function.
- * \param simple_mode Disables the loop partition pass. Defaults to false.
- * \return The result module.
- */
-TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string&
name,
- bool simple_mode = false);
-
-/*!
- * \brief Build a device and host module for a specific target from an
IRModule.
- * \param funcs The functions to be built.
- * \param target The target device to build for.
- * \param target_host The target for building host code. To use the default,
pass Target()
- * \return The built module.
- */
-TVM_DLL runtime::Module build(const IRModule& funcs, const Target& target,
- const Target& target_host);
-
-/*!
- * \brief Build a device and host module for a specific target from a map
- * contains target to IRModule. This function is used
- * for heterogeneous build.
- * \param input The map contains target to an IRModule.
- * \param target_host The target for building host code. To use the default,
- * pass Target().
- * \return The built module that contains code for different processors.
- */
-TVM_DLL runtime::Module build(const Map<Target, IRModule>& input, const
Target& target_host);
-
-/*!
- * \brief Build a device and host module for a specific target from a map
- * contains target to IRModule. This function is used
- * for heterogeneous build.
- * \param input The map contains target string to an IRModule.
- * \param target_host The target for building host code. To use the default,
- * pass Target().
- * \return The built module that contains code for different processors.
- */
-TVM_DLL runtime::Module build(const Map<String, IRModule>& input, const
Target& target_host);
-} // namespace tvm
-
-#endif // TVM_DRIVER_DRIVER_API_H_
diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py
index abbab3ad6d..f4519f834d 100644
--- a/python/tvm/__init__.py
+++ b/python/tvm/__init__.py
@@ -55,7 +55,7 @@ from . import target
from . import te
# tvm.driver
-from .driver import build, lower
+from .driver import build
# others
from . import arith
diff --git a/python/tvm/driver/__init__.py b/python/tvm/driver/__init__.py
index 75e94cc91c..b97375c3a3 100644
--- a/python/tvm/driver/__init__.py
+++ b/python/tvm/driver/__init__.py
@@ -15,4 +15,4 @@
# specific language governing permissions and limitations
# under the License.
"""Namespace for driver APIs"""
-from .build_module import lower, build
+from .build_module import build
diff --git a/python/tvm/driver/build_module.py
b/python/tvm/driver/build_module.py
index 94006111ff..8d6a2a5343 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -18,130 +18,15 @@
# pylint: disable=invalid-name
"""The build utils in python."""
from typing import Union, Optional
-
-
-import tvm.tir
-
-
-from tvm.runtime import ndarray
+import tvm
from tvm.tir import PrimFunc
from tvm.ir.module import IRModule
from tvm.target import Target
-from tvm.driver import _ffi_api as _driver_ffi
-
-from . import _ffi_api as ffi
-
-
-def lower(
- inp: Union[PrimFunc, IRModule],
- name: str = "main",
- simple_mode: bool = False,
-) -> IRModule:
- """Lowering step before build into target.
-
- Parameters
- ----------
- inp : Union[tvm.tir.PrimFunc, IRModule]
- The TE schedule or TensorIR PrimFunc/IRModule to be built
-
- name : str
- The name of the result function.
-
- simple_mode : bool
- Whether only output simple and compact statement, this will skip
- LoopPartition, api wrapper generation and Unrolling.
-
- Returns
- -------
- m : IRModule
- The result IRModule
- """
- if isinstance(inp, IRModule):
- return ffi.lower_module(inp, simple_mode)
- if isinstance(inp, PrimFunc):
- return ffi.lower_primfunc(inp, name, simple_mode)
- raise ValueError(
- f"Expected input to be an IRModule, PrimFunc or te.Schedule, but got
{type(inp)}"
- )
def build(
- inputs: Union[PrimFunc, IRModule],
+ mod: Union[PrimFunc, IRModule],
target: Optional[Union[str, Target]] = None,
- name: str = "main",
+ pipeline: Optional[Union[str, tvm.transform.Pass]] = "default_tir",
):
- """Build a function with arguments as signature. Code will be generated
- for devices coupled with target information.
-
- Parameters
- ----------
- input : Union[tvm.tir.PrimFunc, IRModule]
- The input to be built
-
- target : Optional[Union[str, Target]]
- The target and option of the compilation.
-
- name : str
- The name of result function.
-
- Returns
- -------
- ret : tvm.module
- A module that combines both host and device code.
-
- Note
- ----
- See the note on :any:`tvm.target` on target string format.
- """
- if isinstance(inputs, PrimFunc):
- input_mod = lower(inputs, name=name)
- elif isinstance(inputs, tvm.IRModule):
- assert (
- len(inputs.get_global_vars()) > 0
- ), "Expected a non-empty IRModule, but the IRModule contained no
functions."
- input_mod = lower(inputs)
- else:
- raise ValueError("Inputs must be IRModule or PrimFunc")
-
- target = Target.current() if target is None else target
- if target is None and isinstance(input_mod, tvm.IRModule):
- target_mod = {}
- for gvar, func in input_mod.functions.items():
- tgt = func.attrs["target"] if "target" in func.attrs else "llvm"
- if tgt not in target_mod:
- target_mod[tgt] = {}
- target_mod[tgt][gvar] = func
-
- target_input_mod = {}
- for tgt in target_mod.keys():
- tir_mod = tvm.IRModule(target_mod[tgt])
- tir_mod = tir_mod.with_attrs(input_mod.attrs)
- target_input_mod[tgt] = tir_mod
- else:
- target_input_mod = {target: input_mod}
-
- # Because modules can be created from a variety of sources, we annotate
them
- # with the relevant attributes here to ensure they propagate
- annotated_mods = {}
- for tgt, mod in target_input_mod.items():
- if not isinstance(tgt, (str, Target)):
- raise ValueError("The key of inputs must be str or " "Target when
inputs is dict.")
- if not isinstance(mod, tvm.IRModule):
- raise ValueError("inputs must be IRModule, " "or dict of str to
IRModule.")
- annotated_mods[tgt] = mod
-
- annotated_mods, target_host =
Target.canon_target_map_and_host(annotated_mods)
- if not target_host:
- for tar, mod in annotated_mods.items():
- device_type = ndarray.device(tar.kind.name, 0).device_type
- if device_type == ndarray.cpu(0).device_type:
- target_host = tar
- break
- if not target_host:
- target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"
-
- annotated_mods, target_host =
Target.canon_target_map_and_host(annotated_mods, target_host)
-
- rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
-
- return rt_mod_host
+ return tvm.tir.build(mod, target, pipeline)
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 1d7352f665..9ff5bff5f1 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -108,3 +108,5 @@ from . import ir_builder
from . import transform
from . import analysis
from . import stmt_functor
+from .build import build
+from .pipeline import get_pipeline
diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py
new file mode 100644
index 0000000000..cd44ed881b
--- /dev/null
+++ b/python/tvm/tir/build.py
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=invalid-name
+"""The build utils in python."""
+from typing import Union, Optional, Dict
+import enum
+
+import tvm
+from tvm import ir
+from tvm.runtime import ndarray
+from tvm.tir import PrimFunc
+from tvm.ir.module import IRModule
+from tvm.target import Target
+
+
+def split_host_device_mods(mod):
+ """Split an IRModule into host and device modules.
+
+ Parameters
+ ----------
+ mod : tvm.IRModule
+ The input module to split
+
+ Returns
+ -------
+ host_mod : tvm.IRModule
+ The module containing host functions
+ device_mod_dict : Dict[Target, tvm.IRModule]
+ A dict mapping targets to device modules
+ """
+
+ class CallConv(enum.IntEnum):
+ """Enum representing different calling conventions.
+ Corresponds to the C++ tvm::ir::CallingConv enum.
+ """
+
+ kDefault = 0
+ kCPackedFunc = 1
+ kDeviceKernelLaunch = 2
+
+ host_mod = tvm.tir.transform.Filter(
+ lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault))
+ != int(CallConv.kDeviceKernelLaunch)
+ )(mod)
+ device_mod = tvm.tir.transform.Filter(
+ lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault))
+ == int(CallConv.kDeviceKernelLaunch)
+ )(mod)
+ device_mod_dict = {}
+ for gv, func in device_mod.functions.items():
+ device_mod_dict.setdefault(func.attrs.get("target", None),
dict()).update({gv: func})
+ for target, funcs in device_mod_dict.items():
+ device_mod_dict[target] = tvm.IRModule(funcs, attrs=device_mod.attrs)
+ return host_mod, device_mod_dict
+
+
+def codegen_build(mod: IRModule, target: Target) -> tvm.runtime.Module:
+ """Build a runtime module from an IRModule and a Target."""
+ if tvm.ir.transform.PassContext.current().config.get("tir.disable_assert",
False):
+ mod = tvm.tir.transform.SkipAssert()(mod)
+ build_f_name = "target.build." + target.kind.name
+ bf = tvm.get_global_func(build_f_name)
+ if bf is None:
+ raise ValueError(f"{build_f_name} is not enabled")
+ return bf(mod, target)
+
+
+def tir_to_runtime(
+ host_mod: IRModule, device_mod_dict: Dict[Target, IRModule], target_host:
Target
+):
+ """Convert a collection of TIR IRModules (keyed by Target) into a single
runtime Module."""
+
+ # Get the first module to get the attributes
+ # necessary for
tests/python/codegen/test_target_codegen_blob.py::test_cuda_multi_lib
+ mhost_all = ir.IRModule({}, attrs=host_mod.attrs)
+
+ mhost_all.update(host_mod)
+ device_modules = []
+ for target, device_mod in device_mod_dict.items():
+ if len(device_mod.functions) != 0:
+ device_modules.append(codegen_build(device_mod, target))
+
+ mhost = codegen_build(mhost_all, target_host)
+ for dev_mod in device_modules:
+ if dev_mod is not None:
+ mhost.import_module(dev_mod)
+ return mhost
+
+
+def build(
+ mod: Union[PrimFunc, IRModule],
+ target: Optional[Union[str, Target]] = None,
+ pipeline: Union[None, str, tvm.transform.Pass] = "default_tir",
+):
+ """Build a function with a signature, generating code for devices
+ coupled with target information.
+
+ Parameters
+ ----------
+ mod : Union[PrimFunc, IRModule]
+ The input to be built.
+ target : Optional[Union[str, Target]]
+ The target for compilation.
+ pipeline : Union[None, str, tvm.transform.Pass]
+ The pipeline to use for compilation.
+
+ Returns
+ -------
+ tvm.runtime.Module
+ A module combining both host and device code.
+ """
+ # Convert PrimFunc to IRModule
+ if isinstance(mod, PrimFunc):
+ mod = tvm.IRModule.from_expr(mod)
+ else:
+ assert isinstance(mod, tvm.IRModule)
+
+ # Step 0: Determine the target in environment
+ target = Target.current() if target is None else target
+ if target is None:
+ target = "llvm"
+ assert target is not None
+ target = Target.canon_target(target)
+
+ # Step 1: Determine the host
+ target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"
+ if target is not None:
+ if target.host is not None:
+ target_host = target.host
+ elif ndarray.device(target.kind.name, 0).device_type ==
ndarray.cpu(0).device_type:
+ target_host = target
+ else:
+ for func in mod.functions.values():
+ f_target = func.attrs.get("target", None)
+ if f_target is not None and f_target.host is not None:
+ target_host = f_target.host
+ assert target_host is not None
+ target_host = Target.canon_target(target_host)
+ target = target.with_host(target_host)
+
+ # Step 2: Bind the target to the input module
+ mod = tvm.tir.transform.BindTarget(target)(mod)
+
+ # Step 3: Apply the pipeline
+ if pipeline is not None:
+ if isinstance(pipeline, str):
+ pipeline = tvm.tir.get_pipeline(pipeline)
+ mod = pipeline(mod)
+
+ # Step 4: Get host and device modules
+ host_mod, device_mod_dict = split_host_device_mods(mod)
+
+ # Step 5: Apply finalization passes
+ host_mod = tvm.tir.pipeline.finalize_host_passes()(host_mod)
+ device_mod_dict = {
+ target: tvm.tir.pipeline.finalize_device_passes()(device_mod)
+ for target, device_mod in device_mod_dict.items()
+ }
+
+ # Convert TIR IRModules to runtime Module by calling target.build
+ return tir_to_runtime(host_mod, device_mod_dict, target_host)
+
+
+tvm.register_func("tir.build", build)
diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tir/pipeline.py
new file mode 100644
index 0000000000..0b6d622c90
--- /dev/null
+++ b/python/tvm/tir/pipeline.py
@@ -0,0 +1,175 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=invalid-name
+"""The TIR backend compilation pipeline."""
+
+import tvm
+from tvm import tir
+
+
+def default_tir_pipeline():
+ """The default tir pipeline used in tvm.tir.build"""
+
+ @tvm.transform.module_pass(opt_level=0)
+ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) ->
tvm.ir.IRModule:
+ """The default lowering passes for TIR backend."""
+ pass_ctx = tvm.transform.PassContext.current()
+ config = pass_ctx.config
+ passes = [
+ tir.transform.InjectPrefetch(),
+ tir.transform.TextureFlatten(),
+ tir.transform.StorageFlatten(
+ 64, bool(config.get("tir.instrument_bound_checkers", False))
+ ),
+ tir.transform.LowerCrossThreadReduction(),
+ tir.transform.LowerInitBlock(),
+ tir.transform.PlanAndUpdateBufferAllocationLocation(),
+ tir.transform.ConvertBlocksToOpaque(),
+ tir.transform.LiftThreadBinding(),
+ tir.transform.ManifestSharedMemoryLocalStage(),
+ tir.transform.CompactBufferAllocation(),
+ tir.transform.LowerAutoCopy(),
+ tir.transform.UnifyThreadBinding(),
+ tir.transform.LowerMatchBuffer(),
+ tir.transform.Simplify(),
+ tir.transform.InjectPermutedLayout(),
+ tir.transform.InjectSoftwarePipeline(),
+ tir.transform.TransformMmaBufferLayout(),
+ tir.transform.LowerOpaqueBlock(),
+ tir.transform.FlattenBuffer(),
+ tir.transform.BF16ComputeLegalize(),
+ tir.transform.NarrowDataType(32),
+ tir.transform.LoopPartition(),
+ tir.transform.VectorizeLoop(not
bool(config.get("tir.disable_vectorize", False))),
+ tir.transform.InjectVirtualThread(),
+ tir.transform.InjectDoubleBuffer(),
+ ]
+ if not bool(config.get("tir.disable_storage_rewrite", False)):
+ passes.append(tir.transform.StorageRewrite())
+ if config.get("tir.use_async_copy", False):
+ passes.append(tir.transform.LowerAsyncDMA())
+ passes.extend(
+ [
+ tir.transform.HoistIfThenElse(),
+ tir.transform.UnrollLoop(),
+ tir.transform.RenormalizeSplitPattern(),
+ tir.transform.Simplify(),
+ tir.transform.RemoveNoOp(),
+ tir.transform.RewriteUnsafeSelect(),
+ ]
+ )
+ # Additional passes based on configuration.
+ if bool(config.get("tir.instrument_bound_checkers", False)):
+ passes.append(tir.transform.InstrumentBoundCheckers())
+ if bool(config.get("tir.ptx_ldg32", False)):
+ passes.append(tir.transform.InjectPTXLDG32(True))
+ passes.append(
+ tir.transform.CommonSubexprElimTIR(
+ not bool(config.get("tir.disable_cse_tir", False)),
+ bool(config.get("tir.enable_equiv_terms_in_cse_tir", False)),
+ )
+ )
+ if bool(config.get("tir.instrument_lwp", False)):
+ passes.append(tir.transform.InstrumentProfileIntrinsics())
+ passes.extend(
+ [
+ # Bind the target first so that target-specific attributes are
available.
+ tir.transform.FP8ComputeLegalize(),
+ # VerifyVTCMLimit must occur before LowerVtcmAlloc.
+ tir.transform.VerifyVTCMLimit(),
+ tir.transform.LowerVtcmAlloc(),
+ tir.transform.VerifyMemory(),
+ tir.transform.AnnotateEntryFunc(),
+ ]
+ )
+ if bool(config.get("tir.detect_global_barrier", False)):
+ passes.append(tir.transform.ThreadSync("global"))
+ passes.extend(
+ [
+ tir.transform.ThreadSync("shared"),
+ tir.transform.ThreadSync("shared.dyn"),
+ tir.transform.ThreadSync("warp"),
+ tir.transform.InferFragment(),
+ tir.transform.LowerThreadAllreduce(),
+ ]
+ )
+ if bool(config.get("tir.use_async_copy", False)):
+ passes.append(tir.transform.InjectPTXAsyncCopy())
+ if bool(config.get("tir.ptx_ldg32", False)):
+ passes.append(tir.transform.InjectPTXLDG32())
+ passes.extend(
+ [
+ tir.transform.AnnotateDeviceRegions(),
+ tir.transform.SplitHostDevice(),
+ # MergeSharedMemoryAllocations must follow SplitHostDevice.
+ tir.transform.MergeSharedMemoryAllocations(),
+ tir.transform.MakePackedAPI(),
+ tir.transform.FP8StorageLegalize(),
+ tir.transform.BF16StorageLegalize(),
+ tir.transform.LowerDeviceKernelLaunch(),
+ ]
+ )
+ mod = tvm.ir.transform.Sequential(passes)(mod)
+ return mod
+
+ return _pipeline
+
+
+def finalize_host_passes(): # pylint: disable=unused-argument
+ """The default finalization passes for TIR backend."""
+ host_pass_list = [
+ tir.transform.LowerTVMBuiltin(),
+ tir.transform.LowerCustomDatatypes(),
+ tir.transform.LowerIntrin(),
+ tir.transform.LowerDeviceStorageAccessInfo(),
+ tir.transform.CombineContextCall(),
+ ]
+ return tvm.ir.transform.Sequential(host_pass_list)
+
+
+def finalize_device_passes(): # pylint: disable=unused-argument
+ """The default finalization passes for TIR backend."""
+ device_pass_list = [
+ tir.transform.LowerWarpMemory(),
+ tir.transform.Simplify(),
+ tir.transform.LowerCustomDatatypes(),
+ tir.transform.LowerDeviceStorageAccessInfo(),
+ tir.transform.LowerIntrin(),
+ ]
+ return tvm.ir.transform.Sequential(device_pass_list)
+
+
+# global map of pre-built pipelines
+PIPELINE_MAP = {
+ "default_tir": default_tir_pipeline,
+}
+
+
+def get_pipeline(name: str = "default_tir", **kwargs) -> tvm.transform.Pass:
+ """Get pre-build pipeline by name
+
+ Parameters
+ ----------
+ name : Optional[str]
+ Name of the pipeline
+ """
+ if name not in PIPELINE_MAP:
+ raise ValueError(
+ f"Unknown pre-built pipeline {name}," f"candidates are
{list(PIPELINE_MAP.keys())}"
+ )
+ return PIPELINE_MAP[name](**kwargs)
diff --git a/python/tvm/tir/transform/transform.py
b/python/tvm/tir/transform/transform.py
index b08659e1c7..99a2e1e664 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -713,7 +713,7 @@ def VerifyMemory():
return _ffi_api.VerifyMemory() # type: ignore
-def VerifyVTCMLimit(limit: int):
+def VerifyVTCMLimit(limit=None):
"""Verify if the size of the allocated vtcm memory satisfies the limit.
Returns
@@ -1200,3 +1200,36 @@ def UseAssumeToReduceBranches():
The result pass
"""
return _ffi_api.UseAssumeToReduceBranches() # type: ignore
+
+
+def LowerAsyncDMA():
+ """Lower async DMA to DMA.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LowerAsyncDMA() # type: ignore
+
+
+def InjectPTXLDG32(enable_inject_ptx_intrin: bool = True):
+ """Inject ptx.ldg.32 intrinsics.
+
+ Parameters
+ ----------
+ enable_inject_ptx_intrin : bool
+ If True, inject ptx.ldg.32 intrinsics.
+ """
+ return _ffi_api.InjectPTXLDG32(enable_inject_ptx_intrin) # type: ignore
+
+
+def LowerVtcmAlloc():
+ """Lower vtcm allocation.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LowerVtcmAlloc() # type: ignore
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
deleted file mode 100644
index 5b12f13d96..0000000000
--- a/src/driver/driver_api.cc
+++ /dev/null
@@ -1,595 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Compile executable modules.
- * \file driver_api.cc
- */
-#include <dmlc/thread_local.h>
-#include <tvm/driver/driver_api.h>
-#include <tvm/ir/transform.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/target/codegen.h>
-#include <tvm/te/operation.h>
-#include <tvm/tir/analysis.h>
-#include <tvm/tir/transform.h>
-
-#include <algorithm>
-
-namespace tvm {
-
-// Register build pipeline related options
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.noalias", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_buffer_level_predication", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_static_smem", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool);
-
-// WARNING: May cause coherency issues resulting data miscompares
-// Experimental feature that, when enabled by the runtime, bypasses the cache
when using DMA. When
-// bypassing the cache TVM must manage cache coherency in software. Software
managed cache coherency
-// can be tricky e.g. it is yet to be proven out in the Hexagon runtime. Hence
the warning above and
-// the "experimental" notation for this feature.
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.experimental_dma_bypass_cache", Bool);
-
-using tvm::Array;
-using tvm::transform::Pass;
-
-bool LLVMEnabled() {
- const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm");
- return pf != nullptr;
-}
-
-/*! \return The default host target for a given device target */
-Target DefaultTargetHost(Target target) {
- if (target.defined() && target->GetTargetDeviceType() == kDLCPU) {
- return target;
- } else {
- if (LLVMEnabled()) {
- return Target("llvm");
- } else {
- return Target("stackvm");
- }
- }
-}
-
-void GetBinds(const Array<ObjectRef>& args, bool compact,
- const std::unordered_map<te::Tensor, tir::Buffer>& binds,
- Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>*
out_arg_list) {
- *out_binds = binds;
-
- for (const ObjectRef& x : args) {
- if (auto tensor_node = x.as<te::Tensor>()) {
- te::Tensor x_ref = tensor_node.value();
- if (out_binds->find(x_ref) == out_binds->end()) {
- tir::Buffer buf = tir::BufferWithOffsetAlignment(x_ref->shape,
x_ref->dtype,
- x_ref->op->name, -1,
0, compact);
- out_binds->Set(x_ref, buf);
- out_arg_list->push_back(buf);
- } else {
- out_arg_list->push_back((*out_binds)[x_ref]);
- }
- } else if (x.as<te::BufferNode>() || x.as<tir::VarNode>()) {
- out_arg_list->push_back(x);
- } else {
- LOG(FATAL)
- << "Expected type of the elements of args to be te::Tensor,
te::Buffer or tir::Var, "
- << "but got a " << x->GetTypeKey();
- }
- }
-}
-
-void GetBinds(const Array<te::Tensor>& args, bool compact,
- const std::unordered_map<te::Tensor, tir::Buffer>& binds,
- Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>*
out_arg_list) {
- Array<ObjectRef> ref_args;
- for (ObjectRef x : args) {
- ref_args.push_back(x);
- }
- GetBinds(ref_args, compact, binds, out_binds, out_arg_list);
-}
-
-TVM_REGISTER_GLOBAL("driver.get_binds")
- .set_body_typed([](const Array<ObjectRef>& args, bool compact,
- const Map<te::Tensor, tir::Buffer>& binds) {
- std::unordered_map<te::Tensor, tir::Buffer> c_binds;
- // Check to make sure binds is not null before doing the conversion;
- if (binds.get() != nullptr) {
- for (auto kv : binds) {
- c_binds.insert({kv.first, kv.second});
- }
- }
- Map<te::Tensor, tir::Buffer> out_binds;
- Array<ObjectRef> out_arg_list;
- GetBinds(args, compact, c_binds, &out_binds, &out_arg_list);
-
- // TVM object system doesn't have a pair object, so we'll put both ret
values in an array
- // and return that.
- Array<ObjectRef> out_arr = {out_binds, out_arg_list};
- return out_arr;
- });
-
-Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
- transform::PassContext pass_ctx = transform::PassContext::Current();
-
- bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize",
Bool(false)).value();
- bool disable_storage_rewrite =
- pass_ctx->GetConfig<Bool>("tir.disable_storage_rewrite",
Bool(false)).value();
- bool instrument_bound_checkers =
- pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers",
Bool(false)).value();
- bool disable_cse_tir = pass_ctx->GetConfig<Bool>("tir.disable_cse_tir",
Bool(false)).value();
- bool enable_equiv_terms_in_cse_tir =
- pass_ctx->GetConfig<Bool>("tir.enable_equiv_terms_in_cse_tir",
Bool(false)).value();
-
- bool ptx_ldg32 = pass_ctx->GetConfig<Bool>("tir.ptx_ldg32",
Bool(false)).value();
-
- // Get any user-added passes
- Array<Array<ObjectRef>> add_lower_pass =
- pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass",
Array<Array<ObjectRef>>())
- .value();
-
- bool instrument_lwp = pass_ctx->GetConfig<Bool>("tir.instrument_lwp",
Bool(false)).value();
-
- Array<transform::Pass> user_lower_phase0 = Array<transform::Pass>();
- Array<transform::Pass> user_lower_phase1 = Array<transform::Pass>();
- Array<transform::Pass> user_lower_phase2 = Array<transform::Pass>();
- Array<transform::Pass> user_lower_phase3 = Array<transform::Pass>();
-
- // phase passes is of the form
- // [[phase_number, pass], [phase_number, pass]... ]
- for (Array<ObjectRef> phase_pass : add_lower_pass) {
- auto phase_num = phase_pass[0].as<runtime::Int::ContainerType>();
- ICHECK(phase_num)
- << "Expected the first entry in the inner Array of tir.add_lower_pass
to be an integer, "
- << "but instead received " << phase_pass[0] << " with type " <<
phase_pass[0]->GetTypeKey();
- int phase_num_val = phase_num->value;
-
- CHECK_GE(phase_num_val, 0);
-
- auto pass = Downcast<tvm::transform::Pass>(phase_pass[1]);
- // Copy the pass into the correct phase
- if (phase_num_val == 0) {
- user_lower_phase0.push_back(pass);
- } else if (phase_num_val == 1) {
- user_lower_phase1.push_back(pass);
- } else if (phase_num_val == 2) {
- user_lower_phase2.push_back(pass);
- } else if (phase_num_val >= 3) {
- user_lower_phase3.push_back(pass);
- }
- }
-
- // Construct the pass list, inserting the user provided passes at the end of
the phase
-
- // PHASE 0
- Array<tvm::transform::Pass> pass_list = user_lower_phase0;
-
- // PHASE 1
- pass_list.push_back(tir::transform::InjectPrefetch());
- pass_list.push_back(tir::transform::TextureFlatten());
- pass_list.push_back(tir::transform::StorageFlatten(64,
instrument_bound_checkers));
- pass_list.push_back(tir::transform::LowerCrossThreadReduction());
- pass_list.push_back(tir::transform::LowerInitBlock());
- pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
- pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
- pass_list.push_back(tir::transform::LiftThreadBinding());
- pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage());
- pass_list.push_back(tir::transform::CompactBufferAllocation());
- pass_list.push_back(tir::transform::LowerAutoCopy());
- pass_list.push_back(tir::transform::UnifyThreadBinding());
- pass_list.push_back(tir::transform::LowerMatchBuffer());
- pass_list.push_back(tir::transform::Simplify());
- pass_list.push_back(tir::transform::InjectPermutedLayout());
- pass_list.push_back(tir::transform::Simplify());
- pass_list.push_back(tir::transform::InjectSoftwarePipeline());
- pass_list.push_back(tir::transform::TransformMmaBufferLayout());
- pass_list.push_back(tir::transform::LowerOpaqueBlock());
- pass_list.push_back(tir::transform::FlattenBuffer());
- pass_list.push_back(tir::transform::BF16ComputeLegalize());
- pass_list.push_back(tir::transform::NarrowDataType(32));
- pass_list.push_back(tir::transform::Simplify());
-
- // Add user-defined phase-1 passes
- pass_list.insert(pass_list.end(), user_lower_phase1.begin(),
user_lower_phase1.end());
-
- // PHASE 2
- if (!disable_loop_partition) {
- pass_list.push_back(tir::transform::LoopPartition());
- }
-
- pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
- pass_list.push_back(tir::transform::InjectVirtualThread());
- pass_list.push_back(tir::transform::InjectDoubleBuffer());
- if (!disable_storage_rewrite) {
- pass_list.push_back(tir::transform::StorageRewrite());
- }
- bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy",
Bool(false)).value();
-
- if (use_async_copy) {
- pass_list.push_back(tir::transform::LowerAsyncDMA());
- }
- // HoistIfThenElse must be applied before UnrollLoop
- // because HoistIfThenElse could utilize for loop structure
- // which might be unrolled in UnrollLoop
- pass_list.push_back(tir::transform::HoistIfThenElse());
- pass_list.push_back(tir::transform::UnrollLoop());
-
- // Add user-defined phase-2 passes
- pass_list.insert(pass_list.end(), user_lower_phase2.begin(),
user_lower_phase2.end());
-
- // PHASE 3
- pass_list.push_back(tir::transform::RenormalizeSplitPattern());
- pass_list.push_back(tir::transform::Simplify());
- pass_list.push_back(tir::transform::RemoveNoOp());
- pass_list.push_back(tir::transform::RewriteUnsafeSelect());
-
- // Add user-defined phase-3 passes
- pass_list.insert(pass_list.end(), user_lower_phase3.begin(),
user_lower_phase3.end());
-
- if (instrument_bound_checkers) {
- pass_list.push_back(tir::transform::InstrumentBoundCheckers());
- }
-
- if (ptx_ldg32) {
- pass_list.push_back(tir::transform::InjectPTXLDG32(true));
- }
-
- pass_list.push_back(
- tir::transform::CommonSubexprElimTIR(!disable_cse_tir,
enable_equiv_terms_in_cse_tir));
-
- // This pass instruments the loops with the profile builtin calls to capture
the runtime
- // performance data (only enabled for Hexagon at the moment). To ensure that
no other
- // optimizations are performed on the instrumented code, this pass must be
added at the end
- // of the list.
- if (instrument_lwp) {
- pass_list.push_back(tir::transform::InstrumentProfileIntrinsics());
- }
-
- return pass_list;
-}
-
-IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass>
pass_list) {
- auto optimize = tvm::transform::Sequential(pass_list);
- mod = optimize(std::move(mod));
- return mod;
-}
-
-IRModule ApplyPasses(IRModule mod, transform::Sequential seq) {
- mod = seq(std::move(mod));
- return mod;
-}
-
-IRModule LowerModule(IRModule mod, bool simple_mode) {
- Array<transform::Pass> pass_list = CreatePassList(simple_mode);
- return LowerWithPassList(std::move(mod), pass_list);
-}
-
-TVM_REGISTER_GLOBAL("driver.lower_module").set_body_typed([](IRModule mod,
bool simple_mode) {
- return LowerModule(std::move(mod), simple_mode);
-});
-
-IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool
simple_mode) {
- transform::PassContext pass_ctx = transform::PassContext::Current();
- tir::PrimFunc f = WithAttr(std::move(func), "global_symbol",
runtime::String(name));
-
- bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
-
- if (noalias) {
- f = WithAttr(std::move(f), "tir.noalias", Bool(true));
- }
- IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-
- // Get the pass list
- Array<transform::Pass> pass_list = CreatePassList(simple_mode);
- return LowerWithPassList(std::move(mod), pass_list);
-}
-
-TVM_REGISTER_GLOBAL("driver.lower_primfunc")
- .set_body_typed([](te::PrimFunc func, const String& name, bool
simple_mode) {
- return LowerPrimFunc(std::move(func), name, simple_mode);
- });
-
-/**
- * This function takes the input module that contains both the device and host
opts.
- * Then, it applies transformation on the original module before splitting
into separate modules for
- * device and host. Then it also applies transformations on the new splitted
modules.
- */
-std::pair<IRModule, IRModule> SplitMixedModule(IRModule mod_mixed, const
Target& target_arg,
- const Target& target_host_arg) {
- Target target = target_arg, target_host = target_host_arg;
- CheckAndUpdateHostConsistency(&target, &target_host);
-
- ICHECK(mod_mixed.defined()) << "This module must be defined";
-
- mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed,
target));
-
- IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed,
target_host));
-
- IRModule device_mod = ApplyPasses(mod_mixed,
DeviceModulePassManager(mod_mixed, target));
-
- auto keys = target->GetKeys();
-
- CheckAndUpdateHostConsistency(&target, &target_host);
-
- bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") !=
keys.end();
- if (target_is_gpu && device_mod->functions.size() == 0) {
- DLOG(WARNING) << "Specified target " << target->str()
- << " but cannot find device code. Did you forget to bind?";
- }
-
- return {host_mod, device_mod};
-}
-
-/*!
- * \brief Check and update host field of the given legacy heterogeneous
targets and
- * target host.Note that this function is for legacy target api compatibility
issue only,
- * not recommended for other use.
- * \param ir_modules The pointer to a Map objects with keys being Target
objects
- * \param host The Target typed object for target host to be updated
- */
-void CheckAndUpdateHostConsistency(Map<Target, IRModule>* targets, Target*
host) {
- Map<Target, IRModule> new_targets;
- for (auto& it : *targets) {
- auto target = it.first;
- CheckAndUpdateHostConsistency(&target, host);
- new_targets.Set(target, it.second);
- }
- *targets = new_targets;
-}
-
-runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
- const Target& target_host_arg) {
- CHECK(inputs_arg.size()) << "TIRToRuntime expects at least one IRModule as
input.";
- std::vector<runtime::Module> device_modules;
- Map<Target, IRModule> inputs = inputs_arg;
- Target target_host = target_host_arg;
-
- // Fetch previous defined target host in targets
- CheckAndUpdateHostConsistency(&inputs, &target_host);
-
- if (!target_host.defined()) {
- for (const auto& it : inputs) {
- if (it.first->GetTargetDeviceType() == kDLCPU) {
- target_host = it.first;
- break;
- }
- }
- }
-
- if (!target_host.defined()) {
- target_host = DefaultTargetHost(target_host);
- }
-
- // Update target host for all targets
- CheckAndUpdateHostConsistency(&inputs, &target_host);
-
- // Take the attrs from the first module so the eventual modules have them.
- // Ideally this would just be one unified module all the way through;
- IRModule first_module = (*inputs.begin()).second;
- IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>(), {},
first_module->attrs);
-
- ICHECK(mhost_all.defined()) << "The host module must be defined";
-
- for (const auto& it : inputs) {
- if (it.second.defined()) {
- const Target& target = it.first;
- const IRModule& ir_module = it.second;
- auto pair = SplitMixedModule(ir_module, target, target_host);
- auto& host_mod = pair.first;
- auto& device_mod = pair.second;
-
- ICHECK(host_mod.defined()) << "The split host module must be defined";
-
- ICHECK(mhost_all.defined()) << "The host module must be defined";
-
- // We don't want library modules going back into host codegen
- // unless they're supposed to. Here if we overrode the target host
- // to allow lowering previously we check that it's meant to be placed
- // back into the host Module.
- bool overrides_host_target =
- target->GetTargetDeviceType() == target_host->GetTargetDeviceType();
- bool non_host_target_kind = target->kind != target_host->kind;
- if (overrides_host_target && non_host_target_kind) {
- device_modules.push_back(codegen::Build(host_mod, it.first));
- } else {
- mhost_all->Update(host_mod);
- }
-
- if (device_mod->functions.size() != 0) {
- device_modules.push_back(codegen::Build(device_mod, it.first));
- }
- }
- }
-
- runtime::Module mhost = codegen::Build(mhost_all, target_host);
- for (const auto& it : device_modules) {
- if (it.operator->()) {
- mhost.Import(it);
- }
- }
-
- return mhost;
-}
-
-TVM_REGISTER_GLOBAL("driver.tir_to_runtime")
- .set_body_typed([](const Map<Target, IRModule>& inputs_arg, Target
host_target) {
- return TIRToRuntime(inputs_arg, host_target);
- });
-
-// Build for heterogeneous execution when targets are specified as
-// objects. This wrapper around the internal API is maintained for
-// backwards compatibility.
-runtime::Module build(const Map<Target, IRModule>& input, const Target&
target_host) {
- return TIRToRuntime(input, target_host);
-}
-
-// Build for heterogeneous execution when target is a string.
-runtime::Module build(const Map<String, IRModule>& inputs_arg, const Target&
target_host_arg) {
- Map<Target, IRModule> updated_inputs;
- Target target_host = target_host_arg;
- for (const auto& it : inputs_arg) {
- Target target = Target(it.first);
- CheckAndUpdateHostConsistency(&target, &target_host);
- Optional<String> device = target->GetAttr<String>("device");
- if (device.defined() && device.value() == "vta") {
- target = Target("ext_dev");
- }
- updated_inputs.Set(target, it.second);
- }
- return TIRToRuntime(updated_inputs, target_host);
-}
-
-// Build for homogeneous execution.
-runtime::Module build(const IRModule& funcs, const Target& target_arg,
- const Target& target_host_arg) {
- auto target = target_arg, target_host = target_host_arg;
- CheckAndUpdateHostConsistency(&target, &target_host);
- // More maps of target and target host
- Map<Target, IRModule> inputs = {{target, funcs}};
- return TIRToRuntime(inputs, target_host);
-}
-
-transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target
target) {
- transform::PassContext pass_ctx = transform::PassContext::Current();
-
- Array<Pass> mixed_pass_list;
-
- // FPComputeLegalize uses the target attrs added by BindTarget, so it must
come first
- mixed_pass_list.push_back(tir::transform::BindTarget(target));
- mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize());
-
- // VerifyVTCMLimit must occur before LowerVtcmAlloc
- mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target));
- // LowerVtcmAlloc must occur after any transformations that modify memory
allocation locations
- mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc());
-
- mixed_pass_list.push_back(tir::transform::VerifyMemory());
-
- mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc());
-
- bool detect_global_barrier =
- pass_ctx->GetConfig<Bool>("tir.detect_global_barrier",
Bool(false)).value();
- if (detect_global_barrier) {
- mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
- }
-
- mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
- mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn"));
- mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
- mixed_pass_list.push_back(tir::transform::InferFragment());
- mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
-
- bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy",
Bool(false)).value();
-
- if (use_async_copy) {
- mixed_pass_list.push_back(tir::transform::InjectPTXAsyncCopy());
- }
-
- bool ptx_ldg32 = pass_ctx->GetConfig<Bool>("tir.ptx_ldg32",
Bool(false)).value();
- if (ptx_ldg32) {
- mixed_pass_list.push_back(tir::transform::InjectPTXLDG32());
- }
-
- mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
- mixed_pass_list.push_back(tir::transform::SplitHostDevice());
- // MergeSharedMemoryAllocations must be applied after SplitHostDevice
- // because the merged allocation site is at the beginning of each device
function
- mixed_pass_list.push_back(tir::transform::MergeSharedMemoryAllocations());
-
- mixed_pass_list.push_back(tir::transform::MakePackedAPI());
- mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
- mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());
-
- mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());
-
- return transform::Sequential(mixed_pass_list);
-}
-
-TVM_REGISTER_GLOBAL("driver.mixed_mod_passes")
- .set_body_typed([](IRModule mixed_mod, Target target) {
- return MixedModulePassManager(mixed_mod, target);
- });
-
-transform::Sequential HostModulePassManager(IRModule mixed_mod, Target
target_host) {
- transform::PassContext pass_ctx = transform::PassContext::Current();
-
- Array<tvm::transform::Pass> host_pass_list;
-
- runtime::TypedPackedFunc<bool(tir::PrimFunc)> fcond = [](const
tir::PrimFunc& f) {
- return f->GetAttr<Integer>(tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) !=
- CallingConv::kDeviceKernelLaunch;
- };
- host_pass_list.push_back(tir::transform::Filter(fcond));
-
- ICHECK(mixed_mod.defined()) << "This module must be defined";
-
- host_pass_list.push_back(tir::transform::BindTarget(target_host));
-
- host_pass_list.push_back(tir::transform::LowerTVMBuiltin());
- host_pass_list.push_back(tir::transform::LowerCustomDatatypes());
- host_pass_list.push_back(tir::transform::LowerIntrin());
- host_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());
- host_pass_list.push_back(tir::transform::CombineContextCall());
-
- return transform::Sequential(host_pass_list);
-}
-
-TVM_REGISTER_GLOBAL("driver.host_mod_passes")
- .set_body_typed([](IRModule mixed_mod, Target target_host) {
- return HostModulePassManager(mixed_mod, target_host);
- });
-
-transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target
target) {
- Array<Pass> device_pass_list;
- runtime::TypedPackedFunc<bool(tir::PrimFunc)> fcond = [](const
tir::PrimFunc& f) {
- return f->GetAttr<Integer>(tvm::attr::kCallingConv,
Integer(CallingConv::kDefault)) ==
- CallingConv::kDeviceKernelLaunch;
- };
- device_pass_list.push_back(tir::transform::Filter(fcond));
-
- device_pass_list.push_back(tir::transform::BindTarget(target));
-
- device_pass_list.push_back(tir::transform::LowerWarpMemory());
- device_pass_list.push_back(tir::transform::Simplify());
- device_pass_list.push_back(tir::transform::LowerCustomDatatypes());
- device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo());
- device_pass_list.push_back(tir::transform::LowerIntrin());
-
- return transform::Sequential(device_pass_list);
-}
-
-TVM_REGISTER_GLOBAL("driver.device_mod_passes")
- .set_body_typed([](IRModule mixed_mod, Target target_host) {
- return DeviceModulePassManager(mixed_mod, target_host);
- });
-
-} // namespace tvm
diff --git a/src/driver/internal_driver_api.h b/src/driver/internal_driver_api.h
deleted file mode 100644
index 3b7cc7c7f7..0000000000
--- a/src/driver/internal_driver_api.h
+++ /dev/null
@@ -1,48 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file src/driver/driver_api.h
- * \brief Internal compiler driver APIs to drive the compilation.
- *
- * This module provides functionality that may be called internally
- * within TVM, but is not part of the public-facing API.
- */
-#ifndef TVM_DRIVER_INTERNAL_DRIVER_API_H_
-#define TVM_DRIVER_INTERNAL_DRIVER_API_H_
-
-#include <tvm/ir/module.h>
-#include <tvm/target/target.h>
-
-namespace tvm {
-
-/*!
- * \brief Build a device and host module for a specific target from a map
- * contains target to IRModule. This function is used
- * for heterogeneous build.
- * \param input The map contains target to an IRModule.
- * \param target_host The target for building host code. To use the default,
- * pass Target().
- * \return The built module that contains code for different processors.
- */
-runtime::Module TIRToRuntime(const Map<Target, IRModule>& input, const Target&
target_host);
-
-} // namespace tvm
-
-#endif // TVM_DRIVER_INTERNAL_DRIVER_API_H_
diff --git a/src/relax/backend/vm/codegen_vm.cc
b/src/relax/backend/vm/codegen_vm.cc
index 8c0ddeb6c3..18da88be80 100644
--- a/src/relax/backend/vm/codegen_vm.cc
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -21,7 +21,6 @@
* \file src/relax/backend/vm/codegen_vm.cc
* \brief A codegen to generate VM executable from a Relax IRModule.
*/
-#include <tvm/driver/driver_api.h>
#include <tvm/relax/exec_builder.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/op_attr_types.h>
diff --git a/src/relax/backend/vm/codegen_vm_tir.cc
b/src/relax/backend/vm/codegen_vm_tir.cc
index a92cf7c749..e3812ea8c1 100644
--- a/src/relax/backend/vm/codegen_vm_tir.cc
+++ b/src/relax/backend/vm/codegen_vm_tir.cc
@@ -21,7 +21,6 @@
* \file src/relax/backend/vm/codegen_tir.cc
* \brief A codegen to generate VMTIR function(that can be compiled) from
executable.
*/
-#include <tvm/driver/driver_api.h>
#include <tvm/ir/module.h>
#include <tvm/relax/exec_builder.h>
#include <tvm/relax/expr_functor.h>
diff --git a/src/relax/transform/bind_params.cc
b/src/relax/transform/bind_params.cc
index 27931b6017..14f68da3e4 100644
--- a/src/relax/transform/bind_params.cc
+++ b/src/relax/transform/bind_params.cc
@@ -17,7 +17,6 @@
* under the License.
*/
-#include <tvm/driver/driver_api.h>
#include <tvm/ir/function.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
diff --git a/src/relax/transform/fold_constant.cc
b/src/relax/transform/fold_constant.cc
index ff193acf14..fb6a01a19d 100644
--- a/src/relax/transform/fold_constant.cc
+++ b/src/relax/transform/fold_constant.cc
@@ -17,7 +17,6 @@
* under the License.
*/
-#include <tvm/driver/driver_api.h>
#include <tvm/ir/function.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
@@ -116,8 +115,10 @@ class ConstantFolder : public ExprMutator {
// already scheduled to only work on GPU, we will need to skip this in
the const folder for
// now
// TODO(Hongyi): further check and narrow the scope of foldable function
- runtime::Module rt_module =
- build(LowerPrimFunc(func, "tir_function"), eval_cpu_target,
eval_cpu_target);
+ auto* pf = runtime::Registry::Get("tir.build");
+ ICHECK(pf != nullptr) << "Cannot find tir.build in registry";
+ func = WithAttr(func, tvm::attr::kGlobalSymbol, String("tir_function"));
+ runtime::Module rt_module = (*pf)(func, eval_cpu_target);
build_func = rt_module.GetFunction("tir_function");
} catch (const tvm::Error& err) {
// build failure may happen in which case we skip
diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc
index fbc43a00ca..1c77219d45 100644
--- a/src/tir/ir/transform.cc
+++ b/src/tir/ir/transform.cc
@@ -29,6 +29,26 @@ namespace tvm {
namespace tir {
namespace transform {
+// Register build pipeline related options
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.noalias", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_buffer_level_predication", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_static_smem", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer);
+TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool);
+
/*!
* \brief Function level pass that applies transformations to all
* TIR functions within the module.
diff --git a/src/tir/transforms/primfunc_utils.cc
b/src/tir/transforms/primfunc_utils.cc
index 7f45fee9a2..d5946fda21 100644
--- a/src/tir/transforms/primfunc_utils.cc
+++ b/src/tir/transforms/primfunc_utils.cc
@@ -22,7 +22,6 @@
* \brief Passes that serve as helper functions.
*/
-#include <tvm/driver/driver_api.h>
#include <tvm/tir/transform.h>
namespace tvm {
diff --git a/tests/python/codegen/test_target_codegen_cuda.py
b/tests/python/codegen/test_target_codegen_cuda.py
index ae3173a14d..b3cad9acd3 100644
--- a/tests/python/codegen/test_target_codegen_cuda.py
+++ b/tests/python/codegen/test_target_codegen_cuda.py
@@ -766,6 +766,7 @@ def test_cuda_thread_sync_inside_condition():
tvm.build(mod, target="cuda")
[email protected]_cuda
def test_invalid_reinterpret():
@T.prim_func
def func(A: T.Buffer((4,), "uint32"), B: T.Buffer((4,), "uint8")) -> None:
diff --git a/tests/python/codegen/test_target_codegen_llvm.py
b/tests/python/codegen/test_target_codegen_llvm.py
index e3ccff49ba..304c79559c 100644
--- a/tests/python/codegen/test_target_codegen_llvm.py
+++ b/tests/python/codegen/test_target_codegen_llvm.py
@@ -42,7 +42,7 @@ def test_llvm_intrin():
body = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A],
body).with_attr("global_symbol", "prefetch"))
- fcode = tvm.build(mod, None, "llvm")
+ fcode = tvm.build(mod, None)
@tvm.testing.requires_llvm
@@ -54,7 +54,7 @@ def test_llvm_void_intrin():
ib.emit(x)
body = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A],
body).with_attr("global_symbol", "main"))
- fcode = tvm.build(mod, None, "llvm")
+ fcode = tvm.build(mod, None)
@tvm.testing.requires_llvm
@@ -106,7 +106,7 @@ def test_llvm_lookup_intrin():
ib.emit(x)
body = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A],
body).with_attr("global_symbol", "main"))
- fcode = tvm.build(mod, None, "llvm")
+ fcode = tvm.build(mod, None)
@tvm.testing.requires_llvm
@@ -557,9 +557,6 @@ def test_llvm_div():
print("dtype: {}".format(dtype))
print("dividend range: [{}, {}]".format(start, end))
print("divisor range: [{}, {}]".format(dstart, dend))
- lowered = tvm.lower(sch.mod, simple_mode=True)
- print("Lowered code:")
- print(lowered)
# Check that the computed values are correct
for i in range(start, end + 1):
@@ -764,44 +761,6 @@ def test_dwarf_debug_information():
check_llvm_ir()
[email protected]_llvm
-def test_llvm_shuffle():
- a = te.placeholder((8,), "int32")
- b = te.placeholder((8,), "int32")
- c = te.compute((8,), lambda x: a[x] + b[7 - x])
-
- # Convert to TIR and create schedule
- mod = te.create_prim_func([a, b, c])
- sch = tir.Schedule(mod)
-
- def my_vectorize():
- def vectorizer(op):
- store = op.body
- idx = tvm.tir.Ramp(tvm.tir.const(0, "int32"), tvm.tir.const(1,
"int32"), 8)
- value = store.value
- b_idx = tvm.tir.Shuffle([idx], [tvm.tir.const(i, "int32") for i in
range(7, -1, -1)])
- new_a = tvm.tir.BufferLoad(value.a.buffer, [idx])
- new_b = tvm.tir.BufferLoad(value.b.buffer, [b_idx])
- value = new_a + new_b
- return tvm.tir.BufferStore(store.buffer, new_a + new_b, [idx])
-
- def _transform(f, *_):
- return f.with_body(
- tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer,
["tir.For"])
- )
-
- return tvm.tir.transform.prim_func_pass(_transform, opt_level=0,
name="my_vectorize")
-
- with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1,
my_vectorize())]}):
- ir = tvm.lower(sch.mod, simple_mode=True)
- module = tvm.build(sch.mod)
- a_ = tvm.nd.array(np.arange(1, 9, dtype="int32"))
- b_ = tvm.nd.array(np.arange(8, 0, -1, dtype="int32"))
- c_ = tvm.nd.array(np.zeros((8,), dtype="int32"))
- module(a_, b_, c_)
- tvm.testing.assert_allclose(c_.numpy(), (a_.numpy() *
2).astype("int32"))
-
-
def np_float2np_bf16(arr):
"""Convert a numpy array of float to a numpy array
of bf16 in uint16"""
diff --git a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py
b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py
index 99fc6ac074..169d868b54 100644
--- a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py
+++ b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py
@@ -244,15 +244,6 @@ class TestElementWise:
return [sch.mod]
- @tvm.testing.fixture
- def ir_module(self, schedule_args):
- # If the two buffers are accessed with the same indices, CSE
- # will replace them with a Let binding. Since this makes it
- # harder to test what the transformed indices are, disabling
- # the CSE pass for this test.
- with
tvm.transform.PassContext(disabled_pass=["tir.CommonSubexprElimTIR"]):
- return tvm.lower(*schedule_args)
-
@tvm.testing.fixture
def uses_unsupported_physical_dimensions( # pylint: disable=invalid-name
self, target_host, input_layout, working_layout, output_layout
@@ -291,9 +282,6 @@ class TestElementWise:
assert len(buffer.shape) == expected_physical_dimensions
- def test_lower(self, schedule_args):
- assert tvm.lower(*schedule_args)
-
@requires_hexagon_toolchain
def test_build(self, schedule_args, target_host, input_layout,
working_layout, output_layout):
"""Testing build success/failure
diff --git a/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py
b/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py
index a927532c8f..f0cefa3fe2 100644
--- a/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py
+++ b/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py
@@ -199,12 +199,6 @@ def _benchmark_hexagon_elementwise_add_kernel(
try:
ns_tir_module = _get_irmod_elemwise_add(shape, dtype, mem_scope)
- # Dump the primfunc NS-TIR (as text) to the log file...
- lowered_mod = tvm.lower(ns_tir_module, _PRIMFUNC_NAME)
- log_file.write("LOWERED IR MODULE:\n")
- log_file.write(str(lowered_mod))
- log_file.write("\n")
-
# Lower the primfunc's IRModule to Hexagon object code...
input1 = tvm.te.placeholder(shape, dtype=dtype)
input2 = tvm.te.placeholder(shape, dtype=dtype)
diff --git a/tests/python/contrib/test_hexagon/test_meta_schedule.py
b/tests/python/contrib/test_hexagon/test_meta_schedule.py
index 26acedb88e..c0c7355a9a 100644
--- a/tests/python/contrib/test_hexagon/test_meta_schedule.py
+++ b/tests/python/contrib/test_hexagon/test_meta_schedule.py
@@ -156,7 +156,7 @@ def schedule_dense(sch, block, m_size, do_tune):
def verify_dense(sch, target, m_size, n_size, k_size, hexagon_session):
"""Verify dense operator."""
- f = tvm.build(sch.mod["main"], target=target, name="dense")
+ f = tvm.build(sch.mod["main"], target=target)
mod = hexagon_session.load_module(f)
dev = hexagon_session.device
diff --git a/tests/python/contrib/test_hexagon/test_sigmoid.py
b/tests/python/contrib/test_hexagon/test_sigmoid.py
index cc633795c2..1247d90759 100644
--- a/tests/python/contrib/test_hexagon/test_sigmoid.py
+++ b/tests/python/contrib/test_hexagon/test_sigmoid.py
@@ -92,7 +92,7 @@ class TestSigmoid(BaseSigmoid):
func_name = "sigmoid"
with tvm.transform.PassContext(opt_level=3):
- runtime_module = tvm.build(tir_s.mod,
target=get_hexagon_target("v69"), name=func_name)
+ runtime_module = tvm.build(tir_s.mod,
target=get_hexagon_target("v69"))
assert "hvx_sigmoid" in runtime_module.get_source("asm")
assert "vmin" in runtime_module.get_source("asm")
diff --git a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py
b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py
index 498e29e407..d45b35befd 100644
--- a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py
+++ b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py
@@ -183,7 +183,6 @@ class TestAsyncSoftwarePipeline:
"tir.experimental_dma_bypass_cache": 1,
}
):
- # tvm.lower(schedule.mod["main"]).show()
func = tvm.build(schedule.mod["main"],
target=get_hexagon_target("v68"))
with hexagon_launcher.create_session() as hexagon_session:
diff --git a/tests/python/ir/test_pass_instrument.py
b/tests/python/ir/test_pass_instrument.py
index cfeb70b963..718cf3a663 100644
--- a/tests/python/ir/test_pass_instrument.py
+++ b/tests/python/ir/test_pass_instrument.py
@@ -38,7 +38,7 @@ def test_tir_print_all_passes(capsys):
B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0
with tvm.transform.PassContext(opt_level=3, instruments=[PrintBeforeAll(),
PrintAfterAll()]):
- tvm.lower(func)
+ tvm.build(func)
all_passes_output = capsys.readouterr().out
assert "Before Running Pass:" in all_passes_output
assert "After Running Pass:" in all_passes_output
diff --git a/tests/python/tir-base/test_lower_build.py
b/tests/python/tir-base/test_lower_build.py
deleted file mode 100644
index edb3ed351e..0000000000
--- a/tests/python/tir-base/test_lower_build.py
+++ /dev/null
@@ -1,133 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import numpy as np
-
-import tvm
-from tvm.ir.module import IRModule
-from tvm.script import tir as T
-import tvm.testing
-
-
-def _check_module_with_numpy(mod, shape=(128, 128, 128)):
- m, n, k = shape
- a = tvm.nd.array(np.random.rand(m, k).astype("float32"))
- b = tvm.nd.array(np.random.rand(n, k).astype("float32"))
- c = tvm.nd.array(np.zeros((m, n), dtype="float32"))
- c_np = np.dot(a.numpy(), b.numpy().transpose())
- mod(a, b, c)
- tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
-
-
-# pylint: disable=no-self-argument, missing-class-docstring,
missing-function-docstring
[email protected]_func
-def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, [128, 128])
- B = T.match_buffer(b, [128, 128])
- C = T.match_buffer(c, [128, 128])
- for i, j in T.grid(128, 128):
- with T.block("init"):
- vi, vj = T.axis.remap("SS", [i, j])
- C[vi, vj] = T.float32(0)
- for k in range(128):
- with T.block("update"):
- vi, vj, vk = T.axis.remap("SSR", [i, j, k])
- C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
-
-
[email protected]_module
-class LoweredModule:
- @T.prim_func
- def main(
- A: T.Buffer((128, 128), "float32"),
- B: T.Buffer((128, 128), "float32"),
- C: T.Buffer((128, 128), "float32"),
- ) -> None:
- # function attr dict
- T.func_attr({"global_symbol": "main", "from_legacy_te_schedule": True,
"tir.noalias": True})
- A_flat = T.Buffer([16384], data=A.data)
- B_flat = T.Buffer([16384], data=B.data)
- C_flat = T.Buffer([16384], data=C.data)
- # body
- for x, y in T.grid(128, 128):
- C_flat[x * 128 + y] = 0.0
- for k in T.serial(0, 128):
- C_flat[x * 128 + y] = (
- C_flat[x * 128 + y] + A_flat[x * 128 + k] * B_flat[y * 128
+ k]
- )
-
-
[email protected]_module
-class LoweredTIRModule:
- @T.prim_func
- def main(
- A: T.Buffer((128, 128), "float32"),
- B: T.Buffer((128, 128), "float32"),
- C: T.Buffer((128, 128), "float32"),
- ) -> None:
- # function attr dict
- T.func_attr({"global_symbol": "main", "tir.noalias": True})
- A_flat = T.Buffer([16384], data=A.data)
- B_flat = T.Buffer([16384], data=B.data)
- C_flat = T.Buffer([16384], data=C.data)
- # body
- for x, y in T.grid(128, 128):
- C_flat[x * 128 + y] = 0.0
- for k in T.serial(0, 128):
- C_flat[x * 128 + y] = (
- C_flat[x * 128 + y] + A_flat[x * 128 + k] * B_flat[y * 128
+ k]
- )
-
-
-def test_lower_build_tir_func():
- # check lowering with the CSE pass disabled as otherwise it would do some
commoning
- with tvm.transform.PassContext(opt_level=3,
disabled_pass=["tir.CommonSubexprElimTIR"]):
- ir_mod = tvm.lower(matmul)
- tvm.ir.assert_structural_equal(ir_mod, LoweredTIRModule)
- # check building
- mod = tvm.build(matmul, target="llvm")
- _check_module_with_numpy(mod)
-
-
-def test_lower_build_tir_module():
- func = matmul.with_attr("global_symbol", "main")
- func = func.with_attr("tir.noalias", T.bool(True))
- ir_mod = IRModule({"main": func})
- # check lowering with the CSE pass disabled as otherwise it would do some
commoning
- with tvm.transform.PassContext(opt_level=3,
disabled_pass=["tir.CommonSubexprElimTIR"]):
- lowered_mod = tvm.lower(ir_mod)
- tvm.ir.assert_structural_equal(lowered_mod, LoweredTIRModule)
- # check building
- mod = tvm.build(ir_mod, target="llvm")
- _check_module_with_numpy(mod)
-
-
-def test_lower_build_lowered_module():
- # check lowering with the CSE pass disabled as otherwise it would do some
commoning
- with tvm.transform.PassContext(opt_level=3,
disabled_pass=["tir.CommonSubexprElimTIR"]):
- ir_mod = tvm.lower(LoweredTIRModule)
- tvm.ir.assert_structural_equal(ir_mod, LoweredTIRModule)
- # check building
- mod = tvm.build(ir_mod, target="llvm")
- _check_module_with_numpy(mod)
-
-
-if __name__ == "__main__":
- test_lower_build_te_schedule()
- test_lower_build_tir_func()
- test_lower_build_tir_module()
- test_lower_build_lowered_module()
diff --git a/tests/python/tir-base/test_tir_te_extern_primfunc.py
b/tests/python/tir-base/test_tir_te_extern_primfunc.py
index 45ca7a1c72..16bc0b0ae2 100644
--- a/tests/python/tir-base/test_tir_te_extern_primfunc.py
+++ b/tests/python/tir-base/test_tir_te_extern_primfunc.py
@@ -192,7 +192,6 @@ class TestPrimFuncs:
input_tensors = [te.placeholder(buf_name_map[name].shape) for name in
params]
output = te.extern_primfunc(input_tensors, prim_func)
rt_prim_func = te.create_prim_func(tensors_from_extern_op(output,
prim_func))
- tvm.ir.assert_structural_equal(tvm.lower(prim_func),
tvm.lower(rt_prim_func))
target = tvm.target.Target("llvm")
func = tvm.build(rt_prim_func, target=target)
diff --git
a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py
b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py
index 390745fe9d..fe9998bc79 100644
---
a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py
+++
b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py
@@ -117,7 +117,7 @@ def run_test(
mma_store_intrin,
)
- f = tvm.build(sch.mod["main"], target="cuda", name="dense")
+ f = tvm.build(sch.mod["main"], target="cuda")
dev = tvm.device("cuda", 0)
diff --git
a/tests/python/tir-schedule/test_tir_schedule_tensorize_mfma_numeric.py
b/tests/python/tir-schedule/test_tir_schedule_tensorize_mfma_numeric.py
index 8077a603bc..2b3e6ce39b 100644
--- a/tests/python/tir-schedule/test_tir_schedule_tensorize_mfma_numeric.py
+++ b/tests/python/tir-schedule/test_tir_schedule_tensorize_mfma_numeric.py
@@ -109,7 +109,7 @@ def run_test(
mma_store_intrin,
)
- f = tvm.build(sch.mod["main"], target="rocm", name="dense")
+ f = tvm.build(sch.mod["main"], target="rocm")
dev = tvm.device("rocm", 0)
if in_dtype == "float32":
diff --git a/tests/python/tir-transform/test_tir_transform_convert_ssa.py
b/tests/python/tir-transform/test_tir_transform_convert_ssa.py
index ec768ba74f..b93747c84a 100644
--- a/tests/python/tir-transform/test_tir_transform_convert_ssa.py
+++ b/tests/python/tir-transform/test_tir_transform_convert_ssa.py
@@ -234,41 +234,6 @@ def test_no_change_if_already_ssa():
assert before.same_as(after)
-class TestDedupAutoBroadcastBuffer(BaseBeforeAfter):
- """De-dup auto-broadcast buffers
-
- Auto-broadcast buffers can define additional variables during the
- `Buffer::Buffer` constructor for the strides. This is intended to
- be used for match buffers, where these variables are defined based
- on the argument being passed in.
-
- These additional variables can cause errors when copying a buffer
- with the `Buffer::Buffer` constructor. If a buffer has non-empty
- shape, empty strides, and kAutoBroadcast type, then the resulting
- buffer will have additional strides defined. Such a buffer can
- result from lowering of a scalar buffer, which will be flattened
- to a shape of [1].
-
- Previous implementations of ConvertSSA incorrectly handled this
- case, resulting in undefined stride variables.
- """
-
- def _make_func(self):
- @T.prim_func
- def func(a: T.handle):
- A = T.match_buffer(a, shape=(), dtype="float32",
buffer_type="auto")
- A[()] = 1.0
-
- return tvm.lower(func)["main"]
-
- def before(self):
- func = self._make_func()
- return tvm.IRModule({"func_a": func, "func_b": func})
-
- def expected(self):
- return tvm.IRModule({"func_a": self._make_func(), "func_b":
self._make_func()})
-
-
class TestKeepDuplicateThreadIdxInSameFunction(BaseBeforeAfter):
"""Environment threads are treated as being at function scope
diff --git a/tests/python/tir-transform/test_tir_transform_extract_constants.py
b/tests/python/tir-transform/test_tir_transform_extract_constants.py
index b3e0aa74f9..cbfb6d39bc 100644
--- a/tests/python/tir-transform/test_tir_transform_extract_constants.py
+++ b/tests/python/tir-transform/test_tir_transform_extract_constants.py
@@ -63,8 +63,6 @@ def test_const_extraction():
for n, f in mod.functions.items():
tvm.tir.stmt_functor.post_order_visit(f.body, _visit)
- tvm.lower(mod)
-
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/tir-transform/test_tir_transform_flatten_buffer.py
b/tests/python/tir-transform/test_tir_transform_flatten_buffer.py
index b215398622..925f004cc5 100644
--- a/tests/python/tir-transform/test_tir_transform_flatten_buffer.py
+++ b/tests/python/tir-transform/test_tir_transform_flatten_buffer.py
@@ -322,36 +322,5 @@ class TestFlattenDeclBufferWithAxisSeparators(BaseCompare):
T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5])
-def test_lower_2d_physical_memory():
- """Axis separators should preserve 2-d buffers through lowering.
-
- A catch-all test to ensure that defining axis_separators is
- sufficient to maintain non-flat buffer descriptions through all
- lowering steps.
- """
-
- # This test doesn't use CompareBeforeAfter, because the after step
- # is not currently expressible in TVMScript. This test can be
- # re-written after https://github.com/apache/tvm/pull/12412.
-
- @T.prim_func
- def func():
- buf = T.alloc_buffer(
- [1, 1],
- dtype="int32",
- scope="global",
- axis_separators=[1],
- )
- buf[0, 0] = 0
-
- lowered = tvm.lower(func)["main"]
- assert isinstance(lowered.body, tvm.tir.Allocate)
- assert list(lowered.body.extents) == [1, 1], (
- "Non-flat buffer allocations, "
- "marked by axis_separators, "
- "flattened to flat memory allocation."
- )
-
-
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
index 754ce03240..0a040b0eea 100644
--- a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
+++ b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
@@ -179,7 +179,7 @@ def test_call_packed_return_non_i32():
)
mod = build_tir()
- f = tvm.build(mod, None, "llvm")
+ f = tvm.build(mod, None)
a = tvm.nd.array(np.zeros(2, dtype="float32"))
f(a)
tvm.testing.assert_allclose(a.numpy(), expected_value)
diff --git a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py
b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py
index 93c680c846..a7b5280939 100644
--- a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py
+++ b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py
@@ -111,33 +111,6 @@ def test_thread_axis():
check(2**14, 32, target_bits=16, target_dtype="int32")
-def test_thread_axis_2():
- # fmt: off
- @tvm.script.ir_module
- class Before:
- @T.prim_func
- def main(T_reshape: T.Buffer((1, 12, 384, 384), "float32"),
placeholder_1: T.Buffer((T.int64(1), T.int64(12), T.int64(384), 384), "bool"),
T_where: T.Buffer((T.int64(1), T.int64(12), T.int64(384), 384), "float32")) ->
None:
- # function attr dict
- T.func_attr({"global_symbol": "main", "tir.noalias": True})
- # body
- # with T.block("root")
- for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(256),
thread="blockIdx.x"):
- for i0_i1_i2_i3_fused_2 in T.thread_binding(T.int64(1024),
thread="threadIdx.x"):
- for i0_i1_i2_i3_fused_0 in T.serial(T.int64(7)):
- with T.block("T_where"):
- ax0 = T.axis.spatial(T.int64(1), T.int64(0))
- ax1 = T.axis.spatial(T.int64(12),
((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) +
i0_i1_i2_i3_fused_2) % T.int64(1769472) // T.int64(147456))
- ax2 = T.axis.spatial(T.int64(384),
((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) +
i0_i1_i2_i3_fused_2) % T.int64(147456) // T.int64(384))
- ax3 = T.axis.spatial(384,
T.cast(((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) *
T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(384), "int32"))
- T.where((i0_i1_i2_i3_fused_0 * T.int64(256) +
i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2 < T.int64(1769472))
- T.reads(placeholder_1[ax0, ax1, ax2, ax3],
T_reshape[ax0, ax1, ax2, ax3])
- T.writes(T_where[ax0, ax1, ax2, ax3])
- T_where[ax0, ax1, ax2, ax3] =
T.Select(T.cast(placeholder_1[ax0, ax1, ax2, ax3], "int32") != 0,
T.float32(-1000000000), T_reshape[ax0, ax1, ax2, ax3])
- # fmt: on
- # TODO(@junrushao1994): make this test more "unit" after the new TVMScript
printer/parser lands
- tvm.lower(Before)
-
-
def test_multilanes():
def check(m, lanes, target_bits, target_dtype):
ib = tvm.tir.ir_builder.create()
diff --git a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py
b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py
index ab91c6c7b3..548b199a94 100644
--- a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py
+++ b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py
@@ -610,87 +610,5 @@ class TestNoOrphanedDeclBuffer(BaseCompare):
D[i] = C[i]
-def test_vulkan_smem_reuse():
- target = tvm.target.Target(
- {
- "keys": ["vulkan", "gpu"],
- "kind": "vulkan",
- "max_num_threads": 256,
- "max_threads_per_block": 256,
- "supports_float32": True,
- "supports_int32": True,
- "tag": "",
- "thread_warp_size": 1,
- }
- )
-
- @T.prim_func(private=True)
- def func(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")):
- T.func_attr({"tir.noalias": T.bool(True)})
- A_shared = T.allocate([4], "float32", "shared")
- A_local = T.allocate([4], "float32", "local")
- B_shared = T.allocate([4], "float16", "shared")
- A_shared_1 = T.Buffer((4,), data=A_shared, scope="shared")
- with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
- A_1 = T.Buffer((4,), data=A.data)
- A_shared_1[threadIdx_x] = A_1[threadIdx_x]
- A_local_1 = T.Buffer((4,), data=A_local, scope="local")
- with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
- A_local_1[threadIdx_x] = A_shared_1[threadIdx_x]
- B_shared_1 = T.Buffer((4,), "float16", data=B_shared, scope="shared")
- with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
- B_shared_1[threadIdx_x] = T.Cast("float16", A_local_1[threadIdx_x])
- threadIdx_x = T.launch_thread("threadIdx.x", 4)
- B_1 = T.Buffer((4,), "float16", data=B.data)
- B_1[threadIdx_x] = B_shared_1[threadIdx_x]
-
- @T.prim_func(private=True)
- def normal_lowering(A: T.Buffer((4,), "float32"), B: T.Buffer((4,),
"float16")):
- T.func_attr({"tir.noalias": T.bool(True)})
- A_shared = T.allocate([4], "float32", "shared")
- A_local = T.allocate([4], "float32", "local")
- A_shared_1 = T.Buffer((4,), data=A_shared, scope="shared")
- with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
- A_1 = T.Buffer((4,), data=A.data)
- A_shared_1[threadIdx_x] = A_1[threadIdx_x]
- A_local_1 = T.Buffer((4,), data=A_local, scope="local")
- with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
- A_local_1[threadIdx_x] = A_shared_1[threadIdx_x]
- A_shared_2 = T.Buffer((4,), "float16", data=A_shared, scope="shared")
- with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
- A_shared_2[threadIdx_x] = T.Cast("float16", A_local_1[threadIdx_x])
- threadIdx_x = T.launch_thread("threadIdx.x", 4)
- B_1 = T.Buffer((4,), "float16", data=B.data)
- B_1[threadIdx_x] = A_shared_2[threadIdx_x]
-
- @T.prim_func(private=True)
- def no_reuse_lowering(A: T.Buffer((4,), "float32"), B: T.Buffer((4,),
"float16")):
- T.func_attr({"target": target, "tir.noalias": T.bool(True)})
- A_shared_1 = T.allocate([4], "float32", "shared")
- A_local_1 = T.allocate([4], "float32", "local")
- B_shared_1 = T.allocate([4], "float16", "shared")
- A_shared_1_1 = T.Buffer((4,), data=A_shared_1, scope="shared")
- with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
- A_1 = T.Buffer((4,), data=A.data)
- A_shared_1_1[threadIdx_x] = A_1[threadIdx_x]
- A_local_1_1 = T.Buffer((4,), data=A_local_1, scope="local")
- with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
- A_local_1_1[threadIdx_x] = A_shared_1_1[threadIdx_x]
- B_shared_1_1 = T.Buffer((4,), "float16", data=B_shared_1,
scope="shared")
- with T.launch_thread("threadIdx.x", 4) as threadIdx_x:
- B_shared_1_1[threadIdx_x] = T.Cast("float16",
A_local_1_1[threadIdx_x])
- threadIdx_x = T.launch_thread("threadIdx.x", 4)
- B_1 = T.Buffer((4,), "float16", data=B.data)
- B_1[threadIdx_x] = B_shared_1_1[threadIdx_x]
-
- # Reuse shared memory when lowering without target.
- mod = tvm.IRModule({"main": func})
- tvm.ir.assert_structural_equal(tvm.lower(mod)["main"], normal_lowering)
-
- # No shared memory reuse when lowering with target Vulkan.
- mod = tvm.tir.transform.BindTarget(target)(mod)
- tvm.ir.assert_structural_equal(tvm.lower(mod)["main"], no_reuse_lowering)
-
-
if __name__ == "__main__":
tvm.testing.main()