This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6258fae  [Fix] Tensor core type issue for dense (#7187)
6258fae is described below

commit 6258fae6d1e9ab77b8065d4ffb81a5033665e0cc
Author: Leyuan Wang <[email protected]>
AuthorDate: Fri Jan 1 16:07:41 2021 -0800

    [Fix] Tensor core type issue for dense (#7187)
    
    * fix tc type issue for dense
    
    * fix lint
    
    * rm float 32
    
    Co-authored-by: Leyuan Wang <[email protected]>
---
 python/tvm/relay/op/strategy/cuda.py | 23 ++++++++++++++++++++---
 1 file changed, 20 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/op/strategy/cuda.py 
b/python/tvm/relay/op/strategy/cuda.py
index 9d8420c..37946c0 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -678,9 +678,26 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
         if target.kind.name == "cuda":
             if nvcc.have_tensorcore(target=target):
                 if (
-                    (i % 16 == 0 and b % 16 == 0 and o % 16 == 0)
-                    or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0)
-                    or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0)
+                    (
+                        data.dtype in ["float16", "int8", "uint8"]
+                        and (
+                            (i % 16 == 0 and b % 16 == 0 and o % 16 == 0)
+                            or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0)
+                            or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0)
+                        )
+                    )
+                    or (
+                        data.dtype in ["int4", "uint4"]
+                        and i % 32 == 0
+                        and b % 8 == 0
+                        and o % 8 == 0
+                    )
+                    or (
+                        data.dtype in ["int1", "uint1"]
+                        and i % 128 == 0
+                        and b % 8 == 0
+                        and o % 8 == 0
+                    )
                 ):
                     strategy.add_implementation(
                         wrap_compute_dense(topi.cuda.dense_tensorcore),

Reply via email to