This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 13ea9dc104 [TIR] Add step attribute to ForNode (Initial codes) (#18421)
13ea9dc104 is described below

commit 13ea9dc10436836e9654a897cf6f8f87813dc8a4
Author: wrongtest <[email protected]>
AuthorDate: Mon Nov 24 21:30:16 2025 +0800

    [TIR] Add step attribute to ForNode (Initial codes) (#18421)
    
    An initial change to add `ForNode::step`.
    
    - Add `Optional<PrimExpr>` typed step attribute to ForNode. Then add
    minimal codes for
        - Roundtrip support for TIR tvmscript grammar
        - Correctness of TIR lowering pipeline:
            - Canonicalize the loop in default pipeline
    - Ensure the original `ForNode::step` is not dropped by mutations on
    `ForNode`.
        - CodeGen support for non-zero min and non-trivial step.
    
    - TODOs in the future (hopefully)
    - For **all transformations and analysis tools**, make adaptions to
    non-consecutive loop iteration indices
        - Correctness of TensorIR schedule and MetaSchedule
    
    ---------
    
    Co-authored-by: baoxinqi <[email protected]>
---
 include/tvm/script/ir_builder/tir/frame.h          |   8 +-
 include/tvm/script/ir_builder/tir/ir.h             |  16 +++-
 include/tvm/tir/stmt.h                             |  17 +++-
 python/tvm/script/ir_builder/tir/ir.py             |  44 +++++++--
 python/tvm/script/parser/tir/parser.py             |  27 +++++-
 python/tvm/tir/ir_builder.py                       |   8 +-
 python/tvm/tir/pipeline.py                         |   1 +
 python/tvm/tir/stmt.py                             |   7 ++
 python/tvm/tir/transform/transform.py              |  11 +++
 .../transform/lower_global_view_to_local_view.cc   |   4 +-
 src/script/ir_builder/tir/frame.cc                 |   2 +-
 src/script/ir_builder/tir/ir.cc                    |  20 +++-
 src/script/printer/tir/for_loop.cc                 |  15 ++-
 src/target/llvm/codegen_cpu.cc                     |  16 ++--
 src/target/llvm/codegen_llvm.cc                    |   8 +-
 src/target/source/codegen_c.cc                     |  14 ++-
 src/target/source/codegen_cuda.cc                  |   1 -
 src/target/source/codegen_webgpu.cc                |  14 ++-
 src/target/spirv/codegen_spirv.cc                  |  23 +++--
 src/tir/ir/data_type_rewriter.cc                   |   9 +-
 src/tir/ir/stmt.cc                                 |  30 ++++--
 src/tir/ir/stmt_functor.cc                         |  11 ++-
 src/tir/schedule/primitive/blockize_tensorize.cc   |   2 +-
 src/tir/schedule/primitive/decompose_padding.cc    |   2 +-
 src/tir/schedule/primitive/loop_transformation.cc  |   4 +-
 src/tir/schedule/primitive/reduction.cc            |  13 ++-
 src/tir/transforms/canonicalize_loop.cc            | 102 +++++++++++++++++++++
 src/tir/transforms/common_subexpr_elim.cc          |   2 +-
 src/tir/transforms/convert_for_loops_serial.cc     |   2 +-
 src/tir/transforms/inject_software_pipeline.cc     |   2 +-
 src/tir/transforms/ir_utils.cc                     |   6 +-
 src/tir/transforms/lift_thread_binding.cc          |   2 +-
 src/tir/transforms/loop_partition.cc               |   8 +-
 src/tir/transforms/lower_cross_thread_reduction.cc |   4 +-
 src/tir/transforms/lower_opaque_block.cc           |   2 +-
 src/tir/transforms/memhammer_coalesce.cc           |   3 +-
 src/tir/transforms/memhammer_tensorcore_rewrite.cc |  55 ++++++-----
 src/tir/transforms/storage_rewrite.cc              |   2 +-
 src/tir/transforms/unify_thread_binding.cc         |   6 +-
 src/tir/transforms/unroll_loop.cc                  |   5 +-
 src/tir/transforms/vectorize_loop.cc               |   6 +-
 tests/python/codegen/test_target_codegen.py        |  44 ++++++++-
 tests/python/codegen/test_target_codegen_cuda.py   |  32 +++++++
 tests/python/tir-base/test_tir_nodes.py            |   1 +
 .../test_tir_transform_canonicalize_loop.py        |  88 ++++++++++++++++++
 .../python/tvmscript/test_tvmscript_parser_tir.py  |  26 ++++++
 tests/python/tvmscript/test_tvmscript_roundtrip.py |  20 ++++
 47 files changed, 619 insertions(+), 126 deletions(-)

diff --git a/include/tvm/script/ir_builder/tir/frame.h 
b/include/tvm/script/ir_builder/tir/frame.h
index 827e4e0329..db5776890a 100644
--- a/include/tvm/script/ir_builder/tir/frame.h
+++ b/include/tvm/script/ir_builder/tir/frame.h
@@ -251,13 +251,15 @@ class ForFrameNode : public TIRFrameNode {
    * \param loop_body The loop body
    * \return A stmt, the loop nest
    */
-  using FMakeForLoop =
-      ffi::TypedFunction<tvm::tir::Stmt(ffi::Array<tvm::tir::Var> loop_vars,
-                                        ffi::Array<Range> loop_extents, 
tvm::tir::Stmt loop_body)>;
+  using FMakeForLoop = ffi::TypedFunction<tvm::tir::Stmt(
+      ffi::Array<tvm::tir::Var> loop_vars, ffi::Array<Range> loop_extents,
+      ffi::Array<ffi::Optional<PrimExpr>> loop_steps, tvm::tir::Stmt 
loop_body)>;
   /*! \brief The loop variable. */
   ffi::Array<tvm::tir::Var> vars;
   /*! \brief The domains of iteration. */
   ffi::Array<Range> doms;
+  /*! \brief The optional steps of iteration. */
+  ffi::Array<ffi::Optional<PrimExpr>> steps;
   /*! \brief The for loop generating function. */
   FMakeForLoop f_make_for_loop;
 
diff --git a/include/tvm/script/ir_builder/tir/ir.h 
b/include/tvm/script/ir_builder/tir/ir.h
index 24ce8fdf99..07c7fe262b 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -228,37 +228,45 @@ ffi::Array<Var> Remap(ffi::String kinds, 
ffi::Array<PrimExpr> bindings,
  * \param start The minimum value of iteration.
  * \param stop The maximum value of iteration.
  * \param annotations The optional annotations of the For statement.
+ * \param step The optional step value of iteration.
  * \return The ForFrame.
  */
 ForFrame Serial(PrimExpr start, PrimExpr stop,
-                ffi::Optional<ffi::Map<ffi::String, Any>> annotations = 
std::nullopt);
+                ffi::Optional<ffi::Map<ffi::String, Any>> annotations = 
std::nullopt,
+                ffi::Optional<PrimExpr> step = std::nullopt);
 /*!
  * \brief The parallel For statement.
  * \param start The minimum value of iteration.
  * \param stop The maximum value of iteration.
  * \param annotations The optional annotations of the For statement.
+ * \param step The optional step value of iteration.
  * \return The ForFrame.
  */
 ForFrame Parallel(PrimExpr start, PrimExpr stop,
-                  ffi::Optional<ffi::Map<ffi::String, Any>> annotations = 
std::nullopt);
+                  ffi::Optional<ffi::Map<ffi::String, Any>> annotations = 
std::nullopt,
+                  ffi::Optional<PrimExpr> step = std::nullopt);
 /*!
  * \brief The vectorized For statement.
  * \param start The minimum value of iteration.
  * \param stop The maximum value of iteration.
  * \param annotations The optional annotations of the For statement.
+ * \param step The optional step value of iteration.
  * \return The ForFrame.
  */
 ForFrame Vectorized(PrimExpr start, PrimExpr stop,
-                    ffi::Optional<ffi::Map<ffi::String, Any>> annotations = 
std::nullopt);
+                    ffi::Optional<ffi::Map<ffi::String, Any>> annotations = 
std::nullopt,
+                    ffi::Optional<PrimExpr> step = std::nullopt);
 /*!
  * \brief The unrolled For statement.
  * \param start The minimum value of iteration.
  * \param stop The maximum value of iteration.
  * \param annotations The optional annotations of the For statement.
+ * \param step The optional step value of iteration.
  * \return The ForFrame.
  */
 ForFrame Unroll(PrimExpr start, PrimExpr stop,
-                ffi::Optional<ffi::Map<ffi::String, Any>> annotations = 
std::nullopt);
+                ffi::Optional<ffi::Map<ffi::String, Any>> annotations = 
std::nullopt,
+                ffi::Optional<PrimExpr> step = std::nullopt);
 /*!
  * \brief The thread-binding For statement.
  * \param start The minimum value of iteration.
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 1b8041e36c..0831b84cf6 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -717,7 +717,7 @@ enum class ForKind : int {
  *
  * \code
  *
- *  for (loop_var = min; loop_var < min + extent; ++loop_var) {
+ *  for (loop_var = min; loop_var < min + extent; loop_var += step) {
  *    // body
  *  }
  * \endcode
@@ -748,6 +748,10 @@ class ForNode : public StmtNode {
    *  and can be ignored in most passes.
    */
   ffi::Map<ffi::String, ffi::Any> annotations;
