This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new e022eafdd1 [SYSTEMDS-3829] BERT layer forward pass
e022eafdd1 is described below

commit e022eafdd1d91f6b5be69c1395196d248e03f670
Author: MaximilianSchreff <[email protected]>
AuthorDate: Sun Jan 19 18:44:52 2025 +0100

    [SYSTEMDS-3829] BERT layer forward pass
    
    This patch introduces the forward pass of the BERT layer from the
    BERT transformer architecture as a built-in.
    
    Closes #2184
---
 scripts/nn/layers/bert_layer.dml                   | 186 +++++++++++++++++++++
 .../nn/transformers/BertLayerTest.java             | 117 +++++++++++++
 .../nn/transformers/MultiAttentionLayerTest.java   |   2 +-
 .../transformers/bert_layer/input_W_K_test1.csv    |   6 +
 .../transformers/bert_layer/input_W_K_test2.csv    |   8 +
 .../transformers/bert_layer/input_W_Q_test1.csv    |   6 +
 .../transformers/bert_layer/input_W_Q_test2.csv    |   8 +
 .../transformers/bert_layer/input_W_V_test1.csv    |   6 +
 .../transformers/bert_layer/input_W_V_test2.csv    |   8 +
 .../bert_layer/input_W_context_test1.csv           |   6 +
 .../bert_layer/input_W_context_test2.csv           |   8 +
 .../bert_layer/input_W_intermediate_test1.csv      |   6 +
 .../bert_layer/input_W_intermediate_test2.csv      |   8 +
 .../transformers/bert_layer/input_W_out_test1.csv  |   7 +
 .../transformers/bert_layer/input_W_out_test2.csv  |   7 +
 .../transformers/bert_layer/input_b_K_test1.csv    |   6 +
 .../transformers/bert_layer/input_b_K_test2.csv    |   8 +
 .../transformers/bert_layer/input_b_Q_test1.csv    |   6 +
 .../transformers/bert_layer/input_b_Q_test2.csv    |   8 +
 .../transformers/bert_layer/input_b_V_test1.csv    |   6 +
 .../transformers/bert_layer/input_b_V_test2.csv    |   8 +
 .../bert_layer/input_b_context_test1.csv           |   6 +
 .../bert_layer/input_b_context_test2.csv           |   8 +
 .../bert_layer/input_b_intermediate_test1.csv      |   7 +
 .../bert_layer/input_b_intermediate_test2.csv      |   7 +
 .../transformers/bert_layer/input_b_out_test1.csv  |   6 +
 .../transformers/bert_layer/input_b_out_test2.csv  |   8 +
 .../bert_layer/input_beta_ln1_test1.csv            |   6 +
 .../bert_layer/input_beta_ln1_test2.csv            |   8 +
 .../bert_layer/input_beta_ln2_test1.csv            |   6 +
 .../bert_layer/input_beta_ln2_test2.csv            |   8 +
 .../bert_layer/input_gamma_ln1_test1.csv           |   6 +
 .../bert_layer/input_gamma_ln1_test2.csv           |   8 +
 .../bert_layer/input_gamma_ln2_test1.csv           |   6 +
 .../bert_layer/input_gamma_ln2_test2.csv           |   8 +
 .../transformers/bert_layer/input_states_test1.csv |   5 +
 .../transformers/bert_layer/input_states_test2.csv |   4 +
 .../bert_layer/output_attention_test1.csv          |   5 +
 .../bert_layer/output_attention_test2.csv          |   4 +
 .../bert_layer/output_states_test1.csv             |   5 +
 .../bert_layer/output_states_test2.csv             |   4 +
 .../nn/component/bert_layer_forward.dml            |  92 ++++++++++
 42 files changed, 647 insertions(+), 1 deletion(-)

