This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 22642a1bdd [SYSTEMDS-3829] BERT layer backward pass
22642a1bdd is described below

commit 22642a1bddf6a2755069f80413b04216b2dd1a89
Author: MaximilianSchreff <[email protected]>
AuthorDate: Wed Feb 5 22:26:45 2025 +0100

    [SYSTEMDS-3829] BERT layer backward pass
    
    Closes #2213
---
 scripts/nn/layers/bert_layer.dml                   | 182 +++++++++++++++++++
 .../nn/transformers/BertLayerTest.java             | 116 +++++++++++-
 .../transformers/bert_layer/input_W_K_test3.csv    |   8 +
 .../transformers/bert_layer/input_W_K_test4.csv    |   4 +
 .../transformers/bert_layer/input_W_Q_test3.csv    |   8 +
 .../transformers/bert_layer/input_W_Q_test4.csv    |   4 +
 .../transformers/bert_layer/input_W_V_test3.csv    |   8 +
 .../transformers/bert_layer/input_W_V_test4.csv    |   4 +
 .../bert_layer/input_W_context_test3.csv           |   8 +
 .../bert_layer/input_W_context_test4.csv           |   4 +
 .../bert_layer/input_W_intermediate_test3.csv      |   8 +
 .../bert_layer/input_W_intermediate_test4.csv      |   4 +
 .../transformers/bert_layer/input_W_out_test3.csv  |   5 +
 .../transformers/bert_layer/input_W_out_test4.csv  |   4 +
 .../transformers/bert_layer/input_b_K_test3.csv    |   8 +
 .../transformers/bert_layer/input_b_K_test4.csv    |   4 +
 .../transformers/bert_layer/input_b_Q_test3.csv    |   8 +
 .../transformers/bert_layer/input_b_Q_test4.csv    |   4 +
 .../transformers/bert_layer/input_b_V_test3.csv    |   8 +
 .../transformers/bert_layer/input_b_V_test4.csv    |   4 +
 .../bert_layer/input_b_context_test3.csv           |   8 +
 .../bert_layer/input_b_context_test4.csv           |   4 +
 .../bert_layer/input_b_intermediate_test3.csv      |   5 +
 .../bert_layer/input_b_intermediate_test4.csv      |   4 +
 .../transformers/bert_layer/input_b_out_test3.csv  |   8 +
 .../transformers/bert_layer/input_b_out_test4.csv  |   4 +
 .../bert_layer/input_beta_ln1_test3.csv            |   8 +
 .../bert_layer/input_beta_ln1_test4.csv            |   4 +
 .../bert_layer/input_beta_ln2_test3.csv            |   8 +
 .../bert_layer/input_beta_ln2_test4.csv            |   4 +
 .../bert_layer/input_dstates_test3.csv             |   2 +
 .../bert_layer/input_dstates_test4.csv             |   4 +
 .../bert_layer/input_gamma_ln1_test3.csv           |   8 +
 .../bert_layer/input_gamma_ln1_test4.csv           |   4 +
 .../bert_layer/input_gamma_ln2_test3.csv           |   8 +
 .../bert_layer/input_gamma_ln2_test4.csv           |   4 +
 .../transformers/bert_layer/input_states_test3.csv |   2 +
 .../transformers/bert_layer/input_states_test4.csv |   4 +
 .../bert_layer/output_attention_test3.csv          |   2 +
 .../bert_layer/output_attention_test4.csv          |   4 +
 .../transformers/bert_layer/output_dW_K_test3.csv  |   8 +
 .../transformers/bert_layer/output_dW_K_test4.csv  |   4 +
 .../transformers/bert_layer/output_dW_Q_test3.csv  |   8 +
 .../transformers/bert_layer/output_dW_Q_test4.csv  |   4 +
 .../transformers/bert_layer/output_dW_V_test3.csv  |   8 +
 .../transformers/bert_layer/output_dW_V_test4.csv  |   4 +
 .../bert_layer/output_dW_context_test3.csv         |   8 +
 .../bert_layer/output_dW_context_test4.csv         |   4 +
 .../bert_layer/output_dW_intermediate_test3.csv    |   8 +
 .../bert_layer/output_dW_intermediate_test4.csv    |   4 +
 .../bert_layer/output_dW_out_test3.csv             |   5 +
 .../bert_layer/output_dW_out_test4.csv             |   4 +
 .../transformers/bert_layer/output_db_K_test3.csv  |   8 +
 .../transformers/bert_layer/output_db_K_test4.csv  |   4 +
 .../transformers/bert_layer/output_db_Q_test3.csv  |   8 +
 .../transformers/bert_layer/output_db_Q_test4.csv  |   4 +
 .../transformers/bert_layer/output_db_V_test3.csv  |   8 +
 .../transformers/bert_layer/output_db_V_test4.csv  |   4 +
 .../bert_layer/output_db_context_test3.csv         |   8 +
 .../bert_layer/output_db_context_test4.csv         |   4 +
 .../bert_layer/output_db_intermediate_test3.csv    |   5 +
 .../bert_layer/output_db_intermediate_test4.csv    |   4 +
 .../bert_layer/output_db_out_test3.csv             |   8 +
 .../bert_layer/output_db_out_test4.csv             |   4 +
 .../bert_layer/output_dbeta_ln1_test3.csv          |   8 +
 .../bert_layer/output_dbeta_ln1_test4.csv          |   4 +
 .../bert_layer/output_dbeta_ln2_test3.csv          |   8 +
 .../bert_layer/output_dbeta_ln2_test4.csv          |   4 +
 .../bert_layer/output_dgamma_ln1_test3.csv         |   8 +
 .../bert_layer/output_dgamma_ln1_test4.csv         |   4 +
 .../bert_layer/output_dgamma_ln2_test3.csv         |   8 +
 .../bert_layer/output_dgamma_ln2_test4.csv         |   4 +
 .../bert_layer/output_dstates_test3.csv            |   2 +
 .../bert_layer/output_dstates_test4.csv            |   4 +
 .../bert_layer/output_states_test3.csv             |   2 +
 .../bert_layer/output_states_test4.csv             |   4 +
 .../nn/component/bert_layer_backward.dml           | 198 +++++++++++++++++++++
 77 files changed, 892 insertions(+), 6 deletions(-)

