commit:     ca2e68ab5527af67bf95684068758b4efbd5b5a5
Author:     Sv. Lockal <lockalsash <AT> gmail <DOT> com>
AuthorDate: Thu Jul 25 09:27:07 2024 +0000
Commit:     Alfredo Tupone <tupone <AT> gentoo <DOT> org>
CommitDate: Sat Jul 27 19:08:42 2024 +0000
URL:        https://gitweb.gentoo.org/repo/gentoo.git/commit/?id=ca2e68ab

sci-libs/caffe2: update dependencies to fix rocm flag

pytorch 2.3.0 introduced 2 new direct dependencies: hipBLASLt and aotriton.

pytorch uses hipBLASLt to perform gemm operation on datacenter AMD Instinct 
GPUs. For other GPUs pytorch fallbacks to hipBLAS.
caffe2-2.3.x ebuilds now contain a patch to optionally disable this dependency, 
when none AMDGPU_TARGETS="gfx90a gfx940 gfx941 gfx942" is used.

pytorch uses aotriton to perform FlashAttention operation.
caffe2-2.3.x ebuilds now contain a patch which fully disables aotriton 
dependency, as there is no such package yet.
Technically aotriton can be compiled (with minor patches), but I suggest to 
wait for next releases.
It is a massive burden, as it depends on forked triton and forked clang (merge 
with upstream is not expected anytime soon).
aotriton is usually distributed as a huge static (!) library (but in next 
release library will be shared).

Minor fixes added for compatibility with libc++ (used in experimental llvm 
Gentoo profile), however other ebuilds also require minor patches
(in other words: right now ROCm ecosystem can be compiled with libc++, but only 
by people with experience in C++).

Closes: https://bugs.gentoo.org/931046
Signed-off-by: Sv. Lockal <lockalsash <AT> gmail.com>
Signed-off-by: Alfredo Tupone <tupone <AT> gentoo.org>

 sci-libs/caffe2/caffe2-2.3.0-r3.ebuild             |  45 ++--
 sci-libs/caffe2/caffe2-2.3.1.ebuild                |  46 ++--
 .../files/caffe2-2.3.0-exclude-aotriton.patch      |  35 +++
 .../caffe2-2.3.0-fix-gcc-clang-abi-compat.patch    |  17 ++
 .../caffe2/files/caffe2-2.3.0-fix-libcpp.patch     |  24 +++
 .../files/caffe2-2.3.0-fix-rocm-gcc14-clamp.patch  |  18 ++
 .../files/caffe2-2.3.0-optional-hipblaslt.patch    | 235 +++++++++++++++++++++
 7 files changed, 393 insertions(+), 27 deletions(-)

diff --git a/sci-libs/caffe2/caffe2-2.3.0-r3.ebuild 
b/sci-libs/caffe2/caffe2-2.3.0-r3.ebuild
index c01e904d8eb0..666800d8f4b6 100644
--- a/sci-libs/caffe2/caffe2-2.3.0-r3.ebuild
+++ b/sci-libs/caffe2/caffe2-2.3.0-r3.ebuild
@@ -4,7 +4,7 @@
 EAPI=8
 
 PYTHON_COMPAT=( python3_{10..12} )
-ROCM_VERSION=5.7
+ROCM_VERSION=6.1
 inherit python-single-r1 cmake cuda flag-o-matic prefix rocm
 
 MYPN=pytorch
