llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sve

Author: Momchil Velikov (momchil-velikov)

<details>
<summary>Changes</summary>



---

Patch is 30.80 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/140573.diff


5 Files Affected:

- (added) 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir 
(+117) 
- (added) 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
 (+159) 
- (added) 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir 
(+118) 
- (added) 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir 
(+119) 
- (added) 
mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir 
(+117) 


``````````diff
diff --git 
a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir 
b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
new file mode 100644
index 0000000000000..88534dd2aab1e
--- /dev/null
+++ 
b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
@@ -0,0 +1,117 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  
--convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm 
--finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} 
-entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    
-shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+  %c128 = arith.constant 128 : i32
+  func.call @setArmVLBits(%c128) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+  %acc_cst = arith.constant dense<[[-44,  20,  44, -46],
+                                   [ -8,  25, -34,  26],
+                                   [-20, -36,  -3,  39],
+                                   [-48, -31, -25, -21]]> : vector<4x4xi32>
+  %acc_m = memref.alloca() : memref<4x4xi32>
+  vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, 
memref<4x4xi32>
+
+  %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into 
memref<16xi32>
+  %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} 
: memref<16xi32>, vector<[16]xi32>
+  %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+  vector.print str "ACC:\n"
+  %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %acc0 : vector<[4]xi32>
+  vector.print %acc1 : vector<[4]xi32>
+  vector.print %acc2 : vector<[4]xi32>
+  vector.print %acc3 : vector<[4]xi32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[-35, -27, -36, -31,  23, -34,  -8, -33],
+                                   [-20,  17, -32, -47,  37,  22,  -7, -21],
+                                   [ -7, -35,  20,  -4,  39,  46, -23,  40],
+                                   [ 40,  27,  37,  43,  38,  -6,  37,  49]]> 
: vector<4x8xi8>
+
+  %lhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, 
memref<4x8xi8>
+  %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, 
vector<4x8xi8>
+
+  vector.print str "LHS:\n"
+  %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+  %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+  %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+  %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+  vector.print %lhs0 : vector<8xi8>
+  vector.print %lhs1 : vector<8xi8>
+  vector.print %lhs2 : vector<8xi8>
+  vector.print %lhs3 : vector<8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[-17, -50,  -1,  48, -13,  22,  39,  33],
+                                   [-35, -24,  37, -32,  33,  30, -11, -17],
+                                   [-28,  31,   3, -44, -15, -27,  22,  35],
+                                   [-23,  39,  48,  26, -23,  32, -39, -38]]> 
: vector<4x8xi8>
+
+  %rhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, 
memref<4x8xi8>
+
+  %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into 
memref<32xi8>
+  %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : 
memref<32xi8>, vector<[32]xi8>
+
+  vector.print str "RHS:\n"
+  %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from 
vector<[32]xi8>
+  %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from 
vector<[32]xi8>
+  vector.print %rhs0 : vector<[16]xi8>
+  vector.print %rhs1 : vector<[16]xi8>
+
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+  // Matrix multiplication
+  %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+  // Display the result of the multiplication
+  vector.print str "Result:\n"
+  %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %u0 : vector<[4]xi32>
+  vector.print %u1 : vector<[4]xi32>
+  vector.print %u2 : vector<[4]xi32>
+  vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -1999,  1941,   685, -2879 )
+// CHECK: ( -3705,  2952,   987,  -685 )
+// CHECK: (  2565,  4157, -1589,  -357 )
+// CHECK: (  2383, -2252,    32, -1365 )
+  return
+}
diff --git 
a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
 
