This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 34a6571038 [SYSTEMDS-3803] DML-bodied Util Function for Transposing
ABCD to ACBD
34a6571038 is described below
commit 34a6571038f88517b27d5f218a2618cb29e28c7a
Author: Maximilian.S <[email protected]>
AuthorDate: Thu Dec 5 15:29:22 2024 +0100
[SYSTEMDS-3803] DML-bodied Util Function for Transposing ABCD to ACBD
This patch adds a simple util function for transposing matrices in a
specified way, which is required for multi-head attention implementation.
Closes #2151
---
scripts/nn/util.dml | 25 +++++++
.../test/applications/nn/NNComponentTest.java | 5 ++
.../nn/component/transpose_ABCD_to_ACBD.dml | 82 ++++++++++++++++++++++
3 files changed, 112 insertions(+)
diff --git a/scripts/nn/util.dml b/scripts/nn/util.dml
index 807d7baad9..a3d0f84c5c 100644
--- a/scripts/nn/util.dml
+++ b/scripts/nn/util.dml
@@ -380,3 +380,28 @@ top_k2d = function(matrix[double] X, int k, int C, int
Hin, int Win)
indices = transpose_NCHW_to_CNHW(indices_K_NHW, N)
}
+transpose_ABCD_to_ACBD = function(matrix[double] X, int B, int C)
+ return (matrix[double] out) {
+ /*
+ * Reshape util for tensors in ABCD format.
+ * Transposes the 2nd and 3rd axes.
+ *
+ * Inputs:
+ * - X: Inputs, of shape (A, B*C*D).
+ * - B: Dimension of 2nd axis.
+ * - C: Dimension of 3rd axis.
+ *
+ * Outputs:
+ * - out: Outputs with the 2nd and 3rd axes transposed, of
+ * shape (A, C*B*D).
+ */
+ A = nrow(X)
+ BCD = ncol(X)
+
+ # use NCHW_to_CNHW for X: (A, B*C*D) -> (B, A*C*D)
+ X_BACD = transpose_NCHW_to_CNHW(X, B)
+ # use NCHW_to_CNHW for X: (B, A*C*D) -> (A*C, B*D)
+ X_ACBD = transpose_NCHW_to_CNHW(X_BACD, A*C)
+ # reshape X: (A*C, B*D) -> (A, C*B*D)
+ out = matrix(X_ACBD, rows=A, cols=BCD)
+}
diff --git
a/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java
b/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java
index 86b2f64bb7..3b002871d7 100644
--- a/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java
+++ b/src/test/java/org/apache/sysds/test/applications/nn/NNComponentTest.java
@@ -108,6 +108,11 @@ public class NNComponentTest extends TestFolder {
run("transpose_NCHW_to_CNHW.dml");
}
+ @Test
+ public void transpose_ABCD_to_ACBD() {
+ run("transpose_ABCD_to_ACBD.dml");
+ }
+
@Test
public void logcosh(){
run("logcosh.dml");
diff --git
a/src/test/scripts/applications/nn/component/transpose_ABCD_to_ACBD.dml
b/src/test/scripts/applications/nn/component/transpose_ABCD_to_ACBD.dml
new file mode 100644
index 0000000000..1fa49c9b87
--- /dev/null
+++ b/src/test/scripts/applications/nn/component/transpose_ABCD_to_ACBD.dml
@@ -0,0 +1,82 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/applications/nn/util.dml") as test_util
+source("scripts/nn/util.dml") as util
+
+
+transpose_ABCD_to_ACBD = function() {
+ /*
+ * Test for `transpose_ABCD_to_ACBD` function.
+ */
+ print("Testing transpose_ABCD_to_ACBD function.")
+
+ # Generate data
+ A = 2
+ B = 3
+ C = 4
+ D = 5
+ X = matrix(seq(1, A*B*C*D), rows=A, cols=B*C*D)
+
+ out = util::transpose_ABCD_to_ACBD(X, B, C)
+
+ target =
+ matrix("1 2 3 4 5 21 22 23 24 25 41 42 43 44 45
+ 6 7 8 9 10 26 27 28 29 30 46 47 48 49 50
+ 11 12 13 14 15 31 32 33 34 35 51 52 53 54 55
+ 16 17 18 19 20 36 37 38 39 40 56 57 58 59 60
+
+ 61 62 63 64 65 81 82 83 84 85 101 102 103 104 105
+ 66 67 68 69 70 86 87 88 89 90 106 107 108 109 110
+ 71 72 73 74 75 91 92 93 94 95 111 112 113 114 115
+ 76 77 78 79 80 96 97 98 99 100 116 117 118 119 120",
+ rows=A, cols=C*B*D)
+
+ # Equivalency check
+ test_util::check_all_close(out, target, 1e-10)
+}
+
+
+transpose_ABCD_to_ACBD_single_val = function() {
+ /*
+ * Test for `transpose_ABCD_to_ACBD` function,
+ * transposing a single value matrix.
+ */
+ print("Testing transpose_ABCD_to_ACBD function with single value.")
+
+ # Generate data
+ A = 1
+ B = 1
+ C = 1
+ D = 1
+ X = matrix(seq(1, A*B*C*D), rows=A, cols=B*C*D)
+
+ out = util::transpose_ABCD_to_ACBD(X, B, C)
+
+ target = X
+
+ # Equivalency check
+ test_util::check_all_close(out, target, 1e-10)
+}
+
+
+transpose_ABCD_to_ACBD()
+transpose_ABCD_to_ACBD_single_val()