vinx13 commented on a change in pull request #9492: URL: https://github.com/apache/tvm/pull/9492#discussion_r756358992
########## File path: python/tvm/script/parser.py ########## @@ -1047,6 +1056,91 @@ def transform_TypeConstant(self, node): """ return node.value + def transform_TypeTuple(self, node): + return node.values + + def transform_TypeCall(self, node): + """Call value visitor for TypeCall. + + This method is for syntax sugar of T.match_buffer() + """ + + def parse_typecall_params(func, params, keyword_params): Review comment: this function is very similar to `parse_arg_list` is it possible to reuse it (and extend it with needed functionality) ########## File path: python/tvm/script/tir/ty.py ########## @@ -20,7 +20,10 @@ a wrapper for uniform Type system in IR """ # pylint: disable=invalid-name +from os import stat import tvm +from tvm import script +from tvm.script.tir.special_stmt import SpecialStmt Review comment: ```suggestion from .special_stmt import SpecialStmt ``` ########## File path: python/tvm/script/parser.py ########## @@ -1047,6 +1056,91 @@ def transform_TypeConstant(self, node): """ return node.value + def transform_TypeTuple(self, node): + return node.values + + def transform_TypeCall(self, node): + """Call value visitor for TypeCall. + + This method is for syntax sugar of T.match_buffer() + """ + + def parse_typecall_params(func, params, keyword_params): + args = [] + for arg in params: + if isinstance(arg, ast.TypeTuple): + values = [] + for value in self.transform(arg): + values.append(self.transform(value)) Review comment: would be better to recursively transform elements of `TypeTuple` in `transform_TypeTuple` than do it here, you can find examples in `transform_Tuple` ########## File path: python/tvm/script/parser.py ########## @@ -440,7 +443,13 @@ def check_decorator(decorators: List[ast.Expr]) -> bool: # add parameters of function for arg in node.params: arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg)) - self.context.update_symbol(arg.name, arg_var, node) + # Note that this case is for T.match_buffer syntax sugar + if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)): + buf = self.transform(arg.ty) Review comment: if this if statement only works for `T.match_buffer` it would be great to add some assertions (otherwise naming the return value `buf` here is confusing) ########## File path: python/tvm/script/tir/ty.py ########## @@ -20,7 +20,10 @@ a wrapper for uniform Type system in IR """ # pylint: disable=invalid-name +from os import stat Review comment: not needed ########## File path: python/tvm/script/parser.py ########## @@ -1047,6 +1056,91 @@ def transform_TypeConstant(self, node): """ return node.value + def transform_TypeTuple(self, node): + return node.values + + def transform_TypeCall(self, node): + """Call value visitor for TypeCall. + + This method is for syntax sugar of T.match_buffer() + """ + + def parse_typecall_params(func, params, keyword_params): + args = [] + for arg in params: + if isinstance(arg, ast.TypeTuple): + values = [] + for value in self.transform(arg): + values.append(self.transform(value)) + else: + values = self.transform(arg) + args.append(values) + kw_args = {} + for k, v in keyword_params.items(): + if isinstance(v, ast.TypeTuple): + values = [] + for value in self.transform(v): + values.append(self.transform(value)) Review comment: ditto -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org