Lunderberg commented on issue #17211: URL: https://github.com/apache/tvm/issues/17211#issuecomment-2256741566
Hmm. I think this is something that should be catchable by propagating the known struct info, but currently isn't caught. 1. In `main_2`, `cls.func()` returns a tuple with known dtype and static shapes, but is assigned to a variable with unknown dtype and shape. This is legal, because the set of all `R.Tuple(R.Tensor, R.Tensor, R.Tensor)` is a superset of the set of all `R.Tuple(R.Tensor((16,16), "int32"), R.Tensor((16,16), "int32"), R.Tensor((32,32), "int32"))`. 2. In `main`, even if the return type of `cls.main_2()` isn't explicitly specified, it gets inferred as `R.Tuple(R.Tensor, R.Tensor)`. 3. The return type from `main` may be more specific than the body. This is intended to ensure that the return type is stable, even if an optimization prevents shape inference from reaching all the way to the end of the function, the function still has accurate annotations. However, this means that the return struct info may be more a sub-type of the body's struct info. 4. Whenever the return type is a sub-type of the body's struct info, a runtime assert is inserted. This is the assert that triggers the error message. I think this is a limitation in the `StructInfo` inference, which should catch the IRModule as ill-formed at compile-time, rather than runtime. However, it would first require a few extra steps of `StructInfo` inference that aren't currently performed. 1. If an expression has more specific StructInfo than the variable it is bound to, propagate from the expression to the variable. 2. If the body of a function has more specific StructInfo than the current return type, propagate from the body to the return type. 3. If a function has more specific StructInfo than the GlobalVar used to represent it, propagate from the function to the GlobalVar. For the example, this would let the `"int32"` type returned by `cls.func` to be propagated through `main_2`, and into `main`. At that point, it could be recognized as an error to return `"int32"` in a function that is marked as returning `"float32"`. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
