This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 0f42fd3  [SYSTEMDS-3218] Fix BIN release parameter server local 
execution
0f42fd3 is described below

commit 0f42fd38389b9b100668261cc342acdf9278652b
Author: baunsgaard <[email protected]>
AuthorDate: Mon Nov 15 17:59:10 2021 +0100

    [SYSTEMDS-3218] Fix BIN release parameter server local execution
    
    This commit fixes the parameter server execution to not crash on use
    with out BIN release artifact.
    The reason for the crash was import of scala tuple in LocalCP worker,
    the import was the result of having spark methods in this class
    instead of isolated in other files.
---
 .../controlprogram/paramserv/LocalPSWorker.java    |  12 ++-
 .../controlprogram/paramserv/ParamservUtils.java   | 104 ++----------------
 .../paramserv/SparkParamservUtils.java             | 118 +++++++++++++++++++++
 .../cp/ParamservBuiltinCPInstruction.java          |  39 +++----
 .../paramserv/SparkDataPartitionerTest.java        |   7 +-
 5 files changed, 160 insertions(+), 120 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
index f1848ad..93207f3 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -31,8 +31,10 @@ import org.apache.sysds.parser.Statement;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.instructions.cp.ListObject;
+import org.apache.sysds.runtime.util.CommonThreadPool;
 import org.apache.sysds.utils.Statistics;
 
 public class LocalPSWorker extends PSWorker implements Callable<Void> {
@@ -88,6 +90,8 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                        // Pull the global parameters from ps
                        ListObject params = pullModel();
                        Future<ListObject> accGradients = 
ConcurrentUtils.constantFuture(null);
+                       if(_tpool == null)
+                               _tpool = 
CommonThreadPool.get(InfrastructureAnalyzer.getLocalParallelism());
 
                        try {
                                for (int j = 0; j < batchIter; j++) {
@@ -98,9 +102,13 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                                        // Accumulate the intermediate 
gradients (async for overlap w/ model updates
                                        // and gradient computation, sequential 
over gradient matrices to avoid deadlocks)
                                        ListObject accGradientsPrev = 
accGradients.get();
-                                       accGradients = _modelAvg ? 
ConcurrentUtils.constantFuture(null) :
-                                               _tpool.submit(() -> 
ParamservUtils.accrueGradients(
+                                       if(_modelAvg){
+                                               accGradients = 
ConcurrentUtils.constantFuture(null);
+                                       }
+                                       else{
+                                               accGradients = _tpool.submit(() 
-> ParamservUtils.accrueGradients(
                                                        accGradientsPrev, 
gradients, false, !localUpdate));
+                                       }
                                        
                                        // Update the local model with gradients
                                        if(localUpdate | _modelAvg)
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
index 5b416d7..cfc3a20 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -19,11 +19,16 @@
 
 package org.apache.sysds.runtime.controlprogram.paramserv;
 
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map.Entry;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
-import org.apache.spark.Partitioner;
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.common.Types.FileFormat;
 import org.apache.sysds.common.Types.ValueType;
@@ -34,7 +39,6 @@ import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.hops.recompile.Recompiler;
 import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.DMLTranslator;
-import org.apache.sysds.parser.Statement;
 import org.apache.sysds.parser.StatementBlock;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.ForProgramBlock;
@@ -51,30 +55,14 @@ import 
org.apache.sysds.runtime.controlprogram.caching.FrameObject;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
-import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
-import 
org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionerSparkAggregator;
-import 
org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionerSparkMapper;
-import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.functionobjects.Plus;
 import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.meta.MetaDataFormat;
 import org.apache.sysds.runtime.util.ProgramConverter;
-import org.apache.sysds.utils.Statistics;
-import scala.Tuple2;
-
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.HashSet;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Map.Entry;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
 
 public class ParamservUtils {
 
@@ -373,82 +361,6 @@ public class ParamservUtils {
                return left.append(right, new MatrixBlock());
        }
 
-       /**
-        * Assemble the matrix of features and labels according to the rowID
-        *
-        * @param featuresRDD indexed features matrix block
-        * @param labelsRDD indexed labels matrix block
-        * @return Assembled rdd with rowID as key while matrix of features and 
labels as value (rowID {@literal ->} features, labels)
-        */
-       public static JavaPairRDD<Long, Tuple2<MatrixBlock, MatrixBlock>> 
assembleTrainingData(JavaPairRDD<MatrixIndexes, MatrixBlock> featuresRDD, 
JavaPairRDD<MatrixIndexes, MatrixBlock> labelsRDD) {
-               JavaPairRDD<Long, MatrixBlock> fRDD = groupMatrix(featuresRDD);
-               JavaPairRDD<Long, MatrixBlock> lRDD = groupMatrix(labelsRDD);
-               //TODO Add an additional physical operator which broadcasts the 
labels directly (broadcast join with features) if certain memory budgets are 
satisfied
-               return fRDD.join(lRDD);
-       }
-
-       private static JavaPairRDD<Long, MatrixBlock> 
groupMatrix(JavaPairRDD<MatrixIndexes, MatrixBlock> rdd) {
-               //TODO could use join and aggregation to avoid unnecessary 
shuffle introduced by reduceByKey
-               return rdd.mapToPair(input -> new 
Tuple2<>(input._1.getRowIndex(), new Tuple2<>(input._1.getColumnIndex(), 
input._2)))
-                       .aggregateByKey(new LinkedList<Tuple2<Long, 
MatrixBlock>>(),
-                               (list, input) -> {
-                                       list.add(input);
-                                       return list;
-                               }, 
-                               (l1, l2) -> {
-                                       l1.addAll(l2);
-                                       l1.sort((o1, o2) -> 
o1._1.compareTo(o2._1));
-                                       return l1;
-                               })
-                       .mapToPair(input -> {
-                               LinkedList<Tuple2<Long, MatrixBlock>> list = 
input._2;
-                               MatrixBlock result = list.get(0)._2;
-                               for (int i = 1; i < list.size(); i++) {
-                                       result = 
ParamservUtils.cbindMatrix(result, list.get(i)._2);
-                               }
-                               return new Tuple2<>(input._1, result);
-                       });
-       }
-
-       @SuppressWarnings("unchecked")
-       public static JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> 
doPartitionOnSpark(SparkExecutionContext sec, MatrixObject features, 
MatrixObject labels, Statement.PSScheme scheme, int workerNum) {
-               Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
-               // Get input RDD
-               JavaPairRDD<MatrixIndexes, MatrixBlock> featuresRDD = 
(JavaPairRDD<MatrixIndexes, MatrixBlock>)
-                       sec.getRDDHandleForMatrixObject(features, 
FileFormat.BINARY);
-               JavaPairRDD<MatrixIndexes, MatrixBlock> labelsRDD = 
(JavaPairRDD<MatrixIndexes, MatrixBlock>)
-                       sec.getRDDHandleForMatrixObject(labels, 
FileFormat.BINARY);
-
-               DataPartitionerSparkMapper mapper = new 
DataPartitionerSparkMapper(scheme, workerNum, sec, (int) features.getNumRows());
-               JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> result = 
ParamservUtils
-                       .assembleTrainingData(featuresRDD, labelsRDD) // 
Combine features and labels into a pair (rowBlockID => (features, labels))
-                       .flatMapToPair(mapper) // Do the data partitioning on 
spark (workerID => (rowBlockID, (single row features, single row labels))
-                       // Aggregate the partitioned matrix according to rowID 
for each worker
-                       // i.e. (workerID => ordered list[(rowBlockID, (single 
row features, single row labels)]
-                       .aggregateByKey(new LinkedList<Tuple2<Long, 
Tuple2<MatrixBlock, MatrixBlock>>>(), new Partitioner() {
-                               private static final long serialVersionUID = 
-7937781374718031224L;
-                               @Override
-                               public int getPartition(Object workerID) {
-                                       return (int) workerID;
-                               }
-                               @Override
-                               public int numPartitions() {
-                                       return workerNum;
-                               }
-                       }, (list, input) -> {
-                               list.add(input);
-                               return list;
-                       }, (l1, l2) -> {
-                               l1.addAll(l2);
-                               l1.sort((o1, o2) -> o1._1.compareTo(o2._1));
-                               return l1;
-                       })
-                       .mapToPair(new 
DataPartitionerSparkAggregator(features.getNumColumns(), 
labels.getNumColumns()));
-
-               if (DMLScript.STATISTICS)
-                       Statistics.accPSSetupTime((long) tSetup.stop());
-               return result;
-       }
 
        /**
         * Accumulate the given gradients into the accrued gradients
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkParamservUtils.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkParamservUtils.java
new file mode 100644
index 0000000..3a40ad5
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkParamservUtils.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv;
+
+import java.util.LinkedList;
+
+import org.apache.spark.Partitioner;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import 
org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionerSparkAggregator;
+import 
org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionerSparkMapper;
+import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysds.utils.Statistics;
+
+import scala.Tuple2;
+
+public class SparkParamservUtils {
+       
+       /**
+        * Assemble the matrix of features and labels according to the rowID
+        *
+        * @param featuresRDD indexed features matrix block
+        * @param labelsRDD indexed labels matrix block
+        * @return Assembled rdd with rowID as key while matrix of features and 
labels as value (rowID {@literal ->} features, labels)
+        */
+       public static JavaPairRDD<Long, Tuple2<MatrixBlock, MatrixBlock>> 
assembleTrainingData(JavaPairRDD<MatrixIndexes, MatrixBlock> featuresRDD, 
JavaPairRDD<MatrixIndexes, MatrixBlock> labelsRDD) {
+               JavaPairRDD<Long, MatrixBlock> fRDD = groupMatrix(featuresRDD);
+               JavaPairRDD<Long, MatrixBlock> lRDD = groupMatrix(labelsRDD);
+               //TODO Add an additional physical operator which broadcasts the 
labels directly (broadcast join with features) if certain memory budgets are 
satisfied
+               return fRDD.join(lRDD);
+       }
+
+       private static JavaPairRDD<Long, MatrixBlock> 
groupMatrix(JavaPairRDD<MatrixIndexes, MatrixBlock> rdd) {
+               //TODO could use join and aggregation to avoid unnecessary 
shuffle introduced by reduceByKey
+               return rdd.mapToPair(input -> new 
Tuple2<>(input._1.getRowIndex(), new Tuple2<>(input._1.getColumnIndex(), 
input._2)))
+                       .aggregateByKey(new LinkedList<Tuple2<Long, 
MatrixBlock>>(),
+                               (list, input) -> {
+                                       list.add(input);
+                                       return list;
+                               }, 
+                               (l1, l2) -> {
+                                       l1.addAll(l2);
+                                       l1.sort((o1, o2) -> 
o1._1.compareTo(o2._1));
+                                       return l1;
+                               })
+                       .mapToPair(input -> {
+                               LinkedList<Tuple2<Long, MatrixBlock>> list = 
input._2;
+                               MatrixBlock result = list.get(0)._2;
+                               for (int i = 1; i < list.size(); i++) {
+                                       result = 
ParamservUtils.cbindMatrix(result, list.get(i)._2);
+                               }
+                               return new Tuple2<>(input._1, result);
+                       });
+       }
+
+       @SuppressWarnings("unchecked")
+       public static JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> 
doPartitionOnSpark(SparkExecutionContext sec, MatrixObject features, 
MatrixObject labels, Statement.PSScheme scheme, int workerNum) {
+               Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
+               // Get input RDD
+               JavaPairRDD<MatrixIndexes, MatrixBlock> featuresRDD = 
(JavaPairRDD<MatrixIndexes, MatrixBlock>)
+                       sec.getRDDHandleForMatrixObject(features, 
FileFormat.BINARY);
+               JavaPairRDD<MatrixIndexes, MatrixBlock> labelsRDD = 
(JavaPairRDD<MatrixIndexes, MatrixBlock>)
+                       sec.getRDDHandleForMatrixObject(labels, 
FileFormat.BINARY);
+
+               DataPartitionerSparkMapper mapper = new 
DataPartitionerSparkMapper(scheme, workerNum, sec, (int) features.getNumRows());
+               JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> result = 
+                       assembleTrainingData(featuresRDD, labelsRDD) // Combine 
features and labels into a pair (rowBlockID => (features, labels))
+                       .flatMapToPair(mapper) // Do the data partitioning on 
spark (workerID => (rowBlockID, (single row features, single row labels))
+                       // Aggregate the partitioned matrix according to rowID 
for each worker
+                       // i.e. (workerID => ordered list[(rowBlockID, (single 
row features, single row labels)]
+                       .aggregateByKey(new LinkedList<Tuple2<Long, 
Tuple2<MatrixBlock, MatrixBlock>>>(), new Partitioner() {
+                               private static final long serialVersionUID = 
-7937781374718031224L;
+                               @Override
+                               public int getPartition(Object workerID) {
+                                       return (int) workerID;
+                               }
+                               @Override
+                               public int numPartitions() {
+                                       return workerNum;
+                               }
+                       }, (list, input) -> {
+                               list.add(input);
+                               return list;
+                       }, (l1, l2) -> {
+                               l1.addAll(l2);
+                               l1.sort((o1, o2) -> o1._1.compareTo(o2._1));
+                               return l1;
+                       })
+                       .mapToPair(new 
DataPartitionerSparkAggregator(features.getNumColumns(), 
labels.getNumColumns()));
+
+               if (DMLScript.STATISTICS)
+                       Statistics.accPSSetupTime((long) tSetup.stop());
+               return result;
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index f59a1b6..1408c66 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -19,38 +19,38 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.LinkedHashMap;
-import java.util.List;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-
 import static org.apache.sysds.parser.Statement.PS_AGGREGATION_FUN;
 import static org.apache.sysds.parser.Statement.PS_BATCH_SIZE;
 import static org.apache.sysds.parser.Statement.PS_EPOCHS;
 import static org.apache.sysds.parser.Statement.PS_FEATURES;
+import static org.apache.sysds.parser.Statement.PS_FED_RUNTIME_BALANCING;
+import static org.apache.sysds.parser.Statement.PS_FED_WEIGHTING;
 import static org.apache.sysds.parser.Statement.PS_FREQUENCY;
 import static org.apache.sysds.parser.Statement.PS_HYPER_PARAMS;
 import static org.apache.sysds.parser.Statement.PS_LABELS;
 import static org.apache.sysds.parser.Statement.PS_MODE;
 import static org.apache.sysds.parser.Statement.PS_MODEL;
-import static org.apache.sysds.parser.Statement.PS_NBATCHES;
 import static org.apache.sysds.parser.Statement.PS_MODELAVG;
+import static org.apache.sysds.parser.Statement.PS_NBATCHES;
 import static org.apache.sysds.parser.Statement.PS_PARALLELISM;
 import static org.apache.sysds.parser.Statement.PS_SCHEME;
+import static org.apache.sysds.parser.Statement.PS_SEED;
 import static org.apache.sysds.parser.Statement.PS_UPDATE_FUN;
 import static org.apache.sysds.parser.Statement.PS_UPDATE_TYPE;
-import static org.apache.sysds.parser.Statement.PS_FED_RUNTIME_BALANCING;
-import static org.apache.sysds.parser.Statement.PS_FED_WEIGHTING;
-import static org.apache.sysds.parser.Statement.PS_SEED;
 import static org.apache.sysds.parser.Statement.PS_VAL_FEATURES;
-import static org.apache.sysds.parser.Statement.PS_VAL_LABELS;
 import static org.apache.sysds.parser.Statement.PS_VAL_FUN;
+import static org.apache.sysds.parser.Statement.PS_VAL_LABELS;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
 import org.apache.commons.lang3.concurrent.BasicThreadFactory;
 import org.apache.commons.logging.Log;
@@ -60,12 +60,12 @@ import org.apache.spark.util.LongAccumulator;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.parser.Statement.FederatedPSScheme;
 import org.apache.sysds.parser.Statement.PSFrequency;
 import org.apache.sysds.parser.Statement.PSModeType;
+import org.apache.sysds.parser.Statement.PSRuntimeBalancing;
 import org.apache.sysds.parser.Statement.PSScheme;
-import org.apache.sysds.parser.Statement.FederatedPSScheme;
 import org.apache.sysds.parser.Statement.PSUpdateType;
-import org.apache.sysds.parser.Statement.PSRuntimeBalancing;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -78,6 +78,7 @@ import 
org.apache.sysds.runtime.controlprogram.paramserv.ParamServer;
 import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
 import org.apache.sysds.runtime.controlprogram.paramserv.SparkPSBody;
 import org.apache.sysds.runtime.controlprogram.paramserv.SparkPSWorker;
+import org.apache.sysds.runtime.controlprogram.paramserv.SparkParamservUtils;
 import 
org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionFederatedScheme;
 import 
org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme;
 import 
org.apache.sysds.runtime.controlprogram.paramserv.dp.FederatedDataPartitioner;
@@ -279,7 +280,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                MatrixObject features = 
sec.getMatrixObject(getParam(PS_FEATURES));
                MatrixObject labels = sec.getMatrixObject(getParam(PS_LABELS));
                try {
-                       ParamservUtils.doPartitionOnSpark(sec, features, 
labels, getScheme(), workerNum) // Do data partitioning
+                       SparkParamservUtils.doPartitionOnSpark(sec, features, 
labels, getScheme(), workerNum) // Do data partitioning
                                .foreach(worker); // Run remote workers
                } catch (Exception e) {
                        throw new DMLRuntimeException("Paramserv function 
failed: ", e);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/paramserv/SparkDataPartitionerTest.java
 
b/src/test/java/org/apache/sysds/test/functions/paramserv/SparkDataPartitionerTest.java
index 4b985a3..c4301c2 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/paramserv/SparkDataPartitionerTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/paramserv/SparkDataPartitionerTest.java
@@ -22,16 +22,17 @@ package org.apache.sysds.test.functions.paramserv;
 import java.util.Map;
 import java.util.stream.IntStream;
 
-import org.junit.Assert;
-import org.junit.Test;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.parser.Statement;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysds.runtime.controlprogram.paramserv.SparkParamservUtils;
 import 
org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.junit.Assert;
+import org.junit.Test;
 
 import scala.Tuple2;
 
@@ -48,7 +49,7 @@ public class SparkDataPartitionerTest extends 
BaseDataPartitionerTest {
 
        private Map<Integer, Tuple2<MatrixBlock, MatrixBlock>> 
doPartitioning(Statement.PSScheme scheme) {
                MatrixBlock[] mbs = generateData();
-               return ParamservUtils.doPartitionOnSpark(_sec, 
ParamservUtils.newMatrixObject(mbs[0]), ParamservUtils.newMatrixObject(mbs[1]), 
scheme, WORKER_NUM).collectAsMap();
+               return SparkParamservUtils.doPartitionOnSpark(_sec, 
ParamservUtils.newMatrixObject(mbs[0]), ParamservUtils.newMatrixObject(mbs[1]), 
scheme, WORKER_NUM).collectAsMap();
        }
 
        @Test

Reply via email to