This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 876ed385ce [Unity][Op] Symbolic shape support of take grad (#14559)
876ed385ce is described below

commit 876ed385ce95dd0823f203603e8e61eaac126202
Author: Chaofan Lin <[email protected]>
AuthorDate: Tue Apr 11 02:28:28 2023 +0800

    [Unity][Op] Symbolic shape support of take grad (#14559)
    
    Prior to this PR the legalization of R.grad.take_backward doesn't support 
symbolic shape. This PR enables the support by some small fixes and adds tests.
---
 python/tvm/relax/transform/legalize_ops/grad.py    |  9 ++--
 .../relax/test_transform_legalize_ops_grad.py      | 60 +++++++++++++++++++---
 2 files changed, 57 insertions(+), 12 deletions(-)

diff --git a/python/tvm/relax/transform/legalize_ops/grad.py 
b/python/tvm/relax/transform/legalize_ops/grad.py
index 7fb9b0864d..f5c295afc6 100644
--- a/python/tvm/relax/transform/legalize_ops/grad.py
+++ b/python/tvm/relax/transform/legalize_ops/grad.py
@@ -162,7 +162,8 @@ def _grad_take_backward(bb: BlockBuilder, call: Call) -> 
Expr:
             with ib.for_range(0, fused_shape) as i:
                 out[i] = tir.const(0, dtype=x_ptr.dtype)
 
-            indices_len = indices_ptr.shape[0].value  # must be 1-dim
+            assert len(indices_ptr.shape) == 1  # indices in take must be 
1-dim Tensor
+            indices_len = indices_ptr.shape[0]
 
             if axis is not None:
                 fused_output_grad_shape_pre = 1
@@ -173,14 +174,14 @@ def _grad_take_backward(bb: BlockBuilder, call: Call) -> 
Expr:
                     elif i > axis:
                         fused_output_grad_shape_nxt *= output_grad_ptr.shape[i]
 
-                x_axis_len = x_ptr.shape[axis].value
+                x_axis_len = x_ptr.shape[axis]
 
                 with ib.for_range(
                     0, fused_output_grad_shape_pre * 
fused_output_grad_shape_nxt, "parallel"
                 ) as fused:
                     i = fused // fused_output_grad_shape_nxt
                     j = fused % fused_output_grad_shape_nxt
-                    for l in reversed(range(indices_len)):
+                    with ib.for_range(0, indices_len, "serial") as l:
                         out[
                             i * fused_output_grad_shape_nxt * x_axis_len
                             + indices[l] * fused_output_grad_shape_nxt
@@ -191,7 +192,7 @@ def _grad_take_backward(bb: BlockBuilder, call: Call) -> 
Expr:
                             + j
                         ]
             else:
-                for l in reversed(range(indices_len)):
+                with ib.for_range(0, indices_len, "serial") as l:
                     out[indices[l]] += output_grad[l]
 
             return ib.get()
diff --git a/tests/python/relax/test_transform_legalize_ops_grad.py 
b/tests/python/relax/test_transform_legalize_ops_grad.py
index a92537f0d1..14ac96bb4d 100644
--- a/tests/python/relax/test_transform_legalize_ops_grad.py
+++ b/tests/python/relax/test_transform_legalize_ops_grad.py
@@ -302,28 +302,28 @@ def test_take_backward():
     @tvm.script.ir_module
     class TakeBackward:
         @R.function
-        def main(output_grad: R.Tensor((3, 2, 4), "float32"), x: R.Tensor((3, 
4, 5), "float32"), indices: R.Tensor((2,), "int32")):
-            gv = R.grad.take_backward(output_grad, x, indices)
+        def main(output_grad: R.Tensor((3, 2, 5), "float32"), x: R.Tensor((3, 
4, 5), "float32"), indices: R.Tensor((2,), "int32")):
+            gv = R.grad.take_backward(output_grad, x, indices, axis=1)
             return gv
 
     @I.ir_module
     class Expected:
         @T.prim_func
         def take_backward(var_rxplaceholder: T.handle, var_rxplaceholder_1: 
T.handle, var_rxplaceholder_2: T.handle, out_buf: T.Buffer((T.int64(3), 
T.int64(4), T.int64(5)), "float32")):
-            T.func_attr({"tir.noalias": True})
-            rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(3), 
T.int64(2), T.int64(4)), offset_factor=1)
+            T.func_attr({"tir.noalias": T.bool(True)})
+            rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(3), 
T.int64(2), T.int64(5)), offset_factor=1)
             rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(3), 
T.int64(4), T.int64(5)), offset_factor=1)
             rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, 
(T.int64(2),), "int32", offset_factor=1)
             with T.block("take_backward"):
