tqchen commented on a change in pull request #6227: URL: https://github.com/apache/incubator-tvm/pull/6227#discussion_r467312444
########## File path: python/tvm/hybrid/registry.py ########## @@ -0,0 +1,240 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hybrid Script Parser Function Registry """ +# pylint: disable=inconsistent-return-statements +import inspect +from enum import IntEnum +from typed_ast import ast3 as ast + + +class Category(IntEnum): + """Categories of registered functions""" + INTRIN = 0 + WITH_SCOPE = 1 + FOR_SCOPE = 2 + SPECIAL_STMT = 3 + + +class Registry(object): + """Registration map + All these maps are static + """ + intrin = dict() + with_scope = dict() + for_scope = dict() + special_stmt = dict() + + host_dict = { + Category.INTRIN: intrin, + Category.WITH_SCOPE: with_scope, + Category.FOR_SCOPE: for_scope, + Category.SPECIAL_STMT: special_stmt + } + + +class CallArgumentReader(object): + """A helper class which read required argument from passed arguments""" + + def __init__(self, func_name, args, kwargs, parser): + self.func_name = func_name + self.args = args + self.kwargs = kwargs + self.parser = parser + + def get_func_compulsory_arg(self, pos, name): + """Get corresponding function argument from argument list which is compulsory""" + + if len(self.args) >= pos: + arg = self.args[pos - 1] + elif name not in self.kwargs.keys(): + self.parser.report_error(self.func_name + " misses argument " + name) + else: + arg = self.kwargs[name] + + return arg + + def get_func_optional_arg(self, pos, name, default): + """Get corresponding function argument from argument list which is optional. + If user doesn't provide the argument, set it to default value + """ + + if len(self.args) >= pos: + arg = self.args[pos - 1] + elif name in self.kwargs.keys(): + arg = self.kwargs[name] + else: + return default + + return arg + + +def func_wrapper(func_name, func_to_register, arg_list, need_parser_and_node, need_body, concise): + """Helper function to wrap a function to be registered """ + + def wrap_func(parser, node, args, kwargs): + reader = CallArgumentReader(func_name, args, kwargs, parser) + internal_args = list() + + if need_body and not isinstance(node, ast.For): + # automatically parse body for with scope handlers + if isinstance(node, ast.With): + # the with scope handler is used inside with context + parser.scope_emitter.new_scope() + parser.scope_emitter.node_stack[-1].extend(reversed(node.body)) + body = parser.get_body() + parser.scope_emitter.pop_scope() + else: + # the with scope handler is used in concise scoping + if not concise: + parser.report_error("Concise scoping is not allowed here") + body = parser.get_body() + + if need_parser_and_node: + internal_args.append(parser) + internal_args.append(node) + + for i, arg_info in enumerate(arg_list): + if len(arg_info) == 1: + arg_name, = arg_info + if need_body and arg_name == "body": + internal_args.append(body) + else: + internal_args.append(reader.get_func_compulsory_arg(i + 1, arg_name)) + else: + arg_name, default = arg_info + internal_args.append(reader.get_func_optional_arg(i + 1, arg_name, default=default)) + + return func_to_register(*internal_args) + + return wrap_func + + +def register_func(category, origin_func, need_parser_and_node, need_body, concise): Review comment: Seems that we can refactor the code a bit to remove Category. - change this function to wrap_function - Break register_scope_handler into - register_with_scope - register_for_scope Then we don't need to introduce the Category enum ########## File path: python/tvm/hybrid/parser.py ########## @@ -0,0 +1,754 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hybrid Script Parser For TIR""" +# pylint: disable=invalid-name, missing-docstring, inconsistent-return-statements, no-else-return +# pylint: disable=unnecessary-comprehension, unused-argument, import-outside-toplevel +# pylint: disable=unused-import +import json +import numbers +import operator +from typed_ast import ast3 as ast + +import tvm._ffi +from tvm import tir +from tvm._ffi.base import TVMError +from tvm.ir import GlobalVar +from tvm.tir import all as _all +from tvm.tir import expr as _expr + +from . import scope_emitter, special_stmt, scope_handler, intrin +from .meta_unparser import MetaUnparser +from .registry import Registry + + +class HybridParserError(RuntimeError): + """Hybrid Parser Runtime Error""" + + +class HybridParser(ast.NodeVisitor): + """Python AST visitor pass which finally lowers it to TIR + Notes for extension: + 1. To support new types of AST nodes. Add a function visit_xxx(). + 2. To support new functions + We divide allowed function calls in hybrid script into 3 categories, + which is scope_handler, intrin and special_stmt. + 1) scope_handler: scope_handler functions correspond to StmtNodes without body, which can be + further classified into 2 categories: with scope handler can for scope handlers + 2) intrin: intrin functions corresponds to the remaining IRNodes (StmtNodes without body, + PrimExprNodes and more) + 3) special_stmt: special_stmt functions don't correspond to an IRNode in the AST directly. + It is usually used for some information that is not suitable to be printed directly. + When visiting With node, we check with_scope registry. + When visiting For node, we check for_scope registry. + """ + + _binop_maker = { + ast.Add: tir.Add, + ast.Sub: tir.Sub, + ast.Mult: tir.Mul, + ast.Div: tir.Div, + ast.FloorDiv: tir.FloorDiv, + ast.Mod: tir.FloorMod, + ast.BitOr: operator.or_, + ast.BitAnd: operator.and_, + ast.BitXor: operator.xor, + ast.Gt: tir.GT, + ast.GtE: tir.GE, + ast.Lt: tir.LT, + ast.LtE: tir.LE, + ast.Eq: tir.EQ, + ast.NotEq: tir.NE, + ast.And: tir.And, + ast.Or: tir.Or, + } + + _unaryop_maker = { + ast.USub: operator.neg, + ast.Invert: operator.invert, + ast.Not: tir.Not + } + + def __init__(self, src, base_lienno): + self.params = None + self.buffer_map = None + self.dict_attr = None + self.scope_emitter = None + + self.src = src.split('\n') + self.base_lineno = base_lienno + self.current_lineno = 0 + self.current_col_offset = 0 + self.meta = None + + self.functions = {} + + self._in_with_func_arg = False + self._assign_target = None + + def init_function_parsing_env(self): + """Initialize function parsing environment""" + self.params = [] # parameter list + self.buffer_map = {} # buffer map + self.dict_attr = {} # dict attr + self.scope_emitter = scope_emitter.ScopeEmitter(self) # scope emitter + + @staticmethod + def is_meta(node): + """Judge whether an AST node is META""" + return isinstance(node, ast.Assign) and len(node.targets) == 1 \ + and isinstance(node.targets[0], ast.Name) and node.targets[0].id == "__tvm_meta__" + + def init_meta(self, meta_dict): + if meta_dict is not None: + self.meta = tvm.ir.load_json(json.dumps(meta_dict)) + + def visit(self, node): + """Override method in ast.NodeVisitor""" + old_lineno, old_col_offset = self.current_lineno, self.current_col_offset + + if hasattr(node, "lineno"): + self.current_lineno = self.base_lineno + node.lineno - 1 + if hasattr(node, "col_offset"): + self.current_col_offset = node.col_offset + + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + visit_res = visitor(node) + + self.current_lineno, self.current_col_offset = old_lineno, old_col_offset + + return visit_res + + def wrap_line_col(self, message, lineno, col_offset): + """Wrap the message with line number and column offset""" + src_line = self.src[lineno - self.base_lineno] + leading_space = len(src_line) - len(src_line.lstrip(' ')) + col_offset = col_offset - leading_space + src_line = src_line[leading_space:] + return "\n " + src_line + "\n " + " " * col_offset + "^\n" + "ParserError in line " \ + + str(lineno) + " : " + message + + def report_error(self, message, lineno=None, col_offset=None): + """ Report an error occur in line lineno and column col_offset + Parameters + ---------- + message : str + Error message + lineno : int + Line number of error line + col_offset : int + Column offset of error line + """ + + if lineno is None: + lineno = self.current_lineno + if col_offset is None: + col_offset = self.current_col_offset + raise HybridParserError(self.wrap_line_col(message, lineno, col_offset)) + + def get_type_name(self, vtype): + if isinstance(vtype, ast.Attribute) \ + and isinstance(vtype.value, ast.Name) and vtype.value.id == 'ty': + return vtype.attr + self.report_error("invalid type annotation") + + def get_body(self): + body = [] + while len(self.scope_emitter.node_stack[-1]) > 0: + res = self.visit(self.scope_emitter.node_stack[-1].pop()) + if res is not None: + body.append(res) + return tvm.tir.SeqStmt(body) if len(body) > 1 else body[0] + + def parse_type(self, vtype): + """ Parse type annotation AST into Type object """ + if isinstance(vtype, ast.NameConstant) and vtype.value is None: + return tvm.ir.TupleType([]) + elif isinstance(vtype, ast.Attribute): + return tvm.ir.PrimType(self.get_type_name(vtype)) + elif isinstance(vtype, ast.Subscript) and isinstance(vtype.slice, ast.Index): + type_name = self.get_type_name(vtype.value) + if isinstance(vtype.slice.value, ast.Tuple): + args = [self.parse_type(element) for element in vtype.slice.value.elts] + else: + args = [self.parse_type(vtype.slice.value)] + if type_name == "Ptr": + return tvm.ir.PointerType(*args) + elif type_name == "Tuple": + return tvm.ir.TupleType(args) + + self.report_error("invalid type annotation") + + def generic_visit(self, node): + """ Override method in ast.NodeVisitor. + To directly filter out invalidate type of stmt. + """ + + self.report_error(type(node).__name__ + " stmt is not supported now") + + def visit_Module(self, node): + """ Module visitor + AST abstract grammar: + Module(stmt* body, type_ignore* type_ignore) + By now we support two format of hybrid script shown below. + + Example + ------- + 1. Generate a Function(If the code is printed, then it may bring meta) + .. code-block:: python + + import tvm + + @tvm.hybrid.script + def A(...): + ... + + # call hybrid parser when call this function, get a Function + func = A + + 2. Generate a Module + .. code-block:: python + + import tvm + + @tvm.hybrid.script + class MyMod(): + def A(...): + ... + + def B(...): + ... + + __tvm_meta__ = ... + + # call hybrid parser during construction, get a Module + mod = MyMod + """ Review comment: shall we update the comment? is bracket needed? ########## File path: python/tvm/hybrid/__init__.py ########## @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hybrid Script APIs of TVM Python Package, aimed to support TIR""" + +from .utils import create_module, ashybrid, script +from .registry import register_intrin as register Review comment: avoid exposing `register_intrin`. We can consider implicitly expose `tir` prefix intrin as tir.xyz ########## File path: python/tvm/hybrid/parser.py ########## @@ -0,0 +1,754 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hybrid Script Parser For TIR""" +# pylint: disable=invalid-name, missing-docstring, inconsistent-return-statements, no-else-return +# pylint: disable=unnecessary-comprehension, unused-argument, import-outside-toplevel +# pylint: disable=unused-import +import json +import numbers +import operator +from typed_ast import ast3 as ast + +import tvm._ffi +from tvm import tir +from tvm._ffi.base import TVMError +from tvm.ir import GlobalVar +from tvm.tir import all as _all +from tvm.tir import expr as _expr + +from . import scope_emitter, special_stmt, scope_handler, intrin +from .meta_unparser import MetaUnparser +from .registry import Registry + + +class HybridParserError(RuntimeError): + """Hybrid Parser Runtime Error""" + + +class HybridParser(ast.NodeVisitor): + """Python AST visitor pass which finally lowers it to TIR + Notes for extension: + 1. To support new types of AST nodes. Add a function visit_xxx(). + 2. To support new functions + We divide allowed function calls in hybrid script into 3 categories, + which is scope_handler, intrin and special_stmt. + 1) scope_handler: scope_handler functions correspond to StmtNodes without body, which can be + further classified into 2 categories: with scope handler can for scope handlers + 2) intrin: intrin functions corresponds to the remaining IRNodes (StmtNodes without body, + PrimExprNodes and more) + 3) special_stmt: special_stmt functions don't correspond to an IRNode in the AST directly. + It is usually used for some information that is not suitable to be printed directly. + When visiting With node, we check with_scope registry. + When visiting For node, we check for_scope registry. + """ + + _binop_maker = { + ast.Add: tir.Add, + ast.Sub: tir.Sub, + ast.Mult: tir.Mul, + ast.Div: tir.Div, + ast.FloorDiv: tir.FloorDiv, + ast.Mod: tir.FloorMod, + ast.BitOr: operator.or_, + ast.BitAnd: operator.and_, + ast.BitXor: operator.xor, + ast.Gt: tir.GT, + ast.GtE: tir.GE, + ast.Lt: tir.LT, + ast.LtE: tir.LE, + ast.Eq: tir.EQ, + ast.NotEq: tir.NE, + ast.And: tir.And, + ast.Or: tir.Or, + } + + _unaryop_maker = { + ast.USub: operator.neg, + ast.Invert: operator.invert, + ast.Not: tir.Not + } + + def __init__(self, src, base_lienno): + self.params = None + self.buffer_map = None + self.dict_attr = None + self.scope_emitter = None + + self.src = src.split('\n') + self.base_lineno = base_lienno + self.current_lineno = 0 + self.current_col_offset = 0 + self.meta = None + + self.functions = {} + + self._in_with_func_arg = False + self._assign_target = None + + def init_function_parsing_env(self): + """Initialize function parsing environment""" + self.params = [] # parameter list + self.buffer_map = {} # buffer map + self.dict_attr = {} # dict attr + self.scope_emitter = scope_emitter.ScopeEmitter(self) # scope emitter + + @staticmethod + def is_meta(node): + """Judge whether an AST node is META""" + return isinstance(node, ast.Assign) and len(node.targets) == 1 \ + and isinstance(node.targets[0], ast.Name) and node.targets[0].id == "__tvm_meta__" + + def init_meta(self, meta_dict): + if meta_dict is not None: + self.meta = tvm.ir.load_json(json.dumps(meta_dict)) + + def visit(self, node): + """Override method in ast.NodeVisitor""" + old_lineno, old_col_offset = self.current_lineno, self.current_col_offset + + if hasattr(node, "lineno"): + self.current_lineno = self.base_lineno + node.lineno - 1 + if hasattr(node, "col_offset"): + self.current_col_offset = node.col_offset + + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + visit_res = visitor(node) + + self.current_lineno, self.current_col_offset = old_lineno, old_col_offset + + return visit_res + + def wrap_line_col(self, message, lineno, col_offset): + """Wrap the message with line number and column offset""" + src_line = self.src[lineno - self.base_lineno] + leading_space = len(src_line) - len(src_line.lstrip(' ')) + col_offset = col_offset - leading_space + src_line = src_line[leading_space:] + return "\n " + src_line + "\n " + " " * col_offset + "^\n" + "ParserError in line " \ + + str(lineno) + " : " + message + + def report_error(self, message, lineno=None, col_offset=None): + """ Report an error occur in line lineno and column col_offset + Parameters + ---------- + message : str + Error message + lineno : int + Line number of error line + col_offset : int + Column offset of error line + """ + + if lineno is None: + lineno = self.current_lineno + if col_offset is None: + col_offset = self.current_col_offset + raise HybridParserError(self.wrap_line_col(message, lineno, col_offset)) + + def get_type_name(self, vtype): + if isinstance(vtype, ast.Attribute) \ + and isinstance(vtype.value, ast.Name) and vtype.value.id == 'ty': + return vtype.attr + self.report_error("invalid type annotation") + + def get_body(self): + body = [] + while len(self.scope_emitter.node_stack[-1]) > 0: + res = self.visit(self.scope_emitter.node_stack[-1].pop()) + if res is not None: + body.append(res) + return tvm.tir.SeqStmt(body) if len(body) > 1 else body[0] + + def parse_type(self, vtype): + """ Parse type annotation AST into Type object """ + if isinstance(vtype, ast.NameConstant) and vtype.value is None: + return tvm.ir.TupleType([]) + elif isinstance(vtype, ast.Attribute): + return tvm.ir.PrimType(self.get_type_name(vtype)) + elif isinstance(vtype, ast.Subscript) and isinstance(vtype.slice, ast.Index): + type_name = self.get_type_name(vtype.value) + if isinstance(vtype.slice.value, ast.Tuple): + args = [self.parse_type(element) for element in vtype.slice.value.elts] + else: + args = [self.parse_type(vtype.slice.value)] + if type_name == "Ptr": + return tvm.ir.PointerType(*args) + elif type_name == "Tuple": + return tvm.ir.TupleType(args) + + self.report_error("invalid type annotation") + + def generic_visit(self, node): + """ Override method in ast.NodeVisitor. + To directly filter out invalidate type of stmt. + """ + + self.report_error(type(node).__name__ + " stmt is not supported now") + + def visit_Module(self, node): + """ Module visitor + AST abstract grammar: + Module(stmt* body, type_ignore* type_ignore) + By now we support two format of hybrid script shown below. + + Example + ------- + 1. Generate a Function(If the code is printed, then it may bring meta) + .. code-block:: python + + import tvm + + @tvm.hybrid.script + def A(...): + ... + + # call hybrid parser when call this function, get a Function + func = A + + 2. Generate a Module + .. code-block:: python + + import tvm + + @tvm.hybrid.script + class MyMod(): + def A(...): + ... + + def B(...): + ... + + __tvm_meta__ = ... + + # call hybrid parser during construction, get a Module + mod = MyMod + """ + + if len(node.body) == 1 and isinstance(node.body[0], (ast.ClassDef, ast.FunctionDef)): + # class or single function + return self.visit(node.body[0]) + elif len(node.body) == 2: + if isinstance(node.body[0], ast.Assign): + node.body[0], node.body[1] = node.body[1], node.body[0] + if isinstance(node.body[0], ast.FunctionDef) and HybridParser.is_meta(node.body[1]): + # function with meta + self.init_meta(MetaUnparser().visit(node.body[1].value)) + return self.visit(node.body[0]) + self.report_error( + "Only one-function, one-class or function-with-meta source code is allowed") + + def visit_ClassDef(self, node): + """ ClassDef visitor + AST abstract grammar: + ClassDef(identifier name, expr* bases, keyword* keywords, stmt* body, + expr* decorator_list) + """ + + # parse meta + count = False + for body_element in node.body: + if isinstance(body_element, ast.FunctionDef): + pass + elif HybridParser.is_meta(body_element) and not count: + count = True + self.init_meta(MetaUnparser().visit(body_element.value)) + else: + self.report_error("invalid class member") + + # parse member functions + for body_element in node.body: + if isinstance(body_element, ast.FunctionDef): + self.visit(body_element) + from .utils import create_module + return create_module(self.functions) + + def visit_FunctionDef(self, node): + """ FunctionDef visitor + AST abstract grammar: + FunctionDef(identifier name, arguments args, stmt* body, expr* decorator_list, + expr? returns, string? type_comment) + arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs, + expr* kw_defaults, arg? kwarg, expr* defaults) + arg = (identifier arg, expr? annotation, string? type_comment) + """ + + self.init_function_parsing_env() + # add parameters of function + for arg in node.args.args: + arg_var = tvm.te.var(arg.arg, self.parse_type(arg.annotation)) + self.scope_emitter.update_symbol(arg.arg, arg_var) + self.params.append(arg_var) + + # visit the body of function + self.scope_emitter.node_stack[-1].extend(reversed(node.body)) + + # fetch the body and return a tir.PrimFunc + func = tvm.tir.PrimFunc(self.params, self.get_body(), + ret_type=self.parse_type(node.returns), + buffer_map=self.buffer_map, + attrs=tvm.ir.make_node("DictAttrs", **self.dict_attr)) + self.functions[GlobalVar(node.name)] = func + return func + + def visit_Assign(self, node): + """ Assign visitor + AST abstract grammar: + Assign(expr* targets, expr value, string? type_comment) + By now only 2 types of Assign is supported: + 1. special stmts that appear as assign stmt + 1.1 Buffer = tir.buffer_bind()/tir.buffer_decl() + 1.2 Var = tir.var() + 2. (BufferStore) Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr + 3. (Store) Var[PrimExpr] = PrimExpr + """ + + if not len(node.targets) == 1: + self.report_error("Only one-valued assignment is supported now") + target = node.targets[0] + + if isinstance(target, ast.Name): + # scenario 1 + self._assign_target = target.id + rhs = self.visit(node.value) + if not isinstance(node.value, ast.Call): + self.report_error("Unsupported Assign stmt") + self.scope_emitter.update_symbol(target.id, rhs) + elif isinstance(target, ast.Subscript): + # scenario 2&3 + symbol, indexes = self.visit(target) + self._assign_target = (symbol, indexes) + rhs = self.visit(node.value) + if isinstance(symbol, tvm.tir.Buffer): + return tvm.tir.BufferStore(symbol, tvm.runtime.convert(rhs), indexes) + else: + if len(indexes) != 1: + self.report_error("Invalid Store stmt") + return tvm.tir.Store(symbol, tvm.runtime.convert(rhs), indexes[0], + tvm.runtime.convert(True)) + else: + self.report_error("Unsupported Assign stmt") + + def visit_AnnAssign(self, node): + """ AnnAssign visitor + AST abstract grammar: + AnnAssign(expr target, expr annotation, expr? value, int simple) + Corresponds to concise mode of with tir.let() + """ + + if isinstance(node.target, ast.Name): + value = self.visit(node.value) + var = tvm.te.var(node.target.id, self.parse_type(node.annotation)) + self.scope_emitter.update_symbol(var.name, var) + return tvm.tir.LetStmt(var, value, self.visit(self.scope_emitter.node_stack[-1].pop())) + else: + self.report_error("Unsupported AnnAssign stmt") + + def visit_Assert(self, node): + """ Assert visitor + AST abstract grammar: + Assert(expr test, expr? msg) + Corresponds to concise mode of with tir.assert() + """ + + condition = self.visit(node.test) + if node.msg is None: + self.report_error("Message of AssertStmt can't be None") + message = self.visit(node.msg) + return tvm.tir.AssertStmt(condition, tvm.runtime.convert(message), self.get_body()) + + def visit_For(self, node): + """ For visitor + AST abstract grammar: + For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment) + By now only 1 type of For is supported: + 1. for name in tir.range(begin, end, for_type) + """ + + if not isinstance(node.target, ast.Name): + self.report_error("The loop variable should be a name variable") + # check node.iter, which is a tir Call + if not isinstance(node.iter, ast.Call): + self.report_error("The loop iter should be a Call") + if not isinstance(node.iter.func, ast.Attribute) \ + or not isinstance(node.iter.func.value, ast.Name) \ + or node.iter.func.value.id != "tir": + self.report_error("The loop iter Call should be tir.name()") + + func_name = node.iter.func.attr + # collect arguments + args = [self.visit(arg) for arg in node.iter.args] + kw_args = [self.visit(keyword) for keyword in node.iter.keywords] + kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} + # All the functions supported in For stmt are registered in scope_handler.ForScope + if func_name not in Registry.for_scope.keys(): + self.report_error("Function " + func_name + " used in For stmt is not supported now", + self.current_lineno, + node.iter.col_offset) + + old_lineno, old_col_offset = self.current_lineno, self.current_col_offset + self.current_lineno, self.current_col_offset = \ + self.base_lineno + node.iter.lineno - 1, node.iter.col_offset + res = Registry.for_scope.get(func_name)(self, node, args, kw_args) + self.current_lineno, self.current_col_offset = old_lineno, old_col_offset + return res + + def visit_With(self, node): + """ With visitor + AST abstract grammar: + With(withitem* items, stmt* body, string? type_comment) + withitem = (expr context_expr, expr? optional_vars) + By now only 1 type of With is supported: + 1. with tir.let/tir.Assert()/tir.attr()/tir.allocate()/tir.realize() + """ + + if len(node.items) != 1: + self.report_error("Only one with element is supported now") + if not isinstance(node.items[0].context_expr, ast.Call): + self.report_error("The context expression of with should be a Call") + func_call = node.items[0].context_expr + if not isinstance(func_call.func, ast.Attribute) \ + or not isinstance(func_call.func.value, ast.Name) \ + or func_call.func.value.id != "tir": + self.report_error("The context expression of with should be tir.name()") + + func_name = func_call.func.attr + # collect arguments + args = [self.visit(arg) for arg in func_call.args] + kw_args = [self.visit(keyword) for keyword in func_call.keywords] + kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} + if func_name not in Registry.with_scope.keys(): Review comment: `Registry.with_scope.keys()` => `Registry.with_scope` ########## File path: python/tvm/hybrid/parser.py ########## @@ -0,0 +1,754 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hybrid Script Parser For TIR""" +# pylint: disable=invalid-name, missing-docstring, inconsistent-return-statements, no-else-return +# pylint: disable=unnecessary-comprehension, unused-argument, import-outside-toplevel +# pylint: disable=unused-import +import json +import numbers +import operator +from typed_ast import ast3 as ast + +import tvm._ffi +from tvm import tir +from tvm._ffi.base import TVMError +from tvm.ir import GlobalVar +from tvm.tir import all as _all +from tvm.tir import expr as _expr + +from . import scope_emitter, special_stmt, scope_handler, intrin +from .meta_unparser import MetaUnparser +from .registry import Registry + + +class HybridParserError(RuntimeError): + """Hybrid Parser Runtime Error""" + + +class HybridParser(ast.NodeVisitor): + """Python AST visitor pass which finally lowers it to TIR + Notes for extension: + 1. To support new types of AST nodes. Add a function visit_xxx(). + 2. To support new functions + We divide allowed function calls in hybrid script into 3 categories, + which is scope_handler, intrin and special_stmt. + 1) scope_handler: scope_handler functions correspond to StmtNodes without body, which can be + further classified into 2 categories: with scope handler can for scope handlers + 2) intrin: intrin functions corresponds to the remaining IRNodes (StmtNodes without body, + PrimExprNodes and more) + 3) special_stmt: special_stmt functions don't correspond to an IRNode in the AST directly. + It is usually used for some information that is not suitable to be printed directly. + When visiting With node, we check with_scope registry. + When visiting For node, we check for_scope registry. + """ + + _binop_maker = { + ast.Add: tir.Add, + ast.Sub: tir.Sub, + ast.Mult: tir.Mul, + ast.Div: tir.Div, + ast.FloorDiv: tir.FloorDiv, + ast.Mod: tir.FloorMod, + ast.BitOr: operator.or_, + ast.BitAnd: operator.and_, + ast.BitXor: operator.xor, + ast.Gt: tir.GT, + ast.GtE: tir.GE, + ast.Lt: tir.LT, + ast.LtE: tir.LE, + ast.Eq: tir.EQ, + ast.NotEq: tir.NE, + ast.And: tir.And, + ast.Or: tir.Or, + } + + _unaryop_maker = { + ast.USub: operator.neg, + ast.Invert: operator.invert, + ast.Not: tir.Not + } + + def __init__(self, src, base_lienno): + self.params = None + self.buffer_map = None + self.dict_attr = None + self.scope_emitter = None + + self.src = src.split('\n') + self.base_lineno = base_lienno + self.current_lineno = 0 + self.current_col_offset = 0 + self.meta = None + + self.functions = {} + + self._in_with_func_arg = False + self._assign_target = None + + def init_function_parsing_env(self): + """Initialize function parsing environment""" + self.params = [] # parameter list + self.buffer_map = {} # buffer map + self.dict_attr = {} # dict attr + self.scope_emitter = scope_emitter.ScopeEmitter(self) # scope emitter + + @staticmethod + def is_meta(node): + """Judge whether an AST node is META""" + return isinstance(node, ast.Assign) and len(node.targets) == 1 \ + and isinstance(node.targets[0], ast.Name) and node.targets[0].id == "__tvm_meta__" + + def init_meta(self, meta_dict): + if meta_dict is not None: + self.meta = tvm.ir.load_json(json.dumps(meta_dict)) + + def visit(self, node): + """Override method in ast.NodeVisitor""" + old_lineno, old_col_offset = self.current_lineno, self.current_col_offset + + if hasattr(node, "lineno"): + self.current_lineno = self.base_lineno + node.lineno - 1 + if hasattr(node, "col_offset"): + self.current_col_offset = node.col_offset + + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + visit_res = visitor(node) + + self.current_lineno, self.current_col_offset = old_lineno, old_col_offset + + return visit_res + + def wrap_line_col(self, message, lineno, col_offset): + """Wrap the message with line number and column offset""" + src_line = self.src[lineno - self.base_lineno] + leading_space = len(src_line) - len(src_line.lstrip(' ')) + col_offset = col_offset - leading_space + src_line = src_line[leading_space:] + return "\n " + src_line + "\n " + " " * col_offset + "^\n" + "ParserError in line " \ + + str(lineno) + " : " + message + + def report_error(self, message, lineno=None, col_offset=None): + """ Report an error occur in line lineno and column col_offset + Parameters + ---------- + message : str + Error message + lineno : int + Line number of error line + col_offset : int + Column offset of error line + """ + + if lineno is None: + lineno = self.current_lineno + if col_offset is None: + col_offset = self.current_col_offset + raise HybridParserError(self.wrap_line_col(message, lineno, col_offset)) + + def get_type_name(self, vtype): + if isinstance(vtype, ast.Attribute) \ + and isinstance(vtype.value, ast.Name) and vtype.value.id == 'ty': + return vtype.attr + self.report_error("invalid type annotation") + + def get_body(self): + body = [] + while len(self.scope_emitter.node_stack[-1]) > 0: + res = self.visit(self.scope_emitter.node_stack[-1].pop()) + if res is not None: + body.append(res) + return tvm.tir.SeqStmt(body) if len(body) > 1 else body[0] + + def parse_type(self, vtype): + """ Parse type annotation AST into Type object """ + if isinstance(vtype, ast.NameConstant) and vtype.value is None: + return tvm.ir.TupleType([]) + elif isinstance(vtype, ast.Attribute): + return tvm.ir.PrimType(self.get_type_name(vtype)) + elif isinstance(vtype, ast.Subscript) and isinstance(vtype.slice, ast.Index): + type_name = self.get_type_name(vtype.value) + if isinstance(vtype.slice.value, ast.Tuple): + args = [self.parse_type(element) for element in vtype.slice.value.elts] + else: + args = [self.parse_type(vtype.slice.value)] + if type_name == "Ptr": + return tvm.ir.PointerType(*args) + elif type_name == "Tuple": + return tvm.ir.TupleType(args) + + self.report_error("invalid type annotation") + + def generic_visit(self, node): + """ Override method in ast.NodeVisitor. + To directly filter out invalidate type of stmt. + """ + + self.report_error(type(node).__name__ + " stmt is not supported now") + + def visit_Module(self, node): + """ Module visitor + AST abstract grammar: + Module(stmt* body, type_ignore* type_ignore) + By now we support two format of hybrid script shown below. + + Example + ------- + 1. Generate a Function(If the code is printed, then it may bring meta) + .. code-block:: python + + import tvm + + @tvm.hybrid.script + def A(...): + ... + + # call hybrid parser when call this function, get a Function + func = A + + 2. Generate a Module + .. code-block:: python + + import tvm + + @tvm.hybrid.script + class MyMod(): + def A(...): + ... + + def B(...): + ... + + __tvm_meta__ = ... + + # call hybrid parser during construction, get a Module + mod = MyMod + """ + + if len(node.body) == 1 and isinstance(node.body[0], (ast.ClassDef, ast.FunctionDef)): + # class or single function + return self.visit(node.body[0]) + elif len(node.body) == 2: + if isinstance(node.body[0], ast.Assign): + node.body[0], node.body[1] = node.body[1], node.body[0] + if isinstance(node.body[0], ast.FunctionDef) and HybridParser.is_meta(node.body[1]): + # function with meta + self.init_meta(MetaUnparser().visit(node.body[1].value)) + return self.visit(node.body[0]) + self.report_error( + "Only one-function, one-class or function-with-meta source code is allowed") + + def visit_ClassDef(self, node): + """ ClassDef visitor + AST abstract grammar: + ClassDef(identifier name, expr* bases, keyword* keywords, stmt* body, + expr* decorator_list) + """ + + # parse meta + count = False + for body_element in node.body: + if isinstance(body_element, ast.FunctionDef): + pass + elif HybridParser.is_meta(body_element) and not count: + count = True + self.init_meta(MetaUnparser().visit(body_element.value)) + else: + self.report_error("invalid class member") + + # parse member functions + for body_element in node.body: + if isinstance(body_element, ast.FunctionDef): + self.visit(body_element) + from .utils import create_module + return create_module(self.functions) + + def visit_FunctionDef(self, node): + """ FunctionDef visitor + AST abstract grammar: + FunctionDef(identifier name, arguments args, stmt* body, expr* decorator_list, + expr? returns, string? type_comment) + arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs, + expr* kw_defaults, arg? kwarg, expr* defaults) + arg = (identifier arg, expr? annotation, string? type_comment) + """ + + self.init_function_parsing_env() + # add parameters of function + for arg in node.args.args: + arg_var = tvm.te.var(arg.arg, self.parse_type(arg.annotation)) + self.scope_emitter.update_symbol(arg.arg, arg_var) + self.params.append(arg_var) + + # visit the body of function + self.scope_emitter.node_stack[-1].extend(reversed(node.body)) + + # fetch the body and return a tir.PrimFunc + func = tvm.tir.PrimFunc(self.params, self.get_body(), + ret_type=self.parse_type(node.returns), + buffer_map=self.buffer_map, + attrs=tvm.ir.make_node("DictAttrs", **self.dict_attr)) + self.functions[GlobalVar(node.name)] = func + return func + + def visit_Assign(self, node): + """ Assign visitor + AST abstract grammar: + Assign(expr* targets, expr value, string? type_comment) + By now only 2 types of Assign is supported: + 1. special stmts that appear as assign stmt + 1.1 Buffer = tir.buffer_bind()/tir.buffer_decl() + 1.2 Var = tir.var() + 2. (BufferStore) Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr + 3. (Store) Var[PrimExpr] = PrimExpr + """ + + if not len(node.targets) == 1: + self.report_error("Only one-valued assignment is supported now") + target = node.targets[0] + + if isinstance(target, ast.Name): + # scenario 1 + self._assign_target = target.id + rhs = self.visit(node.value) + if not isinstance(node.value, ast.Call): + self.report_error("Unsupported Assign stmt") + self.scope_emitter.update_symbol(target.id, rhs) + elif isinstance(target, ast.Subscript): + # scenario 2&3 + symbol, indexes = self.visit(target) + self._assign_target = (symbol, indexes) + rhs = self.visit(node.value) + if isinstance(symbol, tvm.tir.Buffer): + return tvm.tir.BufferStore(symbol, tvm.runtime.convert(rhs), indexes) + else: + if len(indexes) != 1: + self.report_error("Invalid Store stmt") + return tvm.tir.Store(symbol, tvm.runtime.convert(rhs), indexes[0], + tvm.runtime.convert(True)) + else: + self.report_error("Unsupported Assign stmt") + + def visit_AnnAssign(self, node): + """ AnnAssign visitor + AST abstract grammar: + AnnAssign(expr target, expr annotation, expr? value, int simple) + Corresponds to concise mode of with tir.let() + """ + + if isinstance(node.target, ast.Name): + value = self.visit(node.value) + var = tvm.te.var(node.target.id, self.parse_type(node.annotation)) + self.scope_emitter.update_symbol(var.name, var) + return tvm.tir.LetStmt(var, value, self.visit(self.scope_emitter.node_stack[-1].pop())) + else: + self.report_error("Unsupported AnnAssign stmt") + + def visit_Assert(self, node): + """ Assert visitor + AST abstract grammar: + Assert(expr test, expr? msg) + Corresponds to concise mode of with tir.assert() + """ + + condition = self.visit(node.test) + if node.msg is None: + self.report_error("Message of AssertStmt can't be None") + message = self.visit(node.msg) + return tvm.tir.AssertStmt(condition, tvm.runtime.convert(message), self.get_body()) + + def visit_For(self, node): + """ For visitor + AST abstract grammar: + For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment) + By now only 1 type of For is supported: + 1. for name in tir.range(begin, end, for_type) + """ + + if not isinstance(node.target, ast.Name): + self.report_error("The loop variable should be a name variable") + # check node.iter, which is a tir Call + if not isinstance(node.iter, ast.Call): + self.report_error("The loop iter should be a Call") + if not isinstance(node.iter.func, ast.Attribute) \ + or not isinstance(node.iter.func.value, ast.Name) \ + or node.iter.func.value.id != "tir": + self.report_error("The loop iter Call should be tir.name()") + + func_name = node.iter.func.attr + # collect arguments + args = [self.visit(arg) for arg in node.iter.args] + kw_args = [self.visit(keyword) for keyword in node.iter.keywords] + kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} + # All the functions supported in For stmt are registered in scope_handler.ForScope + if func_name not in Registry.for_scope.keys(): + self.report_error("Function " + func_name + " used in For stmt is not supported now", + self.current_lineno, + node.iter.col_offset) + + old_lineno, old_col_offset = self.current_lineno, self.current_col_offset + self.current_lineno, self.current_col_offset = \ + self.base_lineno + node.iter.lineno - 1, node.iter.col_offset + res = Registry.for_scope.get(func_name)(self, node, args, kw_args) + self.current_lineno, self.current_col_offset = old_lineno, old_col_offset + return res + + def visit_With(self, node): + """ With visitor + AST abstract grammar: + With(withitem* items, stmt* body, string? type_comment) + withitem = (expr context_expr, expr? optional_vars) + By now only 1 type of With is supported: + 1. with tir.let/tir.Assert()/tir.attr()/tir.allocate()/tir.realize() + """ + + if len(node.items) != 1: + self.report_error("Only one with element is supported now") + if not isinstance(node.items[0].context_expr, ast.Call): + self.report_error("The context expression of with should be a Call") + func_call = node.items[0].context_expr + if not isinstance(func_call.func, ast.Attribute) \ + or not isinstance(func_call.func.value, ast.Name) \ + or func_call.func.value.id != "tir": + self.report_error("The context expression of with should be tir.name()") + + func_name = func_call.func.attr + # collect arguments + args = [self.visit(arg) for arg in func_call.args] + kw_args = [self.visit(keyword) for keyword in func_call.keywords] + kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} + if func_name not in Registry.with_scope.keys(): + self.report_error("Function " + func_name + " used in With stmt is not supported now") + + # All the functions supported in With stmt are registered in scope_handler.WithScope + old_lineno, old_col_offset = self.current_lineno, self.current_col_offset + self.current_lineno, self.current_col_offset = \ + self.base_lineno + func_call.lineno - 1, func_call.col_offset + res = Registry.with_scope.get(func_name)(self, node, args, kw_args) + self.current_lineno, self.current_col_offset = old_lineno, old_col_offset + return res + + def visit_If(self, node): + """ If visitor + AST abstract grammar: + If(expr test, stmt* body, stmt* orelse) + """ + + condition = self.visit(node.test) + # then body + self.scope_emitter.new_scope() + self.scope_emitter.node_stack[-1].extend(reversed(node.body)) + then_body = self.get_body() + self.scope_emitter.pop_scope() + + # else body + if len(node.orelse) > 0: + self.scope_emitter.new_scope() + self.scope_emitter.node_stack[-1].extend(reversed(node.orelse)) + else_body = self.get_body() + self.scope_emitter.pop_scope() + else: + else_body = None + return tvm.tir.IfThenElse(condition, then_body, else_body) + + def visit_Call(self, node): + """ Call visitor + AST abstract grammar: + Call(expr func, expr* args, keyword* keywords) + keyword = (identifier? arg, expr value) + """ + + # collect arguments + args = [self.visit(arg) for arg in node.args] + kw_args = [self.visit(keyword) for keyword in node.keywords] + kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} + + maybe_intrin = False + if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): + if node.func.value.id == "tir": + func_name = node.func.attr + maybe_intrin = True + else: + self.report_error("Unsupported Attribute typed function call") + else: + self.report_error("Unsupported function call") + + if func_name in Registry.special_stmt.keys(): + return Registry.special_stmt.get(func_name)(self, node, args, kw_args) + if func_name in Registry.intrin.keys(): + return Registry.intrin.get(func_name)(self, node, args, kw_args) + if func_name in Registry.with_scope.keys(): + return Registry.with_scope.get(func_name)(self, node, args, kw_args) + if maybe_intrin: + return tvm.tir.Call(kw_args["dtype"], tvm.ir.op.Op.get("tir." + func_name), args) + + self.report_error("Function " + func_name + " is not supported now") + + def visit_Expr(self, node): + """ Expr visitor + AST abstract grammar: + Expr(expr value) + + Now only 2 types of Expr stmt is allowed: + 1. Concise mode of with scope handlers + tir.attr()/tir.assert()/tir.allocate()/tir.realize() + 2. special stmts appear as a call + tir.set_func_attr() + """ + + if not isinstance(node.value, ast.Call): + self.report_error("Unsupported Expr stmt") + return self.visit(node.value) + + def visit_BinOp(self, node): + """ BinOp visitor + AST abstract grammar: + BinOp(expr left, operator op, expr right) + """ + + lhs = self.visit(node.left) + rhs = self.visit(node.right) + if not isinstance(node.op, tuple(HybridParser._binop_maker.keys())): + self.report_error("BinOp " + str(type(node.op)) + " is not supported now") + return HybridParser._binop_maker[type(node.op)](lhs, rhs) + + def visit_Compare(self, node): + """ Compare visitor + AST abstract grammar: + Compare(expr left, expr right, ops=) + """ + + ops = [self.visit(node.left)] + ops += [self.visit(comparator) for comparator in node.comparators] + res = [] + for i in range(len(node.ops)): + lhs = ops[i] + rhs = ops[i + 1] + res.append(HybridParser._binop_maker[type(node.ops[i])](lhs, rhs)) + return _all(*res) + + def visit_BoolOp(self, node): + """ BoolOp visitor + AST abstract grammar: + BoolOp(boolop op, expr* values) + """ + + values = [self.visit(value) for value in node.values] + return HybridParser._binop_maker[type(node.op)](*values) + + def visit_UnaryOp(self, node): + """ UnaryOp visitor + AST abstract grammar: + UnaryOp(unaryop op, expr operand) + """ + + operand = self.visit(node.operand) + if not isinstance(node.op, tuple(HybridParser._unaryop_maker.keys())): + self.report_error("UnaryOp " + str(type(node.op)) + " is not supported now") + return HybridParser._unaryop_maker[type(node.op)](operand) + + def visit_Subscript(self, node): + """ Subscript visitor + AST abstract grammar: + Subscript(expr value, slice slice, expr_context ctx) + slice = Slice(expr? lower, expr? upper, expr? step) + | ExtSlice(slice* dims) + | Index(expr value) + By now only 2 types of Subscript are supported: + 1. Buffer[index, index, ...], Buffer element access(BufferLoad & BufferStore) + Var[index] Buffer element access() + 2. meta[type_key][index], Meta info access + """ + + if isinstance(node.value, (ast.Name, ast.Attribute)): + symbol = self.visit(node.value) + if isinstance(node.slice, ast.Index): + # BufferLoad & BufferStore + if isinstance(node.slice.value, ast.Tuple): + # Buffer/Var[index, index, ...] + indexes = [self.visit(element) for element in node.slice.value.elts] + else: + # Buffer/Var[index] + indexes = [self.visit(node.slice.value)] + if isinstance(node.ctx, ast.Load): + if isinstance(symbol, tir.expr.Var): + return tvm.tir.Load("float32", symbol, indexes, True) + else: + return tvm.tir.BufferLoad(symbol, indexes) + else: + return symbol, indexes + else: + # Buffer Region, now used in tir.realize(buffer[bounds]) + doms = [] + slice_nodes = [] + if isinstance(node.slice, ast.Slice): + # Buffer[begin:end] + slice_nodes.append(node.slice) + elif isinstance(node.slice, ast.ExtSlice): + # Buffer[begin:end, begin:end] + slice_nodes.extend(node.slice.dims) + + for dim in slice_nodes: + if not hasattr(dim, "step"): + self.report_error("slice of Buffer Region ought to be begin:end") + if dim.step is not None: + self.report_error("step is not allowed in Buffer Region") + upper = self.visit(dim.upper) + lower = self.visit(dim.lower) + extent = upper - lower + if isinstance(extent, _expr.PrimExpr): + ana = tvm.arith.Analyzer() + extent = ana.simplify(extent) + doms.append(tvm.ir.Range.from_min_extent(lower, extent)) + return symbol, doms + + elif isinstance(node.value, ast.Subscript) and isinstance(node.value.value, ast.Name) \ + and node.value.value.id == 'meta': + # meta[type_key][index] + if not (isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Num)) \ + or not (isinstance(node.value.slice, ast.Index) \ + and isinstance(node.value.slice.value, ast.Name)): + self.report_error("The meta access format ought to be meta[type_key][index]") + type_key = node.value.slice.value.id + index = node.slice.value.n + node_list = self.meta[type_key] + if node_list is None: + self.report_error("type_key " + type_key + " in meta not found") + if len(node_list) <= index: + self.report_error("index " + index + " out of range " + len(node_list)) + return node_list[index] + else: + self.report_error("Only buffer variable and meta can be subscriptable") + + def visit_Name(self, node): + """ Name visitor + AST abstract grammar: + Name(identifier id, expr_context ctx) + """ + + name = node.id + symbol = self.scope_emitter.lookup_symbol(name) + if symbol is None: + self.report_error("Unknown symbol %s" % name) + return symbol + + def visit_Attribute(self, node): + """ Attribute visitor + AST abstract grammar: + Attribute(expr value, identifier attr, expr_context ctx) + """ + + if not isinstance(node.value, ast.Name): + self.report_error("The value of Attribute ought to a Name") + name = node.value.id + symbol = self.scope_emitter.lookup_symbol(name) + if symbol is None or not isinstance(symbol, tvm.tir.Buffer): + self.report_error("Unsupported Attribute expression") + if not hasattr(symbol, node.attr): + self.report_error("Type " + type(symbol) + " has not attr " + node.attr) + return getattr(symbol, node.attr) + + def visit_Dict(self, node): + """ Dict visitor + AST abstract grammar: + Dict(expr* keys, expr* values) + """ + + keys = [self.visit(key) for key in node.keys] + values = [self.visit(value) for value in node.values] + + return {key: value for key, value in zip(keys, values)} + + def visit_Tuple(self, node): + """ Tuple visitor + AST abstract grammar: + Tuple(expr* elts, expr_context ctx) + """ + + return tuple(self.visit(element) for element in node.elts) + + def visit_List(self, node): + """ List visitor + AST abstract grammar: + List(expr* elts, expr_context ctx) + """ + + return [self.visit(element) for element in node.elts] + + def visit_keyword(self, node): + """ Keyword visitor + AST abstract grammar: + keyword = (identifier? arg, expr value) + """ + + return node.arg, self.visit(node.value) + + def visit_NameConstant(self, node): + return tvm.runtime.convert(node.value) + + def visit_Constant(self, node): + return tvm.runtime.convert(node.value) + + def visit_Num(self, node): + if isinstance(node.n, numbers.Integral): + dtype = "int32" + elif isinstance(node.n, float): + dtype = "float32" + else: + self.report_error("The data type should be one of (int, float)") + return tvm.tir.const(node.n, dtype) + + def visit_Str(self, node): + return node.s + + +def source_to_op(src, func_lineno=0): + """ Another level of wrapper + Parameters + ---------- + src : str + Pruned source of original script + func_lineno : Optional[int] + The line number of the first line of the script to be parsed + Returns + ------- + functions : PrimFunc or Module Review comment: Module->IRModule ########## File path: python/tvm/hybrid/scope_handler.py ########## @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hybrid Script Parser Scope Handler Functions + +This module provides the functions registered into parser under with_scope or for_scope category. +Scope handler nodes are StmtNodes with body, which are used to handle such scenarios. + +.. code-block:: python + + for x in tir.name(): + with tir.name(): + tir.name() # with scope handlers + concise scoping + +""" +# pylint: disable=redefined-builtin, unused-argument, invalid-name +import tvm.tir +from .registry import register_scope_handler +from .registry import Category + +# With scope handler + + +@register_scope_handler(Category.WITH_SCOPE, concise=False) +def Assert(parser, node, condition, message, body): + """ With scope handler function assert(condition, message, body) """ + + return tvm.tir.AssertStmt(condition, tvm.runtime.convert(message), body) + + +@register_scope_handler(Category.WITH_SCOPE, concise=False) +def let(parser, node, var, value, body): + """ With scope handler function let(var, value, body) """ + + return tvm.tir.LetStmt(var, value, body) + + +@register_scope_handler(Category.WITH_SCOPE, concise=True) +def realize(parser, node, buffer_bounds, body, condition=True): + """ With scope handler function realize(buffer_bounds, condition, body) """ + + buffer, bounds = buffer_bounds + return tvm.tir.BufferRealize(buffer, bounds, condition, body) + + +@register_scope_handler(Category.WITH_SCOPE, concise=True) +def attr(parser, node, attr_node, attr_key, value, body): + """ With scope handler function attr(attr_node, attr_key, value, body) """ + + return tvm.tir.AttrStmt(attr_node, attr_key, tvm.runtime.convert(value), body) + + +@register_scope_handler(Category.WITH_SCOPE, concise=True) +def allocate(parser, node, buffer_var, dtype, extents, body, condition=True): + """ With scope handler function allocate(buffer_var, dtype, extents, condition, body) """ + + return tvm.tir.Allocate(buffer_var, dtype, extents, tvm.runtime.convert(condition), body) + + +# For scope handler + + +@register_scope_handler(Category.FOR_SCOPE) +def range(parser, node, begin, end, for_type="serial"): Review comment: We can do it as a followup PR Can we still support the normal `for x in range()` that translate to serial range? Is it possible to also register syntax via the scope handler? ########## File path: python/tvm/hybrid/parser.py ########## @@ -0,0 +1,754 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hybrid Script Parser For TIR""" +# pylint: disable=invalid-name, missing-docstring, inconsistent-return-statements, no-else-return +# pylint: disable=unnecessary-comprehension, unused-argument, import-outside-toplevel +# pylint: disable=unused-import +import json +import numbers +import operator +from typed_ast import ast3 as ast + +import tvm._ffi +from tvm import tir +from tvm._ffi.base import TVMError +from tvm.ir import GlobalVar +from tvm.tir import all as _all +from tvm.tir import expr as _expr + +from . import scope_emitter, special_stmt, scope_handler, intrin +from .meta_unparser import MetaUnparser +from .registry import Registry + + +class HybridParserError(RuntimeError): + """Hybrid Parser Runtime Error""" + + +class HybridParser(ast.NodeVisitor): + """Python AST visitor pass which finally lowers it to TIR + Notes for extension: + 1. To support new types of AST nodes. Add a function visit_xxx(). + 2. To support new functions + We divide allowed function calls in hybrid script into 3 categories, + which is scope_handler, intrin and special_stmt. + 1) scope_handler: scope_handler functions correspond to StmtNodes without body, which can be + further classified into 2 categories: with scope handler can for scope handlers + 2) intrin: intrin functions corresponds to the remaining IRNodes (StmtNodes without body, + PrimExprNodes and more) + 3) special_stmt: special_stmt functions don't correspond to an IRNode in the AST directly. + It is usually used for some information that is not suitable to be printed directly. + When visiting With node, we check with_scope registry. + When visiting For node, we check for_scope registry. + """ + + _binop_maker = { + ast.Add: tir.Add, + ast.Sub: tir.Sub, + ast.Mult: tir.Mul, + ast.Div: tir.Div, + ast.FloorDiv: tir.FloorDiv, + ast.Mod: tir.FloorMod, + ast.BitOr: operator.or_, + ast.BitAnd: operator.and_, + ast.BitXor: operator.xor, + ast.Gt: tir.GT, + ast.GtE: tir.GE, + ast.Lt: tir.LT, + ast.LtE: tir.LE, + ast.Eq: tir.EQ, + ast.NotEq: tir.NE, + ast.And: tir.And, + ast.Or: tir.Or, + } + + _unaryop_maker = { + ast.USub: operator.neg, + ast.Invert: operator.invert, + ast.Not: tir.Not + } + + def __init__(self, src, base_lienno): + self.params = None + self.buffer_map = None + self.dict_attr = None + self.scope_emitter = None + + self.src = src.split('\n') + self.base_lineno = base_lienno + self.current_lineno = 0 + self.current_col_offset = 0 + self.meta = None + + self.functions = {} + + self._in_with_func_arg = False + self._assign_target = None + + def init_function_parsing_env(self): + """Initialize function parsing environment""" + self.params = [] # parameter list + self.buffer_map = {} # buffer map + self.dict_attr = {} # dict attr + self.scope_emitter = scope_emitter.ScopeEmitter(self) # scope emitter + + @staticmethod + def is_meta(node): + """Judge whether an AST node is META""" + return isinstance(node, ast.Assign) and len(node.targets) == 1 \ + and isinstance(node.targets[0], ast.Name) and node.targets[0].id == "__tvm_meta__" + + def init_meta(self, meta_dict): + if meta_dict is not None: + self.meta = tvm.ir.load_json(json.dumps(meta_dict)) + + def visit(self, node): + """Override method in ast.NodeVisitor""" + old_lineno, old_col_offset = self.current_lineno, self.current_col_offset + + if hasattr(node, "lineno"): + self.current_lineno = self.base_lineno + node.lineno - 1 + if hasattr(node, "col_offset"): + self.current_col_offset = node.col_offset + + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + visit_res = visitor(node) + + self.current_lineno, self.current_col_offset = old_lineno, old_col_offset + + return visit_res + + def wrap_line_col(self, message, lineno, col_offset): + """Wrap the message with line number and column offset""" + src_line = self.src[lineno - self.base_lineno] + leading_space = len(src_line) - len(src_line.lstrip(' ')) + col_offset = col_offset - leading_space + src_line = src_line[leading_space:] + return "\n " + src_line + "\n " + " " * col_offset + "^\n" + "ParserError in line " \ + + str(lineno) + " : " + message + + def report_error(self, message, lineno=None, col_offset=None): + """ Report an error occur in line lineno and column col_offset + Parameters + ---------- + message : str + Error message + lineno : int + Line number of error line + col_offset : int + Column offset of error line + """ + + if lineno is None: + lineno = self.current_lineno + if col_offset is None: + col_offset = self.current_col_offset + raise HybridParserError(self.wrap_line_col(message, lineno, col_offset)) + + def get_type_name(self, vtype): + if isinstance(vtype, ast.Attribute) \ + and isinstance(vtype.value, ast.Name) and vtype.value.id == 'ty': + return vtype.attr + self.report_error("invalid type annotation") + + def get_body(self): + body = [] + while len(self.scope_emitter.node_stack[-1]) > 0: + res = self.visit(self.scope_emitter.node_stack[-1].pop()) + if res is not None: + body.append(res) + return tvm.tir.SeqStmt(body) if len(body) > 1 else body[0] + + def parse_type(self, vtype): + """ Parse type annotation AST into Type object """ + if isinstance(vtype, ast.NameConstant) and vtype.value is None: + return tvm.ir.TupleType([]) + elif isinstance(vtype, ast.Attribute): + return tvm.ir.PrimType(self.get_type_name(vtype)) + elif isinstance(vtype, ast.Subscript) and isinstance(vtype.slice, ast.Index): + type_name = self.get_type_name(vtype.value) + if isinstance(vtype.slice.value, ast.Tuple): + args = [self.parse_type(element) for element in vtype.slice.value.elts] + else: + args = [self.parse_type(vtype.slice.value)] + if type_name == "Ptr": + return tvm.ir.PointerType(*args) + elif type_name == "Tuple": + return tvm.ir.TupleType(args) + + self.report_error("invalid type annotation") + + def generic_visit(self, node): + """ Override method in ast.NodeVisitor. + To directly filter out invalidate type of stmt. + """ + + self.report_error(type(node).__name__ + " stmt is not supported now") + + def visit_Module(self, node): + """ Module visitor + AST abstract grammar: + Module(stmt* body, type_ignore* type_ignore) + By now we support two format of hybrid script shown below. + + Example + ------- + 1. Generate a Function(If the code is printed, then it may bring meta) + .. code-block:: python + + import tvm + + @tvm.hybrid.script + def A(...): + ... + + # call hybrid parser when call this function, get a Function + func = A + + 2. Generate a Module + .. code-block:: python + + import tvm + + @tvm.hybrid.script + class MyMod(): + def A(...): + ... + + def B(...): + ... + + __tvm_meta__ = ... + + # call hybrid parser during construction, get a Module + mod = MyMod + """ + + if len(node.body) == 1 and isinstance(node.body[0], (ast.ClassDef, ast.FunctionDef)): + # class or single function + return self.visit(node.body[0]) + elif len(node.body) == 2: + if isinstance(node.body[0], ast.Assign): + node.body[0], node.body[1] = node.body[1], node.body[0] + if isinstance(node.body[0], ast.FunctionDef) and HybridParser.is_meta(node.body[1]): + # function with meta + self.init_meta(MetaUnparser().visit(node.body[1].value)) + return self.visit(node.body[0]) + self.report_error( + "Only one-function, one-class or function-with-meta source code is allowed") + + def visit_ClassDef(self, node): + """ ClassDef visitor + AST abstract grammar: + ClassDef(identifier name, expr* bases, keyword* keywords, stmt* body, + expr* decorator_list) + """ + + # parse meta + count = False + for body_element in node.body: + if isinstance(body_element, ast.FunctionDef): + pass + elif HybridParser.is_meta(body_element) and not count: + count = True + self.init_meta(MetaUnparser().visit(body_element.value)) + else: + self.report_error("invalid class member") + + # parse member functions + for body_element in node.body: + if isinstance(body_element, ast.FunctionDef): + self.visit(body_element) + from .utils import create_module + return create_module(self.functions) + + def visit_FunctionDef(self, node): + """ FunctionDef visitor + AST abstract grammar: + FunctionDef(identifier name, arguments args, stmt* body, expr* decorator_list, + expr? returns, string? type_comment) + arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs, + expr* kw_defaults, arg? kwarg, expr* defaults) + arg = (identifier arg, expr? annotation, string? type_comment) + """ + + self.init_function_parsing_env() + # add parameters of function + for arg in node.args.args: + arg_var = tvm.te.var(arg.arg, self.parse_type(arg.annotation)) + self.scope_emitter.update_symbol(arg.arg, arg_var) + self.params.append(arg_var) + + # visit the body of function + self.scope_emitter.node_stack[-1].extend(reversed(node.body)) + + # fetch the body and return a tir.PrimFunc + func = tvm.tir.PrimFunc(self.params, self.get_body(), + ret_type=self.parse_type(node.returns), + buffer_map=self.buffer_map, + attrs=tvm.ir.make_node("DictAttrs", **self.dict_attr)) + self.functions[GlobalVar(node.name)] = func + return func + + def visit_Assign(self, node): + """ Assign visitor + AST abstract grammar: + Assign(expr* targets, expr value, string? type_comment) + By now only 2 types of Assign is supported: + 1. special stmts that appear as assign stmt + 1.1 Buffer = tir.buffer_bind()/tir.buffer_decl() + 1.2 Var = tir.var() + 2. (BufferStore) Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr + 3. (Store) Var[PrimExpr] = PrimExpr + """ + + if not len(node.targets) == 1: + self.report_error("Only one-valued assignment is supported now") + target = node.targets[0] + + if isinstance(target, ast.Name): + # scenario 1 + self._assign_target = target.id + rhs = self.visit(node.value) + if not isinstance(node.value, ast.Call): + self.report_error("Unsupported Assign stmt") + self.scope_emitter.update_symbol(target.id, rhs) + elif isinstance(target, ast.Subscript): + # scenario 2&3 + symbol, indexes = self.visit(target) + self._assign_target = (symbol, indexes) + rhs = self.visit(node.value) + if isinstance(symbol, tvm.tir.Buffer): + return tvm.tir.BufferStore(symbol, tvm.runtime.convert(rhs), indexes) + else: + if len(indexes) != 1: + self.report_error("Invalid Store stmt") + return tvm.tir.Store(symbol, tvm.runtime.convert(rhs), indexes[0], + tvm.runtime.convert(True)) + else: + self.report_error("Unsupported Assign stmt") + + def visit_AnnAssign(self, node): + """ AnnAssign visitor + AST abstract grammar: + AnnAssign(expr target, expr annotation, expr? value, int simple) + Corresponds to concise mode of with tir.let() + """ + + if isinstance(node.target, ast.Name): + value = self.visit(node.value) + var = tvm.te.var(node.target.id, self.parse_type(node.annotation)) + self.scope_emitter.update_symbol(var.name, var) + return tvm.tir.LetStmt(var, value, self.visit(self.scope_emitter.node_stack[-1].pop())) + else: + self.report_error("Unsupported AnnAssign stmt") + + def visit_Assert(self, node): + """ Assert visitor + AST abstract grammar: + Assert(expr test, expr? msg) + Corresponds to concise mode of with tir.assert() + """ + + condition = self.visit(node.test) + if node.msg is None: + self.report_error("Message of AssertStmt can't be None") + message = self.visit(node.msg) + return tvm.tir.AssertStmt(condition, tvm.runtime.convert(message), self.get_body()) + + def visit_For(self, node): + """ For visitor + AST abstract grammar: + For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment) + By now only 1 type of For is supported: + 1. for name in tir.range(begin, end, for_type) + """ + + if not isinstance(node.target, ast.Name): + self.report_error("The loop variable should be a name variable") + # check node.iter, which is a tir Call + if not isinstance(node.iter, ast.Call): + self.report_error("The loop iter should be a Call") + if not isinstance(node.iter.func, ast.Attribute) \ + or not isinstance(node.iter.func.value, ast.Name) \ + or node.iter.func.value.id != "tir": + self.report_error("The loop iter Call should be tir.name()") + + func_name = node.iter.func.attr + # collect arguments + args = [self.visit(arg) for arg in node.iter.args] + kw_args = [self.visit(keyword) for keyword in node.iter.keywords] + kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} + # All the functions supported in For stmt are registered in scope_handler.ForScope + if func_name not in Registry.for_scope.keys(): + self.report_error("Function " + func_name + " used in For stmt is not supported now", + self.current_lineno, + node.iter.col_offset) + + old_lineno, old_col_offset = self.current_lineno, self.current_col_offset + self.current_lineno, self.current_col_offset = \ + self.base_lineno + node.iter.lineno - 1, node.iter.col_offset + res = Registry.for_scope.get(func_name)(self, node, args, kw_args) + self.current_lineno, self.current_col_offset = old_lineno, old_col_offset + return res + + def visit_With(self, node): + """ With visitor + AST abstract grammar: + With(withitem* items, stmt* body, string? type_comment) + withitem = (expr context_expr, expr? optional_vars) + By now only 1 type of With is supported: + 1. with tir.let/tir.Assert()/tir.attr()/tir.allocate()/tir.realize() + """ + + if len(node.items) != 1: + self.report_error("Only one with element is supported now") + if not isinstance(node.items[0].context_expr, ast.Call): + self.report_error("The context expression of with should be a Call") + func_call = node.items[0].context_expr + if not isinstance(func_call.func, ast.Attribute) \ + or not isinstance(func_call.func.value, ast.Name) \ + or func_call.func.value.id != "tir": + self.report_error("The context expression of with should be tir.name()") + + func_name = func_call.func.attr + # collect arguments + args = [self.visit(arg) for arg in func_call.args] + kw_args = [self.visit(keyword) for keyword in func_call.keywords] + kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} + if func_name not in Registry.with_scope.keys(): + self.report_error("Function " + func_name + " used in With stmt is not supported now") + + # All the functions supported in With stmt are registered in scope_handler.WithScope + old_lineno, old_col_offset = self.current_lineno, self.current_col_offset + self.current_lineno, self.current_col_offset = \ + self.base_lineno + func_call.lineno - 1, func_call.col_offset + res = Registry.with_scope.get(func_name)(self, node, args, kw_args) + self.current_lineno, self.current_col_offset = old_lineno, old_col_offset + return res + + def visit_If(self, node): + """ If visitor + AST abstract grammar: + If(expr test, stmt* body, stmt* orelse) + """ + + condition = self.visit(node.test) + # then body + self.scope_emitter.new_scope() + self.scope_emitter.node_stack[-1].extend(reversed(node.body)) + then_body = self.get_body() + self.scope_emitter.pop_scope() + + # else body + if len(node.orelse) > 0: + self.scope_emitter.new_scope() + self.scope_emitter.node_stack[-1].extend(reversed(node.orelse)) + else_body = self.get_body() + self.scope_emitter.pop_scope() + else: + else_body = None + return tvm.tir.IfThenElse(condition, then_body, else_body) + + def visit_Call(self, node): + """ Call visitor + AST abstract grammar: + Call(expr func, expr* args, keyword* keywords) + keyword = (identifier? arg, expr value) + """ + + # collect arguments + args = [self.visit(arg) for arg in node.args] + kw_args = [self.visit(keyword) for keyword in node.keywords] + kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} + + maybe_intrin = False + if isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Name): + if node.func.value.id == "tir": + func_name = node.func.attr + maybe_intrin = True + else: + self.report_error("Unsupported Attribute typed function call") + else: + self.report_error("Unsupported function call") + + if func_name in Registry.special_stmt.keys(): + return Registry.special_stmt.get(func_name)(self, node, args, kw_args) + if func_name in Registry.intrin.keys(): + return Registry.intrin.get(func_name)(self, node, args, kw_args) + if func_name in Registry.with_scope.keys(): + return Registry.with_scope.get(func_name)(self, node, args, kw_args) + if maybe_intrin: + return tvm.tir.Call(kw_args["dtype"], tvm.ir.op.Op.get("tir." + func_name), args) + + self.report_error("Function " + func_name + " is not supported now") + + def visit_Expr(self, node): + """ Expr visitor + AST abstract grammar: + Expr(expr value) + + Now only 2 types of Expr stmt is allowed: + 1. Concise mode of with scope handlers + tir.attr()/tir.assert()/tir.allocate()/tir.realize() + 2. special stmts appear as a call + tir.set_func_attr() + """ + + if not isinstance(node.value, ast.Call): + self.report_error("Unsupported Expr stmt") + return self.visit(node.value) + + def visit_BinOp(self, node): + """ BinOp visitor + AST abstract grammar: + BinOp(expr left, operator op, expr right) + """ + + lhs = self.visit(node.left) + rhs = self.visit(node.right) + if not isinstance(node.op, tuple(HybridParser._binop_maker.keys())): + self.report_error("BinOp " + str(type(node.op)) + " is not supported now") + return HybridParser._binop_maker[type(node.op)](lhs, rhs) + + def visit_Compare(self, node): + """ Compare visitor + AST abstract grammar: + Compare(expr left, expr right, ops=) + """ + + ops = [self.visit(node.left)] + ops += [self.visit(comparator) for comparator in node.comparators] + res = [] + for i in range(len(node.ops)): + lhs = ops[i] + rhs = ops[i + 1] + res.append(HybridParser._binop_maker[type(node.ops[i])](lhs, rhs)) + return _all(*res) + + def visit_BoolOp(self, node): + """ BoolOp visitor + AST abstract grammar: + BoolOp(boolop op, expr* values) + """ + + values = [self.visit(value) for value in node.values] + return HybridParser._binop_maker[type(node.op)](*values) + + def visit_UnaryOp(self, node): + """ UnaryOp visitor + AST abstract grammar: + UnaryOp(unaryop op, expr operand) + """ + + operand = self.visit(node.operand) + if not isinstance(node.op, tuple(HybridParser._unaryop_maker.keys())): + self.report_error("UnaryOp " + str(type(node.op)) + " is not supported now") + return HybridParser._unaryop_maker[type(node.op)](operand) + + def visit_Subscript(self, node): + """ Subscript visitor + AST abstract grammar: + Subscript(expr value, slice slice, expr_context ctx) + slice = Slice(expr? lower, expr? upper, expr? step) + | ExtSlice(slice* dims) + | Index(expr value) + By now only 2 types of Subscript are supported: + 1. Buffer[index, index, ...], Buffer element access(BufferLoad & BufferStore) + Var[index] Buffer element access() + 2. meta[type_key][index], Meta info access + """ + + if isinstance(node.value, (ast.Name, ast.Attribute)): + symbol = self.visit(node.value) + if isinstance(node.slice, ast.Index): + # BufferLoad & BufferStore + if isinstance(node.slice.value, ast.Tuple): + # Buffer/Var[index, index, ...] + indexes = [self.visit(element) for element in node.slice.value.elts] + else: + # Buffer/Var[index] + indexes = [self.visit(node.slice.value)] + if isinstance(node.ctx, ast.Load): + if isinstance(symbol, tir.expr.Var): + return tvm.tir.Load("float32", symbol, indexes, True) + else: + return tvm.tir.BufferLoad(symbol, indexes) + else: + return symbol, indexes + else: + # Buffer Region, now used in tir.realize(buffer[bounds]) + doms = [] + slice_nodes = [] + if isinstance(node.slice, ast.Slice): + # Buffer[begin:end] + slice_nodes.append(node.slice) + elif isinstance(node.slice, ast.ExtSlice): + # Buffer[begin:end, begin:end] + slice_nodes.extend(node.slice.dims) + + for dim in slice_nodes: + if not hasattr(dim, "step"): + self.report_error("slice of Buffer Region ought to be begin:end") + if dim.step is not None: + self.report_error("step is not allowed in Buffer Region") + upper = self.visit(dim.upper) + lower = self.visit(dim.lower) + extent = upper - lower + if isinstance(extent, _expr.PrimExpr): + ana = tvm.arith.Analyzer() + extent = ana.simplify(extent) + doms.append(tvm.ir.Range.from_min_extent(lower, extent)) + return symbol, doms + + elif isinstance(node.value, ast.Subscript) and isinstance(node.value.value, ast.Name) \ + and node.value.value.id == 'meta': + # meta[type_key][index] + if not (isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Num)) \ + or not (isinstance(node.value.slice, ast.Index) \ + and isinstance(node.value.slice.value, ast.Name)): + self.report_error("The meta access format ought to be meta[type_key][index]") + type_key = node.value.slice.value.id + index = node.slice.value.n + node_list = self.meta[type_key] + if node_list is None: + self.report_error("type_key " + type_key + " in meta not found") + if len(node_list) <= index: + self.report_error("index " + index + " out of range " + len(node_list)) + return node_list[index] + else: + self.report_error("Only buffer variable and meta can be subscriptable") + + def visit_Name(self, node): + """ Name visitor + AST abstract grammar: + Name(identifier id, expr_context ctx) + """ + + name = node.id + symbol = self.scope_emitter.lookup_symbol(name) + if symbol is None: + self.report_error("Unknown symbol %s" % name) + return symbol + + def visit_Attribute(self, node): + """ Attribute visitor + AST abstract grammar: + Attribute(expr value, identifier attr, expr_context ctx) + """ + + if not isinstance(node.value, ast.Name): + self.report_error("The value of Attribute ought to a Name") + name = node.value.id + symbol = self.scope_emitter.lookup_symbol(name) + if symbol is None or not isinstance(symbol, tvm.tir.Buffer): + self.report_error("Unsupported Attribute expression") + if not hasattr(symbol, node.attr): + self.report_error("Type " + type(symbol) + " has not attr " + node.attr) + return getattr(symbol, node.attr) + + def visit_Dict(self, node): + """ Dict visitor + AST abstract grammar: + Dict(expr* keys, expr* values) + """ + + keys = [self.visit(key) for key in node.keys] + values = [self.visit(value) for value in node.values] + + return {key: value for key, value in zip(keys, values)} + + def visit_Tuple(self, node): + """ Tuple visitor + AST abstract grammar: + Tuple(expr* elts, expr_context ctx) + """ + + return tuple(self.visit(element) for element in node.elts) + + def visit_List(self, node): + """ List visitor + AST abstract grammar: + List(expr* elts, expr_context ctx) + """ + + return [self.visit(element) for element in node.elts] + + def visit_keyword(self, node): + """ Keyword visitor + AST abstract grammar: + keyword = (identifier? arg, expr value) + """ + + return node.arg, self.visit(node.value) + + def visit_NameConstant(self, node): + return tvm.runtime.convert(node.value) + + def visit_Constant(self, node): + return tvm.runtime.convert(node.value) + + def visit_Num(self, node): + if isinstance(node.n, numbers.Integral): + dtype = "int32" + elif isinstance(node.n, float): + dtype = "float32" + else: + self.report_error("The data type should be one of (int, float)") + return tvm.tir.const(node.n, dtype) + + def visit_Str(self, node): + return node.s + + +def source_to_op(src, func_lineno=0): Review comment: just to from_source, to avoid conflict in names ########## File path: tests/python/unittest/test_hybrid_roundtrip.py ########## @@ -0,0 +1,539 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import tir +from tvm.hybrid import ty + + [email protected] +class Module1: + def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "mmult", "tir.noalias": True}) + # buffer definition + C_global = tir.buffer_decl([1024, 1024], elem_offset=0, align=128, offset_factor=1) + packedB = tir.buffer_decl([32, 1024, 32], elem_offset=0, align=128, offset_factor=1) + A_1 = tir.buffer_bind(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + B_1 = tir.buffer_bind(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + C_1 = tir.buffer_bind(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + # body + tir.attr(packedB, "realize_scope", "") + tir.realize(packedB[0:32, 0:1024, 0:32]) + for x in tir.range(0, 32, "parallel"): + for y in tir.range(0, 1024): + for z in tir.range(0, 32, "vectorized"): + packedB[x, y, z] = B_1[y, ((x*32) + z)] + tir.attr(C_1, "realize_scope", "") + tir.realize(C_1[0:1024, 0:1024]) + for x_outer in tir.range(0, 32, "parallel"): + for y_outer in tir.range(0, 32): + tir.attr(C_global, "realize_scope", "global") + tir.realize(C_global[(x_outer*32):((x_outer*32) + 32), (y_outer*32):((y_outer*32) + 32)]) + for x_c_init in tir.range(0, 32): + for y_c_init in tir.range(0, 32, "vectorized"): + C_global[(x_c_init + (x_outer*32)), (y_c_init + (y_outer*32))] = tir.float32(0) + for k_outer in tir.range(0, 256): + for x_c in tir.range(0, 32): + for k_inner in tir.range(0, 4, "unroll"): + for y_c in tir.range(0, 32, "vectorized"): + C_global[(x_c + (x_outer*32)), (y_c + (y_outer*32))] = (C_global[(x_c + (x_outer*32)), (y_c + (y_outer*32))] + (A_1[(x_c + (x_outer*32)), (k_inner + (k_outer*4))]*packedB[tir.floordiv((y_c + (y_outer*32)), 32), (k_inner + (k_outer*4)), tir.floormod((y_c + (y_outer*32)), 32)])) + for x_inner in tir.range(0, 32): + for y_inner in tir.range(0, 32): + C_1[(x_inner + (x_outer*32)), (y_inner + (y_outer*32))] = C_global[(x_inner + (x_outer*32)), (y_inner + (y_outer*32))] + + +def test_opt_gemm_normalize(): + mod = Module1() + rt_mod = tvm.hybrid.from_source(tvm.hybrid.ashybrid(mod, True)) + tvm.ir.assert_structural_equal(mod, rt_mod, True) + + [email protected] +class Module2: + def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "mmult", "tir.noalias": True}) + # var definition + C_global = tir.var("handle") + packedB = tir.var("handle") + A_1 = tir.buffer_bind(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + B_1 = tir.buffer_bind(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + C_1 = tir.buffer_bind(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1) + # body + tir.attr(packedB, "storage_scope", "global") + tir.allocate(packedB, "float32x32", [32768]) + tir.attr(C_global, "storage_scope", "global") + tir.allocate(C_global, "float32", [1024]) + for x in tir.range(0, 32, "parallel"): + for y in tir.range(0, 1024): + tir.store(packedB, tir.ramp(((x*32768) + (y*32)), 1, 32), tir.load("float32x32", B_1.data, tir.ramp(((y*1024) + (x*32)), 1, 32), tir.broadcast(True, 32)), tir.broadcast(True, 32)) + for x_outer in tir.range(0, 32): + for y_outer in tir.range(0, 32): + for x_c_init in tir.range(0, 32): + tir.store(C_global, tir.ramp((x_c_init*32), 1, 32), tir.broadcast(tir.float32(0), 32), tir.broadcast(True, 32)) + for k_outer in tir.range(0, 256): + for x_c in tir.range(0, 32): + tir.store(C_global, tir.ramp((x_c*32), 1, 32), (tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)) + (tir.broadcast(tir.load("float32", A_1.data, (((x_outer*32768) + (x_c*1024)) + (k_outer*4))), 32)*tir.load("float32x32", packedB, tir.ramp(((y_outer*32768) + (k_outer*128)), 1, 32), tir.broadcast(True, 32)))), tir.broadcast(True, 32)) + tir.store(C_global, tir.ramp((x_c*32), 1, 32), (tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)) + (tir.broadcast(tir.load("float32", A_1.data, ((((x_outer*32768) + (x_c*1024)) + (k_outer*4)) + 1)), 32)*tir.load("float32x32", packedB, tir.ramp((((y_outer*32768) + (k_outer*128)) + 32), 1, 32), tir.broadcast(True, 32)))), tir.broadcast(True, 32)) + tir.store(C_global, tir.ramp((x_c*32), 1, 32), (tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)) + (tir.broadcast(tir.load("float32", A_1.data, ((((x_outer*32768) + (x_c*1024)) + (k_outer*4)) + 2)), 32)*tir.load("float32x32", packedB, tir.ramp((((y_outer*32768) + (k_outer*128)) + 64), 1, 32), tir.broadcast(True, 32)))), tir.broadcast(True, 32)) + tir.store(C_global, tir.ramp((x_c*32), 1, 32), (tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)) + (tir.broadcast(tir.load("float32", A_1.data, ((((x_outer*32768) + (x_c*1024)) + (k_outer*4)) + 3)), 32)*tir.load("float32x32", packedB, tir.ramp((((y_outer*32768) + (k_outer*128)) + 96), 1, 32), tir.broadcast(True, 32)))), tir.broadcast(True, 32)) + for x_inner in tir.range(0, 32): + for y_inner in tir.range(0, 32): + C_1.data[((((x_outer*32768) + (x_inner*1024)) + (y_outer*32)) + y_inner)] = tir.load("float32", C_global, ((x_inner*32) + y_inner)) + + +def test_opt_gemm_lower(): + mod = Module2() + rt_mod = tvm.hybrid.from_source(tvm.hybrid.ashybrid(mod, True)) + tvm.ir.assert_structural_equal(mod, rt_mod, True) + + [email protected] +class Module3: + def mmult(args: ty.handle, arg_type_ids: ty.handle, num_args: ty.int32, out_ret_value: ty.handle, out_ret_tcode: ty.handle) -> ty.int32: + # function attr dict + tir.func_attr({"tir.noalias": True, "global_symbol": "mmult", "tir.is_entry_func": True, "calling_conv": 1}) + # var definition + C_global = tir.var("handle") + packedB = tir.var("handle") + # body + assert (num_args == 3), "mmult: num_args should be 3" + arg0: ty.handle = tir.tvm_struct_get(args, 0, 12, dtype="handle") + arg0_code: ty.int32 = tir.load("int32", arg_type_ids, 0) + arg1: ty.handle = tir.tvm_struct_get(args, 1, 12, dtype="handle") + arg1_code: ty.int32 = tir.load("int32", arg_type_ids, 1) + arg2: ty.handle = tir.tvm_struct_get(args, 2, 12, dtype="handle") + arg2_code: ty.int32 = tir.load("int32", arg_type_ids, 2) + A: ty.handle = tir.tvm_struct_get(arg0, 0, 1, dtype="handle") + tir.attr(A, "storage_alignment", 128) + arg0_shape: ty.handle = tir.tvm_struct_get(arg0, 0, 2, dtype="handle") + arg0_strides: ty.handle = tir.tvm_struct_get(arg0, 0, 3, dtype="handle") + dev_id: ty.int32 = tir.tvm_struct_get(arg0, 0, 9, dtype="int32") + B: ty.handle = tir.tvm_struct_get(arg1, 0, 1, dtype="handle") + tir.attr(B, "storage_alignment", 128) + arg1_shape: ty.handle = tir.tvm_struct_get(arg1, 0, 2, dtype="handle") + arg1_strides: ty.handle = tir.tvm_struct_get(arg1, 0, 3, dtype="handle") + C: ty.handle = tir.tvm_struct_get(arg2, 0, 1, dtype="handle") + tir.attr(C, "storage_alignment", 128) + arg2_shape: ty.handle = tir.tvm_struct_get(arg2, 0, 2, dtype="handle") + arg2_strides: ty.handle = tir.tvm_struct_get(arg2, 0, 3, dtype="handle") + assert ((((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or (arg0_code == 4)), "mmult: Expect arg[0] to be pointer" + assert ((((arg1_code == 3) or (arg1_code == 13)) or (arg1_code == 7)) or (arg1_code == 4)), "mmult: Expect arg[1] to be pointer" + assert ((((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or (arg2_code == 4)), "mmult: Expect arg[2] to be pointer" + assert (2 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32")), "arg0.ndim is expected to equal 2" + assert (2 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32")), "arg0.ndim is expected to equal 2" + assert (((tir.tvm_struct_get(arg0, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg0, 0, 6, dtype="uint8") == tir.uint8(32))) and (tir.tvm_struct_get(arg0, 0, 7, dtype="uint16") == tir.uint16(1))), "arg0.dtype is expected to be float32" + assert (1024 == tir.cast("int32", tir.load("int64", arg0_shape, 0))), "Argument arg0.shape[0] has an unsatisfied constraint" + assert (1024 == tir.cast("int32", tir.load("int64", arg0_shape, 1))), "Argument arg0.shape[1] has an unsatisfied constraint" + if not (tir.isnullptr(arg0_strides, dtype="bool")): + assert ((1 == tir.cast("int32", tir.load("int64", arg0_strides, 1))) and (1024 == tir.cast("int32", tir.load("int64", arg0_strides, 0)))), "arg0.strides: expected to be compact array" + tir.evaluate(0) + assert (tir.uint64(0) == tir.tvm_struct_get(arg0, 0, 8, dtype="uint64")), "Argument arg0.byte_offset has an unsatisfied constraint" + assert (1 == tir.tvm_struct_get(arg0, 0, 10, dtype="int32")), "Argument arg0.device_type has an unsatisfied constraint" + assert (2 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32")), "arg1.ndim is expected to equal 2" + assert (2 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32")), "arg1.ndim is expected to equal 2" + assert (((tir.tvm_struct_get(arg1, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg1, 0, 6, dtype="uint8") == tir.uint8(32))) and (tir.tvm_struct_get(arg1, 0, 7, dtype="uint16") == tir.uint16(1))), "arg1.dtype is expected to be float32" + assert (1024 == tir.cast("int32", tir.load("int64", arg1_shape, 0))), "Argument arg1.shape[0] has an unsatisfied constraint" + assert (1024 == tir.cast("int32", tir.load("int64", arg1_shape, 1))), "Argument arg1.shape[1] has an unsatisfied constraint" + if not (tir.isnullptr(arg1_strides, dtype="bool")): + assert ((1 == tir.cast("int32", tir.load("int64", arg1_strides, 1))) and (1024 == tir.cast("int32", tir.load("int64", arg1_strides, 0)))), "arg1.strides: expected to be compact array" + tir.evaluate(0) + assert (tir.uint64(0) == tir.tvm_struct_get(arg1, 0, 8, dtype="uint64")), "Argument arg1.byte_offset has an unsatisfied constraint" + assert (1 == tir.tvm_struct_get(arg1, 0, 10, dtype="int32")), "Argument arg1.device_type has an unsatisfied constraint" + assert (dev_id == tir.tvm_struct_get(arg1, 0, 9, dtype="int32")), "Argument arg1.device_id has an unsatisfied constraint" + assert (2 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32")), "arg2.ndim is expected to equal 2" + assert (2 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32")), "arg2.ndim is expected to equal 2" + assert (((tir.tvm_struct_get(arg2, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg2, 0, 6, dtype="uint8") == tir.uint8(32))) and (tir.tvm_struct_get(arg2, 0, 7, dtype="uint16") == tir.uint16(1))), "arg2.dtype is expected to be float32" + assert (1024 == tir.cast("int32", tir.load("int64", arg2_shape, 0))), "Argument arg2.shape[0] has an unsatisfied constraint" + assert (1024 == tir.cast("int32", tir.load("int64", arg2_shape, 1))), "Argument arg2.shape[1] has an unsatisfied constraint" + if not (tir.isnullptr(arg2_strides, dtype="bool")): + assert ((1 == tir.cast("int32", tir.load("int64", arg2_strides, 1))) and (1024 == tir.cast("int32", tir.load("int64", arg2_strides, 0)))), "arg2.strides: expected to be compact array" + tir.evaluate(0) + assert (tir.uint64(0) == tir.tvm_struct_get(arg2, 0, 8, dtype="uint64")), "Argument arg2.byte_offset has an unsatisfied constraint" + assert (1 == tir.tvm_struct_get(arg2, 0, 10, dtype="int32")), "Argument arg2.device_type has an unsatisfied constraint" + assert (dev_id == tir.tvm_struct_get(arg2, 0, 9, dtype="int32")), "Argument arg2.device_id has an unsatisfied constraint" + tir.attr(0, "compute_scope", "mmult_compute_") + tir.attr(packedB, "storage_scope", "global") + tir.attr(packedB, "storage_alignment", 128) + with tir.let(packedB, tir.TVMBackendAllocWorkspace(1, dev_id, tir.uint64(4194304), 2, 32, dtype="handle")): + if tir.isnullptr(packedB, dtype="bool"): + tir.evaluate(tir.tvm_throw_last_error(dtype="int32")) + for x in tir.range(0, 32, "parallel"): + for y in tir.range(0, 1024): + tir.store(packedB, tir.ramp(((x*32768) + (y*32)), 1, 32), tir.load("float32x32", B, tir.ramp(((y*1024) + (x*32)), 1, 32), tir.broadcast(True, 32)), tir.broadcast(True, 32)) + for x_outer in tir.range(0, 32, "parallel"): + tir.attr(C_global, "storage_scope", "global") + tir.attr(C_global, "storage_alignment", 128) + with tir.let(C_global, tir.TVMBackendAllocWorkspace(1, dev_id, tir.uint64(4096), 2, 32, dtype="handle")): + if tir.isnullptr(C_global, dtype="bool"): + tir.evaluate(tir.tvm_throw_last_error(dtype="int32")) + for y_outer in tir.range(0, 32): + for x_c_init in tir.range(0, 32): + tir.store(C_global, tir.ramp((x_c_init*32), 1, 32), tir.broadcast(tir.float32(0), 32), tir.broadcast(True, 32)) + for k_outer in tir.range(0, 256): + for x_c in tir.range(0, 32): + tir.store(C_global, tir.ramp((x_c*32), 1, 32), tir.call_llvm_pure_intrin(tir.uint32(97), tir.uint32(3), tir.broadcast(tir.load("float32", A, (((x_outer*32768) + (x_c*1024)) + (k_outer*4))), 32), tir.load("float32x32", packedB, tir.ramp(((y_outer*32768) + (k_outer*128)), 1, 32), tir.broadcast(True, 32)), tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)), dtype="float32x32"), tir.broadcast(True, 32)) + tir.store(C_global, tir.ramp((x_c*32), 1, 32), tir.call_llvm_pure_intrin(tir.uint32(97), tir.uint32(3), tir.broadcast(tir.load("float32", A, ((((x_outer*32768) + (x_c*1024)) + (k_outer*4)) + 1)), 32), tir.load("float32x32", packedB, tir.ramp((((y_outer*32768) + (k_outer*128)) + 32), 1, 32), tir.broadcast(True, 32)), tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)), dtype="float32x32"), tir.broadcast(True, 32)) + tir.store(C_global, tir.ramp((x_c*32), 1, 32), tir.call_llvm_pure_intrin(tir.uint32(97), tir.uint32(3), tir.broadcast(tir.load("float32", A, ((((x_outer*32768) + (x_c*1024)) + (k_outer*4)) + 2)), 32), tir.load("float32x32", packedB, tir.ramp((((y_outer*32768) + (k_outer*128)) + 64), 1, 32), tir.broadcast(True, 32)), tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)), dtype="float32x32"), tir.broadcast(True, 32)) + tir.store(C_global, tir.ramp((x_c*32), 1, 32), tir.call_llvm_pure_intrin(tir.uint32(97), tir.uint32(3), tir.broadcast(tir.load("float32", A, ((((x_outer*32768) + (x_c*1024)) + (k_outer*4)) + 3)), 32), tir.load("float32x32", packedB, tir.ramp((((y_outer*32768) + (k_outer*128)) + 96), 1, 32), tir.broadcast(True, 32)), tir.load("float32x32", C_global, tir.ramp((x_c*32), 1, 32), tir.broadcast(True, 32)), dtype="float32x32"), tir.broadcast(True, 32)) + for x_inner in tir.range(0, 32): + for y_inner in tir.range(0, 32): + C[((((x_outer*32768) + (x_inner*1024)) + (y_outer*32)) + y_inner)] = tir.load("float32", C_global, ((x_inner*32) + y_inner)) + if (tir.TVMBackendFreeWorkspace(1, dev_id, C_global, dtype="int32") != 0): + tir.evaluate(tir.tvm_throw_last_error(dtype="int32")) + if (tir.TVMBackendFreeWorkspace(1, dev_id, packedB, dtype="int32") != 0): + tir.evaluate(tir.tvm_throw_last_error(dtype="int32")) + + +def test_opt_gemm_mod_host(): + mod = Module3() + rt_mod = tvm.hybrid.from_source(tvm.hybrid.ashybrid(mod, True)) + tvm.ir.assert_structural_equal(mod, rt_mod, True) + + [email protected] +class Module4: + def default_function(A: ty.handle, W: ty.handle, Conv: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + # var definition + blockIdx_x = tir.var("int32") + blockIdx_y = tir.var("int32") + blockIdx_z = tir.var("int32") + threadIdx_x = tir.var("int32") + threadIdx_y = tir.var("int32") + threadIdx_z = tir.var("int32") + # buffer definition + Apad_shared = tir.buffer_decl([16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1) + Apad_shared_wmma_matrix_a = tir.buffer_decl([16, 16, 16, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1) + BA = tir.buffer_decl([16, 16], dtype="float16", scope="wmma.matrix_a", align=32, offset_factor=256) + BB = tir.buffer_decl([16, 16], dtype="float16", scope="wmma.matrix_b", align=32, offset_factor=256) + BC = tir.buffer_decl([16, 16], scope="wmma.accumulator", align=32, offset_factor=256) + Conv_wmma_accumulator = tir.buffer_decl([16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1) + W_shared = tir.buffer_decl([3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1) + W_shared_wmma_matrix_b = tir.buffer_decl([3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1) + buffer = tir.buffer_decl([16, 16], dtype="float16", scope="shared", align=32, offset_factor=256) + buffer_1 = tir.buffer_decl([16, 16], dtype="float16", scope="wmma.matrix_a", align=32, offset_factor=256) + buffer_2 = tir.buffer_decl([16, 16], dtype="float16", scope="shared", align=32, offset_factor=256) + buffer_3 = tir.buffer_decl([16, 16], dtype="float16", scope="wmma.matrix_b", align=32, offset_factor=256) + buffer_4 = tir.buffer_decl([16, 16], scope="wmma.accumulator", align=32, offset_factor=256) + buffer_5 = tir.buffer_decl([16, 16], align=32, offset_factor=256) + A_1 = tir.buffer_bind(A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1) + W_1 = tir.buffer_bind(W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1) + Conv_1 = tir.buffer_bind(Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1) + # body + tir.attr(Conv_1, "realize_scope", "") + tir.realize(Conv_1[0:16, 0:14, 0:14, 0:32, 0:16, 0:16]) + tir.attr(tir.iter_var(blockIdx_z, None, "ThreadIndex", "blockIdx.z"), "thread_extent", 196) + tir.attr(tir.iter_var(blockIdx_x, None, "ThreadIndex", "blockIdx.x"), "thread_extent", 2) + tir.attr(tir.iter_var(blockIdx_y, None, "ThreadIndex", "blockIdx.y"), "thread_extent", 4) + tir.attr(tir.iter_var(threadIdx_y, None, "ThreadIndex", "threadIdx.y"), "thread_extent", 4) + tir.attr(tir.iter_var(threadIdx_z, None, "ThreadIndex", "threadIdx.z"), "thread_extent", 2) + tir.attr(Conv_wmma_accumulator, "realize_scope", "wmma.accumulator") + tir.realize(Conv_wmma_accumulator[((blockIdx_x*8) + (threadIdx_y*2)):(((blockIdx_x*8) + (threadIdx_y*2)) + 2), tir.floordiv(blockIdx_z, 14):(tir.floordiv(blockIdx_z, 14) + 1), tir.floormod(blockIdx_z, 14):(tir.floormod(blockIdx_z, 14) + 1), ((blockIdx_y*8) + (threadIdx_z*4)):(((blockIdx_y*8) + (threadIdx_z*4)) + 4), 0:16, 0:16]) + for n_c_init in tir.range(0, 2): + for o_c_init in tir.range(0, 4): + tir.attr([BC, Conv_wmma_accumulator], "buffer_bind_scope", tir.tvm_tuple((n_c_init + ((blockIdx_x*8) + (threadIdx_y*2))), 1, tir.floordiv(blockIdx_z, 14), 1, tir.floormod(blockIdx_z, 14), 1, (o_c_init + ((blockIdx_y*8) + (threadIdx_z*4))), 1, 0, 16, 0, 16, dtype="handle")) + tir.evaluate(tir.tvm_fill_fragment(BC.data, 16, 16, 16, tir.floordiv(BC.elem_offset, 256), tir.float32(0), dtype="handle")) + for ic_outer in tir.range(0, 8): + for kh in tir.range(0, 3): + tir.attr(Apad_shared, "realize_scope", "shared") + tir.realize(Apad_shared[(blockIdx_x*8):((blockIdx_x*8) + 8), (tir.floordiv(blockIdx_z, 14) + kh):((tir.floordiv(blockIdx_z, 14) + kh) + 1), tir.floormod(blockIdx_z, 14):(tir.floormod(blockIdx_z, 14) + 3), (ic_outer*2):((ic_outer*2) + 2), 0:16, 0:16]) + for ax2 in tir.range(0, 3): + for ax3 in tir.range(0, 2): + for ax4_ax5_fused_outer in tir.range(0, 8): + tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32) + Apad_shared[((threadIdx_z + (threadIdx_y*2)) + (blockIdx_x*8)), (tir.floordiv(blockIdx_z, 14) + kh), (ax2 + tir.floormod(blockIdx_z, 14)), (ax3 + (ic_outer*2)), tir.floordiv((threadIdx_x + (ax4_ax5_fused_outer*32)), 16), tir.floormod((threadIdx_x + (ax4_ax5_fused_outer*32)), 16)] = tir.if_then_else((((((tir.floordiv(blockIdx_z, 14) + kh) >= 1) and (((tir.floordiv(blockIdx_z, 14) + kh) - 1) < 14)) and ((ax2 + tir.floormod(blockIdx_z, 14)) >= 1)) and (((ax2 + tir.floormod(blockIdx_z, 14)) - 1) < 14)), A_1[((threadIdx_z + (threadIdx_y*2)) + (blockIdx_x*8)), ((tir.floordiv(blockIdx_z, 14) + kh) - 1), ((ax2 + tir.floormod(blockIdx_z, 14)) - 1), (ax3 + (ic_outer*2)), tir.floordiv((threadIdx_x + (ax4_ax5_fused_outer*32)), 16), tir.floormod((threadIdx_x + (ax4_ax5_fused_outer*32)), 16)], tir.float16(0), dtype="float16") + tir.attr(W_shared, "realize_scope", "shared") + tir.realize(W_shared[kh:(kh + 1), 0:3, (ic_outer*2):((ic_outer*2) + 2), (blockIdx_y*8):((blockIdx_y*8) + 8), 0:16, 0:16]) + for ax1 in tir.range(0, 3): + for ax2_1 in tir.range(0, 2): + tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32) + for ax4_ax5_fused_inner in tir.range(0, 8, "vectorized"): + W_shared[kh, ax1, (ax2_1 + (ic_outer*2)), ((threadIdx_z + (threadIdx_y*2)) + (blockIdx_y*8)), tir.floordiv((ax4_ax5_fused_inner + (threadIdx_x*8)), 16), tir.floormod((ax4_ax5_fused_inner + (threadIdx_x*8)), 16)] = W_1[kh, ax1, (ax2_1 + (ic_outer*2)), ((threadIdx_z + (threadIdx_y*2)) + (blockIdx_y*8)), tir.floordiv((ax4_ax5_fused_inner + (threadIdx_x*8)), 16), tir.floormod((ax4_ax5_fused_inner + (threadIdx_x*8)), 16)] + for ic_inner in tir.range(0, 2): + for kw in tir.range(0, 3): + tir.attr(Apad_shared_wmma_matrix_a, "realize_scope", "wmma.matrix_a") + tir.realize(Apad_shared_wmma_matrix_a[((blockIdx_x*8) + (threadIdx_y*2)):(((blockIdx_x*8) + (threadIdx_y*2)) + 2), (tir.floordiv(blockIdx_z, 14) + kh):((tir.floordiv(blockIdx_z, 14) + kh) + 1), (kw + tir.floormod(blockIdx_z, 14)):((kw + tir.floormod(blockIdx_z, 14)) + 1), ((ic_outer*2) + ic_inner):(((ic_outer*2) + ic_inner) + 1), 0:16, 0:16]) + for ax0 in tir.range(0, 2): + tir.attr([buffer, Apad_shared], "buffer_bind_scope", tir.tvm_tuple((ax0 + ((blockIdx_x*8) + (threadIdx_y*2))), 1, (tir.floordiv(blockIdx_z, 14) + kh), 1, (kw + tir.floormod(blockIdx_z, 14)), 1, ((ic_outer*2) + ic_inner), 1, 0, 16, 0, 16, dtype="handle")) + tir.attr([buffer_1, Apad_shared_wmma_matrix_a], "buffer_bind_scope", tir.tvm_tuple((ax0 + ((blockIdx_x*8) + (threadIdx_y*2))), 1, (tir.floordiv(blockIdx_z, 14) + kh), 1, (kw + tir.floormod(blockIdx_z, 14)), 1, ((ic_outer*2) + ic_inner), 1, 0, 16, 0, 16, dtype="handle")) + tir.evaluate(tir.tvm_load_matrix_sync(buffer_1.data, 16, 16, 16, tir.floordiv(buffer_1.elem_offset, 256), tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), buffer.data, buffer.elem_offset, 256, 1, dtype="handle"), 16, "row_major", dtype="handle")) + tir.attr(W_shared_wmma_matrix_b, "realize_scope", "wmma.matrix_b") + tir.realize(W_shared_wmma_matrix_b[kh:(kh + 1), kw:(kw + 1), ((ic_outer*2) + ic_inner):(((ic_outer*2) + ic_inner) + 1), ((blockIdx_y*8) + (threadIdx_z*4)):(((blockIdx_y*8) + (threadIdx_z*4)) + 4), 0:16, 0:16]) + for ax3_1 in tir.range(0, 4): + tir.attr([buffer_2, W_shared], "buffer_bind_scope", tir.tvm_tuple(kh, 1, kw, 1, ((ic_outer*2) + ic_inner), 1, (ax3_1 + ((blockIdx_y*8) + (threadIdx_z*4))), 1, 0, 16, 0, 16, dtype="handle")) + tir.attr([buffer_3, W_shared_wmma_matrix_b], "buffer_bind_scope", tir.tvm_tuple(kh, 1, kw, 1, ((ic_outer*2) + ic_inner), 1, (ax3_1 + ((blockIdx_y*8) + (threadIdx_z*4))), 1, 0, 16, 0, 16, dtype="handle")) + tir.evaluate(tir.tvm_load_matrix_sync(buffer_3.data, 16, 16, 16, tir.floordiv(buffer_3.elem_offset, 256), tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), buffer_2.data, buffer_2.elem_offset, 256, 1, dtype="handle"), 16, "row_major", dtype="handle")) + for n_c in tir.range(0, 2): + for o_c in tir.range(0, 4): + tir.attr([BA, Apad_shared_wmma_matrix_a], "buffer_bind_scope", tir.tvm_tuple((n_c + ((blockIdx_x*8) + (threadIdx_y*2))), 1, (tir.floordiv(blockIdx_z, 14) + kh), 1, (tir.floormod(blockIdx_z, 14) + kw), 1, ((ic_outer*2) + ic_inner), 1, 0, 16, 0, 16, dtype="handle")) + tir.attr([BB, W_shared_wmma_matrix_b], "buffer_bind_scope", tir.tvm_tuple(kh, 1, kw, 1, ((ic_outer*2) + ic_inner), 1, (o_c + ((blockIdx_y*8) + (threadIdx_z*4))), 1, 0, 16, 0, 16, dtype="handle")) + tir.attr([BC, Conv_wmma_accumulator], "buffer_bind_scope", tir.tvm_tuple((n_c + ((blockIdx_x*8) + (threadIdx_y*2))), 1, tir.floordiv(blockIdx_z, 14), 1, tir.floormod(blockIdx_z, 14), 1, (o_c + ((blockIdx_y*8) + (threadIdx_z*4))), 1, 0, 16, 0, 16, dtype="handle")) + tir.evaluate(tir.tvm_mma_sync(BC.data, tir.floordiv(BC.elem_offset, 256), BA.data, tir.floordiv(BA.elem_offset, 256), BB.data, tir.floordiv(BB.elem_offset, 256), BC.data, tir.floordiv(BC.elem_offset, 256), dtype="handle")) + for n_inner in tir.range(0, 2): + for o_inner in tir.range(0, 4): + tir.attr([buffer_4, Conv_wmma_accumulator], "buffer_bind_scope", tir.tvm_tuple(((((blockIdx_x*4) + threadIdx_y)*2) + n_inner), 1, tir.floordiv(blockIdx_z, 14), 1, tir.floormod(blockIdx_z, 14), 1, ((((blockIdx_y*2) + threadIdx_z)*4) + o_inner), 1, 0, 16, 0, 16, dtype="handle")) + tir.attr([buffer_5, Conv_1], "buffer_bind_scope", tir.tvm_tuple(((((blockIdx_x*4) + threadIdx_y)*2) + n_inner), 1, tir.floordiv(blockIdx_z, 14), 1, tir.floormod(blockIdx_z, 14), 1, ((((blockIdx_y*2) + threadIdx_z)*4) + o_inner), 1, 0, 16, 0, 16, dtype="handle")) + tir.evaluate(tir.tvm_store_matrix_sync(buffer_4.data, 16, 16, 16, tir.floordiv(buffer_4.elem_offset, 256), tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), buffer_5.data, buffer_5.elem_offset, 256, 2, dtype="handle"), 16, "row_major", dtype="handle")) + + +def test_opt_conv_tensorcore_normalize(): + mod = Module4() + rt_mod = tvm.hybrid.from_source(tvm.hybrid.ashybrid(mod, True)) + tvm.ir.assert_structural_equal(mod, rt_mod, True) + + [email protected] +class Module5: + def default_function(A: ty.handle, W: ty.handle, Conv: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + # var definition + Apad_shared = tir.var("handle") + Apad_shared_wmma_matrix_a = tir.var("handle") + Conv_wmma_accumulator = tir.var("handle") + W_shared = tir.var("handle") + W_shared_wmma_matrix_b = tir.var("handle") + blockIdx_x = tir.var("int32") + blockIdx_y = tir.var("int32") + blockIdx_z = tir.var("int32") + threadIdx_x = tir.var("int32") + threadIdx_y = tir.var("int32") + threadIdx_z = tir.var("int32") + A_1 = tir.buffer_bind(A, [16, 14, 14, 16, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1) + W_1 = tir.buffer_bind(W, [3, 3, 16, 32, 16, 16], dtype="float16", elem_offset=0, align=128, offset_factor=1) + Conv_1 = tir.buffer_bind(Conv, [16, 14, 14, 32, 16, 16], elem_offset=0, align=128, offset_factor=1) + # body + tir.attr(tir.iter_var(blockIdx_z, None, "ThreadIndex", "blockIdx.z"), "thread_extent", 196) + tir.attr(Conv_wmma_accumulator, "storage_scope", "wmma.accumulator") + tir.allocate(Conv_wmma_accumulator, "float32", [2048]) + tir.attr(Apad_shared, "storage_scope", "shared") + tir.allocate(Apad_shared, "float16", [12288]) + tir.attr(W_shared, "storage_scope", "shared") + tir.allocate(W_shared, "float16", [12288]) + tir.attr(Apad_shared_wmma_matrix_a, "storage_scope", "wmma.matrix_a") + tir.allocate(Apad_shared_wmma_matrix_a, "float16", [512]) + tir.attr(W_shared_wmma_matrix_b, "storage_scope", "wmma.matrix_b") + tir.allocate(W_shared_wmma_matrix_b, "float16", [1024]) + tir.attr(tir.iter_var(blockIdx_x, None, "ThreadIndex", "blockIdx.x"), "thread_extent", 2) + tir.attr(tir.iter_var(blockIdx_y, None, "ThreadIndex", "blockIdx.y"), "thread_extent", 4) + tir.attr(tir.iter_var(threadIdx_y, None, "ThreadIndex", "threadIdx.y"), "thread_extent", 4) + tir.attr(tir.iter_var(threadIdx_z, None, "ThreadIndex", "threadIdx.z"), "thread_extent", 2) + tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 0, tir.float32(0), dtype="handle")) + tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 1, tir.float32(0), dtype="handle")) + tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 2, tir.float32(0), dtype="handle")) + tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 3, tir.float32(0), dtype="handle")) + tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 4, tir.float32(0), dtype="handle")) + tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 5, tir.float32(0), dtype="handle")) + tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 6, tir.float32(0), dtype="handle")) + tir.evaluate(tir.tvm_fill_fragment(Conv_wmma_accumulator, 16, 16, 16, 7, tir.float32(0), dtype="handle")) + for ic_outer in tir.range(0, 8): + for kh in tir.range(0, 3): + for ax2 in tir.range(0, 3): + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61440)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 32)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61408)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 64)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61376)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 96)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61344)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 128)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61312)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 160)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61280)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 192)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61248)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 224)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61216)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 256)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61184)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 288)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61152)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 320)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61120)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 352)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61088)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 384)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61056)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 416)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 61024)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 448)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 60992)), tir.float16(0), dtype="float16") + tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32) + Apad_shared[(((((threadIdx_y*3072) + (threadIdx_z*1536)) + (ax2*512)) + threadIdx_x) + 480)] = tir.if_then_else(((((1 <= (tir.floordiv(blockIdx_z, 14) + kh)) and ((tir.floordiv(blockIdx_z, 14) + kh) < 15)) and (1 <= (ax2 + tir.floormod(blockIdx_z, 14)))) and ((ax2 + tir.floormod(blockIdx_z, 14)) < 15)), tir.load("float16", A_1.data, (((((((((blockIdx_x*6422528) + (threadIdx_y*1605632)) + (threadIdx_z*802816)) + (kh*57344)) + (blockIdx_z*4096)) + (ax2*4096)) + (ic_outer*512)) + threadIdx_x) - 60960)), tir.float16(0), dtype="float16") + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + tir.store(W_shared, tir.ramp((((threadIdx_y*512) + (threadIdx_z*256)) + (threadIdx_x*8)), 1, 8), tir.load("float16x8", W_1.data, tir.ramp(((((((kh*393216) + (ic_outer*16384)) + (blockIdx_y*2048)) + (threadIdx_y*512)) + (threadIdx_z*256)) + (threadIdx_x*8)), 1, 8), tir.broadcast(True, 8)), tir.broadcast(True, 8)) + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + tir.store(W_shared, tir.ramp(((((threadIdx_y*512) + (threadIdx_z*256)) + (threadIdx_x*8)) + 2048), 1, 8), tir.load("float16x8", W_1.data, tir.ramp((((((((kh*393216) + (ic_outer*16384)) + (blockIdx_y*2048)) + (threadIdx_y*512)) + (threadIdx_z*256)) + (threadIdx_x*8)) + 8192), 1, 8), tir.broadcast(True, 8)), tir.broadcast(True, 8)) + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + tir.store(W_shared, tir.ramp(((((threadIdx_y*512) + (threadIdx_z*256)) + (threadIdx_x*8)) + 4096), 1, 8), tir.load("float16x8", W_1.data, tir.ramp((((((((kh*393216) + (ic_outer*16384)) + (blockIdx_y*2048)) + (threadIdx_y*512)) + (threadIdx_z*256)) + (threadIdx_x*8)) + 131072), 1, 8), tir.broadcast(True, 8)), tir.broadcast(True, 8)) + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + tir.store(W_shared, tir.ramp(((((threadIdx_y*512) + (threadIdx_z*256)) + (threadIdx_x*8)) + 6144), 1, 8), tir.load("float16x8", W_1.data, tir.ramp((((((((kh*393216) + (ic_outer*16384)) + (blockIdx_y*2048)) + (threadIdx_y*512)) + (threadIdx_z*256)) + (threadIdx_x*8)) + 139264), 1, 8), tir.broadcast(True, 8)), tir.broadcast(True, 8)) + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + tir.store(W_shared, tir.ramp(((((threadIdx_y*512) + (threadIdx_z*256)) + (threadIdx_x*8)) + 8192), 1, 8), tir.load("float16x8", W_1.data, tir.ramp((((((((kh*393216) + (ic_outer*16384)) + (blockIdx_y*2048)) + (threadIdx_y*512)) + (threadIdx_z*256)) + (threadIdx_x*8)) + 262144), 1, 8), tir.broadcast(True, 8)), tir.broadcast(True, 8)) + with tir.attr(tir.iter_var(threadIdx_x, None, "ThreadIndex", "threadIdx.x"), "thread_extent", 32): + tir.store(W_shared, tir.ramp(((((threadIdx_y*512) + (threadIdx_z*256)) + (threadIdx_x*8)) + 10240), 1, 8), tir.load("float16x8", W_1.data, tir.ramp((((((((kh*393216) + (ic_outer*16384)) + (blockIdx_y*2048)) + (threadIdx_y*512)) + (threadIdx_z*256)) + (threadIdx_x*8)) + 270336), 1, 8), tir.broadcast(True, 8)), tir.broadcast(True, 8)) + for ic_inner in tir.range(0, 2): + for kw in tir.range(0, 3): + tir.evaluate(tir.tvm_load_matrix_sync(Apad_shared_wmma_matrix_a, 16, 16, 16, 0, tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), Apad_shared, (((threadIdx_y*3072) + (kw*512)) + (ic_inner*256)), 256, 1, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_load_matrix_sync(Apad_shared_wmma_matrix_a, 16, 16, 16, 1, tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), Apad_shared, ((((threadIdx_y*3072) + (kw*512)) + (ic_inner*256)) + 1536), 256, 1, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_load_matrix_sync(W_shared_wmma_matrix_b, 16, 16, 16, 0, tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), W_shared, (((kw*4096) + (ic_inner*2048)) + (threadIdx_z*1024)), 256, 1, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_load_matrix_sync(W_shared_wmma_matrix_b, 16, 16, 16, 1, tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), W_shared, ((((kw*4096) + (ic_inner*2048)) + (threadIdx_z*1024)) + 256), 256, 1, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_load_matrix_sync(W_shared_wmma_matrix_b, 16, 16, 16, 2, tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), W_shared, ((((kw*4096) + (ic_inner*2048)) + (threadIdx_z*1024)) + 512), 256, 1, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_load_matrix_sync(W_shared_wmma_matrix_b, 16, 16, 16, 3, tir.tvm_access_ptr(tir.type_annotation(dtype="float16"), W_shared, ((((kw*4096) + (ic_inner*2048)) + (threadIdx_z*1024)) + 768), 256, 1, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 0, Apad_shared_wmma_matrix_a, 0, W_shared_wmma_matrix_b, 0, Conv_wmma_accumulator, 0, dtype="handle")) + tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 1, Apad_shared_wmma_matrix_a, 0, W_shared_wmma_matrix_b, 1, Conv_wmma_accumulator, 1, dtype="handle")) + tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 2, Apad_shared_wmma_matrix_a, 0, W_shared_wmma_matrix_b, 2, Conv_wmma_accumulator, 2, dtype="handle")) + tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 3, Apad_shared_wmma_matrix_a, 0, W_shared_wmma_matrix_b, 3, Conv_wmma_accumulator, 3, dtype="handle")) + tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 4, Apad_shared_wmma_matrix_a, 1, W_shared_wmma_matrix_b, 0, Conv_wmma_accumulator, 4, dtype="handle")) + tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 5, Apad_shared_wmma_matrix_a, 1, W_shared_wmma_matrix_b, 1, Conv_wmma_accumulator, 5, dtype="handle")) + tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 6, Apad_shared_wmma_matrix_a, 1, W_shared_wmma_matrix_b, 2, Conv_wmma_accumulator, 6, dtype="handle")) + tir.evaluate(tir.tvm_mma_sync(Conv_wmma_accumulator, 7, Apad_shared_wmma_matrix_a, 1, W_shared_wmma_matrix_b, 3, Conv_wmma_accumulator, 7, dtype="handle")) + tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 0, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, (((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)), 256, 2, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 1, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 256), 256, 2, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 2, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 512), 256, 2, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 3, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 768), 256, 2, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 4, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 1605632), 256, 2, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 5, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 1605888), 256, 2, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 6, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 1606144), 256, 2, dtype="handle"), 16, "row_major", dtype="handle")) + tir.evaluate(tir.tvm_store_matrix_sync(Conv_wmma_accumulator, 16, 16, 16, 7, tir.tvm_access_ptr(tir.type_annotation(dtype="float32"), Conv_1.data, ((((((blockIdx_x*12845056) + (threadIdx_y*3211264)) + (blockIdx_z*8192)) + (blockIdx_y*2048)) + (threadIdx_z*1024)) + 1606400), 256, 2, dtype="handle"), 16, "row_major", dtype="handle")) + + +def test_opt_conv_tensorcore_lower(): + mod = Module5() + rt_mod = tvm.hybrid.from_source(tvm.hybrid.ashybrid(mod, True)) + tvm.ir.assert_structural_equal(mod, rt_mod, True) + + [email protected] +class Module6: Review comment: let us also test hybrid script decorator directly on top of functions ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