b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
new file mode 100644
index 0000000000000..ce57be91fa540
--- /dev/null
+++ 
b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
@@ -0,0 +1,159 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  
--convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm 
--finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} 
-entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    
-shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+  %c256 = arith.constant 256 : i32
+  func.call @setArmVLBits(%c256) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+
+  // Accumulator test data
+  %acc_cst = arith.constant dense<[[-44,  20,  44, -46,  -8,  25, -34,  26],
+                                   [-20, -36,  -3,  39, -48, -31, -25, -21],
+                                   [-35, -27, -36, -31,  23, -34,  -8, -33],
+                                   [-20,  17, -32, -47,  37,  22,  -7, -21],
+                                   [ -7, -35,  20,  -4,  39,  46, -23,  40],
+                                   [ 40,  27,  37,  43,  38,  -6,  37,  49],
+                                   [-17, -50,  -1,  48, -13,  22,  39,  33],
+                                   [-35, -24,  37, -32,  33,  30, -11, -17]]> 
: vector<8x8xi32>
+  %acc_m = memref.alloca() : memref<8x8xi32>
+  vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<8x8xi32>, 
memref<8x8xi32>
+
+  %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<8x8xi32> into 
memref<64xi32>
+  %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} 
: memref<64xi32>, vector<[32]xi32>
+  %acc = vector.shape_cast %acc_flat : vector<[32]xi32> to vector<8x[4]xi32>
+
+  vector.print str "ACC:\n"
+  %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc4 = vector.extract %acc[4] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc5 = vector.extract %acc[5] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc6 = vector.extract %acc[6] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc7 = vector.extract %acc[7] : vector<[4]xi32> from vector<8x[4]xi32>
+  vector.print %acc0 : vector<[4]xi32>
+  vector.print %acc1 : vector<[4]xi32>
+  vector.print %acc2 : vector<[4]xi32>
+  vector.print %acc3 : vector<[4]xi32>
+  vector.print %acc4 : vector<[4]xi32>
+  vector.print %acc5 : vector<[4]xi32>
+  vector.print %acc6 : vector<[4]xi32>
+  vector.print %acc7 : vector<[4]xi32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[-28,  31,   3, -44, -15, -27,  22,  35],
+                                   [-23,  39,  48,  26, -23,  32, -39, -38],
+                                   [ -3,   9,  43, -30, -32,  39,  41, -39],
+                                   [-13, -21, -25,  27,  47, -36, -11, -11],
+                                   [ -4, -20,  36,  11,  13, -23,  24, -13],
+                                   [-20,  30,  -5,   1,  42, -37, -22,  35],
+                                   [-22,  38,  -4,  44,  25, -31,  23, -39],
+                                   [-45,  -4, -31, -24,  14, -41, -47,  22]]> 
: vector<8x8xi8>
+
+  %lhs_m = memref.alloca() : memref<8x8xi8>
+  vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<8x8xi8>, 
memref<8x8xi8>
+  %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<8x8xi8>, 
vector<8x8xi8>
+
+  vector.print str "LHS:\n"
+  %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<8x8xi8>
+  %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<8x8xi8>
+  %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<8x8xi8>
+  %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<8x8xi8>
+  %lhs4 = vector.extract %lhs[4] : vector<8xi8> from vector<8x8xi8>
+  %lhs5 = vector.extract %lhs[5] : vector<8xi8> from vector<8x8xi8>
+  %lhs6 = vector.extract %lhs[6] : vector<8xi8> from vector<8x8xi8>
+  %lhs7 = vector.extract %lhs[7] : vector<8xi8> from vector<8x8xi8>
+  vector.print %lhs0 : vector<8xi8>
+  vector.print %lhs1 : vector<8xi8>
+  vector.print %lhs2 : vector<8xi8>
+  vector.print %lhs3 : vector<8xi8>
+  vector.print %lhs4 : vector<8xi8>
+  vector.print %lhs5 : vector<8xi8>
+  vector.print %lhs6 : vector<8xi8>
+  vector.print %lhs7 : vector<8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[-40, -11, -36,  36,  -1,  20,  14, -32],
+                                   [ 46, -45, -48, -46, -24,  31, -36,  22],
+                                   [  2,  36,  45, -29, -37, -49, -20, -35],
+                                   [ -6,  23,  23,  15,  20,   4,  -8,  -2],
+                                   [-35,  -6,  16,  49, -50,   9, -44,  13],
+                                   [ 24,   1,  -4, -44,  41,  15, -43,  44],
+                                   [ 44,   0, -10,  41,  22,  44, -40,   0],
+                                   [-33,  19,  27,  22,  38, -17,  23,  -9]]> 
: vector<8x8xi8>
+
+  %rhs_m = memref.alloca() : memref<8x8xi8>
+  vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<8x8xi8>, 
memref<8x8xi8>
+
+  %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<8x8xi8> into 
memref<64xi8>
+  %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : 
memref<64xi8>, vector<[32]xi8>
+
+  vector.print str "RHS:\n"
+  %rhs0 = vector.scalable.extract %rhs_flat[ 0] : vector<[16]xi8> from 
vector<[32]xi8>
+  %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from 
vector<[32]xi8>
+  vector.print %rhs0 : vector<[16]xi8>
+  vector.print %rhs1 : vector<[16]xi8>
+
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+  // Matrix multiplication
+  %0 = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
+  %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<8x8xi32>, vector<[4]x8xi32> into vector<8x[4]xi32>
+
+  // Display the result of the multilication
+  vector.print str "Result:\n"
+  %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u4 = vector.extract %2[4] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u5 = vector.extract %2[5] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u6 = vector.extract %2[6] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u7 = vector.extract %2[7] : vector<[4]xi32> from vector<8x[4]xi32>
+  vector.print %u0 : vector<[4]xi32>
+  vector.print %u1 : vector<[4]xi32>
+  vector.print %u2 : vector<[4]xi32>
+  vector.print %u3 : vector<[4]xi32>
+  vector.print %u4 : vector<[4]xi32>
+  vector.print %u5 : vector<[4]xi32>
+  vector.print %u6 : vector<[4]xi32>
+  vector.print %u7 : vector<[4]xi32>
+
+
+// CHECK: ( -2294, -1282,  2728,  -410, -1328,   882, -5498,   732 )
+// CHECK: (  1012, -4237,  4154,  2624,  5225, -2338,  2011,  1374 )
+// CHECK: (    -8, -1611,  2905,    -1, -1068, -3155, -2428,   153 )
+// CHECK: (  2034, -1768, -2092,   284,  -792,   -23,   668,  2172 )
+// CHECK: (  -248, -3728,  1214,   555,  -668, -2114, -1794,  2560 )
+// CHECK: ( -1484, -2642,   297,  1551,  -483,  3173,  -576,  2570 )
+// CHECK: (  3098, -7851,  1366,  1892,  -427, -4533,  -819,  4698 )
+// CHECK: (  -135,  1247,   765,  -479,  1245,  3074, -2281,   -23 )
+  return
+}
diff --git 
a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir 
b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
new file mode 100644
index 0000000000000..f1f311ddb0c18
--- /dev/null
+++ 
b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
@@ -0,0 +1,118 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  
--convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm 
--finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} 
-entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    
-shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+  %c128 = arith.constant 128 : i32
+  func.call @setArmVLBits(%c128) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+  %acc_cst = arith.constant dense<[[-44,  20,  44, -46],
+                                   [ -8,  25, -34,  26],
+                                   [-20, -36,  -3,  39],
+                                   [-48, -31, -25, -21]]> : vector<4x4xi32>
+  %acc_m = memref.alloca() : memref<4x4xi32>
+  vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, 
memref<4x4xi32>
+
+  %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into 
memref<16xi32>
+  %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} 
: memref<16xi32>, vector<[16]xi32>
+  %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+  vector.print str "ACC:\n"
+  %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %acc0 : vector<[4]xi32>
+  vector.print %acc1 : vector<[4]xi32>
+  vector.print %acc2 : vector<[4]xi32>
+  vector.print %acc3 : vector<[4]xi32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[-35, -27, -36, -31,  23, -34,  -8, -33],
+                                   [-20,  17, -32, -47,  37,  22,  -7, -21],
+                                   [ -7, -35,  20,  -4,  39,  46, -23,  40],
+                                   [ 40,  27,  37,  43,  38,  -6,  37,  49]]> 
: vector<4x8xi8>
+
+  %lhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, 
memref<4x8xi8>
+  %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, 
vector<4x8xi8>
+
+  vector.print str "LHS:\n"
+  %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+  %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+  %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+  %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+  vector.print %lhs0 : vector<8xi8>
+  vector.print %lhs1 : vector<8xi8>
+  vector.print %lhs2 : vector<8xi8>
+  vector.print %lhs3 : vector<8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[125, 171, 138, 187, 108, 175,  82,  99],
+                                   [221,  25, 164,  97, 156, 221, 218, 177],
+                                   [171, 160, 219, 191, 144,  45, 161, 210],
+                                   [223, 165, 123,  99, 108,  86,  37,  92]]> 
: vector<4x8xi8>
+
+  %rhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, 
memref<4x8xi8>
+
+  %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into 
memref<32xi8>
+  %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : 
memref<32xi8>, vector<[32]xi8>
+
+  vector.print str "RHS:\n"
+  %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from 
vector<[32]xi8>
+  %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from 
vector<[32]xi8>
+  vector.print %rhs0 : vector<[16]xi8>
+  vector.print %rhs1 : vector<[16]xi8>
+
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+  // Matrix multiplication
+  %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+  // Display the result of the multiplication
+  vector.print str "Result:\n"
+  %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %u0 : vector<[4]xi32>
+  vector.print %u1 : vector<[4]xi32>
+  vector.print %u2 : vector<[4]xi32>
+  vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -27190, -28812, -30502, -23575 )
+// CHECK: (  -7613,  -8386, -15938,  -6521 )
+// CHECK: (   9468,  18750,   9199,   5764 )
+// CHECK: (  33655,  41064,  48900,  31627 )
+  return
+}
+
diff --git 
a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir 
b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
new file mode 100644
index 0000000000000..7af0b2c3f1054
--- /dev/null
+++ 
b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
@@ -0,0 +1,119 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/140573
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to