This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7340c02 [TIR][REFACTOR] ForNode introduce thread binding and remove
legacy field (#7306)
7340c02 is described below
commit 7340c02d0efe0f5eb5692fb9f4cc7573c5d056cb
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Jan 19 02:00:11 2021 -0500
[TIR][REFACTOR] ForNode introduce thread binding and remove legacy field
(#7306)
[TIR][REFACTOR] ForNode update
- Remove deprecated device_api.
- Add ThreadBinding for_type.
- Add additional annotations.
More style consistency refactor to make the ForNode
to be consistent with rest of the codebase.
- ForType => ForKind
- Add constant prefix k to enum consts per Google C style
- Introduce ForKind to the python side.
---
include/tvm/tir/stmt.h | 85 ++++++++++++++--------
python/tvm/script/scope_handler.py | 8 +-
python/tvm/te/hybrid/calls.py | 20 ++---
python/tvm/te/hybrid/parser.py | 14 ++--
python/tvm/tir/__init__.py | 2 +-
python/tvm/tir/ir_builder.py | 24 +++---
python/tvm/tir/stmt.py | 58 ++++++++++++---
python/tvm/topi/cuda/nms.py | 8 +-
python/tvm/topi/cuda/rcnn/proposal.py | 6 +-
python/tvm/topi/cuda/sparse.py | 22 +++---
python/tvm/topi/nn/sparse.py | 12 +--
python/tvm/topi/sparse/csrmm.py | 4 +-
python/tvm/topi/sparse/csrmv.py | 2 +-
python/tvm/topi/sparse/dense.py | 8 +-
python/tvm/topi/vision/rcnn/proposal.py | 6 +-
python/tvm/topi/x86/scatter.py | 2 +-
src/auto_scheduler/feature.cc | 14 ++--
src/autotvm/feature_visitor.cc | 14 ++--
src/printer/tir_text_printer.cc | 19 +++--
src/printer/tvmscript_printer.cc | 21 +++---
src/target/llvm/codegen_cpu.cc | 11 +--
src/target/llvm/codegen_llvm.cc | 4 +-
src/target/source/codegen_cuda.cc | 2 +-
src/target/spirv/codegen_spirv.cc | 2 +-
src/te/operation/hybrid_op.cc | 27 +++----
src/te/operation/op_utils.cc | 36 ++++-----
src/te/operation/op_utils.h | 10 +--
.../schedule_postproc_rewrite_for_tensor_core.cc | 3 +-
src/tir/ir/stmt.cc | 36 +++++----
src/tir/transforms/combine_context_call.cc | 2 +-
src/tir/transforms/inject_double_buffer.cc | 3 +-
src/tir/transforms/inject_prefetch.cc | 4 +-
src/tir/transforms/inject_virtual_thread.cc | 7 +-
src/tir/transforms/ir_utils.cc | 3 +-
src/tir/transforms/loop_partition.cc | 4 +-
src/tir/transforms/make_packed_api.cc | 4 +-
src/tir/transforms/narrow_datatype.cc | 4 +-
src/tir/transforms/storage_flatten.cc | 4 +-
src/tir/transforms/storage_rewrite.cc | 8 +-
src/tir/transforms/unroll_loop.cc | 10 +--
src/tir/transforms/vectorize_loop.cc | 13 ++--
tests/python/unittest/test_arith_domain_touched.py | 6 +-
.../test_runtime_module_based_interface.py | 3 +-
tests/python/unittest/test_runtime_module_load.py | 6 +-
tests/python/unittest/test_target_codegen_cuda.py | 2 +-
tests/python/unittest/test_target_codegen_llvm.py | 2 +-
.../unittest/test_target_codegen_static_init.py | 2 +-
.../unittest/test_target_codegen_vm_basic.py | 2 +-
tests/python/unittest/test_te_hybrid_script.py | 6 +-
tests/python/unittest/test_tir_constructor.py | 2 +-
tests/python/unittest/test_tir_nodes.py | 2 +-
.../unittest/test_tir_transform_remove_no_op.py | 11 +--
.../unittest/test_tir_transform_storage_rewrite.py | 4 +-
.../unittest/test_tir_transform_unroll_loop.py | 8 +-
.../unittest/test_tir_transform_vectorize.py | 16 ++--
tutorials/dev/low_level_custom_pass.py | 4 +-
vta/python/vta/transform.py | 12 ++-
57 files changed, 359 insertions(+), 275 deletions(-)
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 2b7f1e6..093d49c 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -752,23 +752,34 @@ class Evaluate : public Stmt {
TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode);
};
-/*! \brief Additional annotation of for loop. */
-enum class ForType : int {
- /*! \brief serial execution. */
- Serial = 0,
- /*! \brief parallel execution on CPU. */
- Parallel = 1,
- /*! \brief Vector SIMD loop annotaion. */
- Vectorized = 2,
- /*! \brief Unroll annotation. */
- Unrolled = 3
+/*!
+ * \brief The kind of the loop.
+ *
+ * ForKind can change the control flow semantics
+ * of the loop. So the kind field needs to be considered
+ * in all TIR passes.
+ */
+enum class ForKind : int {
+ /*! \brief default semantics -- serial execution. */
+ kSerial = 0,
+ /*! \brief Parallel execution on CPU. */
+ kParallel = 1,
+ /*!
+ * \brief Vector SIMD loop.
+ * The loop body will be vectorized.
+ */
+ kVectorized = 2,
+ /*! \brief The loop body must be unrolled. */
+ kUnrolled = 3,
+ /*!
+ * \brief The loop variable is bound to a thread in
+ * an environment. In the final stage of lowering,
+ * the loop is simply removed and the loop variable is
+ * mapped to the corresponding context thread.
+ */
+ kThreadBinding = 4
};
-// Kevice api of for loop
-// kept for backward compatibility
-// consider refactor and remove later.
-enum class DeviceAPI : int { None = 0 };
-
/*!
* \brief A for loop, with poissible type annotations.
*
@@ -787,39 +798,50 @@ class ForNode : public StmtNode {
PrimExpr min;
/*! \brief The extent of the iteration. */
PrimExpr extent;
- /*! \brief The type of the for loop. */
- ForType for_type;
- /*!
- * \brief Deprecated, reserved for backward compatibility.
- * Consider refactor and remove later.
- */
- DeviceAPI device_api;
+ /*! \brief The kind of the for loop. */
+ ForKind kind;
/*! \brief The body of the for loop. */
Stmt body;
+ /*!
+ * \brief Only valid when kind == ForKind::kThreadBinding
+ * The context thread that this loop variable bounds to.
+ */
+ Optional<IterVar> thread_binding;
+ /*!
+ * \brief Additional annotations about the loop.
+ *
+ * These annotations can be used as auxiliary hint
+ * to future transformations. An annotation should
+ * not change the control flow semantics of the loop
+ * and can be ignored in most passes.
+ */
+ Map<String, ObjectRef> annotations;
void VisitAttrs(AttrVisitor* v) {
v->Visit("loop_var", &loop_var);
v->Visit("min", &min);
v->Visit("extent", &extent);
- v->Visit("for_type", &for_type);
- v->Visit("device_api", &device_api);
+ v->Visit("kind", &kind);
v->Visit("body", &body);
+ v->Visit("thread_binding", &thread_binding);
+ v->Visit("annotations", &annotations);
v->Visit("span", &span);
}
bool SEqualReduce(const ForNode* other, SEqualReducer equal) const {
return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min)
&&
- equal(extent, other->extent) && equal(for_type, other->for_type) &&
- equal(device_api, other->device_api) && equal(body, other->body);
+ equal(extent, other->extent) && equal(kind, other->kind) &&
equal(body, other->body) &&
+ equal(thread_binding, other->thread_binding) && equal(annotations,
other->annotations);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(loop_var);
hash_reduce(min);
hash_reduce(extent);
- hash_reduce(for_type);
- hash_reduce(device_api);
+ hash_reduce(kind);
hash_reduce(body);
+ hash_reduce(thread_binding);
+ hash_reduce(annotations);
}
static constexpr const char* _type_key = "tir.For";
@@ -832,8 +854,9 @@ class ForNode : public StmtNode {
*/
class For : public Stmt {
public:
- TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type,
DeviceAPI device_api,
- Stmt body, Span span = Span());
+ TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt
body,
+ Optional<IterVar> thread_binding = NullOpt,
+ Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode);
};
@@ -1015,7 +1038,7 @@ inline bool IsPragmaKey(const std::string& attr_key) {
TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span());
// overload printing of for type.
-TVM_DLL std::ostream& operator<<(std::ostream& os, ForType for_type);
+TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind);
} // namespace tir
} // namespace tvm
diff --git a/python/tvm/script/scope_handler.py
b/python/tvm/script/scope_handler.py
index 21ed7f6..9449cbd 100644
--- a/python/tvm/script/scope_handler.py
+++ b/python/tvm/script/scope_handler.py
@@ -226,7 +226,7 @@ class Serial(ForScopeHandler):
self.context.report_error("Expect exact 1 loop var", span)
ana = tvm.arith.Analyzer()
extent = end if begin == 0 else ana.simplify(end - begin)
- return tvm.tir.For(self.loop_vars[0], begin, extent, 0, 0,
self.body, span=span)
+ return tvm.tir.For(self.loop_vars[0], begin, extent, 0, self.body,
span=span)
super().__init__(serial)
@@ -241,7 +241,7 @@ class Parallel(ForScopeHandler):
self.context.report_error("Expect exact 1 loop var")
ana = tvm.arith.Analyzer()
extent = end if begin == 0 else ana.simplify(end - begin)
- return tvm.tir.For(self.loop_vars[0], begin, extent, 1, 0,
self.body, span=span)
+ return tvm.tir.For(self.loop_vars[0], begin, extent, 1, self.body,
span=span)
super().__init__(parallel)
@@ -256,7 +256,7 @@ class Vectorized(ForScopeHandler):
self.context.report_error("Expect exact 1 loop var")
ana = tvm.arith.Analyzer()
extent = end if begin == 0 else ana.simplify(end - begin)
- return tvm.tir.For(self.loop_vars[0], begin, extent, 2, 0,
self.body, span=span)
+ return tvm.tir.For(self.loop_vars[0], begin, extent, 2, self.body,
span=span)
super().__init__(vectorized)
@@ -271,6 +271,6 @@ class Unroll(ForScopeHandler):
self.context.report_error("Expect exact 1 loop var")
ana = tvm.arith.Analyzer()
extent = end if begin == 0 else ana.simplify(end - begin)
- return tvm.tir.For(self.loop_vars[0], begin, extent, 3, 0,
self.body, span=span)
+ return tvm.tir.For(self.loop_vars[0], begin, extent, 3, self.body,
span=span)
super().__init__(unroll)
diff --git a/python/tvm/te/hybrid/calls.py b/python/tvm/te/hybrid/calls.py
index 7611891..6785457 100644
--- a/python/tvm/te/hybrid/calls.py
+++ b/python/tvm/te/hybrid/calls.py
@@ -23,18 +23,18 @@ from tvm.ir.container import Array
from tvm.target import Target
from tvm.tir import expr as _expr
from tvm.tir import call_intrin
-from tvm.tir.stmt import For
+from tvm.tir.stmt import ForKind
from .utils import _internal_assert
# pylint: disable=redefined-builtin,invalid-name
LOOP_INTRIN = {
- "range": For.Serial,
- "unroll": For.Unrolled,
- "parallel": For.Parallel,
- "vectorize": For.Vectorized,
- "const_range": (For.Unrolled,),
+ "range": ForKind.SERIAL,
+ "unroll": ForKind.UNROLLED,
+ "parallel": ForKind.PARALLEL,
+ "vectorize": ForKind.VECTORIZED,
+ "const_range": (ForKind.UNROLLED,),
}
@@ -48,9 +48,9 @@ def _range(annotation, args):
low, ext = args[0], args[1]
if not tvm.tir.analysis.expr_deep_equal(low, const(0, dtype="int32")):
ext = ext - low
- for_type = LOOP_INTRIN[annotation]
+ kind = LOOP_INTRIN[annotation]
iter_var = None
- return iter_var, low, ext, for_type
+ return iter_var, low, ext, kind
range = unroll = vectorize = parallel = const_range = _range # pylint:
disable=invalid-name
@@ -63,8 +63,8 @@ def bind(func_id, args):
_internal_assert(isinstance(args[0], str), "A loop bind's first argument
should be a string!")
low, ext = const(0, "int32"), args[1]
iter_var = tvm.te.thread_axis((low, ext), args[0])
- for_type = None
- return iter_var, low, ext, for_type
+ kind = None
+ return iter_var, low, ext, kind
def _math_intrin(func_id, args):
diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py
index d47b2ee..7bb85e3 100644
--- a/python/tvm/te/hybrid/parser.py
+++ b/python/tvm/te/hybrid/parser.py
@@ -480,14 +480,14 @@ class HybridParser(ast.NodeVisitor):
return op
def visit_For(self, node):
- iter_var, low, ext, for_type = self.visit(node.iter)
+ iter_var, low, ext, kind = self.visit(node.iter)
_internal_assert(
isinstance(node.target, ast.Name), "The loop iterator should be a
variable!"
)
_name = node.target.id
- if isinstance(for_type, tuple):
+ if isinstance(kind, tuple):
low = self.analyzer.simplify(low)
ext = self.analyzer.simplify(ext)
_internal_assert(
@@ -511,14 +511,14 @@ class HybridParser(ast.NodeVisitor):
return concat_list_to_block(bodies)
if iter_var is None:
- _internal_assert(for_type is not None, "The loop iterating
function parse error!")
+ _internal_assert(kind is not None, "The loop iterating function
parse error!")
offset = iter_var = tvm.te.var(_name)
if not tvm.tir.analysis.expr_deep_equal(low, tvm.runtime.const(0,
"int32")):
offset = iter_var + low
self.add_symbol(_name, Symbol.LoopVar, offset)
_body = visit_list_to_block(self.visit, node.body)
else:
- _internal_assert(for_type is None, "The loop bind function parse
error!")
+ _internal_assert(kind is None, "The loop bind function parse
error!")
self.add_symbol(_name, Symbol.ThreadBind, iter_var)
self.device += 1
_body = visit_list_to_block(self.visit, node.body)
@@ -526,13 +526,13 @@ class HybridParser(ast.NodeVisitor):
_body = self.wrap_up_realize(node, _body)
- if for_type is None:
+ if kind is None:
res = _body
else:
_internal_assert(
- not isinstance(for_type, tuple), "Micro expansion should be
handled before!"
+ not isinstance(kind, tuple), "Micro expansion should be
handled before!"
)
- res = tvm.tir.For(iter_var, tvm.runtime.const(0, "int32"), ext,
for_type, 0, _body)
+ res = tvm.tir.For(iter_var, tvm.runtime.const(0, "int32"), ext,
kind, _body)
self.symbols.pop(_name)
return res
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 901c89e..324c4da 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -27,7 +27,7 @@ from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or,
Not
from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast,
Shuffle
from .expr import Call, CallEffectKind, Let, IterVar, Any
-from .stmt import Stmt, LetStmt, AssertStmt, For
+from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For
from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate,
AttrStmt
from .stmt import ProducerRealize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py
index 6dcc858..437e8f6 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder.py
@@ -206,7 +206,7 @@ class IRBuilder(object):
value = op.max(1, value)
self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x))
- def for_range(self, begin, end, name="i", dtype="int32",
for_type="serial"):
+ def for_range(self, begin, end, name="i", dtype="int32", kind="serial"):
"""Create a for iteration scope.
Parameters
@@ -224,7 +224,7 @@ class IRBuilder(object):
dtype : str, optional
The data type of iteration variable.
- for_type : str, optional
+ kind : str, optional
The special tag on the for loop.
Returns
@@ -249,17 +249,17 @@ class IRBuilder(object):
extent = end if begin == 0 else (end - begin)
def _exit_cb():
- if for_type == "serial":
- for_type_id = 0
- elif for_type == "parallel":
- for_type_id = 1
- elif for_type == "vectorize":
- for_type_id = 2
- elif for_type == "unroll":
- for_type_id = 3
+ if kind == "serial":
+ kind_id = _stmt.ForKind.SERIAL
+ elif kind == "parallel":
+ kind_id = _stmt.ForKind.PARALLEL
+ elif kind == "vectorize":
+ kind_id = _stmt.ForKind.VECTORIZED
+ elif kind == "unroll":
+ kind_id = _stmt.ForKind.UNROLLED
else:
- raise ValueError("Unknown for_type")
- self.emit(_stmt.For(loop_var, begin, extent, for_type_id, 0,
self._pop_seq()))
+ raise ValueError("Unknown kind")
+ self.emit(_stmt.For(loop_var, begin, extent, kind_id,
self._pop_seq()))
return WithScope(loop_var, _exit_cb)
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index 6857b68..9e1ef56 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -26,6 +26,7 @@ Each statement node have subfields that can be visited from
python side.
assert isinstance(st, tvm.tir.stmt.Store)
assert(st.buffer_var == a)
"""
+from enum import IntEnum
import tvm._ffi
from tvm.runtime import Object
@@ -82,6 +83,22 @@ class AssertStmt(Stmt):
self.__init_handle_by_constructor__(_ffi_api.AssertStmt, condition,
message, body, span)
+class ForKind(IntEnum):
+ """The kind of the for loop.
+
+ note
+ ----
+ ForKind can change the control flow semantics
+ of the loop and need to be considered in all TIR passes.
+ """
+
+ SERIAL = 0
+ PARALLEL = 1
+ VECTORIZED = 2
+ UNROLLED = 3
+ THREAD_BINDING = 4
+
+
@tvm._ffi.register_object("tir.For")
class For(Stmt):
"""For node.
@@ -97,27 +114,44 @@ class For(Stmt):
extent : PrimExpr
The length of the loop.
- for_type : int
- The for type.
-
- device_api : int
- The device api type.
+ kind : ForKind
+ The type of the for.
body : Stmt
The body statement.
+ thread_binding: Optional[tir.IterVar]
+ The thread this loop binds to. Only valid
+ if kind is ThreadBinding
+
+ annotations: tvm.ir.Map
+ Additional annotation hints.
+
span : Optional[Span]
The location of this itervar in the source code.
"""
- Serial = 0
- Parallel = 1
- Vectorized = 2
- Unrolled = 3
-
- def __init__(self, loop_var, min_val, extent, for_type, device_api, body,
span=None):
+ def __init__(
+ self,
+ loop_var,
+ min_val,
+ extent,
+ kind,
+ body,
+ thread_binding=None,
+ annotations=None,
+ span=None,
+ ):
self.__init_handle_by_constructor__(
- _ffi_api.For, loop_var, min_val, extent, for_type, device_api,
body, span
+ _ffi_api.For,
+ loop_var,
+ min_val,
+ extent,
+ kind,
+ body,
+ thread_binding,
+ annotations,
+ span,
)
diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py
index 6f3ed78..0c01cc9 100644
--- a/python/tvm/topi/cuda/nms.py
+++ b/python/tvm/topi/cuda/nms.py
@@ -580,7 +580,7 @@ def nms_ir(
j = bx * max_threads + tx
with ib.if_scope(j < nkeep):
src_idx = base_src_idx + sorted_index[i * num_anchors + j] *
box_data_length
- with ib.for_range(0, 4, for_type="unroll") as k:
+ with ib.for_range(0, 4, kind="unroll") as k:
out_bboxes[(base_bbox_idx + j * 4 + k)] = data[src_idx +
coord_start + k]
out_scores[i * num_anchors + j] = data[src_idx + score_index]
@@ -593,7 +593,7 @@ def nms_ir(
# Only needed for return_indices = False case
if return_indices is False:
with ib.if_scope(j < num_anchors):
- with ib.for_range(0, 4, for_type="unroll") as k:
+ with ib.for_range(0, 4, kind="unroll") as k:
out_bboxes[(base_bbox_idx + j * 4 + k)] = -1.0
out_scores[i, j] = -1.0
@@ -609,7 +609,7 @@ def nms_ir(
with ib.if_scope(j < valid_count[i]):
src_offset = base_src_idx + j * box_data_length
- with ib.for_range(0, 4, for_type="unroll") as k:
+ with ib.for_range(0, 4, kind="unroll") as k:
out_bboxes[base_bbox_idx + j * 4 + k] = data[src_offset +
coord_start + k]
out_scores[i * num_anchors + j] = data[src_offset +
score_index]
@@ -855,7 +855,7 @@ def _concatenate_outputs(
i = by
with ib.if_scope(tid < num_anchors):
- with ib.for_range(0, 4, for_type="unroll") as j:
+ with ib.for_range(0, 4, kind="unroll") as j:
out[i, tid, coord_start + j] = out_bboxes[i, tid, j]
out[i, tid, score_index] = out_scores[i, tid]
if id_index >= 0:
diff --git a/python/tvm/topi/cuda/rcnn/proposal.py
b/python/tvm/topi/cuda/rcnn/proposal.py
index 5b7884c..e5e83b4 100644
--- a/python/tvm/topi/cuda/rcnn/proposal.py
+++ b/python/tvm/topi/cuda/rcnn/proposal.py
@@ -181,7 +181,7 @@ def argsort_ir(data_buf, out_index_buf):
idxm = tvm.tir.indexmod
- with ib.for_range(0, batch, for_type="unroll") as b:
+ with ib.for_range(0, batch, kind="unroll") as b:
start = b * num_bbox
for i in range(2):
bbox_id = tid * 2 + i
@@ -259,7 +259,7 @@ def nms_ir(sorted_bbox_buf, out_buf, nms_threshold):
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
i = bx * max_threads + tx
- with ib.for_range(0, batch, for_type="unroll", name="n") as b:
+ with ib.for_range(0, batch, kind="unroll", name="n") as b:
base_idx = b * num_bbox
with ib.if_scope(i < num_bbox):
p_out[base_idx + i] = False
@@ -323,7 +323,7 @@ def prepare_output_ir(sorted_bbox_buf, remove_mask_buf,
out_buf):
tvm.tir.all(i[0] < rpn_post_nms_top_n, p_remove[(b *
num_bbox + j)] == False)
):
p_out[offset_i] = tvm.tir.Cast("float32", b)
- with ib.for_range(0, 4, for_type="unroll") as k:
+ with ib.for_range(0, 4, kind="unroll") as k:
p_out[offset_i + k + 1] = p_sorted_bbox[offset_j + k]
i[0] = i[0] + 1
diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py
index f2cecac..cb61d96 100644
--- a/python/tvm/topi/cuda/sparse.py
+++ b/python/tvm/topi/cuda/sparse.py
@@ -228,8 +228,8 @@ def sparse_dense_tir(data, w_data, w_indices, w_indptr):
)
# zero block
- with ib.for_range(0, bs_m, name="x", for_type="unroll") as x:
- with ib.for_range(0, bs_n, name="y", for_type="unroll") as y:
+ with ib.for_range(0, bs_m, name="x", kind="unroll") as x:
+ with ib.for_range(0, bs_n, name="y", kind="unroll") as y:
block[x, y] = 0.0
# compute into thread local storage using warp_size chunks
with ib.for_range(0, rowlength_bo, name="bb") as bb:
@@ -240,26 +240,26 @@ def sparse_dense_tir(data, w_data, w_indices, w_indptr):
# each thread has a row
# TODO: ideally we could vectorize this
with ib.for_range(0, rowlength_bi, name="bi") as bi:
- with ib.for_range(0, bs_m, name="x", for_type="unroll") as x:
- with ib.for_range(0, bs_k, name="z", for_type="unroll") as
z:
+ with ib.for_range(0, bs_m, name="x", kind="unroll") as x:
+ with ib.for_range(0, bs_k, name="z", kind="unroll") as z:
# This memory acces should be out of bounds when
# m_index >= mb (which occurs when the dense matrix
# rows % 32 != 0), but it seems to work just fine...
data_cache[bi, x, z] = data_ptr[indices[bi] * bs_k +
z, m_index * bs_m + x]
# cache w_data
elem_idx = bb * rowlength_bi + tx
- with ib.for_range(0, bs_n, name="y", for_type="unroll") as y:
- with ib.for_range(0, bs_k, name="z", for_type="unroll") as z:
+ with ib.for_range(0, bs_n, name="y", kind="unroll") as y:
+ with ib.for_range(0, bs_k, name="z", kind="unroll") as z:
w_data_cache[tx, y, z] = w_data_ptr[row_start + elem_idx,
y, z]
with ib.for_range(0, mi, name="i") as i:
# thread local block matmul
- with ib.for_range(0, bs_m, name="x", for_type="unroll") as x:
- with ib.for_range(0, bs_n, name="y", for_type="unroll") as
y:
- with ib.for_range(0, bs_k, name="z",
for_type="unroll") as z:
+ with ib.for_range(0, bs_m, name="x", kind="unroll") as x:
+ with ib.for_range(0, bs_n, name="y", kind="unroll") as y:
+ with ib.for_range(0, bs_k, name="z", kind="unroll") as
z:
block[x, y] += data_cache[i, x, z] *
w_data_cache[i, y, z]
# store results
- with ib.for_range(0, bs_m, name="x", for_type="unroll") as x:
- with ib.for_range(0, bs_n, name="y", for_type="unroll") as y:
+ with ib.for_range(0, bs_m, name="x", kind="unroll") as x:
+ with ib.for_range(0, bs_n, name="y", kind="unroll") as y:
with ib.if_scope(m_index < mb):
with ib.if_scope(n_index < nb):
# It doesn't seem like we would be getting coelesced
diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py
index cdccc80..8145ed8 100644
--- a/python/tvm/topi/nn/sparse.py
+++ b/python/tvm/topi/nn/sparse.py
@@ -294,26 +294,26 @@ def _csr_transpose_ir(data, indices, indptr, out_data,
out_indices, out_indptr):
n = get_const_tuple(indptr.shape)[0] - 1
nnz = get_const_tuple(data.shape)[0]
- with irb.for_range(0, n, for_type="parallel", name="col") as col:
+ with irb.for_range(0, n, kind="parallel", name="col") as col:
out_indptr_ptr[col] = 0
- with irb.for_range(0, nnz, for_type="serial", name="nz_idx") as nz_idx:
+ with irb.for_range(0, nnz, kind="serial", name="nz_idx") as nz_idx:
out_indptr_ptr[indices_ptr[nz_idx]] += 1
cumsum = irb.allocate("int32", (1,), name="cumsum", scope="local")
temp = irb.allocate("int32", (1,), name="temp", scope="local")
cumsum[0] = 0
- with irb.for_range(0, n, for_type="serial", name="col") as col:
+ with irb.for_range(0, n, kind="serial", name="col") as col:
temp[0] = out_indptr_ptr[col]
out_indptr_ptr[col] = cumsum[0]
cumsum[0] += temp[0]
out_indptr_ptr[n] = nnz
- with irb.for_range(0, n, for_type="serial", name="row") as row:
+ with irb.for_range(0, n, kind="serial", name="row") as row:
offset = indptr_ptr[row]
diff = indptr_ptr[row + 1] - indptr_ptr[row]
- with irb.for_range(0, diff, for_type="serial", name="idx") as idx:
+ with irb.for_range(0, diff, kind="serial", name="idx") as idx:
real_idx = offset + idx
col = indices_ptr[real_idx]
dest = out_indptr_ptr[col]
@@ -325,7 +325,7 @@ def _csr_transpose_ir(data, indices, indptr, out_data,
out_indices, out_indptr):
last = irb.allocate("int32", (1,), name="last", scope="local")
temp2 = irb.allocate("int32", (1,), name="temp2", scope="local")
last[0] = 0
- with irb.for_range(0, n, for_type="serial", name="col") as col:
+ with irb.for_range(0, n, kind="serial", name="col") as col:
temp2[0] = out_indptr_ptr[col]
out_indptr_ptr[col] = last[0]
last[0] = temp2[0]
diff --git a/python/tvm/topi/sparse/csrmm.py b/python/tvm/topi/sparse/csrmm.py
index f578e60..39ba333 100644
--- a/python/tvm/topi/sparse/csrmm.py
+++ b/python/tvm/topi/sparse/csrmm.py
@@ -72,8 +72,8 @@ def csrmm_default(data, indices, indptr, weight, bias=None):
out_ptr = irb.buffer_ptr(out)
M = simplify(indptr.shape[0] - 1)
_, N = weight.shape
- with irb.for_range(0, N, for_type="vectorize", name="n") as n:
- with irb.for_range(0, M, for_type="parallel", name="row") as row:
+ with irb.for_range(0, N, kind="vectorize", name="n") as n:
+ with irb.for_range(0, M, kind="parallel", name="row") as row:
dot = irb.allocate("float32", (1,), name="dot", scope="local")
out_ptr[row * N + n] = 0.0
dot[0] = 0.0
diff --git a/python/tvm/topi/sparse/csrmv.py b/python/tvm/topi/sparse/csrmv.py
index afe3bc7..a2d22af 100644
--- a/python/tvm/topi/sparse/csrmv.py
+++ b/python/tvm/topi/sparse/csrmv.py
@@ -63,7 +63,7 @@ def csrmv_default(data, indices, indptr, weight, bias=None):
weight_ptr = irb.buffer_ptr(weight)
out_ptr = irb.buffer_ptr(out)
num_rows = indptr.shape[0] - 1
- with irb.for_range(0, num_rows, for_type="parallel", name="row") as
row:
+ with irb.for_range(0, num_rows, kind="parallel", name="row") as row:
dot = irb.allocate("float32", (1,), name="dot", scope="local")
out_ptr[row] = 0.0
dot[0] = 0.0
diff --git a/python/tvm/topi/sparse/dense.py b/python/tvm/topi/sparse/dense.py
index d1516d0..5c63e44 100644
--- a/python/tvm/topi/sparse/dense.py
+++ b/python/tvm/topi/sparse/dense.py
@@ -74,8 +74,8 @@ def dense_si(data, indices, indptr, weight, bias=None):
out_ptr = irb.buffer_ptr(out)
M = simplify(indptr.shape[0] - 1)
N, K = weight.shape
- with irb.for_range(0, N, for_type="vectorize", name="n") as n:
- with irb.for_range(0, M, for_type="parallel", name="m") as m:
+ with irb.for_range(0, N, kind="vectorize", name="n") as n:
+ with irb.for_range(0, M, kind="parallel", name="m") as m:
dot = irb.allocate(dtype, (1,), name="dot", scope="local")
out_ptr[m * N + n] = tvm.tir.const(0, dtype)
dot[0] = tvm.tir.const(0, dtype)
@@ -153,8 +153,8 @@ def dense_sw(data, w_data, w_indices, w_indptr, bias=None):
out_ptr = irb.buffer_ptr(out)
M, K = data.shape
N = simplify(w_indptr.shape[0] - 1)
- with irb.for_range(0, M, for_type="vectorize", name="m") as m:
- with irb.for_range(0, N, for_type="parallel", name="n") as n:
+ with irb.for_range(0, M, kind="vectorize", name="m") as m:
+ with irb.for_range(0, N, kind="parallel", name="n") as n:
dot = irb.allocate(dtype, (1,), name="dot", scope="local")
out_ptr[m * N + n] = tvm.tir.const(0, dtype)
dot[0] = tvm.tir.const(0, dtype)
diff --git a/python/tvm/topi/vision/rcnn/proposal.py
b/python/tvm/topi/vision/rcnn/proposal.py
index 89726ef..e15ba8c 100644
--- a/python/tvm/topi/vision/rcnn/proposal.py
+++ b/python/tvm/topi/vision/rcnn/proposal.py
@@ -208,7 +208,7 @@ def argsort_ir(data_buf, out_index_buf):
temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
idxm = tvm.tir.indexmod
- with ib.for_range(0, batch, for_type="unroll") as b:
+ with ib.for_range(0, batch, kind="unroll") as b:
start = b * num_bbox
for i in range(2):
with ib.for_range(0, (num_bbox + 1) // 2) as tid:
@@ -279,7 +279,7 @@ def nms_ir(sorted_bbox_buf, out_buf, nms_threshold):
ib = tvm.tir.ir_builder.create()
p_data = ib.buffer_ptr(sorted_bbox_buf)
p_out = ib.buffer_ptr(out_buf)
- with ib.for_range(0, batch, for_type="unroll", name="n") as b:
+ with ib.for_range(0, batch, kind="unroll", name="n") as b:
base_idx = b * num_bbox
for i in range(num_bbox):
p_out[base_idx + i] = False
@@ -345,7 +345,7 @@ def prepare_output_ir(sorted_bbox_buf, remove_mask_buf,
out_buf):
)
):
p_out[offset_i] = tvm.tir.Cast("float32", b)
- with ib.for_range(0, 4, for_type="unroll") as k:
+ with ib.for_range(0, 4, kind="unroll") as k:
p_out[offset_i + k + 1] = p_sorted_bbox[offset_j +
k]
i[b] = i[b] + 1
diff --git a/python/tvm/topi/x86/scatter.py b/python/tvm/topi/x86/scatter.py
index 8147d3a..8bb3f57 100644
--- a/python/tvm/topi/x86/scatter.py
+++ b/python/tvm/topi/x86/scatter.py
@@ -84,7 +84,7 @@ def scatter_nd(data, indices, shape):
out[i] = tvm.tir.Cast(data_ptr.dtype, 0)
with ib.for_range(0, fused_indices_dimension) as i:
- with ib.for_range(0, fused_data_dimension, for_type="parallel") as
j:
+ with ib.for_range(0, fused_data_dimension, kind="parallel") as j:
offset = fused_data_dimension
index = j # This is x_M, .. x_{N-1} part of the index into
out.
# Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1,
y_0, .. y_{K-1}] part
diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc
index 1b10cd5..cf516d8 100755
--- a/src/auto_scheduler/feature.cc
+++ b/src/auto_scheduler/feature.cc
@@ -618,7 +618,7 @@ class PerStoreFeatureExtractor : public StmtExprVisitor {
is_gpu_ = true;
// make a fake for node for blockIdx.x or threadIdx.x
- Stmt fake_for_node = For(var, 0, extent, ForType::Parallel,
DeviceAPI::None, node->body);
+ Stmt fake_for_node = For(var, 0, extent, ForKind::kParallel, node->body);
outer_loop_prod_ *= extent;
for_loop_stack_.push_back(fake_for_node.as<ForNode>());
@@ -642,11 +642,11 @@ class PerStoreFeatureExtractor : public StmtExprVisitor {
void VisitStmt_(const ForNode* node) final {
int64_t loop_extent = GetLoopExtent(node);
- if (node->for_type == ForType::Vectorized) {
+ if (node->kind == ForKind::kVectorized) {
vec_for_stack_.push_back(node);
- } else if (node->for_type == ForType::Unrolled) {
+ } else if (node->kind == ForKind::kUnrolled) {
unroll_for_stack_.push_back(node);
- } else if (node->for_type == ForType::Parallel) {
+ } else if (node->kind == ForKind::kParallel) {
parallel_for_stack_.push_back(node);
}
@@ -656,11 +656,11 @@ class PerStoreFeatureExtractor : public StmtExprVisitor {
for_loop_stack_.pop_back();
outer_loop_prod_ /= loop_extent;
- if (node->for_type == ForType::Vectorized) {
+ if (node->kind == ForKind::kVectorized) {
vec_for_stack_.pop_back();
- } else if (node->for_type == ForType::Unrolled) {
+ } else if (node->kind == ForKind::kUnrolled) {
unroll_for_stack_.pop_back();
- } else if (node->for_type == ForType::Parallel) {
+ } else if (node->kind == ForKind::kParallel) {
parallel_for_stack_.pop_back();
}
}
diff --git a/src/autotvm/feature_visitor.cc b/src/autotvm/feature_visitor.cc
index 15e0975..59cac9c 100644
--- a/src/autotvm/feature_visitor.cc
+++ b/src/autotvm/feature_visitor.cc
@@ -34,19 +34,23 @@ void FeatureVisitor::VisitStmt_(const ForNode* op) {
int64_t loop_extent = -1;
if (extent != nullptr) loop_extent = extent->value;
AnnotationType ann = kSerial;
- switch (op->for_type) {
- case ForType ::Parallel:
+ switch (op->kind) {
+ case ForKind ::kParallel:
ann = kParallel;
break;
- case ForType::Unrolled:
+ case ForKind::kUnrolled:
ann = kUnrolled;
break;
- case ForType::Vectorized:
+ case ForKind::kVectorized:
ann = kVectorized;
break;
- case ForType::Serial:
+ case ForKind::kSerial:
ann = kSerial;
break;
+ case ForKind::kThreadBinding:
+ LOG(FATAL) << "Loop ThreadBinding is reserved for future used and "
+ << "not yet supported in TIR";
+ break;
}
if (EnterItervar_(op->loop_var, loop_extent, ann)) {
diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc
index 107817d..4b0871a 100644
--- a/src/printer/tir_text_printer.cc
+++ b/src/printer/tir_text_printer.cc
@@ -465,18 +465,21 @@ Doc TIRTextPrinter::VisitStmt_(const EvaluateNode* op) {
return doc;
}
-inline const char* ForType2String(ForType t) {
+inline const char* ForKind2String(ForKind t) {
switch (t) {
- case ForType::Serial:
+ case ForKind::kSerial:
return "serial";
- case ForType::Parallel:
+ case ForKind::kParallel:
return "parallel";
- case ForType::Vectorized:
+ case ForKind::kVectorized:
return "vectorized";
- case ForType::Unrolled:
+ case ForKind::kUnrolled:
return "unroll";
+ case ForKind::kThreadBinding:
+ LOG(FATAL) << "Loop ThreadBinding is reserved for future used and "
+ << "not yet supported in TIR";
}
- LOG(FATAL) << "Unknown ForType";
+ LOG(FATAL) << "Unknown ForKind";
return "Unknown";
}
@@ -484,8 +487,8 @@ Doc TIRTextPrinter::VisitStmt_(const ForNode* op) {
Doc doc;
doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", "
<< Print(op->min + op->extent) << ")";
- if (op->for_type != ForType::Serial) {
- doc << " " << Doc::StrLiteral(ForType2String(op->for_type));
+ if (op->kind != ForKind::kSerial) {
+ doc << " " << Doc::StrLiteral(ForKind2String(op->kind));
}
doc << PrintBody(op->body);
return doc;
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index 09f95e4..86b175e 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -649,27 +649,30 @@ Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) {
return doc;
}
-inline const char* ForType2String(ForType t) {
+inline const char* ForKind2String(ForKind t) {
switch (t) {
- case ForType::Serial:
+ case ForKind::kSerial:
return "serial";
- case ForType::Parallel:
+ case ForKind::kParallel:
return "parallel";
- case ForType::Vectorized:
+ case ForKind::kVectorized:
return "vectorized";
- case ForType::Unrolled:
+ case ForKind::kUnrolled:
return "unroll";
+ case ForKind::kThreadBinding:
+ LOG(FATAL) << "Loop ThreadBinding is reserved for future used and "
+ << "not yet supported in TIR";
+ return "threadbinding";
}
- LOG(FATAL) << "Unknown ForType";
+ LOG(FATAL) << "Unknown ForKind";
return "Unknown";
}
Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) {
Doc doc;
var_not_in_headers.insert(op->loop_var.get());
- doc << "for " << Print(op->loop_var)
- << " in tir." + std::string(ForType2String(op->for_type)) + "(" <<
Print(op->min) << ", "
- << Print(op->min + op->extent)
+ doc << "for " << Print(op->loop_var) << " in tir." +
std::string(ForKind2String(op->kind)) + "("
+ << Print(op->min) << ", " << Print(op->min + op->extent)
<< "):" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
return doc;
}
diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc
index 6143e70..e2a8553 100644
--- a/src/target/llvm/codegen_cpu.cc
+++ b/src/target/llvm/codegen_cpu.cc
@@ -976,12 +976,13 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) {
void CodeGenCPU::VisitStmt_(const ForNode* op) {
ICHECK(is_zero(op->min));
- if (op->for_type == ForType::Serial || op->for_type == ForType::Unrolled) {
+ if (op->kind == ForKind::kSerial || op->kind == ForKind::kUnrolled) {
CodeGenLLVM::VisitStmt_(op);
- } else if (op->for_type == ForType::Parallel) {
+ } else if (op->kind == ForKind::kParallel) {
if (parallel_env_.penv == nullptr) {
- CreateParallelLaunch(
- For(op->loop_var, op->min, op->extent, op->for_type, op->device_api,
op->body), 0);
+ CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind,
op->body,
+ op->thread_binding, op->annotations),
+ 0);
} else {
// already in parallel env.
ICHECK(parallel_env_.task_id.defined());
@@ -1007,7 +1008,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
++parallel_env_.parallel_loop_count;
}
} else {
- LOG(FATAL) << "cannot handle for type " << op->for_type;
+ LOG(FATAL) << "cannot handle for type " << op->kind;
}
}
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 34f3897..1dd76f6 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -1318,11 +1318,11 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) {
void CodeGenLLVM::VisitStmt_(const ForNode* op) {
ICHECK(is_zero(op->min));
analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
- if (op->for_type == ForType::Unrolled) {
+ if (op->kind == ForKind::kUnrolled) {
LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, "
<< " consider set unroll_explicit=True";
} else {
- ICHECK(op->for_type == ForType::Serial);
+ ICHECK(op->kind == ForKind::kSerial);
}
CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1),
op->loop_var, op->body);
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index 6c73716..e554731 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -84,7 +84,7 @@ std::string CodeGenCUDA::Finish() {
void CodeGenCUDA::VisitStmt_(const tir::ForNode* op) {
ICHECK(is_const_int(op->min, 0));
- if (op->for_type == tir::ForType::Unrolled) {
+ if (op->kind == tir::ForKind::kUnrolled) {
PrintIndent();
stream << "#pragma unroll\n";
}
diff --git a/src/target/spirv/codegen_spirv.cc
b/src/target/spirv/codegen_spirv.cc
index c3b12ab..51d136d 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -492,7 +492,7 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) {
loop_var.SetIncoming(0, init_value, init_label);
spirv::Value loop_cond = builder_->LT(loop_var, extent_value);
uint32_t control =
- (op->for_type == ForType::Unrolled ? spv::LoopControlUnrollMask :
spv::LoopControlMaskNone);
+ (op->kind == ForKind::kUnrolled ? spv::LoopControlUnrollMask :
spv::LoopControlMaskNone);
builder_->MakeInst(spv::OpLoopMerge, merge_label, continue_label, control);
builder_->MakeInst(spv::OpBranchConditional, loop_cond, body_label,
merge_label,
weight_likely_branch_, 1);
diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc
index 94e06d2..65b8660 100644
--- a/src/te/operation/hybrid_op.cc
+++ b/src/te/operation/hybrid_op.cc
@@ -234,9 +234,9 @@ Stmt ApplyLoopShapes(const Stage& stage, const
std::unordered_map<IterVar, Range
PrimExpr cond = likely(outer * factor < (op->extent - inner));
ret = IfThenElse(cond, ret);
ret = For(inner->var, PrimExpr(0), inner->dom->extent,
- IterVarTypeToForType(inner->iter_type), op->device_api, ret);
+ IterVarTypeToForKind(inner->iter_type), ret);
ret = For(outer->var, PrimExpr(0), outer->dom->extent,
- IterVarTypeToForType(outer->iter_type), op->device_api, ret);
+ IterVarTypeToForKind(outer->iter_type), ret);
splitted = true;
return ret;
}
@@ -277,8 +277,8 @@ Stmt ApplyLoopShapes(const Stage& stage, const
std::unordered_map<IterVar, Range
rmap[op->loop_var.get()] = indexdiv(parent, extent);
body = tir::Substitute(body, rmap);
under_outer = false;
- return For(parent->var, PrimExpr(0), extent * op->extent,
op->for_type, op->device_api,
- body);
+ return For(parent->var, PrimExpr(0), extent * op->extent, op->kind,
body,
+ op->thread_binding, op->annotations);
} else if (under_outer) {
Stmt body = this->VisitStmt(op->body);
std::unordered_map<const VarNode*, PrimExpr> rmap;
@@ -331,8 +331,8 @@ Stmt ApplyLoopAnnotations(const Stage& stage, const
std::unordered_map<IterVar,
Stmt body = tir::Substitute(op->body, rmap);
return AttrStmt(iter_var, "thread_extent", op->extent, body);
} else {
- return For(op->loop_var, op->min, op->extent,
IterVarTypeToForType(attr->iter_type),
- op->device_api, op->body);
+ return For(op->loop_var, op->min, op->extent,
IterVarTypeToForKind(attr->iter_type),
+ op->body, op->thread_binding, op->annotations);
}
}
return StmtMutator::VisitStmt_(op);
@@ -345,18 +345,18 @@ Stmt ApplyLoopAnnotations(const Stage& stage, const
std::unordered_map<IterVar,
const IterVar& actual = rebased.count(iter_var) ?
rebased.find(iter_var)->second : iter_var;
const VarNode* var = actual->var.get();
- ForType expected = IterVarTypeToForType(iter_var->iter_type);
+ ForKind expected = IterVarTypeToForKind(iter_var->iter_type);
IterVarAttr attr;
if (stage->iter_var_attrs.count(iter_var)) {
attr = stage->iter_var_attrs[iter_var];
- expected = IterVarTypeToForType(attr->iter_type);
+ expected = IterVarTypeToForKind(attr->iter_type);
}
PostOrderVisit(stmt, [&found, &var, &attr, &expected, &need_change](const
ObjectRef& node) {
if (const ForNode* op = node.as<ForNode>()) {
if (op->loop_var.get() == var) {
++found;
- need_change = expected != op->for_type || (attr.defined() &&
attr->bind_thread.defined());
+ need_change = expected != op->kind || (attr.defined() &&
attr->bind_thread.defined());
}
}
});
@@ -409,12 +409,13 @@ Stmt ApplyLoopOrder(const Stage& stage, const
std::unordered_map<IterVar, Range>
if (body_.same_as(op->body) && op->loop_var.get() == target->var.get())
return GetRef<Stmt>(op);
const Stmt& body = op->body.same_as(body_) ? op->body : body_;
- ForType for_type = IterVarTypeToForType(target->iter_type);
+ ForKind kind = IterVarTypeToForKind(target->iter_type);
if (stage->iter_var_attrs.count(target)) {
- for_type =
IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type);
+ kind = IterVarTypeToForKind(stage->iter_var_attrs[target]->iter_type);
}
const Range& range = target->dom.defined() ? target->dom :
dom_map.find(target)->second;
- return For(target->var, range->min, range->extent, for_type,
DeviceAPI::None, body);
+ return For(target->var, range->min, range->extent, kind, body,
op->thread_binding,
+ op->annotations);
}
};
@@ -448,7 +449,7 @@ std::vector<IterVar> GatherLoopVars(Stmt stmt) {
if (const ForNode* op = node.as<ForNode>()) {
Var loop_var(op->loop_var);
Range dom = Range::FromMinExtent(op->min, op->extent);
- res_.push_back(IterVar(dom, loop_var,
ForTypeToIterVarType(op->for_type)));
+ res_.push_back(IterVar(dom, loop_var, ForKindToIterVarType(op->kind)));
}
});
std::reverse(res_.begin(), res_.end());
diff --git a/src/te/operation/op_utils.cc b/src/te/operation/op_utils.cc
index f1991c1..32ffccb 100644
--- a/src/te/operation/op_utils.cc
+++ b/src/te/operation/op_utils.cc
@@ -77,7 +77,7 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage&
stage,
var = Var(iv->var->name_hint + ".init", bind_iv->var.dtype());
}
- ForType for_type = ForType::Serial;
+ ForKind kind = ForKind::kSerial;
IterVarAttr it_attr;
if (stage->iter_var_attrs.count(iv)) {
it_attr = stage->iter_var_attrs[iv];
@@ -85,13 +85,13 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage&
stage,
if (it_attr.defined()) {
switch (it_attr->iter_type) {
case kUnrolled:
- for_type = ForType::Unrolled;
+ kind = ForKind::kUnrolled;
break;
case kVectorized:
- for_type = ForType::Vectorized;
+ kind = ForKind::kVectorized;
break;
case kParallelized:
- for_type = ForType::Parallel;
+ kind = ForKind::kParallel;
break;
case kDataPar:
break;
@@ -115,11 +115,11 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage&
stage,
nest[i + 1].emplace_back(LetStmt(var, cast(var.dtype(), dom->min),
no_op));
value_map[iv] = cast(var.dtype(), dom->min);
} else if (is_zero(dom->min)) {
- nest[i + 1].emplace_back(For(var, 0, dom->extent, for_type,
DeviceAPI::None, no_op));
+ nest[i + 1].emplace_back(For(var, 0, dom->extent, kind, no_op));
value_map[iv] = var;
} else {
Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.dtype());
- nest[i + 1].emplace_back(For(idx, 0, dom->extent, for_type,
DeviceAPI::None, no_op));
+ nest[i + 1].emplace_back(For(idx, 0, dom->extent, kind, no_op));
PrimExpr new_value = dom->min + idx;
value_map[iv] = new_value;
nest[i + 1].emplace_back(LetStmt(var, new_value, no_op));
@@ -243,33 +243,33 @@ Stmt Substitute(Stmt s, const std::unordered_map<IterVar,
PrimExpr>& value_map)
return tir::Substitute(s, init);
}
-IterVarType ForTypeToIterVarType(tir::ForType for_type) {
- switch (for_type) {
- case ForType::Serial:
+IterVarType ForKindToIterVarType(tir::ForKind kind) {
+ switch (kind) {
+ case ForKind::kSerial:
return kDataPar;
- case ForType::Parallel:
+ case ForKind::kParallel:
return kParallelized;
- case ForType::Vectorized:
+ case ForKind::kVectorized:
return kVectorized;
- case ForType::Unrolled:
+ case ForKind::kUnrolled:
return kUnrolled;
default:
return kDataPar;
}
}
-tir::ForType IterVarTypeToForType(IterVarType iter_type) {
+tir::ForKind IterVarTypeToForKind(IterVarType iter_type) {
switch (iter_type) {
case kDataPar:
- return ForType::Serial;
+ return ForKind::kSerial;
case kParallelized:
- return ForType::Parallel;
+ return ForKind::kParallel;
case kVectorized:
- return ForType::Vectorized;
+ return ForKind::kVectorized;
case kUnrolled:
- return ForType::Unrolled;
+ return ForKind::kUnrolled;
default:
- return ForType::Serial;
+ return ForKind::kSerial;
}
}
diff --git a/src/te/operation/op_utils.h b/src/te/operation/op_utils.h
index 16f7d96..e6bf2ca 100644
--- a/src/te/operation/op_utils.h
+++ b/src/te/operation/op_utils.h
@@ -88,16 +88,16 @@ PrimExpr ReplaceTensor(PrimExpr expr, const
std::unordered_map<Tensor, Tensor>&
Stmt Substitute(Stmt stmt, const std::unordered_map<IterVar, PrimExpr>&
value_map);
/*!
- * \brief Converts Halide ForType to its corresponding IterVarType
- * \param for_type The ForType to be converted
+ * \brief Converts Halide ForKind to its corresponding IterVarType
+ * \param kind The ForKind to be converted
*/
-IterVarType ForTypeToIterVarType(tir::ForType for_type);
+IterVarType ForKindToIterVarType(tir::ForKind kind);
/*!
- * \brief Converts IterVarType to its corresponding Halide ForType
+ * \brief Converts IterVarType to its corresponding Halide ForKind
* \param iter_type The IterVarType to be converted
*/
-tir::ForType IterVarTypeToForType(IterVarType iter_type);
+tir::ForKind IterVarTypeToForKind(IterVarType iter_type);
} // namespace te
} // namespace tvm
diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
index f81d72e..74d1a19 100644
--- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
+++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
@@ -968,7 +968,8 @@ class TensorCoreIRMutator : public StmtExprMutator {
scaled_extent_value = ori_extent_value / scale_factor;
}
PrimExpr scaled_extent = make_const(op->extent.dtype(),
scaled_extent_value);
- stmt = For(op->loop_var, op->min, scaled_extent, op->for_type,
op->device_api, op->body);
+ stmt = For(op->loop_var, op->min, scaled_extent, op->kind, op->body,
op->thread_binding,
+ op->annotations);
}
}
return stmt;
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index fd03046..92dc387 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -128,8 +128,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});
// For
-For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type,
DeviceAPI device_api,
- Stmt body, Span span) {
+For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
+ Optional<IterVar> thread_binding, Map<String, ObjectRef> annotations,
Span span) {
ICHECK(min.defined());
ICHECK(extent.defined());
ICHECK(min.dtype().is_scalar());
@@ -141,36 +141,40 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent,
ForType for_type, DeviceAP
node->loop_var = std::move(loop_var);
node->min = std::move(min);
node->extent = std::move(extent);
- node->for_type = for_type;
- node->device_api = device_api;
+ node->kind = kind;
node->body = std::move(body);
+ node->thread_binding = std::move(thread_binding);
+ node->annotations = std::move(annotations);
node->span = std::move(span);
data_ = std::move(node);
}
-TVM_REGISTER_GLOBAL("tir.For").set_body_typed([](Var loop_var, PrimExpr min,
PrimExpr extent,
- int for_type, int device_api,
Stmt body,
- Span span) {
- return For(loop_var, min, extent, static_cast<ForType>(for_type),
- static_cast<DeviceAPI>(device_api), body, span);
-});
+TVM_REGISTER_GLOBAL("tir.For").set_body_typed(
+ [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, Stmt body,
+ Optional<IterVar> thread_binding, Optional<Map<String, ObjectRef>>
annotations, Span span) {
+ return For(loop_var, min, extent, static_cast<ForKind>(kind), body,
thread_binding,
+ annotations.value_or(Map<String, ObjectRef>()), span);
+ });
TVM_REGISTER_NODE_TYPE(ForNode);
-std::ostream& operator<<(std::ostream& out, ForType type) { // NOLINT(*)
+std::ostream& operator<<(std::ostream& out, ForKind type) { // NOLINT(*)
switch (type) {
- case ForType::Serial:
+ case ForKind::kSerial:
out << "for";
break;
- case ForType::Parallel:
+ case ForKind::kParallel:
out << "parallel";
break;
- case ForType::Unrolled:
+ case ForKind::kUnrolled:
out << "unrolled";
break;
- case ForType::Vectorized:
+ case ForKind::kVectorized:
out << "vectorized";
break;
+ case ForKind::kThreadBinding:
+ out << "launch_thread";
+ break;
}
return out;
}
@@ -179,7 +183,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ForNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ForNode*>(node.get());
p->PrintIndent();
- p->stream << op->for_type << " (" << op->loop_var << ", ";
+ p->stream << op->kind << " (" << op->loop_var << ", ";
p->Print(op->min);
p->stream << ", ";
p->Print(op->extent);
diff --git a/src/tir/transforms/combine_context_call.cc
b/src/tir/transforms/combine_context_call.cc
index 03a0d5e..4a39864 100644
--- a/src/tir/transforms/combine_context_call.cc
+++ b/src/tir/transforms/combine_context_call.cc
@@ -72,7 +72,7 @@ class ContextCallCombiner final : public StmtExprMutator {
}
Stmt VisitStmt_(const ForNode* op) final {
- if (op->for_type == ForType::Parallel) {
+ if (op->kind == ForKind::kParallel) {
// Map of comparison expression to variable
std::unordered_map<PrimExpr, Var, StructuralHash, StructuralEqual> temp;
std::swap(temp, ctx_map_);
diff --git a/src/tir/transforms/inject_double_buffer.cc
b/src/tir/transforms/inject_double_buffer.cc
index 22a6ca2..7a16c06 100644
--- a/src/tir/transforms/inject_double_buffer.cc
+++ b/src/tir/transforms/inject_double_buffer.cc
@@ -158,8 +158,7 @@ class DoubleBufferInjector : public StmtExprMutator {
vmap[old_loop->loop_var.get()] = outer_var * factor +
make_const(factor.dtype(), i);
loop_seq.emplace_back(Substitute(old_loop->body, vmap));
}
- Stmt loop = For(outer_var, zero, outer_ext, old_loop->for_type,
old_loop->device_api,
- SeqStmt::Flatten(loop_seq));
+ Stmt loop = For(outer_var, zero, outer_ext, old_loop->kind,
SeqStmt::Flatten(loop_seq));
// tail
std::vector<Stmt> tail_seq;
Stmt tail_body = StripDoubleBufferWrite()(old_loop->body);
diff --git a/src/tir/transforms/inject_prefetch.cc
b/src/tir/transforms/inject_prefetch.cc
index b5c4cf5..4ce9c76 100644
--- a/src/tir/transforms/inject_prefetch.cc
+++ b/src/tir/transforms/inject_prefetch.cc
@@ -71,11 +71,11 @@ class PrefetchInjector : public StmtMutator {
Stmt VisitStmt_(const ForNode* op) final {
auto& var = op->loop_var;
loop_nest_.push_back(var);
- if (op->for_type == ForType::Vectorized) {
+ if (op->kind == ForKind::kVectorized) {
vectorized_[var.get()] = IntSet::Interval(op->min, (op->min +
op->extent) - 1);
}
Stmt ret = StmtMutator::VisitStmt_(op);
- if (op->for_type == ForType::Vectorized) {
+ if (op->kind == ForKind::kVectorized) {
vectorized_.erase(var.get());
}
loop_nest_.pop_back();
diff --git a/src/tir/transforms/inject_virtual_thread.cc
b/src/tir/transforms/inject_virtual_thread.cc
index 5622d14..b24a0e9 100644
--- a/src/tir/transforms/inject_virtual_thread.cc
+++ b/src/tir/transforms/inject_virtual_thread.cc
@@ -303,7 +303,10 @@ class VTInjector : public StmtExprMutator {
if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
- return For(op->loop_var, op->min, extent, op->for_type, op->device_api,
body);
+ auto n = CopyOnWrite(op);
+ n->extent = std::move(extent);
+ n->body = std::move(body);
+ return Stmt(n);
}
}
// IfThenElse
@@ -417,7 +420,7 @@ class VTInjector : public StmtExprMutator {
Map<Var, PrimExpr> values{{var_, idx}};
stmt = Substitute(stmt, values);
return For(idx, make_zero(idx.dtype()), make_const(idx.dtype(),
num_threads_),
- ForType::Serial, DeviceAPI::None, stmt);
+ ForKind::kSerial, stmt);
}
}
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index 033a2e0..cbae3f9 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -149,7 +149,8 @@ class IRConvertSSA final : public StmtExprMutator {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
scope_[v.get()].pop_back();
op = stmt.as<ForNode>();
- return For(new_var, op->min, op->extent, op->for_type, op->device_api,
op->body);
+ return For(new_var, op->min, op->extent, op->kind, op->body,
op->thread_binding,
+ op->annotations);
} else {
defined_.insert(v.get());
return StmtExprMutator::VisitStmt_(op);
diff --git a/src/tir/transforms/loop_partition.cc
b/src/tir/transforms/loop_partition.cc
index a104dbb..f1d816f 100644
--- a/src/tir/transforms/loop_partition.cc
+++ b/src/tir/transforms/loop_partition.cc
@@ -607,8 +607,8 @@ inline Stmt LoopPartitioner::MakeFor(const Object* node,
PrimExpr extent, Stmt b
// If the loop extent is 1, do not create the loop anymore
return Substitute(body, {{Var{for_node->loop_var},
make_const(DataType::Int(32), 0)}});
} else {
- return For(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent,
for_node->for_type,
- for_node->device_api, body);
+ ICHECK(for_node->kind != ForKind::kThreadBinding);
+ return For(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent,
for_node->kind, body);
}
}
diff --git a/src/tir/transforms/make_packed_api.cc
b/src/tir/transforms/make_packed_api.cc
index adbe78a..0946af6 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -46,9 +46,9 @@ class ReturnRewriter : public StmtMutator {
explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var),
ret_tcode_(ret_tcode) {}
Stmt VisitStmt_(const ForNode* node) override {
- if (node->for_type == ForType::Parallel) in_parallel_ += 1;
+ if (node->kind == ForKind::kParallel) in_parallel_ += 1;
Stmt ret = StmtMutator::VisitStmt_(node);
- if (node->for_type == ForType::Parallel) in_parallel_ -= 1;
+ if (node->kind == ForKind::kParallel) in_parallel_ -= 1;
return ret;
}
diff --git a/src/tir/transforms/narrow_datatype.cc
b/src/tir/transforms/narrow_datatype.cc
index 0b24895..dc34626 100644
--- a/src/tir/transforms/narrow_datatype.cc
+++ b/src/tir/transforms/narrow_datatype.cc
@@ -220,8 +220,8 @@ class DataTypeRewriter : public StmtExprMutator {
<< ", but get " << s->GetTypeKey();
PrimExpr e = VisitExpr(op->loop_var);
Var var = Downcast<Var>(e);
- return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent),
op->for_type,
- op->device_api, op->body);
+ return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent),
op->kind, op->body,
+ op->thread_binding, op->annotations);
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
diff --git a/src/tir/transforms/storage_flatten.cc
b/src/tir/transforms/storage_flatten.cc
index d392866..43fc1f1 100644
--- a/src/tir/transforms/storage_flatten.cc
+++ b/src/tir/transforms/storage_flatten.cc
@@ -318,14 +318,14 @@ class StorageFlattener : public StmtExprMutator {
}
for (int i = starts; i >= 0; --i) {
if (i < starts) {
- stmt = For(vars[i], 0, op->bounds[i]->extent, ForType::Serial,
DeviceAPI::None, stmt);
+ stmt = For(vars[i], 0, op->bounds[i]->extent, ForKind::kSerial, stmt);
} else {
PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype);
PrimExpr address = Call(DataType::Handle(), builtin::address_of(),
{load});
PrimExpr prefetch = Call(op->buffer->dtype, builtin::prefetch(),
{address, 0, 3, 1});
stmt = Evaluate(prefetch);
PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1;
- stmt = For(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt);
+ stmt = For(vars[i], 0, extent, ForKind::kSerial, stmt);
}
}
return stmt;
diff --git a/src/tir/transforms/storage_rewrite.cc
b/src/tir/transforms/storage_rewrite.cc
index d4c5ca0..0b1429c 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -438,14 +438,14 @@ class StoragePlanRewriter : public StmtExprMutator {
}
}
Stmt VisitStmt_(const ForNode* op) final {
- ICHECK(op->for_type != ForType::Vectorized) << "VectorizeLoop before
LiftStorageAlloc";
+ ICHECK(op->kind != ForKind::kVectorized) << "VectorizeLoop before
LiftStorageAlloc";
// remake all the allocation at the attach scope.
if (attach_map_.count(op)) {
auto& svec = attach_map_[op];
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
- return For(op->loop_var, op->min, op->extent, op->for_type,
op->device_api,
- MakeAttach(svec, op->body));
+ return For(op->loop_var, op->min, op->extent, op->kind, MakeAttach(svec,
op->body),
+ op->thread_binding, op->annotations);
} else {
return StmtExprMutator::VisitStmt_(op);
}
@@ -765,7 +765,7 @@ class StoragePlanRewriter : public StmtExprMutator {
}
} else if (s.stmt->IsInstance<ForNode>()) {
const auto* op = static_cast<const ForNode*>(s.stmt);
- if (op->for_type == ForType::Parallel) {
+ if (op->kind == ForKind::kParallel) {
if (thread_scope_ == nullptr || thread_scope_ == op) {
PlanNewScope(op);
}
diff --git a/src/tir/transforms/unroll_loop.cc
b/src/tir/transforms/unroll_loop.cc
index 71ad899..c6e0b5c 100644
--- a/src/tir/transforms/unroll_loop.cc
+++ b/src/tir/transforms/unroll_loop.cc
@@ -100,13 +100,13 @@ class LoopUnroller : public StmtExprMutator {
op = stmt.as<ForNode>();
int value = GetExtent(op);
// condition for auto unroll
- bool auto_unroll = (op->for_type == ForType::Serial && value >= 0 &&
normal_loop_depth_ == 0 &&
+ bool auto_unroll = (op->kind == ForKind::kSerial && value >= 0 &&
normal_loop_depth_ == 0 &&
unroll_depth_ <= auto_max_depth_);
auto_unroll =
auto_unroll && (value * step_count_ <= auto_max_step_ || value <=
auto_max_extent_);
- if (op->for_type == ForType::Unrolled) {
+ if (op->kind == ForKind::kUnrolled) {
ICHECK_GE(value, 0) << "Cannot unroll non-constant loop";
auto_unroll = true;
}
@@ -124,9 +124,9 @@ class LoopUnroller : public StmtExprMutator {
return Unroll(op);
} else {
if (auto_unroll) {
- if (op->for_type != ForType::Unrolled) {
- return For(op->loop_var, op->min, op->extent, ForType::Unrolled,
op->device_api,
- op->body);
+ if (op->kind != ForKind::kUnrolled) {
+ return For(op->loop_var, op->min, op->extent, ForKind::kUnrolled,
op->body,
+ op->thread_binding, op->annotations);
}
}
return stmt;
diff --git a/src/tir/transforms/vectorize_loop.cc
b/src/tir/transforms/vectorize_loop.cc
index 239f422..66f4ae3 100644
--- a/src/tir/transforms/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -352,7 +352,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
}
// For
Stmt VisitStmt_(const ForNode* op) final {
- if (op->for_type == ForType::Vectorized) {
+ if (op->kind == ForKind::kVectorized) {
LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring...";
}
ICHECK(is_zero(op->min));
@@ -365,7 +365,8 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
} else {
- return For(op->loop_var, op->min, extent, op->for_type, op->device_api,
body);
+ return For(op->loop_var, op->min, extent, op->kind, body,
op->thread_binding,
+ op->annotations);
}
}
// IfThenElse
@@ -436,7 +437,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
Var idx(var_->name_hint + ".s", var_->dtype);
Map<Var, PrimExpr> values{{var_, idx}};
stmt = Substitute(stmt, values);
- return For(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt);
+ return For(idx, 0, var_lanes_, ForKind::kSerial, stmt);
}
// ProducerStore
Stmt VisitStmt_(const ProducerStoreNode* op) final {
@@ -525,7 +526,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
class LoopVectorizer : public StmtMutator {
public:
Stmt VisitStmt_(const ForNode* op) final {
- if (op->for_type == ForType::Vectorized) {
+ if (op->kind == ForKind::kVectorized) {
ICHECK(is_zero(op->min));
auto* extent_as_int = op->extent.as<IntImmNode>();
if (!extent_as_int || extent_as_int->value < 1) {
@@ -545,8 +546,8 @@ class VectorizeSkipper : public StmtMutator {
Stmt VisitStmt_(const ForNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
- if (op->for_type == ForType::Vectorized) {
- return For(op->loop_var, op->min, op->extent, ForType::Serial,
op->device_api, op->body);
+ if (op->kind == ForKind::kVectorized) {
+ return For(op->loop_var, op->min, op->extent, ForKind::kSerial,
op->body);
} else {
return stmt;
}
diff --git a/tests/python/unittest/test_arith_domain_touched.py
b/tests/python/unittest/test_arith_domain_touched.py
index ca5df4a..af06a03 100644
--- a/tests/python/unittest/test_arith_domain_touched.py
+++ b/tests/python/unittest/test_arith_domain_touched.py
@@ -31,14 +31,12 @@ def test_domain_touched():
i,
0,
n,
- 0,
- 0,
+ tvm.tir.ForKind.SERIAL,
tvm.tir.For(
j,
0,
m,
- 0,
- 0,
+ tvm.tir.ForKind.SERIAL,
tvm.tir.BufferStore(
a,
tvm.tir.BufferLoad(b, [i - 1, j + 1]) + tvm.tir.BufferLoad(a,
[i - 1, j - 1]),
diff --git a/tests/python/unittest/test_runtime_module_based_interface.py
b/tests/python/unittest/test_runtime_module_based_interface.py
index 64f87fb..51a5872 100644
--- a/tests/python/unittest/test_runtime_module_based_interface.py
+++ b/tests/python/unittest/test_runtime_module_based_interface.py
@@ -547,8 +547,7 @@ def test_multiple_imported_modules():
i,
0,
n - 1,
- 0,
- 0,
+ tvm.tir.ForKind.SERIAL,
tvm.tir.Store(Ab.data, tvm.tir.Load("float32", Ab.data, i) + 1, i
+ 1),
)
return tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", symbol)
diff --git a/tests/python/unittest/test_runtime_module_load.py
b/tests/python/unittest/test_runtime_module_load.py
index 7befed3..38800e8 100644
--- a/tests/python/unittest/test_runtime_module_load.py
+++ b/tests/python/unittest/test_runtime_module_load.py
@@ -55,7 +55,11 @@ def test_dso_module_load():
i = te.var("i")
# for i in 0 to n-1:
stmt = tvm.tir.For(
- i, 0, n - 1, 0, 0, tvm.tir.Store(Ab.data, tvm.tir.Load(dtype,
Ab.data, i) + 1, i + 1)
+ i,
+ 0,
+ n - 1,
+ tvm.tir.ForKind.SERIAL,
+ tvm.tir.Store(Ab.data, tvm.tir.Load(dtype, Ab.data, i) + 1, i + 1),
)
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "main")
diff --git a/tests/python/unittest/test_target_codegen_cuda.py
b/tests/python/unittest/test_target_codegen_cuda.py
index e877674..a22fe10 100644
--- a/tests/python/unittest/test_target_codegen_cuda.py
+++ b/tests/python/unittest/test_target_codegen_cuda.py
@@ -200,7 +200,7 @@ def test_cuda_shuffle():
def MyVectorize():
def vectorizer(op):
- if op.for_type == tvm.tir.For.Vectorized:
+ if op.kind == tvm.tir.ForKind.VECTORIZED:
four = tvm.tir.const(4, "int32")
idx = tvm.tir.Ramp(thrx.var * four, tvm.tir.const(1, "int32"),
4)
all_ones = tvm.tir.const(1, "int32x4")
diff --git a/tests/python/unittest/test_target_codegen_llvm.py
b/tests/python/unittest/test_target_codegen_llvm.py
index 4b67752..67c1f6b 100644
--- a/tests/python/unittest/test_target_codegen_llvm.py
+++ b/tests/python/unittest/test_target_codegen_llvm.py
@@ -761,7 +761,7 @@ def test_llvm_lower_atomic():
atomic_add_return = ib.allocate(A.dtype, (1,),
name="atomic_add_return", scope="local")
one = tvm.tir.const(1, A.dtype)
A_ptr = ib.buffer_ptr(A)
- with ib.for_range(0, n, name="i", for_type="parallel") as i:
+ with ib.for_range(0, n, name="i", kind="parallel") as i:
atomic_add_return[0] = atomic_add(
tvm.tir.call_intrin("handle", "tir.address_of", A_ptr[0]), one
)
diff --git a/tests/python/unittest/test_target_codegen_static_init.py
b/tests/python/unittest/test_target_codegen_static_init.py
index 179e302..b0c19df 100644
--- a/tests/python/unittest/test_target_codegen_static_init.py
+++ b/tests/python/unittest/test_target_codegen_static_init.py
@@ -30,7 +30,7 @@ def test_static_callback():
cp = te.thread_axis((0, 1), "cop")
finit = tvm.tir.StringImm("TVMBackendRunOnce")
ib.scope_attr(cp, "coproc_uop_scope", finit)
- with ib.for_range(0, n, "i", for_type="parallel") as i:
+ with ib.for_range(0, n, "i", kind="parallel") as i:
A[i] = A[i] + 1
stmt = ib.get()
diff --git a/tests/python/unittest/test_target_codegen_vm_basic.py
b/tests/python/unittest/test_target_codegen_vm_basic.py
index 26f1493..9bbee76 100644
--- a/tests/python/unittest/test_target_codegen_vm_basic.py
+++ b/tests/python/unittest/test_target_codegen_vm_basic.py
@@ -109,7 +109,7 @@ def test_vm_parallel():
i = te.size_var("i")
ib = tvm.tir.ir_builder.create()
A = ib.buffer_ptr(Ab)
- with ib.for_range(0, n, "i", for_type="parallel") as i:
+ with ib.for_range(0, n, "i", kind="parallel") as i:
A[i] = A[i] + 1
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab],
stmt).with_attr("global_symbol", "test"))
diff --git a/tests/python/unittest/test_te_hybrid_script.py
b/tests/python/unittest/test_te_hybrid_script.py
index 06d4099..be99565 100644
--- a/tests/python/unittest/test_te_hybrid_script.py
+++ b/tests/python/unittest/test_te_hybrid_script.py
@@ -267,9 +267,9 @@ def test_looptype():
iloop = ir[0]
jloop = ir[1]
kloop = ir[2]
- assert iloop.for_type == tvm.tir.For.Parallel
- assert jloop.for_type == tvm.tir.For.Vectorized
- assert kloop.for_type == tvm.tir.For.Unrolled
+ assert iloop.kind == tvm.tir.ForKind.PARALLEL
+ assert jloop.kind == tvm.tir.ForKind.VECTORIZED
+ assert kloop.kind == tvm.tir.ForKind.UNROLLED
func, ins, outs = run_and_check(looptype, [a, b, c])
run_and_check(func, ins, outs=outs)
diff --git a/tests/python/unittest/test_tir_constructor.py
b/tests/python/unittest/test_tir_constructor.py
index 2bf4ba5..2cc21db 100644
--- a/tests/python/unittest/test_tir_constructor.py
+++ b/tests/python/unittest/test_tir_constructor.py
@@ -142,7 +142,7 @@ def test_stmt_constructor():
assert isinstance(x, tvm.tir.AssertStmt)
assert x.body == nop
- x = tvm.tir.For(te.var("x"), 0, 10, 0, 0, nop)
+ x = tvm.tir.For(te.var("x"), 0, 10, tvm.tir.ForKind.SERIAL, nop)
assert isinstance(x, tvm.tir.For)
assert x.min.value == 0
assert x.extent.value == 10
diff --git a/tests/python/unittest/test_tir_nodes.py
b/tests/python/unittest/test_tir_nodes.py
index 4d57ed8..bff60f7 100644
--- a/tests/python/unittest/test_tir_nodes.py
+++ b/tests/python/unittest/test_tir_nodes.py
@@ -129,7 +129,7 @@ def test_basic():
def test_stmt():
x = tvm.tir.Evaluate(0)
- tvm.tir.For(te.var("i"), 0, 1, tvm.tir.For.Serial, 0, x)
+ tvm.tir.For(te.var("i"), 0, 1, tvm.tir.ForKind.SERIAL, x)
def test_dir():
diff --git a/tests/python/unittest/test_tir_transform_remove_no_op.py
b/tests/python/unittest/test_tir_transform_remove_no_op.py
index 2edb8cf..8b7a169 100644
--- a/tests/python/unittest/test_tir_transform_remove_no_op.py
+++ b/tests/python/unittest/test_tir_transform_remove_no_op.py
@@ -34,20 +34,17 @@ def test_remove_no_op():
i,
0,
4,
- 0,
- 0,
+ tvm.tir.ForKind.SERIAL,
tvm.tir.For(
j,
0,
n,
- 0,
- 0,
+ tvm.tir.ForKind.SERIAL,
tvm.tir.For(
k,
0,
m,
- 0,
- 0,
+ tvm.tir.ForKind.SERIAL,
tvm.tir.IfThenElse((i * m + j + k < n), tvm.tir.Evaluate(m),
tvm.tir.Evaluate(n)),
),
),
@@ -65,7 +62,7 @@ def test_remove_no_op():
assert ret == store
# remove zero extent loop
- stmt3 = tvm.tir.For(i, 0, 0, 0, 0, store)
+ stmt3 = tvm.tir.For(i, 0, 0, tvm.tir.ForKind.SERIAL, store)
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt3))
ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body
assert isinstance(ret, tvm.tir.Evaluate)
diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py
b/tests/python/unittest/test_tir_transform_storage_rewrite.py
index cc2b427..49adcfb 100644
--- a/tests/python/unittest/test_tir_transform_storage_rewrite.py
+++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py
@@ -269,7 +269,7 @@ def test_storage_share_gpu():
def test_parallel_alloc():
ib = tvm.tir.ir_builder.create()
n = te.var("n")
- with ib.for_range(0, n, name="i", for_type="parallel") as i:
+ with ib.for_range(0, n, name="i", kind="parallel") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", n, name="A", scope="global")
A[j] = A[j] + 2
@@ -286,7 +286,7 @@ def test_parallel_alloc():
ib.scope_attr(
tvm.tir.const(1, "int32"), "pragma_scope",
tvm.tir.StringImm("parallel_launch_point")
)
- with ib.for_range(0, n, name="i", for_type="parallel") as i:
+ with ib.for_range(0, n, name="i", kind="parallel") as i:
with ib.for_range(0, 10, name="j") as j:
A = ib.allocate("float32", n, name="A", scope="global")
A[j] = A[j] + 2
diff --git a/tests/python/unittest/test_tir_transform_unroll_loop.py
b/tests/python/unittest/test_tir_transform_unroll_loop.py
index 57b7810..b511118 100644
--- a/tests/python/unittest/test_tir_transform_unroll_loop.py
+++ b/tests/python/unittest/test_tir_transform_unroll_loop.py
@@ -27,7 +27,7 @@ def test_unroll_loop():
Aptr = ib.buffer_ptr(Ab)
# for i in 0 to n-1:
with ib.for_range(n, n + 2, name="i") as i:
- with ib.for_range(0, 8, name="i", for_type="unroll") as j:
+ with ib.for_range(0, 8, name="i", kind="unroll") as j:
Aptr[j + 1] = Aptr[i] + 1
stmt = ib.get()
@@ -48,7 +48,7 @@ def test_unroll_loop():
):
ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body
assert isinstance(ret, tvm.tir.For)
- assert ret.for_type == tvm.tir.For.Unrolled
+ assert ret.kind == tvm.tir.ForKind.UNROLLED
ib = tvm.tir.ir_builder.create()
ib.scope_attr(tvm.tir.const(0, "int32"), "pragma_auto_unroll_max_step", 16)
@@ -63,9 +63,9 @@ def test_unroll_loop():
):
ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body
assert isinstance(ret[0], tvm.tir.For)
- assert ret[0].for_type == tvm.tir.For.Unrolled
+ assert ret[0].kind == tvm.tir.ForKind.UNROLLED
assert isinstance(ret[1], tvm.tir.For)
- assert ret[1].for_type != tvm.tir.For.Unrolled
+ assert ret[1].kind != tvm.tir.ForKind.UNROLLED
def test_unroll_fake_loop():
diff --git a/tests/python/unittest/test_tir_transform_vectorize.py
b/tests/python/unittest/test_tir_transform_vectorize.py
index 204e26f..5ae47e0 100644
--- a/tests/python/unittest/test_tir_transform_vectorize.py
+++ b/tests/python/unittest/test_tir_transform_vectorize.py
@@ -24,7 +24,7 @@ def test_vectorize_loop():
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32", name="A")
with ib.for_range(0, n) as i:
- with ib.for_range(0, 4, for_type="vectorize") as j:
+ with ib.for_range(0, 4, kind="vectorize") as j:
A[j] = tvm.tir.const(1, A.dtype)
stmt = ib.get()
@@ -45,7 +45,7 @@ def test_vectorize_vector():
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32x4", name="A")
with ib.for_range(0, n) as i:
- with ib.for_range(0, 4, for_type="vectorize") as j:
+ with ib.for_range(0, 4, kind="vectorize") as j:
A[j] = tvm.tir.const(1, A.dtype)
stmt = ib.get()
assert isinstance(stmt.body, tvm.tir.For)
@@ -64,7 +64,7 @@ def test_vectorize_with_if():
x = te.var("x")
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32", name="A")
- with ib.for_range(0, 4, for_type="vectorize") as i:
+ with ib.for_range(0, 4, kind="vectorize") as i:
with ib.if_scope(x < n):
A[i] = A[i] + 1
with ib.else_scope():
@@ -86,7 +86,7 @@ def test_vectorize_let():
v = tvm.tir.Var("v", "float32")
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32", name="A")
- with ib.for_range(0, 4, for_type="vectorize") as i:
+ with ib.for_range(0, 4, kind="vectorize") as i:
ib.emit(lambda body: tvm.tir.LetStmt(v, A[i] + 1, body))
A[i] = v + 2
@@ -100,7 +100,7 @@ def test_vectorize_with_le_cond():
n = te.var("n")
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32", name="A")
- with ib.for_range(0, 4, for_type="vectorize") as i:
+ with ib.for_range(0, 4, kind="vectorize") as i:
with ib.if_scope(i <= n):
A[i] = A[i] + 1
stmt = ib.get()
@@ -115,7 +115,7 @@ def test_vectorize_with_ge_cond():
n = te.var("n")
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32", name="A")
- with ib.for_range(0, 4, for_type="vectorize") as i:
+ with ib.for_range(0, 4, kind="vectorize") as i:
with ib.if_scope(i >= n):
A[i] = A[i] + 1
stmt = ib.get()
@@ -131,7 +131,7 @@ def test_vectorize_if_then_else():
x = te.var("x")
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32", name="A")
- with ib.for_range(0, 4, for_type="vectorize") as i:
+ with ib.for_range(0, 4, kind="vectorize") as i:
A[i] = tvm.tir.call_intrin("float32", "tir.if_then_else", i > 0, A[i]
+ 1, A[i])
stmt = ib.get()
@@ -143,7 +143,7 @@ def test_vectorize_if_then_else():
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32", name="A")
with ib.for_range(0, n) as k:
- with ib.for_range(0, 4, for_type="vectorize") as i:
+ with ib.for_range(0, 4, kind="vectorize") as i:
A[k * 4 + i] = tvm.tir.call_intrin(
"float32", "tir.if_then_else", k > 0, A[k * 4 + i], 0
)
diff --git a/tutorials/dev/low_level_custom_pass.py
b/tutorials/dev/low_level_custom_pass.py
index 44fe59f..0bd656d 100644
--- a/tutorials/dev/low_level_custom_pass.py
+++ b/tutorials/dev/low_level_custom_pass.py
@@ -116,8 +116,8 @@ def vectorize8(op):
name = op.loop_var.name
lo, li = te.var(name + ".outer"), te.var(name + ".inner")
body = tvm.tir.stmt_functor.substitute(op.body, {op.loop_var: lo * 8 +
li})
- body = tvm.tir.For(li, 0, 8, tvm.tir.For.Vectorized, 0, body)
- body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.For.Serial, 0, body)
+ body = tvm.tir.For(li, 0, 8, tvm.tir.ForKind.VECTORIZED, body)
+ body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.ForKind.SERIAL, body)
return body
return None
diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py
index a485d2c..9770857 100644
--- a/vta/python/vta/transform.py
+++ b/vta/python/vta/transform.py
@@ -231,7 +231,13 @@ def LiftAllocToScopeBegin():
body = tvm.tir.AttrStmt(op.node, op.attr_key, op.value,
body)
elif isinstance(op, tvm.tir.For):
body = tvm.tir.For(
- op.loop_var, op.min, op.extent, op.for_type,
op.device_api, body
+ op.loop_var,
+ op.min,
+ op.extent,
+ op.kind,
+ body,
+ op.thread_binding,
+ op.annotations,
)
else:
raise RuntimeError("unexpected op")
@@ -314,7 +320,9 @@ def InjectCoProcSync():
if _match_pragma(stmt, "trim_loop"):
op = stmt.body
assert isinstance(op, tvm.tir.For)
- return tvm.tir.For(op.loop_var, op.min, 2, op.for_type,
op.device_api, op.body)
+ return tvm.tir.For(
+ op.loop_var, op.min, 2, op.kind, op.body,
op.thread_binding, op.annotations
+ )
return None
return f.with_body(