This is an automated email from the ASF dual-hosted git repository.
ekalda pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new af0c038f2e [SVE] Add codegen support for scalable buffer accesses
(#16696)
af0c038f2e is described below
commit af0c038f2ec36d1762e7f500bb000d945b01e326
Author: Luke Hutton <[email protected]>
AuthorDate: Thu Mar 14 11:48:21 2024 +0000
[SVE] Add codegen support for scalable buffer accesses (#16696)
This commit adds support for generating code for scalable loads and
stores. It also adds support for the creation of scalable broadcast
operations.
Co-authored-by: Elen Kalda <[email protected]>
Co-authored-by: Neil Hickey <[email protected]>
---
include/tvm/runtime/data_type.h | 16 ++-
python/tvm/testing/utils.py | 7 ++
src/target/llvm/codegen_llvm.cc | 66 ++++++-----
src/target/llvm/codegen_llvm.h | 1 -
src/tir/ir/data_type_rewriter.cc | 2 +-
src/tir/ir/expr.cc | 7 +-
src/tir/transforms/storage_rewrite.cc | 7 ++
tests/cpp/tir_scalable_datatype.cc | 16 +++
.../python/codegen/test_target_codegen_aarch64.py | 41 +++++++
tests/python/target/test_arm_target.py | 125 +++++++++++++++++++++
10 files changed, 249 insertions(+), 39 deletions(-)
diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index f6a7d424ed..8f3ae9b424 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -110,6 +110,8 @@ class DataType {
}
return -lanes_as_int;
}
+ /*! \return get vscale factor or lanes depending on scalability of the
vector. */
+ int get_lanes_or_vscale_factor() { return is_scalable_vector() ?
vscale_factor() : lanes(); }
/*! \return whether type is a scalar type. */
bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
/*! \return whether type is a scalar type. */
@@ -211,10 +213,13 @@ class DataType {
/*!
* \brief Construct an uint type.
* \param bits The number of bits in the type.
- * \param lanes The number of lanes
+ * \param lanes The number of lanes.
+ * \param is_scalable Whether the data type is scalable.
* \return The constructed data type.
*/
- static DataType UInt(int bits, int lanes = 1) { return DataType(kDLUInt,
bits, lanes); }
+ static DataType UInt(int bits, int lanes = 1, bool is_scalable = false) {
+ return DataType(kDLUInt, bits, lanes, is_scalable);
+ }
/*!
* \brief Construct an float type.
* \param bits The number of bits in the type.
@@ -243,10 +248,13 @@ class DataType {
static DataType NVFloat8E5M2(int lanes = 1) { return DataType(kE5M2Float, 8,
lanes); }
/*!
* \brief Construct a bool type.
- * \param lanes The number of lanes
+ * \param lanes The number of lanes.
+ * \param is_scalable Whether the data type is scalable.
* \return The constructed data type.
*/
- static DataType Bool(int lanes = 1) { return DataType::UInt(1, lanes); }
+ static DataType Bool(int lanes = 1, bool is_scalable = false) {
+ return DataType::UInt(1, lanes, is_scalable);
+ }
/*!
* \brief Construct a handle type.
* \param bits The number of bits in the type.
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index 6e23a84bc2..e1b1c65457 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -1045,6 +1045,13 @@ requires_arm_dot = Feature(
)
+requires_aarch64_sve = Feature(
+ "arm_sve",
+ "AArch64 SVE",
+ run_time_check=lambda: _has_cpu_feat("sve"),
+)
+
+
requires_x86_vnni = Feature(
"x86_vnni",
"x86 VNNI Extensions",
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index eae26e5cac..bba1488274 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -587,10 +587,17 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType&
dtype) const {
LOG(FATAL) << "do not support " << dtype;
}
}
- if (dtype.lanes() != 1) {
+ if (!dtype.is_scalar()) {
#if TVM_LLVM_VERSION >= 110
- return llvm::FixedVectorType::get(etype, dtype.lanes());
+ if (dtype.is_scalable_vector()) {
+ return llvm::VectorType::get(etype, dtype.vscale_factor(), true);
+ } else {
+ return llvm::FixedVectorType::get(etype, dtype.lanes());
+ }
#else
+ ICHECK(!dtype.is_scalable_vector())
+ << "Versions of LLVM < 11 do not support scalable vectors. Please
upgrade to a later "
+ "version.";
return llvm::VectorType::get(etype, dtype.lanes());
#endif
} else {
@@ -749,26 +756,6 @@ std::unique_ptr<CodeGenLLVM::DebugInfo>
CodeGenLLVM::CreateDebugInfo(llvm::Modul
return debug_info;
}
-llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
-#if TVM_LLVM_VERSION >= 110
- llvm::Type* type = llvm::FixedVectorType::get(value->getType(), lanes);
-#else
- llvm::Type* type = llvm::VectorType::get(value->getType(), lanes);
-#endif
- llvm::Constant* undef = llvm::UndefValue::get(type);
- llvm::Constant* zero = ConstInt32(0);
- value = builder_->CreateInsertElement(undef, value, zero);
-#if TVM_LLVM_VERSION >= 120
- llvm::Constant* mask =
llvm::ConstantVector::getSplat(llvm::ElementCount::getFixed(lanes), zero);
-#elif TVM_LLVM_VERSION >= 110
- llvm::Constant* mask =
- llvm::ConstantVector::getSplat(llvm::ElementCount(lanes,
/*Scalable=*/false), zero);
-#else
- llvm::Constant* mask = llvm::ConstantVector::getSplat(lanes, zero);
-#endif
- return builder_->CreateShuffleVector(value, undef, mask);
-}
-
llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int
extent) {
int num_elems = GetVectorNumElements(vec);
if (extent == num_elems && begin == 0) return vec;
@@ -1693,7 +1680,8 @@ void CodeGenLLVM::BufferAccessHelper(
}
PrimExpr last_index = indices[indices.size() - 1];
- ICHECK_EQ(value_dtype.lanes(), last_index.dtype().lanes() *
buffer_element_dtype.lanes());
+ ICHECK_EQ(value_dtype.get_lanes_or_vscale_factor(),
+ last_index.dtype().get_lanes_or_vscale_factor() *
buffer_element_dtype.lanes());
// Record index and elemtype in original form used for alias info
PrimExpr last_index_origin = last_index;
@@ -1736,8 +1724,6 @@ void CodeGenLLVM::BufferAccessHelper(
llvm::Value* last_index_value;
int subelement_i = i;
if (const RampNode* ramp = last_index.as<RampNode>()) {
- // TODO(ekalda): P4 in https://github.com/apache/tvm/issues/16455
- ICHECK(!last_index.dtype().is_scalable_vector());
PrimExpr offset = ramp->base + (ramp->stride * i);
last_index_value = MakeValue(offset);
} else if (last_index.dtype().lanes() > 1) {
@@ -1754,8 +1740,13 @@ void CodeGenLLVM::BufferAccessHelper(
all_index_values.push_back(last_index_value);
TypedPointer buffer_ptr =
- CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype,
all_index_values,
- value_dtype.with_lanes(value_dtype.lanes() /
last_index.dtype().lanes()));
+ value_dtype.is_scalable_vector()
+ ? CreateBufferPtr(MakeValue(buffer->data), buffer_element_dtype,
all_index_values,
+
value_dtype.with_scalable_vscale_factor(value_dtype.vscale_factor() /
+
last_index.dtype().lanes()))
+ : CreateBufferPtr(
+ MakeValue(buffer->data), buffer_element_dtype,
all_index_values,
+ value_dtype.with_lanes(value_dtype.lanes() /
last_index.dtype().lanes()));
auto instruction = make_instruction(buffer_ptr, subelement_i, alignment,
is_volatile);
AddAliasInfo(instruction, buffer->data.get(), last_index_origin,
buffer_element_dtype_origin);
}
@@ -1870,10 +1861,23 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode*
op) {
}
llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) {
- // TODO(ekalda): P4 in https://github.com/apache/tvm/issues/16455
- ICHECK(!op->dtype.is_scalable_vector());
- int lanes = op->dtype.lanes();
- return CreateBroadcast(MakeValue(op->value), lanes);
+ DataType dtype = op->dtype;
+ llvm::Value* value = MakeValue(op->value);
+ llvm::Type* type = DTypeToLLVMType(dtype);
+ llvm::Constant* undef = llvm::UndefValue::get(type);
+ llvm::Constant* zero = ConstInt32(0);
+ value = builder_->CreateInsertElement(undef, value, zero);
+#if TVM_LLVM_VERSION >= 110
+ llvm::ElementCount ec =
+ llvm::ElementCount::get(dtype.get_lanes_or_vscale_factor(),
dtype.is_scalable_vector());
+ llvm::Constant* mask = llvm::ConstantVector::getSplat(ec, zero);
+#else
+ ICHECK(!dtype.is_scalable_vector())
+ << "Versions of LLVM < 11 do not support scalable vectors. Please
upgrade to a later "
+ "version.";
+ llvm::Constant* mask = llvm::ConstantVector::getSplat(dtype.lanes(), zero);
+#endif
+ return builder_->CreateShuffleVector(value, undef, mask);
}
void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) {
diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h
index 2efac03073..0f7aa847ec 100644
--- a/src/target/llvm/codegen_llvm.h
+++ b/src/target/llvm/codegen_llvm.h
@@ -468,7 +468,6 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const
PrimExpr&)>,
llvm::Value* CreateAdd(DataType t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateSub(DataType t, llvm::Value* a, llvm::Value* b);
llvm::Value* CreateMul(DataType t, llvm::Value* a, llvm::Value* b);
- llvm::Value* CreateBroadcast(llvm::Value* value, int lanes);
virtual TypedPointer CreateBufferPtr(llvm::Value* buffer_ptr, DataType
buffer_element_dtype,
llvm::ArrayRef<llvm::Value*> indices,
DataType value_dtype);
// Vector concatenation.
diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc
index 2bd1e06083..2d2c097be4 100644
--- a/src/tir/ir/data_type_rewriter.cc
+++ b/src/tir/ir/data_type_rewriter.cc
@@ -451,7 +451,7 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const
BufferStoreNode* op) {
Buffer new_buffer = GetRemappedBuffer(op->buffer);
auto value = this->VisitExpr(op->value);
- if (new_buffer->dtype != value->dtype && value->dtype.lanes() == 1) {
+ if (new_buffer->dtype != value->dtype && value->dtype.is_scalar()) {
value = cast(new_buffer->dtype, value);
}
auto indices = VisitIndices(op->indices);
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 1b611d4534..c2baad2096 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -58,7 +58,9 @@ namespace tir {
CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " <<
a.dtype() << " vs. " \
<< b.dtype() << "\n";
\
ObjectPtr<T> node = make_object<T>();
\
- node->dtype = DataType::Bool(a.dtype().lanes());
\
+ DataType a_dtype = a.dtype();
\
+ node->dtype =
\
+ DataType::Bool(a_dtype.get_lanes_or_vscale_factor(),
a_dtype.is_scalable_vector()); \
node->a = std::move(a);
\
node->b = std::move(b);
\
node->span = std::move(span);
\
@@ -393,7 +395,8 @@ Not::Not(PrimExpr a, Span span) {
ICHECK(a.dtype().is_bool());
ObjectPtr<NotNode> node = make_object<NotNode>();
- node->dtype = DataType::Bool(a.dtype().lanes());
+ DataType a_dtype = a.dtype();
+ node->dtype = DataType::Bool(a_dtype.get_lanes_or_vscale_factor(),
a_dtype.is_scalable_vector());
node->a = std::move(a);
node->span = std::move(span);
data_ = std::move(node);
diff --git a/src/tir/transforms/storage_rewrite.cc
b/src/tir/transforms/storage_rewrite.cc
index e40f683e21..3f34f2e870 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -1275,6 +1275,13 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
auto it = info_map_.find(buffer);
ICHECK(it != info_map_.end()) << "Load/Store of buffer " <<
buffer->name_hint << " (" << buffer
<< ") occurred before its declaration.";
+
+ if (value_dtype.is_scalable_vector()) {
+ // Scalable types are not currently supported in storage_rewrite.
Scalable buffer
+ // accesses are not currently checked and therefore are not rewritten.
+ return;
+ }
+
BufferVarInfo& var_info = it->second;
if (value_dtype.element_of() == DataType::Bool()) {
diff --git a/tests/cpp/tir_scalable_datatype.cc
b/tests/cpp/tir_scalable_datatype.cc
index 23decef69e..4b4764555f 100644
--- a/tests/cpp/tir_scalable_datatype.cc
+++ b/tests/cpp/tir_scalable_datatype.cc
@@ -162,6 +162,22 @@ TEST(ScalableDataType,
TestScalableDataTypeInvalidLanesAccess) {
tvm::InternalError);
}
+TEST(ScalableDataType, TestScalableBool) {
+ tvm::DataType scalable_type = tvm::DataType::Bool(4, true);
+ ASSERT_EQ(scalable_type.code(), kDLUInt);
+ ASSERT_EQ(scalable_type.bits(), 1);
+ ASSERT_EQ(scalable_type.vscale_factor(), 4);
+ ASSERT_TRUE(scalable_type.is_scalable_vector());
+}
+
+TEST(ScalableDataType, TestScalableUInt) {
+ tvm::DataType scalable_type = tvm::DataType::UInt(1, 4, true);
+ ASSERT_EQ(scalable_type.code(), kDLUInt);
+ ASSERT_EQ(scalable_type.bits(), 1);
+ ASSERT_EQ(scalable_type.vscale_factor(), 4);
+ ASSERT_TRUE(scalable_type.is_scalable_vector());
+}
+
// -----------
// Integration
// -----------
diff --git a/tests/python/codegen/test_target_codegen_aarch64.py
b/tests/python/codegen/test_target_codegen_aarch64.py
index 4e75f916d9..773c113f4a 100644
--- a/tests/python/codegen/test_target_codegen_aarch64.py
+++ b/tests/python/codegen/test_target_codegen_aarch64.py
@@ -492,5 +492,46 @@ def test_codegen_vscale():
assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM."
[email protected](
+ llvm_version_major() < 11, reason="Vscale is not supported in earlier
versions of LLVM"
+)
+def test_scalable_buffer_load_store():
+ target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
+
+ @T.prim_func
+ def my_func(a: T.handle, b: T.handle):
+ A = T.match_buffer(a, (128,), "float32")
+ B = T.match_buffer(b, (128,), "float32")
+ T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+ B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())]
+
+ mod = tvm.build(my_func, target=target)
+ llvm = mod.get_source("ll")
+
+ assert re.findall(r"load <vscale x 4 x float>", llvm), "No scalable load
in generated LLVM."
+ assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable
store in generated LLVM."
+
+
[email protected](
+ llvm_version_major() < 11, reason="Vscale is not supported in earlier
versions of LLVM"
+)
+def test_scalable_broadcast():
+ target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
+
+ @T.prim_func
+ def my_func(a: T.handle):
+ A = T.match_buffer(a, (128,), "float32")
+ T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+ A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale())
+
+ mod = tvm.build(my_func, target=target)
+ llvm = mod.get_source("ll")
+
+ assert re.findall(
+ r"shufflevector \(<vscale x 4 x float> insertelement \(<vscale x 4 x
float>", llvm
+ ), "No scalable broadcast in generated LLVM."
+ assert re.findall(r" store <vscale x 4 x float>", llvm), "No scalable
store in generated LLVM."
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/target/test_arm_target.py
b/tests/python/target/test_arm_target.py
index dc8452710a..158d941073 100644
--- a/tests/python/target/test_arm_target.py
+++ b/tests/python/target/test_arm_target.py
@@ -14,9 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
+import subprocess
+import tempfile
+import re
+
import pytest
+import numpy as np
import tvm
+from tvm.script import tir as T
from tvm.topi.arm_cpu.conv2d_int8 import is_int8_hw_support
from tvm.target import codegen
@@ -61,3 +68,121 @@ def test_arm_conv2d_int8_support(
with tvm.target.Target(arm_target):
monkeypatch.setattr(codegen, "llvm_version_major", lambda:
llvm_version)
assert is_int8_hw_support(input_dtype, kernel_dtype) == is_supported
+
+
[email protected](scope="session")
+def sve_device_vector_length():
+ c_code = r"""
+ #include <stdio.h>
+ #include <arm_sve.h>
+
+ int main() {
+ printf("%ld\n", svcntb() * 8);
+ }
+ """
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ c_path = f"{tmp_dir}/vl.c"
+ o_path = f"{tmp_dir}/out.o"
+ with open(c_path, "w") as f:
+ f.write(c_code)
+ tvm.contrib.cc.create_executable(o_path, c_path, ["-march=native"])
+ out = subprocess.check_output(o_path, shell=True).strip().decode()
+
+ return int(out)
+
+
[email protected]_aarch64_sve
+def test_scalable_div(sve_device_vector_length):
+ np.random.seed(0)
+ target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
+ dev = tvm.cpu(0)
+
+ @T.prim_func
+ def my_func(a: T.handle):
+ A = T.match_buffer(a, (1,), "int32")
+ T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+ A[0] = T.Div(10000, 4 * T.vscale())
+
+ mod = tvm.build(my_func, target=target)
+
+ A_nd = tvm.nd.array(np.empty((1,), dtype="int32"), device=dev)
+ mod(A_nd)
+
+ ref = 10000 // (sve_device_vector_length // 32)
+ tvm.testing.assert_allclose(A_nd.numpy()[0], ref)
+
+
[email protected]_aarch64_sve
+def test_scalable_buffer_load_store(sve_device_vector_length):
+ np.random.seed(0)
+ target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
+ num_elements = sve_device_vector_length // 32
+ dev = tvm.cpu(0)
+
+ @T.prim_func
+ def my_func(a: T.handle, b: T.handle):
+ A = T.match_buffer(a, (num_elements,), "float32")
+ B = T.match_buffer(b, (num_elements,), "float32")
+ T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+ B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())]
+
+ mod = tvm.build(my_func, target=target)
+
+ A_np = np.random.uniform(size=(num_elements,)).astype("float32")
+ B_np = np.zeros((num_elements,)).astype("float32")
+ A_nd = tvm.nd.array(A_np, device=dev)
+ B_nd = tvm.nd.array(B_np, device=dev)
+ mod(A_nd, B_nd)
+
+ tvm.testing.assert_allclose(B_nd.numpy(), A_np)
+
+
[email protected]_aarch64_sve
+def test_scalable_loop_bound(sve_device_vector_length):
+ np.random.seed(0)
+
+ dtype = "float32"
+ num_elements = sve_device_vector_length // 32
+ target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
+ dev = tvm.cpu(0)
+
+ @T.prim_func
+ def my_func(a: T.handle, b: T.handle):
+ A = T.match_buffer(a, (num_elements,), "float32")
+ B = T.match_buffer(b, (num_elements,), "float32")
+ T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+ for i in T.serial(0, 4 * T.vscale()):
+ B[i] = A[i]
+
+ mod = tvm.build(my_func, target=target)
+
+ A_np = np.random.uniform(size=(num_elements,)).astype(dtype)
+ B_np = np.zeros((num_elements,)).astype(dtype)
+ A_nd = tvm.nd.array(A_np, device=dev)
+ B_nd = tvm.nd.array(B_np, device=dev)
+ mod(A_nd, B_nd)
+
+ tvm.testing.assert_allclose(B_nd.numpy(), A_np)
+
+
[email protected]_aarch64_sve
+def test_scalable_broadcast(sve_device_vector_length):
+ target = "llvm -mtriple=aarch64-linux-gnu -mattr=+sve"
+ num_elements = sve_device_vector_length // 32
+ dev = tvm.cpu(0)
+
+ @T.prim_func
+ def my_func(a: T.handle):
+ A = T.match_buffer(a, (num_elements,), "float32")
+ T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+ A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale())
+
+ mod = tvm.build(my_func, target=target)
+
+ A_np = np.zeros((num_elements,)).astype("float32")
+ A_nd = tvm.nd.array(A_np, device=dev)
+ mod(A_nd)
+
+ ref = np.ones((num_elements,))
+ tvm.testing.assert_allclose(A_nd.numpy(), ref)