@@ -65,18 +65,23 @@ RDEPEND="
        opencv? ( media-libs/opencv:= )
        qnnpack? ( sci-libs/QNNPACK )
        rocm? (
-               >=dev-util/hip-5.7
-               >=dev-libs/rccl-5.7[${ROCM_USEDEP}]
-               >=sci-libs/rocThrust-5.7[${ROCM_USEDEP}]
-               >=sci-libs/rocPRIM-5.7[${ROCM_USEDEP}]
-               >=sci-libs/hipBLAS-5.7[${ROCM_USEDEP}]
-               >=sci-libs/hipFFT-5.7[${ROCM_USEDEP}]
-               >=sci-libs/hipSPARSE-5.7[${ROCM_USEDEP}]
-               >=sci-libs/hipRAND-5.7[${ROCM_USEDEP}]
-               >=sci-libs/hipCUB-5.7[${ROCM_USEDEP}]
-               >=sci-libs/hipSOLVER-5.7[${ROCM_USEDEP}]
-               >=sci-libs/miopen-5.7[${ROCM_USEDEP}]
-               >=dev-util/roctracer-5.7[${ROCM_USEDEP}]
+               =dev-util/hip-6.1*
+               =dev-libs/rccl-6.1*[${ROCM_USEDEP}]
+               =sci-libs/rocThrust-6.1*[${ROCM_USEDEP}]
+               =sci-libs/rocPRIM-6.1*[${ROCM_USEDEP}]
+               =sci-libs/hipBLAS-6.1*[${ROCM_USEDEP}]
+               =sci-libs/hipFFT-6.1*[${ROCM_USEDEP}]
+               =sci-libs/hipSPARSE-6.1*[${ROCM_USEDEP}]
+               =sci-libs/hipRAND-6.1*[${ROCM_USEDEP}]
+               =sci-libs/hipCUB-6.1*[${ROCM_USEDEP}]
+               =sci-libs/hipSOLVER-6.1*[${ROCM_USEDEP}]
+               =sci-libs/miopen-6.1*[${ROCM_USEDEP}]
+               =dev-util/roctracer-6.1*[${ROCM_USEDEP}]
+
+               amdgpu_targets_gfx90a? ( 
=sci-libs/hipBLASLt-6.1*[amdgpu_targets_gfx90a] )
+               amdgpu_targets_gfx940? ( 
=sci-libs/hipBLASLt-6.1*[amdgpu_targets_gfx940] )
+               amdgpu_targets_gfx941? ( 
=sci-libs/hipBLASLt-6.1*[amdgpu_targets_gfx941] )
+               amdgpu_targets_gfx942? ( 
=sci-libs/hipBLASLt-6.1*[amdgpu_targets_gfx942] )
        )
        distributed? ( sci-libs/tensorpipe[cuda?] )
        xnnpack? ( >=sci-libs/XNNPACK-2022.12.22 )