+  /*!
+   * \brief The loop step. It is one if not specified.
+   */
+  ffi::Optional<PrimExpr> step;
 
   static void RegisterReflection() {
     namespace refl = tvm::ffi::reflection;
@@ -758,8 +762,13 @@ class ForNode : public StmtNode {
         .def_ro("kind", &ForNode::kind)
         .def_ro("body", &ForNode::body)
         .def_ro("thread_binding", &ForNode::thread_binding)
-        .def_ro("annotations", &ForNode::annotations);
+        .def_ro("annotations", &ForNode::annotations)
+        .def_ro("step", &ForNode::step);
   }
+
+  /*! \brief Check it is a loop without nontrivial loop step. */
+  bool HasTrivialStep() const;
+
   TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.For", ForNode, StmtNode);
 };
 
@@ -771,8 +780,8 @@ class For : public Stmt {
  public:
   TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt 
body,
               ffi::Optional<IterVar> thread_binding = std::nullopt,
-              ffi::Map<ffi::String, ffi::Any> annotations = 
ffi::Map<ffi::String, ffi::Any>(),
-              Span span = Span());
+              ffi::Map<ffi::String, ffi::Any> annotations = {},
+              ffi::Optional<PrimExpr> step = std::nullopt, Span span = Span());
 
   TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(For, Stmt, ForNode);
   TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode);
diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index 6d746d73b1..31e48260f5 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -677,7 +677,11 @@ class axis:  # pylint: disable=invalid-name
 
 
 def serial(
-    start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = 
None
+    start: PrimExpr,
+    stop: PrimExpr = None,
+    *,
+    annotations: Dict[str, Any] = None,
+    step: Optional[PrimExpr] = None,
 ) -> frame.ForFrame:
     """The serial For statement.
 
@@ -692,6 +696,9 @@ def serial(
     annotations : Dict[str, Any]
         The optional annotations of the For statement.
 
+    step : PrimExpr
+        The optional step value of iteration.
+
     Returns
     -------
     res : frame.ForFrame
@@ -703,11 +710,15 @@ def serial(
             start = IntImm(start.dtype, 0)
         else:
             start = 0
-    return _ffi_api.Serial(start, stop, annotations)  # type: 
ignore[attr-defined] # pylint: disable=no-member
+    return _ffi_api.Serial(start, stop, annotations, step)  # type: 
ignore[attr-defined] # pylint: disable=no-member
 
 
 def parallel(
-    start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = 
None
+    start: PrimExpr,
+    stop: PrimExpr = None,
+    *,
+    annotations: Dict[str, Any] = None,
+    step: Optional[PrimExpr] = None,
 ) -> frame.ForFrame:
     """The parallel For statement.
 
