spectrometerHBH commented on a change in pull request #6227: URL: https://github.com/apache/incubator-tvm/pull/6227#discussion_r467524516
########## File path: python/tvm/hybrid/registry.py ########## @@ -0,0 +1,240 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hybrid Script Parser Function Registry """ +# pylint: disable=inconsistent-return-statements +import inspect +from enum import IntEnum +from typed_ast import ast3 as ast + + +class Category(IntEnum): + """Categories of registered functions""" + INTRIN = 0 + WITH_SCOPE = 1 + FOR_SCOPE = 2 + SPECIAL_STMT = 3 + + +class Registry(object): + """Registration map + All these maps are static + """ + intrin = dict() + with_scope = dict() + for_scope = dict() + special_stmt = dict() + + host_dict = { + Category.INTRIN: intrin, + Category.WITH_SCOPE: with_scope, + Category.FOR_SCOPE: for_scope, + Category.SPECIAL_STMT: special_stmt + } + + +class CallArgumentReader(object): + """A helper class which read required argument from passed arguments""" + + def __init__(self, func_name, args, kwargs, parser): + self.func_name = func_name + self.args = args + self.kwargs = kwargs + self.parser = parser + + def get_func_compulsory_arg(self, pos, name): + """Get corresponding function argument from argument list which is compulsory""" + + if len(self.args) >= pos: + arg = self.args[pos - 1] + elif name not in self.kwargs.keys(): + self.parser.report_error(self.func_name + " misses argument " + name) + else: + arg = self.kwargs[name] + + return arg + + def get_func_optional_arg(self, pos, name, default): + """Get corresponding function argument from argument list which is optional. + If user doesn't provide the argument, set it to default value + """ + + if len(self.args) >= pos: + arg = self.args[pos - 1] + elif name in self.kwargs.keys(): + arg = self.kwargs[name] + else: + return default + + return arg + + +def func_wrapper(func_name, func_to_register, arg_list, need_parser_and_node, need_body, concise): + """Helper function to wrap a function to be registered """ + + def wrap_func(parser, node, args, kwargs): + reader = CallArgumentReader(func_name, args, kwargs, parser) + internal_args = list() + + if need_body and not isinstance(node, ast.For): + # automatically parse body for with scope handlers + if isinstance(node, ast.With): + # the with scope handler is used inside with context + parser.scope_emitter.new_scope() + parser.scope_emitter.node_stack[-1].extend(reversed(node.body)) + body = parser.get_body() + parser.scope_emitter.pop_scope() + else: + # the with scope handler is used in concise scoping + if not concise: + parser.report_error("Concise scoping is not allowed here") + body = parser.get_body() + + if need_parser_and_node: + internal_args.append(parser) + internal_args.append(node) + + for i, arg_info in enumerate(arg_list): + if len(arg_info) == 1: + arg_name, = arg_info + if need_body and arg_name == "body": + internal_args.append(body) + else: + internal_args.append(reader.get_func_compulsory_arg(i + 1, arg_name)) + else: + arg_name, default = arg_info + internal_args.append(reader.get_func_optional_arg(i + 1, arg_name, default=default)) + + return func_to_register(*internal_args) + + return wrap_func + + +def register_func(category, origin_func, need_parser_and_node, need_body, concise): Review comment: I'll check it. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
