euntaik commented on a change in pull request #7400:
URL: https://github.com/apache/tvm/pull/7400#discussion_r573533135
##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -3588,8 +3643,14 @@ def from_tflite(model, shape_dict, dtype_dict):
exp_tab = ExprTable()
for model_input in model_inputs:
model_input_name = get_tensor_name(subgraph, model_input)
- shape = shape_dict[model_input_name] if model_input_name in shape_dict
else None
- dtype = dtype_dict[model_input_name] if model_input_name in dtype_dict
else "float32"
+ if shape_dict:
+ shape = shape_dict[model_input_name] if model_input_name in
shape_dict else None
+ else:
+ shape = get_tensor_shape(subgraph, model_input)
+ if dtype_dict:
+ dtype = dtype_dict[model_input_name] if model_input_name in
dtype_dict else "float32"
+ else:
+ dtype = get_tensor_type(subgraph, model_input)
Review comment:
> We have a similar function, that collect the same information being
proposed here in TVMC. I agree we should move what is in there, to unify
functionality here.
Oh, it was there all along. I think I missed your code since I was loading
my models in a separate script to put the relay output into my compile passes.
> Can you have a look on the function I'm pointing here (below) and spot why
are they so different,
I don't see much difference except that your code accounts for models with
more than one subgraph.
> and in case you agree on what's the best approach, improve it here and
remove it there?
My rationale behind making and putting this code in the tflite.py file was:
1. use the data in the graph since it is already embedded in it.
2. place the code inside the frontend code since it is dependent on the
frontend.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]