@@ -722,6 +733,9 @@ def parallel(
     annotations : Dict[str, Any]
         The optional annotations of the For statement.
 
+    step : PrimExpr
+        The optional step value of iteration.
+
     Returns
     -------
     res : frame.ForFrame
@@ -733,11 +747,15 @@ def parallel(
             start = IntImm(start.dtype, 0)
         else:
             start = 0
-    return _ffi_api.Parallel(start, stop, annotations)  # type: 
ignore[attr-defined] # pylint: disable=no-member
+    return _ffi_api.Parallel(start, stop, annotations, step)  # type: 
ignore[attr-defined] # pylint: disable=no-member
 
 
 def vectorized(
-    start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = 
None
+    start: PrimExpr,
+    stop: PrimExpr = None,
+    *,
+    annotations: Dict[str, Any] = None,
+    step: Optional[PrimExpr] = None,
 ) -> frame.ForFrame:
     """The vectorized For statement.
 
@@ -752,6 +770,9 @@ def vectorized(
     annotations : Dict[str, Any]
         The optional annotations of the For statement.
 
+    step : PrimExpr
+        The optional step value of iteration.
+
     Returns
     -------
     res : frame.ForFrame
@@ -763,11 +784,15 @@ def vectorized(
             start = IntImm(start.dtype, 0)
         else:
             start = 0
-    return _ffi_api.Vectorized(start, stop, annotations)  # type: 
ignore[attr-defined] # pylint: disable=no-member
+    return _ffi_api.Vectorized(start, stop, annotations, step)  # type: 
ignore[attr-defined] # pylint: disable=no-member
 
 
 def unroll(
-    start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = 
None
+    start: PrimExpr,
+    stop: PrimExpr = None,
+    *,
+    annotations: Dict[str, Any] = None,
+    step: Optional[PrimExpr] = None,
 ) -> frame.ForFrame:
     """The unrolled For statement.
 
@@ -782,6 +807,9 @@ def unroll(
     annotations : Dict[str, Any]
         The optional annotations of the For statement.
 
+    step : PrimExpr
+        The optional step value of iteration.
+
     Returns
     -------
     res : frame.ForFrame
@@ -793,7 +821,7 @@ def unroll(
             start = IntImm(start.dtype, 0)
         else:
             start = 0
-    return _ffi_api.Unroll(start, stop, annotations)  # type: 
ignore[attr-defined] # pylint: disable=no-member
+    return _ffi_api.Unroll(start, stop, annotations, step)  # type: 
ignore[attr-defined] # pylint: disable=no-member
 
 
 def thread_binding(
diff --git a/python/tvm/script/parser/tir/parser.py 
b/python/tvm/script/parser/tir/parser.py
index 85ab1982f3..f8cbc0b4f5 100644
--- a/python/tvm/script/parser/tir/parser.py
+++ b/python/tvm/script/parser/tir/parser.py
@@ -18,7 +18,7 @@
 
 import contextlib
 from functools import partial
-from typing import Any
+from typing import Any, Dict, Optional
 
 import tvm
 from tvm.ir import GlobalVar, PrimType
@@ -168,6 +168,28 @@ def find_decorator_annotation(node: doc.FunctionDef, 
annotation: str, default: b
     return default
 
 
+def range_sugar(
+    start: PrimExpr,
+    stop: PrimExpr = None,
+    step: Optional[PrimExpr] = None,
+    *,
+    annotations: Dict[str, Any] = None,
+) -> T.frame.ForFrame:
+    """The sugar for python range builtin."""
+
+    # Since `tir.For` do not support reversed iteration semantic,
+    # the step must be checked to be positive integer when use range sugar
+    if step is not None:
+        try:
+            step = int(step)
+            if step <= 0:
+                raise ValueError(f"Only support positive step in range(), get 
{step}")
+        except TypeError:  # pylint: disable=broad-except
+            raise ValueError(f"Only support literal step in range(), get 
{step}")
+
+    return T.serial(start, stop, annotations=annotations, step=step)
+
+
 @dispatch.register(token="tir", type_name="For")
 def visit_for(self: Parser, node: doc.For) -> None:
     """The for visiting method for tir.
@@ -379,7 +401,8 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) 
-> None:
     privacy = find_decorator_annotation(node, "private", default=False)
     self.function_annotations = None
     with self.var_table.with_frame():
-        self.var_table.add("range", T.serial)
+
+        self.var_table.add("range", range_sugar)
         with T.prim_func(is_private=privacy):
             T.func_name(node.name)
             if node.returns is not None:
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py
index a6313ae3bc..1e9cb07830 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder.py
@@ -202,7 +202,7 @@ class IRBuilder(object):
             value = op.max(1, value)
         self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x))
 
-    def for_range(self, begin, end, name="i", dtype=None, kind="serial"):
+    def for_range(self, begin, end, name="i", dtype=None, kind="serial", 
step=None):
         """Create a for iteration scope.
 
         Parameters
@@ -223,6 +223,10 @@ class IRBuilder(object):
         kind : str, optional
             The special tag on the for loop.
 
+        step : PrimExpr
+            The loop step. Default to none which
+            represent one.
+
         Returns
         -------
         loop_scope : With.Scope of Var
@@ -275,7 +279,7 @@ class IRBuilder(object):
                 kind_id = _stmt.ForKind.UNROLLED
             else:
                 raise ValueError("Unknown kind")
-            self.emit(_stmt.For(loop_var, begin, extent, kind_id, 
self._pop_seq()))
+            self.emit(_stmt.For(loop_var, begin, extent, kind_id, 
self._pop_seq(), step=step))
 
         return WithScope(loop_var, _exit_cb)
 
diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tir/pipeline.py
index 22cec30334..96ed9dfdbc 100644
--- a/python/tvm/tir/pipeline.py
+++ b/python/tvm/tir/pipeline.py
@@ -31,6 +31,7 @@ def default_tir_pipeline():
         pass_ctx = tvm.transform.PassContext.current()
         config = pass_ctx.config
         passes = [
+            tir.transform.CanonicalizeLoop(),
             tir.transform.LowerCrossThreadReduction(),
             tir.transform.LowerInitBlock(),
             tir.transform.PlanAndUpdateBufferAllocationLocation(),
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index bd90d52574..448ace3ade 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -145,6 +145,10 @@ class For(Stmt):
         The thread this loop binds to. Only valid
         if kind is ThreadBinding
 
+    step : PrimExpr
+        The loop step. Default to none which
+        represent one.
+
     annotations: Optional[Mapping[str, Object]]
         Additional annotation hints.
 
@@ -159,6 +163,7 @@ class For(Stmt):
     body: Stmt
     thread_binding: Optional[IterVar]
     annotations: Mapping[str, Object]
+    step: Optional[PrimExpr]
     span: Optional[Span]
 
     def __init__(
@@ -170,6 +175,7 @@ class For(Stmt):
         body: Stmt,
         thread_binding: Optional[IterVar] = None,
         annotations: Optional[Mapping[str, Object]] = None,
+        step: Optional[PrimExpr] = None,
         span: Optional[Span] = None,
     ) -> None:
         self.__init_handle_by_constructor__(
@@ -181,6 +187,7 @@ class For(Stmt):
             body,
             thread_binding,
             annotations,
+            step,
             span,
         )
 
diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index 39105f21a2..88cf4720d3 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -1171,3 +1171,14 @@ def LowerVtcmAlloc():
         The result pass
     """
     return _ffi_api.LowerVtcmAlloc()  # type: ignore
+
+
+def CanonicalizeLoop():
+    """Canonicalize the loop to start from zero and use trivial step
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.CanonicalizeLoop()  # type: ignore
diff --git a/src/relax/distributed/transform/lower_global_view_to_local_view.cc 
b/src/relax/distributed/transform/lower_global_view_to_local_view.cc
index f83edb3e90..837f2f0a5d 100644
--- a/src/relax/distributed/transform/lower_global_view_to_local_view.cc
+++ b/src/relax/distributed/transform/lower_global_view_to_local_view.cc
@@ -330,8 +330,8 @@ class DistributedBufferCompactor : StmtExprMutator {
       if (shard > 1) {
         arith::Analyzer analyzer;
         ICHECK(analyzer.CanProve(floormod(new_loop->extent, shard) == 0));
-        return For(new_loop->loop_var, new_loop->min, 
floordiv(new_loop->extent, shard),
-                   new_loop->kind, new_loop->body, new_loop->thread_binding, 
new_loop->annotations);
+        new_loop.CopyOnWrite()->extent = floordiv(new_loop->extent, shard);
+        return new_loop;
       }
     }
     return new_loop;
diff --git a/src/script/ir_builder/tir/frame.cc 
b/src/script/ir_builder/tir/frame.cc
index 94eef40f59..7c10b6cdc8 100644
--- a/src/script/ir_builder/tir/frame.cc
+++ b/src/script/ir_builder/tir/frame.cc
@@ -123,7 +123,7 @@ void BlockInitFrameNode::ExitWithScope() {
 
 void ForFrameNode::ExitWithScope() {
   TIRFrameNode::ExitWithScope();
-  AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts)));
+  AddToParent(this->f_make_for_loop(vars, doms, steps, AsStmt(stmts)));
 }
 
 void AssertFrameNode::ExitWithScope() {
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index b981b90bd8..00f9c28475 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -362,19 +362,23 @@ ffi::Array<Var> Remap(ffi::String kinds, 
ffi::Array<PrimExpr> bindings, DataType
 
 #define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind)                             
              \
   ForFrame Method(PrimExpr start, PrimExpr stop,                               
              \
-                  ffi::Optional<ffi::Map<ffi::String, Any>> annotations) {     
              \
+                  ffi::Optional<ffi::Map<ffi::String, Any>> annotations,       
              \
+                  ffi::Optional<PrimExpr> step) {                              
              \
     PrimExpr min = start;                                                      
              \
     PrimExpr extent = arith::Analyzer().Simplify(stop - start);                
              \
     ObjectPtr<ForFrameNode> n = ffi::make_object<ForFrameNode>();              
              \
     int bits = std::max(min.dtype().bits(), extent.dtype().bits());            
              \
     n->vars = {Var("v", DataType(min.dtype().code(), bits, 1))};               
              \
     n->doms = {Range::FromMinExtent(min, extent)};                             
              \
+    n->steps = {step};                                                         
              \
     n->f_make_for_loop = [annotations](ffi::Array<Var> vars, ffi::Array<Range> 
doms,         \
+                                       ffi::Array<ffi::Optional<PrimExpr>> 
steps,            \
                                        tvm::tir::Stmt body) {                  
              \
       ICHECK_EQ(vars.size(), 1);                                               
              \
       ICHECK_EQ(doms.size(), 1);                                               
              \
+      ICHECK_EQ(steps.size(), 1);                                              
              \
       return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, 
std::nullopt, \
-                           annotations.value_or(ffi::Map<ffi::String, 
Any>()));              \
+                           annotations.value_or(ffi::Map<ffi::String, Any>()), 
steps[0]);    \
     };                                                                         
              \
     return ForFrame(n);                                                        
              \
   }
@@ -396,13 +400,16 @@ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, 
ffi::String thread,
   DataType dtype = DataType(min.dtype().code(), bits, 1);
   n->vars = {Var("v", dtype)};
   n->doms = {Range::FromMinExtent(min, extent)};
+  n->steps = {std::nullopt};
   n->f_make_for_loop = [annotations, thread, dtype](ffi::Array<Var> vars, 
ffi::Array<Range> doms,
+                                                    
ffi::Array<ffi::Optional<PrimExpr>> steps,
                                                     Stmt body) -> For {
     ICHECK_EQ(vars.size(), 1);
     ICHECK_EQ(doms.size(), 1);
+    ICHECK(steps.size() == 1 && (!steps[0].has_value() || is_one(*steps[0])));
     IterVar iter_var(Range(nullptr), Var("iter", dtype), 
IterVarType::kThreadIndex, thread);
     return For(vars[0], doms[0]->min, doms[0]->extent, 
ForKind::kThreadBinding, body, iter_var,
-               annotations.value_or(ffi::Map<ffi::String, ffi::Any>()));
+               annotations.value_or(ffi::Map<ffi::String, ffi::Any>()), 
std::nullopt);
   };
   return ForFrame(n);
 }
@@ -412,19 +419,22 @@ ForFrame Grid(ffi::Array<PrimExpr> extents) {
   ObjectPtr<ForFrameNode> n = ffi::make_object<ForFrameNode>();
   n->vars.reserve(extents.size());
   n->doms.reserve(extents.size());
+  n->steps.resize(extents.size());
   for (const auto& extent : extents) {
     DataType dtype = extent.dtype();
     n->vars.push_back(Var("v", extent.dtype()));
     n->doms.push_back(Range(make_const(dtype, 0), extent));
   }
-  n->f_make_for_loop = [](ffi::Array<Var> vars, ffi::Array<Range> doms, Stmt 
body) -> Stmt {
+  n->f_make_for_loop = [](ffi::Array<Var> vars, ffi::Array<Range> doms,
+                          ffi::Array<ffi::Optional<PrimExpr>> steps, Stmt 
body) -> Stmt {
     ICHECK_EQ(vars.size(), doms.size());
+    ICHECK_EQ(vars.size(), steps.size());
     int n = vars.size();
     for (int i = n - 1; i >= 0; --i) {
       Range dom = doms[i];
       Var var = vars[i];
       body = For(var, dom->min, dom->extent, ForKind::kSerial, std::move(body),
-                 /*thread_binding=*/std::nullopt, /*annotations=*/{});
+                 /*thread_binding=*/std::nullopt, /*annotations=*/{}, 
/*step=*/steps[i]);
     }
     return body;
   };
diff --git a/src/script/printer/tir/for_loop.cc 
b/src/script/printer/tir/for_loop.cc
index 742d23f69c..b2e091f380 100644
--- a/src/script/printer/tir/for_loop.cc
+++ b/src/script/printer/tir/for_loop.cc
@@ -39,7 +39,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
           if (l->kind != tir::ForKind::kSerial ||  //
               !tir::is_zero(l->min) ||             //
               !l->annotations.empty() ||           //
-              f_var_dep(l->extent)) {
+              !l->HasTrivialStep() || f_var_dep(l->extent)) {
             break;
           }
           grid.push_back(l);
@@ -69,7 +69,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
       ffi::Optional<ExprDoc> max = std::nullopt;
       ffi::Optional<ExprDoc> annotations = std::nullopt;
       ffi::Optional<ExprDoc> thread = std::nullopt;
-      if (tir::is_zero(loop->min)) {
+      if (tir::is_zero(loop->min) && loop->HasTrivialStep()) {
         max = d->AsDoc<ExprDoc>(loop->extent, loop_p->Attr("extent"));
       } else {
         min = d->AsDoc<ExprDoc>(loop->min, loop_p->Attr("min"));
@@ -78,10 +78,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
       if (!loop->annotations.empty()) {
         annotations = d->AsDoc<ExprDoc>(loop->annotations, 
loop_p->Attr("annotations"));
       }
+      bool use_range_sugar = false;
       ExprDoc prefix{ffi::UnsafeInit()};
       if (loop->kind == tir::ForKind::kSerial) {
         if (loop->annotations.empty()) {
           prefix = IdDoc("range");
+          use_range_sugar = true;
         } else {
           prefix = TIR(d, "serial");
         }
@@ -115,6 +117,15 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
         kwargs_keys.push_back("annotations");
         kwargs_values.push_back(annotations.value());
       }
+      if (!loop->HasTrivialStep()) {
+        ExprDoc step = d->AsDoc<ExprDoc>(*loop->step, loop_p->Attr("step"));
+        if (use_range_sugar) {
+          args.push_back(step);
+        } else {
+          kwargs_keys.push_back("step");
+          kwargs_values.push_back(step);
+        }
+      }
       ExprDoc rhs = prefix->Call(args, kwargs_keys, kwargs_values);
       AsDocBody(loop->body, loop_p->Attr("body"), (*f).get(), d);
       return ForDoc(lhs, rhs, (*f)->stmts);
diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc
index d9ee972321..bc67cdad2f 100644
--- a/src/target/llvm/codegen_cpu.cc
+++ b/src/target/llvm/codegen_cpu.cc
@@ -1152,14 +1152,15 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) {
 
 void CodeGenCPU::VisitStmt_(const ForNode* op) {
   EmitDebugLocation(op);
-  ICHECK(is_zero(op->min));
   if (op->kind == ForKind::kSerial || op->kind == ForKind::kUnrolled) {
     CodeGenLLVM::VisitStmt_(op);
   } else if (op->kind == ForKind::kParallel) {
+    ICHECK(is_zero(op->min)) << "Parallel launch require canonical loop with 
zero start index";
+    ICHECK(op->HasTrivialStep()) << "Parallel launch require canonical loop 
with trivial loop step";
     if (parallel_env_.penv == nullptr) {
-      CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind, 
op->body,
-                               op->thread_binding, op->annotations),
-                           0, std::string("loop_parallel_") + 
op->loop_var->name_hint.c_str());
+      auto copy_node = For(ffi::make_object<ForNode>(*op));
+      CreateParallelLaunch(copy_node, 0,
+                           std::string("loop_parallel_") + 
op->loop_var->name_hint.c_str());
     } else {
       // already in parallel env.
       ICHECK(parallel_env_.task_id.defined());
@@ -1171,13 +1172,14 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
       ICHECK(!parallel_env_.in_parallel_loop)
           << "Nested parallel loop is not supported by threadpool, try fuse 
them instead";
       parallel_env_.in_parallel_loop = true;
+      PrimExpr end = is_zero(op->min) ? op->extent : 
analyzer_->Simplify(op->min + op->extent);
       if (parallel_env_.stride_pattern) {
-        CreateSerialFor(MakeValue(task_id), MakeValue(op->extent), 
MakeValue(num_task),
-                        op->loop_var, op->body);
+        CreateSerialFor(MakeValue(task_id), MakeValue(end), 
MakeValue(num_task), op->loop_var,
+                        op->body);
       } else {
         PrimExpr step = (op->extent + num_task - make_const(t, 1)) / num_task;
         PrimExpr begin = min(task_id * step, op->extent);
-        PrimExpr end = min((task_id + make_const(t, 1)) * step, op->extent);
+        end = min((task_id + make_const(t, 1)) * step, end);
         CreateSerialFor(MakeValue(begin), MakeValue(end),
                         llvm::ConstantInt::getSigned(GetLLVMType(end), 1), 
op->loop_var, op->body);
       }
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 5f8b599a3b..131c8212c5 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -2023,7 +2023,6 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) {
 
 void CodeGenLLVM::VisitStmt_(const ForNode* op) {
   EmitDebugLocation(op);
-  ICHECK(is_zero(op->min));
   analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
   if (op->kind == ForKind::kUnrolled) {
     LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, "
@@ -2031,8 +2030,11 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) {
   } else {
     ICHECK(op->kind == ForKind::kSerial);
   }
-  CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
-                  llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), 
op->loop_var, op->body);
+  PrimExpr step = op->step.value_or(make_const(op->extent->dtype, 1));
+  PrimExpr end = is_zero(op->min) ? op->extent : analyzer_->Simplify(op->min + 
op->extent);
+  llvm::Value* begin_value = MakeValue(op->min);
+  llvm::Value* end_value = MakeValue(end);
+  CreateSerialFor(begin_value, end_value, MakeValue(step), op->loop_var, 
op->body);
 }
 
 void CodeGenLLVM::VisitStmt_(const WhileNode* op) {
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index 8ebd41645a..52ad781669 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -1120,13 +1120,21 @@ void CodeGenC::VisitStmt_(const AssertStmtNode* op) {
 }
 
 void CodeGenC::VisitStmt_(const ForNode* op) {
-  std::string extent = PrintExpr(op->extent);
+  std::string begin_str = PrintExpr(op->min);
+  PrimExpr end = is_zero(op->min) ? op->extent : 
arith::Analyzer().Simplify(op->min + op->extent);
+  std::string end_str = PrintExpr(end);
+  std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : "";
   PrintIndent();
   std::string vid = AllocVarID(op->loop_var.get());
-  ICHECK(is_zero(op->min));
   stream << "for (";
   PrintType(op->loop_var.dtype(), stream);
-  stream << ' ' << vid << " = 0; " << vid << " < " << extent << "; ++" << vid 
<< ") {\n";
+  stream << ' ' << vid << " = " << begin_str << "; " << vid << " < " << 
end_str << "; ";
+  if (step_str.empty()) {
+    stream << "++" << vid;
+  } else {
+    stream << vid << " += " << step_str;
+  }
+  stream << ") {\n";
   int for_scope = BeginScope();
   PrintStmt(op->body);
   this->EndScope(for_scope);
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index 9565eba5d4..a9cfad9ab6 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -319,7 +319,6 @@ std::string CodeGenCUDA::Finish() {
 }
 
 void CodeGenCUDA::VisitStmt_(const tir::ForNode* op) {
-  ICHECK(is_const_int(op->min, 0));
   if (op->kind == tir::ForKind::kUnrolled) {
     PrintIndent();
     stream << "#pragma unroll\n";
diff --git a/src/target/source/codegen_webgpu.cc 
b/src/target/source/codegen_webgpu.cc
index 330a54563f..cf8176001a 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -667,13 +667,21 @@ void CodeGenWebGPU::VisitStmt_(const AllocateNode* op) {
 }
 
 void CodeGenWebGPU::VisitStmt_(const ForNode* op) {
-  std::string extent = PrintExpr(op->extent);
+  std::string begin_str = PrintExpr(op->min);
+  PrimExpr end = is_zero(op->min) ? op->extent : 
arith::Analyzer().Simplify(op->min + op->extent);
+  std::string end_str = PrintExpr(end);
+  std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : "";
   std::string vid = AllocVarID(op->loop_var.get());
-  ICHECK(is_zero(op->min));
   PrintIndent();
   stream << "for (var " << vid << " : ";
   PrintType(op->loop_var.dtype(), stream);
-  stream << " = 0; " << vid << " < " << extent << "; " << vid << "++) {\n";
+  stream << " = " << begin_str << "; " << vid << " < " << end_str << "; " << 
vid;
+  if (step_str.empty()) {
+    stream << "++";
+  } else {
+    stream << " += " << step_str;
+  }
+  stream << ") {\n";
   int for_scope = BeginScope();
   PrintStmt(op->body);
   this->EndScope(for_scope);
diff --git a/src/target/spirv/codegen_spirv.cc 
b/src/target/spirv/codegen_spirv.cc
index c062926cc2..136f969896 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -672,10 +672,21 @@ void CodeGenSPIRV::VisitStmt_(const BufferStoreNode* op) {
 }
 
 void CodeGenSPIRV::VisitStmt_(const ForNode* op) {
-  ICHECK(is_zero(op->min));
   analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
   spirv::Value init_value = MakeValue(op->min);
-  spirv::Value extent_value = MakeValue(op->extent);
+  PrimExpr end = is_zero(op->min) ? op->extent : analyzer_->Simplify(op->min + 
op->extent);
+  spirv::Value end_value = MakeValue(end);
+  spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2);
+
+  // loop step
+  spirv::Value step;
+  if (op->HasTrivialStep()) {
+    step = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1)
+                                         : builder_->UIntImm(loop_var.stype, 
1);
+  } else {
+    step = MakeValue(tvm::cast(end->dtype, *op->step));
+  }
+
   // Must get init label after making value(to make sure they are correct)
   spirv::Label init_label = builder_->CurrentLabel();
   spirv::Label head_label = builder_->NewLabel();
@@ -690,9 +701,8 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) {
 
   // Loop head
   builder_->StartLabel(head_label);
-  spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2);
   loop_var.SetIncoming(0, init_value, init_label);
-  spirv::Value loop_cond = builder_->LT(loop_var, extent_value);
+  spirv::Value loop_cond = builder_->LT(loop_var, end_value);
   uint32_t control =
       (op->kind == ForKind::kUnrolled ? spv::LoopControlUnrollMask : 
spv::LoopControlMaskNone);
   builder_->MakeInst(spv::OpLoopMerge, merge_label, continue_label, control);
@@ -707,9 +717,8 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) {
 
   // loop continue
   builder_->StartLabel(continue_label);
-  spirv::Value one = op->loop_var.dtype().is_int() ? 
builder_->IntImm(loop_var.stype, 1)
-                                                   : 
builder_->UIntImm(loop_var.stype, 1);
-  spirv::Value next_value = builder_->Add(loop_var, one);
+
+  spirv::Value next_value = builder_->Add(loop_var, step);
   loop_var.SetIncoming(1, next_value, builder_->CurrentLabel());
   builder_->MakeInst(spv::OpBranch, head_label);
   // loop merge
diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc
index d6dcae6540..393ac7ee57 100644
--- a/src/tir/ir/data_type_rewriter.cc
+++ b/src/tir/ir/data_type_rewriter.cc
@@ -41,8 +41,13 @@ Stmt DataTypeLegalizer::VisitStmt_(const ForNode* op) {
   ICHECK(op != nullptr) << "Expected type to be ForNode, but get " << 
s->GetTypeKey();
   PrimExpr e = VisitExpr(op->loop_var);
   Var var = Downcast<Var>(e);
-  return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), 
op->kind, op->body,
-             op->thread_binding, op->annotations);
+  auto n = CopyOnWrite(op);
+  n->min = cast(var.dtype(), op->min);
+  n->extent = cast(var.dtype(), op->extent);
+  if (op->step.has_value()) {
+    n->step = cast(var.dtype(), *op->step);
+  }
+  return For(n);
 }
 
 Stmt DataTypeLegalizer::VisitStmt_(const BlockRealizeNode* op) {
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 47622757e5..b7e28e84e7 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -132,7 +132,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
 
 // For
 For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
-         ffi::Optional<IterVar> thread_binding, ffi::Map<ffi::String, Any> 
annotations, Span span) {
+         ffi::Optional<IterVar> thread_binding, ffi::Map<ffi::String, Any> 
annotations,
+         ffi::Optional<PrimExpr> step, Span span) {
   ICHECK(loop_var.defined());
   ICHECK(min.defined());
   ICHECK(extent.defined());
@@ -148,8 +149,8 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, 
ForKind kind, Stmt body,
   require_scalar_int_dtype(min, "min");
   require_scalar_int_dtype(extent, "extent");
 
-  // When extent or min is an IntImm but has narrower dtype than loop_var, we 
directly promote them
-  // without raising errors.
+  // When extent, min or step is an IntImm but has narrower dtype than loop_var
+  // we directly promote them without raising errors.
   auto try_promote_imm_dtype = [&](const PrimExpr& e) {
     ICHECK(e.dtype().bits() <= loop_var.dtype().bits())
         << " Loop variable's dtype (" << loop_var.dtype()
@@ -168,6 +169,12 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, 
ForKind kind, Stmt body,
   ICHECK(loop_var.dtype() == min.dtype()) << loop_var.dtype() << " vs " << 
min.dtype();
   ICHECK(loop_var.dtype() == extent.dtype()) << loop_var.dtype() << " vs " << 
extent.dtype();
 
+  if (step.has_value()) {
+    require_scalar_int_dtype(*step, "step");
+    step = try_promote_imm_dtype(*step);
+    ICHECK(loop_var.dtype() == (*step).dtype()) << loop_var.dtype() << " vs " 
<< (*step).dtype();
+  }
+
   ObjectPtr<ForNode> node = ffi::make_object<ForNode>();
   node->loop_var = std::move(loop_var);
   node->min = std::move(min);
@@ -176,19 +183,22 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, 
ForKind kind, Stmt body,
   node->body = std::move(body);
   node->thread_binding = std::move(thread_binding);
   node->annotations = std::move(annotations);
+  node->step = std::move(step);
   node->span = std::move(span);
   data_ = std::move(node);
 }
 
+bool ForNode::HasTrivialStep() const { return !step.has_value() || 
is_one(*step); }
+
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def(
-      "tir.For", [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, 
Stmt body,
-                    ffi::Optional<IterVar> thread_binding,
-                    ffi::Optional<ffi::Map<ffi::String, Any>> annotations, 
Span span) {
-        return For(loop_var, min, extent, static_cast<ForKind>(kind), body, 
thread_binding,
-                   annotations.value_or(ffi::Map<ffi::String, Any>()), span);
-      });
+  refl::GlobalDef().def("tir.For", [](Var loop_var, PrimExpr min, PrimExpr 
extent, int kind,
+                                      Stmt body, ffi::Optional<IterVar> 
thread_binding,
+                                      ffi::Optional<ffi::Map<ffi::String, 
Any>> annotations,
+                                      ffi::Optional<PrimExpr> step, Span span) 
{
+    return For(loop_var, min, extent, static_cast<ForKind>(kind), body, 
thread_binding,
+               annotations.value_or(ffi::Map<ffi::String, Any>()), step, span);
+  });
 }
 
 std::ostream& operator<<(std::ostream& out, ForKind type) {  // NOLINT(*)
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index 80c787b114..e6666cc638 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -46,6 +46,9 @@ void StmtVisitor::VisitStmt_(const AttrStmtNode* op) {
 void StmtVisitor::VisitStmt_(const ForNode* op) {
   this->VisitExpr(op->min);
   this->VisitExpr(op->extent);
+  if (op->step.has_value()) {
+    this->VisitExpr(*op->step);
+  }
   this->VisitStmt(op->body);
 }
 
@@ -260,13 +263,19 @@ Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) {
 Stmt StmtMutator::VisitStmt_(const ForNode* op) {
   PrimExpr min = this->VisitExpr(op->min);
   PrimExpr extent = this->VisitExpr(op->extent);
+  ffi::Optional<PrimExpr> step{std::nullopt};
+  if (op->step.has_value()) {
+    step = this->VisitExpr(*op->step);
+  }
   Stmt body = this->VisitStmt(op->body);
-  if (min.same_as(op->min) && extent.same_as(op->extent) && 
body.same_as(op->body)) {
+  if (min.same_as(op->min) && extent.same_as(op->extent) && 
body.same_as(op->body) &&
+      step.same_as(op->step)) {
     return ffi::GetRef<Stmt>(op);
   } else {
     auto n = CopyOnWrite(op);
     n->min = std::move(min);
     n->extent = std::move(extent);
+    n->step = std::move(step);
     n->body = std::move(body);
     return Stmt(n);
   }
diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc 
b/src/tir/schedule/primitive/blockize_tensorize.cc
index fbc569ece6..2ae32ea66a 100644
--- a/src/tir/schedule/primitive/blockize_tensorize.cc
+++ b/src/tir/schedule/primitive/blockize_tensorize.cc
@@ -703,7 +703,7 @@ class BlockizeRewriter : public StmtMutator {
   Stmt VisitStmt_(const ForNode* loop) final {
     if (loop == lca_->stmt) {
       return For(loop->loop_var, loop->min, loop->extent, loop->kind, 
RewriteSeq(loop->body),
-                 loop->thread_binding, loop->annotations, loop->span);
+                 loop->thread_binding, loop->annotations, loop->step, 
loop->span);
     }
     return StmtMutator::VisitStmt_(loop);
   }
diff --git a/src/tir/schedule/primitive/decompose_padding.cc 
b/src/tir/schedule/primitive/decompose_padding.cc
index 5499ab9c58..7e61fd4eb2 100644
--- a/src/tir/schedule/primitive/decompose_padding.cc
+++ b/src/tir/schedule/primitive/decompose_padding.cc
@@ -343,7 +343,7 @@ static std::pair<Stmt, BlockRealize> 
CreateInBoundBlock(const BlockRealizeNode*
     PrimExpr min = it == new_loop_ranges.end() ? loop->min : (*it).second->min;
     PrimExpr extent = it == new_loop_ranges.end() ? loop->extent : 
(*it).second->extent;
     nest_stmt_root = For(loop->loop_var, min, extent, loop->kind, 
nest_stmt_root,
-                         loop->thread_binding, loop->annotations, loop->span);
+                         loop->thread_binding, loop->annotations, loop->step, 
loop->span);
     if (loop.same_as(highest_pos_inclusive)) {
       break;
     }
diff --git a/src/tir/schedule/primitive/loop_transformation.cc 
b/src/tir/schedule/primitive/loop_transformation.cc
index b2c64e65e5..3cd364b0fd 100644
--- a/src/tir/schedule/primitive/loop_transformation.cc
+++ b/src/tir/schedule/primitive/loop_transformation.cc
@@ -1137,8 +1137,8 @@ void Reorder(ScheduleState self, const 
ffi::Array<StmtSRef>& ordered_loop_srefs)
 
 StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref) {
   if (sref->stmt->IsInstance<ForNode>()) {
-    For new_loop(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial,
-                 ffi::GetRef<Stmt>(sref->stmt));
+    For new_loop =
+        For(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, 
ffi::GetRef<Stmt>(sref->stmt));
     self->Replace(sref, new_loop, {});
     return self->stmt2ref.at(new_loop.get());
   }
diff --git a/src/tir/schedule/primitive/reduction.cc 
b/src/tir/schedule/primitive/reduction.cc
index 49dc31e6f6..0629757a13 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -268,7 +268,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const 
StmtSRef& block_sref,
   std::unordered_map<Var, Var> loop_var_map;
   Stmt body = BlockRealize(init_realize);
   for (int i : chosen_loops) {
-    const ForNode* old_loop = TVM_SREF_TO_FOR(loops[i]);
+    For old_loop = ffi::GetRef<For>(TVM_SREF_TO_FOR(loops[i]));
     // Create a new equivalent to the chosen loop
     Var old_loop_var = old_loop->loop_var;
     Var new_loop_var = old_loop_var.copy_with_suffix("_init");
@@ -280,12 +280,11 @@ StmtSRef DecomposeReduction(ScheduleState self, const 
StmtSRef& block_sref,
       thread_binding.CopyOnWrite()->var = new_var;
       opt_thread_binding = thread_binding;
     }
-    body = For(/*loop_var=*/new_loop_var,
-               /*min=*/old_loop->min,
-               /*extent=*/old_loop->extent,
-               /*kind=*/old_loop->kind,
-               /*body=*/body,
-               /*thread_binding=*/opt_thread_binding);
+    auto new_loop = old_loop.CopyOnWrite();
+    new_loop->loop_var = new_loop_var;
+    new_loop->thread_binding = opt_thread_binding;
+    new_loop->body = body;
+    body = ffi::GetRef<For>(new_loop);
   }
   body = Substitute(body, loop_var_map);
   // Step 6. Mutate IR
diff --git a/src/tir/transforms/canonicalize_loop.cc 
b/src/tir/transforms/canonicalize_loop.cc
new file mode 100644
index 0000000000..93511bf84b
--- /dev/null
+++ b/src/tir/transforms/canonicalize_loop.cc
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/transforms/canonicalize_loop.cc
+ * \brief Canonicalize all loops to start from zero and step one.
+ */
+#include <tvm/arith/analyzer.h>
+#include <tvm/ffi/reflection/registry.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <utility>
+
+namespace tvm {
+namespace tir {
+
+class LoopCanonicalizer : public StmtExprMutator {
+ public:
+  LoopCanonicalizer() = default;
+
+ private:
+  Stmt VisitStmt_(const ForNode* op) final {
+    if (is_zero(op->min) && op->HasTrivialStep()) {
+      return StmtExprMutator::VisitStmt_(op);
+    }
+    arith::Analyzer analyzer;
+    const auto* loop_var = op->loop_var.get();
+    PrimExpr step = op->step.value_or(make_const(loop_var->dtype, 1));
+
+    // report warning for negative step, since it would be a forever loop
+    if (!analyzer.CanProveGreaterEqual(step, 1)) {
+      // TODO(tvm): prove dynamic shaped step
+      LOG(FATAL) << "Loop step for " << op->loop_var << " may not be positive: 
" << step;
+    }
+
+    new_iter_info_[loop_var] = std::make_pair(step, op->min);
+    auto n = CopyOnWrite(op);
+    n->body = VisitStmt(op->body);
+    n->min = make_zero(loop_var->dtype);
+    n->extent = analyzer.Simplify(ceildiv(op->extent, step));
+    n->step = std::nullopt;
+    new_iter_info_.erase(loop_var);
+    return For(n);
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) final {
+    auto it = new_iter_info_.find(op);
+    if (it != new_iter_info_.end()) {
+      const auto& [stride, offset] = it->second;
+      return ffi::GetRef<Var>(op) * stride + offset;
+    }
+    return ffi::GetRef<Var>(op);
+  }
+
+  /*! \brief Map iter variable `x` to `x * stride + offset`. */
+  std::unordered_map<const VarNode*, std::pair<PrimExpr, PrimExpr>> 
new_iter_info_;
+};
+
+PrimFunc CanonicalizeLoop(PrimFunc func) {
+  PrimFuncNode* fptr = func.CopyOnWrite();
+  fptr->body = LoopCanonicalizer()(func->body);
+  return func;
+}
+
+namespace transform {
+
+Pass CanonicalizeLoop() {
+  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    return CanonicalizeLoop(std::move(f));
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.CanonicalizeLoop", {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef().def("tir.transform.CanonicalizeLoop", CanonicalizeLoop);
+}
+
+}  // namespace transform
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/transforms/common_subexpr_elim.cc 
b/src/tir/transforms/common_subexpr_elim.cc
index dfeb7fe2e2..9b9619fae9 100644
--- a/src/tir/transforms/common_subexpr_elim.cc
+++ b/src/tir/transforms/common_subexpr_elim.cc
@@ -602,7 +602,7 @@ Stmt CommonSubexpressionEliminator::VisitStmt_(const 
ForNode* op) {
     // Otherwise return a for node built with the new `min_new`, `extent_new` 
and `body_new`
     // that have just been obtained
     return For(op->loop_var, min_new, extent_new, op->kind, body_new, 
op->thread_binding,
-               op->annotations, op->span);
+               op->annotations, op->step, op->span);
   }
 }
 
diff --git a/src/tir/transforms/convert_for_loops_serial.cc 
b/src/tir/transforms/convert_for_loops_serial.cc
index a8b30ebf91..691d8b885c 100644
--- a/src/tir/transforms/convert_for_loops_serial.cc
+++ b/src/tir/transforms/convert_for_loops_serial.cc
@@ -43,7 +43,7 @@ class ForLoopSerialConverter : public StmtExprMutator {
 Stmt ForLoopSerialConverter::VisitStmt_(const ForNode* op) {
   if (op->kind == ForKind::kParallel) {
     return For(op->loop_var, op->min, op->extent, ForKind::kSerial, op->body, 
op->thread_binding,
-               op->annotations, op->span);
+               op->annotations, op->step, op->span);
   }
   return StmtExprMutator::VisitStmt_(op);
 }
diff --git a/src/tir/transforms/inject_software_pipeline.cc 
b/src/tir/transforms/inject_software_pipeline.cc
index af1b7c8bdf..f4258fc479 100644
--- a/src/tir/transforms/inject_software_pipeline.cc
+++ b/src/tir/transforms/inject_software_pipeline.cc
@@ -943,7 +943,7 @@ class PipelineRewriter : public StmtExprMutator {
     if (!is_unit_loop) {
       new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
                      unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, 
std::move(new_loop),
-                     std::nullopt, preserved_annotations_);
+                     std::nullopt, preserved_annotations_, std::nullopt);
     }
 
     // Update producer heads in the global async states.
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index dba13cfbbc..8bcb2077c6 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -362,9 +362,9 @@ class IRConvertSSA final : public StmtExprMutator {
     if (defined_.count(v.get())) {
       ScopedRedefine redefine(this, v);
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
-      op = stmt.as<ForNode>();
-      return For(redefine.new_var, op->min, op->extent, op->kind, op->body, 
op->thread_binding,
-                 op->annotations);
+      auto n = ffi::make_object<ForNode>(*stmt.as<ForNode>());
+      n->loop_var = redefine.new_var;
+      return For(n);
     } else {
       defined_.insert(v.get());
       return StmtExprMutator::VisitStmt_(op);
diff --git a/src/tir/transforms/lift_thread_binding.cc 
b/src/tir/transforms/lift_thread_binding.cc
index 2dffc11b72..45bbf4af52 100644
--- a/src/tir/transforms/lift_thread_binding.cc
+++ b/src/tir/transforms/lift_thread_binding.cc
@@ -133,7 +133,7 @@ class ThreadBindingLifter : public StmtExprMutator {
                    ForKind::kThreadBinding, std::move(body),
                    IterVar(Range(nullptr), Var(iter_var->thread_tag, 
iter_var->var->dtype),
                            kThreadIndex, iter_var->thread_tag),
-                   annotation);
+                   annotation, std::nullopt);
       }
     }
     if (is_kernel_root) {
diff --git a/src/tir/transforms/loop_partition.cc 
b/src/tir/transforms/loop_partition.cc
index e644c387cf..fd9bd2d653 100644
--- a/src/tir/transforms/loop_partition.cc
+++ b/src/tir/transforms/loop_partition.cc
@@ -760,14 +760,18 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var 
var, PrimExpr min, Prim
 inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt 
body) {
   const ForNode* for_node = static_cast<const ForNode*>(node);
   ICHECK(for_node);
+
   if (analyzer_.CanProve(extent == make_const(DataType::Int(32), 1)) &&
       !no_unroll_loop_with_extent_one_ && for_node->annotations.empty()) {
     // If the loop extent is 1, do not create the loop anymore
     return Substitute(body, {{Var{for_node->loop_var}, 
make_const(DataType::Int(32), 0)}});
   } else {
     ICHECK(for_node->kind != ForKind::kThreadBinding);
-    return For(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, 
for_node->kind, body,
-               for_node->thread_binding, for_node->annotations);
+    auto new_loop = ffi::make_object<ForNode>(*for_node);
+    new_loop->min = IntImm(for_node->min.dtype(), 0);
+    new_loop->extent = extent;
+    new_loop->body = body;
+    return For(new_loop);
   }
 }
 
diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc 
b/src/tir/transforms/lower_cross_thread_reduction.cc
index 25e8734ff1..2f7ac3ddb1 100644
--- a/src/tir/transforms/lower_cross_thread_reduction.cc
+++ b/src/tir/transforms/lower_cross_thread_reduction.cc
@@ -878,7 +878,9 @@ class CrossThreadReductionTransformer : public StmtMutator {
           /*body=*/body,                                      //
           /*thread_binding=*/
           IterVar(NullValue<Range>(), Var("", loop_vars[i]->dtype), 
IterVarType::kThreadIndex,
-                  "threadIdx." + dim_index));
+                  "threadIdx." + dim_index),
+          /*annotations=*/{},
+          /*step=*/std::nullopt);
     }
     return body;
   }
diff --git a/src/tir/transforms/lower_opaque_block.cc 
b/src/tir/transforms/lower_opaque_block.cc
index 2e53e89667..c0363dd898 100644
--- a/src/tir/transforms/lower_opaque_block.cc
+++ b/src/tir/transforms/lower_opaque_block.cc
@@ -111,7 +111,7 @@ class OpaqueBlockLower : public StmtExprMutator {
     } else {
       // Case 3. An ordinary loop
       body = For(op->loop_var, std::move(min), std::move(extent), op->kind, 
std::move(body),
-                 std::nullopt, new_annotations);
+                 std::nullopt, new_annotations, op->step);
     }
     // Step 5. Insert nested attrs
     for (auto it = pragma_attrs.rbegin(); it != pragma_attrs.rend(); ++it) {
diff --git a/src/tir/transforms/memhammer_coalesce.cc 
b/src/tir/transforms/memhammer_coalesce.cc
index 094f48e321..0d5b270442 100644
--- a/src/tir/transforms/memhammer_coalesce.cc
+++ b/src/tir/transforms/memhammer_coalesce.cc
@@ -128,7 +128,8 @@ Stmt SplitBindVectorize(const Stmt& stmt, const 
ConstraintSet& constraints) {
   body = For(new_loop_vars.back(), 0, vector_len, ForKind::kVectorized, 
std::move(body));
   for (int i = n - 2; i >= 1; i--) {
     body = For(new_loop_vars[i], 0, factors[i], ForKind::kThreadBinding, 
std::move(body),
-               IterVar(Range(nullptr), Var(thread_axis[i - 1]), kThreadIndex, 
thread_axis[i - 1]));
+               IterVar(Range(nullptr), Var(thread_axis[i - 1]), kThreadIndex, 
thread_axis[i - 1]),
+               {}, std::nullopt);
   }
   return For(new_loop_vars[0], 0, factors[0], ForKind::kSerial, 
std::move(body));
 }
diff --git a/src/tir/transforms/memhammer_tensorcore_rewrite.cc 
b/src/tir/transforms/memhammer_tensorcore_rewrite.cc
index e16c518771..e69ac30366 100644
--- a/src/tir/transforms/memhammer_tensorcore_rewrite.cc
+++ b/src/tir/transforms/memhammer_tensorcore_rewrite.cc
@@ -70,8 +70,9 @@ std::pair<Stmt, ffi::Optional<For>> TileWmmaBlock(Stmt stmt) {
   }
   For compute_location = Downcast<For>(body);
   for (int i = n - 3; i >= 0; i--) {
-    body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, 
loops[i]->kind, std::move(body),
-               loops[i]->thread_binding, loops[i]->annotations);
+    auto new_loop = ffi::GetRef<For>(loops[i]);
+    new_loop.CopyOnWrite()->body = std::move(body);
+    body = new_loop;
   }
   return {body, compute_location};
 }
@@ -187,8 +188,9 @@ Stmt RewriteWmmaLoad(Stmt stmt) {
           },
           /*annotations=*/{}));
   for (int i = n - 3; i >= 0; i--) {
-    wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, 
loops[i]->kind,
-                    std::move(wmma_body), loops[i]->thread_binding, 
loops[i]->annotations);
+    auto new_loop = ffi::GetRef<For>(loops[i]);
+    new_loop.CopyOnWrite()->body = std::move(wmma_body);
+    wmma_body = new_loop;
   }
   return wmma_body;
 }
