Hzfengsy opened a new pull request, #14154:
URL: https://github.com/apache/tvm/pull/14154
`BindParam` replace function params to constant nodes. However, it will drop
the shape information of the params, considering the following case:
```python
@R.function
def main(
x: R.Tensor(("batch", "m"), dtype="float32"),
w0: R.Tensor(("n", "m"), dtype="float32"),
b0: R.Tensor(("n",), dtype="float32"),
w1: R.Tensor(("k", "n"), dtype="float32"),
b1: R.Tensor(("k",), dtype="float32"),
) -> R.Tensor(("batch", "k"), dtype="float32"):
batch = T.Var("batch", "int64")
k = T.Var("k", "int64")
m = T.Var("m", "int64")
n = T.Var("n", "int64")
with R.dataflow():
lv0 = R.call_tir("linear0", (x, w0, b0), out_sinfo=R.Tensor((batch,
n), dtype="float32"))
out = R.call_tir("linear1", (lv0, w1, b1),
out_sinfo=R.Tensor((batch, k), dtype="float32"))
R.output(out)
return out
```
The current pass will simply drop the symbolic var `n`, `k` and cause
undefined vars during build as
```python
@R.function
def main(x: R.Tensor((1, "m"), dtype="float32")) ->
R.Tensor(dtype="float32", ndim=2):
m = T.Var("m", "int64")
n = T.Var("n", "int64")
k = T.Var("k", "int64")
with R.dataflow():
lv0 = R.call_tir("linear0", (x, metadata["relax.expr.Constant"][0],
metadata["relax.expr.Constant"][1]), out_sinfo=R.Tensor((1, n),
dtype="float32"))
out = R.call_tir("linear1", (lv0,
metadata["relax.expr.Constant"][2], metadata["relax.expr.Constant"][3]),
out_sinfo=R.Tensor((1, k), dtype="float32"))
R.output(out)
return out
```
This PR updates the pass to bind the symbolic shape during binding.
This issue is reported by @l1nkr.
cc @tqchen @YuchenJin
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]