This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 1230b4009b [Unity] Dtype check in legalization of R.matmul (#15825)
1230b4009b is described below

commit 1230b4009bac859b28853edaf9e7c478c3aa76a8
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Sep 27 14:47:49 2023 -0500

    [Unity] Dtype check in legalization of R.matmul (#15825)
    
    * [Unity] Dtype check in legalization of R.matmul
    
    Prior to this commit, if `R.matmul` has arguments with an unknown
    datatype, it would produce an error from within the `R.call_te`
    implementation of `BlockBuilder.call_te`, stating `TVMError: cannot
    make const for type handle`.  While true, this is not the most easily
    interpreted error message.
    
    This error message occurs due to the use of an unknown data type in
    Relax as a concrete data type in TE/TIR.  While Relax supports unknown
    datatypes, TIR does not.  When the Relax representation of an unknown
    datatype, `DataType::Void()`, is passed into TIR, this is interpreted
    as an opaque handle, which cannot have a scalar constant produced for
    it.
    
    This commit validates the types being used in Relax, prior to
    generating the legalized implementation of `R.matmul`.  By catching
    the error earlier, the error message can be in terms of the Relax
    objects being used.
    
    * Add unit test to validate location of caught error.
---
 .../relax/transform/legalize_ops/linear_algebra.py |  8 ++++++++
 tests/python/relax/test_transform_legalize_ops.py  | 24 +++++++++++++++++++++-
 2 files changed, 31 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/transform/legalize_ops/linear_algebra.py 
b/python/tvm/relax/transform/legalize_ops/linear_algebra.py
index 7cc75bab1c..318c9521f3 100644
--- a/python/tvm/relax/transform/legalize_ops/linear_algebra.py
+++ b/python/tvm/relax/transform/legalize_ops/linear_algebra.py
@@ -89,6 +89,14 @@ def _matmul(bb: BlockBuilder, call: Call) -> Expr:
             name="matmul",
         )
 
+    lhs, rhs = call.args
+    lhs_sinfo = call.args[0].struct_info
+    rhs_sinfo = call.args[1].struct_info
+    assert lhs_sinfo.dtype and rhs_sinfo.dtype, (
+        f"To legalize R.matmul into R.call_tir, the dtype of both operands 
must be known.  "
+        f"However, the LHS {lhs} has struct info {lhs_sinfo} 
(dtype='{lhs_sinfo.dtype}') "
+        f"and the RHS {rhs} has struct info {rhs_sinfo} 
(dtype='{rhs_sinfo.dtype}')."
+    )
     return bb.call_te(te_matmul, call.args[0], call.args[1], 
primfunc_name_hint="matmul")
 
 
diff --git a/tests/python/relax/test_transform_legalize_ops.py 
b/tests/python/relax/test_transform_legalize_ops.py
index 146e2e0cea..af6004bd0a 100644
--- a/tests/python/relax/test_transform_legalize_ops.py
+++ b/tests/python/relax/test_transform_legalize_ops.py
@@ -15,11 +15,13 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import pytest
+
 import tvm
 from tvm import relax
 from tvm.relax.transform import LegalizeOps
 from tvm.relax.transform.legalize_ops.common import register_legalize
-from tvm.script import relax as R, tir as T
+from tvm.script import relax as R, tir as T, ir as I
 import tvm.testing
 
 
@@ -260,5 +262,25 @@ def test_legalize_scalar_data_type_preserve():
     tvm.ir.assert_structural_equal(After2, Expected2)
 
 
+def test_matmul_legalization_requires_known_dtype():
+    @I.ir_module
+    class ArbitraryDtype:
+        @R.function
+        def main(A: R.Tensor([16, 32]), B: R.Tensor([32, 8])) -> R.Tensor([16, 
8]):
+            return R.matmul(A, B)
+
+    with pytest.raises(AssertionError) as err:
+        LegalizeOps()(ArbitraryDtype)
+
+    # This error should be caught while attempting to legalize the
+    # R.matmul, where we can present a user-friendly error.
+    # Otherwise, the error isn't caught until the implementation of
+    # `BlockBuilder.call_te`, when attempting to create a numeric
+    # constant of type kHandle, which produces a much less
+    # user-friendly error.
+    err_message = err.value.args[0]
+    assert err_message.startswith("To legalize R.matmul")
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to