@@ -290,8 +292,9 @@ Stmt RewriteWmmaStore(Stmt stmt) {
             },
             /*annotations=*/{}));
   for (int i = n - 3; i >= 0; i--) {
-    wmma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, 
loops[i]->kind,
-                    std::move(wmma_body), loops[i]->thread_binding, 
loops[i]->annotations);
+    auto new_loop = ffi::GetRef<For>(loops[i]);
+    new_loop.CopyOnWrite()->body = std::move(wmma_body);
+    wmma_body = new_loop;
   }
   return wmma_body;
 }
@@ -395,8 +398,9 @@ std::pair<Stmt, ffi::Optional<For>> 
TileMmaToGlobalBlock(Stmt stmt) {
   }
   For compute_location = Downcast<For>(body);
   for (int i = n - 3; i >= 0; i--) {
-    body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, 
loops[i]->kind, std::move(body),
-               loops[i]->thread_binding, loops[i]->annotations);
+    auto new_loop = ffi::GetRef<For>(loops[i]);
+    new_loop.CopyOnWrite()->body = std::move(body);
+    body = new_loop;
   }
   return {body, compute_location};
 }
@@ -484,21 +488,21 @@ Stmt RewriteMmaStore(Stmt stmt) {
             /*reads=*/{BufferRegion(src_buffer, read_region)},
             /*writes=*/{BufferRegion(tgt_buffer, write_region)},
             /*name_hint=*/"mma_store",
-            AttrStmt(/*node=*/IterVar(
-                         /*dom=*/Range::FromMinExtent(0, 32),
-                         /*var=*/tx,
-                         /*iter_type=*/IterVarType::kThreadIndex,
-                         /*thread_tag=*/"threadIdx.x"),
-                     /*attr_key=*/"thread_extent",
-                     /*value=*/Integer(32),
-                     /*body=*/
-                     For(vec, 0, 2, ForKind::kVectorized,
-                         /*body=*/
-                         BufferStore(new_tgt_buffer,
-                                     BufferLoad(new_src_buffer,
-                                                {floordiv(tx, 4), floormod(tx, 
4) * 2 + vec}),
-                                     {floordiv(tx, 4), floormod(tx, 4) * 2 + 
vec}),
-                         /*annotations=*/{})),
+            AttrStmt(
+                /*node=*/IterVar(
+                    /*dom=*/Range::FromMinExtent(0, 32),
+                    /*var=*/tx,
+                    /*iter_type=*/IterVarType::kThreadIndex,
+                    /*thread_tag=*/"threadIdx.x"),
+                /*attr_key=*/"thread_extent",
+                /*value=*/Integer(32),
+                /*body=*/
+                For(vec, 0, 2, ForKind::kVectorized,
+                    /*body=*/
+                    BufferStore(
+                        new_tgt_buffer,
+                        BufferLoad(new_src_buffer, {floordiv(tx, 4), 
floormod(tx, 4) * 2 + vec}),
+                        {floordiv(tx, 4), floormod(tx, 4) * 2 + vec}))),
             /*init=*/std::nullopt,
             /*alloc_buffers=*/{},
             /*match_buffers=*/
@@ -510,8 +514,9 @@ Stmt RewriteMmaStore(Stmt stmt) {
 
   // Step 3.4. wrap outer loops
   for (int i = n - 3; i >= 0; i--) {
-    mma_body = For(loops[i]->loop_var, loops[i]->min, loops[i]->extent, 
loops[i]->kind,
-                   std::move(mma_body), loops[i]->thread_binding, 
loops[i]->annotations);
+    auto new_loop = ffi::GetRef<For>(loops[i]);
+    new_loop.CopyOnWrite()->body = std::move(mma_body);
+    mma_body = new_loop;
   }
   return mma_body;
 }
diff --git a/src/tir/transforms/storage_rewrite.cc 
b/src/tir/transforms/storage_rewrite.cc
index 4af12c69a3..830364788c 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -510,7 +510,7 @@ class StoragePlanRewriter : public StmtExprMutator {
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
       op = stmt.as<ForNode>();
       return For(op->loop_var, op->min, op->extent, op->kind, MakeAttach(svec, 
op->body),
-                 op->thread_binding, op->annotations);
+                 op->thread_binding, op->annotations, op->step);
     } else {
       return StmtExprMutator::VisitStmt_(op);
     }
