gemini-code-assist[bot] commented on code in PR #18663:
URL: https://github.com/apache/tvm/pull/18663#discussion_r2694182595


##########
tests/python/relax/test_transform_static_plan_block_memory.py:
##########
@@ -1018,6 +1018,245 @@ def main(x: R.Tensor((2, "n"), dtype="float32")) -> 
R.Tensor(("2 * n + 2",), dty
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_lower_bound_only():
+    # fmt: off
+    @tvm.script.ir_module
+    class Module:
+        @T.prim_func
+        def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add: 
T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def reshape(rxplaceholder: T.handle, T_reshape: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def relu(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def log(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def exp(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def pad(rxplaceholder: T.handle, PadInput: T.handle):
+            T.evaluate(0)
+
+        @R.function
+        def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n + 
2",), dtype="float32"):
+            R.func_attr({"tir_var_lower_bound": {"n": 2}, "relax.force_pure": 
True})
+            n = T.int64()
+            cls = Module
+            alloc: R.Tensor((2, n), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2, n]), dtype="float32", runtime_device_index=0)
+            _: R.Tuple() = cls.exp(x, alloc)
+            lv: R.Tensor((2, n), dtype="float32") = alloc
+            lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv, (2 * n,))
+            alloc1: R.Tensor((2 * n,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32", 
runtime_device_index=0)
+            _1: R.Tuple() = cls.relu(lv1, alloc1)
+            lv2: R.Tensor((2 * n,), dtype="float32") = alloc1
+            alloc2: R.Tensor((2 * n,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32", 
runtime_device_index=0)
+            _2: R.Tuple() = cls.add(lv2, R.const(1, "float32"), alloc2)
+            lv3: R.Tensor((2 * n,), dtype="float32") = alloc2
+            alloc3: R.Tensor((2 * n + 2,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2 * n + 2]), dtype="float32", 
runtime_device_index=0)
+            _3: R.Tuple() = cls.pad(lv3, alloc3)
+            lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3
+            alloc4: R.Tensor((2 * n + 2,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0)
+            _4: R.Tuple() = cls.log(lv4, alloc4)
+            gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add: 
T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def exp(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def log(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def pad(rxplaceholder: T.handle, PadInput: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def relu(rxplaceholder: T.handle, compute: T.handle):
+            T.evaluate(0)
+
+        @T.prim_func
+        def reshape(rxplaceholder: T.handle, T_reshape: T.handle):
+            T.evaluate(0)
+
+        @R.function
+        def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n + 
2",), dtype="float32"):
+            n = T.int64()
+            R.func_attr({"tir_var_lower_bound": {"n": 2}, "relax.force_pure": 
True})
+            cls = Expected
+            storage: R.Object = R.memory.alloc_storage(R.shape([8 * n]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            alloc: R.Tensor((2, n), dtype="float32") = 
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, n]), 
R.dtype("float32"), R.prim_value(0))
+            _: R.Tuple = cls.exp(x, alloc)
+            lv: R.Tensor((2, n), dtype="float32") = alloc
+            lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv, 
R.shape([2 * n]))
+            storage1: R.Object = R.memory.alloc_storage(R.shape([4 * (2 * 
n)]), R.prim_value(0), R.str("global"), R.dtype("float32"))
+            alloc1: R.Tensor((2 * n,), dtype="float32") = 
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2 * n]), 
R.dtype("float32"))
+            _1: R.Tuple = cls.relu(lv1, alloc1)
+            lv2: R.Tensor((2 * n,), dtype="float32") = alloc1
+            alloc2: R.Tensor((2 * n,), dtype="float32") = 
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2 * n]), 
R.dtype("float32"))
+            _2: R.Tuple = cls.add(lv2, R.const(1, "float32"), alloc2)
+            lv3: R.Tensor((2 * n,), dtype="float32") = alloc2
+            storage2: R.Object = R.memory.alloc_storage(R.shape([4 * (2 * n + 
2)]), R.prim_value(0), R.str("global"), R.dtype("float32"))
+            alloc3: R.Tensor((2 * n + 2,), dtype="float32") = 
R.memory.alloc_tensor(storage2, R.prim_value(0), R.shape([2 * n + 2]), 
R.dtype("float32"), R.prim_value(0))
+            _3: R.Tuple = cls.pad(lv3, alloc3)
+            lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3
+            alloc4: R.Tensor((2 * n + 2,), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([10]), R.dtype("float32"), R.prim_value(0))
+            _4: R.Tuple = cls.log(lv4, alloc4)
+            gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4
+            return gv
+    # fmt: on
+
+    mod = relax.transform.StaticPlanBlockMemory()(Module)
+    tvm.ir.assert_structural_equal(mod, Expected)
+

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The test functions `test_lower_bound_only` and `test_upper_and_lower_bounds` 
are very similar and contain a lot of duplicated code. Consider refactoring 
them into a single parameterized test using `pytest.mark.parametrize` to 
improve maintainability.
   
   You could parameterize the function attributes and the expected memory 
allocation logic (static vs. dynamic). Here is a conceptual example:
   
   ```python
   import pytest
   
   @pytest.mark.parametrize(
       "func_attr, is_static",
       [
           ({"tir_var_lower_bound": {"n": 2}, "relax.force_pure": True}, False),
           (
               {"tir_var_upper_bound": {"n": 4}, "tir_var_lower_bound": {"n": 
2}, "relax.force_pure": True},
               True,
           ),
       ],
   )
   def test_bounds(func_attr, is_static):
       @tvm.script.ir_module
       class Module:
           # ... (common module definition)
           @R.function
           def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n 
+ 2",), dtype="float32"):
               R.func_attr(func_attr)
               # ... (common function body)
   
       @I.ir_module
       class Expected:
           # ... (common primfuncs)
           @R.function
           def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n 
+ 2",), dtype="float32"):
               n = T.int64()
               R.func_attr(func_attr)
               cls = Expected
               if is_static:
                   storage: R.Object = R.memory.alloc_storage(R.shape([32]), 
...)
                   # ... static allocation logic
               else:
                   storage: R.Object = R.memory.alloc_storage(R.shape([8 * n]), 
...)
                   # ... dynamic allocation logic
       
       mod = relax.transform.StaticPlanBlockMemory()(Module)
       tvm.ir.assert_structural_equal(mod, Expected)
   ```
   
   This approach would make the tests more concise and easier to maintain.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to