This is an automated email from the ASF dual-hosted git repository.

ekalda pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new af0c038f2e [SVE] Add codegen support for scalable buffer accesses 
(#16696)
af0c038f2e is described below

commit af0c038f2ec36d1762e7f500bb000d945b01e326
Author: Luke Hutton <[email protected]>
AuthorDate: Thu Mar 14 11:48:21 2024 +0000

    [SVE] Add codegen support for scalable buffer accesses (#16696)
    
    This commit adds support for generating code for scalable loads and
    stores. It also adds support for the creation of scalable broadcast
    operations.
    
    
    Co-authored-by: Elen Kalda <[email protected]>
    Co-authored-by: Neil Hickey <[email protected]>
---
 include/tvm/runtime/data_type.h                    |  16 ++-
 python/tvm/testing/utils.py                        |   7 ++
 src/target/llvm/codegen_llvm.cc                    |  66 ++++++-----
 src/target/llvm/codegen_llvm.h                     |   1 -
 src/tir/ir/data_type_rewriter.cc                   |   2 +-
 src/tir/ir/expr.cc                                 |   7 +-
 src/tir/transforms/storage_rewrite.cc              |   7 ++
 tests/cpp/tir_scalable_datatype.cc                 |  16 +++
 .../python/codegen/test_target_codegen_aarch64.py  |  41 +++++++
 tests/python/target/test_arm_target.py             | 125 +++++++++++++++++++++
 10 files changed, 249 insertions(+), 39 deletions(-)

diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index f6a7d424ed..8f3ae9b424 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -110,6 +110,8 @@ class DataType {
     }
     return -lanes_as_int;
   }
+  /*! \return get vscale factor or lanes depending on scalability of the 
vector. */
+  int get_lanes_or_vscale_factor() { return is_scalable_vector() ? 
vscale_factor() : lanes(); }
   /*! \return whether type is a scalar type. */
   bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
   /*! \return whether type is a scalar type. */
@@ -211,10 +213,13 @@ class DataType {
   /*!
    * \brief Construct an uint type.
    * \param bits The number of bits in the type.
-   * \param lanes The number of lanes
+   * \param lanes The number of lanes.
+   * \param is_scalable Whether the data type is scalable.
    * \return The constructed data type.
    */
-  static DataType UInt(int bits, int lanes = 1) { return DataType(kDLUInt, 
bits, lanes); }
+  static DataType UInt(int bits, int lanes = 1, bool is_scalable = false) {
+    return DataType(kDLUInt, bits, lanes, is_scalable);
+  }
   /*!
    * \brief Construct an float type.
    * \param bits The number of bits in the type.
@@ -243,10 +248,13 @@ class DataType {
   static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8, 
lanes); }
   /*!
    * \brief Construct a bool type.
-   * \param lanes The number of lanes
+   * \param lanes The number of lanes.
+   * \param is_scalable Whether the data type is scalable.
    * \return The constructed data type.
    */
-  static DataType Bool(int lanes = 1) { return DataType::UInt(1, lanes); }
+  static DataType Bool(int lanes = 1, bool is_scalable = false) {
+    return DataType::UInt(1, lanes, is_scalable);
+  }
   /*!
    * \brief Construct a handle type.
    * \param bits The number of bits in the type.
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index 6e23a84bc2..e1b1c65457 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -1045,6 +1045,13 @@ requires_arm_dot = Feature(
 )
 
 
+requires_aarch64_sve = Feature(
+    "arm_sve",
+    "AArch64 SVE",
+    run_time_check=lambda: _has_cpu_feat("sve"),
+)
+
+
 requires_x86_vnni = Feature(
     "x86_vnni",
     "x86 VNNI Extensions",
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index eae26e5cac..bba1488274 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -587,10 +587,17 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& 
dtype) const {
         LOG(FATAL) << "do not support " << dtype;
     }
   }
-  if (dtype.lanes() != 1) {
+  if (!dtype.is_scalar()) {
 #if TVM_LLVM_VERSION >= 110
-    return llvm::FixedVectorType::get(etype, dtype.lanes());
+    if (dtype.is_scalable_vector()) {
+      return llvm::VectorType::get(etype, dtype.vscale_factor(), true);
+    } else {
+      return llvm::FixedVectorType::get(etype, dtype.lanes());
+    }
 #else
+    ICHECK(!dtype.is_scalable_vector())
+        << "Versions of LLVM < 11 do not support scalable vectors. Please 
upgrade to a later "
+           "version.";
     return llvm::VectorType::get(etype, dtype.lanes());
 #endif
   } else {
@@ -749,26 +756,6 @@ std::unique_ptr<CodeGenLLVM::DebugInfo> 
CodeGenLLVM::CreateDebugInfo(llvm::Modul
   return debug_info;
 }
 
-llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
-#if TVM_LLVM_VERSION >= 110
-  llvm::Type* type = llvm::FixedVectorType::get(value->getType(), lanes);
-#else
-  llvm::Type* type = llvm::VectorType::get(value->getType(), lanes);
-#endif
-  llvm::Constant* undef = llvm::UndefValue::get(type);
-  llvm::Constant* zero = ConstInt32(0);
-  value = builder_->CreateInsertElement(undef, value, zero);
-#if TVM_LLVM_VERSION >= 120
-  llvm::Constant* mask = 
llvm::ConstantVector::getSplat(llvm::ElementCount::getFixed(lanes), zero);
-#elif TVM_LLVM_VERSION >= 110
-  llvm::Constant* mask =
-      llvm::ConstantVector::getSplat(llvm::ElementCount(lanes, 
/*Scalable=*/false), zero);
-#else
-  llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero);
-#endif
-  return builder_->CreateShuffleVector(value, undef, mask);
-}
-
 llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int 
