Lunderberg commented on issue #17207:
URL: https://github.com/apache/tvm/issues/17207#issuecomment-2254182004

   Looks like the test case is correctly lifting out a pre-processing step (the 
computation of `B_offset`) into a separate function, but the test script 
doesn't end up running the pre-processing step.  In most cases, this would 
result in a shape mismatch (e.g. if `matmul_weight = 
R.permute_dims(linear_weight)` is lifted out), but in this case the shapes of 
`B` and `B_offset` match.
   
   If I update your test case to call `"main_transform_params"` first, then the 
results do match.
   
   ```python
   #!/usr/bin/env python3
   
   import tvm
   from tvm import relax
   import numpy as np
   
   from tvm.script import ir as I, relax as R
   
   
   @I.ir_module
   class Module:
       @R.function
       def main(A: R.Tensor((16,), dtype="int32"), B: R.Tensor((16,), 
dtype="int32")) -> R.Tensor(
           (16,), dtype="int32"
       ):
           R.func_attr({"num_input": 1})
           cls = Module
           with R.dataflow():
               offset = R.ones(R.shape([16]), dtype="int32")
               A_offset = R.add(A, offset)
               B_offset = R.add(B, offset)
               output = R.multiply(A_offset, B_offset)
               R.output(output)
           return output
   
   
   def compile_mod(mod):
       mod = relax.transform.FuseTIR()(mod)
       mod = relax.transform.LambdaLift()(mod)
       ex = relax.build(mod, target="llvm")
       vm = relax.VirtualMachine(ex, tvm.cpu())
       return vm
   
   
   def main():
       mod = Module
       mod = tvm.relax.transform.LegalizeOps()(mod)
   
       input_0 = tvm.nd.array(np.random.randint(10, size=[16]).astype("int32"))
       input_1 = tvm.nd.array(np.random.randint(10, size=[16]).astype("int32"))
   
       compiled_before = compile_mod(mod)
       # Without LiftTransformParams, all parameters are directly
       # accepted as arguments to main.
       before_outputs = compiled_before["main"](input_0, input_1)
   
       compiled_after = compile_mod(relax.transform.LiftTransformParams()(mod))
       # With LiftTransformParams, preprocessing steps that may be
       # performed on weight parameter (param index >
       # func.attrs['num_input']) are lifted out.
       transformed_weights = compiled_after["main_transform_params"]([input_1])
       # The main function is now called using the transformed weights,
       # rather than the original `input_1`.
       after_outputs = compiled_after["main"](input_0, *transformed_weights)
   
       np.testing.assert_equal(before_outputs.numpy(), after_outputs.numpy())
   
   if __name__ == "__main__":
       main()
   
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to