Repository: incubator-systemml Updated Branches: refs/heads/master 8c1e89e51 -> 0a61fe084
[SYSTEMML-1378] Native dataset support in parfor spark dp-execute, tests See https://issues.apache.org/jira/browse/SYSTEMML-1378 for details. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/0a61fe08 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/0a61fe08 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/0a61fe08 Branch: refs/heads/master Commit: 0a61fe084094c32493d7fd5a013fed8ecd461acc Parents: 8c1e89e Author: Matthias Boehm <mboe...@gmail.com> Authored: Wed Mar 8 23:06:31 2017 -0800 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Wed Mar 8 23:22:15 2017 -0800 ---------------------------------------------------------------------- .../api/mlcontext/MLContextConversionUtil.java | 7 +- .../parfor/RemoteDPParForSpark.java | 126 +++++++++++++-- .../parfor/util/PairWritableBlock.java | 10 +- .../spark/data/BroadcastObject.java | 5 +- .../instructions/spark/data/DatasetObject.java | 53 ++++++ .../instructions/spark/data/LineageObject.java | 5 +- .../instructions/spark/data/RDDObject.java | 6 +- .../spark/utils/RDDConverterUtils.java | 2 +- .../mlcontext/MLContextParforDatasetTest.java | 161 +++++++++++++++++++ 9 files changed, 353 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a61fe08/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java index 0225ea8..c496325 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java @@ -49,6 +49,7 @@ import org.apache.sysml.runtime.controlprogram.caching.CacheException; import org.apache.sysml.runtime.controlprogram.caching.FrameObject; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysml.runtime.instructions.spark.data.DatasetObject; import org.apache.sysml.runtime.instructions.spark.data.RDDObject; import org.apache.sysml.runtime.instructions.spark.functions.ConvertStringToLongTextPair; import org.apache.sysml.runtime.instructions.spark.functions.CopyTextInputFunction; @@ -328,7 +329,11 @@ public class MLContextConversionUtil { { matrixMetadata = (matrixMetadata!=null) ? matrixMetadata : new MatrixMetadata(); JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlock = dataFrameToMatrixBinaryBlocks(dataFrame, matrixMetadata); - return binaryBlocksToMatrixObject(variableName, binaryBlock, matrixMetadata, false); + MatrixObject mo = binaryBlocksToMatrixObject(variableName, binaryBlock, matrixMetadata, false); + //keep lineage of original dataset to allow bypassing binary block conversion if possible + mo.getRDDHandle().addLineageChild(new DatasetObject(dataFrame, variableName, + isDataFrameWithIDColumn(matrixMetadata),isVectorBasedDataFrame(matrixMetadata))); + return mo; } /** http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a61fe08/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteDPParForSpark.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteDPParForSpark.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteDPParForSpark.java index afc6c95..3a27b66 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteDPParForSpark.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteDPParForSpark.java @@ -29,6 +29,11 @@ import org.apache.hadoop.io.Writable; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.ml.linalg.SparseVector; +import org.apache.spark.ml.linalg.Vector; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.util.LongAccumulator; import scala.Tuple2; @@ -40,13 +45,17 @@ import org.apache.sysml.runtime.controlprogram.ParForProgramBlock.PDataPartition import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysml.runtime.controlprogram.parfor.util.PairWritableBlock; +import org.apache.sysml.runtime.instructions.spark.data.DatasetObject; +import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils; import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils; +import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils.DataFrameExtractIDFunction; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; -import org.apache.sysml.runtime.matrix.MatrixDimensionsMetaData; import org.apache.sysml.runtime.matrix.data.InputInfo; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.matrix.data.OutputInfo; +import org.apache.sysml.runtime.util.UtilFunctions; import org.apache.sysml.utils.Statistics; /** @@ -71,9 +80,8 @@ public class RemoteDPParForSpark JavaSparkContext sc = sec.getSparkContext(); //prepare input parameters - MatrixDimensionsMetaData md = (MatrixDimensionsMetaData) input.getMetaData(); - MatrixCharacteristics mc = md.getMatrixCharacteristics(); - InputInfo ii = InputInfo.BinaryBlockInputInfo; + MatrixObject mo = sec.getMatrixObject(matrixvar); + MatrixCharacteristics mc = mo.getMatrixCharacteristics(); //initialize accumulators for tasks/iterations, and inputs JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable(matrixvar); @@ -86,11 +94,10 @@ public class RemoteDPParForSpark int numReducers2 = Math.max(numReducers, Math.min(numParts, numParts2)); //core parfor datapartition-execute (w/ or w/o shuffle, depending on data characteristics) - DataPartitionerRemoteSparkMapper dpfun = new DataPartitionerRemoteSparkMapper(mc, ii, oi, dpf); RemoteDPParForSparkWorker efun = new RemoteDPParForSparkWorker(program, clsMap, matrixvar, itervar, enableCPCaching, mc, tSparseCol, dpf, oi, aTasks, aIters); - JavaPairRDD<Long,Writable> tmp = in.flatMapToPair(dpfun); - List<Tuple2<Long,String>> out = (requiresGrouping(dpf, mc) ? + JavaPairRDD<Long,Writable> tmp = getPartitionedInput(sec, matrixvar, oi, dpf); + List<Tuple2<Long,String>> out = (requiresGrouping(dpf, mo) ? tmp.groupByKey(numReducers2) : tmp.map(new PseudoGrouping()) ) .mapPartitionsToPair(efun) //execute parfor tasks, incl cleanup .collect(); //get output handles @@ -113,10 +120,57 @@ public class RemoteDPParForSpark return ret; } + private static JavaPairRDD<Long, Writable> getPartitionedInput(SparkExecutionContext sec, + String matrixvar, OutputInfo oi, PDataPartitionFormat dpf) + throws DMLRuntimeException + { + InputInfo ii = InputInfo.BinaryBlockInputInfo; + MatrixObject mo = sec.getMatrixObject(matrixvar); + MatrixCharacteristics mc = mo.getMatrixCharacteristics(); + + //leverage existing dataset (w/o shuffling for reblock and data partitioning) + //NOTE: there will always be a checkpoint rdd on top of the input rdd and the dataset + if( hasInputDataSet(dpf, mo) ) + { + DatasetObject dsObj = (DatasetObject)mo.getRDDHandle() + .getLineageChilds().get(0).getLineageChilds().get(0); + Dataset<Row> in = dsObj.getDataset(); + + //construct or reuse row ids + JavaPairRDD<Row, Long> prepinput = dsObj.containsID() ? + in.javaRDD().mapToPair(new DataFrameExtractIDFunction( + in.schema().fieldIndex(RDDConverterUtils.DF_ID_COLUMN))) : + in.javaRDD().zipWithIndex(); //zip row index + + //convert row to row in matrix block format + return prepinput.mapToPair(new DataFrameToRowBinaryBlockFunction( + mc.getCols(), dsObj.isVectorBased(), dsObj.containsID())); + } + //default binary block input rdd + else + { + //get input rdd and data partitioning + JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable(matrixvar); + DataPartitionerRemoteSparkMapper dpfun = new DataPartitionerRemoteSparkMapper(mc, ii, oi, dpf); + return in.flatMapToPair(dpfun); + } + } + //determines if given input matrix requires grouping of partial partition slices - private static boolean requiresGrouping(PDataPartitionFormat dpf, MatrixCharacteristics mc) { - return (dpf == PDataPartitionFormat.ROW_WISE && mc.getNumColBlocks() > 1) - || (dpf == PDataPartitionFormat.COLUMN_WISE && mc.getNumRowBlocks() > 1); + private static boolean requiresGrouping(PDataPartitionFormat dpf, MatrixObject mo) { + MatrixCharacteristics mc = mo.getMatrixCharacteristics(); + return ((dpf == PDataPartitionFormat.ROW_WISE && mc.getNumColBlocks() > 1) + || (dpf == PDataPartitionFormat.COLUMN_WISE && mc.getNumRowBlocks() > 1)) + && !hasInputDataSet(dpf, mo); + } + + //determines if given input matrix wraps input data set applicable to direct processing + private static boolean hasInputDataSet(PDataPartitionFormat dpf, MatrixObject mo) { + return (dpf == PDataPartitionFormat.ROW_WISE + && mo.getRDDHandle().isCheckpointRDD() + && mo.getRDDHandle().getLineageChilds().size()==1 + && mo.getRDDHandle().getLineageChilds().get(0).getLineageChilds().size()==1 + && mo.getRDDHandle().getLineageChilds().get(0).getLineageChilds().get(0) instanceof DatasetObject); } //function to map data partition output to parfor input signature without grouping @@ -128,4 +182,56 @@ public class RemoteDPParForSpark return new Tuple2<Long, Iterable<Writable>>(arg0._1(), Collections.singletonList(arg0._2())); } } + + //function to map dataset rows to rows in binary block representation + private static class DataFrameToRowBinaryBlockFunction implements PairFunction<Tuple2<Row,Long>,Long,Writable> + { + private static final long serialVersionUID = -3162404379379461523L; + + private final long _clen; + private final boolean _containsID; + private final boolean _isVector; + + public DataFrameToRowBinaryBlockFunction(long clen, boolean containsID, boolean isVector) { + _clen = clen; + _containsID = containsID; + _isVector = isVector; + } + + @Override + public Tuple2<Long, Writable> call(Tuple2<Row, Long> arg0) + throws Exception + { + long rowix = arg0._2() + 1; + + //process row data + int off = _containsID ? 1: 0; + Object obj = _isVector ? arg0._1().get(off) : arg0._1(); + boolean sparse = (obj instanceof SparseVector); + MatrixBlock mb = new MatrixBlock(1, (int)_clen, sparse); + + if( _isVector ) { + Vector vect = (Vector) obj; + if( vect instanceof SparseVector ) { + SparseVector svect = (SparseVector) vect; + int lnnz = svect.numNonzeros(); + for( int k=0; k<lnnz; k++ ) + mb.appendValue(0, svect.indices()[k], svect.values()[k]); + } + else { //dense + for( int j=0; j<_clen; j++ ) + mb.appendValue(0, j, vect.apply(j)); + } + } + else { //row + Row row = (Row) obj; + for( int j=off; j<off+_clen; j++ ) + mb.appendValue(0, j-off, UtilFunctions.getDouble(row.get(j))); + } + mb.examSparsity(); + + return new Tuple2<Long, Writable>(rowix, + new PairWritableBlock(new MatrixIndexes(1,1),mb)); + } + } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a61fe08/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/util/PairWritableBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/util/PairWritableBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/util/PairWritableBlock.java index 1e3a992..40907d5 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/util/PairWritableBlock.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/util/PairWritableBlock.java @@ -36,12 +36,20 @@ import org.apache.sysml.runtime.matrix.data.MatrixIndexes; */ public class PairWritableBlock implements Writable, Serializable { - private static final long serialVersionUID = -6022511967446089164L; public MatrixIndexes indexes; public MatrixBlock block; + public PairWritableBlock() { + + } + + public PairWritableBlock(MatrixIndexes ix, MatrixBlock mb) { + indexes = ix; + block = mb; + } + @Override public void readFields(DataInput in) throws IOException { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a61fe08/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java index 97b840d..4316b15 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java @@ -29,10 +29,9 @@ public class BroadcastObject<T extends CacheBlock> extends LineageObject //soft reference storage for graceful cleanup in case of memory pressure protected SoftReference<PartitionedBroadcast<T>> _bcHandle = null; - public BroadcastObject( PartitionedBroadcast<T> bvar, String varName ) - { + public BroadcastObject( PartitionedBroadcast<T> bvar, String varName ) { + super(varName); _bcHandle = new SoftReference<PartitionedBroadcast<T>>(bvar); - _varName = varName; } @SuppressWarnings("rawtypes") http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a61fe08/src/main/java/org/apache/sysml/runtime/instructions/spark/data/DatasetObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/DatasetObject.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/DatasetObject.java new file mode 100644 index 0000000..5030136 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/DatasetObject.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.spark.data; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; + +public class DatasetObject extends LineageObject +{ + private final Dataset<Row> _dsHandle; + private final boolean _isVector; + private final boolean _containsID; + + public DatasetObject( Dataset<Row> dsvar, String varName) { + this(dsvar, varName, true, true); + } + + public DatasetObject( Dataset<Row> dsvar, String varName, boolean isVector, boolean containsID) { + super(varName); + _dsHandle = dsvar; + _isVector = isVector; + _containsID = containsID; + } + + public Dataset<Row> getDataset() { + return _dsHandle; + } + + public boolean isVectorBased() { + return _isVector; + } + + public boolean containsID() { + return _containsID; + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a61fe08/src/main/java/org/apache/sysml/runtime/instructions/spark/data/LineageObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/LineageObject.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/LineageObject.java index 4b550cc..58629e6 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/LineageObject.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/LineageObject.java @@ -29,12 +29,13 @@ public abstract class LineageObject //basic lineage information protected int _numRef = -1; protected List<LineageObject> _childs = null; - protected String _varName = null; + protected final String _varName; //N:1 back reference to matrix/frame object protected CacheableData<?> _cd = null; - protected LineageObject() { + protected LineageObject(String varName) { + _varName = varName; _numRef = 0; _childs = new ArrayList<LineageObject>(); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a61fe08/src/main/java/org/apache/sysml/runtime/instructions/spark/data/RDDObject.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/RDDObject.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/RDDObject.java index 5b896bc..fd5dad3 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/RDDObject.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/RDDObject.java @@ -23,7 +23,6 @@ import org.apache.spark.api.java.JavaPairRDD; public class RDDObject extends LineageObject { - private JavaPairRDD<?,?> _rddHandle = null; //meta data on origin of given rdd handle @@ -31,10 +30,9 @@ public class RDDObject extends LineageObject private boolean _hdfsfile = false; //created from hdfs file private String _hdfsFname = null; //hdfs filename, if created from hdfs. - public RDDObject( JavaPairRDD<?,?> rddvar, String varName) - { + public RDDObject( JavaPairRDD<?,?> rddvar, String varName) { + super(varName); _rddHandle = rddvar; - _varName = varName; } public JavaPairRDD<?,?> getRDD() http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a61fe08/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtils.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtils.java index f370890..134a071 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtils.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtils.java @@ -1151,7 +1151,7 @@ public class RDDConverterUtils } } - protected static class DataFrameExtractIDFunction implements PairFunction<Row, Row,Long> + public static class DataFrameExtractIDFunction implements PairFunction<Row, Row,Long> { private static final long serialVersionUID = 7438855241666363433L; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a61fe08/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java new file mode 100644 index 0000000..41b8d16 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.mlcontext; + +import static org.apache.sysml.api.mlcontext.ScriptFactory.dml; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.sysml.api.mlcontext.MLContext; +import org.apache.sysml.api.mlcontext.MLResults; +import org.apache.sysml.api.mlcontext.MatrixFormat; +import org.apache.sysml.api.mlcontext.MatrixMetadata; +import org.apache.sysml.api.mlcontext.Script; +import org.apache.sysml.conf.ConfigurationManager; +import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; +import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.MatrixIndexes; +import org.apache.sysml.runtime.util.DataConverter; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.utils.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + + +public class MLContextParforDatasetTest extends AutomatedTestBase +{ + protected final static String TEST_DIR = "org/apache/sysml/api/mlcontext"; + protected final static String TEST_NAME = "MLContext"; + + private final static int rows = 100; + private final static int cols = 1600; + private final static double sparsity = 0.7; + + private static SparkConf conf; + private static JavaSparkContext sc; + private static MLContext ml; + + @BeforeClass + public static void setUpClass() { + if (conf == null) + conf = SparkExecutionContext.createSystemMLSparkConf() + .setAppName("MLContextTest").setMaster("local"); + if (sc == null) + sc = new JavaSparkContext(conf); + ml = new MLContext(sc); + } + + @Override + public void setUp() { + addTestConfiguration(TEST_DIR, TEST_NAME); + getAndLoadTestConfiguration(TEST_NAME); + } + + + @Test + public void testParforDatasetVector() { + runMLContextParforDatasetTest(true, false); + } + + @Test + public void testParforDatasetRow() { + runMLContextParforDatasetTest(false, false); + } + + @Test + public void testParforDatasetVectorUnkownDims() { + runMLContextParforDatasetTest(true, true); + } + + @Test + public void testParforDatasetRowUnknownDims() { + runMLContextParforDatasetTest(false, true); + } + + private void runMLContextParforDatasetTest(boolean vector, boolean unknownDims) + { + //modify memory budget to trigger fused datapartition-execute + long oldmem = InfrastructureAnalyzer.getLocalMaxMemory(); + InfrastructureAnalyzer.setLocalMaxMemory(1*1024*1024); //1MB + + try + { + double[][] A = getRandomMatrix(rows, cols, -10, 10, sparsity, 76543); + MatrixBlock mbA = DataConverter.convertToMatrixBlock(A); + int blksz = ConfigurationManager.getBlocksize(); + MatrixCharacteristics mc1 = new MatrixCharacteristics(rows, cols, blksz, blksz, mbA.getNonZeros()); + MatrixCharacteristics mc2 = unknownDims ? new MatrixCharacteristics() : new MatrixCharacteristics(mc1); + + //create input dataset + SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate(); + JavaPairRDD<MatrixIndexes,MatrixBlock> in = SparkExecutionContext.toMatrixJavaPairRDD(sc, mbA, blksz, blksz); + Dataset<Row> df = RDDConverterUtils.binaryBlockToDataFrame(sparkSession, in, mc1, vector); + MatrixMetadata mm = new MatrixMetadata(vector ? MatrixFormat.DF_VECTOR_WITH_INDEX : MatrixFormat.DF_DOUBLES_WITH_INDEX); + mm.setMatrixCharacteristics(mc2); + + String s = "v = matrix(0, rows=nrow(X), cols=1)" + + "parfor(i in 1:nrow(X), log=DEBUG) {" + + " v[i, ] = sum(X[i, ]);" + + "}" + + "r = sum(v);"; + Script script = dml(s).in("X", df, mm).out("r"); + MLResults results = ml.execute(script); + + //compare aggregation results + double sum1 = results.getDouble("r"); + double sum2 = mbA.sum(); + + TestUtils.compareScalars(sum2, sum1, 0.000001); + } + catch(Exception ex) { + throw new RuntimeException(ex); + } + finally { + InfrastructureAnalyzer.setLocalMaxMemory(oldmem); + } + } + + @After + public void tearDown() { + super.tearDown(); + } + + @AfterClass + public static void tearDownClass() { + // stop spark context to allow single jvm tests (otherwise the + // next test that tries to create a SparkContext would fail) + sc.stop(); + sc = null; + conf = null; + + // clear status mlcontext and spark exec context + ml.close(); + ml = null; + } +}