This is an automated email from the ASF dual-hosted git repository.

kevinthesun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 93f3010  Move infer_value to _get_list_param (#8051)
93f3010 is described below

commit 93f301059308d8d9fc85e892cd018a08c1a5d2d5
Author: Trevor Morris <[email protected]>
AuthorDate: Tue May 18 14:25:25 2021 -0700

    Move infer_value to _get_list_param (#8051)
---
 python/tvm/relay/frontend/tensorflow.py | 83 ++++++++++++---------------------
 1 file changed, 31 insertions(+), 52 deletions(-)

diff --git a/python/tvm/relay/frontend/tensorflow.py 
b/python/tvm/relay/frontend/tensorflow.py
index 4bd332f..5b0507b 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -113,8 +113,11 @@ def _get_num_param(params, input_node):
     return _get_param(params, input_node).item()
 
 
-def _get_list_param(params, input_node):
-    return _get_param(params, input_node).tolist()
+def _get_list_param(params, input_node, mod):
+    try:
+        return _get_param(params, input_node).tolist()
+    except (IndexError, KeyError, AttributeError):
+        return _infer_value(input_node, params, mod).asnumpy().tolist()
 
 
 def _get_tuple_param(params, input_node):
@@ -913,10 +916,7 @@ def _crop_and_resize():
     def _impl(inputs, attr, params, mod):
         # input image is a 4-D tensor of shape [batch, image_height, 
image_width, depth]
         # boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, 
x2]
-        try:
-            crop_size = _get_list_param(params, inputs[3])
-        except (IndexError, KeyError):
-            crop_size = _infer_value(inputs[3], params, mod).asnumpy().tolist()
+        crop_size = _get_list_param(params, inputs[3], mod)
 
         method = attr["method"].decode()
         method = "nearest_neighbor" if method == "nearest" else method
@@ -1658,7 +1658,7 @@ def _tile():
             np_reps = _infer_value(reps_input, params, mod).asnumpy()
             reps = [np_reps.flatten()[i] for i in 
range(np_reps.flatten().shape[0])]
         else:
-            reps = _get_list_param(params, reps_input)
+            reps = _get_list_param(params, reps_input, mod)
         new_input = [inputs.pop(0)]
 
         return AttrCvt(op_name="tile", extras={"reps": tuple(reps)}, 
ignores=["Tmultiples"])(
@@ -1671,21 +1671,15 @@ def _tile():
 def _slice():
     def _impl(inputs, attr, params, mod):
         try:
-            begin = _get_list_param(params, inputs[1])
-        except (IndexError, KeyError, AttributeError):
+            begin = _get_list_param(params, inputs[1], mod)
+        except Exception:
             # Handle symbolic begin
-            try:
-                begin = _infer_value(inputs[1], params, mod).asnumpy().tolist()
-            except Exception:
-                begin = inputs[1]
+            begin = inputs[1]
         try:
-            size = _get_list_param(params, inputs[2])
-        except (IndexError, KeyError, AttributeError):
+            size = _get_list_param(params, inputs[2], mod)
+        except Exception:
             # Handle symbolic size
-            try:
-                size = _infer_value(inputs[2], params, mod).asnumpy().tolist()
-            except Exception:
-                size = inputs[2]
+            size = inputs[2]
 
         # Align begin and strides for dynamic shape.
         data_dim = len(_infer_shape(inputs[0], mod))
@@ -1962,7 +1956,7 @@ def _sum():
 
 def _reduce(op):
     def _impl(inputs, attr, params, mod):
-        axis = _get_list_param(params, inputs[1])
+        axis = _get_list_param(params, inputs[1], mod)
         axis = tuple(axis)
         if not axis:
             axis = None
@@ -1978,7 +1972,7 @@ def _reduce(op):
 
 def _euclidean_norm():
     def _impl(inputs, attr, params, mod):
-        axis = tuple(_get_list_param(params, inputs[1]))
+        axis = tuple(_get_list_param(params, inputs[1], mod))
         keep_dims = bool(attr.get("keep_dims", False))
         return _op.sqrt(
             _op.cast(_op.reduce.sum(_op.multiply(inputs[0], inputs[0]), axis, 
keep_dims), "float32")
@@ -2039,9 +2033,9 @@ def _stridedSlice():
         Tensorflow mask validation: 
https://github.com/tensorflow/tensorflow/blob/master/
         tensorflow/core/util/strided_slice_op.cc#L147-L368
         """
-        begin = _get_list_param(params, inputs[1])
-        end = _get_list_param(params, inputs[2])
-        stride = _get_list_param(params, inputs[3])
+        begin = _get_list_param(params, inputs[1], mod)
+        end = _get_list_param(params, inputs[2], mod)
+        stride = _get_list_param(params, inputs[3], mod)
 
         begin_mask = int(attr.get("begin_mask", 0))
         end_mask = int(attr.get("end_mask", 0))
@@ -2243,10 +2237,7 @@ def _transpose():
     def _impl(inputs, attr, params, mod):
         # If perm is not specified, axes is left empty,
         # otherwise its value is get from params
-        try:
-            axes = _get_list_param(params, inputs[1])
-        except (IndexError, KeyError, AttributeError):
-            axes = _infer_value(inputs[1], params, mod).asnumpy().tolist()
+        axes = _get_list_param(params, inputs[1], mod)
         return _op.transpose(inputs[0], axes=axes)
 
     return _impl
@@ -2536,19 +2527,13 @@ def _logical(name):
 
 def _space_to_batch_nd():
     def _impl(inputs, attr, params, mod):
-        try:
-            block_shape = _get_list_param(params, inputs[1])
-        except (IndexError, KeyError, AttributeError):
-            block_shape = _infer_value(inputs[1], params, 
mod).asnumpy().tolist()
+        block_shape = _get_list_param(params, inputs[1], mod)
 
-        try:
-            paddings = _get_list_param(params, inputs[2])
-        except (IndexError, KeyError, AttributeError):
-            paddings = _infer_value(inputs[2], params, mod).asnumpy()
-            paddings = np.squeeze(paddings)
-            if len(paddings.shape) == 1:
-                paddings = np.expand_dims(paddings, axis=0)
-            paddings = paddings.tolist()
+        paddings = _get_list_param(params, inputs[2], mod)
+        paddings = np.squeeze(paddings)
+        if len(paddings.shape) == 1:
+            paddings = np.expand_dims(paddings, axis=0)
+        paddings = paddings.tolist()
 
         attr["block_shape"] = block_shape
         attr["paddings"] = paddings
@@ -2561,19 +2546,13 @@ def _space_to_batch_nd():
 
 def _batch_to_space_nd():
     def _impl(inputs, attr, params, mod):
-        try:
-            block_shape = _get_list_param(params, inputs[1])
-        except (IndexError, KeyError, AttributeError):
-            block_shape = _infer_value(inputs[1], params, 
mod).asnumpy().tolist()
+        block_shape = _get_list_param(params, inputs[1], mod)
 
-        try:
-            crops = _get_list_param(params, inputs[2])
-        except (IndexError, KeyError, AttributeError):
-            crops = _infer_value(inputs[2], params, mod).asnumpy()
-            crops = np.squeeze(crops)
-            if len(crops.shape) == 1:
-                crops = np.expand_dims(crops, axis=0)
-            crops = crops.tolist()
+        crops = _get_list_param(params, inputs[2], mod)
+        crops = np.squeeze(crops)
+        if len(crops.shape) == 1:
+            crops = np.expand_dims(crops, axis=0)
+        crops = crops.tolist()
 
         attr["block_shape"] = block_shape
         attr["crops"] = crops

Reply via email to