This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new fb9a3f9875 [SYSTEMDS-3018] Add U-Net Model
fb9a3f9875 is described below

commit fb9a3f987568fb845f8d1498ad356f547135b1ed
Author: sebwrede <[email protected]>
AuthorDate: Fri Nov 11 17:43:18 2022 +0100

    [SYSTEMDS-3018] Add U-Net Model
    
    This adds a U-Net model to nn examples. It can be trained with or without 
parameter server setup.
    The model also supports homomorphic encryption of gradient updates.
    
    Closes #1764.
---
 scripts/nn/examples/u-net.dml                      | 949 +++++++++++++++++++++
 scripts/nn/layers/conv2d_transpose.dml             |  29 +
 scripts/utils/image_utils.dml                      |  22 +-
 .../paramserv/dp/ShuffleFederatedScheme.java       |   4 +-
 .../cp/ParamservBuiltinCPInstruction.java          |  51 +-
 .../paramserv/EncryptedFederatedParamservTest.java |  93 +-
 .../paramserv/FederatedParamservTest.java          |  80 +-
 .../federated/paramserv/ParamServTestUtils.java    | 126 +++
 .../paramserv/EncryptedFederatedParamservTest.dml  |  19 +
 .../federated/paramserv/FederatedParamservTest.dml |  21 +-
 10 files changed, 1268 insertions(+), 126 deletions(-)

