This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 7ca3212 create function.py (#5087)
7ca3212 is described below
commit 7ca3212f06a56eb95420060aa822a56860d114fd
Author: Zhi <[email protected]>
AuthorDate: Wed Mar 18 08:56:55 2020 -0700
create function.py (#5087)
---
docs/api/python/relay/expr.rst | 3 -
docs/langref/relay_expr.rst | 2 +-
python/tvm/autotvm/graph_tuner/base_graph_tuner.py | 4 +-
.../autotvm/graph_tuner/utils/traverse_graph.py | 3 +-
python/tvm/autotvm/task/relay_integration.py | 6 +-
python/tvm/relay/__init__.py | 3 +-
python/tvm/relay/_parser.py | 7 +-
python/tvm/relay/analysis/analysis.py | 2 +-
python/tvm/relay/backend/compile_engine.py | 4 +-
python/tvm/relay/backend/interpreter.py | 3 +-
python/tvm/relay/build_module.py | 13 ++--
python/tvm/relay/expr.py | 66 +----------------
python/tvm/relay/expr_functor.py | 3 +-
python/tvm/relay/frontend/caffe2.py | 7 +-
python/tvm/relay/frontend/common.py | 5 +-
python/tvm/relay/frontend/coreml.py | 3 +-
python/tvm/relay/frontend/darknet.py | 3 +-
python/tvm/relay/frontend/keras.py | 3 +-
python/tvm/relay/frontend/mxnet.py | 5 +-
python/tvm/relay/frontend/onnx.py | 7 +-
python/tvm/relay/frontend/tensorflow.py | 3 +-
python/tvm/relay/frontend/tflite.py | 3 +-
python/tvm/relay/function.py | 86 ++++++++++++++++++++++
python/tvm/relay/loops.py | 3 +-
python/tvm/relay/prelude.py | 3 +-
python/tvm/relay/testing/nat.py | 3 +-
python/tvm/relay/testing/py_converter.py | 3 +-
src/relay/ir/function.cc | 9 +--
28 files changed, 152 insertions(+), 113 deletions(-)
diff --git a/docs/api/python/relay/expr.rst b/docs/api/python/relay/expr.rst
index 57a4a25..cfb6df0 100644
--- a/docs/api/python/relay/expr.rst
+++ b/docs/api/python/relay/expr.rst
@@ -35,9 +35,6 @@ tvm.relay.expr
.. autoclass:: tvm.relay.expr.Tuple
:members:
-.. autoclass:: tvm.relay.expr.Function
- :members:
-
.. autoclass:: tvm.relay.expr.Call
:members:
diff --git a/docs/langref/relay_expr.rst b/docs/langref/relay_expr.rst
index 66bfe43..3b93360 100644
--- a/docs/langref/relay_expr.rst
+++ b/docs/langref/relay_expr.rst
@@ -120,7 +120,7 @@ Additionally, functions in Relay are higher-order, which
means that a function c
function or returned by a function, as function expressions evaluate to
closures (see the `Closures`_ subsection),
which are values like tensors and tuples.
-See :py:class:`~tvm.relay.expr.Function` for the definition and documentation
of function nodes.
+See :py:class:`~tvm.relay.function.Function` for the definition and
documentation of function nodes.
Syntax
~~~~~~
diff --git a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py
b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py
index f1a0756..e7b4694 100644
--- a/python/tvm/autotvm/graph_tuner/base_graph_tuner.py
+++ b/python/tvm/autotvm/graph_tuner/base_graph_tuner.py
@@ -69,7 +69,7 @@ class BaseGraphTuner(object):
target_op in the input graph and layout transformation benchmark need
to be
executed before initialization.
- graph : tvm.relay.Expr.Function
+ graph : tvm.relay.function.Function
Input graph
input_shapes : dict of str to tuple.
@@ -143,7 +143,7 @@ class BaseGraphTuner(object):
if isinstance(graph, tvm.IRModule):
graph = graph["main"]
- if isinstance(graph, relay.expr.Function):
+ if isinstance(graph, relay.function.Function):
node_dict = {}
graph = bind_inputs(graph, input_shapes, dtype)
expr2graph(graph, self._target_ops, node_dict, self._node_list)
diff --git a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
index f1dd404..8470fb6 100644
--- a/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
+++ b/python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
@@ -21,7 +21,8 @@ import threading
import tvm
from tvm import relay, autotvm
from tvm.relay import transform
-from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple
+from tvm.relay.expr import Call, TupleGetItem, Var, Constant, Tuple
+from tvm.relay.function import Function
from tvm.relay.ty import TupleType, TensorType
from tvm.autotvm.task import TaskExtractEnv
diff --git a/python/tvm/autotvm/task/relay_integration.py
b/python/tvm/autotvm/task/relay_integration.py
index cd8d32f..a7cbef7 100644
--- a/python/tvm/autotvm/task/relay_integration.py
+++ b/python/tvm/autotvm/task/relay_integration.py
@@ -61,7 +61,7 @@ def extract_from_program(mod, params, target,
target_host=None, ops=None):
Parameters
----------
- mod: tvm.IRModule or relay.expr.Function
+ mod: tvm.IRModule or relay.function.Function
The module or function to tune
params: dict of str to numpy array
The associated parameters of the program
@@ -88,7 +88,7 @@ def extract_from_multiple_program(mods, params, target,
target_host=None, ops=No
Parameters
----------
- mods: List[tvm.IRModule] or List[relay.expr.Function]
+ mods: List[tvm.IRModule] or List[relay.function.Function]
The list of modules or functions to tune
params: List of dict of str to numpy array
The associated parameters of the programs
@@ -118,7 +118,7 @@ def extract_from_multiple_program(mods, params, target,
target_host=None, ops=No
logger.disabled = True
for mod, param in zip(mods, params):
- if isinstance(mod, relay.expr.Function):
+ if isinstance(mod, relay.function.Function):
mod = tvm.IRModule.from_expr(mod)
assert isinstance(mod, tvm.IRModule), \
"only support relay Module or Function to be tuned"
diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py
index b1aac3e..95545c8 100644
--- a/python/tvm/relay/__init__.py
+++ b/python/tvm/relay/__init__.py
@@ -22,6 +22,7 @@ from sys import setrecursionlimit
from . import base
from . import ty
from . import expr
+from . import function
from . import type_functor
from . import expr_functor
from . import adt
@@ -87,7 +88,7 @@ Constant = expr.Constant
Tuple = expr.Tuple
Var = expr.Var
GlobalVar = expr.GlobalVar
-Function = expr.Function
+Function = function.Function
Call = expr.Call
Let = expr.Let
If = expr.If
diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py
index 49bdbb3..4a73e57 100644
--- a/python/tvm/relay/_parser.py
+++ b/python/tvm/relay/_parser.py
@@ -43,6 +43,7 @@ from tvm.ir import IRModule
from .base import Span, SourceName
from . import adt
from . import expr
+from . import function
from . import ty
from . import op
@@ -481,7 +482,7 @@ class ParseTreeToRelayIR(RelayVisitor):
def mk_func(
self,
ctx: Union[RelayParser.FuncContext, RelayParser.DefnContext]) \
- -> expr.Function:
+ -> function.Function:
"""Construct a function from either a Func or Defn."""
# Enter var scope early to put params in scope.
self.enter_var_scope()
@@ -511,10 +512,10 @@ class ParseTreeToRelayIR(RelayVisitor):
self.exit_var_scope()
attrs = tvm.ir.make_node("DictAttrs", **attr_list) if attr_list is not
None else None
- return expr.Function(var_list, body, ret_type, type_params, attrs)
+ return function.Function(var_list, body, ret_type, type_params, attrs)
@spanify
- def visitFunc(self, ctx: RelayParser.FuncContext) -> expr.Function:
+ def visitFunc(self, ctx: RelayParser.FuncContext) -> function.Function:
return self.mk_func(ctx)
# TODO: how to set spans for definitions?
diff --git a/python/tvm/relay/analysis/analysis.py
b/python/tvm/relay/analysis/analysis.py
index beb3c65..722f3b0 100644
--- a/python/tvm/relay/analysis/analysis.py
+++ b/python/tvm/relay/analysis/analysis.py
@@ -421,7 +421,7 @@ def extract_fused_functions(mod):
Returns
-------
- ret : Dict[int, tvm.relay.ir.expr.Function]
+ ret : Dict[int, tvm.relay.function.Function]
A module containing only fused primitive functions
"""
ret_mod = _ffi_api.ExtractFusedFunctions()(mod)
diff --git a/python/tvm/relay/backend/compile_engine.py
b/python/tvm/relay/backend/compile_engine.py
index 03d91d5..3e35bd2 100644
--- a/python/tvm/relay/backend/compile_engine.py
+++ b/python/tvm/relay/backend/compile_engine.py
@@ -25,7 +25,7 @@ from tvm import te
from tvm.runtime import Object
from ... import target as _target
from ... import autotvm
-from .. import expr as _expr
+from .. import function as _function
from .. import op as _op
from .. import ty as _ty
from . import _backend
@@ -65,7 +65,7 @@ class CCacheValue(Object):
def _get_cache_key(source_func, target):
- if isinstance(source_func, _expr.Function):
+ if isinstance(source_func, _function.Function):
if isinstance(target, str):
target = _target.create(target)
if not target:
diff --git a/python/tvm/relay/backend/interpreter.py
b/python/tvm/relay/backend/interpreter.py
index ab39f7c..9c4be29 100644
--- a/python/tvm/relay/backend/interpreter.py
+++ b/python/tvm/relay/backend/interpreter.py
@@ -27,7 +27,8 @@ from tvm.ir import IRModule
from . import _backend
from .. import _make, analysis, transform
from ... import nd
-from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
+from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, const
+from ..function import Function
from ..scope_builder import ScopeBuilder
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index d1add27..30c5971 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -29,6 +29,7 @@ from ..contrib import graph_runtime as _graph_rt
from . import _build_module
from . import ty as _ty
from . import expr as _expr
+from . import function as _function
from .backend import interpreter as _interpreter
from .backend.vm import VMExecutor
@@ -218,16 +219,16 @@ def build(mod, target=None, target_host=None,
params=None):
params : dict
The parameters of the final graph.
"""
- if not isinstance(mod, (IRModule, _expr.Function)):
+ if not isinstance(mod, (IRModule, _function.Function)):
raise ValueError("Type of input parameter mod must be tvm.IRModule")
- if isinstance(mod, _expr.Function):
+ if isinstance(mod, _function.Function):
if params:
mod = bind_params_by_name(mod, params)
mod = IRModule.from_expr(mod)
warnings.warn(
"Please use input parameter mod (tvm.IRModule) "
- "instead of deprecated parameter mod (tvm.relay.expr.Function)",
+ "instead of deprecated parameter mod
(tvm.relay.function.Function)",
DeprecationWarning)
target = _update_target(target)
@@ -276,16 +277,16 @@ def optimize(mod, target=None, params=None):
params : dict
The parameters of the final graph.
"""
- if not isinstance(mod, (IRModule, _expr.Function)):
+ if not isinstance(mod, (IRModule, _function.Function)):
raise ValueError("Type of input parameter mod must be tvm.IRModule")
- if isinstance(mod, _expr.Function):
+ if isinstance(mod, _function.Function):
if params:
mod = bind_params_by_name(mod, params)
mod = IRModule.from_expr(mod)
warnings.warn(
"Please use input parameter mod (tvm.IRModule) "
- "instead of deprecated parameter func (tvm.relay.expr.Function)",
+ "instead of deprecated parameter func
(tvm.relay.function.Function)",
DeprecationWarning)
target = _update_target(target)
diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py
index 380cdf7..ff13683 100644
--- a/python/tvm/relay/expr.py
+++ b/python/tvm/relay/expr.py
@@ -22,8 +22,8 @@ from numbers import Number as _Number
import numpy as _np
import tvm._ffi
from tvm._ffi import base as _base
-from tvm.runtime import NDArray, convert, ndarray as _nd
-from tvm.ir import RelayExpr, GlobalVar, BaseFunc
+from tvm.runtime import NDArray, ndarray as _nd
+from tvm.ir import RelayExpr, GlobalVar
from .base import RelayNode
from . import _ffi_api
@@ -225,68 +225,6 @@ class Var(ExprWithOp):
return name
-@tvm._ffi.register_object("relay.Function")
-class Function(BaseFunc):
- """A function declaration expression.
-
- Parameters
- ----------
- params: List[tvm.relay.Var]
- List of input parameters to the function.
-
- body: tvm.relay.Expr
- The body of the function.
-
- ret_type: Optional[tvm.relay.Type]
- The return type annotation of the function.
-
- type_params: Optional[List[tvm.relay.TypeParam]]
- The additional type parameters, this is only
- used in advanced usecase of template functions.
- """
- def __init__(self,
- params,
- body,
- ret_type=None,
- type_params=None,
- attrs=None):
- if type_params is None:
- type_params = convert([])
-
- self.__init_handle_by_constructor__(
- _ffi_api.Function, params, body, ret_type, type_params, attrs)
-
- def __call__(self, *args):
- """Invoke the global function.
-
- Parameters
- ----------
- args: List[relay.Expr]
- Arguments.
- """
- return Call(self, args, None, None)
-
- def with_attr(self, attr_key, attr_value):
- """Create a new copy of the function and update the attribute
-
- Parameters
- ----------
- attr_key : str
- The attribute key to use.
-
- attr_value : Object
- The new attribute value.
-
- Returns
- -------
- func : Function
- A new copy of the function
- """
- return _ffi_api.FunctionWithAttr(
- self, attr_key, convert(attr_value))
-
-
-
@tvm._ffi.register_object("relay.Call")
class Call(ExprWithOp):
"""Function call node in Relay.
diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py
index 8d69239..874a3a7 100644
--- a/python/tvm/relay/expr_functor.py
+++ b/python/tvm/relay/expr_functor.py
@@ -17,7 +17,8 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The expression functor of Relay."""
-from .expr import Function, Call, Let, Var, GlobalVar
+from .function import Function
+from .expr import Call, Let, Var, GlobalVar
from .expr import If, Tuple, TupleGetItem, Constant
from .expr import RefCreate, RefRead, RefWrite
from .adt import Constructor, Match, Clause
diff --git a/python/tvm/relay/frontend/caffe2.py
b/python/tvm/relay/frontend/caffe2.py
index da0cc64..f4fcd92 100644
--- a/python/tvm/relay/frontend/caffe2.py
+++ b/python/tvm/relay/frontend/caffe2.py
@@ -21,6 +21,7 @@ from tvm.ir import IRModule
from .. import analysis
from .. import expr as _expr
+from .. import function as _function
from .. import op as _op
from ... import nd as _nd
from .common import AttrCvt, Renamer
@@ -451,7 +452,7 @@ class Caffe2NetDef(object):
else:
outputs = out[0]
- func = _expr.Function(analysis.free_vars(outputs), outputs)
+ func = _function.Function(analysis.free_vars(outputs), outputs)
self._mod["main"] = func
return self._mod, self._params
@@ -517,7 +518,7 @@ class Caffe2NetDef(object):
----------
op_type : str
Operator name, such as Convolution, FullyConnected
- inputs : list of tvm.relay.expr.Function
+ inputs : list of tvm.relay.function.Function
List of input inputs.
args : dict
Dict of operator attributes
@@ -530,7 +531,7 @@ class Caffe2NetDef(object):
Returns
-------
- func : tvm.relay.expr.Function
+ func : tvm.relay.function.Function
Converted relay function
"""
identity_list = identity_list if identity_list else _identity_list
diff --git a/python/tvm/relay/frontend/common.py
b/python/tvm/relay/frontend/common.py
index d427fe9..6185121 100644
--- a/python/tvm/relay/frontend/common.py
+++ b/python/tvm/relay/frontend/common.py
@@ -24,6 +24,7 @@ from tvm.ir import IRModule
from topi.util import get_const_tuple
from .. import expr as _expr
+from .. import function as _function
from .. import transform as _transform
from .. import op as _op
from .. import analysis
@@ -459,7 +460,7 @@ def infer_type(node, mod=None):
new_mod.update(mod)
new_mod = _transform.InferType()(new_mod)
entry = new_mod["main"]
- return entry if isinstance(node, _expr.Function) else entry.body
+ return entry if isinstance(node, _function.Function) else entry.body
def infer_shape(inputs, mod=None):
"""A method to get the output type of an intermediate node in the graph."""
@@ -491,7 +492,7 @@ def infer_value(input_val, params):
# Check that all free variables have associated parameters.
assert all(var.name_hint in params.keys() for var in analysis.free_vars(
input_val)), "All inputs to infer must be available in params."
- func = _expr.Function(analysis.free_vars(input_val), input_val)
+ func = _function.Function(analysis.free_vars(input_val), input_val)
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm",
params=params)
ctx = tvm.cpu(0)
diff --git a/python/tvm/relay/frontend/coreml.py
b/python/tvm/relay/frontend/coreml.py
index 0e5b64c..6658803 100644
--- a/python/tvm/relay/frontend/coreml.py
+++ b/python/tvm/relay/frontend/coreml.py
@@ -24,6 +24,7 @@ from tvm.ir import IRModule
from .. import analysis
from .. import expr as _expr
+from .. import function as _function
from .. import op as _op
from ... import nd as _nd
from ..._ffi import base as _base
@@ -503,6 +504,6 @@ def from_coreml(model, shape=None):
for o in spec.description.output]
# for now return first output
outexpr = outexpr[0]
- func = _expr.Function(analysis.free_vars(outexpr), outexpr)
+ func = _function.Function(analysis.free_vars(outexpr), outexpr)
params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in
etab.params.items()}
return IRModule.from_expr(func), params
diff --git a/python/tvm/relay/frontend/darknet.py
b/python/tvm/relay/frontend/darknet.py
index 0dae645..936d7c0 100644
--- a/python/tvm/relay/frontend/darknet.py
+++ b/python/tvm/relay/frontend/darknet.py
@@ -26,6 +26,7 @@ from tvm.ir import IRModule
from .. import analysis
from .. import expr as _expr
+from .. import function as _function
from .common import get_relay_op, new_var
__all__ = ['from_darknet']
@@ -821,7 +822,7 @@ class GraphProto(object):
outputs = _as_list(sym) + self._outs
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
- sym = _expr.Function(analysis.free_vars(outputs), outputs)
+ sym = _function.Function(analysis.free_vars(outputs), outputs)
return IRModule.from_expr(sym), self._tvmparams
def from_darknet(net,
diff --git a/python/tvm/relay/frontend/keras.py
b/python/tvm/relay/frontend/keras.py
index adb28c4..090bd4c 100644
--- a/python/tvm/relay/frontend/keras.py
+++ b/python/tvm/relay/frontend/keras.py
@@ -23,6 +23,7 @@ from tvm.ir import IRModule
from .. import analysis
from .. import expr as _expr
+from .. import function as _function
from .. import op as _op
from ... import nd as _nd
from .common import ExprTable, new_var
@@ -914,6 +915,6 @@ def from_keras(model, shape=None, layout='NCHW'):
outexpr = [etab.get_expr(oc[0].name + ":" + str(oc[1]) + ":" + str(oc[2]))
\
for oc in model._output_coordinates]
outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr)
- func = _expr.Function(analysis.free_vars(outexpr), outexpr)
+ func = _function.Function(analysis.free_vars(outexpr), outexpr)
params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in
etab.params.items()}
return IRModule.from_expr(func), params
diff --git a/python/tvm/relay/frontend/mxnet.py
b/python/tvm/relay/frontend/mxnet.py
index ba93bb2..17be368 100644
--- a/python/tvm/relay/frontend/mxnet.py
+++ b/python/tvm/relay/frontend/mxnet.py
@@ -25,6 +25,7 @@ from tvm import relay
from topi.util import get_const_tuple
from .. import analysis
from .. import expr as _expr
+from .. import function as _function
from .. import op as _op
from .. import scope_builder as _scope_builder
from ... import nd as _nd
@@ -1096,7 +1097,7 @@ def _mx_cond(inputs, attrs, subgraphs):
else_arg_dtype_info = [arg.type_annotation.dtype for arg in else_args]
else_func = _from_mxnet_impl(subgraphs[2], else_arg_shapes,
else_arg_dtype_info)
sb.ret(_expr.Call(else_func, else_args))
- func = _expr.Function(input_args, sb.get())
+ func = _function.Function(input_args, sb.get())
ret = _expr.Call(func, inputs)
if num_outputs > 1:
ret = _expr.TupleWrapper(ret, num_outputs)
@@ -1969,7 +1970,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info,
params=None, mod=None):
outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
- func = _expr.Function(analysis.free_vars(outputs), outputs)
+ func = _function.Function(analysis.free_vars(outputs), outputs)
return func
diff --git a/python/tvm/relay/frontend/onnx.py
b/python/tvm/relay/frontend/onnx.py
index 7f417d3..e1b0a7f 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -24,6 +24,7 @@ from tvm.ir import IRModule
from ... import nd as _nd
from .. import analysis
from .. import expr as _expr
+from .. import function as _function
from .. import op as _op
from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels
@@ -1708,7 +1709,7 @@ class GraphProto(object):
# now return the outputs
outputs = [self._nodes[self._parse_value_proto(i)] for i in
graph.output]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
- func = _expr.Function(analysis.free_vars(outputs), outputs)
+ func = _function.Function(analysis.free_vars(outputs), outputs)
return IRModule.from_expr(func), self._params
def _parse_value_proto(self, value_proto):
@@ -1774,7 +1775,7 @@ class GraphProto(object):
----------
op_name : str
Operator name, such as Convolution, FullyConnected
- inputs : list of tvm.relay.expr.Function
+ inputs : list of tvm.relay.function.Function
List of inputs.
attrs : dict
Dict of operator attributes
@@ -1783,7 +1784,7 @@ class GraphProto(object):
Returns
-------
- sym : tvm.relay.expr.Function
+ sym : tvm.relay.function.Function
Converted relay function
"""
convert_map = _get_convert_map(opset)
diff --git a/python/tvm/relay/frontend/tensorflow.py
b/python/tvm/relay/frontend/tensorflow.py
index 3dca365..e0da863 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -31,6 +31,7 @@ from tvm.relay.prelude import Prelude
from .. import analysis
from .. import expr as _expr
+from .. import function as _function
from .. import op as _op
from ..expr_functor import ExprMutator
from .common import AttrCvt, get_relay_op
@@ -2461,7 +2462,7 @@ class GraphProto(object):
out.append(out_rnn)
out = out[0] if len(out) == 1 else _expr.Tuple(out)
- func = _expr.Function(analysis.free_vars(out), out)
+ func = _function.Function(analysis.free_vars(out), out)
self._mod["main"] = func
return self._mod, self._params
diff --git a/python/tvm/relay/frontend/tflite.py
b/python/tvm/relay/frontend/tflite.py
index 95f7579..aa51570 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -24,6 +24,7 @@ from tvm.ir import IRModule
from tvm import relay
from .. import analysis
from .. import expr as _expr
+from .. import function as _function
from .. import op as _op
from .. import qnn as _qnn
from ... import nd as _nd
@@ -2365,6 +2366,6 @@ def from_tflite(model, shape_dict, dtype_dict):
params = {k:_nd.array(np.array(v)) for k, v in exp_tab.params.items()}
outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in
model_outputs]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
- func = _expr.Function(analysis.free_vars(outputs), outputs)
+ func = _function.Function(analysis.free_vars(outputs), outputs)
mod = IRModule.from_expr(func)
return mod, params
diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py
new file mode 100644
index 0000000..786a7f4
--- /dev/null
+++ b/python/tvm/relay/function.py
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=no-else-return, invalid-name, unused-import
+"""The expression nodes of Relay."""
+from __future__ import absolute_import
+
+import tvm._ffi
+from tvm.runtime import convert
+from tvm.ir import BaseFunc
+
+from .expr import Call
+from . import _ffi_api
+
+@tvm._ffi.register_object("relay.Function")
+class Function(BaseFunc):
+ """A function declaration expression.
+
+ Parameters
+ ----------
+ params: List[tvm.relay.Var]
+ List of input parameters to the function.
+
+ body: tvm.relay.Expr
+ The body of the function.
+
+ ret_type: Optional[tvm.relay.Type]
+ The return type annotation of the function.
+
+ type_params: Optional[List[tvm.relay.TypeParam]]
+ The additional type parameters, this is only
+ used in advanced usecase of template functions.
+ """
+ def __init__(self,
+ params,
+ body,
+ ret_type=None,
+ type_params=None,
+ attrs=None):
+ if type_params is None:
+ type_params = convert([])
+
+ self.__init_handle_by_constructor__(
+ _ffi_api.Function, params, body, ret_type, type_params, attrs)
+
+ def __call__(self, *args):
+ """Invoke the global function.
+
+ Parameters
+ ----------
+ args: List[relay.Expr]
+ Arguments.
+ """
+ return Call(self, args, None, None)
+
+ def with_attr(self, attr_key, attr_value):
+ """Create a new copy of the function and update the attribute
+
+ Parameters
+ ----------
+ attr_key : str
+ The attribute key to use.
+
+ attr_value : Object
+ The new attribute value.
+
+ Returns
+ -------
+ func : Function
+ A new copy of the function
+ """
+ return _ffi_api.FunctionWithAttr(
+ self, attr_key, convert(attr_value))
diff --git a/python/tvm/relay/loops.py b/python/tvm/relay/loops.py
index 8e066ab..9af6811 100644
--- a/python/tvm/relay/loops.py
+++ b/python/tvm/relay/loops.py
@@ -20,6 +20,7 @@ Utilities for building Relay loops.
"""
from .scope_builder import ScopeBuilder
from . import expr as _expr
+from . import function as _function
def while_loop(cond, loop_vars, loop_bodies):
"""
@@ -60,6 +61,6 @@ def while_loop(cond, loop_vars, loop_bodies):
with sb.else_scope():
sb.ret(_expr.Tuple(fresh_vars))
- func = _expr.Function(fresh_vars, sb.get())
+ func = _function.Function(fresh_vars, sb.get())
let = _expr.Let(loop, func, loop)
return let
diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py
index 5288a2e..0e64a2f 100644
--- a/python/tvm/relay/prelude.py
+++ b/python/tvm/relay/prelude.py
@@ -19,7 +19,8 @@
from tvm.ir import IRModule
from .ty import GlobalTypeVar, TensorType, Any, scalar_type
-from .expr import Var, Function, GlobalVar, If, const
+from .expr import Var, GlobalVar, If, const
+from .function import Function
from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard
diff --git a/python/tvm/relay/testing/nat.py b/python/tvm/relay/testing/nat.py
index eb71120..4906eef 100644
--- a/python/tvm/relay/testing/nat.py
+++ b/python/tvm/relay/testing/nat.py
@@ -21,7 +21,8 @@ test cases for recursion and pattern matching."""
from tvm.relay.adt import Constructor, TypeData, Clause, Match,
PatternConstructor, PatternVar
from tvm.relay.backend.interpreter import ConstructorValue
-from tvm.relay.expr import Var, Function, GlobalVar
+from tvm.relay.expr import Var, GlobalVar
+from tvm.relay.function import Function
from tvm.relay.ty import GlobalTypeVar, TypeVar, FuncType
def define_nat_adt(prelude):
diff --git a/python/tvm/relay/testing/py_converter.py
b/python/tvm/relay/testing/py_converter.py
index eacfe37..e850000 100644
--- a/python/tvm/relay/testing/py_converter.py
+++ b/python/tvm/relay/testing/py_converter.py
@@ -23,7 +23,8 @@ import tvm
from tvm import relay
from tvm.relay.adt import Pattern
from tvm.relay.backend import compile_engine
-from tvm.relay.expr import Expr, Function, GlobalVar, Var
+from tvm.relay.expr import Expr, GlobalVar, Var
+from tvm.relay.function import Function
from tvm.relay.expr_functor import ExprFunctor
OUTPUT_VAR_NAME = '_py_out'
diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc
index b251645..48cb4d8 100644
--- a/src/relay/ir/function.cc
+++ b/src/relay/ir/function.cc
@@ -27,10 +27,10 @@ namespace tvm {
namespace relay {
Function::Function(tvm::Array<Var> params,
- Expr body,
- Type ret_type,
- tvm::Array<TypeVar> type_params,
- DictAttrs attrs) {
+ Expr body,
+ Type ret_type,
+ tvm::Array<TypeVar> type_params,
+ DictAttrs attrs) {
ObjectPtr<FunctionNode> n = make_object<FunctionNode>();
CHECK(params.defined());
CHECK(type_params.defined());
@@ -66,7 +66,6 @@ TVM_REGISTER_GLOBAL("relay.ir.Function")
return Function(params, body, ret_type, ty_params, attrs);
});
-
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const FunctionNode*>(ref.get());