manupak commented on code in PR #12447:
URL: https://github.com/apache/tvm/pull/12447#discussion_r950678863
##########
python/tvm/relay/backend/contrib/uma/api/lower.py:
##########
@@ -60,27 +60,7 @@ def _lower_relay_to_tir(self, relay_prim_func:
relay.Function) -> tvm.tir.PrimFu
"""
def _get_tensors(te_cached_func):
- outputs = list(te_cached_func.outputs)
- stack = []
- visited = set()
- for output_ in outputs:
- if output_ not in visited:
- visited.add(output_)
- stack.append(output_)
-
- args = []
- while len(stack) != 0:
- tensor = stack.pop()
- if isinstance(tensor.op, tvm.te.tensor.PlaceholderOp):
- args.append(tensor)
- elif isinstance(tensor.op, tvm.te.tensor.ComputeOp):
- inputs = tensor.op.input_tensors
- for input_ in inputs:
- if input_ not in visited:
- visited.add(input_)
- stack.append(input_)
-
- return args + outputs
+ return list(te_cached_func.inputs) + list(te_cached_func.outputs)
Review Comment:
This make sense to me because the input visitation done here to discover the
inputs might not be the ordering where the call from relay main wants it to be.
##########
python/tvm/testing/aot.py:
##########
@@ -931,20 +930,23 @@ def generate_ref_data(mod, input_data, params=None,
target="llvm"):
return dict(zip(output_tensor_names, out))
-def create_relay_module_and_inputs_from_tflite_file(tflite_model_file):
+def create_relay_module_and_inputs_from_tflite_file(tflite_model_file,
bind_params_by_name=True):
"""A helper function to create a Relay IRModule with inputs
and params from a tflite file"""
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
- mod, params = convert_to_relay(tflite_model_buf)
+ mod, params = convert_to_relay(tflite_model_buf, bind_params_by_name)
inputs = dict()
for param in mod["main"].params:
name = str(param.name_hint)
data_shape = [int(i) for i in param.type_annotation.shape]
dtype = str(param.type_annotation.dtype)
- in_min, in_max = (np.iinfo(dtype).min, np.iinfo(dtype).max)
- data = np.random.randint(in_min, high=in_max, size=data_shape,
dtype=dtype)
+ if dtype == "float32":
Review Comment:
I think we can use np.finfo to make it "generally" working for float dtypes
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]