diff --git a/src/tir/transforms/unify_thread_binding.cc 
b/src/tir/transforms/unify_thread_binding.cc
index fa1e221459..502acd5a46 100644
--- a/src/tir/transforms/unify_thread_binding.cc
+++ b/src/tir/transforms/unify_thread_binding.cc
@@ -79,7 +79,8 @@ class ThreadBindingUnifier : public StmtExprMutator {
                  /*extent=*/IntImm(dtype, 1),      //
                  /*kind=*/ForKind::kSerial, stmt,  //
                  /*thread_binding=*/std::nullopt,  //
-                 /*annotation=*/std::move(annotations));
+                 /*annotation=*/std::move(annotations),
+                 /*step=*/std::nullopt);
     }
   }
 
@@ -155,7 +156,8 @@ class ThreadBindingUnifier : public StmtExprMutator {
       result = For(thread_binding->var, thread_binding->dom->min, 
thread_binding->dom->extent,
                    ForKind::kThreadBinding, result,
                    IterVar(NullValue<Range>(), Var(""), 
IterVarType::kThreadIndex,
-                           thread_binding->thread_tag));
+                           thread_binding->thread_tag),
+                   {}, std::nullopt);
       launch_threads_.pop_back();
     }
     return result;
diff --git a/src/tir/transforms/unroll_loop.cc 
b/src/tir/transforms/unroll_loop.cc
index d1269634ab..74abea57ba 100644
--- a/src/tir/transforms/unroll_loop.cc
+++ b/src/tir/transforms/unroll_loop.cc
@@ -156,8 +156,9 @@ class LoopUnroller : public StmtExprMutator {
     } else {
       if (auto_unroll) {
         if (op->kind != ForKind::kUnrolled) {
-          return For(op->loop_var, op->min, op->extent, ForKind::kUnrolled, 
op->body,
-                     op->thread_binding, op->annotations);
+          auto n = CopyOnWrite(op);
+          n->kind = ForKind::kUnrolled;
+          return For(n);
         }
       }
       return stmt;
diff --git a/src/tir/transforms/vectorize_loop.cc 
b/src/tir/transforms/vectorize_loop.cc
index 857f0b4cea..068903baa8 100644
--- a/src/tir/transforms/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -752,8 +752,10 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
     if (extent.same_as(op->extent) && body.same_as(op->body)) {
       return ffi::GetRef<Stmt>(op);
     } else {
-      return For(op->loop_var, op->min, extent, op->kind, body, 
op->thread_binding,
-                 op->annotations);
+      auto n = CopyOnWrite(op);
+      n->extent = extent;
+      n->body = body;
+      return For(n);
     }
   }
   // IfThenElse