extent) {
   int num_elems = GetVectorNumElements(vec);
   if (extent == num_elems && begin == 0) return vec;
@@ -1693,7 +1680,8 @@ void CodeGenLLVM::BufferAccessHelper(
   }
 
   PrimExpr last_index = indices[indices.size() - 1];
-  ICHECK_EQ(value_dtype.lanes(), last_index.dtype().lanes() * 
buffer_element_dtype.lanes());
+  ICHECK_EQ(value_dtype.get_lanes_or_vscale_factor(),
+            last_index.dtype().get_lanes_or_vscale_factor() * 
buffer_element_dtype.lanes());
 
   // Record index and elemtype in original form used for alias info
   PrimExpr last_index_origin = last_index;
@@ -1736,8 +1724,6 @@ void CodeGenLLVM::BufferAccessHelper(
     llvm::Value* last_index_value;
     int subelement_i = i;
     if (const RampNode* ramp = last_index.as<RampNode>()) {
-      // TODO(ekalda): P4 in https://github.com/apache/tvm/issues/16455
-      ICHECK(!last_index.dtype().is_scalable_vector());
       PrimExpr offset = ramp->base + (ramp->stride * i);
       last_index_value = MakeValue(offset);
     } else if (last_index.dtype().lanes() > 1) {
@@ -1754,8 +1740,13 @@ void CodeGenLLVM::BufferAccessHelper(
     all_index_values.push_back(last_index_value);
 
     TypedPointer buffer_ptr =
-        CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, 
all_index_values,
-                        value_dtype.with_lanes(value_dtype.lanes() / 
last_index.dtype().lanes()));
+        value_dtype.is_scalable_vector()
+            ? CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype, 
all_index_values,
+                              
value_dtype.with_scalable_vscale_factor(value_dtype.vscale_factor() /
+                                                                      
last_index.dtype().lanes()))
+            : CreateBufferPtr(
+                  MakeValue(buffer->data), buffer_element_dtype, 
all_index_values,
+                  value_dtype.with_lanes(value_dtype.lanes() / 
last_index.dtype().lanes()));
     auto instruction = make_instruction(buffer_ptr, subelement_i, alignment, 
is_volatile);
     AddAliasInfo(instruction, buffer->data.get(), last_index_origin, 
buffer_element_dtype_origin);
   }
@@ -1870,10 +1861,23 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* 
op) {
 }
 
 llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) {
-  // TODO(ekalda): P4 in https://github.com/apache/tvm/issues/16455
-  ICHECK(!op->dtype.is_scalable_vector());
-  int lanes = op->dtype.lanes();
-  return CreateBroadcast(MakeValue(op->value), lanes);
+  DataType dtype = op->dtype;
+  llvm::Value* value = MakeValue(op->value);
+  llvm::Type* type = DTypeToLLVMType(dtype);
+  llvm::Constant* undef = llvm::UndefValue::get(type);
+  llvm::Constant* zero = ConstInt32(0);
+  value = builder_->CreateInsertElement(undef, value, zero);
+#if TVM_LLVM_VERSION >= 110
+  llvm::ElementCount ec =
+      llvm::ElementCount::get(dtype.get_lanes_or_vscale_factor(), 
dtype.is_scalable_vector());
+  llvm::Constant* mask = llvm::ConstantVector::getSplat(ec, zero);
+#else
+  ICHECK(!dtype.is_scalable_vector())
+      << "Versions of LLVM < 11 do not support scalable vectors. Please 
upgrade to a later "
+         "version.";
+  llvm::Constant* mask = llvm::ConstantVector::getSplat(dtype.lanes(), zero);
+#endif
+  return builder_->CreateShuffleVector(value, undef, mask);
 }
 
 void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) {
diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h
index 2efac03073..0f7aa847ec 100644
--- a/src/target/llvm/codegen_llvm.h
+++ b/src/target/llvm/codegen_llvm.h
@@ -468,7 +468,6 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const 
PrimExpr&)>,
   llvm::Value* CreateAdd(DataType t, llvm::Value* a, llvm::Value* b);
   llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b);
   llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b);
