masahi commented on a change in pull request #6546:
URL: https://github.com/apache/incubator-tvm/pull/6546#discussion_r494144957
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -2876,45 +2877,49 @@ def get_relay_ty(ishape, pt_type):
msg = "Shapes of input list and information in the graph do
not match"
raise RuntimeError(msg)
pt_dtype = pt_type.scalarType()
+ if not pt_dtype and itype:
+ pt_dtype = itype
dtype = _convert_data_type(pt_dtype, default_dtype=default_dtype)
return TensorType(ishape, dtype)
elif pt_type.kind() == "TupleType":
if not isinstance(ishape, tuple):
msg = "Shapes for tuples must be tuples"
raise RuntimeError(msg)
return TupleType(
- [get_relay_ty(elem, pt_t) for elem, pt_t in zip(ishape,
pt_type.elements())]
+ [get_relay_ty(elem, itype, pt_t)
+ for elem, pt_t in zip(ishape, pt_type.elements())]
)
elif pt_type.kind() == "ListType":
if not isinstance(ishape, list):
msg = "Shapes for lists must be lists"
raise RuntimeError(msg)
pt_elemtype = pt_type.getElementType()
- elem_tys = [get_relay_ty(s, pt_elemtype) for s in ishape]
+ elem_tys = [get_relay_ty(s, itype, pt_elemtype) for s in ishape]
if len(elem_tys) > 0 and not all(map(lambda ty: ty == elem_tys[0],
elem_tys)):
msg = "List elements need have identical types"
raise RuntimeError(msg)
return prelude.l(elem_tys[0])
elif pt_type.kind() == "OptionalType":
# we do not support None yet, so we fill in the type
- return get_relay_ty(ishape, pt_type.getElementType())
+ return get_relay_ty(ishape, itype, pt_type.getElementType())
# TODO: scalar inputs
raise NotImplementedError("unsupported input type")
input_vars = {}
- for num, inp in enumerate(input_shapes):
+ for num, inp in enumerate(input_infos):
if not isinstance(inp, tuple):
msg = "Graph input {} is not a tuple".format(num)
raise RuntimeError(msg)
if len(inp) != 2 or not isinstance(inp[0], str):
- msg = "Graph input {} is not valid, expected ('name',
shape)".format(inp)
+ msg = "Graph input {} is not valid,"
+ " expected ('name', shape) or ('name', (shape,
dtype))".format(inp)
raise RuntimeError(msg)
+ if not isinstance(inp[1], tuple):
+ inp[1] = (inp[1], None)
Review comment:
None -> `default_dtype`
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]