diff --git a/tests/python/codegen/test_target_codegen.py 
b/tests/python/codegen/test_target_codegen.py
index 3332d015a8..7530786a38 100644
--- a/tests/python/codegen/test_target_codegen.py
+++ b/tests/python/codegen/test_target_codegen.py
@@ -16,7 +16,7 @@
 # under the License.
 
 import pytest
-
+import numpy as np
 import tvm
 from tvm.script import tir as T
 
@@ -88,5 +88,47 @@ def test_buffer_load_predicate_not_supported_gpu(target):
             tvm.compile(func)
 
 
[email protected]_targets("c", "llvm")
+def test_codegen_loop_step(target):
+    @T.prim_func
+    def test_loop_step(
+        A: T.Buffer((1024,), "float32"),
+        B: T.Buffer((1024,), "float32"),
+        C: T.Buffer((1024,), "float32"),
+    ):
+        for i in T.serial(3, 1024, step=96):
+            C[i] = A[i] + B[i]
+
+    with tvm.transform.PassContext(disabled_pass=["tir.CanonicalizeLoop"]):
+        lib = tvm.compile(test_loop_step, target=target)
+
+    src = lib.mod.inspect_source()
+    if target == "c":
+        assert src.find("for (int32_t i = 3; i < 1024; i += 96)") >= 0
+
+    dev = tvm.device(target, 0)
+    a_np = np.random.rand(1024).astype("float32")
+    b_np = np.random.rand(1024).astype("float32")
+    c_np = np.zeros(1024, dtype="float32")
+    a_tvm = tvm.runtime.tensor(a_np, dev)
+    b_tvm = tvm.runtime.tensor(b_np, dev)
+    c_tvm = tvm.runtime.tensor(c_np, dev)
+
+    lib(a_tvm, b_tvm, c_tvm)
+
+    c_result = c_tvm.numpy()
+
+    # Check that the loop executes at positions 3, 99, 195, 291, 387, 483, 
579, 675, 771, 867, 963
+    for i in range(3, 1024, 96):
+        np.testing.assert_allclose(c_result[i], a_np[i] + b_np[i], rtol=1e-5)
+
+    # Assert non-touched positions remain zero
+    for i in range(0, 3):
+        assert c_result[i] == 0.0
+    for i in range(4, 1024):
+        if (i - 3) % 96 != 0:
+            assert c_result[i] == 0.0
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/codegen/test_target_codegen_cuda.py 
b/tests/python/codegen/test_target_codegen_cuda.py
index 0841d0f545..1b31e64414 100644
--- a/tests/python/codegen/test_target_codegen_cuda.py
+++ b/tests/python/codegen/test_target_codegen_cuda.py
@@ -877,5 +877,37 @@ def test_thread_return():
     assert "return;" in cuda_code
 
 