diff --git a/scripts/nn/examples/u-net.dml b/scripts/nn/examples/u-net.dml
new file mode 100644
index 0000000000..202dda36e8
--- /dev/null
+++ b/scripts/nn/examples/u-net.dml
@@ -0,0 +1,949 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+*  This file contains functions used for training a U-Net model, with or 
without a parameter server setup.
+*/
+
+source("scripts/nn/layers/affine.dml") as affine
+source("scripts/nn/layers/conv2d_builtin.dml") as conv2d
+source("scripts/nn/layers/conv2d_transpose.dml") as conv2d_transpose
+source("scripts/nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("scripts/nn/layers/dropout.dml") as dropout
+source("scripts/nn/layers/l2_reg.dml") as l2_reg
+source("scripts/nn/layers/max_pool2d_builtin.dml") as max_pool2d
+source("scripts/nn/layers/relu.dml") as relu
+source("scripts/nn/layers/softmax.dml") as softmax
+source("scripts/nn/optim/sgd_momentum.dml") as sgd_momentum
+source("scripts/nn/layers/dropout.dml") as dropout
+source("scripts/utils/image_utils.dml") as img_utils
+
+/*
+* Pad input features X with extrapolated data by mirroring.
+* Only the height and width are padded, no extra channels are added.
+* Dimensions changed from (N,C*Hin*Win) to (N,C*(Hin+184)*(Win+184)).
+*
+*  Inputs:
+*  - X: Features to pad
+*  - N: Number of input elements of X
+*  - C: Number of channels of X
+*  - Hin: Height of each element of X
+*  - Win: Width of each element of X
+*
+*  Outputs:
+*  - X_extrapolated: X padded with extrapolated data
+*  - input_HW: Height and Width of X_extrapolated
+*/
+extrapolate = function(matrix[double] X, int N, int C, int Hin, int Win) 
return (matrix[double] X_extrapolated, int input_HW){
+    input_HW = Hin + 184 # Assuming filter HW 3 and conv stride 1
+    pad_size = 92 # 184 / 2
+
+    X_extrapolated = matrix(0, rows=N, cols=C*input_HW*input_HW)
+
+    for ( i in 1:C ){
+        start_channel = ((i-1) * Hin * Win)+1
+        end_channel = i * Hin * Win
+        original_channel = X[,start_channel:end_channel]
+        # Iterate through the N rows of X each representing a single channel
+        for ( row in 1:N ){
+            img = matrix(original_channel[row], rows=Hin, cols=Win)
+
+            pad_left = t(rev(t(img[,1:pad_size])))
+            pad_right = t(rev(t(img[,(Win-(pad_size-1)):Win])))
+            pad_top = rev(img[1:(pad_size),])
+            pad_bottom = rev(img[(Hin-(pad_size-1)):Hin,])
+            pad_top_left = rev(pad_left[1:(pad_size),])
+            pad_top_right = rev(pad_right[1:(pad_size),])
+            pad_bottom_left = rev(pad_left[(Hin-(pad_size-1)):Hin,])
+            pad_bottom_right = rev(pad_right[(Hin-(pad_size-1)):Hin])
+
+            pad_left_full = rbind(pad_top_left, pad_left, pad_bottom_left)
+            pad_right_full = rbind(pad_top_right, pad_right, pad_bottom_right)
+            pad_center_full = rbind(pad_top, img, pad_bottom)
+
+            modified_channel = cbind(pad_left_full, pad_center_full, 
pad_right_full)
+
+            flat_width = input_HW*input_HW
+            start_col = ((i-1)*flat_width)+1
+            end_col = i*flat_width
+            X_extrapolated[row,start_col:end_col] = matrix(modified_channel, 
rows=1, cols=flat_width)
+        }
+    }
+}
+
+/*
+ * Trains a U-Net model with a parameter server setup.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector.  The targets, Y, have K
+ * classes representing the segmentation map of the input.
+ *
+ * Inputs:
+ *  - X: Input data matrix, of shape (N, C*Hin*Win)
+ *  - y: Target matrix, of shape (N, K)
+ *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win)
+ *  - y_val: Target validation matrix, of shape (N, K)
+ *  - C: Number of input channels
+ *  - Hin: Input height
+ *  - Win: Input width
+ *  - epochs: Total number of full training loops over the full data set
+ *  - workers: Number of federated workers
+ *  - utype: Update type (synchronous, asynchronous, etc.)
+ *  - freq: Frequency of weight updates ("BATCH" or "EPOCH")
+ *  - batch_size: Batch size
+ *  - scheme: Parameter server training scheme
+ *  - learning_rate: The learning rate for the SGD with momentum
+ *  - seed: Seed for the initialization of the convolution weights. Default is 
-1 meaning that the seeds are random.
+ *  - he: Homomorphic encryption activated (boolean)
+ *  - F1: Number of filters of the top layer of the U-Net model. Default is 64.
+ *
+ * Outputs:
+ *  - model_trained: List containing weights and biases
+ */
+train_paramserv = function(matrix[double] X, matrix[double] y,
+                 matrix[double] X_val, matrix[double] y_val,
+                 int C, int Hin, int Win, int epochs, int workers,
+                 string utype, string freq, int batch_size, string scheme, 
double learning_rate,
+                 int seed = -1, boolean he = FALSE, int F1 = 64)
+    return (list[unknown] model_trained) {
+  N = nrow(X) # Number of inputs
+  K = ncol(y) # Number of target classes
+
+  # Define model network constants
+  Hf = 3  # convolution filter height
+  Wf = 3  # convolution filter width
+  conv_stride = 1
+  pool_stride = 2
+  pool_HWf = 2
+  conv_t_HWf = 2
+  conv_t_stride = 2
+  pad = 0  # For same dimensions, (Hf - stride) / 2
+  F2 = F1 * 2
+  F3 = F2 * 2
+  F4 = F3 * 2
+  F5 = F4 * 2
+  dropProb = 1.0
+  dropSeed = -1
+
+  # Create different seeds for each layer, unless seed is -1
+  lseed = list()
+  for ( i in 1:23 ){
+    if (seed == -1){
+        lseed = rbind(lseed, -1)
+    } else {
+        lseed = rbind(lseed, seed+i)
+    }
+  }
+
+  # Initialize convolution weights
+
+  # First step
+  [W1, b1] = conv2d::init(F1, C, Hf, Wf, seed = 
as.integer(as.scalar(lseed[1])))  # inputs: (N, C*Hin*Win)
+  [W2, b2] = conv2d::init(F1, F1, Hf, Wf, seed = 
as.integer(as.scalar(lseed[2])))
+  # Second step
+  [W3, b3] = conv2d::init(F2, F1, Hf, Wf, seed = 
as.integer(as.scalar(lseed[3])))
+  [W4, b4] = conv2d::init(F2, F2, Hf, Wf, seed = 
as.integer(as.scalar(lseed[4])))
+  # Third step
+  [W5, b5] = conv2d::init(F3, F2, Hf, Wf, seed = 
as.integer(as.scalar(lseed[5])))
+  [W6, b6] = conv2d::init(F3, F3, Hf, Wf, seed = 
as.integer(as.scalar(lseed[6])))
+  # Fourth step
+  [W7, b7] = conv2d::init(F4, F3, Hf, Wf, seed = 
as.integer(as.scalar(lseed[7])))
+  [W8, b8] = conv2d::init(F4, F4, Hf, Wf, seed = 
as.integer(as.scalar(lseed[8])))
+  # Fifth step
+  [W9, b9] = conv2d::init(F5, F4, Hf, Wf, seed = 
as.integer(as.scalar(lseed[9])))
+  [W10, b10] = conv2d::init(F5, F5, Hf, Wf, seed = 
as.integer(as.scalar(lseed[10])))
+  # First Up-convolution
+  [W11, b11] = conv2d_transpose::init_seed(F4, F5, conv_t_HWf, conv_t_HWf, 
seed = as.integer(as.scalar(lseed[11])))
+  [W12, b12] = conv2d::init(F4, F5, Hf, Wf, seed = 
as.integer(as.scalar(lseed[12])))
+  [W13, b13] = conv2d::init(F4, F4, Hf, Wf, seed = 
as.integer(as.scalar(lseed[13])))
+  # Second Up-convolution
+  [W14, b14] = conv2d_transpose::init_seed(F3, F4, conv_t_HWf, conv_t_HWf, 
seed = as.integer(as.scalar(lseed[14])))
+  [W15, b15] = conv2d::init(F3, F4, Hf, Wf, seed = 
as.integer(as.scalar(lseed[15])))
+  [W16, b16] = conv2d::init(F3, F3, Hf, Wf, seed = 
as.integer(as.scalar(lseed[16])))
+  # Third Up-convolution
+  [W17, b17] = conv2d_transpose::init_seed(F2, F3, conv_t_HWf, conv_t_HWf, 
seed = as.integer(as.scalar(lseed[17])))
+  [W18, b18] = conv2d::init(F2, F3, Hf, Wf, seed = 
as.integer(as.scalar(lseed[18])))
+  [W19, b19] = conv2d::init(F2, F2, Hf, Wf, seed = 
as.integer(as.scalar(lseed[19])))
+  # Fourth Up-convolution
+  [W20, b20] = conv2d_transpose::init_seed(F1, F2, conv_t_HWf, conv_t_HWf, 
seed = as.integer(as.scalar(lseed[20])))
+  [W21, b21] = conv2d::init(F1, F2, Hf, Wf, seed = 
as.integer(as.scalar(lseed[21])))
+  [W22, b22] = conv2d::init(F1, F1, Hf, Wf, seed = 
as.integer(as.scalar(lseed[22])))
+  # Segmentation map
+  [W23, b23] = conv2d::init(C, F1, 1, 1, seed = 
as.integer(as.scalar(lseed[23])))
+
+  # Initialize SGD with momentum
+  vW1 = sgd_momentum::init(W1); vb1 = sgd_momentum::init(b1)
+  vW2 = sgd_momentum::init(W2); vb2 = sgd_momentum::init(b2)
+  vW3 = sgd_momentum::init(W3); vb3 = sgd_momentum::init(b3)
+  vW4 = sgd_momentum::init(W4); vb4 = sgd_momentum::init(b4)
+  vW5 = sgd_momentum::init(W5); vb5 = sgd_momentum::init(b5)
+  vW6 = sgd_momentum::init(W6); vb6 = sgd_momentum::init(b6)
+  vW7 = sgd_momentum::init(W7); vb7 = sgd_momentum::init(b7)
+  vW8 = sgd_momentum::init(W8); vb8 = sgd_momentum::init(b8)
+  vW9 = sgd_momentum::init(W9); vb9 = sgd_momentum::init(b9)
+  vW10 = sgd_momentum::init(W10); vb10 = sgd_momentum::init(b10)
+  vW11 = sgd_momentum::init(W11); vb11 = sgd_momentum::init(b11)
+  vW12 = sgd_momentum::init(W12); vb12 = sgd_momentum::init(b12)
+  vW13 = sgd_momentum::init(W13); vb13 = sgd_momentum::init(b13)
+  vW14 = sgd_momentum::init(W14); vb14 = sgd_momentum::init(b14)
+  vW15 = sgd_momentum::init(W15); vb15 = sgd_momentum::init(b15)
+  vW16 = sgd_momentum::init(W16); vb16 = sgd_momentum::init(b16)
+  vW17 = sgd_momentum::init(W17); vb17 = sgd_momentum::init(b17)
+  vW18 = sgd_momentum::init(W18); vb18 = sgd_momentum::init(b18)
+  vW19 = sgd_momentum::init(W19); vb19 = sgd_momentum::init(b19)
+  vW20 = sgd_momentum::init(W20); vb20 = sgd_momentum::init(b20)
+  vW21 = sgd_momentum::init(W21); vb21 = sgd_momentum::init(b21)
+  vW22 = sgd_momentum::init(W22); vb22 = sgd_momentum::init(b22)
+  vW23 = sgd_momentum::init(W23); vb23 = sgd_momentum::init(b23)
+
+  # Define optimizer constants
+  mu = 0.9  # momentum
+  decay = 0.95  # learning rate decay constant
+
+  # Regularization
+  lambda = 5e-04
+
+  # Create the model list
+  model_list = list(
+    W1, W2, W3, W4, W5, W6, W7, W8, W9, W10, W11, W12, W13, W14, W15, W16, 
W17, W18, W19, W20, W21, W22, W23,
+    b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15, b16, 
b17, b18, b19, b20, b21, b22, b23,
+    vW1, vW2, vW3, vW4, vW5, vW6, vW7, vW8, vW9, vW10, vW11, vW12, vW13, vW14, 
vW15, vW16, vW17, vW18, vW19, vW20, vW21, vW22, vW23,
+    vb1, vb2, vb3, vb4, vb5, vb6, vb7, vb8, vb9, vb10, vb11, vb12, vb13, vb14, 
vb15, vb16, vb17, vb18, vb19, vb20, vb21, vb22, vb23)
+
+  # Create the hyper parameter list
+  params = list(
+    learning_rate=learning_rate, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, 
Hf=Hf, Wf=Wf,
+    conv_stride=conv_stride, pool_stride=pool_stride, pool_HWf=pool_HWf, 
conv_t_HWf=conv_t_HWf, conv_t_stride=conv_t_stride,
+    pad=pad, lambda=lambda, F1=F1, F2=F2, F3=F3, F4=F4, F5=F5, 
dropProb=dropProb, dropSeed=dropSeed)
+
+  # Use paramserv function
+  model_trained = paramserv(model=model_list, features=X, labels=y, 
val_features=X_val, val_labels=y_val,
+    upd="./scripts/nn/examples/u-net.dml::gradients",
+    agg="./scripts/nn/examples/u-net.dml::aggregation",
+    scheme=scheme, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
+    k=workers, hyperparams=params, checkpointing="NONE", he=he, modelAvg=TRUE)
+}
+
+/*
+* Forward pass of U-Net model on X using specified batch size.
+*
+*  Inputs:
+*  - X: Input features of size C*Hin*Win.
+*       The features need to be padded with mirrored data, hence the actual 
input size before padding is C*(Hin-pad)*(Win-pad).
+*  - C: Number of input channels
+*  - Hin: Input height
+*  - Win: Input width
+*  - batch_size: Batch size
+*  - model: List of weights of the model (23 weights, 23 biases)
+*  - K: Size of the segmentation map (C*(Hin-pad)*(Win-pad))
+*  - F1: Number of filters of the top layer of the U-Net model. Default is 64.
+*
+*  Output:
+*  - probs: Segmentation map probabilities generated by the forward pass of 
the U-Net model
+*/
+predict = function(matrix[double] X, int C, int Hin, int Win, int batch_size, 
list[unknown] model, int K, int F1 = 64)
+    return (matrix[double] probs) {
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  W4 = as.matrix(model[4])
+  W5 = as.matrix(model[5])
+  W6 = as.matrix(model[6])
+  W7 = as.matrix(model[7])
+  W8 = as.matrix(model[8])
+  W9 = as.matrix(model[9])
+  W10 = as.matrix(model[10])
+  W11 = as.matrix(model[11])
+  W12 = as.matrix(model[12])
+  W13 = as.matrix(model[13])
+  W14 = as.matrix(model[14])
+  W15 = as.matrix(model[15])
+  W16 = as.matrix(model[16])
+  W17 = as.matrix(model[17])
+  W18 = as.matrix(model[18])
+  W19 = as.matrix(model[19])
+  W20 = as.matrix(model[20])
+  W21 = as.matrix(model[21])
+  W22 = as.matrix(model[22])
+  W23 = as.matrix(model[23])
+  b1 = as.matrix(model[24])
+  b2 = as.matrix(model[25])
+  b3 = as.matrix(model[26])
+  b4 = as.matrix(model[27])
+  b5 = as.matrix(model[28])
+  b6 = as.matrix(model[29])
+  b7 = as.matrix(model[30])
+  b8 = as.matrix(model[31])
+  b9 = as.matrix(model[32])
+  b10 = as.matrix(model[33])
+  b11 = as.matrix(model[34])
+  b12 = as.matrix(model[35])
+  b13 = as.matrix(model[36])
+  b14 = as.matrix(model[37])
+  b15 = as.matrix(model[38])
+  b16 = as.matrix(model[39])
+  b17 = as.matrix(model[40])
+  b18 = as.matrix(model[41])
+  b19 = as.matrix(model[42])
+  b20 = as.matrix(model[43])
+  b21 = as.matrix(model[44])
+  b22 = as.matrix(model[45])
+  b23 = as.matrix(model[46])
+  N = nrow(X) # Number of inputs
+
+  Hf = 3  # convolution filter height
+  Wf = 3  # convolution filter width
+  conv_stride = 1
+  pool_stride = 2
+  pool_HWf = 2
+  conv_t_HWf = 2
+  conv_t_stride = 2
+  pad = 0  # For same dimensions, (Hf - stride) / 2
+  F2 = F1 * 2
+  F3 = F2 * 2
+  F4 = F3 * 2
+  F5 = F4 * 2
+  dropProb = 1.0
+  dropSeed = -1
+
+  # Compute predictions over mini-batches
+  probs = matrix(0, rows=N, cols=K)
+  iters = ceil(N / batch_size)
+  for(i in 1:iters, check=0) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1
+    end = min(N, beg + batch_size - 1)
+    X_batch = X[beg:end,]
+
+    # Down-Convolution
+    [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, 
Hf, Wf, conv_stride, conv_stride, pad, pad)
+    outr1 = relu::forward(outc1)
+    [outc2, Houtc2, Woutc2] = conv2d::forward(outr1, W2, b2, F1, Houtc1, 
Woutc1, Hf, Wf, conv_stride, conv_stride, pad, pad)
+    outr2 = relu::forward(outc2)
+    [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr2, F1, Houtc2, Woutc2, 
pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
+    [outc3, Houtc3, Woutc3] = conv2d::forward(outp1, W3, b3, F1, Houtp1, 
Woutp1, Hf, Wf, conv_stride, conv_stride, pad, pad)
+    outr3 = relu::forward(outc3)
+    [outc4, Houtc4, Woutc4] = conv2d::forward(outr3, W4, b4, F2, Houtc3, 
Woutc3, Hf, Wf, conv_stride, conv_stride, pad, pad)
+    outr4 = relu::forward(outc4)
+    [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr4, F2, Houtc4, Woutc4, 
pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
+    [outc5, Houtc5, Woutc5] = conv2d::forward(outp2, W5, b5, F2, Houtp2, 
Woutp2, Hf, Wf, conv_stride, conv_stride, pad, pad)
+    outr5 = relu::forward(outc5)
+    [outc6, Houtc6, Woutc6] = conv2d::forward(outr5, W6, b6, F3, Houtc5, 
Woutc5, Hf, Wf, conv_stride, conv_stride, pad, pad)
+    outr6 = relu::forward(outc6)
+    [outp3, Houtp3, Woutp3] = max_pool2d::forward(outr6, F3, Houtc6, Woutc6, 
pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
+    [outc7, Houtc7, Woutc7] = conv2d::forward(outp3, W7, b7, F3, Houtp3, 
Woutp3, Hf, Wf, conv_stride, conv_stride, pad, pad)
+    outr7 = relu::forward(outc7)
+    [outc8, Houtc8, Woutc8] = conv2d::forward(outr7, W8, b8, F4, Houtc7, 
Woutc7, Hf, Wf, conv_stride, conv_stride, pad, pad)
+    outr8 = relu::forward(outc8)
+    [outp4, Houtp4, Woutp4] = max_pool2d::forward(outr8, F4, Houtc8, Woutc8, 
pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
+    [outd1, mask1]          = dropout::forward(outp4, dropProb, dropSeed)
+
+    # Bottom
+    [outc9, Houtc9, Woutc9] = conv2d::forward(outd1, W9, b9, F4, Houtp4, 
Woutp4, Hf, Wf, conv_stride, conv_stride, pad, pad)
+    outr9 = relu::forward(outc9)
+    [outc10, Houtc10, Woutc10] = conv2d::forward(outr9, W10, b10, F5, Houtc9, 
Woutc9, Hf, Wf, conv_stride, conv_stride, pad, pad)
+    outr10 = relu::forward(outc10)
+    [outc11, Houtc11, Woutc11] = conv2d_transpose::forward(outr10, W11, b11, 
F5, Houtc10, Woutc10, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, 
pad, pad, 0, 0)
+
+    # Up-Convolution
+    outConcat1 = cbind(img_utils::crop_channel(outr8, Houtc8, Woutc8, Houtc11, 
Woutc11, F4),outc11)
+    [outc12, Houtc12, Woutc12] = conv2d::forward(outConcat1, W12, b12, F5, 
Houtc11, Woutc11, Hf, Wf, conv_stride, conv_stride, pad, pad)
+    outr11 = relu::forward(outc12)
+    [outc13, Houtc13, Woutc13] = conv2d::forward(outr11, W13, b13, F4, 
Houtc12, Woutc12, Hf, Wf, conv_stride, conv_stride, pad, pad)
+    outr12 = relu::forward(outc13)
+    [outc14, Houtc14, Woutc14] = conv2d_transpose::forward(outr12, W14, b14, 
F4, Houtc13, Woutc13, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, 
pad, pad, 0, 0)
+    outConcat2 = cbind(img_utils::crop_channel(outr6, Houtc6, Woutc6, Houtc14, 
Woutc14, F3),outc14)
+    [outc15, Houtc15, Woutc15] = conv2d::forward(outConcat2, W15, b15, F4, 
Houtc14, Woutc14, Hf, Wf, conv_stride, conv_stride, pad, pad)
+    outr13 = relu::forward(outc15)
+    [outc16, Houtc16, Woutc16] = conv2d::forward(outr13, W16, b16, F3, 
Houtc15, Woutc15, Hf, Wf, conv_stride, conv_stride, pad, pad)
+    outr14 = relu::forward(outc16)
+    [outc17, Houtc17, Woutc17] = conv2d_transpose::forward(outr14, W17, b17, 
F3, Houtc16, Woutc16, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, 
pad, pad, 0, 0)
+    outConcat3 = cbind(img_utils::crop_channel(outr4, Houtc4, Woutc4, Houtc17, 
Woutc17, F2), outc17)
+    [outc18, Houtc18, Woutc18] = conv2d::forward(outConcat3, W18, b18, F3, 
Houtc17, Woutc17, Hf, Wf, conv_stride, conv_stride, pad, pad)
+    outr15 = relu::forward(outc18)
+    [outc19, Houtc19, Woutc19] = conv2d::forward(outr15, W19, b19, F2, 
Houtc18, Woutc18, Hf, Wf, conv_stride, conv_stride, pad, pad)
+    outr16 = relu::forward(outc19)
+    [outc20, Houtc20, Woutc20] = conv2d_transpose::forward(outr16, W20, b20, 
F2, Houtc19, Woutc19, conv_t_HWf, conv_t_HWf, conv_t_stride, conv_t_stride, 
pad, pad, 0, 0)
+    outConcat4 = cbind(img_utils::crop_channel(outr2, Houtc2, Woutc2, Houtc20, 
Woutc20, F1), outc20)
+    [outc21, Houtc21, Woutc21] = conv2d::forward(outConcat4, W21, b21, F2, 
Houtc20, Woutc20, Hf, Wf, conv_stride, conv_stride, pad, pad)
+    outr17 = relu::forward(outc21)
+    [outc22, Houtc22, Woutc22] = conv2d::forward(outr17, W22, b22, F1, 
Houtc21, Woutc21, Hf, Wf, conv_stride, conv_stride, pad, pad)
+    outr18 = relu::forward(outc22)
+
+    # This last conv2d needs to create the segmentation map (1x1 filter):
+    [outc23, Houtc23, Woutc23] = conv2d::forward(outr18, W23, b23, F1, 
Houtc22, Woutc22, 1, 1, conv_stride, conv_stride, pad, pad)
+
+    # Store predictions
+    probs[beg:end,] = softmax::forward(outc23)
+  }
+}
+
+/*
+*  Forward and backward pass of U-Net with gradients returned.
+*
+*  Inputs:
+*  - model: List of model weights
+*  - hyperparams: List of hyper parameters containing:
+*                 - (scalar[integer]) C: Number of input channels
+*                 - (scalar[integer]) Hin: Input height
+*                 - (scalar[integer]) Win: Input width
+*                 - (scalar[integer]) Hf: Filter height
+*                 - (scalar[integer]) Wf: Filter width
+*                 - (scalar[integer]) pool_stride: Stride of the max pool 
operations
+*                 - (scalar[integer]) pool_HWf: Filter height and width of the 
max pool operation
+*                 - (scalar[integer]) conv_stride: Stride of all convolutions
+*                 - (scalar[integer]) conv_t_HWf: Filter height and width of 
the transpose convolutions
+*                 - (scalar[integer]) conv_t_stride: Stride of the transpose 
convolutions
+*                 - (scalar[integer]) pad: Padding of all convolutions and 
transpose convolutions
+*                 - (scalar[double]) lambda: Regularization strength
+*                 - (scalar[integer]) F1, F2, F3, F4, F5:  Number of filters 
of the convolutions in the five layers
+*                 - (scalar[double]) dropProb: Dropout probability
+*                 - (scalar[integer]) dropSeed: Dropout seed
+*  - features: Features of size C*Hin*Win. The features need to be padded with 
mirrored data.
+*              The input feature size should result in an output size of the 
U-Net equal to the label size.
+*              See extrapolate function for how to pad the features by 
extrapolating.
+*  - labels: Labels of size C * (Hin-pad) * (Win-pad) representing a 
segmentation map of the input features.
+*  Output:
+*  - gradients: List of gradients
+*/
+gradients = function(list[unknown] model,
+                     list[unknown] hyperparams,
+                     matrix[double] features,
+                     matrix[double] labels)
+          return (list[unknown] gradients) {
+        C = as.integer(as.scalar(hyperparams["C"]))
+        Hin = as.integer(as.scalar(hyperparams["Hin"]))
+        Win = as.integer(as.scalar(hyperparams["Win"]))
+        Hf = as.integer(as.scalar(hyperparams["Hf"]))
+        Wf = as.integer(as.scalar(hyperparams["Wf"]))
+        pool_stride = as.integer(as.scalar(hyperparams["pool_stride"]))
+        pool_HWf = as.integer(as.scalar(hyperparams["pool_HWf"]))
+        conv_stride = as.integer(as.scalar(hyperparams["conv_stride"]))
+        conv_t_HWf = as.integer(as.scalar(hyperparams["conv_t_HWf"]))
+        conv_t_stride = as.integer(as.scalar(hyperparams["conv_t_stride"]))
+        pad = as.integer(as.scalar(hyperparams["pad"]))
+        lambda = as.double(as.scalar(hyperparams["lambda"]))
+        F1 = as.integer(as.scalar(hyperparams["F1"]))
+        F2 = as.integer(as.scalar(hyperparams["F2"]))
+        F3 = as.integer(as.scalar(hyperparams["F3"]))
+        F4 = as.integer(as.scalar(hyperparams["F4"]))
+        F5 = as.integer(as.scalar(hyperparams["F5"]))
+        dropProb = as.double(as.scalar(hyperparams["dropProb"]))
+        dropSeed = as.integer(as.scalar(hyperparams["dropSeed"]))
+        W1 = as.matrix(model[1])
+        W2 = as.matrix(model[2])
+        W3 = as.matrix(model[3])
+        W4 = as.matrix(model[4])
+        W5 = as.matrix(model[5])
+        W6 = as.matrix(model[6])
+        W7 = as.matrix(model[7])
+        W8 = as.matrix(model[8])
+        W9 = as.matrix(model[9])
+        W10 = as.matrix(model[10])
+        W11 = as.matrix(model[11])
+        W12 = as.matrix(model[12])
+        W13 = as.matrix(model[13])
+        W14 = as.matrix(model[14])
+        W15 = as.matrix(model[15])
+        W16 = as.matrix(model[16])
+        W17 = as.matrix(model[17])
+        W18 = as.matrix(model[18])
+        W19 = as.matrix(model[19])
+        W20 = as.matrix(model[20])
+        W21 = as.matrix(model[21])
+        W22 = as.matrix(model[22])
+        W23 = as.matrix(model[23])
+        b1 = as.matrix(model[24])
+        b2 = as.matrix(model[25])
+        b3 = as.matrix(model[26])
+        b4 = as.matrix(model[27])
+        b5 = as.matrix(model[28])
+        b6 = as.matrix(model[29])
+        b7 = as.matrix(model[30])
+        b8 = as.matrix(model[31])
+        b9 = as.matrix(model[32])
+        b10 = as.matrix(model[33])
+        b11 = as.matrix(model[34])
+        b12 = as.matrix(model[35])
+        b13 = as.matrix(model[36])
+        b14 = as.matrix(model[37])
+        b15 = as.matrix(model[38])
+        b16 = as.matrix(model[39])
+        b17 = as.matrix(model[40])
+        b18 = as.matrix(model[41])
+        b19 = as.matrix(model[42])
+        b20 = as.matrix(model[43])
+        b21 = as.matrix(model[44])
+        b22 = as.matrix(model[45])
+        b23 = as.matrix(model[46])
+
+        # Down-Convolution
+        [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, 
Win, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        outr1 = relu::forward(outc1)
+        [outc2, Houtc2, Woutc2] = conv2d::forward(outr1, W2, b2, F1, Houtc1, 
Woutc1, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        outr2 = relu::forward(outc2)
+        [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr2, F1, Houtc2, 
Woutc2, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
+        [outc3, Houtc3, Woutc3] = conv2d::forward(outp1, W3, b3, F1, Houtp1, 
Woutp1, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        outr3 = relu::forward(outc3)
+        [outc4, Houtc4, Woutc4] = conv2d::forward(outr3, W4, b4, F2, Houtc3, 
Woutc3, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        outr4 = relu::forward(outc4)
+        [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr4, F2, Houtc4, 
Woutc4, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
+        [outc5, Houtc5, Woutc5] = conv2d::forward(outp2, W5, b5, F2, Houtp2, 
Woutp2, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        outr5 = relu::forward(outc5)
+        [outc6, Houtc6, Woutc6] = conv2d::forward(outr5, W6, b6, F3, Houtc5, 
Woutc5, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        outr6 = relu::forward(outc6)
+        [outp3, Houtp3, Woutp3] = max_pool2d::forward(outr6, F3, Houtc6, 
Woutc6, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
+        [outc7, Houtc7, Woutc7] = conv2d::forward(outp3, W7, b7, F3, Houtp3, 
Woutp3, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        outr7 = relu::forward(outc7)
+        [outc8, Houtc8, Woutc8] = conv2d::forward(outr7, W8, b8, F4, Houtc7, 
Woutc7, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        outr8 = relu::forward(outc8)
+        [outp4, Houtp4, Woutp4] = max_pool2d::forward(outr8, F4, Houtc8, 
Woutc8, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
+        [outd1, mask1]          = dropout::forward(outp4, dropProb, dropSeed)
+
+        # Bottom
+        [outc9, Houtc9, Woutc9] = conv2d::forward(outd1, W9, b9, F4, Houtp4, 
Woutp4, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        outr9 = relu::forward(outc9)
+        [outc10, Houtc10, Woutc10] = conv2d::forward(outr9, W10, b10, F5, 
Houtc9, Woutc9, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        outr10 = relu::forward(outc10)
+        [outc11, Houtc11, Woutc11] = conv2d_transpose::forward(outr10, W11, 
b11, F5, Houtc10, Woutc10, conv_t_HWf, conv_t_HWf, conv_t_stride, 
conv_t_stride, pad, pad, 0, 0)
+
+        # Up-Convolution
+        outConcat1 = cbind(img_utils::crop_channel(outr8, Houtc8, Woutc8, 
Houtc11, Woutc11, F4),outc11)
+        [outc12, Houtc12, Woutc12] = conv2d::forward(outConcat1, W12, b12, F5, 
Houtc11, Woutc11, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        outr11 = relu::forward(outc12)
+        [outc13, Houtc13, Woutc13] = conv2d::forward(outr11, W13, b13, F4, 
Houtc12, Woutc12, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        outr12 = relu::forward(outc13)
+        [outc14, Houtc14, Woutc14] = conv2d_transpose::forward(outr12, W14, 
b14, F4, Houtc13, Woutc13, conv_t_HWf, conv_t_HWf, conv_t_stride, 
conv_t_stride, pad, pad, 0, 0)
+        outConcat2 = cbind(img_utils::crop_channel(outr6, Houtc6, Woutc6, 
Houtc14, Woutc14, F3),outc14)
+        [outc15, Houtc15, Woutc15] = conv2d::forward(outConcat2, W15, b15, F4, 
Houtc14, Woutc14, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        outr13 = relu::forward(outc15)
+        [outc16, Houtc16, Woutc16] = conv2d::forward(outr13, W16, b16, F3, 
Houtc15, Woutc15, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        outr14 = relu::forward(outc16)
+        [outc17, Houtc17, Woutc17] = conv2d_transpose::forward(outr14, W17, 
b17, F3, Houtc16, Woutc16, conv_t_HWf, conv_t_HWf, conv_t_stride, 
conv_t_stride, pad, pad, 0, 0)
+        outConcat3 = cbind(img_utils::crop_channel(outr4, Houtc4, Woutc4, 
Houtc17, Woutc17, F2), outc17)
+        [outc18, Houtc18, Woutc18] = conv2d::forward(outConcat3, W18, b18, F3, 
Houtc17, Woutc17, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        outr15 = relu::forward(outc18)
+        [outc19, Houtc19, Woutc19] = conv2d::forward(outr15, W19, b19, F2, 
Houtc18, Woutc18, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        outr16 = relu::forward(outc19)
+        [outc20, Houtc20, Woutc20] = conv2d_transpose::forward(outr16, W20, 
b20, F2, Houtc19, Woutc19, conv_t_HWf, conv_t_HWf, conv_t_stride, 
conv_t_stride, pad, pad, 0, 0)
+        outConcat4 = cbind(img_utils::crop_channel(outr2, Houtc2, Woutc2, 
Houtc20, Woutc20, F1), outc20)
+        [outc21, Houtc21, Woutc21] = conv2d::forward(outConcat4, W21, b21, F2, 
Houtc20, Woutc20, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        outr17 = relu::forward(outc21)
+        [outc22, Houtc22, Woutc22] = conv2d::forward(outr17, W22, b22, F1, 
Houtc21, Woutc21, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        outr18 = relu::forward(outc22)
+
+        # This last conv2d needs to create the segmentation map (1x1 filter):
+        [outc23, Houtc23, Woutc23] = conv2d::forward(outr18, W23, b23, F1, 
Houtc22, Woutc22, 1, 1, conv_stride, conv_stride, pad, pad)
+        probs = softmax::forward(outc23)
+
+        # Compute loss & accuracy for training data
+        loss = cross_entropy_loss::forward(probs, labels)
+        accuracy = mean(rowIndexMax(probs) == rowIndexMax(labels))
+        # print("[+] Completed forward pass on batch: train loss: " + loss + 
", train accuracy: " + accuracy)
+
+        # Compute data backward pass
+
+        ## loss
+        dprobs = cross_entropy_loss::backward(probs, labels)
+        doutc23 = softmax::backward(dprobs, outc23)
+
+        # Up-Convolution
+        # conv2d parameters: (previous_gradient, output height, output width, 
input to original layer, layer weight, layer bias, layer input channel number, 
input height, input width, filter height, filter width, stride height, stride 
width, pad height, pad width)
+        [doutc22, dW23, db23] = conv2d::backward(doutc23, Houtc23, Woutc23, 
outr18, W23, b23, F1, Houtc22, Woutc22, 1, 1, conv_stride, conv_stride, pad, 
pad)
+        doutr18 = relu::backward(doutc22, outc22)
+        [doutc21, dW22, db22] = conv2d::backward(doutr18, Houtc22, Woutc22, 
outr17, W22, b22, F1, Houtc21, Woutc21, Hf, Wf, conv_stride, conv_stride, pad, 
pad)
+        doutr17 = relu::backward(doutc21, outc21)
+        [doutc20, dW21, db21] = conv2d::backward(doutr17, Houtc21, Woutc21, 
outConcat4, W21, b21, F2, Houtc20, Woutc20, Hf, Wf, conv_stride, conv_stride, 
pad, pad)
+        doutc20_cropped = doutc20[,(F1*Houtc20*Woutc20+1):(ncol(doutc20))] 
#Removing half of the gradients since they are related to a different layer.
+        [doutc19, dW20, db20] = conv2d_transpose::backward(doutc20_cropped, 
Houtc20, Woutc20, outr16, W20, b20, F2, Houtc19, Woutc19, conv_t_HWf, 
conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad)
+        doutr16 = relu::backward(doutc19, outc19)
+        [doutc18, dW19, db19] = conv2d::backward(doutr16, Houtc19, Woutc19, 
outr15, W19, b19, F2, Houtc18, Woutc18, Hf, Wf, conv_stride, conv_stride, pad, 
pad)
+        doutr15 = relu::backward(doutc18, outc18)
+        [doutc17, dW18, db18] = conv2d::backward(doutr15, Houtc18, Woutc18, 
outConcat3, W18, b18, F3, Houtc17, Woutc17, Hf, Wf, conv_stride, conv_stride, 
pad, pad)
+        doutc17_cropped = doutc17[,(F2*Houtc17*Woutc17+1):ncol(doutc17)]
+        [doutc16, dW17, db17] = conv2d_transpose::backward(doutc17_cropped, 
Houtc17, Woutc17, outr14, W17, b17, F3, Houtc16, Woutc16, conv_t_HWf, 
conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad)
+        doutr14 = relu::backward(doutc16, outc16)
+        [doutc15, dW16, db16] = conv2d::backward(doutr14, Houtc16, Woutc16, 
outr13, W16, b16, F3, Houtc15, Woutc15, Hf, Wf, conv_stride, conv_stride, pad, 
pad)
+        doutr13 = relu::backward(doutc15, outc15)
+        [doutc14, dW15, db15] = conv2d::backward(doutr13, Houtc15, Woutc15, 
outConcat2, W15, b15, F4, Houtc14, Woutc14, Hf, Wf, conv_stride, conv_stride, 
pad, pad)
+        doutc14_cropped = doutc14[,(F3*Houtc14*Woutc14+1):ncol(doutc14)]
+        [doutc13, dW14, db14] = conv2d_transpose::backward(doutc14_cropped, 
Houtc14, Woutc14, outr12, W14, b14, F4, Houtc13, Woutc13, conv_t_HWf, 
conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad)
+        doutr12 = relu::backward(doutc13, outc13)
+        [doutc12, dW13, db13] = conv2d::backward(doutr12, Houtc13, Woutc13, 
outr11, W13, b13, F4, Houtc12, Woutc12, Hf, Wf, conv_stride, conv_stride, pad, 
pad)
+        doutr11 = relu::backward(doutc12, outc12)
+        [doutc11, dW12, db12] = conv2d::backward(doutr11, Houtc12, Woutc12, 
outConcat1, W12, b12, F5, Houtc11, Woutc11, Hf, Wf, conv_stride, conv_stride, 
pad, pad)
+
+        # Bottom
+        doutc11_cropped = doutc11[,(F4*Houtc11*Woutc11+1):ncol(doutc11)]
+        [doutc10, dW11, db11] = conv2d_transpose::backward(doutc11_cropped, 
Houtc11, Woutc11, outr10, W11, b11, F5, Houtc10, Woutc10, conv_t_HWf, 
conv_t_HWf, conv_t_stride, conv_t_stride, pad, pad)
+        doutr10 = relu::backward(doutc10, outc10)
+        [doutc9, dW10, db10] = conv2d::backward(doutr10, Houtc10, Woutc10, 
outr9, W10, b10, F5, Houtc9, Woutc9, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        doutr9 = relu::backward(doutc9, outc9)
+        [doutc8, dW9, db9] = conv2d::backward(doutr9, Houtc9, Woutc9, outd1, 
W9, b9, F4, Houtp4, Woutp4, Hf, Wf, conv_stride, conv_stride, pad, pad)
+
+        # Down-Convolution
+        doutd1 = dropout::backward(doutc8, outp4, dropProb, mask1)
+        doutp4 = max_pool2d::backward(doutd1, Houtp4, Woutp4, outr8, F4, 
Houtc8, Woutc8, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
+        doutr8 = relu::backward(doutp4, outc8)
+        [doutc7, dW8, db8] = conv2d::backward(doutr8, Houtc8, Woutc8, outr7, 
W8, b8, F4, Houtc7, Woutc7, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        doutr7 = relu::backward(doutc7, outc7)
+        [doutc6, dW7, db7] = conv2d::backward(doutr7, Houtc7, Woutc7, outp3, 
W7, b7, F3, Houtp3, Woutp3, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        doutp3 = max_pool2d::backward(doutc6, Houtp3, Woutp3, outr6, F3, 
Houtc6, Woutc6, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
+        doutr6 = relu::backward(doutp3, outc6)
+        [doutc5, dW6, db6] = conv2d::backward(doutr6, Houtc6, Woutc6, outr5, 
W6, b6, F3, Houtc5, Woutc5, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        doutr5 = relu::backward(doutc5, outc5)
+        [doutc4, dW5, db5] = conv2d::backward(doutr5, Houtc5, Woutc5, outp2, 
W5, b5, F2, Houtp2, Woutp2, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        doutp2 = max_pool2d::backward(doutc4, Houtp2, Woutp2, outr4, F2, 
Houtc4, Woutc4, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
+        doutr4 = relu::backward(doutp2, outc4)
+        [doutc3, dW4, db4] = conv2d::backward(doutr4, Houtc4, Woutc4, outr3, 
W4, b4, F2, Houtc3, Woutc3, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        doutr3 = relu::backward(doutc3, outc3)
+        [doutc2, dW3, db3] = conv2d::backward(doutr3, Houtc3, Woutc3, outp1, 
W3, b3, F1,  Houtp1, Woutp1, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        doutp1 = max_pool2d::backward(doutc2, Houtp1, Woutp1, outr2, F1, 
Houtc2, Woutc2, pool_HWf, pool_HWf, pool_stride, pool_stride, 0, 0)
+        doutr2 = relu::backward(doutp1, outc2)
+        [doutc1, dW2, db2] = conv2d::backward(doutr2, Houtc2, Woutc2, outr1, 
W2, b2, F1, Houtc1, Woutc1, Hf, Wf, conv_stride, conv_stride, pad, pad)
+        doutr1 = relu::backward(doutc1,outc1)
+        [dx_batch, dW1, db1] = conv2d::backward(doutr1, Houtc1, Woutc1, 
features, W1, b1, C, Hin, Win, Hf, Wf, conv_stride, conv_stride, pad, pad)
+
+        # Compute regularization backward pass
+        dW1_reg = l2_reg::backward(W1, lambda)
+        dW2_reg = l2_reg::backward(W2, lambda)
+        dW3_reg = l2_reg::backward(W3, lambda)
+        dW4_reg = l2_reg::backward(W4, lambda)
+        dW5_reg = l2_reg::backward(W5, lambda)
+        dW6_reg = l2_reg::backward(W6, lambda)
+        dW7_reg = l2_reg::backward(W7, lambda)
+        dW8_reg = l2_reg::backward(W8, lambda)
+        dW9_reg = l2_reg::backward(W9, lambda)
+        dW10_reg = l2_reg::backward(W10, lambda)
+        dW11_reg = l2_reg::backward(W11, lambda)
+        dW12_reg = l2_reg::backward(W12, lambda)
+        dW13_reg = l2_reg::backward(W13, lambda)
+        dW14_reg = l2_reg::backward(W14, lambda)
+        dW15_reg = l2_reg::backward(W15, lambda)
+        dW16_reg = l2_reg::backward(W16, lambda)
+        dW17_reg = l2_reg::backward(W17, lambda)
+        dW18_reg = l2_reg::backward(W18, lambda)
+        dW19_reg = l2_reg::backward(W19, lambda)
+        dW20_reg = l2_reg::backward(W20, lambda)
+        dW21_reg = l2_reg::backward(W21, lambda)
+        dW22_reg = l2_reg::backward(W22, lambda)
+        dW23_reg = l2_reg::backward(W23, lambda)
+
+        dW1 = dW1 + dW1_reg
+        dW2 = dW2 + dW2_reg
+        dW3 = dW3 + dW3_reg
+        dW4 = dW4 + dW4_reg
+        dW5 = dW5 + dW5_reg
+        dW6 = dW6 + dW6_reg
+        dW7 = dW7 + dW7_reg
+        dW8 = dW8 + dW8_reg
+        dW9 = dW9 + dW9_reg
+        dW10 = dW10 + dW10_reg
+        dW11 = dW11 + dW11_reg
+        dW12 = dW12 + dW12_reg
+        dW13 = dW13 + dW13_reg
+        dW14 = dW14 + dW14_reg
+        dW15 = dW15 + dW15_reg
+        dW16 = dW16 + dW16_reg
+        dW17 = dW17 + dW17_reg
+        dW18 = dW18 + dW18_reg
+        dW19 = dW19 + dW19_reg
+        dW20 = dW20 + dW20_reg
+        dW21 = dW21 + dW21_reg
+        dW22 = dW22 + dW22_reg
+        dW23 = dW23 + dW23_reg
+
+        gradients = list(
+            dW1, dW2, dW3, dW4, dW5, dW6, dW7, dW8, dW9, dW10, dW11, dW12, 
dW13, dW14, dW15, dW16, dW17, dW18, dW19, dW20, dW21, dW22, dW23,
+            db1, db2, db3, db4, db5, db6, db7, db8, db9, db10, db11, db12, 
db13, db14, db15, db16, db17, db18, db19, db20, db21, db22, db23
+        )
+    }
+
+/*
+*  Updates the model weights based on gradients and hyperparameters (learning 
rate and mu).
+*
+*  Inputs:
+*  - model: List of model weights
+*  - hyperparams: List of hyper parameters containing (scalar[double]) 
learning_rate and (scalar[double]) mu
+*  - gradients: List of gradients
+*
+*  Outputs:
+*  - model_result: List of updated model weights
+*/
+aggregation = function(list[unknown] model,
+                       list[unknown] hyperparams,
+                       list[unknown] gradients)
+    return (list[unknown] model_result) {
+    W1 = as.matrix(model[1])
+    W2 = as.matrix(model[2])
+    W3 = as.matrix(model[3])
+    W4 = as.matrix(model[4])
+    W5 = as.matrix(model[5])
+    W6 = as.matrix(model[6])
+    W7 = as.matrix(model[7])
+    W8 = as.matrix(model[8])
+    W9 = as.matrix(model[9])
+    W10 = as.matrix(model[10])
+    W11 = as.matrix(model[11])
+    W12 = as.matrix(model[12])
+    W13 = as.matrix(model[13])
+    W14 = as.matrix(model[14])
+    W15 = as.matrix(model[15])
+    W16 = as.matrix(model[16])
+    W17 = as.matrix(model[17])
+    W18 = as.matrix(model[18])
+    W19 = as.matrix(model[19])
+    W20 = as.matrix(model[20])
+    W21 = as.matrix(model[21])
+    W22 = as.matrix(model[22])
+    W23 = as.matrix(model[23])
+
+
+    b1 = as.matrix(model[24])
+    b2 = as.matrix(model[25])
+    b3 = as.matrix(model[26])
+    b4 = as.matrix(model[27])
+    b5 = as.matrix(model[28])
+    b6 = as.matrix(model[29])
+    b7 = as.matrix(model[30])
+    b8 = as.matrix(model[31])
+    b9 = as.matrix(model[32])
+    b10 = as.matrix(model[33])
+    b11 = as.matrix(model[34])
+    b12 = as.matrix(model[35])
+    b13 = as.matrix(model[36])
+    b14 = as.matrix(model[37])
+    b15 = as.matrix(model[38])
+    b16 = as.matrix(model[39])
+    b17 = as.matrix(model[40])
+    b18 = as.matrix(model[41])
+    b19 = as.matrix(model[42])
+    b20 = as.matrix(model[43])
+    b21 = as.matrix(model[44])
+    b22 = as.matrix(model[45])
+    b23 = as.matrix(model[46])
+
+    dW1 = as.matrix(gradients[1])
+    dW2 = as.matrix(gradients[2])
+    dW3 = as.matrix(gradients[3])
+    dW4 = as.matrix(gradients[4])
+    dW5 = as.matrix(gradients[5])
+    dW6 = as.matrix(gradients[6])
+    dW7 = as.matrix(gradients[7])
+    dW8 = as.matrix(gradients[8])
+    dW9 = as.matrix(gradients[9])
+    dW10 = as.matrix(gradients[10])
+    dW11 = as.matrix(gradients[11])
+    dW12 = as.matrix(gradients[12])
+    dW13 = as.matrix(gradients[13])
+    dW14 = as.matrix(gradients[14])
+    dW15 = as.matrix(gradients[15])
+    dW16 = as.matrix(gradients[16])
+    dW17 = as.matrix(gradients[17])
+    dW18 = as.matrix(gradients[18])
+    dW19 = as.matrix(gradients[19])
+    dW20 = as.matrix(gradients[20])
+    dW21 = as.matrix(gradients[21])
+    dW22 = as.matrix(gradients[22])
+    dW23 = as.matrix(gradients[23])
+
+    db1 = as.matrix(gradients[24])
+    db2 = as.matrix(gradients[25])
+    db3 = as.matrix(gradients[26])
+    db4 = as.matrix(gradients[27])
+    db5 = as.matrix(gradients[28])
+    db6 = as.matrix(gradients[29])
+    db7 = as.matrix(gradients[30])
+    db8 = as.matrix(gradients[31])
+    db9 = as.matrix(gradients[32])
+    db10 = as.matrix(gradients[33])
+    db11 = as.matrix(gradients[34])
+    db12 = as.matrix(gradients[35])
+    db13 = as.matrix(gradients[36])
+    db14 = as.matrix(gradients[37])
+    db15 = as.matrix(gradients[38])
+    db16 = as.matrix(gradients[39])
+    db17 = as.matrix(gradients[40])
+    db18 = as.matrix(gradients[41])
+    db19 = as.matrix(gradients[42])
+    db20 = as.matrix(gradients[43])
+    db21 = as.matrix(gradients[44])
+    db22 = as.matrix(gradients[45])
+    db23 = as.matrix(gradients[46])
+
+    vW1 = as.matrix(model[47])
+    vW2 = as.matrix(model[48])
+    vW3 = as.matrix(model[49])
+    vW4 = as.matrix(model[50])
+    vW5 = as.matrix(model[51])
+    vW6 = as.matrix(model[52])
+    vW7 = as.matrix(model[53])
+    vW8 = as.matrix(model[54])
+    vW9 = as.matrix(model[55])
+    vW10 = as.matrix(model[56])
+    vW11 = as.matrix(model[57])
+    vW12 = as.matrix(model[58])
+    vW13 = as.matrix(model[59])
+    vW14 = as.matrix(model[60])
+    vW15 = as.matrix(model[61])
+    vW16 = as.matrix(model[62])
+    vW17 = as.matrix(model[63])
+    vW18 = as.matrix(model[64])
+    vW19 = as.matrix(model[65])
+    vW20 = as.matrix(model[66])
+    vW21 = as.matrix(model[67])
+    vW22 = as.matrix(model[68])
+    vW23 = as.matrix(model[69])
+
+    vb1 = as.matrix(model[70])
+    vb2 = as.matrix(model[71])
+    vb3 = as.matrix(model[72])
+    vb4 = as.matrix(model[73])
+    vb5 = as.matrix(model[74])
+    vb6 = as.matrix(model[75])
+    vb7 = as.matrix(model[76])
+    vb8 = as.matrix(model[77])
+    vb9 = as.matrix(model[78])
+    vb10 = as.matrix(model[79])
+    vb11 = as.matrix(model[80])
+    vb12 = as.matrix(model[81])
+    vb13 = as.matrix(model[82])
+    vb14 = as.matrix(model[83])
+    vb15 = as.matrix(model[84])
+    vb16 = as.matrix(model[85])
+    vb17 = as.matrix(model[86])
+    vb18 = as.matrix(model[87])
+    vb19 = as.matrix(model[88])
+    vb20 = as.matrix(model[89])
+    vb21 = as.matrix(model[90])
+    vb22 = as.matrix(model[91])
+    vb23 = as.matrix(model[92])
+
+    learning_rate = as.double(as.scalar(hyperparams["learning_rate"]))
+    mu = as.double(as.scalar(hyperparams["mu"]))
+
+    # Optimize with SGD with momentum
+    [W1, vW1] = sgd_momentum::update(W1, dW1, learning_rate, mu, vW1)
+    [b1, vb1] = sgd_momentum::update(b1, db1, learning_rate, mu, vb1)
+    [W2, vW2] = sgd_momentum::update(W2, dW2, learning_rate, mu, vW2)
+    [b2, vb2] = sgd_momentum::update(b2, db2, learning_rate, mu, vb2)
+    [W3, vW3] = sgd_momentum::update(W3, dW3, learning_rate, mu, vW3)
+    [b3, vb3] = sgd_momentum::update(b3, db3, learning_rate, mu, vb3)
+    [W4, vW4] = sgd_momentum::update(W4, dW4, learning_rate, mu, vW4)
+    [b4, vb4] = sgd_momentum::update(b4, db4, learning_rate, mu, vb4)
+    [W5, vW5] = sgd_momentum::update(W5, dW5, learning_rate, mu, vW5)
+    [b5, vb5] = sgd_momentum::update(b5, db5, learning_rate, mu, vb5)
+    [W6, vW6] = sgd_momentum::update(W6, dW6, learning_rate, mu, vW6)
+    [b6, vb6] = sgd_momentum::update(b6, db6, learning_rate, mu, vb6)
+    [W7, vW7] = sgd_momentum::update(W7, dW7, learning_rate, mu, vW7)
+    [b7, vb7] = sgd_momentum::update(b7, db7, learning_rate, mu, vb7)
+    [W8, vW8] = sgd_momentum::update(W8, dW8, learning_rate, mu, vW8)
+    [b8, vb8] = sgd_momentum::update(b8, db8, learning_rate, mu, vb8)
+    [W9, vW9] = sgd_momentum::update(W9, dW9, learning_rate, mu, vW9)
+    [b9, vb9] = sgd_momentum::update(b9, db9, learning_rate, mu, vb9)
+    [W10, vW10] = sgd_momentum::update(W10, dW10, learning_rate, mu, vW10)
+    [b10, vb10] = sgd_momentum::update(b10, db10, learning_rate, mu, vb10)
+    [W11, vW11] = sgd_momentum::update(W11, dW11, learning_rate, mu, vW11)
+    [b11, vb11] = sgd_momentum::update(b11, db11, learning_rate, mu, vb11)
+    [W12, vW12] = sgd_momentum::update(W12, dW12, learning_rate, mu, vW12)
+    [b12, vb12] = sgd_momentum::update(b12, db12, learning_rate, mu, vb12)
+    [W13, vW13] = sgd_momentum::update(W13, dW13, learning_rate, mu, vW13)
+    [b13, vb13] = sgd_momentum::update(b13, db13, learning_rate, mu, vb13)
+    [W14, vW14] = sgd_momentum::update(W14, dW14, learning_rate, mu, vW14)
+    [b14, vb14] = sgd_momentum::update(b14, db14, learning_rate, mu, vb14)
+    [W15, vW15] = sgd_momentum::update(W15, dW15, learning_rate, mu, vW15)
+    [b15, vb15] = sgd_momentum::update(b15, db15, learning_rate, mu, vb15)
+    [W16, vW16] = sgd_momentum::update(W16, dW16, learning_rate, mu, vW16)
+    [b16, vb16] = sgd_momentum::update(b16, db16, learning_rate, mu, vb16)
+    [W17, vW17] = sgd_momentum::update(W17, dW17, learning_rate, mu, vW17)
+    [b17, vb17] = sgd_momentum::update(b17, db17, learning_rate, mu, vb17)
+    [W18, vW18] = sgd_momentum::update(W18, dW18, learning_rate, mu, vW18)
+    [b18, vb18] = sgd_momentum::update(b18, db18, learning_rate, mu, vb18)
+    [W19, vW19] = sgd_momentum::update(W19, dW19, learning_rate, mu, vW19)
+    [b19, vb19] = sgd_momentum::update(b19, db19, learning_rate, mu, vb19)
+    [W20, vW20] = sgd_momentum::update(W20, dW20, learning_rate, mu, vW20)
+    [b20, vb20] = sgd_momentum::update(b20, db20, learning_rate, mu, vb20)
+    [W21, vW21] = sgd_momentum::update(W21, dW21, learning_rate, mu, vW21)
+    [b21, vb21] = sgd_momentum::update(b21, db21, learning_rate, mu, vb21)
+    [W22, vW22] = sgd_momentum::update(W22, dW22, learning_rate, mu, vW22)
+    [b22, vb22] = sgd_momentum::update(b22, db22, learning_rate, mu, vb22)
+    [W23, vW23] = sgd_momentum::update(W23, dW23, learning_rate, mu, vW23)
+    [b23, vb23] = sgd_momentum::update(b23, db23, learning_rate, mu, vb23)
+
+    model_result = list(
+            W1, W2, W3, W4, W5, W6, W7, W8, W9, W10, W11, W12, W13, W14, W15, 
W16, W17, W18, W19, W20, W21, W22, W23,
+            b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15, 
b16, b17, b18, b19, b20, b21, b22, b23,
+            vW1, vW2, vW3, vW4, vW5, vW6, vW7, vW8, vW9, vW10, vW11, vW12, 
vW13, vW14, vW15, vW16, vW17, vW18, vW19, vW20, vW21, vW22, vW23,
+            vb1, vb2, vb3, vb4, vb5, vb6, vb7, vb8, vb9, vb10, vb11, vb12, 
vb13, vb14, vb15, vb16, vb17, vb18, vb19, vb20, vb21, vb22, vb23
+        )
+}
+
+/*
+ * Evaluates a U-Net architecture.
+ *
+ * The probs matrix contains the class probability predictions and y contains 
the target.
+ *
+ * Inputs:
+ *  - probs: Class probabilities, of shape (N, C*Hin*Win)
+ *  - y: Target matrix, of shape (N, C*Hin*Win)
+ *
+ * Outputs:
+ *  - loss: Scalar loss, of shape (1)
+ *  - accuracy: Scalar accuracy, of shape (1)
+ */
+eval = function(matrix[double] probs, matrix[double] y)
+    return (double loss, double accuracy) {
+
+  # Compute loss & accuracy
+  loss = cross_entropy_loss::forward(probs, y)
+  correct_pred = rowIndexMax(probs) == rowIndexMax(y)
+  accuracy = mean(correct_pred)
+}
+
+/*
+ * Gives the accuracy and loss for a model and given feature and label matrices
+ *
+ * This function is a combination of the predict and eval function used for 
validation.
+ * For inputs see eval and predict.
+ *
+ * Inputs:
+ *  - val_features: Validation data features
+ *  - val_labels: Validation data labels
+ *  - model: List of weights of the trained model
+ *  - hyperparams: Hyperparameters including C, Hin, Win, and K
+ *  - F1: Number of filters of the top layer of the U-Net model. Default is 64.
+ *  - batch_size: Batch size of prediction
+ *
+ * Outputs:
+ *  - loss: Scalar loss, of shape (1).
+ *  - accuracy: Scalar accuracy, of shape (1).
+ */
+validate = function(matrix[double] val_features, matrix[double] val_labels,
+  list[unknown] model, list[unknown] hyperparams, int F1 = 64, int batch_size 
= 32)
+       return (double loss, double accuracy)
+{
+  C = as.integer(as.scalar(hyperparams["C"]))
+  Hin = as.integer(as.scalar(hyperparams["Hin"]))
+  Win = as.integer(as.scalar(hyperparams["Win"]))
+  K = as.integer(as.scalar(hyperparams["K"]))
+  predictions = predict(val_features, C, Hin, Win, batch_size, model, K, F1)
+  [loss, accuracy] = eval(predictions, val_labels)
+}
diff --git a/scripts/nn/layers/conv2d_transpose.dml 
b/scripts/nn/layers/conv2d_transpose.dml
index 6443b3ed46..c8731879d9 100644
--- a/scripts/nn/layers/conv2d_transpose.dml
+++ b/scripts/nn/layers/conv2d_transpose.dml
@@ -175,6 +175,35 @@ init = function(int F, int C, int Hf, int Wf)
   b = matrix(0, rows=F, cols=1)
 }
 
+init_seed = function(int F, int C, int Hf, int Wf, int seed = -1)
+    return (matrix[double] W, matrix[double] b) {
+  /*
+   * Initialize the parameters of this layer.
+   *
+   * Note: This is just a convenience function, and parameters
+   * may be initialized manually if needed.
+   *
+   * We use the heuristic by He et al., which limits the magnification
+   * of inputs/gradients during forward/backward passes by scaling
+   * unit-Gaussian weights by a factor of sqrt(2/n), under the
+   * assumption of relu neurons.
+   *  - http://arxiv.org/abs/1502.01852
+   *
+   * Inputs:
+   *  - F: Number of filters.
+   *  - C: Number of input channels (dimensionality of depth).
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - seed: The seed to initialize the weights
+   *
+   * Outputs:
+   *  - W: Weights, of shape (F, C*Hf*Wf).
+   *  - b: Biases, of shape (F, 1).
+   */
+  W = rand(rows=C, cols=F*Hf*Wf, pdf="normal", seed=seed) * sqrt(2.0/(C*Hf*Wf))
+  b = matrix(0, rows=F, cols=1)
+}
+
 init_bilinear = function(int C, int K)
     return (matrix[double] W, matrix[double] b){
   /*
diff --git a/scripts/utils/image_utils.dml b/scripts/utils/image_utils.dml
index 76a0860e32..f494daac34 100644
--- a/scripts/utils/image_utils.dml
+++ b/scripts/utils/image_utils.dml
@@ -36,6 +36,26 @@ crop_rgb = function(matrix[double] input, int Hin, int Win, 
int Hout, int Wout)
        out = removeEmpty(target=(input+1), margin="cols", select=mask) - 1
 }
 
+/*
+ * Simple utility to crop image of shape [N, C * Hin * Win] into [N, C * Hout 
* Wout]
+ * Assumption: Hout < Hin, Wout < Win and input contains values [0, ..]
+ */
+crop_channel = function(matrix[double] input, int Hin, int Win, int Hout, int 
Wout, int C) return (matrix[double] out) {
+       start_h = ceil((Hin - Hout) / 2)
+       end_h = start_h + Hout - 1
+       start_w = ceil((Win - Wout) / 2)
+       end_w = start_w + Wout - 1
+       mask = matrix(0, rows=Hin, cols=Win)
+       temp_mask = matrix(1, rows=Hout, cols=Wout)
+       mask[start_h:end_h, start_w:end_w] = temp_mask
+       mask = matrix(mask, rows=1, cols=Hin*Win)
+       maskC = mask
+       for ( i in 1:(C-1) ){
+           maskC = cbind(maskC,mask)
+       }
+       out = removeEmpty(target=(input+1), margin="cols", select=maskC) - 1
+}
+
 /*
  * Simple utility to crop image of shape [N, Hin * Win] into [N, Hout * Wout]
  * Assumption: Hout < Hin, Wout < Win and input contains values [0, ..]
@@ -80,4 +100,4 @@ crop_grayscale = function(matrix[double] input, int Hin, int 
Win, int Hout, int
        mask[start_h:end_h, start_w:end_w] = temp_mask
        mask = matrix(mask, rows=1, cols=Hin*Win)
        out = removeEmpty(target=(input+1), margin="cols", select=mask) - 1
-}
\ No newline at end of file
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
index 8037611622..40ca96734a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
@@ -64,7 +64,7 @@ public class ShuffleFederatedScheme extends 
DataPartitionFederatedScheme {
                        try {
                                FederatedResponse response = udfResponse.get();
                                if(!response.isSuccessful())
-                                       throw new 
DMLRuntimeException("FederatedDataPartitioner ShuffleFederatedScheme: shuffle 
UDF returned fail");
+                                       throw new 
DMLRuntimeException("FederatedDataPartitioner ShuffleFederatedScheme: shuffle 
UDF returned fail. Federated worker error message: " + 
response.getErrorMessage());
                        }
                        catch(Exception e) {
                                throw new 
DMLRuntimeException("FederatedDataPartitioner ShuffleFederatedScheme: executing 
shuffle UDF failed" + e.getMessage());
@@ -103,4 +103,4 @@ public class ShuffleFederatedScheme extends 
DataPartitionFederatedScheme {
                        return null;
                }
        }
-}
\ No newline at end of file
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index b4c3c64553..51db6ed751 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -144,12 +144,14 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                DataPartitionFederatedScheme.Result result = new 
FederatedDataPartitioner(federatedPSScheme, seed)
                        
.doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)), 
ec.getMatrixObject(getParam(PS_LABELS)));
                int workerNum = result._workerNum;
-               if (DMLScript.STATISTICS)
-                       ParamServStatistics.accFedDataPartitioningTime((long) 
tDataPartitioning.stop());
 
+               if (DMLScript.STATISTICS ){
+                       if (tDataPartitioning != null)
+                               
ParamServStatistics.accFedDataPartitioningTime((long) tDataPartitioning.stop());
+                       if (tSetup != null)
+                               tSetup.start();
+               }
 
-               if (DMLScript.STATISTICS)
-                       tSetup.start();
                // setup threading
                BasicThreadFactory factory = new BasicThreadFactory.Builder()
                        .namingPattern("workers-pool-thread-%d").build();
@@ -168,20 +170,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                MatrixObject val_labels = (getParam(PS_VAL_LABELS) != null) ? 
ec.getMatrixObject(getParam(PS_VAL_LABELS)) : null;
                boolean modelAvg = Boolean.parseBoolean(getParam(PS_MODELAVG));
 
-               // check if we need homomorphic encryption
-               boolean use_homomorphic_encryption_ = getHe();
-               for (int i = 0; i < workerNum; i++) {
-                       use_homomorphic_encryption_ = 
use_homomorphic_encryption_ || checkIsPrivate(result._pFeatures.get(i));
-                       use_homomorphic_encryption_ = 
use_homomorphic_encryption_ || checkIsPrivate(result._pLabels.get(i));
-               }
-               final boolean use_homomorphic_encryption = 
use_homomorphic_encryption_;
-               if (use_homomorphic_encryption && !modelAvg) {
-                       throw new DMLRuntimeException("can't use homomorphic 
encryption without modelAvg");
-               }
-
-               if (use_homomorphic_encryption && weighting) {
-                       throw new DMLRuntimeException("can't use homomorphic 
encryption with weighting");
-               }
+               final boolean use_homomorphic_encryption = 
useHomomorphicEncryption(result, workerNum, modelAvg, weighting);
 
                LocalParamServer ps = (LocalParamServer) 
createPS(PSModeType.FEDERATED, aggFunc, updateType, freq, workerNum,
                        model, aggServiceEC, getValFunction(), 
getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics),
@@ -236,6 +225,32 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                }
        }
 
+       /**
+        * Check if homomorphic encryption is needed
+        * @param result data partition result
+        * @param workerNum number of workers
+        * @param modelAvg model average
+        * @param weighting use weighting
+        * @return true if homomorphic encryption is needed
+        */
+       private boolean 
useHomomorphicEncryption(DataPartitionFederatedScheme.Result result,
+               int workerNum, boolean modelAvg, boolean weighting){
+               boolean use_homomorphic_encryption = getHe();
+               for (int i = 0; i < workerNum; i++) {
+                       use_homomorphic_encryption = use_homomorphic_encryption 
|| checkIsPrivate(result._pFeatures.get(i));
+                       use_homomorphic_encryption = use_homomorphic_encryption 
|| checkIsPrivate(result._pLabels.get(i));
+               }
+               if ( use_homomorphic_encryption ){
+                       if ( !modelAvg )
+                               throw new DMLRuntimeException("can't use 
homomorphic encryption without modelAvg");
+                       if ( weighting )
+                               throw new DMLRuntimeException("can't use 
homomorphic encryption with weighting");
+                       LOG.info("Homomorphic encryption activated for 
federated parameter server");
+               }
+
+               return use_homomorphic_encryption;
+       }
+
        private void runOnSpark(SparkExecutionContext sec, PSModeType mode) {
                Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java
index 437869209c..4ef5d0b45a 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/EncryptedFederatedParamservTest.java
@@ -23,8 +23,10 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.List;
+import java.util.Objects;
 
 import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.paramserv.NativeHEHelper;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint;
 import org.apache.sysds.test.AutomatedTestBase;
@@ -36,6 +38,8 @@ import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
+import static org.junit.Assert.fail;
+
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
 public class EncryptedFederatedParamservTest extends AutomatedTestBase {
@@ -67,6 +71,8 @@ public class EncryptedFederatedParamservTest extends 
AutomatedTestBase {
                                //{"TwoNN",     4, 60000, 32, 4, 0.01,  "BSP", 
"BATCH", "KEEP_DATA_ON_WORKER",  "NONE" ,                "false","BALANCED",    
         200},
 
                                // One important point is that we do the model 
averaging in the case of BSP
+                               {"UNet",        2, 4, 1, 1, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "BASELINE",             "false",        
"BALANCED",             200},
+                               //{"UNet",      2, 4, 1, 1, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "BASELINE",             "false",        
"IMBALANCED",   200},
                                {"TwoNN",       2, 4, 1, 1, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "BASELINE",             "false",        
"IMBALANCED",   200},
                                {"CNN",         2, 4, 1, 1, 0.01,               
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "BASELINE",             "false",        
"IMBALANCED",   200},
                                //{"TwoNN",     5, 1000, 100, 1, 0.01,  "BSP", 
"BATCH", "KEEP_DATA_ON_WORKER",  "NONE",                 "true", "BALANCED",    
         200},
@@ -99,11 +105,7 @@ public class EncryptedFederatedParamservTest extends 
AutomatedTestBase {
                int dataSetSize, int batch_size, int epochs, double eta, String 
utype, String freq,
                String scheme, String runtime_balancing, String weighting, 
String data_distribution, int seed)
        {
-               try {
-                       NativeHEHelper.initialize();
-               } catch (Exception e) {
-                       throw e;
-               }
+               NativeHEHelper.initialize();
                _networkType = networkType;
                _numFederatedWorkers = numFederatedWorkers;
                _dataSetSize = dataSetSize;
@@ -144,22 +146,30 @@ public class EncryptedFederatedParamservTest extends 
AutomatedTestBase {
 
                int C = 1, Hin = 28, Win = 28;
                int numLabels = 10;
+               if (Objects.equals(_networkType, "UNet")){
+                       C = 3; Hin = 340; Win = 340;
+                       numLabels = C * Hin * Win;
+               }
 
                ExecMode platformOld = setExecMode(mode);
-
+               // start threads
+               List<Integer> ports = new ArrayList<>();
+               List<Thread> threads = new ArrayList<>();
                try {
-                       // start threads
-                       List<Integer> ports = new ArrayList<>();
-                       List<Thread> threads = new ArrayList<>();
                        for(int i = 0; i < _numFederatedWorkers; i++) {
-                               ports.add(getRandomAvailablePort());
-                               
threads.add(startLocalFedWorkerThread(ports.get(i),
+                               int port = getRandomAvailablePort();
+                               threads.add(startLocalFedWorkerThread(port,
                                                i==(_numFederatedWorkers-1) ? 
FED_WORKER_WAIT : FED_WORKER_WAIT_S));
+                               ports.add(port);
+                               System.out.println("Worker with port " + port + 
" started!");
+
+                               if ( threads.get(i).isInterrupted() || 
!threads.get(i).isAlive() )
+                                       throw new 
DMLRuntimeException("Federated worker thread dead or interrupted! Port " + 
port);
                        }
 
                        // generate test data
-                       double[][] features = 
generateDummyMNISTFeatures(_dataSetSize, C, Hin, Win);
-                       double[][] labels = 
generateDummyMNISTLabels(_dataSetSize, numLabels);
+                       double[][] features = 
ParamServTestUtils.generateFeatures(_networkType, _dataSetSize, C, Hin, Win);
+                       double[][] labels = 
ParamServTestUtils.generateLabels(_networkType, _dataSetSize, numLabels, 
C*Hin*Win, features);
                        String featuresName = "";
                        String labelsName = "";
 
@@ -181,13 +191,8 @@ public class EncryptedFederatedParamservTest extends 
AutomatedTestBase {
                                
rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels, 
_numFederatedWorkers, ports, ranges, privacyConstraint);
                        }
 
-                       try {
-                               //wait for all workers to be setup
-                               Thread.sleep(FED_WORKER_WAIT);
-                       }
-                       catch(InterruptedException e) {
-                               e.printStackTrace();
-                       }
+                       //wait for all workers to be setup
+                       Thread.sleep(FED_WORKER_WAIT);
 
                        // dml name
                        fullDMLScriptName = HOME + TEST_NAME + ".dml";
@@ -213,45 +218,23 @@ public class EncryptedFederatedParamservTest extends 
AutomatedTestBase {
 
                        programArgs = programArgsList.toArray(new String[0]);
                        String log = runTest(null).toString();
+                       System.out.println(log);
+                       if (!heavyHittersContainsAllString("paramserv"))
+                               fail("The following expected heavy hitters are 
missing: "
+                                       + 
Arrays.toString(missingHeavyHitters("paramserv")));
                        Assert.assertEquals("Test Failed \n" + log, 0, 
Statistics.getNoOfExecutedSPInst());
-
-                       // shut down threads
-                       for(int i = 0; i < _numFederatedWorkers; i++) {
-                               TestUtils.shutdownThreads(threads.get(i));
-                       }
+               }
+               catch(InterruptedException e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
                }
                finally {
+                       // shut down threads
+                       for ( Thread thread : threads ){
+                               TestUtils.shutdownThreads(thread);
+                       }
+
                        resetExecMode(platformOld);
                }
        }
-
-       /**
-        * Generates an feature matrix that has the same format as the MNIST 
dataset,
-        * but is completely random and normalized
-        *
-        *  @param numExamples Number of examples to generate
-        *  @param C Channels in the input data
-        *  @param Hin Height in Pixels of the input data
-        *  @param Win Width in Pixels of the input data
-        *  @return a dummy MNIST feature matrix
-        */
-       private double[][] generateDummyMNISTFeatures(int numExamples, int C, 
int Hin, int Win) {
-               // Seed -1 takes the time in milliseconds as a seed
-               // Sparsity 1 means no sparsity
-               return getRandomMatrix(numExamples, C*Hin*Win, 0, 1, 1, -1);
-       }
-
-       /**
-        * Generates an label matrix that has the same format as the MNIST 
dataset, but is completely random and consists
-        * of one hot encoded vectors as rows
-        *
-        *  @param numExamples Number of examples to generate
-        *  @param numLabels Number of labels to generate
-        *  @return a dummy MNIST lable matrix
-        */
-       private double[][] generateDummyMNISTLabels(int numExamples, int 
numLabels) {
-               // Seed -1 takes the time in milliseconds as a seed
-               // Sparsity 1 means no sparsity
-               return getRandomMatrix(numExamples, numLabels, 0, 1, 1, -1);
-       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index 81463b4c54..894c7ca548 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -25,6 +25,7 @@ import java.util.Collection;
 import java.util.List;
 
 import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
@@ -34,6 +35,8 @@ import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
+import static org.junit.Assert.fail;
+
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
 public class FederatedParamservTest extends AutomatedTestBase {
@@ -63,6 +66,7 @@ public class FederatedParamservTest extends AutomatedTestBase 
{
                        // Network type, number of federated workers, data set 
size, batch size, epochs, learning rate, update type, update frequency
                        // basic functionality
 
+                       {"UNet",        2, 4, 1, 1, 0.01,               "BSP", 
"BATCH", "KEEP_DATA_ON_WORKER",  "BASELINE",             "false","BALANCED",    
         200},
                        {"TwoNN",       2, 4, 1, 4, 0.01,               "BSP", 
"BATCH", "KEEP_DATA_ON_WORKER",  "BASELINE",             "true", "IMBALANCED",  
 200},
                        {"CNN",         2, 4, 1, 4, 0.01,               "BSP", 
"EPOCH", "SHUFFLE",                              "NONE",                 
"true", "IMBALANCED",   200},
                        {"CNN",         2, 4, 1, 4, 0.01,               "ASP", 
"BATCH", "REPLICATE_TO_MAX",     "CYCLE_MIN",    "true", "IMBALANCED",   200},
@@ -136,21 +140,29 @@ public class FederatedParamservTest extends 
AutomatedTestBase {
 
                int C = 1, Hin = 28, Win = 28;
                int numLabels = 10;
+               if (_networkType.equals("UNet")){
+                       C = 3; Hin = 340; Win = 340;
+                       numLabels = C * Hin * Win;
+               }
 
                ExecMode platformOld = setExecMode(mode);
-
+               List<Integer> ports = new ArrayList<>();
+               List<Thread> threads = new ArrayList<>();
                try {
                        // start threads
-                       List<Integer> ports = new ArrayList<>();
-                       List<Thread> threads = new ArrayList<>();
                        for(int i = 0; i < _numFederatedWorkers; i++) {
-                               ports.add(getRandomAvailablePort());
-                               
threads.add(startLocalFedWorkerThread(ports.get(i), FED_WORKER_WAIT_S));
+                               int port = getRandomAvailablePort();
+                               threads.add(startLocalFedWorkerThread(port, 
FED_WORKER_WAIT_S));
+                               ports.add(port);
+                               System.out.println("Worker with port " + port + 
" started!");
+
+                               if ( threads.get(i).isInterrupted() || 
!threads.get(i).isAlive() )
+                                       throw new 
DMLRuntimeException("Federated worker thread dead or interrupted! Port " + 
port);
                        }
 
                        // generate test data
-                       double[][] features = 
generateDummyMNISTFeatures(_dataSetSize, C, Hin, Win);
-                       double[][] labels = 
generateDummyMNISTLabels(_dataSetSize, numLabels);
+                       double[][] features = 
ParamServTestUtils.generateFeatures(_networkType, _dataSetSize, C, Hin, Win);
+                       double[][] labels = 
ParamServTestUtils.generateLabels(_networkType, _dataSetSize, numLabels, 
C*Hin*Win, features);
                        String featuresName = "";
                        String labelsName = "";
 
@@ -170,13 +182,10 @@ public class FederatedParamservTest extends 
AutomatedTestBase {
                                
rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels, 
_numFederatedWorkers, ports, ranges);
                        }
 
-                       try {
-                               //wait for all workers to be setup
-                               Thread.sleep(FED_WORKER_WAIT);
-                       }
-                       catch(InterruptedException e) {
-                               e.printStackTrace();
-                       }
+                       //wait for all workers to be setup
+                       Thread.sleep(FED_WORKER_WAIT);
+                       if (threads.stream().anyMatch(t -> !t.isAlive()))
+                               throw new DMLRuntimeException("Federated worker 
thread interrupted!");
 
                        // dml name
                        fullDMLScriptName = HOME + TEST_NAME + ".dml";
@@ -202,44 +211,17 @@ public class FederatedParamservTest extends 
AutomatedTestBase {
                        programArgs = programArgsList.toArray(new String[0]);
                        String log = runTest(null).toString();
                        Assert.assertEquals("Test Failed \n" + log, 0, 
Statistics.getNoOfExecutedSPInst());
-                       
-                       // shut down threads
-                       for(int i = 0; i < _numFederatedWorkers; i++) {
-                               TestUtils.shutdownThreads(threads.get(i));
-                       }
+               }
+               catch(InterruptedException e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
                }
                finally {
+                       // shut down threads
+                       for ( Thread thread : threads ){
+                               TestUtils.shutdownThreads(thread);
+                       }
                        resetExecMode(platformOld);
                }
        }
-
-       /**
-        * Generates an feature matrix that has the same format as the MNIST 
dataset,
-        * but is completely random and normalized
-        *
-        *  @param numExamples Number of examples to generate
-        *  @param C Channels in the input data
-        *  @param Hin Height in Pixels of the input data
-        *  @param Win Width in Pixels of the input data
-        *  @return a dummy MNIST feature matrix
-        */
-       private double[][] generateDummyMNISTFeatures(int numExamples, int C, 
int Hin, int Win) {
-               // Seed -1 takes the time in milliseconds as a seed
-               // Sparsity 1 means no sparsity
-               return getRandomMatrix(numExamples, C*Hin*Win, 0, 1, 1, -1);
-       }
-
-       /**
-        * Generates an label matrix that has the same format as the MNIST 
dataset, but is completely random and consists
-        * of one hot encoded vectors as rows
-        *
-        *  @param numExamples Number of examples to generate
-        *  @param numLabels Number of labels to generate
-        *  @return a dummy MNIST lable matrix
-        */
-       private double[][] generateDummyMNISTLabels(int numExamples, int 
numLabels) {
-               // Seed -1 takes the time in milliseconds as a seed
-               // Sparsity 1 means no sparsity
-               return getRandomMatrix(numExamples, numLabels, 0, 1, 1, -1);
-       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/ParamServTestUtils.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/ParamServTestUtils.java
new file mode 100644
index 0000000000..075eabcebc
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/ParamServTestUtils.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.paramserv;
+
+import org.apache.sysds.test.TestUtils;
+
+import java.util.Objects;
+
+/**
+ * Util class with helper methods for generating and writing
+ * features and labels for parameter server tests.
+ */
+public class ParamServTestUtils {
+       /**
+        * Generate features
+        * @param networkType network type
+        * @param numExamples number of input examples
+        * @param C number of channels
+        * @param Hin input height
+        * @param Win input width
+        * @return features
+        */
+       public static double[][] generateFeatures(String networkType, int 
numExamples, int C, int Hin, int Win){
+               if (Objects.equals(networkType, "UNet"))
+                       return generateDummyMedicalImageFeatures(numExamples, 
C, Hin, Win);
+               else
+                       return generateDummyMNISTFeatures(numExamples, C, Hin, 
Win);
+       }
+
+       /**
+        * Generates an feature matrix that has the same format as the MNIST 
dataset,
+        * but is completely random and normalized
+        *
+        *  @param numExamples Number of examples to generate
+        *  @param C Channels in the input data
+        *  @param Hin Height in Pixels of the input data
+        *  @param Win Width in Pixels of the input data
+        *  @return a dummy MNIST feature matrix
+        */
+       private static double[][] generateDummyMNISTFeatures(int numExamples, 
int C, int Hin, int Win) {
+               // Seed -1 takes the time in milliseconds as a seed
+               // Sparsity 1 means no sparsity
+               return TestUtils.generateTestMatrix(numExamples, C*Hin*Win, 0, 
1, 1, -1);
+       }
+
+       /**
+        * Generate dummy medical image features for training UNet.
+        * Input height and input width are padded so that the output
+        * dimensions of UNet matches the label dimensions.
+        * @param numExamples number of input examples
+        * @param C number of channels
+        * @param Hin input height
+        * @param Win input width
+        * @return features
+        */
+       private static double[][] generateDummyMedicalImageFeatures(int 
numExamples, int C, int Hin, int Win) {
+               // Pad height and width
+               Hin = Hin + 184;
+               Win = Win + 184;
+               return TestUtils.generateTestMatrix(numExamples, C*Hin*Win, 
-1024, 4096, 1, -1);
+       }
+
+       /**
+        * Generate labels
+        * @param networkType type of network
+        * @param numExamples number of examples to generate labels for
+        * @param numLabels number of labels to generate (except for UNet)
+        * @param numFeatures number of features without padding (only used for 
UNet)
+        * @param features features for which labels are generated (only used 
for UNet)
+        * @return labels
+        */
+       public static double[][] generateLabels(String networkType, int 
numExamples, int numLabels, int numFeatures, double[][] features) {
+               if (Objects.equals(networkType, "UNet"))
+                       return 
generateDummyMedicalImageLabels(features,numFeatures);
+               else
+                       return generateDummyMNISTLabels(numExamples, numLabels);
+       }
+
+       /**
+        * Generates a label matrix that has the same format as the MNIST 
dataset, but is completely random and consists
+        * of one hot encoded vectors as rows
+        *
+        *  @param numExamples Number of examples to generate
+        *  @param numLabels Number of labels to generate
+        *  @return a dummy MNIST lable matrix
+        */
+       private static double[][] generateDummyMNISTLabels(int numExamples, int 
numLabels) {
+               // Seed -1 takes the time in milliseconds as a seed
+               // Sparsity 1 means no sparsity
+               return TestUtils.generateTestMatrix(numExamples, numLabels, 0, 
1, 1, -1);
+       }
+
+       /**
+        * Return labels as 0 or 1 based on the values in features.
+        * @param features for which labels are generated
+        * @param numFeatures number of features without padding
+        * @return labels
+        */
+       private static double[][] generateDummyMedicalImageLabels(double[][] 
features, int numFeatures) {
+               double split = 1000;
+               double[][] labels = new double[features.length][numFeatures];
+               for ( int i = 0; i < labels.length; i++ ){
+                       for ( int j = 0; j < labels[0].length; j++ ){
+                               labels[i][j] = (features[i][j] > split) ? 1 : 0;
+                       }
+               }
+               return labels;
+       }
+}
diff --git 
a/src/test/scripts/functions/federated/paramserv/EncryptedFederatedParamservTest.dml
 
b/src/test/scripts/functions/federated/paramserv/EncryptedFederatedParamservTest.dml
index b8021867dc..7b00f52eff 100644
--- 
a/src/test/scripts/functions/federated/paramserv/EncryptedFederatedParamservTest.dml
+++ 
b/src/test/scripts/functions/federated/paramserv/EncryptedFederatedParamservTest.dml
@@ -23,6 +23,7 @@ 
source("src/test/scripts/functions/federated/paramserv/TwoNN.dml") as TwoNN
 source("src/test/scripts/functions/federated/paramserv/TwoNNModelAvg.dml") as 
TwoNNModelAvg
 source("src/test/scripts/functions/federated/paramserv/CNN.dml") as CNN
 source("src/test/scripts/functions/federated/paramserv/CNNModelAvg.dml") as 
CNNModelAvg
+source("scripts/nn/examples/u-net.dml") as UNet
 
 
 # create federated input matrices
@@ -59,3 +60,21 @@ else if($network_type == "CNN") {
     print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test 
+ "\n")
   }
 }
+else if($network_type == "UNet") {
+    numFeatures = $channels*$hin*$win
+    numRows = nrow(features)
+    F1 = 2
+
+    x_hw = $hin + 184 # Padded input height and width
+    x_val = matrix(0, rows=numRows, cols=$channels*x_hw*x_hw)
+    y_val = matrix(0, rows=numRows, cols=numFeatures)
+
+    model = UNet::train_paramserv(features, labels, x_val, y_val, $channels, 
x_hw, x_hw, $epochs, 0, $utype, $freq, $batch_size, $scheme, $eta, $seed, TRUE, 
F1)
+    print("Test results:")
+    hyperparams = list(learning_rate=$eta, C=$channels, Hin=x_hw, Win=x_hw, 
K=numFeatures)
+    [loss_test, accuracy_test] = UNet::validate(matrix(0, rows=numRows, 
cols=$channels*x_hw*x_hw), matrix(0, rows=numRows, cols=numFeatures), model, 
hyperparams, F1, $batch_size)
+    print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test 
+ "\n")
+}
+else {
+    print("Network type not recognized: " + $network_type)
+}
diff --git 
a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml 
b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
index 7efd58842a..009d5d3dec 100644
--- a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
+++ b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
@@ -21,6 +21,7 @@
 
 source("src/test/scripts/functions/federated/paramserv/TwoNN.dml") as TwoNN
 source("src/test/scripts/functions/federated/paramserv/CNN.dml") as CNN
+source("scripts/nn/examples/u-net.dml") as UNet
 
 # create federated input matrices
 features = read($features)
@@ -32,10 +33,28 @@ if($network_type == "TwoNN") {
   [loss_test, accuracy_test] = TwoNN::validate(matrix(0, rows=100, cols=784), 
matrix(0, rows=100, cols=10), model, list())
   print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + 
"\n")
 }
-else {
+else if($network_type == "CNN"){
   model = CNN::train_paramserv(features, labels, matrix(0, rows=100, 
cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, 
$batch_size, $scheme, $runtime_balancing, $weighting, $eta, $channels, $hin, 
$win, $seed)
   print("Test results:")
   hyperparams = list(learning_rate=$eta, C=$channels, Hin=$hin, Win=$win)
   [loss_test, accuracy_test] = CNN::validate(matrix(0, rows=100, cols=784), 
matrix(0, rows=100, cols=10), model, hyperparams)
   print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + 
"\n")
 }
+else if($network_type == "UNet"){
+  numFeatures = $channels*$hin*$win
+  numRows = nrow(features)
+  F1 = 2
+
+  x_hw = $hin + 184 # Padded input height and width
+  x_val = matrix(0, rows=numRows, cols=$channels*x_hw*x_hw)
+  y_val = matrix(0, rows=numRows, cols=numFeatures)
+
+  model = UNet::train_paramserv(features, labels, x_val, y_val, $channels, 
x_hw, x_hw, $epochs, 2, $utype, $freq, $batch_size, $scheme, $eta, $seed, 
FALSE, F1)
+  print("Test results:")
+  hyperparams = list(learning_rate=$eta, C=$channels, Hin=x_hw, Win=x_hw, 
K=numFeatures)
+  [loss_test, accuracy_test] = UNet::validate(matrix(0, rows=numRows, 
cols=$channels*x_hw*x_hw), matrix(0, rows=numRows, cols=numFeatures), model, 
hyperparams, F1, numRows)
+  print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + 
"\n")
+}
+else {
+    print("Network type not recognized: " + $network_type)
+}

Reply via email to