This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 876ed385ce [Unity][Op] Symbolic shape support of take grad (#14559)
876ed385ce is described below
commit 876ed385ce95dd0823f203603e8e61eaac126202
Author: Chaofan Lin <[email protected]>
AuthorDate: Tue Apr 11 02:28:28 2023 +0800
[Unity][Op] Symbolic shape support of take grad (#14559)
Prior to this PR the legalization of R.grad.take_backward doesn't support
symbolic shape. This PR enables the support by some small fixes and adds tests.
---
python/tvm/relax/transform/legalize_ops/grad.py | 9 ++--
.../relax/test_transform_legalize_ops_grad.py | 60 +++++++++++++++++++---
2 files changed, 57 insertions(+), 12 deletions(-)
diff --git a/python/tvm/relax/transform/legalize_ops/grad.py
b/python/tvm/relax/transform/legalize_ops/grad.py
index 7fb9b0864d..f5c295afc6 100644
--- a/python/tvm/relax/transform/legalize_ops/grad.py
+++ b/python/tvm/relax/transform/legalize_ops/grad.py
@@ -162,7 +162,8 @@ def _grad_take_backward(bb: BlockBuilder, call: Call) ->
Expr:
with ib.for_range(0, fused_shape) as i:
out[i] = tir.const(0, dtype=x_ptr.dtype)
- indices_len = indices_ptr.shape[0].value # must be 1-dim
+ assert len(indices_ptr.shape) == 1 # indices in take must be
1-dim Tensor
+ indices_len = indices_ptr.shape[0]
if axis is not None:
fused_output_grad_shape_pre = 1
@@ -173,14 +174,14 @@ def _grad_take_backward(bb: BlockBuilder, call: Call) ->
Expr:
elif i > axis:
fused_output_grad_shape_nxt *= output_grad_ptr.shape[i]
- x_axis_len = x_ptr.shape[axis].value
+ x_axis_len = x_ptr.shape[axis]
with ib.for_range(
0, fused_output_grad_shape_pre *
fused_output_grad_shape_nxt, "parallel"
) as fused:
i = fused // fused_output_grad_shape_nxt
j = fused % fused_output_grad_shape_nxt
- for l in reversed(range(indices_len)):
+ with ib.for_range(0, indices_len, "serial") as l:
out[
i * fused_output_grad_shape_nxt * x_axis_len
+ indices[l] * fused_output_grad_shape_nxt
@@ -191,7 +192,7 @@ def _grad_take_backward(bb: BlockBuilder, call: Call) ->
Expr:
+ j
]
else:
- for l in reversed(range(indices_len)):
+ with ib.for_range(0, indices_len, "serial") as l:
out[indices[l]] += output_grad[l]
return ib.get()
diff --git a/tests/python/relax/test_transform_legalize_ops_grad.py
b/tests/python/relax/test_transform_legalize_ops_grad.py
index a92537f0d1..14ac96bb4d 100644
--- a/tests/python/relax/test_transform_legalize_ops_grad.py
+++ b/tests/python/relax/test_transform_legalize_ops_grad.py
@@ -302,28 +302,28 @@ def test_take_backward():
@tvm.script.ir_module
class TakeBackward:
@R.function
- def main(output_grad: R.Tensor((3, 2, 4), "float32"), x: R.Tensor((3,
4, 5), "float32"), indices: R.Tensor((2,), "int32")):
- gv = R.grad.take_backward(output_grad, x, indices)
+ def main(output_grad: R.Tensor((3, 2, 5), "float32"), x: R.Tensor((3,
4, 5), "float32"), indices: R.Tensor((2,), "int32")):
+ gv = R.grad.take_backward(output_grad, x, indices, axis=1)
return gv
@I.ir_module
class Expected:
@T.prim_func
def take_backward(var_rxplaceholder: T.handle, var_rxplaceholder_1:
T.handle, var_rxplaceholder_2: T.handle, out_buf: T.Buffer((T.int64(3),
T.int64(4), T.int64(5)), "float32")):
- T.func_attr({"tir.noalias": True})
- rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(3),
T.int64(2), T.int64(4)), offset_factor=1)
+ T.func_attr({"tir.noalias": T.bool(True)})
+ rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(3),
T.int64(2), T.int64(5)), offset_factor=1)
rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(3),
T.int64(4), T.int64(5)), offset_factor=1)
rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2,
(T.int64(2),), "int32", offset_factor=1)
with T.block("take_backward"):
- T.reads(rxplaceholder[T.int64(0):T.int64(3),
T.int64(0):T.int64(2), T.int64(0):T.int64(4)],
rxplaceholder_1[T.int64(0):T.int64(3), T.int64(0):T.int64(4),
T.int64(0):T.int64(5)], rxplaceholder_2[T.int64(0):T.int64(2)])
+ T.reads(rxplaceholder[T.int64(0):T.int64(3),
T.int64(0):T.int64(2), T.int64(0):T.int64(5)],
rxplaceholder_1[T.int64(0):T.int64(3), T.int64(0):T.int64(4),
T.int64(0):T.int64(5)], rxplaceholder_2[T.int64(0):T.int64(2)])
T.writes(out_buf[T.int64(0):T.int64(3), T.int64(0):T.int64(4),
T.int64(0):T.int64(5)])
for i in range(T.int64(60)):
out_buf[i // T.int64(5) // T.int64(4), i // T.int64(5) %
T.int64(4), i % T.int64(5)] = T.float32(0)
- out_buf[T.Cast("int64", rxplaceholder_2[T.int64(1)]) //
T.int64(5) // T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(1)]) //
T.int64(5) % T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(1)]) %
T.int64(5)] = out_buf[T.Cast("int64", rxplaceholder_2[T.int64(1)]) //
T.int64(5) // T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(1)]) //
T.int64(5) % T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(1)]) %
T.int64(5)] + rxplaceholder[T.int64(0), T.int64(0), T.int64(1)]
- out_buf[T.Cast("int64", rxplaceholder_2[T.int64(0)]) //
T.int64(5) // T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(0)]) //
T.int64(5) % T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(0)]) %
T.int64(5)] = out_buf[T.Cast("int64", rxplaceholder_2[T.int64(0)]) //
T.int64(5) // T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(0)]) //
T.int64(5) % T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(0)]) %
T.int64(5)] + rxplaceholder[T.int64(0), T.int64(0), T.int64(0)]
+ for parallel, serial in T.grid(T.int64(15), T.int64(2)):
+ out_buf[(parallel // T.int64(5) * T.int64(5) * T.int64(4)
+ T.Cast("int64", rxplaceholder_2[serial]) * T.int64(5) + parallel %
T.int64(5)) // T.int64(5) // T.int64(4), (parallel // T.int64(5) * T.int64(5) *
T.int64(4) + T.Cast("int64", rxplaceholder_2[serial]) * T.int64(5) + parallel %
T.int64(5)) // T.int64(5) % T.int64(4), (parallel // T.int64(5) * T.int64(5) *
T.int64(4) + T.Cast("int64", rxplaceholder_2[serial]) * T.int64(5) + parallel %
T.int64(5)) % T.int64(5)] [...]
@R.function
- def main(output_grad: R.Tensor((3, 2, 4), dtype="float32"), x:
R.Tensor((3, 4, 5), dtype="float32"), indices: R.Tensor((2,), dtype="int32"))
-> R.Tensor((3, 4, 5), dtype="float32"):
+ def main(output_grad: R.Tensor((3, 2, 5), dtype="float32"), x:
R.Tensor((3, 4, 5), dtype="float32"), indices: R.Tensor((2,), dtype="int32"))
-> R.Tensor((3, 4, 5), dtype="float32"):
cls = Expected
gv = R.call_tir(cls.take_backward, (output_grad, x, indices),
out_sinfo=R.Tensor((3, 4, 5), dtype="float32"))
return gv
@@ -333,5 +333,49 @@ def test_take_backward():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_take_backward_symbolic():
+ # fmt: off
+ @tvm.script.ir_module
+ class TakeBackward:
+ @R.function
+ def main(output_grad: R.Tensor(("m", "i"), "float32"), x:
R.Tensor(("m", "n"), "float32"), indices: R.Tensor(("i",), "int32")):
+ m = T.int64()
+ i = T.int64()
+ gv = R.grad.take_backward(output_grad, x, indices, axis=1)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def take_backward(var_rxplaceholder: T.handle, var_rxplaceholder_1:
T.handle, var_rxplaceholder_2: T.handle, var_take_backward: T.handle):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ m, i = T.int64(), T.int64()
+ rxplaceholder = T.match_buffer(var_rxplaceholder, (m, i),
offset_factor=1)
+ n = T.int64()
+ rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (m, n),
offset_factor=1)
+ rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (i,),
"int32", offset_factor=1)
+ out_buf = T.match_buffer(var_take_backward, (m, n))
+ with T.block("take_backward"):
+ T.reads(rxplaceholder[T.int64(0):m, T.int64(0):i],
rxplaceholder_1[T.int64(0):m, T.int64(0):n], rxplaceholder_2[T.int64(0):i])
+ T.writes(out_buf[T.int64(0):m, T.int64(0):n])
+ for i_1 in range(m * n):
+ out_buf[i_1 // n, i_1 % n] = T.float32(0)
+ for parallel, serial in T.grid(m, i):
+ out_buf[(parallel * n + T.Cast("int64",
rxplaceholder_2[serial])) // n, (parallel * n + T.Cast("int64",
rxplaceholder_2[serial])) % n] = out_buf[(parallel * n + T.Cast("int64",
rxplaceholder_2[serial])) // n, (parallel * n + T.Cast("int64",
rxplaceholder_2[serial])) % n] + rxplaceholder[(parallel * i + serial) // i,
(parallel * i + serial) % i]
+
+ @R.function
+ def main(output_grad: R.Tensor(("m", "i"), dtype="float32"), x:
R.Tensor(("m", "n"), dtype="float32"), indices: R.Tensor(("i",),
dtype="int32")) -> R.Tensor(("m", "n"), dtype="float32"):
+ m = T.int64()
+ n = T.int64()
+ i = T.int64()
+ cls = Expected
+ gv = R.call_tir(cls.take_backward, (output_grad, x, indices),
out_sinfo=R.Tensor((m, n), dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(TakeBackward)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()