diff --git a/scripts/nn/layers/bert_layer.dml b/scripts/nn/layers/bert_layer.dml
index 75b33fb263..e1b1fdf362 100644
--- a/scripts/nn/layers/bert_layer.dml
+++ b/scripts/nn/layers/bert_layer.dml
@@ -37,6 +37,17 @@ linear_tensor_forward = function(matrix[double] X, 
matrix[double] W, matrix[doub
   out = matrix(out, rows=A, cols=B*C_new)
 }
 
+linear_tensor_backward = function(matrix[double] dout, matrix[double] X, 
matrix[double] W, matrix[double] b, int B,
+    int C_out, int C_in)
+  return (matrix[double] dX, matrix[double] dW, matrix[double] db) {
+  /*
+   * Helper function for computing linear layer with tensor input, of shape 
(A, B*C)
+   */
+  A = nrow(X)
+  [dX, dW, db] = affine::backward(matrix(dout, rows=A*B, cols=C_out), 
matrix(X, rows=A*B, cols=C_in), W, b)
+  dX = matrix(dX, rows=A, cols=B*C_in)
+}
+
 layer_norm_forward = function(matrix[double] X, matrix[double] gamma, 
matrix[double] beta, double epsilon, int B, int C)
   return (matrix[double] out, matrix[double] cache_mean, matrix[double] 
cache_var, matrix[double] cache_norm) {
   /*
@@ -51,6 +62,27 @@ layer_norm_forward = function(matrix[double] X, 
matrix[double] gamma, matrix[dou
   out = matrix(t(batch_norm_out), rows=A, cols=B*C)
 }
 
+layer_norm_backward = function(matrix[double] dout, matrix[double] cache_mean, 
matrix[double] cache_var,
+    matrix[double] cache_norm, matrix[double] X, matrix[double] gamma, 
matrix[double] beta, double epsilon, int B, int C)
+  return (matrix[double] dX, matrix[double] dgamma, matrix[double] dbeta) {
+  /*
+   * Helper function for computing layer norm via 1D batch norm with tensor 
input, of shpae (A, B*C)
+   */
+  A = nrow(X)
+  batch_norm_input = t(matrix(X, rows=A*B, cols=C))
+  batch_norm_doutput = t(matrix(dout, rows=A*B, cols=C))
+  # EMA matrices, updated EMA matrices and out matrix are unused and thus 
empty matrices will be provided
+  empty_mat = matrix(0, rows=1, cols=1)
+  [batch_norm_dX, unused1, unused2] = batch_norm::backward(
+    batch_norm_doutput,
+    empty_mat, empty_mat, empty_mat,
+    cache_mean, cache_var, cache_norm,
+    batch_norm_input, t(gamma), t(beta), "train", empty_mat, empty_mat, 0.0, 
epsilon)
+  dX = matrix(t(batch_norm_dX), rows=A, cols=B*C)
+  dgamma = t(rowSums(batch_norm_doutput * cache_norm))
+  dbeta = t(rowSums(batch_norm_doutput))
+}
+
 forward = function(matrix[double] states,
       int H, int T, int d, int I,
       matrix[double] W_Q, matrix[double] b_Q, 
@@ -184,3 +216,153 @@ forward = function(matrix[double] states,
   [out_states, cache_mean_ln2, cache_var_ln2, cache_norm_ln2] = 
layer_norm_forward(
     out_states, gamma_ln2, beta_ln2, epsilon_ln, T, D)
 }
+
+backward = function(matrix[double] dout_states,
+      matrix[double] dropout_mask_attention,
+      matrix[double] dropout_mask_output_1,
+      matrix[double] dropout_mask_output_2,
+      matrix[double] cache_mean_ln1, matrix[double] cache_var_ln1, 
matrix[double] cache_norm_ln1,
+      matrix[double] cache_mean_ln2, matrix[double] cache_var_ln2, 
matrix[double] cache_norm_ln2,
+      list[unknown] outputs,
+      matrix[double] states,
+      int H, int T, int d, int I,
+      matrix[double] W_Q, matrix[double] b_Q,
+      matrix[double] W_K, matrix[double] b_K,
+      matrix[double] W_V, matrix[double] b_V,
+      matrix[double] W_context, matrix[double] b_context,
+      matrix[double] W_intermediate, matrix[double] b_intermediate,
+      matrix[double] W_out, matrix[double] b_out,
+      double dropout_p_attention,
+      double dropout_p_output,
+      double epsilon_ln,
+      matrix[double] gamma_ln1, matrix[double] beta_ln1,
+      matrix[double] gamma_ln2, matrix[double] beta_ln2,
+      string activation)
+    return (matrix[double] din_states,
+      matrix[double] dW_Q, matrix[double] db_Q,
+      matrix[double] dW_K, matrix[double] db_K,
+      matrix[double] dW_V, matrix[double] db_V,
+      matrix[double] dW_context, matrix[double] db_context,
+      matrix[double] dW_intermediate, matrix[double] db_intermediate,
+      matrix[double] dW_out, matrix[double] db_out,
+      matrix[double] dgamma_ln1, matrix[double] dbeta_ln1,
+      matrix[double] dgamma_ln2, matrix[double] dbeta_ln2) {
+  /*
+   * Computes the backward pass for a layer of the BERT transformer 
architecture.
+   *
+   * Inputs (B: Batch size, T: Sequence length, D: Embedding length, H: Heads):
+   * - dout_states: Gradients w.r.t. output states, of shape (B, T*D)
+   * - dropout_mask_attention: Dropout mask used on attention, of shape (B, 
H*T*T)
+   * - dropout_mask_output_1: Dropout mask used on attention output, of shape 
(B, T*D)
+   * - dropout_mask_output_2: Dropout mask used on attention output, of shape 
(B, T*D)
+   * - cache_mean_ln1: Cached mean from layer norm 1, of shape (1, B*T)
+   * - cache_var_ln1: Cached mean from layer norm 1, of shape (1, B*T)
+   * - cache_norm_ln1: Cached mean from layer norm 1, of shape (1, B*T)
+   * - cache_mean_ln2: Cached mean from layer norm 2, of shape (1, B*T)
+   * - cache_var_ln2: Cached mean from layer norm 2, of shape (1, B*T)
+   * - cache_norm_ln2: Cached mean from layer norm 2, of shape (1, B*T)
+   * - outputs: list of relevant outputs from forward pass
+   *     with the following order/content:
+   *   -> 1: Output of linear query layer, of shape (B, T*D).
+   *   -> 2: Output of linear key layer, of shape (B, T*D).
+   *   -> 3: Output of linear value layer, of shape (B, T*D).
+   *   -> 4: Output context of attention layer, of shape (B, T*D).
+   *   -> 5: Output attention of attention layer, of shape (B, T*D).
+   *   -> 6: Output of residual pass 1, of shape (B, T*D).
+   *   -> 7: Output of layer norm 1, of shape (B, T*D).
+   *   -> 8: Output of intermediate linear layer, of shape (B, T*I).
+   *   -> 9: Output of activation layer, of shape (B, T*I).
+   *   -> 10: Output of residual pass 2, of shape (B, T*D).
+   * - states: Hidden states, of shape (B, T*D).
+   * - H: Head count.
+   * - T: Sequence length.
+   * - d: Embedding length of single token per head with d*H = D.
+   * - I: Intemediate embedding length.
+   * - W_Q: Weights for linear query layer, of shape (D, D).
+   * - b_Q: Biases for linear query layer, of shape (1, D).
+   * - W_K: Weights for linear key layer, of shape (D, D).
+   * - b_K: Biases for linear key layer, of shape (1, D).
+   * - W_V: Weights for linear value layer, of shape (D, D).
+   * - b_V: Biases for linear value layer, of shape (1, D).
+   * - W_context: Weights for linear output layer on context, of shape (D, D).
+   * - b_context: Biases for linear output layer on context, of shape (1, D).
+   * - W_intermediate: Weights for intermediate linear layer, of shape (D, I).
+   * - b_intermediate: Biases for intermediate linear layer, of shape (1, I).
+   * - W_out: Weights for last linear output layer, of shape (D, D).
+   * - b_out: Biases for last linear output layer, of shape (1, D).
+   * - dropout_p_attention: Probability for dropout on attention.
+   * - dropout_p_output: Probability for dropout on output.
+   * - epsilon_ln: Epsilon value for layer norm.
+   * - gamma_ln1: Gamma params for layer norm 1, of shape (1, D).
+   * - beta_ln1: Beta params for layer norm 1, of shape (1, D).
+   * - gamma_ln2: Gamma params for layer norm 2, of shape (1, D).
+   * - beta_ln2: Beta params for layer norm 2, of shape (1, D).
+   * - activation: String specifying type of activation to use.
+   *     Can be tanh or gelu.
+   *
+   * Outputs:
+   * - din_states: Gradients w.r.t. hidden input states, of shape (B, T*D).
+   * - W_Q: Gradients w.r.t. weights for linear query layer, of shape (D, D).
+   * - b_Q: Gradients w.r.t. biases for linear query layer, of shape (1, D).
+   * - W_K: Gradients w.r.t. weights for linear key layer, of shape (D, D).
+   * - b_K: Gradients w.r.t. biases for linear key layer, of shape (1, D).
+   * - W_V: Gradients w.r.t. weights for linear value layer, of shape (D, D).
+   * - b_V: Gradients w.r.t. biases for linear value layer, of shape (1, D).
+   * - W_context: Gradients w.r.t. weights for linear output layer on context, 
of shape (D, D).
+   * - b_context: Gradients w.r.t. biases for linear output layer on context, 
of shape (1, D).
+   * - W_intermediate: Gradients w.r.t. weights for intermediate linear layer, 
of shape (D, I).
+   * - b_intermediate: Gradients w.r.t. biases for intermediate linear layer, 
of shape (1, I).
+   * - W_out: Gradients w.r.t. weights for last linear output layer, of shape 
(D, D).
+   * - b_out: Gradients w.r.t. biases for last linear output layer, of shape 
(1, D).
+   */
+  # Embedding dim
+  D = d * H
+
+  # Layer norm 2 for each token
+  [dout_states, dgamma_ln2, dbeta_ln2] = layer_norm_backward(
+    dout_states, cache_mean_ln2, cache_var_ln2, cache_norm_ln2, 
as.matrix(outputs[10]), gamma_ln2, beta_ln2, epsilon_ln, T, D)
+  # Save dout_states for residual pass
+  dout_states_identity_2 = dout_states
+  # Dropout on output 2
+  if (dropout_p_output > 0.0) {
+    dout_states = dropout::backward(dout_states, matrix(0, 1, 1), 
dropout_p_output, dropout_mask_output_2)
+  }
+  # Final linear output layer
+  [dout_states, dW_out, db_out] = linear_tensor_backward(dout_states, 
as.matrix(outputs[9]), W_out, b_out, T, D, I)
+
+  # Activation
+  if (activation == "gelu") {
+    dout_states = gelu::backward(dout_states, as.matrix(outputs[8]))
+  } else if (activation == "tanh") {
+    dout_states = tanh::backward(dout_states, as.matrix(outputs[8]))
+  }
+  # Linear layer of intermediate part
+  [dout_states, dW_intermediate, db_intermediate] = 
linear_tensor_backward(dout_states, as.matrix(outputs[7]), W_intermediate,
+    b_intermediate, T, I, D)
+  # Residual pass 2
+  dout_states = dout_states + dout_states_identity_2
+
+  # Layer norm 1 for each token
+  [dout_states, dgamma_ln1, dbeta_ln1] = layer_norm_backward(
+    dout_states, cache_mean_ln1, cache_var_ln1, cache_norm_ln1, 
as.matrix(outputs[6]), gamma_ln1, beta_ln1, epsilon_ln, T, D)
+  # Save dout_states for residual pass
+  dout_states_identity_1 = dout_states
+
+  # Dropout on output 1
+  if (dropout_p_output > 0.0) {
+    dout_states = dropout::backward(dout_states, matrix(0, 1, 1), 
dropout_p_output, dropout_mask_output_1)
+  }
+  # Linear layer on attention output (output layer)
+  [dcontext, dW_context, db_context] = linear_tensor_backward(dout_states, 
as.matrix(outputs[4]), W_context, b_context, T, D, D)
+
+  # Multi-head self attention
+  [dQ, dK, dV] = attention::backward(dcontext, dropout_mask_attention, 
as.matrix(outputs[5]), as.matrix(outputs[1]),
+    as.matrix(outputs[2]), as.matrix(outputs[3]), H, T, d, dropout_p_attention)
+
+  # Linear layers for Q, K, V
+  [dstates_Q, dW_Q, db_Q] = linear_tensor_backward(dQ, states, W_Q, b_Q, T, D, 
D)
+  [dstates_K, dW_K, db_K] = linear_tensor_backward(dK, states, W_K, b_K, T, D, 
D)
+  [dstates_V, dW_V, db_V] = linear_tensor_backward(dV, states, W_V, b_V, T, D, 
D)
+  # Add paths + residual pass 1
+  din_states = dstates_Q + dstates_K + dstates_V + dout_states_identity_1
+}
diff --git 
a/src/test/java/org/apache/sysds/test/applications/nn/transformers/BertLayerTest.java
 
b/src/test/java/org/apache/sysds/test/applications/nn/transformers/BertLayerTest.java
index 583288dfbc..5d25c15928 100644
--- 
a/src/test/java/org/apache/sysds/test/applications/nn/transformers/BertLayerTest.java
+++ 
b/src/test/java/org/apache/sysds/test/applications/nn/transformers/BertLayerTest.java
@@ -26,6 +26,7 @@ import org.junit.Test;
 
 public class BertLayerTest extends AutomatedTestBase{
        private static final String TEST_NAME_FORWARD = "bert_layer_forward";
+       private static final String TEST_NAME_BACKWARD = "bert_layer_backward";
        private static final String TEST_DIR = "applications/nn/component/";
        private static final String RESOURCE_DIR = 
"src/test/resources/component/transformers/bert_layer/";
 
@@ -33,6 +34,7 @@ public class BertLayerTest extends AutomatedTestBase{
        public void setUp() {
                TestUtils.clearAssertionInformation();
                addTestConfiguration(TEST_NAME_FORWARD, new 
TestConfiguration(TEST_DIR, TEST_NAME_FORWARD));
+               addTestConfiguration(TEST_NAME_BACKWARD, new 
TestConfiguration(TEST_DIR, TEST_NAME_BACKWARD));
        }
 
        @Test
@@ -47,6 +49,18 @@ public class BertLayerTest extends AutomatedTestBase{
             1e-5, true);
        }
 
+       @Test
+       public void testBertLayerBackwardNormalGelu() {
+               runBertLayerTest("test3", 2, 3, 8, 2, 4, 5, "gelu", 0, 
TEST_NAME_BACKWARD, 
+            1e-4, false);
+       }
+
+       @Test
+       public void testBertLayerBackwardSameDimsTanh() {
+               runBertLayerTest("test4", 4, 4, 4, 2, 2, 4, "tanh", 0, 
TEST_NAME_BACKWARD, 
+            1e-4, false);
+       }
+
        private void runBertLayerTest(String testSuffix, int batchSize, int 
seqLength, int embeddingDim, int numHeads,
             int perHeadEmbeddingDim, int intermediateEmbeddingDim, String 
activation, int debug, String testname, double precision, 
             boolean isForward) {
@@ -88,6 +102,68 @@ public class BertLayerTest extends AutomatedTestBase{
                                        output("states_error"),
                                        output("attention_error"), 
                                };
+                       } else {
+                               programArgs = new String[] { 
+                                       "-stats", "-args",
+                                       String.valueOf(debug), 
String.valueOf(batchSize), 
+                    String.valueOf(seqLength), String.valueOf(embeddingDim),
+                                       String.valueOf(numHeads), 
String.valueOf(perHeadEmbeddingDim),
+                    String.valueOf(intermediateEmbeddingDim), activation,
+                                       RESOURCE_DIR + "input_states_" + 
testSuffix + ".csv",
+                                       RESOURCE_DIR + "input_W_Q_" + 
testSuffix + ".csv",
+                                       RESOURCE_DIR + "input_b_Q_" + 
testSuffix + ".csv",
+                    RESOURCE_DIR + "input_W_K_" + testSuffix + ".csv",
+                                       RESOURCE_DIR + "input_b_K_" + 
testSuffix + ".csv",
+                    RESOURCE_DIR + "input_W_V_" + testSuffix + ".csv",
+                                       RESOURCE_DIR + "input_b_V_" + 
testSuffix + ".csv",
+                    RESOURCE_DIR + "input_W_context_" + testSuffix + ".csv",
+                                       RESOURCE_DIR + "input_b_context_" + 
testSuffix + ".csv",
+                    RESOURCE_DIR + "input_W_intermediate_" + testSuffix + 
".csv",
+                                       RESOURCE_DIR + "input_b_intermediate_" 
+ testSuffix + ".csv",
+                    RESOURCE_DIR + "input_W_out_" + testSuffix + ".csv",
+                                       RESOURCE_DIR + "input_b_out_" + 
testSuffix + ".csv",
+                    RESOURCE_DIR + "input_gamma_ln1_" + testSuffix + ".csv",
+                                       RESOURCE_DIR + "input_beta_ln1_" + 
testSuffix + ".csv",
+                    RESOURCE_DIR + "input_gamma_ln2_" + testSuffix + ".csv",
+                                       RESOURCE_DIR + "input_beta_ln2_" + 
testSuffix + ".csv",
+                                       RESOURCE_DIR + "output_states_" + 
testSuffix + ".csv",
+                                       RESOURCE_DIR + "output_attention_" + 
testSuffix + ".csv",
+                                       RESOURCE_DIR + "input_dstates_" + 
testSuffix + ".csv",
+                                       RESOURCE_DIR + "output_dstates_" + 
testSuffix + ".csv",
+                                       RESOURCE_DIR + "output_dW_Q_" + 
testSuffix + ".csv",
+                                       RESOURCE_DIR + "output_db_Q_" + 
testSuffix + ".csv",
+                    RESOURCE_DIR + "output_dW_K_" + testSuffix + ".csv",
+                                       RESOURCE_DIR + "output_db_K_" + 
testSuffix + ".csv",
+                    RESOURCE_DIR + "output_dW_V_" + testSuffix + ".csv",
+                                       RESOURCE_DIR + "output_db_V_" + 
testSuffix + ".csv",
+                    RESOURCE_DIR + "output_dW_context_" + testSuffix + ".csv",
+                                       RESOURCE_DIR + "output_db_context_" + 
testSuffix + ".csv",
+                    RESOURCE_DIR + "output_dW_intermediate_" + testSuffix + 
".csv",
+                                       RESOURCE_DIR + 
"output_db_intermediate_" + testSuffix + ".csv",
+                    RESOURCE_DIR + "output_dW_out_" + testSuffix + ".csv",
+                                       RESOURCE_DIR + "output_db_out_" + 
testSuffix + ".csv",
+                    RESOURCE_DIR + "output_dgamma_ln1_" + testSuffix + ".csv",
+                                       RESOURCE_DIR + "output_dbeta_ln1_" + 
testSuffix + ".csv",
+                    RESOURCE_DIR + "output_dgamma_ln2_" + testSuffix + ".csv",
+                                       RESOURCE_DIR + "output_dbeta_ln2_" + 
testSuffix + ".csv",
+                                       output("din_error"),
+                                       output("dW_Q_error"),
+                                       output("db_Q_error"),
+                                       output("dW_K_error"),
+                                       output("db_K_error"),
+                                       output("dW_V_error"),
+                                       output("db_V_error"),
+                                       output("dW_context_error"),
+                                       output("db_context_error"),
+                                       output("dW_intermediate_error"),
+                                       output("db_intermediate_error"),
+                                       output("dW_out_error"),
+                                       output("db_out_error"),
+                                       output("dgamma_ln1_error"),
+                                       output("dbeta_ln1_error"),
+                                       output("dgamma_ln2_error"),
+                                       output("dbeta_ln2_error"),
+                               };
                        }
 
                        // Run the test
@@ -100,12 +176,40 @@ public class BertLayerTest extends AutomatedTestBase{
                                double attentionMaxError = (Double) 
readDMLScalarFromOutputDir("attention_error").values().toArray()[0];
                                assert attentionMaxError < precision;
                        } else {
-                               double dqueryMaxError = (Double) 
readDMLScalarFromOutputDir("dquery_error").values().toArray()[0];
-                               assert dqueryMaxError < precision;
-                               double dkeyMaxError = (Double) 
readDMLScalarFromOutputDir("dkey_error").values().toArray()[0];
-                               assert dkeyMaxError < precision;
-                               double dvalueMaxError = (Double) 
readDMLScalarFromOutputDir("dvalue_error").values().toArray()[0];
-                               assert dvalueMaxError < precision;
+                               double dinMaxError = (Double) 
readDMLScalarFromOutputDir("din_error").values().toArray()[0];
+                               assert dinMaxError < precision;
+                               double dWQMaxError = (Double) 
readDMLScalarFromOutputDir("dW_Q_error").values().toArray()[0];
+                               assert dWQMaxError < precision;
+                               double dbQMaxError = (Double) 
readDMLScalarFromOutputDir("db_Q_error").values().toArray()[0];
+                               assert dbQMaxError < precision;
+                               double dWKMaxError = (Double) 
readDMLScalarFromOutputDir("dW_K_error").values().toArray()[0];
+                               assert dWKMaxError < precision;
+                               double dbKMaxError = (Double) 
readDMLScalarFromOutputDir("db_K_error").values().toArray()[0];
+                               assert dbKMaxError < precision;
+                               double dWVMaxError = (Double) 
readDMLScalarFromOutputDir("dW_V_error").values().toArray()[0];
+                               assert dWVMaxError < precision;
+                               double dbVMaxError = (Double) 
readDMLScalarFromOutputDir("db_V_error").values().toArray()[0];
+                               assert dbVMaxError < precision;
+                               double dWContextMaxError = (Double) 
readDMLScalarFromOutputDir("dW_context_error").values().toArray()[0];
+                               assert dWContextMaxError < precision;
+                               double dbContextMaxError = (Double) 
readDMLScalarFromOutputDir("db_context_error").values().toArray()[0];
+                               assert dbContextMaxError < precision;
+                               double dWIntermediateMaxError = (Double) 
readDMLScalarFromOutputDir("dW_intermediate_error").values().toArray()[0];
+                               assert dWIntermediateMaxError < precision;
+                               double dbIntermediateMaxError = (Double) 
readDMLScalarFromOutputDir("db_intermediate_error").values().toArray()[0];
+                               assert dbIntermediateMaxError < precision;
+                               double dWOutMaxError = (Double) 
readDMLScalarFromOutputDir("dW_out_error").values().toArray()[0];
+                               assert dWOutMaxError < precision;
+                               double dbOutMaxError = (Double) 
readDMLScalarFromOutputDir("db_out_error").values().toArray()[0];
+                               assert dbOutMaxError < precision;
+                               double dgammaLn1MaxError = (Double) 
readDMLScalarFromOutputDir("dgamma_ln1_error").values().toArray()[0];
+                               assert dgammaLn1MaxError < precision;
+                               double dbetaLn1MaxError = (Double) 
readDMLScalarFromOutputDir("dbeta_ln1_error").values().toArray()[0];
+                               assert dbetaLn1MaxError < precision;
+                               double dgammaLn2MaxError = (Double) 
readDMLScalarFromOutputDir("dgamma_ln2_error").values().toArray()[0];
+                               assert dgammaLn2MaxError < precision;
+                               double dbetaLn2MaxError = (Double) 
readDMLScalarFromOutputDir("dbeta_ln2_error").values().toArray()[0];
+                               assert dbetaLn2MaxError < precision;
                        }
                } catch (Throwable ex) {
                        ex.printStackTrace(System.out); // Log or debug all 
exceptions or errors
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_K_test3.csv 
b/src/test/resources/component/transformers/bert_layer/input_W_K_test3.csv
new file mode 100644
index 0000000000..3430a1f849
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_W_K_test3.csv
@@ -0,0 +1,8 @@
+0.138744,-0.123685,0.009826,-0.317605,-0.071714,-0.068100,-0.318219,-0.040798
+0.021178,-0.289780,-0.030501,-0.167612,0.193891,-0.069418,-0.023860,-0.157829
+-0.172509,-0.075206,0.071553,0.240736,0.191146,-0.317261,0.310922,0.282720
+0.167298,0.075574,0.224803,-0.002292,-0.340978,-0.305271,-0.144212,-0.285705
+-0.339146,-0.230328,0.334902,-0.175732,0.220540,-0.055324,0.319260,0.037938
+-0.209553,-0.018144,0.224526,-0.270932,-0.276659,0.004572,0.128041,-0.074023
+-0.088505,0.253091,0.335668,-0.330874,-0.074745,-0.160610,-0.319068,0.252477
+-0.172221,-0.036345,-0.025570,-0.298402,-0.143356,0.133183,0.223692,0.098693
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_K_test4.csv 
b/src/test/resources/component/transformers/bert_layer/input_W_K_test4.csv
new file mode 100644
index 0000000000..951091fda8
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_W_K_test4.csv
@@ -0,0 +1,4 @@
+0.093586,0.192278,0.357936,0.249658
+-0.084230,-0.296152,0.186956,0.104651
+-0.082281,0.183296,-0.494868,-0.390042
+-0.228878,0.252854,-0.324348,-0.287910
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_Q_test3.csv 
b/src/test/resources/component/transformers/bert_layer/input_W_Q_test3.csv
new file mode 100644
index 0000000000..a2c9ca4c1a
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_W_Q_test3.csv
@@ -0,0 +1,8 @@
+-0.156799,0.065883,0.194091,0.138950,-0.329496,0.135961,0.176535,-0.336794
+-0.012757,-0.274112,-0.044625,0.292936,0.314130,-0.209411,0.073999,-0.006355
+0.226119,-0.245043,0.013499,0.307665,0.268828,0.129610,-0.275801,-0.266247
+0.351479,-0.182640,0.081920,0.311960,-0.352679,0.178795,-0.203583,-0.272716
+0.140319,0.159973,0.219336,0.070362,0.066175,0.253099,0.332605,-0.019481
+0.047763,0.142185,0.339480,-0.307444,-0.059560,0.132198,0.238231,0.053084
+0.237053,-0.209428,-0.272457,0.032524,-0.058181,-0.349924,-0.154158,-0.144791
+-0.208173,0.106811,-0.129567,-0.221185,-0.161841,-0.229349,-0.088984,0.209791
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_Q_test4.csv 
b/src/test/resources/component/transformers/bert_layer/input_W_Q_test4.csv
new file mode 100644
index 0000000000..f5fab7922e
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_W_Q_test4.csv
@@ -0,0 +1,4 @@
+0.274486,0.310188,0.196505,0.099507
+-0.063109,0.480097,0.414275,-0.434791
+0.019091,-0.385312,0.435104,0.045996
+0.115852,-0.183235,0.441178,-0.312803
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_V_test3.csv 
b/src/test/resources/component/transformers/bert_layer/input_W_V_test3.csv
new file mode 100644
index 0000000000..0910e03418
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_W_V_test3.csv
@@ -0,0 +1,8 @@
+-0.249389,0.014873,0.127848,-0.276551,-0.306393,0.044137,0.116741,0.004873
+0.015768,0.227909,0.044769,-0.185308,0.175143,0.316675,0.265246,-0.060110
+-0.249232,-0.267258,-0.002632,0.285492,-0.251829,0.216273,-0.113814,-0.186207
+-0.194626,-0.242719,-0.069891,-0.286925,-0.100361,-0.223521,0.000566,0.046730
+-0.206017,-0.205295,0.044359,-0.025387,-0.118623,0.158570,0.182018,0.292360
+0.120825,0.247464,-0.080732,0.349749,-0.052357,-0.249925,-0.341919,-0.103351
+-0.210687,-0.127090,-0.002484,0.127717,0.003867,-0.149845,0.255612,-0.209903
+-0.007714,0.298218,0.045111,0.010010,0.291613,0.103988,-0.292361,-0.130758
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_V_test4.csv 
b/src/test/resources/component/transformers/bert_layer/input_W_V_test4.csv
new file mode 100644
index 0000000000..33b3ab1341
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_W_V_test4.csv
@@ -0,0 +1,4 @@
+-0.476299,-0.027550,-0.304270,-0.124442
+-0.008987,0.075073,0.453685,0.022561
+-0.376529,-0.204765,0.342650,0.072951
+-0.385678,0.296689,-0.421641,0.118587
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_context_test3.csv
 
b/src/test/resources/component/transformers/bert_layer/input_W_context_test3.csv
new file mode 100644
index 0000000000..9b22b81c74
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_W_context_test3.csv
@@ -0,0 +1,8 @@
+0.128242,-0.334274,0.011264,-0.080849,0.244498,0.069431,0.122792,0.029533
+-0.118073,0.076385,-0.237900,-0.015723,-0.263194,-0.262510,-0.129004,0.044147
+-0.098800,-0.198408,-0.285785,-0.215330,0.144839,0.058866,0.134202,-0.277945
+0.104450,-0.315220,0.281811,0.119572,-0.118884,0.150589,0.235453,0.027785
+0.290633,0.310023,0.057572,0.111782,-0.170578,0.139947,-0.184608,0.244825
+0.096128,-0.229602,0.293317,-0.007293,0.063514,-0.044505,0.003487,0.318592
+-0.167282,-0.040221,-0.118525,-0.079515,-0.183656,-0.289839,0.146194,0.207801
+-0.166197,0.101291,0.104141,-0.217941,0.081460,-0.054502,0.027711,0.047377
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_context_test4.csv
 
b/src/test/resources/component/transformers/bert_layer/input_W_context_test4.csv
new file mode 100644
index 0000000000..df3243ba95
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_W_context_test4.csv
@@ -0,0 +1,4 @@
+-0.479624,-0.174917,-0.325733,0.013896
+-0.296353,-0.409811,-0.025660,-0.043134
+-0.125165,-0.106358,0.357925,0.101191
+-0.243557,0.106878,-0.051400,0.317920
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_intermediate_test3.csv
 
b/src/test/resources/component/transformers/bert_layer/input_W_intermediate_test3.csv
new file mode 100644
index 0000000000..7232251997
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_W_intermediate_test3.csv
@@ -0,0 +1,8 @@
+0.266992,-0.212761,-0.063619,-0.239211,-0.241276
+0.353417,0.317851,0.137643,-0.348794,-0.057128
+0.309145,0.170822,0.062799,-0.283923,-0.229609
+0.273917,0.192735,0.150426,0.279120,0.245506
+-0.081055,-0.221599,-0.120147,0.191285,-0.267289
+-0.124077,0.101434,0.172383,0.331709,-0.172499
+0.290285,-0.123943,-0.246947,0.283239,-0.341565
+0.198124,0.276251,0.079852,-0.315739,-0.200728
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_intermediate_test4.csv
 
b/src/test/resources/component/transformers/bert_layer/input_W_intermediate_test4.csv
new file mode 100644
index 0000000000..cd86f5ffa1
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_W_intermediate_test4.csv
@@ -0,0 +1,4 @@
+-0.449161,-0.248523,-0.101418,0.311891
+-0.237039,-0.383156,0.274203,-0.391255
+0.340453,-0.467926,0.270321,-0.105705
+-0.003241,-0.422004,-0.482216,-0.202736
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_out_test3.csv 
b/src/test/resources/component/transformers/bert_layer/input_W_out_test3.csv
new file mode 100644
index 0000000000..c207020fd7
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_W_out_test3.csv
@@ -0,0 +1,5 @@
+-0.114515,0.231953,-0.106802,0.317938,0.033243,0.428020,-0.070085,-0.121394
+0.196775,0.278010,-0.252099,-0.050899,-0.211145,0.339602,0.379730,-0.086280
+0.398436,-0.156521,-0.250939,-0.258796,0.411011,-0.163005,0.018478,0.042799
+0.147947,0.214585,-0.344122,0.345654,0.182884,0.251403,-0.316278,0.413596
+0.447069,0.051388,0.300228,0.285988,-0.339500,-0.254102,-0.149474,0.023944
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_W_out_test4.csv 
b/src/test/resources/component/transformers/bert_layer/input_W_out_test4.csv
new file mode 100644
index 0000000000..4b556cc6dd
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_W_out_test4.csv
@@ -0,0 +1,4 @@
+-0.078240,-0.450029,0.451502,-0.057697
+0.006466,-0.033744,0.181077,-0.223203
+-0.227137,0.439710,-0.451230,0.399827
+0.188350,-0.203946,0.316349,-0.404048
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_K_test3.csv 
b/src/test/resources/component/transformers/bert_layer/input_b_K_test3.csv
new file mode 100644
index 0000000000..3737690e1b
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_b_K_test3.csv
@@ -0,0 +1,8 @@
+0.169884
+0.124861
+-0.085021
+-0.074354
+-0.291357
+0.191571
+0.280714
+0.241910
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_K_test4.csv 
b/src/test/resources/component/transformers/bert_layer/input_b_K_test4.csv
new file mode 100644
index 0000000000..2f4932cf80
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_b_K_test4.csv
@@ -0,0 +1,4 @@
+0.470375
+0.336909
+-0.218013
+-0.125842
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_Q_test3.csv 
b/src/test/resources/component/transformers/bert_layer/input_b_Q_test3.csv
new file mode 100644
index 0000000000..1787a85fe5
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_b_Q_test3.csv
@@ -0,0 +1,8 @@
+-0.215151
+0.320804
+0.242290
+-0.298146
+-0.087994
+0.015953
+0.051584
+0.083854
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_Q_test4.csv 
b/src/test/resources/component/transformers/bert_layer/input_b_Q_test4.csv
new file mode 100644
index 0000000000..e1372696e9
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_b_Q_test4.csv
@@ -0,0 +1,4 @@
+-0.465977
+0.444246
+0.380180
+-0.498764
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_V_test3.csv 
b/src/test/resources/component/transformers/bert_layer/input_b_V_test3.csv
new file mode 100644
index 0000000000..0ed151a540
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_b_V_test3.csv
@@ -0,0 +1,8 @@
+-0.350424
+0.159592
+-0.169799
+-0.235940
+-0.203683
+0.203278
+0.187233
+0.271360
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_V_test4.csv 
b/src/test/resources/component/transformers/bert_layer/input_b_V_test4.csv
new file mode 100644
index 0000000000..3f2bdf9561
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_b_V_test4.csv
@@ -0,0 +1,4 @@
+0.196214
+0.029950
+-0.243964
+0.236594
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_context_test3.csv
 
b/src/test/resources/component/transformers/bert_layer/input_b_context_test3.csv
new file mode 100644
index 0000000000..9712d704da
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_b_context_test3.csv
@@ -0,0 +1,8 @@
+0.165114
+-0.171997
+-0.292986
+-0.304028
+0.352708
+0.224432
+-0.244388
+0.138325
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_context_test4.csv
 
b/src/test/resources/component/transformers/bert_layer/input_b_context_test4.csv
new file mode 100644
index 0000000000..baefdcee8b
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_b_context_test4.csv
@@ -0,0 +1,4 @@
+0.473623
+0.317528
+0.474707
+-0.036161
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_intermediate_test3.csv
 
b/src/test/resources/component/transformers/bert_layer/input_b_intermediate_test3.csv
new file mode 100644
index 0000000000..531a6d47bb
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_b_intermediate_test3.csv
@@ -0,0 +1,5 @@
+0.290790
+0.289475
+0.253087
+0.272988
+0.314375
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_intermediate_test4.csv
 
b/src/test/resources/component/transformers/bert_layer/input_b_intermediate_test4.csv
new file mode 100644
index 0000000000..24a3a16261
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_b_intermediate_test4.csv
@@ -0,0 +1,4 @@
+-0.096308
+-0.098171
+-0.448675
+-0.431719
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_out_test3.csv 
b/src/test/resources/component/transformers/bert_layer/input_b_out_test3.csv
new file mode 100644
index 0000000000..a75d6bf069
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_b_out_test3.csv
@@ -0,0 +1,8 @@
+-0.276128
+0.022922
+0.214433
+0.221836
+-0.408726
+-0.080025
+-0.332344
+-0.190820
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_b_out_test4.csv 
b/src/test/resources/component/transformers/bert_layer/input_b_out_test4.csv
new file mode 100644
index 0000000000..439f3dfd0f
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/input_b_out_test4.csv
@@ -0,0 +1,4 @@
+0.053652
+-0.104684
+0.357056
+0.139572
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_beta_ln1_test3.csv 
b/src/test/resources/component/transformers/bert_layer/input_beta_ln1_test3.csv
new file mode 100644
index 0000000000..69fba5d887
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_beta_ln1_test3.csv
@@ -0,0 +1,8 @@
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_beta_ln1_test4.csv 
b/src/test/resources/component/transformers/bert_layer/input_beta_ln1_test4.csv
new file mode 100644
index 0000000000..cec7b5178a
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_beta_ln1_test4.csv
@@ -0,0 +1,4 @@
+0.000000
+0.000000
+0.000000
+0.000000
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_beta_ln2_test3.csv 
b/src/test/resources/component/transformers/bert_layer/input_beta_ln2_test3.csv
new file mode 100644
index 0000000000..69fba5d887
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_beta_ln2_test3.csv
@@ -0,0 +1,8 @@
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_beta_ln2_test4.csv 
b/src/test/resources/component/transformers/bert_layer/input_beta_ln2_test4.csv
new file mode 100644
index 0000000000..cec7b5178a
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_beta_ln2_test4.csv
@@ -0,0 +1,4 @@
+0.000000
+0.000000
+0.000000
+0.000000
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_dstates_test3.csv 
b/src/test/resources/component/transformers/bert_layer/input_dstates_test3.csv
new file mode 100644
index 0000000000..30508d5072
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_dstates_test3.csv
@@ -0,0 +1,2 @@
+0.680147,0.144935,0.685863,0.924389,0.532794,0.166756,0.320858,0.609182,0.118841,0.748405,0.046065,0.019353,0.014170,0.398568,0.836216,0.026761,0.915594,0.299989,0.646442,0.522801,0.049140,0.914665,0.769222,0.996998
+0.752606,0.169966,0.917292,0.526872,0.737108,0.099085,0.356187,0.009061,0.305254,0.607866,0.107419,0.659382,0.768403,0.569655,0.165458,0.112340,0.345742,0.719479,0.993198,0.787515,0.443695,0.675308,0.009469,0.072949
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_dstates_test4.csv 
b/src/test/resources/component/transformers/bert_layer/input_dstates_test4.csv
new file mode 100644
index 0000000000..a93ead7d00
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_dstates_test4.csv
@@ -0,0 +1,4 @@
+0.740253,0.676579,0.379763,0.394847,0.087959,0.770922,0.896989,0.842112,0.147311,0.522300,0.147533,0.224758,0.208647,0.670873,0.202043,0.489091
+0.521034,0.822312,0.122040,0.156744,0.209669,0.849967,0.320267,0.921744,0.680804,0.563313,0.496278,0.401159,0.562733,0.385828,0.496487,0.563797
+0.108897,0.237934,0.903746,0.094227,0.464097,0.994619,0.680619,0.514157,0.066695,0.747689,0.143860,0.358068,0.332242,0.425956,0.505469,0.912404
+0.562419,0.947846,0.805856,0.183893,0.724252,0.146552,0.288087,0.647061,0.665096,0.875114,0.339042,0.500800,0.757412,0.016454,0.861490,0.086539
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_gamma_ln1_test3.csv
 
b/src/test/resources/component/transformers/bert_layer/input_gamma_ln1_test3.csv
new file mode 100644
index 0000000000..8d61340f4e
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_gamma_ln1_test3.csv
@@ -0,0 +1,8 @@
+1.000000
+1.000000
+1.000000
+1.000000
+1.000000
+1.000000
+1.000000
+1.000000
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_gamma_ln1_test4.csv
 
b/src/test/resources/component/transformers/bert_layer/input_gamma_ln1_test4.csv
new file mode 100644
index 0000000000..35bd33c42c
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_gamma_ln1_test4.csv
@@ -0,0 +1,4 @@
+1.000000
+1.000000
+1.000000
+1.000000
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_gamma_ln2_test3.csv
 
b/src/test/resources/component/transformers/bert_layer/input_gamma_ln2_test3.csv
new file mode 100644
index 0000000000..8d61340f4e
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_gamma_ln2_test3.csv
@@ -0,0 +1,8 @@
+1.000000
+1.000000
+1.000000
+1.000000
+1.000000
+1.000000
+1.000000
+1.000000
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_gamma_ln2_test4.csv
 
b/src/test/resources/component/transformers/bert_layer/input_gamma_ln2_test4.csv
new file mode 100644
index 0000000000..35bd33c42c
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_gamma_ln2_test4.csv
@@ -0,0 +1,4 @@
+1.000000
+1.000000
+1.000000
+1.000000
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_states_test3.csv 
b/src/test/resources/component/transformers/bert_layer/input_states_test3.csv
new file mode 100644
index 0000000000..d797a800e7
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_states_test3.csv
@@ -0,0 +1,2 @@
+0.496257,0.768222,0.088477,0.132030,0.307423,0.634079,0.490093,0.896445,0.455628,0.632306,0.348893,0.401717,0.022326,0.168859,0.293888,0.518522,0.697668,0.800011,0.161029,0.282269,0.681609,0.915194,0.397100,0.874156
+0.419408,0.552907,0.952738,0.036165,0.185231,0.373417,0.305100,0.932000,0.175910,0.269834,0.150680,0.031720,0.208130,0.929799,0.723109,0.742336,0.526296,0.243658,0.584592,0.033153,0.138717,0.242235,0.815469,0.793161
diff --git 
a/src/test/resources/component/transformers/bert_layer/input_states_test4.csv 
b/src/test/resources/component/transformers/bert_layer/input_states_test4.csv
new file mode 100644
index 0000000000..a15792d51b
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/input_states_test4.csv
@@ -0,0 +1,4 @@
+0.496257,0.768222,0.088477,0.132030,0.307423,0.634079,0.490093,0.896445,0.455628,0.632306,0.348893,0.401717,0.022326,0.168859,0.293888,0.518522
+0.697668,0.800011,0.161029,0.282269,0.681609,0.915194,0.397100,0.874156,0.419408,0.552907,0.952738,0.036165,0.185231,0.373417,0.305100,0.932000
+0.175910,0.269834,0.150680,0.031720,0.208130,0.929799,0.723109,0.742336,0.526296,0.243658,0.584592,0.033153,0.138717,0.242235,0.815469,0.793161
+0.278252,0.481959,0.819780,0.997067,0.698441,0.567546,0.835243,0.205599,0.593172,0.112347,0.153457,0.241708,0.726237,0.701080,0.203824,0.651054
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_attention_test3.csv
 
b/src/test/resources/component/transformers/bert_layer/output_attention_test3.csv
new file mode 100644
index 0000000000..1506f20100
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_attention_test3.csv
@@ -0,0 +1,2 @@
+0.339324,0.301625,0.359050,0.333347,0.324201,0.342453,0.337702,0.301196,0.361102,0.330164,0.329774,0.340062,0.330691,0.334149,0.335160,0.332653,0.322728,0.344619
+0.324605,0.342803,0.332592,0.304915,0.359697,0.335388,0.323544,0.342926,0.333530,0.332023,0.333600,0.334377,0.336369,0.333007,0.330623,0.327169,0.338724,0.334107
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_attention_test4.csv
 
b/src/test/resources/component/transformers/bert_layer/output_attention_test4.csv
new file mode 100644
index 0000000000..afe547864f
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_attention_test4.csv
@@ -0,0 +1,4 @@
+0.219257,0.275532,0.245886,0.259325,0.231218,0.266120,0.247480,0.255182,0.225803,0.270295,0.246807,0.257096,0.234147,0.264871,0.247535,0.253447,0.264908,0.242993,0.253594,0.238505,0.281723,0.232930,0.256558,0.228788,0.274176,0.237156,0.255232,0.233437,0.264776,0.242755,0.253492,0.238976
+0.226483,0.259603,0.243000,0.270914,0.230393,0.257917,0.244387,0.267304,0.234738,0.257874,0.243686,0.263703,0.236825,0.256908,0.244453,0.261814,0.272171,0.255864,0.237683,0.234281,0.286323,0.257591,0.231062,0.225024,0.286075,0.255542,0.232694,0.225689,0.273593,0.256246,0.236871,0.233289
+0.227024,0.255324,0.241590,0.276062,0.228867,0.254117,0.242903,0.274112,0.230974,0.253474,0.243860,0.271692,0.239153,0.257003,0.243018,0.260827,0.253968,0.250648,0.252840,0.242544,0.274183,0.243220,0.265539,0.217059,0.268435,0.243923,0.261816,0.225826,0.274571,0.242021,0.265637,0.217770
+0.263989,0.246212,0.242346,0.247454,0.270082,0.244850,0.241179,0.243890,0.270460,0.244742,0.240955,0.243842,0.273977,0.244024,0.240818,0.241181,0.204528,0.248260,0.269666,0.277547,0.205892,0.248446,0.269942,0.275720,0.230842,0.249643,0.257574,0.261941,0.218808,0.249099,0.262332,0.269761
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dW_K_test3.csv 
b/src/test/resources/component/transformers/bert_layer/output_dW_K_test3.csv
new file mode 100644
index 0000000000..41c821ec39
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/output_dW_K_test3.csv
@@ -0,0 +1,8 @@
+-0.000800,0.002193,0.002716,-0.003409,0.002864,0.002029,-0.002455,-0.000706
+-0.000441,0.000948,0.001567,-0.001058,0.000648,-0.000865,-0.002136,-0.002325
+0.000401,0.001514,-0.000556,-0.004101,0.003605,0.000648,-0.005681,-0.005509
+0.000675,-0.001038,-0.002015,0.000915,0.000238,0.000047,-0.000703,-0.000156
+-0.001472,0.001714,0.004574,-0.000843,0.001128,0.000282,-0.001619,0.000514
+-0.001334,-0.000618,0.003660,0.004245,-0.003013,-0.002482,0.002207,0.003067
+-0.000674,0.000650,0.001727,-0.000331,-0.000450,0.001820,0.003364,0.003928
+-0.001270,0.002213,0.004020,-0.002282,0.000910,0.000003,-0.001218,-0.001111
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dW_K_test4.csv 
b/src/test/resources/component/transformers/bert_layer/output_dW_K_test4.csv
new file mode 100644
index 0000000000..7fd3fe518b
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/output_dW_K_test4.csv
@@ -0,0 +1,4 @@
+0.007073,-0.017178,-0.001985,0.001731
+0.001372,-0.001396,-0.005469,0.000587
+-0.004756,0.021898,-0.032340,0.021663
+0.014448,-0.019484,0.040141,-0.030815
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dW_Q_test3.csv 
b/src/test/resources/component/transformers/bert_layer/output_dW_Q_test3.csv
new file mode 100644
index 0000000000..cb27c6220e
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/output_dW_Q_test3.csv
@@ -0,0 +1,8 @@
+-0.004327,-0.001824,0.003820,-0.007255,0.003028,-0.000715,0.004038,0.000054
+-0.005571,-0.002385,0.004874,-0.009210,0.002710,-0.000481,0.002893,0.000198
+-0.001934,0.000165,0.002551,-0.004000,0.007030,-0.003440,0.009369,0.000679
+-0.002374,-0.001257,0.002088,-0.003690,0.000263,-0.000041,-0.000198,0.000005
+-0.002690,-0.001321,0.001997,-0.004416,0.000071,0.000868,0.000402,-0.000193
+-0.003213,-0.002986,0.001748,-0.004936,-0.004474,0.004202,-0.001206,-0.002735
+-0.001570,-0.002067,0.000775,-0.002456,-0.002100,0.002469,0.002623,-0.002737
+-0.004890,-0.002547,0.004074,-0.008366,0.002524,0.000428,0.006496,-0.001373
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dW_Q_test4.csv 
b/src/test/resources/component/transformers/bert_layer/output_dW_Q_test4.csv
new file mode 100644
index 0000000000..3a0f444d3b
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/output_dW_Q_test4.csv
@@ -0,0 +1,4 @@
+0.001177,-0.002373,0.000195,-0.000035
+0.002378,-0.003385,-0.001492,-0.001461
+0.006328,-0.005630,0.002396,0.001547
+0.006038,-0.006465,0.000975,0.000445
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dW_V_test3.csv 
b/src/test/resources/component/transformers/bert_layer/output_dW_V_test3.csv
new file mode 100644
index 0000000000..d82c28866e
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/output_dW_V_test3.csv
@@ -0,0 +1,8 @@
+-0.194254,0.011101,-0.261176,0.257742,0.316156,0.042150,-0.080481,-0.057790
+-0.286023,0.125485,-0.311945,0.337610,0.413694,0.016837,0.085687,-0.082430
+0.010286,-0.327077,-0.198318,0.106174,0.131656,0.122892,-0.607356,-0.004081
+-0.124912,0.120547,-0.087809,0.117464,0.148210,-0.020191,0.164380,-0.033625
+-0.128909,0.050774,-0.154180,0.164703,0.194505,0.015256,0.022213,-0.039221
+-0.176909,-0.089873,-0.306133,0.275754,0.340057,0.081202,-0.254454,-0.056162
+-0.071766,-0.270244,-0.267289,0.188806,0.244573,0.123532,-0.528974,-0.026357
+-0.210769,-0.223723,-0.436915,0.366907,0.453524,0.138179,-0.523960,-0.068175
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dW_V_test4.csv 
b/src/test/resources/component/transformers/bert_layer/output_dW_V_test4.csv
new file mode 100644
index 0000000000..344ba4124a
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/output_dW_V_test4.csv
@@ -0,0 +1,4 @@
+0.276217,-0.191201,0.025388,0.344058
+0.921757,-0.135658,0.203731,1.106362
+0.399174,-0.041598,0.298835,0.472401
+0.667157,-0.103496,0.211956,0.733260
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dW_context_test3.csv
 
b/src/test/resources/component/transformers/bert_layer/output_dW_context_test3.csv
new file mode 100644
index 0000000000..605ff3444d
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_dW_context_test3.csv
@@ -0,0 +1,8 @@
+-0.455333,-0.055141,-1.117865,0.130950,0.758799,-0.500184,0.700354,0.538421
+0.252326,0.072590,0.492536,-0.087689,-0.713985,0.302306,-0.168812,-0.149272
+-0.066676,-0.002103,-0.183949,0.014920,0.063023,-0.072432,0.140905,0.106312
+-0.036082,-0.120186,0.158029,0.066969,0.741043,-0.101635,-0.438162,-0.269976
+-0.130435,-0.004683,-0.346593,0.031118,0.144792,-0.138018,0.253627,0.190193
+0.227982,0.045455,0.517740,-0.070984,-0.480188,0.262608,-0.276156,-0.226458
+0.041401,0.055498,-0.008929,-0.035699,-0.368040,0.072268,0.154043,0.089457
+-0.035908,0.049622,-0.214338,-0.015427,-0.273679,-0.010559,0.299348,0.200941
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dW_context_test4.csv
 
b/src/test/resources/component/transformers/bert_layer/output_dW_context_test4.csv
new file mode 100644
index 0000000000..06817f9dda
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_dW_context_test4.csv
@@ -0,0 +1,4 @@
+1.208036,-1.200712,-0.271654,0.264330
+-0.649720,0.530402,0.004314,0.115004
+0.552197,-0.526363,-0.158731,0.132896
+-1.486619,1.083122,0.335908,0.067589
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dW_intermediate_test3.csv
 
b/src/test/resources/component/transformers/bert_layer/output_dW_intermediate_test3.csv
new file mode 100644
index 0000000000..1d1a4e30f8
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_dW_intermediate_test3.csv
@@ -0,0 +1,8 @@
+0.023180,0.136647,-0.164352,-0.066014,0.019287
+-0.047562,0.567295,-0.063649,-0.015320,0.026060
+-0.139436,-0.632434,0.239251,-0.300146,-0.094239
+0.201011,0.650226,0.032437,-0.433129,-0.349227
+-0.027887,0.030523,0.004450,-0.037535,0.130548
+-0.038294,-0.198302,0.133103,0.369786,0.139019
+0.206919,-0.236830,-0.041701,-0.056939,-0.143315
+-0.177930,-0.317124,-0.139540,0.539296,0.271868
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dW_intermediate_test4.csv
 
b/src/test/resources/component/transformers/bert_layer/output_dW_intermediate_test4.csv
new file mode 100644
index 0000000000..718c48c8f9
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_dW_intermediate_test4.csv
@@ -0,0 +1,4 @@
+-0.034056,0.050829,-0.036211,0.112254
+0.086796,-0.070301,0.082565,-0.132312
+-0.151455,0.059607,0.276365,-0.037135
+0.098714,-0.040135,-0.322719,0.057194
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dW_out_test3.csv 
b/src/test/resources/component/transformers/bert_layer/output_dW_out_test3.csv
new file mode 100644
index 0000000000..3d2430b31e
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_dW_out_test3.csv
@@ -0,0 +1,5 @@
+-0.048979,0.184899,0.169423,-0.143216,-0.079362,-0.035919,0.114787,-0.161634
+0.245306,0.042411,0.241224,-0.100651,-0.152223,-0.138499,0.103267,-0.240837
+0.192060,-0.049454,0.058602,-0.047356,-0.162173,-0.034218,0.116288,-0.073749
+-0.104499,0.090034,-0.244273,0.076709,0.230143,0.069535,-0.008323,-0.109325
+-0.031902,0.082104,-0.169339,-0.081263,-0.117729,0.023760,0.299454,-0.005086
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dW_out_test4.csv 
b/src/test/resources/component/transformers/bert_layer/output_dW_out_test4.csv
new file mode 100644
index 0000000000..b34bde4760
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_dW_out_test4.csv
@@ -0,0 +1,4 @@
+0.038803,0.129503,-0.055618,-0.112688
+0.046289,-0.083861,0.022904,0.014668
+-0.150651,0.156957,0.021754,-0.028059
+0.216684,-0.123484,0.045174,-0.138374
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_db_K_test3.csv 
b/src/test/resources/component/transformers/bert_layer/output_db_K_test3.csv
new file mode 100644
index 0000000000..9590f6dbc6
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/output_db_K_test3.csv
@@ -0,0 +1,8 @@
+0.000000
+-0.000000
+-0.000000
+0.000000
+0.000000
+-0.000000
+-0.000000
+-0.000000
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_db_K_test4.csv 
b/src/test/resources/component/transformers/bert_layer/output_db_K_test4.csv
new file mode 100644
index 0000000000..c5a15205d6
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/output_db_K_test4.csv
@@ -0,0 +1,4 @@
+0.000000
+-0.000000
+-0.000000
+0.000000
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_db_Q_test3.csv 
b/src/test/resources/component/transformers/bert_layer/output_db_Q_test3.csv
new file mode 100644
index 0000000000..6d4da6e7f8
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/output_db_Q_test3.csv
@@ -0,0 +1,8 @@
+-0.006451
+-0.004007
+0.005343
+-0.010633
+0.001388
+0.001217
+0.006591
+-0.002476
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_db_Q_test4.csv 
b/src/test/resources/component/transformers/bert_layer/output_db_Q_test4.csv
new file mode 100644
index 0000000000..525024eba4
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/output_db_Q_test4.csv
@@ -0,0 +1,4 @@
+0.007216
+-0.010793
+0.000941
+-0.000706
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_db_V_test3.csv 
b/src/test/resources/component/transformers/bert_layer/output_db_V_test3.csv
new file mode 100644
index 0000000000..da5f313b04
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/output_db_V_test3.csv
@@ -0,0 +1,8 @@
+-0.290922
+-0.239237
+-0.544494
+0.467195
+0.589094
+0.158871
+-0.582568
+-0.090341
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_db_V_test4.csv 
b/src/test/resources/component/transformers/bert_layer/output_db_V_test4.csv
new file mode 100644
index 0000000000..ddcb98fd50
--- /dev/null
+++ b/src/test/resources/component/transformers/bert_layer/output_db_V_test4.csv
@@ -0,0 +1,4 @@
+1.377245
+-0.092954
+0.632183
+1.599774
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_db_context_test3.csv
 
b/src/test/resources/component/transformers/bert_layer/output_db_context_test3.csv
new file mode 100644
index 0000000000..330af9111a
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_db_context_test3.csv
@@ -0,0 +1,8 @@
+0.659098
+0.096211
+1.591416
+-0.193318
+-1.176293
+0.738551
+-0.963428
+-0.752237
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_db_context_test4.csv
 
b/src/test/resources/component/transformers/bert_layer/output_db_context_test4.csv
new file mode 100644
index 0000000000..d9ea8bf1f5
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_db_context_test4.csv
@@ -0,0 +1,4 @@
+-4.979834
+3.744857
+1.099222
+0.135753
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_db_intermediate_test3.csv
 
b/src/test/resources/component/transformers/bert_layer/output_db_intermediate_test3.csv
new file mode 100644
index 0000000000..2e9e4b1801
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_db_intermediate_test3.csv
@@ -0,0 +1,5 @@
+-0.092516
+-0.071798
+-0.090165
+0.333943
+0.161394
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_db_intermediate_test4.csv
 
b/src/test/resources/component/transformers/bert_layer/output_db_intermediate_test4.csv
new file mode 100644
index 0000000000..8f3226611f
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_db_intermediate_test4.csv
@@ -0,0 +1,4 @@
+-0.153303
+-0.000545
+0.388259
+-0.120641
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_db_out_test3.csv 
b/src/test/resources/component/transformers/bert_layer/output_db_out_test3.csv
new file mode 100644
index 0000000000..7b73704ecb
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_db_out_test3.csv
@@ -0,0 +1,8 @@
+0.449773
+-0.065906
+0.666695
+0.049729
+-0.530572
+0.132503
+-0.539876
+-0.162347
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_db_out_test4.csv 
b/src/test/resources/component/transformers/bert_layer/output_db_out_test4.csv
new file mode 100644
index 0000000000..4f0144d7cf
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_db_out_test4.csv
@@ -0,0 +1,4 @@
+-0.504401
+0.600470
+0.004841
+-0.100911
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dbeta_ln1_test3.csv
 
b/src/test/resources/component/transformers/bert_layer/output_dbeta_ln1_test3.csv
new file mode 100644
index 0000000000..776bf0901c
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_dbeta_ln1_test3.csv
@@ -0,0 +1,8 @@
+0.327260
+-0.259531
+0.488296
+0.129820
+-0.475590
+0.204088
+-0.496108
+-0.345546
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dbeta_ln1_test4.csv
 
b/src/test/resources/component/transformers/bert_layer/output_dbeta_ln1_test4.csv
new file mode 100644
index 0000000000..2a8e60f5fa
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_dbeta_ln1_test4.csv
@@ -0,0 +1,4 @@
+-0.512412
+0.790681
+0.070611
+-0.262950
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dbeta_ln2_test3.csv
 
b/src/test/resources/component/transformers/bert_layer/output_dbeta_ln2_test3.csv
new file mode 100644
index 0000000000..749bb7e79b
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_dbeta_ln2_test3.csv
@@ -0,0 +1,8 @@
+3.118183
+2.690639
+3.396280
+3.440313
+2.545311
+2.824037
+2.457409
+1.827291
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dbeta_ln2_test4.csv
 
b/src/test/resources/component/transformers/bert_layer/output_dbeta_ln2_test4.csv
new file mode 100644
index 0000000000..b00bddcfb8
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_dbeta_ln2_test4.csv
@@ -0,0 +1,4 @@
+6.839521
+9.654258
+7.589570
+7.291402
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dgamma_ln1_test3.csv
 
b/src/test/resources/component/transformers/bert_layer/output_dgamma_ln1_test3.csv
new file mode 100644
index 0000000000..d60a4145db
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_dgamma_ln1_test3.csv
@@ -0,0 +1,8 @@
+0.055622
+0.046851
+0.785715
+-0.233341
+-0.365912
+0.387671
+-0.480873
+-0.785140
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dgamma_ln1_test4.csv
 
b/src/test/resources/component/transformers/bert_layer/output_dgamma_ln1_test4.csv
new file mode 100644
index 0000000000..8ec2dcfc57
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_dgamma_ln1_test4.csv
@@ -0,0 +1,4 @@
+0.005107
+-0.093251
+-0.028160
+0.358046
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dgamma_ln2_test3.csv
 
b/src/test/resources/component/transformers/bert_layer/output_dgamma_ln2_test3.csv
new file mode 100644
index 0000000000..3793192359
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_dgamma_ln2_test3.csv
@@ -0,0 +1,8 @@
+0.621295
+1.390660
+-1.203606
+-4.456588
+-1.122497
+1.470548
+-2.603842
+2.672874
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dgamma_ln2_test4.csv
 
b/src/test/resources/component/transformers/bert_layer/output_dgamma_ln2_test4.csv
new file mode 100644
index 0000000000..73c85ba036
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_dgamma_ln2_test4.csv
@@ -0,0 +1,4 @@
+1.984963
+2.184791
+3.854489
+-4.108948
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dstates_test3.csv 
b/src/test/resources/component/transformers/bert_layer/output_dstates_test3.csv
new file mode 100644
index 0000000000..3c97f5d29d
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_dstates_test3.csv
@@ -0,0 +1,2 @@
+0.502547,-0.504824,-0.150117,0.635026,0.126505,-0.583656,-0.532031,0.614263,-0.932515,1.796216,-0.484930,-0.852762,-1.178533,0.348883,1.891500,-0.506830,0.465477,-0.610949,0.030952,-0.183036,-1.450269,0.603921,0.374889,0.866466
+0.930711,-1.038537,1.419745,-0.211939,0.910450,-0.576544,-0.729507,-0.786905,-0.521563,0.138159,-0.852133,0.733773,0.942996,0.782001,-0.549455,-0.755850,-0.155300,0.148986,1.863989,-0.400224,-0.630028,0.420651,-1.421238,0.099039
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_dstates_test4.csv 
b/src/test/resources/component/transformers/bert_layer/output_dstates_test4.csv
new file mode 100644
index 0000000000..9840827fa3
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_dstates_test4.csv
@@ -0,0 +1,4 @@
+0.035862,-0.096464,-0.223688,-0.038358,-0.814778,-1.351295,0.739883,1.027551,-1.231988,1.321418,-0.487106,0.031419,-2.586746,2.184707,0.025841,0.007052
+-0.591683,0.563377,-0.289221,0.044621,-2.859349,2.553339,-0.853717,0.813563,0.289863,-0.027537,-0.367460,-0.186507,0.821949,-0.696216,-0.475074,0.011081
+-1.327160,-1.293002,2.709496,-0.142882,0.312226,0.603502,-0.534132,-0.438499,-0.493119,1.275077,-0.240618,-0.600009,-0.849704,0.228645,-0.145541,0.701722
+0.345673,1.157150,0.555622,-1.657624,1.029993,-0.893659,-0.142705,0.350081,0.209598,0.989423,-0.596598,-0.282764,1.683423,-2.474998,1.262697,-0.136837
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_states_test3.csv 
b/src/test/resources/component/transformers/bert_layer/output_states_test3.csv
new file mode 100644
index 0000000000..5465dc087d
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_states_test3.csv
@@ -0,0 +1,2 @@
+0.171729,1.207777,-1.369620,-1.140675,-0.256681,0.471505,-0.688117,1.604081,0.628548,1.626016,-0.923234,-0.315812,-0.843941,-0.334182,-1.178016,1.340621,0.367672,0.922365,-1.395176,-1.058890,0.378696,0.715008,-1.247252,1.317578
+0.117804,0.505541,0.710127,-1.514862,-0.635495,0.223551,-1.169896,1.763229,-0.406121,-0.149171,-1.123053,-1.124631,-0.271830,1.639958,-0.146873,1.581722,0.371575,-0.379796,0.150463,-1.654266,-0.710140,-0.126219,0.257467,2.090917
diff --git 
a/src/test/resources/component/transformers/bert_layer/output_states_test4.csv 
b/src/test/resources/component/transformers/bert_layer/output_states_test4.csv
new file mode 100644
index 0000000000..bdd33860a9
--- /dev/null
+++ 
b/src/test/resources/component/transformers/bert_layer/output_states_test4.csv
@@ -0,0 +1,4 @@
+0.593964,1.310086,-0.708939,-1.195111,-1.721363,0.688767,0.637501,0.395095,0.584127,1.146331,-0.214276,-1.516182,-0.637527,-0.867744,1.676665,-0.171395
+0.960430,0.986848,-0.651338,-1.295941,0.799327,1.180061,-1.058779,-0.920608,-0.001309,0.117444,1.352510,-1.468646,-0.578104,-1.301749,0.620299,1.259553
+0.595655,0.744104,0.377504,-1.717264,-1.455665,1.142275,0.670086,-0.356695,0.754210,-0.266670,1.029259,-1.516799,-0.633334,-1.016914,1.600725,0.049524
+-0.665618,-0.977035,1.612907,0.029747,0.708916,0.052460,0.885179,-1.646555,1.497432,-0.455357,0.188676,-1.230751,1.370018,0.527546,-0.799829,-1.097734
diff --git a/src/test/scripts/applications/nn/component/bert_layer_backward.dml 
b/src/test/scripts/applications/nn/component/bert_layer_backward.dml
new file mode 100644
index 0000000000..985b07de01
--- /dev/null
+++ b/src/test/scripts/applications/nn/component/bert_layer_backward.dml
@@ -0,0 +1,198 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("scripts/nn/layers/bert_layer.dml") as bert_layer
+
+debug = as.logical(as.integer($1))
+
+B = as.integer($2)
+T = as.integer($3)
+D = as.integer($4)
+H = as.integer($5)
+d = as.integer($6)
+I = as.integer($7)
+
+dropout_p_attention = 0.0
+dropout_p_output = 0.0
+epsilon_ln = 1e-012
+activation = $8
+
+states = matrix(read($9, format="csv"), rows=B, cols=T*D)
+
+W_Q = matrix(read($10, format="csv"), rows=D, cols=D)
+b_Q = matrix(read($11, format="csv"), rows=1, cols=D)
+W_K = matrix(read($12, format="csv"), rows=D, cols=D)
+b_K = matrix(read($13, format="csv"), rows=1, cols=D)
+W_V = matrix(read($14, format="csv"), rows=D, cols=D)
+b_V = matrix(read($15, format="csv"), rows=1, cols=D)
+W_context = matrix(read($16, format="csv"), rows=D, cols=D)
+b_context = matrix(read($17, format="csv"), rows=1, cols=D)
+W_intermediate = matrix(read($18, format="csv"), rows=D, cols=I)
+b_intermediate = matrix(read($19, format="csv"), rows=1, cols=I)
+W_out = matrix(read($20, format="csv"), rows=I, cols=D)
+b_out = matrix(read($21, format="csv"), rows=1, cols=D)
+
+gamma_ln1 = matrix(read($22, format="csv"), rows=1, cols=D)
+beta_ln1 = matrix(read($23, format="csv"), rows=1, cols=D)
+gamma_ln2 = matrix(read($24, format="csv"), rows=1, cols=D)
+beta_ln2 = matrix(read($25, format="csv"), rows=1, cols=D)
+
+expected_out_states = read($26, format="csv")
+expected_attention = read($27, format="csv")
+
+dout_states = matrix(read($28, format="csv"), rows=B, cols=T*D)
+
+expected_din_states = matrix(read($29, format="csv"), rows=B, cols=T*D)
+
+expected_dW_Q = read($30, format="csv")
+expected_db_Q = matrix(read($31, format="csv"), rows=1, cols=D)
+expected_dW_K = read($32, format="csv")
+expected_db_K = matrix(read($33, format="csv"), rows=1, cols=D)
+expected_dW_V = read($34, format="csv")
+expected_db_V = matrix(read($35, format="csv"), rows=1, cols=D)
+expected_dW_context = read($36, format="csv")
+expected_db_context = matrix(read($37, format="csv"), rows=1, cols=D)
+expected_dW_intermediate = read($38, format="csv")
+expected_db_intermediate = matrix(read($39, format="csv"), rows=1, cols=I)
+expected_dW_out = read($40, format="csv")
+expected_db_out = matrix(read($41, format="csv"), rows=1, cols=D)
+
+expected_dgamma_ln1 = matrix(read($42, format="csv"), rows=1, cols=D)
+expected_dbeta_ln1 = matrix(read($43, format="csv"), rows=1, cols=D)
+expected_dgamma_ln2 = matrix(read($44, format="csv"), rows=1, cols=D)
+expected_dbeta_ln2 = matrix(read($45, format="csv"), rows=1, cols=D)
+
+
+[out_states, attention, outputs, dropout_mask_attention, 
dropout_mask_output_1, dropout_mask_output_2, cache_mean_ln1,
+  cache_var_ln1, cache_norm_ln1, cache_mean_ln2, cache_var_ln2, 
cache_norm_ln2] = bert_layer::forward(states,
+    H, T, d, I,
+    W_Q,  b_Q, 
+    W_K,  b_K, 
+    W_V,  b_V,
+    W_context,  b_context, 
+    W_intermediate,  b_intermediate, 
+    W_out,  b_out, 
+    dropout_p_attention, 
+    dropout_p_output,
+    epsilon_ln,
+    gamma_ln1,  beta_ln1,
+    gamma_ln2,  beta_ln2,
+    activation
+)
+
+[din_states, dW_Q, db_Q, dW_K, db_K, dW_V, db_V, dW_context, db_context, 
dW_intermediate, db_intermediate, dW_out, db_out,
+  dgamma_ln1, dbeta_ln1, dgamma_ln2, dbeta_ln2] = 
bert_layer::backward(dout_states, 
+    dropout_mask_attention, dropout_mask_output_1, dropout_mask_output_2, 
+    cache_mean_ln1, cache_var_ln1, cache_norm_ln1,
+    cache_mean_ln2, cache_var_ln2, cache_norm_ln2,
+    outputs,
+    states,
+    H, T, d, I,
+    W_Q, b_Q,
+    W_K, b_K,
+    W_V, b_V,
+    W_context, b_context,
+    W_intermediate, b_intermediate,
+    W_out, b_out,
+    dropout_p_attention,
+    dropout_p_output,
+    epsilon_ln,
+    gamma_ln1, beta_ln1,
+    gamma_ln2, beta_ln2,
+    activation
+)
+
+if (debug) {
+  print(toString(out_states))
+  print(toString(attention))
+}
+
+states_error = max(abs(expected_out_states - out_states))
+attention_error = max(abs(expected_attention - attention))
+
+if (debug) {
+  print(states_error)
+  print(attention_error)
+}
+
+din_error = max(abs(din_states - expected_din_states))
+dW_Q_error = max(abs(dW_Q - expected_dW_Q))
+db_Q_error = max(abs(db_Q - expected_db_Q))
+dW_K_error = max(abs(dW_K - expected_dW_K))
+db_K_error = max(abs(db_K - expected_db_K))
+dW_V_error = max(abs(dW_V - expected_dW_V))
+db_V_error = max(abs(db_V - expected_db_V))
+dW_context_error = max(abs(dW_context - expected_dW_context))
+db_context_error = max(abs(db_context - expected_db_context))
+dW_intermediate_error = max(abs(dW_intermediate - expected_dW_intermediate))
+db_intermediate_error = max(abs(db_intermediate - expected_db_intermediate))
+dW_out_error = max(abs(dW_out - expected_dW_out))
+db_out_error = max(abs(db_out - expected_db_out))
+dgamma_ln1_error = max(abs(dgamma_ln1 - expected_dgamma_ln1))
+dbeta_ln1_error = max(abs(dbeta_ln1 - expected_dbeta_ln1))
+dgamma_ln2_error = max(abs(dgamma_ln2 - expected_dgamma_ln2))
+dbeta_ln2_error = max(abs(dbeta_ln2 - expected_dbeta_ln2))
+
+if (debug) {
+  print(din_states)
+  
+  print(dW_Q)
+  print(db_Q)
+
+  print(dW_K)
+  print(db_K)
+
+  print(dW_V)
+  print(db_V)
+
+  print(dW_context)
+  print(db_context)
+
+  print(dW_intermediate)
+  print(db_intermediate)
+
+  print(dW_out)
+  print(db_out)
+
+  print(dgamma_ln1)
+  print(dbeta_ln1)
+
+  print(dgamma_ln2)
+  print(dbeta_ln2)
+}
+
+write(din_error, $46, format="text")
+write(dW_Q_error, $47, format="text")
+write(db_Q_error, $48, format="text")
+write(dW_K_error, $49, format="text")
+write(db_K_error, $50, format="text")
+write(dW_V_error, $51, format="text")
+write(db_V_error, $52, format="text")
+write(dW_context_error, $53, format="text")
+write(db_context_error, $54, format="text")
+write(dW_intermediate_error, $55, format="text")
+write(db_intermediate_error, $56, format="text")
+write(dW_out_error, $57, format="text")
+write(db_out_error, $58, format="text")
+write(dgamma_ln1_error, $59, format="text")
+write(dbeta_ln1_error, $60, format="text")
+write(dgamma_ln2_error, $61, format="text")
+write(dbeta_ln2_error, $62, format="text")

Reply via email to