diff --git a/scripts/nn/layers/bert_layer.dml b/scripts/nn/layers/bert_layer.dml
new file mode 100644
index 0000000000..75b33fb263
--- /dev/null
+++ b/scripts/nn/layers/bert_layer.dml
@@ -0,0 +1,186 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("nn/layers/affine.dml") as affine
+source("nn/layers/multi_attention.dml") as attention
+source("nn/layers/dropout.dml") as dropout
+source("nn/layers/batch_norm1d.dml") as batch_norm
+source("nn/layers/tanh.dml") as tanh
+source("nn/layers/gelu.dml") as gelu
+
+linear_tensor_forward = function(matrix[double] X, matrix[double] W, 
matrix[double] b, int B, int C)
+  return (matrix[double] out) {
+  /* 
+   * Helper function for computing linear layer with tensor input, of shape 
(A, B*C) 
+   */
+  A = nrow(X)
+  C_new = ncol(W)
+  out = affine::forward(matrix(X, rows=A*B, cols=C), W, b)
+  out = matrix(out, rows=A, cols=B*C_new)
+}
+
+layer_norm_forward = function(matrix[double] X, matrix[double] gamma, 
matrix[double] beta, double epsilon, int B, int C)
+  return (matrix[double] out, matrix[double] cache_mean, matrix[double] 
cache_var, matrix[double] cache_norm) {
+  /*
+   * Helper function for computing layer norm via 1D batch norm with tensor 
input, of shpae (A, B*C)
+   */
+  A = nrow(X)
+  batch_norm_input = t(matrix(X, rows=A*B, cols=C))
+  # EMA matrices are unused and thus empty matrices will be provided
+  emas_mat = matrix(0, rows=1, cols=A*B)
+  [batch_norm_out, unused1, unused2, cache_mean, cache_var, cache_norm] = 
batch_norm::forward(
+    batch_norm_input, t(gamma), t(beta), "train", emas_mat, emas_mat, 0.0, 
epsilon)
+  out = matrix(t(batch_norm_out), rows=A, cols=B*C)
+}
+
+forward = function(matrix[double] states,
+      int H, int T, int d, int I,
+      matrix[double] W_Q, matrix[double] b_Q, 
+      matrix[double] W_K, matrix[double] b_K, 
+      matrix[double] W_V, matrix[double] b_V,
+      matrix[double] W_context, matrix[double] b_context, 
+      matrix[double] W_intermediate, matrix[double] b_intermediate, 
+      matrix[double] W_out, matrix[double] b_out, 
+      double dropout_p_attention, 
+      double dropout_p_output,
+      double epsilon_ln,
+      matrix[double] gamma_ln1, matrix[double] beta_ln1,
+      matrix[double] gamma_ln2, matrix[double] beta_ln2,
+      string activation)
+    return (matrix[double] out_states, matrix[double] attention,
+      list[unknown] outputs,
+      matrix[double] dropout_mask_attention, 
+      matrix[double] dropout_mask_output_1,
+      matrix[double] dropout_mask_output_2,
+      matrix[double] cache_mean_ln1, matrix[double] cache_var_ln1, 
matrix[double] cache_norm_ln1,
+      matrix[double] cache_mean_ln2, matrix[double] cache_var_ln2, 
matrix[double] cache_norm_ln2) {
+  /*
+   * Computes the forward pass for a layer of the BERT transformer 
architecture.
+   *
+   * Inputs (B: Batch size, T: Sequence length, D: Embedding length, H: Heads):
+   * - states: Hidden states, of shape (B, T*D).
+   * - H: Head count.
+   * - T: Sequence length.
+   * - d: Embedding length of single token per head with d*H = D.
+   * - I: Intemediate embedding length.
+   * - W_Q: Weights for linear query layer, of shape (D, D).
+   * - b_Q: Biases for linear query layer, of shape (1, D).
+   * - W_K: Weights for linear key layer, of shape (D, D).
+   * - b_K: Biases for linear key layer, of shape (1, D).
+   * - W_V: Weights for linear value layer, of shape (D, D).
+   * - b_V: Biases for linear value layer, of shape (1, D).
+   * - W_context: Weights for linear output layer on context, of shape (D, D).
+   * - b_context: Biases for linear output layer on context, of shape (1, D).
+   * - W_intermediate: Weights for intermediate linear layer, of shape (D, I).
+   * - b_intermediate: Biases for intermediate linear layer, of shape (1, I).
+   * - W_out: Weights for last linear output layer, of shape (D, D).
+   * - b_out: Biases for last linear output layer, of shape (1, D).
+   * - dropout_p_attention: Probability for dropout on attention.
+   * - dropout_p_output: Probability for dropout on output.
+   * - epsilon_ln: Epsilon value for layer norm.
+   * - gamma_ln1: Gamma params for layer norm 1, of shape (1, D).
+   * - beta_ln1: Beta params for layer norm 1, of shape (1, D).
+   * - gamma_ln2: Gamma params for layer norm 2, of shape (1, D).
+   * - beta_ln2: Beta params for layer norm 2, of shape (1, D).
+   * - activation: String specifying type of activation to use.
+   *     Can be tanh or gelu.
+   *
+   * Outputs:
+   * - out_states: Token output states, of shape (B, T*D)
+   * - attention: Attention values for keys & querys, of shape (B, H*T*T)
+   * - outputs: List of relevant outputs for backward pass with following
+   *     order/content:
+   *   -> 1: Output of linear query layer, of shape (B, T*D).
+   *   -> 2: Output of linear key layer, of shape (B, T*D).
+   *   -> 3: Output of linear value layer, of shape (B, T*D).
+   *   -> 4: Output context of attention layer, of shape (B, T*D).
+   *   -> 5: Output attention of attention layer, of shape (B, T*D).
+   *   -> 6: Output of residual pass 1, of shape (B, T*D).
+   *   -> 7: Output of layer norm 1, of shape (B, T*D).
+   *   -> 8: Output of intermediate linear layer, of shape (B, T*I).
+   *   -> 9: Output of activation layer, of shape (B, T*I).
+   *   -> 10: Output of residual pass 2, of shape (B, T*D).
+   * - dropout_mask_attention: Dropout mask used on attention, of shape (B, 
H*T*T)
+   * - dropout_mask_output_1: Dropout mask used on attention output, of shape 
(B, T*D)
+   * - dropout_mask_output_2: Dropout mask used on attention output, of shape 
(B, T*D)
+   * - cache_mean_ln1: Cached mean from layer norm 1, of shape (1, B*T)
+   * - cache_var_ln1: Cached mean from layer norm 1, of shape (1, B*T)
+   * - cache_norm_ln1: Cached mean from layer norm 1, of shape (1, B*T)
+   * - cache_mean_ln2: Cached mean from layer norm 2, of shape (1, B*T)
+   * - cache_var_ln2: Cached mean from layer norm 2, of shape (1, B*T)
+   * - cache_norm_ln2: Cached mean from layer norm 2, of shape (1, B*T)
+   */
+  # Embedding dim
+  D = d * H
+
+  # Linear layers for Q, K, V
+  Q = linear_tensor_forward(states, W_Q, b_Q, T, D)  # Shape (B, T*D)
+  K = linear_tensor_forward(states, W_K, b_K, T, D)  # Shape (B, T*D)
+  V = linear_tensor_forward(states, W_V, b_V, T, D)  # Shape (B, T*D)
+
+  # Multi-head self attention
+  [context, attention, dropout_mask_attention] = attention::forward(Q, K, V, 
H, T, d, dropout_p_attention)
+  # Shapes (B, T*D), (B, H*T*T), (B, H*T*T)
+  outputs = list(Q, K, V, context, attention)
+
+  # Linear layer on attention output (output layer)
+  out_states = linear_tensor_forward(context, W_context, b_context, T, D)  # 
Shape (B, T*D)
+  # Dropout on output 1
+  dropout_mask_output_1 = matrix(0, 1, 1)
+  if (dropout_p_output > 0.0) {
+    [out_states, dropout_mask_output_1] = dropout::forward(out_states, 
dropout_p_output, -1)
+  }
+
+  # Residual pass 1
+  out_states = out_states + states  # Shapes (B, T*D).
+  outputs = append(outputs, out_states)
+  # Layer norm 1 for each token
+  [out_states, cache_mean_ln1, cache_var_ln1, cache_norm_ln1] = 
layer_norm_forward(
+    out_states, gamma_ln1, beta_ln1, epsilon_ln, T, D)
+  outputs = append(outputs, out_states)
+
+  # Save out_states for residual pass
+  out_states_identity = out_states
+  # Linear layer of intermediate part
+  out_states = linear_tensor_forward(out_states, W_intermediate, 
b_intermediate, T, D)  # Shape (B, T*I)
+  outputs = append(outputs, out_states)
+  # Activation
+  if (activation == "gelu") {
+    out_states = gelu::forward(out_states)
+  } else if (activation == "tanh") {
+    out_states = tanh::forward(out_states)
+  }
+  outputs = append(outputs, out_states)
+
+  # Final linear output layer
+  out_states = linear_tensor_forward(out_states, W_out, b_out, T, I)  # Shape 
(B, T*D)
+  # Dropout on output 2
+  dropout_mask_output_2 = matrix(0, 1, 1)
+  if (dropout_p_output > 0.0) {
+    [out_states, dropout_mask_output_2] = dropout::forward(out_states, 
dropout_p_output, -1)
+  }
+  # Residual pass 2
+  out_states = out_states + out_states_identity
+  outputs = append(outputs, out_states)
+  # Layer norm 2 for each token
+  [out_states, cache_mean_ln2, cache_var_ln2, cache_norm_ln2] = 
layer_norm_forward(
+    out_states, gamma_ln2, beta_ln2, epsilon_ln, T, D)
+}
diff --git 
a/src/test/java/org/apache/sysds/test/applications/nn/transformers/BertLayerTest.java
 
