This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 13ea9dc104 [TIR] Add step attribute to ForNode (Initial codes) (#18421)
13ea9dc104 is described below
commit 13ea9dc10436836e9654a897cf6f8f87813dc8a4
Author: wrongtest <[email protected]>
AuthorDate: Mon Nov 24 21:30:16 2025 +0800
[TIR] Add step attribute to ForNode (Initial codes) (#18421)
An initial change to add `ForNode::step`.
- Add `Optional<PrimExpr>` typed step attribute to ForNode. Then add
minimal codes for
- Roundtrip support for TIR tvmscript grammar
- Correctness of TIR lowering pipeline:
- Canonicalize the loop in default pipeline
- Ensure the original `ForNode::step` is not dropped by mutations on
`ForNode`.
- CodeGen support for non-zero min and non-trivial step.
- TODOs in the future (hopefully)
- For **all transformations and analysis tools**, make adaptions to
non-consecutive loop iteration indices
- Correctness of TensorIR schedule and MetaSchedule
---------
Co-authored-by: baoxinqi <[email protected]>
---
include/tvm/script/ir_builder/tir/frame.h | 8 +-
include/tvm/script/ir_builder/tir/ir.h | 16 +++-
include/tvm/tir/stmt.h | 17 +++-
python/tvm/script/ir_builder/tir/ir.py | 44 +++++++--
python/tvm/script/parser/tir/parser.py | 27 +++++-
python/tvm/tir/ir_builder.py | 8 +-
python/tvm/tir/pipeline.py | 1 +
python/tvm/tir/stmt.py | 7 ++
python/tvm/tir/transform/transform.py | 11 +++
.../transform/lower_global_view_to_local_view.cc | 4 +-
src/script/ir_builder/tir/frame.cc | 2 +-
src/script/ir_builder/tir/ir.cc | 20 +++-
src/script/printer/tir/for_loop.cc | 15 ++-
src/target/llvm/codegen_cpu.cc | 16 ++--
src/target/llvm/codegen_llvm.cc | 8 +-
src/target/source/codegen_c.cc | 14 ++-
src/target/source/codegen_cuda.cc | 1 -
src/target/source/codegen_webgpu.cc | 14 ++-
src/target/spirv/codegen_spirv.cc | 23 +++--
src/tir/ir/data_type_rewriter.cc | 9 +-
src/tir/ir/stmt.cc | 30 ++++--
src/tir/ir/stmt_functor.cc | 11 ++-
src/tir/schedule/primitive/blockize_tensorize.cc | 2 +-
src/tir/schedule/primitive/decompose_padding.cc | 2 +-
src/tir/schedule/primitive/loop_transformation.cc | 4 +-
src/tir/schedule/primitive/reduction.cc | 13 ++-
src/tir/transforms/canonicalize_loop.cc | 102 +++++++++++++++++++++
src/tir/transforms/common_subexpr_elim.cc | 2 +-
src/tir/transforms/convert_for_loops_serial.cc | 2 +-
src/tir/transforms/inject_software_pipeline.cc | 2 +-
src/tir/transforms/ir_utils.cc | 6 +-
src/tir/transforms/lift_thread_binding.cc | 2 +-
src/tir/transforms/loop_partition.cc | 8 +-
src/tir/transforms/lower_cross_thread_reduction.cc | 4 +-
src/tir/transforms/lower_opaque_block.cc | 2 +-
src/tir/transforms/memhammer_coalesce.cc | 3 +-
src/tir/transforms/memhammer_tensorcore_rewrite.cc | 55 ++++++-----
src/tir/transforms/storage_rewrite.cc | 2 +-
src/tir/transforms/unify_thread_binding.cc | 6 +-
src/tir/transforms/unroll_loop.cc | 5 +-
src/tir/transforms/vectorize_loop.cc | 6 +-
tests/python/codegen/test_target_codegen.py | 44 ++++++++-
tests/python/codegen/test_target_codegen_cuda.py | 32 +++++++
tests/python/tir-base/test_tir_nodes.py | 1 +
.../test_tir_transform_canonicalize_loop.py | 88 ++++++++++++++++++
.../python/tvmscript/test_tvmscript_parser_tir.py | 26 ++++++
tests/python/tvmscript/test_tvmscript_roundtrip.py | 20 ++++
47 files changed, 619 insertions(+), 126 deletions(-)
diff --git a/include/tvm/script/ir_builder/tir/frame.h
b/include/tvm/script/ir_builder/tir/frame.h
index 827e4e0329..db5776890a 100644
--- a/include/tvm/script/ir_builder/tir/frame.h
+++ b/include/tvm/script/ir_builder/tir/frame.h
@@ -251,13 +251,15 @@ class ForFrameNode : public TIRFrameNode {
* \param loop_body The loop body
* \return A stmt, the loop nest
*/
- using FMakeForLoop =
- ffi::TypedFunction<tvm::tir::Stmt(ffi::Array<tvm::tir::Var> loop_vars,
- ffi::Array<Range> loop_extents,
tvm::tir::Stmt loop_body)>;
+ using FMakeForLoop = ffi::TypedFunction<tvm::tir::Stmt(
+ ffi::Array<tvm::tir::Var> loop_vars, ffi::Array<Range> loop_extents,
+ ffi::Array<ffi::Optional<PrimExpr>> loop_steps, tvm::tir::Stmt
loop_body)>;
/*! \brief The loop variable. */
ffi::Array<tvm::tir::Var> vars;
/*! \brief The domains of iteration. */
ffi::Array<Range> doms;
+ /*! \brief The optional steps of iteration. */
+ ffi::Array<ffi::Optional<PrimExpr>> steps;
/*! \brief The for loop generating function. */
FMakeForLoop f_make_for_loop;
diff --git a/include/tvm/script/ir_builder/tir/ir.h
b/include/tvm/script/ir_builder/tir/ir.h
index 24ce8fdf99..07c7fe262b 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -228,37 +228,45 @@ ffi::Array<Var> Remap(ffi::String kinds,
ffi::Array<PrimExpr> bindings,
* \param start The minimum value of iteration.
* \param stop The maximum value of iteration.
* \param annotations The optional annotations of the For statement.
+ * \param step The optional step value of iteration.
* \return The ForFrame.
*/
ForFrame Serial(PrimExpr start, PrimExpr stop,
- ffi::Optional<ffi::Map<ffi::String, Any>> annotations =
std::nullopt);
+ ffi::Optional<ffi::Map<ffi::String, Any>> annotations =
std::nullopt,
+ ffi::Optional<PrimExpr> step = std::nullopt);
/*!
* \brief The parallel For statement.
* \param start The minimum value of iteration.
* \param stop The maximum value of iteration.
* \param annotations The optional annotations of the For statement.
+ * \param step The optional step value of iteration.
* \return The ForFrame.
*/
ForFrame Parallel(PrimExpr start, PrimExpr stop,
- ffi::Optional<ffi::Map<ffi::String, Any>> annotations =
std::nullopt);
+ ffi::Optional<ffi::Map<ffi::String, Any>> annotations =
std::nullopt,
+ ffi::Optional<PrimExpr> step = std::nullopt);
/*!
* \brief The vectorized For statement.
* \param start The minimum value of iteration.
* \param stop The maximum value of iteration.
* \param annotations The optional annotations of the For statement.
+ * \param step The optional step value of iteration.
* \return The ForFrame.
*/
ForFrame Vectorized(PrimExpr start, PrimExpr stop,
- ffi::Optional<ffi::Map<ffi::String, Any>> annotations =
std::nullopt);
+ ffi::Optional<ffi::Map<ffi::String, Any>> annotations =
std::nullopt,
+ ffi::Optional<PrimExpr> step = std::nullopt);
/*!
* \brief The unrolled For statement.
* \param start The minimum value of iteration.
* \param stop The maximum value of iteration.
* \param annotations The optional annotations of the For statement.
+ * \param step The optional step value of iteration.
* \return The ForFrame.
*/
ForFrame Unroll(PrimExpr start, PrimExpr stop,
- ffi::Optional<ffi::Map<ffi::String, Any>> annotations =
std::nullopt);
+ ffi::Optional<ffi::Map<ffi::String, Any>> annotations =
std::nullopt,
+ ffi::Optional<PrimExpr> step = std::nullopt);
/*!
* \brief The thread-binding For statement.
* \param start The minimum value of iteration.
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 1b8041e36c..0831b84cf6 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -717,7 +717,7 @@ enum class ForKind : int {
*
* \code
*
- * for (loop_var = min; loop_var < min + extent; ++loop_var) {
+ * for (loop_var = min; loop_var < min + extent; loop_var += step) {
* // body
* }
* \endcode
@@ -748,6 +748,10 @@ class ForNode : public StmtNode {
* and can be ignored in most passes.
*/
ffi::Map<ffi::String, ffi::Any> annotations;
+ /*!
+ * \brief The loop step. It is one if not specified.
+ */
+ ffi::Optional<PrimExpr> step;
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
@@ -758,8 +762,13 @@ class ForNode : public StmtNode {
.def_ro("kind", &ForNode::kind)
.def_ro("body", &ForNode::body)
.def_ro("thread_binding", &ForNode::thread_binding)
- .def_ro("annotations", &ForNode::annotations);
+ .def_ro("annotations", &ForNode::annotations)
+ .def_ro("step", &ForNode::step);
}
+
+ /*! \brief Check it is a loop without nontrivial loop step. */
+ bool HasTrivialStep() const;
+
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.For", ForNode, StmtNode);
};
@@ -771,8 +780,8 @@ class For : public Stmt {
public:
TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt
body,
ffi::Optional<IterVar> thread_binding = std::nullopt,
- ffi::Map<ffi::String, ffi::Any> annotations =
ffi::Map<ffi::String, ffi::Any>(),
- Span span = Span());
+ ffi::Map<ffi::String, ffi::Any> annotations = {},
+ ffi::Optional<PrimExpr> step = std::nullopt, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(For, Stmt, ForNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode);
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index 6d746d73b1..31e48260f5 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -677,7 +677,11 @@ class axis: # pylint: disable=invalid-name
def serial(
- start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] =
None
+ start: PrimExpr,
+ stop: PrimExpr = None,
+ *,
+ annotations: Dict[str, Any] = None,
+ step: Optional[PrimExpr] = None,
) -> frame.ForFrame:
"""The serial For statement.
@@ -692,6 +696,9 @@ def serial(
annotations : Dict[str, Any]
The optional annotations of the For statement.
+ step : PrimExpr
+ The optional step value of iteration.
+
Returns
-------
res : frame.ForFrame
@@ -703,11 +710,15 @@ def serial(
start = IntImm(start.dtype, 0)
else:
start = 0
- return _ffi_api.Serial(start, stop, annotations) # type:
ignore[attr-defined] # pylint: disable=no-member
+ return _ffi_api.Serial(start, stop, annotations, step) # type:
ignore[attr-defined] # pylint: disable=no-member
def parallel(
- start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] =
None
+ start: PrimExpr,
+ stop: PrimExpr = None,
+ *,
+ annotations: Dict[str, Any] = None,
+ step: Optional[PrimExpr] = None,
) -> frame.ForFrame:
"""The parallel For statement.
@@ -722,6 +733,9 @@ def parallel(
annotations : Dict[str, Any]
The optional annotations of the For statement.
+ step : PrimExpr
+ The optional step value of iteration.
+
Returns
-------
res : frame.ForFrame
@@ -733,11 +747,15 @@ def parallel(
start = IntImm(start.dtype, 0)
else:
start = 0
- return _ffi_api.Parallel(start, stop, annotations) # type:
ignore[attr-defined] # pylint: disable=no-member
+ return _ffi_api.Parallel(start, stop, annotations, step) # type:
ignore[attr-defined] # pylint: disable=no-member
def vectorized(
- start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] =
None
+ start: PrimExpr,
+ stop: PrimExpr = None,
+ *,
+ annotations: Dict[str, Any] = None,
+ step: Optional[PrimExpr] = None,
) -> frame.ForFrame:
"""The vectorized For statement.
@@ -752,6 +770,9 @@ def vectorized(
annotations : Dict[str, Any]
The optional annotations of the For statement.
+ step : PrimExpr
+ The optional step value of iteration.
+
Returns
-------
res : frame.ForFrame
@@ -763,11 +784,15 @@ def vectorized(
start = IntImm(start.dtype, 0)
else:
start = 0
- return _ffi_api.Vectorized(start, stop, annotations) # type:
ignore[attr-defined] # pylint: disable=no-member
+ return _ffi_api.Vectorized(start, stop, annotations, step) # type:
ignore[attr-defined] # pylint: disable=no-member
def unroll(
- start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] =
None
+ start: PrimExpr,
+ stop: PrimExpr = None,
+ *,
+ annotations: Dict[str, Any] = None,
+ step: Optional[PrimExpr] = None,
) -> frame.ForFrame:
"""The unrolled For statement.
@@ -782,6 +807,9 @@ def unroll(
annotations : Dict[str, Any]
The optional annotations of the For statement.
+ step : PrimExpr
+ The optional step value of iteration.
+
Returns
-------
res : frame.ForFrame
@@ -793,7 +821,7 @@ def unroll(
start = IntImm(start.dtype, 0)
else:
start = 0
- return _ffi_api.Unroll(start, stop, annotations) # type:
ignore[attr-defined] # pylint: disable=no-member
+ return _ffi_api.Unroll(start, stop, annotations, step) # type:
ignore[attr-defined] # pylint: disable=no-member
def thread_binding(
diff --git a/python/tvm/script/parser/tir/parser.py
b/python/tvm/script/parser/tir/parser.py
index 85ab1982f3..f8cbc0b4f5 100644
--- a/python/tvm/script/parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -18,7 +18,7 @@
import contextlib
from functools import partial
-from typing import Any
+from typing import Any, Dict, Optional
import tvm
from tvm.ir import GlobalVar, PrimType
@@ -168,6 +168,28 @@ def find_decorator_annotation(node: doc.FunctionDef,
annotation: str, default: b
return default
+def range_sugar(
+ start: PrimExpr,
+ stop: PrimExpr = None,
+ step: Optional[PrimExpr] = None,
+ *,
+ annotations: Dict[str, Any] = None,
+) -> T.frame.ForFrame:
+ """The sugar for python range builtin."""
+
+ # Since `tir.For` do not support reversed iteration semantic,
+ # the step must be checked to be positive integer when use range sugar
+ if step is not None:
+ try:
+ step = int(step)
+ if step <= 0:
+ raise ValueError(f"Only support positive step in range(), get
{step}")
+ except TypeError: # pylint: disable=broad-except
+ raise ValueError(f"Only support literal step in range(), get
{step}")
+
+ return T.serial(start, stop, annotations=annotations, step=step)
+
+
@dispatch.register(token="tir", type_name="For")
def visit_for(self: Parser, node: doc.For) -> None:
"""The for visiting method for tir.
@@ -379,7 +401,8 @@ def visit_function_def(self: Parser, node: doc.FunctionDef)
-> None:
privacy = find_decorator_annotation(node, "private", default=False)
self.function_annotations = None
with self.var_table.with_frame():
- self.var_table.add("range", T.serial)
+
+ self.var_table.add("range", range_sugar)
with T.prim_func(is_private=privacy):
T.func_name(node.name)
if node.returns is not None:
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py
index a6313ae3bc..1e9cb07830 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder.py
@@ -202,7 +202,7 @@ class IRBuilder(object):
value = op.max(1, value)
self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x))
- def for_range(self, begin, end, name="i", dtype=None, kind="serial"):
+ def for_range(self, begin, end, name="i", dtype=None, kind="serial",
step=None):
"""Create a for iteration scope.
Parameters
@@ -223,6 +223,10 @@ class IRBuilder(object):
kind : str, optional
The special tag on the for loop.
+ step : PrimExpr
+ The loop step. Default to none which
+ represent one.
+
Returns
-------
loop_scope : With.Scope of Var
@@ -275,7 +279,7 @@ class IRBuilder(object):
kind_id = _stmt.ForKind.UNROLLED
else:
raise ValueError("Unknown kind")
- self.emit(_stmt.For(loop_var, begin, extent, kind_id,
self._pop_seq()))
+ self.emit(_stmt.For(loop_var, begin, extent, kind_id,
self._pop_seq(), step=step))
return WithScope(loop_var, _exit_cb)
diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tir/pipeline.py
index 22cec30334..96ed9dfdbc 100644
--- a/python/tvm/tir/pipeline.py
+++ b/python/tvm/tir/pipeline.py
@@ -31,6 +31,7 @@ def default_tir_pipeline():
pass_ctx = tvm.transform.PassContext.current()
config = pass_ctx.config
passes = [
+ tir.transform.CanonicalizeLoop(),
tir.transform.LowerCrossThreadReduction(),
tir.transform.LowerInitBlock(),
tir.transform.PlanAndUpdateBufferAllocationLocation(),
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index bd90d52574..448ace3ade 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -145,6 +145,10 @@ class For(Stmt):
The thread this loop binds to. Only valid
if kind is ThreadBinding
+ step : PrimExpr
+ The loop step. Default to none which
+ represent one.
+
annotations: Optional[Mapping[str, Object]]
Additional annotation hints.
@@ -159,6 +163,7 @@ class For(Stmt):
body: Stmt
thread_binding: Optional[IterVar]
annotations: Mapping[str, Object]
+ step: Optional[PrimExpr]
span: Optional[Span]
def __init__(
@@ -170,6 +175,7 @@ class For(Stmt):
body: Stmt,
thread_binding: Optional[IterVar] = None,
annotations: Optional[Mapping[str, Object]] = None,
+ step: Optional[PrimExpr] = None,
span: Optional[Span] = None,
) -> None:
self.__init_handle_by_constructor__(
@@ -181,6 +187,7 @@ class For(Stmt):
body,
thread_binding,
annotations,
+ step,
span,
)
diff --git a/python/tvm/tir/transform/transform.py
b/python/tvm/tir/transform/transform.py
index 39105f21a2..88cf4720d3 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -1171,3 +1171,14 @@ def LowerVtcmAlloc():
The result pass
"""
return _ffi_api.LowerVtcmAlloc() # type: ignore
+
+
+def CanonicalizeLoop():
+ """Canonicalize the loop to start from zero and use trivial step
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.CanonicalizeLoop() # type: ignore
diff --git a/src/relax/distributed/transform/lower_global_view_to_local_view.cc
b/src/relax/distributed/transform/lower_global_view_to_local_view.cc
index f83edb3e90..837f2f0a5d 100644
--- a/src/relax/distributed/transform/lower_global_view_to_local_view.cc
+++ b/src/relax/distributed/transform/lower_global_view_to_local_view.cc
@@ -330,8 +330,8 @@ class DistributedBufferCompactor : StmtExprMutator {
if (shard > 1) {
arith::Analyzer analyzer;
ICHECK(analyzer.CanProve(floormod(new_loop->extent, shard) == 0));
- return For(new_loop->loop_var, new_loop->min,
floordiv(new_loop->extent, shard),
- new_loop->kind, new_loop->body, new_loop->thread_binding,
new_loop->annotations);
+ new_loop.CopyOnWrite()->extent = floordiv(new_loop->extent, shard);
+ return new_loop;
}
}
return new_loop;
diff --git a/src/script/ir_builder/tir/frame.cc
b/src/script/ir_builder/tir/frame.cc
index 94eef40f59..7c10b6cdc8 100644
--- a/src/script/ir_builder/tir/frame.cc
+++ b/src/script/ir_builder/tir/frame.cc
@@ -123,7 +123,7 @@ void BlockInitFrameNode::ExitWithScope() {
void ForFrameNode::ExitWithScope() {
TIRFrameNode::ExitWithScope();
- AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts)));
+ AddToParent(this->f_make_for_loop(vars, doms, steps, AsStmt(stmts)));
}
void AssertFrameNode::ExitWithScope() {
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index b981b90bd8..00f9c28475 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -362,19 +362,23 @@ ffi::Array<Var> Remap(ffi::String kinds,
ffi::Array<PrimExpr> bindings, DataType
#define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind)
\
ForFrame Method(PrimExpr start, PrimExpr stop,
\
- ffi::Optional<ffi::Map<ffi::String, Any>> annotations) {
\
+ ffi::Optional<ffi::Map<ffi::String, Any>> annotations,
\
+ ffi::Optional<PrimExpr> step) {
\
PrimExpr min = start;
\
PrimExpr extent = arith::Analyzer().Simplify(stop - start);
\
ObjectPtr<ForFrameNode> n = ffi::make_object<ForFrameNode>();
\
int bits = std::max(min.dtype().bits(), extent.dtype().bits());
\
n->vars = {Var("v", DataType(min.dtype().code(), bits, 1))};
\
n->doms = {Range::FromMinExtent(min, extent)};
\
+ n->steps = {step};
\
n->f_make_for_loop = [annotations](ffi::Array<Var> vars, ffi::Array<Range>
doms, \
+ ffi::Array<ffi::Optional<PrimExpr>>
steps, \
tvm::tir::Stmt body) {
\
ICHECK_EQ(vars.size(), 1);
\
ICHECK_EQ(doms.size(), 1);
\
+ ICHECK_EQ(steps.size(), 1);
\
return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body,
std::nullopt, \
- annotations.value_or(ffi::Map<ffi::String,
Any>())); \
+ annotations.value_or(ffi::Map<ffi::String, Any>()),
steps[0]); \
};
\
return ForFrame(n);
\
}
@@ -396,13 +400,16 @@ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop,
ffi::String thread,
DataType dtype = DataType(min.dtype().code(), bits, 1);
n->vars = {Var("v", dtype)};
n->doms = {Range::FromMinExtent(min, extent)};
+ n->steps = {std::nullopt};
n->f_make_for_loop = [annotations, thread, dtype](ffi::Array<Var> vars,
ffi::Array<Range> doms,
+
ffi::Array<ffi::Optional<PrimExpr>> steps,
Stmt body) -> For {
ICHECK_EQ(vars.size(), 1);
ICHECK_EQ(doms.size(), 1);
+ ICHECK(steps.size() == 1 && (!steps[0].has_value() || is_one(*steps[0])));
IterVar iter_var(Range(nullptr), Var("iter", dtype),
IterVarType::kThreadIndex, thread);
return For(vars[0], doms[0]->min, doms[0]->extent,
ForKind::kThreadBinding, body, iter_var,
- annotations.value_or(ffi::Map<ffi::String, ffi::Any>()));
+ annotations.value_or(ffi::Map<ffi::String, ffi::Any>()),
std::nullopt);
};
return ForFrame(n);
}
@@ -412,19 +419,22 @@ ForFrame Grid(ffi::Array<PrimExpr> extents) {
ObjectPtr<ForFrameNode> n = ffi::make_object<ForFrameNode>();
n->vars.reserve(extents.size());
n->doms.reserve(extents.size());
+ n->steps.resize(extents.size());
for (const auto& extent : extents) {
DataType dtype = extent.dtype();
n->vars.push_back(Var("v", extent.dtype()));
n->doms.push_back(Range(make_const(dtype, 0), extent));
}
- n->f_make_for_loop = [](ffi::Array<Var> vars, ffi::Array<Range> doms, Stmt
body) -> Stmt {
+ n->f_make_for_loop = [](ffi::Array<Var> vars, ffi::Array<Range> doms,
+ ffi::Array<ffi::Optional<PrimExpr>> steps, Stmt
body) -> Stmt {
ICHECK_EQ(vars.size(), doms.size());
+ ICHECK_EQ(vars.size(), steps.size());
int n = vars.size();
for (int i = n - 1; i >= 0; --i) {
Range dom = doms[i];
Var var = vars[i];
body = For(var, dom->min, dom->extent, ForKind::kSerial, std::move(body),
- /*thread_binding=*/std::nullopt, /*annotations=*/{});
+ /*thread_binding=*/std::nullopt, /*annotations=*/{},
/*step=*/steps[i]);
}
return body;
};
diff --git a/src/script/printer/tir/for_loop.cc
b/src/script/printer/tir/for_loop.cc
index 742d23f69c..b2e091f380 100644
--- a/src/script/printer/tir/for_loop.cc
+++ b/src/script/printer/tir/for_loop.cc
@@ -39,7 +39,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
if (l->kind != tir::ForKind::kSerial || //
!tir::is_zero(l->min) || //
!l->annotations.empty() || //
- f_var_dep(l->extent)) {
+ !l->HasTrivialStep() || f_var_dep(l->extent)) {
break;
}
grid.push_back(l);
@@ -69,7 +69,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
ffi::Optional<ExprDoc> max = std::nullopt;
ffi::Optional<ExprDoc> annotations = std::nullopt;
ffi::Optional<ExprDoc> thread = std::nullopt;
- if (tir::is_zero(loop->min)) {
+ if (tir::is_zero(loop->min) && loop->HasTrivialStep()) {
max = d->AsDoc<ExprDoc>(loop->extent, loop_p->Attr("extent"));
} else {
min = d->AsDoc<ExprDoc>(loop->min, loop_p->Attr("min"));
@@ -78,10 +78,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
if (!loop->annotations.empty()) {
annotations = d->AsDoc<ExprDoc>(loop->annotations,
loop_p->Attr("annotations"));
}
+ bool use_range_sugar = false;
ExprDoc prefix{ffi::UnsafeInit()};
if (loop->kind == tir::ForKind::kSerial) {
if (loop->annotations.empty()) {
prefix = IdDoc("range");
+ use_range_sugar = true;
} else {
prefix = TIR(d, "serial");
}
@@ -115,6 +117,15 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
kwargs_keys.push_back("annotations");
kwargs_values.push_back(annotations.value());
}
+ if (!loop->HasTrivialStep()) {
+ ExprDoc step = d->AsDoc<ExprDoc>(*loop->step, loop_p->Attr("step"));
+ if (use_range_sugar) {
+ args.push_back(step);
+ } else {
+ kwargs_keys.push_back("step");
+ kwargs_values.push_back(step);
+ }
+ }
ExprDoc rhs = prefix->Call(args, kwargs_keys, kwargs_values);
AsDocBody(loop->body, loop_p->Attr("body"), (*f).get(), d);
return ForDoc(lhs, rhs, (*f)->stmts);
diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc
index d9ee972321..bc67cdad2f 100644
--- a/src/target/llvm/codegen_cpu.cc
+++ b/src/target/llvm/codegen_cpu.cc
@@ -1152,14 +1152,15 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) {
void CodeGenCPU::VisitStmt_(const ForNode* op) {
EmitDebugLocation(op);
- ICHECK(is_zero(op->min));
if (op->kind == ForKind::kSerial || op->kind == ForKind::kUnrolled) {
CodeGenLLVM::VisitStmt_(op);
} else if (op->kind == ForKind::kParallel) {
+ ICHECK(is_zero(op->min)) << "Parallel launch require canonical loop with
zero start index";
+ ICHECK(op->HasTrivialStep()) << "Parallel launch require canonical loop
with trivial loop step";
if (parallel_env_.penv == nullptr) {
- CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind,
op->body,
- op->thread_binding, op->annotations),
- 0, std::string("loop_parallel_") +
op->loop_var->name_hint.c_str());
+ auto copy_node = For(ffi::make_object<ForNode>(*op));
+ CreateParallelLaunch(copy_node, 0,
+ std::string("loop_parallel_") +
op->loop_var->name_hint.c_str());
} else {
// already in parallel env.
ICHECK(parallel_env_.task_id.defined());
@@ -1171,13 +1172,14 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
ICHECK(!parallel_env_.in_parallel_loop)
<< "Nested parallel loop is not supported by threadpool, try fuse
them instead";
parallel_env_.in_parallel_loop = true;
+ PrimExpr end = is_zero(op->min) ? op->extent :
analyzer_->Simplify(op->min + op->extent);
if (parallel_env_.stride_pattern) {
- CreateSerialFor(MakeValue(task_id), MakeValue(op->extent),
MakeValue(num_task),
- op->loop_var, op->body);
+ CreateSerialFor(MakeValue(task_id), MakeValue(end),
MakeValue(num_task), op->loop_var,
+ op->body);
} else {
PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task;
PrimExpr begin = min(task_id * step, op->extent);
- PrimExpr end = min((task_id + make_const(t, 1)) * step, op->extent);
+ end = min((task_id + make_const(t, 1)) * step, end);
CreateSerialFor(MakeValue(begin), MakeValue(end),
llvm::ConstantInt::getSigned(GetLLVMType(end), 1),
op->loop_var, op->body);
}
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 5f8b599a3b..131c8212c5 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -2023,7 +2023,6 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) {
void CodeGenLLVM::VisitStmt_(const ForNode* op) {
EmitDebugLocation(op);
- ICHECK(is_zero(op->min));
analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
if (op->kind == ForKind::kUnrolled) {
LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, "
@@ -2031,8 +2030,11 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) {
} else {
ICHECK(op->kind == ForKind::kSerial);
}
- CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
- llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1),
op->loop_var, op->body);
+ PrimExpr step = op->step.value_or(make_const(op->extent->dtype, 1));
+ PrimExpr end = is_zero(op->min) ? op->extent : analyzer_->Simplify(op->min +
op->extent);
+ llvm::Value* begin_value = MakeValue(op->min);
+ llvm::Value* end_value = MakeValue(end);
+ CreateSerialFor(begin_value, end_value, MakeValue(step), op->loop_var,
op->body);
}
void CodeGenLLVM::VisitStmt_(const WhileNode* op) {
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 8ebd41645a..52ad781669 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -1120,13 +1120,21 @@ void CodeGenC::VisitStmt_(const AssertStmtNode* op) {
}
void CodeGenC::VisitStmt_(const ForNode* op) {
- std::string extent = PrintExpr(op->extent);
+ std::string begin_str = PrintExpr(op->min);
+ PrimExpr end = is_zero(op->min) ? op->extent :
arith::Analyzer().Simplify(op->min + op->extent);
+ std::string end_str = PrintExpr(end);
+ std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : "";
PrintIndent();
std::string vid = AllocVarID(op->loop_var.get());
- ICHECK(is_zero(op->min));
stream << "for (";
PrintType(op->loop_var.dtype(), stream);
- stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ++" << vid
<< ") {\n";
+ stream << ' ' << vid << " = " << begin_str << "; " << vid << " < " <<
end_str << "; ";
+ if (step_str.empty()) {
+ stream << "++" << vid;
+ } else {
+ stream << vid << " += " << step_str;
+ }
+ stream << ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index 9565eba5d4..a9cfad9ab6 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -319,7 +319,6 @@ std::string CodeGenCUDA::Finish() {
}
void CodeGenCUDA::VisitStmt_(const tir::ForNode* op) {
- ICHECK(is_const_int(op->min, 0));
if (op->kind == tir::ForKind::kUnrolled) {
PrintIndent();
stream << "#pragma unroll\n";
diff --git a/src/target/source/codegen_webgpu.cc
b/src/target/source/codegen_webgpu.cc
index 330a54563f..cf8176001a 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -667,13 +667,21 @@ void CodeGenWebGPU::VisitStmt_(const AllocateNode* op) {
}
void CodeGenWebGPU::VisitStmt_(const ForNode* op) {
- std::string extent = PrintExpr(op->extent);
+ std::string begin_str = PrintExpr(op->min);
+ PrimExpr end = is_zero(op->min) ? op->extent :
arith::Analyzer().Simplify(op->min + op->extent);
+ std::string end_str = PrintExpr(end);
+ std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : "";
std::string vid = AllocVarID(op->loop_var.get());
- ICHECK(is_zero(op->min));
PrintIndent();
stream << "for (var " << vid << " : ";
PrintType(op->loop_var.dtype(), stream);
- stream << " = 0; " << vid << " < " << extent << "; " << vid << "++) {\n";
+ stream << " = " << begin_str << "; " << vid << " < " << end_str << "; " <<
vid;
+ if (step_str.empty()) {
+ stream << "++";
+ } else {
+ stream << " += " << step_str;
+ }
+ stream << ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
diff --git a/src/target/spirv/codegen_spirv.cc
b/src/target/spirv/codegen_spirv.cc
index c062926cc2..136f969896 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -672,10 +672,21 @@ void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) {
}
void CodeGenSPIRV::VisitStmt_(const ForNode* op) {
- ICHECK(is_zero(op->min));
analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
spirv::Value init_value = MakeValue(op->min);
- spirv::Value extent_value = MakeValue(op->extent);
+ PrimExpr end = is_zero(op->min) ? op->extent : analyzer_->Simplify(op->min +
op->extent);
+ spirv::Value end_value = MakeValue(end);
+ spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2);
+
+ // loop step
+ spirv::Value step;
+ if (op->HasTrivialStep()) {
+ step = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1)
+ : builder_->UIntImm(loop_var.stype,
1);
+ } else {
+ step = MakeValue(tvm::cast(end->dtype, *op->step));
+ }
+
// Must get init label after making value(to make sure they are correct)
spirv::Label init_label = builder_->CurrentLabel();
spirv::Label head_label = builder_->NewLabel();
@@ -690,9 +701,8 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) {
// Loop head
builder_->StartLabel(head_label);
- spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2);
loop_var.SetIncoming(0, init_value, init_label);
- spirv::Value loop_cond = builder_->LT(loop_var, extent_value);
+ spirv::Value loop_cond = builder_->LT(loop_var, end_value);
uint32_t control =
(op->kind == ForKind::kUnrolled ? spv::LoopControlUnrollMask :
spv::LoopControlMaskNone);
builder_->MakeInst(spv::OpLoopMerge, merge_label, continue_label, control);
@@ -707,9 +717,8 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) {
// loop continue
builder_->StartLabel(continue_label);
- spirv::Value one = op->loop_var.dtype().is_int() ?
builder_->IntImm(loop_var.stype, 1)
- :
builder_->UIntImm(loop_var.stype, 1);
- spirv::Value next_value = builder_->Add(loop_var, one);
+
+ spirv::Value next_value = builder_->Add(loop_var, step);
loop_var.SetIncoming(1, next_value, builder_->CurrentLabel());
builder_->MakeInst(spv::OpBranch, head_label);
// loop merge
diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc
index d6dcae6540..393ac7ee57 100644
--- a/src/tir/ir/data_type_rewriter.cc
+++ b/src/tir/ir/data_type_rewriter.cc
@@ -41,8 +41,13 @@ Stmt DataTypeLegalizer::VisitStmt_(const ForNode* op) {
ICHECK(op != nullptr) << "Expected type to be ForNode, but get " <<
s->GetTypeKey();
PrimExpr e = VisitExpr(op->loop_var);
Var var = Downcast<Var>(e);
- return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent),
op->kind, op->body,
- op->thread_binding, op->annotations);
+ auto n = CopyOnWrite(op);
+ n->min = cast(var.dtype(), op->min);
+ n->extent = cast(var.dtype(), op->extent);
+ if (op->step.has_value()) {
+ n->step = cast(var.dtype(), *op->step);
+ }
+ return For(n);
}
Stmt DataTypeLegalizer::VisitStmt_(const BlockRealizeNode* op) {
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 47622757e5..b7e28e84e7 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -132,7 +132,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
// For
For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
- ffi::Optional<IterVar> thread_binding, ffi::Map<ffi::String, Any>
annotations, Span span) {
+ ffi::Optional<IterVar> thread_binding, ffi::Map<ffi::String, Any>
annotations,
+ ffi::Optional<PrimExpr> step, Span span) {
ICHECK(loop_var.defined());
ICHECK(min.defined());
ICHECK(extent.defined());
@@ -148,8 +149,8 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent,
ForKind kind, Stmt body,
require_scalar_int_dtype(min, "min");
require_scalar_int_dtype(extent, "extent");
- // When extent or min is an IntImm but has narrower dtype than loop_var, we
directly promote them
- // without raising errors.
+ // When extent, min or step is an IntImm but has narrower dtype than loop_var
+ // we directly promote them without raising errors.
auto try_promote_imm_dtype = [&](const PrimExpr& e) {
ICHECK(e.dtype().bits() <= loop_var.dtype().bits())
<< " Loop variable's dtype (" << loop_var.dtype()
@@ -168,6 +169,12 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent,
ForKind kind, Stmt body,
ICHECK(loop_var.dtype() == min.dtype()) << loop_var.dtype() << " vs " <<
min.dtype();
ICHECK(loop_var.dtype() == extent.dtype()) << loop_var.dtype() << " vs " <<
extent.dtype();
+ if (step.has_value()) {
+ require_scalar_int_dtype(*step, "step");
+ step = try_promote_imm_dtype(*step);
+ ICHECK(loop_var.dtype() == (*step).dtype()) << loop_var.dtype() << " vs "
<< (*step).dtype();
+ }
+
ObjectPtr<ForNode> node = ffi::make_object<ForNode>();
node->loop_var = std::move(loop_var);
node->min = std::move(min);
@@ -176,19 +183,22 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent,
ForKind kind, Stmt body,
node->body = std::move(body);
node->thread_binding = std::move(thread_binding);
node->annotations = std::move(annotations);
+ node->step = std::move(step);
node->span = std::move(span);
data_ = std::move(node);
}
+bool ForNode::HasTrivialStep() const { return !step.has_value() ||
is_one(*step); }
+
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def(
- "tir.For", [](Var loop_var, PrimExpr min, PrimExpr extent, int kind,
Stmt body,
- ffi::Optional<IterVar> thread_binding,
- ffi::Optional<ffi::Map<ffi::String, Any>> annotations,
Span span) {
- return For(loop_var, min, extent, static_cast<ForKind>(kind), body,
thread_binding,
- annotations.value_or(ffi::Map<ffi::String, Any>()), span);
- });
+ refl::GlobalDef().def("tir.For", [](Var loop_var, PrimExpr min, PrimExpr
extent, int kind,
+ Stmt body, ffi::Optional<IterVar>
thread_binding,
+ ffi::Optional<ffi::Map<ffi::String,
Any>> annotations,
+ ffi::Optional<PrimExpr> step, Span span)
{
+ return For(loop_var, min, extent, static_cast<ForKind>(kind), body,
thread_binding,
+ annotations.value_or(ffi::Map<ffi::String, Any>()), step, span);
+ });
}
std::ostream& operator<<(std::ostream& out, ForKind type) { // NOLINT(*)
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index 80c787b114..e6666cc638 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -46,6 +46,9 @@ void StmtVisitor::VisitStmt_(const AttrStmtNode* op) {
void StmtVisitor::VisitStmt_(const ForNode* op) {
this->VisitExpr(op->min);
this->VisitExpr(op->extent);
+ if (op->step.has_value()) {
+ this->VisitExpr(*op->step);
+ }
this->VisitStmt(op->body);
}
@@ -260,13 +263,19 @@ Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) {
Stmt StmtMutator::VisitStmt_(const ForNode* op) {
PrimExpr min = this->VisitExpr(op->min);
PrimExpr extent = this->VisitExpr(op->extent);
+ ffi::Optional<PrimExpr> step{std::nullopt};
+ if (op->step.has_value()) {
+ step = this->VisitExpr(*op->step);
+ }
Stmt body = this->VisitStmt(op->body);
- if (min.same_as(op->min) && extent.same_as(op->extent) &&
body.same_as(op->body)) {
+ if (min.same_as(op->min) && extent.same_as(op->extent) &&
body.same_as(op->body) &&
+ step.same_as(op->step)) {
return ffi::GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->min = std::move(min);
n->extent = std::move(extent);
+ n->step = std::move(step);
n->body = std::move(body);
return Stmt(n);
}
diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc
b/src/tir/schedule/primitive/blockize_tensorize.cc
index fbc569ece6..2ae32ea66a 100644
--- a/src/tir/schedule/primitive/blockize_tensorize.cc
+++ b/src/tir/schedule/primitive/blockize_tensorize.cc
@@ -703,7 +703,7 @@ class BlockizeRewriter : public StmtMutator {
Stmt VisitStmt_(const ForNode* loop) final {
if (loop == lca_->stmt) {
return For(loop->loop_var, loop->min, loop->extent, loop->kind,
RewriteSeq(loop->body),
- loop->thread_binding, loop->annotations, loop->span);
+ loop->thread_binding, loop->annotations, loop->step,
loop->span);
}
return StmtMutator::VisitStmt_(loop);
}
diff --git a/src/tir/schedule/primitive/decompose_padding.cc
b/src/tir/schedule/primitive/decompose_padding.cc
index 5499ab9c58..7e61fd4eb2 100644
--- a/src/tir/schedule/primitive/decompose_padding.cc
+++ b/src/tir/schedule/primitive/decompose_padding.cc
@@ -343,7 +343,7 @@ static std::pair<Stmt, BlockRealize>
CreateInBoundBlock(const BlockRealizeNode*
PrimExpr min = it == new_loop_ranges.end() ? loop->min : (*it).second->min;
PrimExpr extent = it == new_loop_ranges.end() ? loop->extent :
(*it).second->extent;
nest_stmt_root = For(loop->loop_var, min, extent, loop->kind,
nest_stmt_root,
- loop->thread_binding, loop->annotations, loop->span);
+ loop->thread_binding, loop->annotations, loop->step,
loop->span);
if (loop.same_as(highest_pos_inclusive)) {
break;
}
diff --git a/src/tir/schedule/primitive/loop_transformation.cc
b/src/tir/schedule/primitive/loop_transformation.cc
index b2c64e65e5..3cd364b0fd 100644
--- a/src/tir/schedule/primitive/loop_transformation.cc
+++ b/src/tir/schedule/primitive/loop_transformation.cc
@@ -1137,8 +1137,8 @@ void Reorder(ScheduleState self, const
ffi::Array<StmtSRef>& ordered_loop_srefs)
StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref) {
if (sref->stmt->IsInstance<ForNode>()) {
- For new_loop(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial,
- ffi::GetRef<Stmt>(sref->stmt));
+ For new_loop =
+ For(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial,
ffi::GetRef<Stmt>(sref->stmt));
self->Replace(sref, new_loop, {});
return self->stmt2ref.at(new_loop.get());
}
diff --git a/src/tir/schedule/primitive/reduction.cc
b/src/tir/schedule/primitive/reduction.cc
index 49dc31e6f6..0629757a13 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -268,7 +268,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const
StmtSRef& block_sref,
std::unordered_map<Var, Var> loop_var_map;
Stmt body = BlockRealize(init_realize);
for (int i : chosen_loops) {
- const ForNode* old_loop = TVM_SREF_TO_FOR(loops[i]);
+ For old_loop = ffi::GetRef<For>(TVM_SREF_TO_FOR(loops[i]));
// Create a new equivalent to the chosen loop
Var old_loop_var = old_loop->loop_var;
Var new_loop_var = old_loop_var.copy_with_suffix("_init");
@@ -280,12 +280,11 @@ StmtSRef DecomposeReduction(ScheduleState self, const
StmtSRef& block_sref,
thread_binding.CopyOnWrite()->var = new_var;
opt_thread_binding = thread_binding;
}
- body = For(/*loop_var=*/new_loop_var,
- /*min=*/old_loop->min,
- /*extent=*/old_loop->extent,
- /*kind=*/old_loop->kind,
- /*body=*/body,
- /*thread_binding=*/opt_thread_binding);
+ auto new_loop = old_loop.CopyOnWrite();
+ new_loop->loop_var = new_loop_var;
+ new_loop->thread_binding = opt_thread_binding;
+ new_loop->body = body;
+ body = ffi::GetRef<For>(new_loop);
}
body = Substitute(body, loop_var_map);
// Step 6. Mutate IR
diff --git a/src/tir/transforms/canonicalize_loop.cc
b/src/tir/transforms/canonicalize_loop.cc
new file mode 100644
index 0000000000..93511bf84b
--- /dev/null
+++ b/src/tir/transforms/canonicalize_loop.cc
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/transforms/canonicalize_loop.cc
+ * \brief Canonicalize all loops to start from zero and step one.
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/ffi/reflection/registry.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <utility>
+
+namespace tvm {
+namespace tir {
+
+class LoopCanonicalizer : public StmtExprMutator {
+ public:
+ LoopCanonicalizer() = default;
+
+ private:
+ Stmt VisitStmt_(const ForNode* op) final {
+ if (is_zero(op->min) && op->HasTrivialStep()) {
+ return StmtExprMutator::VisitStmt_(op);
+ }
+ arith::Analyzer analyzer;
+ const auto* loop_var = op->loop_var.get();
+ PrimExpr step = op->step.value_or(make_const(loop_var->dtype, 1));
+
+ // report warning for negative step, since it would be a forever loop
+ if (!analyzer.CanProveGreaterEqual(step, 1)) {
+ // TODO(tvm): prove dynamic shaped step
+ LOG(FATAL) << "Loop step for " << op->loop_var << " may not be positive:
" << step;
+ }
+
+ new_iter_info_[loop_var] = std::make_pair(step, op->min);
+ auto n = CopyOnWrite(op);
+ n->body = VisitStmt(op->body);
+ n->min = make_zero(loop_var->dtype);
+ n->extent = analyzer.Simplify(ceildiv(op->extent, step));
+ n->step = std::nullopt;
+ new_iter_info_.erase(loop_var);
+ return For(n);
+ }
+
+ PrimExpr VisitExpr_(const VarNode* op) final {
+ auto it = new_iter_info_.find(op);
+ if (it != new_iter_info_.end()) {
+ const auto& [stride, offset] = it->second;
+ return ffi::GetRef<Var>(op) * stride + offset;
+ }
+ return ffi::GetRef<Var>(op);
+ }
+
+ /*! \brief Map iter variable `x` to `x * stride + offset`. */
+ std::unordered_map<const VarNode*, std::pair<PrimExpr, PrimExpr>>
new_iter_info_;
+};
+
+PrimFunc CanonicalizeLoop(PrimFunc func) {
+ PrimFuncNode* fptr = func.CopyOnWrite();
+ fptr->body = LoopCanonicalizer()(func->body);
+ return func;
+}
+
+namespace transform {
+
+Pass CanonicalizeLoop() {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ return CanonicalizeLoop(std::move(f));
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.CanonicalizeLoop", {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("tir.transform.CanonicalizeLoop", CanonicalizeLoop);
+}
+
+} // namespace transform
+
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/transforms/common_subexpr_elim.cc
b/src/tir/transforms/common_subexpr_elim.cc
index dfeb7fe2e2..9b9619fae9 100644
--- a/src/tir/transforms/common_subexpr_elim.cc
+++ b/src/tir/transforms/common_subexpr_elim.cc
@@ -602,7 +602,7 @@ Stmt CommonSubexpressionEliminator::VisitStmt_(const
ForNode* op) {
// Otherwise return a for node built with the new `min_new`, `extent_new`
and `body_new`
// that have just been obtained
return For(op->loop_var, min_new, extent_new, op->kind, body_new,
op->thread_binding,
- op->annotations, op->span);
+ op->annotations, op->step, op->span);
}
}
diff --git a/src/tir/transforms/convert_for_loops_serial.cc
b/src/tir/transforms/convert_for_loops_serial.cc
index a8b30ebf91..691d8b885c 100644
--- a/src/tir/transforms/convert_for_loops_serial.cc
+++ b/src/tir/transforms/convert_for_loops_serial.cc
@@ -43,7 +43,7 @@ class ForLoopSerialConverter : public StmtExprMutator {
Stmt ForLoopSerialConverter::VisitStmt_(const ForNode* op) {
if (op->kind == ForKind::kParallel) {
return For(op->loop_var, op->min, op->extent, ForKind::kSerial, op->body,
op->thread_binding,
- op->annotations, op->span);
+ op->annotations, op->step, op->span);
}
return StmtExprMutator::VisitStmt_(op);
}
diff --git a/src/tir/transforms/inject_software_pipeline.cc
b/src/tir/transforms/inject_software_pipeline.cc
index af1b7c8bdf..f4258fc479 100644
--- a/src/tir/transforms/inject_software_pipeline.cc
+++ b/src/tir/transforms/inject_software_pipeline.cc
@@ -943,7 +943,7 @@ class PipelineRewriter : public StmtExprMutator {
if (!is_unit_loop) {
new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind,
std::move(new_loop),
- std::nullopt, preserved_annotations_);
+ std::nullopt, preserved_annotations_, std::nullopt);
}
// Update producer heads in the global async states.
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index dba13cfbbc..8bcb2077c6 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -362,9 +362,9 @@ class IRConvertSSA final : public StmtExprMutator {
if (defined_.count(v.get())) {
ScopedRedefine redefine(this, v);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<ForNode>();
- return For(redefine.new_var, op->min, op->extent, op->kind, op->body,
op->thread_binding,
- op->annotations);
+ auto n = ffi::make_object<ForNode>(*stmt.as<ForNode>());
+ n->loop_var = redefine.new_var;
+ return For(n);
} else {
defined_.insert(v.get());
return StmtExprMutator::VisitStmt_(op);
diff --git a/src/tir/transforms/lift_thread_binding.cc
b/src/tir/transforms/lift_thread_binding.cc
index 2dffc11b72..45bbf4af52 100644
--- a/src/tir/transforms/lift_thread_binding.cc
+++ b/src/tir/transforms/lift_thread_binding.cc
@@ -133,7 +133,7 @@ class ThreadBindingLifter : public StmtExprMutator {
ForKind::kThreadBinding, std::move(body),
IterVar(Range(nullptr), Var(iter_var->thread_tag,
iter_var->var->dtype),
kThreadIndex, iter_var->thread_tag),
- annotation);
+ annotation, std::nullopt);
}
}
if (is_kernel_root) {
diff --git a/src/tir/transforms/loop_partition.cc
b/src/tir/transforms/loop_partition.cc
index e644c387cf..fd9bd2d653 100644
--- a/src/tir/transforms/loop_partition.cc
+++ b/src/tir/transforms/loop_partition.cc
@@ -760,14 +760,18 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var
var, PrimExpr min, Prim
inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt
body) {
const ForNode* for_node = static_cast<const ForNode*>(node);
ICHECK(for_node);
+
if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1)) &&
!no_unroll_loop_with_extent_one_ && for_node->annotations.empty()) {
// If the loop extent is 1, do not create the loop anymore
return Substitute(body, {{Var{for_node->loop_var},
make_const(DataType::Int(32), 0)}});
} else {
ICHECK(for_node->kind != ForKind::kThreadBinding);
- return For(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent,
for_node->kind, body,
- for_node->thread_binding, for_node->annotations);
+ auto new_loop = ffi::make_object<ForNode>(*for_node);
+ new_loop->min = IntImm(for_node->min.dtype(), 0);
+ new_loop->extent = extent;
+ new_loop->body = body;
+ return For(new_loop);
}
}
diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc
b/src/tir/transforms/lower_cross_thread_reduction.cc
index 25e8734ff1..2f7ac3ddb1 100644
--- a/src/tir/transforms/lower_cross_thread_reduction.cc
+++ b/src/tir/transforms/lower_cross_thread_reduction.cc
@@ -878,7 +878,9 @@ class CrossThreadReductionTransformer : public StmtMutator {
/*body=*/body, //
/*thread_binding=*/
IterVar(NullValue<Range>(), Var("", loop_vars[i]->dtype),
IterVarType::kThreadIndex,
- "threadIdx." + dim_index));
+ "threadIdx." + dim_index),
+ /*annotations=*/{},
+ /*step=*/std::nullopt);
}
return body;
}
diff --git a/src/tir/transforms/lower_opaque_block.cc
b/src/tir/transforms/lower_opaque_block.cc
index 2e53e89667..c0363dd898 100644
--- a/src/tir/transforms/lower_opaque_block.cc
+++ b/src/tir/transforms/lower_opaque_block.cc
@@ -111,7 +111,7 @@ class OpaqueBlockLower : public StmtExprMutator {
} else {
// Case 3. An ordinary loop
body = For(op->loop_var, std::move(min), std::move(extent), op->kind,
std::move(body),
- std::nullopt, new_annotations);
+ std::nullopt, new_annotations, op->step);
}
// Step 5. Insert nested attrs
for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) {
diff --git a/src/tir/transforms/memhammer_coalesce.cc
b/src/tir/transforms/memhammer_coalesce.cc
index 094f48e321..0d5b270442 100644
--- a/src/tir/transforms/memhammer_coalesce.cc
+++ b/src/tir/transforms/memhammer_coalesce.cc
@@ -128,7 +128,8 @@ Stmt SplitBindVectorize(const Stmt& stmt, const
ConstraintSet& constraints) {
body = For(new_loop_vars.back(), 0, vector_len, ForKind::kVectorized,
std::move(body));
for (int i = n - 2; i >= 1; i--) {
body = For(new_loop_vars[i], 0, factors[i], ForKind::kThreadBinding,
std::move(body),
- IterVar(Range(nullptr), Var(thread_axis[i - 1]), kThreadIndex,
thread_axis[i - 1]));
+ IterVar(Range(nullptr), Var(thread_axis[i - 1]), kThreadIndex,
thread_axis[i - 1]),
+ {}, std::nullopt);
}
return For(new_loop_vars[0], 0, factors[0], ForKind::kSerial,
std::move(body));
}
diff --git a/src/tir/transforms/memhammer_tensorcore_rewrite.cc
b/src/tir/transforms/memhammer_tensorcore_rewrite.cc
index e16c518771..e69ac30366 100644
--- a/src/tir/transforms/memhammer_tensorcore_rewrite.cc
+++ b/src/tir/transforms/memhammer_tensorcore_rewrite.cc
@@ -70,8 +70,9 @@ std::pair<Stmt, ffi::Optional<For>> TileWmmaBlock(Stmt stmt) {
}
For compute_location = Downcast<For>(body);
for (int i = n - 3; i >= 0; i--) {
- body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent,
loops[i]->kind, std::move(body),
- loops[i]->thread_binding, loops[i]->annotations);
+ auto new_loop = ffi::GetRef<For>(loops[i]);
+ new_loop.CopyOnWrite()->body = std::move(body);
+ body = new_loop;
}
return {body, compute_location};
}
@@ -187,8 +188,9 @@ Stmt RewriteWmmaLoad(Stmt stmt) {
},
/*annotations=*/{}));
for (int i = n - 3; i >= 0; i--) {
- wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent,
loops[i]->kind,
- std::move(wmma_body), loops[i]->thread_binding,
loops[i]->annotations);
+ auto new_loop = ffi::GetRef<For>(loops[i]);
+ new_loop.CopyOnWrite()->body = std::move(wmma_body);
+ wmma_body = new_loop;
}
return wmma_body;
}
@@ -290,8 +292,9 @@ Stmt RewriteWmmaStore(Stmt stmt) {
},
/*annotations=*/{}));
for (int i = n - 3; i >= 0; i--) {
- wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent,
loops[i]->kind,
- std::move(wmma_body), loops[i]->thread_binding,
loops[i]->annotations);
+ auto new_loop = ffi::GetRef<For>(loops[i]);
+ new_loop.CopyOnWrite()->body = std::move(wmma_body);
+ wmma_body = new_loop;
}
return wmma_body;
}
@@ -395,8 +398,9 @@ std::pair<Stmt, ffi::Optional<For>>
TileMmaToGlobalBlock(Stmt stmt) {
}
For compute_location = Downcast<For>(body);
for (int i = n - 3; i >= 0; i--) {
- body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent,
loops[i]->kind, std::move(body),
- loops[i]->thread_binding, loops[i]->annotations);
+ auto new_loop = ffi::GetRef<For>(loops[i]);
+ new_loop.CopyOnWrite()->body = std::move(body);
+ body = new_loop;
}
return {body, compute_location};
}
@@ -484,21 +488,21 @@ Stmt RewriteMmaStore(Stmt stmt) {
/*reads=*/{BufferRegion(src_buffer, read_region)},
/*writes=*/{BufferRegion(tgt_buffer, write_region)},
/*name_hint=*/"mma_store",
- AttrStmt(/*node=*/IterVar(
- /*dom=*/Range::FromMinExtent(0, 32),
- /*var=*/tx,
- /*iter_type=*/IterVarType::kThreadIndex,
- /*thread_tag=*/"threadIdx.x"),
- /*attr_key=*/"thread_extent",
- /*value=*/Integer(32),
- /*body=*/
- For(vec, 0, 2, ForKind::kVectorized,
- /*body=*/
- BufferStore(new_tgt_buffer,
- BufferLoad(new_src_buffer,
- {floordiv(tx, 4), floormod(tx,
4) * 2 + vec}),
- {floordiv(tx, 4), floormod(tx, 4) * 2 +
vec}),
- /*annotations=*/{})),
+ AttrStmt(
+ /*node=*/IterVar(
+ /*dom=*/Range::FromMinExtent(0, 32),
+ /*var=*/tx,
+ /*iter_type=*/IterVarType::kThreadIndex,
+ /*thread_tag=*/"threadIdx.x"),
+ /*attr_key=*/"thread_extent",
+ /*value=*/Integer(32),
+ /*body=*/
+ For(vec, 0, 2, ForKind::kVectorized,
+ /*body=*/
+ BufferStore(
+ new_tgt_buffer,
+ BufferLoad(new_src_buffer, {floordiv(tx, 4),
floormod(tx, 4) * 2 + vec}),
+ {floordiv(tx, 4), floormod(tx, 4) * 2 + vec}))),
/*init=*/std::nullopt,
/*alloc_buffers=*/{},
/*match_buffers=*/
@@ -510,8 +514,9 @@ Stmt RewriteMmaStore(Stmt stmt) {
// Step 3.4. wrap outer loops
for (int i = n - 3; i >= 0; i--) {
- mma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent,
loops[i]->kind,
- std::move(mma_body), loops[i]->thread_binding,
loops[i]->annotations);
+ auto new_loop = ffi::GetRef<For>(loops[i]);
+ new_loop.CopyOnWrite()->body = std::move(mma_body);
+ mma_body = new_loop;
}
return mma_body;
}
diff --git a/src/tir/transforms/storage_rewrite.cc
b/src/tir/transforms/storage_rewrite.cc
index 4af12c69a3..830364788c 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -510,7 +510,7 @@ class StoragePlanRewriter : public StmtExprMutator {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
return For(op->loop_var, op->min, op->extent, op->kind, MakeAttach(svec,
op->body),
- op->thread_binding, op->annotations);
+ op->thread_binding, op->annotations, op->step);
} else {
return StmtExprMutator::VisitStmt_(op);
}
diff --git a/src/tir/transforms/unify_thread_binding.cc
b/src/tir/transforms/unify_thread_binding.cc
index fa1e221459..502acd5a46 100644
--- a/src/tir/transforms/unify_thread_binding.cc
+++ b/src/tir/transforms/unify_thread_binding.cc
@@ -79,7 +79,8 @@ class ThreadBindingUnifier : public StmtExprMutator {
/*extent=*/IntImm(dtype, 1), //
/*kind=*/ForKind::kSerial, stmt, //
/*thread_binding=*/std::nullopt, //
- /*annotation=*/std::move(annotations));
+ /*annotation=*/std::move(annotations),
+ /*step=*/std::nullopt);
}
}
@@ -155,7 +156,8 @@ class ThreadBindingUnifier : public StmtExprMutator {
result = For(thread_binding->var, thread_binding->dom->min,
thread_binding->dom->extent,
ForKind::kThreadBinding, result,
IterVar(NullValue<Range>(), Var(""),
IterVarType::kThreadIndex,
- thread_binding->thread_tag));
+ thread_binding->thread_tag),
+ {}, std::nullopt);
launch_threads_.pop_back();
}
return result;
diff --git a/src/tir/transforms/unroll_loop.cc
b/src/tir/transforms/unroll_loop.cc
index d1269634ab..74abea57ba 100644
--- a/src/tir/transforms/unroll_loop.cc
+++ b/src/tir/transforms/unroll_loop.cc
@@ -156,8 +156,9 @@ class LoopUnroller : public StmtExprMutator {
} else {
if (auto_unroll) {
if (op->kind != ForKind::kUnrolled) {
- return For(op->loop_var, op->min, op->extent, ForKind::kUnrolled,
op->body,
- op->thread_binding, op->annotations);
+ auto n = CopyOnWrite(op);
+ n->kind = ForKind::kUnrolled;
+ return For(n);
}
}
return stmt;
diff --git a/src/tir/transforms/vectorize_loop.cc
b/src/tir/transforms/vectorize_loop.cc
index 857f0b4cea..068903baa8 100644
--- a/src/tir/transforms/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -752,8 +752,10 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
if (extent.same_as(op->extent) && body.same_as(op->body)) {
return ffi::GetRef<Stmt>(op);
} else {
- return For(op->loop_var, op->min, extent, op->kind, body,
op->thread_binding,
- op->annotations);
+ auto n = CopyOnWrite(op);
+ n->extent = extent;
+ n->body = body;
+ return For(n);
}
}
// IfThenElse
diff --git a/tests/python/codegen/test_target_codegen.py
b/tests/python/codegen/test_target_codegen.py
index 3332d015a8..7530786a38 100644
--- a/tests/python/codegen/test_target_codegen.py
+++ b/tests/python/codegen/test_target_codegen.py
@@ -16,7 +16,7 @@
# under the License.
import pytest
-
+import numpy as np
import tvm
from tvm.script import tir as T
@@ -88,5 +88,47 @@ def test_buffer_load_predicate_not_supported_gpu(target):
tvm.compile(func)
[email protected]_targets("c", "llvm")
+def test_codegen_loop_step(target):
+ @T.prim_func
+ def test_loop_step(
+ A: T.Buffer((1024,), "float32"),
+ B: T.Buffer((1024,), "float32"),
+ C: T.Buffer((1024,), "float32"),
+ ):
+ for i in T.serial(3, 1024, step=96):
+ C[i] = A[i] + B[i]
+
+ with tvm.transform.PassContext(disabled_pass=["tir.CanonicalizeLoop"]):
+ lib = tvm.compile(test_loop_step, target=target)
+
+ src = lib.mod.inspect_source()
+ if target == "c":
+ assert src.find("for (int32_t i = 3; i < 1024; i += 96)") >= 0
+
+ dev = tvm.device(target, 0)
+ a_np = np.random.rand(1024).astype("float32")
+ b_np = np.random.rand(1024).astype("float32")
+ c_np = np.zeros(1024, dtype="float32")
+ a_tvm = tvm.runtime.tensor(a_np, dev)
+ b_tvm = tvm.runtime.tensor(b_np, dev)
+ c_tvm = tvm.runtime.tensor(c_np, dev)
+
+ lib(a_tvm, b_tvm, c_tvm)
+
+ c_result = c_tvm.numpy()
+
+ # Check that the loop executes at positions 3, 99, 195, 291, 387, 483,
579, 675, 771, 867, 963
+ for i in range(3, 1024, 96):
+ np.testing.assert_allclose(c_result[i], a_np[i] + b_np[i], rtol=1e-5)
+
+ # Assert non-touched positions remain zero
+ for i in range(0, 3):
+ assert c_result[i] == 0.0
+ for i in range(4, 1024):
+ if (i - 3) % 96 != 0:
+ assert c_result[i] == 0.0
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/codegen/test_target_codegen_cuda.py
b/tests/python/codegen/test_target_codegen_cuda.py
index 0841d0f545..1b31e64414 100644
--- a/tests/python/codegen/test_target_codegen_cuda.py
+++ b/tests/python/codegen/test_target_codegen_cuda.py
@@ -877,5 +877,37 @@ def test_thread_return():
assert "return;" in cuda_code
[email protected]_gpu
[email protected]_cuda
+def test_cuda_loop_step():
+ @T.prim_func
+ def cuda_loop_step(
+ A: T.Buffer((1024,), "float32"),
+ B: T.Buffer((1024,), "float32"),
+ C: T.Buffer((1024,), "float32"),
+ ):
+ # Each thread computes a strided subset of the i loop: start = tx*3,
step = 96 (3 * 32 threads)
+ for bx in T.thread_binding(1, "blockIdx.x"):
+ for tx in T.thread_binding(96, "threadIdx.x"):
+ for i in T.serial(tx, 1024, step=96):
+ C[i] = A[i] + B[i]
+
+ target = tvm.target.Target({"kind": "cuda"})
+ with tvm.transform.PassContext(disabled_pass=["tir.CanonicalizeLoop"]):
+ lib = tvm.compile(cuda_loop_step, target=target)
+
+ cuda_src = lib.mod.imports[0].inspect_source()
+ assert "i += 96" in cuda_src
+ dev = tvm.cuda(0)
+ a_np = np.random.uniform(1, 100, (1024,)).astype("float32")
+ b_np = np.random.uniform(1, 100, (1024,)).astype("float32")
+ c_np = np.zeros((1024,), dtype="float32")
+ a_nd = tvm.runtime.tensor(a_np, dev)
+ b_nd = tvm.runtime.tensor(b_np, dev)
+ c_nd = tvm.runtime.tensor(c_np, dev)
+ lib["main"](a_nd, b_nd, c_nd)
+ tvm.testing.assert_allclose(c_nd.numpy(), a_np + b_np)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/tir-base/test_tir_nodes.py
b/tests/python/tir-base/test_tir_nodes.py
index bc7cfeae17..85cd726dda 100644
--- a/tests/python/tir-base/test_tir_nodes.py
+++ b/tests/python/tir-base/test_tir_nodes.py
@@ -134,6 +134,7 @@ def test_basic():
def test_stmt():
x = tvm.tir.Evaluate(0)
tvm.tir.For(te.var("i"), 0, 1, tvm.tir.ForKind.SERIAL, x)
+ tvm.tir.For(te.var("i"), 0, 1, tvm.tir.ForKind.UNROLLED, x, step=2)
def test_dir():
diff --git a/tests/python/tir-transform/test_tir_transform_canonicalize_loop.py
b/tests/python/tir-transform/test_tir_transform_canonicalize_loop.py
new file mode 100644
index 0000000000..6f6d88137c
--- /dev/null
+++ b/tests/python/tir-transform/test_tir_transform_canonicalize_loop.py
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm
+from tvm import tir
+from tvm.script import tir as T
+
+
+def test_canonicalize_loop():
+ @T.prim_func
+ def before(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"]):
+ T.func_attr({"global_symbol": "main"})
+ for i in range(1, 128, 5):
+ B[i] = A[i] + 1.0
+
+ @T.prim_func
+ def expected(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,),
"float32"]):
+ T.func_attr({"global_symbol": "main"})
+ for i in T.serial(0, 26):
+ B[i * 5 + 1] = A[i * 5 + 1] + 1.0
+
+ mod = tvm.IRModule.from_expr(before)
+ mod = tir.transform.CanonicalizeLoop()(mod)
+ tvm.ir.assert_structural_equal(mod["main"], expected)
+
+
+def test_canonicalize_nested_loop():
+ @T.prim_func
+ def before(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128),
"float32"]):
+ T.func_attr({"global_symbol": "main"})
+ for i in range(1, 128, 5):
+ for j in range(2, 128, 3):
+ B[i, j] = A[i, j] + 1.0
+
+ @T.prim_func
+ def expected(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128),
"float32"]):
+ T.func_attr({"global_symbol": "main"})
+ for i in T.serial(0, 26):
+ for j in T.serial(0, 42):
+ B[i * 5 + 1, j * 3 + 2] = A[i * 5 + 1, j * 3 + 2] + 1.0
+
+ mod = tvm.IRModule.from_expr(before)
+ mod = tir.transform.CanonicalizeLoop()(mod)
+ tvm.ir.assert_structural_equal(mod["main"], expected)
+
+
+def test_canonicalize_negative_step():
+ @T.prim_func
+ def before(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"]):
+ T.func_attr({"global_symbol": "main"})
+ for i in T.serial(0, 127, step=-3):
+ B[i] = A[i] + 1.0
+
+ mod = tvm.IRModule.from_expr(before)
+ with pytest.raises(tvm.error.InternalError):
+ mod = tir.transform.CanonicalizeLoop()(mod)
+
+
+def test_canonicalize_dynamic_step():
+ """Currently we report error for dynamic step since we could not prove it
is positive"""
+
+ @T.prim_func
+ def before(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"],
step: T.int32):
+ T.func_attr({"global_symbol": "main"})
+ for i in T.serial(0, 128, step=step):
+ B[i] = A[i] + 1.0
+
+ mod = tvm.IRModule.from_expr(before)
+ with pytest.raises(tvm.error.InternalError):
+ mod = tir.transform.CanonicalizeLoop()(mod)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py
b/tests/python/tvmscript/test_tvmscript_parser_tir.py
index f1569be5b1..3b84e919c8 100644
--- a/tests/python/tvmscript/test_tvmscript_parser_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py
@@ -327,6 +327,32 @@ def test_tir_starred_for_loop():
tvm.ir.assert_structural_equal(starred, non_starred)
+def test_tir_loop_steps():
+ N = T.Var("N", "int32")
+
+ @T.prim_func(private=True)
+ def loop_with_steps(
+ A: T.Buffer((N,)), B: T.Buffer((N,)), C: T.Buffer((N,)), tid: T.int32,
v: T.int32
+ ):
+ for i in T.serial(tid, N, step=2):
+ C[i] = A[i] + B[i]
+ for i in T.unroll(tid, N, step=3):
+ C[i] = A[i] + B[i]
+ for i in T.vectorized(tid, N, step=4):
+ C[i] = A[i] + B[i]
+ for i in T.parallel(tid, N, step=5):
+ C[i] = A[i] + B[i]
+ for i in T.serial(tid, N, step=v):
+ C[i] = A[i] + B[i]
+
+ stmts = loop_with_steps.body.seq
+ assert stmts[0].step == 2
+ assert stmts[1].step == 3
+ assert stmts[2].step == 4
+ assert stmts[3].step == 5
+ assert stmts[4].step.name == "v"
+
+
def test_tir_empty_tuple_index():
@T.macro
def bar(val):
diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py
b/tests/python/tvmscript/test_tvmscript_roundtrip.py
index 1954ca773f..b3d459b2e6 100644
--- a/tests/python/tvmscript/test_tvmscript_roundtrip.py
+++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py
@@ -4018,6 +4018,25 @@ def func_with_loop_jumps():
return func
+def func_with_loop_steps():
+ @T.prim_func
+ def func(
+ A: T.Buffer((1024,)), B: T.Buffer((1024,)), C: T.Buffer((1024,)), tid:
T.int32, v: T.int32
+ ):
+ for i in T.serial(tid, 1024, step=2):
+ C[i] = A[i] + B[i]
+ for i in T.unroll(tid, 1024, step=3):
+ C[i] = A[i] + B[i]
+ for i in T.vectorized(tid, 1024, step=4):
+ C[i] = A[i] + B[i]
+ for i in T.parallel(tid, 1024, step=5):
+ C[i] = A[i] + B[i]
+ for i in range(tid, 1024, 6):
+ C[i] = A[i] + B[i]
+
+ return func
+
+
def op_of_literal():
op_list = [
(T.exp, 0),
@@ -4237,6 +4256,7 @@ ir_generator = tvm.testing.parameter(
return_zero_private_with_attr,
func_attr_with_list,
func_with_loop_jumps,
+ func_with_loop_steps,
*op_of_literal(),
*relax_match_cast_struct_info_proxy(),
relax_symbolic_size_var,