This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6258fae [Fix] Tensor core type issue for dense (#7187)
6258fae is described below
commit 6258fae6d1e9ab77b8065d4ffb81a5033665e0cc
Author: Leyuan Wang <[email protected]>
AuthorDate: Fri Jan 1 16:07:41 2021 -0800
[Fix] Tensor core type issue for dense (#7187)
* fix tc type issue for dense
* fix lint
* rm float 32
Co-authored-by: Leyuan Wang <[email protected]>
---
python/tvm/relay/op/strategy/cuda.py | 23 ++++++++++++++++++++---
1 file changed, 20 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relay/op/strategy/cuda.py
b/python/tvm/relay/op/strategy/cuda.py
index 9d8420c..37946c0 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -678,9 +678,26 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
if target.kind.name == "cuda":
if nvcc.have_tensorcore(target=target):
if (
- (i % 16 == 0 and b % 16 == 0 and o % 16 == 0)
- or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0)
- or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0)
+ (
+ data.dtype in ["float16", "int8", "uint8"]
+ and (
+ (i % 16 == 0 and b % 16 == 0 and o % 16 == 0)
+ or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0)
+ or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0)
+ )
+ )
+ or (
+ data.dtype in ["int4", "uint4"]
+ and i % 32 == 0
+ and b % 8 == 0
+ and o % 8 == 0
+ )
+ or (
+ data.dtype in ["int1", "uint1"]
+ and i % 128 == 0
+ and b % 8 == 0
+ and o % 8 == 0
+ )
):
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_tensorcore),