b/src/test/java/org/apache/sysds/test/applications/nn/transformers/BertLayerTest.java
new file mode 100644
index 0000000000..583288dfbc
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/applications/nn/transformers/BertLayerTest.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysds.test.applications.nn.transformers;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class BertLayerTest extends AutomatedTestBase{
+       private static final String TEST_NAME_FORWARD = "bert_layer_forward";
+       private static final String TEST_DIR = "applications/nn/component/";
+       private static final String RESOURCE_DIR = 
"src/test/resources/component/transformers/bert_layer/";
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME_FORWARD, new 
TestConfiguration(TEST_DIR, TEST_NAME_FORWARD));
+       }
+
+       @Test
+       public void testBertLayerForwardNormalTanh() {
+               runBertLayerTest("test1", 5, 4, 6, 2, 3, 7, "tanh", 0, 
TEST_NAME_FORWARD, 
+            1e-5, true);
+       }
+
+       @Test
+       public void testBertLayerForwardNormalGelu() {
+               runBertLayerTest("test2", 4, 4, 8, 2, 4, 7, "gelu", 0, 
TEST_NAME_FORWARD, 
+            1e-5, true);
+       }
+
+       private void runBertLayerTest(String testSuffix, int batchSize, int 
seqLength, int embeddingDim, int numHeads,
+            int perHeadEmbeddingDim, int intermediateEmbeddingDim, String 
activation, int debug, String testname, double precision, 
+            boolean isForward) {
+               // Set execution platform
+               Types.ExecMode platformOld = 
setExecMode(Types.ExecMode.SINGLE_NODE);
+
+               try {
+                       // Load test configuration
+                       getAndLoadTestConfiguration(testname);
+                       fullDMLScriptName = getScript();
+
+                       // Program arguments
+                       if (isForward) {
+                               programArgs = new String[] { 
+                                       "-stats", "-args",
+                                       String.valueOf(debug), 
String.valueOf(batchSize), 
+                    String.valueOf(seqLength), String.valueOf(embeddingDim),
+                                       String.valueOf(numHeads), 
String.valueOf(perHeadEmbeddingDim),
+                    String.valueOf(intermediateEmbeddingDim), activation,
+                                       RESOURCE_DIR + "input_states_" + 
testSuffix + ".csv",
+                                       RESOURCE_DIR + "input_W_Q_" + 
testSuffix + ".csv",
+                                       RESOURCE_DIR + "input_b_Q_" + 
testSuffix + ".csv",
+                    RESOURCE_DIR + "input_W_K_" + testSuffix + ".csv",
+                                       RESOURCE_DIR + "input_b_K_" + 
testSuffix + ".csv",
+                    RESOURCE_DIR + "input_W_V_" + testSuffix + ".csv",
+                                       RESOURCE_DIR + "input_b_V_" + 
testSuffix + ".csv",
+                    RESOURCE_DIR + "input_W_context_" + testSuffix + ".csv",
+                                       RESOURCE_DIR + "input_b_context_" + 
testSuffix + ".csv",
+                    RESOURCE_DIR + "input_W_intermediate_" + testSuffix + 
".csv",
+                                       RESOURCE_DIR + "input_b_intermediate_" 
+ testSuffix + ".csv",
+                    RESOURCE_DIR + "input_W_out_" + testSuffix + ".csv",
+                                       RESOURCE_DIR + "input_b_out_" + 
testSuffix + ".csv",
+                    RESOURCE_DIR + "input_gamma_ln1_" + testSuffix + ".csv",
+                                       RESOURCE_DIR + "input_beta_ln1_" + 
testSuffix + ".csv",
+                    RESOURCE_DIR + "input_gamma_ln2_" + testSuffix + ".csv",
+                                       RESOURCE_DIR + "input_beta_ln2_" + 
testSuffix + ".csv",
+                                       RESOURCE_DIR + "output_states_" + 
testSuffix + ".csv",
+                                       RESOURCE_DIR + "output_attention_" + 
testSuffix + ".csv",
+                                       output("states_error"),
+                                       output("attention_error"), 
+                               };
+                       }
+
+                       // Run the test
+                       runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+
+                       // Compare results
+                       if (isForward) {
+                               double statesMaxError = (Double) 
readDMLScalarFromOutputDir("states_error").values().toArray()[0];
+                               assert statesMaxError < precision;
+                               double attentionMaxError = (Double) 
readDMLScalarFromOutputDir("attention_error").values().toArray()[0];
+                               assert attentionMaxError < precision;
+                       } else {
+                               double dqueryMaxError = (Double) 
readDMLScalarFromOutputDir("dquery_error").values().toArray()[0];
+                               assert dqueryMaxError < precision;
+                               double dkeyMaxError = (Double) 
readDMLScalarFromOutputDir("dkey_error").values().toArray()[0];
+                               assert dkeyMaxError < precision;
+                               double dvalueMaxError = (Double) 
readDMLScalarFromOutputDir("dvalue_error").values().toArray()[0];
+                               assert dvalueMaxError < precision;
+                       }
+               } catch (Throwable ex) {
+                       ex.printStackTrace(System.out); // Log or debug all 
exceptions or errors
+                       throw new RuntimeException(ex);
+               } finally {
+                       resetExecMode(platformOld);
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/applications/nn/transformers/MultiAttentionLayerTest.java
 
b/src/test/java/org/apache/sysds/test/applications/nn/transformers/MultiAttentionLayerTest.java
index ba10f50e89..225d8983aa 100644
--- 
a/src/test/java/org/apache/sysds/test/applications/nn/transformers/MultiAttentionLayerTest.java
+++ 
b/src/test/java/org/apache/sysds/test/applications/nn/transformers/MultiAttentionLayerTest.java
@@ -119,7 +119,7 @@ public class MultiAttentionLayerTest extends 
AutomatedTestBase {
                        if (isForward) {
                                double contextMaxError = (Double) 
readDMLScalarFromOutputDir("context_error").values().toArray()[0];
                                assert contextMaxError < precision;
-                               double attentionMaxError = (Double) 
readDMLScalarFromOutputDir("context_error").values().toArray()[0];
+                               double attentionMaxError = (Double) 
readDMLScalarFromOutputDir("attention_error").values().toArray()[0];
                                assert attentionMaxError < precision;
                        } else {
                                double dqueryMaxError = (Double) 
readDMLScalarFromOutputDir("dquery_error").values().toArray()[0];
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_K_test1.csv 
b/src/test/resources/component/transformers/bert_layer/input_W_K_test1.csv
new file mode 100644
index 0000000000..a4aeec399f
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_W_K_test1.csv
@@ -0,0 +1,6 @@
+-0.366342,-0.367447,-0.368428,0.043807,-0.098173,-0.287969
+-0.352497,-0.027552,0.258298,-0.085474,-0.085857,0.018208
+-0.063882,0.359021,-0.047110,0.291535,-0.336430,-0.287788
+0.005280,-0.166521,-0.182245,0.113960,0.221207,-0.224734
+-0.185457,0.368649,0.326457,0.196166,0.324140,-0.237889
+0.153787,0.147849,-0.329904,0.144177,0.279334,0.139517
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_K_test2.csv 
b/src/test/resources/component/transformers/bert_layer/input_W_K_test2.csv
new file mode 100644
index 0000000000..717ab3e11e
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_W_K_test2.csv
@@ -0,0 +1,8 @@
+0.014873,0.127848,-0.276551,-0.306393,0.044137,0.116741,0.004873,-0.350424
+0.227909,0.044769,-0.185308,0.175143,0.316675,0.265246,-0.060110,0.159592
+-0.267258,-0.002632,0.285492,-0.251829,0.216273,-0.113814,-0.186207,-0.169799
+-0.242719,-0.069891,-0.286925,-0.100361,-0.223521,0.000566,0.046730,-0.235940
+-0.205295,0.044359,-0.025387,-0.118623,0.158570,0.182018,0.292360,-0.203683
+0.247464,-0.080732,0.349749,-0.052357,-0.249925,-0.341919,-0.103351,0.203278
+-0.127090,-0.002484,0.127717,0.003867,-0.149845,0.255612,-0.209903,0.187233
+0.298218,0.045111,0.010010,0.291613,0.103988,-0.292361,-0.130758,0.271360
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_Q_test1.csv 
b/src/test/resources/component/transformers/bert_layer/input_W_Q_test1.csv
new file mode 100644
index 0000000000..38999e5b0d
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_W_Q_test1.csv
@@ -0,0 +1,6 @@
+0.160208,-0.102197,-0.265960,0.082622,-0.366738,-0.382060
+0.024454,-0.198863,-0.020951,0.259580,-0.193541,-0.344565
+-0.199196,-0.142819,0.292245,0.386712,0.277978,-0.082808
+0.193179,-0.334609,-0.041968,0.259260,-0.002646,0.223886
+-0.391612,-0.086841,0.011346,0.387596,-0.202918,0.220716
+-0.241971,0.087266,-0.035219,-0.029525,-0.312845,-0.393728
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_Q_test2.csv 
b/src/test/resources/component/transformers/bert_layer/input_W_Q_test2.csv
new file mode 100644
index 0000000000..28e781c7ef
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_W_Q_test2.csv
@@ -0,0 +1,8 @@
+-0.123685,0.009826,-0.317605,-0.071714,-0.068100,-0.318219,-0.040798,0.169884
+-0.289780,-0.030501,-0.167612,0.193891,-0.069418,-0.023860,-0.157829,0.124861
+-0.075206,0.071553,0.240736,0.191146,-0.317261,0.310922,0.282720,-0.085021
+0.075574,0.224803,-0.002292,-0.340978,-0.305271,-0.144212,-0.285705,-0.074354
+-0.230328,0.334902,-0.175732,0.220540,-0.055324,0.319260,0.037938,-0.291357
+-0.018144,0.224526,-0.270932,-0.276659,0.004572,0.128041,-0.074023,0.191571
+0.253091,0.335668,-0.330874,-0.074745,-0.160610,-0.319068,0.252477,0.280714
+-0.036345,-0.025570,-0.298402,-0.143356,0.133183,0.223692,0.098693,0.241910
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_V_test1.csv 
b/src/test/resources/component/transformers/bert_layer/input_W_V_test1.csv
new file mode 100644
index 0000000000..aebde16a95
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_W_V_test1.csv
@@ -0,0 +1,6 @@
+-0.237054,-0.003039,-0.319334,0.147474,-0.136974,0.249731
+0.285747,-0.080703,-0.213976,0.011559,-0.060456,-0.258100
+-0.146751,0.051221,0.329658,-0.353792,0.004466,0.183101
+0.344353,-0.093221,-0.331313,0.202237,0.336726,-0.288589
+0.147626,-0.002869,-0.029315,-0.290787,0.050965,-0.173026
+0.051695,0.052090,0.403855,-0.115887,0.365665,0.120075
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_V_test2.csv 
b/src/test/resources/component/transformers/bert_layer/input_W_V_test2.csv
new file mode 100644
index 0000000000..66d9a2d249
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_W_V_test2.csv
@@ -0,0 +1,8 @@
+-0.334274,0.011264,-0.080849,0.244498,0.069431,0.122792,0.029533,0.165114
+0.076385,-0.237900,-0.015723,-0.263194,-0.262510,-0.129004,0.044147,-0.171997
+-0.198408,-0.285785,-0.215330,0.144839,0.058866,0.134202,-0.277945,-0.292986
+-0.315220,0.281811,0.119572,-0.118884,0.150589,0.235453,0.027785,-0.304028
+0.310023,0.057572,0.111782,-0.170578,0.139947,-0.184608,0.244825,0.352708
+-0.229602,0.293317,-0.007293,0.063514,-0.044505,0.003487,0.318592,0.224432
+-0.040221,-0.118525,-0.079515,-0.183656,-0.289839,0.146194,0.207801,-0.244388
+0.101291,0.104141,-0.217941,0.081460,-0.054502,0.027711,0.047377,0.138325
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_context_test1.csv
 
b/src/test/resources/component/transformers/bert_layer/input_W_context_test1.csv
new file mode 100644
index 0000000000..2475b868e9
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_W_context_test1.csv
@@ -0,0 +1,6 @@
+0.295156,0.337588,-0.196067,0.148081,-0.193161,0.357983
+-0.337590,-0.119339,-0.272441,-0.136338,-0.191908,-0.265121
+0.005627,-0.242375,-0.235192,-0.114084,-0.385986,-0.046443
+-0.069409,-0.150986,0.234725,0.120609,0.088201,0.116961
+-0.215013,-0.404635,0.216198,0.335594,-0.229102,0.013006
+0.053959,0.184281,0.313339,0.111000,-0.363984,-0.274703
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_context_test2.csv
 
b/src/test/resources/component/transformers/bert_layer/input_W_context_test2.csv
new file mode 100644
index 0000000000..828e3cff62
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_W_context_test2.csv
@@ -0,0 +1,8 @@
+-0.212761,-0.063619,-0.239211,-0.241276,0.290790,0.116962,-0.199302,0.226094
+0.317851,0.137643,-0.348794,-0.057128,0.289475,0.353439,-0.198385,0.026281
+0.170822,0.062799,-0.283923,-0.229609,0.253087,0.183375,-0.272053,-0.166925
+0.192735,0.150426,0.279120,0.245506,0.272988,0.219786,0.237351,0.324932
+-0.221599,-0.120147,0.191285,-0.267289,0.314375,-0.123741,0.251352,0.144582
+0.101434,0.172383,0.331709,-0.172499,-0.090532,0.169645,-0.040239,-0.268398
+-0.123943,-0.246947,0.283239,-0.341565,0.155564,0.040626,-0.204596,0.338380
+0.276251,0.079852,-0.315739,-0.200728,0.314991,-0.084435,0.273263,0.268479
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_intermediate_test1.csv
 
b/src/test/resources/component/transformers/bert_layer/input_W_intermediate_test1.csv
new file mode 100644
index 0000000000..517e834185
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_W_intermediate_test1.csv
@@ -0,0 +1,6 @@
+-0.093357,-0.091816,-0.196967,0.067973,0.141788,0.168810,0.282700
+-0.018155,-0.251656,0.073340,0.173885,-0.148960,0.031998,0.367878
+-0.248641,0.282322,-0.212067,0.161597,0.154963,0.034102,0.239948
+0.138070,-0.303910,0.094062,-0.051390,0.271878,0.050976,0.054707
+0.129074,0.167245,0.080172,-0.334677,-0.213168,-0.320943,0.190658
+-0.008422,-0.137275,-0.303120,-0.062933,0.004026,0.032084,-0.198604
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_intermediate_test2.csv
 
b/src/test/resources/component/transformers/bert_layer/input_W_intermediate_test2.csv
new file mode 100644
index 0000000000..336e2aa2e8
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_W_intermediate_test2.csv
@@ -0,0 +1,8 @@
+-0.095970,0.175377,0.300088,-0.339869,0.016123,0.019002,0.112700
+-0.068211,-0.323126,0.023189,-0.343534,-0.318806,0.167661,0.189790
+0.033835,-0.063266,-0.235639,-0.071723,0.293212,-0.283489,0.049253
+0.326977,-0.262741,-0.126673,0.237741,0.190369,-0.101691,-0.236557
+0.018929,-0.150856,0.077203,-0.334631,0.351431,-0.347146,-0.274117
+-0.218298,0.127383,-0.269520,0.293869,0.178619,-0.137706,-0.109077
+0.018121,-0.251069,0.175649,-0.141429,-0.233370,0.076272,0.155195
+0.169524,0.131425,-0.320980,0.103550,0.295070,-0.277597,0.348744
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_out_test1.csv 
b/src/test/resources/component/transformers/bert_layer/input_W_out_test1.csv
new file mode 100644
index 0000000000..b3a89218fb
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_W_out_test1.csv
@@ -0,0 +1,7 @@
+0.377819,-0.227451,0.295325,-0.263997,0.354612,-0.285744
+0.330490,0.339797,-0.068011,0.085365,0.302795,-0.184409
+0.292829,0.182617,0.147146,-0.255728,-0.337539,-0.365148
+-0.086652,0.206042,0.067135,-0.372876,-0.257935,-0.214588
+-0.132643,-0.236899,0.160812,-0.303527,-0.061072,0.310867
+0.310327,0.108438,-0.128443,0.298392,-0.245462,0.309462
+0.211804,-0.132500,0.184285,0.204492,0.262457,0.270561
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_out_test2.csv 
b/src/test/resources/component/transformers/bert_layer/input_W_out_test2.csv
new file mode 100644
index 0000000000..247929a979
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_W_out_test2.csv
@@ -0,0 +1,7 @@
+0.181837,0.233912,-0.264798,-0.350101,0.154676,0.095152,-0.091160,-0.001032
+-0.266817,0.209144,0.317217,0.266578,-0.111686,0.134062,0.132476,-0.183122
+-0.187210,0.012154,-0.041096,0.189420,0.126487,-0.184663,-0.273795,0.377318
+-0.311325,-0.116858,-0.316714,0.223724,-0.108745,0.033411,-0.222254,0.369147
+0.197237,-0.082181,-0.204501,0.319958,0.233681,0.219044,-0.191851,-0.285056
+-0.038513,0.050232,0.334430,-0.203704,-0.104868,-0.037606,0.347353,-0.306404
+0.290882,0.187361,0.345659,0.119352,-0.140904,0.115027,-0.101704,-0.286494
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_K_test1.csv 
b/src/test/resources/component/transformers/bert_layer/input_b_K_test1.csv
new file mode 100644
index 0000000000..2cf9a2086e
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_b_K_test1.csv
@@ -0,0 +1,6 @@
+-0.243281
+-0.008907
+0.017174
+0.263166
+-0.308603
+-0.280267
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_K_test2.csv 
b/src/test/resources/component/transformers/bert_layer/input_b_K_test2.csv
new file mode 100644
index 0000000000..6a3c2d6ae1
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_b_K_test2.csv
@@ -0,0 +1,8 @@
+0.128242
+-0.118073
+-0.098800
+0.104450
+0.290633
+0.096128
+-0.167282
+-0.166197
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_Q_test1.csv 
b/src/test/resources/component/transformers/bert_layer/input_b_Q_test1.csv
new file mode 100644
index 0000000000..4d6662d8cf
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_b_Q_test1.csv
@@ -0,0 +1,6 @@
+0.254658
+-0.319458
+-0.086308
+-0.165534
+-0.078635
+-0.080157
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_Q_test2.csv 
b/src/test/resources/component/transformers/bert_layer/input_b_Q_test2.csv
new file mode 100644
index 0000000000..3d64bbb0c3
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_b_Q_test2.csv
@@ -0,0 +1,8 @@
+-0.249389
+0.015768
+-0.249232
+-0.194626
+-0.206017
+0.120825
+-0.210687
+-0.007714
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_V_test1.csv 
b/src/test/resources/component/transformers/bert_layer/input_b_V_test1.csv
new file mode 100644
index 0000000000..5060bfdf87
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_b_V_test1.csv
@@ -0,0 +1,6 @@
+0.134800
+0.306279
+-0.131422
+0.000654
+0.210176
+-0.394814
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_V_test2.csv 
b/src/test/resources/component/transformers/bert_layer/input_b_V_test2.csv
new file mode 100644
index 0000000000..7d725d1af5
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_b_V_test2.csv
@@ -0,0 +1,8 @@
+0.266992
+0.353417
+0.309145
+0.273917
+-0.081055
+-0.124077
+0.290285
+0.198124
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_context_test1.csv
 
b/src/test/resources/component/transformers/bert_layer/input_b_context_test1.csv
new file mode 100644
index 0000000000..26b70e9390
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_b_context_test1.csv
@@ -0,0 +1,6 @@
+-0.329996
+0.325408
+0.066479
+0.338693
+-0.136861
+0.120252
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_context_test2.csv
 
b/src/test/resources/component/transformers/bert_layer/input_b_context_test2.csv
new file mode 100644
index 0000000000..a5e465902a
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_b_context_test2.csv
@@ -0,0 +1,8 @@
+-0.128867
+0.198752
+-0.200885
+-0.055407
+0.300203
+0.014609
+-0.250039
+-0.118169
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_intermediate_test1.csv
 
b/src/test/resources/component/transformers/bert_layer/input_b_intermediate_test1.csv
new file mode 100644
index 0000000000..219a580d32
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_b_intermediate_test1.csv
@@ -0,0 +1,7 @@
+-0.338311
+-0.351061
+0.407272
+0.259152
+-0.282195
+0.159724
+0.308296
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_intermediate_test2.csv
 
b/src/test/resources/component/transformers/bert_layer/input_b_intermediate_test2.csv
new file mode 100644
index 0000000000..ff877c864d
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_b_intermediate_test2.csv
@@ -0,0 +1,7 @@
+0.203303
+-0.039813
+0.123962
+-0.346858
+-0.301970
+0.164971
+-0.200258
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_out_test1.csv 
b/src/test/resources/component/transformers/bert_layer/input_b_out_test1.csv
new file mode 100644
index 0000000000..4636941a76
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_b_out_test1.csv
@@ -0,0 +1,6 @@
+0.291837
+0.336081
+-0.096783
+0.166305
+0.336740
+0.125038
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_out_test2.csv 
b/src/test/resources/component/transformers/bert_layer/input_b_out_test2.csv
new file mode 100644
index 0000000000..a5b08ae460
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_b_out_test2.csv
@@ -0,0 +1,8 @@
+-0.001823
+-0.096345
+-0.247395
+-0.135565
+0.071409
+-0.197483
+0.083750
+-0.086673
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_beta_ln1_test1.csv 
b/src/test/resources/component/transformers/bert_layer/input_beta_ln1_test1.csv
new file mode 100644
index 0000000000..5a180c18a8
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_beta_ln1_test1.csv
@@ -0,0 +1,6 @@
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_beta_ln1_test2.csv 
b/src/test/resources/component/transformers/bert_layer/input_beta_ln1_test2.csv
new file mode 100644
index 0000000000..69fba5d887
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_beta_ln1_test2.csv
@@ -0,0 +1,8 @@
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_beta_ln2_test1.csv 
b/src/test/resources/component/transformers/bert_layer/input_beta_ln2_test1.csv
new file mode 100644
index 0000000000..5a180c18a8
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_beta_ln2_test1.csv
@@ -0,0 +1,6 @@
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_beta_ln2_test2.csv 
b/src/test/resources/component/transformers/bert_layer/input_beta_ln2_test2.csv
new file mode 100644
index 0000000000..69fba5d887
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_beta_ln2_test2.csv
@@ -0,0 +1,8 @@
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_gamma_ln1_test1.csv
 
b/src/test/resources/component/transformers/bert_layer/input_gamma_ln1_test1.csv
new file mode 100644
index 0000000000..abb00d5bfa
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_gamma_ln1_test1.csv
@@ -0,0 +1,6 @@
+1.000000
+1.000000
+1.000000
+1.000000
+1.000000
+1.000000
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_gamma_ln1_test2.csv
 
b/src/test/resources/component/transformers/bert_layer/input_gamma_ln1_test2.csv
new file mode 100644
index 0000000000..8d61340f4e
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_gamma_ln1_test2.csv
@@ -0,0 +1,8 @@
+1.000000
+1.000000
+1.000000
+1.000000
+1.000000
+1.000000
+1.000000
+1.000000
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_gamma_ln2_test1.csv
 
b/src/test/resources/component/transformers/bert_layer/input_gamma_ln2_test1.csv
new file mode 100644
index 0000000000..abb00d5bfa
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_gamma_ln2_test1.csv
@@ -0,0 +1,6 @@
+1.000000
+1.000000
+1.000000
+1.000000
+1.000000
+1.000000
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_gamma_ln2_test2.csv
 
b/src/test/resources/component/transformers/bert_layer/input_gamma_ln2_test2.csv
new file mode 100644
index 0000000000..8d61340f4e
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_gamma_ln2_test2.csv
@@ -0,0 +1,8 @@
+1.000000
+1.000000
+1.000000
+1.000000
+1.000000
+1.000000
+1.000000
+1.000000
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_states_test1.csv 
b/src/test/resources/component/transformers/bert_layer/input_states_test1.csv
new file mode 100644
index 0000000000..d173363309
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_states_test1.csv
@@ -0,0 +1,5 @@
+0.496257,0.768222,0.088477,0.132030,0.307423,0.634079,0.490093,0.896445,0.455628,0.632306,0.348893,0.401717,0.022326,0.168859,0.293888,0.518522,0.697668,0.800011,0.161029,0.282269,0.681609,0.915194,0.397100,0.874156
+0.419408,0.552907,0.952738,0.036165,0.185231,0.373417,0.305100,0.932000,0.175910,0.269834,0.150680,0.031720,0.208130,0.929799,0.723109,0.742336,0.526296,0.243658,0.584592,0.033153,0.138717,0.242235,0.815469,0.793161
+0.278252,0.481959,0.819780,0.997067,0.698441,0.567546,0.835243,0.205599,0.593172,0.112347,0.153457,0.241708,0.726237,0.701080,0.203824,0.651054,0.774486,0.436891,0.519091,0.615852,0.810188,0.980097,0.114688,0.316765
+0.696505,0.914275,0.935104,0.941178,0.599507,0.065209,0.545996,0.187197,0.034023,0.944246,0.880180,0.001236,0.593586,0.415770,0.417719,0.271122,0.692278,0.203848,0.683296,0.752854,0.857936,0.686956,0.005132,0.175652
+0.749658,0.604651,0.109958,0.212090,0.970375,0.836909,0.281987,0.374158,0.023701,0.491013,0.123471,0.114322,0.472450,0.575073,0.295235,0.796689,0.195730,0.953685,0.842650,0.078359,0.375558,0.522561,0.572951,0.618587
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_states_test2.csv 
b/src/test/resources/component/transformers/bert_layer/input_states_test2.csv
new file mode 100644
index 0000000000..3d1e77719b
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_states_test2.csv
@@ -0,0 +1,4 @@
+0.496257,0.768222,0.088477,0.132030,0.307423,0.634079,0.490093,0.896445,0.455628,0.632306,0.348893,0.401717,0.022326,0.168859,0.293888,0.518522,0.697668,0.800011,0.161029,0.282269,0.681609,0.915194,0.397100,0.874156,0.419408,0.552907,0.952738,0.036165,0.185231,0.373417,0.305100,0.932000
+0.175910,0.269834,0.150680,0.031720,0.208130,0.929799,0.723109,0.742336,0.526296,0.243658,0.584592,0.033153,0.138717,0.242235,0.815469,0.793161,0.278252,0.481959,0.819780,0.997067,0.698441,0.567546,0.835243,0.205599,0.593172,0.112347,0.153457,0.241708,0.726237,0.701080,0.203824,0.651054
+0.774486,0.436891,0.519091,0.615852,0.810188,0.980097,0.114688,0.316765,0.696505,0.914275,0.935104,0.941178,0.599507,0.065209,0.545996,0.187197,0.034023,0.944246,0.880180,0.001236,0.593586,0.415770,0.417719,0.271122,0.692278,0.203848,0.683296,0.752854,0.857936,0.686956,0.005132,0.175652
+0.749658,0.604651,0.109958,0.212090,0.970375,0.836909,0.281987,0.374158,0.023701,0.491013,0.123471,0.114322,0.472450,0.575073,0.295235,0.796689,0.195730,0.953685,0.842650,0.078359,0.375558,0.522561,0.572951,0.618587,0.696214,0.529950,0.256036,0.736594,0.020376,0.203647,0.374835,0.256443
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_attention_test1.csv
 
b/src/test/resources/component/transformers/bert_layer/output_attention_test1.csv
new file mode 100644
index 0000000000..9612f1aa89
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_attention_test1.csv
@@ -0,0 +1,5 @@
+0.260437,0.256136,0.235072,0.248356,0.266629,0.260266,0.230285,0.242820,0.266394,0.266508,0.229160,0.237938,0.269346,0.266699,0.227771,0.236184,0.236945,0.274874,0.228124,0.260057,0.230361,0.266480,0.235668,0.267491,0.241364,0.260988,0.239111,0.258537,0.233467,0.259916,0.241311,0.265306
+0.239317,0.272557,0.243360,0.244766,0.241516,0.266844,0.235070,0.256570,0.234964,0.281846,0.237928,0.245263,0.244160,0.263001,0.241525,0.251314,0.265337,0.225869,0.267915,0.240879,0.269381,0.235652,0.262392,0.232575,0.271319,0.222565,0.265873,0.240242,0.282794,0.240248,0.260288,0.216670
+0.216586,0.264298,0.259379,0.259736,0.229986,0.262832,0.248209,0.258972,0.218763,0.268013,0.250437,0.262787,0.219160,0.267811,0.251059,0.261971,0.263463,0.246347,0.231249,0.258941,0.253863,0.250784,0.235230,0.260123,0.252160,0.259235,0.222052,0.266554,0.263342,0.243720,0.234733,0.258205
+0.234500,0.263956,0.234049,0.267495,0.235018,0.259855,0.236050,0.269078,0.241771,0.256273,0.238365,0.263591,0.232219,0.264606,0.238087,0.265088,0.284496,0.222616,0.234091,0.258797,0.264414,0.222925,0.244467,0.268194,0.275078,0.222873,0.237621,0.264428,0.281372,0.229373,0.234067,0.255188
+0.234040,0.252889,0.257620,0.255452,0.224066,0.263633,0.260296,0.252005,0.225525,0.262858,0.259129,0.252488,0.228838,0.258643,0.258804,0.253715,0.238112,0.242119,0.243021,0.276748,0.245645,0.248420,0.247481,0.258454,0.243159,0.239284,0.243943,0.273614,0.242616,0.242325,0.245512,0.269546
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_attention_test2.csv
 
b/src/test/resources/component/transformers/bert_layer/output_attention_test2.csv
new file mode 100644
index 0000000000..b6e5f467e1
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_attention_test2.csv
@@ -0,0 +1,4 @@
+0.235963,0.286636,0.247663,0.229737,0.238065,0.274944,0.247426,0.239566,0.232993,0.292855,0.246338,0.227815,0.238436,0.277112,0.245687,0.238766,0.263707,0.244682,0.250010,0.241600,0.260119,0.247386,0.250597,0.241898,0.263318,0.247697,0.249780,0.239204,0.262353,0.253734,0.254431,0.229482
+0.200993,0.247095,0.287640,0.264272,0.209195,0.248364,0.281387,0.261054,0.203293,0.250675,0.284920,0.261113,0.196908,0.246377,0.297310,0.259405,0.275405,0.250326,0.237604,0.236665,0.275410,0.243066,0.237995,0.243529,0.268634,0.238802,0.252534,0.240029,0.260009,0.251079,0.252050,0.236862
+0.233736,0.309380,0.193616,0.263268,0.234876,0.298143,0.208577,0.258404,0.233731,0.292249,0.219675,0.254344,0.236114,0.298053,0.203767,0.262066,0.252833,0.254184,0.246234,0.246750,0.267395,0.239596,0.231539,0.261470,0.250207,0.258886,0.242438,0.248469,0.253909,0.258431,0.234523,0.253137
+0.260840,0.206205,0.213076,0.319878,0.260534,0.215267,0.221396,0.302804,0.256224,0.219768,0.224008,0.300000,0.263226,0.215017,0.223917,0.297839,0.241882,0.252050,0.249805,0.256263,0.244105,0.250428,0.252672,0.252795,0.247877,0.255111,0.241012,0.256000,0.235736,0.265862,0.248117,0.250285
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_states_test1.csv 
b/src/test/resources/component/transformers/bert_layer/output_states_test1.csv
new file mode 100644
index 0000000000..6cee7e5a41
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_states_test1.csv
@@ -0,0 +1,5 @@
+-0.664551,1.349482,-1.209526,0.404443,-0.965478,1.085631,-0.936040,1.337401,-0.616370,1.112321,-1.259653,0.362341,-1.642397,0.013962,-0.763702,1.117529,0.063756,1.210851,-1.415662,0.038215,-0.195357,1.344599,-0.916016,1.144221
+-1.229007,1.304426,0.646766,0.037673,-1.399634,0.639775,-0.692394,1.919591,-0.705923,0.614814,-0.964332,-0.171756,-1.462572,1.483334,-0.087375,1.049626,-0.676030,-0.306984,-0.724915,-0.395985,-1.570939,0.557762,0.705374,1.428703
+-1.760529,0.281079,0.029176,1.602762,-0.422803,0.270315,-0.195682,0.251327,0.200025,0.764210,-2.051764,1.031883,-0.751686,0.987462,-1.608172,1.368675,-0.012139,0.015861,-0.905140,0.669941,0.167975,1.464571,-1.588176,0.190830
+-0.998164,1.258126,0.129792,1.359412,-0.760246,-0.988922,-0.510680,-0.110131,-1.239414,1.743648,0.814538,-0.697961,-1.430896,1.009113,-1.063847,0.907095,0.926561,-0.348027,-0.644755,1.156432,0.229766,1.029308,-1.770506,-0.000245
+-0.750959,0.645399,-1.840763,0.233233,0.622690,1.090399,-0.896813,0.906920,-0.829986,1.581897,-1.048696,0.286679,-0.901825,0.554419,-0.745103,1.129790,-1.252949,1.215668,-0.462463,-0.840367,-0.949719,1.467577,-0.533816,1.318788
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_states_test2.csv 
b/src/test/resources/component/transformers/bert_layer/output_states_test2.csv
new file mode 100644
index 0000000000..ea1dc320a0
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_states_test2.csv
@@ -0,0 +1,4 @@
+0.107880,1.019378,-1.493256,-1.389399,0.794790,0.638421,-0.760268,1.082454,0.502346,1.475414,-1.059900,-0.800042,0.679010,-0.141650,-1.574906,0.919728,0.197951,0.675718,-1.499572,-1.111723,1.262451,0.866343,-1.089658,0.698491,0.020563,0.650392,0.510717,-1.720302,0.648202,0.237416,-1.555443,1.208454
+-0.479673,0.021814,-1.219189,-1.467430,0.830685,1.627373,-0.183904,0.870324,0.499331,-0.030822,-0.247241,-2.259303,0.867174,0.090776,-0.247555,1.327639,-0.698871,-0.044917,-0.524429,0.206017,2.304515,0.538566,-0.838074,-0.942806,0.366909,-0.345850,-1.177876,-0.736904,1.781455,0.944009,-1.253939,0.422195
+0.448434,-0.028957,-0.718420,-0.310317,1.695408,1.122331,-1.657827,-0.550652,0.321436,1.200561,0.028985,0.019767,1.668159,-1.221590,-1.277247,-0.740072,-0.742072,1.290652,0.293041,-1.603472,1.570833,0.167195,-0.766135,-0.210041,0.479908,-0.365131,-0.354259,0.115244,1.883001,0.681843,-1.724637,-0.715968
+0.389834,0.344840,-1.333194,-0.933634,1.855805,0.720664,-1.056654,0.012339,-0.525639,0.504899,-1.274403,-1.064770,1.492330,0.777126,-0.923244,1.013701,-0.584799,1.352766,-0.084635,-1.859024,1.122414,0.390217,-0.837819,0.500880,1.150180,0.986524,-1.630600,0.097320,0.872525,-0.111189,-1.498034,0.133275
diff --git a/src/test/scripts/applications/nn/component/bert_layer_forward.dml 
b/src/test/scripts/applications/nn/component/bert_layer_forward.dml
new file mode 100644
index 0000000000..6801d4d1df
--- /dev/null
+++ b/src/test/scripts/applications/nn/component/bert_layer_forward.dml
@@ -0,0 +1,92 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("scripts/nn/layers/bert_layer.dml") as bert_layer
+
+debug = as.logical(as.integer($1))
+
+B = as.integer($2)
+T = as.integer($3)
+D = as.integer($4)
+H = as.integer($5)
+d = as.integer($6)
+I = as.integer($7)
+
+dropout_p_attention = 0.0
+dropout_p_output = 0.0
+epsilon_ln = 1e-012
+activation = $8
+
+states = matrix(read($9, format="csv"), rows=B, cols=T*D) #
+
+W_Q = matrix(read($10, format="csv"), rows=D, cols=D)
+b_Q = matrix(read($11, format="csv"), rows=1, cols=D)
+W_K = matrix(read($12, format="csv"), rows=D, cols=D)
+b_K = matrix(read($13, format="csv"), rows=1, cols=D)
+W_V = matrix(read($14, format="csv"), rows=D, cols=D)
+b_V = matrix(read($15, format="csv"), rows=1, cols=D)
+W_context = matrix(read($16, format="csv"), rows=D, cols=D)
+b_context = matrix(read($17, format="csv"), rows=1, cols=D)
+W_intermediate = matrix(read($18, format="csv"), rows=D, cols=I)
+b_intermediate = matrix(read($19, format="csv"), rows=1, cols=I)
+W_out = matrix(read($20, format="csv"), rows=I, cols=D)
+b_out = matrix(read($21, format="csv"), rows=1, cols=D)
+
+gamma_ln1 = matrix(read($22, format="csv"), rows=1, cols=D)
+beta_ln1 = matrix(read($23, format="csv"), rows=1, cols=D)
+gamma_ln2 = matrix(read($24, format="csv"), rows=1, cols=D)
+beta_ln2 = matrix(read($25, format="csv"), rows=1, cols=D)
+
+expected_out_states = read($26, format="csv")
+expected_attention = read($27, format="csv")
+
+[out_states, attention, outputs, dropout_mask_attention, 
dropout_mask_output_1, dropout_mask_output_2, cache_mean_ln1,
+  cache_var_ln1, cache_norm_ln1, cache_mean_ln2, cache_var_ln2, 
cache_norm_ln2] = bert_layer::forward(states,
+    H, T, d, I,
+    W_Q,  b_Q, 
+    W_K,  b_K, 
+    W_V,  b_V,
+    W_context,  b_context, 
+    W_intermediate,  b_intermediate, 
+    W_out,  b_out, 
+    dropout_p_attention, 
+    dropout_p_output,
+    epsilon_ln,
+    gamma_ln1,  beta_ln1,
+    gamma_ln2,  beta_ln2,
+    activation
+)
+
+if (debug) {
+    print(toString(out_states))
+    print(toString(attention))
+}
+
+states_error = max(abs(expected_out_states - out_states))
+attention_error = max(abs(expected_attention - attention))
+
+if (debug) {
+    print(states_error)
+    print(attention_error)
+}
+
+write(states_error, $28, format="text")
+write(attention_error, $29, format="text")

Reply via email to