rohanmukh commented on a change in pull request #8142: URL: https://github.com/apache/tvm/pull/8142#discussion_r644271159
########## File path: python/tvm/relay/frontend/tensorflow2.py ########## @@ -0,0 +1,686 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except +"""Tensorflow2.x graph to relay converter. + +If model is constructed using tf2.x API, then use this converter: + from tvm.relay.frontend.tensorflow2 import from_tensorflow +Otherwise use the tf1.x converter: + from tvm.relay.frontend.tensorflow import from_tensorflow + +""" + +import numpy as np +from tensorflow.python.framework import function_def_to_graph +from tensorflow.python.framework import tensor_util +from tensorflow.python.framework import dtypes + + +import tvm +from tvm.relay.transform import InferType +from tvm.relay.prelude import Prelude +from tvm.ir import IRModule +from .. import expr as _expr +from .. import analysis +from .. import function as _function +from ..loops import while_loop as _while_loop +from .common import infer_type as _infer_type + +from .tensorflow import _convert_map as _convert_map_tf1 +from .tensorflow import _need_prelude_for_shape_inference + +from ..ty import Any + +__all__ = ["from_tensorflow"] + + +def _infer_type_with_prelude(val, prelude): + body = _infer_type(val, prelude.mod) + return body.checked_type + + +def set_span(sym, node_name): + """set span of symbol""" + + span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0) + if isinstance(sym, _expr.Call): + sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span) + elif isinstance(sym, _expr.TupleWrapper): + tuple_value = sym.tuple_value + if isinstance(tuple_value, _expr.Call): + tuple_value = _expr.Call( + tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span + ) + sym = _expr.TupleWrapper(tuple_value, sym.size) + return sym + + +def convert_const_node(node, shape): + """convert tf const node into relay const or var""" + + # get the value of the constant + tensor_value = node.attr["value"].tensor + np_array = tensor_util.MakeNdarray(tensor_value) + + if np_array.dtype == np.dtype(object): + if shape and node.name in shape: + var_shape = shape[node.name] + else: + var_shape = tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape) + param = None + sym = [_expr.var(node.name, shape=var_shape, dtype="uint8")] + return sym, param + + if len(np_array.shape) == 0: + param = None + sym = [tvm.relay.const(np_array, np_array.dtype)] + else: + param = tvm.nd.array(np_array) + sym = [_expr.var(node.name, shape=param.shape, dtype=param.dtype)] + + return sym, param + + +def get_attr(buf): + """convert value of a node attribute. node attribute is part of a node in a graph. + // tensorflow/core/framework/attr_value.proto + message AttrValue { + oneof value { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + DataType type = 6; // "type" + TensorShapeProto shape = 7; // "shape" + TensorProto tensor = 8; // "tensor" + ListValue list = 1; // any "list(...)" } + } + Parameters + ---------- + buf: attrvalue protobuf. <class 'tensorflow.core.framework.attr_value_pb2.AttrValue'> + Returns + ------- + The value of the attr, as a Python object. + """ + fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"] + + x = buf + + ret = [] + + if not x.WhichOneof("value"): + return ret + + if x.HasField("list"): + for f in fields: + if getattr(x.list, f): + if f == "type": + ret += [dtypes.as_dtype(x) for x in list(getattr(x.list, f))] + else: + ret += list(getattr(x.list, f)) + else: + for f in fields: + if x.HasField(f): + if f == "type": + ret = dtypes.as_dtype(getattr(x, f)) + else: + ret = getattr(x, f) + return ret + + +def parse_attr(attr_proto): + """Convert node attributes (a serialized map of key-value pairs) in a node to a dict + Parameters + ---------- + attr_proto: <class 'google.protobuf.pyext._message.MessageMapContainer'> + attributes of a tf node + protobuf message format: + // tensorflow/core/framework/node_def.proto + message NodeDef { + map<string, AttrValue> attr = 5; + } + Returns + ------- + Dict {string: python object} + Examples + -------- + attributes in following node converted to {'_user_specified_name': b'x', 'dtype': tf.float32 } + node { + name: "x" + op: "Placeholder" + attr { + key: "_user_specified_name" + value { + s: "x" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + """ + attrs = {} + for key, value in attr_proto.items(): + attrs[key] = get_attr(value) + + return attrs + + +def convert_place_holder(shape, node, in_type=None): + """convert tf place holder into relay var. + + Examples + -------- + a tf place holder with name "x" is converted to [Var(x, ty=TensorType([], float32))] + """ + + if shape and node.name in shape: + input_shape = list(shape[node.name]) + else: + input_shape = tensor_util.TensorShapeProtoToList(node.attr["shape"].shape) + for idx, dim in enumerate(input_shape): + if dim < 0: + input_shape[idx] = Any() + attr = parse_attr(node.attr) + if in_type is not None: + sym = [_expr.var(node.name, type_annotation=in_type)] + else: + sym = [_expr.var(node.name, shape=input_shape, dtype=attr["dtype"].name)] + return input_shape, sym + Review comment: The utility functions on top are imported from `tensorflow.py`. Since they were methods inside class they had to extracted. Future refactor PR can address this to allow better code reuse. However all TF1 unit tests and models need to be tested alongside for such a refactor. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