-  llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
   virtual TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType 
buffer_element_dtype,
                                        llvm::ArrayRef<llvm::Value*> indices, 
DataType value_dtype);
   // Vector concatenation.
diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc
index 2bd1e06083..2d2c097be4 100644
--- a/src/tir/ir/data_type_rewriter.cc
+++ b/src/tir/ir/data_type_rewriter.cc
@@ -451,7 +451,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const 
BufferStoreNode* op) {
 
   Buffer new_buffer = GetRemappedBuffer(op->buffer);
   auto value = this->VisitExpr(op->value);
-  if (new_buffer->dtype != value->dtype && value->dtype.lanes() == 1) {
+  if (new_buffer->dtype != value->dtype && value->dtype.is_scalar()) {
     value = cast(new_buffer->dtype, value);
   }
   auto indices = VisitIndices(op->indices);
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 1b611d4534..c2baad2096 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -58,7 +58,9 @@ namespace tir {
     CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << 
a.dtype() << " vs. " \
                                   << b.dtype() << "\n";                        
              \
     ObjectPtr<T> node = make_object<T>();                                      
              \
-    node->dtype = DataType::Bool(a.dtype().lanes());                           
              \
+    DataType a_dtype = a.dtype();                                              
              \
+    node->dtype =                                                              
              \
+        DataType::Bool(a_dtype.get_lanes_or_vscale_factor(), 
a_dtype.is_scalable_vector());  \
     node->a = std::move(a);                                                    
              \
     node->b = std::move(b);                                                    
              \
     node->span = std::move(span);                                              
              \
@@ -393,7 +395,8 @@ Not::Not(PrimExpr a, Span span) {
   ICHECK(a.dtype().is_bool());
 
   ObjectPtr<NotNode> node = make_object<NotNode>();
-  node->dtype = DataType::Bool(a.dtype().lanes());
+  DataType a_dtype = a.dtype();
+  node->dtype = DataType::Bool(a_dtype.get_lanes_or_vscale_factor(), 
a_dtype.is_scalable_vector());
   node->a = std::move(a);
   node->span = std::move(span);
   data_ = std::move(node);
diff --git a/src/tir/transforms/storage_rewrite.cc 
b/src/tir/transforms/storage_rewrite.cc
index e40f683e21..3f34f2e870 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -1275,6 +1275,13 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
     auto it = info_map_.find(buffer);
     ICHECK(it != info_map_.end()) << "Load/Store of buffer " << 
buffer->name_hint << " (" << buffer
                                   << ") occurred before its declaration.";
+
+    if (value_dtype.is_scalable_vector()) {
+      // Scalable types are not currently supported in storage_rewrite. 
Scalable buffer
+      // accesses are not currently checked and therefore are not rewritten.
+      return;
+    }
+
     BufferVarInfo& var_info = it->second;
 
     if (value_dtype.element_of() == DataType::Bool()) {
diff --git a/tests/cpp/tir_scalable_datatype.cc 
b/tests/cpp/tir_scalable_datatype.cc
index 23decef69e..4b4764555f 100644
--- a/tests/cpp/tir_scalable_datatype.cc
+++ b/tests/cpp/tir_scalable_datatype.cc
@@ -162,6 +162,22 @@ TEST(ScalableDataType, 
TestScalableDataTypeInvalidLanesAccess) {
       tvm::InternalError);
 }
 
+TEST(ScalableDataType, TestScalableBool) {
+  tvm::DataType scalable_type = tvm::DataType::Bool(4, true);
+  ASSERT_EQ(scalable_type.code(), kDLUInt);
+  ASSERT_EQ(scalable_type.bits(), 1);
+  ASSERT_EQ(scalable_type.vscale_factor(), 4);
+  ASSERT_TRUE(scalable_type.is_scalable_vector());
+}
+
+TEST(ScalableDataType, TestScalableUInt) {
+  tvm::DataType scalable_type = tvm::DataType::UInt(1, 4, true);
+  ASSERT_EQ(scalable_type.code(), kDLUInt);
+  ASSERT_EQ(scalable_type.bits(), 1);
+  ASSERT_EQ(scalable_type.vscale_factor(), 4);
+  ASSERT_TRUE(scalable_type.is_scalable_vector());
+}
+
 // -----------
 // Integration
 // -----------
diff --git a/tests/python/codegen/test_target_codegen_aarch64.py 
b/tests/python/codegen/test_target_codegen_aarch64.py
index 4e75f916d9..773c113f4a 100644
--- a/tests/python/codegen/test_target_codegen_aarch64.py
+++ b/tests/python/codegen/test_target_codegen_aarch64.py
@@ -492,5 +492,46 @@ def test_codegen_vscale():
     assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM."
 
 
[email protected](
+    llvm_version_major() < 11, reason="Vscale is not supported in earlier 
versions of LLVM"
+)
+def test_scalable_buffer_load_store():
+    target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
+
+    @T.prim_func
+    def my_func(a: T.handle, b: T.handle):
+        A = T.match_buffer(a, (128,), "float32")
+        B = T.match_buffer(b, (128,), "float32")
+        T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+        B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())]
+
+    mod = tvm.build(my_func, target=target)
+    llvm = mod.get_source("ll")
+
+    assert re.findall(r"load <vscale x 4 x float>", llvm), "No scalable load 
in generated LLVM."
+    assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable 
store in generated LLVM."
+
+
[email protected](
+    llvm_version_major() < 11, reason="Vscale is not supported in earlier 
versions of LLVM"
+)
+def test_scalable_broadcast():
+    target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
+
+    @T.prim_func
+    def my_func(a: T.handle):
+        A = T.match_buffer(a, (128,), "float32")
+        T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+        A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale())
+
+    mod = tvm.build(my_func, target=target)
+    llvm = mod.get_source("ll")
+
+    assert re.findall(
+        r"shufflevector \(<vscale x 4 x float> insertelement \(<vscale x 4 x 
float>", llvm
+    ), "No scalable broadcast in generated LLVM."
+    assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable 
store in generated LLVM."
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/target/test_arm_target.py 
b/tests/python/target/test_arm_target.py
index dc8452710a..158d941073 100644
--- a/tests/python/target/test_arm_target.py
+++ b/tests/python/target/test_arm_target.py
@@ -14,9 +14,16 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
+import subprocess
+import tempfile
+import re
+
 import pytest
+import numpy as np
 
 import tvm
+from tvm.script import tir as T
 from tvm.topi.arm_cpu.conv2d_int8 import is_int8_hw_support
 from tvm.target import codegen
 
@@ -61,3 +68,121 @@ def test_arm_conv2d_int8_support(
     with tvm.target.Target(arm_target):
         monkeypatch.setattr(codegen, "llvm_version_major", lambda: 
llvm_version)
         assert is_int8_hw_support(input_dtype, kernel_dtype) == is_supported
+
+
[email protected](scope="session")
+def sve_device_vector_length():
+    c_code = r"""
+    #include <stdio.h>
+    #include <arm_sve.h>
+
+    int main() {
+        printf("%ld\n", svcntb() * 8);
+    }
+    """
+
+    with tempfile.TemporaryDirectory() as tmp_dir:
+        c_path = f"{tmp_dir}/vl.c"
+        o_path = f"{tmp_dir}/out.o"
+        with open(c_path, "w") as f:
+            f.write(c_code)
+        tvm.contrib.cc.create_executable(o_path, c_path, ["-march=native"])
+        out = subprocess.check_output(o_path, shell=True).strip().decode()
+
+    return int(out)
+
+
[email protected]_aarch64_sve
+def test_scalable_div(sve_device_vector_length):
+    np.random.seed(0)
+    target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
+    dev = tvm.cpu(0)
+
+    @T.prim_func
+    def my_func(a: T.handle):
+        A = T.match_buffer(a, (1,), "int32")
+        T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+        A[0] = T.Div(10000, 4 * T.vscale())
+
+    mod = tvm.build(my_func, target=target)
+
+    A_nd = tvm.nd.array(np.empty((1,), dtype="int32"), device=dev)
+    mod(A_nd)
+
+    ref = 10000 // (sve_device_vector_length // 32)
+    tvm.testing.assert_allclose(A_nd.numpy()[0], ref)
+
+
[email protected]_aarch64_sve
+def test_scalable_buffer_load_store(sve_device_vector_length):
+    np.random.seed(0)
+    target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
+    num_elements = sve_device_vector_length // 32
+    dev = tvm.cpu(0)
+
+    @T.prim_func
+    def my_func(a: T.handle, b: T.handle):
+        A = T.match_buffer(a, (num_elements,), "float32")
+        B = T.match_buffer(b, (num_elements,), "float32")
+        T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+        B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())]
+
+    mod = tvm.build(my_func, target=target)
+
+    A_np = np.random.uniform(size=(num_elements,)).astype("float32")
+    B_np = np.zeros((num_elements,)).astype("float32")
+    A_nd = tvm.nd.array(A_np, device=dev)
+    B_nd = tvm.nd.array(B_np, device=dev)
+    mod(A_nd, B_nd)
+
+    tvm.testing.assert_allclose(B_nd.numpy(), A_np)
+
+
[email protected]_aarch64_sve
+def test_scalable_loop_bound(sve_device_vector_length):
+    np.random.seed(0)
+
+    dtype = "float32"
+    num_elements = sve_device_vector_length // 32
+    target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
+    dev = tvm.cpu(0)
+
+    @T.prim_func
+    def my_func(a: T.handle, b: T.handle):
+        A = T.match_buffer(a, (num_elements,), "float32")
+        B = T.match_buffer(b, (num_elements,), "float32")
+        T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+        for i in T.serial(0, 4 * T.vscale()):
+            B[i] = A[i]
+
+    mod = tvm.build(my_func, target=target)
+
+    A_np = np.random.uniform(size=(num_elements,)).astype(dtype)
+    B_np = np.zeros((num_elements,)).astype(dtype)
+    A_nd = tvm.nd.array(A_np, device=dev)
+    B_nd = tvm.nd.array(B_np, device=dev)
+    mod(A_nd, B_nd)
+
+    tvm.testing.assert_allclose(B_nd.numpy(), A_np)
+
+
[email protected]_aarch64_sve
+def test_scalable_broadcast(sve_device_vector_length):
+    target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
+    num_elements = sve_device_vector_length // 32
+    dev = tvm.cpu(0)
+
+    @T.prim_func
+    def my_func(a: T.handle):
+        A = T.match_buffer(a, (num_elements,), "float32")
+        T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+        A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale())
+
+    mod = tvm.build(my_func, target=target)
+
+    A_np = np.zeros((num_elements,)).astype("float32")
+    A_nd = tvm.nd.array(A_np, device=dev)
+    mod(A_nd)
+
+    ref = np.ones((num_elements,))
+    tvm.testing.assert_allclose(A_nd.numpy(), ref)

Reply via email to