vinx13 commented on a change in pull request #9492:
URL: https://github.com/apache/tvm/pull/9492#discussion_r764256370
##########
File path: python/tvm/script/parser.py
##########
@@ -276,6 +280,22 @@ def parse_arg_list(self, func, node_call):
reader = CallArgumentReader(func_name, args, kw_args, self, node_call)
pos_only, kwargs, varargs = param_list
internal_args = list()
+ # check the default value of "name" for TypeCall and TypeApply
+ # if none, change it to the variable name
+ # Note that here the variable name will always be the last
+ # item in self.context.func_params. See transform_Function
+ # for more details of adding function parameters in self.context
+ buf_name: str = self.context.func_params[-1].name
+ if isinstance(node_call, ast.TypeApply):
+ if len(args) == 2:
Review comment:
Rather than checking `len(args)` here we should use more explicit and
clear way to check whether buffer name is present. One possible way is to force
users to use kwargs for parameters other than `shape` and `dtype` so that you
can check and set kw_args here
##########
File path: python/tvm/script/parser.py
##########
@@ -440,8 +463,23 @@ def check_decorator(decorators: List[ast.Expr]) -> bool:
# add parameters of function
for arg in node.params:
arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
Review comment:
since this is not used in match buffer case, move it to `else` branch
##########
File path: python/tvm/script/parser.py
##########
@@ -440,8 +463,23 @@ def check_decorator(decorators: List[ast.Expr]) -> bool:
# add parameters of function
for arg in node.params:
arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
- self.context.update_symbol(arg.name, arg_var, node)
self.context.func_params.append(arg_var)
+ # Note that this case is for T.match_buffer syntax sugar
+ if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)):
+ result = self.transform(arg.ty)
+ if not isinstance(result, buffer.Buffer):
+ self.report_error(
+ "The result type of evaluating TypeCall and TypeApply
stmt"
+ f" is wrong: {type(result)}. It should be a Buffer",
+ node.span,
+ )
+ arg_name_with_handle = arg.name + "_handle"
+ new_arg_var = tvm.te.var(arg_name_with_handle,
self.parse_type(arg.ty, arg))
Review comment:
`self.transform(arg.ty)` and `parse_type` is kind of duplication, dtype
of `te.var` here is actually the same as `buffer.dtype`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]