junrushao commented on code in PR #15670:
URL: https://github.com/apache/tvm/pull/15670#discussion_r1316271852


##########
python/tvm/relax/frontend/nn/core.py:
##########
@@ -39,6 +39,7 @@
 
 import numpy as np
 
+import tvm

Review Comment:
   don't import tvm directly inside the tvm pacakge in case of any unexpected 
behavior. instead please import a specific item you need, for example, `from 
tvm.transform import PassContent`



##########
python/tvm/relax/frontend/nn/core.py:
##########
@@ -411,6 +412,10 @@ def jit(  # pylint: disable=too-many-arguments
 
         # Compile mod and feed it to VM
         mod = relax.pipeline.get_pipeline(pipeline)(mod)  # pylint: 
disable=no-value-for-parameter
+
+        if device != "cpu":
+            with target, tvm.transform.PassContext(opt_level=3):
+                mod = tvm.tir.transform.DefaultGPUSchedule()(mod)

Review Comment:
   to be clear, the system is designed to work instead of `relax.pipeline` (see 
line 414 above), meaning users are expected to provide the name of the pipeline 
that covers this set of transformations, instead of hardcoding a 
DefaultGPUSchedule pass



##########
python/tvm/relax/frontend/nn/spec.py:
##########
@@ -59,6 +59,32 @@ def __repr__(self) -> str:
         return f"Tensor([{shape}], '{self.dtype}')"
 
 
+class TupleList:

Review Comment:
   Technically tuple and list are two different things. Intuitively, you can 
see `Tuple` as a fixed length array whose elements are known typed, for 
example, `Tuple[int, str, float]` means a tuple of 3 elements `int`, `str`, 
`float`; `List`, if it's homogenous, means a variable length array of the same 
type, for example, `List[int]` means a list of integers.
   
   Therefore, you may instead use a more indicative name, i.e. `Tuple`, instead 
of `TupleList`. If it name conflicts with `typing.Tuple`, just avoid importing 
it from the `typing` package.



##########
python/tvm/relax/frontend/nn/spec.py:
##########
@@ -59,6 +59,32 @@ def __repr__(self) -> str:
         return f"Tensor([{shape}], '{self.dtype}')"
 
 
+class TupleList:
+    """A tuple input or a list input"""
+
+    name: str
+    elements: Union[List[Union[core.Tensor, "TupleList"]], 
Tuple[Union[core.Tensor, "TupleList"]]]
+
+    def __init__(
+        self,
+        name: str,
+        elements: Union[
+            List[Union[core.Tensor, "TupleList"]], Tuple[Union[core.Tensor, 
"TupleList"]]
+        ],
+    ) -> None:
+        assert type(elements) in [tuple, list], f"Unsupported container type: 
{type(elements)}"
+        for i, e in enumerate(elements):  # pylint: disable=invalid-name
+            assert isinstance(e, (core.Tensor, TupleList)), (
+                f"Expected all elements in the {name} tuple/list to be of type 
Tensor/tuple/list, "
+                f"but found a {type(e)} at index {i}."
+            )
+        self.name = name
+        self.elements = elements

Review Comment:
   always consolidate by casting `elements` to python `list`



##########
python/tvm/relax/frontend/nn/spec.py:
##########
@@ -105,17 +131,23 @@ def from_raw(spec: MethodSpecType, method: Callable) -> 
"MethodSpec":
         method_signature = inspect.signature(method)
         arg_names = list(method_signature.parameters.keys())
         arg_specs = []
+
+        def _convert_arg_spec(arg_spec):
+            if arg_spec is Int or arg_spec is int:
+                return Int()
+            elif isinstance(arg_spec, str) and arg_spec == "int":
+                return Int()
+            elif isinstance(arg_spec, (Int, Tensor)):
+                return arg_spec
+            elif isinstance(arg_spec, (tuple, list)):
+                return type(arg_spec)([_convert_arg_spec(arg_spec_i) for 
arg_spec_i in arg_spec])
+            else:
+                raise TypeError(f"Invalid spec for argument {arg_name}: 
{arg_spec}")

Review Comment:
   Make it a global method. `parse_spec(arg_spec, arg_name)`



##########
python/tvm/relax/frontend/nn/spec.py:
##########
@@ -59,6 +59,32 @@ def __repr__(self) -> str:
         return f"Tensor([{shape}], '{self.dtype}')"
 
 
+class TupleList:
+    """A tuple input or a list input"""
+
+    name: str
+    elements: Union[List[Union[core.Tensor, "TupleList"]], 
Tuple[Union[core.Tensor, "TupleList"]]]
+
+    def __init__(
+        self,
+        name: str,
+        elements: Union[
+            List[Union[core.Tensor, "TupleList"]], Tuple[Union[core.Tensor, 
"TupleList"]]
+        ],
+    ) -> None:
+        assert type(elements) in [tuple, list], f"Unsupported container type: 
{type(elements)}"
+        for i, e in enumerate(elements):  # pylint: disable=invalid-name
+            assert isinstance(e, (core.Tensor, TupleList)), (
+                f"Expected all elements in the {name} tuple/list to be of type 
Tensor/tuple/list, "
+                f"but found a {type(e)} at index {i}."
+            )

Review Comment:
   There isn't much point to limit its types to `Tuple`/`Tensor` but not `int` 
- it's possible to make `int` work as well



##########
python/tvm/relax/frontend/nn/spec.py:
##########
@@ -59,6 +59,32 @@ def __repr__(self) -> str:
         return f"Tensor([{shape}], '{self.dtype}')"
 
 
+class TupleList:
+    """A tuple input or a list input"""
+
+    name: str
+    elements: Union[List[Union[core.Tensor, "TupleList"]], 
Tuple[Union[core.Tensor, "TupleList"]]]

Review Comment:
   Define a global type `SpecAny`:
   
   ```
   SpecAny = Union["Int", "Tensor", "Tuple"]
   ```
   
   Then define the elements in Tuple as:
   
   ```
   elements: List[SpecAny]
   ```



##########
tests/python/relax/test_frontend_nn_modules.py:
##########
@@ -529,5 +531,83 @@ def forward(
     assert_structural_equal(tvm_mod["forward"], forward, True)
 
 
+def test_nn_module_tuple_input_output():
+    class Layer(nn.Module):
+        def __init__(self):
+            pass
+
+        def forward(self, x: Tuple[nn.Tensor]):

Review Comment:
   `Tuple[nn.Tensor]` means a tuple of length 1 that contains a tensor
   
   ```suggestion
           def forward(self, x: Tuple[nn.Tensor, nn.Tensor]):
   ```



##########
python/tvm/relax/frontend/nn/spec.py:
##########
@@ -105,17 +131,23 @@ def from_raw(spec: MethodSpecType, method: Callable) -> 
"MethodSpec":
         method_signature = inspect.signature(method)
         arg_names = list(method_signature.parameters.keys())
         arg_specs = []
+
+        def _convert_arg_spec(arg_spec):
+            if arg_spec is Int or arg_spec is int:
+                return Int()
+            elif isinstance(arg_spec, str) and arg_spec == "int":
+                return Int()
+            elif isinstance(arg_spec, (Int, Tensor)):
+                return arg_spec
+            elif isinstance(arg_spec, (tuple, list)):

Review Comment:
   ```suggestion
               elif isinstance(arg_spec, (tuple, list, TupleList)):
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to