This is an automated email from the ASF dual-hosted git repository.

ekalda pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 563ef9587c [SVE] Add support for scalable data type strings (#16612)
563ef9587c is described below

commit 563ef9587cfa913cf96f9ec061cdab43ce744b70
Author: Luke Hutton <[email protected]>
AuthorDate: Tue Feb 27 09:24:36 2024 +0000

    [SVE] Add support for scalable data type strings (#16612)
    
    This commit adds support for representing scalable vectors using the
    string data type format. For example, "float32xvscalex4" may be used
    to represent the following scalable type:
    `DataType(kDLFloat, 32, /*lanes=*/4, /*is_scalable=*/true)`.
    
    
    ---------
    
    
    Co-authored-by: Elen Kalda <[email protected]>
    Co-authored-by: Neil Hickey <[email protected]>
---
 include/tvm/runtime/data_type.h                    | 17 +++--
 python/tvm/_ffi/runtime_ctypes.py                  | 11 +++-
 src/tir/op/op.cc                                   |  2 +-
 tests/cpp/tir_scalable_datatype.cc                 | 76 +++++++++++++++++++---
 tests/python/tir-base/test_tir_nodes.py            | 15 +----
 .../python/tir-base/test_tir_scalable_datatype.py  | 60 +++++++++++++++++
 6 files changed, 153 insertions(+), 28 deletions(-)

diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index 5efa5f3b90..f6a7d424ed 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -27,6 +27,7 @@
 #include <tvm/runtime/c_runtime_api.h>
 #include <tvm/runtime/logging.h>
 
+#include <cstring>
 #include <string>
 #include <type_traits>
 
@@ -110,7 +111,7 @@ class DataType {
     return -lanes_as_int;
   }
   /*! \return whether type is a scalar type. */
-  bool is_scalar() const { return lanes() == 1; }
+  bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
   /*! \return whether type is a scalar type. */
   bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
   /*! \return whether type is a float type. */
@@ -389,9 +390,12 @@ inline std::ostream& operator<<(std::ostream& os, 
DLDataType t) {  // NOLINT(*)
     os << "custom[" << GetCustomTypeName(t.code) << "]";
   }
   if (t.code == kTVMOpaqueHandle) return os;
+  int16_t lanes = static_cast<int16_t>(t.lanes);
   os << static_cast<int>(t.bits);
-  if (t.lanes != 1) {
-    os << 'x' << static_cast<int>(t.lanes);
+  if (lanes > 1) {
+    os << 'x' << lanes;
+  } else if (lanes < -1) {
+    os << "xvscalex" << -lanes;
   }
   return os;
 }
@@ -456,9 +460,14 @@ inline DLDataType String2DLDataType(std::string s) {
   char* xdelim;  // emulate sscanf("%ux%u", bits, lanes)
   uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
   if (bits != 0) t.bits = bits;
+  int scalable_multiplier = 1;
+  if (strncmp(xdelim, "xvscale", 7) == 0) {
+    scalable_multiplier = -1;
+    xdelim += 7;
+  }
   char* endpt = xdelim;
   if (*xdelim == 'x') {
-    t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
+    t.lanes = static_cast<uint16_t>(scalable_multiplier * strtoul(xdelim + 1, 
&endpt, 10));
   }
   ICHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
   return t;
diff --git a/python/tvm/_ffi/runtime_ctypes.py 
b/python/tvm/_ffi/runtime_ctypes.py
index 54e4d8f205..06f2d4c7e6 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -135,7 +135,11 @@ class DataType(ctypes.Structure):
 
         arr = type_str.split("x")
         head = arr[0]
-        self.lanes = int(arr[1]) if len(arr) > 1 else 1
+        if len(arr) == 3:
+            assert arr[1] == "vscale", f"Invalid data type. Expected 'vscale' 
but got '{arr[1]}'"
+            self.lanes = ctypes.c_uint16(-int(arr[2]))
+        elif len(arr) > 1:
+            self.lanes = ctypes.c_uint16(int(arr[1]))
         bits = 32
 
         if head.startswith("int"):
@@ -188,8 +192,11 @@ class DataType(ctypes.Structure):
 
             type_name = "custom[%s]" % 
tvm.runtime._ffi_api._datatype_get_type_name(self.type_code)
         x = "%s%d" % (type_name, self.bits)
-        if self.lanes != 1:
+        lanes_as_int = ctypes.c_int16(self.lanes).value
+        if lanes_as_int > 1:
             x += "x%d" % self.lanes
+        elif lanes_as_int < -1:
+            x += "xvscalex%d" % -lanes_as_int
         return x
 
     def __eq__(self, other):
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index b329d25b54..c46a8c2643 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -342,7 +342,7 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) 
{
   using tir::FloatImmNode;
   if (value.dtype() == t) return value;
   // const fold IntImm as they are used in index computations
-  if (t.lanes() == 1) {
+  if (t.is_scalar()) {
     if (const IntImmNode* op = value.as<IntImmNode>()) {
       return make_const(t, op->value, op->span);
     } else if (const FloatImmNode* op = value.as<FloatImmNode>()) {
diff --git a/tests/cpp/tir_scalable_datatype.cc 
b/tests/cpp/tir_scalable_datatype.cc
index daa4dfe729..23decef69e 100644
--- a/tests/cpp/tir_scalable_datatype.cc
+++ b/tests/cpp/tir_scalable_datatype.cc
@@ -24,12 +24,14 @@
 #include <tvm/tir/builtin.h>
 #include <tvm/tir/expr.h>
 
+#include "../../src/script/printer/utils.h"
+
 using ::testing::HasSubstr;
 
 // ---------
 // Data Type
 // ---------
-TEST(TIR, TestCreateScalableType) {
+TEST(ScalableDataType, TestCreateScalableType) {
   tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true);
   ASSERT_EQ(scalable_type.code(), kDLInt);
   ASSERT_EQ(scalable_type.bits(), 32);
@@ -38,7 +40,7 @@ TEST(TIR, TestCreateScalableType) {
   ASSERT_TRUE(scalable_type.is_scalable_or_fixed_length_vector());
 }
 
-TEST(TIR, TestScalableWithBits) {
+TEST(ScalableDataType, TestScalableWithBits) {
   tvm::DataType scalable_type = tvm::DataType(kDLInt, 1, 8, true);
   scalable_type = scalable_type.with_bits(32);
   ASSERT_EQ(scalable_type.bits(), 32);
@@ -46,7 +48,7 @@ TEST(TIR, TestScalableWithBits) {
   ASSERT_TRUE(scalable_type.is_scalable_or_fixed_length_vector());
 }
 
-TEST(TIR, TestScalableWithVscaleFactor) {
+TEST(ScalableDataType, TestScalableWithVscaleFactor) {
   tvm::DataType type = tvm::DataType(kDLInt, 32, 1);
   tvm::DataType scalable_type = type.with_scalable_vscale_factor(4);
   ASSERT_EQ(scalable_type.vscale_factor(), 4);
@@ -54,18 +56,54 @@ TEST(TIR, TestScalableWithVscaleFactor) {
   ASSERT_TRUE(scalable_type.is_scalable_or_fixed_length_vector());
 }
 
-TEST(TIR, TestAssignScalableDataType) {
+TEST(ScalableDataType, TestAssignScalableDataType) {
   tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 2, true);
   tvm::DataType scalable_type_copy = scalable_type;
   ASSERT_TRUE(scalable_type_copy.is_scalable_vector());
   ASSERT_TRUE(scalable_type_copy.is_scalable_or_fixed_length_vector());
 }
 
-TEST(TIR, TestScalableDataTypeAndNonScalableDataTypeInequality) {
+TEST(ScalableDataType, TestScalableDataTypeEquality) {
+  ASSERT_TRUE(tvm::DataType(kDLInt, 32, 4, true) == tvm::DataType(kDLInt, 32, 
4, true));
+}
+
+TEST(ScalableDataType, TestScalableDataTypeAndNonScalableDataTypeInequality) {
   ASSERT_FALSE(tvm::DataType(kDLInt, 32, 4, true) == tvm::DataType(kDLInt, 32, 
4));
 }
 
-TEST(TIR, TestGetScalableVectorBytesError) {
+TEST(ScalableDataType, TestIsScalar) {
+  ASSERT_FALSE(tvm::DataType(kDLInt, 32, 4, true).is_scalar());
+  ASSERT_TRUE(tvm::DataType(kDLInt, 32, 1, false).is_scalar());
+  ASSERT_FALSE(tvm::DataType(kDLInt, 32, 4, false).is_scalar());
+  ASSERT_FALSE(tvm::DataType(kDLOpaqueHandle, 1, 0, false).is_scalar());
+}
+
+TEST(ScalableDataType, TestScalableDataTypeToString) {
+  tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true);
+  EXPECT_EQ(tvm::runtime::DLDataType2String(scalable_type), "int32xvscalex4");
+}
+
+TEST(ScalableDataType, TestStringToScalableDataType) {
+  std::string scalable_type_str = "int32xvscalex4";
+  EXPECT_EQ(tvm::DataType(tvm::runtime::String2DLDataType(scalable_type_str)),
+            tvm::DataType(kDLInt, 32, 4, true));
+}
+
+TEST(ScalableDataType, TestInvalidStringToScalableDataType) {
+  std::string scalable_type_str = "int32x4xvscale";
+  EXPECT_THROW(
+      {
+        try {
+          tvm::runtime::String2DLDataType(scalable_type_str);
+        } catch (const tvm::InternalError& e) {
+          EXPECT_THAT(e.what(), HasSubstr("unknown type int32x4xvscale"));
+          throw;
+        }
+      },
+      tvm::InternalError);
+}
+
+TEST(ScalableDataType, TestGetScalableVectorBytes) {
   tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true);
   EXPECT_THROW(
       {
@@ -80,7 +118,7 @@ TEST(TIR, TestGetScalableVectorBytesError) {
       tvm::InternalError);
 }
 
-TEST(TIR, TestScalableDataTypeInvalidLanesError) {
+TEST(ScalableDataType, TestScalableDataTypeInvalidLanesError) {
   EXPECT_THROW(
       {
         try {
@@ -93,7 +131,7 @@ TEST(TIR, TestScalableDataTypeInvalidLanesError) {
       tvm::InternalError);
 }
 
-TEST(TIR, TestScalableDataTypeInvalidVscaleFactorAccess) {
+TEST(ScalableDataType, TestScalableDataTypeInvalidVscaleFactorAccess) {
   tvm::DataType fixed_length_type = tvm::DataType(kDLFloat, 32, 4);
   ASSERT_TRUE(fixed_length_type.is_fixed_length_vector());
   ASSERT_TRUE(fixed_length_type.is_scalable_or_fixed_length_vector());
@@ -109,7 +147,7 @@ TEST(TIR, TestScalableDataTypeInvalidVscaleFactorAccess) {
       tvm::InternalError);
 }
 
-TEST(TIR, TestScalableDataTypeInvalidLanesAccess) {
+TEST(ScalableDataType, TestScalableDataTypeInvalidLanesAccess) {
   tvm::DataType scalable_type = tvm::DataType(kDLFloat, 32, 4, true);
   EXPECT_THROW(
       {
@@ -123,3 +161,23 @@ TEST(TIR, TestScalableDataTypeInvalidLanesAccess) {
       },
       tvm::InternalError);
 }
+
+// -----------
+// Integration
+// -----------
+#if TVM_LLVM_VERSION >= 130
+TEST(ScalableDataType, TestScalableIntrinCall) {
+  tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true);
+  tvm::tir::Call call = tvm::tir::Call(
+      scalable_type, tvm::tir::builtin::call_llvm_intrin(),
+      {tvm::IntImm(tvm::DataType::Int(32), 
::llvm::Intrinsic::experimental_stepvector)});
+  ASSERT_EQ(call->dtype, scalable_type);
+  ASSERT_EQ(call->Script(),
+            "T.call_llvm_intrin(\"int32xvscalex4\", 
\"llvm.experimental.stepvector\")");
+}
+#endif
+
+TEST(ScalableDataType, TestTIRScriptScalableDtype2Str) {
+  tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true);
+  ASSERT_EQ(tvm::script::printer::DType2Str(scalable_type), "int32xvscalex4");
+}
diff --git a/tests/python/tir-base/test_tir_nodes.py 
b/tests/python/tir-base/test_tir_nodes.py
index 5b55c432b0..f3498f8ec7 100644
--- a/tests/python/tir-base/test_tir_nodes.py
+++ b/tests/python/tir-base/test_tir_nodes.py
@@ -439,21 +439,15 @@ def test_broadcast_to_scalable_vec():
     assert broadcast.lanes.b == 4
 
 
[email protected](
-    reason="Support for scalable data type string will be added in P3 of 
https://github.com/apache/tvm/issues/16455";
-)
 def test_buffer_load_scalable_vec():
     buf = tvm.tir.decl_buffer((24,), "float32")
     index = tvm.tir.expr.Ramp(1, 1, 8 * tvm.tir.vscale())
     load = tvm.tir.BufferLoad(buf, [index])
 
     assert isinstance(load, tvm.tir.BufferLoad)
-    assert load.dtype == "float32x8xvscale"
+    assert load.dtype == "float32xvscalex8"
 
 
[email protected](
-    reason="Support for scalable data type string will be added in P3 of 
https://github.com/apache/tvm/issues/16455";
-)
 def test_buffer_store_scalable_vec():
     b = tvm.tir.decl_buffer((24,), "int32")
     value = tvm.tir.expr.Broadcast(1, 4 * tvm.tir.vscale())
@@ -461,15 +455,12 @@ def test_buffer_store_scalable_vec():
     store = tvm.tir.BufferStore(b, value, [index])
 
     assert isinstance(store, tvm.tir.BufferStore)
-    assert store.value.dtype == "int32x4xvscale"
+    assert store.value.dtype == "int32xvscalex4"
 
 
[email protected](
-    reason="Support for scalable data type string will be added in P3 of 
https://github.com/apache/tvm/issues/16455";
-)
 def test_scalable_vec_cast():
     b = tvm.tir.decl_buffer((24,), "float32")
-    value = tvm.tir.expr.Broadcast(1, 12 * 
tvm.tir.vscale()).astype("float32x12xvscale")
+    value = tvm.tir.expr.Broadcast(1, 12 * 
tvm.tir.vscale()).astype("float32xvscalex12")
     index = tvm.tir.expr.Ramp(0, 1, 12 * tvm.tir.vscale())
 
     store = tvm.tir.BufferStore(b, value, [index])
diff --git a/tests/python/tir-base/test_tir_scalable_datatype.py 
b/tests/python/tir-base/test_tir_scalable_datatype.py
new file mode 100644
index 0000000000..41a367e6e5
--- /dev/null
+++ b/tests/python/tir-base/test_tir_scalable_datatype.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+import tvm
+from tvm import tir
+from tvm.script import tir as T
+from tvm.target.codegen import llvm_version_major
+
+"""
+Tests for scalable data types.
+"""
+
+
+def test_create_scalable_data_type_python_api():
+    dtype = tvm.DataType("float32xvscalex4")
+    assert str(dtype) == "float32xvscalex4"
+
+
[email protected](llvm_version_major() < 13, reason="Stepvector intrinsic 
was added in LLVM 13.")
+def test_create_scalable_tir_intrin():
+    intrin = tir.call_llvm_intrin("int32xvscalex4", 
"llvm.experimental.stepvector")
+    assert intrin.dtype == "int32xvscalex4"
+    assert str(intrin) == 'T.call_llvm_intrin("int32xvscalex4", 
"llvm.experimental.stepvector")'
+
+
[email protected](llvm_version_major() < 13, reason="Stepvector intrinsic 
was added in LLVM 13.")
+def test_tvm_script_create_scalable_tir_intrin():
+    @T.prim_func
+    def my_func():
+        T.call_llvm_intrin("int32xvscalex4", "llvm.experimental.stepvector")
+
+    assert (
+        'T.call_llvm_intrin("int32xvscalex4", "llvm.experimental.stepvector")' 
in my_func.script()
+    )
+
+
+def test_invalid_data_type():
+    err_msg = "Invalid data type. Expected 'vscale' but got '4'"
+    with pytest.raises(AssertionError, match=err_msg):
+        tvm.DataType("float32x4xvscale")
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to