MasterJH5574 commented on code in PR #16757:
URL: https://github.com/apache/tvm/pull/16757#discussion_r1536871492


##########
python/tvm/relax/frontend/nn/exporter.py:
##########
@@ -176,35 +183,26 @@ def _unwrap_ret(expr: typing.Any) -> typing.Any:
     def _convert_input(arg):
         if isinstance(arg, tir.Var):
             return rx.Var(arg.name, struct_info=ShapeStructInfo(values=[arg]))
-        if isinstance(arg, (core.Tensor, core.Object)):
+        elif isinstance(arg, (core.Tensor, core.Object)):
             return arg._expr  # pylint: disable=protected-access
-        if isinstance(arg, _spec.Tuple):
+        elif isinstance(arg, _spec.Tuple):
             return rx.Var(
                 arg.name,
                 struct_info=TupleStructInfo(
                     [_convert_input(arg_i).struct_info for arg_i in 
arg.elements]
                 ),
             )
-        raise TypeError(f"Unsupported input type: {type(arg)}")
+        elif isinstance(arg, rx.Expr):
+            return arg
+        else:
+            raise TypeError(f"Unsupported input type: {type(arg)}")
 
     def _params(mode: str) -> typing.List[rx.Var]:
         inputs: typing.List[rx.Var] = []
 
-        def _get_var(shape_var: tir.Var) -> tir.Var:
-            name = shape_var.name
-            if name in str2var_params:
-                return str2var_params[name]

Review Comment:
   Hi @Lunderberg, this PR causes the compilation failure in MLC LLM CI for the 
Phi model https://ci.mlc.ai/blue/organizations/jenkins/mlc-llm/detail/main/373/
   
   I dug a bit into this. To describe the issue, say the Phi model has hidden 
size 2560, the final linear layer of the Phi model has weight shape 
`("vocab_size", 2560)`, and bias shape `("vocab_size",)`.
   
   Prior to this PR, both `"vocab_size"` in the weight and bias are remapped to 
the same tir.Var, so that for any `x`, `op.matmul(x, op.transpose(weight)) + 
bias` can have the output shape successfully inferred, represented in 
`vocab_size`.
   
   Since this PR, the `"vocab_size"` var in the weight and bias shape mismatch. 
Therefore, the result shape cannot be symbolically inferred, falling back to 
`R.Tensor(dtype="float16", ndim=3)`, and thus caused the assertion failure as 
you can see in the CI:
   
   ```
   In matmul, x.sinfo = R.Tensor((1, 1, vocab_size), dtype="float16"), b.sinfo 
= R.Tensor((vocab_size,), dtype="float16")
   In wrap nested, expr = R.add(matmul128, lm_head_linear_bias1)
   after emitting, expr = add192, expr.sinfo = R.Tensor(dtype="float16", ndim=3)
   
     File 
"/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/modules.py", line 
139, in forward
       x = x + self.bias
           ~~^~~~~~~~~~~
     File 
"/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/_tensor_op.py", line 
44, in __add__
       return _op().add(self, other)
              ^^^^^^^^^^^^^^^^^^^^^^
     File "/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/op.py", 
line 105, in add
       return wrap_nested(_op.add(a._expr, b._expr), name)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File "/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/core.py", 
line 614, in wrap_nested
       return Tensor(_expr=expr)
              ^^^^^^^^^^^^^^^^^^
     File "/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/core.py", 
line 117, in __init__
       _check_tensor(_expr)
     File "/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/core.py", 
line 112, in _check_tensor
       assert expr.struct_info.shape is not None
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   AssertionError  File 
"/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/modules.py", line 
139, in forward
       x = x + self.bias
           ~~^~~~~~~~~~~
     File 
"/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/_tensor_op.py", line 
44, in __add__
       return _op().add(self, other)
              ^^^^^^^^^^^^^^^^^^^^^^
     File "/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/op.py", 
line 105, in add
       return wrap_nested(_op.add(a._expr, b._expr), name)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File "/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/core.py", 
line 614, in wrap_nested
       return Tensor(_expr=expr)
              ^^^^^^^^^^^^^^^^^^
     File "/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/core.py", 
line 117, in __init__
       _check_tensor(_expr)
     File "/home/ruihang/Workspace/tvm/python/tvm/relax/frontend/nn/core.py", 
line 112, in _check_tensor
       assert expr.struct_info.shape is not None
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   AssertionError
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to