Hzfengsy commented on a change in pull request #7630:
URL: https://github.com/apache/tvm/pull/7630#discussion_r592439797
##########
File path: python/tvm/script/scope_handler.py
##########
@@ -185,92 +226,365 @@ def let(var, value, span):
super().__init__(let, concise_scope=False, def_symbol=False)
+@register
+class Block(WithScopeHandler):
+ """ With scope handler tir.block(extents, name) as iter_vars"""
+
+ def __init__(self):
+ def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+ assert self.node
+ assert self.context
+ assert self.body
+ block_info = self.context.block_info_stack[-1]
+ if axes is None:
+ axes = []
+ if len(axes) != len(self.block_vars):
+ self.context.report_error(
+ "Inconsistent number of block vars, "
+ + f"gets {len(axes)} axes but {len(self.block_vars)} block
vars.",
+ self.node.span,
+ )
+ block_iters: List[IterVar] = []
+ for i, axis in enumerate(axes):
+ axis = tvm.runtime.convert(axis)
+ if isinstance(axis, tvm.tir.PrimExpr):
+ block_var_dom = Range.from_min_extent(0, axis)
+ block_iters.append(IterVar(block_var_dom,
self.block_vars[i], 0))
+ elif isinstance(axis, Range):
+ block_iters.append(IterVar(axis, self.block_vars[i], 0))
+ elif isinstance(axis, IterVar):
+ block_iters.append(IterVar(axis.dom, self.block_vars[i],
axis.iter_type))
+ else:
+ self.context.report_error(
+ "Invalid argument of tir.block(), "
+ + f"expects PrimExpr, Range or IterVar, but gets
{type(axis)}",
+ self.node.span,
+ )
+
+ # create block read/write regions
+
+ reads: List[BufferRegion] = (
+ [buffer_slice_to_region(read) for read in block_info.reads]
+ if block_info.reads
+ else []
+ )
+ writes: List[BufferRegion] = (
+ [buffer_slice_to_region(write) for write in block_info.writes]
+ if block_info.writes
+ else []
+ )
+ inner = tvm.tir.Block(
+ block_iters,
+ reads,
+ writes,
+ name_hint,
+ self.body,
+ block_info.init,
+ block_info.alloc_buffers,
+ block_info.match_buffers,
+ block_info.annotations,
+ span,
+ )
+ # create block var iter binding
+ values: List[PrimExpr]
+ if not block_info.iter_bindings:
+ values = self.context.loop_stack[-2].copy()
+ if len(values) == 0:
+ values = [tvm.tir.const(float("nan"), dtype="float32")] *
len(block_iters)
+ elif len(values) != len(block_iters):
+ self.context.report_error(
+ "Autocomplete block iter var binding expects larger
number of loops",
+ self.node.span,
+ )
+ else:
+ for block_var in self.block_vars:
+ if block_var not in block_info.iter_bindings:
+ self.context.report_error(
+ "Missing block iter var binding for " +
block_var.name,
+ self.node.span,
+ )
+ values = [block_info.iter_bindings[block_var] for block_var in
self.block_vars]
+ predicate = (
+ tvm.tir.const(True, "bool")
+ if block_info.predicate is None
+ else block_info.predicate
+ )
+ body = tvm.tir.BlockRealize(values, predicate, inner, span)
+ return body
+
+ super().__init__(func=block, concise_scope=False, def_symbol=True)
+ self.block_vars = None
+
+ def enter_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
+ # define block vars
+ assert isinstance(node, ast.With)
+
+ var_names = WithScopeHandler.get_optional_var_names(node, context)
+ self.block_vars = [tvm.te.var(name) for name in var_names]
+ for block_var in self.block_vars:
+ context.update_symbol(block_var.name, block_var, node)
+
+
+@register
+class InitBlock(WithScopeHandler):
+ """ With scope handler tir.init()"""
+
+ def __init__(self):
+ def init(span: Span = None):
+ assert self.context
+ if self.context.block_info_stack[-2].init is not None:
+ self.context.report_error("Duplicate init block declaration",
span)
Review comment:
It is a bit tricky to show two spans in one report_error
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]