This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 811584992c Infer the value of shape expr to avoid dynamic (#12313)
811584992c is described below
commit 811584992c9f86b8a46596baf6bde247446b8c8a
Author: Black <[email protected]>
AuthorDate: Wed Aug 10 08:09:17 2022 +0800
Infer the value of shape expr to avoid dynamic (#12313)
---
python/tvm/relay/frontend/tflite.py | 18 ++++++++++++++++--
1 file changed, 16 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relay/frontend/tflite.py
b/python/tvm/relay/frontend/tflite.py
index 239d72055b..c38191b389 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -33,7 +33,7 @@ from .. import qnn as _qnn
from ..backend.name_transforms import sanitize_name
from .common import ExprTable
from .common import infer_shape as _infer_shape
-from .common import lstm_cell, to_int_list, shape_of
+from .common import lstm_cell, to_int_list, shape_of, try_infer_value
from .tflite_flexbuffer import FlexBufferDecoder
__all__ = ["from_tflite"]
@@ -599,7 +599,21 @@ class OperatorConverter(object):
if len(input_tensors) == 2:
shape_tensor = input_tensors[1]
if self.has_expr(shape_tensor.tensor_idx):
- target_shape = self.get_expr(shape_tensor.tensor_idx)
+ target_expr = self.get_expr(shape_tensor.tensor_idx)
+ target_value, success = try_infer_value(
+ target_expr,
+ parameters={k: _nd.array(np.array(v)) for k, v in
self.exp_tab.params.items()},
+ )
+ if success:
+ # convert to flattened list
+ from itertools import chain
+
+ try:
+ target_shape = list(chain(*target_value))
+ except TypeError:
+ target_shape = list(chain(target_value))
+ else:
+ target_shape = target_expr
else:
target_shape = self.get_tensor_value(shape_tensor)
# convert to flattened list