Lunderberg commented on issue #17207:
URL: https://github.com/apache/tvm/issues/17207#issuecomment-2254182004
Looks like the test case is correctly lifting out a pre-processing step (the
computation of `B_offset`) into a separate function, but the test script
doesn't end up running the pre-processing step. In most cases, this would
result in a shape mismatch (e.g. if `matmul_weight =
R.permute_dims(linear_weight)` is lifted out), but in this case the shapes of
`B` and `B_offset` match.
If I update your test case to call `"main_transform_params"` first, then the
results do match.
```python
#!/usr/bin/env python3
import tvm
from tvm import relax
import numpy as np
from tvm.script import ir as I, relax as R
@I.ir_module
class Module:
@R.function
def main(A: R.Tensor((16,), dtype="int32"), B: R.Tensor((16,),
dtype="int32")) -> R.Tensor(
(16,), dtype="int32"
):
R.func_attr({"num_input": 1})
cls = Module
with R.dataflow():
offset = R.ones(R.shape([16]), dtype="int32")
A_offset = R.add(A, offset)
B_offset = R.add(B, offset)
output = R.multiply(A_offset, B_offset)
R.output(output)
return output
def compile_mod(mod):
mod = relax.transform.FuseTIR()(mod)
mod = relax.transform.LambdaLift()(mod)
ex = relax.build(mod, target="llvm")
vm = relax.VirtualMachine(ex, tvm.cpu())
return vm
def main():
mod = Module
mod = tvm.relax.transform.LegalizeOps()(mod)
input_0 = tvm.nd.array(np.random.randint(10, size=[16]).astype("int32"))
input_1 = tvm.nd.array(np.random.randint(10, size=[16]).astype("int32"))
compiled_before = compile_mod(mod)
# Without LiftTransformParams, all parameters are directly
# accepted as arguments to main.
before_outputs = compiled_before["main"](input_0, input_1)
compiled_after = compile_mod(relax.transform.LiftTransformParams()(mod))
# With LiftTransformParams, preprocessing steps that may be
# performed on weight parameter (param index >
# func.attrs['num_input']) are lifted out.
transformed_weights = compiled_after["main_transform_params"]([input_1])
# The main function is now called using the transformed weights,
# rather than the original `input_1`.
after_outputs = compiled_after["main"](input_0, *transformed_weights)
np.testing.assert_equal(before_outputs.numpy(), after_outputs.numpy())
if __name__ == "__main__":
main()
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]