This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 1230b4009b [Unity] Dtype check in legalization of R.matmul (#15825)
1230b4009b is described below
commit 1230b4009bac859b28853edaf9e7c478c3aa76a8
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Sep 27 14:47:49 2023 -0500
[Unity] Dtype check in legalization of R.matmul (#15825)
* [Unity] Dtype check in legalization of R.matmul
Prior to this commit, if `R.matmul` has arguments with an unknown
datatype, it would produce an error from within the `R.call_te`
implementation of `BlockBuilder.call_te`, stating `TVMError: cannot
make const for type handle`. While true, this is not the most easily
interpreted error message.
This error message occurs due to the use of an unknown data type in
Relax as a concrete data type in TE/TIR. While Relax supports unknown
datatypes, TIR does not. When the Relax representation of an unknown
datatype, `DataType::Void()`, is passed into TIR, this is interpreted
as an opaque handle, which cannot have a scalar constant produced for
it.
This commit validates the types being used in Relax, prior to
generating the legalized implementation of `R.matmul`. By catching
the error earlier, the error message can be in terms of the Relax
objects being used.
* Add unit test to validate location of caught error.
---
.../relax/transform/legalize_ops/linear_algebra.py | 8 ++++++++
tests/python/relax/test_transform_legalize_ops.py | 24 +++++++++++++++++++++-
2 files changed, 31 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/transform/legalize_ops/linear_algebra.py
b/python/tvm/relax/transform/legalize_ops/linear_algebra.py
index 7cc75bab1c..318c9521f3 100644
--- a/python/tvm/relax/transform/legalize_ops/linear_algebra.py
+++ b/python/tvm/relax/transform/legalize_ops/linear_algebra.py
@@ -89,6 +89,14 @@ def _matmul(bb: BlockBuilder, call: Call) -> Expr:
name="matmul",
)
+ lhs, rhs = call.args
+ lhs_sinfo = call.args[0].struct_info
+ rhs_sinfo = call.args[1].struct_info
+ assert lhs_sinfo.dtype and rhs_sinfo.dtype, (
+ f"To legalize R.matmul into R.call_tir, the dtype of both operands
must be known. "
+ f"However, the LHS {lhs} has struct info {lhs_sinfo}
(dtype='{lhs_sinfo.dtype}') "
+ f"and the RHS {rhs} has struct info {rhs_sinfo}
(dtype='{rhs_sinfo.dtype}')."
+ )
return bb.call_te(te_matmul, call.args[0], call.args[1],
primfunc_name_hint="matmul")
diff --git a/tests/python/relax/test_transform_legalize_ops.py
b/tests/python/relax/test_transform_legalize_ops.py
index 146e2e0cea..af6004bd0a 100644
--- a/tests/python/relax/test_transform_legalize_ops.py
+++ b/tests/python/relax/test_transform_legalize_ops.py
@@ -15,11 +15,13 @@
# specific language governing permissions and limitations
# under the License.
+import pytest
+
import tvm
from tvm import relax
from tvm.relax.transform import LegalizeOps
from tvm.relax.transform.legalize_ops.common import register_legalize
-from tvm.script import relax as R, tir as T
+from tvm.script import relax as R, tir as T, ir as I
import tvm.testing
@@ -260,5 +262,25 @@ def test_legalize_scalar_data_type_preserve():
tvm.ir.assert_structural_equal(After2, Expected2)
+def test_matmul_legalization_requires_known_dtype():
+ @I.ir_module
+ class ArbitraryDtype:
+ @R.function
+ def main(A: R.Tensor([16, 32]), B: R.Tensor([32, 8])) -> R.Tensor([16,
8]):
+ return R.matmul(A, B)
+
+ with pytest.raises(AssertionError) as err:
+ LegalizeOps()(ArbitraryDtype)
+
+ # This error should be caught while attempting to legalize the
+ # R.matmul, where we can present a user-friendly error.
+ # Otherwise, the error isn't caught until the implementation of
+ # `BlockBuilder.call_te`, when attempting to create a numeric
+ # constant of type kHandle, which produces a much less
+ # user-friendly error.
+ err_message = err.value.args[0]
+ assert err_message.startswith("To legalize R.matmul")
+
+
if __name__ == "__main__":
tvm.testing.main()