This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push: new afc239a [REFACTOR][PY] relay.op.Op -> tvm.ir.Op (#5705) afc239a is described below commit afc239aeb870d5c0a25a3e3e8e8c838f7122d9cf Author: Tianqi Chen <tqc...@users.noreply.github.com> AuthorDate: Mon Jun 1 15:35:06 2020 -0700 [REFACTOR][PY] relay.op.Op -> tvm.ir.Op (#5705) * [REFACTOR][PY] relay.op.Op -> tvm.ir.Op * Improve the error check --- include/tvm/ir/op.h | 6 +- python/tvm/autotvm/graph_tuner/base_graph_tuner.py | 2 +- .../autotvm/graph_tuner/utils/traverse_graph.py | 4 +- python/tvm/autotvm/task/relay_integration.py | 4 +- python/tvm/autotvm/task/topi_integration.py | 2 +- python/tvm/ir/__init__.py | 1 + python/tvm/ir/json_compact.py | 2 +- python/tvm/ir/op.py | 114 +++++++++++++++++++++ python/tvm/relay/__init__.py | 1 - python/tvm/relay/_parser.py | 2 +- python/tvm/relay/analysis/annotated_regions.py | 4 +- python/tvm/relay/backend/compile_engine.py | 7 +- python/tvm/relay/expr.py | 2 +- python/tvm/relay/expr_functor.py | 2 +- python/tvm/relay/op/__init__.py | 4 +- python/tvm/relay/op/contrib/dnnl.py | 4 +- python/tvm/relay/op/op.py | 100 ++---------------- python/tvm/relay/qnn/op/legalizations.py | 3 +- python/tvm/relay/qnn/op/op.py | 4 +- python/tvm/relay/quantize/_annotate.py | 9 +- python/tvm/relay/quantize/_partition.py | 5 +- python/tvm/relay/testing/py_converter.py | 2 +- src/ir/op.cc | 76 ++++++-------- tests/cpp/relay_build_module_test.cc | 2 +- tests/python/relay/test_ir_op.py | 11 +- tests/python/relay/test_ir_parser.py | 2 +- tests/python/relay/test_pass_annotate_target.py | 13 ++- tests/python/relay/test_pass_partition_graph.py | 12 +-- tests/scripts/task_python_docs.sh | 5 +- tests/scripts/task_sphinx_precheck.sh | 6 +- vta/python/vta/top/graphpack.py | 2 +- 31 files changed, 215 insertions(+), 198 deletions(-) diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index f86aeba..8fc96a4 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -121,7 +121,7 @@ class OpNode : public RelayExprNode { return is_primitive_ != 0; } - static constexpr const char* _type_key = "relay.Op"; + static constexpr const char* _type_key = "Op"; TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelayExprNode); private: @@ -180,7 +180,7 @@ class Op : public RelayExpr { * \tparam ValueType The type of the attribute. */ template <typename ValueType> - inline static OpAttrMap<ValueType> GetAttrMap(const std::string& attr_name); + inline static OpAttrMap<ValueType> GetAttrMap(const String& attr_name); /*! * \brief Checks if an attr map is present in the registry. * \param attr_name The name of the attribute. @@ -374,7 +374,7 @@ class OpAttrMap : public AttrRegistryMap<Op, ValueType> { inline const OpNode* Op::operator->() const { return static_cast<const OpNode*>(get()); } template <typename ValueType> -inline OpAttrMap<ValueType> Op::GetAttrMap(const std::string& key) { +inline OpAttrMap<ValueType> Op::GetAttrMap(const String& key) { return OpAttrMap<ValueType>(Op::GetAttrMapContainer(key)); } diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py index e7b4694..1cc4f39 100644 --- a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py +++ b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py @@ -81,7 +81,7 @@ class BaseGraphTuner(object): Each row of this file is an encoded record pair. Otherwise, it is an iterator. - target_ops : List of relay.op.Op + target_ops : List of tvm.ir.Op Target tuning operators. target : str or tvm.target diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py index 8470fb6..b85c562 100644 --- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py +++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py @@ -38,7 +38,7 @@ def expr2graph(expr, target_ops, node_dict, node_list): expr : tvm.relay.Expr.Function Input relay function expression. - target_ops: List of relay.op.Op + target_ops: List of tvm.ir.Op List of target relay ops node_dict : dictionary from tvm.relay.Expr to int @@ -157,7 +157,7 @@ def _expr2graph_impl(expr, target_ops, node_dict, node_list): elif isinstance(node, Constant): node_entry["name"] = "Constant_" + str(node_index) node_entry["types"] = [node.checked_type] - elif isinstance(node, relay.op.op.Op): + elif isinstance(node, tvm.ir.Op): return else: raise RuntimeError("Not supported relay node type in graph tuning: %s" diff --git a/python/tvm/autotvm/task/relay_integration.py b/python/tvm/autotvm/task/relay_integration.py index ecc0112..b7d8fac 100644 --- a/python/tvm/autotvm/task/relay_integration.py +++ b/python/tvm/autotvm/task/relay_integration.py @@ -78,7 +78,7 @@ def extract_from_program(mod, params, target, target_host=None, ops=None): The compilation target target_host: tvm.target.Target The host compilation target - ops: List[relay.op.Op] or None + ops: List[tvm.ir.Op] or None List of relay ops to be tuned. If not specified, all tunable ops will be extracted. Returns @@ -105,7 +105,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No The compilation target target_host: tvm.target.Target The host compilation target - ops: List[relay.op.Op] or None + ops: List[tvm.ir.Op] or None List of relay ops to be tuned. If not specified, all tunable ops will be extracted. Returns diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index 67f9780..59e77f7 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -61,7 +61,7 @@ class TaskExtractEnv: Parameters ---------- - wanted_relay_ops: List of relay.op.Op + wanted_relay_ops: List of tvm.ir.Op The relay ops to be extracted """ self.task_collection = [] diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 1aabf3e..f1d1d50 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -23,6 +23,7 @@ from .type import TypeConstraint, FuncType, IncompleteType, RelayRefType from .tensor_type import TensorType from .type_relation import TypeCall, TypeRelation from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range +from .op import Op, register_op_attr from .function import CallingConv, BaseFunc from .adt import Constructor, TypeData from .module import IRModule diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index 9d90685..6fc24c0 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -109,7 +109,7 @@ def create_updater_06_to_07(): # Base IR "SourceName": _update_global_key, "EnvFunc": _update_global_key, - "relay.Op": _update_global_key, + "relay.Op": [_update_global_key, _rename("Op")], "relay.TypeVar": [_ftype_var, _update_from_std_str("name_hint")], "relay.Id": [_update_from_std_str("name_hint")], "relay.GlobalTypeVar": [_ftype_var, _update_from_std_str("name_hint")], diff --git a/python/tvm/ir/op.py b/python/tvm/ir/op.py new file mode 100644 index 0000000..da546ce --- /dev/null +++ b/python/tvm/ir/op.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Primitive operators in the TVM IR.""" +import tvm._ffi +from . expr import RelayExpr +from . import _ffi_api + + +@tvm._ffi.register_object("Op") +class Op(RelayExpr): + """Primitive operator in the IR.""" + def __init__(self): + raise RuntimeError("Cannot create op, use get instead") + + @staticmethod + def get(op_name): + """Get the Op for a given name + + Parameters + ---------- + op_name : str + The operator name + + Returns + ------- + op : Op + The op of the corresponding name + """ + return _ffi_api.GetOp(op_name) + + def get_attr(self, attr_name): + """Get additional attribute about the operator. + + Parameters + ---------- + attr_name : str + The attribute name. + + Returns + ------- + value : object + The attribute value + """ + return _ffi_api.OpGetAttr(self, attr_name) + + def set_attr(self, attr_name, value, plevel=10): + """Set attribute about the operator. + + Parameters + ---------- + attr_name : str + The attribute name + + value : object + The attribute value + + plevel : int + The priority level + """ + _ffi_api.OpSetAttr(self, attr_name, value, plevel) + + def reset_attr(self, attr_name): + """Reset attribute about the operator. + + Parameters + ---------- + attr_name : str + The attribute name + """ + _ffi_api.OpResetAttr(self, attr_name) + + +def register_op_attr(op_name, attr_key, value=None, level=10): + """Register an operator property of an operator by name. + + Parameters + ---------- + op_name : str + The name of operator + + attr_key : str + The attribute name. + + value : object, optional + The value to set + + level : int, optional + The priority level + + Returns + ------- + fregister : function + Register function if value is not specified. + """ + def _register(v): + """internal register function""" + _ffi_api.RegisterOpAttr(op_name, attr_key, v, level) + return v + return _register(value) if value is not None else _register diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 8e48e50..9c56540 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -40,7 +40,6 @@ from . import param_dict from .backend import vm # Root operators -from .op import Op from .op import nn from .op import image from .op import annotation diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 49f2d4d..8c050fa 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -378,7 +378,7 @@ class ParseTreeToRelayIR(RelayVisitor): return self.module # Exprs - def visitOpIdent(self, ctx) -> op.Op: + def visitOpIdent(self, ctx) -> tvm.ir.Op: op_name = ".".join([name.getText() for name in ctx.CNAME()]) if op_name in FUNC_OPS: return FuncOp(FUNC_OPS[op_name]) diff --git a/python/tvm/relay/analysis/annotated_regions.py b/python/tvm/relay/analysis/annotated_regions.py index fc8e85a..f29b726 100644 --- a/python/tvm/relay/analysis/annotated_regions.py +++ b/python/tvm/relay/analysis/annotated_regions.py @@ -31,9 +31,9 @@ class AnnotatedRegionSet(Object): ---------- expr : tvm.relay.Expr The expression from which to construct the regions. - region_begin_op : tvm.relay.Op + region_begin_op : tvm.ir.Op The region begin annotation. - region_end_op : tvm.relay.Op + region_end_op : tvm.ir.Op The region end annotation. """ diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 3e35bd2..eb5c2b3 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -26,7 +26,6 @@ from tvm.runtime import Object from ... import target as _target from ... import autotvm from .. import function as _function -from .. import op as _op from .. import ty as _ty from . import _backend @@ -98,7 +97,7 @@ def get_valid_implementations(op, attrs, inputs, out_type, target): Parameters ---------- - op : relay.op.Op + op : tvm.ir.Op Relay operator. attrs : object @@ -157,7 +156,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) Parameters ---------- - op : relay.op.Op + op : tvm.ir.Op Relay operator. attrs : object @@ -215,7 +214,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) @tvm._ffi.register_func("relay.backend.lower_call") def lower_call(call, inputs, target): """Lower the call expression to op implementation and tensor outputs.""" - assert isinstance(call.op, _op.Op) + assert isinstance(call.op, tvm.ir.Op) op = call.op # Prepare the call_node->checked_type(). For the call node inputs, we ensure that diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index a428e1b..fbb98fc 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -234,7 +234,7 @@ class Call(ExprWithOp): Parameters ---------- - op: tvm.relay.Op or any tvm.relay.Expr with function type. + op: tvm.ir.Op or any tvm.relay.Expr with function type. The operation to be called. args: List[tvm.relay.Expr] diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index 874a3a7..fd9b253 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -16,13 +16,13 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """The expression functor of Relay.""" +from tvm.ir import Op from .function import Function from .expr import Call, Let, Var, GlobalVar from .expr import If, Tuple, TupleGetItem, Constant from .expr import RefCreate, RefRead, RefWrite from .adt import Constructor, Match, Clause -from .op import Op class ExprFunctor: """ diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index b3054d6..ce0df95 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -17,9 +17,9 @@ #pylint: disable=wildcard-import, redefined-builtin """Relay core operators.""" # operator defs -from .op import get, register, register_compute, register_gradient, \ +from .op import get, register_compute, register_gradient, \ register_pattern, register_alter_op_layout, register_legalize, \ - Op, OpPattern, OpStrategy, debug, register_external_compiler + OpPattern, OpStrategy, debug, register_external_compiler from . import strategy # Operators diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index bd3dd83..27574a8 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -32,7 +32,7 @@ it is supported. For example: - The other way is to implement the function by themselves to check the attributes of the op and decide if it should be offloaded to DNNL. """ -from ... import op as _op +import tvm.ir from ...dataflow_pattern import wildcard, is_op from .register import register_pattern_table @@ -51,7 +51,7 @@ def _register_external_op_helper(op_name, supported=True): f : callable A function that returns if the operator is supported by DNNL. """ - @_op.register(op_name, "target.dnnl") + @tvm.ir.register_op_attr(op_name, "target.dnnl") def _func_wrapper(attrs, args): return supported diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index e6bd6bf..7fad9a2 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -17,61 +17,13 @@ #pylint: disable=unused-argument,invalid-name """The base node types for the Relay language.""" import tvm._ffi +import tvm.ir from tvm.driver import lower, build -from ..expr import RelayExpr from ...target import get_native_generic_func, GenericFunc from ...runtime import Object from . import _make -@tvm._ffi.register_object("relay.Op") -class Op(RelayExpr): - """A Relay operator definition.""" - - def __init__(self): - raise RuntimeError("Cannot create op, use get instead") - - def get_attr(self, attr_name): - """Get additional attribute about the operator. - - Parameters - ---------- - attr_name : str - The attribute name. - - Returns - ------- - value : object - The attribute value - """ - return _OpGetAttr(self, attr_name) - - def set_attr(self, attr_name, value, plevel=10): - """Set attribute about the operator. - - Parameters - ---------- - attr_name : str - The attribute name - - value : object - The attribute value - - plevel : int - The priority level - """ - _OpSetAttr(self, attr_name, value, plevel) - - def reset_attr(self, attr_name): - """Reset attribute about the operator. - - Parameters - ---------- - attr_name : str - The attribute name - """ - _OpResetAttr(self, attr_name) - def get(op_name): """Get the Op for a given name @@ -86,37 +38,7 @@ def get(op_name): op : Op The op of the corresponding name """ - return _GetOp(op_name) - - -def register(op_name, attr_key, value=None, level=10): - """Register an operator property of an operator. - - - Parameters - ---------- - op_name : str - The name of operator - - attr_key : str - The attribute name. - - value : object, optional - The value to set - - level : int, optional - The priority level - - Returns - ------- - fregister : function - Register function if value is not specified. - """ - def _register(v): - """internal register function""" - _Register(op_name, attr_key, v, level) - return v - return _register(value) if value is not None else _register + return tvm.ir.Op.get(op_name) class OpPattern(object): @@ -258,7 +180,7 @@ def register_compute(op_name, compute=None, level=10): level : int The priority level """ - return register(op_name, "FTVMCompute", compute, level) + return tvm.ir.register_op_attr(op_name, "FTVMCompute", compute, level) def register_strategy(op_name, fstrategy=None, level=10): @@ -279,7 +201,7 @@ def register_strategy(op_name, fstrategy=None, level=10): if not isinstance(fstrategy, GenericFunc): assert hasattr(fstrategy, "generic_func_node") fstrategy = fstrategy.generic_func_node - return register(op_name, "FTVMStrategy", fstrategy, level) + return tvm.ir.register_op_attr(op_name, "FTVMStrategy", fstrategy, level) def register_schedule(op_name, schedule, level=10): @@ -360,7 +282,7 @@ def register_alter_op_layout(op_name, alter_layout=None, level=10): level : int The priority level """ - return register(op_name, "FTVMAlterOpLayout", alter_layout, level) + return tvm.ir.register_op_attr(op_name, "FTVMAlterOpLayout", alter_layout, level) def register_convert_op_layout(op_name, convert_layout=None, level=10): @@ -377,7 +299,7 @@ def register_convert_op_layout(op_name, convert_layout=None, level=10): level : int The priority level """ - return register(op_name, "FTVMConvertOpLayout", convert_layout, level) + return tvm.ir.register_op_attr(op_name, "FTVMConvertOpLayout", convert_layout, level) def register_legalize(op_name, legal_op=None, level=10): @@ -394,7 +316,7 @@ def register_legalize(op_name, legal_op=None, level=10): level : int The priority level """ - return register(op_name, "FTVMLegalize", legal_op, level) + return tvm.ir.register_op_attr(op_name, "FTVMLegalize", legal_op, level) def register_pattern(op_name, pattern, level=10): @@ -411,7 +333,7 @@ def register_pattern(op_name, pattern, level=10): level : int The priority level """ - return register(op_name, "TOpPattern", pattern, level) + return tvm.ir.register_op_attr(op_name, "TOpPattern", pattern, level) def register_gradient(op_name, fgradient=None, level=10): @@ -428,7 +350,7 @@ def register_gradient(op_name, fgradient=None, level=10): level : int The priority level """ - return register(op_name, "FPrimalGradient", fgradient, level) + return tvm.ir.register_op_attr(op_name, "FPrimalGradient", fgradient, level) def register_shape_func(op_name, data_dependant, shape_func=None, level=10): @@ -450,7 +372,7 @@ def register_shape_func(op_name, data_dependant, shape_func=None, level=10): The priority level """ get(op_name).set_attr("TShapeDataDependant", data_dependant, level) - return register(op_name, "FShapeFunc", shape_func, level) + return tvm.ir.register_op_attr(op_name, "FShapeFunc", shape_func, level) def register_external_compiler(op_name, fexternal=None, level=10): @@ -469,7 +391,7 @@ def register_external_compiler(op_name, fexternal=None, level=10): level : int The priority level """ - return register(op_name, "FTVMExternalCompiler", fexternal, level) + return tvm.ir.register_op_attr(op_name, "FTVMExternalCompiler", fexternal, level) @tvm._ffi.register_func("relay.op.compiler._lower") diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py index c96a730..d3b0e44 100644 --- a/python/tvm/relay/qnn/op/legalizations.py +++ b/python/tvm/relay/qnn/op/legalizations.py @@ -16,11 +16,10 @@ # under the License. # pylint: disable=invalid-name, unused-argument """Backend QNN related feature registration""" -from __future__ import absolute_import +import numpy as np import tvm from tvm import relay -import numpy as np from .. import op as reg ################################################# diff --git a/python/tvm/relay/qnn/op/op.py b/python/tvm/relay/qnn/op/op.py index 6da15eb..720bac4 100644 --- a/python/tvm/relay/qnn/op/op.py +++ b/python/tvm/relay/qnn/op/op.py @@ -16,7 +16,7 @@ # under the License. #pylint: disable=unused-argument """The register functions for the QNN dialect.""" -from tvm.relay.op.op import register +import tvm.ir def register_qnn_legalize(op_name, legal_op=None, level=10): """Register legal transformation function for a QNN op @@ -32,4 +32,4 @@ def register_qnn_legalize(op_name, legal_op=None, level=10): level : int The priority level """ - return register(op_name, "FTVMQnnLegalize", legal_op, level) + return tvm.ir.register_op_attr(op_name, "FTVMQnnLegalize", legal_op, level) diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 2658a0a..5954e07 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -19,17 +19,16 @@ import warnings import topi import tvm._ffi - +from tvm.relay.op import op as _reg from .. import expr as _expr from .. import analysis as _analysis from .. import op as _op -from ..op import op as _reg from . import _quantize from .quantize import QAnnotateKind, current_qconfig, quantize_context from .quantize import _forward_op -@_reg.register_compute("relay.op.annotation.simulated_quantize") +@_op.register_compute("relay.op.annotation.simulated_quantize") def simulated_quantize_compute(attrs, inputs, out_type): """Compiler for simulated_quantize.""" assert len(inputs) == 4 @@ -106,8 +105,8 @@ def register_annotate_function(op_name, frewrite=None, level=10): if not current_qconfig().guard(ref_call): return default_rewrite(ref_call, new_args, ctx) return func(ref_call, new_args, ctx) - _reg._Register(op_name, "FQAnnotateRewrite", frewrite_with_guard, level) - return frewrite_with_guard + + return tvm.ir.register_op_attr(op_name, "FQAnnotateRewrite", frewrite_with_guard, level) return _register(frewrite) if frewrite is not None else _register diff --git a/python/tvm/relay/quantize/_partition.py b/python/tvm/relay/quantize/_partition.py index bb3db99..a607f4e 100644 --- a/python/tvm/relay/quantize/_partition.py +++ b/python/tvm/relay/quantize/_partition.py @@ -19,14 +19,11 @@ import tvm from .. import expr as _expr from .. import analysis as _analysis -from ..op import op as _reg from . import _quantize from .quantize import _forward_op def register_partition_function(op_name, frewrite=None, level=10): - def _register(func): - return _reg._Register(op_name, "FQPartitionRewrite", func, level) - return _register(frewrite) if frewrite is not None else _register + return tvm.ir.register_op_attr(op_name, "FQPartitionRewrite", frewrite, level) @tvm._ffi.register_object("relay.QPartitionExpr") diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index 89c3393..351f153 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -493,7 +493,7 @@ class PythonConverter(ExprFunctor): func = call.op fields, field_defs = self.convert_fields(call.args) - if isinstance(func, relay.Op): + if isinstance(func, tvm.ir.Op): raise Exception('Operators should have been lowered and eliminated') if isinstance(func, relay.Constructor): diff --git a/src/ir/op.cc b/src/ir/op.cc index b81e358..2c802b6 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -74,63 +74,53 @@ void OpRegEntry::UpdateAttr(const String& key, TVMRetValue value, int plevel) { } // Frontend APIs -TVM_REGISTER_GLOBAL("relay.op._ListOpNames").set_body_typed([]() { +TVM_REGISTER_GLOBAL("ir.ListOpNames").set_body_typed([]() { return OpRegistry::Global()->ListAllNames(); }); -TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed([](String name) -> Op { - return Op::Get(name); -}); +TVM_REGISTER_GLOBAL("ir.GetOp").set_body_typed([](String name) -> Op { return Op::Get(name); }); -TVM_REGISTER_GLOBAL("relay.op._OpGetAttr").set_body([](TVMArgs args, TVMRetValue* rv) { - Op op = args[0]; - std::string attr_name = args[1]; +TVM_REGISTER_GLOBAL("ir.OpGetAttr").set_body_typed([](Op op, String attr_name) -> TVMRetValue { auto op_map = Op::GetAttrMap<TVMRetValue>(attr_name); + TVMRetValue rv; if (op_map.count(op)) { - *rv = op_map[op]; + rv = op_map[op]; } + return rv; }); -TVM_REGISTER_GLOBAL("relay.op._OpSetAttr").set_body([](TVMArgs args, TVMRetValue* rv) { - Op op = args[0]; - std::string attr_name = args[1]; - runtime::TVMArgValue value = args[2]; - int plevel = args[3]; - auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); - reg.set_attr(attr_name, value, plevel); -}); +TVM_REGISTER_GLOBAL("ir.OpSetAttr") + .set_body_typed([](Op op, String attr_name, runtime::TVMArgValue value, int plevel) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op->name).set_name(); + reg.set_attr(attr_name, value, plevel); + }); -TVM_REGISTER_GLOBAL("relay.op._OpResetAttr").set_body([](TVMArgs args, TVMRetValue* rv) { - Op op = args[0]; - std::string attr_name = args[1]; +TVM_REGISTER_GLOBAL("ir.OpResetAttr").set_body_typed([](Op op, String attr_name) { auto& reg = OpRegistry::Global()->RegisterOrGet(op->name); reg.reset_attr(attr_name); }); -TVM_REGISTER_GLOBAL("relay.op._Register").set_body([](TVMArgs args, TVMRetValue* rv) { - std::string op_name = args[0]; - std::string attr_key = args[1]; - runtime::TVMArgValue value = args[2]; - int plevel = args[3]; - auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); - // enable resgiteration and override of certain properties - if (attr_key == "num_inputs" && plevel > 128) { - reg.set_num_inputs(value); - } else if (attr_key == "attrs_type_key" && plevel > 128) { - LOG(FATAL) << "attrs type key no longer supported"; - } else { - // normal attr table override. - if (args[2].type_code() == kTVMPackedFuncHandle) { - // do an eager copy of the PackedFunc - PackedFunc f = args[2]; - // If we get a function from frontend, avoid deleting it. - auto* fcopy = new PackedFunc(f); - reg.set_attr(attr_key, *fcopy, plevel); - } else { - reg.set_attr(attr_key, args[2], plevel); - } - } -}); +TVM_REGISTER_GLOBAL("ir.RegisterOpAttr") + .set_body_typed([](String op_name, String attr_key, runtime::TVMArgValue value, int plevel) { + auto& reg = OpRegistry::Global()->RegisterOrGet(op_name).set_name(); + // enable resgiteration and override of certain properties + if (attr_key == "num_inputs" && plevel > 128) { + reg.set_num_inputs(value); + } else if (attr_key == "attrs_type_key" && plevel > 128) { + LOG(FATAL) << "attrs type key no longer supported"; + } else { + // normal attr table override. + if (value.type_code() == kTVMPackedFuncHandle) { + // do an eager copy of the PackedFunc + PackedFunc f = value; + // If we get a function from frontend, avoid deleting it. + auto* fcopy = new PackedFunc(f); + reg.set_attr(attr_key, *fcopy, plevel); + } else { + reg.set_attr(attr_key, value, plevel); + } + } + }); // helper to get internal dev function in objectref. struct Op2ObjectPtr : public ObjectRef { diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index 3a2b2d9..636593f 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -95,7 +95,7 @@ TEST(Relay, BuildModule) { pC[i] = i + 2; } // get schedule - auto reg = tvm::runtime::Registry::Get("relay.op._Register"); + auto reg = tvm::runtime::Registry::Get("ir.RegisterOpAttr"); if (!reg) { LOG(FATAL) << "no _Register"; } diff --git a/tests/python/relay/test_ir_op.py b/tests/python/relay/test_ir_op.py index 1fd68b3..46e4b02 100644 --- a/tests/python/relay/test_ir_op.py +++ b/tests/python/relay/test_ir_op.py @@ -14,13 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import tvm from tvm import relay from tvm.relay.testing.temp_op_attr import TempOpAttr def test_op_attr(): log_op = relay.op.get("log") - @relay.op.register("exp", "ftest") + @tvm.ir.register_op_attr("exp", "ftest") def test(x): return x + 1 @@ -37,9 +38,9 @@ def test_op_reset_attr(): return x + 2 # Register fadd1 and fadd2 attributes. - relay.op.register("exp", "fadd1", add1) - relay.op.register("log", "fadd1", add1) - relay.op.register("log", "fadd2", add2) + tvm.ir.register_op_attr("exp", "fadd1", add1) + tvm.ir.register_op_attr("log", "fadd1", add1) + tvm.ir.register_op_attr("log", "fadd2", add2) # Reset log fadd1 attr. log_op = relay.op.get("log") @@ -63,7 +64,7 @@ def test_op_temp_attr(): return x + 2 # Set original attr value is add1. - relay.op.register("sqrt", "ftest", add1) + tvm.ir.register_op_attr("sqrt", "ftest", add1) with TempOpAttr("sqrt", "ftest", add2): # Check that the attr value is updated to add2. diff --git a/tests/python/relay/test_ir_parser.py b/tests/python/relay/test_ir_parser.py index 9e62491..c4ac042 100644 --- a/tests/python/relay/test_ir_parser.py +++ b/tests/python/relay/test_ir_parser.py @@ -219,7 +219,7 @@ def test_vars(): # operator id op = parse_text("foo") - assert isinstance(op, relay.Op) + assert isinstance(op, tvm.ir.Op) assert op.name == "foo" diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 0583946..273c27b 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -22,7 +22,6 @@ import pytest import tvm import tvm.relay.testing -import tvm.relay.op as reg import tvm.relay.transform as transform from tvm import relay from tvm import runtime @@ -187,7 +186,7 @@ def test_extern_dnnl_mobilenet(): def test_multiple_ends(): - @reg.register("nn.relu", "target.test") + @tvm.ir.register_op_attr("nn.relu", "target.test") def relu(attrs, args): # pylint: disable=unused-variable return True @@ -229,7 +228,7 @@ def test_multiple_ends(): def test_type_propagation(): target = "test_type_propagation" - @reg.register("nn.relu", "target." + target) + @tvm.ir.register_op_attr("nn.relu", "target." + target) def relu(attrs, args): # pylint: disable=unused-variable return args[0].checked_type.dtype == "float32" @@ -248,11 +247,11 @@ def test_type_propagation(): def test_tuple(): target = "test_tuple" - @reg.register("nn.relu", "target." + target) + @tvm.ir.register_op_attr("nn.relu", "target." + target) def relu(attrs, args): # pylint: disable=unused-variable return True - @reg.register("concatenate", "target." + target) + @tvm.ir.register_op_attr("concatenate", "target." + target) def concatenate(attrs, args): # pylint: disable=unused-variable return True @@ -338,11 +337,11 @@ def test_composite_function(): def test_multiple_runs(): - @reg.register("nn.relu", "target.A") + @tvm.ir.register_op_attr("nn.relu", "target.A") def relu(attrs, args): # pylint: disable=unused-variable return True - @reg.register("add", "target.B") + @tvm.ir.register_op_attr("add", "target.B") def add(attrs, args): # pylint: disable=unused-variable return True diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 23bf618..473ca9d 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -1034,7 +1034,7 @@ def test_multiple_use_of_an_output(): def test_duplicate_outputs(): target = "test_duplicate_outputs" - @reg.register("abs", "target." + target) + @tvm.ir.register_op_attr("abs", "target." + target) def abs(attrs, args): # pylint: disable=unused-variable return True @@ -1090,11 +1090,11 @@ def test_duplicate_outputs(): def test_duplicate_merge_and_tuplegetitem(): target = "test_duplicate_merge_and_tuplegetitem" - @reg.register("nn.batch_norm", "target." + target) + @tvm.ir.register_op_attr("nn.batch_norm", "target." + target) def batch_norm(attrs, args): # pylint: disable=unused-variable return True - @reg.register("nn.relu", "target." + target) + @tvm.ir.register_op_attr("nn.relu", "target." + target) def relu(attrs, args): # pylint: disable=unused-variable return True @@ -1165,7 +1165,7 @@ def test_duplicate_merge_and_tuplegetitem(): assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) def test_constant_tuples(): - @reg.register("qnn.concatenate", "target.const_tuples") + @tvm.ir.register_op_attr("qnn.concatenate", "target.const_tuples") def add(attrs, args): # pylint: disable=unused-variable return True @@ -1203,11 +1203,11 @@ def test_constant_tuples(): def test_flatten_tuple_output(): target = "test_flatten_tuple_output" - @reg.register("split", "target." + target) + @tvm.ir.register_op_attr("split", "target." + target) def split(attrs, args): # pylint: disable=unused-variable return True - @reg.register("abs", "target." + target) + @tvm.ir.register_op_attr("abs", "target." + target) def abs(attrs, args): # pylint: disable=unused-variable return True diff --git a/tests/scripts/task_python_docs.sh b/tests/scripts/task_python_docs.sh index be2d792..4c52e1e 100755 --- a/tests/scripts/task_python_docs.sh +++ b/tests/scripts/task_python_docs.sh @@ -43,10 +43,9 @@ find . -type f -path "*.pyc" | xargs rm -f make cython3 cd docs -PYTHONPATH=`pwd`/../python make html 2>/tmp/$$.log.txt +PYTHONPATH=`pwd`/../python make html |& tee /tmp/$$.log.txt if grep -E "failed to execute" < /tmp/$$.log.txt; then echo "Some of sphinx-gallery item example failed to execute." - cat /tmp/$$.log.txt exit 1 fi cd .. @@ -78,3 +77,5 @@ echo "Start creating the docs tarball.." tar -C _docs -czf docs.tgz . echo "Finish creating the docs tarball" du -h docs.tgz + +echo "Finish everything" diff --git a/tests/scripts/task_sphinx_precheck.sh b/tests/scripts/task_sphinx_precheck.sh index 0328b9e..fd67b0a 100755 --- a/tests/scripts/task_sphinx_precheck.sh +++ b/tests/scripts/task_sphinx_precheck.sh @@ -23,10 +23,6 @@ set -o pipefail cleanup() { - # cat error log if non zero exit - if [ $? ]; then - cat /tmp/$$.log.txt - fi rm -rf /tmp/$$.* } trap cleanup 0 @@ -40,7 +36,7 @@ make cython3 echo "PreCheck sphinx doc generation WARNINGS.." cd docs make clean -TVM_TUTORIAL_EXEC_PATTERN=none make html 2>/tmp/$$.log.txt +TVM_TUTORIAL_EXEC_PATTERN=none make html |& tee /tmp/$$.log.txt grep -v -E "__mro__|UserWarning|FutureWarning|tensorflow|Keras|pytorch|TensorFlow|403" < /tmp/$$.log.txt > /tmp/$$.logclean.txt || true echo "---------Sphinx Log----------" diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py index e1fdfcb..231d400 100644 --- a/vta/python/vta/top/graphpack.py +++ b/vta/python/vta/top/graphpack.py @@ -376,7 +376,7 @@ def get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, cou if isinstance(anf, relay.expr.Let): value = anf.value if isinstance(value, relay.expr.Call): - if isinstance(value.op, relay.op.Op): + if isinstance(value.op, tvm.ir.Op): if value.op.name == start_name and not start_found: if operator_current_idx == start_name_idx or start_name_idx is None: value = relay.expr.Call(bitpack_start, [value])