This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 811584992c Infer the value of shape expr to avoid dynamic (#12313)
811584992c is described below

commit 811584992c9f86b8a46596baf6bde247446b8c8a
Author: Black <[email protected]>
AuthorDate: Wed Aug 10 08:09:17 2022 +0800

    Infer the value of shape expr to avoid dynamic (#12313)
---
 python/tvm/relay/frontend/tflite.py | 18 ++++++++++++++++--
 1 file changed, 16 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/frontend/tflite.py 
b/python/tvm/relay/frontend/tflite.py
index 239d72055b..c38191b389 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -33,7 +33,7 @@ from .. import qnn as _qnn
 from ..backend.name_transforms import sanitize_name
 from .common import ExprTable
 from .common import infer_shape as _infer_shape
-from .common import lstm_cell, to_int_list, shape_of
+from .common import lstm_cell, to_int_list, shape_of, try_infer_value
 from .tflite_flexbuffer import FlexBufferDecoder
 
 __all__ = ["from_tflite"]
@@ -599,7 +599,21 @@ class OperatorConverter(object):
         if len(input_tensors) == 2:
             shape_tensor = input_tensors[1]
             if self.has_expr(shape_tensor.tensor_idx):
-                target_shape = self.get_expr(shape_tensor.tensor_idx)
+                target_expr = self.get_expr(shape_tensor.tensor_idx)
+                target_value, success = try_infer_value(
+                    target_expr,
+                    parameters={k: _nd.array(np.array(v)) for k, v in 
self.exp_tab.params.items()},
+                )
+                if success:
+                    # convert to flattened list
+                    from itertools import chain
+
+                    try:
+                        target_shape = list(chain(*target_value))
+                    except TypeError:
+                        target_shape = list(chain(target_value))
+                else:
+                    target_shape = target_expr
             else:
                 target_shape = self.get_tensor_value(shape_tensor)
                 # convert to flattened list

Reply via email to