This is an automated email from the ASF dual-hosted git repository.
mousius pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 095b639 Revert "[Frontend] Add Span filling for frontends to Relay
(#9723)" (#10072)
095b639 is described below
commit 095b63935efdd42effc0298998d67838086b7b26
Author: Chun-I Tsai <[email protected]>
AuthorDate: Wed Jan 26 22:35:38 2022 +0800
Revert "[Frontend] Add Span filling for frontends to Relay (#9723)" (#10072)
Because of the failure of LSTM conversion from Pytorch
---
python/tvm/relay/expr.py | 7 +--
python/tvm/relay/frontend/common.py | 53 ---------------------
python/tvm/relay/frontend/pytorch.py | 19 --------
python/tvm/relay/frontend/tensorflow.py | 17 ++++++-
python/tvm/relay/frontend/tensorflow2.py | 17 ++++++-
python/tvm/relay/frontend/tflite.py | 16 ++-----
src/printer/relay_text_printer.cc | 23 +++------
src/printer/text_printer.h | 2 +-
src/relay/ir/expr.cc | 4 +-
tests/python/frontend/pytorch/test_forward.py | 47 -------------------
tests/python/frontend/tensorflow/test_forward.py | 54 ----------------------
.../frontend/tensorflow2/test_sequential_models.py | 24 +---------
tests/python/frontend/tflite/test_forward.py | 54 ----------------------
13 files changed, 48 insertions(+), 289 deletions(-)
diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py
index 598354e..811e205 100644
--- a/python/tvm/relay/expr.py
+++ b/python/tvm/relay/expr.py
@@ -316,13 +316,10 @@ class TupleGetItem(ExprWithOp):
index: int
The index.
-
- span: Optional[tvm.relay.Span]
- Span that points to original source code
"""
- def __init__(self, tuple_value, index, span=None):
- self.__init_handle_by_constructor__(_ffi_api.TupleGetItem,
tuple_value, index, span)
+ def __init__(self, tuple_value, index):
+ self.__init_handle_by_constructor__(_ffi_api.TupleGetItem,
tuple_value, index)
@tvm._ffi.register_object("relay.RefCreate")
diff --git a/python/tvm/relay/frontend/common.py
b/python/tvm/relay/frontend/common.py
index f8c12ff..eeede18 100755
--- a/python/tvm/relay/frontend/common.py
+++ b/python/tvm/relay/frontend/common.py
@@ -25,7 +25,6 @@ from tvm.ir import IRModule
from tvm.topi.utils import get_const_tuple
from .. import expr as _expr
-from ..expr_functor import ExprMutator
from .. import function as _function
from .. import transform as _transform
from .. import op as _op
@@ -955,55 +954,3 @@ def try_resolve_var_to_const(x, graph_params):
return _op.const(value, dtype)
return x
-
-
-def set_span(sym, node_name):
- """Set up the span of relay expression(s) while converting OP"""
-
- class SpanFiller(ExprMutator):
- """SpanFiller"""
-
- def __init__(self, node_name, suffix_str="_PART_"):
- ExprMutator.__init__(self)
- self.node_name = node_name
- self.suffix_str = suffix_str
- self.counter = 0
- self.distance_from_leaf = -1
-
- def _create_span(self):
- if self.distance_from_leaf == 0:
- return tvm.relay.Span(tvm.relay.SourceName(self.node_name), 0,
0, 0, 0)
- self.distance_from_leaf -= 1
- span_str = "{}{}{}".format(self.node_name, self.suffix_str,
str(self.counter))
- self.counter += 1
- return tvm.relay.Span(tvm.relay.SourceName(span_str), 0, 0, 0, 0)
-
- def visit_call(self, call):
- if call.span is None:
- self.distance_from_leaf += 1
- new_args = [self.visit(arg) for arg in call.args]
- return _expr.Call(
- call.op, new_args, call.attrs, call.type_args,
self._create_span()
- )
- return call
-
- def visit_tuple(self, tup):
- if tup.span is None:
- self.distance_from_leaf += 1
- return _expr.Tuple([self.visit(field) for field in
tup.fields], self._create_span())
- return tup
-
- def visit_tuple_getitem(self, op):
- if op.span is None:
- self.distance_from_leaf += 1
- return _expr.TupleGetItem(self.visit(op.tuple_value),
op.index, self._create_span())
- return op
-
- def fill(self, sym):
- if isinstance(sym, _expr.TupleWrapper):
- return _expr.TupleWrapper(self.visit(sym.tuple_value),
sym.size)
- if isinstance(sym, _expr.RelayExpr):
- return self.visit(sym)
- return sym
-
- return SpanFiller(node_name).fill(sym)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index b718837..f7538f0 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -45,7 +45,6 @@ from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
from .common import infer_value_simulated as _infer_value_simulated
from .common import lstm_cell, try_infer_value, unbind
-from .common import set_span
from .pytorch_utils import is_version_greater_than
__all__ = ["from_pytorch"]
@@ -3276,9 +3275,6 @@ class PyTorchOpConverter:
def convert_operators(self, operators, outputs, ret_names):
"""Convert each Torch IR operators to Relay equivalent"""
- # an op node might not belong to any of scope in trace info natively
- # use a cunter to prevent from messing up its scope in span
- empty_counter = 0
for node_name, op_node in operators:
operator = op_node.kind()
inputs = _get_op_inputs(op_node, outputs)
@@ -3339,9 +3335,6 @@ class PyTorchOpConverter:
relay_out = relay_op(
inputs, _get_input_types(op_node, outputs,
default_dtype=self.default_dtype)
)
- span_str, empty_counter = self._get_torch_span(op_node,
empty_counter)
- relay_out = set_span(relay_out, span_str)
-
self.record_output_type(relay_out)
if isinstance(relay_out, tuple):
@@ -3355,18 +3348,6 @@ class PyTorchOpConverter:
return [_wrap_const(outputs[ret_name]) for ret_name in ret_names]
- def _get_torch_span(self, node, empty_counter):
- # torch span looks like
- # %input.5 : Float(...) = aten::relu_(%input.3), scope: __module.relu
# ${torch}/nn file
- # the scope part might not exist
- if node.scopeName():
- scope_name_str = "jit._trace.TopLevelTracedModule: " +
node.scopeName()
- else:
- scope_name_str = "warning: no trace info " + str(empty_counter)
- empty_counter += 1
- span_str = "C.graph: {}, {}".format(node.kind(), scope_name_str)
- return span_str, empty_counter
-
def _pytorch_result_type(dtypes, non_tensor_inputs):
"""This promotes TVM dtypes like PyTorch would"""
diff --git a/python/tvm/relay/frontend/tensorflow.py
b/python/tvm/relay/frontend/tensorflow.py
index c2aa5a1..d35e0e1 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -37,7 +37,6 @@ from .common import get_relay_op
from .common import infer_type as _infer_type
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
-from .common import set_span
from .tensorflow_ops import _convert_map
from .tensorflow_ops import _need_prelude_for_shape_inference
@@ -1029,10 +1028,24 @@ class GraphProto(object):
else:
raise NotImplementedError("Operator {} not
implemented.".format(op_name))
- sym = set_span(sym, node_name)
+ sym = self._set_span(sym, node_name)
return sym
+ @staticmethod
+ def _set_span(sym, node_name):
+ span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0)
+ if isinstance(sym, _expr.Call) and sym.span is None:
+ sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span)
+ elif isinstance(sym, _expr.TupleWrapper):
+ tuple_value = sym.tuple_value
+ if isinstance(tuple_value, _expr.Call) and tuple_value.span is
None:
+ tuple_value = _expr.Call(
+ tuple_value.op, tuple_value.args, tuple_value.attrs,
tuple_value.type_args, span
+ )
+ sym = _expr.TupleWrapper(tuple_value, sym.size)
+ return sym
+
def _licm_construct(self, loop_name, node_name):
"""Construct a node by considering whether it is
loop invariant with the given while loop. If yes, we
diff --git a/python/tvm/relay/frontend/tensorflow2.py
b/python/tvm/relay/frontend/tensorflow2.py
index 2c8b7d4..465f530 100644
--- a/python/tvm/relay/frontend/tensorflow2.py
+++ b/python/tvm/relay/frontend/tensorflow2.py
@@ -36,7 +36,6 @@ from .. import analysis
from .. import function as _function
from ..loops import while_loop as _while_loop
from .common import infer_type as _infer_type
-from .common import set_span
from .tensorflow_ops import _convert_map as _convert_map_common
from .tensorflow_ops import _get_more_static_shape_rank
@@ -59,6 +58,22 @@ def _infer_type_with_prelude(val, prelude):
return body.checked_type
+def set_span(sym, node_name):
+ """set span of symbol"""
+
+ span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0)
+ if isinstance(sym, _expr.Call):
+ sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span)
+ elif isinstance(sym, _expr.TupleWrapper):
+ tuple_value = sym.tuple_value
+ if isinstance(tuple_value, _expr.Call):
+ tuple_value = _expr.Call(
+ tuple_value.op, tuple_value.args, tuple_value.attrs,
tuple_value.type_args, span
+ )
+ sym = _expr.TupleWrapper(tuple_value, sym.size)
+ return sym
+
+
def is_tensor_list_constuctor(tf_node):
"""Check whether is tensor list constructor node."""
return tf_node.op == "TensorListReserve"
diff --git a/python/tvm/relay/frontend/tflite.py
b/python/tvm/relay/frontend/tflite.py
index 12296bd..b675dd5 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -32,7 +32,6 @@ from .. import op as _op
from .. import qnn as _qnn
from .common import ExprTable
from .common import infer_shape as _infer_shape
-from .common import set_span
from .common import to_int_list
from .tflite_flexbuffer import FlexBufferDecoder
@@ -240,17 +239,12 @@ class OperatorConverter(object):
if len(output_tensors) == 1:
tensor_idx = output_tensors[0].tensor_idx
- curr_output = get_tensor_name(self.subgraph, tensor_idx)
- ret = set_span(ret, "location: {}, output_name:
{}".format(op_idx, curr_output))
- self.exp_tab.set_expr(curr_output, ret)
+ self.exp_tab.set_expr(get_tensor_name(self.subgraph,
tensor_idx), ret)
else:
- out_names = []
- for output_tensor in output_tensors:
- out_names.append(get_tensor_name(self.subgraph,
output_tensor.tensor_idx))
- curr_output = ", ".join(out_names)
- ret = set_span(ret, "location: {}, output_name:
{}".format(op_idx, curr_output))
- for idx, out_name in enumerate(out_names):
- self.exp_tab.set_expr(out_name, ret[idx])
+ for idx, output_tensor in enumerate(output_tensors):
+ self.exp_tab.set_expr(
+ get_tensor_name(self.subgraph,
output_tensor.tensor_idx), ret[idx]
+ )
def get_op_code_str(self, op):
"""Get TFLite ops string representation"""
diff --git a/src/printer/relay_text_printer.cc
b/src/printer/relay_text_printer.cc
index 7654ef1..fdc6c37 100644
--- a/src/printer/relay_text_printer.cc
+++ b/src/printer/relay_text_printer.cc
@@ -389,21 +389,12 @@ Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) {
if (op->fields.size() == 1) {
doc << ",";
}
- doc << ")";
- if (op->span.defined()) {
- doc << " /* " << PrintSpan(op->span) << " */";
- }
- return doc;
+ return doc << ")";
}
Doc RelayTextPrinter::VisitExpr_(const TupleGetItemNode* op) {
Doc doc;
- doc << Print(op->tuple) << "." << op->index;
-
- if (op->span.defined()) {
- doc << " /* " << PrintSpan(op->span) << " */";
- }
- return doc;
+ return doc << Print(op->tuple) << "." << op->index;
}
Doc RelayTextPrinter::VisitExpr_(const IfNode* op) {
@@ -977,13 +968,11 @@ Doc RelayTextPrinter::PrintMapAsAttributeValue(const
Map<ObjectRef, ObjectRef>&
return doc;
}
-Doc RelayTextPrinter::PrintSpan(const Span& span, bool include_spans) {
+Doc RelayTextPrinter::PrintSpan(const Span& span) {
Doc doc;
- if (include_spans) {
- const auto* span_node = span.as<SpanNode>();
- ICHECK(span_node);
- doc << span_node->source_name->name;
- }
+ const auto* span_node = span.as<SpanNode>();
+ ICHECK(span_node);
+ doc << span_node->source_name->name;
return doc;
}
diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h
index ca46700..a4d0ff3 100644
--- a/src/printer/text_printer.h
+++ b/src/printer/text_printer.h
@@ -113,7 +113,7 @@ class RelayTextPrinter : public ExprFunctor<Doc(const
Expr&)>,
*/
Doc PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>& map);
- Doc PrintSpan(const Span& span, bool include_spans = true);
+ Doc PrintSpan(const Span& span);
Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false);
diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc
index 64d921e..73ae3fa 100644
--- a/src/relay/ir/expr.cc
+++ b/src/relay/ir/expr.cc
@@ -375,8 +375,8 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item,
Optional<Expr> opt_tuple,
TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
-TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int
index, Span span) {
- return TupleGetItem(tuple, index, span);
+TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int
index) {
+ return TupleGetItem(tuple, index);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 2c07094..3fbef49 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -247,53 +247,6 @@ def verify_model(
torch.cuda.empty_cache()
-def verify_span(model_name, input_data=[], custom_convert_map={}):
- if isinstance(model_name, str):
- baseline_model, baseline_input = load_model(model_name)
- elif isinstance(input_data, list):
- baseline_model = model_name
- baseline_input = input_data
- elif isinstance(input_data, torch.Tensor) or len(input_data.shape) == 0:
- baseline_model = model_name
- baseline_input = [input_data]
- else:
- assert False, "Unexpected input format"
-
- trace = torch.jit.trace(baseline_model, [input.clone() for input in
baseline_input])
- if isinstance(baseline_model, torch.nn.Module):
- trace = trace.float().eval()
-
- if torch.cuda.is_available():
- trace = trace.cuda()
- else:
- trace = trace.cpu()
-
- input_names = ["input{}".format(idx) for idx, inp in
enumerate(baseline_input)]
- input_shapes = list(zip(input_names, [inp.shape for inp in
baseline_input]))
- mod, params = relay.frontend.from_pytorch(trace, input_shapes,
custom_convert_map)
-
- # collect fail cases for the convenience of further improvement
- fail_cases = []
- mod_main_start = False
- for line in str(mod.__str__).split("\n"):
- if "@main" in line:
- mod_main_start = True
- continue
-
- if mod_main_start == True:
- if "}" == line:
- break
- elif not ("/*" in line and "*/" in line):
- fail_cases.append(line)
-
- print(fail_cases)
- assert len(fail_cases) == 0
-
-
-def test_span():
- verify_span("resnet18")
-
-
# Single operator tests
@tvm.testing.uses_gpu
def test_forward_pixel_shuffle():
diff --git a/tests/python/frontend/tensorflow/test_forward.py
b/tests/python/frontend/tensorflow/test_forward.py
index c76803b..a5a67e1 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -298,60 +298,6 @@ def is_gpu_available():
return False
-def verify_span(mod):
- # collect fail cases for the convenience of further improvement
- fail_cases = []
- mod_main_start = False
- for line in str(mod.__str__).split("\n"):
- if "@main" in line:
- mod_main_start = True
- continue
-
- if mod_main_start == True:
- if "}" == line:
- break
- elif not ("/*" in line and "*/" in line):
- fail_cases.append(line)
-
- print(fail_cases)
- assert len(fail_cases) == 0
-
-
-def simple_model():
- input_node = tf.placeholder(shape=[None, None, 3, 1], dtype=np.float32,
name="input")
-
- shape = tf.shape(input_node)
- stack = tf.stack([shape[0], 3, 3], axis=0)
- output_node = tf.reshape(input_node, stack, name="output")
- return output_node
-
-
-#######################################################################
-# Span fill up
-# -------
-def test_span_complement_simple_model():
- with tf.Graph().as_default() as graph:
- model_graph = simple_model()
- graph_def = graph.as_graph_def()
-
- graph_def = tf_testing.ProcessGraphDefParam(graph_def)
-
- mod, params = relay.frontend.from_tensorflow(graph_def,
shape={"input:0", (1, 3, 3, 1)})
- verify_span(mod)
-
-
-def test_span_complement_big_model():
- with tf.Graph().as_default() as graph:
- graph_def =
tf_testing.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb")
- # Call the utility to import the graph definition into default graph.
- graph_def = tf_testing.ProcessGraphDefParam(graph_def)
-
- mod, params = relay.frontend.from_tensorflow(
- graph_def, shape={"input_tensor:0", (128, 224, 224, 3)}
- )
- verify_span(mod)
-
-
#######################################################################
# Pooling
# -------
diff --git a/tests/python/frontend/tensorflow2/test_sequential_models.py
b/tests/python/frontend/tensorflow2/test_sequential_models.py
index b76b4a7..1b5a634 100644
--- a/tests/python/frontend/tensorflow2/test_sequential_models.py
+++ b/tests/python/frontend/tensorflow2/test_sequential_models.py
@@ -26,25 +26,6 @@ from tensorflow.python.framework.convert_to_constants import
convert_variables_t
from common import compare_tf_tvm
from common import run_tf_code
-from tvm.relay.frontend.tensorflow2 import from_tensorflow
-
-
-def verify_span(mod):
- fail_cases = []
- mod_main_start = False
- for line in str(mod.__str__).split("\n"):
- if "@main" in line:
- mod_main_start = True
- continue
-
- if mod_main_start == True:
- if "}" == line:
- break
- elif not ("/*" in line and "*/" in line):
- fail_cases.append(line)
-
- print(fail_cases)
- assert len(fail_cases) == 0
def run_sequential_model(model_fn, input_shape):
@@ -67,10 +48,7 @@ def run_sequential_model(model_fn, input_shape):
gdef = f.graph.as_graph_def(add_shapes=True)
return gdef, _input, _output
- gdef, _input, _output = model_graph(model_fn, input_shape)
- mod, _ = from_tensorflow(gdef)
- compare_tf_tvm(gdef, _input, _output, runtime="vm")
- verify_span(mod)
+ compare_tf_tvm(*model_graph(model_fn, input_shape), runtime="vm")
def test_dense_model():
diff --git a/tests/python/frontend/tflite/test_forward.py
b/tests/python/frontend/tflite/test_forward.py
index 77acce4..60af94b 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -259,59 +259,6 @@ def run_tflite_graph(tflite_model_buf, input_data):
return tflite_output
-def run_span_verification(
- tflite_model_buf,
- input_data,
- input_node,
- num_output=1,
- target="llvm",
- out_names=None,
- mode="graph_executor",
-):
- """Generic function to compile on relay and execute on tvm"""
- # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
- try:
- import tflite.Model
-
- tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
- except AttributeError:
- import tflite
-
- tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
- except ImportError:
- raise ImportError("The tflite package must be installed")
-
- input_data = convert_to_list(input_data)
- input_node = convert_to_list(input_node)
-
- shape_dict = {}
- dtype_dict = {}
- for i, e in enumerate(input_node):
- shape_dict[e] = input_data[i].shape
- dtype_dict[e] = input_data[i].dtype.name
-
- mod, _ = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict,
dtype_dict=dtype_dict)
- verify_span(mod)
-
-
-def verify_span(mod):
- fail_cases = []
- mod_main_start = False
- for line in str(mod.__str__).split("\n"):
- if "@main" in line:
- mod_main_start = True
- continue
-
- if mod_main_start == True:
- if "}" == line:
- break
- elif not ("/*" in line and "*/" in line):
- fail_cases.append(line)
-
- print(fail_cases)
- assert len(fail_cases) == 0
-
-
def compare_tflite_with_tvm(
in_data,
in_name,
@@ -4620,7 +4567,6 @@ def test_forward_tflite2_qnn_resnet50():
tflite_output = run_tflite_graph(tflite_model_buf, data)
tflite_predictions = np.squeeze(tflite_output)
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
- run_span_verification(tflite_model_buf, np.array(data), "input_1")
tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), "input_1")
tvm_predictions = np.squeeze(tvm_output)
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]