This is an automated email from the ASF dual-hosted git repository.
ekalda pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 563ef9587c [SVE] Add support for scalable data type strings (#16612)
563ef9587c is described below
commit 563ef9587cfa913cf96f9ec061cdab43ce744b70
Author: Luke Hutton <[email protected]>
AuthorDate: Tue Feb 27 09:24:36 2024 +0000
[SVE] Add support for scalable data type strings (#16612)
This commit adds support for representing scalable vectors using the
string data type format. For example, "float32xvscalex4" may be used
to represent the following scalable type:
`DataType(kDLFloat, 32, /*lanes=*/4, /*is_scalable=*/true)`.
---------
Co-authored-by: Elen Kalda <[email protected]>
Co-authored-by: Neil Hickey <[email protected]>
---
include/tvm/runtime/data_type.h | 17 +++--
python/tvm/_ffi/runtime_ctypes.py | 11 +++-
src/tir/op/op.cc | 2 +-
tests/cpp/tir_scalable_datatype.cc | 76 +++++++++++++++++++---
tests/python/tir-base/test_tir_nodes.py | 15 +----
.../python/tir-base/test_tir_scalable_datatype.py | 60 +++++++++++++++++
6 files changed, 153 insertions(+), 28 deletions(-)
diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index 5efa5f3b90..f6a7d424ed 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -27,6 +27,7 @@
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/logging.h>
+#include <cstring>
#include <string>
#include <type_traits>
@@ -110,7 +111,7 @@ class DataType {
return -lanes_as_int;
}
/*! \return whether type is a scalar type. */
- bool is_scalar() const { return lanes() == 1; }
+ bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
/*! \return whether type is a scalar type. */
bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
/*! \return whether type is a float type. */
@@ -389,9 +390,12 @@ inline std::ostream& operator<<(std::ostream& os,
DLDataType t) { // NOLINT(*)
os << "custom[" << GetCustomTypeName(t.code) << "]";
}
if (t.code == kTVMOpaqueHandle) return os;
+ int16_t lanes = static_cast<int16_t>(t.lanes);
os << static_cast<int>(t.bits);
- if (t.lanes != 1) {
- os << 'x' << static_cast<int>(t.lanes);
+ if (lanes > 1) {
+ os << 'x' << lanes;
+ } else if (lanes < -1) {
+ os << "xvscalex" << -lanes;
}
return os;
}
@@ -456,9 +460,14 @@ inline DLDataType String2DLDataType(std::string s) {
char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
if (bits != 0) t.bits = bits;
+ int scalable_multiplier = 1;
+ if (strncmp(xdelim, "xvscale", 7) == 0) {
+ scalable_multiplier = -1;
+ xdelim += 7;
+ }
char* endpt = xdelim;
if (*xdelim == 'x') {
- t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
+ t.lanes = static_cast<uint16_t>(scalable_multiplier * strtoul(xdelim + 1,
&endpt, 10));
}
ICHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
return t;
diff --git a/python/tvm/_ffi/runtime_ctypes.py
b/python/tvm/_ffi/runtime_ctypes.py
index 54e4d8f205..06f2d4c7e6 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -135,7 +135,11 @@ class DataType(ctypes.Structure):
arr = type_str.split("x")
head = arr[0]
- self.lanes = int(arr[1]) if len(arr) > 1 else 1
+ if len(arr) == 3:
+ assert arr[1] == "vscale", f"Invalid data type. Expected 'vscale'
but got '{arr[1]}'"
+ self.lanes = ctypes.c_uint16(-int(arr[2]))
+ elif len(arr) > 1:
+ self.lanes = ctypes.c_uint16(int(arr[1]))
bits = 32
if head.startswith("int"):
@@ -188,8 +192,11 @@ class DataType(ctypes.Structure):
type_name = "custom[%s]" %
tvm.runtime._ffi_api._datatype_get_type_name(self.type_code)
x = "%s%d" % (type_name, self.bits)
- if self.lanes != 1:
+ lanes_as_int = ctypes.c_int16(self.lanes).value
+ if lanes_as_int > 1:
x += "x%d" % self.lanes
+ elif lanes_as_int < -1:
+ x += "xvscalex%d" % -lanes_as_int
return x
def __eq__(self, other):
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index b329d25b54..c46a8c2643 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -342,7 +342,7 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span)
{
using tir::FloatImmNode;
if (value.dtype() == t) return value;
// const fold IntImm as they are used in index computations
- if (t.lanes() == 1) {
+ if (t.is_scalar()) {
if (const IntImmNode* op = value.as<IntImmNode>()) {
return make_const(t, op->value, op->span);
} else if (const FloatImmNode* op = value.as<FloatImmNode>()) {
diff --git a/tests/cpp/tir_scalable_datatype.cc
b/tests/cpp/tir_scalable_datatype.cc
index daa4dfe729..23decef69e 100644
--- a/tests/cpp/tir_scalable_datatype.cc
+++ b/tests/cpp/tir_scalable_datatype.cc
@@ -24,12 +24,14 @@
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
+#include "../../src/script/printer/utils.h"
+
using ::testing::HasSubstr;
// ---------
// Data Type
// ---------
-TEST(TIR, TestCreateScalableType) {
+TEST(ScalableDataType, TestCreateScalableType) {
tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true);
ASSERT_EQ(scalable_type.code(), kDLInt);
ASSERT_EQ(scalable_type.bits(), 32);
@@ -38,7 +40,7 @@ TEST(TIR, TestCreateScalableType) {
ASSERT_TRUE(scalable_type.is_scalable_or_fixed_length_vector());
}
-TEST(TIR, TestScalableWithBits) {
+TEST(ScalableDataType, TestScalableWithBits) {
tvm::DataType scalable_type = tvm::DataType(kDLInt, 1, 8, true);
scalable_type = scalable_type.with_bits(32);
ASSERT_EQ(scalable_type.bits(), 32);
@@ -46,7 +48,7 @@ TEST(TIR, TestScalableWithBits) {
ASSERT_TRUE(scalable_type.is_scalable_or_fixed_length_vector());
}
-TEST(TIR, TestScalableWithVscaleFactor) {
+TEST(ScalableDataType, TestScalableWithVscaleFactor) {
tvm::DataType type = tvm::DataType(kDLInt, 32, 1);
tvm::DataType scalable_type = type.with_scalable_vscale_factor(4);
ASSERT_EQ(scalable_type.vscale_factor(), 4);
@@ -54,18 +56,54 @@ TEST(TIR, TestScalableWithVscaleFactor) {
ASSERT_TRUE(scalable_type.is_scalable_or_fixed_length_vector());
}
-TEST(TIR, TestAssignScalableDataType) {
+TEST(ScalableDataType, TestAssignScalableDataType) {
tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 2, true);
tvm::DataType scalable_type_copy = scalable_type;
ASSERT_TRUE(scalable_type_copy.is_scalable_vector());
ASSERT_TRUE(scalable_type_copy.is_scalable_or_fixed_length_vector());
}
-TEST(TIR, TestScalableDataTypeAndNonScalableDataTypeInequality) {
+TEST(ScalableDataType, TestScalableDataTypeEquality) {
+ ASSERT_TRUE(tvm::DataType(kDLInt, 32, 4, true) == tvm::DataType(kDLInt, 32,
4, true));
+}
+
+TEST(ScalableDataType, TestScalableDataTypeAndNonScalableDataTypeInequality) {
ASSERT_FALSE(tvm::DataType(kDLInt, 32, 4, true) == tvm::DataType(kDLInt, 32,
4));
}
-TEST(TIR, TestGetScalableVectorBytesError) {
+TEST(ScalableDataType, TestIsScalar) {
+ ASSERT_FALSE(tvm::DataType(kDLInt, 32, 4, true).is_scalar());
+ ASSERT_TRUE(tvm::DataType(kDLInt, 32, 1, false).is_scalar());
+ ASSERT_FALSE(tvm::DataType(kDLInt, 32, 4, false).is_scalar());
+ ASSERT_FALSE(tvm::DataType(kDLOpaqueHandle, 1, 0, false).is_scalar());
+}
+
+TEST(ScalableDataType, TestScalableDataTypeToString) {
+ tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true);
+ EXPECT_EQ(tvm::runtime::DLDataType2String(scalable_type), "int32xvscalex4");
+}
+
+TEST(ScalableDataType, TestStringToScalableDataType) {
+ std::string scalable_type_str = "int32xvscalex4";
+ EXPECT_EQ(tvm::DataType(tvm::runtime::String2DLDataType(scalable_type_str)),
+ tvm::DataType(kDLInt, 32, 4, true));
+}
+
+TEST(ScalableDataType, TestInvalidStringToScalableDataType) {
+ std::string scalable_type_str = "int32x4xvscale";
+ EXPECT_THROW(
+ {
+ try {
+ tvm::runtime::String2DLDataType(scalable_type_str);
+ } catch (const tvm::InternalError& e) {
+ EXPECT_THAT(e.what(), HasSubstr("unknown type int32x4xvscale"));
+ throw;
+ }
+ },
+ tvm::InternalError);
+}
+
+TEST(ScalableDataType, TestGetScalableVectorBytes) {
tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true);
EXPECT_THROW(
{
@@ -80,7 +118,7 @@ TEST(TIR, TestGetScalableVectorBytesError) {
tvm::InternalError);
}
-TEST(TIR, TestScalableDataTypeInvalidLanesError) {
+TEST(ScalableDataType, TestScalableDataTypeInvalidLanesError) {
EXPECT_THROW(
{
try {
@@ -93,7 +131,7 @@ TEST(TIR, TestScalableDataTypeInvalidLanesError) {
tvm::InternalError);
}
-TEST(TIR, TestScalableDataTypeInvalidVscaleFactorAccess) {
+TEST(ScalableDataType, TestScalableDataTypeInvalidVscaleFactorAccess) {
tvm::DataType fixed_length_type = tvm::DataType(kDLFloat, 32, 4);
ASSERT_TRUE(fixed_length_type.is_fixed_length_vector());
ASSERT_TRUE(fixed_length_type.is_scalable_or_fixed_length_vector());
@@ -109,7 +147,7 @@ TEST(TIR, TestScalableDataTypeInvalidVscaleFactorAccess) {
tvm::InternalError);
}
-TEST(TIR, TestScalableDataTypeInvalidLanesAccess) {
+TEST(ScalableDataType, TestScalableDataTypeInvalidLanesAccess) {
tvm::DataType scalable_type = tvm::DataType(kDLFloat, 32, 4, true);
EXPECT_THROW(
{
@@ -123,3 +161,23 @@ TEST(TIR, TestScalableDataTypeInvalidLanesAccess) {
},
tvm::InternalError);
}
+
+// -----------
+// Integration
+// -----------
+#if TVM_LLVM_VERSION >= 130
+TEST(ScalableDataType, TestScalableIntrinCall) {
+ tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true);
+ tvm::tir::Call call = tvm::tir::Call(
+ scalable_type, tvm::tir::builtin::call_llvm_intrin(),
+ {tvm::IntImm(tvm::DataType::Int(32),
::llvm::Intrinsic::experimental_stepvector)});
+ ASSERT_EQ(call->dtype, scalable_type);
+ ASSERT_EQ(call->Script(),
+ "T.call_llvm_intrin(\"int32xvscalex4\",
\"llvm.experimental.stepvector\")");
+}
+#endif
+
+TEST(ScalableDataType, TestTIRScriptScalableDtype2Str) {
+ tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true);
+ ASSERT_EQ(tvm::script::printer::DType2Str(scalable_type), "int32xvscalex4");
+}
diff --git a/tests/python/tir-base/test_tir_nodes.py
b/tests/python/tir-base/test_tir_nodes.py
index 5b55c432b0..f3498f8ec7 100644
--- a/tests/python/tir-base/test_tir_nodes.py
+++ b/tests/python/tir-base/test_tir_nodes.py
@@ -439,21 +439,15 @@ def test_broadcast_to_scalable_vec():
assert broadcast.lanes.b == 4
[email protected](
- reason="Support for scalable data type string will be added in P3 of
https://github.com/apache/tvm/issues/16455"
-)
def test_buffer_load_scalable_vec():
buf = tvm.tir.decl_buffer((24,), "float32")
index = tvm.tir.expr.Ramp(1, 1, 8 * tvm.tir.vscale())
load = tvm.tir.BufferLoad(buf, [index])
assert isinstance(load, tvm.tir.BufferLoad)
- assert load.dtype == "float32x8xvscale"
+ assert load.dtype == "float32xvscalex8"
[email protected](
- reason="Support for scalable data type string will be added in P3 of
https://github.com/apache/tvm/issues/16455"
-)
def test_buffer_store_scalable_vec():
b = tvm.tir.decl_buffer((24,), "int32")
value = tvm.tir.expr.Broadcast(1, 4 * tvm.tir.vscale())
@@ -461,15 +455,12 @@ def test_buffer_store_scalable_vec():
store = tvm.tir.BufferStore(b, value, [index])
assert isinstance(store, tvm.tir.BufferStore)
- assert store.value.dtype == "int32x4xvscale"
+ assert store.value.dtype == "int32xvscalex4"
[email protected](
- reason="Support for scalable data type string will be added in P3 of
https://github.com/apache/tvm/issues/16455"
-)
def test_scalable_vec_cast():
b = tvm.tir.decl_buffer((24,), "float32")
- value = tvm.tir.expr.Broadcast(1, 12 *
tvm.tir.vscale()).astype("float32x12xvscale")
+ value = tvm.tir.expr.Broadcast(1, 12 *
tvm.tir.vscale()).astype("float32xvscalex12")
index = tvm.tir.expr.Ramp(0, 1, 12 * tvm.tir.vscale())
store = tvm.tir.BufferStore(b, value, [index])
diff --git a/tests/python/tir-base/test_tir_scalable_datatype.py
b/tests/python/tir-base/test_tir_scalable_datatype.py
new file mode 100644
index 0000000000..41a367e6e5
--- /dev/null
+++ b/tests/python/tir-base/test_tir_scalable_datatype.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+import tvm
+from tvm import tir
+from tvm.script import tir as T
+from tvm.target.codegen import llvm_version_major
+
+"""
+Tests for scalable data types.
+"""
+
+
+def test_create_scalable_data_type_python_api():
+ dtype = tvm.DataType("float32xvscalex4")
+ assert str(dtype) == "float32xvscalex4"
+
+
[email protected](llvm_version_major() < 13, reason="Stepvector intrinsic
was added in LLVM 13.")
+def test_create_scalable_tir_intrin():
+ intrin = tir.call_llvm_intrin("int32xvscalex4",
"llvm.experimental.stepvector")
+ assert intrin.dtype == "int32xvscalex4"
+ assert str(intrin) == 'T.call_llvm_intrin("int32xvscalex4",
"llvm.experimental.stepvector")'
+
+
[email protected](llvm_version_major() < 13, reason="Stepvector intrinsic
was added in LLVM 13.")
+def test_tvm_script_create_scalable_tir_intrin():
+ @T.prim_func
+ def my_func():
+ T.call_llvm_intrin("int32xvscalex4", "llvm.experimental.stepvector")
+
+ assert (
+ 'T.call_llvm_intrin("int32xvscalex4", "llvm.experimental.stepvector")'
in my_func.script()
+ )
+
+
+def test_invalid_data_type():
+ err_msg = "Invalid data type. Expected 'vscale' but got '4'"
+ with pytest.raises(AssertionError, match=err_msg):
+ tvm.DataType("float32x4xvscale")
+
+
+if __name__ == "__main__":
+ tvm.testing.main()