This is an automated email from the ASF dual-hosted git repository.
zhaowu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ce108c1 [Frontend] Add Span filling for frontends to Relay (#9723)
ce108c1 is described below
commit ce108c1f53235a483eb11dffffb8770642907642
Author: Chun-I Tsai <[email protected]>
AuthorDate: Tue Dec 28 12:53:18 2021 +0800
[Frontend] Add Span filling for frontends to Relay (#9723)
* [Frontend] Add Span filling for frontends to Relay
* Add a common span filling feature for tf1/2, tflite and pytorch.
* Add test case for Span filling in each frontend.
* Expose Tuple and TupleGetItem to python end
* [Frontend] Add Span filling for frontends to Relay
* Fix lint errors
* Change default string of scope_part in Pytorch
* Reorder the span position for one to many conversion
* [Frontend] Add Span filling for frontends to Relay
* nit fixed
* Add a bool flag to control print span
* refactor pytorch get span to a birefer way
* [Frontend] Add Span filling for frontends to Relay
* Add one more condition for spanFller
* Refine the format for those pytorch node without scopeName
* [Frontend] Add Span filling for frontends to Relay
* Fix lint
---
python/tvm/relay/expr.py | 7 ++-
python/tvm/relay/frontend/common.py | 53 +++++++++++++++++++++
python/tvm/relay/frontend/pytorch.py | 19 ++++++++
python/tvm/relay/frontend/tensorflow.py | 17 +------
python/tvm/relay/frontend/tensorflow2.py | 17 +------
python/tvm/relay/frontend/tflite.py | 16 +++++--
src/printer/relay_text_printer.cc | 23 ++++++---
src/printer/text_printer.h | 2 +-
src/relay/ir/expr.cc | 4 +-
tests/python/frontend/pytorch/test_forward.py | 47 +++++++++++++++++++
tests/python/frontend/tensorflow/test_forward.py | 54 ++++++++++++++++++++++
.../frontend/tensorflow2/test_sequential_models.py | 24 +++++++++-
tests/python/frontend/tflite/test_forward.py | 54 ++++++++++++++++++++++
13 files changed, 289 insertions(+), 48 deletions(-)
diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py
index 811e205..598354e 100644
--- a/python/tvm/relay/expr.py
+++ b/python/tvm/relay/expr.py
@@ -316,10 +316,13 @@ class TupleGetItem(ExprWithOp):
index: int
The index.
+
+ span: Optional[tvm.relay.Span]
+ Span that points to original source code
"""
- def __init__(self, tuple_value, index):
- self.__init_handle_by_constructor__(_ffi_api.TupleGetItem,
tuple_value, index)
+ def __init__(self, tuple_value, index, span=None):
+ self.__init_handle_by_constructor__(_ffi_api.TupleGetItem,
tuple_value, index, span)
@tvm._ffi.register_object("relay.RefCreate")
diff --git a/python/tvm/relay/frontend/common.py
b/python/tvm/relay/frontend/common.py
index 407afc4..be3d5ae 100755
--- a/python/tvm/relay/frontend/common.py
+++ b/python/tvm/relay/frontend/common.py
@@ -25,6 +25,7 @@ from tvm.ir import IRModule
from tvm.topi.utils import get_const_tuple
from .. import expr as _expr
+from ..expr_functor import ExprMutator
from .. import function as _function
from .. import transform as _transform
from .. import op as _op
@@ -954,3 +955,55 @@ def try_resolve_var_to_const(x, graph_params):
return _op.const(value, dtype)
return x
+
+
+def set_span(sym, node_name):
+ """Set up the sapn of relay expression(s) while converting OP"""
+
+ class SpanFiller(ExprMutator):
+ """SpanFiller"""
+
+ def __init__(self, node_name, suffix_str="_PART_"):
+ ExprMutator.__init__(self)
+ self.node_name = node_name
+ self.suffix_str = suffix_str
+ self.counter = 0
+ self.distance_from_leaf = -1
+
+ def _create_span(self):
+ if self.distance_from_leaf == 0:
+ return tvm.relay.Span(tvm.relay.SourceName(self.node_name), 0,
0, 0, 0)
+ self.distance_from_leaf -= 1
+ span_str = "{}{}{}".format(self.node_name, self.suffix_str,
str(self.counter))
+ self.counter += 1
+ return tvm.relay.Span(tvm.relay.SourceName(span_str), 0, 0, 0, 0)
+
+ def visit_call(self, call):
+ if call.span is None:
+ self.distance_from_leaf += 1
+ new_args = [self.visit(arg) for arg in call.args]
+ return _expr.Call(
+ call.op, new_args, call.attrs, call.type_args,
self._create_span()
+ )
+ return call
+
+ def visit_tuple(self, tup):
+ if tup.span is None:
+ self.distance_from_leaf += 1
+ return _expr.Tuple([self.visit(field) for field in
tup.fields], self._create_span())
+ return tup
+
+ def visit_tuple_getitem(self, op):
+ if op.span is None:
+ self.distance_from_leaf += 1
+ return _expr.TupleGetItem(self.visit(op.tuple_value),
op.index, self._create_span())
+ return op
+
+ def fill(self, sym):
+ if isinstance(sym, _expr.TupleWrapper):
+ return _expr.TupleWrapper(self.visit(sym.tuple_value),
sym.size)
+ if isinstance(sym, _expr.RelayExpr):
+ return self.visit(sym)
+ return sym
+
+ return SpanFiller(node_name).fill(sym)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 24ccad5..6e8ad68 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -45,6 +45,7 @@ from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
from .common import infer_value_simulated as _infer_value_simulated
from .common import lstm_cell, try_infer_value, unbind
+from .common import set_span
from .pytorch_utils import is_version_greater_than
__all__ = ["from_pytorch"]
@@ -3271,6 +3272,9 @@ class PyTorchOpConverter:
def convert_operators(self, operators, outputs, ret_names):
"""Convert each Torch IR operators to Relay equivalent"""
+ # an op node might not belong to any of scope in trace info natively
+ # use a cunter to prevent from messing up its scope in span
+ empty_counter = 0
for node_name, op_node in operators:
operator = op_node.kind()
inputs = _get_op_inputs(op_node, outputs)
@@ -3308,6 +3312,9 @@ class PyTorchOpConverter:
relay_out = relay_op(
inputs, _get_input_types(op_node, outputs,
default_dtype=self.default_dtype)
)
+ span_str, empty_counter = self._get_torch_span(op_node,
empty_counter)
+ relay_out = set_span(relay_out, span_str)
+
self.record_output_type(relay_out)
if isinstance(relay_out, tuple):
@@ -3321,6 +3328,18 @@ class PyTorchOpConverter:
return [_wrap_const(outputs[ret_name]) for ret_name in ret_names]
+ def _get_torch_span(self, node, empty_counter):
+ # torch span looks like
+ # %input.5 : Float(...) = aten::relu_(%input.3), scope: __module.relu
# ${torch}/nn file
+ # the scope part might not exist
+ if node.scopeName():
+ scope_name_str = "jit._trace.TopLevelTracedModule: " +
node.scopeName()
+ else:
+ scope_name_str = "warning: no trace info " + str(empty_counter)
+ empty_counter += 1
+ span_str = "C.graph: {}, {}".format(node.kind(), scope_name_str)
+ return span_str, empty_counter
+
def _pytorch_result_type(dtypes, non_tensor_inputs):
"""This promotes TVM dtypes like PyTorch would"""
diff --git a/python/tvm/relay/frontend/tensorflow.py
b/python/tvm/relay/frontend/tensorflow.py
index d35e0e1..c2aa5a1 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -37,6 +37,7 @@ from .common import get_relay_op
from .common import infer_type as _infer_type
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
+from .common import set_span
from .tensorflow_ops import _convert_map
from .tensorflow_ops import _need_prelude_for_shape_inference
@@ -1028,24 +1029,10 @@ class GraphProto(object):
else:
raise NotImplementedError("Operator {} not
implemented.".format(op_name))
- sym = self._set_span(sym, node_name)
+ sym = set_span(sym, node_name)
return sym
- @staticmethod
- def _set_span(sym, node_name):
- span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0)
- if isinstance(sym, _expr.Call) and sym.span is None:
- sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span)
- elif isinstance(sym, _expr.TupleWrapper):
- tuple_value = sym.tuple_value
- if isinstance(tuple_value, _expr.Call) and tuple_value.span is
None:
- tuple_value = _expr.Call(
- tuple_value.op, tuple_value.args, tuple_value.attrs,
tuple_value.type_args, span
- )
- sym = _expr.TupleWrapper(tuple_value, sym.size)
- return sym
-
def _licm_construct(self, loop_name, node_name):
"""Construct a node by considering whether it is
loop invariant with the given while loop. If yes, we
diff --git a/python/tvm/relay/frontend/tensorflow2.py
b/python/tvm/relay/frontend/tensorflow2.py
index 465f530..2c8b7d4 100644
--- a/python/tvm/relay/frontend/tensorflow2.py
+++ b/python/tvm/relay/frontend/tensorflow2.py
@@ -36,6 +36,7 @@ from .. import analysis
from .. import function as _function
from ..loops import while_loop as _while_loop
from .common import infer_type as _infer_type
+from .common import set_span
from .tensorflow_ops import _convert_map as _convert_map_common
from .tensorflow_ops import _get_more_static_shape_rank
@@ -58,22 +59,6 @@ def _infer_type_with_prelude(val, prelude):
return body.checked_type
-def set_span(sym, node_name):
- """set span of symbol"""
-
- span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0)
- if isinstance(sym, _expr.Call):
- sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span)
- elif isinstance(sym, _expr.TupleWrapper):
- tuple_value = sym.tuple_value
- if isinstance(tuple_value, _expr.Call):
- tuple_value = _expr.Call(
- tuple_value.op, tuple_value.args, tuple_value.attrs,
tuple_value.type_args, span
- )
- sym = _expr.TupleWrapper(tuple_value, sym.size)
- return sym
-
-
def is_tensor_list_constuctor(tf_node):
"""Check whether is tensor list constructor node."""
return tf_node.op == "TensorListReserve"
diff --git a/python/tvm/relay/frontend/tflite.py
b/python/tvm/relay/frontend/tflite.py
index f0f20e1..b0b2bd3 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -32,6 +32,7 @@ from .. import op as _op
from .. import qnn as _qnn
from .common import ExprTable
from .common import infer_shape as _infer_shape
+from .common import set_span
from .common import to_int_list
from .tflite_flexbuffer import FlexBufferDecoder
@@ -239,12 +240,17 @@ class OperatorConverter(object):
if len(output_tensors) == 1:
tensor_idx = output_tensors[0].tensor_idx
- self.exp_tab.set_expr(get_tensor_name(self.subgraph,
tensor_idx), ret)
+ curr_output = get_tensor_name(self.subgraph, tensor_idx)
+ ret = set_span(ret, "location: {}, output_name:
{}".format(op_idx, curr_output))
+ self.exp_tab.set_expr(curr_output, ret)
else:
- for idx, output_tensor in enumerate(output_tensors):
- self.exp_tab.set_expr(
- get_tensor_name(self.subgraph,
output_tensor.tensor_idx), ret[idx]
- )
+ out_names = []
+ for output_tensor in output_tensors:
+ out_names.append(get_tensor_name(self.subgraph,
output_tensor.tensor_idx))
+ curr_output = ", ".join(out_names)
+ ret = set_span(ret, "location: {}, output_name:
{}".format(op_idx, curr_output))
+ for idx, out_name in enumerate(out_names):
+ self.exp_tab.set_expr(out_name, ret[idx])
def get_op_code_str(self, op):
"""Get TFLite ops string representation"""
diff --git a/src/printer/relay_text_printer.cc
b/src/printer/relay_text_printer.cc
index fdc6c37..7654ef1 100644
--- a/src/printer/relay_text_printer.cc
+++ b/src/printer/relay_text_printer.cc
@@ -389,12 +389,21 @@ Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) {
if (op->fields.size() == 1) {
doc << ",";
}
- return doc << ")";
+ doc << ")";
+ if (op->span.defined()) {
+ doc << " /* " << PrintSpan(op->span) << " */";
+ }
+ return doc;
}
Doc RelayTextPrinter::VisitExpr_(const TupleGetItemNode* op) {
Doc doc;
- return doc << Print(op->tuple) << "." << op->index;
+ doc << Print(op->tuple) << "." << op->index;
+
+ if (op->span.defined()) {
+ doc << " /* " << PrintSpan(op->span) << " */";
+ }
+ return doc;
}
Doc RelayTextPrinter::VisitExpr_(const IfNode* op) {
@@ -968,11 +977,13 @@ Doc RelayTextPrinter::PrintMapAsAttributeValue(const
Map<ObjectRef, ObjectRef>&
return doc;
}
-Doc RelayTextPrinter::PrintSpan(const Span& span) {
+Doc RelayTextPrinter::PrintSpan(const Span& span, bool include_spans) {
Doc doc;
- const auto* span_node = span.as<SpanNode>();
- ICHECK(span_node);
- doc << span_node->source_name->name;
+ if (include_spans) {
+ const auto* span_node = span.as<SpanNode>();
+ ICHECK(span_node);
+ doc << span_node->source_name->name;
+ }
return doc;
}
diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h
index a4d0ff3..ca46700 100644
--- a/src/printer/text_printer.h
+++ b/src/printer/text_printer.h
@@ -113,7 +113,7 @@ class RelayTextPrinter : public ExprFunctor<Doc(const
Expr&)>,
*/
Doc PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>& map);
- Doc PrintSpan(const Span& span);
+ Doc PrintSpan(const Span& span, bool include_spans = true);
Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false);
diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc
index b680a49..f8cb4f0 100644
--- a/src/relay/ir/expr.cc
+++ b/src/relay/ir/expr.cc
@@ -362,8 +362,8 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item,
Optional<Expr> opt_tuple,
TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
-TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int
index) {
- return TupleGetItem(tuple, index);
+TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int
index, Span span) {
+ return TupleGetItem(tuple, index, span);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 86970bf..a64fa0b 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -247,6 +247,53 @@ def verify_model(
torch.cuda.empty_cache()
+def verify_span(model_name, input_data=[], custom_convert_map={}):
+ if isinstance(model_name, str):
+ baseline_model, baseline_input = load_model(model_name)
+ elif isinstance(input_data, list):
+ baseline_model = model_name
+ baseline_input = input_data
+ elif isinstance(input_data, torch.Tensor) or len(input_data.shape) == 0:
+ baseline_model = model_name
+ baseline_input = [input_data]
+ else:
+ assert False, "Unexpected input format"
+
+ trace = torch.jit.trace(baseline_model, [input.clone() for input in
baseline_input])
+ if isinstance(baseline_model, torch.nn.Module):
+ trace = trace.float().eval()
+
+ if torch.cuda.is_available():
+ trace = trace.cuda()
+ else:
+ trace = trace.cpu()
+
+ input_names = ["input{}".format(idx) for idx, inp in
enumerate(baseline_input)]
+ input_shapes = list(zip(input_names, [inp.shape for inp in
baseline_input]))
+ mod, params = relay.frontend.from_pytorch(trace, input_shapes,
custom_convert_map)
+
+ # collect fail cases for the convenience of further improvement
+ fail_cases = []
+ mod_main_start = False
+ for line in str(mod.__str__).split("\n"):
+ if "@main" in line:
+ mod_main_start = True
+ continue
+
+ if mod_main_start == True:
+ if "}" == line:
+ break
+ elif not ("/*" in line and "*/" in line):
+ fail_cases.append(line)
+
+ print(fail_cases)
+ assert len(fail_cases) == 0
+
+
+def test_span():
+ verify_span("resnet18")
+
+
# Single operator tests
@tvm.testing.uses_gpu
def test_forward_pixel_shuffle():
diff --git a/tests/python/frontend/tensorflow/test_forward.py
b/tests/python/frontend/tensorflow/test_forward.py
index 338d219..be32ca3 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -298,6 +298,60 @@ def is_gpu_available():
return False
+def verify_span(mod):
+ # collect fail cases for the convenience of further improvement
+ fail_cases = []
+ mod_main_start = False
+ for line in str(mod.__str__).split("\n"):
+ if "@main" in line:
+ mod_main_start = True
+ continue
+
+ if mod_main_start == True:
+ if "}" == line:
+ break
+ elif not ("/*" in line and "*/" in line):
+ fail_cases.append(line)
+
+ print(fail_cases)
+ assert len(fail_cases) == 0
+
+
+def simple_model():
+ input_node = tf.placeholder(shape=[None, None, 3, 1], dtype=np.float32,
name="input")
+
+ shape = tf.shape(input_node)
+ stack = tf.stack([shape[0], 3, 3], axis=0)
+ output_node = tf.reshape(input_node, stack, name="output")
+ return output_node
+
+
+#######################################################################
+# Span fill up
+# -------
+def test_span_complement_simple_model():
+ with tf.Graph().as_default() as graph:
+ model_graph = simple_model()
+ graph_def = graph.as_graph_def()
+
+ graph_def = tf_testing.ProcessGraphDefParam(graph_def)
+
+ mod, params = relay.frontend.from_tensorflow(graph_def,
shape={"input:0", (1, 3, 3, 1)})
+ verify_span(mod)
+
+
+def test_span_complement_big_model():
+ with tf.Graph().as_default() as graph:
+ graph_def =
tf_testing.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb")
+ # Call the utility to import the graph definition into default graph.
+ graph_def = tf_testing.ProcessGraphDefParam(graph_def)
+
+ mod, params = relay.frontend.from_tensorflow(
+ graph_def, shape={"input_tensor:0", (128, 224, 224, 3)}
+ )
+ verify_span(mod)
+
+
#######################################################################
# Pooling
# -------
diff --git a/tests/python/frontend/tensorflow2/test_sequential_models.py
b/tests/python/frontend/tensorflow2/test_sequential_models.py
index 1b5a634..b76b4a7 100644
--- a/tests/python/frontend/tensorflow2/test_sequential_models.py
+++ b/tests/python/frontend/tensorflow2/test_sequential_models.py
@@ -26,6 +26,25 @@ from tensorflow.python.framework.convert_to_constants import
convert_variables_t
from common import compare_tf_tvm
from common import run_tf_code
+from tvm.relay.frontend.tensorflow2 import from_tensorflow
+
+
+def verify_span(mod):
+ fail_cases = []
+ mod_main_start = False
+ for line in str(mod.__str__).split("\n"):
+ if "@main" in line:
+ mod_main_start = True
+ continue
+
+ if mod_main_start == True:
+ if "}" == line:
+ break
+ elif not ("/*" in line and "*/" in line):
+ fail_cases.append(line)
+
+ print(fail_cases)
+ assert len(fail_cases) == 0
def run_sequential_model(model_fn, input_shape):
@@ -48,7 +67,10 @@ def run_sequential_model(model_fn, input_shape):
gdef = f.graph.as_graph_def(add_shapes=True)
return gdef, _input, _output
- compare_tf_tvm(*model_graph(model_fn, input_shape), runtime="vm")
+ gdef, _input, _output = model_graph(model_fn, input_shape)
+ mod, _ = from_tensorflow(gdef)
+ compare_tf_tvm(gdef, _input, _output, runtime="vm")
+ verify_span(mod)
def test_dense_model():
diff --git a/tests/python/frontend/tflite/test_forward.py
b/tests/python/frontend/tflite/test_forward.py
index 545315a..d234cd1 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -259,6 +259,59 @@ def run_tflite_graph(tflite_model_buf, input_data):
return tflite_output
+def run_span_verification(
+ tflite_model_buf,
+ input_data,
+ input_node,
+ num_output=1,
+ target="llvm",
+ out_names=None,
+ mode="graph_executor",
+):
+ """Generic function to compile on relay and execute on tvm"""
+ # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
+ try:
+ import tflite.Model
+
+ tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
+ except AttributeError:
+ import tflite
+
+ tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
+ except ImportError:
+ raise ImportError("The tflite package must be installed")
+
+ input_data = convert_to_list(input_data)
+ input_node = convert_to_list(input_node)
+
+ shape_dict = {}
+ dtype_dict = {}
+ for i, e in enumerate(input_node):
+ shape_dict[e] = input_data[i].shape
+ dtype_dict[e] = input_data[i].dtype.name
+
+ mod, _ = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict,
dtype_dict=dtype_dict)
+ verify_span(mod)
+
+
+def verify_span(mod):
+ fail_cases = []
+ mod_main_start = False
+ for line in str(mod.__str__).split("\n"):
+ if "@main" in line:
+ mod_main_start = True
+ continue
+
+ if mod_main_start == True:
+ if "}" == line:
+ break
+ elif not ("/*" in line and "*/" in line):
+ fail_cases.append(line)
+
+ print(fail_cases)
+ assert len(fail_cases) == 0
+
+
def compare_tflite_with_tvm(
in_data,
in_name,
@@ -4507,6 +4560,7 @@ def test_forward_tflite2_qnn_resnet50():
tflite_output = run_tflite_graph(tflite_model_buf, data)
tflite_predictions = np.squeeze(tflite_output)
tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
+ run_span_verification(tflite_model_buf, np.array(data), "input_1")
tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), "input_1")
tvm_predictions = np.squeeze(tvm_output)
tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]