This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fffed0f [TensorIR] TVMScript Parser/Printer (#7630)
fffed0f is described below
commit fffed0ff91c46f5c45070b52794f4f2bf4d1b8a5
Author: Siyuan Feng <[email protected]>
AuthorDate: Sun Mar 21 04:22:53 2021 +0800
[TensorIR] TVMScript Parser/Printer (#7630)
Co-authored-by: Bohan Hou
<[email protected]>
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Tianqi Chen <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Tristan Konolige <[email protected]>
Co-authored-by: Cody Yu <[email protected]>
---
include/tvm/tir/analysis.h | 15 +
python/tvm/script/context_maintainer.py | 210 +++++++--
python/tvm/script/intrin.py | 20 +-
python/tvm/script/node.py | 150 +++++++
python/tvm/script/parser.py | 179 +++++---
python/tvm/script/registry.py | 20 +-
python/tvm/script/scope_handler.py | 473 ++++++++++++++++++---
python/tvm/script/special_stmt.py | 380 +++++++++++++++--
python/tvm/script/utils.py | 95 ++++-
python/tvm/tir/analysis/analysis.py | 23 +
src/printer/tir_text_printer.cc | 3 +-
src/printer/tvmscript_printer.cc | 232 +++++++++-
src/tir/analysis/block_access_region_detector.cc | 246 +++++++++++
src/tir/ir/script/script_complete.cc | 122 ++++++
.../test_tir_analysis_get_block_access_region.py | 57 +++
.../python/unittest/test_tvmscript_error_report.py | 205 +++++++++
tests/python/unittest/test_tvmscript_roundtrip.py | 170 ++++++++
tests/scripts/task_ci_python_setup.sh | 2 +-
tests/scripts/task_ci_setup.sh | 2 +-
19 files changed, 2395 insertions(+), 209 deletions(-)
diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index 1ad7859..1692a8c 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -157,6 +157,21 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func);
*/
TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr>
constraints);
+/*!
+ * \brief Auto detect the block read/write region according to body stmt
+ * It will detect the read/write region as an array in order of
appearance in AST
+ * \param block The block to be detected
+ * \param buffer_var_map The outside buffers which may be accessed the block.
+ * It is a map from buffer var to the buffer.
+ * \return Array of access regions.
+ * There are three arrays of BufferRegion:
+ * - first: read regions
+ * - second: write regions
+ * - third: opaque regions
+ */
+Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
+ const Map<Var, Buffer>&
buffer_var_map);
+
// Pass variants of verification analysis
// directly throws RuntimeError when verification fails.
namespace transform {
diff --git a/python/tvm/script/context_maintainer.py
b/python/tvm/script/context_maintainer.py
index 955266c..ae3e9d8 100644
--- a/python/tvm/script/context_maintainer.py
+++ b/python/tvm/script/context_maintainer.py
@@ -16,59 +16,217 @@
# under the License.
"""TVM Script Context Maintainer for TIR"""
-from tvm.te import schedule
+from typing import List, Mapping, Union, Optional, Dict, Callable
+import synr
+
+
+import tvm
+from tvm.ir import Span
+from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
+from tvm.runtime import Object
+from .node import BufferSlice
+
+
+class BlockInfo:
+ """Information for block and block_realize signature
+
+ Examples
+ ----------
+ .. code-block:: python
+
+ @tvm.script.tir
+ def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 16), "float32")
+ B = tir.match_buffer(b, (16, 16), "float32")
+ C = tir.match_buffer(a, (16, 16), "float32")
+
+ for i, j, k in tir.grid(16, 16, 16):
+ with tir.block([16, 16, tir.reduce_axis(16)], "matmul") as
[vi, vj, vk]:
+ tir.bind(vi, i)
+ tir.bind(vj, j)
+ tir.bind(vk, k) # iter_bindings = {vj: i, vj: j,
vk: k}
+
+ tir.where(True) # predicate of the block_realize
+
+ tir.reads(A[0:16, 0:16], B[0: 16, 0: 16]) # reads
region of the block
+ tir.writes(C[0: 16, 0: 16]) # writes
region of the block
+ tir.block_attr({"attr_key": "attr_value"}) # block
annotations
+
+ # alloc_buffers inside the block
+ CC = tir.alloc_buffer((1, 1), dtype="float32")
+
+ # match_buffers of the block,
+ # which bind a sub-region of source buffer into a new
buffer
+ D = tir.match_buffer_region(C[vi, vj])
+
+ # init part of the block, executed when all reduce axes
are the beginning value
+ with tir.init():
+ C[vi, vj] = tir.float32(0)
+
+ # block body
+ CC[0, 0] = A[vi, vk] * B[vj, vk]
+ D[0, 0] += CC[0, 0] # The same as C[vi, vj] +=
CC[0, 0]
+ """
+
+ alloc_buffers: List[Buffer] = []
+ """List[Buffer]: list of tir.alloc_buffer statements in the block
signature"""
+ match_buffers: List[MatchBufferRegion] = []
+ """List[MatchBufferRegion]: list of tir.match_buffer_region statements in
the block signature"""
+ iter_bindings: Mapping[Var, PrimExpr] = {}
+ """Mapping[Var, PrimExpr]: map of block iter var to its values"""
+ reads: Optional[List[BufferSlice]] = None
+ """Optional[List[BufferSlice]]:
+ list of tir.reads statements in the block signature, None for
not-visited"""
+ writes: Optional[List[BufferSlice]] = None
+ """Optional[List[BufferSlice]]:
+ list of tir.writes statements in the block signature, None for
not-visited"""
+ annotations: Optional[Mapping[str, Object]] = None
+ """Optional[Mapping[str, Object]]:
+ list of tir.block_attr statements in the block signature, None for
not-visited"""
+ predicate: Optional[PrimExpr] = None
+ """Optional[PrimExpr]: block realize predicate, None for not-visited"""
+ init: Optional[Stmt] = None
+ """Optional[Stmt]: init part of the block, None for not-visited"""
+
+ def __init__(self):
+ self.alloc_buffers = []
+ self.match_buffers = []
+ self.iter_bindings = {}
+ self.reads = None
+ self.writes = None
+ self.annotations = None
+ self.predicate = None
+ self.init = None
class ContextMaintainer:
- """Maintain all the necessary context info"""
+ """Maintain all the necessary context info
+ Parameters
+ ----------
+ _report_error : Callable[[str, Union[Span, synr.ast.Span]], None]
+ The report error function handle
+ """
+
+ # scope context
+ node_stack: List[List[synr.ast.Node]] = []
+ """List[List[synr.ast.Node]]: The ast nodes insides the current scope"""
+ block_info_stack: List[BlockInfo] = []
+ """List[BlockInfo]: The block info for the current block scope"""
+ loop_stack: List[List[Var]] = []
+ """List[List[Var]]: List of loop vars inside the current block scope"""
+ symbols: List[Dict[str, Union[Var, Buffer]]] = []
+ """List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for
the current scope"""
- def __init__(self, parser):
+ # function context
+ func_params: List[Var] = []
+ """List[Var]: The function parameters"""
+ func_buffer_map: Mapping[Var, Buffer] = {}
+ """Mapping[Var, Buffer]: The function buffer map"""
+ func_dict_attr: Mapping[str, Object] = {}
+ """Mapping[str, Object]: The function attrs"""
+ func_var_env_dict: Mapping[Var, str] = {}
+ """Mapping[Var, str]: The map from var to env thread"""
+
+ # parser and analyzer
+ analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer()
+ """tvm.arith.Analyzer: The analyzer for simplifying"""
+ _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
+ """Callable[[str, Union[Span, synr.ast.Span]], None]: The report error
function handle"""
+
+ def __init__(self, _report_error: Callable[[str, Union[Span,
synr.ast.Span]], None]):
# scope context
- self.node_stack = [] # AST nodes of scopes
- self.symbols = [] # symbols of scopes
+ self.node_stack = []
+ self.block_info_stack = []
+ self.loop_stack = []
+ self.symbols = []
# function context
- self.func_params = [] # parameter list of function
- self.func_buffer_map = {} # buffer_map of function
- self.func_dict_attr = {} # func_attr of function
- self.func_var_env_dict = {} # map from var to env_name
- # parser
- self.parser = parser
-
- def pop_scope(self):
- """Pop the inner most scope"""
- self.symbols.pop()
- self.node_stack.pop()
+ self.func_params = []
+ self.func_buffer_map = {}
+ self.func_dict_attr = {}
+ self.func_var_env_dict = {}
+ # parser and analyzer
+ self._report_error = _report_error
+ self.analyzer = tvm.arith.Analyzer()
+
+ def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
+ """Creates a new scope
- def new_scope(self, nodes=None):
- """Creating a new scope"""
+ Note
+ ----
+ This function is used for normal scopes that do not involve
+ a `with block` scope. Use `enter_block_scope`
+ for block scope cases.
+
+ Parameters
+ ----------
+ nodes : Optional[List[synr.ast.Node]]
+ The synr AST nodes in new scope
+ """
if nodes is None:
nodes = []
self.node_stack.append(list(reversed(nodes)))
self.symbols.append(dict())
- def update_symbol(self, name, symbol):
+ def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
+ """Creates a new block scope, the function will call `enter_scope`
implicitly
+ Besides the behaviors of `enter_scope`, it will update loop_stack and
block_info_stack
+ to maintain block info.
+
+ Note
+ ----
+ This function should be used to handle a block scope,
+ aka the blocks that involve a `with block` scope.
+
+ Parameters
+ ----------
+ nodes : Optional[List[synr.ast.Node]]
+ The synr AST nodes in new scope
+ """
+ self.enter_scope(nodes)
+ # Create a new loop stack for the new block
+ self.loop_stack.append([])
+ # Create a new BlockInfo for the new block
+ self.block_info_stack.append(BlockInfo())
+
+ def exit_scope(self):
+ """Pop the inner most scope"""
+ self.symbols.pop()
+ self.node_stack.pop()
+
+ def exit_block_scope(self):
+ """Pop the inner most block scope, the function will call `exit_scope`
implicitly"""
+ self.exit_scope()
+ # Pop loop stack
+ self.loop_stack.pop()
+ # Pop block_info
+ self.block_info_stack.pop()
+
+ def update_symbol(self, name: str, symbol: Union[Buffer, Var], node:
synr.ast.Node):
"""Append a symbol into current scope"""
- if isinstance(symbol, schedule.Buffer):
+ if isinstance(symbol, Buffer):
if name in self.symbols[0]:
- self.parser.report_error("Duplicate Buffer name")
+ self.report_error("Duplicate Buffer name: " + symbol.name,
node.span)
self.symbols[0][name] = symbol
else:
self.symbols[-1][name] = symbol
- def remove_symbol(self, name):
+ def remove_symbol(self, name: str):
"""Remove a symbol"""
for symbols in reversed(self.symbols):
if name in symbols:
symbols.pop(name)
return
- raise RuntimeError("Internal error of tvm script parser: no symbol
named" + name)
+ raise RuntimeError("Internal error of tvm script parser: no symbol
named " + name)
- def lookup_symbol(self, name):
+ def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]:
"""Look up symbol by name"""
for symbols in reversed(self.symbols):
if name in symbols:
return symbols[name]
return None
- def report_error(self, message, span):
- self.parser.report_error(message, span)
+ def report_error(self, message: str, span: Union[Span, synr.ast.Span]):
+ self._report_error(message, span)
+
+ def current_block_scope(self) -> BlockInfo:
+ return self.block_info_stack[-1]
diff --git a/python/tvm/script/intrin.py b/python/tvm/script/intrin.py
index 053cd4a..48f50a2 100644
--- a/python/tvm/script/intrin.py
+++ b/python/tvm/script/intrin.py
@@ -16,9 +16,11 @@
# under the License.
"""TVM Script Parser Intrinsic Classes"""
# pylint: disable=redefined-builtin, relative-beyond-top-level
+from typing import List, Any
+
import tvm.tir
from .registry import register
-from .utils import get_param_list, from_synr_span
+from .utils import get_param_list, tvm_span_from_synr
class Intrin:
@@ -29,8 +31,8 @@ class Intrin:
def signature(self):
return "tir." + self.intrin.__name__, get_param_list(self.intrin)
- def handle(self, arg_list, span):
- return self.intrin(*arg_list, span=from_synr_span(span))
+ def handle(self, arg_list: List[Any], span: tvm.ir.Span):
+ return self.intrin(*arg_list, span=tvm_span_from_synr(span))
@register
@@ -99,6 +101,16 @@ def float64(imm, span):
@register
+def min_value(dtype, span):
+ return tvm.tir.min_value(dtype, span)
+
+
+@register
+def max_value(dtype, span):
+ return tvm.tir.max_value(dtype, span)
+
+
+@register
def floordiv(x, y, span):
return tvm.tir.floordiv(x, y, span)
@@ -145,7 +157,7 @@ def get_axis(begin, end, iter_type, span):
block_var_dom = tvm.ir.Range.from_min_extent(begin, extent)
iter_type_dict = {"data_par": 0, "reduce": 2, "scan": 3, "opaque": 4}
- return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type],
span)
+ return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type],
span=span)
@register
diff --git a/python/tvm/script/node.py b/python/tvm/script/node.py
new file mode 100644
index 0000000..039eeb4
--- /dev/null
+++ b/python/tvm/script/node.py
@@ -0,0 +1,150 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+"""TVM Script nodes."""
+
+from typing import Optional, Union, List, Callable
+import synr
+
+from tvm.runtime import ObjectGeneric
+from tvm.tir import PrimExpr, Buffer, BufferLoad
+from tvm.ir import Span
+
+
+class Slice:
+ """A helper class to present slice information for BufferSlice
+
+ Parameters
+ ----------
+ start : Union[PrimExpr, int]
+ The start index.
+
+ stop : Optional[Union[PrimExpr, int]]
+ The stop index, None means the Slice is an element-wise index
+
+ span : Optional[Span]
+ The location of the slice in the source.
+ """
+
+ start: Union[PrimExpr, int]
+ stop: Optional[Union[PrimExpr, int]]
+ span: Optional[Span]
+
+ def __init__(
+ self,
+ start: Union[PrimExpr, int],
+ stop: Optional[Union[PrimExpr, int]] = None,
+ span: Optional[Span] = None,
+ ):
+ self.start = start
+ self.stop = stop
+ self.span = span
+
+
+class BufferSlice(ObjectGeneric):
+ """A generic object for representing general buffer access. Following
cases are supported:
+ - element wise access buffer[i, j], which can be converted to
BufferLoad if necessary
+ - slice access buffer[i: i + 1, j : j + 2]
+ - union of element and slice buffer[i, j: j + 2]
+
+ This node is used in TVMScript to parse BufferLoad, BufferRegion and
Realize
+
+ Parameters
+ ----------
+ buffer : Buffer
+ The buffer.
+
+ indices : List[Union[Slice, PrimExpr, int]]
+ The access indexes can be slice, PrimExpr or int.
+
+ report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
+ The error report func
+
+ span : Optional[Span]
+ The location of the buffer access in the source.
+ """
+
+ buffer: Buffer
+ slices: List[Slice]
+ report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
+ span: Optional[Span]
+
+ def __init__(
+ self,
+ buffer: Buffer,
+ indices: List[Union[Slice, PrimExpr, int]],
+ report_error: Callable[[str, Union[Span, synr.ast.Span]], None],
+ span: Optional[Span] = None,
+ ):
+ def check_index(index: Union[int, PrimExpr]):
+ """ Check input index is non-negative integer or PrimExpr"""
+ if isinstance(index, int):
+ if index < 0:
+ report_error("Negative index is not allowed during buffer
access", span)
+ elif isinstance(index, PrimExpr):
+ if index.dtype != "int32":
+ report_error(
+ "index expected an int32 type PrimExpr but got " +
str(index.dtype),
+ index.span,
+ )
+ else:
+ report_error(
+ "Unsupported index type, expected int or tvm.tir.PrimExpr,
but got "
+ + str(type(index)),
+ span,
+ )
+
+ slices: List[Slice] = []
+ for index in indices:
+ if isinstance(index, Slice):
+ check_index(index.start)
+ check_index(index.stop)
+ slices.append(index)
+ elif isinstance(index, (PrimExpr, int)):
+ check_index(index)
+ slices.append(Slice(index))
+ else:
+ report_error(
+ "Unsupported index type for BufferSlice, "
+ + "expected int, tvm.tir.PrimExpr, tvm.tir.Slice, but got "
+ + str(type(index)),
+ span,
+ )
+
+ self.buffer = buffer
+ self.slices = slices
+ self.report_error = report_error
+ self.span = span
+
+ def __str__(self):
+ regions: List[str] = []
+ for s in self.slices:
+ if s.stop is None:
+ regions.append(str(s.start))
+ else:
+ regions.append(str(s.start) + ": " + str(s.stop))
+
+ return self.buffer.name + "[" + ", ".join(regions) + "]"
+
+ def asobject(self) -> BufferLoad:
+ """Convert object."""
+ for s in self.slices:
+ if s.stop is not None:
+ self.report_error("BufferLoad only accepts elementwise
access", self.span)
+
+ indices = [s.start for s in self.slices]
+ return BufferLoad(self.buffer, indices, span=self.span)
diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py
index 33b0bab..8f6d338 100644
--- a/python/tvm/script/parser.py
+++ b/python/tvm/script/parser.py
@@ -24,6 +24,7 @@ use for error reporting.
import json
import operator
import inspect
+from typing import Union
from synr import ast, Transformer, to_ast
import tvm
@@ -32,6 +33,7 @@ from tvm._ffi.base import TVMError
from tvm.ir import GlobalVar
from . import context_maintainer, ty
+from .context_maintainer import BlockInfo
from .meta_unparser import MetaUnparser
from .registry import Registry
from .intrin import Intrin
@@ -39,7 +41,8 @@ from .special_stmt import SpecialStmt
from .scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler
from . import _ffi_api
from .diagnostics import TVMDiagnosticCtx
-from .utils import from_synr_span
+from .utils import tvm_span_from_synr, synr_span_from_tvm,
call_with_error_reporting
+from .node import Slice, BufferSlice
class CallArgumentReader(object):
@@ -158,7 +161,7 @@ class TVMScriptParser(Transformer):
def init_function_parsing_env(self):
"""Initialize function parsing environment"""
- self.context = context_maintainer.ContextMaintainer(self) # scope
emitter
+ self.context = context_maintainer.ContextMaintainer(self.report_error)
# scope emitter
def init_meta(self, meta_dict):
if meta_dict is not None:
@@ -182,7 +185,7 @@ class TVMScriptParser(Transformer):
return transform_res
- def report_error(self, message, span):
+ def report_error(self, message: str, span: Union[ast.Span, tvm.ir.Span]):
"""Report an error occuring at a location.
This just dispatches to synr's DiagnosticContext.
@@ -191,9 +194,11 @@ class TVMScriptParser(Transformer):
----------
message : str
Error message
- span : synr.ast.Span
+ span : Union[synr.ast.Span, tvm.ir.Span】
Location of the error
"""
+ if isinstance(span, tvm.ir.Span):
+ span = synr_span_from_tvm(span)
self.error(message, span)
def parse_body(self, parent):
@@ -221,7 +226,7 @@ class TVMScriptParser(Transformer):
)
else:
return (
- tvm.tir.SeqStmt(body, from_synr_span(ast.Span.union(spans)))
+ tvm.tir.SeqStmt(body,
tvm_span_from_synr(ast.Span.union(spans)))
if len(body) > 1
else body[0]
)
@@ -270,6 +275,13 @@ class TVMScriptParser(Transformer):
internal_args.append(reader.get_kwarg(i + 1 + len(pos_only),
arg_name, default=default))
if varargs is not None:
internal_args.extend(reader.get_varargs(len(pos_only) +
len(kwargs) + 1))
+ elif len(args) + len(kw_args) > len(pos_only) + len(kwargs):
+ self.report_error(
+ "Arguments mismatched. "
+ + f"Expected {len(pos_only) + len(kwargs)} args but got "
+ + f"{len(args) + len(kw_args)}",
+ node_call.span,
+ )
return internal_args
def parse_type(self, type_node, parent):
@@ -401,25 +413,52 @@ class TVMScriptParser(Transformer):
"""
self.init_function_parsing_env()
- self.context.new_scope(nodes=node.body.stmts)
+ self.context.enter_scope(nodes=node.body.stmts)
# add parameters of function
for arg in node.params:
arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
- self.context.update_symbol(arg.name, arg_var)
+ self.context.update_symbol(arg.name, arg_var, node)
self.context.func_params.append(arg_var)
- # fetch the body and return a tir.PrimFunc
+ # New Scope : Implicit root block
+ # Each function contains an implicit root block in TensorIR,
+ # so here we need a block scope for it. Please note that
`enter_block_scope`
+ # will not create a block directly but just stores some information.
+ # If the PrimFunc is not a TensorIR func (e.g. TE scheduled func or
low-level func),
+ # the root block will not be added. The logic to add root block is in
`_ffi_api.Complete`
+ self.context.enter_block_scope(nodes=node.body.stmts)
+
+ # fetch the body of root block
+ body = self.parse_body(node.body)
+ # Emit Scope : Implicit root block
+ root_info: BlockInfo = self.context.current_block_scope()
+ self.context.exit_block_scope()
+
+ # return a tir.PrimFunc
+ dict_attr = self.context.func_dict_attr
func = tvm.tir.PrimFunc(
self.context.func_params,
- self.parse_body(node.body),
+ body,
ret_type=self.parse_type(node.ret_type, node),
buffer_map=self.context.func_buffer_map,
- attrs=tvm.ir.make_node("DictAttrs", **self.context.func_dict_attr),
- span=from_synr_span(node.span),
+ attrs=tvm.ir.make_node("DictAttrs", **dict_attr) if dict_attr else
None,
+ span=tvm_span_from_synr(node.span),
+ )
+
+ # Fix the PrimFunc
+ # 1. generate root block if necessary
+ # 2. generate surrounding loops for blocks if necessary
+
+ func = call_with_error_reporting(
+ self.report_error,
+ node.span,
+ _ffi_api.Complete,
+ func,
+ root_info.alloc_buffers,
)
- self.context.pop_scope()
+ self.context.exit_scope()
return func
def transform_Assign(self, node):
@@ -470,12 +509,12 @@ class TVMScriptParser(Transformer):
var = tvm.te.var(
node.lhs.id.name,
self.parse_type(node.ty, node.lhs),
- span=from_synr_span(node.lhs.span),
+ span=tvm_span_from_synr(node.lhs.span),
)
- self.context.update_symbol(var.name, var)
+ self.context.update_symbol(var.name, var, node)
body = self.parse_body(node)
self.context.remove_symbol(var.name)
- return tvm.tir.LetStmt(var, value, body,
span=from_synr_span(node.span))
+ return tvm.tir.LetStmt(var, value, body,
span=tvm_span_from_synr(node.span))
self.report_error("Unsupported Assign stmt", node.span)
@@ -484,28 +523,28 @@ class TVMScriptParser(Transformer):
symbol = self.transform(node.params[0])
indexes = self.transform(node.params[1])
rhs = self.transform(node.params[2])
- rhs_span = from_synr_span(node.params[2].span)
+ rhs_span = tvm_span_from_synr(node.params[2].span)
if isinstance(symbol, tvm.tir.Buffer):
# BufferStore
return tvm.tir.BufferStore(
symbol,
tvm.runtime.convert(rhs, span=rhs_span),
indexes,
- span=from_synr_span(node.span),
+ span=tvm_span_from_synr(node.span),
)
else:
if len(indexes) != 1:
self.report_error(
f"Store is only allowed with one index, but {len(indexes)}
were provided.",
- Span.union([x.span for x in indexes]),
+ tvm.ir.Span.union([x.span for x in indexes]),
)
# Store
return tvm.tir.Store(
symbol,
tvm.runtime.convert(rhs, span=rhs_span),
indexes[0],
- tvm.runtime.convert(True, span=from_synr_span(node.span)),
- span=from_synr_span(node.span),
+ tvm.runtime.convert(True, span=tvm_span_from_synr(node.span)),
+ span=tvm_span_from_synr(node.span),
)
def transform_Assert(self, node):
@@ -520,7 +559,7 @@ class TVMScriptParser(Transformer):
message = self.transform(node.msg)
body = self.parse_body(node)
return tvm.tir.AssertStmt(
- condition, tvm.runtime.convert(message), body,
span=from_synr_span(node.span)
+ condition, tvm.runtime.convert(message), body,
span=tvm_span_from_synr(node.span)
)
def transform_For(self, node):
@@ -529,7 +568,8 @@ class TVMScriptParser(Transformer):
For(expr target, expr iter, stmt* body, stmt* orelse, string?
type_comment)
By now 1 pattern of For is supported:
1. for scope handler
- for name in
tir.serial()/tir.parallel()/tir.vectorized()/tir.unroll()
+ for name in
tir.serial()/tir.parallel()/tir.vectorized()/tir.unroll()/tir.range()/
+ tir.grid()/tir.thread_binding()
"""
if not isinstance(node.rhs, ast.Call):
@@ -543,14 +583,14 @@ class TVMScriptParser(Transformer):
old_lineno, old_col_offset = self.current_lineno,
self.current_col_offset
self.current_lineno = node.span.start_line
self.current_col_offset = node.span.start_column
- self.context.new_scope(nodes=node.body.stmts)
+ self.context.enter_scope(nodes=node.body.stmts)
# for scope handler process the scope
arg_list = self.parse_arg_list(func, node.rhs)
func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
func.body = self.parse_body(node)
res = func.exit_scope(node, self.context, arg_list,
node.rhs.func_name.span)
# exit the scope
- self.context.pop_scope()
+ self.context.exit_scope()
self.current_lineno, self.current_col_offset = old_lineno,
old_col_offset
return res
@@ -561,9 +601,9 @@ class TVMScriptParser(Transformer):
withitem = (expr context_expr, expr? optional_vars)
By now 2 patterns of With is supported:
1. with scope handler with symbol def
- with tir.allocate() as targets:
+ with tir.block(*axes)/tir.allocate() as targets:
2. with scope handler without symbol def
- with tir.let()/tir.Assert()/tir.attr()//tir.realize()
+ with tir.let()/tir.Assert()/tir.attr()/tir.realize()
"""
if not isinstance(node.rhs, ast.Call):
@@ -582,14 +622,14 @@ class TVMScriptParser(Transformer):
old_lineno, old_col_offset = self.current_lineno,
self.current_col_offset
self.current_lineno = node.body.span.start_line
self.current_col_offset = node.body.span.start_column
- self.context.new_scope(nodes=node.body.stmts)
+ self.context.enter_block_scope(nodes=node.body.stmts)
# with scope handler process the scope
arg_list = self.parse_arg_list(func, node.rhs)
func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span)
func.body = self.parse_body(node)
res = func.exit_scope(node, self.context, arg_list,
node.rhs.func_name.span)
# exit the scope
- self.context.pop_scope()
+ self.context.exit_block_scope()
self.current_lineno, self.current_col_offset = old_lineno,
old_col_offset
return res
@@ -601,19 +641,21 @@ class TVMScriptParser(Transformer):
condition = self.transform(node.condition)
# then body
- self.context.new_scope(nodes=node.true.stmts)
+ self.context.enter_scope(nodes=node.true.stmts)
then_body = self.parse_body(node)
- self.context.pop_scope()
+ self.context.exit_scope()
# else body
if len(node.false.stmts) > 0:
- self.context.new_scope(nodes=node.false.stmts)
+ self.context.enter_scope(nodes=node.false.stmts)
else_body = self.parse_body(node)
- self.context.pop_scope()
+ self.context.exit_scope()
else:
else_body = None
- return tvm.tir.IfThenElse(condition, then_body, else_body,
span=from_synr_span(node.span))
+ return tvm.tir.IfThenElse(
+ condition, then_body, else_body, span=tvm_span_from_synr(node.span)
+ )
def transform_Call(self, node):
"""Call visitor
@@ -633,18 +675,26 @@ class TVMScriptParser(Transformer):
lhs = self.transform(node.params[0])
rhs = self.transform(node.params[1])
return self._binop_maker[node.func_name.name](
- lhs, rhs, span=from_synr_span(node.span)
+ lhs, rhs, span=tvm_span_from_synr(node.span)
)
if node.func_name.name in self._unaryop_maker:
rhs = self.transform(node.params[0])
- return self._unaryop_maker[node.func_name.name](rhs,
span=from_synr_span(node.span))
+ return self._unaryop_maker[node.func_name.name](
+ rhs, span=tvm_span_from_synr(node.span)
+ )
self.report_error(f"Unsupported operator {node.func_name.name}.",
node.func_name.span)
else:
func = self.transform(node.func_name)
if isinstance(func, Intrin) and not func.stmt:
# pattern 1
arg_list = self.parse_arg_list(func, node)
- return func.handle(arg_list, node.func_name.span)
+ return call_with_error_reporting(
+ self.report_error,
+ node.func_name.span,
+ func.handle,
+ arg_list,
+ node.func_name.span,
+ )
else:
args = [self.transform(arg) for arg in node.params]
kw_args = {
@@ -653,7 +703,7 @@ class TVMScriptParser(Transformer):
if isinstance(func, tvm.tir.op.Op):
# pattern 2
return tvm.tir.Call(
- kw_args["dtype"], func, args,
span=from_synr_span(node.span)
+ kw_args["dtype"], func, args,
span=tvm_span_from_synr(node.span)
)
elif callable(func):
# pattern 3
@@ -700,7 +750,13 @@ class TVMScriptParser(Transformer):
)
if isinstance(func, Intrin) and func.stmt:
- return func.handle(arg_list, node.call.func_name.span)
+ return call_with_error_reporting(
+ self.report_error,
+ node.call.func_name.span,
+ func.handle,
+ arg_list,
+ node.call.func_name.span,
+ )
elif isinstance(func, WithScopeHandler) and func.concise_scope and not
func.def_symbol:
func.enter_scope(node, self.context, arg_list,
node.call.func_name.span)
func.body = self.parse_body(node)
@@ -716,11 +772,7 @@ class TVMScriptParser(Transformer):
end = self.transform(node.end)
if not (isinstance(node.step, ast.Constant) and node.step.value == 1):
self.report_error("Only step size 1 is supported for slices.",
node.step.span)
- extent = end - start
- if isinstance(extent, tvm.tir.PrimExpr):
- ana = tvm.arith.Analyzer()
- extent = ana.simplify(extent)
- return tvm.ir.Range.from_min_extent(start, extent,
span=from_synr_span(node.span))
+ return Slice(start, end)
def transform_Subscript(self, node):
"""Array access visitor.
@@ -728,7 +780,7 @@ class TVMScriptParser(Transformer):
By now only 2 types of Subscript are supported:
1. Buffer[index, index, ...], Buffer element access(BufferLoad &
BufferStore)
Var[index] Buffer element access()
- 2. meta[type_key][index], Meta info access
+ 2. Buffer[start: stop, start: stop, ...],
BufferRealize(realize(buffer[...]))
"""
symbol = self.transform(node.params[0])
@@ -736,19 +788,27 @@ class TVMScriptParser(Transformer):
self.report_error(f"Variable {node.value.id} is not defined.",
node.params[0].span)
indexes = [self.transform(x) for x in node.params[1].values]
- if isinstance(indexes[0], tvm.ir.Range):
- return symbol, indexes
-
if isinstance(symbol, tvm.tir.expr.Var):
- return tvm.tir.Load("float32", symbol, indexes, True,
span=from_synr_span(node.span))
- if isinstance(symbol, tvm.tir.Buffer):
- return tvm.tir.BufferLoad(symbol, indexes,
span=from_synr_span(node.span))
-
- self.report_error(
- f"Cannot subscript from a {type(symbol).__name__}. Only variables
and "
- "buffers are supported.",
- node.params[0].span,
- )
+ for index in indexes:
+ if not isinstance(index, (tvm.tir.PrimExpr, int)):
+ self.report_error(
+ "Buffer load indexes should be int or PrimExpr, but
they are "
+ + type(index),
+ node.span,
+ )
+ return tvm.tir.Load(
+ "float32", symbol, indexes, True,
span=tvm_span_from_synr(node.span)
+ )
+ elif isinstance(symbol, tvm.tir.Buffer):
+ return BufferSlice(
+ symbol, indexes, self.report_error,
span=tvm_span_from_synr(node.span)
+ )
+ else:
+ self.report_error(
+ f"Cannot subscript from a {type(symbol).__name__}. Only
variables and "
+ "buffers are supported.",
+ node.params[0].span,
+ )
def transform_Attr(self, node):
"""Visitor for field access of the form `x.y`.
@@ -756,7 +816,7 @@ class TVMScriptParser(Transformer):
This visitor is used to lookup function and symbol names. We have two
cases to handle here:
1. If we have a statement of the form `tir.something`, then we lookup
- `tir.somthing` in the `Registry`. If the function is not in the
+ `tir.something` in the `Registry`. If the function is not in the
registry, then we try to find a `tvm.ir.op.Op` with the same name.
2. All other names `tvm.something` are lookup up in this current python
namespace.
@@ -875,7 +935,7 @@ class TVMScriptParser(Transformer):
Constant values include `None`, `"strings"`, `2` (integers), `4.2`
(floats), and `true` (booleans).
"""
- return tvm.runtime.convert(node.value, span=from_synr_span(node.span))
+ return tvm.runtime.convert(node.value,
span=tvm_span_from_synr(node.span))
def transform_TypeConstant(self, node):
"""Constant value visitor for types.
@@ -902,8 +962,7 @@ def from_source(src):
----------
src : [str, function, class]
Pruned source of original script
- func_lineno : Optional[int]
- The line number of the first line of the script to be parsed
+
Returns
-------
functions : PrimFunc or IRModule
diff --git a/python/tvm/script/registry.py b/python/tvm/script/registry.py
index 3895701..245cc01 100644
--- a/python/tvm/script/registry.py
+++ b/python/tvm/script/registry.py
@@ -16,7 +16,8 @@
# under the License.
"""TVM Script Parser Function Registry """
# pylint: disable=inconsistent-return-statements, relative-beyond-top-level,
import-outside-toplevel
-import inspect
+import types
+from typing import Union, Callable, Dict, Optional, Any
class Registry(object):
@@ -24,10 +25,10 @@ class Registry(object):
All these maps are static
"""
- registrations = dict()
+ registrations: Dict[str, type] = dict()
@staticmethod
- def lookup(name):
+ def lookup(name: str) -> Optional[Any]:
if name in Registry.registrations:
# every time we create a new handler
# since we may want to keep some local info inside it
@@ -35,12 +36,14 @@ class Registry(object):
return None
-def register(inputs):
+def register(inputs: Union[Callable, type]) -> type:
"""Register Intrin/ScopeHandler/SpecialStmt"""
- if inspect.isfunction(inputs):
+ registration: type
+ if isinstance(inputs, types.FunctionType):
+ # is function
from .intrin import Intrin
- def create_new_intrin(func):
+ def create_new_intrin(func) -> type:
class NewIntrin(Intrin):
def __init__(self):
super().__init__(func)
@@ -48,11 +51,12 @@ def register(inputs):
return NewIntrin
registration = create_new_intrin(inputs)
- elif inspect.isclass(inputs):
+ elif isinstance(inputs, type):
+ # is class
registration = inputs
else:
raise ValueError()
- key = registration().signature()[0]
+ key: str = registration().signature()[0]
Registry.registrations[key] = registration
return registration
diff --git a/python/tvm/script/scope_handler.py
b/python/tvm/script/scope_handler.py
index 9449cbd..c7d841a 100644
--- a/python/tvm/script/scope_handler.py
+++ b/python/tvm/script/scope_handler.py
@@ -16,32 +16,59 @@
# under the License.
"""TVM Script Parser Scope Handler Classes"""
# pylint: disable=redefined-builtin, unused-argument, invalid-name,
relative-beyond-top-level
+from typing import Tuple, Any, Callable, Optional, List, Union, Mapping
+import synr
from synr import ast
import tvm.tir
-from .utils import get_param_list, from_synr_span
+from tvm.runtime import Object
+from tvm.ir import Span, Range
+from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
+
+from .context_maintainer import ContextMaintainer
+from .utils import (
+ get_param_list,
+ tvm_span_from_synr,
+ buffer_slice_to_region,
+ call_with_error_reporting,
+)
from .registry import register
+from .node import BufferSlice
class ScopeHandler:
"""Base class for all scope handlers"""
- def __init__(self, func):
- self.func = func
- self.body = None
- self.node = None
- self.context = None
+ def __init__(self, func: Callable):
+ self.func: Callable = func
+ self.body: Optional[Stmt] = None
+ self.node: Optional[synr.ast.Node] = None
+ self.context: Optional[ContextMaintainer] = None
- def signature(self):
+ def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
return "tir." + self.func.__name__, get_param_list(self.func)
- def enter_scope(self, node, context, arg_list, span):
+ def enter_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
pass
- def exit_scope(self, node, context, arg_list, span):
+ def exit_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
self.node = node
self.context = context
- return self.func(*arg_list, span=from_synr_span(span))
+ return call_with_error_reporting(
+ context.report_error, span, self.func, *arg_list,
span=tvm_span_from_synr(span)
+ )
class WithScopeHandler(ScopeHandler):
@@ -55,24 +82,29 @@ class WithScopeHandler(ScopeHandler):
@staticmethod
def get_optional_var_names(node, context):
"""Get list of names from ast.With's optional_vars"""
- assert isinstance(node, ast.With)
-
- var_names = None
- if isinstance(node.items[0].optional_vars, ast.Name):
- var_names = [node.items[0].optional_vars.id]
- elif isinstance(node.items[0].optional_vars, (ast.List, ast.Tuple)):
- for var in node.items[0].optional_vars.elts:
- if not isinstance(var, ast.Name):
- context.report_error("Invalid optional var definition")
- var_names = [var.id for var in node.items[0].optional_vars.elts]
+ assert isinstance(
+ node, ast.With
+ ), f"WithScopeHandler expected ast.With but got {type(node)}"
+
+ if isinstance(node.lhs, list):
+ for var in node.lhs:
+ if not isinstance(var, ast.Var):
+ context.report_error(
+ f"Invalid optional var definition, expected Var but
got {type(var)}",
+ node.span,
+ )
+ var_names = [var.id.name for var in node.lhs]
else:
- context.report_error("Invalid optional var definition")
+ context.report_error(
+ f"Invalid optional var definition, expected list of Var but
got {type(node.lhs)}",
+ node.span,
+ )
return var_names
@register
class Allocate(WithScopeHandler):
- """ With scope handler tir.alloc_with_scope(var, extents, dtype, scope,
condition) """
+ """ With scope handler tir.allocate(extents, dtype, scope, condition) """
def __init__(self):
def allocate(extents, dtype, scope, condition=True, span=None):
@@ -86,7 +118,13 @@ class Allocate(WithScopeHandler):
super().__init__(allocate, concise_scope=True, def_symbol=True)
self.buffer_var = None
- def enter_scope(self, node, context, arg_list, span):
+ def enter_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
# define buffer vars in symbol table
if isinstance(node, ast.With):
names = WithScopeHandler.get_optional_var_names(node, context)
@@ -98,13 +136,13 @@ class Allocate(WithScopeHandler):
else:
raise Exception("Internal Bug")
- def setup_buffer_var(extents, dtype, scope, condition=True, span=None):
+ def setup_buffer_var(extents, dtype, scope, condition=True, span: Span
= None):
"""Setup buffer var for a given type."""
buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype))
self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)
- setup_buffer_var(*arg_list, span=from_synr_span(node.lhs.id.span))
- context.update_symbol(name, self.buffer_var)
+ setup_buffer_var(*arg_list, span=tvm_span_from_synr(node.lhs.id.span))
+ context.update_symbol(name, self.buffer_var, node)
@register
@@ -115,10 +153,10 @@ class LaunchThread(WithScopeHandler):
def launch_thread(env_var, extent, span):
extent = tvm.runtime.convert(extent, span=span)
return tvm.tir.AttrStmt(
- tvm.tir.IterVar(
+ IterVar(
None,
env_var,
- getattr(tvm.tir.IterVar, "ThreadIndex"),
+ getattr(IterVar, "ThreadIndex"),
self.context.func_var_env_dict[env_var],
span=span,
),
@@ -136,8 +174,19 @@ class Realize(WithScopeHandler):
""" With scope handler tir.realize(buffer_bounds, scope, condition) """
def __init__(self):
- def realize(buffer_bounds, scope, condition=True, span=None):
- buffer, bounds = buffer_bounds
+ def realize(
+ buffer_slice: BufferSlice, scope: str, condition: bool = True,
span: bool = None
+ ):
+ assert self.context, "call 'exit_scope' before 'enter_scope'"
+ buffer: Buffer = buffer_slice.buffer
+ bounds: List[Range] = []
+ for s in buffer_slice.slices:
+ min: Union[PrimExpr, int] = s.start
+ extent: Union[PrimExpr, int] = 1 if s.stop is None else s.stop
- s.start
+ if isinstance(extent, PrimExpr):
+ extent = self.context.analyzer.simplify(extent)
+ bounds.append(Range.from_min_extent(min, extent, span=s.span))
+
scope = tvm.runtime.convert(scope, span=span)
return tvm.tir.AttrStmt(
buffer,
@@ -185,92 +234,380 @@ class Let(WithScopeHandler):
super().__init__(let, concise_scope=False, def_symbol=False)
+@register
+class Block(WithScopeHandler):
+ """ With scope handler tir.block(extents, name) as iter_vars"""
+
+ def __init__(self):
+ def block(axes=None, name_hint: str = "", span: Optional[Span] = None):
+ assert (
+ self.node and self.context and self.body
+ ), "call 'exit_scope' before 'enter_scope'"
+ block_info = self.context.block_info_stack[-1]
+ if axes is None:
+ axes = []
+ if len(axes) != len(self.block_vars):
+ self.context.report_error(
+ "Inconsistent number of block vars, "
+ + f"there are {len(axes)} axes but {len(self.block_vars)}
block vars. "
+ + "The number of block vars should match the number of
axes.",
+ self.node.span,
+ )
+ block_iters: List[IterVar] = []
+ for i, axis in enumerate(axes):
+ axis = tvm.runtime.convert(axis)
+ if isinstance(axis, tvm.tir.PrimExpr):
+ block_var_dom = Range.from_min_extent(0, axis)
+ block_iters.append(IterVar(block_var_dom,
self.block_vars[i], 0))
+ elif isinstance(axis, Range):
+ block_iters.append(IterVar(axis, self.block_vars[i], 0))
+ elif isinstance(axis, IterVar):
+ block_iters.append(IterVar(axis.dom, self.block_vars[i],
axis.iter_type))
+ else:
+ self.context.report_error(
+ "Invalid argument of tir.block(), "
+ + f"expected PrimExpr, Range or IterVar, but got
{type(axis)}",
+ self.node.span,
+ )
+
+ # create block read/write regions
+
+ reads: List[BufferRegion] = (
+ [buffer_slice_to_region(read) for read in block_info.reads]
+ if block_info.reads
+ else []
+ )
+ writes: List[BufferRegion] = (
+ [buffer_slice_to_region(write) for write in block_info.writes]
+ if block_info.writes
+ else []
+ )
+ inner = tvm.tir.Block(
+ block_iters,
+ reads,
+ writes,
+ name_hint,
+ self.body,
+ block_info.init,
+ block_info.alloc_buffers,
+ block_info.match_buffers,
+ block_info.annotations,
+ span,
+ )
+ # create block var iter binding
+ values: List[PrimExpr]
+ if not block_info.iter_bindings:
+ values = self.context.loop_stack[-2].copy()
+ if len(values) == 0:
+ values = [tvm.tir.const(float("nan"), dtype="float32")] *
len(block_iters)
+ elif len(values) != len(block_iters):
+ self.context.report_error(
+ "Number of block iter var and outer loop nesting
mismatch, "
+ + f"{len(block_iters)} block iter vars but
{len(values)} loops",
+ self.node.span,
+ )
+ else:
+ for block_var in self.block_vars:
+ if block_var not in block_info.iter_bindings:
+ self.context.report_error(
+ "Missing block iter var binding for " +
block_var.name,
+ self.node.span,
+ )
+ values = [block_info.iter_bindings[block_var] for block_var in
self.block_vars]
+ predicate = (
+ tvm.tir.const(True, "bool")
+ if block_info.predicate is None
+ else block_info.predicate
+ )
+ body = tvm.tir.BlockRealize(values, predicate, inner, span)
+ return body
+
+ super().__init__(func=block, concise_scope=False, def_symbol=True)
+ self.block_vars = None
+
+ def enter_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
+ # define block vars
+ assert isinstance(
+ node, ast.With
+ ), f"BlockScopeHandler expected to work on ast.With but got
{type(node)}"
+
+ var_names = WithScopeHandler.get_optional_var_names(node, context)
+ self.block_vars = [tvm.te.var(name) for name in var_names]
+ for block_var in self.block_vars:
+ context.update_symbol(block_var.name, block_var, node)
+
+
+@register
+class InitBlock(WithScopeHandler):
+ """ With scope handler tir.init()"""
+
+ def __init__(self):
+ def init(span: Span = None):
+ assert self.context, "call 'exit_scope' before 'enter_scope'"
+ if self.context.block_info_stack[-2].init is not None:
+ self.context.report_error("Duplicate init block declaration",
span)
+ self.context.block_info_stack[-2].init = self.body
+
+ super().__init__(func=init, concise_scope=False, def_symbol=True)
+
+
class ForScopeHandler(ScopeHandler):
"""Base class for all for scope handlers"""
def __init__(self, func):
super().__init__(func)
- self.loop_vars = None
+ self.loop_vars: Optional[List[Var]] = None
- def enter_scope(self, node, context, arg_list, span):
- assert isinstance(node, ast.For)
+ def enter_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
+ assert isinstance(node, ast.For), f"ForScopeHandler expected ast.For
but got {type(node)}"
loop_var_names = list()
spans = list()
if isinstance(node.lhs, ast.Var):
loop_var_names.append(node.lhs.id.name)
- spans.append(from_synr_span(node.lhs.id.span))
- elif isinstance(node.lhs, ast.Tuple):
- for elt in node.lhs.values:
+ spans.append(tvm_span_from_synr(node.lhs.id.span))
+ elif isinstance(node.lhs, list):
+ for elt in node.lhs:
if not isinstance(elt, ast.Var):
- context.report_error("Invalid loop var", elt.span)
+ context.report_error(
+ f"Invalid loop var. Expected a var, but got
{type(elt)}", elt.span
+ )
loop_var_names.append(elt.id.name)
- spans.append(from_synr_span(elt.id.span))
+ spans.append(tvm_span_from_synr(elt.id.span))
else:
- context.report_error("Invalid loop var", node.lhs.span)
+ context.report_error(
+ f"Invalid loop var. Expected var or list of vars as lhs, but
got {type(node.lhs)}",
+ span,
+ )
self.loop_vars = [
tvm.te.var(name, dtype="int32", span=span) for name, span in
zip(loop_var_names, spans)
]
for loop_var in self.loop_vars:
- context.update_symbol(loop_var.name, loop_var)
+ context.update_symbol(loop_var.name, loop_var, node)
+ context.loop_stack[-1].append(loop_var)
+
+ def exit_scope(
+ self,
+ node: synr.ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
+ assert self.loop_vars, "call 'exit_scope' before 'enter_scope'"
+ for _ in self.loop_vars:
+ context.loop_stack[-1].pop()
+ return super().exit_scope(node, context, arg_list, span)
+
+ def create_loop(
+ self,
+ begin: PrimExpr,
+ end: PrimExpr,
+ kind: ForKind,
+ thread_binding: Optional[str] = None,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ) -> tvm.tir.For:
+ """
+ Helper function for creating For in TVM Script parser.
+
+ Parameters
+ ----------
+ begin : PrimExpr
+ The beginning value.
+
+ end : PrimExpr
+ The endding value.
+
+ kind : ForKind
+ The type of the for.
+
+ thread_binding: Optional[str]
+ The thread this loop binds to.
+
+ annotations : Optional[Mapping[str, Object]]
+ Additional annotation hints.
+
+ span : Optional[Span]
+ The location of this for in the source code.
+
+ Returns
+ -------
+ for : For
+ The constructed For.
+ """
+ assert (
+ self.loop_vars and self.context and self.node
+ ), "call 'exit_scope' before 'enter_scope'"
+ if len(self.loop_vars) != 1:
+ self.context.report_error(
+ f"Expected exactly one loop var, but got {self.loop_vars}",
self.node.span
+ )
+ extent = end if begin == 0 else self.context.analyzer.simplify(end -
begin)
+ annos: Mapping[str, Object] = {}
+ if annotations is not None:
+ annos = {
+ key: tvm.tir.StringImm(val) if isinstance(val, str) else val
+ for key, val in annotations.items()
+ }
+ return tvm.tir.For(
+ self.loop_vars[0],
+ begin,
+ extent,
+ kind,
+ self.body,
+ thread_binding=thread_binding,
+ annotations=annos,
+ span=span,
+ )
@register
class Serial(ForScopeHandler):
- """ For scope handler tir.serial(begin, end)"""
+ """ For scope handler tir.serial(begin, end, annotations)"""
def __init__(self):
- def serial(begin, end, span):
- if len(self.loop_vars) != 1:
- self.context.report_error("Expect exact 1 loop var", span)
- ana = tvm.arith.Analyzer()
- extent = end if begin == 0 else ana.simplify(end - begin)
- return tvm.tir.For(self.loop_vars[0], begin, extent, 0, self.body,
span=span)
+ def serial(
+ begin: PrimExpr,
+ end: PrimExpr,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ):
+ return self.create_loop(begin, end, ForKind.SERIAL,
annotations=annotations, span=span)
super().__init__(serial)
@register
class Parallel(ForScopeHandler):
- """ For scope handler tir.parallel(begin, end)"""
+ """ For scope handler tir.parallel(begin, end, annotations)"""
def __init__(self):
- def parallel(begin, end, span):
- if len(self.loop_vars) != 1:
- self.context.report_error("Expect exact 1 loop var")
- ana = tvm.arith.Analyzer()
- extent = end if begin == 0 else ana.simplify(end - begin)
- return tvm.tir.For(self.loop_vars[0], begin, extent, 1, self.body,
span=span)
+ def parallel(
+ begin: PrimExpr,
+ end: PrimExpr,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ):
+ return self.create_loop(
+ begin, end, ForKind.PARALLEL, annotations=annotations,
span=span
+ )
super().__init__(parallel)
@register
class Vectorized(ForScopeHandler):
- """ For scope handler tir.vectorized(begin, end)"""
+ """ For scope handler tir.vectorized(begin, end, annotations)"""
def __init__(self):
- def vectorized(begin, end, span):
- if len(self.loop_vars) != 1:
- self.context.report_error("Expect exact 1 loop var")
- ana = tvm.arith.Analyzer()
- extent = end if begin == 0 else ana.simplify(end - begin)
- return tvm.tir.For(self.loop_vars[0], begin, extent, 2, self.body,
span=span)
+ def vectorized(
+ begin: PrimExpr,
+ end: PrimExpr,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ):
+ return self.create_loop(
+ begin, end, ForKind.VECTORIZED, annotations=annotations,
span=span
+ )
super().__init__(vectorized)
@register
class Unroll(ForScopeHandler):
- """ For scope handler tir.unroll(begin, end)"""
+ """ For scope handler tir.unroll(begin, end, annotations)"""
def __init__(self):
- def unroll(begin, end, span):
- if len(self.loop_vars) != 1:
- self.context.report_error("Expect exact 1 loop var")
- ana = tvm.arith.Analyzer()
- extent = end if begin == 0 else ana.simplify(end - begin)
- return tvm.tir.For(self.loop_vars[0], begin, extent, 3, self.body,
span=span)
+ def unroll(
+ begin: PrimExpr,
+ end: PrimExpr,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ):
+ return self.create_loop(
+ begin, end, ForKind.UNROLLED, annotations=annotations,
span=span
+ )
super().__init__(unroll)
+
+
+@register
+class ThreadBinding(ForScopeHandler):
+ """ For scope handler tir.thread_binding(begin, end, thread,
annotations)"""
+
+ def __init__(self):
+ def thread_binding(
+ begin: PrimExpr,
+ end: PrimExpr,
+ thread: str,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ):
+ thread_iter_var = IterVar(None, None, IterVar.ThreadIndex, thread,
span=span)
+ return self.create_loop(
+ begin,
+ end,
+ ForKind.THREAD_BINDING,
+ thread_binding=thread_iter_var,
+ annotations=annotations,
+ span=span,
+ )
+
+ super().__init__(thread_binding)
+
+
+@register
+class RangeHandler(ForScopeHandler):
+ """For scope handler range(begin, end, annotations)
+ Note that tir.range is totally the same as tir.serial
+ """
+
+ def __init__(self):
+ def for_range(
+ begin: PrimExpr,
+ end: PrimExpr,
+ annotations: Optional[Mapping[str, Object]] = None,
+ span: Optional[Span] = None,
+ ):
+ return self.create_loop(begin, end, ForKind.SERIAL,
annotations=annotations, span=span)
+
+ super().__init__(for_range)
+
+ def signature(self):
+ return "range", get_param_list(self.func)
+
+
+@register
+class Grid(ForScopeHandler):
+ """ For scope handler tir.grid(extents)"""
+
+ def __init__(self):
+ def grid(*extents: List[PrimExpr], span: Span):
+ assert (
+ self.node and self.context and self.loop_vars
+ ), "call 'exit_scope' before 'enter_scope'"
+ if len(self.loop_vars) != len(extents):
+ self.context.report_error(
+ "Inconsistent number of loop vars and extents, "
+ + f"got {len(self.loop_vars)} vs {len(extents)}",
+ self.node.span,
+ )
+ body = self.body
+ for loop_var, extent in zip(reversed(self.loop_vars),
reversed(extents)):
+ body = tvm.tir.For(loop_var, 0, extent, ForKind.SERIAL, body,
span=span)
+ return body
+
+ super().__init__(grid)
diff --git a/python/tvm/script/special_stmt.py
b/python/tvm/script/special_stmt.py
index 62ce1ea..6aa1239 100644
--- a/python/tvm/script/special_stmt.py
+++ b/python/tvm/script/special_stmt.py
@@ -17,30 +17,81 @@
"""TVM Script Parser Special Stmt Classes"""
# pylint: disable=unused-argument, no-self-argument,
inconsistent-return-statements
# pylint: disable=relative-beyond-top-level
+from typing import Callable, List, Optional, Tuple, Any, Mapping, Union
+
+import synr
from synr import ast
import tvm.tir
+from tvm.runtime import Object
from tvm import te
-from .utils import get_param_list, from_synr_span
+from tvm.ir import Span
+from tvm.tir import IntImm
+from .utils import (
+ get_param_list,
+ tvm_span_from_synr,
+ buffer_slice_to_region,
+ call_with_error_reporting,
+)
from .registry import register
+from .context_maintainer import ContextMaintainer
+from .node import BufferSlice
+
+
+def convert_to_int(
+ value: Union[IntImm, int],
+ arg_name: str,
+ report_error: Callable,
+ span: Union[Span, synr.ast.Span],
+) -> int:
+ """convert a const int or TVM IntImm to Python int.
+ Reports an error when input cannot be converted to int.
+
+ Parameters
+ ----------
+ value : Union[tvm.tir.IntImm, int]
+ The input value to be converted.
+ arg_name : str
+ Function argument name for error reporting.
+ report_error: Callable
+ The report error function handle
+ span : Union[synr.ast.Span, tvm.ir.Span]
+ Location of the error
+ """
+ if isinstance(value, IntImm):
+ return value.value
+ if isinstance(value, int):
+ return value
+ report_error(
+ f"Expected int or IntImm for {arg_name}, but got {str(type(value))}",
+ span,
+ )
class SpecialStmt:
"""Base class for all Special Stmts"""
- def __init__(self, func, def_symbol):
- self.func = func
- self.def_symbol = def_symbol
- self.node = None
- self.context = None
+ def __init__(self, func: Callable, def_symbol: bool):
+ self.func: Callable = func
+ self.def_symbol: bool = def_symbol
+ self.node: Optional[synr.ast.Node] = None
+ self.context: Optional[ContextMaintainer] = None
- def signature(self):
+ def signature(self) -> Tuple[str, Tuple[list, list, Any]]:
return "tir." + self.func.__name__, get_param_list(self.func)
- def handle(self, node, context, arg_list, span):
+ def handle(
+ self,
+ node: ast.Node,
+ context: ContextMaintainer,
+ arg_list: List[Any],
+ span: synr.ast.Span,
+ ):
self.node = node
self.context = context
- return self.func(*arg_list, span=from_synr_span(span))
+ return call_with_error_reporting(
+ context.report_error, span, self.func, *arg_list,
span=tvm_span_from_synr(span)
+ )
@register
@@ -67,17 +118,20 @@ class MatchBuffer(SpecialStmt):
buffer_type="default",
span=None,
):
- assert isinstance(self.node, ast.Assign)
-
+ if not isinstance(self.node, ast.Assign):
+ self.context.report_error(
+ "match_buffer must be assigned to a buffer, e.g. A =
match_buffer(...)",
+ self.node.span,
+ )
if param not in self.context.func_params:
self.context.report_error(
"Can not bind non-input param to buffer",
self.node.rhs.params[0].span
)
if strides is None:
strides = []
- align = align.value if not isinstance(align, int) else align
- offset_factor = (
- offset_factor.value if not isinstance(offset_factor, int) else
offset_factor
+ align = convert_to_int(align, "align", self.context.report_error,
self.node.span)
+ offset_factor = convert_to_int(
+ offset_factor, "offset_factor", self.context.report_error,
self.node.span
)
buffer = tvm.tir.decl_buffer(
shape,
@@ -93,7 +147,7 @@ class MatchBuffer(SpecialStmt):
span=span,
)
self.context.func_buffer_map[param] = buffer
- self.context.update_symbol(self.node.lhs.id.name, buffer)
+ self.context.update_symbol(self.node.lhs.id.name, buffer,
self.node)
super().__init__(match_buffer, def_symbol=True)
@@ -121,13 +175,17 @@ class BufferDeclare(SpecialStmt):
buffer_type="default",
span=None,
):
- assert isinstance(self.node, ast.Assign)
+ if not isinstance(self.node, ast.Assign):
+ self.context.report_error(
+ "buffer_decl must be assigned to a buffer, e.g. A =
buffer_decl(...)",
+ self.node.span,
+ )
if strides is None:
strides = []
- align = align.value if not isinstance(align, int) else align
- offset_factor = (
- offset_factor.value if not isinstance(offset_factor, int) else
offset_factor
+ align = convert_to_int(align, "align", self.context.report_error,
self.node.span)
+ offset_factor = convert_to_int(
+ offset_factor, "offset_factor", self.context.report_error,
self.node.span
)
buffer = tvm.tir.decl_buffer(
shape,
@@ -142,21 +200,293 @@ class BufferDeclare(SpecialStmt):
buffer_type,
span=span,
)
- self.context.update_symbol(self.node.lhs.id.name, buffer)
+ self.context.update_symbol(self.node.lhs.id.name, buffer,
self.node)
return buffer
super().__init__(buffer_decl, def_symbol=True)
@register
+class AllocBuffer(SpecialStmt):
+ """Special function alloc_buffer(shape, dtype, data, strides, elem_offset,
scope, align,
+ offset_factor, buffer_type)
+
+ Example
+ -------
+ .. code-block:: python
+
+ A = tir.alloc_buffer((128, 128), dtype="float32")
+ """
+
+ def __init__(self):
+ def alloc_buffer(
+ shape,
+ dtype="float32",
+ data=None,
+ strides=None,
+ elem_offset=None,
+ scope="",
+ align=-1,
+ offset_factor=0,
+ buffer_type="default",
+ span=None,
+ ):
+ if not isinstance(self.node, ast.Assign):
+ self.context.report_error(
+ "alloc_buffer must be assigned to a buffer, e.g. A =
alloc_buffer(...)",
+ self.node.span,
+ )
+
+ if strides is None:
+ strides = []
+ align = convert_to_int(align, "align", self.context.report_error,
self.node.span)
+ offset_factor = convert_to_int(
+ offset_factor, "offset_factor", self.context.report_error,
self.node.span
+ )
+ buffer = tvm.tir.decl_buffer(
+ shape,
+ dtype,
+ self.node.lhs.id.name,
+ data,
+ strides,
+ elem_offset,
+ scope,
+ align,
+ offset_factor,
+ buffer_type,
+ span=span,
+ )
+ self.context.current_block_scope().alloc_buffers.append(buffer)
+ self.context.update_symbol(self.node.lhs.id.name, buffer,
self.node)
+
+ super().__init__(alloc_buffer, def_symbol=True)
+
+
+@register
+class BlockVarBind(SpecialStmt):
+ """Special function bind(block_iter, binding_value)
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.bind(vx, i)
+ """
+
+ def __init__(self):
+ def bind(iter_var, values, span=None):
+ block_scope = self.context.current_block_scope()
+ if iter_var in block_scope.iter_bindings:
+ self.context.report_error("Duplicate iter_var bindings of " +
str(iter_var), span)
+ block_scope.iter_bindings[iter_var] = values
+
+ super().__init__(bind, def_symbol=False)
+
+
+@register
+class BlockReads(SpecialStmt):
+ """Special function reads([read_buffer_regions])
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]])
+ """
+
+ def __init__(self):
+ def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span:
Span = None):
+ assert self.context, "call 'exit_scope' before 'enter_scope'"
+ block_scope = self.context.current_block_scope()
+ if block_scope.reads is not None:
+ self.context.report_error(
+ "Duplicate write region declaration, "
+ + "previous one is "
+ + str(", ".join(str(x) for x in block_scope.reads)),
+ span,
+ )
+ if isinstance(read_regions, BufferSlice):
+ read_regions = [read_regions]
+ if not isinstance(read_regions, list):
+ self.context.report_error(
+ "Incorrect input type. "
+ + f"Expected BufferSlice or List[BufferSlice], but got
{type(read_regions)}",
+ span,
+ )
+ block_scope.reads = read_regions
+
+ super().__init__(reads, def_symbol=False)
+
+
+@register
+class BlockWrites(SpecialStmt):
+ """Special function writes([write_buffer_regions])
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.writes([C[vi: vi + 4, vj])
+ """
+
+ def __init__(self):
+ def writes(write_region: Union[BufferSlice, List[BufferSlice]], span:
Span = None):
+ assert self.context, "call 'exit_scope' before 'enter_scope'"
+ block_scope = self.context.current_block_scope()
+ if block_scope.writes is not None:
+ self.context.report_error(
+ "Duplicate write region declaration, "
+ + "previous one is "
+ + str(", ".join(str(x) for x in block_scope.writes)),
+ span,
+ )
+ if isinstance(write_region, list):
+ pass
+ elif isinstance(write_region, BufferSlice):
+ write_region = [write_region]
+ else:
+ self.context.report_error(
+ "Incorrect input type. "
+ + f"Expected BufferSlice or List[BufferSlice], but got
{type(write_region)}",
+ span,
+ )
+ block_scope.writes = write_region
+
+ super().__init__(writes, def_symbol=False)
+
+
+@register
+class BlockAttr(SpecialStmt):
+ """Special function block_attr({attr_key: attr_value})
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.block_attr({"double_buffer_scope": 1})
+ """
+
+ def __init__(self):
+ def block_attr(attrs: Mapping[str, Object], span: Span = None):
+ assert self.context, "call 'exit_scope' before 'enter_scope'"
+ block_scope = self.context.current_block_scope()
+ if block_scope.annotations is not None:
+ self.context.report_error(
+ "Duplicate block annotations declaration, "
+ + "previous one is "
+ + str(block_scope.annotations),
+ span,
+ )
+ attrs = {
+ key: tvm.tir.StringImm(val) if isinstance(val, str) else val
+ for key, val in attrs.items()
+ }
+ block_scope.annotations = attrs
+
+ super().__init__(block_attr, def_symbol=False)
+
+
+@register
+class BlockPredicate(SpecialStmt):
+ """Special function where(predicate)
+
+ Example
+ -------
+ .. code-block:: python
+
+ tir.where(i < 4)
+ """
+
+ def __init__(self):
+ def where(predicate, span=None):
+ assert self.context, "call 'exit_scope' before 'enter_scope'"
+ block_scope = self.context.current_block_scope()
+ if block_scope.predicate is not None:
+ self.context.report_error(
+ "Duplicate block predicate declaration, "
+ + "previous one is "
+ + str(block_scope.predicate),
+ span,
+ )
+
+ block_scope.predicate = predicate
+
+ super().__init__(where, def_symbol=False)
+
+
+@register
+class BlockMatchBufferRegion(SpecialStmt):
+ """Special function match_buffer_region(source, strides, elem_offset,
align, offset_factor)
+
+ Example
+ -------
+ .. code-block:: python
+
+ B = tir.match_buffer_region(A[0: 4])
+ """
+
+ def __init__(self):
+ def match_buffer_region(
+ source,
+ strides=None,
+ elem_offset=None,
+ align=-1,
+ offset_factor=0,
+ span=None,
+ ):
+ assert self.context, "call 'exit_scope' before 'enter_scope'"
+ if not isinstance(self.node, ast.Assign):
+ self.context.report_error(
+ "match_buffer_region must be assigned to a buffer, "
+ + "e.g. A = match_buffer_region(...)",
+ self.node.span,
+ )
+
+ if strides is None:
+ strides = []
+ align = convert_to_int(align, "align", self.context.report_error,
self.node.span)
+ offset_factor = convert_to_int(
+ offset_factor, "offset_factor", self.context.report_error,
self.node.span
+ )
+
+ if not isinstance(source, BufferSlice):
+ self.context.report_error(
+ "match_buffer_region needs a buffer region as source",
+ span=span,
+ )
+ buffer_region = buffer_slice_to_region(source)
+ shape = [r.extent for r in buffer_region.region]
+ buffer = tvm.tir.decl_buffer(
+ shape,
+ buffer_region.buffer.dtype,
+ self.node.lhs.id.name,
+ data=None,
+ strides=strides,
+ elem_offset=elem_offset,
+ scope=buffer_region.buffer.scope,
+ data_alignment=align,
+ offset_factor=offset_factor,
+ span=span,
+ )
+ self.context.current_block_scope().match_buffers.append(
+ tvm.tir.MatchBufferRegion(buffer, buffer_region)
+ )
+ self.context.update_symbol(self.node.lhs.id.name, buffer,
self.node)
+
+ super().__init__(match_buffer_region, def_symbol=True)
+
+
+@register
class VarDef(SpecialStmt):
""" Special function for defining a Var"""
def __init__(self):
def var(dtype, span):
- assert isinstance(self.node, ast.Assign)
+ assert isinstance(
+ self.node, ast.Assign
+ ), f"VarDef expected ast.Assign but got {type(self.node)}"
v = te.var(self.node.lhs.id.name, dtype, span=span)
- self.context.update_symbol(v.name, v)
+ self.context.update_symbol(v.name, v, self.node)
super().__init__(var, def_symbol=True)
@@ -167,10 +497,12 @@ class EnvThread(SpecialStmt):
def __init__(self):
def env_thread(env_name, span):
- assert isinstance(self.node, ast.Assign)
+ assert isinstance(
+ self.node, ast.Assign
+ ), f"EnvThread expected ast.Assign but got {type(self.node)}"
v = te.var(self.node.lhs.id.name, span=span)
self.context.func_var_env_dict[v] = env_name
- self.context.update_symbol(v.name, v)
+ self.context.update_symbol(v.name, v, self.node)
super().__init__(env_thread, def_symbol=True)
diff --git a/python/tvm/script/utils.py b/python/tvm/script/utils.py
index a6ba9d0..f8a0f61 100644
--- a/python/tvm/script/utils.py
+++ b/python/tvm/script/utils.py
@@ -16,15 +16,32 @@
# under the License.
"""Helper functions in TVM Script Parser"""
+from typing import Callable, List, Any, Optional, Tuple, Union
+
import inspect
-from ..ir import Span, SourceName
+import synr
+
+from tvm.arith import Analyzer
+from tvm.ir import Range, Span, SourceName
+from tvm.tir import PrimExpr, BufferRegion
+from tvm.error import DiagnosticError
+from .node import BufferSlice
-def get_param_list(func):
+def get_param_list(
+ func: Callable,
+) -> Tuple[List[str], List[Tuple[str, Tuple[Any, ...]]], Optional[str]]:
"""Get the parameter list from definition of function"""
- full_arg_spec = inspect.getfullargspec(func)
+ full_arg_spec: inspect.FullArgSpec = inspect.getfullargspec(func)
- args, defaults = full_arg_spec.args, full_arg_spec.defaults
+ args: List[str]
+ defaults: Optional[Tuple[Any, ...]]
+ kwonlyargs: List[str]
+ args, defaults, kwonlyargs = (
+ full_arg_spec.args,
+ full_arg_spec.defaults,
+ full_arg_spec.kwonlyargs,
+ )
if defaults is None:
defaults = tuple()
@@ -33,14 +50,17 @@ def get_param_list(func):
raise RuntimeError(
"TVM Script register error : variable keyword argument is not
supported now"
)
- if not len(full_arg_spec.kwonlyargs) == 0:
+
+ if len(kwonlyargs) == 1 and kwonlyargs[0] == "span":
+ pass
+ elif not len(kwonlyargs) == 0:
raise RuntimeError("TVM Script register error : keyword only argument
is not supported now")
- pos_only = list()
+ pos_only: List[str] = list()
for arg in args[: len(args) - len(defaults)]:
if arg != "span":
pos_only.append(arg)
- kwargs = list()
+ kwargs: List[Tuple[str, Tuple[Any, ...]]] = list()
for default, arg in zip(defaults, args[len(args) - len(defaults) :]):
if arg != "span":
kwargs.append((arg, default))
@@ -48,7 +68,37 @@ def get_param_list(func):
return pos_only, kwargs, full_arg_spec.varargs
-def from_synr_span(span):
+def buffer_slice_to_region(
+ buffer_slice: BufferSlice, analyzer: Optional[Analyzer] = None
+) -> BufferRegion:
+ """Construct BufferRegion from BufferSlice
+
+ Parameters
+ ----------
+ buffer_slice : BufferSlice
+ The input BufferSlice
+
+ analyzer : Optional[tvm.arith.Analyzer]
+ The analyzer for simplifying. If not provided, the method will
construct a new one
+
+ Returns
+ -------
+ buffer_region : BufferRegion
+ The constructed BufferRegion.
+ """
+ region: List[Range] = []
+ for s in buffer_slice.slices:
+ start: Union[PrimExpr, int] = s.start
+ extent: Union[PrimExpr, int] = 1 if s.stop is None else s.stop -
s.start
+ if not analyzer:
+ analyzer = Analyzer()
+ if isinstance(extent, PrimExpr):
+ extent = analyzer.simplify(extent)
+ region.append(Range.from_min_extent(start, extent, span=s.span))
+ return BufferRegion(buffer_slice.buffer, region)
+
+
+def tvm_span_from_synr(span: synr.ast.Span) -> Span:
"""Convert a synr span to a TVM span"""
return Span(
SourceName(span.filename),
@@ -57,3 +107,32 @@ def from_synr_span(span):
span.start_column,
span.end_column,
)
+
+
+def synr_span_from_tvm(span: Span) -> synr.ast.Span:
+ """Convert a TVM span to a synr span"""
+ return synr.ast.Span(
+ span.source_name.name,
+ span.line,
+ span.column,
+ span.end_line,
+ span.end_column,
+ )
+
+
+def call_with_error_reporting(
+ report_error,
+ node_span,
+ func,
+ *args,
+ **kwargs,
+):
+ """Call function with exception handling and report error using
node_span"""
+ try:
+ return func(*args, **kwargs)
+ except DiagnosticError:
+ raise
+ except Exception as err: # pylint: disable=broad-except
+ # printing last non-empty row of error message.
+ error_msg = list(filter(None, str(err).split("\n")))[-1]
+ report_error(error_msg, node_span)
diff --git a/python/tvm/tir/analysis/analysis.py
b/python/tvm/tir/analysis/analysis.py
index 1a3eb48..829eb8b 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -106,3 +106,26 @@ def verify_gpu_code(func, constraints):
The result of verification.
"""
return _ffi_api.verify_gpu_code(func, constraints)
+
+
+def get_block_access_region(block, buffer_var_map):
+ """Detect which regions of tensors in this block are read or written to.
+ Regions are sorted by order of appearance in the AST.
+
+ Parameters
+ ----------
+ block: tvm.tir.Block
+ The block in which we are detecting read/write regions.
+
+ buffer_var_map : Dict[Var, Buffer]
+ The outside buffers which may access the block. Mapping from buffer
var to the buffer
+
+ Returns
+ -------
+ result : List[List[BufferRegion]]
+ Array of access regions. There are three arrays of BufferRegion:
+ - first: read regions
+ - second: write regions
+ - third: opaque regions
+ """
+ return _ffi_api.get_block_access_region(block, buffer_var_map)
diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc
index 8d5bba5..7880740 100644
--- a/src/printer/tir_text_printer.cc
+++ b/src/printer/tir_text_printer.cc
@@ -476,8 +476,7 @@ inline const char* ForKind2String(ForKind t) {
case ForKind::kUnrolled:
return "unroll";
case ForKind::kThreadBinding:
- LOG(FATAL) << "Loop ThreadBinding is reserved for future used and "
- << "not yet supported in TIR";
+ return "thread_binding";
}
LOG(FATAL) << "Unknown ForKind";
return "Unknown";
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index 86b175e..4380795 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -22,6 +22,7 @@
* \brief Printer class to print Tensor IR to python syntax script
*/
+#include <tvm/arith/analyzer.h>
#include <tvm/ir/module.h>
#include <tvm/node/serialization.h>
#include <tvm/runtime/registry.h>
@@ -66,7 +67,10 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const
Stmt&)>,
std::unordered_map<const BaseFuncNode*, GlobalVar> func2var_;
/*! \brief var collector (var defined by For/Loop/Block) */
std::unordered_set<const VarNode*> var_not_in_headers;
- /*! \brief buffer collector (buffer defined in BufferMap and
BufferAllocation)*/
+ /*!
+ * \brief buffer collector
+ * (buffer defined in BufferMap, BufferAllocation and
MatchBufferRegion)
+ */
std::unordered_set<const BufferNode*> buf_not_in_headers;
/*! \brief Map from Var to thread env name */
std::unordered_map<Var, String, ObjectPtrHash, ObjectPtrEqual> var_env_map_;
@@ -84,6 +88,8 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
int num_child_;
/*! \brief the number of current node */
int current_num_;
+ /*! \brief loop stack without annotations */
+ std::vector<For> loop_stack_;
Doc VisitExpr_(const CastNode* op) override;
Doc VisitExpr_(const VarNode* op) override;
@@ -131,6 +137,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const
Stmt&)>,
Doc VisitStmt_(const ForNode* op) override;
Doc VisitStmt_(const PrefetchNode* op) override;
Doc VisitStmt_(const EvaluateNode* op) override;
+ Doc VisitStmt_(const BlockRealizeNode* op) override;
Doc VisitStmtDefault_(const Object* op) override;
Doc VisitType_(const PrimTypeNode* node) override;
@@ -145,12 +152,24 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const
Stmt&)>,
Doc PrintArray(const ArrayNode* op);
Doc PrintBuffer(const BufferNode* op);
Doc AllocBufferDeclaration(const Buffer& buf);
+ Doc PrintBufferRegion(const BufferRegionNode* op);
+ Doc PrintMatchBufferRegion(const MatchBufferRegionNode* op);
+ Doc PrintAnnotations(const Map<String, ObjectRef>& annotations);
static Doc PrintString(const StringObj* op) { return
Doc::StrLiteral(op->data); }
Doc GetUniqueName(std::string prefix);
Doc AllocVar(const Var& var);
Doc AllocBuf(const Buffer& buffer);
+ /*! Helper functions for loop printing. */
+ /*!
+ * \brief Print a single for loop
+ * \param loop The for loop to be printed
+ */
+ Doc PrintLoop(const For& loop);
+ /*! \brief Print all simple loops in stack into one line using tir.grid(). */
+ Doc PrintLoopStack();
+
/*!
* \brief Print additional info about expr in comment.
* \param expr The expression.
@@ -308,6 +327,36 @@ Doc TVMScriptPrinter::AllocBuf(const Buffer& buffer) {
return val;
}
+Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {
+ const Buffer& buf = op->buffer;
+ buf_not_in_headers.insert(buf.get());
+
+ Doc doc = Print(op->buffer) << " = tir.match_buffer_region(" <<
Print(op->source);
+ if (!buf->strides.empty()) {
+ doc << ", strides=" << Print(buf->strides);
+ }
+ if (buf->offset_factor != 0 && buf->elem_offset->IsInstance<VarNode>()) {
+ Var elem_offset = Downcast<Var>(buf->elem_offset);
+ if (memo_var_.find(elem_offset) != memo_var_.end()) {
+ doc << ", elem_offset=" << Print(buf->elem_offset);
+ } else {
+ // implicitly define elem_offset
+ memo_var_[elem_offset] = Doc::Text(memo_buf_[buf].str() +
".elem_offset");
+ var_not_in_headers.insert(elem_offset.get());
+ }
+ } else {
+ doc << ", elem_offset=" << Print(buf->elem_offset);
+ }
+ if (buf->data_alignment != -1) {
+ doc << ", align=" << buf->data_alignment;
+ }
+ if (buf->offset_factor != 0) {
+ doc << ", offset_factor=" << buf->offset_factor;
+ }
+ doc << ")";
+ return doc;
+}
+
Doc TVMScriptPrinter::Print(const ObjectRef& node) {
if (!node.defined()) return Doc::Text("None");
if (node->IsInstance<StmtNode>()) {
@@ -330,6 +379,10 @@ Doc TVMScriptPrinter::Print(const ObjectRef& node) {
return PrintIterVar(node.as<IterVarNode>());
} else if (node->IsInstance<RangeNode>()) {
return PrintRange(node.as<RangeNode>());
+ } else if (node->IsInstance<BufferRegionNode>()) {
+ return PrintBufferRegion(node.as<BufferRegionNode>());
+ } else if (node->IsInstance<MatchBufferRegionNode>()) {
+ return PrintMatchBufferRegion(node.as<MatchBufferRegionNode>());
} else {
meta_collector_.Collect(node);
return this->meta_.GetMetaNode(node);
@@ -660,9 +713,7 @@ inline const char* ForKind2String(ForKind t) {
case ForKind::kUnrolled:
return "unroll";
case ForKind::kThreadBinding:
- LOG(FATAL) << "Loop ThreadBinding is reserved for future used and "
- << "not yet supported in TIR";
- return "threadbinding";
+ return "thread_binding";
}
LOG(FATAL) << "Unknown ForKind";
return "Unknown";
@@ -671,9 +722,27 @@ inline const char* ForKind2String(ForKind t) {
Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) {
Doc doc;
var_not_in_headers.insert(op->loop_var.get());
- doc << "for " << Print(op->loop_var) << " in tir." +
std::string(ForKind2String(op->kind)) + "("
- << Print(op->min) << ", " << Print(op->min + op->extent)
- << "):" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
+ const auto* body = op->body.as<ForNode>();
+ bool simple_loop = op->kind == ForKind::kSerial && op->annotations.empty()
&& is_zero(op->min);
+ if (simple_loop) loop_stack_.push_back(GetRef<For>(op));
+ // It is a loop that can be compressed, let the loops below print it out
+ if (simple_loop && body != nullptr) return Print(GetRef<For>(body));
+ // It is a loop that can not be compressed
+ bool print_above = !loop_stack_.empty();
+ // print loops above if needed
+ if (print_above) {
+ doc << PrintLoopStack();
+ loop_stack_.clear();
+ }
+ if (!simple_loop) {
+ // print current loop if needed
+ Doc current_loop;
+ current_loop << PrintLoop(GetRef<For>(op));
+ current_loop << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
+ doc << (print_above ? Doc::Indent(4, Doc::NewLine() << current_loop) :
current_loop);
+ } else {
+ doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
+ }
return doc;
}
@@ -713,6 +782,88 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode*
op) {
return doc;
}
+Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) {
+ const auto* block_op = op->block.as<BlockNode>();
+ // print block name and block vars
+ Doc doc;
+ doc << "with tir.block([";
+ std::vector<Doc> block_var_docs;
+ for (const auto& iter_var : block_op->iter_vars) {
+ Doc block_var_doc;
+ if (is_zero(iter_var->dom->min) && iter_var->iter_type == kDataPar) {
+ block_var_doc << Print(iter_var->dom->extent);
+ } else {
+ block_var_doc << "tir.";
+ switch (iter_var->iter_type) {
+ case kDataPar:
+ block_var_doc << "range";
+ break;
+ case kCommReduce:
+ block_var_doc << "reduce_axis";
+ break;
+ case kOrdered:
+ block_var_doc << "scan_axis";
+ break;
+ case kOpaque:
+ block_var_doc << "opaque_axis";
+ break;
+ default:
+ LOG(FATAL) << "Unknown block var iter type: " << iter_var->iter_type;
+ break;
+ }
+ block_var_doc << "(" << Print(iter_var->dom->min) << ", "
+ << Print(iter_var->dom->min + iter_var->dom->extent) <<
")";
+ }
+ block_var_docs.push_back(block_var_doc);
+ }
+ doc << PrintSep(block_var_docs, Doc::Text(", ")) << "], ";
+ doc << Doc::StrLiteral(block_op->name_hint) << ")";
+ std::vector<Doc> block_var_names;
+ for (const auto& iter_var : block_op->iter_vars) {
+ var_not_in_headers.insert(iter_var->var.get());
+ block_var_names.push_back(Print(iter_var->var));
+ }
+ if (!block_var_names.empty()) {
+ doc << " as [" << PrintSep(block_var_names, Doc::Text(", ")) << "]";
+ }
+ doc << ":";
+ Doc block_attr_doc;
+ // print predicate, binding, read/write tensor region, annotations
+ if (!is_one(op->predicate)) {
+ block_attr_doc << Doc::NewLine() << "tir.where(" << Print(op->predicate)
<< ")";
+ }
+ for (size_t i = 0; i < block_op->iter_vars.size(); ++i)
+ block_attr_doc << Doc::NewLine() << "tir.bind(" <<
Print(block_op->iter_vars[i]->var) << ", "
+ << Print(op->iter_values[i]) << ")";
+ block_attr_doc << Doc::NewLine() << "tir.reads(" << Print(block_op->reads)
<< ")";
+ block_attr_doc << Doc::NewLine() << "tir.writes(" << Print(block_op->writes)
<< ")";
+ if (!block_op->annotations.empty()) {
+ block_attr_doc << Doc::NewLine() << "tir.block_attr({";
+ block_attr_doc << PrintAnnotations(block_op->annotations);
+ block_attr_doc << "})";
+ }
+ // print body
+ Doc body;
+ body << Doc::NewLine();
+ for (const auto& alloc_buf : block_op->alloc_buffers) {
+ buf_not_in_headers.insert(alloc_buf.get());
+ body << Print(alloc_buf) << " = tir.alloc_buffer(" <<
memo_buf_decl_[alloc_buf] << ")"
+ << Doc::NewLine();
+ }
+ for (const auto& match_buf : block_op->match_buffers) {
+ body << Print(match_buf) << Doc::NewLine();
+ }
+ if (block_op->init.defined()) {
+ Doc init_block;
+ init_block << "with tir.init():";
+ init_block << Doc::Indent(4, Doc::NewLine() <<
PrintBody(block_op->init.value()));
+ body << init_block << Doc::NewLine();
+ }
+ body << PrintBody(block_op->body);
+ doc << Doc::Indent(4, block_attr_doc << body);
+ return doc;
+}
+
Doc TVMScriptPrinter::PrintBody(const Stmt& body) {
int memo_num_child, memo_current_num;
std::swap(memo_num_child, num_child_);
@@ -890,6 +1041,73 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) {
return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer);
}
+Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) {
+ Doc doc;
+ doc << Print(op->buffer) << "[";
+ for (size_t i = 0; i < op->region.size(); ++i) {
+ if (i != 0) doc << ", ";
+ const auto& range = op->region[i];
+ if (!is_one(range->extent)) {
+ doc << Print(range->min) << ":" << Print(range->min + range->extent);
+ } else {
+ doc << Print(range->min);
+ }
+ }
+ doc << "]";
+ return doc;
+}
+
+Doc TVMScriptPrinter::PrintAnnotations(const Map<String, ObjectRef>&
annotations) {
+ Doc res;
+ std::vector<std::pair<String, ObjectRef>> anno_list;
+ anno_list.reserve(annotations.size());
+ for (const auto& pair : annotations) {
+ anno_list.emplace_back(pair);
+ }
+ sort(anno_list.begin(), anno_list.end());
+ for (size_t i = 0; i < anno_list.size(); ++i) {
+ if (i != 0) {
+ res << ", ";
+ }
+ res << "\"" << anno_list[i].first << "\":" << Print(anno_list[i].second);
+ }
+ return res;
+}
+
+Doc TVMScriptPrinter::PrintLoop(const For& loop) {
+ Doc res;
+ res << "for " << Print(loop->loop_var)
+ << " in tir." + std::string(ForKind2String(loop->kind)) + "(" <<
Print(loop->min) << ", "
+ << Print(loop->min + loop->extent);
+ if (loop->thread_binding.defined()) {
+ res << ", thread = ";
+ res << Print(loop->thread_binding.value()->thread_tag);
+ }
+ if (!loop->annotations.empty()) {
+ res << ", annotation = {";
+ res << PrintAnnotations(loop->annotations);
+ res << "}";
+ }
+ res << "):";
+ return res;
+}
+
+Doc TVMScriptPrinter::PrintLoopStack() {
+ Doc res;
+ if (loop_stack_.size() == 1) {
+ res << PrintLoop(loop_stack_[0]);
+ } else if (loop_stack_.size() > 1) {
+ std::vector<Doc> vars, extents;
+ for (const auto& loop : loop_stack_) {
+ vars.push_back(Print(loop->loop_var));
+ extents.push_back(Print(loop->extent));
+ }
+ res << "for " << PrintSep(vars, Doc::Text(", ")) << " in tir.grid("
+ << PrintSep(extents, Doc::Text(", ")) << "):";
+ }
+ return res;
+}
+
TVM_REGISTER_GLOBAL("script.AsTVMScript")
.set_body_typed<std::string(const ObjectRef&, bool)>([](const ObjectRef&
functions,
bool show_meta) {
diff --git a/src/tir/analysis/block_access_region_detector.cc
b/src/tir/analysis/block_access_region_detector.cc
new file mode 100644
index 0000000..b1da536
--- /dev/null
+++ b/src/tir/analysis/block_access_region_detector.cc
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/block_region_detector.cc
+ * \brief Detect block read/write regions by visiting its body
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Detect which regions of tensors in this block are read or written
to. Regions are sorted
+ * by order of appearance in the AST. \note This detector can only visit
blocks and will not visit
+ * child blocks recursively
+ */
+class BlockReadWriteDetector : public StmtExprVisitor {
+ public:
+ explicit BlockReadWriteDetector(const Map<Var, Buffer>& buffer_var_map)
+ : buffer_var_map_(buffer_var_map) {}
+
+ /*! \brief Return read regions of the block */
+ Array<BufferRegion> CollectReads();
+ /*! \brief Return write regions of the block */
+ Array<BufferRegion> CollectWrites();
+ /*!
+ * \brief Return opaque buffer regions of the block
+ * \note The buffer accessed by load/store or call with buffer.data will
+ * be marked as opaque.
+ */
+ Array<BufferRegion> CollectOpaques();
+ /*! \brief overload operator() to make sure it accepts a block node */
+ void operator()(const Stmt& stmt);
+
+ private:
+ /*! \brief Iteration range for loop_vars */
+ std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
+ /*! \brief The buffers that the current block reads */
+ std::vector<Buffer> read_buffers_;
+ /*! \brief The buffers that the current block writes */
+ std::vector<Buffer> writes_buffers_;
+ /*! \brief The opaque buffer which is access by buffer.data */
+ std::vector<Buffer> opaque_buffers_;
+ /*! \brief The read regions of the current block */
+ std::vector<std::vector<tvm::arith::IntSet>> read_regions_;
+ /*! \brief The write regions of the current block */
+ std::vector<std::vector<tvm::arith::IntSet>> write_regions_;
+ /*! \brief The outside buffer data mapping to its buffer */
+ Map<Var, Buffer> buffer_var_map_;
+ /*! \brief The analyzer for simplifying*/
+ arith::Analyzer analyzer_;
+
+ /*!
+ * \brief Update read/write buffers and regions with provided buffer and
region
+ * \param buffers The buffers should be updated
+ * \param regions The access regions should be updated
+ * \param buffer The provided buffer
+ * \param region The provided region
+ */
+ void Update(std::vector<Buffer>* buffers,
std::vector<std::vector<arith::IntSet>>* regions,
+ const Buffer& buffer, const std::vector<arith::IntSet>& region);
+
+ /*! \brief Helper function to collect access regions. */
+ Array<BufferRegion> CollectRegions(const std::vector<Buffer>& buffers,
+ const
std::vector<std::vector<tvm::arith::IntSet>>& regions);
+
+ /*! \brief Helper function to add a opaque buffer. */
+ void AddOpaque(const Var& buffer_var);
+
+ void VisitStmt_(const ForNode* op) override;
+ void VisitStmt_(const BlockRealizeNode* op) override;
+ void VisitStmt_(const BufferStoreNode* op) override;
+ void VisitStmt_(const StoreNode* op) override;
+ void VisitExpr_(const BufferLoadNode* op) override;
+ void VisitExpr_(const LoadNode* op) override;
+ void VisitExpr_(const VarNode* op) override;
+};
+
+void BlockReadWriteDetector::operator()(const Stmt& stmt) {
+ ICHECK(stmt.as<BlockNode>() != nullptr)
+ << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey();
+ StmtExprVisitor::operator()(stmt);
+}
+
+Array<BufferRegion> BlockReadWriteDetector::CollectReads() {
+ return CollectRegions(read_buffers_, read_regions_);
+}
+
+Array<BufferRegion> BlockReadWriteDetector::CollectWrites() {
+ return CollectRegions(writes_buffers_, write_regions_);
+}
+
+Array<BufferRegion> BlockReadWriteDetector::CollectOpaques() {
+ Array<BufferRegion> res;
+ res.reserve(opaque_buffers_.size());
+ for (const Buffer& buffer : opaque_buffers_) {
+ res.push_back(BufferRegion::FullRegion(buffer));
+ }
+ return res;
+}
+
+void BlockReadWriteDetector::VisitExpr_(const VarNode* op) {
AddOpaque(GetRef<Var>(op)); }
+
+void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) {
+ AddOpaque(op->buffer_var);
+ ExprVisitor::VisitExpr_(op);
+}
+
+void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) {
+ std::vector<arith::IntSet> relaxed_region;
+ for (const PrimExpr& index : op->indices) {
+ relaxed_region.push_back(arith::EvalSet(index, dom_map_));
+ }
+ Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
+ ExprVisitor::VisitExpr_(op);
+}
+
+void BlockReadWriteDetector::VisitStmt_(const ForNode* op) {
+ Range range = Range::FromMinExtent(op->min, op->extent);
+ dom_map_[op->loop_var.get()] = arith::IntSet::FromRange(range);
+ StmtVisitor::VisitStmt_(op);
+ dom_map_.erase(op->loop_var.get());
+}
+
+void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) {
+ AddOpaque(op->buffer_var);
+ StmtVisitor::VisitStmt_(op);
+}
+
+void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) {
+ std::vector<arith::IntSet> relaxed_region;
+ for (const PrimExpr& index : op->indices) {
+ relaxed_region.push_back(arith::EvalSet(index, dom_map_));
+ }
+ Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);
+ StmtVisitor::VisitStmt_(op);
+}
+
+void BlockReadWriteDetector::VisitStmt_(const BlockRealizeNode* op) {
+ /*! \note detector will not visit child block recursively, so it will stop
here */
+ std::unordered_map<const VarNode*, PrimExpr> vmap;
+ for (size_t i = 0; i < op->block->iter_vars.size(); ++i) {
+ vmap[op->block->iter_vars[i]->var.get()] = op->iter_values[i];
+ }
+ for (const auto& read : op->block->reads) {
+ std::vector<arith::IntSet> relaxed_region;
+ for (const auto& range : read->region) {
+ relaxed_region.push_back(
+ arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent(
+ Substitute(range->min, vmap),
Substitute(range->extent, vmap))),
+ dom_map_));
+ }
+ Update(&read_buffers_, &read_regions_, read->buffer, relaxed_region);
+ }
+ for (const auto& write : op->block->writes) {
+ std::vector<arith::IntSet> relaxed_region;
+ for (const auto& range : write->region) {
+ relaxed_region.push_back(
+ arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent(
+ Substitute(range->min, vmap),
Substitute(range->extent, vmap))),
+ dom_map_));
+ }
+ Update(&writes_buffers_, &write_regions_, write->buffer, relaxed_region);
+ }
+}
+
+void BlockReadWriteDetector::Update(std::vector<Buffer>* buffers,
+ std::vector<std::vector<arith::IntSet>>*
regions,
+ const Buffer& buffer,
+ const std::vector<arith::IntSet>& region) {
+ if (buffer_var_map_.find(buffer->data) == buffer_var_map_.end()) return;
+ ICHECK_EQ(buffers->size(), regions->size())
+ << " Expected the buffer and regions to have the same size ";
+ for (size_t i = 0; i < regions->size(); ++i) {
+ if ((*buffers)[i].same_as(buffer)) {
+ ICHECK_EQ((*regions)[i].size(), region.size()) << "Inconsistent buffer
dimension";
+ for (size_t j = 0; j < region.size(); ++j) {
+ (*regions)[i][j] = arith::Union({(*regions)[i][j], region[j]});
+ }
+ return;
+ }
+ }
+ buffers->push_back(buffer);
+ regions->push_back(region);
+}
+
+Array<BufferRegion> BlockReadWriteDetector::CollectRegions(
+ const std::vector<Buffer>& buffers,
+ const std::vector<std::vector<tvm::arith::IntSet>>& regions) {
+ ICHECK_EQ(buffers.size(), regions.size());
+ Array<BufferRegion> res;
+ res.reserve(buffers.size());
+ for (size_t i = 0; i < regions.size(); ++i) {
+ Array<Range> region;
+ region.reserve(regions[i].size());
+ for (size_t j = 0; j < regions[i].size(); j++) {
+ tvm::arith::IntSet range = regions[i][j];
+ region.push_back(range.CoverRange(Range::FromMinExtent(0,
buffers[i]->shape[j])));
+ }
+ res.push_back(BufferRegion(buffers[i], region));
+ }
+ return res;
+}
+
+void BlockReadWriteDetector::AddOpaque(const Var& buffer_var) {
+ auto it = buffer_var_map_.find(buffer_var);
+ if (it != buffer_var_map_.end()) {
+ const Buffer& buffer = (*it).second;
+ for (const Buffer& opaque_buffer : opaque_buffers_) {
+ if (buffer.same_as(opaque_buffer)) return;
+ }
+ opaque_buffers_.push_back(buffer);
+ }
+}
+
+Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
+ const Map<Var, Buffer>&
buffer_var_map) {
+ BlockReadWriteDetector detector(buffer_var_map);
+ detector(block);
+ return {detector.CollectReads(), detector.CollectWrites(),
detector.CollectOpaques()};
+}
+
+TVM_REGISTER_GLOBAL("tir.analysis.get_block_access_region").set_body_typed(GetBlockAccessRegion);
+
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/ir/script/script_complete.cc
b/src/tir/ir/script/script_complete.cc
new file mode 100644
index 0000000..7c9fff7
--- /dev/null
+++ b/src/tir/ir/script/script_complete.cc
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/ir/script/script_complete.cc
+ * \brief Used by TVM Script parser to expand incomplete TIR input
+ */
+
+#include <tvm/arith/int_set.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/stmt.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <utility>
+
+namespace tvm {
+namespace tir {
+
+/*! \brief Generate surrounding loops automatically */
+class ScriptCompleter : public StmtMutator {
+ public:
+ explicit ScriptCompleter(Map<Var, Buffer>* buffer_var_map) :
buffer_var_map_(buffer_var_map) {}
+ /*! \brief Whether the stmt contains at least one block. */
+ bool contains_block = false;
+
+ private:
+ Map<Var, Buffer>* buffer_var_map_;
+ Stmt VisitStmt_(const BlockRealizeNode* op) override {
+ contains_block = true;
+ Stmt body = StmtMutator::VisitStmt_(op);
+ if (!op->iter_values.empty() && !op->iter_values[0].dtype().is_int()) {
+ auto block_with_binding =
CopyOnWrite(Downcast<BlockRealize>(body).get());
+ std::vector<PrimExpr> bindings;
+ for (size_t i = 0; i < op->iter_values.size(); ++i) {
+ bindings.push_back(Var("i" + std::to_string(i)));
+ }
+ block_with_binding->iter_values = bindings;
+ body = BlockRealize(block_with_binding);
+ for (int i = op->iter_values.size() - 1; i >= 0; --i) {
+ body = For(Downcast<Var>(bindings[i]),
op->block->iter_vars[i]->dom->min,
+ op->block->iter_vars[i]->dom->extent, {}, body);
+ }
+ }
+ return body;
+ }
+
+ Stmt VisitStmt_(const BlockNode* op) override {
+ // Buffers allocated in the block can be accessed by its body.
+ for (const auto& alloc_buffer : op->alloc_buffers) {
+ buffer_var_map_->Set(alloc_buffer->data, alloc_buffer);
+ }
+ Block block = Downcast<Block>(StmtMutator::VisitStmt_(op));
+ // Remove buffers allocated inside block to detect its access region
+ for (const auto& alloc_buffer : op->alloc_buffers) {
+ buffer_var_map_->erase(alloc_buffer->data);
+ }
+ if (block->reads.empty() || block->writes.empty()) {
+ auto access_region = GetBlockAccessRegion(block, *buffer_var_map_);
+ const Array<BufferRegion>& reads = access_region[0];
+ const Array<BufferRegion>& writes = access_region[1];
+ const Array<BufferRegion>& opaque = access_region[2];
+ CHECK(opaque.empty())
+ << "ValueError: Can not auto detect buffer access region from
tir.Load, tir.Store or "
+ "direct access by buffer data. Please annotation the access
region manually";
+ auto n = CopyOnWrite(block.operator->());
+ if (!n->reads.defined()) n->reads = reads;
+ if (!n->writes.defined()) n->writes = writes;
+ return Block(n);
+ } else {
+ return std::move(block);
+ }
+ }
+};
+
+PrimFunc ScriptComplete(PrimFunc func, const Array<Buffer>& root_allocates) {
+ Map<Var, Buffer> buffer_var_map;
+ for (const auto& pair : func->buffer_map) {
+ const Buffer& buffer = pair.second;
+ buffer_var_map.Set(buffer->data, buffer);
+ }
+ for (const auto& alloc : root_allocates) {
+ buffer_var_map.Set(alloc->data, alloc);
+ }
+ ScriptCompleter script_completer(&buffer_var_map);
+ // generate surrounding loops automatically
+ Stmt res = script_completer(func->body);
+ // generate root block automatically
+ if (script_completer.contains_block &&
+ (!res->IsInstance<BlockRealizeNode>() || !root_allocates.empty())) {
+ res = Block({}, {}, {}, "root", res, NullOpt, root_allocates);
+ res = BlockRealize({}, Bool(true), Downcast<Block>(res));
+ }
+ if (func->body.same_as(res)) {
+ return func;
+ } else {
+ auto fptr = func.CopyOnWrite();
+ fptr->body = res;
+ return func;
+ }
+}
+
+TVM_REGISTER_GLOBAL("script.Complete").set_body_typed(ScriptComplete);
+
+} // namespace tir
+} // namespace tvm
diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py
b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
new file mode 100644
index 0000000..7e4d7d8
--- /dev/null
+++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import tir, script
+from tvm.ir import Range
+
+
[email protected]
+def func() -> None:
+ A = tir.alloc_buffer((128, 128), "float32")
+ B = tir.alloc_buffer((128, 128), "float32")
+ C = tir.alloc_buffer((128, 128), "float32")
+ D = tir.alloc_buffer((128, 128), "float32")
+ with tir.block([]):
+ # Need add read/write region manually to avoid triggering block access
region detector
+ tir.reads([B[0, 0], C[0:16, 0:16], A[4:12, 4:12]])
+ tir.writes([A[0:12, 0:12]])
+ for i, j in tir.grid(8, 8):
+ A[i, j] = B[0, 0] + C[0, 0]
+ with tir.block([2, 2]) as [vi, vj]:
+ tir.reads([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8],
C[12:16, 12:16]])
+ tir.writes([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8]])
+ for i, j in tir.grid(4, 4):
+ A[vi * 4 + 4 + i, vj * 4 + 4 + j] += C[i + 12, j + 12]
+ tir.evaluate(D.data)
+
+
+def test_block_access_region_detector():
+ block = func.body.block.body.block
+ alloc_buffers = func.body.block.alloc_buffers
+ buffer_var_map = {buf.data: buf for buf in alloc_buffers}
+ ret = tir.analysis.get_block_access_region(block, buffer_var_map)
+
+ tvm.ir.assert_structural_equal(block.reads, ret[0])
+ tvm.ir.assert_structural_equal(block.writes, ret[1])
+ D = alloc_buffers[-1]
+ tvm.ir.assert_structural_equal(
+ [tvm.tir.BufferRegion(D, [Range(0, 128), Range(0, 128)])], ret[2]
+ )
+
+
+if __name__ == "__main__":
+ test_block_access_region_detector()
diff --git a/tests/python/unittest/test_tvmscript_error_report.py
b/tests/python/unittest/test_tvmscript_error_report.py
index 048a954..052217b 100644
--- a/tests/python/unittest/test_tvmscript_error_report.py
+++ b/tests/python/unittest/test_tvmscript_error_report.py
@@ -144,6 +144,197 @@ def test_no_body():
check_error(no_body, 3)
+def allocate_with_buffers() -> None:
+ with tir.allocate([1], "float32", "") as [A, B]: # error
+ tir.evaluate(1.0)
+
+
+def test_allocate_with_buffers():
+ check_error(allocate_with_buffers, 2)
+
+
+def inconsistent_binding() -> None:
+ with tir.block([128, 128]) as [vi]: # error
+ tir.evaluate(1.0)
+
+
+def test_inconsistent_binding():
+ check_error(inconsistent_binding, 2)
+
+
+def invalid_block_axes(a: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 16), "float32")
+ with tir.block([A]) as [vi]: # error
+ tir.evaluate(1.0)
+
+
+def test_invalid_block_axes():
+ check_error(invalid_block_axes, 3)
+
+
+def miss_block_bind() -> None:
+ with tir.block([16, 16]) as [vi, vj]: # error
+ tir.bind(vi, 1)
+ tir.evaluate(1.0)
+
+
+def test_miss_block_bind():
+ check_error(miss_block_bind, 2)
+
+
+def invalid_loop_var() -> None:
+ for i, j in range(0, 16): # error
+ tir.evaluate(1.0)
+
+
+def test_invalid_loop_var():
+ check_error(invalid_loop_var, 2)
+
+
+def inconsistent_grid() -> None:
+ for i in tir.grid(16, 16): # error
+ tir.evaluate(1.0)
+
+
+def test_inconsistent_grid():
+ check_error(inconsistent_grid, 2)
+
+
+def invalid_match_buffer_region() -> None:
+ with tir.block([16, 16]) as [vi, vj]:
+ A = tir.match_buffer_region(vi) # error
+ tir.evaluate(1.0)
+
+
+def test_invalid_match_buffer_region():
+ check_error(invalid_match_buffer_region, 3)
+
+
+def duplicate_buffer() -> None:
+ A = tir.alloc_buffer((128, 128), "float32")
+ with tir.block([16, 16]) as [vi, vj]:
+ A = tir.alloc_buffer((128, 128), "float32") # error
+ tir.evaluate(1.0)
+
+
+def test_duplicate_buffer():
+ check_error(duplicate_buffer, 4)
+
+
+def duplicate_reads() -> None:
+ A = tir.alloc_buffer((128, 128), "float32")
+ with tir.block([16, 16]) as [vi, vj]:
+ tir.reads(A[0:8, 0:8])
+ tir.reads(A[0:16, 0:16]) # error
+ tir.evaluate(1.0)
+
+
+def duplicate_writes() -> None:
+ A = tir.alloc_buffer((128, 128), "float32")
+ with tir.block([16, 16]) as [vi, vj]:
+ tir.writes(A[0:8, 0:8])
+ tir.writes(A[0:16, 0:16]) # error
+ tir.evaluate(1.0)
+
+
+def duplicate_predicate() -> None:
+ with tir.block([16, 16]) as [vi, vj]:
+ tir.where(1)
+ tir.where(0) # error
+
+
+def duplicate_annotations() -> None:
+ with tir.block([16, 16]) as [vi, vj]:
+ tir.block_attr({})
+ tir.block_attr({}) # error
+
+
+def duplicate_init() -> None:
+ with tir.block([16, 16]) as [vi, vj]:
+ with tir.init():
+ tir.evaluate(1.0)
+ with tir.init(): # error
+ tir.evaluate(1.0)
+
+
+def test_duplicate_block_signature():
+ check_error(duplicate_reads, 5)
+ check_error(duplicate_writes, 5)
+ check_error(duplicate_predicate, 4)
+ check_error(duplicate_annotations, 4)
+ check_error(duplicate_init, 5)
+
+
+def opaque_access_during_complete(a: ty.handle) -> None: # error
+ A = tir.match_buffer(a, (16, 16), "float32")
+ with tir.block([16, 16]) as [vi, vj]:
+ tir.evaluate(tir.load("float32", A.data, vi * 16 + vj))
+
+
+def test_opaque_access_during_complete():
+ check_error(opaque_access_during_complete, 1)
+
+
+def convert_slice_to_bufferload() -> None:
+ A = tir.alloc_buffer((128, 128), "float32")
+ with tir.block([16, 16]) as [vi, vj]:
+ A[vi, vj] = A[vi : vi + 2, vj] + 1 # error
+
+
+def test_convert_slice_to_bufferload():
+ check_error(convert_slice_to_bufferload, 4)
+
+
+def error_index_type() -> None:
+ A = tir.alloc_buffer((128, 128), "float32")
+ with tir.block([16, 16]) as [vi, vj]:
+ A[vi, vj] = A[vi, 0.0] + 1 # error
+
+
+def test_error_index_type():
+ check_error(error_index_type, 4)
+
+
+def mismatch_args() -> None:
+ A = tir.alloc_buffer((128, 128), "float32")
+ with tir.block([16, 16]) as [vi, vj]:
+ tir.reads(A[0, 0], A[1, 1]) # error
+ tir.evaluate(1.0)
+
+
+def test_mismatch_args():
+ check_error(mismatch_args, 4)
+
+
+def special_stmt_except() -> None:
+ A = tir.alloc_buffer("(128, 128)", "float32") # error
+ with tir.block([16, 16]) as [vi, vj]:
+ tir.evaluate(1.0)
+
+
+def scope_handler_except() -> None:
+ for i in tir.serial("1", "1"): # error
+ tir.evaluate(1)
+
+
+def intrin_except_unassign(a: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 16), "float32")
+ tir.evaluate(A) # error
+
+
+def intrin_except_assign(a: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 16), "float32")
+ A[0, 0] = tir.load(A, A, A) # error
+
+
+def test_tvm_exception_catch():
+ # test catching c++ side exception
+ check_error(special_stmt_except, 2)
+ check_error(scope_handler_except, 2)
+ check_error(intrin_except_unassign, 3)
+ check_error(intrin_except_assign, 3)
+
+
def check_error(module, rel_lineno):
# Override the default renderer to accumulate errors
_, start_line = inspect.getsourcelines(module)
@@ -180,3 +371,17 @@ if __name__ == "__main__":
test_return_not_allowed()
test_tir_assert()
test_no_body()
+ test_allocate_with_buffers()
+ test_inconsistent_binding()
+ test_invalid_block_axes()
+ test_miss_block_bind()
+ test_invalid_loop_var()
+ test_inconsistent_grid()
+ test_invalid_match_buffer_region()
+ test_duplicate_buffer()
+ test_duplicate_block_signature()
+ test_opaque_access_during_complete()
+ test_convert_slice_to_bufferload()
+ test_error_index_type()
+ test_mismatch_args()
+ test_tvm_exception_catch()
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py
b/tests/python/unittest/test_tvmscript_roundtrip.py
index c7a38cc..a295908 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -2662,6 +2662,169 @@ def test_opt_conv_tensorcore_mod_host():
tvm.ir.assert_structural_equal(mod, rt_mod, True)
[email protected]
+def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, [128, 128])
+ B = tir.match_buffer(b, [128, 128])
+ C = tir.match_buffer(c, [128, 128])
+
+ with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj,
vk]:
+ with tir.init():
+ C[vi, vj] = tir.float32(0)
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
[email protected]
+def matmul_original(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, [128, 128])
+ B = tir.match_buffer(b, [128, 128])
+ C = tir.match_buffer(c, [128, 128])
+
+ for i, j in tir.grid(128, 128):
+ with tir.block([128, 128], "init") as [vi, vj]:
+ C[vi, vj] = tir.float32(0)
+
+ for k in range(0, 128):
+ with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as
[vi, vj, vk]:
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+
[email protected]
+def element_wise(a: ty.handle, c: ty.handle) -> None:
+ A = tir.match_buffer(a, (128, 128), "float32")
+ C = tir.match_buffer(c, (128, 128), "float32")
+ B = tir.alloc_buffer((128, 128), "float32")
+
+ with tir.block([128, 128], "B") as [vi, vj]:
+ B[vi, vj] = A[vi, vj] * tir.float32(2)
+
+ with tir.block([128, 128], "C") as [vi, vj]:
+ C[vi, vj] = B[vi, vj] + tir.float32(1)
+
+
[email protected]
+def predicate(b: ty.handle, c: ty.handle) -> None:
+ B = tir.match_buffer(b, (16, 16), "float32")
+ C = tir.match_buffer(c, (16, 16), "float32")
+
+ for i, jo, ji in tir.grid(16, 4, 5):
+ with tir.block([16, 16], "update") as [vi, vj]:
+ tir.bind(vi, i)
+ tir.bind(vj, jo * 4 + ji)
+ tir.where(jo * 4 + ji < 16)
+ C[vi, vj] = B[vi, vj] + tir.float32(1)
+
+
+def test_module_define():
+ func1 = tvm.script.create_module({"matmul": matmul})["matmul"]
+ func2 = tvm.script.create_module({"element_wise":
element_wise})["element_wise"]
+ func3 = tvm.script.create_module({"predicate": predicate})["predicate"]
+ mod1 = tvm.script.create_module({"func1": func1, "func2": func2, "func3":
func3})
+ mod2 = tvm.script.create_module({"func1": matmul, "func2": element_wise,
"func3": predicate})
+ tvm.ir.assert_structural_equal(mod1, mod2)
+
+
+def test_matmul():
+ func = matmul
+ rt_func = tvm.script.from_source(tvm.script.asscript(func, True))
+ tvm.ir.assert_structural_equal(func, rt_func)
+
+
+def test_matmul_original():
+ func = matmul_original
+ rt_func = tvm.script.from_source(tvm.script.asscript(func, True))
+ tvm.ir.assert_structural_equal(func, rt_func)
+
+ assert isinstance(rt_func.body.block, tir.stmt.Block)
+ assert isinstance(rt_func.body.block.body, tir.stmt.For)
+ assert isinstance(rt_func.body.block.body.body, tir.stmt.For)
+ assert isinstance(rt_func.body.block.body.body.body, tir.stmt.SeqStmt)
+ assert isinstance(rt_func.body.block.body.body.body[0].block,
tir.stmt.Block)
+ assert isinstance(rt_func.body.block.body.body.body[1], tir.stmt.For)
+ assert isinstance(rt_func.body.block.body.body.body[1].body.block,
tir.stmt.Block)
+
+
+def test_element_wise():
+ func = element_wise
+ rt_func = tvm.script.from_source(tvm.script.asscript(func, True))
+ tvm.ir.assert_structural_equal(func, rt_func)
+
+ assert isinstance(rt_func.body.block, tir.stmt.Block)
+ assert isinstance(rt_func.body.block.body, tir.stmt.SeqStmt)
+ assert isinstance(rt_func.body.block.body[0], tir.stmt.For)
+ assert isinstance(rt_func.body.block.body[0].body, tir.stmt.For)
+ assert isinstance(rt_func.body.block.body[0].body.body.block,
tir.stmt.Block)
+
+ assert isinstance(rt_func.body.block.body[1], tir.stmt.For)
+ assert isinstance(rt_func.body.block.body[1].body, tir.stmt.For)
+ assert isinstance(rt_func.body.block.body[1].body.body.block,
tir.stmt.Block)
+
+
+def test_predicate():
+ func = predicate
+ rt_func = tvm.script.from_source(tvm.script.asscript(func, True))
+ tvm.ir.assert_structural_equal(func, rt_func)
+
+ assert isinstance(rt_func.body.block, tir.stmt.Block)
+ assert isinstance(rt_func.body.block.body, tir.stmt.For)
+ assert isinstance(rt_func.body.block.body.body, tir.stmt.For)
+ assert isinstance(rt_func.body.block.body.body.body, tir.stmt.For)
+ assert isinstance(rt_func.body.block.body.body.body.body.block,
tir.stmt.Block)
+
+
[email protected]
+def for_thread_binding(a: ty.handle, b: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 16), "float32")
+ B = tir.match_buffer(b, (16, 16), "float32")
+
+ for i in tir.thread_binding(0, 16, thread="threadIdx.x"):
+ for j in tir.thread_binding(0, 16, thread="threadIdx.y"):
+ A[i, j] = B[i, j] + tir.float32(1)
+
+
+def test_for_thread_binding():
+ func = for_thread_binding
+ rt_func = tvm.script.from_source(tvm.script.asscript(func, True))
+ tvm.ir.assert_structural_equal(func, rt_func)
+
+ assert isinstance(rt_func.body, tir.stmt.For)
+ assert rt_func.body.kind == 4
+ assert rt_func.body.thread_binding.thread_tag == "threadIdx.x"
+ assert isinstance(rt_func.body.body, tir.stmt.For)
+ assert rt_func.body.body.kind == 4
+ assert rt_func.body.body.thread_binding.thread_tag == "threadIdx.y"
+
+
[email protected]
+def block_elements(a: ty.handle, b: ty.handle) -> None:
+ A = tir.match_buffer(a, (16, 16), "float32")
+ B = tir.match_buffer(b, (1, 1), "float32")
+
+ with tir.block([1], "update") as [vi]:
+ tir.bind(vi, 0)
+ tir.where(True)
+ tir.reads(A[0:16, 0:16])
+ tir.writes(B[0, 0])
+ tir.block_attr({"attr_key": "attr_value"})
+ C = tir.alloc_buffer((4, 4), dtype="float32")
+ D = tir.match_buffer_region(A[0:4, 0])
+ with tir.init():
+ B[0, 0] = tir.float32(0)
+ B[0, 0] = A[0, 0] + B[0, 0] + C[1, 1] + D[2, 0]
+
+
+def test_block_elements():
+ func = block_elements
+ rt_func = tvm.script.from_source(tvm.script.asscript(func, True))
+ tvm.ir.assert_structural_equal(func, rt_func)
+
+ assert isinstance(rt_func.body.block, tir.stmt.Block)
+ assert isinstance(rt_func.body.block.body, tir.stmt.BufferStore)
+ assert isinstance(rt_func.body.block.init, tir.stmt.BufferStore)
+ assert len(rt_func.body.block.annotations) == 1
+ assert rt_func.body.block.annotations["attr_key"] == "attr_value"
+
+
if __name__ == "__main__":
test_opt_gemm_normalize()
test_opt_gemm_mod_host()
@@ -2669,3 +2832,10 @@ if __name__ == "__main__":
test_opt_conv_tensorcore_normalize()
test_opt_conv_tensorcore_lower()
test_opt_conv_tensorcore_mod_host()
+ test_module_define()
+ test_matmul()
+ test_matmul_original()
+ test_element_wise()
+ test_predicate()
+ test_for_thread_binding()
+ test_block_elements()
diff --git a/tests/scripts/task_ci_python_setup.sh
b/tests/scripts/task_ci_python_setup.sh
index f48ed49..b880cb9 100755
--- a/tests/scripts/task_ci_python_setup.sh
+++ b/tests/scripts/task_ci_python_setup.sh
@@ -30,4 +30,4 @@ set -o pipefail
#
echo "Addtiional setup in" ${CI_IMAGE_NAME}
-python3 -m pip install --user tlcpack-sphinx-addon==0.1.4 synr==0.2.1
+python3 -m pip install --user tlcpack-sphinx-addon==0.1.4 synr==0.3.0
diff --git a/tests/scripts/task_ci_setup.sh b/tests/scripts/task_ci_setup.sh
index 17838c5..9dda54e 100755
--- a/tests/scripts/task_ci_setup.sh
+++ b/tests/scripts/task_ci_setup.sh
@@ -30,7 +30,7 @@ set -o pipefail
#
echo "Addtiional setup in" ${CI_IMAGE_NAME}
-python3 -m pip install --user tlcpack-sphinx-addon==0.1.4 synr==0.2.1
+python3 -m pip install --user tlcpack-sphinx-addon==0.1.4 synr==0.3.0
# Rebuild standalone_crt in build/ tree. This file is not currently archived
by pack_lib() in
# Jenkinsfile. We expect config.cmake to be present from pack_lib().