Hzfengsy commented on a change in pull request #7630:
URL: https://github.com/apache/tvm/pull/7630#discussion_r593893931



##########
File path: python/tvm/script/scope_handler.py
##########
@@ -185,92 +226,365 @@ def let(var, value, span):
         super().__init__(let, concise_scope=False, def_symbol=False)
 
 
+@register
+class Block(WithScopeHandler):
+    """ With scope handler tir.block(extents, name) as iter_vars"""
+
+    def __init__(self):
+        def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+            assert self.node
+            assert self.context
+            assert self.body
+            block_info = self.context.block_info_stack[-1]
+            if axes is None:
+                axes = []
+            if len(axes) != len(self.block_vars):
+                self.context.report_error(
+                    "Inconsistent number of block vars, "
+                    + f"gets {len(axes)} axes but {len(self.block_vars)} block 
vars.",
+                    self.node.span,
+                )
+            block_iters: List[IterVar] = []
+            for i, axis in enumerate(axes):
+                axis = tvm.runtime.convert(axis)
+                if isinstance(axis, tvm.tir.PrimExpr):
+                    block_var_dom = Range.from_min_extent(0, axis)
+                    block_iters.append(IterVar(block_var_dom, 
self.block_vars[i], 0))
+                elif isinstance(axis, Range):
+                    block_iters.append(IterVar(axis, self.block_vars[i], 0))
+                elif isinstance(axis, IterVar):
+                    block_iters.append(IterVar(axis.dom, self.block_vars[i], 
axis.iter_type))
+                else:
+                    self.context.report_error(
+                        "Invalid argument of tir.block(), "
+                        + f"expects PrimExpr, Range or IterVar, but gets 
{type(axis)}",
+                        self.node.span,
+                    )
+
+            # create block read/write regions
+
+            reads: List[BufferRegion] = (
+                [buffer_slice_to_region(read) for read in block_info.reads]
+                if block_info.reads
+                else []
+            )
+            writes: List[BufferRegion] = (
+                [buffer_slice_to_region(write) for write in block_info.writes]
+                if block_info.writes
+                else []
+            )
+            inner = tvm.tir.Block(
+                block_iters,
+                reads,
+                writes,
+                name_hint,
+                self.body,
+                block_info.init,
+                block_info.alloc_buffers,
+                block_info.match_buffers,
+                block_info.annotations,
+                span,
+            )
+            # create block var iter binding
+            values: List[PrimExpr]
+            if not block_info.iter_bindings:
+                values = self.context.loop_stack[-2].copy()

Review comment:
       Yes we do




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to