-                T.reads(rxplaceholder[T.int64(0):T.int64(3), 
T.int64(0):T.int64(2), T.int64(0):T.int64(4)], 
rxplaceholder_1[T.int64(0):T.int64(3), T.int64(0):T.int64(4), 
T.int64(0):T.int64(5)], rxplaceholder_2[T.int64(0):T.int64(2)])
+                T.reads(rxplaceholder[T.int64(0):T.int64(3), 
T.int64(0):T.int64(2), T.int64(0):T.int64(5)], 
rxplaceholder_1[T.int64(0):T.int64(3), T.int64(0):T.int64(4), 
T.int64(0):T.int64(5)], rxplaceholder_2[T.int64(0):T.int64(2)])
                 T.writes(out_buf[T.int64(0):T.int64(3), T.int64(0):T.int64(4), 
T.int64(0):T.int64(5)])
                 for i in range(T.int64(60)):
                     out_buf[i // T.int64(5) // T.int64(4), i // T.int64(5) % 
T.int64(4), i % T.int64(5)] = T.float32(0)
-                out_buf[T.Cast("int64", rxplaceholder_2[T.int64(1)]) // 
T.int64(5) // T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(1)]) // 
T.int64(5) % T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(1)]) % 
T.int64(5)] = out_buf[T.Cast("int64", rxplaceholder_2[T.int64(1)]) // 
T.int64(5) // T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(1)]) // 
T.int64(5) % T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(1)]) % 
T.int64(5)] + rxplaceholder[T.int64(0), T.int64(0), T.int64(1)]
-                out_buf[T.Cast("int64", rxplaceholder_2[T.int64(0)]) // 
T.int64(5) // T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(0)]) // 
T.int64(5) % T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(0)]) % 
T.int64(5)] = out_buf[T.Cast("int64", rxplaceholder_2[T.int64(0)]) // 
T.int64(5) // T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(0)]) // 
T.int64(5) % T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(0)]) % 
T.int64(5)] + rxplaceholder[T.int64(0), T.int64(0), T.int64(0)]
+                for parallel, serial in T.grid(T.int64(15), T.int64(2)):
+                    out_buf[(parallel // T.int64(5) * T.int64(5) * T.int64(4) 
+ T.Cast("int64", rxplaceholder_2[serial]) * T.int64(5) + parallel % 
T.int64(5)) // T.int64(5) // T.int64(4), (parallel // T.int64(5) * T.int64(5) * 
T.int64(4) + T.Cast("int64", rxplaceholder_2[serial]) * T.int64(5) + parallel % 
T.int64(5)) // T.int64(5) % T.int64(4), (parallel // T.int64(5) * T.int64(5) * 
T.int64(4) + T.Cast("int64", rxplaceholder_2[serial]) * T.int64(5) + parallel % 
T.int64(5)) % T.int64(5)]  [...]
 
         @R.function
-        def main(output_grad: R.Tensor((3, 2, 4), dtype="float32"), x: 
R.Tensor((3, 4, 5), dtype="float32"), indices: R.Tensor((2,), dtype="int32")) 
-> R.Tensor((3, 4, 5), dtype="float32"):
+        def main(output_grad: R.Tensor((3, 2, 5), dtype="float32"), x: 
R.Tensor((3, 4, 5), dtype="float32"), indices: R.Tensor((2,), dtype="int32")) 
-> R.Tensor((3, 4, 5), dtype="float32"):
             cls = Expected
             gv = R.call_tir(cls.take_backward, (output_grad, x, indices), 
out_sinfo=R.Tensor((3, 4, 5), dtype="float32"))
             return gv
@@ -333,5 +333,49 @@ def test_take_backward():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_take_backward_symbolic():
+    # fmt: off
+    @tvm.script.ir_module
+    class TakeBackward:
+        @R.function
+        def main(output_grad: R.Tensor(("m", "i"), "float32"), x: 
R.Tensor(("m", "n"), "float32"), indices: R.Tensor(("i",), "int32")):
+            m = T.int64()
+            i = T.int64()
+            gv = R.grad.take_backward(output_grad, x, indices, axis=1)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def take_backward(var_rxplaceholder: T.handle, var_rxplaceholder_1: 
T.handle, var_rxplaceholder_2: T.handle, var_take_backward: T.handle):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            m, i = T.int64(), T.int64()
+            rxplaceholder = T.match_buffer(var_rxplaceholder, (m, i), 
offset_factor=1)
+            n = T.int64()
+            rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (m, n), 
offset_factor=1)
+            rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (i,), 
"int32", offset_factor=1)
+            out_buf = T.match_buffer(var_take_backward, (m, n))
+            with T.block("take_backward"):
+                T.reads(rxplaceholder[T.int64(0):m, T.int64(0):i], 
rxplaceholder_1[T.int64(0):m, T.int64(0):n], rxplaceholder_2[T.int64(0):i])
+                T.writes(out_buf[T.int64(0):m, T.int64(0):n])
+                for i_1 in range(m * n):
+                    out_buf[i_1 // n, i_1 % n] = T.float32(0)
+                for parallel, serial in T.grid(m, i):
+                    out_buf[(parallel * n + T.Cast("int64", 
rxplaceholder_2[serial])) // n, (parallel * n + T.Cast("int64", 
rxplaceholder_2[serial])) % n] = out_buf[(parallel * n + T.Cast("int64", 
rxplaceholder_2[serial])) // n, (parallel * n + T.Cast("int64", 
rxplaceholder_2[serial])) % n] + rxplaceholder[(parallel * i + serial) // i, 
(parallel * i + serial) % i]
+
+        @R.function
+        def main(output_grad: R.Tensor(("m", "i"), dtype="float32"), x: 
R.Tensor(("m", "n"), dtype="float32"), indices: R.Tensor(("i",), 
dtype="int32")) -> R.Tensor(("m", "n"), dtype="float32"):
+            m = T.int64()
+            n = T.int64()
+            i = T.int64()
+            cls = Expected
+            gv = R.call_tir(cls.take_backward, (output_grad, x, indices), 
out_sinfo=R.Tensor((m, n), dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(TakeBackward)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to