[email protected]_gpu
[email protected]_cuda
+def test_cuda_loop_step():
+    @T.prim_func
+    def cuda_loop_step(
+        A: T.Buffer((1024,), "float32"),
+        B: T.Buffer((1024,), "float32"),
+        C: T.Buffer((1024,), "float32"),
+    ):
+        # Each thread computes a strided subset of the i loop: start = tx*3, 
step = 96 (3 * 32 threads)
+        for bx in T.thread_binding(1, "blockIdx.x"):
+            for tx in T.thread_binding(96, "threadIdx.x"):
+                for i in T.serial(tx, 1024, step=96):
+                    C[i] = A[i] + B[i]
+
+    target = tvm.target.Target({"kind": "cuda"})
+    with tvm.transform.PassContext(disabled_pass=["tir.CanonicalizeLoop"]):
+        lib = tvm.compile(cuda_loop_step, target=target)
+
+    cuda_src = lib.mod.imports[0].inspect_source()
+    assert "i += 96" in cuda_src
+    dev = tvm.cuda(0)
+    a_np = np.random.uniform(1, 100, (1024,)).astype("float32")
+    b_np = np.random.uniform(1, 100, (1024,)).astype("float32")
+    c_np = np.zeros((1024,), dtype="float32")
+    a_nd = tvm.runtime.tensor(a_np, dev)
+    b_nd = tvm.runtime.tensor(b_np, dev)
+    c_nd = tvm.runtime.tensor(c_np, dev)
+    lib["main"](a_nd, b_nd, c_nd)
+    tvm.testing.assert_allclose(c_nd.numpy(), a_np + b_np)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/tir-base/test_tir_nodes.py 
b/tests/python/tir-base/test_tir_nodes.py
index bc7cfeae17..85cd726dda 100644
--- a/tests/python/tir-base/test_tir_nodes.py
+++ b/tests/python/tir-base/test_tir_nodes.py
@@ -134,6 +134,7 @@ def test_basic():
 def test_stmt():
     x = tvm.tir.Evaluate(0)
     tvm.tir.For(te.var("i"), 0, 1, tvm.tir.ForKind.SERIAL, x)
