This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new c3511c5 [TIR][REFACTOR] Remove te::Tensor dependencies from TIR
passes. (#5372)
c3511c5 is described below
commit c3511c5e2c2f903606209c9eb6d56c2221570a24
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Apr 19 15:26:51 2020 -0700
[TIR][REFACTOR] Remove te::Tensor dependencies from TIR passes. (#5372)
* [TIR][REFACTOR] Remove te::Tensor dependencies from TIR passes.
te::Tensor is an useful object for tensor expression, but brings
un-necessary reverse dependency in TIR nodes such as Provide and Realize.
This PR is a first step to remove this dependency. We will use Buffer in
all the places
where the te::Tensor was used. The rough correspondence are:
- Provide -> BufferStore
- Realize -> BufferRealize
- HalideCall -> BufferLoad.
After this change, we can not use IRModule of PrimFuncs cleanly to
represent TIR
at any point of the optimizations. Buffer will serve as the abstraction for
the TIR data
models to represent the intermediate storages and their constraints.
We still keep Realize/HalideCall and Provide as TIR nodes for now to make
the change minimum.
Right after ScheduleOps, we call SchedulePostProcToPrimFunc to canonicalize
the temporary IR
generated by TE(which contains these nodes) to the TIR.
The TIR optimizations are now mostly migrated to to the pass manager.
Followup PRs are needed to migrate the remaining few passes.
* Fix dev tutorial
---
include/tvm/arith/bound.h | 14 +-
include/tvm/runtime/memory.h | 2 +-
include/tvm/te/schedule_pass.h | 21 +++
include/tvm/tir/expr.h | 15 +-
include/tvm/tir/ir_pass.h | 23 ---
include/tvm/tir/stmt.h | 119 +++++++++---
include/tvm/tir/stmt_functor.h | 5 +
include/tvm/tir/transform.h | 21 +++
python/tvm/autotvm/feature.py | 9 +-
python/tvm/driver/build_module.py | 102 ++++++----
python/tvm/ir/transform.py | 8 +-
python/tvm/tir/__init__.py | 3 +-
python/tvm/tir/stmt.py | 37 +++-
python/tvm/tir/transform/transform.py | 32 ++++
src/arith/domain_touched.cc | 38 ++--
src/driver/driver_api.cc | 53 ++----
src/te/operation/op_util.cc | 4 +-
src/te/schedule/schedule_postproc_to_primfunc.cc | 194 +++++++++++++++++++
src/tir/ir/expr.cc | 13 ++
src/tir/ir/stmt.cc | 87 +++++++--
src/tir/ir/stmt_functor.cc | 34 +++-
src/tir/pass/ffi_api.cc | 10 -
src/tir/{pass => transforms}/inject_prefetch.cc | 30 ++-
src/tir/{pass => transforms}/storage_flatten.cc | 209 +++++++++++++--------
tests/python/unittest/test_arith_domain_touched.py | 18 +-
tests/python/unittest/test_te_build_lower.py | 4 +-
tests/python/unittest/test_te_hybrid_script.py | 8 +-
tests/python/unittest/test_te_schedule.py | 2 +-
tests/python/unittest/test_te_schedule_ops.py | 19 +-
tests/python/unittest/test_te_tensor.py | 4 +-
.../unittest/test_tir_analysis_verify_memory.py | 29 +--
tests/python/unittest/test_tir_constructor.py | 4 +-
tests/python/unittest/test_tir_ir_builder.py | 6 +-
tests/python/unittest/test_tir_nodes.py | 4 +
.../test_tir_transform_inject_copy_intrin.py | 39 ++--
.../unittest/test_tir_transform_make_packed_api.py | 18 +-
.../unittest/test_tir_transform_narrow_datatype.py | 7 +-
...en.py => test_tir_transform_storage_flatten.py} | 40 ++--
.../unittest/test_tir_transform_storage_rewrite.py | 57 +++---
.../unittest/test_tir_transform_thread_sync.py | 12 +-
tutorials/dev/low_level_custom_pass.py | 12 +-
41 files changed, 935 insertions(+), 431 deletions(-)
diff --git a/include/tvm/arith/bound.h b/include/tvm/arith/bound.h
index 6165a2a..b1cb779 100644
--- a/include/tvm/arith/bound.h
+++ b/include/tvm/arith/bound.h
@@ -78,15 +78,15 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond,
/*!
* \brief Infer a regular domain that covers all the calls or provides within
the given statement.
* \param body The given statement.
- * \param tensor The name of the calls or provides.
- * \param consider_calls If calls (read) are considered.
- * \param consider_provides If provides (write) are considered.
+ * \param buffer The buffer to check the access info.
+ * \param consider_loads If loads are considered.
+ * \param consider_stores If stores are considered.
* \return The domain that covers all the calls or provides within the given
statement.
*/
-Domain DomainTouched(Stmt body,
- const te::Tensor &tensor,
- bool consider_calls,
- bool consider_provides);
+Domain DomainTouched(const Stmt& body,
+ const tir::Buffer& buffer,
+ bool consider_loads,
+ bool consider_stores);
} // namespace arith
} // namespace tvm
diff --git a/include/tvm/runtime/memory.h b/include/tvm/runtime/memory.h
index 121dbdd..b9b420a 100644
--- a/include/tvm/runtime/memory.h
+++ b/include/tvm/runtime/memory.h
@@ -70,7 +70,7 @@ class ObjAllocatorBase {
static_assert(std::is_base_of<Object, T>::value,
"make can only be used to create Object");
T* ptr = Handler::New(static_cast<Derived*>(this),
- std::forward<Args>(args)...);
+ std::forward<Args>(args)...);
ptr->type_index_ = T::RuntimeTypeIndex();
ptr->deleter_ = Handler::Deleter();
return ObjectPtr<T>(ptr);
diff --git a/include/tvm/te/schedule_pass.h b/include/tvm/te/schedule_pass.h
index b3ecbf8..e64ea21 100644
--- a/include/tvm/te/schedule_pass.h
+++ b/include/tvm/te/schedule_pass.h
@@ -29,6 +29,7 @@
#define TVM_TE_SCHEDULE_PASS_H_
#include <tvm/te/schedule.h>
+#include <tvm/tir/function.h>
namespace tvm {
namespace te {
@@ -55,6 +56,26 @@ Map<IterVar, Range> InferBound(const Schedule& sch);
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool
debug_keep_trivial_loop);
/*!
+ * \brief Postprocessing the Stmt generated by ScheduleOps to create
+ * a PrimFunc that can then be used for further TIR optimizations.
+ *
+ * Perform this translation before running any TIR optimizations.
+ *
+ * List of actions taken by the function:
+ * - Remove occurences of te::Tensor, te::Operation in the IR
+ * and replace them by corresponding IR nodes via tir::Buffer.
+ * - Add annotation of extern buffers using the buffer_map field
+ * in the PrimFunc type.
+ *
+ * \param arg_list Array of Tensor/Var/Buffer arguments to the function.
+ * \param body The body of the function.
+ * \param bindings potential Tensor to Buffer bindings for the Tensors in the
body.
+ */
+PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list,
+ Stmt body,
+ Optional<Map<Tensor, Buffer>> bindings);
+
+/*!
* \brief To automatically inline the element-wise operations.
*
* \param sch The schedule to be inlined.
diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h
index 6764178..bf0d4f9 100644
--- a/include/tvm/tir/expr.h
+++ b/include/tvm/tir/expr.h
@@ -694,7 +694,10 @@ class CallNode : public PrimExprNode {
ExternCPlusPlus = 1,
/*! \brief Extern "C" without side-effect. */
PureExtern = 2,
- /*! \brief Halide-style call, evaluates func(args). */
+ /*!
+ * \brief Halide-style call, evaluates func(args).
+ * \note Deprecated, move to BufferLoad in the future.
+ */
Halide = 3,
/*! \brief Intrinsic functions. */
Intrinsic = 4,
@@ -707,9 +710,15 @@ class CallNode : public PrimExprNode {
Array<PrimExpr> args;
/*! \brief Type of calls. */
CallType call_type;
- /*! \brief The function to be called. */
+ /*!
+ * \brief The function to be called.
+ * \note Deprecated, move to BufferLoad in the future.
+ */
FunctionRef func;
- /*! \brief The output value index if func's value is a tuple. */
+ /*!
+ * \brief The output value index if func's value is a tuple.
+ * \note Deprecated, move to BufferLoad in the future.
+ */
int value_index{0};
void VisitAttrs(AttrVisitor* v) {
diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h
index f3d447e..e6e2de6 100644
--- a/include/tvm/tir/ir_pass.h
+++ b/include/tvm/tir/ir_pass.h
@@ -165,22 +165,6 @@ Stmt Inline(Stmt stmt,
PrimExpr body);
/*!
- * \brief Flatten the multi-dimensional read/write
- * to single dimensional Load/Store
- *
- * \param stmt The stmt to be trasnformed.
- * \param extern_buffer Map specifies external
- * buffer assignment of input and outputs.
- * \param cache_line_size The size of CPU cache line.
- * \param create_bound_attribute Whether to create bound attributes.
- * \return Transformed stmt.
- */
-Stmt StorageFlatten(Stmt stmt,
- Map<te::Tensor, Buffer> extern_buffer,
- int cache_line_size,
- bool create_bound_attribute = false);
-
-/*!
* \brief Try to modify the AST to support TensorCore
*
* \param stmt The stmt to be trasnformed.
@@ -203,13 +187,6 @@ Stmt RewriteForTensorCore(Stmt stmt,
bool VerifyCompactBuffer(Stmt stmt);
/*!
- * \brief Inject prefetch instructions into stmt.
- * \param stmt The statement to be transformed.
- * \return Transformed stmt.
- */
-Stmt InjectPrefetch(Stmt stmt);
-
-/*!
* \brief Decorate the stmt with a device scope, this is helpful for
* hardware accelerator without thread blocks.
*
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 5bc492f..20c2d00 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -248,7 +248,6 @@ class StoreNode : public StmtNode {
* \endcode
* \sa BufferLoad
*/
-class BufferStore;
class BufferStoreNode : public StmtNode {
public:
/*! \brief The buffer variable. */
@@ -281,6 +280,10 @@ class BufferStoreNode : public StmtNode {
TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode);
};
+/*!
+ * \brief Managed reference to BufferStoreNode.
+ * \sa BufferStoreNode
+ */
class BufferStore : public Stmt {
public:
TVM_DLL explicit BufferStore(Buffer buffer,
@@ -290,7 +293,79 @@ class BufferStore : public Stmt {
};
/*!
+ * \brief Annotate the region where the buffer need to
+ * be read and write in the body.
+ * We only need to allocate the space for the corresponding region.
+ *
+ * \note There should be at most one BufferRealize for each buffer.
+ * BufferRealize is not necessary for external buffers,
+ * since they are assumed to be fully allocated.
+ *
+ * \sa BufferLoad, BufferStore
+ */
+class BufferRealizeNode : public StmtNode {
+ public:
+ /*! \brief The buffer variable. */
+ Buffer buffer;
+ /*! \brief Bounds to be realized */
+ Array<Range> bounds;
+ /*! \brief Only realize if condition holds. */
+ PrimExpr condition;
+ /*! \brief The body of realization. */
+ Stmt body;
+
+ void VisitAttrs(AttrVisitor* v) {
+ v->Visit("buffer", &buffer);
+ v->Visit("bounds", &bounds);
+ v->Visit("condition", &condition);
+ v->Visit("body", &body);
+ }
+
+ bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const
{
+ return
+ equal(buffer, other->buffer) &&
+ equal(bounds, other->bounds) &&
+ equal(condition, other->condition) &&
+ equal(body, other->body);
+ }
+
+ void SHashReduce(SHashReducer hash_reduce) const {
+ hash_reduce(buffer);
+ hash_reduce(bounds);
+ hash_reduce(condition);
+ hash_reduce(body);
+ }
+
+ BufferRealizeNode() = default;
+ BufferRealizeNode(Buffer buffer,
+ Array<Range> bounds,
+ PrimExpr condition,
+ Stmt body)
+ : buffer(buffer), bounds(bounds),
+ condition(condition), body(body) {}
+
+ static constexpr const char* _type_key = "BufferRealize";
+ TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode);
+};
+
+/*!
+ * \brief Managed reference to BufferRealizeNode.
+ * \sa BufferRealizeNode
+ */
+class BufferRealize : public Stmt {
+ public:
+ TVM_DLL explicit BufferRealize(Buffer buffer,
+ Array<Range> bounds,
+ PrimExpr condition,
+ Stmt body);
+
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BufferRealize, Stmt,
BufferRealizeNode);
+};
+
+/*!
* \brief Store value into mult-dimensional array defined by func.
+ *
+ * \note Deprecated, move to BufferStore in the future.
*/
class ProvideNode : public StmtNode {
public:
@@ -430,6 +505,8 @@ class FreeNode : public StmtNode {
/*!
* \brief Annotate the bounds where func need to be written and read in body.
* We will need to allocate space for the corresponding regions.
+ *
+ * \note Deprecated, move to BufferRealize in the future.
*/
class RealizeNode : public StmtNode {
public:
@@ -747,51 +824,51 @@ class ForNode : public StmtNode {
};
/*!
- * \brief A prefetch hint of func.
+ * \brief A prefetch hint for abuffer
*/
class PrefetchNode : public StmtNode {
public:
/*! \brief The function to be prefetched. */
- FunctionRef func;
- /*! \brief The output value index if func's value is a tuple. */
- int value_index;
- /*! \brief The data type of the array. */
- DataType dtype;
+ Buffer buffer;
/*! \brief Bounds to be prefetched. */
- Region bounds;
+ Array<Range> bounds;
void VisitAttrs(AttrVisitor* v) {
- v->Visit("func", &func);
- v->Visit("value_index", &value_index);
- v->Visit("dtype", &dtype);
+ v->Visit("buffer", &buffer);
v->Visit("bounds", &bounds);
}
bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const {
return
- equal(func, other->func) &&
- equal(value_index, other->value_index) &&
- equal(dtype, other->dtype) &&
+ equal(buffer, other->buffer) &&
equal(bounds, other->bounds);
}
void SHashReduce(SHashReducer hash_reduce) const {
- hash_reduce(func);
- hash_reduce(value_index);
- hash_reduce(dtype);
+ hash_reduce(buffer);
hash_reduce(bounds);
}
- TVM_DLL static Stmt make(FunctionRef func,
- int value_index,
- DataType dtype,
- Region bounds);
+ PrefetchNode() = default;
+ PrefetchNode(Buffer buffer, Array<Range> bounds)
+ : buffer(buffer), bounds(bounds) {}
static constexpr const char* _type_key = "Prefetch";
TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode);
};
/*!
+ * \brief Managed reference to PrefetchNode.
+ * \sa PrefetchNode
+ */
+class Prefetch : public Stmt {
+ public:
+ TVM_DLL explicit Prefetch(Buffer buffer, Array<Range> bounds);
+
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
+};
+
+/*!
* \brief Auxiliary data structure used in IR Pass to indicate a tensor.
*/
struct TensorKey {
diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h
index f93e908..a87ff97 100644
--- a/include/tvm/tir/stmt_functor.h
+++ b/include/tvm/tir/stmt_functor.h
@@ -92,6 +92,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const AllocateNode* op, Args... args)
STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferStoreNode* op, Args... args)
STMT_FUNCTOR_DEFAULT;
+ virtual R VisitStmt_(const BufferRealizeNode* op, Args... args)
STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AssertStmtNode* op, Args... args)
STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProvideNode* op, Args... args)
STMT_FUNCTOR_DEFAULT;
@@ -121,6 +122,8 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(PrefetchNode);
IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode);
IR_STMT_FUNCTOR_DISPATCH(EvaluateNode);
+ IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode);
+ IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode);
return vtable;
}
};
@@ -154,6 +157,7 @@ class TVM_DLL StmtVisitor :
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override;
+ void VisitStmt_(const BufferRealizeNode* op) override;
void VisitStmt_(const FreeNode* op) override;
void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const ProvideNode* op) override;
@@ -248,6 +252,7 @@ class TVM_DLL StmtMutator :
Stmt VisitStmt_(const AllocateNode* op) override;
Stmt VisitStmt_(const StoreNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
+ Stmt VisitStmt_(const BufferRealizeNode* op) override;
Stmt VisitStmt_(const FreeNode* op) override;
Stmt VisitStmt_(const AssertStmtNode* op) override;
Stmt VisitStmt_(const ProvideNode* op) override;
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index e593e1b..09ea097 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -58,6 +58,27 @@ TVM_DLL Pass CreatePrimFuncPass(const
runtime::TypedPackedFunc<
const std::string& name,
const tvm::Array<runtime::String>& required);
+
+/*!
+ * \brief Inject prefetch instructions into stmt.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass InjectPrefetch();
+
+// TODO(tvm-team): consolidate configs to the PassContext
+/*!
+ * \brief Flatten the multi-dimensional read/write
+ * to single dimensional Load/Store
+ *
+ * \param cache_line_size The size of CPU cache line.
+ * \param create_bound_attribute Whether to create bound attributes.
+ *
+ * \return The Pass
+ */
+TVM_DLL Pass StorageFlatten(int cache_line_size,
+ bool create_bound_attribute = false);
+
/*!
* \brief Inject copy intrinsics with optional pad.
*
diff --git a/python/tvm/autotvm/feature.py b/python/tvm/autotvm/feature.py
index c576ffd..0c0591c 100644
--- a/python/tvm/autotvm/feature.py
+++ b/python/tvm/autotvm/feature.py
@@ -31,7 +31,6 @@ import numpy as np
import tvm._ffi
from tvm import target as _target
-from tvm.tir import ir_pass
from tvm.te import schedule
from tvm.driver import build_module
@@ -46,10 +45,12 @@ def ana_lower(sch, args,
# Phase 0
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds, True)
- stmt = ir_pass.StorageFlatten(stmt, binds, 64)
- stmt = ir_pass.CanonicalSimplify(stmt)
+ func = schedule.SchedulePostProcToPrimFunc(args, stmt, None)
+ mod = tvm.IRModule.from_expr(func._move())
+ mod = tvm.tir.transform.StorageFlatten(64)(mod._move())
+ mod = tvm.tir.transform.Simplify()(mod._move())
assert simple_mode
- return stmt
+ return mod["main"].body
try:
_get_buffer_curve_sample_flatten = tvm._ffi.get_global_func(
diff --git a/python/tvm/driver/build_module.py
b/python/tvm/driver/build_module.py
index 18a8a47..eea3727 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -85,7 +85,8 @@ def get_binds(args, compact=False, binds=None):
def form_body(sch):
- """According to the given schedule, form the raw body
+ """According to the given schedule, form a function.
+
Parameters
----------
sch : tvm.te.schedule.Schedule
@@ -99,13 +100,31 @@ def form_body(sch):
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
- stmt = ir_pass.InjectPrefetch(stmt)
return stmt
+def _wrap_as_prim_func_pass(flist, name):
+ """Wrap flist as a function pass.
+
+ This is an temporary adapter before we fully
+ migrate to the new pass manager.
+ """
+ def _transform(func, *_):
+ stmt = func.body
+ for f in flist:
+ stmt = f(stmt)
+ # create a new function with updated body.
+ return tvm.tir.PrimFunc(func.params,
+ stmt,
+ func.ret_type,
+ func.buffer_map,
+ func.attrs)
+ return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name=name)
+
+
def lower(sch,
args,
- name="default_function",
+ name="main",
binds=None,
simple_mode=False):
"""Lowering step before build into target.
@@ -154,56 +173,57 @@ def lower(sch,
compact = ir_pass.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)
+ stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
+
+ # Start the new style pass manager.
+ func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)
+ func = func.with_attr("global_symbol", name)
+ if cfg.restricted_func:
+ func = func.with_attr("tir.noalias", True)
+ mod = tvm.IRModule({name: func})
# Phase 1
- stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
- stmt = ir_pass.StorageFlatten(stmt, binds, 64,
cfg.instrument_bound_checkers)
- stmt = ir_pass.NarrowDataType(stmt, 32)
- stmt = ir_pass.CanonicalSimplify(stmt)
- for f in lower_phase1:
- stmt = f(stmt)
+ pass_list = [
+ tvm.tir.transform.InjectPrefetch(),
+ tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers),
+ tvm.tir.transform.NarrowDataType(32),
+ tvm.tir.transform.Simplify(),
+ _wrap_as_prim_func_pass(lower_phase1, "Custom-Phase1"),
+ ]
# Phase 2
if not simple_mode:
- stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
- if cfg.disable_vectorize:
- stmt = ir_pass.SkipVectorize(stmt)
- else:
- stmt = ir_pass.VectorizeLoop(stmt)
- stmt = ir_pass.InjectVirtualThread(stmt)
- stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
- stmt = ir_pass.StorageRewrite(stmt)
- stmt = ir_pass.UnrollLoop(
- stmt,
- cfg.auto_unroll_max_step,
- cfg.auto_unroll_max_depth,
- cfg.auto_unroll_max_extent,
- cfg.unroll_explicit)
-
- for f in lower_phase2:
- stmt = f(stmt)
+ pass_list +=
[(tvm.tir.transform.LoopPartition(cfg.partition_const_loop))]
+
+ pass_list += [
+ tvm.tir.transform.VectorizeLoop(not cfg.disable_vectorize),
+ tvm.tir.transform.InjectVirtualThread(),
+ tvm.tir.transform.InjectDoubleBuffer(cfg.double_buffer_split_loop),
+ tvm.tir.transform.StorageRewrite(),
+ tvm.tir.transform.UnrollLoop(
+ cfg.auto_unroll_max_step,
+ cfg.auto_unroll_max_depth,
+ cfg.auto_unroll_max_extent,
+ cfg.unroll_explicit),
+ _wrap_as_prim_func_pass(lower_phase2, "Custom-Phase2"),
+ ]
# Phase 3
- stmt = ir_pass.Simplify(stmt)
- stmt = ir_pass.RemoveNoOp(stmt)
- if not cfg.disable_select_rewriting:
- stmt = ir_pass.RewriteUnsafeSelect(stmt)
+ pass_list += [
+ tvm.tir.transform.Simplify(),
+ tvm.tir.transform.RemoveNoOp(),
+ ]
- for f in lower_phase3:
- stmt = f(stmt)
+ if not cfg.disable_select_rewriting:
+ pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
+ pass_list += [_wrap_as_prim_func_pass(lower_phase3, "Custom-Phase3")]
# Instrument BoundCheckers
if cfg.instrument_bound_checkers:
- stmt = ir_pass.InstrumentBoundCheckers(stmt)
+ pass_list += [tvm.tir.transform.InstrumentBoundCheckers()]
- if simple_mode:
- return stmt
-
- f = tvm.tir.PrimFunc(arg_list, stmt).with_attr(
- "global_symbol", tvm.runtime.String(name))
- if cfg.restricted_func:
- f = f.with_attr("tir.noalias", True)
- mod = tvm.IRModule({name: f})
+ optimize = tvm.transform.Sequential(pass_list)
+ mod = optimize(mod)
return mod
diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py
index 614f969..af0be45 100644
--- a/python/tvm/ir/transform.py
+++ b/python/tvm/ir/transform.py
@@ -157,11 +157,6 @@ class Sequential(Pass):
"""A pass that works on a sequence of pass objects. Multiple passes can be
executed sequentially using this class.
- Some typical usage of the sequential pass are:
- 1. Users provide a list of passes for optimization.
- 2. Only an optimization level is provided so that the backend system has
- to glob all passes at this level and below to perform the optimizations.
-
Note that users can also provide a series of passes that they don't want to
apply when running a sequential pass. Pass dependency will be resolved in
the backend as well.
@@ -173,6 +168,9 @@ class Sequential(Pass):
opt_level : Optional[int]
The optimization level of this sequential pass.
+ The opt_level of a default sequential pass is set to 0.
+ Note that some of the passes within the Sequantial may still not be
executed
+ if their opt_level is higher than the provided opt_level.
name : Optional[str]
The name of the sequential pass.
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index d2238ad..ddfb6a5 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -28,7 +28,8 @@ from .expr import Select, BufferLoad, Load, Ramp, Broadcast,
Shuffle, Call, Let
from .expr import IterVar, Any
from .stmt import Stmt, LetStmt, AssertStmt, For
-from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free,
Realize, SeqStmt
+from .stmt import BufferStore, BufferRealize, Store, Provide, Allocate,
AttrStmt
+from .stmt import Free, Realize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
from .function import PrimFunc
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index c5b2a79..eee5b0b 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -161,6 +161,29 @@ class BufferStore(Stmt):
@tvm._ffi.register_object
+class BufferRealize(Stmt):
+ """Buffer realize node.
+
+ Parameters
+ ----------
+ buffer : Buffer
+ The buffer.
+
+ bounds : List[Range]
+ The value we to be stored.
+
+ condition : PrimExpr
+ The realize condition.
+
+ body : Stmt
+ The body of the statement.
+ """
+ def __init__(self, buffer, bounds, condition, body):
+ self.__init_handle_by_constructor__(
+ _ffi_api.BufferRealize, buffer, bounds, condition, body)
+
+
+@tvm._ffi.register_object
class Provide(Stmt):
"""Provide node.
@@ -348,21 +371,15 @@ class Prefetch(Stmt):
Parameters
----------
- func : Operation
- The operation to create the function.
-
- value_index : int
- The output value index
-
- dtype : str
- The data type to be prefetched.
+ buffer : Buffer
+ The buffer to be prefetched.
bounds : list of Range
The bounds to be prefetched.
"""
- def __init__(self, func, value_index, dtype, bounds):
+ def __init__(self, buffer, bounds):
self.__init_handle_by_constructor__(
- _ffi_api.Prefetch, func, value_index, dtype, bounds)
+ _ffi_api.Prefetch, buffer, bounds)
def stmt_seq(*args):
diff --git a/python/tvm/tir/transform/transform.py
b/python/tvm/tir/transform/transform.py
index f83bb11..bb39c1f 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -60,6 +60,38 @@ def Filter(fcond):
return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter")
+def InjectPrefetch():
+ """Inject prefetch instructions into stmt.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.InjectPrefetch()
+
+
+def StorageFlatten(cache_line_size, create_bound_attribute=False):
+ """Flatten the multi-dimensional read/write to 1D.
+
+
+ Parameters
+ ----------
+ cache_line_size: int
+ The size of CPU cache line.
+
+ create_bound_attribute:
+ Whether to create bound attributes.
+
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.StorageFlatten(cache_line_size, create_bound_attribute)
+
+
def InjectCopyIntrin(pragma_key, fintrin):
"""Inject virtual thread loops.
diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc
index bda70fb..2467e75 100644
--- a/src/arith/domain_touched.cc
+++ b/src/arith/domain_touched.cc
@@ -36,10 +36,14 @@ namespace arith {
using namespace tir;
// Find Read region of the tensor in the stmt.
-class FuncTouchedDomain final : public StmtExprVisitor {
+class BufferTouchedDomain final : public StmtExprVisitor {
public:
- FuncTouchedDomain(const te::Tensor &tensor, bool consider_calls, bool
consider_provides)
- : tensor_(tensor), consider_calls_(consider_calls),
consider_provides_(consider_provides) {}
+ BufferTouchedDomain(const Buffer &buffer,
+ bool consider_loads,
+ bool consider_stores)
+ : buffer_(buffer),
+ consider_loads_(consider_loads),
+ consider_stores_(consider_stores) {}
Domain Find(const Stmt& stmt) {
operator()(stmt);
@@ -80,18 +84,16 @@ class FuncTouchedDomain final : public StmtExprVisitor {
}
}
- void VisitExpr_(const CallNode* op) final {
- if (consider_calls_ && tensor_->op.same_as(op->func)
- && tensor_->value_index == op->value_index) {
- Touch(op->args);
+ void VisitExpr_(const BufferLoadNode* op) final {
+ if (consider_loads_ && buffer_.same_as(op->buffer)) {
+ Touch(op->indices);
}
StmtExprVisitor::VisitExpr_(op);
}
- void VisitStmt_(const ProvideNode* op) final {
- if (consider_provides_ && tensor_->op.same_as(op->func)
- && tensor_->value_index == op->value_index) {
- Touch(op->args);
+ void VisitStmt_(const BufferStoreNode* op) final {
+ if (consider_stores_ && buffer_.same_as(op->buffer)) {
+ Touch(op->indices);
}
StmtExprVisitor::VisitStmt_(op);
}
@@ -106,17 +108,17 @@ class FuncTouchedDomain final : public StmtExprVisitor {
}
}
- const te::Tensor &tensor_;
- bool consider_calls_, consider_provides_;
+ const Buffer &buffer_;
+ bool consider_loads_, consider_stores_;
std::vector<std::vector<IntSet> > bounds_;
std::unordered_map<const VarNode*, IntSet> dom_map_;
};
-Domain DomainTouched(Stmt stmt,
- const te::Tensor &tensor,
- bool consider_calls,
- bool consider_provides) {
- return FuncTouchedDomain(tensor, consider_calls,
consider_provides).Find(stmt);
+Domain DomainTouched(const Stmt& stmt,
+ const Buffer& buffer,
+ bool consider_loads,
+ bool consider_stores) {
+ return BufferTouchedDomain(buffer, consider_loads,
consider_stores).Find(stmt);
}
TVM_REGISTER_GLOBAL("arith.DomainTouched")
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index e38179e..c3802b1 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -130,35 +130,6 @@ transform::Pass Filter(FCond fcond) {
}
-IRModule BuildIRModule(const Array<ObjectRef>& out_arg_list,
- tir::Stmt stmt,
- const std::string& name,
- const BuildConfig& config) {
- Array<tir::Var> params;
- Map<tir::Var, tir::Buffer> buffer_map;
-
- for (auto var : out_arg_list) {
- if (auto* n = var.as<tir::VarNode>()) {
- params.push_back(GetRef<tir::Var>(n));
- } else {
- tir::Buffer buffer = Downcast<tir::Buffer>(var);
- tir::Var bptr(buffer->name, DataType::Handle());
- params.push_back(bptr);
- buffer_map.Set(bptr, buffer);
- }
- }
-
- auto f = tir::PrimFunc(params, stmt, VoidType(), buffer_map);
- f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
- if (config->restricted_func) {
- f = WithAttr(std::move(f), "tir.noalias", Integer(1));
- }
-
- return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-}
-
-
IRModule lower(te::Schedule sch,
const Array<te::Tensor>& args,
const std::string& name,
@@ -168,23 +139,31 @@ IRModule lower(te::Schedule sch,
sch = sch.normalize();
- // Phase 0
+ // Before TIR transformation.
auto bounds = te::InferBound(sch);
auto stmt = te::ScheduleOps(sch, bounds, false);
- stmt = tir::InjectPrefetch(stmt);
-
bool compact = tir::VerifyCompactBuffer(stmt);
+
Map<te::Tensor, tir::Buffer> out_binds;
GetBinds(args, compact, binds, &out_binds, &out_arg_list, config);
- // Phase 1
- stmt = tir::StorageFlatten(stmt, out_binds, 64,
- config->instrument_bound_checkers);
+ // build the function
+ tir::PrimFunc f = te::SchedulePostProcToPrimFunc(
+ out_arg_list, std::move(stmt), out_binds);
+ f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
+ if (config->restricted_func) {
+ f = WithAttr(std::move(f), "tir.noalias", Integer(1));
+ }
- // convert to IRModule.
- auto mod = BuildIRModule(out_arg_list, stmt, name, config);
+ auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
auto pass_list = Array<tvm::transform::Pass>();
+ // Phase 0
+ pass_list.push_back(tir::transform::InjectPrefetch());
+ pass_list.push_back(
+ tir::transform::StorageFlatten(64, config->instrument_bound_checkers));
+ // Phase 1
+ pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::LoopPartition(config->partition_const_loop));
pass_list.push_back(tir::transform::VectorizeLoop(!config->disable_vectorize));
diff --git a/src/te/operation/op_util.cc b/src/te/operation/op_util.cc
index 4ecfe94..d022134 100644
--- a/src/te/operation/op_util.cc
+++ b/src/te/operation/op_util.cc
@@ -132,8 +132,8 @@ MakeLoopNest(const Stage& stage,
for (size_t j = 0; j < it_attr->prefetch_data.size(); ++j) {
nest[i + 1].emplace_back(
AttrStmtNode::make(it_attr->prefetch_data[j],
- tir::attr::prefetch_scope,
- it_attr->prefetch_offset[j], no_op));
+ tir::attr::prefetch_scope,
+ it_attr->prefetch_offset[j], no_op));
}
}
} else if (bind_iv->thread_tag == "vthread" ||
diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc
b/src/te/schedule/schedule_postproc_to_primfunc.cc
new file mode 100644
index 0000000..bb52be4
--- /dev/null
+++ b/src/te/schedule/schedule_postproc_to_primfunc.cc
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file schedule_postproc_to_primfunc.cc
+ *
+ * \brief Translate the function body generated by ScheduleOps
+ * with te related dialects that incorporates Tensor
+ * into the Stmts to a PrimFunc.
+ *
+ * Perform this translation before running any TIR optimizations.
+ *
+ * Rationale: The body generated by ScheduleOps is not
+ * a formal PrimFunc and cannot be used for further optimization.
+ * This function canonicalize that body and creates a formal PrimFunc.
+ *
+ * List of actions taken by the function:
+ * - Remove occurences of te::Tensor, te::Operation in the IR
+ * and replace them by corresponding IR nodes via tir::Buffer.
+ * - Add annotation of extern buffers using the buffer_map field
+ * in the PrimFunc type.
+ */
+#include <tvm/runtime/registry.h>
+#include <tvm/runtime/container.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/te/operation.h>
+#include <utility>
+#include <unordered_map>
+
+namespace tvm {
+namespace te {
+
+// create a buffer for tensor.
+Buffer CreateBufferFor(const Tensor& tensor) {
+ std::string name = tensor->op->name;
+ if (tensor->op->num_outputs() != 1) {
+ name += ".v" + std::to_string(tensor->value_index);
+ }
+ Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name);
+ return buffer;
+}
+
+// A remapper that maps tensor to buffer
+class TensorToBufferMapper : public StmtExprMutator {
+ public:
+ explicit TensorToBufferMapper(std::unordered_map<Tensor, Buffer> buffer_map)
+ : buffer_map_(buffer_map) {
+ }
+
+ Stmt VisitStmt_(const AttrStmtNode* op) final {
+ auto ret = StmtExprMutator::VisitStmt_(op);
+ op = ret.as<AttrStmtNode>();
+ // TODO(tvm-team): remove realize_scope, turn the info into
+ // Buffer's scope field in this pass.
+ if (op->attr_key == tir::attr::realize_scope ||
+ op->attr_key == tir::attr::double_buffer_scope) {
+ Stmt body = op->body;
+ Operation operation = Downcast<Operation>(op->node);
+ for (int i = operation->num_outputs(); i != 0; --i) {
+ Buffer buffer = GetOrAllocBuffer(operation.output(i - 1));
+ body = AttrStmtNode::make(
+ buffer, op->attr_key, op->value, body);
+ }
+ return body;
+ } else if (op->attr_key == tir::attr::buffer_bind_scope) {
+ Array<ObjectRef> tuple = Downcast<Array<ObjectRef> >(op->node);
+ Tensor tensor = Downcast<Tensor>(tuple[1]);
+ return AttrStmtNode::make(
+ Array<ObjectRef>{tuple[0], GetOrAllocBuffer(tensor)},
+ op->attr_key, op->value, op->body);
+ } else if (op->attr_key == tir::attr::buffer_dim_align||
+ op->attr_key == tir::attr::prefetch_scope) {
+ Tensor tensor = Downcast<Tensor>(op->node);
+ Buffer buffer = GetOrAllocBuffer(tensor);
+ return AttrStmtNode::make(
+ buffer, op->attr_key, op->value, op->body);
+ } else {
+ return ret;
+ }
+ }
+
+ Stmt VisitStmt_(const RealizeNode* op) final {
+ Tensor tensor = Downcast<Operation>(op->func).output(op->value_index);
+ Buffer buffer = GetOrAllocBuffer(tensor);
+
+ auto ret = StmtExprMutator::VisitStmt_(op);
+ op = ret.as<RealizeNode>();
+
+ return BufferRealize(buffer, op->bounds, op->condition, op->body);
+ }
+
+ Stmt VisitStmt_(const ProvideNode* op) final {
+ Tensor tensor = Downcast<Operation>(op->func).output(op->value_index);
+ Buffer buffer = GetBuffer(tensor);
+
+ auto ret = StmtExprMutator::VisitStmt_(op);
+ op = ret.as<ProvideNode>();
+
+ return BufferStore(buffer, op->value, op->args);
+ }
+
+ PrimExpr VisitExpr_(const CallNode* op) final {
+ auto ret = StmtExprMutator::VisitExpr_(op);
+ op = ret.as<CallNode>();
+
+ if (op->call_type == CallNode::Halide) {
+ Tensor tensor = Downcast<Operation>(op->func).output(op->value_index);
+ Buffer buffer = GetBuffer(tensor);
+ return tir::BufferLoad(buffer, op->args);
+ } else {
+ return ret;
+ }
+ }
+
+ private:
+ Buffer GetOrAllocBuffer(const Tensor& tensor) {
+ return GetBuffer(tensor, true);
+ }
+
+ Buffer GetBuffer(const Tensor& tensor, bool allow_alloc = false) {
+ auto it = buffer_map_.find(tensor);
+ if (it != buffer_map_.end()) return it->second;
+ CHECK(allow_alloc) << "Cannot find the Realization point of tensor " <<
tensor;
+
+ auto buffer = CreateBufferFor(tensor);
+ buffer_map_[tensor] = buffer;
+ return buffer;
+ }
+
+ // maps tensor to buffer.
+ std::unordered_map<Tensor, Buffer> buffer_map_;
+};
+
+
+PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list,
+ Stmt body,
+ Optional<Map<Tensor, Buffer>>
extern_buffer_opt) {
+ std::unordered_map<Tensor, Buffer> extern_buffer;
+
+ if (extern_buffer_opt.defined()) {
+ auto v = extern_buffer_opt.value();
+ extern_buffer = std::unordered_map<Tensor, Buffer>(v.begin(), v.end());
+ }
+
+ Array<tir::Var> params;
+ Map<tir::Var, tir::Buffer> buffer_map;
+
+ for (auto var : arg_list) {
+ if (auto* n = var.as<tir::VarNode>()) {
+ params.push_back(GetRef<tir::Var>(n));
+ } else if (auto* n = var.as<te::TensorNode>()) {
+ te::Tensor tensor = GetRef<te::Tensor>(n);
+ CHECK(!extern_buffer.count(tensor));
+
+ tir::Buffer buffer = CreateBufferFor(tensor);
+ tir::Var bptr(buffer->name, DataType::Handle());
+ params.push_back(bptr);
+ buffer_map.Set(bptr, buffer);
+ extern_buffer[tensor] = buffer;
+ } else {
+ tir::Buffer buffer = Downcast<tir::Buffer>(var);
+ tir::Var bptr(buffer->name, DataType::Handle());
+ params.push_back(bptr);
+ buffer_map.Set(bptr, buffer);
+ }
+ }
+
+ body = TensorToBufferMapper(std::move(extern_buffer))(std::move(body));
+ return tir::PrimFunc(params, body, VoidType(), buffer_map);
+}
+
+TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc")
+.set_body_typed(SchedulePostProcToPrimFunc);
+
+} // namespace te
+} // namespace tvm
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 65d424e..03925ec 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -646,6 +646,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+.set_dispatch<BufferLoadNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const BufferLoadNode*>(node.get());
+ p->stream << op->buffer->name << "[";
+ for (size_t i = 0; i < op->indices.size(); ++i) {
+ p->Print(op->indices[i]);
+ if (i < op->indices.size() - 1) {
+ p->stream << ", ";
+ }
+ }
+ p->stream << "]";
+ });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LetNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const LetNode*>(node.get());
p->stream << "(let " << op->var << " = ";
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 1f6a7dd..f8e82ea 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -253,24 +253,14 @@ TVM_REGISTER_GLOBAL("tir.Realize")
.set_body_typed(RealizeNode::make);
-Stmt PrefetchNode::make(FunctionRef func, int value_index, DataType dtype,
Region bounds) {
- for (size_t i = 0; i < bounds.size(); ++i) {
- CHECK(bounds[i]->min.defined());
- CHECK(bounds[i]->extent.defined());
- CHECK(bounds[i]->min.dtype().is_scalar());
- CHECK(bounds[i]->extent.dtype().is_scalar());
- }
-
- ObjectPtr<PrefetchNode> node = make_object<PrefetchNode>();
- node->func = std::move(func);
- node->value_index = value_index;
- node->dtype = dtype;
- node->bounds = std::move(bounds);
- return Stmt(node);
+Prefetch::Prefetch(Buffer buffer, Array<Range> bounds) {
+ data_ = make_object<PrefetchNode>(buffer, bounds);
}
TVM_REGISTER_GLOBAL("tir.Prefetch")
-.set_body_typed(PrefetchNode::make);
+.set_body_typed([](Buffer buffer, Array<Range> bounds) {
+ return Prefetch(buffer, bounds);
+});
SeqStmt::SeqStmt(Array<Stmt> seq) {
@@ -326,6 +316,25 @@ TVM_REGISTER_GLOBAL("tir.BufferStore")
TVM_REGISTER_NODE_TYPE(BufferStoreNode);
+
+BufferRealize::BufferRealize(Buffer buffer,
+ Array<Range> bounds,
+ PrimExpr condition,
+ Stmt body) {
+ data_ = make_object<BufferRealizeNode>(
+ buffer, bounds, condition, body);
+}
+
+TVM_REGISTER_GLOBAL("tir.BufferRealize")
+.set_body_typed([](Buffer buffer,
+ Array<Range> bounds,
+ PrimExpr condition,
+ Stmt body) {
+ return BufferRealize(buffer, bounds, condition, body);
+});
+
+TVM_REGISTER_NODE_TYPE(BufferRealizeNode);
+
// Printers
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
@@ -433,6 +442,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+.set_dispatch<BufferStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const BufferStoreNode*>(node.get());
+ p->PrintIndent();
+ p->stream << op->buffer->name << "[";
+ for (size_t i = 0; i < op->indices.size(); ++i) {
+ p->Print(op->indices[i]);
+ if (i < op->indices.size() - 1) p->stream << ", ";
+ }
+ p->stream << "]";
+ p->stream << " = ";
+ p->Print(op->value);
+ p->stream << '\n';
+ });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<AllocateNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const AllocateNode*>(node.get());
p->PrintIndent();
@@ -459,6 +483,34 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+.set_dispatch<BufferRealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
+ auto* op = static_cast<const BufferRealizeNode*>(node.get());
+ p->PrintIndent();
+ p->stream << "buffer_realize " << op->buffer->name << "(";
+ for (size_t i = 0; i < op->bounds.size(); ++i) {
+ p->stream << "[";
+ p->Print(op->bounds[i]->min);
+ p->stream << ", ";
+ p->Print(op->bounds[i]->extent);
+ p->stream << "]";
+ if (i < op->bounds.size() - 1) p->stream << ", ";
+ }
+ p->stream << ")";
+ if (!is_one(op->condition)) {
+ p->stream << " if ";
+ p->Print(op->condition);
+ }
+ p->stream << " {\n";
+
+ p->indent += 2;
+ p->Print(op->body);
+ p->indent -= 2;
+
+ p->PrintIndent();
+ p->stream << "}\n";
+ });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RealizeNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const RealizeNode*>(node.get());
p->PrintIndent();
@@ -493,7 +545,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PrefetchNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const PrefetchNode*>(node.get());
p->PrintIndent();
- p->stream << "prefetch " << op->func->func_name() << "(";
+ p->stream << "prefetch " << op->buffer << "(";
for (size_t i = 0; i < op->bounds.size(); ++i) {
p->stream << "[";
p->Print(op->bounds[i]->min);
@@ -503,9 +555,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
if (i < op->bounds.size() - 1) p->stream << ", ";
}
p->stream << ")";
- if (op->func->num_outputs() != 1) {
- p->stream << ".value[" << op->value_index << "]";
- }
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index ed3c2c7..5e584eb 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -158,9 +158,19 @@ void StmtVisitor::VisitStmt_(const StoreNode* op) {
}
void StmtVisitor::VisitStmt_(const BufferStoreNode* op) {
+ this->VisitExpr(op->value);
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
}
+void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) {
+ VisitArray(op->bounds, [this](const Range& r) {
+ this->VisitExpr(r->min);
+ this->VisitExpr(r->extent);
+ });
+ this->VisitExpr(op->condition);
+ this->VisitStmt(op->body);
+}
+
void StmtVisitor::VisitStmt_(const IfThenElseNode* op) {
this->VisitExpr(op->condition);
this->VisitStmt(op->then_case);
@@ -336,16 +346,38 @@ Stmt StmtMutator::VisitStmt_(const StoreNode* op) {
}
Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) {
+ PrimExpr value = this->VisitExpr(op->value);
Array<PrimExpr> indices = Internal::Mutate(this, op->indices);
- if (indices.same_as(op->indices)) {
+
+ if (value.same_as(op->value) &&
+ indices.same_as(op->indices)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
+ n->value = std::move(value);
n->indices = std::move(indices);
return Stmt(n);
}
}
+Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) {
+ Region bounds = Internal::Mutate(this, op->bounds);
+ PrimExpr condition = this->VisitExpr(op->condition);
+ Stmt body = this->VisitStmt(op->body);
+
+ if (bounds.same_as(op->bounds) &&
+ condition.same_as(op->condition) &&
+ body.same_as(op->body)) {
+ return GetRef<Stmt>(op);
+ } else {
+ auto n = CopyOnWrite(op);
+ n->bounds = std::move(bounds);
+ n->condition = std::move(condition);
+ n->body = std::move(body);
+ return Stmt(n);
+ }
+}
+
Stmt StmtMutator::VisitStmt_(const ProvideNode* op) {
Array<PrimExpr> args = Internal::Mutate(this, op->args);
PrimExpr value = this->VisitExpr(op->value);
diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc
index 65981b9..4d7ed5d 100644
--- a/src/tir/pass/ffi_api.cc
+++ b/src/tir/pass/ffi_api.cc
@@ -75,15 +75,6 @@ TVM_REGISTER_GLOBAL("ir_pass.Substitute")
}
});
-TVM_REGISTER_GLOBAL("ir_pass.StorageFlatten")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
- if (args.size() <= 3) {
- *ret = StorageFlatten(args[0], args[1], args[2]);
- } else {
- *ret = StorageFlatten(args[0], args[1], args[2], args[3]);
- }
- });
-
TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore")
.set_body_typed
([](const Stmt& stmt,
@@ -116,7 +107,6 @@ REGISTER_PASS(ConvertSSA);
REGISTER_PASS(VerifySSA);
REGISTER_PASS(Inline);
REGISTER_PASS(IRTransform);
-REGISTER_PASS(InjectPrefetch);
REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS(DecorateDeviceScope);
REGISTER_PASS(VerifyCompactBuffer);
diff --git a/src/tir/pass/inject_prefetch.cc
b/src/tir/transforms/inject_prefetch.cc
similarity index 79%
rename from src/tir/pass/inject_prefetch.cc
rename to src/tir/transforms/inject_prefetch.cc
index 894ff38..e9dae0a 100644
--- a/src/tir/pass/inject_prefetch.cc
+++ b/src/tir/transforms/inject_prefetch.cc
@@ -21,9 +21,12 @@
* \file inject_prefetch.cc
*/
// Inject prefetch op in HalideIR
+#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
+#include <tvm/arith/bound.h>
#include <tvm/arith/analyzer.h>
#include <unordered_set>
@@ -39,9 +42,9 @@ class PrefetchInjector : public StmtMutator {
Stmt ret = StmtMutator::VisitStmt_(op);
op = ret.as<AttrStmtNode>();
if (op && op->attr_key == attr::prefetch_scope) {
- te::Tensor ts = Downcast<te::Tensor>(op->node);
+ Buffer buffer = Downcast<Buffer>(op->node);
CHECK_NE(loop_nest_.size(), 0U);
- Domain domain = DomainTouched(op->body, ts, true, false);
+ Domain domain = DomainTouched(op->body, buffer, true, false);
Region region;
auto iter_var = loop_nest_.back().get();
@@ -49,7 +52,7 @@ class PrefetchInjector : public StmtMutator {
for (Range r : domain) {
if (!r.defined()) {
- LOG(WARNING) << "Cannot decide prefetch region for " << ts;
+ LOG(WARNING) << "Cannot decide prefetch region for " << buffer;
return op->body;
}
Range res(EvalSet(r, vectorized_).cover_range(none));
@@ -58,7 +61,7 @@ class PrefetchInjector : public StmtMutator {
vectorized_.erase(iter_var);
- Stmt prefetch = PrefetchNode::make(ts->op, ts->value_index, ts->dtype,
region);
+ Stmt prefetch = Prefetch(buffer, region);
return SeqStmt({prefetch, op->body});
}
return ret;
@@ -90,5 +93,22 @@ Stmt InjectPrefetch(Stmt stmt) {
return PrefetchInjector()(std::move(stmt));
}
+
+namespace transform {
+
+Pass InjectPrefetch() {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+ n->body = PrefetchInjector()(std::move(n->body));
+ return f;
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.InjectPrefetch", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.InjectPrefetch")
+.set_body_typed(InjectPrefetch);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
diff --git a/src/tir/pass/storage_flatten.cc
b/src/tir/transforms/storage_flatten.cc
similarity index 77%
rename from src/tir/pass/storage_flatten.cc
rename to src/tir/transforms/storage_flatten.cc
index f9533fa..99d437d 100644
--- a/src/tir/pass/storage_flatten.cc
+++ b/src/tir/transforms/storage_flatten.cc
@@ -19,22 +19,24 @@
/*!
* \file storage_flatten.cc
+ * \brief Flattens storage from multi-dimensional array to 1D buffer access
*/
-// Flattens storage from multi-dimensional array to 1D
-// buffer access as in Halide pipeline.
+// The pass definition originates from Halide pipeline.
+
+#include <tvm/runtime/registry.h>
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/te/operation.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/op.h>
-#include <tvm/tir/ir_pass.h>
+#include <tvm/tir/transform.h>
#include <tvm/tir/buffer.h>
#include <tvm/target/target_info.h>
#include <tvm/runtime/device_api.h>
#include <unordered_map>
-#include "ir_util.h"
-#include "arg_binder.h"
+#include "../pass/ir_util.h"
+#include "../pass/arg_binder.h"
#include "../../arith/compute_expr.h"
#include "../../arith/ir_visitor_with_analyzer.h"
#include "../../runtime/thread_storage_scope.h"
@@ -49,16 +51,17 @@ using intrinsic::tvm_address_of;
class StorageFlattener : public StmtExprMutator {
public:
- explicit StorageFlattener(Map<te::Tensor, Buffer> extern_buffer,
- int cache_line_size, bool create_bound_attributes,
- IRVisitorWithAnalyzer* bounded_analyzer)
- : bounded_analyzer_(bounded_analyzer),
+ explicit StorageFlattener(const Map<Var, Buffer>& extern_buffer_map,
+ int cache_line_size,
+ bool create_bound_attributes,
+ IRVisitorWithAnalyzer* bound_analyzer)
+ : bound_analyzer_(bound_analyzer),
create_bound_attributes_(create_bound_attributes) {
- for (auto kv : extern_buffer) {
+ for (auto kv : extern_buffer_map) {
BufferEntry e;
e.buffer = kv.second;
e.external = true;
- buf_map_[TensorKey{kv.first->op, kv.first->value_index}] = e;
+ buf_map_[kv.second] = e;
}
cache_line_size_ = cache_line_size;
}
@@ -82,17 +85,14 @@ class StorageFlattener : public StmtExprMutator {
storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
return this->VisitStmt(op->body);
} else if (op->attr_key == attr::double_buffer_scope &&
- op->node->IsInstance<te::OperationNode>()) {
- auto func = Downcast<te::Operation>(op->node);
+ op->node->IsInstance<tir::BufferNode>()) {
+ auto buffer = Downcast<tir::Buffer>(op->node);
Stmt body = this->VisitStmt(op->body);
- for (int i = 0; i < func->num_outputs(); ++i) {
- TensorKey key{func, i};
- auto it = buf_map_.find(key);
- CHECK(it != buf_map_.end())
- << "Cannot find allocated buffer for " << key.f;
- body = AttrStmtNode::make(
- it->second.buffer->data, op->attr_key, op->value, body);
- }
+ auto it = buf_map_.find(buffer);
+ CHECK(it != buf_map_.end())
+ << "Cannot find allocated buffer for " << buffer;
+ body = AttrStmtNode::make(
+ it->second.buffer->data, op->attr_key, op->value, std::move(body));
return body;
} else if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
@@ -104,11 +104,10 @@ class StorageFlattener : public StmtExprMutator {
} else if (op->attr_key == attr::buffer_bind_scope) {
return HandleBufferBindScope(op);
} else if (op->attr_key == attr::buffer_dim_align) {
- auto tensor = Downcast<te::Tensor>(op->node);
+ auto buffer = Downcast<tir::Buffer>(op->node);
const CallNode* tuple = op->value.as<CallNode>();
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
- TensorKey key{tensor->op, tensor->value_index};
- auto& vinfo = dim_align_[key];
+ auto& vinfo = dim_align_[buffer];
int dim = tuple->args[0].as<IntImmNode>()->value;
if (static_cast<size_t>(dim) >= vinfo.size()) {
vinfo.resize(dim + 1);
@@ -122,18 +121,21 @@ class StorageFlattener : public StmtExprMutator {
return StmtExprMutator::VisitStmt_(op);
}
- Stmt VisitStmt_(const ProvideNode* op) final {
- if (create_bound_attributes_)
- shape_collector_.clear();
+ Stmt VisitStmt_(const BufferStoreNode* op) final {
+ if (create_bound_attributes_) shape_collector_.clear();
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<ProvideNode>();
- TensorKey key{op->func, op->value_index};
+ op = stmt.as<BufferStoreNode>();
+
+ const auto& key = op->buffer;
+
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end())
- << "Cannot find allocated buffer for " << key.f;
+ << "Cannot find allocated buffer for " << key;
+
const BufferEntry& e = it->second;
CHECK(!e.released)
<< "Read a buffer that is already out of scope";
+
if (is_opengl_) {
return EvaluateNode::make(CallNode::make(
DataType(),
@@ -141,7 +143,7 @@ class StorageFlattener : public StmtExprMutator {
{e.buffer->data, op->value},
CallNode::Intrinsic));
} else {
- Stmt body = e.buffer.vstore(e.RelIndex(op->args), op->value);
+ Stmt body = e.buffer.vstore(e.RelIndex(op->indices), op->value);
if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
shape_collector_.push_back(
std::make_pair(e.buffer->data, e.buffer->shape));
@@ -158,8 +160,9 @@ class StorageFlattener : public StmtExprMutator {
}
}
- Stmt VisitStmt_(const RealizeNode* op) final {
- TensorKey key{op->func, op->value_index};
+ Stmt VisitStmt_(const BufferRealizeNode* op) final {
+ const auto& key = op->buffer;
+
if (buf_map_.count(key)) {
CHECK(buf_map_.at(key).external);
return this->VisitStmt(op->body);
@@ -172,10 +175,9 @@ class StorageFlattener : public StmtExprMutator {
shape.push_back(r->extent);
}
// deduce current storage scope.
- auto it = storage_scope_.find(op->func.get());
+ auto it = storage_scope_.find(op->buffer.get());
CHECK(it != storage_scope_.end())
- << "Cannot find storage scope of " << op->func
- << " value_index=" << op->value_index;
+ << "Cannot find storage scope of " << op->buffer;
StorageScope skey;
const std::string& strkey = it->second;
if (strkey.length() == 0) {
@@ -188,13 +190,14 @@ class StorageFlattener : public StmtExprMutator {
}
// use small alignment for small arrays
+ auto dtype = op->buffer->dtype;
int32_t const_size = AllocateNode::constant_allocation_size(shape);
- int align = GetTempAllocaAlignment(op->dtype, const_size);
+ int align = GetTempAllocaAlignment(dtype, const_size);
if (skey.tag.length() != 0) {
MemoryInfo info = GetMemoryInfo(skey.to_string());
if (info.defined()) {
- align = (info->max_simd_bits + op->dtype.bits() - 1) /
op->dtype.bits();
- CHECK_LE(const_size * op->dtype.bits(), info->max_num_bits)
+ align = (info->max_simd_bits + dtype.bits() - 1) / dtype.bits();
+ CHECK_LE(const_size * dtype.bits(), info->max_num_bits)
<< "Allocation exceed bound of memory tag " << skey.to_string();
}
}
@@ -210,7 +213,7 @@ class StorageFlattener : public StmtExprMutator {
PrimExpr factor = make_const(stride.dtype(),
avec[dim].align_factor);
PrimExpr offset = make_const(stride.dtype(),
avec[dim].align_offset);
stride = stride + indexmod(factor + offset - indexmod(stride,
factor), factor);
- stride = tir::Simplify(stride);
+ stride = bound_analyzer_->Simplify(stride);
}
rstrides.push_back(stride);
stride = stride * shape[dim];
@@ -219,9 +222,9 @@ class StorageFlattener : public StmtExprMutator {
}
e.buffer = BufferNode::make(
- Var(key.GetName(), DataType::Handle()),
- op->dtype, shape, strides, PrimExpr(),
- key.GetName(), skey.to_string(),
+ Var(op->buffer->data->name_hint, DataType::Handle()),
+ op->buffer->dtype, shape, strides, PrimExpr(),
+ op->buffer->name, skey.to_string(),
align, 0, kDefault);
buf_map_[key] = e;
@@ -285,36 +288,36 @@ class StorageFlattener : public StmtExprMutator {
}
}
- PrimExpr VisitExpr_(const CallNode* op) final {
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<CallNode>();
- if (op != nullptr && op->call_type == CallNode::Halide) {
- TensorKey key{op->func, op->value_index};
- auto it = buf_map_.find(key);
- CHECK(it != buf_map_.end())
- << "Cannot find allocated buffer for " << key.f;
- const BufferEntry& e = it->second;
- CHECK(!e.released)
- << "Read a buffer that is already out of scope";
+ op = expr.as<BufferLoadNode>();
- if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
+ const auto& key = op->buffer;
+
+ auto it = buf_map_.find(key);
+ CHECK(it != buf_map_.end())
+ << "Cannot find allocated buffer for " << key;
+ const BufferEntry& e = it->second;
+ CHECK(!e.released)
+ << "Read a buffer that is already out of scope";
+
+ if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
shape_collector_.push_back(
std::make_pair(e.buffer->data, e.buffer->shape));
- }
- return e.buffer.vload(e.RelIndex(op->args), e.buffer->dtype);
- } else {
- return expr;
}
+ return e.buffer.vload(e.RelIndex(op->indices), e.buffer->dtype);
}
+
Stmt VisitStmt_(const PrefetchNode *op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<PrefetchNode>();
CHECK(op != nullptr);
- TensorKey key{op->func, op->value_index};
+
+ const auto& key = op->buffer;
auto it = buf_map_.find(key);
CHECK(it != buf_map_.end())
- << "Cannot find allocated buffer for " << key.f;
+ << "Cannot find allocated buffer for " << key;
const BufferEntry& e = it->second;
CHECK(!e.released)
@@ -340,7 +343,7 @@ class StorageFlattener : public StmtExprMutator {
for (int i = op->bounds.size() - 1; i > starts; --i) {
args.push_back(op->bounds[i]->min);
}
- auto &func_name = op->func->func_name();
+ auto &func_name = op->buffer->name;
vars.push_back(Var(
"prefetch." + func_name + "." + std::to_string(starts),
DataType::Int(32)));
args.push_back(op->bounds[starts]->min + stride * vars.back());
@@ -358,7 +361,7 @@ class StorageFlattener : public StmtExprMutator {
PrimExpr address = CallNode::make(
DataType::Handle(), tvm_address_of, {load},
CallNode::PureIntrinsic);
PrimExpr prefetch = CallNode::make(
- op->dtype, CallNode::prefetch, {address, 0, 3, 1},
CallNode::Intrinsic);
+ op->buffer->dtype, CallNode::prefetch, {address, 0, 3, 1},
CallNode::Intrinsic);
stmt = EvaluateNode::make(prefetch);
PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1;
stmt = ForNode::make(vars[i], 0, extent, ForType::Serial,
DeviceAPI::None, stmt);
@@ -367,6 +370,26 @@ class StorageFlattener : public StmtExprMutator {
return stmt;
}
+ PrimExpr VisitExpr_(const CallNode* op) final {
+ CHECK(op->call_type != CallNode::Halide)
+ << "Cannot handle Halide calls "
+ << " please run SchedulePostProcToPrimFunc first";
+ return StmtExprMutator::VisitExpr_(op);
+ }
+
+ Stmt VisitStmt_(const ProvideNode* op) final {
+ LOG(FATAL) << "Cannot handle Provide "
+ << " please run SchedulePostProcToPrimFunc first";
+ return Stmt();
+ }
+
+ Stmt VisitStmt_(const RealizeNode* op) final {
+ LOG(FATAL) << "Cannot handle Realize "
+ << " please run SchedulePostProcToPrimFunc first";
+ return Stmt();
+ }
+
+
private:
// The specific tensor data layout is not determined before
// StorageFlatten pass. We use buffer_bind_scope
@@ -406,14 +429,16 @@ class StorageFlattener : public StmtExprMutator {
Array<ObjectRef> arr = Downcast<Array<ObjectRef> > (op->node);
CHECK_EQ(arr.size(), 2U);
const BufferNode* buffer = arr[0].as<BufferNode>();
- const te::TensorNode* tensor = arr[1].as<te::TensorNode>();
+ const BufferNode* target = arr[1].as<BufferNode>();
const CallNode* tuple = op->value.as<CallNode>();
- CHECK(buffer && tensor);
+ CHECK(buffer && target);
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
- TensorKey key{tensor->op, tensor->value_index};
- CHECK(buf_map_.count(key))
- << "Cannot find buffer of " << tensor->op << " value=" <<
tensor->value_index;
- const BufferEntry& be = buf_map_.at(key);
+ auto key = GetRef<Buffer>(target);
+
+ auto it = buf_map_.find(key);
+ CHECK(it != buf_map_.end())
+ << "Cannot find buffer of " << key;
+ const BufferEntry& be = it->second;
CHECK(!be.released);
CHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2);
Array<PrimExpr> begins, extents;
@@ -426,7 +451,7 @@ class StorageFlattener : public StmtExprMutator {
} else {
for (size_t i = 0; i < tuple->args.size(); i += 2) {
begins.push_back(tuple->args[i]);
- auto new_extent = bounded_analyzer_->Simplify(tuple->args[i+1]);
+ auto new_extent = bound_analyzer_->Simplify(tuple->args[i+1]);
extents.push_back(new_extent);
}
}
@@ -451,6 +476,7 @@ class StorageFlattener : public StmtExprMutator {
}
return body;
}
+
// The buffer entry in the flatten map
struct DimAlignInfo {
int align_factor{0};
@@ -509,9 +535,10 @@ class StorageFlattener : public StmtExprMutator {
// Variable remap
std::unordered_map<const VarNode*, PrimExpr> var_remap_;
// Buffer map
- std::unordered_map<TensorKey, BufferEntry> buf_map_;
+ std::unordered_map<Buffer, BufferEntry, ObjectHash, ObjectEqual> buf_map_;
// Dimension alignment
- std::unordered_map<TensorKey, std::vector<DimAlignInfo> > dim_align_;
+ std::unordered_map<Buffer, std::vector<DimAlignInfo>,
+ ObjectHash, ObjectEqual> dim_align_;
// Storage scope
std::unordered_map<const Object*, std::string> storage_scope_;
// The current thread scope.
@@ -520,7 +547,7 @@ class StorageFlattener : public StmtExprMutator {
std::vector<std::pair<Var, Array<PrimExpr>>> shape_collector_;
// bounds populator. We really need the analyzer from it.
// However
- IRVisitorWithAnalyzer* bounded_analyzer_;
+ IRVisitorWithAnalyzer* bound_analyzer_;
// The size of cacheline
int cache_line_size_;
// The current stage is an OpenGL shader.
@@ -529,15 +556,37 @@ class StorageFlattener : public StmtExprMutator {
bool create_bound_attributes_{false};
};
-Stmt StorageFlatten(Stmt stmt, Map<te::Tensor, Buffer> extern_buffer,
- int cache_line_size, bool create_bound_attributes) {
- IRVisitorWithAnalyzer bounded_analyzer;
- bounded_analyzer(stmt);
- stmt =
- StorageFlattener(extern_buffer, cache_line_size,
- create_bound_attributes,
&bounded_analyzer)(std::move(stmt));
- return stmt;
+PrimFunc StorageFlatten(PrimFunc func,
+ int cache_line_size,
+ bool create_bound_attributes) {
+ auto fptr = func.CopyOnWrite();
+
+ IRVisitorWithAnalyzer bound_analyzer;
+ bound_analyzer(fptr->body);
+ fptr->body = StorageFlattener(fptr->buffer_map,
+ cache_line_size,
+ create_bound_attributes,
+ &bound_analyzer)(std::move(fptr->body));
+ return func;
}
+
+namespace transform {
+
+// TODO(tvm-team): consolidate configs to the PassContext
+Pass StorageFlatten(int cache_line_size,
+ bool create_bound_attributes) {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ return StorageFlatten(
+ std::move(f), cache_line_size, create_bound_attributes);
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.StorageFlatten", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.StorageFlatten")
+.set_body_typed(StorageFlatten);
+
+} // namespace transform
+
} // namespace tir
} // namespace tvm
diff --git a/tests/python/unittest/test_arith_domain_touched.py
b/tests/python/unittest/test_arith_domain_touched.py
index 0d769aa..1033721 100644
--- a/tests/python/unittest/test_arith_domain_touched.py
+++ b/tests/python/unittest/test_arith_domain_touched.py
@@ -22,21 +22,25 @@ def test_domain_touched():
j = te.var('j')
n = tvm.runtime.convert(100)
m = te.var('m')
- a = te.placeholder((n, m), name = 'a')
- b = te.placeholder((n, m), name = 'b')
+
+ a = tvm.tir.decl_buffer((n, m), name='a')
+ b = tvm.tir.decl_buffer((n, m), name='b')
+
+
ir = tvm.tir.For(
i, 0, n, 0, 0,
tvm.tir.For(j, 0, m, 0, 0,
- tvm.tir.Provide(
- a.op,
- 0,
- tvm.tir.Call(b.dtype, 'b', [i - 1, j + 1], 3, b.op, 0) +
- tvm.tir.Call(a.dtype, 'a', [i - 1, j - 1], 3, a.op, 0),
+ tvm.tir.BufferStore(
+ a,
+ tvm.tir.BufferLoad(b, [i - 1, j + 1]) +
+ tvm.tir.BufferLoad(a, [i - 1, j - 1]),
[i, j]
)
)
)
+
a_domain_r = tvm.arith._ffi_api.DomainTouched(ir, a, True, False)
+
assert a_domain_r[0].min.value == -1
assert a_domain_r[0].extent.value == 100
assert a_domain_r[1].min.value == -1
diff --git a/tests/python/unittest/test_te_build_lower.py
b/tests/python/unittest/test_te_build_lower.py
index 442c4fe..b1d7546 100644
--- a/tests/python/unittest/test_te_build_lower.py
+++ b/tests/python/unittest/test_te_build_lower.py
@@ -48,9 +48,9 @@ def test_split_uneven_unique_likely():
x, y = c.op.axis
sch = te.create_schedule(c.op)
xo, xi = sch[c].split(x, 5)
- stmt = tvm.lower(sch, [a, b, c], simple_mode=True)
+ stmt = tvm.lower(sch, [a, b, c])["main"].body
assert isinstance(stmt.body.body.body, tvm.tir.stmt.IfThenElse)
- assert str(stmt.body.body.body).count("likely") == 1
+
if __name__ == "__main__":
test_lower_rfactor()
diff --git a/tests/python/unittest/test_te_hybrid_script.py
b/tests/python/unittest/test_te_hybrid_script.py
index b525d01..5b4a1c9 100644
--- a/tests/python/unittest/test_te_hybrid_script.py
+++ b/tests/python/unittest/test_te_hybrid_script.py
@@ -365,7 +365,7 @@ def test_bind():
a = te.placeholder((8, 4), 'float32')
c = foo(a)
s = te.create_schedule(c.op)
- ir = tvm.lower(s, [a, c], simple_mode=True)
+ ir = tvm.lower(s, [a, c])
func, ins, outs = run_and_check(foo, [a], target='cuda')
run_and_check(func, ins, outs=outs, target='cuda')
@@ -517,7 +517,7 @@ def test_upstream():
c = te.compute((20, ), lambda x: a[x] + b[x])
d = upstream(c)
sch = te.create_schedule([c.op, d.op])
- ir = tvm.lower(sch, [a, b, d], simple_mode=True)
+ ir = tvm.lower(sch, [a, b, d])
func = tvm.build(sch, [a, b, d])
assert(func)
@@ -730,7 +730,7 @@ def test_schedule():
joo, joi = sch[c].split(jo, 4)
sch[c].vectorize(ji)
sch[c].reorder(ii, io, joo, joi, ji)
- ir = tvm.lower(sch, [a, b, c], simple_mode=True)
+ ir = tvm.lower(sch, [a, b, c])["main"].body
assert isinstance(ir, tvm.tir.AttrStmt)
ir = ir.body
assert isinstance(ir, tvm.tir.For)
@@ -751,7 +751,7 @@ def test_schedule():
# Test fuse
sch = te.create_schedule(c.op)
sch[c].fuse(c.op.axis[0], c.op.axis[1])
- ir = tvm.lower(sch, [a, b, c], simple_mode=True)
+ ir = tvm.lower(sch, [a, b, c])["main"].body
assert isinstance(ir, tvm.tir.AttrStmt)
ir = ir.body
assert isinstance(ir, tvm.tir.For)
diff --git a/tests/python/unittest/test_te_schedule.py
b/tests/python/unittest/test_te_schedule.py
index c9b422f..9e4d45e 100644
--- a/tests/python/unittest/test_te_schedule.py
+++ b/tests/python/unittest/test_te_schedule.py
@@ -283,7 +283,7 @@ def test_tensor_intrin_scalar_params():
# Pass scalar inputs to the TensorIntrin, interleaved with tensor inputs
C = te.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C")
s = te.create_schedule(C.op)
- stmt = tvm.lower(s, [A, C], simple_mode=True)
+ stmt = tvm.lower(s, [A, C])["main"].body
assert isinstance(stmt.body.body, tvm.tir.Evaluate)
assert len(stmt.body.body.value.args) == 5
assert str(stmt.body.body.value.args[3]) == "(i*i)"
diff --git a/tests/python/unittest/test_te_schedule_ops.py
b/tests/python/unittest/test_te_schedule_ops.py
index 3e521ab..2a0c6c1 100644
--- a/tests/python/unittest/test_te_schedule_ops.py
+++ b/tests/python/unittest/test_te_schedule_ops.py
@@ -28,6 +28,9 @@ def test_schedule0():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
+ func = tvm.te.schedule.SchedulePostProcToPrimFunc(
+ [A, A1], stmt, None)
+ assert isinstance(func, tvm.tir.PrimFunc)
def test_schedule1():
@@ -43,6 +46,10 @@ def test_schedule1():
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
+ func = tvm.te.schedule.SchedulePostProcToPrimFunc(
+ [A, A1], stmt, None)
+ assert isinstance(func, tvm.tir.PrimFunc)
+
def test_schedule2():
m = te.var('m')
@@ -57,6 +64,9 @@ def test_schedule2():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
+ func = tvm.te.schedule.SchedulePostProcToPrimFunc(
+ [A, A2], stmt, None)
+ assert isinstance(func, tvm.tir.PrimFunc)
def test_schedule_scan():
@@ -77,6 +87,7 @@ def test_schedule_scan():
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
+
def test_inline_multi_reduce():
def argmax_comp(x, y):
idx = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
@@ -510,19 +521,19 @@ def test_local_stage_predicate():
return ret
# local vs. threadIdx
s = schedule(tx, "local")
- lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
+ lowered_body = tvm.lower(s, [A, C])["main"].body
assert (not any(
collect_visit(lowered_body,
lambda x: isinstance(x, tvm.tir.IfThenElse))))
# local vs. vthread
s = schedule(vx, "local")
- lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
+ lowered_body = tvm.lower(s, [A, C])["main"].body
assert (not any(
collect_visit(lowered_body,
lambda x: isinstance(x, tvm.tir.IfThenElse))))
# shared vs. blockIdx
s = schedule(by, "shared")
- lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
+ lowered_body = tvm.lower(s, [A, C])["main"].body
assert (not any(
collect_visit(lowered_body,
lambda x: isinstance(x, tvm.tir.IfThenElse))))
@@ -548,7 +559,7 @@ def test_local_stage_predicate2():
s[AA].compute_at(s[C], ooc)
oaa, iaa = s[AA].split(s[AA].op.axis[0], factor=32)
s[AA].bind(iaa, thread_x)
- lowered_body = tvm.lower(s, [A, C], simple_mode=True).body
+ lowered_body = tvm.lower(s, [A, C])["main"].body
def collect_visit(stmt, f):
ret = []
diff --git a/tests/python/unittest/test_te_tensor.py
b/tests/python/unittest/test_te_tensor.py
index 55edd1c..4528086 100644
--- a/tests/python/unittest/test_te_tensor.py
+++ b/tests/python/unittest/test_te_tensor.py
@@ -128,7 +128,7 @@ def test_tensor_compute1():
lambda i: vadd(A[i, 0:factor], B[i, 0:factor]))
s = te.create_schedule(C.op)
- stmt = tvm.lower(s, [A, B, C], simple_mode=True)
+ stmt = tvm.lower(s, [A, B, C])["main"].body
assert isinstance(stmt.body, tvm.tir.Evaluate)
def test_tensor_compute2():
@@ -171,7 +171,7 @@ def test_tensor_compute2():
lambda i, j: vgemm(A[i, k, 0:factor1, 0:factor], B[j, k, 0:factor2,
0:factor], reduce_axis=k))
s = te.create_schedule(C.op)
- stmt = tvm.lower(s, [A, B, C], simple_mode=True)
+ stmt = tvm.lower(s, [A, B, C])["main"].body
assert isinstance(stmt.body.body[0], tvm.tir.Evaluate)
assert isinstance(stmt.body.body[1].body, tvm.tir.Evaluate)
diff --git a/tests/python/unittest/test_tir_analysis_verify_memory.py
b/tests/python/unittest/test_tir_analysis_verify_memory.py
index b362508..b0de91b 100644
--- a/tests/python/unittest/test_tir_analysis_verify_memory.py
+++ b/tests/python/unittest/test_tir_analysis_verify_memory.py
@@ -24,29 +24,6 @@ gpu_devices = ["cuda", "opencl", "metal", "vulkan"]
other_devices = ["llvm", "ext_dev"]
-def lower(sch, args):
- binds = {}
- arg_list = []
- for x in args:
- if isinstance(x, te.tensor.Tensor):
- buf = tvm.tir.decl_buffer(x.shape, dtype=x.dtype, name=x.name)
- assert x not in binds
- binds[x] = buf
- arg_list.append(buf)
- else:
- raise ValueError("args must be Tensor, Buffer or Var")
- sch = sch.normalize()
- bounds = tvm.te.schedule.InferBound(sch)
- stmt = tvm.te.schedule.ScheduleOps(sch, bounds)
- stmt = tvm.tir.ir_pass.LoopPartition(stmt, False)
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64)
-
- f = tvm.tir.PrimFunc(arg_list, stmt).with_attr(
- "global_symbol", tvm.runtime.String("test"))
- mod = tvm.IRModule({"test": f})
- return mod
-
-
# All computations are bound.
# So VerifyMemory pass is expected to succeed.
#
@@ -61,7 +38,7 @@ def test_verify_memory_all_bind():
s[B].bind(bx, te.thread_axis("blockIdx.x"))
s[B].bind(tx, te.thread_axis("threadIdx.x"))
- mod = lower(s, [A, B])
+ mod = tvm.lower(s, [A, B])
for dev_type in gpu_devices + other_devices:
binded_mod = tvm.tir.transform.Apply(
@@ -81,7 +58,7 @@ def test_verify_memory_not_bind():
# B is not bound to threads.
s = te.create_schedule(B.op)
- mod = lower(s, [A, B])
+ mod = tvm.lower(s, [A, B])
for dev_type in gpu_devices:
binded_mod = tvm.tir.transform.Apply(
@@ -111,7 +88,7 @@ def test_verify_memory_partially_bind():
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
- mod = lower(s, [A, B, C, D])
+ mod = tvm. lower(s, [A, B, C, D])
for dev_type in gpu_devices:
binded_mod = tvm.tir.transform.Apply(
diff --git a/tests/python/unittest/test_tir_constructor.py
b/tests/python/unittest/test_tir_constructor.py
index 7a03e48..4af93fd 100644
--- a/tests/python/unittest/test_tir_constructor.py
+++ b/tests/python/unittest/test_tir_constructor.py
@@ -194,9 +194,9 @@ def test_stmt_constructor():
assert x.then_case.value.value == 11
assert x.else_case == nop
- x = tvm.tir.Prefetch(None, 1, "float32", [])
+ b = tvm.tir.decl_buffer((1, 2))
+ x = tvm.tir.Prefetch(b, [])
assert isinstance(x, tvm.tir.Prefetch)
- assert x.value_index == 1
if __name__ == "__main__":
diff --git a/tests/python/unittest/test_tir_ir_builder.py
b/tests/python/unittest/test_tir_ir_builder.py
index 9106be8..090acda 100644
--- a/tests/python/unittest/test_tir_ir_builder.py
+++ b/tests/python/unittest/test_tir_ir_builder.py
@@ -28,7 +28,6 @@ def test_for():
A[j] = A[j] + 2
body = ib.get()
- print(body)
assert isinstance(body, tvm.tir.AttrStmt)
body = body.body
assert isinstance(body, tvm.tir.Allocate)
@@ -59,14 +58,13 @@ def test_if():
assert body.else_case.index.value == 0
def test_prefetch():
- A = te.placeholder((10, 20), name="A")
+ A = tvm.tir.decl_buffer((10, 20), name="A")
ib = tvm.tir.ir_builder.create()
n = te.size_var("n")
with ib.for_range(0, n, name="i") as i:
ib.emit(
- tvm.tir.Prefetch(
- A.op, A.value_index, A.dtype,
+ tvm.tir.Prefetch(A,
[tvm.ir.Range.make_by_min_extent(i+1, 2),
tvm.ir.Range.make_by_min_extent(0, 20)]))
body = ib.get()
diff --git a/tests/python/unittest/test_tir_nodes.py
b/tests/python/unittest/test_tir_nodes.py
index 9f4ccad..468ab1d 100644
--- a/tests/python/unittest/test_tir_nodes.py
+++ b/tests/python/unittest/test_tir_nodes.py
@@ -301,6 +301,10 @@ def test_buffer_load_store():
s = tvm.tir.BufferStore(b, 0.1, [0])
assert isinstance(s, tvm.tir.BufferStore)
+ s = tvm.tir.BufferRealize(b, [tvm.ir.Range(0, 1)],
+ True, tvm.tir.Evaluate(0))
+ assert isinstance(s, tvm.tir.BufferRealize)
+
def test_intimm_cond():
x = tvm.runtime.convert(1)
diff --git a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py
b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py
index 7ec2e48..9d16413 100644
--- a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py
+++ b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py
@@ -26,9 +26,10 @@ def test_copy2d():
s[B].pragma(B.op.axis[0], "memcpy")
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
- Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
+ func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
+ mod = tvm.IRModule.from_expr(func)
+ mod = tvm.tir.transform.StorageFlatten(64)(mod)
+
def cb(src, dst, pad_before, pad_after, pad_value):
assert dst.strides[0] == l
assert dst.strides[1].value == 1
@@ -36,7 +37,6 @@ def test_copy2d():
assert tuple(src.shape) == (m, l)
return tvm.tir.Evaluate(0)
- mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
@@ -51,9 +51,11 @@ def test_copy_pad():
s[B].pragma(B.op.axis[0], "memcpy")
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
- Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
+
+ func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
+ mod = tvm.IRModule.from_expr(func)
+ mod = tvm.tir.transform.StorageFlatten(64)(mod)
+
def cb(src, dst, pad_before, pad_after, pad_value):
assert tvm.tir.ir_pass.Simplify(src.elem_offset).value == 0
assert pad_before[0].value == 1
@@ -63,7 +65,6 @@ def test_copy_pad():
assert pad_value.value == 1.0
return tvm.tir.Evaluate(0)
- mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
@@ -75,9 +76,11 @@ def test_single_point_test():
s[B].pragma(B.op.axis[0], "memcpy")
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
- Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
+
+ func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
+ mod = tvm.IRModule.from_expr(func)
+ mod = tvm.tir.transform.StorageFlatten(64)(mod)
+
def cb(src, dst, pad_before, pad_after, pad_value):
assert tvm.tir.ir_pass.Simplify(src.elem_offset).value == 0
assert tvm.tir.ir_pass.Simplify(dst.elem_offset).value == 0
@@ -85,7 +88,6 @@ def test_single_point_test():
assert tvm.tir.ir_pass.Simplify(dst.strides[0]).value == 1
return tvm.tir.Evaluate(0)
- mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
@@ -105,11 +107,12 @@ def test_copy_pad_split():
s[Apad].pragma(s[Apad].op.axis[0], "memcpy")
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
- Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
- stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt)
+
+ func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
+ mod = tvm.IRModule.from_expr(func)
+ mod = tvm.tir.transform.StorageFlatten(64)(mod._move())
+ mod = tvm.tir.transform.Simplify()(mod._move())
+
def cb(src, dst, pad_before, pad_after, pad_value):
assert(dst.elem_offset.value == 0)
assert_expr_equal(src.elem_offset, tvm.te.max(xo * 4, 1) - 1)
@@ -121,12 +124,10 @@ def test_copy_pad_split():
assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after)
return tvm.tir.Evaluate(0)
- mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body
-
if __name__ == "__main__":
test_copy2d()
test_copy_pad()
diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py
b/tests/python/unittest/test_tir_transform_make_packed_api.py
index fb76597..760cf47 100644
--- a/tests/python/unittest/test_tir_transform_make_packed_api.py
+++ b/tests/python/unittest/test_tir_transform_make_packed_api.py
@@ -28,18 +28,16 @@ def test_makeapi():
bounds = tvm.te.schedule.InferBound(s)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-
- Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
- Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
- Cb = tvm.tir.decl_buffer(C.shape, C.dtype, name='C')
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}, 64)
+ func = tvm.te.schedule.SchedulePostProcToPrimFunc([n, A, B, C], stmt, None)
+ mod = tvm.IRModule.from_expr(func)
+ mod = tvm.tir.transform.StorageFlatten(64)(mod)
+ mod = tvm.tir.transform.Apply(
+ lambda f: f.with_attr({
+ "target": tvm.target.create("llvm"),
+ "global_symbol": "main",
+ }))(mod)
num_unpacked_args = 2
- mod = tvm.IRModule.from_expr(
- tvm.tir.PrimFunc([n, Ab, Bb, Cb], stmt).with_attr({
- "global_symbol": "main",
- "target": tvm.target.create("llvm")
- }))
f = tvm.tir.transform.MakePackedAPI(num_unpacked_args)(mod)["main"]
assert(len(f.params) == 7)
diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py
b/tests/python/unittest/test_tir_transform_narrow_datatype.py
index dbf2267..6179bbb 100644
--- a/tests/python/unittest/test_tir_transform_narrow_datatype.py
+++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py
@@ -40,8 +40,11 @@ def lower_sch(sch, args, target_bits):
raise ValueError("args must be Tensor, Buffer or Var")
bounds = te.schedule.InferBound(sch)
stmt = te.schedule.ScheduleOps(sch, bounds)
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False)
- return lower_stmt(arg_list, stmt, target_bits)
+
+ func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None)
+ mod = tvm.IRModule.from_expr(func)
+ mod = tvm.tir.transform.StorageFlatten(64)(mod)
+ return tvm.tir.transform.NarrowDataType(target_bits)(mod)["main"].body
def test_basic():
diff --git a/tests/python/unittest/test_tir_pass_storage_flatten.py
b/tests/python/unittest/test_tir_transform_storage_flatten.py
similarity index 82%
rename from tests/python/unittest/test_tir_pass_storage_flatten.py
rename to tests/python/unittest/test_tir_transform_storage_flatten.py
index 1eaadb3..e2bfeb0 100644
--- a/tests/python/unittest/test_tir_pass_storage_flatten.py
+++ b/tests/python/unittest/test_tir_transform_storage_flatten.py
@@ -30,11 +30,14 @@ def test_flatten2():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
-
Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
- stmt = tvm.tir.ir_pass.Simplify(stmt)
+
+ func = tvm.te.schedule.SchedulePostProcToPrimFunc(
+ [Ab, A2b], stmt, {A: Ab, A2: A2b})
+ mod = tvm.IRModule.from_expr(func)
+ mod = tvm.tir.transform.StorageFlatten(64)(mod)
+
def test_flatten_prefetch():
A = te.placeholder((25, 100, 4), name = 'A')
@@ -42,8 +45,14 @@ def test_flatten_prefetch():
i = te.size_var('i')
j = te.size_var('j')
region = [tvm.ir.Range.make_by_min_extent(i[0], i[1]) for i in [(i, 2),
(j, 8), (0, 4)]]
- stmt = tvm.tir.Prefetch(A.op, 0, A.dtype, region)
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: _A}, 64)
+ stmt = tvm.tir.Prefetch(_A, region)
+
+ func = tvm.te.schedule.SchedulePostProcToPrimFunc(
+ [_A], stmt, {A: _A})
+
+ mod = tvm.IRModule.from_expr(func)
+ mod = tvm.tir.transform.StorageFlatten(64)(mod)
+ stmt = mod["main"].body
stmt = tvm.tir.ir_pass.Simplify(stmt)
assert stmt.extent.value == 2
assert isinstance(stmt.body, tvm.tir.For)
@@ -62,12 +71,15 @@ def test_flatten_storage_align():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
- A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
+
+ func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
+ mod = tvm.IRModule.from_expr(func)
+ mod = tvm.tir.transform.StorageFlatten(64)(mod)
+ stmt = mod["main"].body
stmt = tvm.tir.ir_pass.Simplify(stmt)
assert(stmt.body.extents[0].value == 17 * 8)
+
def test_flatten_double_buffer():
dtype = 'int64'
n = 100
@@ -87,7 +99,13 @@ def test_flatten_double_buffer():
C[j] = B[j] + 1
stmt = ib.get()
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {}, 64)
+
+ mod = tvm.IRModule.from_expr(
+ tvm.tir.PrimFunc([A, C], stmt))
+
+ mod = tvm.tir.transform.StorageFlatten(64)(mod)
+ stmt = mod["main"].body
+
stmt = tvm.tir.ir_pass.InjectDoubleBuffer(stmt, 2)
stmt = tvm.tir.ir_pass.Simplify(stmt)
assert isinstance(stmt.body.body, tvm.tir.Allocate)
@@ -105,7 +123,7 @@ def test_flatten_double_buffer():
assert count[0] == 4
if __name__ == "__main__":
- test_flatten_storage_align()
test_flatten2()
- test_flatten_prefetch()
+ test_flatten_storage_align()
test_flatten_double_buffer()
+ test_flatten_prefetch()
diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py
b/tests/python/unittest/test_tir_transform_storage_rewrite.py
index e4e1b31..85f856d 100644
--- a/tests/python/unittest/test_tir_transform_storage_rewrite.py
+++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py
@@ -30,11 +30,11 @@ def test_storage_share():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
- Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
- mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
+ mod = tvm.IRModule.from_expr(func)
+ mod = tvm.tir.transform.StorageFlatten(64)(mod)
+
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
@@ -166,11 +166,11 @@ def test_inplace_rule():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
- Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
- mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
+ mod = tvm.IRModule.from_expr(func)
+ mod = tvm.tir.transform.StorageFlatten(64)(mod)
+
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
@@ -201,11 +201,10 @@ def test_storage_combine():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
- Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
+ func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B], stmt, None)
+ mod = tvm.IRModule.from_expr(func)
+ mod = tvm.tir.transform.StorageFlatten(64)(mod)
- mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
@@ -238,11 +237,9 @@ def test_storage_share_gpu():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- Ab = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='A')
- Bb = tvm.tir.decl_buffer(A[0].shape, A[0].dtype, name='B')
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A[0]: Ab, A[-1]: Bb}, 64)
-
- mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb], stmt))
+ func = tvm.te.schedule.SchedulePostProcToPrimFunc([A[0], A[-1]], stmt,
None)
+ mod = tvm.IRModule.from_expr(func)
+ mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
@@ -306,13 +303,11 @@ def test_inplace_rule2(scope_tb = "local_TB2", max_bits =
1024 * 1024 * 1024):
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
- Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
- Cc = tvm.tir.decl_buffer(C.shape, B.dtype, name='C')
- Dd = tvm.tir.decl_buffer(D.shape, B.dtype, name='D')
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb, C: Cc, D:Dd},
64)
- mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab, Bb, Cc, Dd], stmt))
+ func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, B, C, D], stmt, None)
+ mod = tvm.IRModule.from_expr(func)
+ mod = tvm.tir.transform.StorageFlatten(64)(mod)
+
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
@@ -398,17 +393,11 @@ def test_inplace_rule3():
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- B0a = tvm.tir.decl_buffer(B0.shape, B0.dtype, name='B0')
- B1a = tvm.tir.decl_buffer(B1.shape, B1.dtype, name='B1')
- B2a = tvm.tir.decl_buffer(B2.shape, B2.dtype, name='B2')
- B3a = tvm.tir.decl_buffer(B3.shape, B3.dtype, name='B3')
- B4a = tvm.tir.decl_buffer(B4.shape, B4.dtype, name='B4')
- B5a = tvm.tir.decl_buffer(B5.shape, B5.dtype, name='B5')
-
- Bb = tvm.tir.decl_buffer(B.shape, B.dtype, name='B')
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {B0: B0a, B1: B1a, B2: B2a,
B3: B3a, B4: B4a, B5: B5a, B: Bb}, 64)
+ func = tvm.te.schedule.SchedulePostProcToPrimFunc(
+ [B0, B1, B2, B3, B4, B5, B], stmt, None)
+ mod = tvm.IRModule.from_expr(func)
+ mod = tvm.tir.transform.StorageFlatten(64)(mod)
- mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([B0a, B1a, B2a, B3a, B4a,
B5a, Bb], stmt))
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
stmt = mod["main"].body
@@ -547,7 +536,7 @@ def test_large_input():
c = te.compute(shape, lambda i, j: compute(a, b)[i, j])
c = te.compute(shape, lambda i, j: 1 + c[i, j])
s = te.create_schedule(c.op)
- stmt = tvm.lower(s, [a, b, c], simple_mode=True)
+ stmt = tvm.lower(s, [a, b, c])["main"].body
def verify(n):
if isinstance(n, tvm.tir.Allocate):
assert n.extents[0].value == 268435456
diff --git a/tests/python/unittest/test_tir_transform_thread_sync.py
b/tests/python/unittest/test_tir_transform_thread_sync.py
index 9257f6c..783b669 100644
--- a/tests/python/unittest/test_tir_transform_thread_sync.py
+++ b/tests/python/unittest/test_tir_transform_thread_sync.py
@@ -34,15 +34,15 @@ def test_thread_storage_sync():
bounds = tvm.te.schedule.InferBound(s)
assert isinstance(bounds, tvm.container.Map)
stmt = tvm.te.schedule.ScheduleOps(s, bounds)
- Ab = tvm.tir.decl_buffer(A.shape, A.dtype, name='A')
- A2b = tvm.tir.decl_buffer(A2.shape, A2.dtype, name='A2')
- stmt = tvm.tir.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b}, 64)
+
+ func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None)
+ mod = tvm.IRModule.from_expr(func)
+ mod = tvm.tir.transform.StorageFlatten(64)(mod._move())
cuda_target = tvm.target.create("cuda")
- mod = tvm.IRModule.from_expr(
- tvm.tir.PrimFunc([Ab, A2b], stmt).with_attr({
- "global_symbol": "test", "target": cuda_target}))
+ mod = tvm.tir.transform.Apply(lambda f: f.with_attr({
+ "global_symbol": "test", "target": cuda_target}))(mod._move())
fdevice = tvm.tir.transform.SplitHostDevice()(mod)["test_kernel0"]
mod = tvm.IRModule.from_expr(fdevice)
diff --git a/tutorials/dev/low_level_custom_pass.py
b/tutorials/dev/low_level_custom_pass.py
index 25ca279..d35913b 100644
--- a/tutorials/dev/low_level_custom_pass.py
+++ b/tutorials/dev/low_level_custom_pass.py
@@ -40,8 +40,6 @@ Before reading this tutorial, we assume readers have already
known these topics
take a look at ``python/tvm/build_module.py`` to get some basics.
"""
-
-from __future__ import absolute_import, print_function
import tvm
from tvm import te
import numpy as np
@@ -57,7 +55,7 @@ b = te.placeholder((n, ), name="b")
c = te.compute((n, ), lambda i: a[i] + b[i], name='c')
sch = te.create_schedule(c.op)
-ir = tvm.lower(sch, [a, b, c], simple_mode=True)
+ir = tvm.lower(sch, [a, b, c])
print(ir)
######################################################################
@@ -137,12 +135,8 @@ def vectorize(stmt):
# Glue to Lowering
# ----------------
# So far, we are done with writing this IR transformation pass. What we need
to do next is to glue
-# this pass to TVM's lower pass. We can first call this function directly as a
sanity check.
+# this pass to TVM's lower pass.
#
-
-print(vectorize(ir))
-
-#####################################################################
# In TVM, there is a property called ``BuildConfig``. You can use this
property to customize your
# own lowering options. In this case, we inject the pass written above into
the TVM standard lowering
# pass by feeding **a list of tuple** as argument to ``add_lower_pass``.
"Tuple" indicates different
@@ -160,7 +154,7 @@ print(vectorize(ir))
#
with tvm.target.build_config(add_lower_pass=[(1, vectorize)]) as cfg:
- print(tvm.lower(sch, [a, b, c], simple_mode=True))
+ print(tvm.lower(sch, [a, b, c]))
#####################################################################
# Quick View