This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a34731b7fc [ROCM] DP4A intrinsic support for TE/TIR (#11009)
a34731b7fc is described below
commit a34731b7fcdd41f381e94b53f2279b97b75f7bbd
Author: Masahiro Masuda <[email protected]>
AuthorDate: Fri Apr 15 02:13:13 2022 +0900
[ROCM] DP4A intrinsic support for TE/TIR (#11009)
* [ROCM] Support dp4a on AMDGPU by sdot4 intrinsic
commit 0225f2bfe3f413cd4764c2dba6c922af2520146b
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 14 08:56:10 2022 +0900
share op strategy between cuda and rocm
commit 762c7e8611c9ec3cca3321428e2362c81fe89b9b
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 14 08:28:34 2022 +0900
fixed rocm batch_matmul strategy for mixed i8i8i32
commit ce53e8d141f7f901303ec6a91674337cbf2b2384
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 14 06:17:30 2022 +0900
add rocm sdot4 TIR intrin
commit f4562b991f9180b61be7339b2890de1584656c10
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 14 06:03:44 2022 +0900
rocm sdot4 works
commit 6cc62805f82dd884a18a1c4c0e9bae5866e00da0
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 14 05:32:07 2022 +0900
more wip
commit 0602f4a3157d4cb5a3f280a3a3c514bb6535aac8
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 14 03:47:37 2022 +0900
Squashed commit of the following:
commit 65b8bcf955f44540d6a52c8416e60f3047c8366c
Author: Masahiro Masuda <[email protected]>
Date: Wed Apr 13 20:36:49 2022 +0900
[WIP] adding DP4A support to rocm
commit 4f8f308ab6bb85ef3bdcc2b8e846c2eea15f2167
Author: Masahiro Masuda <[email protected]>
Date: Wed Apr 13 14:03:25 2022 +0900
Squashed commit of the following:
commit 1711be38a17e3b6171350009f1da05824cd0b340
Author: Masahiro Masuda <[email protected]>
Date: Wed Apr 13 13:11:40 2022 +0900
fixed condition for real
commit 8a48fb5262e80e318cd81d5ff51bf95fd5eb576e
Author: Masahiro Masuda <[email protected]>
Date: Wed Apr 13 09:57:42 2022 +0900
Revert "Skip applying sch_rule when both ann and sch_rule are
defined"
This reverts commit 4915c6a5a91ff87038e71f8aff9f31db684b4a95.
commit daea033d2cb06388ef27ddadb80fc5bce72181d2
Author: Masahiro Masuda <[email protected]>
Date: Mon Apr 11 09:31:05 2022 +0900
[Metaschedule] Support rocm and spirv
commit eb0cae2c779808cced074d189e8f487bf46ea89f
Author: Masahiro Masuda <[email protected]>
Date: Wed Apr 13 07:25:04 2022 +0900
dp4a works
commit 4915c6a5a91ff87038e71f8aff9f31db684b4a95
Author: Masahiro Masuda <[email protected]>
Date: Wed Apr 13 06:13:45 2022 +0900
Skip applying sch_rule when both ann and sch_rule are defined
commit 7b3d71c6b21a9c5de9ef2b89d0a7db2800a5f3a2
Author: Masahiro Masuda <[email protected]>
Date: Wed Apr 13 04:40:31 2022 +0900
fixed intrin description
commit 7666cd7a5b0ce182791662673fbe45944c84d0ae
Author: Masahiro Masuda <[email protected]>
Date: Tue Apr 12 19:59:47 2022 +0900
add DP4A intrin
commit 7086bdb75546a2680d12dc8f80c040cea23f729a
Author: Masahiro Masuda <[email protected]>
Date: Tue Apr 12 19:03:44 2022 +0900
works
commit db343974bfae86e51078e40e6170022a782d8e0a
Author: Masahiro Masuda <[email protected]>
Date: Tue Apr 12 12:49:52 2022 +0900
more hack to tensorize loop mapping to make resnet50 e2e work
commit 2409674a7884a60beb50d7aa3345c4b907b8cd13
Author: Masahiro Masuda <[email protected]>
Date: Mon Apr 11 13:40:59 2022 +0900
wip support pad + qnn.conv2d folding
commit 613cb7ec33b6df41f1ebe0f0a0ac8eca7c73cff1
Author: Masahiro Masuda <[email protected]>
Date: Sun Apr 10 12:04:08 2022 +0900
hack to tensorize loop mapping to make conv2d work
commit 9e4f9df6a409396a8a4a20d967c4f51accf5d210
Author: Masahiro Masuda <[email protected]>
Date: Sun Apr 10 11:34:13 2022 +0900
wrap tensorize with try/catch
commit d4b496d858da0ae43063d47cb03a28b803d0269f
Author: Masahiro Masuda <[email protected]>
Date: Sun Apr 10 11:33:39 2022 +0900
revert change in task_scheduler.cc
commit 476129be7b286f5d109402280aea585e89f6dc1d
Author: Masahiro Masuda <[email protected]>
Date: Sat Apr 9 05:54:10 2022 +0900
try / catch in ThreadedApply
commit d8226ff26f25eba17d4000f25131822874bdc2cc
Author: Masahiro Masuda <[email protected]>
Date: Fri Apr 8 17:17:59 2022 +0900
filter out invalid candidate
commit 2632899a2759885d338e25f2a25ba0b2c555f0c3
Author: Masahiro Masuda <[email protected]>
Date: Fri Apr 8 10:09:48 2022 +0900
try graceful exit in parallel_for_dynamic
commit 9d6741c3dd29c4dde861aa1d3b2ca85f560f5ac6
Author: Masahiro Masuda <[email protected]>
Date: Fri Apr 8 09:35:51 2022 +0900
[QNN] Fix broadcast for invalid axis
commit 6ccde0959343ce4246ef99505b4f54de469a1a5c
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 20:51:15 2022 +0900
refactor rewrite_tensorize
commit 2ce206699f10b03b9611c4683018f7e0c70c7eb5
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 20:48:17 2022 +0900
allow missing schedule_rule in post order apply
commit 3a69353a29abfc454e28d4e530d22a3e2043712e
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 19:42:48 2022 +0900
refactor rewrite_tensorize
commit 43e0b2f7f98299679807aaf1ffb13cce2b5f5ce3
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 18:25:14 2022 +0900
rewrite_vnni -> rewrite_tensorize
commit 823797e2627a9bfa812b72019468569ee79eb4c6
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 18:12:12 2022 +0900
VNNI -> WithIntrin
commit 4284a47e5933aa89c1c3362b15ad53b14782fc81
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 17:45:41 2022 +0900
introduce TileForIntrin
commit b87ef32e30e1e71b3f39789f7289976a8cba4ab4
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 17:34:04 2022 +0900
move TilingwithTensorIntrin to auto_tensorize.cc
commit 2fc118b3726586ba13f7de950beaa299b83a0af3
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 17:28:45 2022 +0900
clean up headers
commit d8b2aa325c91b524bec22dc1ec2fc52c9f060fce
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 17:09:32 2022 +0900
clean up using namespace
commit eb05d25e2b71f4a1232a8796d1413011ec7629d3
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 17:03:05 2022 +0900
refactored init
commit 5e6b0a08d447c0470c2c8a993e4bd62673e34fe3
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 16:57:14 2022 +0900
compiled
commit 2b8c430e2fec7ceb285eed7bc7aa73bb9a74a997
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 12:51:55 2022 +0900
wip MultiLevelTiling refactor
commit 7c21a9fea0511c88bd82f49f799b5198252df40a
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 11:58:33 2022 +0900
function doc string not supported by tvmscript
commit 40f9742bc9c3aa11e8c2c0551d1827ad47fc0f39
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 11:56:45 2022 +0900
update vnni intrin name
commit 4814f825a5315efd2a3da8c36d2ce6b5df5447cd
Merge: e0c5eb84b 07bbb38f7
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 11:44:47 2022 +0900
Merge branch 'tir-tensor-intrin' into auto-tensorize-vnni
commit 07bbb38f7fb52db4a2ecde3d5c87cf4d5cd000a1
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 11:24:56 2022 +0900
more lint fix
commit 15e60b42362cc64b1428b219c8eada414d1b8372
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 11:16:08 2022 +0900
black
commit 7a757fe53758e06418ea1367b348b47c8cd2dcf9
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 11:12:54 2022 +0900
pylint
commit 9a3e508b6f4529158e703b4617f2ddaa351a89eb
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 10:58:52 2022 +0900
simplify import
commit d8e43ecf1c0a79a2c195ff31e1e699a447a11335
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 10:52:50 2022 +0900
use vectorlow/high in arm intrin
commit 625cd2774ec455307646b0c26bb3971d89613d1e
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 10:34:57 2022 +0900
fixed offset factor
commit 69e72b6b612588e670937e003435afa647030ceb
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 10:12:02 2022 +0900
Add ARM intrin
commit 1351fdea6b22f231a290a6c28e06732c9cf993cf
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 08:27:27 2022 +0900
use buffer syntax sugar
commit 0ced85fd097ed48aad8714912718d8735791e1fb
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 08:17:43 2022 +0900
rename vnni.py to x86.py
commit 38a5aca87ec438446593a3af17760339211f5ad9
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 07:24:44 2022 +0900
add VNNI unittest
commit 88b763ec48c20cf68db8bc3bae3fa3ae78996ee8
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 07:10:06 2022 +0900
refactored existing test using VNNI intrin
commit 711a0076d9be2b9aa80ada67e1edda5ba1fdf1fd
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 07:04:58 2022 +0900
[TIR] Add VNNI dot product intrinsic for TIR
commit e0c5eb84bf6a0ad2ba0cddc4bdf22a799dc4b8a0
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 11:42:26 2022 +0900
merge fix
commit b171748139e53f0cf75ff4b6fde436f9d8a5fe91
Merge: 71fe3bdf0 82e152a3c
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 11:33:59 2022 +0900
Merge branch 'tir-tensor-intrin' into auto-tensorize-vnni
commit 71fe3bdf02ae10ddbe090a4fd1020f545a05bb41
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 06:57:38 2022 +0900
move tensor intrin under tir
commit 0c51badef45af2a1025ab42fe38d1b3f07ab493e
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 06:12:39 2022 +0900
remove log
commit fed910e03eb94c169d4a160b8f3cad406d04c6aa
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 06:11:22 2022 +0900
more revert
commit 7150aff9fba167d88dbfb40d48727de8a144b9c0
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 06:10:44 2022 +0900
revert stmt_functor change
commit 155107b98b09c5e5cc7f19afbd327b0557a02843
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 06:10:09 2022 +0900
refactored RewriteVNNI a bit
commit ca15255e3a882b89b05bb83079640c929fb63096
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 05:41:13 2022 +0900
add RewriteVNNI
commit dc9f71d5e3122b50fa8ae6a4462f959f13870b05
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 05:38:56 2022 +0900
vectorized init loop
commit fcc31ee20ddfafd47f566bf98ff40a9f684d12eb
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 04:55:36 2022 +0900
tensorize worked
commit 2b534377a45b9ab84bf35c3d7c03ecae7616d17f
Author: Masahiro Masuda <[email protected]>
Date: Wed Apr 6 19:11:05 2022 +0900
TilingwithTensorIntrin works
commit 86baa31e773fc864f77dc113bc9a93b79f3fc652
Author: Masahiro Masuda <[email protected]>
Date: Wed Apr 6 08:58:27 2022 +0900
Ported auto-tensorization code
commit 82e152a3c91144041ade783116a50565ebb48b89
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 11:24:56 2022 +0900
more lint fix
commit 88d9bdd3b21302bc2dd068a990df15c375a1a8ef
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 11:16:08 2022 +0900
black
commit 31fe7eb8075445161d804d170772eac8e90d3425
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 11:12:54 2022 +0900
pylint
commit 7876754effc40ad089349534dacd75df19d38fc4
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 10:58:52 2022 +0900
simplify import
commit 56f2e9a85069426021e2872eb1da95bf134ac7e0
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 10:52:50 2022 +0900
use vectorlow/high in arm intrin
commit 995cc8d6fcec70a3fadcfb1c6fee7b9f0b5a0951
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 10:34:57 2022 +0900
fixed offset factor
commit 86bbd4955b34257d68d957cb4a2536aea3ef9bac
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 10:12:02 2022 +0900
Add ARM intrin
commit 120fd96e80307b4301ee3fc93e6793e0b40485f0
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 08:27:27 2022 +0900
use buffer syntax sugar
commit 0f0682d00c3961afd1f492ae55f180c5b5502767
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 08:17:43 2022 +0900
rename vnni.py to x86.py
commit f88c31ead1fa6db4bfd2c88eeaf5f665e4c6dddb
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 07:24:44 2022 +0900
add VNNI unittest
commit 6cc80094adac398762924b0b31a4c741417ba9dc
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 07:10:06 2022 +0900
refactored existing test using VNNI intrin
commit 11a29c704cdaad96aeeca39c9c753ef006d27a50
Author: Masahiro Masuda <[email protected]>
Date: Thu Apr 7 07:04:58 2022 +0900
[TIR] Add VNNI dot product intrinsic for TIR
* cleanup
* black
* update dot prod intrin
* add mattr kind
* conv2d topi test working
* add dense and bmm test
* add conv2d relay test
* add tir intrin test
* pylint
---
python/tvm/relay/op/strategy/cuda.py | 8 +-
python/tvm/relay/op/strategy/rocm.py | 172 +++------------------
python/tvm/relay/qnn/op/legalizations.py | 22 +--
python/tvm/tir/tensor_intrin/__init__.py | 2 +
python/tvm/tir/tensor_intrin/dot_product_common.py | 55 +++++++
python/tvm/tir/tensor_intrin/rocm.py | 47 ++++++
python/tvm/topi/cuda/batch_matmul.py | 7 +-
python/tvm/topi/cuda/conv2d_alter_op.py | 12 +-
python/tvm/topi/cuda/conv2d_int8.py | 4 +-
python/tvm/topi/cuda/dense.py | 5 +-
python/tvm/topi/cuda/tensor_intrin.py | 23 ++-
python/tvm/topi/rocm/dense.py | 79 +---------
python/tvm/topi/utils.py | 7 +
src/target/target_kind.cc | 1 +
tests/python/relay/test_op_level1.py | 38 +++++
tests/python/relay/test_op_level10.py | 35 +++++
tests/python/relay/test_op_level2.py | 50 ++++++
tests/python/topi/python/test_topi_conv2d_int8.py | 13 +-
tests/python/topi/python/test_topi_dense.py | 1 -
.../python/unittest/test_tir_schedule_tensorize.py | 50 ++++++
20 files changed, 358 insertions(+), 273 deletions(-)
diff --git a/python/tvm/relay/op/strategy/cuda.py
b/python/tvm/relay/op/strategy/cuda.py
index 08da62e640..4253d93f65 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -145,7 +145,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
if layout == "NCHW":
assert kernel_layout == "OIHW"
if (
- (target.kind.name in ["cuda", "vulkan"])
+ (target.kind.name in ["cuda", "vulkan", "rocm"])
and data.dtype in ("int8", "uint8")
and kernel.dtype in ("int8", "uint8")
):
@@ -297,7 +297,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
Need to satisfy tensor core schedule."
)
elif (
- (target.kind.name in ["cuda", "vulkan"])
+ (target.kind.name in ["cuda", "vulkan", "rocm"])
and layout == "NCHW4c"
and data.dtype in ["int8", "uint8"]
):
@@ -376,7 +376,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
ic_chunk = in_channels // 4
if (
- (target.kind.name in ["cuda", "vulkan"])
+ (target.kind.name in ["cuda", "vulkan", "rocm"])
and data.dtype in ["int8", "uint8"]
and kernel.dtype in ["int8", "uint8"]
and channels % groups == 0
@@ -836,7 +836,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
b, i = get_const_tuple(data.shape)
o, _ = get_const_tuple(weights.shape)
if (
- target.kind.name in ["cuda", "vulkan"]
+ target.kind.name in ["cuda", "vulkan", "rocm"]
and data.dtype == "int8"
and weights.dtype == "int8"
and out_type.dtype == "int32"
diff --git a/python/tvm/relay/op/strategy/rocm.py
b/python/tvm/relay/op/strategy/rocm.py
index 1453128eeb..6e91101826 100644
--- a/python/tvm/relay/op/strategy/rocm.py
+++ b/python/tvm/relay/op/strategy/rocm.py
@@ -17,162 +17,39 @@
"""Definition of ROCm operator strategy."""
# pylint:
disable=invalid-name,unused-argument,unused-wildcard-import,wildcard-import
from tvm import topi
-from tvm.auto_scheduler import is_auto_scheduler_enabled
from tvm.te import SpecializedCondition
from tvm.contrib.thrust import can_use_rocthrust
from tvm.contrib import miopen
from .generic import *
from .. import op as _op
-from .cuda import judge_winograd, naive_schedule
+from .cuda import batch_matmul_strategy_cuda, conv2d_strategy_cuda,
dense_strategy_cuda
@conv2d_strategy.register("rocm")
def conv2d_strategy_rocm(attrs, inputs, out_type, target):
"""conv2d rocm strategy"""
- strategy = _op.OpStrategy()
- data, kernel = inputs
- dilation_h, dilation_w = attrs.get_int_tuple("dilation")
groups = attrs.groups
layout = attrs.data_layout
- stride_h, stride_w = attrs.get_int_tuple("strides")
- kernel_layout = attrs.kernel_layout
padding = attrs.get_int_tuple("padding")
- if dilation_h < 1 or dilation_w < 1:
- raise ValueError("dilation should be positive value")
-
- if groups == 1:
- if layout == "NCHW":
- # TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is
int8/uint8.
- assert kernel_layout == "OIHW"
- strategy.add_implementation(
- wrap_compute_conv2d(topi.cuda.conv2d_nchw),
- wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
- name="conv2d_nchw.cuda",
- )
- _, _, kh, kw = get_const_tuple(kernel.shape)
- if (
- 2 < kh < 8
- and 2 < kw < 8
- and kh == kw
- and stride_h == 1
- and stride_w == 1
- and dilation_h == 1
- and dilation_w == 1
- ):
- strategy.add_implementation(
- wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd),
-
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd),
- name="conv2d_nchw_winograd.cuda",
- plevel=5,
- )
- elif layout == "NHWC":
- assert kernel_layout == "HWIO"
- strategy.add_implementation(
- wrap_compute_conv2d(topi.gpu.conv2d_nhwc),
- wrap_topi_schedule(topi.gpu.schedule_conv2d_nhwc),
- name="conv2d_nhwc.gpu",
- )
- N, H, W, _ = get_const_tuple(data.shape)
- KH, KW, CI, CO = get_const_tuple(kernel.shape)
- (_, judge_winograd_autotvm, judge_winograd_auto_scheduler,) =
judge_winograd(
- N,
- H,
- W,
- KH,
- KW,
- CI,
- CO,
- padding,
- stride_h,
- stride_w,
- dilation_h,
- dilation_w,
- data.dtype,
- kernel.dtype,
- pre_flag=False,
- )
+ strategy = conv2d_strategy_cuda(attrs, inputs, out_type, target)
- if judge_winograd_autotvm:
- strategy.add_implementation(
- wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_direct),
-
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_winograd_direct),
- name="conv2d_nhwc_winograd_direct.cuda",
- plevel=5,
- )
+ # add miopen implementation
+ if (
+ "miopen" in target.libs
+ and groups == 1
+ and layout == "NCHW"
+ and padding[0] == padding[2]
+ and padding[1] == padding[3]
+ ):
+ strategy.add_implementation(
+ wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
+ wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
+ name="conv2d_nchw_miopen.rocm",
+ plevel=50,
+ )
- if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler:
- strategy.add_implementation(
- wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc),
- naive_schedule, # this implementation should never be
picked by autotvm
- name="conv2d_nhwc.winograd",
- plevel=15,
- )
- elif layout == "HWCN":
- assert kernel_layout == "HWIO"
- strategy.add_implementation(
- wrap_compute_conv2d(topi.cuda.conv2d_hwcn),
- wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
- name="conv2d_hwcn.cuda",
- )
- elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
- assert kernel_layout == "OIHW4o4i"
- strategy.add_implementation(
- wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
- wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
- name="conv2d_NCHWc_int8.cuda",
- )
- else:
- raise RuntimeError("Unsupported conv2d layout {} for
CUDA".format(layout))
- # add miopen implementation
- if (
- "miopen" in target.libs
- and layout == "NCHW"
- and padding[0] == padding[2]
- and padding[1] == padding[3]
- ):
- strategy.add_implementation(
- wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
- wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
- name="conv2d_nchw_miopen.rocm",
- plevel=15,
- )
- elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout,
groups):
- if layout == "NCHW":
- assert kernel_layout == "OIHW"
- strategy.add_implementation(
- wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw),
- wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw),
- name="depthwise_conv2d_nchw.cuda",
- )
- elif layout == "NHWC":
- assert kernel_layout == "HWOI"
- strategy.add_implementation(
- wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
- wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc),
- name="depthwise_conv2d_nhwc.cuda",
- )
- else:
- raise RuntimeError("Unsupported depthwise_conv2d layout
{}".format(layout))
- else: # group_conv2d
- if layout == "NCHW":
- # TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when
dtype is int8/uint8.
- assert kernel_layout == "OIHW"
- strategy.add_implementation(
- wrap_compute_conv2d(topi.cuda.group_conv2d_nchw,
has_groups=True),
- wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw),
- name="group_conv2d_nchw.cuda",
- )
- elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
- assert kernel_layout == "OIHW4o4i"
- strategy.add_implementation(
- wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True),
- wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8),
- name="group_conv2d_NCHWc_int8.cuda",
- )
- else:
- raise RuntimeError("Unsupported group_conv2d layout
{}".format(layout))
return strategy
@@ -180,12 +57,8 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):
def dense_strategy_rocm(attrs, inputs, out_type, target):
"""Dense strategy for ROCM"""
assert len(inputs[0].shape) == 2 and len(inputs[1].shape) == 2, "Only
support 2-dim dense"
- strategy = _op.OpStrategy()
- strategy.add_implementation(
- wrap_compute_dense(topi.rocm.dense),
- wrap_topi_schedule(topi.rocm.schedule_dense),
- name="dense.rocm",
- )
+ strategy = dense_strategy_cuda(attrs, inputs, out_type, target)
+
if target.kind.name == "rocm" and "rocblas" in target.libs:
assert out_type.dtype == inputs[0].dtype, "Mixed precision not
supported."
strategy.add_implementation(
@@ -200,13 +73,8 @@ def dense_strategy_rocm(attrs, inputs, out_type, target):
@batch_matmul_strategy.register("rocm")
def batch_matmul_strategy_rocm(attrs, inputs, out_type, target):
"""Batch matmul strategy for ROCM"""
- strategy = _op.OpStrategy()
- strategy.add_implementation(
- wrap_compute_batch_matmul(topi.cuda.batch_matmul),
- wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
- name="batch_matmul.cuda",
- plevel=10,
- )
+ strategy = batch_matmul_strategy_cuda(attrs, inputs, out_type, target)
+
if target.kind.name == "rocm" and "rocblas" in target.libs:
assert out_type.dtype == inputs[0].dtype, "Mixed precision not
supported."
strategy.add_implementation(
diff --git a/python/tvm/relay/qnn/op/legalizations.py
b/python/tvm/relay/qnn/op/legalizations.py
index 93b1ad7a44..e669e14032 100644
--- a/python/tvm/relay/qnn/op/legalizations.py
+++ b/python/tvm/relay/qnn/op/legalizations.py
@@ -24,6 +24,7 @@ from tvm._ffi.base import TVMError
from tvm.relay.qnn.op.canonicalizations import create_integer_lookup_op
from ....topi.x86.utils import target_has_sse42
+from ....topi.utils import is_target
from .. import op as reg
#################################################
@@ -387,18 +388,6 @@ def is_aarch64_arm():
return "aarch64" in target.attrs.get("mtriple", "")
-def is_vulkan():
- """Checks whether we are compiling for a vulkan/spirv target."""
- target = tvm.target.Target.current(allow_none=False)
- return "vulkan" in target.keys
-
-
-def is_cuda():
- """Checks whether we are compiling for a cuda target."""
- target = tvm.target.Target.current(allow_none=False)
- return "cuda" in target.keys
-
-
########################
# ARM CPU legalizations.
########################
@@ -456,10 +445,10 @@ def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):
@qnn_conv2d_legalize.register(["cuda", "gpu"])
def _qnn_conv2d_legalize_cuda(attrs, inputs, types):
- if is_vulkan():
+ if is_target("vulkan"):
# prefers the dtypes to be same. Mixed type is not yet supported.
return helper_change_dtypes_to_be_same(attrs, inputs, types,
relay.qnn.op.conv2d)
- if is_cuda():
+ if is_target(["cuda", "rocm"]):
# CUDA prefers both datatypes to be int8.
return helper_change_dtypes_to_int8(attrs, inputs, types,
relay.qnn.op.conv2d)
return None
@@ -467,11 +456,10 @@ def _qnn_conv2d_legalize_cuda(attrs, inputs, types):
@qnn_dense_legalize.register(["cuda", "gpu"])
def _qnn_dense_legalize_cuda(attrs, inputs, types):
- if is_vulkan():
+ if is_target("vulkan"):
# prefers the dtypes to be same. Mixed type is not yet supported.
return helper_change_dtypes_to_be_same(attrs, inputs, types,
relay.qnn.op.dense)
- if is_cuda():
+ if is_target(["cuda", "rocm"]):
# CUDA prefers both datatypes to be the int8.
return helper_change_dtypes_to_int8(attrs, inputs, types,
relay.qnn.op.dense)
-
return None
diff --git a/python/tvm/tir/tensor_intrin/__init__.py
b/python/tvm/tir/tensor_intrin/__init__.py
index 62159851b3..4115c3b900 100644
--- a/python/tvm/tir/tensor_intrin/__init__.py
+++ b/python/tvm/tir/tensor_intrin/__init__.py
@@ -18,3 +18,5 @@
"""Intrinsics for tensorization."""
from .x86 import *
from .arm_cpu import *
+from .dot_product_common import *
+from .rocm import *
diff --git a/python/tvm/tir/tensor_intrin/dot_product_common.py
b/python/tvm/tir/tensor_intrin/dot_product_common.py
new file mode 100644
index 0000000000..c531b80380
--- /dev/null
+++ b/python/tvm/tir/tensor_intrin/dot_product_common.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,missing-function-docstring
+"""Dot product related intrinsics."""
+from tvm.script import tir as T
+from .. import TensorIntrin
+
+
[email protected]_func
+def dp4a_desc(
+ A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
+ B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
+ C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
+) -> None:
+ with T.block("root"):
+ T.reads(C[0], A[0:4], B[0:4])
+ T.writes(C[0])
+ for i in range(0, 4):
+ with T.block("update"):
+ vi = T.axis.remap("R", [i])
+ C[0] = C[0] + T.cast(A[vi], "int32") * T.cast(B[vi], "int32")
+
+
[email protected]_func
+def dp4a_impl(
+ A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
+ B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
+ C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
+) -> None:
+ with T.block("root"):
+ T.reads(C[0], A[0:4], B[0:4])
+ T.writes(C[0])
+
+ C[0] += T.call_pure_extern(
+ "__dp4a", A.vload([0], "int8x4"), B.vload([0], "int8x4"),
T.int32(0), dtype="int32"
+ )
+
+
+DP4A_INTRIN = "dp4a"
+
+TensorIntrin.register(DP4A_INTRIN, dp4a_desc, dp4a_impl)
diff --git a/python/tvm/tir/tensor_intrin/rocm.py
b/python/tvm/tir/tensor_intrin/rocm.py
new file mode 100644
index 0000000000..7a989d0bcc
--- /dev/null
+++ b/python/tvm/tir/tensor_intrin/rocm.py
@@ -0,0 +1,47 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,missing-function-docstring
+"""Intrinsics for AMDGPU tensorization."""
+from tvm.script import tir as T
+from .. import TensorIntrin
+from .dot_product_common import dp4a_desc
+
+
[email protected]_func
+def sdot4(
+ A: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
+ B: T.Buffer((4,), "int8", offset_factor=1, align=4, scope="shared"),
+ C: T.Buffer((1,), "int32", offset_factor=1, align=4, scope="local"),
+) -> None:
+ with T.block("root"):
+ T.reads(C[0], A[0:4], B[0:4])
+ T.writes(C[0])
+
+ C[0] += T.call_llvm_pure_intrin(
+ T.llvm_lookup_intrinsic_id("llvm.amdgcn.sdot4"),
+ T.uint32(4),
+ T.reinterpret(A.vload([0], "int8x4"), dtype="int32"),
+ T.reinterpret(B.vload([0], "int8x4"), dtype="int32"),
+ T.int32(0),
+ T.bool(1),
+ dtype="int32",
+ )
+
+
+AMDGPU_SDOT4_INTRIN = "sdot4"
+
+TensorIntrin.register(AMDGPU_SDOT4_INTRIN, dp4a_desc, sdot4)
diff --git a/python/tvm/topi/cuda/batch_matmul.py
b/python/tvm/topi/cuda/batch_matmul.py
index 5fce9d7a3f..ff625d6d71 100644
--- a/python/tvm/topi/cuda/batch_matmul.py
+++ b/python/tvm/topi/cuda/batch_matmul.py
@@ -22,7 +22,7 @@ from tvm import te
from tvm.contrib import cublas
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from .. import nn, generic
-from ..utils import traverse_inline, get_const_tuple, get_max_power2_factor
+from ..utils import traverse_inline, get_const_tuple, get_max_power2_factor,
is_target
from .tensor_intrin import dp4a
@@ -333,9 +333,6 @@ def schedule_batch_matmul_int8(cfg, outs):
return s
-_dp4a = dp4a("shared", "shared", "local")
-
-
def _schedule_batch_matmul_int8(cfg, s, output):
input_x, input_y = s[output].op.input_tensors
if len(input_y.op.input_tensors) == 1 and input_y.op.input_tensors[0] ==
input_x:
@@ -372,7 +369,7 @@ def _schedule_batch_matmul_int8(cfg, s, output):
target = tvm.target.Target.current(allow_none=False)
do_tensorize = True
- if "vulkan" in target.keys:
+ if is_target(["vulkan", "rocm"]):
do_tensorize = "+dotprod" in target.mattr or
target.supports_integer_dot_product
if do_tensorize:
diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py
b/python/tvm/topi/cuda/conv2d_alter_op.py
index eaafe15e96..35d50eb367 100644
--- a/python/tvm/topi/cuda/conv2d_alter_op.py
+++ b/python/tvm/topi/cuda/conv2d_alter_op.py
@@ -22,7 +22,7 @@ import tvm
from tvm import te, relay, autotvm
from .. import nn
-from ..utils import get_const_tuple
+from ..utils import get_const_tuple, is_target
from .conv2d_winograd import _infer_tile_size
from .tensorcore_alter_op import pad_to_tensorcore
from ..nn import conv2d_legalize
@@ -34,8 +34,7 @@ logger = logging.getLogger("topi")
@nn.conv2d_alter_layout.register(["cuda", "gpu"])
def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
target = tvm.target.Target.current(allow_none=False)
- doit = "vulkan" in target.keys or "cuda" in target.keys
- if not doit:
+ if not is_target(["vulkan", "rocm", "cuda"]):
return None
dispatch_ctx = autotvm.task.DispatchContext.current
@@ -87,7 +86,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
do_new_layout = False
- if "vulkan" in target.keys:
+ if is_target(["vulkan", "rocm"]):
do_new_layout = "+dotprod" in target.mattr or
target.supports_integer_dot_product
if not do_new_layout:
return None
@@ -349,10 +348,7 @@ def _conv2d_legalize(attrs, inputs, arg_types):
result : tvm.relay.Expr
The legalized expr
"""
-
- target = tvm.target.Target.current(allow_none=False)
- doit = "vulkan" in target.keys or "cuda" in target.keys
- if not doit:
+ if not is_target(["vulkan", "rocm", "cuda"]):
return None
# Dilation not supported yet. Return None if dilation is not (1, 1)
dilation = attrs.get_int_tuple("dilation")
diff --git a/python/tvm/topi/cuda/conv2d_int8.py
b/python/tvm/topi/cuda/conv2d_int8.py
index 15120f6a25..a8b21a1dec 100644
--- a/python/tvm/topi/cuda/conv2d_int8.py
+++ b/python/tvm/topi/cuda/conv2d_int8.py
@@ -26,7 +26,7 @@ from .tensor_intrin import dp4a
from ..nn.pad import pad
from ..nn.conv2d import unpack_NCHWc_to_nchw
from ..nn.utils import get_pad_tuple
-from ..utils import get_const_tuple, traverse_inline
+from ..utils import get_const_tuple, traverse_inline, is_target
def conv2d_nchw_int8(data, kernel, strides, padding, dilation,
out_dtype="int32"):
@@ -312,7 +312,7 @@ def _schedule_conv2d_NCHWc_int8(cfg, s, output):
_, rc_block = s[conv].split(rc_block, factor=4)
target = tvm.target.Target.current(allow_none=False)
do_tensorize = True
- if "vulkan" in target.keys:
+ if is_target(["vulkan", "rocm"]):
do_tensorize = "+dotprod" in target.mattr or
target.supports_integer_dot_product
if do_tensorize:
dtypes = (pad_data.dtype, packed_kernel.dtype)
diff --git a/python/tvm/topi/cuda/dense.py b/python/tvm/topi/cuda/dense.py
index 862e7b5bc5..859f6c1097 100644
--- a/python/tvm/topi/cuda/dense.py
+++ b/python/tvm/topi/cuda/dense.py
@@ -24,7 +24,7 @@ from tvm.contrib import cublas
from .tensor_intrin import dp4a
from .. import tag
from .. import generic
-from ..utils import traverse_inline, get_const_tuple
+from ..utils import traverse_inline, get_const_tuple, is_target
logger = logging.getLogger("topi")
@@ -173,8 +173,9 @@ def _schedule_dense_int8(cfg, s, output):
ko, kt = cfg["tile_k"].apply(s, CC, ko)
target = tvm.target.Target.current(allow_none=False)
do_tensorize = True
- if "vulkan" in target.keys:
+ if is_target(["vulkan", "rocm"]):
do_tensorize = "+dotprod" in target.mattr or
target.supports_integer_dot_product
+
if do_tensorize:
dtypes = (data.dtype, weight.dtype)
s[CC].tensorize(ki, dp4a("shared", "shared", "local", dtypes))
diff --git a/python/tvm/topi/cuda/tensor_intrin.py
b/python/tvm/topi/cuda/tensor_intrin.py
index c0596fc432..0a504906c0 100644
--- a/python/tvm/topi/cuda/tensor_intrin.py
+++ b/python/tvm/topi/cuda/tensor_intrin.py
@@ -18,6 +18,7 @@
"""Tensor intrinsics on CUDA."""
import tvm
from tvm import te
+from ..utils import is_target
def dp4a(x_scope="local", y_scope="local", z_scope="local", dtypes=("int8",
"int8")):
@@ -71,7 +72,27 @@ def dp4a(x_scope="local", y_scope="local", z_scope="local",
dtypes=("int8", "int
vec_y = yy.vload(0, dtype=vec_y_dtype)
prev_z = 0 if index == 0 else zz.vload(0)
- new_z = tvm.tir.call_pure_extern(zz_dtype, "__dp4a", vec_x, vec_y,
prev_z)
+ if is_target("rocm"):
+ # TODO(masahi): Here we are assuming that we are compiling for
gfx10 or later
+ # We can refine the specification for dot product on rocm if
needed later.
+
+ # We can just use "llvm.amdgcn.udot4" for u8u8u32, but it is
not tested.
+ assert (
+ dtypes[0] == "int8" and dtypes[0] == "int8"
+ ), "u8u8u32 dot product for rocm not supported yet"
+
+ new_z = tvm.tir.call_llvm_pure_intrin(
+ zz_dtype,
+ "llvm.amdgcn.sdot4",
+ tvm.tir.const(4, "uint32"),
+ tvm.tir.call_intrin("int32", "tir.reinterpret", vec_x),
+ tvm.tir.call_intrin("int32", "tir.reinterpret", vec_y),
+ prev_z,
+ True,
+ )
+ else:
+ new_z = tvm.tir.call_pure_extern(zz_dtype, "__dp4a", vec_x,
vec_y, prev_z)
+
ib.emit(zz.vstore(0, new_z))
return ib.get()
diff --git a/python/tvm/topi/rocm/dense.py b/python/tvm/topi/rocm/dense.py
index 2f3ce77cc7..983f235f0e 100644
--- a/python/tvm/topi/rocm/dense.py
+++ b/python/tvm/topi/rocm/dense.py
@@ -19,85 +19,8 @@
from tvm import te
from tvm import autotvm
from tvm.contrib import rocblas
-from .. import generic, nn
+from .. import generic
from .. import tag
-from ..utils import traverse_inline
-
-
[email protected]_topi_compute("dense.rocm")
-def dense(cfg, data, weight, bias=None, out_dtype=None):
- """Dense operator for rocm backend.
-
- Parameters
- ----------
- data : tvm.te.Tensor
- 2-D with shape [batch, in_dim]
-
- weight : tvm.te.Tensor
- 2-D with shape [out_dim, in_dim]
-
- bias : tvm.te.Tensor, optional
- 1-D with shape [out_dim]
-
- out_dtype : str
- The output type. This is used for mixed precision.
-
- Returns
- -------
- output : tvm.te.Tensor
- 2-D with shape [batch, out_dim]
- """
- assert len(data.shape) == 2 and len(weight.shape) == 2, "only support
2-dim dense"
- if bias is not None:
- assert len(bias.shape) == 1
- if out_dtype is None:
- out_dtype = data.dtype
- return nn.dense(data, weight, bias, out_dtype)
-
-
[email protected]_topi_schedule("dense.rocm")
-def schedule_dense(cfg, outs):
- """Schedule for dense operator.
-
- Parameters
- ----------
- outs: Array of Tensor
- The computation graph description of dense
- in the format of an array of tensors.
-
- Returns
- -------
- s: Schedule
- The computation schedule for dense.
- """
- outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
- s = te.create_schedule([x.op for x in outs])
-
- def _callback(op):
- if op.tag == "dense":
- Dense = op.output(0)
- num_thread = 64
- k = Dense.op.reduce_axis[0]
- ko, kf = s[Dense].split(k, factor=num_thread)
- DenseF = s.rfactor(Dense, kf)
-
- if Dense.op in s.outputs:
- Out = Dense
- else:
- Out = outs[0].op.output(0)
- s[Dense].compute_at(s[Out], s[Out].op.axis[1])
- s[Out].bind(s[Out].op.axis[0], te.thread_axis("blockIdx.y"))
- s[Out].bind(s[Out].op.axis[1], te.thread_axis("blockIdx.x"))
-
- tx = s[Dense].op.reduce_axis[0]
- thread_x = te.thread_axis("threadIdx.x")
- s[Dense].bind(tx, thread_x)
- s[DenseF].compute_at(s[Dense], tx)
- s[Dense].set_store_predicate(thread_x.var.equal(0))
- s[Out].set_store_predicate(thread_x.var.equal(0))
-
- traverse_inline(s, outs[0].op, _callback)
- return s
@autotvm.register_topi_compute("dense_rocblas.rocm")
diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py
index af68ee905e..f1c6fb5aa4 100644
--- a/python/tvm/topi/utils.py
+++ b/python/tvm/topi/utils.py
@@ -524,3 +524,10 @@ def ceil_div(a, b):
def swap(arr, axis):
"""swap arr[axis] and arr[-1]"""
return arr[:axis] + [arr[-1]] + arr[axis + 1 : -1] + [arr[axis]]
+
+
+def is_target(names):
+ """Return True if the name of the current target is one of provided
names"""
+ names = [names] if isinstance(names, str) else names
+ target = tvm.target.Target.current(allow_none=False)
+ return any(name in target.keys for name in names)
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index 6fef8b48c3..96c193d34a 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -306,6 +306,7 @@ TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA)
TVM_REGISTER_TARGET_KIND("rocm", kDLROCM)
.add_attr_option<String>("mcpu")
.add_attr_option<String>("mtriple")
+ .add_attr_option<Array<String>>("mattr")
.add_attr_option<Bool>("system-lib")
.add_attr_option<Integer>("max_num_threads", Integer(256))
.add_attr_option<Integer>("thread_warp_size", Integer(64))
diff --git a/tests/python/relay/test_op_level1.py
b/tests/python/relay/test_op_level1.py
index c7aceb685b..d4238f81e0 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -676,5 +676,43 @@ def test_dense_vnni():
np.testing.assert_equal(out, ref)
[email protected]("Requires GFX10 AMDGPU")
+def test_dense_rocm_sdot4():
+ data_shape = (32, 96)
+ weight_shape = (128, 96)
+
+ data_dtype = "int8"
+ data = relay.var("data", shape=data_shape, dtype=data_dtype)
+ weight = relay.var("weight", shape=weight_shape, dtype="int8")
+ bias = relay.var("bias", shape=(weight_shape[0],), dtype="int32")
+ dense = relay.nn.dense(data, weight, out_dtype="int32")
+ out = relay.nn.bias_add(dense, bias)
+ mod = tvm.IRModule.from_expr(out)
+
+ target = "rocm -mattr=+dotprod"
+ with tvm.transform.PassContext(opt_level=3):
+ lib = relay.build(mod, target=target)
+
+ asm = lib.lib.imported_modules[0].get_source("asm")
+ assert "v_dot4_i32_i8" in asm
+
+ dev = tvm.device(target, 0)
+ runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
+
+ a = np.random.uniform(1, 10, size=data_shape).astype(data_dtype)
+ b = np.random.uniform(1, 10, size=weight_shape).astype("int8")
+ c = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32")
+
+ runtime.set_input("data", a)
+ runtime.set_input("weight", b)
+ runtime.set_input("bias", c)
+ runtime.run()
+
+ out = runtime.get_output(0).numpy()
+ ref = np.dot(a.astype("int32"), b.transpose().astype("int32")) + c
+
+ np.testing.assert_equal(out, ref)
+
+
if __name__ == "__main__":
pytest.main([__file__])
diff --git a/tests/python/relay/test_op_level10.py
b/tests/python/relay/test_op_level10.py
index 85a3dd5636..8ee5adbb31 100644
--- a/tests/python/relay/test_op_level10.py
+++ b/tests/python/relay/test_op_level10.py
@@ -447,6 +447,41 @@ def test_batch_matmul_vnni():
np.testing.assert_equal(out, ref)
[email protected]("Requires GFX10 AMDGPU")
+def test_batch_matmul_rocm_sdot4():
+ x_shape = (16, 32, 96)
+ y_shape = (16, 128, 96)
+
+ lhs_dtype = "int8"
+ x = relay.var("x", shape=x_shape, dtype=lhs_dtype)
+ y = relay.var("y", shape=y_shape, dtype="int8")
+ bmm = relay.nn.batch_matmul(x, y, out_dtype="int32")
+
+ mod = tvm.IRModule.from_expr(bmm)
+
+ target = "rocm -mattr=+dotprod"
+ with tvm.transform.PassContext(opt_level=3):
+ lib = relay.build(mod, target=target)
+
+ asm = lib.lib.imported_modules[0].get_source("asm")
+ assert "v_dot4_i32_i8" in asm
+
+ dev = tvm.device(target, 0)
+ runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
+
+ x_np = np.random.uniform(1, 10, size=x_shape).astype(lhs_dtype)
+ y_np = np.random.uniform(1, 10, size=y_shape).astype("int8")
+
+ runtime.set_input("x", x_np)
+ runtime.set_input("y", y_np)
+ runtime.run()
+
+ out = runtime.get_output(0).numpy()
+ ref = tvm.topi.testing.batch_matmul(x_np, y_np, out_dtype="int32")
+
+ np.testing.assert_equal(out, ref)
+
+
@tvm.testing.uses_gpu
def test_shape_of():
shape = (10, 5, 12)
diff --git a/tests/python/relay/test_op_level2.py
b/tests/python/relay/test_op_level2.py
index bd9536742a..7b261b0eb7 100644
--- a/tests/python/relay/test_op_level2.py
+++ b/tests/python/relay/test_op_level2.py
@@ -1944,5 +1944,55 @@ def test_correlation():
)
[email protected]("Requires GFX10 AMDGPU")
+def test_conv2d_rocm_sdot4():
+ d_shape = (1, 64, 56, 56)
+ w_shape = (64, 64, 3, 3)
+ padding = (1, 1)
+ strides = (1, 1)
+ data_dtype = "int8"
+ weight_dtype = "int8"
+ out_dtype = "int32"
+
+ data = relay.var("data", shape=d_shape, dtype=data_dtype)
+ weight = relay.var("weight", shape=w_shape, dtype=weight_dtype)
+ out_channel = w_shape[0]
+ conv2d = relay.nn.conv2d(
+ data=data,
+ weight=weight,
+ kernel_size=w_shape[2:],
+ channels=out_channel,
+ padding=padding,
+ strides=strides,
+ out_dtype=out_dtype,
+ )
+
+ mod = tvm.IRModule.from_expr(conv2d)
+
+ data_np = np.random.uniform(1, 10, d_shape).astype("int8")
+ weight_np = np.random.uniform(1, 10, size=w_shape).astype("int8")
+
+ target = "rocm -mattr=+dotprod"
+ with tvm.transform.PassContext(opt_level=3):
+ lib = relay.build(mod, target=target, params={"weight": weight_np})
+
+ asm = lib.lib.imported_modules[0].get_source("asm")
+ assert "v_dot4_i32_i8" in asm
+
+ dev = tvm.device(target, 0)
+ runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
+
+ runtime.set_input("data", data_np)
+ runtime.run()
+
+ out = runtime.get_output(0).numpy()
+
+ ref = tvm.topi.testing.conv2d_nchw_python(
+ data_np.astype("int32"), weight_np.astype("int32"), strides, padding
+ )
+
+ np.testing.assert_equal(out, ref)
+
+
if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))
diff --git a/tests/python/topi/python/test_topi_conv2d_int8.py
b/tests/python/topi/python/test_topi_conv2d_int8.py
index 860118531e..17c5573b2c 100644
--- a/tests/python/topi/python/test_topi_conv2d_int8.py
+++ b/tests/python/topi/python/test_topi_conv2d_int8.py
@@ -376,15 +376,22 @@ def verify_conv2d_NCHWc_int8(
)
if in_dtype == "int8":
- targets.append(
+ targets += [
(
"llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon",
topi.arm_cpu.conv2d_NCHWc_int8,
topi.arm_cpu.schedule_conv2d_NCHWc_int8,
8,
build_only_aarch64,
- )
- )
+ ),
+ (
+ "rocm -mattr=+dotprod",
+ lambda a, w, s, p, d, l, ol, o: topi.cuda.conv2d_NCHWc_int8(a,
w, s, p, d, l, o),
+ topi.cuda.schedule_conv2d_NCHWc_int8,
+ 4,
+ False,
+ ),
+ ]
for target, compute, schedule, oc_block_factor, build_only in targets:
check_target(target, compute, schedule, oc_block_factor, build_only)
diff --git a/tests/python/topi/python/test_topi_dense.py
b/tests/python/topi/python/test_topi_dense.py
index 8f58415da3..2826d70ba0 100644
--- a/tests/python/topi/python/test_topi_dense.py
+++ b/tests/python/topi/python/test_topi_dense.py
@@ -52,7 +52,6 @@ _dense_implementations = {
],
"mali": [(topi.mali.dense, topi.mali.schedule_dense)],
"bifrost": [(topi.bifrost.dense, topi.bifrost.schedule_dense)],
- "rocm": [(topi.rocm.dense, topi.rocm.schedule_dense)],
"hls": [(topi.nn.dense, topi.hls.schedule_dense)],
}
diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py
b/tests/python/unittest/test_tir_schedule_tensorize.py
index 482d6f3db5..65dfa06eb6 100644
--- a/tests/python/unittest/test_tir_schedule_tensorize.py
+++ b/tests/python/unittest/test_tir_schedule_tensorize.py
@@ -26,6 +26,8 @@ from tvm.tir.tensor_intrin import (
VNNI_DOT_16x4_INTRIN,
ARM_DOT_4x4_i8_NEON_INTRIN,
ARM_DOT_4x4_i8_SDOT_INTRIN,
+ AMDGPU_SDOT4_INTRIN,
+ DP4A_INTRIN,
)
# fmt: off
@@ -595,5 +597,53 @@ def test_tensorize_arm_dot():
verify_trace_roundtrip(sch=sch, mod=func)
+def test_tensorize_dpa4():
+ m, n, k = 128, 128, 128
+
+ X = te.placeholder((m, k), name="X", dtype="int8")
+ W = te.placeholder((n, k), name="W", dtype="int8")
+ ak = te.reduce_axis((0, k), name="k")
+
+ matmul = te.compute(
+ (m, n),
+ lambda i, j: te.sum(
+ X[i, ak].astype("int32")
+ * W[j, ak].astype("int32"),
+ axis=ak,
+ ),
+ name="compute",
+ )
+
+ func = te.create_prim_func([X, W, matmul])
+
+ for intrin in [AMDGPU_SDOT4_INTRIN, DP4A_INTRIN]:
+ sch = tir.Schedule(func, debug_mask="all")
+ block = sch.get_block("compute")
+ i, j, k = sch.get_loops(block)
+
+ by, ty, yi = sch.split(i, factors=sch.sample_perfect_tile(i, n=3))
+ bx, tx, xi = sch.split(j, factors=sch.sample_perfect_tile(j, n=3))
+ ko, ki = sch.split(k, [None, 4])
+ ko, kt = sch.split(ko, factors=sch.sample_perfect_tile(ko, n=2))
+
+ sch.reorder(by, bx, ty, tx, yi, xi)
+
+ CC = sch.cache_write(block, 0, "local")
+ sch.reverse_compute_at(CC, tx)
+
+ def fetch_to_shared(block, idx):
+ block_read = sch.cache_read(block, idx, "shared")
+ sch.compute_at(block_read, ko, True)
+ return block_read
+
+ fetch_to_shared(block, 0)
+ fetch_to_shared(block, 1)
+
+ sch.decompose_reduction(block, ko)
+ sch.tensorize(ki, intrin)
+
+ verify_trace_roundtrip(sch=sch, mod=func)
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))