comaniac commented on a change in pull request #7630:
URL: https://github.com/apache/tvm/pull/7630#discussion_r594049237
##########
File path: python/tvm/script/scope_handler.py
##########
@@ -16,32 +16,59 @@
# under the License.
"""TVM Script Parser Scope Handler Classes"""
# pylint: disable=redefined-builtin, unused-argument, invalid-name,
relative-beyond-top-level
+from typing import Tuple, Any, Callable, Optional, List, Union, Mapping
+import synr
from synr import ast
import tvm.tir
-from .utils import get_param_list, from_synr_span
+from tvm.runtime import Object
+from tvm.ir import Span, Range
+from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
+
+from .context_maintainer import ContextMaintainer
+from .utils import (
+ get_param_list,
+ tvm_span_from_synr,
+ buffer_slice_to_region,
+ call_with_error_reporting,
+)
from .registry import register
+from .node import BufferSlice
class ScopeHandler:
"""Base class for all scope handlers"""
- def __init__(self, func):
- self.func = func
- self.body = None
- self.node = None
- self.context = None
+ def __init__(self, func: Callable):
Review comment:
I found that some classes added/changed by this PR have docstring but
others don't. It would be good to make up the missing ones for the classes
changed by this PR.
##########
File path: tests/python/unittest/test_tvmscript_error_report.py
##########
@@ -144,6 +144,197 @@ def test_no_body():
check_error(no_body, 3)
+def allocate_with_buffers() -> None:
Review comment:
nit: why some tests have the return annotations but others haven't?
IMHO, it's fine to ignore the type annotation of returning None.
##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,21 +181,300 @@ def buffer_decl(
buffer_type,
span=span,
)
- self.context.update_symbol(self.node.lhs.id.name, buffer)
+ self.context.update_symbol(self.node.lhs.id.name, buffer,
self.node)
return buffer
super().__init__(buffer_decl, def_symbol=True)
+@register
+class AllocBuffer(SpecialStmt):
+ """Special function alloc_buffer(shape, dtype, data, strides, elem_offset,
scope, align,
+ offset_factor, buffer_type)
+
+ Example
+ -------
+ .. code-block:: python
+
+ A = tir.alloc_buffer((128, 128), dtype="float32")
+
+ """
+
+ def __init__(self):
+ def alloc_buffer(
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="",
+ align=-1,
+ offset_factor=0,
+ buffer_type="default",
+ span=None,
+ ):
+ if not isinstance(self.node, ast.Assign):
+ self.context.report_error(
+ "alloc_buffer must be assigned to a buffer, e.g. A =
alloc_buffer(...)",
+ self.node.span,
+ )
+
+ if strides is None:
+ strides = []
+ align = convert_to_int(align, "align", self.context.report_error,
self.node.span)
+ offset_factor = convert_to_int(
+ offset_factor, "offset_factor", self.context.report_error,
self.node.span
+ )
+ buffer = tvm.tir.decl_buffer(
+ shape,
+ dtype,
+ self.node.lhs.id.name,
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ span=span,
+ )
+ self.context.current_block_scope().alloc_buffers.append(buffer)
+ self.context.update_symbol(self.node.lhs.id.name, buffer,
self.node)
+
+ super().__init__(alloc_buffer, def_symbol=True)
+
+
+@register
+class BlockVarBind(SpecialStmt):
+ """Special function bind(block_iter, binding_value)
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.bind(vx, i)
+
+ """
+
+ def __init__(self):
+ def bind(iter_var, values, span=None):
+ block_scope = self.context.current_block_scope()
+ if iter_var in block_scope.iter_bindings:
+ self.context.report_error("Duplicate iter_var bindings of " +
str(iter_var), span)
+ block_scope.iter_bindings[iter_var] = values
+
+ super().__init__(bind, def_symbol=False)
+
+
+@register
+class BlockReads(SpecialStmt):
+ """Special function reads([read_buffer_regions])
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
+
+ """
+
+ def __init__(self):
+ def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span:
Span = None):
+ assert self.context, "call 'exit_scope' before 'enter_scope'"
+ block_scope = self.context.current_block_scope()
+ if block_scope.reads is not None:
+ self.context.report_error(
+ "Duplicate write region declaration, "
+ + "previous one is "
+ + str(", ".join(str(x) for x in block_scope.reads)),
+ span,
+ )
+ if isinstance(read_regions, list):
+ pass
+ elif isinstance(read_regions, BufferSlice):
+ read_regions = [read_regions]
+ else:
+ self.context.report_error(
+ "Error input type. "
+ + f"Expected BufferSlice or List[BufferSlice], but got
{type(read_regions)}",
+ span,
+ )
+ block_scope.reads = read_regions
+
+ super().__init__(reads, def_symbol=False)
+
+
+@register
+class BlockWrites(SpecialStmt):
+ """Special function writes([write_buffer_regions])
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.writes([C[vi: vi + 4, vj])
+
+ """
+
+ def __init__(self):
+ def writes(write_region: Union[BufferSlice, List[BufferSlice]], span:
Span = None):
+ assert self.context, "call 'exit_scope' before 'enter_scope'"
Review comment:
Seems like Stmts need this checker? Even a Stmt doesn't need the info
from context, it may also need to report errors through it. Could we use a
wrapper in SpecialStmt to help them guarantee this so that we don't need to
repeat this checker in every statements? In case some Stmts don't need this
check, we could also add an optional flag to the `__init__` to turn this
checker off.
##########
File path: python/tvm/script/scope_handler.py
##########
@@ -185,92 +232,384 @@ def let(var, value, span):
super().__init__(let, concise_scope=False, def_symbol=False)
+@register
+class Block(WithScopeHandler):
+ """ With scope handler tir.block(extents, name) as iter_vars"""
+
+ def __init__(self):
+ def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+ assert (
+ self.node and self.context and self.body
+ ), "call 'exit_scope' before 'enter_scope'"
+ block_info = self.context.block_info_stack[-1]
+ if axes is None:
+ axes = []
+ if len(axes) != len(self.block_vars):
+ self.context.report_error(
+ "Inconsistent number of block vars, "
+ + f"there are {len(axes)} axes but {len(self.block_vars)}
block vars. "
+ + "The number of block vars should match the number of
axes.",
+ self.node.span,
+ )
+ block_iters: List[IterVar] = []
+ for i, axis in enumerate(axes):
+ axis = tvm.runtime.convert(axis)
+ if isinstance(axis, tvm.tir.PrimExpr):
+ block_var_dom = Range.from_min_extent(0, axis)
+ block_iters.append(IterVar(block_var_dom,
self.block_vars[i], 0))
+ elif isinstance(axis, Range):
+ block_iters.append(IterVar(axis, self.block_vars[i], 0))
+ elif isinstance(axis, IterVar):
+ block_iters.append(IterVar(axis.dom, self.block_vars[i],
axis.iter_type))
+ else:
+ self.context.report_error(
+ "Invalid argument of tir.block(), "
+ + f"expected PrimExpr, Range or IterVar, but got
{type(axis)}",
+ self.node.span,
+ )
+
+ # create block read/write regions
+
+ reads: List[BufferRegion] = (
+ [buffer_slice_to_region(read) for read in block_info.reads]
+ if block_info.reads
+ else []
+ )
+ writes: List[BufferRegion] = (
+ [buffer_slice_to_region(write) for write in block_info.writes]
+ if block_info.writes
+ else []
+ )
+ inner = tvm.tir.Block(
+ block_iters,
+ reads,
+ writes,
+ name_hint,
+ self.body,
+ block_info.init,
+ block_info.alloc_buffers,
+ block_info.match_buffers,
+ block_info.annotations,
+ span,
+ )
+ # create block var iter binding
+ values: List[PrimExpr]
+ if not block_info.iter_bindings:
+ values = self.context.loop_stack[-2].copy()
+ if len(values) == 0:
+ values = [tvm.tir.const(float("nan"), dtype="float32")] *
len(block_iters)
+ elif len(values) != len(block_iters):
+ self.context.report_error(
+ "Number of block iter var and outer loop nesting
mismatch, "
+ + f"{len(block_iters)} block iter vars but
{len(values)} loops",
+ self.node.span,
+ )
+ else:
+ for block_var in self.block_vars:
+ if block_var not in block_info.iter_bindings:
+ self.context.report_error(
+ "Missing block iter var binding for " +
block_var.name,
+ self.node.span,
+ )
+ values = [block_info.iter_bindings[block_var] for block_var in
self.block_vars]
+ predicate = (
+ tvm.tir.const(True, "bool")
+ if block_info.predicate is None
+ else block_info.predicate
+ )
+ body = tvm.tir.BlockRealize(values, predicate, inner, span)
+ return body
+
+ super().__init__(func=block, concise_scope=False, def_symbol=True)
+ self.block_vars = None
+
+ def enter_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
+ # define block vars
+ assert isinstance(
+ node, ast.With
+ ), f"BlockScopeHandler expected to work on ast.With but got
{type(node)}"
+
+ var_names = WithScopeHandler.get_optional_var_names(node, context)
+ self.block_vars = [tvm.te.var(name) for name in var_names]
+ for block_var in self.block_vars:
+ context.update_symbol(block_var.name, block_var, node)
+
+
+@register
+class InitBlock(WithScopeHandler):
+ """ With scope handler tir.init()"""
+
+ def __init__(self):
+ def init(span: Span = None):
+ assert self.context, "call 'exit_scope' before 'enter_scope'"
+ if self.context.block_info_stack[-2].init is not None:
+ self.context.report_error("Duplicate init block declaration",
span)
+ self.context.block_info_stack[-2].init = self.body
+
+ super().__init__(func=init, concise_scope=False, def_symbol=True)
+
+
class ForScopeHandler(ScopeHandler):
"""Base class for all for scope handlers"""
def __init__(self, func):
super().__init__(func)
- self.loop_vars = None
-
- def enter_scope(self, node, context, arg_list, span):
- assert isinstance(node, ast.For)
+ self.loop_vars: Optional[List[Var]] = None
+
+ def enter_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
+ assert isinstance(
+ node, ast.For
+ ), f"ForScopeHandler expected to work on ast.For but got {type(node)}"
loop_var_names = list()
spans = list()
if isinstance(node.lhs, ast.Var):
loop_var_names.append(node.lhs.id.name)
- spans.append(from_synr_span(node.lhs.id.span))
- elif isinstance(node.lhs, ast.Tuple):
- for elt in node.lhs.values:
+ spans.append(tvm_span_from_synr(node.lhs.id.span))
+ elif isinstance(node.lhs, list):
+ for elt in node.lhs:
if not isinstance(elt, ast.Var):
- context.report_error("Invalid loop var", elt.span)
+ context.report_error(
+ f"Invalid loop var. Expected a var, but got
{type(elt)}", elt.span
+ )
loop_var_names.append(elt.id.name)
- spans.append(from_synr_span(elt.id.span))
+ spans.append(tvm_span_from_synr(elt.id.span))
else:
- context.report_error("Invalid loop var", node.lhs.span)
+ context.report_error(
+ f"Invalid loop var. Expected var or list of vars as lhs, but
got {type(node.lhs)}",
+ span,
+ )
self.loop_vars = [
tvm.te.var(name, dtype="int32", span=span) for name, span in
zip(loop_var_names, spans)
]
for loop_var in self.loop_vars:
- context.update_symbol(loop_var.name, loop_var)
+ context.update_symbol(loop_var.name, loop_var, node)
+ context.loop_stack[-1].append(loop_var)
+
+ def exit_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
+ assert self.loop_vars, "call 'exit_scope' before 'enter_scope'"
+ for _ in self.loop_vars:
+ context.loop_stack[-1].pop()
+ return super().exit_scope(node, context, arg_list, span)
+
+ def create_loop(
+ self,
+ begin: PrimExpr,
+ end: PrimExpr,
+ kind: ForKind,
+ thread_binding: Optional[str] = None,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ) -> tvm.tir.For:
+ """
+ Helper function for creating For in TVM Script parser.
+
+ Parameters
+ ----------
+ begin : PrimExpr
+ The beginning value.
+
+ end : PrimExpr
+ The endding value.
+
+ kind : ForKind
+ The type of the for.
+
+ thread_binding: Optional[str]
+ The thread this loop binds to.
+
+ annotations : Optional[Mapping[str, Object]]
+ Additional annotation hints.
+
+ span : Optional[Span]
+ The location of this for in the source code.
+
+ Returns
+ -------
+ for : For
+ The constructed For.
+ """
+ assert (
+ self.loop_vars and self.context and self.node
+ ), "call 'exit_scope' before 'enter_scope'"
+ if len(self.loop_vars) != 1:
+ self.context.report_error(
+ f"Expect exact only one loop var, but get {self.loop_vars}",
self.node.span
+ )
+ extent = end if begin == 0 else self.context.analyzer.simplify(end -
begin)
+ annos: Mapping[str, Object]
+ if annotations is None:
+ annos = {}
+ else:
Review comment:
```suggestion
annos: Mapping[str, Object] = {}
if annotations is not None:
```
##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,21 +181,300 @@ def buffer_decl(
buffer_type,
span=span,
)
- self.context.update_symbol(self.node.lhs.id.name, buffer)
+ self.context.update_symbol(self.node.lhs.id.name, buffer,
self.node)
return buffer
super().__init__(buffer_decl, def_symbol=True)
+@register
+class AllocBuffer(SpecialStmt):
+ """Special function alloc_buffer(shape, dtype, data, strides, elem_offset,
scope, align,
+ offset_factor, buffer_type)
+
+ Example
+ -------
+ .. code-block:: python
+
+ A = tir.alloc_buffer((128, 128), dtype="float32")
+
+ """
+
+ def __init__(self):
+ def alloc_buffer(
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="",
+ align=-1,
+ offset_factor=0,
+ buffer_type="default",
+ span=None,
+ ):
+ if not isinstance(self.node, ast.Assign):
+ self.context.report_error(
+ "alloc_buffer must be assigned to a buffer, e.g. A =
alloc_buffer(...)",
+ self.node.span,
+ )
+
+ if strides is None:
+ strides = []
+ align = convert_to_int(align, "align", self.context.report_error,
self.node.span)
+ offset_factor = convert_to_int(
+ offset_factor, "offset_factor", self.context.report_error,
self.node.span
+ )
+ buffer = tvm.tir.decl_buffer(
+ shape,
+ dtype,
+ self.node.lhs.id.name,
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ span=span,
+ )
+ self.context.current_block_scope().alloc_buffers.append(buffer)
+ self.context.update_symbol(self.node.lhs.id.name, buffer,
self.node)
+
+ super().__init__(alloc_buffer, def_symbol=True)
+
+
+@register
+class BlockVarBind(SpecialStmt):
+ """Special function bind(block_iter, binding_value)
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.bind(vx, i)
+
+ """
+
+ def __init__(self):
+ def bind(iter_var, values, span=None):
+ block_scope = self.context.current_block_scope()
+ if iter_var in block_scope.iter_bindings:
+ self.context.report_error("Duplicate iter_var bindings of " +
str(iter_var), span)
+ block_scope.iter_bindings[iter_var] = values
+
+ super().__init__(bind, def_symbol=False)
+
+
+@register
+class BlockReads(SpecialStmt):
+ """Special function reads([read_buffer_regions])
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
+
+ """
+
+ def __init__(self):
+ def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span:
Span = None):
+ assert self.context, "call 'exit_scope' before 'enter_scope'"
+ block_scope = self.context.current_block_scope()
+ if block_scope.reads is not None:
+ self.context.report_error(
+ "Duplicate write region declaration, "
+ + "previous one is "
+ + str(", ".join(str(x) for x in block_scope.reads)),
+ span,
+ )
+ if isinstance(read_regions, list):
+ pass
+ elif isinstance(read_regions, BufferSlice):
+ read_regions = [read_regions]
+ else:
+ self.context.report_error(
+ "Error input type. "
+ + f"Expected BufferSlice or List[BufferSlice], but got
{type(read_regions)}",
+ span,
+ )
+ block_scope.reads = read_regions
+
+ super().__init__(reads, def_symbol=False)
+
+
+@register
+class BlockWrites(SpecialStmt):
+ """Special function writes([write_buffer_regions])
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.writes([C[vi: vi + 4, vj])
+
+ """
+
+ def __init__(self):
+ def writes(write_region: Union[BufferSlice, List[BufferSlice]], span:
Span = None):
+ assert self.context, "call 'exit_scope' before 'enter_scope'"
+ block_scope = self.context.current_block_scope()
+ if block_scope.writes is not None:
+ self.context.report_error(
+ "Duplicate write region declaration, "
+ + "previous one is "
+ + str(", ".join(str(x) for x in block_scope.writes)),
+ span,
+ )
+ if isinstance(write_region, list):
+ pass
+ elif isinstance(write_region, BufferSlice):
+ write_region = [write_region]
+ else:
+ self.context.report_error(
+ "Error input type. "
+ + f"Expected BufferSlice or List[BufferSlice], but got
{type(write_region)}",
+ span,
+ )
+ block_scope.writes = write_region
+
+ super().__init__(writes, def_symbol=False)
+
+
+@register
+class BlockAttr(SpecialStmt):
+ """Special function block_attr({attr_key: attr_value})
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.block_attr({"double_buffer_scope": 1})
+
+ """
+
+ def __init__(self):
+ def block_attr(attrs: Mapping[str, Object], span: Span = None):
+ assert self.context, "call 'exit_scope' before 'enter_scope'"
+ block_scope = self.context.current_block_scope()
+ if block_scope.annotations is not None:
+ self.context.report_error(
+ "Duplicate block annotations declaration, "
+ + "previous one is "
+ + str(block_scope.annotations),
+ span,
+ )
+ attrs = {
+ key: tvm.tir.StringImm(val) if isinstance(val, str) else val
+ for key, val in attrs.items()
+ }
+ block_scope.annotations = attrs
+
+ super().__init__(block_attr, def_symbol=False)
+
+
+@register
+class BlockPredicate(SpecialStmt):
+ """Special function where(predicate)
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.where(i < 4)
+
+ """
+
+ def __init__(self):
+ def where(predicate, span=None):
+ block_scope = self.context.current_block_scope()
Review comment:
Need to check `self.context`?
##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,21 +181,300 @@ def buffer_decl(
buffer_type,
span=span,
)
- self.context.update_symbol(self.node.lhs.id.name, buffer)
+ self.context.update_symbol(self.node.lhs.id.name, buffer,
self.node)
return buffer
super().__init__(buffer_decl, def_symbol=True)
+@register
+class AllocBuffer(SpecialStmt):
+ """Special function alloc_buffer(shape, dtype, data, strides, elem_offset,
scope, align,
+ offset_factor, buffer_type)
+
+ Example
+ -------
+ .. code-block:: python
+
+ A = tir.alloc_buffer((128, 128), dtype="float32")
+
Review comment:
remove this line (ditto to the rest)
##########
File path: python/tvm/script/parser.py
##########
@@ -191,9 +194,11 @@ def report_error(self, message, span):
----------
message : str
Error message
- span : synr.ast.Span
+ span : synr.ast.Span or tvm.ir.Span
Review comment:
It would be better to make the type annotation consistent. i.e.,
`Union[...]`.
##########
File path: python/tvm/script/parser.py
##########
@@ -401,25 +413,52 @@ def my_function(x: ty.handle): # 1. Argument types
"""
self.init_function_parsing_env()
- self.context.new_scope(nodes=node.body.stmts)
+ self.context.enter_scope(nodes=node.body.stmts)
# add parameters of function
for arg in node.params:
arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
- self.context.update_symbol(arg.name, arg_var)
+ self.context.update_symbol(arg.name, arg_var, node)
self.context.func_params.append(arg_var)
- # fetch the body and return a tir.PrimFunc
+ # New Scope : Implicit root block
+ # Each function contains an implicit root block in TensorIR,
+ # so here we need a block scope for it. Please note that
`enter_block_scope`
+ # will not create a block directly but just store some information.
+ # If the PrimFunc is not a TensorIR func (e.g. TE scheduled func or
low-level func),
+ # the roo block will not be added. The logic to add root block is in
`_ffi_api.Complete`
Review comment:
```suggestion
# the root block will not be added. The logic to add root block is
in `_ffi_api.Complete`
```
##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,21 +181,300 @@ def buffer_decl(
buffer_type,
span=span,
)
- self.context.update_symbol(self.node.lhs.id.name, buffer)
+ self.context.update_symbol(self.node.lhs.id.name, buffer,
self.node)
return buffer
super().__init__(buffer_decl, def_symbol=True)
+@register
+class AllocBuffer(SpecialStmt):
+ """Special function alloc_buffer(shape, dtype, data, strides, elem_offset,
scope, align,
+ offset_factor, buffer_type)
+
Review comment:
parameters? (ditto to the rest)
##########
File path: python/tvm/script/node.py
##########
@@ -0,0 +1,150 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+"""TVM Script nodes."""
+
+from typing import Optional, Union, List, Callable
+import synr
+
+from tvm.runtime import ObjectGeneric
+from tvm.tir import PrimExpr, Buffer, BufferLoad
+from tvm.ir import Span
+
+
+class Slice:
+ """A helper class to present slice information for BufferSlice
+
+ Parameters
+ ----------
+ start : Union[PrimExpr, int]
+ The start index.
+
+ stop : Optional[Union[PrimExpr, int]]
+ The stop index, None means the Slice is an element-wise index
+
+ span : Optional[Span]
+ The location of the slice in the source.
+ """
+
+ start: Union[PrimExpr, int]
+ stop: Optional[Union[PrimExpr, int]]
+ span: Optional[Span]
+
+ def __init__(
+ self,
+ start: Union[PrimExpr, int],
+ stop: Optional[Union[PrimExpr, int]] = None,
+ span: Optional[Span] = None,
+ ):
+ self.start = start
+ self.stop = stop
+ self.span = span
+
+
+class BufferSlice(ObjectGeneric):
+ """A generic object for representing general buffer access. Following
cases are supported:
+ - element wise access buffer[i, j], which can be convert to BufferLoad
if necessary
Review comment:
```suggestion
- element wise access buffer[i, j], which can be converted to
BufferLoad if necessary
```
##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,21 +181,300 @@ def buffer_decl(
buffer_type,
span=span,
)
- self.context.update_symbol(self.node.lhs.id.name, buffer)
+ self.context.update_symbol(self.node.lhs.id.name, buffer,
self.node)
return buffer
super().__init__(buffer_decl, def_symbol=True)
+@register
+class AllocBuffer(SpecialStmt):
+ """Special function alloc_buffer(shape, dtype, data, strides, elem_offset,
scope, align,
+ offset_factor, buffer_type)
+
+ Example
+ -------
+ .. code-block:: python
+
+ A = tir.alloc_buffer((128, 128), dtype="float32")
+
+ """
+
+ def __init__(self):
+ def alloc_buffer(
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="",
+ align=-1,
+ offset_factor=0,
+ buffer_type="default",
+ span=None,
+ ):
+ if not isinstance(self.node, ast.Assign):
+ self.context.report_error(
+ "alloc_buffer must be assigned to a buffer, e.g. A =
alloc_buffer(...)",
+ self.node.span,
+ )
+
+ if strides is None:
+ strides = []
+ align = convert_to_int(align, "align", self.context.report_error,
self.node.span)
+ offset_factor = convert_to_int(
+ offset_factor, "offset_factor", self.context.report_error,
self.node.span
+ )
+ buffer = tvm.tir.decl_buffer(
+ shape,
+ dtype,
+ self.node.lhs.id.name,
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ span=span,
+ )
+ self.context.current_block_scope().alloc_buffers.append(buffer)
+ self.context.update_symbol(self.node.lhs.id.name, buffer,
self.node)
+
+ super().__init__(alloc_buffer, def_symbol=True)
+
+
+@register
+class BlockVarBind(SpecialStmt):
+ """Special function bind(block_iter, binding_value)
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.bind(vx, i)
+
+ """
+
+ def __init__(self):
+ def bind(iter_var, values, span=None):
+ block_scope = self.context.current_block_scope()
+ if iter_var in block_scope.iter_bindings:
+ self.context.report_error("Duplicate iter_var bindings of " +
str(iter_var), span)
+ block_scope.iter_bindings[iter_var] = values
+
+ super().__init__(bind, def_symbol=False)
+
+
+@register
+class BlockReads(SpecialStmt):
+ """Special function reads([read_buffer_regions])
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
+
+ """
+
+ def __init__(self):
+ def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span:
Span = None):
+ assert self.context, "call 'exit_scope' before 'enter_scope'"
+ block_scope = self.context.current_block_scope()
+ if block_scope.reads is not None:
+ self.context.report_error(
+ "Duplicate write region declaration, "
+ + "previous one is "
+ + str(", ".join(str(x) for x in block_scope.reads)),
+ span,
+ )
+ if isinstance(read_regions, list):
+ pass
+ elif isinstance(read_regions, BufferSlice):
+ read_regions = [read_regions]
+ else:
+ self.context.report_error(
+ "Error input type. "
+ + f"Expected BufferSlice or List[BufferSlice], but got
{type(read_regions)}",
+ span,
+ )
+ block_scope.reads = read_regions
+
+ super().__init__(reads, def_symbol=False)
+
+
+@register
+class BlockWrites(SpecialStmt):
+ """Special function writes([write_buffer_regions])
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.writes([C[vi: vi + 4, vj])
+
+ """
+
+ def __init__(self):
+ def writes(write_region: Union[BufferSlice, List[BufferSlice]], span:
Span = None):
+ assert self.context, "call 'exit_scope' before 'enter_scope'"
+ block_scope = self.context.current_block_scope()
+ if block_scope.writes is not None:
+ self.context.report_error(
+ "Duplicate write region declaration, "
+ + "previous one is "
+ + str(", ".join(str(x) for x in block_scope.writes)),
+ span,
+ )
+ if isinstance(write_region, list):
+ pass
+ elif isinstance(write_region, BufferSlice):
+ write_region = [write_region]
+ else:
+ self.context.report_error(
+ "Error input type. "
+ + f"Expected BufferSlice or List[BufferSlice], but got
{type(write_region)}",
+ span,
+ )
+ block_scope.writes = write_region
+
+ super().__init__(writes, def_symbol=False)
+
+
+@register
+class BlockAttr(SpecialStmt):
+ """Special function block_attr({attr_key: attr_value})
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.block_attr({"double_buffer_scope": 1})
+
+ """
+
+ def __init__(self):
+ def block_attr(attrs: Mapping[str, Object], span: Span = None):
+ assert self.context, "call 'exit_scope' before 'enter_scope'"
+ block_scope = self.context.current_block_scope()
+ if block_scope.annotations is not None:
+ self.context.report_error(
+ "Duplicate block annotations declaration, "
+ + "previous one is "
+ + str(block_scope.annotations),
+ span,
+ )
+ attrs = {
+ key: tvm.tir.StringImm(val) if isinstance(val, str) else val
+ for key, val in attrs.items()
+ }
+ block_scope.annotations = attrs
+
+ super().__init__(block_attr, def_symbol=False)
+
+
+@register
+class BlockPredicate(SpecialStmt):
+ """Special function where(predicate)
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.where(i < 4)
+
+ """
+
+ def __init__(self):
+ def where(predicate, span=None):
+ block_scope = self.context.current_block_scope()
+ if block_scope.predicate is not None:
+ self.context.report_error(
+ "Duplicate block predicate declaration, "
+ + "previous one is "
+ + str(block_scope.predicate),
+ span,
+ )
+
+ block_scope.predicate = predicate
+
+ super().__init__(where, def_symbol=False)
+
+
+@register
+class BlockMatchBufferRegion(SpecialStmt):
+ """Special function match_buffer_region(source, strides, elem_offset,
align, offset_factor)
+
+ Example
+ -------
+ .. code-block:: python
+
+ B = tir.match_buffer_region(A[0: 4])
+
+ """
+
+ def __init__(self):
+ def match_buffer_region(
+ source,
+ strides=None,
+ elem_offset=None,
+ align=-1,
+ offset_factor=0,
+ span=None,
+ ):
+ if not isinstance(self.node, ast.Assign):
+ self.context.report_error(
Review comment:
Need to check `self.context`?
##########
File path: python/tvm/tir/analysis/analysis.py
##########
@@ -106,3 +106,26 @@ def verify_gpu_code(func, constraints):
The result of verification.
"""
return _ffi_api.verify_gpu_code(func, constraints)
+
+
+def get_block_access_region(block, buffer_var_map):
+ """Auto detect the block read/write region according to body stmt
+ It will detect the read/write region as an array in order of
appearance in AST
+
+ Parameters
+ ----------
+ block: tvm.tir.Block
+ The block to be detected.
+
+ buffer_var_map : Dict[Var, Buffer]
+ The outside buffers which may be accessed the block. Mapping from
buffer var to the buffer
Review comment:
```suggestion
The outside buffers which may access the block. Mapping from buffer
var to the buffer
```
##########
File path: python/tvm/script/special_stmt.py
##########
@@ -142,21 +181,300 @@ def buffer_decl(
buffer_type,
span=span,
)
- self.context.update_symbol(self.node.lhs.id.name, buffer)
+ self.context.update_symbol(self.node.lhs.id.name, buffer,
self.node)
return buffer
super().__init__(buffer_decl, def_symbol=True)
+@register
+class AllocBuffer(SpecialStmt):
+ """Special function alloc_buffer(shape, dtype, data, strides, elem_offset,
scope, align,
+ offset_factor, buffer_type)
+
+ Example
+ -------
+ .. code-block:: python
+
+ A = tir.alloc_buffer((128, 128), dtype="float32")
+
+ """
+
+ def __init__(self):
+ def alloc_buffer(
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="",
+ align=-1,
+ offset_factor=0,
+ buffer_type="default",
+ span=None,
+ ):
+ if not isinstance(self.node, ast.Assign):
+ self.context.report_error(
+ "alloc_buffer must be assigned to a buffer, e.g. A =
alloc_buffer(...)",
+ self.node.span,
+ )
+
+ if strides is None:
+ strides = []
+ align = convert_to_int(align, "align", self.context.report_error,
self.node.span)
+ offset_factor = convert_to_int(
+ offset_factor, "offset_factor", self.context.report_error,
self.node.span
+ )
+ buffer = tvm.tir.decl_buffer(
+ shape,
+ dtype,
+ self.node.lhs.id.name,
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ span=span,
+ )
+ self.context.current_block_scope().alloc_buffers.append(buffer)
+ self.context.update_symbol(self.node.lhs.id.name, buffer,
self.node)
+
+ super().__init__(alloc_buffer, def_symbol=True)
+
+
+@register
+class BlockVarBind(SpecialStmt):
+ """Special function bind(block_iter, binding_value)
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.bind(vx, i)
+
+ """
+
+ def __init__(self):
+ def bind(iter_var, values, span=None):
+ block_scope = self.context.current_block_scope()
+ if iter_var in block_scope.iter_bindings:
+ self.context.report_error("Duplicate iter_var bindings of " +
str(iter_var), span)
+ block_scope.iter_bindings[iter_var] = values
+
+ super().__init__(bind, def_symbol=False)
+
+
+@register
+class BlockReads(SpecialStmt):
+ """Special function reads([read_buffer_regions])
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
+
+ """
+
+ def __init__(self):
+ def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span:
Span = None):
+ assert self.context, "call 'exit_scope' before 'enter_scope'"
+ block_scope = self.context.current_block_scope()
+ if block_scope.reads is not None:
+ self.context.report_error(
+ "Duplicate write region declaration, "
+ + "previous one is "
+ + str(", ".join(str(x) for x in block_scope.reads)),
+ span,
+ )
+ if isinstance(read_regions, list):
+ pass
+ elif isinstance(read_regions, BufferSlice):
+ read_regions = [read_regions]
+ else:
Review comment:
```suggestion
if isinstance(read_regions, BufferSlice):
read_regions = [read_regions]
if not isinstance(read_regions, list):
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]