@@ -111,6 +116,11 @@ PATCHES=(
        "${FILESDIR}"/${P}-rocm-fix-std-cpp17.patch
        "${FILESDIR}"/${PN}-2.2.2-musl.patch
        "${FILESDIR}"/${P}-CMakeFix.patch
+       "${FILESDIR}"/${PN}-2.3.0-exclude-aotriton.patch
+       "${FILESDIR}"/${PN}-2.3.0-fix-rocm-gcc14-clamp.patch
+       "${FILESDIR}"/${PN}-2.3.0-optional-hipblaslt.patch
+       "${FILESDIR}"/${PN}-2.3.0-fix-libcpp.patch
+       "${FILESDIR}"/${PN}-2.3.0-fix-gcc-clang-abi-compat.patch
 )
 
 src_prepare() {
@@ -235,11 +245,20 @@ src_configure() {
                )
        elif use rocm; then
                export PYTORCH_ROCM_ARCH="$(get_amdgpu_flags)"
+               local use_hipblaslt="OFF"
+               if use amdgpu_targets_gfx90a || use amdgpu_targets_gfx940 || 
use amdgpu_targets_gfx941 \
+                       || use amdgpu_targets_gfx942; then
+                       use_hipblaslt="ON"
+               fi
 
                mycmakeargs+=(
                        -DUSE_NCCL=ON
                        -DUSE_SYSTEM_NCCL=ON
+                       -DUSE_HIPBLASLT=${use_hipblaslt}
                )
+
+               # ROCm libraries produce too much warnings
+               append-cxxflags -Wno-deprecated-declarations -Wno-unused-result
        fi
 
        if use onednn; then

diff --git a/sci-libs/caffe2/caffe2-2.3.1.ebuild 
b/sci-libs/caffe2/caffe2-2.3.1.ebuild
index 6355d0083336..ee1da28aa12f 100644
--- a/sci-libs/caffe2/caffe2-2.3.1.ebuild
+++ b/sci-libs/caffe2/caffe2-2.3.1.ebuild
@@ -4,7 +4,7 @@
 EAPI=8
 
 PYTHON_COMPAT=( python3_{10..12} )
-ROCM_VERSION=5.7
+ROCM_VERSION=6.1
 inherit python-single-r1 cmake cuda flag-o-matic prefix rocm
 
 MYPN=pytorch
@@ -65,19 +65,23 @@ RDEPEND="
        opencv? ( media-libs/opencv:= )
        qnnpack? ( sci-libs/QNNPACK )
        rocm? (
-               =dev-util/hip-5.7*
-               =dev-libs/rccl-5.7*[${ROCM_USEDEP}]
-               =sci-libs/rocThrust-5.7*[${ROCM_USEDEP}]
-               =sci-libs/rocPRIM-5.7*[${ROCM_USEDEP}]
-               =sci-libs/hipBLAS-5.7*[${ROCM_USEDEP}]
-               sci-libs/hipBLASLt
-               =sci-libs/hipFFT-5.7*[${ROCM_USEDEP}]
-               =sci-libs/hipSPARSE-5.7*[${ROCM_USEDEP}]
-               =sci-libs/hipRAND-5.7*[${ROCM_USEDEP}]
-               =sci-libs/hipCUB-5.7*[${ROCM_USEDEP}]
-               =sci-libs/hipSOLVER-5.7*[${ROCM_USEDEP}]
-               =sci-libs/miopen-5.7*[${ROCM_USEDEP}]
-               =dev-util/roctracer-5.7*[${ROCM_USEDEP}]
+               =dev-util/hip-6.1*
+               =dev-libs/rccl-6.1*[${ROCM_USEDEP}]
+               =sci-libs/rocThrust-6.1*[${ROCM_USEDEP}]
+               =sci-libs/rocPRIM-6.1*[${ROCM_USEDEP}]
+               =sci-libs/hipBLAS-6.1*[${ROCM_USEDEP}]
+               =sci-libs/hipFFT-6.1*[${ROCM_USEDEP}]
+               =sci-libs/hipSPARSE-6.1*[${ROCM_USEDEP}]
+               =sci-libs/hipRAND-6.1*[${ROCM_USEDEP}]
+               =sci-libs/hipCUB-6.1*[${ROCM_USEDEP}]
+               =sci-libs/hipSOLVER-6.1*[${ROCM_USEDEP}]
+               =sci-libs/miopen-6.1*[${ROCM_USEDEP}]
+               =dev-util/roctracer-6.1*[${ROCM_USEDEP}]
+
+               amdgpu_targets_gfx90a? ( 
=sci-libs/hipBLASLt-6.1*[amdgpu_targets_gfx90a] )
+               amdgpu_targets_gfx940? ( 
=sci-libs/hipBLASLt-6.1*[amdgpu_targets_gfx940] )
+               amdgpu_targets_gfx941? ( 
=sci-libs/hipBLASLt-6.1*[amdgpu_targets_gfx941] )
+               amdgpu_targets_gfx942? ( 
=sci-libs/hipBLASLt-6.1*[amdgpu_targets_gfx942] )
        )
        distributed? ( sci-libs/tensorpipe[cuda?] )
        xnnpack? ( >=sci-libs/XNNPACK-2022.12.22 )
@@ -112,6 +116,11 @@ PATCHES=(
        "${FILESDIR}"/${PN}-2.3.0-rocm-fix-std-cpp17.patch
        "${FILESDIR}"/${PN}-2.2.2-musl.patch
        "${FILESDIR}"/${PN}-2.3.0-CMakeFix.patch
+       "${FILESDIR}"/${PN}-2.3.0-exclude-aotriton.patch
+       "${FILESDIR}"/${PN}-2.3.0-fix-rocm-gcc14-clamp.patch
+       "${FILESDIR}"/${PN}-2.3.0-optional-hipblaslt.patch
+       "${FILESDIR}"/${PN}-2.3.0-fix-libcpp.patch
+       "${FILESDIR}"/${PN}-2.3.0-fix-gcc-clang-abi-compat.patch
 )
 
 src_prepare() {
@@ -236,11 +245,20 @@ src_configure() {
                )
        elif use rocm; then
                export PYTORCH_ROCM_ARCH="$(get_amdgpu_flags)"
+               local use_hipblaslt="OFF"
+               if use amdgpu_targets_gfx90a || use amdgpu_targets_gfx940 || 
use amdgpu_targets_gfx941 \
+                       || use amdgpu_targets_gfx942; then
+                       use_hipblaslt="ON"
+               fi
 
                mycmakeargs+=(
                        -DUSE_NCCL=ON
                        -DUSE_SYSTEM_NCCL=ON
+                       -DUSE_HIPBLASLT=${use_hipblaslt}
                )
+
+               # ROCm libraries produce too much warnings
+               append-cxxflags -Wno-deprecated-declarations -Wno-unused-result
        fi
 
        if use onednn; then

diff --git a/sci-libs/caffe2/files/caffe2-2.3.0-exclude-aotriton.patch 
b/sci-libs/caffe2/files/caffe2-2.3.0-exclude-aotriton.patch
new file mode 100644
index 000000000000..2c65987acd85
--- /dev/null
+++ b/sci-libs/caffe2/files/caffe2-2.3.0-exclude-aotriton.patch
@@ -0,0 +1,35 @@
+Disables aotriton download when both USE_FLASH_ATTENTION and 
USE_MEM_EFF_ATTENTION cmake flags are OFF
+Backports upstream PR to 2.3.0: https://github.com/pytorch/pytorch/pull/130197
+--- a/cmake/Dependencies.cmake
++++ b/cmake/Dependencies.cmake
+@@ -1334,7 +1334,9 @@ if(USE_ROCM)
+       message(STATUS "Disabling Kernel Assert for ROCm")
+     endif()
+ 
+-    include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake)
++    if(USE_FLASH_ATTENTION)
++      include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake)
++    endif()
+     if(USE_CUDA)
+       caffe2_update_option(USE_MEM_EFF_ATTENTION OFF)
+     endif()
+--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
++++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
+@@ -21,7 +21,7 @@
+ #include <cmath>
+ #include <functional>
+ 
+-#if USE_ROCM
++#if defined(USE_ROCM) && defined(USE_FLASH_ATTENTION)
+ #include <aotriton/flash.h>
+ #endif
+ 
+@@ -186,7 +186,7 @@ bool check_flash_attention_hardware_support(sdp_params 
const& params, bool debug
+   // Check that the gpu is capable of running flash attention
+   using sm80 = SMVersion<8, 0>;
+   using sm90 = SMVersion<9, 0>;
+-#if USE_ROCM
++#if defined(USE_ROCM) && defined(USE_FLASH_ATTENTION)
+   auto stream = at::cuda::getCurrentCUDAStream().stream();
+   if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
+       auto dprops = at::cuda::getCurrentDeviceProperties();

diff --git a/sci-libs/caffe2/files/caffe2-2.3.0-fix-gcc-clang-abi-compat.patch 
b/sci-libs/caffe2/files/caffe2-2.3.0-fix-gcc-clang-abi-compat.patch
new file mode 100644
index 000000000000..a6f981b7e054
--- /dev/null
+++ b/sci-libs/caffe2/files/caffe2-2.3.0-fix-gcc-clang-abi-compat.patch
@@ -0,0 +1,17 @@
+
+When gcc builds libtorch_cpu.so and hipcc (clang-18) build libtorch_hip.so,
+resulting binary fails in runtime due to different mangling.
+Related issue in LLVM: https://github.com/llvm/llvm-project/issues/85656
+Fixed in pytorch-2.4.0 in 
https://github.com/pytorch/pytorch/commit/a89f442f0b103fa6f38103784a2dfedbd147f863
+--- a/cmake/Dependencies.cmake
++++ b/cmake/Dependencies.cmake
+@@ -1314,6 +1314,9 @@ if(USE_ROCM)
+        list(APPEND HIP_HIPCC_FLAGS -fdebug-info-for-profiling)
+     endif(CMAKE_BUILD_TYPE MATCHES Debug)
+ 
++    # needed for compat with newer versions of hip-clang that introduced 
C++20 mangling rules
++    list(APPEND HIP_HIPCC_FLAGS -fclang-abi-compat=17)
++
+     set(HIP_CLANG_FLAGS ${HIP_CXX_FLAGS})
+     # Ask hcc to generate device code during compilation so we can use
+     # host linker to link.

diff --git a/sci-libs/caffe2/files/caffe2-2.3.0-fix-libcpp.patch 
b/sci-libs/caffe2/files/caffe2-2.3.0-fix-libcpp.patch
new file mode 100644
index 000000000000..75808fd7ec50
--- /dev/null
+++ b/sci-libs/caffe2/files/caffe2-2.3.0-fix-libcpp.patch
@@ -0,0 +1,24 @@
+Workaround for libc++ issue https://github.com/llvm/llvm-project/issues/100802
+"reference to __host__ function 'memcpy' in __device__ function"
+--- a/c10/util/Half.h
++++ b/c10/util/Half.h
+@@ -227,7 +227,7 @@ C10_HOST_DEVICE inline float 
fp16_ieee_to_fp32_value(uint16_t h) {
+   // const float exp_scale = 0x1.0p-112f;
+   constexpr uint32_t scale_bits = (uint32_t)15 << 23;
+   float exp_scale_val = 0;
+-  std::memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
++  memcpy(&exp_scale_val, &scale_bits, sizeof(exp_scale_val));
+   const float exp_scale = exp_scale_val;
+   const float normalized_value =
+       fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
+@@ -298,8 +298,8 @@ inline uint16_t fp16_ieee_from_fp32_value(float f) {
+   constexpr uint32_t scale_to_inf_bits = (uint32_t)239 << 23;
+   constexpr uint32_t scale_to_zero_bits = (uint32_t)17 << 23;
+   float scale_to_inf_val = 0, scale_to_zero_val = 0;
+-  std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, 
sizeof(scale_to_inf_val));
+-  std::memcpy(
++  memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val));
++  memcpy(
+       &scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val));
+   const float scale_to_inf = scale_to_inf_val;
+   const float scale_to_zero = scale_to_zero_val;

