This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 75e936e [REFACTOR][TIR] Migrate most of low-level build to use the
Pass Manager. (#5225)
75e936e is described below
commit 75e936e1b5db305864c76277e2ba47c453c4c6a8
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Apr 3 15:50:11 2020 -0700
[REFACTOR][TIR] Migrate most of low-level build to use the Pass Manager.
(#5225)
* [REFACTOR][TIR] Migrate most of low-level build to use the Pass Manager.
- SplitHostDevice
- ThreadSync
- BindDevice
- LowerThreadAllreduce
- Provide a temp fix for printing IRModule with PrimFunc before the formal
text printer.
* Address comments, fix tests.
* Fix relay tests
* Explicit move
---
include/tvm/ir/function.h | 14 +-
include/tvm/ir/module.h | 2 +
include/tvm/tir/analysis.h | 10 ++
include/tvm/tir/ir_pass.h | 78 ----------
include/tvm/tir/transform.h | 15 ++
python/tvm/driver/build_module.py | 79 ++++------
python/tvm/ir/__init__.py | 2 +-
python/tvm/ir/function.py | 8 +
python/tvm/ir/module.py | 1 -
python/tvm/tir/transform/function_pass.py | 3 +-
python/tvm/tir/transform/transform.py | 64 ++++++++
src/driver/driver_api.cc | 109 +++++++------
src/printer/relay_text_printer.cc | 10 +-
src/target/codegen.cc | 5 +-
src/target/llvm/codegen_cpu.cc | 1 +
src/tir/ir/transform.cc | 7 +-
src/tir/pass/ffi_api.cc | 5 -
src/tir/pass/make_api.cc | 63 --------
src/tir/transforms/bind_device_type.cc | 112 ++++++++++++++
src/tir/transforms/lower_thread_allreduce.cc | 12 --
src/tir/{pass => transforms}/split_host_device.cc | 169 ++++++++++++++-------
src/tir/transforms/tensorcore_infer_fragment.cc | 15 +-
src/tir/transforms/thread_storage_sync.cc | 7 -
..._host_device.py => test_tir_analysis_usedef.py} | 2 +-
.../unittest/test_tir_pass_inject_double_buffer.py | 4 +-
.../unittest/test_tir_pass_storage_flatten.py | 5 +-
.../test_tir_transform_lower_warp_memory.py | 15 +-
.../unittest/test_tir_transform_thread_sync.py | 10 +-
28 files changed, 465 insertions(+), 362 deletions(-)
diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h
index ecf7c19..dc7a2b2 100644
--- a/include/tvm/ir/function.h
+++ b/include/tvm/ir/function.h
@@ -47,19 +47,19 @@ enum class CallingConv : int {
*/
kDefault = 0,
/*!
+ * \brief PackedFunc that exposes a CPackedFunc signature.
+ *
+ * - Calling by PackedFunc calling convention.
+ * - Implementation: Expose a function with the CPackedFunc signature.
+ */
+ kCPackedFunc = 1,
+ /*!
* \brief Device kernel launch
*
* - Call by PackedFunc calling convention.
* - Implementation: defined by device runtime(e.g. runtime/cuda)
*/
kDeviceKernelLaunch = 2,
- /*!
- * \brief PackedFunc that exposes a CPackedFunc signature.
- *
- * - Calling by PackedFunc calling convention.
- * - Implementation: Expose a function with the CPackedFunc signature.
- */
- kCPackedFunc = 3,
};
/*!
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index f6ea918..f63bf96 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -324,6 +324,8 @@ class IRModule : public ObjectRef {
/*! \brief Declare the container type. */
using ContainerType = IRModuleNode;
+ // allow copy on write.
+ TVM_DEFINE_OBJECT_REF_COW_METHOD(IRModuleNode);
};
/*!
diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index 6bab44e..fe74a96 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -49,6 +49,16 @@ struct ExprDeepEqual {
public:
TVM_DLL bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const;
};
+
+
+/*!
+ * \brief Find undefined vars in the statment.
+ * \param stmt The function to be checked.
+ * \param defs The vars that is defined.
+ * \return Array of undefined vars.
+ */
+Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
+
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_ANALYSIS_H_
diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h
index 6a1a178..8ba008b 100644
--- a/include/tvm/tir/ir_pass.h
+++ b/include/tvm/tir/ir_pass.h
@@ -407,56 +407,6 @@ LoweredFunc MakeAPI(Stmt body,
bool is_restricted);
/*!
- * \brief Bind the device type of host function to be device_type.
- * \param func The function to be binded.
- * \param device_type The device type to be binded.
- * \return The binded function.
- */
-LoweredFunc BindDeviceType(LoweredFunc func,
- int device_type);
-/*!
- * \brief Find undefined vars in the statment.
- * \param stmt The function to be checked.
- * \param defs The vars that is defined.
- * \return Array of undefined vars.
- */
-Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& defs);
-
-/*!
- * \brief Split the function into a host function and device functions.
- * \param func The function to be splitted.
- *
- * \return Array of functions, the first one is host function,
- * the others are device functions.
- */
-Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
-
-/*!
- * \brief Insert sync between parallel read/write of shared buffers.
- *
- * \param stmt The stmt to be trasnformed.
- * \param storage_scope The storage scope considered.
- */
-LoweredFunc ThreadSync(LoweredFunc stmt, std::string storage_scope);
-
-/*!
- * \brief Lower cross thread alleduce in the stmt.
- * \param f The device function to be lowered.
- * \param warp_size the size of warp where no sync is needed.
- * \return Transformed function.
- */
-LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);
-
-/*!
- * \brief Lower warp memory in stmt.
- * \param f The device function to be lowered.
- * \param warp_size the size of warp where no sync is needed.
- * this function will only take in effect if warp_size is bigger than
one.
- * \return Transformed function.
- */
-LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size);
-
-/*!
* \brief Remap the thread axis
*
* This can be used to get equivalent program which uses
@@ -471,26 +421,6 @@ LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size);
LoweredFunc RemapThreadAxis(LoweredFunc f, Map<PrimExpr, IterVar> axis_map);
/*!
- * \brief Lower packed function call.
- * \param f The function to be lowered.
- * \return Transformed function.
- */
-LoweredFunc LowerTVMBuiltin(LoweredFunc f);
-
-
-/*!
- * \brief Rewrite the pointer content type of arguments,
- * as well as Alloc internal to the function to use
- * the most frequently accessed type for load/store
- * to avoid pointer casting in backend when possible.
- *
- * \note implemeneted in storage_rewrite.cc
- * \param f The function to be trasnformed
- * \return Transformed function.
- */
-LoweredFunc PointerValueTypeRewrite(LoweredFunc f);
-
-/*!
* \brief Rewrite the pointer content type of arguments,
* as well as Alloc internal to the function to use
* the most frequently accessed type for load/store
@@ -514,14 +444,6 @@ PrimFunc PointerValueTypeRewrite(PrimFunc f);
LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target);
/*!
- * \brief Infer the TensorCore fragment infomation using tensor intrinsics
- *
- * \param f The device function to be lowered.
- * \return Transformed function.
- */
-LoweredFunc InferFragment(LoweredFunc f);
-
-/*!
* \brief Verify if memory accesses are legal for a specific target device
type.
*
* In the case that tgt is cuda, if not all workload is bound with
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index d809e07..211e344 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -59,6 +59,21 @@ TVM_DLL Pass CreatePrimFuncPass(const
runtime::TypedPackedFunc<
const tvm::Array<tvm::PrimExpr>& required);
/*!
+ * \brief Bind the device type ofthe function to be
+ * the device_type specified in the target attribute.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass BindDeviceType();
+
+/*!
+ * \brief Split the function into a host function and device functions.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass SplitHostDevice();
+
+/*!
* \brief skip assert stmt.
*
* \return The pass.
diff --git a/python/tvm/driver/build_module.py
b/python/tvm/driver/build_module.py
index 7eda40d..e4bd200 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
+# pylint: disable=invalid-name
"""The build utils in python.
This module provides the functions to transform schedule to
@@ -25,6 +27,7 @@ import tvm.tir
from tvm.runtime import ndarray
from tvm.ir import container
+from tvm.ir import CallingConv
from tvm.target import codegen, BuildConfig
from tvm.tir import ir_pass
from tvm.tir.stmt import LoweredFunc
@@ -222,75 +225,59 @@ def _build_for_device(flist, target, target_host):
mdev : tvm.module
A module that contains device code.
"""
- @tvm.tir.transform.prim_func_pass(opt_level=0)
- class BindTarget:
- def __init__(self, target):
- self.target = target
-
- # pylint: disable=unused-argument
- def transform_function(self, func, mod, ctx):
- return func.with_attr("target", self.target)
-
target = _target.create(target)
+ target_host = _target.create(target_host)
device_type = ndarray.context(target.target_name, 0).device_type
- fhost = []
- fdevice = []
+
for func in flist:
if not ir_pass.VerifyMemory(func, device_type):
raise ValueError(
"Direct host side access to device memory is detected in %s. "
"Did you forget to bind?" % func.name)
- if func.func_type == LoweredFunc.MixedFunc:
- if BuildConfig.current().detect_global_barrier:
- func = ir_pass.ThreadSync(func, "global")
- func = ir_pass.ThreadSync(func, "shared")
- func = ir_pass.ThreadSync(func, "warp")
- func = ir_pass.InferFragment(func)
- warp_size = target.thread_warp_size
- func = ir_pass.LowerThreadAllreduce(func, warp_size)
- fsplits = list(ir_pass.SplitHostDevice(func))
- fhost.append(fsplits[0])
- for x in fsplits[1:]:
- fdevice.append(x)
- elif func.func_type == LoweredFunc.HostFunc:
- fhost.append(func)
- elif func.func_type == LoweredFunc.DeviceFunc:
- fdevice.append(func)
- else:
- raise ValueError("unknown function type %d" % func.func_type)
-
- if "gpu" in target.keys and not fdevice:
- warnings.warn(
- "Specified target %s, but cannot find device code, did you do "
- "bind?" % target)
- fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
+ mod_mixed = tvm.testing.LoweredFuncsToIRModule(flist)
+ opt_mixed = [tvm.tir.transform.Apply(lambda f: f.with_attr("target",
target))]
+ if BuildConfig.current().detect_global_barrier:
+ opt_mixed += [tvm.tir.transform.ThreadSync("global")]
+ opt_mixed += [tvm.tir.transform.ThreadSync("shared"),
+ tvm.tir.transform.ThreadSync("warp"),
+ tvm.tir.transform.InferFragment(),
+ tvm.tir.transform.LowerThreadAllreduce(),
+ tvm.tir.transform.BindDeviceType(),
+ tvm.tir.transform.SplitHostDevice()]
+ mod_mixed = tvm.ir.transform.Sequential(opt_mixed)(mod_mixed)
- if device_type == ndarray.cpu(0).device_type and target_host == target:
- assert not fdevice
-
- target_host = _target.create(target_host)
# device optimizations
- mod_dev = tvm.testing.LoweredFuncsToIRModule(fdevice)
opt_device = tvm.ir.transform.Sequential(
- [BindTarget(target),
+ [tvm.tir.transform.Filter(
+ lambda f: "calling_conv" in f.attrs and
+ f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH),
tvm.tir.transform.LowerWarpMemory(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerIntrin()])
- mod_dev = opt_device(mod_dev)
+ mod_dev = opt_device(mod_mixed)
# host optimizations
- mod_host = tvm.testing.LoweredFuncsToIRModule(fhost)
opt_host = tvm.ir.transform.Sequential(
- [BindTarget(target_host),
+ [tvm.tir.transform.Filter(
+ lambda f: "calling_conv" not in f.attrs or
+ f.attrs["calling_conv"].value != CallingConv.DEVICE_KERNEL_LAUNCH),
+ tvm.tir.transform.Apply(lambda f: f.with_attr("target", target)),
tvm.tir.transform.LowerTVMBuiltin(),
tvm.tir.transform.LowerDeviceStorageAccessInfo(),
tvm.tir.transform.LowerIntrin(),
tvm.tir.transform.CombineContextCall()])
- mod_host = opt_host(mod_host)
+ mod_host = opt_host(mod_mixed)
+
+ if device_type == ndarray.cpu(0).device_type and target_host == target:
+ assert len(mod_dev.functions) == 0
+ if "gpu" in target.keys and len(mod_dev.functions) == 0:
+ warnings.warn(
+ "Specified target %s, but cannot find device code, did you do "
+ "bind?" % target)
- rt_mod_dev = codegen.build_module(mod_dev, target) if fdevice else None
+ rt_mod_dev = codegen.build_module(mod_dev, target) if
len(mod_dev.functions) != 0 else None
return mod_host, rt_mod_dev
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index b3efd6b..1aabf3e 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -23,7 +23,7 @@ from .type import TypeConstraint, FuncType, IncompleteType,
RelayRefType
from .tensor_type import TensorType
from .type_relation import TypeCall, TypeRelation
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
-from .function import BaseFunc
+from .function import CallingConv, BaseFunc
from .adt import Constructor, TypeData
from .module import IRModule
from .attrs import Attrs, DictAttrs, make_node
diff --git a/python/tvm/ir/function.py b/python/tvm/ir/function.py
index 70eb51a..afc8c10 100644
--- a/python/tvm/ir/function.py
+++ b/python/tvm/ir/function.py
@@ -15,10 +15,18 @@
# specific language governing permissions and limitations
# under the License.
"""Function defintiions."""
+from enum import IntEnum
from .expr import RelayExpr
from . import _ffi_api
+class CallingConv(IntEnum):
+ """Possible kinds of calling conventions."""
+ DEFAULT = 0
+ C_PACKED_FUNC = 1
+ DEVICE_KERNEL_LAUNCH = 2
+
+
class BaseFunc(RelayExpr):
"""Base class of all functions."""
@property
diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py
index 24f5211..8d75d8e 100644
--- a/python/tvm/ir/module.py
+++ b/python/tvm/ir/module.py
@@ -60,7 +60,6 @@ class IRModule(Node):
type_definitions = mapped_type_defs
self.__init_handle_by_constructor__(_ffi_api.IRModule, functions,
type_definitions)
-
def __setitem__(self, var, val):
"""Add a mapping to the module.
diff --git a/python/tvm/tir/transform/function_pass.py
b/python/tvm/tir/transform/function_pass.py
index 93bb996..a19cc2f 100644
--- a/python/tvm/tir/transform/function_pass.py
+++ b/python/tvm/tir/transform/function_pass.py
@@ -16,6 +16,7 @@
# under the License.
"""TIR specific function pass support."""
import inspect
+import types
import functools
import tvm._ffi
@@ -142,7 +143,7 @@ def prim_func_pass(pass_func=None, opt_level=None,
name=None, required=None):
return _wrap_class_function_pass(pass_arg, info)
if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
- return _ffi_api.MakeFunctionPass(pass_arg, info)
+ return _ffi_api.CreatePrimFuncPass(pass_arg, info)
if pass_func:
return create_function_pass(pass_func)
diff --git a/python/tvm/tir/transform/transform.py
b/python/tvm/tir/transform/transform.py
index 6be4a38..c823c1a 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -17,6 +17,70 @@
"""Wrapping existing transformations."""
# pylint: disable=invalid-name
from . import _ffi_api
+from . import function_pass as _fpass
+
+
+def Apply(ftransform):
+ """Apply ftransform to each function in the Module.
+
+ This function is a thin wrapper around tvm.tir.transform.prim_func_pass
+
+ Parameters
+ ----------
+ ftransform: tvm.tir.PrimFunc -> tvm.tir.PrimFunc
+ The transformation pass.
+
+ Returns
+ -------
+ fpass : tvm.ir.transform.Pass
+ The result pass
+ """
+ # pylint: disable=unused-argument
+ def _transform(func, mod, ctx):
+ return ftransform(func)
+ return _fpass.prim_func_pass(_transform, opt_level=0)
+
+
+def Filter(fcond):
+ """Filter functions by the calling convention attribute.
+
+ Parameters
+ ----------
+ fcond : tvm.tir.PrimFunc -> bool
+ The condition of the filtering.
+
+ Returns
+ -------
+ fpass : tvm.ir.transform.Pass
+ The result pass
+ """
+ # pylint: disable=unused-argument
+ def _transform(func, mod, ctx):
+ return func if fcond(func) else None
+ return _fpass.prim_func_pass(_transform, opt_level=0)
+
+
+def BindDeviceType():
+ """Bind the device type of the function to be
+ the device_type specified in the target attribute.
+
+ Returns
+ -------
+ fpass : tvm.ir.transform.Pass
+ The result pass
+ """
+ return _ffi_api.BindDeviceType()
+
+
+def SplitHostDevice():
+ """Split the function into a host function and device functions.
+
+ Returns
+ -------
+ fpass : tvm.ir.transform.Pass
+ The result pass
+ """
+ return _ffi_api.SplitHostDevice()
def SkipAssert():
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index f59e764..d54d6f8 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -185,75 +185,50 @@ transform::Pass BindTarget(Target target) {
}
+template<typename FCond>
+transform::Pass FilterBy(FCond fcond) {
+ auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext
ctx) {
+ if (fcond(f)) {
+ return f;
+ } else {
+ return tir::PrimFunc(nullptr);
+ }
+ };
+ return tir::transform::CreatePrimFuncPass(fpass, 0, "FilterBy", {});
+}
+
+
std::pair<IRModule, IRModule>
split_dev_host_funcs(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config) {
- std::unordered_set<std::string> all_names;
- for (const auto& x : funcs) {
- CHECK(all_names.count(x->name) == 0)
- << "Duplicate function name " << x->name;
- all_names.insert(x->name);
- }
-
- Array<LoweredFunc> fhost;
- Array<LoweredFunc> fdevice;
-
for (const auto& x : funcs) {
CHECK(tir::VerifyMemory(x, target->device_type))
<< "Direct host side access to device memory is detected in "
<< x->func_name() << ". Did you forget to bind?";
-
- if (x->func_type == tir::kMixedFunc) {
- auto func = x;
- if (config->detect_global_barrier) {
- func = tir::ThreadSync(func, "global");
- }
-
- func = tir::ThreadSync(func, "shared");
- func = tir::ThreadSync(func, "warp");
- func = tir::InferFragment(func);
- func = tir::LowerThreadAllreduce(func, target->thread_warp_size);
- auto fsplits = tir::SplitHostDevice(func);
- fhost.push_back(fsplits[0]);
- for (auto f = fsplits.begin() + 1; f != fsplits.end(); ++f) {
- fdevice.push_back(*f);
- }
- } else if (x->func_type == tir::kHostFunc) {
- fhost.push_back(x);
- } else if (x->func_type == tir::kDeviceFunc) {
- fdevice.push_back(x);
- } else {
- LOG(FATAL) << "unknown function type " << x->func_type;
- }
- }
-
- auto keys = target->keys();
- bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") !=
keys.end();
- if (target_is_gpu && fdevice.size() == 0) {
- LOG(WARNING) << "Specified target "
- << target->str()
- << " but cannot find device code. Did you forget to bind?";
}
+ IRModule mod_mixed = codegen::ToIRModule(funcs);
- if (target->device_type == target::llvm()->device_type &&
- target_host == target) {
- CHECK(fdevice.empty()) << "No device code should be generated when target "
- << "and host_target are both llvm target."
- << "\n";
- }
-
- for (size_t i = 0; i < fhost.size(); ++i) {
- auto func = fhost[i];
- func = tir::BindDeviceType(func, target->device_type);
- fhost.Set(i, func);
+ Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target)};
+ if (config->detect_global_barrier) {
+ mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
}
+ mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
+ mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
+ mixed_pass_list.push_back(tir::transform::InferFragment());
+ mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
+ mixed_pass_list.push_back(tir::transform::BindDeviceType());
+ mixed_pass_list.push_back(tir::transform::SplitHostDevice());
+ auto opt_mixed = transform::Sequential(mixed_pass_list);
+ mod_mixed = opt_mixed(std::move(mod_mixed));
- // host pipeline
- auto mhost = codegen::ToIRModule(fhost);
auto host_pass_list = {
+ FilterBy([](const tir::PrimFunc& f) {
+ int64_t value = f->GetAttr<Integer>(tvm::attr::kCallingConv, 0)->value;
+ return value != static_cast<int>(CallingConv::kDeviceKernelLaunch);
+ }),
BindTarget(target_host),
tir::transform::LowerTVMBuiltin(),
tir::transform::LowerIntrin(),
@@ -261,18 +236,38 @@ split_dev_host_funcs(const Array<LoweredFunc>& funcs,
tir::transform::CombineContextCall(),
};
auto opt_host = transform::Sequential(host_pass_list);
- mhost = opt_host(mhost);
+ auto mhost = opt_host(mod_mixed);
// device pipeline
- auto mdevice = codegen::ToIRModule(fdevice);
auto device_pass_list = {
+ FilterBy([](const tir::PrimFunc& f) {
+ int64_t value = f->GetAttr<Integer>(tvm::attr::kCallingConv, 0)->value;
+ return value == static_cast<int>(CallingConv::kDeviceKernelLaunch);
+ }),
BindTarget(target),
tir::transform::LowerWarpMemory(),
tir::transform::LowerIntrin(),
tir::transform::LowerDeviceStorageAccessInfo(),
};
auto opt_device = transform::Sequential(device_pass_list);
- mdevice = opt_device(mdevice);
+ auto mdevice = opt_device(mod_mixed);
+
+ // some final misc checks.
+ auto keys = target->keys();
+ bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") !=
keys.end();
+ if (target_is_gpu && mdevice->functions.size() == 0) {
+ LOG(WARNING) << "Specified target "
+ << target->str()
+ << " but cannot find device code. Did you forget to bind?";
+ }
+
+ if (target->device_type == target::llvm()->device_type &&
+ target_host == target) {
+ CHECK(mdevice->functions.empty())
+ << "No device code should be generated when target "
+ << "and host_target are both llvm target."
+ << "\n";
+ }
return {mhost, mdevice};
}
diff --git a/src/printer/relay_text_printer.cc
b/src/printer/relay_text_printer.cc
index 56e77b7..bda997a 100644
--- a/src/printer/relay_text_printer.cc
+++ b/src/printer/relay_text_printer.cc
@@ -34,6 +34,7 @@
*/
#include <tvm/ir/type_functor.h>
#include <tvm/ir/module.h>
+#include <tvm/tir/function.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include "doc.h"
@@ -434,6 +435,10 @@ class RelayTextPrinter :
Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func) {
if (auto* n = base_func.as<relay::FunctionNode>()) {
return PrintFunc(prefix, GetRef<relay::Function>(n));
+ } else if (auto* n = base_func.as<tir::PrimFuncNode>()) {
+ std::ostringstream os;
+ os << GetRef<tir::PrimFunc>(n);
+ return Doc::RawText(os.str());
} else {
// def @xyz = meta['ExternalFunc'][id]
Doc doc;
@@ -455,8 +460,9 @@ class RelayTextPrinter :
}
// functions
for (const auto& kv : mod->functions) {
- dg_ = DependencyGraph::Create(&arena_, kv.second);
-
+ if (kv.second.as<relay::FunctionNode>()) {
+ dg_ = DependencyGraph::Create(&arena_, kv.second);
+ }
if (counter++ != 0) {
doc << Doc::NewLine();
}
diff --git a/src/target/codegen.cc b/src/target/codegen.cc
index a977d35..703328f 100644
--- a/src/target/codegen.cc
+++ b/src/target/codegen.cc
@@ -50,9 +50,10 @@ tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) {
Map<tir::Var, PrimExpr> remap_vars;
for (auto var : from->args) {
- if (from->handle_data_type.count(var)) {
+ auto it = from->handle_data_type.find(var);
+ if (it != from->handle_data_type.end()) {
tir::Var new_var(var->name_hint,
- PointerType(PrimType(var->dtype)));
+ PointerType(PrimType((*it).second->dtype)));
args.push_back(new_var);
remap_vars.Set(var, new_var);
} else {
diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc
index 70bcfe8..33a3e17 100644
--- a/src/target/llvm/codegen_cpu.cc
+++ b/src/target/llvm/codegen_cpu.cc
@@ -24,6 +24,7 @@
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/analysis.h>
#include <memory>
#include <unordered_map>
#include "codegen_cpu.h"
diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc
index f991e90..773c67d 100644
--- a/src/tir/ir/transform.cc
+++ b/src/tir/ir/transform.cc
@@ -108,8 +108,13 @@ IRModule PrimFuncPassNode::operator()(const IRModule& mod,
updates.push_back({it.first, updated_func});
}
}
+ // automatic removal of None
for (const auto& pair : updates) {
- updated_mod->Add(pair.first, pair.second, true);
+ if (pair.second.defined()) {
+ updated_mod->Add(pair.first, pair.second, true);
+ } else {
+ updated_mod->Remove(pair.first);
+ }
}
pass_ctx.Trace(updated_mod, pass_info, false);
return updated_mod;
diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc
index ff821fe..83db1a9 100644
--- a/src/tir/pass/ffi_api.cc
+++ b/src/tir/pass/ffi_api.cc
@@ -128,10 +128,7 @@ REGISTER_PASS(VectorizeLoop);
REGISTER_PASS(SkipVectorize);
REGISTER_PASS(UnrollLoop);
REGISTER_PASS(InjectCopyIntrin);
-REGISTER_PASS(ThreadSync);
REGISTER_PASS(MakeAPI);
-REGISTER_PASS(BindDeviceType);
-REGISTER_PASS(SplitHostDevice);
REGISTER_PASS(StorageRewrite);
REGISTER_PASS(CoProcSync);
REGISTER_PASS(LowerStorageAccessInfo);
@@ -141,7 +138,6 @@ REGISTER_PASS(InjectDoubleBuffer);
REGISTER_PASS(LoopPartition);
REGISTER_PASS(RemoveNoOp);
REGISTER_PASS(LiftAttrScope);
-REGISTER_PASS(LowerThreadAllreduce);
REGISTER_PASS(RemapThreadAxis);
REGISTER_PASS(LowerCustomDatatypes);
REGISTER_PASS(VerifyMemory);
@@ -150,7 +146,6 @@ REGISTER_PASS(DecorateDeviceScope);
REGISTER_PASS(InstrumentBoundCheckers);
REGISTER_PASS(VerifyCompactBuffer);
REGISTER_PASS(HoistIfThenElse);
-REGISTER_PASS(InferFragment)
REGISTER_PASS(NarrowDataType);
} // namespace tir
} // namespace tvm
diff --git a/src/tir/pass/make_api.cc b/src/tir/pass/make_api.cc
index f8eae64..861cd43 100644
--- a/src/tir/pass/make_api.cc
+++ b/src/tir/pass/make_api.cc
@@ -218,69 +218,6 @@ LoweredFunc MakeAPI(Stmt body,
return f;
}
-class DeviceTypeBinder: public StmtExprMutator {
- public:
- explicit DeviceTypeBinder(int device_type)
- : device_type_(device_type) {}
-
- Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == attr::device_context_type) {
- if (const VarNode* var = op->value.as<VarNode>()) {
- var_ = var;
- PrimExpr value = make_const(op->value.dtype(), device_type_);
- Stmt body = StmtExprMutator::VisitStmt_(op);
- var_ = nullptr;
- std::ostringstream os;
- os << "device_type need to be " << device_type_;
- return AssertStmtNode::make(op->value == value, os.str(), body);
- }
- }
- return StmtExprMutator::VisitStmt_(op);
- }
-
- Stmt VisitStmt_(const IfThenElseNode* op) final {
- // eager simplify if guard.
- Stmt res = StmtExprMutator::VisitStmt_(op);
- op = res.as<IfThenElseNode>();
- if (is_zero(op->condition)) {
- if (op->else_case.defined()) return op->else_case;
- return EvaluateNode::make(0);
- }
- if (is_one(op->condition)) {
- return op->then_case;
- }
- return res;
- }
-
- PrimExpr VisitExpr_(const NENode* op) final {
- // eager check NE for device check
- PrimExpr res = StmtExprMutator::VisitExpr_(op);
- op = res.as<NENode>();
- if (tir::ExprDeepEqual()(op->a, op->b)) {
- return make_const(op->dtype, false);
- }
- return res;
- }
-
- PrimExpr VisitExpr_(const VarNode* op) final {
- if (op == var_) {
- return make_const(op->dtype, device_type_);
- } else {
- return GetRef<PrimExpr>(op);
- }
- }
-
- public:
- const VarNode* var_{nullptr};
- int device_type_;
-};
-
-LoweredFunc BindDeviceType(LoweredFunc f,
- int device_type) {
- auto n = make_object<LoweredFuncNode>(*f.operator->());
- n->body = DeviceTypeBinder(device_type)(n->body);
- return LoweredFunc(n);
-}
} // namespace tir
} // namespace tvm
diff --git a/src/tir/transforms/bind_device_type.cc
b/src/tir/transforms/bind_device_type.cc
new file mode 100644
index 0000000..486f21c
--- /dev/null
+++ b/src/tir/transforms/bind_device_type.cc
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file bind_device_type.cc
+ * \brief Bind the device type according to the target field.
+ */
+#include <tvm/ir/transform.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/transform.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/target/target.h>
+#include <tvm/runtime/registry.h>
+
+namespace tvm {
+namespace tir {
+
+class DeviceTypeBinder: public StmtExprMutator {
+ public:
+ explicit DeviceTypeBinder(int device_type)
+ : device_type_(device_type) {}
+
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
+ if (op->attr_key == attr::device_context_type) {
+ if (const VarNode* var = op->value.as<VarNode>()) {
+ var_ = var;
+ PrimExpr value = make_const(op->value.dtype(), device_type_);
+ Stmt body = StmtExprMutator::VisitStmt_(op);
+ var_ = nullptr;
+ std::ostringstream os;
+ os << "device_type need to be " << device_type_;
+ return AssertStmtNode::make(op->value == value, os.str(), body);
+ }
+ }
+ return StmtExprMutator::VisitStmt_(op);
+ }
+
+ Stmt VisitStmt_(const IfThenElseNode* op) final {
+ // eager simplify if guard.
+ Stmt res = StmtExprMutator::VisitStmt_(op);
+ op = res.as<IfThenElseNode>();
+ if (is_zero(op->condition)) {
+ if (op->else_case.defined()) return op->else_case;
+ return EvaluateNode::make(0);
+ }
+ if (is_one(op->condition)) {
+ return op->then_case;
+ }
+ return res;
+ }
+
+ PrimExpr VisitExpr_(const NENode* op) final {
+ // eager check NE for device check
+ PrimExpr res = StmtExprMutator::VisitExpr_(op);
+ op = res.as<NENode>();
+ if (tir::ExprDeepEqual()(op->a, op->b)) {
+ return make_const(op->dtype, false);
+ }
+ return res;
+ }
+
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ if (op == var_) {
+ return make_const(op->dtype, device_type_);
+ } else {
+ return GetRef<PrimExpr>(op);
+ }
+ }
+
+ public:
+ const VarNode* var_{nullptr};
+ int device_type_;
+};
+
+namespace transform {
+
+Pass BindDeviceType() {
+ auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ auto target = f->GetAttr<Target>(tvm::attr::kTarget);
+ CHECK(target.defined())
+ << "BindDeviceType: Require the target attribute";
+ n->body = DeviceTypeBinder(target->device_type)(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.BindDeviceType", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.BindDeviceType")
+.set_body_typed(BindDeviceType);
+
+} // namespace transform
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/transforms/lower_thread_allreduce.cc
b/src/tir/transforms/lower_thread_allreduce.cc
index e7e89f8..c4df2dc 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -340,14 +340,6 @@ class ThreadAllreduceBuilder final : public
StmtExprMutator {
std::unordered_map<const VarNode *, Stmt> alloc_remap_;
};
-LoweredFunc
-LowerThreadAllreduce(LoweredFunc f, int warp_size) {
- CHECK_NE(f->func_type, kHostFunc);
- auto n = make_object<LoweredFuncNode>(*f.operator->());
- n->body = ThreadAllreduceBuilder(warp_size)(n->body);
- return LoweredFunc(n);
-}
-
namespace transform {
Pass LowerThreadAllreduce() {
@@ -356,10 +348,6 @@ Pass LowerThreadAllreduce() {
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined())
<< "LowerThreadAllreduce: Require the target attribute";
- auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
- CHECK(calling_conv.defined() &&
- calling_conv->value ==
static_cast<int>(CallingConv::kDeviceKernelLaunch))
- << "LowerThreadAllreeduce: expect calling_conv equals
CallingConv::kDeviceKernelLaunch";
n->body = ThreadAllreduceBuilder(target->thread_warp_size)(n->body);
return f;
};
diff --git a/src/tir/pass/split_host_device.cc
b/src/tir/transforms/split_host_device.cc
similarity index 61%
rename from src/tir/pass/split_host_device.cc
rename to src/tir/transforms/split_host_device.cc
index 519101f..838ad82 100644
--- a/src/tir/pass/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -21,18 +21,22 @@
* \file split_host_device.cc
* \brief Split device function from host.
*/
+#include <tvm/ir/transform.h>
#include <tvm/tir/expr.h>
-#include <tvm/tir/lowered_func.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/runtime/module.h>
+#include <tvm/target/target.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/runtime/container.h>
+
#include <unordered_map>
namespace tvm {
namespace tir {
// use/def analysis, also delete unreferenced lets
-class IRUseDefAnalysis : public StmtExprMutator {
+class VarUseDefAnalysis : public StmtExprMutator {
public:
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
@@ -156,8 +160,27 @@ class IRUseDefAnalysis : public StmtExprMutator {
std::unordered_map<const VarNode*, int> def_count_;
};
+
+Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
+ VarUseDefAnalysis m;
+ for (Var arg : args) {
+ m.use_count_[arg.get()] = 0;
+ }
+ m(stmt);
+ return m.undefined_;
+}
+
+
class HostDeviceSplitter : public StmtMutator {
public:
+ explicit HostDeviceSplitter(IRModuleNode* device_mod,
+ Target device_target,
+ std::string name_prefix)
+ : device_mod_(device_mod),
+ device_target_(device_target),
+ name_prefix_(name_prefix) {
+ }
+
Stmt VisitStmt_(const AllocateNode* op) final {
handle_data_type_[op->buffer_var.get()] = make_const(op->dtype, 0);
return StmtMutator::VisitStmt_(op);
@@ -172,86 +195,128 @@ class HostDeviceSplitter : public StmtMutator {
return StmtMutator::VisitStmt_(op);
}
- Array<LoweredFunc> Split(LoweredFunc f) {
- CHECK_EQ(f->func_type, kMixedFunc);
- for (auto kv : f->handle_data_type) {
- handle_data_type_[kv.first.get()] = kv.second;
- }
- name_ = f->name;
- ObjectPtr<LoweredFuncNode> n =
- make_object<LoweredFuncNode>(*f.operator->());
- n->body = operator()(f->body);
- n->func_type = kHostFunc;
- Array<LoweredFunc> ret{LoweredFunc(n)};
- for (LoweredFunc x : device_funcs_) {
- ret.push_back(x);
- }
- return ret;
- }
-
private:
Stmt SplitDeviceFunc(Stmt body) {
std::ostringstream os;
- os << name_ << "_kernel" << device_funcs_.size();
- ObjectPtr<LoweredFuncNode> n = make_object<LoweredFuncNode>();
+ os << name_prefix_ << "_kernel" << device_func_counter_++;
+ std::string kernel_symbol = os.str();
// isolate the device function.
- IRUseDefAnalysis m;
+ VarUseDefAnalysis m;
m.visit_thread_extent_ = false;
- n->body = m(std::move(body));
- n->name = os.str();
- n->func_type = kDeviceFunc;
- n->thread_axis = m.thread_axis_;
+ body = m(std::move(body));
+
+ Array<Var> params;
+ Array<PrimExpr> arguments;
+ Map<tir::Var, PrimExpr> remap_vars;
+
// Strictly order the arguments: Var pointers, positional arguments.
- for (Var v : m.undefined_) {
- if (v.dtype().is_handle()) {
- n->args.push_back(v);
- // mark handle data type.
- auto it = handle_data_type_.find(v.get());
+ for (Var var : m.undefined_) {
+ if (var.dtype().is_handle()) {
+ // Create a new version of v.
+ auto it = handle_data_type_.find(var.get());
if (it != handle_data_type_.end()) {
- n->handle_data_type.Set(v, it->second);
+ tir::Var new_var(var->name_hint,
+ PointerType(PrimType((*it).second->dtype)));
+ params.push_back(new_var);
+ remap_vars.Set(var, new_var);
+ } else {
+ params.push_back(var);
}
+ arguments.push_back(var);
}
}
- for (Var v : m.undefined_) {
- if (!v.dtype().is_handle()) {
- n->args.push_back(v);
+ // positional arguments
+ for (Var var : m.undefined_) {
+ if (!var.dtype().is_handle()) {
+ params.push_back(var);
+ arguments.push_back(var);
}
}
- LoweredFunc f_device(n);
+ PrimFunc device_func(params, Substitute(body, remap_vars));
+ device_func = WithAttr(std::move(device_func),
tir::attr::kDeviceThreadAxis, m.thread_axis_);
+ device_func = WithAttr(std::move(device_func), tvm::attr::kCallingConv,
+ Integer(CallingConv::kDeviceKernelLaunch));
+ device_func = WithAttr(std::move(device_func), tvm::attr::kGlobalSymbol,
+ runtime::String(kernel_symbol));
+ device_func = WithAttr(std::move(device_func), tir::attr::kNoAlias,
Integer(1));
+ device_func = WithAttr(std::move(device_func), tvm::attr::kTarget,
device_target_);
+ device_mod_->Add(GlobalVar(kernel_symbol), device_func);
+
+ // generate calls to the device function
Array<PrimExpr> call_args;
- call_args.push_back(StringImmNode::make(f_device->name));
- for (Var arg : n->args) {
+ call_args.push_back(StringImmNode::make(kernel_symbol));
+ for (PrimExpr arg : arguments) {
call_args.push_back(arg);
}
for (PrimExpr ext : m.thread_extent_) {
call_args.push_back(ext);
}
- device_funcs_.emplace_back(f_device);
return EvaluateNode::make(CallNode::make(
DataType::Int(32), intrinsic::tvm_call_packed,
call_args, CallNode::Intrinsic));
}
- // function name
- std::string name_;
- // the device functions
+ // target ir module
+ IRModuleNode* device_mod_;
+ // Device target
+ Target device_target_;
+ // function name hint
+ std::string name_prefix_;
+ // Number of device functions.
+ int device_func_counter_{0};
std::vector<LoweredFunc> device_funcs_;
std::unordered_map<const VarNode*, PrimExpr> handle_data_type_;
};
-Array<Var> UndefinedVars(const Stmt& stmt, const Array<Var>& args) {
- IRUseDefAnalysis m;
- for (Var arg : args) {
- m.use_count_[arg.get()] = 0;
- }
- m(stmt);
- return m.undefined_;
+PrimFunc SplitHostDevice(PrimFunc&& func, IRModuleNode* device_mod) {
+ auto target = func->GetAttr<Target>(tvm::attr::kTarget);
+ CHECK(target.defined())
+ << "SplitHostDevice: Require the target attribute";
+ auto global_symbol =
func->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+ CHECK(global_symbol.defined())
+ << "SplitHostDevice: Expect PrimFunc to have the global_symbol
attribute";
+
+ HostDeviceSplitter splitter(
+ device_mod, target, static_cast<std::string>(global_symbol));
+
+ auto* n = func.CopyOnWrite();
+ n->body = splitter(std::move(n->body));
+ // set the host target to None.
+ func = WithAttr(std::move(func), tvm::attr::kTarget, Target(nullptr));
+ return std::move(func);
}
-Array<LoweredFunc> SplitHostDevice(LoweredFunc func) {
- return HostDeviceSplitter().Split(func);
+
+
+namespace transform {
+
+Pass SplitHostDevice() {
+ auto pass_func = [](IRModule m, PassContext ctx) {
+ IRModuleNode* mptr = m.CopyOnWrite();
+ std::vector<std::pair<GlobalVar, PrimFunc> > updates;
+
+ for (const auto& kv : mptr->functions) {
+ if (auto* n = kv.second.as<PrimFuncNode>()) {
+ PrimFunc func = GetRef<PrimFunc>(n);
+ auto updated_func = SplitHostDevice(std::move(func), mptr);
+ updates.push_back({kv.first, updated_func});
+ }
+ }
+
+ for (const auto& pair : updates) {
+ mptr->Add(pair.first, pair.second, true);
+ }
+ return m;
+ };
+
+ return tvm::transform::CreateModulePass(
+ pass_func, 0, "tir.SplitHostDevice", {});
}
+TVM_REGISTER_GLOBAL("tir.transform.SplitHostDevice")
+.set_body_typed(SplitHostDevice);
+
+} // namespace transform
} // namespace tir
} // namespace tvm
diff --git a/src/tir/transforms/tensorcore_infer_fragment.cc
b/src/tir/transforms/tensorcore_infer_fragment.cc
index fad4233..1ece078 100644
--- a/src/tir/transforms/tensorcore_infer_fragment.cc
+++ b/src/tir/transforms/tensorcore_infer_fragment.cc
@@ -218,26 +218,19 @@ Stmt InferFragment(Stmt stmt) {
return stmt;
}
-LoweredFunc InferFragment(LoweredFunc f) {
- CHECK_NE(f->func_type, kHostFunc);
- auto n = make_object<LoweredFuncNode>(*f.operator->());
- n->body = InferFragment(f->body);
- return LoweredFunc(n);
-}
-
namespace transform {
-Pass InferFragement() {
+Pass InferFragment() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = InferFragment(std::move(n->body));
return f;
};
- return CreatePrimFuncPass(pass_func, 0, "tir.InferFragement", {});
+ return CreatePrimFuncPass(pass_func, 0, "tir.InferFragment", {});
}
-TVM_REGISTER_GLOBAL("tir.transform.InferFragement")
-.set_body_typed(InferFragement);
+TVM_REGISTER_GLOBAL("tir.transform.InferFragment")
+.set_body_typed(InferFragment);
} // namespace transform
} // namespace tir
diff --git a/src/tir/transforms/thread_storage_sync.cc
b/src/tir/transforms/thread_storage_sync.cc
index b631a62..f464af6 100644
--- a/src/tir/transforms/thread_storage_sync.cc
+++ b/src/tir/transforms/thread_storage_sync.cc
@@ -374,13 +374,6 @@ Stmt ThreadSync(Stmt stmt, std::string storage_scope) {
return ThreadSyncInserter(sync_scope,
planner.syncs_inserted_)(std::move(stmt));
}
-LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) {
- CHECK_NE(f->func_type, kHostFunc);
- auto n = make_object<LoweredFuncNode>(*f.operator->());
- n->body = ThreadSync(f->body, storage_scope);
- return LoweredFunc(n);
-}
-
namespace transform {
Pass ThreadSync(std::string storage_scope) {
diff --git a/tests/python/unittest/test_tir_pass_split_host_device.py
b/tests/python/unittest/test_tir_analysis_usedef.py
similarity index 98%
rename from tests/python/unittest/test_tir_pass_split_host_device.py
rename to tests/python/unittest/test_tir_analysis_usedef.py
index 09f7740..449a462 100644
--- a/tests/python/unittest/test_tir_pass_split_host_device.py
+++ b/tests/python/unittest/test_tir_analysis_usedef.py
@@ -28,7 +28,7 @@ def test_loop_dependent_allocate():
s[AA].compute_at(s[C], s[C].op.axis[0])
# this line should fail due to IRUseDefAnalysis sees an allocate statement
# referencing undefined variable
- tvm.lower(s, [A,C])
+ tvm.lower(s, [A, C])
if __name__ == "__main__":
test_loop_dependent_allocate()
diff --git a/tests/python/unittest/test_tir_pass_inject_double_buffer.py
b/tests/python/unittest/test_tir_pass_inject_double_buffer.py
index 0fe3f61..94e29c6 100644
--- a/tests/python/unittest/test_tir_pass_inject_double_buffer.py
+++ b/tests/python/unittest/test_tir_pass_inject_double_buffer.py
@@ -41,7 +41,9 @@ def test_double_buffer():
assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2
f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2,
True)
- f = tvm.tir.ir_pass.ThreadSync(f, "shared")
+ mod = tvm.testing.LoweredFuncsToIRModule([f])
+ f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
+
count = [0]
def count_sync(op):
if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync":
diff --git a/tests/python/unittest/test_tir_pass_storage_flatten.py
b/tests/python/unittest/test_tir_pass_storage_flatten.py
index e8a78cb..dbfcd20 100644
--- a/tests/python/unittest/test_tir_pass_storage_flatten.py
+++ b/tests/python/unittest/test_tir_pass_storage_flatten.py
@@ -93,7 +93,10 @@ def test_flatten_double_buffer():
assert isinstance(stmt.body.body, tvm.tir.Allocate)
assert stmt.body.body.extents[0].value == 2
f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2,
True)
- f = tvm.tir.ir_pass.ThreadSync(f, "shared")
+ f = tvm.tir.ir_pass.MakeAPI(stmt, "db", [A.asobject(), C.asobject()], 2,
True)
+ mod = tvm.testing.LoweredFuncsToIRModule([f])
+ f = tvm.tir.transform.ThreadSync("shared")(mod)["db"]
+
count = [0]
def count_sync(op):
if isinstance(op, tvm.tir.Call) and op.name == "tvm_storage_sync":
diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py
b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
index 66d3cfb..167899a 100644
--- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py
+++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py
@@ -33,16 +33,15 @@ def test_lower_warp_mem():
xo, xi = s[AA].split(s[AA].op.axis[0], 32)
s[AA].bind(xi, tx)
- f = tvm.lower(s, [A, B])
- fhost, fdevice = tvm.tir.ir_pass.SplitHostDevice(f)
-
- # temp adapter to convert loweredFunc to IRModule
- # to test passes in the new style.
- fname = fdevice.name
- mod = tvm.testing.LoweredFuncsToIRModule([fdevice])
cuda_target = tvm.target.create("cuda")
assert cuda_target.thread_warp_size == 32
- mod = tvm.IRModule.from_expr(mod[fname].with_attr("target", cuda_target))
+ f = tvm.lower(s, [A, B], name="f")
+
+
+ mod = tvm.testing.LoweredFuncsToIRModule([f])
+ mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target",
cuda_target))(mod)
+ fdevice = tvm.tir.transform.SplitHostDevice()(mod)["f_kernel0"]
+ mod = tvm.IRModule.from_expr(fdevice)
fdevice = tvm.tir.transform.LowerWarpMemory()(mod)["main"]
assert(fdevice.body.body.value.value == "local")
assert(fdevice.body.body.body.extents[0].value == 2)
diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py
b/tests/python/unittest/test_tir_transform_thread_sync.py
index e692e23..6c9e7f9 100644
--- a/tests/python/unittest/test_tir_transform_thread_sync.py
+++ b/tests/python/unittest/test_tir_transform_thread_sync.py
@@ -38,13 +38,13 @@ def test_thread_storage_sync():
A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
f = tvm.tir.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0, True)
- flist = tvm.tir.ir_pass.SplitHostDevice(f)
- f = flist[1]
- fname = f.name
- mod = tvm.testing.LoweredFuncsToIRModule([f])
+ cuda_target = tvm.target.create("cuda")
+ mod = tvm.testing.LoweredFuncsToIRModule([f])
+ mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target",
cuda_target))(mod)
+ fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"]
+ mod = tvm.IRModule.from_expr(fdevice)
cuda_target = tvm.target.create("cuda")
- mod = tvm.IRModule.from_expr(mod[fname].with_attr("target", cuda_target))
f = tvm.tir.transform.ThreadSync("shared")(mod)["main"]
body_list = tvm.tir.stmt_list(f.body.body.body.body)
assert(body_list[1].value.name == "tvm_storage_sync")