This is an automated email from the ASF dual-hosted git repository.
kevinthesun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 93f3010 Move infer_value to _get_list_param (#8051)
93f3010 is described below
commit 93f301059308d8d9fc85e892cd018a08c1a5d2d5
Author: Trevor Morris <[email protected]>
AuthorDate: Tue May 18 14:25:25 2021 -0700
Move infer_value to _get_list_param (#8051)
---
python/tvm/relay/frontend/tensorflow.py | 83 ++++++++++++---------------------
1 file changed, 31 insertions(+), 52 deletions(-)
diff --git a/python/tvm/relay/frontend/tensorflow.py
b/python/tvm/relay/frontend/tensorflow.py
index 4bd332f..5b0507b 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -113,8 +113,11 @@ def _get_num_param(params, input_node):
return _get_param(params, input_node).item()
-def _get_list_param(params, input_node):
- return _get_param(params, input_node).tolist()
+def _get_list_param(params, input_node, mod):
+ try:
+ return _get_param(params, input_node).tolist()
+ except (IndexError, KeyError, AttributeError):
+ return _infer_value(input_node, params, mod).asnumpy().tolist()
def _get_tuple_param(params, input_node):
@@ -913,10 +916,7 @@ def _crop_and_resize():
def _impl(inputs, attr, params, mod):
# input image is a 4-D tensor of shape [batch, image_height,
image_width, depth]
# boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2,
x2]
- try:
- crop_size = _get_list_param(params, inputs[3])
- except (IndexError, KeyError):
- crop_size = _infer_value(inputs[3], params, mod).asnumpy().tolist()
+ crop_size = _get_list_param(params, inputs[3], mod)
method = attr["method"].decode()
method = "nearest_neighbor" if method == "nearest" else method
@@ -1658,7 +1658,7 @@ def _tile():
np_reps = _infer_value(reps_input, params, mod).asnumpy()
reps = [np_reps.flatten()[i] for i in
range(np_reps.flatten().shape[0])]
else:
- reps = _get_list_param(params, reps_input)
+ reps = _get_list_param(params, reps_input, mod)
new_input = [inputs.pop(0)]
return AttrCvt(op_name="tile", extras={"reps": tuple(reps)},
ignores=["Tmultiples"])(
@@ -1671,21 +1671,15 @@ def _tile():
def _slice():
def _impl(inputs, attr, params, mod):
try:
- begin = _get_list_param(params, inputs[1])
- except (IndexError, KeyError, AttributeError):
+ begin = _get_list_param(params, inputs[1], mod)
+ except Exception:
# Handle symbolic begin
- try:
- begin = _infer_value(inputs[1], params, mod).asnumpy().tolist()
- except Exception:
- begin = inputs[1]
+ begin = inputs[1]
try:
- size = _get_list_param(params, inputs[2])
- except (IndexError, KeyError, AttributeError):
+ size = _get_list_param(params, inputs[2], mod)
+ except Exception:
# Handle symbolic size
- try:
- size = _infer_value(inputs[2], params, mod).asnumpy().tolist()
- except Exception:
- size = inputs[2]
+ size = inputs[2]
# Align begin and strides for dynamic shape.
data_dim = len(_infer_shape(inputs[0], mod))
@@ -1962,7 +1956,7 @@ def _sum():
def _reduce(op):
def _impl(inputs, attr, params, mod):
- axis = _get_list_param(params, inputs[1])
+ axis = _get_list_param(params, inputs[1], mod)
axis = tuple(axis)
if not axis:
axis = None
@@ -1978,7 +1972,7 @@ def _reduce(op):
def _euclidean_norm():
def _impl(inputs, attr, params, mod):
- axis = tuple(_get_list_param(params, inputs[1]))
+ axis = tuple(_get_list_param(params, inputs[1], mod))
keep_dims = bool(attr.get("keep_dims", False))
return _op.sqrt(
_op.cast(_op.reduce.sum(_op.multiply(inputs[0], inputs[0]), axis,
keep_dims), "float32")
@@ -2039,9 +2033,9 @@ def _stridedSlice():
Tensorflow mask validation:
https://github.com/tensorflow/tensorflow/blob/master/
tensorflow/core/util/strided_slice_op.cc#L147-L368
"""
- begin = _get_list_param(params, inputs[1])
- end = _get_list_param(params, inputs[2])
- stride = _get_list_param(params, inputs[3])
+ begin = _get_list_param(params, inputs[1], mod)
+ end = _get_list_param(params, inputs[2], mod)
+ stride = _get_list_param(params, inputs[3], mod)
begin_mask = int(attr.get("begin_mask", 0))
end_mask = int(attr.get("end_mask", 0))
@@ -2243,10 +2237,7 @@ def _transpose():
def _impl(inputs, attr, params, mod):
# If perm is not specified, axes is left empty,
# otherwise its value is get from params
- try:
- axes = _get_list_param(params, inputs[1])
- except (IndexError, KeyError, AttributeError):
- axes = _infer_value(inputs[1], params, mod).asnumpy().tolist()
+ axes = _get_list_param(params, inputs[1], mod)
return _op.transpose(inputs[0], axes=axes)
return _impl
@@ -2536,19 +2527,13 @@ def _logical(name):
def _space_to_batch_nd():
def _impl(inputs, attr, params, mod):
- try:
- block_shape = _get_list_param(params, inputs[1])
- except (IndexError, KeyError, AttributeError):
- block_shape = _infer_value(inputs[1], params,
mod).asnumpy().tolist()
+ block_shape = _get_list_param(params, inputs[1], mod)
- try:
- paddings = _get_list_param(params, inputs[2])
- except (IndexError, KeyError, AttributeError):
- paddings = _infer_value(inputs[2], params, mod).asnumpy()
- paddings = np.squeeze(paddings)
- if len(paddings.shape) == 1:
- paddings = np.expand_dims(paddings, axis=0)
- paddings = paddings.tolist()
+ paddings = _get_list_param(params, inputs[2], mod)
+ paddings = np.squeeze(paddings)
+ if len(paddings.shape) == 1:
+ paddings = np.expand_dims(paddings, axis=0)
+ paddings = paddings.tolist()
attr["block_shape"] = block_shape
attr["paddings"] = paddings
@@ -2561,19 +2546,13 @@ def _space_to_batch_nd():
def _batch_to_space_nd():
def _impl(inputs, attr, params, mod):
- try:
- block_shape = _get_list_param(params, inputs[1])
- except (IndexError, KeyError, AttributeError):
- block_shape = _infer_value(inputs[1], params,
mod).asnumpy().tolist()
+ block_shape = _get_list_param(params, inputs[1], mod)
- try:
- crops = _get_list_param(params, inputs[2])
- except (IndexError, KeyError, AttributeError):
- crops = _infer_value(inputs[2], params, mod).asnumpy()
- crops = np.squeeze(crops)
- if len(crops.shape) == 1:
- crops = np.expand_dims(crops, axis=0)
- crops = crops.tolist()
+ crops = _get_list_param(params, inputs[2], mod)
+ crops = np.squeeze(crops)
+ if len(crops.shape) == 1:
+ crops = np.expand_dims(crops, axis=0)
+ crops = crops.tolist()
attr["block_shape"] = block_shape
attr["crops"] = crops