diff --git a/sci-libs/caffe2/files/caffe2-2.3.0-fix-rocm-gcc14-clamp.patch 
b/sci-libs/caffe2/files/caffe2-2.3.0-fix-rocm-gcc14-clamp.patch
new file mode 100644
index 000000000000..81ae075c67cc
--- /dev/null
+++ b/sci-libs/caffe2/files/caffe2-2.3.0-fix-rocm-gcc14-clamp.patch
@@ -0,0 +1,18 @@
+Fix hip compilation with gcc-14
+Upstream commit: 
https://github.com/pytorch/pytorch/commit/8c2c3a03fb87c3568a22362d83b00d82b9fb3db2
+--- a/aten/src/ATen/native/cuda/IndexKernel.cu
++++ b/aten/src/ATen/native/cuda/IndexKernel.cu
+@@ -259,7 +259,13 @@ void index_put_kernel_quantized_cuda(TensorIterator& 
iter, const IntArrayRef ind
+ 
+     gpu_index_kernel(iter, index_size, index_stride, [inv_scale, zero_point, 
qmin, qmax]C10_DEVICE(char* const out_data, const char* const in_data, const 
int64_t offset) {
+       int64_t qvalue = static_cast<int64_t>(zero_point + 
nearbyintf(*(float*)in_data * inv_scale));
++      // See https://github.com/pytorch/pytorch/issues/127666
++      // hip-clang std::clamp __glibcxx_assert_fail host function when 
building on Fedora40/gcc14
++#ifndef USE_ROCM
+       qvalue = std::clamp(qvalue, qmin, qmax);
++#else
++      qvalue = (qvalue < qmin) ? qmin : (qmax < qvalue) ? qmax : qvalue;
++#endif
+       *(scalar_t*)(out_data + offset) = static_cast<scalar_t>(qvalue);
+     });
+   });

