tkonolige commented on a change in pull request #7630:
URL: https://github.com/apache/tvm/pull/7630#discussion_r591880328
##########
File path: python/tvm/script/context_maintainer.py
##########
@@ -16,59 +16,138 @@
# under the License.
"""TVM Script Context Maintainer for TIR"""
-from tvm.te import schedule
+from typing import List, Mapping, Union, Optional, Dict, Callable
+import synr
+
+
+import tvm
+from tvm.ir import Span
+from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
+from tvm.runtime import Object
+from .node import BufferSlice
+
+
+class BlockInfo:
+ """Information for block and block_realize signature"""
+
+ alloc_buffers: List[Buffer]
+ match_buffers: List[MatchBufferRegion]
+ iter_bindings: Mapping[Var, PrimExpr]
+ reads: Optional[List[BufferSlice]]
+ writes: Optional[List[BufferSlice]]
+ annotations: Optional[Mapping[str, Object]]
+ predicate: Optional[PrimExpr]
+ init: Optional[Stmt]
+
+ def __init__(self):
+ self.alloc_buffers = []
+ self.match_buffers = []
+ self.iter_bindings = {}
+ self.reads = None
+ self.writes = None
+ self.annotations = None
+ self.predicate = None
+ self.init = None
class ContextMaintainer:
"""Maintain all the necessary context info"""
- def __init__(self, parser):
+ # scope context
+ node_stack: List[List[synr.ast.Node]]
+ block_info_stack: List[BlockInfo]
+ loop_stack: List[List[Var]]
Review comment:
I think a little more documentation on what these stacks hold could be
useful.
##########
File path: python/tvm/script/scope_handler.py
##########
@@ -185,92 +226,365 @@ def let(var, value, span):
super().__init__(let, concise_scope=False, def_symbol=False)
+@register
+class Block(WithScopeHandler):
+ """ With scope handler tir.block(extents, name) as iter_vars"""
+
+ def __init__(self):
+ def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+ assert self.node
+ assert self.context
+ assert self.body
+ block_info = self.context.block_info_stack[-1]
+ if axes is None:
+ axes = []
+ if len(axes) != len(self.block_vars):
+ self.context.report_error(
+ "Inconsistent number of block vars, "
+ + f"gets {len(axes)} axes but {len(self.block_vars)} block
vars.",
+ self.node.span,
+ )
+ block_iters: List[IterVar] = []
+ for i, axis in enumerate(axes):
+ axis = tvm.runtime.convert(axis)
+ if isinstance(axis, tvm.tir.PrimExpr):
+ block_var_dom = Range.from_min_extent(0, axis)
+ block_iters.append(IterVar(block_var_dom,
self.block_vars[i], 0))
+ elif isinstance(axis, Range):
+ block_iters.append(IterVar(axis, self.block_vars[i], 0))
+ elif isinstance(axis, IterVar):
+ block_iters.append(IterVar(axis.dom, self.block_vars[i],
axis.iter_type))
+ else:
+ self.context.report_error(
+ "Invalid argument of tir.block(), "
+ + f"expects PrimExpr, Range or IterVar, but gets
{type(axis)}",
+ self.node.span,
+ )
+
+ # create block read/write regions
+
+ reads: List[BufferRegion] = (
+ [buffer_slice_to_region(read) for read in block_info.reads]
+ if block_info.reads
+ else []
+ )
+ writes: List[BufferRegion] = (
+ [buffer_slice_to_region(write) for write in block_info.writes]
+ if block_info.writes
+ else []
+ )
+ inner = tvm.tir.Block(
+ block_iters,
+ reads,
+ writes,
+ name_hint,
+ self.body,
+ block_info.init,
+ block_info.alloc_buffers,
+ block_info.match_buffers,
+ block_info.annotations,
+ span,
+ )
+ # create block var iter binding
+ values: List[PrimExpr]
+ if not block_info.iter_bindings:
+ values = self.context.loop_stack[-2].copy()
+ if len(values) == 0:
+ values = [tvm.tir.const(float("nan"), dtype="float32")] *
len(block_iters)
+ elif len(values) != len(block_iters):
+ self.context.report_error(
+ "Autocomplete block iter var binding expects larger
number of loops",
+ self.node.span,
+ )
+ else:
+ for block_var in self.block_vars:
+ if block_var not in block_info.iter_bindings:
+ self.context.report_error(
+ "Missing block iter var binding for " +
block_var.name,
+ self.node.span,
+ )
+ values = [block_info.iter_bindings[block_var] for block_var in
self.block_vars]
+ predicate = (
+ tvm.tir.const(True, "bool")
+ if block_info.predicate is None
+ else block_info.predicate
+ )
+ body = tvm.tir.BlockRealize(values, predicate, inner, span)
+ return body
+
+ super().__init__(func=block, concise_scope=False, def_symbol=True)
+ self.block_vars = None
+
+ def enter_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
+ # define block vars
+ assert isinstance(node, ast.With)
Review comment:
An error message here would be helpful.
##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,12 +181,289 @@ def buffer_decl(
buffer_type,
span=span,
)
- self.context.update_symbol(self.node.lhs.id.name, buffer)
+ self.context.update_symbol(self.node.lhs.id.name, buffer,
self.node)
return buffer
super().__init__(buffer_decl, def_symbol=True)
+@register
+class AllocBuffer(SpecialStmt):
+ """Special function alloc_buffer(shape, dtype, data, strides, elem_offset,
scope, align,
+ offset_factor, buffer_type)
+
+ Example
+ -------
+ .. code-block:: python
+
+ A = tir.alloc_buffer((128, 128), dtype="float32")
+
+ """
+
+ def __init__(self):
+ def alloc_buffer(
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="",
+ align=-1,
+ offset_factor=0,
+ buffer_type="default",
+ span=None,
+ ):
+ if not isinstance(self.node, ast.Assign):
+ self.context.report_error(
+ "Need assign the alloc_buffer to a buffer, e.g. A =
alloc_buffer(...)",
+ self.node.span,
+ )
+
+ if strides is None:
+ strides = []
+ align = convert_to_int(align, "align", self.context.report_error,
self.node.span)
+ offset_factor = convert_to_int(
+ offset_factor, "offset_factor", self.context.report_error,
self.node.span
+ )
+ buffer = tvm.tir.decl_buffer(
+ shape,
+ dtype,
+ self.node.lhs.id.name,
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ span=span,
+ )
+ self.context.current_block_scope().alloc_buffers.append(buffer)
+ self.context.update_symbol(self.node.lhs.id.name, buffer,
self.node)
+
+ super().__init__(alloc_buffer, def_symbol=True)
+
+
+@register
+class BlockVarBind(SpecialStmt):
+ """Special function bind(block_iter, binding_value)
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.bind(vx, i)
+
+ """
+
+ def __init__(self):
+ def bind(iter_var, values, span=None):
+ block_scope = self.context.current_block_scope()
+ if iter_var in block_scope.iter_bindings:
+ self.context.report_error("Duplicate iter_var bindings of " +
str(iter_var), span)
+ block_scope.iter_bindings[iter_var] = values
+
+ super().__init__(bind, def_symbol=False)
+
+
+@register
+class BlockReads(SpecialStmt):
+ """Special function reads([read_buffer_regions])
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
+
+ """
+
+ def __init__(self):
+ def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span:
Span = None):
+ assert self.context
+ block_scope = self.context.current_block_scope()
+ if block_scope.reads is not None:
+ self.context.report_error(
+ "Duplicate write region declaration, "
+ + "previous one is "
+ + str(", ".join(str(x) for x in block_scope.reads)),
+ span,
+ )
+ if isinstance(read_regions, list):
+ pass
+ elif isinstance(read_regions, BufferSlice):
+ read_regions = [read_regions]
+ else:
+ self.context.report_error(
+ "Error input type. "
+ + f"Expects BufferSlice or List[BufferSlice], but gets
{type(read_regions)}",
+ span,
+ )
+ block_scope.reads = read_regions
+
+ super().__init__(reads, def_symbol=False)
+
+
+@register
+class BlockWrites(SpecialStmt):
+ """Special function writes([write_buffer_regions])
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.writes([C[vi: vi + 4, vj])
+
+ """
+
+ def __init__(self):
+ def writes(write_region: Union[BufferSlice, List[BufferSlice]], span:
Span = None):
+ assert self.context
+ block_scope = self.context.current_block_scope()
+ if block_scope.writes is not None:
+ self.context.report_error(
+ "Duplicate write region declaration, "
+ + "previous one is "
+ + str(", ".join(str(x) for x in block_scope.writes)),
+ span,
+ )
+ if isinstance(write_region, list):
+ pass
+ elif isinstance(write_region, BufferSlice):
+ write_region = [write_region]
+ else:
+ self.context.report_error(
+ "Error input type. "
+ + f"Expects BufferSlice or List[BufferSlice], but gets
{type(write_region)}",
+ span,
+ )
+ block_scope.writes = write_region
+
+ super().__init__(writes, def_symbol=False)
+
+
+@register
+class BlockAttr(SpecialStmt):
+ """Special function block_attr({attr_key: attr_value})
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.block_attr({"double_buffer_scope": 1})
+
+ """
+
+ def __init__(self):
+ def block_attr(attrs: Mapping[str, Object], span: Span = None):
+ assert self.context
+ block_scope = self.context.current_block_scope()
+ if block_scope.annotations is not None:
+ self.context.report_error(
+ "Duplicate block annotations declaration, "
+ + "previous one is "
+ + str(block_scope.annotations),
+ span,
+ )
+ attrs = {
+ key: tvm.tir.StringImm(val) if isinstance(val, str) else val
+ for key, val in attrs.items()
+ }
+ block_scope.annotations = attrs
+
+ super().__init__(block_attr, def_symbol=False)
+
+
+@register
+class BlockPredicate(SpecialStmt):
+ """Special function where(predicate)
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.where(i < 4)
+
+ """
+
+ def __init__(self):
+ def where(predicate, span=None):
+ block_scope = self.context.current_block_scope()
+ if block_scope.predicate is not None:
+ self.context.report_error(
+ "Duplicate block predicate declaration, "
+ + "previous one is "
+ + str(block_scope.predicate),
+ span,
+ )
+
+ block_scope.predicate = predicate
+
+ super().__init__(where, def_symbol=False)
+
+
+@register
+class BlockMatchBufferRegion(SpecialStmt):
+ """Special function match_buffer_region(source, strides, elem_offset,
align, offset_factor)
+
+ Example
+ -------
+ .. code-block:: python
+
+ B = tir.match_buffer_region(A[0: 4])
+
+ """
+
+ def __init__(self):
+ def match_buffer_region(
+ source,
+ strides=None,
+ elem_offset=None,
+ align=-1,
+ offset_factor=0,
+ span=None,
+ ):
+ if not isinstance(self.node, ast.Assign):
+ self.context.report_error(
+ "Need assign the match_buffer_region to a buffer, "
+ + "e.g. A = match_buffer_region(...)",
+ self.node.span,
+ )
+
+ if strides is None:
+ strides = []
+ align = convert_to_int(align, "align", self.context.report_error,
self.node.span)
+ offset_factor = convert_to_int(
+ offset_factor, "offset_factor", self.context.report_error,
self.node.span
+ )
+
+ if not isinstance(source, BufferSlice):
+ self.context.report_error(
+ "match_buffer_region needs a buffer region as source",
Review comment:
Could you add a small example here.
##########
File path: src/tir/analysis/block_access_region_detector.cc
##########
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/block_region_detector.cc
+ * \brief Detect block read/write regions by visiting its body
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Auto detect the block read write region
+ * It will detect the read/write region as an array in order of
appearance in AST
+ * \note This detector only accepts to visit a block and will not visit child
blocks recursively
+ */
+class BlockReadWriteDetector : public StmtExprVisitor {
+ public:
+ explicit BlockReadWriteDetector(const Map<Var, Buffer>& buffer_var_map)
+ : buffer_var_map_(buffer_var_map) {}
+
+ /*! \brief Return read regions of the block */
+ Array<BufferRegion> CollectReads();
+ /*! \brief Return write regions of the block */
+ Array<BufferRegion> CollectWrites();
+ /*!
+ * \brief Return opaque buffer regions of the block
+ * \note The buffer accessed by load/store or call with buffer.data will
+ * be marked as opaque.
+ */
+ Array<BufferRegion> CollectOpaques();
+ /*! \brief overload operator() to make sure it accepts a block node */
+ void operator()(const Stmt& stmt);
+
+ private:
+ /*! \brief Iteration range for loop_vars */
+ std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
+ /*! \brief The buffers that the current block reads */
+ std::vector<Buffer> read_buffers_;
+ /*! \brief The buffers that the current block writes */
+ std::vector<Buffer> writes_buffers_;
+ /*! \brief The opaque buffer which is access by buffer.data */
+ std::vector<Buffer> opaque_buffers_;
+ /*! \brief The read regions of the current block */
+ std::vector<std::vector<tvm::arith::IntSet>> read_regions_;
+ /*! \brief The write regions of the current block */
+ std::vector<std::vector<tvm::arith::IntSet>> write_regions_;
+ /*! \brief The outside buffer data mapping to its buffer */
+ Map<Var, Buffer> buffer_var_map_;
+ /*! \brief The analyzer for simplifying*/
+ arith::Analyzer analyzer_;
+
+ /*!
+ * \brief Update read/write buffers and regions with provided buffer and
region
+ * \param buffers The buffers should be updated
+ * \param regions The access regions should be updated
+ * \param buffer The provided buffer
+ * \param region The provided region
+ */
+ void Update(std::vector<Buffer>* buffers,
std::vector<std::vector<arith::IntSet>>* regions,
+ const Buffer& buffer, const std::vector<arith::IntSet>& region);
+
+ /*! \brief Helper function to collect access regions. */
+ Array<BufferRegion> CollectRegions(const std::vector<Buffer>& buffers,
+ const
std::vector<std::vector<tvm::arith::IntSet>>& regions);
+
+ /*! \brief Helper function to add a opaque buffer. */
+ void AddOpaque(const Var& buffer_var);
+
+ void VisitStmt_(const ForNode* op) override;
+ void VisitStmt_(const BlockRealizeNode* op) override;
+ void VisitStmt_(const BufferStoreNode* op) override;
+ void VisitStmt_(const StoreNode* op) override;
+ void VisitExpr_(const BufferLoadNode* op) override;
+ void VisitExpr_(const LoadNode* op) override;
+ void VisitExpr_(const VarNode* op) override;
+};
+
+void BlockReadWriteDetector::operator()(const Stmt& stmt) {
+ ICHECK(stmt.as<BlockNode>() != nullptr) << "Only allow to visit a block";
Review comment:
Could you change this to 'Only visiting Blocks is allowed, but got ...'
##########
File path: python/tvm/script/parser.py
##########
@@ -716,47 +767,50 @@ def transform_Slice(self, node):
end = self.transform(node.end)
if not (isinstance(node.step, ast.Constant) and node.step.value == 1):
self.report_error("Only step size 1 is supported for slices.",
node.step.span)
- extent = end - start
- if isinstance(extent, tvm.tir.PrimExpr):
- ana = tvm.arith.Analyzer()
- extent = ana.simplify(extent)
- return tvm.ir.Range.from_min_extent(start, extent,
span=from_synr_span(node.span))
+ return Slice(start, end)
def transform_Subscript(self, node):
"""Array access visitor.
By now only 2 types of Subscript are supported:
1. Buffer[index, index, ...], Buffer element access(BufferLoad &
BufferStore)
Var[index] Buffer element access()
- 2. meta[type_key][index], Meta info access
Review comment:
is meta access no longer allowed?
##########
File path: python/tvm/script/context_maintainer.py
##########
@@ -16,59 +16,138 @@
# under the License.
"""TVM Script Context Maintainer for TIR"""
-from tvm.te import schedule
+from typing import List, Mapping, Union, Optional, Dict, Callable
+import synr
+
+
+import tvm
+from tvm.ir import Span
+from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
+from tvm.runtime import Object
+from .node import BufferSlice
+
+
+class BlockInfo:
+ """Information for block and block_realize signature"""
+
+ alloc_buffers: List[Buffer]
+ match_buffers: List[MatchBufferRegion]
+ iter_bindings: Mapping[Var, PrimExpr]
+ reads: Optional[List[BufferSlice]]
+ writes: Optional[List[BufferSlice]]
+ annotations: Optional[Mapping[str, Object]]
+ predicate: Optional[PrimExpr]
+ init: Optional[Stmt]
+
+ def __init__(self):
+ self.alloc_buffers = []
+ self.match_buffers = []
+ self.iter_bindings = {}
+ self.reads = None
+ self.writes = None
+ self.annotations = None
+ self.predicate = None
+ self.init = None
class ContextMaintainer:
"""Maintain all the necessary context info"""
- def __init__(self, parser):
+ # scope context
+ node_stack: List[List[synr.ast.Node]]
+ block_info_stack: List[BlockInfo]
+ loop_stack: List[List[Var]]
+ symbols: List[Dict[str, Union[Var, Buffer]]]
+ # function context
+ func_params: List[Var]
+ func_buffer_map: Mapping[Var, Buffer]
+ func_dict_attr: Mapping[str, Object]
+ func_var_env_dict: Mapping[Var, str]
+ # parser and analyzer
+ _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
+ analyzer: tvm.arith.Analyzer
+
+ def __init__(self, _report_error: Callable[[str, Union[Span,
synr.ast.Span]], None]):
# scope context
self.node_stack = [] # AST nodes of scopes
+ self.block_info_stack = [] # Block info of scopes
+ self.loop_stack = [] # stack of loop vars
self.symbols = [] # symbols of scopes
# function context
self.func_params = [] # parameter list of function
self.func_buffer_map = {} # buffer_map of function
self.func_dict_attr = {} # func_attr of function
self.func_var_env_dict = {} # map from var to env_name
- # parser
- self.parser = parser
+ # parser and analyzer
+ self._report_error = _report_error
+ self.analyzer = tvm.arith.Analyzer()
- def pop_scope(self):
- """Pop the inner most scope"""
- self.symbols.pop()
- self.node_stack.pop()
+ def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
+ """Creating a new scope
- def new_scope(self, nodes=None):
- """Creating a new scope"""
+ Parameters
+ ----------
+ nodes : Optional[List[synr.ast.Node]]
+ The synr AST nodes in new scope
+ """
if nodes is None:
nodes = []
self.node_stack.append(list(reversed(nodes)))
self.symbols.append(dict())
- def update_symbol(self, name, symbol):
+ def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
Review comment:
Documenting the difference between a block scope and a regular scope
would be useful.
##########
File path: python/tvm/script/scope_handler.py
##########
@@ -185,92 +226,365 @@ def let(var, value, span):
super().__init__(let, concise_scope=False, def_symbol=False)
+@register
+class Block(WithScopeHandler):
+ """ With scope handler tir.block(extents, name) as iter_vars"""
+
+ def __init__(self):
+ def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+ assert self.node
+ assert self.context
+ assert self.body
+ block_info = self.context.block_info_stack[-1]
+ if axes is None:
+ axes = []
+ if len(axes) != len(self.block_vars):
+ self.context.report_error(
+ "Inconsistent number of block vars, "
+ + f"gets {len(axes)} axes but {len(self.block_vars)} block
vars.",
+ self.node.span,
+ )
+ block_iters: List[IterVar] = []
+ for i, axis in enumerate(axes):
+ axis = tvm.runtime.convert(axis)
+ if isinstance(axis, tvm.tir.PrimExpr):
+ block_var_dom = Range.from_min_extent(0, axis)
+ block_iters.append(IterVar(block_var_dom,
self.block_vars[i], 0))
+ elif isinstance(axis, Range):
+ block_iters.append(IterVar(axis, self.block_vars[i], 0))
+ elif isinstance(axis, IterVar):
+ block_iters.append(IterVar(axis.dom, self.block_vars[i],
axis.iter_type))
+ else:
+ self.context.report_error(
+ "Invalid argument of tir.block(), "
+ + f"expects PrimExpr, Range or IterVar, but gets
{type(axis)}",
+ self.node.span,
+ )
+
+ # create block read/write regions
+
+ reads: List[BufferRegion] = (
+ [buffer_slice_to_region(read) for read in block_info.reads]
+ if block_info.reads
+ else []
+ )
+ writes: List[BufferRegion] = (
+ [buffer_slice_to_region(write) for write in block_info.writes]
+ if block_info.writes
+ else []
+ )
+ inner = tvm.tir.Block(
+ block_iters,
+ reads,
+ writes,
+ name_hint,
+ self.body,
+ block_info.init,
+ block_info.alloc_buffers,
+ block_info.match_buffers,
+ block_info.annotations,
+ span,
+ )
+ # create block var iter binding
+ values: List[PrimExpr]
+ if not block_info.iter_bindings:
+ values = self.context.loop_stack[-2].copy()
Review comment:
Are we guaranteed that the loop_stack has two items in it?
##########
File path: python/tvm/script/scope_handler.py
##########
@@ -185,92 +226,365 @@ def let(var, value, span):
super().__init__(let, concise_scope=False, def_symbol=False)
+@register
+class Block(WithScopeHandler):
+ """ With scope handler tir.block(extents, name) as iter_vars"""
+
+ def __init__(self):
+ def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+ assert self.node
+ assert self.context
+ assert self.body
+ block_info = self.context.block_info_stack[-1]
+ if axes is None:
+ axes = []
+ if len(axes) != len(self.block_vars):
+ self.context.report_error(
+ "Inconsistent number of block vars, "
+ + f"gets {len(axes)} axes but {len(self.block_vars)} block
vars.",
Review comment:
```suggestion
+ f"there are {len(axes)} axes but
{len(self.block_vars)} block vars. The number of block vars should match the
number of axes.",
```
##########
File path: python/tvm/script/node.py
##########
@@ -0,0 +1,150 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+"""TVM Script nodes."""
+
+from typing import Optional, Union, List, Callable
+import synr
+
+from tvm.runtime import ObjectGeneric
+from tvm.tir import PrimExpr, Buffer, BufferLoad
+from tvm.ir import Span
+
+
+class Slice:
+ """A helper class to present slice information for BufferSlice
+
+ Parameters
+ ----------
+ start : Union[PrimExpr, int]
+ The start index.
+
+ stop : Optional[Union[PrimExpr, int]]
+ The stop index, None means the Slice is an element-wise index
+
+ span : Optional[Span]
+ The location of the slice in the source.
+ """
+
+ start: Union[PrimExpr, int]
+ stop: Optional[Union[PrimExpr, int]]
+ span: Optional[Span]
+
+ def __init__(
+ self,
+ start: Union[PrimExpr, int],
+ stop: Optional[Union[PrimExpr, int]] = None,
+ span: Optional[Span] = None,
+ ):
+ self.start = start
+ self.stop = stop
+ self.span = span
+
+
+class BufferSlice(ObjectGeneric):
+ """A generic object for representing general buffer access. Following
cases are supported:
+ - element wise access buffer[i, j], which can be convert to BufferLoad
if necessary
+ - slice access buffer[i: i + 1, j : j + 2]
+ - union of element and slice buffer[i, j: j + 2]
+
+ This node is used in TVMScript to parse BufferLoad, BufferRegion and
Realize
+
+ Parameters
+ ----------
+ buffer : Buffer
+ The buffer.
+
+ indices : List[Union[Slice, PrimExpr, int]]
+ The access indexes can be slice, PrimExpr or int.
+
+ report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
+ The error report func
+
+ span : Optional[Span]
+ The location of the buffer access in the source.
+ """
+
+ buffer: Buffer
+ slices: List[Slice]
+ report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
+ span: Optional[Span]
+
+ def __init__(
+ self,
+ buffer: Buffer,
+ indices: List[Union[Slice, PrimExpr, int]],
+ report_error: Callable[[str, Union[Span, synr.ast.Span]], None],
+ span: Optional[Span] = None,
+ ):
+ def check_index(index: Union[int, PrimExpr]):
+ """ Check input index is non-negative integer or PrimExpr"""
+ if isinstance(index, int):
+ if index < 0:
+ report_error("Negative index is not allowed during buffer
access", span)
+ elif isinstance(index, PrimExpr):
+ if index.dtype != "int32":
+ report_error(
+ "index expects an int32 type PrimExpr but gets " +
str(index.dtype),
+ index.span,
+ )
+ else:
+ report_error(
+ "Unsupported index type, expects int or tvm.tir.PrimExpr,
but gets "
Review comment:
Use 'expected ..., but got ...' instead of 'expects ..., but gets'.
##########
File path: python/tvm/script/scope_handler.py
##########
@@ -136,8 +166,19 @@ class Realize(WithScopeHandler):
""" With scope handler tir.realize(buffer_bounds, scope, condition) """
def __init__(self):
- def realize(buffer_bounds, scope, condition=True, span=None):
- buffer, bounds = buffer_bounds
+ def realize(
+ buffer_slice: BufferSlice, scope: str, condition: bool = True,
span: bool = None
+ ):
+ assert self.context
Review comment:
When can context be `None`? Maybe an error message here would be helpful.
##########
File path: python/tvm/script/parser.py
##########
@@ -716,47 +767,50 @@ def transform_Slice(self, node):
end = self.transform(node.end)
if not (isinstance(node.step, ast.Constant) and node.step.value == 1):
self.report_error("Only step size 1 is supported for slices.",
node.step.span)
- extent = end - start
- if isinstance(extent, tvm.tir.PrimExpr):
- ana = tvm.arith.Analyzer()
- extent = ana.simplify(extent)
- return tvm.ir.Range.from_min_extent(start, extent,
span=from_synr_span(node.span))
+ return Slice(start, end)
def transform_Subscript(self, node):
"""Array access visitor.
By now only 2 types of Subscript are supported:
1. Buffer[index, index, ...], Buffer element access(BufferLoad &
BufferStore)
Var[index] Buffer element access()
- 2. meta[type_key][index], Meta info access
+ 2. Buffer[start: stop, start: stop, ...],
BufferRealize(realize(buffer[...]))
"""
symbol = self.transform(node.params[0])
if symbol is None:
self.report_error(f"Variable {node.value.id} is not defined.",
node.params[0].span)
indexes = [self.transform(x) for x in node.params[1].values]
- if isinstance(indexes[0], tvm.ir.Range):
- return symbol, indexes
-
if isinstance(symbol, tvm.tir.expr.Var):
- return tvm.tir.Load("float32", symbol, indexes, True,
span=from_synr_span(node.span))
- if isinstance(symbol, tvm.tir.Buffer):
- return tvm.tir.BufferLoad(symbol, indexes,
span=from_synr_span(node.span))
-
- self.report_error(
- f"Cannot subscript from a {type(symbol).__name__}. Only variables
and "
- "buffers are supported.",
- node.params[0].span,
- )
+ for index in indexes:
+ if not isinstance(index, (tvm.tir.PrimExpr, int)):
+ self.report_error(
+ "Buffer load indexes expect int or PrimExpr, but get "
+ type(index),
Review comment:
```suggestion
"Buffer load indexes expects int or PrimExpr, but
got " + type(index),
```
##########
File path: python/tvm/script/scope_handler.py
##########
@@ -185,92 +226,365 @@ def let(var, value, span):
super().__init__(let, concise_scope=False, def_symbol=False)
+@register
+class Block(WithScopeHandler):
+ """ With scope handler tir.block(extents, name) as iter_vars"""
+
+ def __init__(self):
+ def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+ assert self.node
+ assert self.context
+ assert self.body
Review comment:
I know this isn't part of the work on this PR, but I do want to say that
the control flow for `ScopeHandler` is incredibly confusing. It should
definitely be re-worked.
Error messages here would be helpful for debugging the cases when we forgot
to set the members correctly.
##########
File path: python/tvm/script/scope_handler.py
##########
@@ -57,22 +84,19 @@ def get_optional_var_names(node, context):
"""Get list of names from ast.With's optional_vars"""
assert isinstance(node, ast.With)
- var_names = None
- if isinstance(node.items[0].optional_vars, ast.Name):
- var_names = [node.items[0].optional_vars.id]
- elif isinstance(node.items[0].optional_vars, (ast.List, ast.Tuple)):
- for var in node.items[0].optional_vars.elts:
- if not isinstance(var, ast.Name):
- context.report_error("Invalid optional var definition")
- var_names = [var.id for var in node.items[0].optional_vars.elts]
+ if isinstance(node.lhs, list):
+ for var in node.lhs:
+ if not isinstance(var, ast.Var):
+ context.report_error("Invalid optional var definition",
node.span)
Review comment:
I know this error message is not part of this PR, but you think you
could improve it. Specifically, say what is allowed.
##########
File path: python/tvm/script/parser.py
##########
@@ -401,25 +413,47 @@ def my_function(x: ty.handle): # 1. Argument types
"""
self.init_function_parsing_env()
- self.context.new_scope(nodes=node.body.stmts)
+ self.context.enter_scope(nodes=node.body.stmts)
# add parameters of function
for arg in node.params:
arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
- self.context.update_symbol(arg.name, arg_var)
+ self.context.update_symbol(arg.name, arg_var, node)
self.context.func_params.append(arg_var)
- # fetch the body and return a tir.PrimFunc
+ # New Scope : Implicit root block
Review comment:
I think we should document the difference between scopes and block
scopes in this file. It will help future readers understand the code.
##########
File path: python/tvm/script/scope_handler.py
##########
@@ -185,92 +226,365 @@ def let(var, value, span):
super().__init__(let, concise_scope=False, def_symbol=False)
+@register
+class Block(WithScopeHandler):
+ """ With scope handler tir.block(extents, name) as iter_vars"""
+
+ def __init__(self):
+ def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+ assert self.node
+ assert self.context
+ assert self.body
+ block_info = self.context.block_info_stack[-1]
+ if axes is None:
+ axes = []
+ if len(axes) != len(self.block_vars):
+ self.context.report_error(
+ "Inconsistent number of block vars, "
+ + f"gets {len(axes)} axes but {len(self.block_vars)} block
vars.",
+ self.node.span,
+ )
+ block_iters: List[IterVar] = []
+ for i, axis in enumerate(axes):
+ axis = tvm.runtime.convert(axis)
+ if isinstance(axis, tvm.tir.PrimExpr):
+ block_var_dom = Range.from_min_extent(0, axis)
+ block_iters.append(IterVar(block_var_dom,
self.block_vars[i], 0))
+ elif isinstance(axis, Range):
+ block_iters.append(IterVar(axis, self.block_vars[i], 0))
+ elif isinstance(axis, IterVar):
+ block_iters.append(IterVar(axis.dom, self.block_vars[i],
axis.iter_type))
+ else:
+ self.context.report_error(
+ "Invalid argument of tir.block(), "
+ + f"expects PrimExpr, Range or IterVar, but gets
{type(axis)}",
+ self.node.span,
+ )
+
+ # create block read/write regions
+
+ reads: List[BufferRegion] = (
+ [buffer_slice_to_region(read) for read in block_info.reads]
+ if block_info.reads
+ else []
+ )
+ writes: List[BufferRegion] = (
+ [buffer_slice_to_region(write) for write in block_info.writes]
+ if block_info.writes
+ else []
+ )
+ inner = tvm.tir.Block(
+ block_iters,
+ reads,
+ writes,
+ name_hint,
+ self.body,
+ block_info.init,
+ block_info.alloc_buffers,
+ block_info.match_buffers,
+ block_info.annotations,
+ span,
+ )
+ # create block var iter binding
+ values: List[PrimExpr]
+ if not block_info.iter_bindings:
+ values = self.context.loop_stack[-2].copy()
+ if len(values) == 0:
+ values = [tvm.tir.const(float("nan"), dtype="float32")] *
len(block_iters)
+ elif len(values) != len(block_iters):
+ self.context.report_error(
+ "Autocomplete block iter var binding expects larger
number of loops",
Review comment:
Could you explain a little more in this error message. How many loops
are expected? And what you mean by autocomplete? Is this an inferred or
automatically inserted iter var?
##########
File path: python/tvm/script/scope_handler.py
##########
@@ -185,92 +226,365 @@ def let(var, value, span):
super().__init__(let, concise_scope=False, def_symbol=False)
+@register
+class Block(WithScopeHandler):
+ """ With scope handler tir.block(extents, name) as iter_vars"""
+
+ def __init__(self):
+ def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+ assert self.node
+ assert self.context
+ assert self.body
+ block_info = self.context.block_info_stack[-1]
+ if axes is None:
+ axes = []
+ if len(axes) != len(self.block_vars):
+ self.context.report_error(
+ "Inconsistent number of block vars, "
+ + f"gets {len(axes)} axes but {len(self.block_vars)} block
vars.",
+ self.node.span,
+ )
+ block_iters: List[IterVar] = []
+ for i, axis in enumerate(axes):
+ axis = tvm.runtime.convert(axis)
+ if isinstance(axis, tvm.tir.PrimExpr):
+ block_var_dom = Range.from_min_extent(0, axis)
+ block_iters.append(IterVar(block_var_dom,
self.block_vars[i], 0))
+ elif isinstance(axis, Range):
+ block_iters.append(IterVar(axis, self.block_vars[i], 0))
+ elif isinstance(axis, IterVar):
+ block_iters.append(IterVar(axis.dom, self.block_vars[i],
axis.iter_type))
+ else:
+ self.context.report_error(
+ "Invalid argument of tir.block(), "
+ + f"expects PrimExpr, Range or IterVar, but gets
{type(axis)}",
+ self.node.span,
+ )
+
+ # create block read/write regions
+
+ reads: List[BufferRegion] = (
+ [buffer_slice_to_region(read) for read in block_info.reads]
+ if block_info.reads
+ else []
+ )
+ writes: List[BufferRegion] = (
+ [buffer_slice_to_region(write) for write in block_info.writes]
+ if block_info.writes
+ else []
+ )
+ inner = tvm.tir.Block(
+ block_iters,
+ reads,
+ writes,
+ name_hint,
+ self.body,
+ block_info.init,
+ block_info.alloc_buffers,
+ block_info.match_buffers,
+ block_info.annotations,
+ span,
+ )
+ # create block var iter binding
+ values: List[PrimExpr]
+ if not block_info.iter_bindings:
+ values = self.context.loop_stack[-2].copy()
+ if len(values) == 0:
+ values = [tvm.tir.const(float("nan"), dtype="float32")] *
len(block_iters)
+ elif len(values) != len(block_iters):
+ self.context.report_error(
+ "Autocomplete block iter var binding expects larger
number of loops",
+ self.node.span,
+ )
+ else:
+ for block_var in self.block_vars:
+ if block_var not in block_info.iter_bindings:
+ self.context.report_error(
+ "Missing block iter var binding for " +
block_var.name,
+ self.node.span,
+ )
+ values = [block_info.iter_bindings[block_var] for block_var in
self.block_vars]
+ predicate = (
+ tvm.tir.const(True, "bool")
+ if block_info.predicate is None
+ else block_info.predicate
+ )
+ body = tvm.tir.BlockRealize(values, predicate, inner, span)
+ return body
+
+ super().__init__(func=block, concise_scope=False, def_symbol=True)
+ self.block_vars = None
+
+ def enter_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
+ # define block vars
+ assert isinstance(node, ast.With)
+
+ var_names = WithScopeHandler.get_optional_var_names(node, context)
+ self.block_vars = [tvm.te.var(name) for name in var_names]
+ for block_var in self.block_vars:
+ context.update_symbol(block_var.name, block_var, node)
+
+
+@register
+class InitBlock(WithScopeHandler):
+ """ With scope handler tir.init()"""
+
+ def __init__(self):
+ def init(span: Span = None):
+ assert self.context
+ if self.context.block_info_stack[-2].init is not None:
+ self.context.report_error("Duplicate init block declaration",
span)
Review comment:
Can you add the previous init span to this error message?
##########
File path: python/tvm/script/scope_handler.py
##########
@@ -185,92 +226,365 @@ def let(var, value, span):
super().__init__(let, concise_scope=False, def_symbol=False)
+@register
+class Block(WithScopeHandler):
+ """ With scope handler tir.block(extents, name) as iter_vars"""
+
+ def __init__(self):
+ def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+ assert self.node
+ assert self.context
+ assert self.body
+ block_info = self.context.block_info_stack[-1]
+ if axes is None:
+ axes = []
+ if len(axes) != len(self.block_vars):
+ self.context.report_error(
+ "Inconsistent number of block vars, "
+ + f"gets {len(axes)} axes but {len(self.block_vars)} block
vars.",
+ self.node.span,
+ )
+ block_iters: List[IterVar] = []
+ for i, axis in enumerate(axes):
+ axis = tvm.runtime.convert(axis)
+ if isinstance(axis, tvm.tir.PrimExpr):
+ block_var_dom = Range.from_min_extent(0, axis)
+ block_iters.append(IterVar(block_var_dom,
self.block_vars[i], 0))
+ elif isinstance(axis, Range):
+ block_iters.append(IterVar(axis, self.block_vars[i], 0))
+ elif isinstance(axis, IterVar):
+ block_iters.append(IterVar(axis.dom, self.block_vars[i],
axis.iter_type))
+ else:
+ self.context.report_error(
+ "Invalid argument of tir.block(), "
+ + f"expects PrimExpr, Range or IterVar, but gets
{type(axis)}",
+ self.node.span,
+ )
+
+ # create block read/write regions
+
+ reads: List[BufferRegion] = (
+ [buffer_slice_to_region(read) for read in block_info.reads]
+ if block_info.reads
+ else []
+ )
+ writes: List[BufferRegion] = (
+ [buffer_slice_to_region(write) for write in block_info.writes]
+ if block_info.writes
+ else []
+ )
+ inner = tvm.tir.Block(
+ block_iters,
+ reads,
+ writes,
+ name_hint,
+ self.body,
+ block_info.init,
+ block_info.alloc_buffers,
+ block_info.match_buffers,
+ block_info.annotations,
+ span,
+ )
+ # create block var iter binding
+ values: List[PrimExpr]
+ if not block_info.iter_bindings:
+ values = self.context.loop_stack[-2].copy()
+ if len(values) == 0:
+ values = [tvm.tir.const(float("nan"), dtype="float32")] *
len(block_iters)
+ elif len(values) != len(block_iters):
+ self.context.report_error(
+ "Autocomplete block iter var binding expects larger
number of loops",
+ self.node.span,
+ )
+ else:
+ for block_var in self.block_vars:
+ if block_var not in block_info.iter_bindings:
+ self.context.report_error(
+ "Missing block iter var binding for " +
block_var.name,
+ self.node.span,
+ )
+ values = [block_info.iter_bindings[block_var] for block_var in
self.block_vars]
+ predicate = (
+ tvm.tir.const(True, "bool")
+ if block_info.predicate is None
+ else block_info.predicate
+ )
+ body = tvm.tir.BlockRealize(values, predicate, inner, span)
+ return body
+
+ super().__init__(func=block, concise_scope=False, def_symbol=True)
+ self.block_vars = None
+
+ def enter_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
+ # define block vars
+ assert isinstance(node, ast.With)
+
+ var_names = WithScopeHandler.get_optional_var_names(node, context)
+ self.block_vars = [tvm.te.var(name) for name in var_names]
+ for block_var in self.block_vars:
+ context.update_symbol(block_var.name, block_var, node)
+
+
+@register
+class InitBlock(WithScopeHandler):
+ """ With scope handler tir.init()"""
+
+ def __init__(self):
+ def init(span: Span = None):
+ assert self.context
+ if self.context.block_info_stack[-2].init is not None:
+ self.context.report_error("Duplicate init block declaration",
span)
+ self.context.block_info_stack[-2].init = self.body
+
+ super().__init__(func=init, concise_scope=False, def_symbol=True)
+
+
class ForScopeHandler(ScopeHandler):
"""Base class for all for scope handlers"""
def __init__(self, func):
super().__init__(func)
- self.loop_vars = None
-
- def enter_scope(self, node, context, arg_list, span):
+ self.loop_vars: Optional[List[Var]] = None
+
+ def enter_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
assert isinstance(node, ast.For)
loop_var_names = list()
spans = list()
if isinstance(node.lhs, ast.Var):
loop_var_names.append(node.lhs.id.name)
- spans.append(from_synr_span(node.lhs.id.span))
- elif isinstance(node.lhs, ast.Tuple):
- for elt in node.lhs.values:
+ spans.append(tvm_span_from_synr(node.lhs.id.span))
+ elif isinstance(node.lhs, list):
+ for elt in node.lhs:
if not isinstance(elt, ast.Var):
context.report_error("Invalid loop var", elt.span)
Review comment:
Improve this too please.
##########
File path: python/tvm/script/scope_handler.py
##########
@@ -185,92 +226,365 @@ def let(var, value, span):
super().__init__(let, concise_scope=False, def_symbol=False)
+@register
+class Block(WithScopeHandler):
+ """ With scope handler tir.block(extents, name) as iter_vars"""
+
+ def __init__(self):
+ def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+ assert self.node
+ assert self.context
+ assert self.body
+ block_info = self.context.block_info_stack[-1]
+ if axes is None:
+ axes = []
+ if len(axes) != len(self.block_vars):
+ self.context.report_error(
+ "Inconsistent number of block vars, "
+ + f"gets {len(axes)} axes but {len(self.block_vars)} block
vars.",
+ self.node.span,
+ )
+ block_iters: List[IterVar] = []
+ for i, axis in enumerate(axes):
+ axis = tvm.runtime.convert(axis)
+ if isinstance(axis, tvm.tir.PrimExpr):
+ block_var_dom = Range.from_min_extent(0, axis)
+ block_iters.append(IterVar(block_var_dom,
self.block_vars[i], 0))
+ elif isinstance(axis, Range):
+ block_iters.append(IterVar(axis, self.block_vars[i], 0))
+ elif isinstance(axis, IterVar):
+ block_iters.append(IterVar(axis.dom, self.block_vars[i],
axis.iter_type))
+ else:
+ self.context.report_error(
+ "Invalid argument of tir.block(), "
+ + f"expects PrimExpr, Range or IterVar, but gets
{type(axis)}",
+ self.node.span,
+ )
+
+ # create block read/write regions
+
+ reads: List[BufferRegion] = (
+ [buffer_slice_to_region(read) for read in block_info.reads]
+ if block_info.reads
+ else []
+ )
+ writes: List[BufferRegion] = (
+ [buffer_slice_to_region(write) for write in block_info.writes]
+ if block_info.writes
+ else []
+ )
+ inner = tvm.tir.Block(
+ block_iters,
+ reads,
+ writes,
+ name_hint,
+ self.body,
+ block_info.init,
+ block_info.alloc_buffers,
+ block_info.match_buffers,
+ block_info.annotations,
+ span,
+ )
+ # create block var iter binding
+ values: List[PrimExpr]
+ if not block_info.iter_bindings:
+ values = self.context.loop_stack[-2].copy()
+ if len(values) == 0:
+ values = [tvm.tir.const(float("nan"), dtype="float32")] *
len(block_iters)
+ elif len(values) != len(block_iters):
+ self.context.report_error(
+ "Autocomplete block iter var binding expects larger
number of loops",
+ self.node.span,
+ )
+ else:
+ for block_var in self.block_vars:
+ if block_var not in block_info.iter_bindings:
+ self.context.report_error(
+ "Missing block iter var binding for " +
block_var.name,
+ self.node.span,
+ )
+ values = [block_info.iter_bindings[block_var] for block_var in
self.block_vars]
+ predicate = (
+ tvm.tir.const(True, "bool")
+ if block_info.predicate is None
+ else block_info.predicate
+ )
+ body = tvm.tir.BlockRealize(values, predicate, inner, span)
+ return body
+
+ super().__init__(func=block, concise_scope=False, def_symbol=True)
+ self.block_vars = None
+
+ def enter_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
+ # define block vars
+ assert isinstance(node, ast.With)
+
+ var_names = WithScopeHandler.get_optional_var_names(node, context)
+ self.block_vars = [tvm.te.var(name) for name in var_names]
+ for block_var in self.block_vars:
+ context.update_symbol(block_var.name, block_var, node)
+
+
+@register
+class InitBlock(WithScopeHandler):
+ """ With scope handler tir.init()"""
+
+ def __init__(self):
+ def init(span: Span = None):
+ assert self.context
+ if self.context.block_info_stack[-2].init is not None:
+ self.context.report_error("Duplicate init block declaration",
span)
+ self.context.block_info_stack[-2].init = self.body
+
+ super().__init__(func=init, concise_scope=False, def_symbol=True)
+
+
class ForScopeHandler(ScopeHandler):
"""Base class for all for scope handlers"""
def __init__(self, func):
super().__init__(func)
- self.loop_vars = None
-
- def enter_scope(self, node, context, arg_list, span):
+ self.loop_vars: Optional[List[Var]] = None
+
+ def enter_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
assert isinstance(node, ast.For)
loop_var_names = list()
spans = list()
if isinstance(node.lhs, ast.Var):
loop_var_names.append(node.lhs.id.name)
- spans.append(from_synr_span(node.lhs.id.span))
- elif isinstance(node.lhs, ast.Tuple):
- for elt in node.lhs.values:
+ spans.append(tvm_span_from_synr(node.lhs.id.span))
+ elif isinstance(node.lhs, list):
+ for elt in node.lhs:
if not isinstance(elt, ast.Var):
context.report_error("Invalid loop var", elt.span)
loop_var_names.append(elt.id.name)
- spans.append(from_synr_span(elt.id.span))
+ spans.append(tvm_span_from_synr(elt.id.span))
else:
- context.report_error("Invalid loop var", node.lhs.span)
+ context.report_error("Invalid loop var in loop", span)
Review comment:
Could you improve this error message. Say what is expected and what it
actually got.
##########
File path: python/tvm/script/scope_handler.py
##########
@@ -185,92 +226,365 @@ def let(var, value, span):
super().__init__(let, concise_scope=False, def_symbol=False)
+@register
+class Block(WithScopeHandler):
+ """ With scope handler tir.block(extents, name) as iter_vars"""
+
+ def __init__(self):
+ def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+ assert self.node
+ assert self.context
+ assert self.body
+ block_info = self.context.block_info_stack[-1]
+ if axes is None:
+ axes = []
+ if len(axes) != len(self.block_vars):
+ self.context.report_error(
+ "Inconsistent number of block vars, "
+ + f"gets {len(axes)} axes but {len(self.block_vars)} block
vars.",
+ self.node.span,
+ )
+ block_iters: List[IterVar] = []
+ for i, axis in enumerate(axes):
+ axis = tvm.runtime.convert(axis)
+ if isinstance(axis, tvm.tir.PrimExpr):
+ block_var_dom = Range.from_min_extent(0, axis)
+ block_iters.append(IterVar(block_var_dom,
self.block_vars[i], 0))
+ elif isinstance(axis, Range):
+ block_iters.append(IterVar(axis, self.block_vars[i], 0))
+ elif isinstance(axis, IterVar):
+ block_iters.append(IterVar(axis.dom, self.block_vars[i],
axis.iter_type))
+ else:
+ self.context.report_error(
+ "Invalid argument of tir.block(), "
+ + f"expects PrimExpr, Range or IterVar, but gets
{type(axis)}",
+ self.node.span,
+ )
+
+ # create block read/write regions
+
+ reads: List[BufferRegion] = (
+ [buffer_slice_to_region(read) for read in block_info.reads]
+ if block_info.reads
+ else []
+ )
+ writes: List[BufferRegion] = (
+ [buffer_slice_to_region(write) for write in block_info.writes]
+ if block_info.writes
+ else []
+ )
+ inner = tvm.tir.Block(
+ block_iters,
+ reads,
+ writes,
+ name_hint,
+ self.body,
+ block_info.init,
+ block_info.alloc_buffers,
+ block_info.match_buffers,
+ block_info.annotations,
+ span,
+ )
+ # create block var iter binding
+ values: List[PrimExpr]
+ if not block_info.iter_bindings:
+ values = self.context.loop_stack[-2].copy()
+ if len(values) == 0:
+ values = [tvm.tir.const(float("nan"), dtype="float32")] *
len(block_iters)
+ elif len(values) != len(block_iters):
+ self.context.report_error(
+ "Autocomplete block iter var binding expects larger
number of loops",
+ self.node.span,
+ )
+ else:
+ for block_var in self.block_vars:
+ if block_var not in block_info.iter_bindings:
+ self.context.report_error(
+ "Missing block iter var binding for " +
block_var.name,
+ self.node.span,
+ )
+ values = [block_info.iter_bindings[block_var] for block_var in
self.block_vars]
+ predicate = (
+ tvm.tir.const(True, "bool")
+ if block_info.predicate is None
+ else block_info.predicate
+ )
+ body = tvm.tir.BlockRealize(values, predicate, inner, span)
+ return body
+
+ super().__init__(func=block, concise_scope=False, def_symbol=True)
+ self.block_vars = None
+
+ def enter_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
+ # define block vars
+ assert isinstance(node, ast.With)
+
+ var_names = WithScopeHandler.get_optional_var_names(node, context)
+ self.block_vars = [tvm.te.var(name) for name in var_names]
+ for block_var in self.block_vars:
+ context.update_symbol(block_var.name, block_var, node)
+
+
+@register
+class InitBlock(WithScopeHandler):
+ """ With scope handler tir.init()"""
+
+ def __init__(self):
+ def init(span: Span = None):
+ assert self.context
+ if self.context.block_info_stack[-2].init is not None:
+ self.context.report_error("Duplicate init block declaration",
span)
+ self.context.block_info_stack[-2].init = self.body
+
+ super().__init__(func=init, concise_scope=False, def_symbol=True)
+
+
class ForScopeHandler(ScopeHandler):
"""Base class for all for scope handlers"""
def __init__(self, func):
super().__init__(func)
- self.loop_vars = None
-
- def enter_scope(self, node, context, arg_list, span):
+ self.loop_vars: Optional[List[Var]] = None
+
+ def enter_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
assert isinstance(node, ast.For)
loop_var_names = list()
spans = list()
if isinstance(node.lhs, ast.Var):
loop_var_names.append(node.lhs.id.name)
- spans.append(from_synr_span(node.lhs.id.span))
- elif isinstance(node.lhs, ast.Tuple):
- for elt in node.lhs.values:
+ spans.append(tvm_span_from_synr(node.lhs.id.span))
+ elif isinstance(node.lhs, list):
+ for elt in node.lhs:
if not isinstance(elt, ast.Var):
context.report_error("Invalid loop var", elt.span)
loop_var_names.append(elt.id.name)
- spans.append(from_synr_span(elt.id.span))
+ spans.append(tvm_span_from_synr(elt.id.span))
else:
- context.report_error("Invalid loop var", node.lhs.span)
+ context.report_error("Invalid loop var in loop", span)
self.loop_vars = [
tvm.te.var(name, dtype="int32", span=span) for name, span in
zip(loop_var_names, spans)
]
for loop_var in self.loop_vars:
- context.update_symbol(loop_var.name, loop_var)
+ context.update_symbol(loop_var.name, loop_var, node)
+ context.loop_stack[-1].append(loop_var)
+
+ def exit_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
+ assert self.loop_vars
+ for _ in self.loop_vars:
+ context.loop_stack[-1].pop()
+ return super().exit_scope(node, context, arg_list, span)
+
+ def create_loop(
+ self,
+ begin: PrimExpr,
+ end: PrimExpr,
+ kind: int,
+ thread_binding: Optional[str] = None,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ) -> tvm.tir.For:
+ """
+ Helper function for creating For in TVM Script parser.
+
+ Parameters
+ ----------
+ begin : PrimExpr
+ The beginning value.
+
+ end : PrimExpr
+ The endding value.
+
+ kind : ForKind
+ The type of the for.
+
+ thread_binding: Optional[str]
+ The thread this loop binds to.
+
+ annotations : Optional[Mapping[str, Object]]
+ Additional annotation hints.
+
+ span : Optional[Span]
+ The location of this for in the source code.
+
+ Returns
+ -------
+ for : For
+ The constructed For.
+ """
+ assert self.node
+ assert self.context
+ assert self.loop_vars
+ if len(self.loop_vars) != 1:
+ self.context.report_error(
+ f"Expect exact only one loop var, but get {self.loop_vars}",
self.node.span
+ )
+ extent = end if begin == 0 else self.context.analyzer.simplify(end -
begin)
+ annos: Mapping[str, Object]
+ if annotations is None:
+ annos = {}
+ else:
+ annos = {
+ key: tvm.tir.StringImm(val) if isinstance(val, str) else val
+ for key, val in annotations.items()
+ }
+ return tvm.tir.For(
+ self.loop_vars[0],
+ begin,
+ extent,
+ kind,
+ self.body,
+ thread_binding=thread_binding,
+ annotations=annos,
+ span=span,
+ )
@register
class Serial(ForScopeHandler):
- """ For scope handler tir.serial(begin, end)"""
+ """ For scope handler tir.serial(begin, end, annotations)"""
def __init__(self):
- def serial(begin, end, span):
- if len(self.loop_vars) != 1:
- self.context.report_error("Expect exact 1 loop var", span)
- ana = tvm.arith.Analyzer()
- extent = end if begin == 0 else ana.simplify(end - begin)
- return tvm.tir.For(self.loop_vars[0], begin, extent, 0, self.body,
span=span)
+ def serial(
+ begin: PrimExpr,
+ end: PrimExpr,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ):
+ return self.create_loop(begin, end, 0, annotations=annotations,
span=span)
super().__init__(serial)
@register
class Parallel(ForScopeHandler):
- """ For scope handler tir.parallel(begin, end)"""
+ """ For scope handler tir.parallel(begin, end, annotations)"""
def __init__(self):
- def parallel(begin, end, span):
- if len(self.loop_vars) != 1:
- self.context.report_error("Expect exact 1 loop var")
- ana = tvm.arith.Analyzer()
- extent = end if begin == 0 else ana.simplify(end - begin)
- return tvm.tir.For(self.loop_vars[0], begin, extent, 1, self.body,
span=span)
+ def parallel(
+ begin: PrimExpr,
+ end: PrimExpr,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ):
+ return self.create_loop(begin, end, 1, annotations=annotations,
span=span)
super().__init__(parallel)
@register
class Vectorized(ForScopeHandler):
- """ For scope handler tir.vectorized(begin, end)"""
+ """ For scope handler tir.vectorized(begin, end, annotations)"""
def __init__(self):
- def vectorized(begin, end, span):
- if len(self.loop_vars) != 1:
- self.context.report_error("Expect exact 1 loop var")
- ana = tvm.arith.Analyzer()
- extent = end if begin == 0 else ana.simplify(end - begin)
- return tvm.tir.For(self.loop_vars[0], begin, extent, 2, self.body,
span=span)
+ def vectorized(
+ begin: PrimExpr,
+ end: PrimExpr,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ):
+ return self.create_loop(begin, end, 2, annotations=annotations,
span=span)
super().__init__(vectorized)
@register
class Unroll(ForScopeHandler):
- """ For scope handler tir.unroll(begin, end)"""
+ """ For scope handler tir.unroll(begin, end, annotations)"""
def __init__(self):
- def unroll(begin, end, span):
- if len(self.loop_vars) != 1:
- self.context.report_error("Expect exact 1 loop var")
- ana = tvm.arith.Analyzer()
- extent = end if begin == 0 else ana.simplify(end - begin)
- return tvm.tir.For(self.loop_vars[0], begin, extent, 3, self.body,
span=span)
+ def unroll(
+ begin: PrimExpr,
+ end: PrimExpr,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ):
+ return self.create_loop(begin, end, 3, annotations=annotations,
span=span)
super().__init__(unroll)
+
+
+@register
+class ThreadBinding(ForScopeHandler):
+ """ For scope handler tir.thread_binding(begin, end, thread,
annotations)"""
+
+ def __init__(self):
+ def thread_binding(
+ begin: PrimExpr,
+ end: PrimExpr,
+ thread: str,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ):
+ thread_iter_var = IterVar(None, None, 1, thread, span=span)
+ return self.create_loop(
+ begin,
+ end,
+ 4,
+ thread_binding=thread_iter_var,
+ annotations=annotations,
+ span=span,
+ )
+
+ super().__init__(thread_binding)
+
+
+@register
+class RangeHandler(ForScopeHandler):
+ """For scope handler range(begin, end, annotations)
+ Note that tir.range is totally the same as tir.serial
+ """
+
+ def __init__(self):
+ def for_range(
+ begin: PrimExpr,
+ end: PrimExpr,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ):
+ return self.create_loop(begin, end, 0, annotations=annotations,
span=span)
+
+ super().__init__(for_range)
+
+ def signature(self):
+ return "range", get_param_list(self.func)
+
+
+@register
+class Grid(ForScopeHandler):
+ """ For scope handler tir.grid(extents)"""
+
+ def __init__(self):
+ def grid(*extents: List[PrimExpr], span: Span):
+ assert self.node
+ assert self.context
+ assert self.loop_vars
+ if len(self.loop_vars) != len(extents):
+ self.context.report_error(
+ "Inconsistent number of loop vars and extents",
self.node.span
Review comment:
Please add the actual numbers.
##########
File path: python/tvm/script/scope_handler.py
##########
@@ -185,92 +226,365 @@ def let(var, value, span):
super().__init__(let, concise_scope=False, def_symbol=False)
+@register
+class Block(WithScopeHandler):
+ """ With scope handler tir.block(extents, name) as iter_vars"""
+
+ def __init__(self):
+ def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+ assert self.node
+ assert self.context
+ assert self.body
+ block_info = self.context.block_info_stack[-1]
+ if axes is None:
+ axes = []
+ if len(axes) != len(self.block_vars):
+ self.context.report_error(
+ "Inconsistent number of block vars, "
+ + f"gets {len(axes)} axes but {len(self.block_vars)} block
vars.",
+ self.node.span,
+ )
+ block_iters: List[IterVar] = []
+ for i, axis in enumerate(axes):
+ axis = tvm.runtime.convert(axis)
+ if isinstance(axis, tvm.tir.PrimExpr):
+ block_var_dom = Range.from_min_extent(0, axis)
+ block_iters.append(IterVar(block_var_dom,
self.block_vars[i], 0))
+ elif isinstance(axis, Range):
+ block_iters.append(IterVar(axis, self.block_vars[i], 0))
+ elif isinstance(axis, IterVar):
+ block_iters.append(IterVar(axis.dom, self.block_vars[i],
axis.iter_type))
+ else:
+ self.context.report_error(
+ "Invalid argument of tir.block(), "
+ + f"expects PrimExpr, Range or IterVar, but gets
{type(axis)}",
+ self.node.span,
+ )
+
+ # create block read/write regions
+
+ reads: List[BufferRegion] = (
+ [buffer_slice_to_region(read) for read in block_info.reads]
+ if block_info.reads
+ else []
+ )
+ writes: List[BufferRegion] = (
+ [buffer_slice_to_region(write) for write in block_info.writes]
+ if block_info.writes
+ else []
+ )
+ inner = tvm.tir.Block(
+ block_iters,
+ reads,
+ writes,
+ name_hint,
+ self.body,
+ block_info.init,
+ block_info.alloc_buffers,
+ block_info.match_buffers,
+ block_info.annotations,
+ span,
+ )
+ # create block var iter binding
+ values: List[PrimExpr]
+ if not block_info.iter_bindings:
+ values = self.context.loop_stack[-2].copy()
+ if len(values) == 0:
+ values = [tvm.tir.const(float("nan"), dtype="float32")] *
len(block_iters)
+ elif len(values) != len(block_iters):
+ self.context.report_error(
+ "Autocomplete block iter var binding expects larger
number of loops",
+ self.node.span,
+ )
+ else:
+ for block_var in self.block_vars:
+ if block_var not in block_info.iter_bindings:
+ self.context.report_error(
+ "Missing block iter var binding for " +
block_var.name,
+ self.node.span,
+ )
+ values = [block_info.iter_bindings[block_var] for block_var in
self.block_vars]
+ predicate = (
+ tvm.tir.const(True, "bool")
+ if block_info.predicate is None
+ else block_info.predicate
+ )
+ body = tvm.tir.BlockRealize(values, predicate, inner, span)
+ return body
+
+ super().__init__(func=block, concise_scope=False, def_symbol=True)
+ self.block_vars = None
+
+ def enter_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
+ # define block vars
+ assert isinstance(node, ast.With)
+
+ var_names = WithScopeHandler.get_optional_var_names(node, context)
+ self.block_vars = [tvm.te.var(name) for name in var_names]
+ for block_var in self.block_vars:
+ context.update_symbol(block_var.name, block_var, node)
+
+
+@register
+class InitBlock(WithScopeHandler):
+ """ With scope handler tir.init()"""
+
+ def __init__(self):
+ def init(span: Span = None):
+ assert self.context
+ if self.context.block_info_stack[-2].init is not None:
+ self.context.report_error("Duplicate init block declaration",
span)
+ self.context.block_info_stack[-2].init = self.body
+
+ super().__init__(func=init, concise_scope=False, def_symbol=True)
+
+
class ForScopeHandler(ScopeHandler):
"""Base class for all for scope handlers"""
def __init__(self, func):
super().__init__(func)
- self.loop_vars = None
-
- def enter_scope(self, node, context, arg_list, span):
+ self.loop_vars: Optional[List[Var]] = None
+
+ def enter_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
assert isinstance(node, ast.For)
loop_var_names = list()
spans = list()
if isinstance(node.lhs, ast.Var):
loop_var_names.append(node.lhs.id.name)
- spans.append(from_synr_span(node.lhs.id.span))
- elif isinstance(node.lhs, ast.Tuple):
- for elt in node.lhs.values:
+ spans.append(tvm_span_from_synr(node.lhs.id.span))
+ elif isinstance(node.lhs, list):
+ for elt in node.lhs:
if not isinstance(elt, ast.Var):
context.report_error("Invalid loop var", elt.span)
loop_var_names.append(elt.id.name)
- spans.append(from_synr_span(elt.id.span))
+ spans.append(tvm_span_from_synr(elt.id.span))
else:
- context.report_error("Invalid loop var", node.lhs.span)
+ context.report_error("Invalid loop var in loop", span)
self.loop_vars = [
tvm.te.var(name, dtype="int32", span=span) for name, span in
zip(loop_var_names, spans)
]
for loop_var in self.loop_vars:
- context.update_symbol(loop_var.name, loop_var)
+ context.update_symbol(loop_var.name, loop_var, node)
+ context.loop_stack[-1].append(loop_var)
+
+ def exit_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
+ assert self.loop_vars
+ for _ in self.loop_vars:
+ context.loop_stack[-1].pop()
+ return super().exit_scope(node, context, arg_list, span)
+
+ def create_loop(
+ self,
+ begin: PrimExpr,
+ end: PrimExpr,
+ kind: int,
Review comment:
```suggestion
kind: ForKind,
```
##########
File path: python/tvm/script/special_stmt.py
##########
@@ -67,17 +99,20 @@ def match_buffer(
buffer_type="default",
span=None,
):
- assert isinstance(self.node, ast.Assign)
-
+ if not isinstance(self.node, ast.Assign):
+ self.context.report_error(
+ "Need assign the match_buffer to a buffer, e.g. A =
match_buffer(...)",
Review comment:
```suggestion
"match_buffer must be assigned to a buffer, e.g. A =
match_buffer(...)",
```
##########
File path: python/tvm/script/scope_handler.py
##########
@@ -185,92 +226,365 @@ def let(var, value, span):
super().__init__(let, concise_scope=False, def_symbol=False)
+@register
+class Block(WithScopeHandler):
+ """ With scope handler tir.block(extents, name) as iter_vars"""
+
+ def __init__(self):
+ def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+ assert self.node
+ assert self.context
+ assert self.body
+ block_info = self.context.block_info_stack[-1]
+ if axes is None:
+ axes = []
+ if len(axes) != len(self.block_vars):
+ self.context.report_error(
+ "Inconsistent number of block vars, "
+ + f"gets {len(axes)} axes but {len(self.block_vars)} block
vars.",
+ self.node.span,
+ )
+ block_iters: List[IterVar] = []
+ for i, axis in enumerate(axes):
+ axis = tvm.runtime.convert(axis)
+ if isinstance(axis, tvm.tir.PrimExpr):
+ block_var_dom = Range.from_min_extent(0, axis)
+ block_iters.append(IterVar(block_var_dom,
self.block_vars[i], 0))
+ elif isinstance(axis, Range):
+ block_iters.append(IterVar(axis, self.block_vars[i], 0))
+ elif isinstance(axis, IterVar):
+ block_iters.append(IterVar(axis.dom, self.block_vars[i],
axis.iter_type))
+ else:
+ self.context.report_error(
+ "Invalid argument of tir.block(), "
+ + f"expects PrimExpr, Range or IterVar, but gets
{type(axis)}",
+ self.node.span,
+ )
+
+ # create block read/write regions
+
+ reads: List[BufferRegion] = (
+ [buffer_slice_to_region(read) for read in block_info.reads]
+ if block_info.reads
+ else []
+ )
+ writes: List[BufferRegion] = (
+ [buffer_slice_to_region(write) for write in block_info.writes]
+ if block_info.writes
+ else []
+ )
+ inner = tvm.tir.Block(
+ block_iters,
+ reads,
+ writes,
+ name_hint,
+ self.body,
+ block_info.init,
+ block_info.alloc_buffers,
+ block_info.match_buffers,
+ block_info.annotations,
+ span,
+ )
+ # create block var iter binding
+ values: List[PrimExpr]
+ if not block_info.iter_bindings:
+ values = self.context.loop_stack[-2].copy()
+ if len(values) == 0:
+ values = [tvm.tir.const(float("nan"), dtype="float32")] *
len(block_iters)
+ elif len(values) != len(block_iters):
+ self.context.report_error(
+ "Autocomplete block iter var binding expects larger
number of loops",
+ self.node.span,
+ )
+ else:
+ for block_var in self.block_vars:
+ if block_var not in block_info.iter_bindings:
+ self.context.report_error(
+ "Missing block iter var binding for " +
block_var.name,
+ self.node.span,
+ )
+ values = [block_info.iter_bindings[block_var] for block_var in
self.block_vars]
+ predicate = (
+ tvm.tir.const(True, "bool")
+ if block_info.predicate is None
+ else block_info.predicate
+ )
+ body = tvm.tir.BlockRealize(values, predicate, inner, span)
+ return body
+
+ super().__init__(func=block, concise_scope=False, def_symbol=True)
+ self.block_vars = None
+
+ def enter_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
+ # define block vars
+ assert isinstance(node, ast.With)
+
+ var_names = WithScopeHandler.get_optional_var_names(node, context)
+ self.block_vars = [tvm.te.var(name) for name in var_names]
+ for block_var in self.block_vars:
+ context.update_symbol(block_var.name, block_var, node)
+
+
+@register
+class InitBlock(WithScopeHandler):
+ """ With scope handler tir.init()"""
+
+ def __init__(self):
+ def init(span: Span = None):
+ assert self.context
+ if self.context.block_info_stack[-2].init is not None:
+ self.context.report_error("Duplicate init block declaration",
span)
+ self.context.block_info_stack[-2].init = self.body
+
+ super().__init__(func=init, concise_scope=False, def_symbol=True)
+
+
class ForScopeHandler(ScopeHandler):
"""Base class for all for scope handlers"""
def __init__(self, func):
super().__init__(func)
- self.loop_vars = None
-
- def enter_scope(self, node, context, arg_list, span):
+ self.loop_vars: Optional[List[Var]] = None
+
+ def enter_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
assert isinstance(node, ast.For)
loop_var_names = list()
spans = list()
if isinstance(node.lhs, ast.Var):
loop_var_names.append(node.lhs.id.name)
- spans.append(from_synr_span(node.lhs.id.span))
- elif isinstance(node.lhs, ast.Tuple):
- for elt in node.lhs.values:
+ spans.append(tvm_span_from_synr(node.lhs.id.span))
+ elif isinstance(node.lhs, list):
+ for elt in node.lhs:
if not isinstance(elt, ast.Var):
context.report_error("Invalid loop var", elt.span)
loop_var_names.append(elt.id.name)
- spans.append(from_synr_span(elt.id.span))
+ spans.append(tvm_span_from_synr(elt.id.span))
else:
- context.report_error("Invalid loop var", node.lhs.span)
+ context.report_error("Invalid loop var in loop", span)
self.loop_vars = [
tvm.te.var(name, dtype="int32", span=span) for name, span in
zip(loop_var_names, spans)
]
for loop_var in self.loop_vars:
- context.update_symbol(loop_var.name, loop_var)
+ context.update_symbol(loop_var.name, loop_var, node)
+ context.loop_stack[-1].append(loop_var)
+
+ def exit_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
+ assert self.loop_vars
+ for _ in self.loop_vars:
+ context.loop_stack[-1].pop()
+ return super().exit_scope(node, context, arg_list, span)
+
+ def create_loop(
+ self,
+ begin: PrimExpr,
+ end: PrimExpr,
+ kind: int,
+ thread_binding: Optional[str] = None,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ) -> tvm.tir.For:
+ """
+ Helper function for creating For in TVM Script parser.
+
+ Parameters
+ ----------
+ begin : PrimExpr
+ The beginning value.
+
+ end : PrimExpr
+ The endding value.
+
+ kind : ForKind
+ The type of the for.
+
+ thread_binding: Optional[str]
+ The thread this loop binds to.
+
+ annotations : Optional[Mapping[str, Object]]
+ Additional annotation hints.
+
+ span : Optional[Span]
+ The location of this for in the source code.
+
+ Returns
+ -------
+ for : For
+ The constructed For.
+ """
+ assert self.node
+ assert self.context
+ assert self.loop_vars
+ if len(self.loop_vars) != 1:
+ self.context.report_error(
+ f"Expect exact only one loop var, but get {self.loop_vars}",
self.node.span
+ )
+ extent = end if begin == 0 else self.context.analyzer.simplify(end -
begin)
+ annos: Mapping[str, Object]
+ if annotations is None:
+ annos = {}
+ else:
+ annos = {
+ key: tvm.tir.StringImm(val) if isinstance(val, str) else val
+ for key, val in annotations.items()
+ }
+ return tvm.tir.For(
+ self.loop_vars[0],
+ begin,
+ extent,
+ kind,
+ self.body,
+ thread_binding=thread_binding,
+ annotations=annos,
+ span=span,
+ )
@register
class Serial(ForScopeHandler):
- """ For scope handler tir.serial(begin, end)"""
+ """ For scope handler tir.serial(begin, end, annotations)"""
def __init__(self):
- def serial(begin, end, span):
- if len(self.loop_vars) != 1:
- self.context.report_error("Expect exact 1 loop var", span)
- ana = tvm.arith.Analyzer()
- extent = end if begin == 0 else ana.simplify(end - begin)
- return tvm.tir.For(self.loop_vars[0], begin, extent, 0, self.body,
span=span)
+ def serial(
+ begin: PrimExpr,
+ end: PrimExpr,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ):
+ return self.create_loop(begin, end, 0, annotations=annotations,
span=span)
Review comment:
Can you use `ForKind.SERIAL` instead of the integer value. (There are a
couple more cases below).
##########
File path: python/tvm/script/special_stmt.py
##########
@@ -121,13 +156,17 @@ def buffer_decl(
buffer_type="default",
span=None,
):
- assert isinstance(self.node, ast.Assign)
+ if not isinstance(self.node, ast.Assign):
+ self.context.report_error(
+ "Need assign the buffer_decl to a buffer, e.g. A =
buffer_decl(...)",
Review comment:
```suggestion
"buffer_decl must be assigned to a buffer, e.g. A =
buffer_decl(...)",
```
##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,12 +181,289 @@ def buffer_decl(
buffer_type,
span=span,
)
- self.context.update_symbol(self.node.lhs.id.name, buffer)
+ self.context.update_symbol(self.node.lhs.id.name, buffer,
self.node)
return buffer
super().__init__(buffer_decl, def_symbol=True)
+@register
+class AllocBuffer(SpecialStmt):
+ """Special function alloc_buffer(shape, dtype, data, strides, elem_offset,
scope, align,
+ offset_factor, buffer_type)
+
+ Example
+ -------
+ .. code-block:: python
+
+ A = tir.alloc_buffer((128, 128), dtype="float32")
+
+ """
+
+ def __init__(self):
+ def alloc_buffer(
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="",
+ align=-1,
+ offset_factor=0,
+ buffer_type="default",
+ span=None,
+ ):
+ if not isinstance(self.node, ast.Assign):
+ self.context.report_error(
+ "Need assign the alloc_buffer to a buffer, e.g. A =
alloc_buffer(...)",
+ self.node.span,
+ )
+
+ if strides is None:
+ strides = []
+ align = convert_to_int(align, "align", self.context.report_error,
self.node.span)
+ offset_factor = convert_to_int(
+ offset_factor, "offset_factor", self.context.report_error,
self.node.span
+ )
+ buffer = tvm.tir.decl_buffer(
+ shape,
+ dtype,
+ self.node.lhs.id.name,
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ span=span,
+ )
+ self.context.current_block_scope().alloc_buffers.append(buffer)
+ self.context.update_symbol(self.node.lhs.id.name, buffer,
self.node)
+
+ super().__init__(alloc_buffer, def_symbol=True)
+
+
+@register
+class BlockVarBind(SpecialStmt):
+ """Special function bind(block_iter, binding_value)
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.bind(vx, i)
+
+ """
+
+ def __init__(self):
+ def bind(iter_var, values, span=None):
+ block_scope = self.context.current_block_scope()
+ if iter_var in block_scope.iter_bindings:
+ self.context.report_error("Duplicate iter_var bindings of " +
str(iter_var), span)
+ block_scope.iter_bindings[iter_var] = values
+
+ super().__init__(bind, def_symbol=False)
+
+
+@register
+class BlockReads(SpecialStmt):
+ """Special function reads([read_buffer_regions])
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
+
+ """
+
+ def __init__(self):
+ def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span:
Span = None):
+ assert self.context
+ block_scope = self.context.current_block_scope()
+ if block_scope.reads is not None:
+ self.context.report_error(
+ "Duplicate write region declaration, "
+ + "previous one is "
+ + str(", ".join(str(x) for x in block_scope.reads)),
+ span,
+ )
+ if isinstance(read_regions, list):
+ pass
+ elif isinstance(read_regions, BufferSlice):
+ read_regions = [read_regions]
+ else:
+ self.context.report_error(
+ "Error input type. "
+ + f"Expects BufferSlice or List[BufferSlice], but gets
{type(read_regions)}",
+ span,
+ )
+ block_scope.reads = read_regions
+
+ super().__init__(reads, def_symbol=False)
+
+
+@register
+class BlockWrites(SpecialStmt):
+ """Special function writes([write_buffer_regions])
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.writes([C[vi: vi + 4, vj])
+
+ """
+
+ def __init__(self):
+ def writes(write_region: Union[BufferSlice, List[BufferSlice]], span:
Span = None):
+ assert self.context
+ block_scope = self.context.current_block_scope()
+ if block_scope.writes is not None:
+ self.context.report_error(
+ "Duplicate write region declaration, "
+ + "previous one is "
+ + str(", ".join(str(x) for x in block_scope.writes)),
+ span,
+ )
+ if isinstance(write_region, list):
+ pass
+ elif isinstance(write_region, BufferSlice):
+ write_region = [write_region]
+ else:
+ self.context.report_error(
+ "Error input type. "
+ + f"Expects BufferSlice or List[BufferSlice], but gets
{type(write_region)}",
+ span,
+ )
+ block_scope.writes = write_region
+
+ super().__init__(writes, def_symbol=False)
+
+
+@register
+class BlockAttr(SpecialStmt):
+ """Special function block_attr({attr_key: attr_value})
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.block_attr({"double_buffer_scope": 1})
+
+ """
+
+ def __init__(self):
+ def block_attr(attrs: Mapping[str, Object], span: Span = None):
+ assert self.context
+ block_scope = self.context.current_block_scope()
+ if block_scope.annotations is not None:
+ self.context.report_error(
+ "Duplicate block annotations declaration, "
+ + "previous one is "
+ + str(block_scope.annotations),
+ span,
+ )
+ attrs = {
+ key: tvm.tir.StringImm(val) if isinstance(val, str) else val
+ for key, val in attrs.items()
+ }
+ block_scope.annotations = attrs
+
+ super().__init__(block_attr, def_symbol=False)
+
+
+@register
+class BlockPredicate(SpecialStmt):
+ """Special function where(predicate)
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.where(i < 4)
+
+ """
+
+ def __init__(self):
+ def where(predicate, span=None):
+ block_scope = self.context.current_block_scope()
+ if block_scope.predicate is not None:
+ self.context.report_error(
+ "Duplicate block predicate declaration, "
+ + "previous one is "
+ + str(block_scope.predicate),
+ span,
+ )
+
+ block_scope.predicate = predicate
+
+ super().__init__(where, def_symbol=False)
+
+
+@register
+class BlockMatchBufferRegion(SpecialStmt):
+ """Special function match_buffer_region(source, strides, elem_offset,
align, offset_factor)
+
+ Example
+ -------
+ .. code-block:: python
+
+ B = tir.match_buffer_region(A[0: 4])
+
+ """
+
+ def __init__(self):
+ def match_buffer_region(
+ source,
+ strides=None,
+ elem_offset=None,
+ align=-1,
+ offset_factor=0,
+ span=None,
+ ):
+ if not isinstance(self.node, ast.Assign):
+ self.context.report_error(
+ "Need assign the match_buffer_region to a buffer, "
Review comment:
```suggestion
"match_buffer_region must be assigned to a buffer, "
```
##########
File path: python/tvm/script/utils.py
##########
@@ -57,3 +107,32 @@ def from_synr_span(span):
span.start_column,
span.end_column,
)
+
+
+def synr_span_from_tvm(span: Span) -> synr.ast.Span:
+ """Convert a TVM span to a synr span"""
+ return synr.ast.Span(
+ span.source_name.name,
+ span.line,
+ span.column,
+ span.end_line,
+ span.end_column,
+ )
+
+
+def call_with_error_reporting(
+ report_error,
+ node_span,
+ func,
+ *args,
+ **kwargs,
+):
+ """Call function with exception handling and report error using
node_span"""
+ try:
+ return func(*args, **kwargs)
+ except DiagnosticError as err:
+ raise err
Review comment:
```suggestion
except DiagnosticError:
raise
```
This gives a cleaner stack trace.
##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,12 +181,289 @@ def buffer_decl(
buffer_type,
span=span,
)
- self.context.update_symbol(self.node.lhs.id.name, buffer)
+ self.context.update_symbol(self.node.lhs.id.name, buffer,
self.node)
return buffer
super().__init__(buffer_decl, def_symbol=True)
+@register
+class AllocBuffer(SpecialStmt):
+ """Special function alloc_buffer(shape, dtype, data, strides, elem_offset,
scope, align,
+ offset_factor, buffer_type)
+
+ Example
+ -------
+ .. code-block:: python
+
+ A = tir.alloc_buffer((128, 128), dtype="float32")
+
+ """
+
+ def __init__(self):
+ def alloc_buffer(
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="",
+ align=-1,
+ offset_factor=0,
+ buffer_type="default",
+ span=None,
+ ):
+ if not isinstance(self.node, ast.Assign):
+ self.context.report_error(
+ "Need assign the alloc_buffer to a buffer, e.g. A =
alloc_buffer(...)",
Review comment:
```suggestion
"alloc_buffer must be assigned to a buffer, e.g. A =
alloc_buffer(...)",
```
##########
File path: src/tir/analysis/block_access_region_detector.cc
##########
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/block_region_detector.cc
+ * \brief Detect block read/write regions by visiting its body
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Auto detect the block read write region
+ * It will detect the read/write region as an array in order of
appearance in AST
+ * \note This detector only accepts to visit a block and will not visit child
blocks recursively
Review comment:
```suggestion
* \note This detector can only visit blocks and will not visit child blocks
recursively
```
##########
File path: python/tvm/script/scope_handler.py
##########
@@ -185,92 +226,365 @@ def let(var, value, span):
super().__init__(let, concise_scope=False, def_symbol=False)
+@register
+class Block(WithScopeHandler):
+ """ With scope handler tir.block(extents, name) as iter_vars"""
Review comment:
I think we could really use more documentation on the usage and
parameters of `tir.block`. I'm not sure where it should go though.
##########
File path: tests/python/unittest/test_tir_analysis_get_block_access_region.py
##########
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import tir, script
+from tvm.ir import Range
+
+
[email protected]
+def func() -> None:
+ A = tir.alloc_buffer((128, 128), "float32")
+ B = tir.alloc_buffer((128, 128), "float32")
+ C = tir.alloc_buffer((128, 128), "float32")
+ D = tir.alloc_buffer((128, 128), "float32")
+ with tir.block([]):
Review comment:
Is the empty array (`[]`) required? I think it would be cleaner to
default to and empty array. i.e. `tir.block()`.
##########
File path: src/tir/analysis/block_access_region_detector.cc
##########
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/block_region_detector.cc
+ * \brief Detect block read/write regions by visiting its body
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Auto detect the block read write region
+ * It will detect the read/write region as an array in order of
appearance in AST
+ * \note This detector only accepts to visit a block and will not visit child
blocks recursively
+ */
+class BlockReadWriteDetector : public StmtExprVisitor {
+ public:
+ explicit BlockReadWriteDetector(const Map<Var, Buffer>& buffer_var_map)
+ : buffer_var_map_(buffer_var_map) {}
+
+ /*! \brief Return read regions of the block */
+ Array<BufferRegion> CollectReads();
+ /*! \brief Return write regions of the block */
+ Array<BufferRegion> CollectWrites();
+ /*!
+ * \brief Return opaque buffer regions of the block
+ * \note The buffer accessed by load/store or call with buffer.data will
+ * be marked as opaque.
+ */
+ Array<BufferRegion> CollectOpaques();
+ /*! \brief overload operator() to make sure it accepts a block node */
+ void operator()(const Stmt& stmt);
+
+ private:
+ /*! \brief Iteration range for loop_vars */
+ std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
+ /*! \brief The buffers that the current block reads */
+ std::vector<Buffer> read_buffers_;
+ /*! \brief The buffers that the current block writes */
+ std::vector<Buffer> writes_buffers_;
+ /*! \brief The opaque buffer which is access by buffer.data */
+ std::vector<Buffer> opaque_buffers_;
+ /*! \brief The read regions of the current block */
+ std::vector<std::vector<tvm::arith::IntSet>> read_regions_;
+ /*! \brief The write regions of the current block */
+ std::vector<std::vector<tvm::arith::IntSet>> write_regions_;
+ /*! \brief The outside buffer data mapping to its buffer */
+ Map<Var, Buffer> buffer_var_map_;
+ /*! \brief The analyzer for simplifying*/
+ arith::Analyzer analyzer_;
+
+ /*!
+ * \brief Update read/write buffers and regions with provided buffer and
region
+ * \param buffers The buffers should be updated
+ * \param regions The access regions should be updated
+ * \param buffer The provided buffer
+ * \param region The provided region
+ */
+ void Update(std::vector<Buffer>* buffers,
std::vector<std::vector<arith::IntSet>>* regions,
+ const Buffer& buffer, const std::vector<arith::IntSet>& region);
+
+ /*! \brief Helper function to collect access regions. */
+ Array<BufferRegion> CollectRegions(const std::vector<Buffer>& buffers,
+ const
std::vector<std::vector<tvm::arith::IntSet>>& regions);
+
+ /*! \brief Helper function to add a opaque buffer. */
+ void AddOpaque(const Var& buffer_var);
+
+ void VisitStmt_(const ForNode* op) override;
+ void VisitStmt_(const BlockRealizeNode* op) override;
+ void VisitStmt_(const BufferStoreNode* op) override;
+ void VisitStmt_(const StoreNode* op) override;
+ void VisitExpr_(const BufferLoadNode* op) override;
+ void VisitExpr_(const LoadNode* op) override;
+ void VisitExpr_(const VarNode* op) override;
+};
+
+void BlockReadWriteDetector::operator()(const Stmt& stmt) {
+ ICHECK(stmt.as<BlockNode>() != nullptr) << "Only allow to visit a block";
+ StmtExprVisitor::operator()(stmt);
+}
+
+Array<BufferRegion> BlockReadWriteDetector::CollectReads() {
+ return CollectRegions(read_buffers_, read_regions_);
+}
+
+Array<BufferRegion> BlockReadWriteDetector::CollectWrites() {
+ return CollectRegions(writes_buffers_, write_regions_);
+}
+
+Array<BufferRegion> BlockReadWriteDetector::CollectOpaques() {
+ Array<BufferRegion> res;
+ res.reserve(opaque_buffers_.size());
+ for (const Buffer& buffer : opaque_buffers_) {
+ res.push_back(BufferRegion::FullRegion(buffer));
+ }
+ return res;
+}
+
+void BlockReadWriteDetector::VisitExpr_(const VarNode* op) {
AddOpaque(GetRef<Var>(op)); }
+
+void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) {
+ AddOpaque(op->buffer_var);
+ ExprVisitor::VisitExpr_(op);
+}
+
+void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) {
+ std::vector<arith::IntSet> relaxed_region;
+ for (const PrimExpr& index : op->indices) {
+ relaxed_region.push_back(arith::EvalSet(index, dom_map_));
+ }
+ Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
+ ExprVisitor::VisitExpr_(op);
+}
+
+void BlockReadWriteDetector::VisitStmt_(const ForNode* op) {
+ Range range = Range::FromMinExtent(op->min, op->extent);
+ dom_map_[op->loop_var.get()] = arith::IntSet::FromRange(range);
+ StmtVisitor::VisitStmt_(op);
+ dom_map_.erase(op->loop_var.get());
+}
+
+void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) {
+ AddOpaque(op->buffer_var);
+ StmtVisitor::VisitStmt_(op);
+}
+
+void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) {
+ std::vector<arith::IntSet> relaxed_region;
+ for (const PrimExpr& index : op->indices) {
+ relaxed_region.push_back(arith::EvalSet(index, dom_map_));
+ }
+ Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);
+ StmtVisitor::VisitStmt_(op);
+}
+
+void BlockReadWriteDetector::VisitStmt_(const BlockRealizeNode* op) {
+ /*! \note detector will not visit child block recursively, so that it will
stop here */
+ std::unordered_map<const VarNode*, PrimExpr> vmap;
+ for (size_t i = 0; i < op->block->iter_vars.size(); ++i) {
+ vmap[op->block->iter_vars[i]->var.get()] = op->iter_values[i];
+ }
+ for (const auto& read : op->block->reads) {
+ std::vector<arith::IntSet> relaxed_region;
+ for (const auto& range : read->region) {
+ relaxed_region.push_back(
+ arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent(
+ Substitute(range->min, vmap),
Substitute(range->extent, vmap))),
+ dom_map_));
+ }
+ Update(&read_buffers_, &read_regions_, read->buffer, relaxed_region);
+ }
+ for (const auto& write : op->block->writes) {
+ std::vector<arith::IntSet> relaxed_region;
+ for (const auto& range : write->region) {
+ relaxed_region.push_back(
+ arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent(
+ Substitute(range->min, vmap),
Substitute(range->extent, vmap))),
+ dom_map_));
+ }
+ Update(&writes_buffers_, &write_regions_, write->buffer, relaxed_region);
+ }
+}
+
+void BlockReadWriteDetector::Update(std::vector<Buffer>* buffers,
+ std::vector<std::vector<arith::IntSet>>*
regions,
+ const Buffer& buffer,
+ const std::vector<arith::IntSet>& region) {
+ if (buffer_var_map_.find(buffer->data) == buffer_var_map_.end()) return;
+ ICHECK_EQ(buffers->size(), regions->size())
+ << " Expect the buffer and regions to have the same size ";
Review comment:
```suggestion
<< " Expected the buffer and regions to have the same size ";
```
##########
File path: src/printer/tvmscript_printer.cc
##########
@@ -713,6 +782,88 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode*
op) {
return doc;
}
+Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) {
+ const auto* block_op = op->block.as<BlockNode>();
+ // print block name and block vars
+ Doc doc;
+ doc << "with tir.block([";
+ std::vector<Doc> block_var_docs;
+ for (const auto& iter_var : block_op->iter_vars) {
+ Doc block_var_doc;
+ if (is_zero(iter_var->dom->min) && iter_var->iter_type == kDataPar) {
+ block_var_doc << Print(iter_var->dom->extent);
+ } else {
+ block_var_doc << "tir.";
+ switch (iter_var->iter_type) {
+ case kDataPar:
+ block_var_doc << "range";
+ break;
+ case kCommReduce:
+ block_var_doc << "reduce_axis";
+ break;
+ case kOrdered:
+ block_var_doc << "scan_axis";
+ break;
+ case kOpaque:
+ block_var_doc << "opaque_axis";
+ break;
+ default:
+ LOG(FATAL) << "Unknown block var iter type";
Review comment:
Can you print `iter_var->iter_type` here.
##########
File path: src/tir/analysis/block_access_region_detector.cc
##########
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/block_region_detector.cc
+ * \brief Detect block read/write regions by visiting its body
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Auto detect the block read write region
+ * It will detect the read/write region as an array in order of
appearance in AST
+ * \note This detector only accepts to visit a block and will not visit child
blocks recursively
+ */
+class BlockReadWriteDetector : public StmtExprVisitor {
+ public:
+ explicit BlockReadWriteDetector(const Map<Var, Buffer>& buffer_var_map)
+ : buffer_var_map_(buffer_var_map) {}
+
+ /*! \brief Return read regions of the block */
+ Array<BufferRegion> CollectReads();
+ /*! \brief Return write regions of the block */
+ Array<BufferRegion> CollectWrites();
+ /*!
+ * \brief Return opaque buffer regions of the block
+ * \note The buffer accessed by load/store or call with buffer.data will
+ * be marked as opaque.
+ */
+ Array<BufferRegion> CollectOpaques();
+ /*! \brief overload operator() to make sure it accepts a block node */
+ void operator()(const Stmt& stmt);
+
+ private:
+ /*! \brief Iteration range for loop_vars */
+ std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
+ /*! \brief The buffers that the current block reads */
+ std::vector<Buffer> read_buffers_;
+ /*! \brief The buffers that the current block writes */
+ std::vector<Buffer> writes_buffers_;
+ /*! \brief The opaque buffer which is access by buffer.data */
+ std::vector<Buffer> opaque_buffers_;
+ /*! \brief The read regions of the current block */
+ std::vector<std::vector<tvm::arith::IntSet>> read_regions_;
+ /*! \brief The write regions of the current block */
+ std::vector<std::vector<tvm::arith::IntSet>> write_regions_;
+ /*! \brief The outside buffer data mapping to its buffer */
+ Map<Var, Buffer> buffer_var_map_;
+ /*! \brief The analyzer for simplifying*/
+ arith::Analyzer analyzer_;
+
+ /*!
+ * \brief Update read/write buffers and regions with provided buffer and
region
+ * \param buffers The buffers should be updated
+ * \param regions The access regions should be updated
+ * \param buffer The provided buffer
+ * \param region The provided region
+ */
+ void Update(std::vector<Buffer>* buffers,
std::vector<std::vector<arith::IntSet>>* regions,
+ const Buffer& buffer, const std::vector<arith::IntSet>& region);
+
+ /*! \brief Helper function to collect access regions. */
+ Array<BufferRegion> CollectRegions(const std::vector<Buffer>& buffers,
+ const
std::vector<std::vector<tvm::arith::IntSet>>& regions);
+
+ /*! \brief Helper function to add a opaque buffer. */
+ void AddOpaque(const Var& buffer_var);
+
+ void VisitStmt_(const ForNode* op) override;
+ void VisitStmt_(const BlockRealizeNode* op) override;
+ void VisitStmt_(const BufferStoreNode* op) override;
+ void VisitStmt_(const StoreNode* op) override;
+ void VisitExpr_(const BufferLoadNode* op) override;
+ void VisitExpr_(const LoadNode* op) override;
+ void VisitExpr_(const VarNode* op) override;
+};
+
+void BlockReadWriteDetector::operator()(const Stmt& stmt) {
+ ICHECK(stmt.as<BlockNode>() != nullptr) << "Only allow to visit a block";
+ StmtExprVisitor::operator()(stmt);
+}
+
+Array<BufferRegion> BlockReadWriteDetector::CollectReads() {
+ return CollectRegions(read_buffers_, read_regions_);
+}
+
+Array<BufferRegion> BlockReadWriteDetector::CollectWrites() {
+ return CollectRegions(writes_buffers_, write_regions_);
+}
+
+Array<BufferRegion> BlockReadWriteDetector::CollectOpaques() {
+ Array<BufferRegion> res;
+ res.reserve(opaque_buffers_.size());
+ for (const Buffer& buffer : opaque_buffers_) {
+ res.push_back(BufferRegion::FullRegion(buffer));
+ }
+ return res;
+}
+
+void BlockReadWriteDetector::VisitExpr_(const VarNode* op) {
AddOpaque(GetRef<Var>(op)); }
+
+void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) {
+ AddOpaque(op->buffer_var);
+ ExprVisitor::VisitExpr_(op);
+}
+
+void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) {
+ std::vector<arith::IntSet> relaxed_region;
+ for (const PrimExpr& index : op->indices) {
+ relaxed_region.push_back(arith::EvalSet(index, dom_map_));
+ }
+ Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
+ ExprVisitor::VisitExpr_(op);
+}
+
+void BlockReadWriteDetector::VisitStmt_(const ForNode* op) {
+ Range range = Range::FromMinExtent(op->min, op->extent);
+ dom_map_[op->loop_var.get()] = arith::IntSet::FromRange(range);
+ StmtVisitor::VisitStmt_(op);
+ dom_map_.erase(op->loop_var.get());
+}
+
+void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) {
+ AddOpaque(op->buffer_var);
+ StmtVisitor::VisitStmt_(op);
+}
+
+void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) {
+ std::vector<arith::IntSet> relaxed_region;
+ for (const PrimExpr& index : op->indices) {
+ relaxed_region.push_back(arith::EvalSet(index, dom_map_));
+ }
+ Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);
+ StmtVisitor::VisitStmt_(op);
+}
+
+void BlockReadWriteDetector::VisitStmt_(const BlockRealizeNode* op) {
+ /*! \note detector will not visit child block recursively, so that it will
stop here */
Review comment:
```suggestion
/*! \note detector will not visit child block recursively, so it will stop
here */
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]