Lunderberg commented on code in PR #16602:
URL: https://github.com/apache/tvm/pull/16602#discussion_r1497800531


##########
tests/python/relax/test_transform_lazy_transform_params.py:
##########
@@ -602,5 +602,77 @@ def main_transform_params() -> R.Tuple:
     tvm.ir.assert_structural_equal(after, Expected)
 
 
+def test_params_without_tuple():
+    @I.ir_module
+    class Before:
+        @R.function
+        def transform_params(A: R.Tensor([16, 16], "float32"), B: 
R.Tensor([16, 16], "float32")):
+            C = R.multiply(A, R.const(2, "float32"))
+            D = R.add(C, B)
+            return (D, B)
+
+    @I.ir_module
+    class Expected:
+        @R.function(pure=False)
+        def transform_params():
+            A = R.call_packed("get_item", R.prim_value(0), 
sinfo_args=[R.Object])
+            A = R.match_cast(A, R.Tensor([16, 16], "float32"))
+            C = R.multiply(A, R.const(2, "float32"))
+
+            B = R.call_packed("get_item", R.prim_value(1), 
sinfo_args=[R.Object])
+            B = R.match_cast(B, R.Tensor([16, 16], "float32"))
+            D = R.add(C, B)
+            return (D, B)
+
+    After = LazyTransformParams(fset_item=None)(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_retain_before_num_input():
+    """Only lazily load parameters after num_input"""
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def transform_params(
+            relax_rank: R.Prim(value="rank"),
+            A: R.Tensor([16, 16], "float32"),
+            B: R.Tensor([16, 16], "float32"),
+        ):
+            R.func_attr({"num_input": 1})
+            rank = T.int64()
+            A_sharded = R.strided_slice(
+                A, axes=[0], begin=[rank * 8], end=[(rank + 1) * 8], 
assume_inbound=True
+            )
+            B_sharded = R.strided_slice(
+                B, axes=[1], begin=[rank * 8], end=[(rank + 1) * 8], 
assume_inbound=True
+            )
+            return (A_sharded, B_sharded)
+
+    @I.ir_module
+    class Expected:
+        @R.function(pure=False)
+        def transform_params(relax_rank: R.Prim(value="rank")):
+            R.func_attr({"num_input": 1})
+            rank = T.int64()
+
+            A = R.call_packed("get_item", R.prim_value(0), 
sinfo_args=[R.Object])
+            A = R.match_cast(A, R.Tensor([16, 16], "float32"))
+            A_sharded = R.strided_slice(
+                A, axes=[0], begin=[rank * 8], end=[(rank + 1) * 8], 
assume_inbound=True
+            )
+
+            B = R.call_packed("get_item", R.prim_value(1), 
sinfo_args=[R.Object])
+            B = R.match_cast(B, R.Tensor([16, 16], "float32"))
+            B_sharded = R.strided_slice(
+                B, axes=[1], begin=[rank * 8], end=[(rank + 1) * 8], 
assume_inbound=True
+            )
+
+            return (A_sharded, B_sharded)
+
+    After = LazyTransformParams(fset_item=None)(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
+
+

Review Comment:
   There should be an existing unit test `test_extra_params` that validates the 
`extra_get_item_params`.  However, there aren't any unit tests that validate 
`extra_set_item_params`, nor are there any that validate 
`extra_set_item_params` in the code path for `R.const(...)` outputs.
   
   I've updated the unit tests to cover those additional cases.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to