diff --git a/sci-libs/caffe2/files/caffe2-2.3.0-optional-hipblaslt.patch 
b/sci-libs/caffe2/files/caffe2-2.3.0-optional-hipblaslt.patch
new file mode 100644
index 000000000000..dc544255c2bd
--- /dev/null
+++ b/sci-libs/caffe2/files/caffe2-2.3.0-optional-hipblaslt.patch
@@ -0,0 +1,235 @@
+Makes hipblaslt optional to simplify build for non-datacenter GPUs.
+Based on https://github.com/pytorch/pytorch/pull/120551 with added 
USE_HIPBLASLT cmake option.
+--- a/CMakeLists.txt
++++ b/CMakeLists.txt
+@@ -225,6 +225,9 @@ option(USE_FAKELOWP "Use FakeLowp operators" OFF)
+ option(USE_FFMPEG "Use ffmpeg" OFF)
+ option(USE_GFLAGS "Use GFLAGS" OFF)
+ option(USE_GLOG "Use GLOG" OFF)
++cmake_dependent_option(
++    USE_HIPBLASLT "Use hipBLASLt" ON
++    "USE_ROCM" OFF)
+ option(USE_LEVELDB "Use LEVELDB" OFF)
+ option(USE_LITE_PROTO "Use lite protobuf instead of full." OFF)
+ option(USE_LMDB "Use LMDB" OFF)
+--- a/aten/src/ATen/cuda/CUDABlas.cpp
++++ b/aten/src/ATen/cuda/CUDABlas.cpp
+@@ -14,7 +14,7 @@
+ #include <c10/util/irange.h>
+ 
+ #ifdef USE_ROCM
+-#if ROCM_VERSION >= 60000
++#ifdef USE_HIPBLASLT
+ #include <hipblaslt/hipblaslt-ext.hpp>
+ #endif
+ // until hipblas has an API to accept flags, we must use rocblas here
+@@ -781,7 +781,7 @@ void 
gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
+   }
+ }
+ 
+-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && 
ROCM_VERSION >= 50700)
++#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && 
defined(USE_HIPBLASLT))
+ 
+ #if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
+ // only for rocm 5.7 where we first supported hipblaslt, it was difficult
+@@ -912,6 +912,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
+ };
+ } // namespace
+ 
++#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && 
defined(USE_HIPBLASLT))
+ template <typename Dtype>
+ void gemm_and_bias(
+     bool transpose_mat1,
+@@ -1124,7 +1125,7 @@ template void gemm_and_bias(
+     at::BFloat16* result_ptr,
+     int64_t result_ld,
+     GEMMAndBiasActivationEpilogue activation);
+-
++#endif
+ void scaled_gemm(
+     char transa,
+     char transb,
+--- a/aten/src/ATen/cuda/CUDABlas.h
++++ b/aten/src/ATen/cuda/CUDABlas.h
+@@ -82,7 +82,7 @@ void 
gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
+ template <>
+ void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
+ 
+-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && 
ROCM_VERSION >= 50700)
++#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && 
defined(USE_HIPBLASLT))
+ enum GEMMAndBiasActivationEpilogue {
+   None,
+   RELU,
+--- a/aten/src/ATen/cuda/CUDAContextLight.h
++++ b/aten/src/ATen/cuda/CUDAContextLight.h
+@@ -9,7 +9,7 @@
+ 
+ // cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
+ // added bf16 support
+-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && 
ROCM_VERSION >= 50700)
++#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && 
defined(USE_HIPBLASLT))
+ #include <cublasLt.h>
+ #endif
+ 
+@@ -82,7 +82,7 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
+ /* Handles */
+ TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
+ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
+-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && 
ROCM_VERSION >= 50700)
++#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && 
defined(USE_HIPBLASLT))
+ TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
+ #endif
+ 
+--- a/aten/src/ATen/cuda/CublasHandlePool.cpp
++++ b/aten/src/ATen/cuda/CublasHandlePool.cpp
+@@ -29,7 +29,7 @@ namespace at::cuda {
+ 
+ namespace {
+ 
+-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
++#if defined(USE_ROCM) && defined(USE_HIPBLASLT)
+ void createCublasLtHandle(cublasLtHandle_t *handle) {
+   TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
+ }
+@@ -190,7 +190,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
+   return handle;
+ }
+ 
+-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && 
ROCM_VERSION >= 50700)
++#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && 
defined(USE_HIPBLASLT))
+ cublasLtHandle_t getCurrentCUDABlasLtHandle() {
+ #ifdef USE_ROCM
+   c10::DeviceIndex device = 0;
+--- a/aten/src/ATen/cuda/tunable/TunableGemm.h
++++ b/aten/src/ATen/cuda/tunable/TunableGemm.h
+@@ -11,7 +11,7 @@
+ 
+ #include <ATen/cuda/tunable/GemmCommon.h>
+ #ifdef USE_ROCM
+-#if ROCM_VERSION >= 50700
++#ifdef USE_HIPBLASLT
+ #include <ATen/cuda/tunable/GemmHipblaslt.h>
+ #endif
+ #include <ATen/cuda/tunable/GemmRocblas.h>
+@@ -166,7 +166,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, 
StreamTimer> {
+     }
+ #endif
+ 
+-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
++#if defined(USE_ROCM) && defined(USE_HIPBLASLT)
+     static const char *env = 
std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
+     if (env == nullptr || strcmp(env, "1") == 0) {
+       // disallow tuning of hipblaslt with c10::complex
+@@ -240,7 +240,7 @@ class GemmStridedBatchedTunableOp : public 
TunableOp<GemmStridedBatchedParams<T>
+     }
+ #endif
+ 
+-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
++#if defined(USE_ROCM) && defined(USE_HIPBLASLT)
+     static const char *env = 
std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
+     if (env == nullptr || strcmp(env, "1") == 0) {
+       // disallow tuning of hipblaslt with c10::complex
+--- a/aten/src/ATen/native/cuda/Blas.cpp
++++ b/aten/src/ATen/native/cuda/Blas.cpp
+@@ -155,7 +155,7 @@ enum class Activation {
+   GELU,
+ };
+ 
+-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && 
ROCM_VERSION >= 50700)
++#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && 
defined(USE_HIPBLASLT))
+ cuda::blas::GEMMAndBiasActivationEpilogue 
activation_to_gemm_and_blas_arg(Activation a) {
+   switch (a) {
+     case Activation::None:
+@@ -193,6 +193,7 @@ static bool getDisableAddmmCudaLt() {
+ 
+ #ifdef USE_ROCM
+ static bool isSupportedHipLtROCmArch(int index) {
++#if defined(USE_HIPBLASLT)
+     hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
+     std::string device_arch = prop->gcnArchName;
+     static const std::vector<std::string> archs = {"gfx90a", "gfx940", 
"gfx941", "gfx942"};
+@@ -203,6 +204,7 @@ static bool isSupportedHipLtROCmArch(int index) {
+         }
+     }
+     TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported 
architecture!");
++#endif
+     return false;
+ }
+ #endif
+@@ -228,7 +230,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& 
self, const Tensor& ma
+   at::ScalarType scalar_type = self.scalar_type();
+   c10::MaybeOwned<Tensor> self_;
+   if (&result != &self) {
+-#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || 
defined(USE_ROCM) && ROCM_VERSION >= 50700
++#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || 
defined(USE_ROCM) && defined(USE_HIPBLASLT)
+     // Strangely, if mat2 has only 1 row or column, we get
+     // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
+     // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == 
mat2_sizes[1]
+@@ -271,7 +273,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& 
self, const Tensor& ma
+     }
+     self__sizes = self_->sizes();
+   } else {
+-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
++#if defined(USE_ROCM) && defined(USE_HIPBLASLT)
+     useLtInterface = !disable_addmm_cuda_lt &&
+         result.dim() == 2 && result.is_contiguous() &&
+         isSupportedHipLtROCmArch(self.device().index()) &&
+@@ -322,7 +324,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& 
self, const Tensor& ma
+ 
+   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
+ 
+-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && 
ROCM_VERSION >= 50700)
++#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && 
defined(USE_HIPBLASLT))
+   if (useLtInterface) {
+     AT_DISPATCH_FLOATING_TYPES_AND2(
+         at::ScalarType::Half,
+@@ -876,7 +878,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
+   at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
+   at::native::resize_output(amax, {});
+ 
+-#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && 
ROCM_VERSION >= 60000)
++#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && 
defined(USE_HIPBLASLT))
+   cublasCommonArgs args(mat1, mat2, out);
+   const auto out_dtype_ = args.result->scalar_type();
+   TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication 
of row-major and column-major matrices is supported by cuBLASLt");
+@@ -906,7 +908,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
+   TORCH_CHECK(false, "_scaled_mm_out_cuda is not compiled for this 
platform.");
+ #endif
+ 
+-#if defined(USE_ROCM) && ROCM_VERSION >= 60000
++#if defined(USE_ROCM) && defined(USE_HIPBLASLT)
+   // rocm's hipblaslt does not yet support amax, so calculate separately
+   auto out_float32 = out.to(kFloat);
+   out_float32.abs_();
+--- a/cmake/Dependencies.cmake
++++ b/cmake/Dependencies.cmake
+@@ -1282,6 +1282,9 @@ if(USE_ROCM)
+     if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "6.0.0")
+       list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
+     endif()
++    if(hipblast_FOUND)
++      list(APPEND HIP_CXX_FLAGS -DHIPBLASLT)
++    endif()
+     if(HIPBLASLT_CUSTOM_DATA_TYPE)
+       list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_DATA_TYPE)
+     endif()
+--- a/cmake/public/LoadHIP.cmake
++++ b/cmake/public/LoadHIP.cmake
+@@ -155,7 +155,7 @@ if(HIP_FOUND)
+   find_package_and_print_version(hiprand REQUIRED)
+   find_package_and_print_version(rocblas REQUIRED)
+   find_package_and_print_version(hipblas REQUIRED)
+-  if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
++  if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0" AND USE_HIPBLASLT)
+     find_package_and_print_version(hipblaslt REQUIRED)
+   endif()
+   find_package_and_print_version(miopen REQUIRED)
+@@ -191,7 +191,7 @@ if(HIP_FOUND)
+   # roctx is part of roctracer
+   find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)
+ 
+-  if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
++  if(hipblastlt_FOUND)
+     # check whether hipblaslt is using its own datatype
+     set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_data_type.cc")
+     file(WRITE ${file} ""

Reply via email to