+    tvm.tir.For(te.var("i"), 0, 1, tvm.tir.ForKind.UNROLLED, x, step=2)
 
 
 def test_dir():
diff --git a/tests/python/tir-transform/test_tir_transform_canonicalize_loop.py 
b/tests/python/tir-transform/test_tir_transform_canonicalize_loop.py
new file mode 100644
index 0000000000..6f6d88137c
--- /dev/null
+++ b/tests/python/tir-transform/test_tir_transform_canonicalize_loop.py
@@ -0,0 +1,88 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm
+from tvm import tir
+from tvm.script import tir as T
+
+
+def test_canonicalize_loop():
+    @T.prim_func
+    def before(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"]):
+        T.func_attr({"global_symbol": "main"})
+        for i in range(1, 128, 5):
+            B[i] = A[i] + 1.0
+
+    @T.prim_func
+    def expected(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), 
"float32"]):
+        T.func_attr({"global_symbol": "main"})
+        for i in T.serial(0, 26):
+            B[i * 5 + 1] = A[i * 5 + 1] + 1.0
+
+    mod = tvm.IRModule.from_expr(before)
+    mod = tir.transform.CanonicalizeLoop()(mod)
+    tvm.ir.assert_structural_equal(mod["main"], expected)
+
+
+def test_canonicalize_nested_loop():
+    @T.prim_func
+    def before(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), 
"float32"]):
+        T.func_attr({"global_symbol": "main"})
+        for i in range(1, 128, 5):
+            for j in range(2, 128, 3):
+                B[i, j] = A[i, j] + 1.0
+
+    @T.prim_func
+    def expected(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), 
"float32"]):
+        T.func_attr({"global_symbol": "main"})
+        for i in T.serial(0, 26):
+            for j in T.serial(0, 42):
+                B[i * 5 + 1, j * 3 + 2] = A[i * 5 + 1, j * 3 + 2] + 1.0
+
+    mod = tvm.IRModule.from_expr(before)
+    mod = tir.transform.CanonicalizeLoop()(mod)
+    tvm.ir.assert_structural_equal(mod["main"], expected)
+
+
+def test_canonicalize_negative_step():
+    @T.prim_func
+    def before(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"]):
+        T.func_attr({"global_symbol": "main"})
+        for i in T.serial(0, 127, step=-3):
+            B[i] = A[i] + 1.0
+
+    mod = tvm.IRModule.from_expr(before)
+    with pytest.raises(tvm.error.InternalError):
+        mod = tir.transform.CanonicalizeLoop()(mod)
+
+
+def test_canonicalize_dynamic_step():
+    """Currently we report error for dynamic step since we could not prove it 
is positive"""
+
+    @T.prim_func
+    def before(A: T.Buffer[(128,), "float32"], B: T.Buffer[(128,), "float32"], 
step: T.int32):
+        T.func_attr({"global_symbol": "main"})
+        for i in T.serial(0, 128, step=step):
+            B[i] = A[i] + 1.0
+
+    mod = tvm.IRModule.from_expr(before)
+    with pytest.raises(tvm.error.InternalError):
+        mod = tir.transform.CanonicalizeLoop()(mod)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py 
b/tests/python/tvmscript/test_tvmscript_parser_tir.py
index f1569be5b1..3b84e919c8 100644
--- a/tests/python/tvmscript/test_tvmscript_parser_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py
@@ -327,6 +327,32 @@ def test_tir_starred_for_loop():
     tvm.ir.assert_structural_equal(starred, non_starred)
 
 
+def test_tir_loop_steps():
+    N = T.Var("N", "int32")
+
+    @T.prim_func(private=True)
+    def loop_with_steps(
+        A: T.Buffer((N,)), B: T.Buffer((N,)), C: T.Buffer((N,)), tid: T.int32, 
v: T.int32
+    ):
+        for i in T.serial(tid, N, step=2):
+            C[i] = A[i] + B[i]
+        for i in T.unroll(tid, N, step=3):
+            C[i] = A[i] + B[i]
+        for i in T.vectorized(tid, N, step=4):
+            C[i] = A[i] + B[i]
+        for i in T.parallel(tid, N, step=5):
+            C[i] = A[i] + B[i]
+        for i in T.serial(tid, N, step=v):
+            C[i] = A[i] + B[i]
+
+    stmts = loop_with_steps.body.seq
+    assert stmts[0].step == 2
+    assert stmts[1].step == 3
+    assert stmts[2].step == 4
+    assert stmts[3].step == 5
+    assert stmts[4].step.name == "v"
+
+
 def test_tir_empty_tuple_index():
     @T.macro
     def bar(val):
diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py 
b/tests/python/tvmscript/test_tvmscript_roundtrip.py
index 1954ca773f..b3d459b2e6 100644
--- a/tests/python/tvmscript/test_tvmscript_roundtrip.py
+++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py
@@ -4018,6 +4018,25 @@ def func_with_loop_jumps():
     return func
 
 
+def func_with_loop_steps():
+    @T.prim_func
+    def func(
+        A: T.Buffer((1024,)), B: T.Buffer((1024,)), C: T.Buffer((1024,)), tid: 
T.int32, v: T.int32
+    ):
+        for i in T.serial(tid, 1024, step=2):
+            C[i] = A[i] + B[i]
+        for i in T.unroll(tid, 1024, step=3):
+            C[i] = A[i] + B[i]
+        for i in T.vectorized(tid, 1024, step=4):
+            C[i] = A[i] + B[i]
+        for i in T.parallel(tid, 1024, step=5):
+            C[i] = A[i] + B[i]
+        for i in range(tid, 1024, 6):
+            C[i] = A[i] + B[i]
+
+    return func
+
+
 def op_of_literal():
     op_list = [
         (T.exp, 0),
@@ -4237,6 +4256,7 @@ ir_generator = tvm.testing.parameter(
     return_zero_private_with_attr,
     func_attr_with_list,
     func_with_loop_jumps,
+    func_with_loop_steps,
     *op_of_literal(),
     *relax_match_cast_struct_info_proxy(),
     relax_symbolic_size_var,

Reply via email to