masahi commented on a change in pull request #6546:
URL: https://github.com/apache/incubator-tvm/pull/6546#discussion_r494117103
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -3244,7 +3267,8 @@ def get_all_op_names(graph):
return set(node.kind() for node in nodes)
-def from_pytorch(script_module, input_shapes, custom_convert_map=None,
default_dtype="float32"):
+def from_pytorch(script_module, input_shapes, custom_convert_map=None,
default_dtype="float32",
+ input_types=None):
Review comment:
yeah, I also prefer stuffing types into the tuple (name, (shape, type))
and rename input_shapes to something like input_info or something
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -3244,7 +3267,8 @@ def get_all_op_names(graph):
return set(node.kind() for node in nodes)
-def from_pytorch(script_module, input_shapes, custom_convert_map=None,
default_dtype="float32"):
+def from_pytorch(script_module, input_shapes, custom_convert_map=None,
default_dtype="float32",
+ input_types=None):
Review comment:
yeah, I also prefer stuffing types into the tuple (name, (shape, type))
and rename input_shapes to something like input_info or something. And add the
default type if type is not provided
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -3244,7 +3267,8 @@ def get_all_op_names(graph):
return set(node.kind() for node in nodes)
-def from_pytorch(script_module, input_shapes, custom_convert_map=None,
default_dtype="float32"):
+def from_pytorch(script_module, input_shapes, custom_convert_map=None,
default_dtype="float32",
+ input_types=None):
Review comment:
I think if it doesn't break existing code, it is fine to change the api
a bit.
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -2876,45 +2877,49 @@ def get_relay_ty(ishape, pt_type):
msg = "Shapes of input list and information in the graph do
not match"
raise RuntimeError(msg)
pt_dtype = pt_type.scalarType()
+ if not pt_dtype and itype:
+ pt_dtype = itype
dtype = _convert_data_type(pt_dtype, default_dtype=default_dtype)
return TensorType(ishape, dtype)
elif pt_type.kind() == "TupleType":
if not isinstance(ishape, tuple):
msg = "Shapes for tuples must be tuples"
raise RuntimeError(msg)
return TupleType(
- [get_relay_ty(elem, pt_t) for elem, pt_t in zip(ishape,
pt_type.elements())]
+ [get_relay_ty(elem, itype, pt_t)
+ for elem, pt_t in zip(ishape, pt_type.elements())]
)
elif pt_type.kind() == "ListType":
if not isinstance(ishape, list):
msg = "Shapes for lists must be lists"
raise RuntimeError(msg)
pt_elemtype = pt_type.getElementType()
- elem_tys = [get_relay_ty(s, pt_elemtype) for s in ishape]
+ elem_tys = [get_relay_ty(s, itype, pt_elemtype) for s in ishape]
if len(elem_tys) > 0 and not all(map(lambda ty: ty == elem_tys[0],
elem_tys)):
msg = "List elements need have identical types"
raise RuntimeError(msg)
return prelude.l(elem_tys[0])
elif pt_type.kind() == "OptionalType":
# we do not support None yet, so we fill in the type
- return get_relay_ty(ishape, pt_type.getElementType())
+ return get_relay_ty(ishape, itype, pt_type.getElementType())
# TODO: scalar inputs
raise NotImplementedError("unsupported input type")
input_vars = {}
- for num, inp in enumerate(input_shapes):
+ for num, inp in enumerate(input_infos):
if not isinstance(inp, tuple):
msg = "Graph input {} is not a tuple".format(num)
raise RuntimeError(msg)
if len(inp) != 2 or not isinstance(inp[0], str):
- msg = "Graph input {} is not valid, expected ('name',
shape)".format(inp)
+ msg = "Graph input {} is not valid,"
+ " expected ('name', shape) or ('name', (shape,
dtype))".format(inp)
raise RuntimeError(msg)
+ if not isinstance(inp[1], tuple):
+ inp[1] = (inp[1], None)
Review comment:
None -> `default_dtype`
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -3254,10 +3259,14 @@ def from_pytorch(script_module, input_shapes,
custom_convert_map=None, default_d
TorchScripted PyTorch graph
Note: We currently only support traces (ie: torch.jit.trace(model,
input))
- input_shapes : List of tuples of input name and input dimensions
- Graph level input shape list
+ input_infos: List of tuples of input name and (input dimensions, input
types)
+ Graph level input shape and type list
Review comment:
Can you also add that `(name, shape)` is also allowed? Since this is the
more common use case.
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -3254,10 +3259,14 @@ def from_pytorch(script_module, input_shapes,
custom_convert_map=None, default_d
TorchScripted PyTorch graph
Note: We currently only support traces (ie: torch.jit.trace(model,
input))
- input_shapes : List of tuples of input name and input dimensions
- Graph level input shape list
+ input_infos: List of tuples of input name and (input dimensions, input
types)
+ Graph level input shape and type list
The same input names need to be used for deployment, so choose easy to
remember names (such as: input0, input1)
+ e.g. [('input0', ((1, 2), 'int')), ('input1', ((3, 4), 'float'))]
Review comment:
Add an example for `(name, shape)` case
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -3254,10 +3259,14 @@ def from_pytorch(script_module, input_shapes,
custom_convert_map=None, default_d
TorchScripted PyTorch graph
Note: We currently only support traces (ie: torch.jit.trace(model,
input))
- input_shapes : List of tuples of input name and input dimensions
- Graph level input shape list
+ input_infos: List of tuples of input name and (input dimensions, input
types)
+ Graph level input shape and type list
The same input names need to be used for deployment, so choose easy to
remember names (such as: input0, input1)
+ e.g. [('input0', ((1, 2), 'int')), ('input1', ((3, 4), 'float'))]
+ supported data types: ['half', 'float', 'double', 'bool', 'char',
+ 'byte', 'short', 'int', 'long',
+ 'quint8', 'qint8', 'qint32', 'str']
Review comment:
I don't think we support that many types.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]