Repository: incubator-systemml
Updated Branches:
  refs/heads/master 8c1e89e51 -> 0a61fe084


[SYSTEMML-1378] Native dataset support in parfor spark dp-execute, tests

See https://issues.apache.org/jira/browse/SYSTEMML-1378 for details.

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/0a61fe08
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/0a61fe08
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/0a61fe08

Branch: refs/heads/master
Commit: 0a61fe084094c32493d7fd5a013fed8ecd461acc
Parents: 8c1e89e
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Wed Mar 8 23:06:31 2017 -0800
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Wed Mar 8 23:22:15 2017 -0800

----------------------------------------------------------------------
 .../api/mlcontext/MLContextConversionUtil.java  |   7 +-
 .../parfor/RemoteDPParForSpark.java             | 126 +++++++++++++--
 .../parfor/util/PairWritableBlock.java          |  10 +-
 .../spark/data/BroadcastObject.java             |   5 +-
 .../instructions/spark/data/DatasetObject.java  |  53 ++++++
 .../instructions/spark/data/LineageObject.java  |   5 +-
 .../instructions/spark/data/RDDObject.java      |   6 +-
 .../spark/utils/RDDConverterUtils.java          |   2 +-
 .../mlcontext/MLContextParforDatasetTest.java   | 161 +++++++++++++++++++
 9 files changed, 353 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a61fe08/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java 
b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
index 0225ea8..c496325 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
@@ -49,6 +49,7 @@ import 
org.apache.sysml.runtime.controlprogram.caching.CacheException;
 import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.instructions.spark.data.DatasetObject;
 import org.apache.sysml.runtime.instructions.spark.data.RDDObject;
 import 
org.apache.sysml.runtime.instructions.spark.functions.ConvertStringToLongTextPair;
 import 
org.apache.sysml.runtime.instructions.spark.functions.CopyTextInputFunction;
@@ -328,7 +329,11 @@ public class MLContextConversionUtil {
        {
                matrixMetadata = (matrixMetadata!=null) ? matrixMetadata : new 
MatrixMetadata();
                JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlock = 
dataFrameToMatrixBinaryBlocks(dataFrame, matrixMetadata);
-               return binaryBlocksToMatrixObject(variableName, binaryBlock, 
matrixMetadata, false);
+               MatrixObject mo = binaryBlocksToMatrixObject(variableName, 
binaryBlock, matrixMetadata, false);
+               //keep lineage of original dataset to allow bypassing binary 
block conversion if possible
+               mo.getRDDHandle().addLineageChild(new DatasetObject(dataFrame, 
variableName, 
+                               
isDataFrameWithIDColumn(matrixMetadata),isVectorBasedDataFrame(matrixMetadata)));
+               return mo;
        }
 
        /**

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a61fe08/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteDPParForSpark.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteDPParForSpark.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteDPParForSpark.java
index afc6c95..3a27b66 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteDPParForSpark.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/RemoteDPParForSpark.java
@@ -29,6 +29,11 @@ import org.apache.hadoop.io.Writable;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.ml.linalg.SparseVector;
+import org.apache.spark.ml.linalg.Vector;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
 import org.apache.spark.util.LongAccumulator;
 
 import scala.Tuple2;
@@ -40,13 +45,17 @@ import 
org.apache.sysml.runtime.controlprogram.ParForProgramBlock.PDataPartition
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.controlprogram.parfor.util.PairWritableBlock;
+import org.apache.sysml.runtime.instructions.spark.data.DatasetObject;
+import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
 import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
+import 
org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils.DataFrameExtractIDFunction;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
-import org.apache.sysml.runtime.matrix.MatrixDimensionsMetaData;
 import org.apache.sysml.runtime.matrix.data.InputInfo;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysml.runtime.matrix.data.OutputInfo;
+import org.apache.sysml.runtime.util.UtilFunctions;
 import org.apache.sysml.utils.Statistics;
 
 /**
@@ -71,9 +80,8 @@ public class RemoteDPParForSpark
                JavaSparkContext sc = sec.getSparkContext();
                
                //prepare input parameters
-               MatrixDimensionsMetaData md = (MatrixDimensionsMetaData) 
input.getMetaData();
-               MatrixCharacteristics mc = md.getMatrixCharacteristics();
-               InputInfo ii = InputInfo.BinaryBlockInputInfo;
+               MatrixObject mo = sec.getMatrixObject(matrixvar);
+               MatrixCharacteristics mc = mo.getMatrixCharacteristics();
 
                //initialize accumulators for tasks/iterations, and inputs
                JavaPairRDD<MatrixIndexes,MatrixBlock> in = 
sec.getBinaryBlockRDDHandleForVariable(matrixvar);
@@ -86,11 +94,10 @@ public class RemoteDPParForSpark
                int numReducers2 = Math.max(numReducers, Math.min(numParts, 
numParts2));
                
                //core parfor datapartition-execute (w/ or w/o shuffle, 
depending on data characteristics)
-               DataPartitionerRemoteSparkMapper dpfun = new 
DataPartitionerRemoteSparkMapper(mc, ii, oi, dpf);
                RemoteDPParForSparkWorker efun = new 
RemoteDPParForSparkWorker(program, clsMap, 
                                matrixvar, itervar, enableCPCaching, mc, 
tSparseCol, dpf, oi, aTasks, aIters);
-               JavaPairRDD<Long,Writable> tmp = in.flatMapToPair(dpfun);
-               List<Tuple2<Long,String>> out = (requiresGrouping(dpf, mc) ?
+               JavaPairRDD<Long,Writable> tmp = getPartitionedInput(sec, 
matrixvar, oi, dpf);
+               List<Tuple2<Long,String>> out = (requiresGrouping(dpf, mo) ?
                                tmp.groupByKey(numReducers2) : tmp.map(new 
PseudoGrouping()) )
                                   .mapPartitionsToPair(efun)  //execute parfor 
tasks, incl cleanup
                           .collect();                 //get output handles 
@@ -113,10 +120,57 @@ public class RemoteDPParForSpark
                return ret;
        }
        
+       private static JavaPairRDD<Long, Writable> 
getPartitionedInput(SparkExecutionContext sec, 
+                       String matrixvar, OutputInfo oi, PDataPartitionFormat 
dpf) 
+               throws DMLRuntimeException 
+       {
+               InputInfo ii = InputInfo.BinaryBlockInputInfo;
+               MatrixObject mo = sec.getMatrixObject(matrixvar);
+               MatrixCharacteristics mc = mo.getMatrixCharacteristics();
+
+               //leverage existing dataset (w/o shuffling for reblock and data 
partitioning) 
+               //NOTE: there will always be a checkpoint rdd on top of the 
input rdd and the dataset
+               if( hasInputDataSet(dpf, mo) )
+               {
+                       DatasetObject dsObj = (DatasetObject)mo.getRDDHandle()
+                                       
.getLineageChilds().get(0).getLineageChilds().get(0);
+                       Dataset<Row> in = dsObj.getDataset();
+                       
+                       //construct or reuse row ids
+                       JavaPairRDD<Row, Long> prepinput = dsObj.containsID() ?
+                                       in.javaRDD().mapToPair(new 
DataFrameExtractIDFunction(
+                                               
in.schema().fieldIndex(RDDConverterUtils.DF_ID_COLUMN))) :
+                                       in.javaRDD().zipWithIndex(); //zip row 
index
+                       
+                       //convert row to row in matrix block format 
+                       return prepinput.mapToPair(new 
DataFrameToRowBinaryBlockFunction(
+                                       mc.getCols(), dsObj.isVectorBased(), 
dsObj.containsID()));
+               }
+               //default binary block input rdd
+               else
+               {
+                       //get input rdd and data partitioning 
+                       JavaPairRDD<MatrixIndexes,MatrixBlock> in = 
sec.getBinaryBlockRDDHandleForVariable(matrixvar);
+                       DataPartitionerRemoteSparkMapper dpfun = new 
DataPartitionerRemoteSparkMapper(mc, ii, oi, dpf);
+                       return in.flatMapToPair(dpfun);
+               }
+       } 
+       
        //determines if given input matrix requires grouping of partial 
partition slices
-       private static boolean requiresGrouping(PDataPartitionFormat dpf, 
MatrixCharacteristics mc) {
-               return (dpf == PDataPartitionFormat.ROW_WISE && 
mc.getNumColBlocks() > 1)
-                       || (dpf == PDataPartitionFormat.COLUMN_WISE && 
mc.getNumRowBlocks() > 1);
+       private static boolean requiresGrouping(PDataPartitionFormat dpf, 
MatrixObject mo) {
+               MatrixCharacteristics mc = mo.getMatrixCharacteristics();
+               return ((dpf == PDataPartitionFormat.ROW_WISE && 
mc.getNumColBlocks() > 1)
+                               || (dpf == PDataPartitionFormat.COLUMN_WISE && 
mc.getNumRowBlocks() > 1))
+                       && !hasInputDataSet(dpf, mo);
+       }
+       
+       //determines if given input matrix wraps input data set applicable to 
direct processing
+       private static boolean hasInputDataSet(PDataPartitionFormat dpf, 
MatrixObject mo) {
+               return (dpf == PDataPartitionFormat.ROW_WISE 
+                       && mo.getRDDHandle().isCheckpointRDD()
+                       && mo.getRDDHandle().getLineageChilds().size()==1
+                       && 
mo.getRDDHandle().getLineageChilds().get(0).getLineageChilds().size()==1
+                       && 
mo.getRDDHandle().getLineageChilds().get(0).getLineageChilds().get(0) 
instanceof DatasetObject);
        }
        
        //function to map data partition output to parfor input signature 
without grouping
@@ -128,4 +182,56 @@ public class RemoteDPParForSpark
                        return new Tuple2<Long, Iterable<Writable>>(arg0._1(), 
Collections.singletonList(arg0._2()));
                }
        }
+       
+       //function to map dataset rows to rows in binary block representation
+       private static class DataFrameToRowBinaryBlockFunction implements 
PairFunction<Tuple2<Row,Long>,Long,Writable> 
+       {
+               private static final long serialVersionUID = 
-3162404379379461523L;
+               
+               private final long _clen;
+               private final boolean _containsID;
+               private final boolean _isVector;
+               
+               public DataFrameToRowBinaryBlockFunction(long clen, boolean 
containsID, boolean isVector) {
+                       _clen = clen;
+                       _containsID = containsID;
+                       _isVector = isVector;
+               }
+               
+               @Override
+               public Tuple2<Long, Writable> call(Tuple2<Row, Long> arg0) 
+                       throws Exception 
+               {
+                       long rowix = arg0._2() + 1;
+                       
+                       //process row data
+                       int off = _containsID ? 1: 0;
+                       Object obj = _isVector ? arg0._1().get(off) : arg0._1();
+                       boolean sparse = (obj instanceof SparseVector);
+                       MatrixBlock mb = new MatrixBlock(1, (int)_clen, sparse);
+                       
+                       if( _isVector ) {
+                               Vector vect = (Vector) obj;
+                               if( vect instanceof SparseVector ) {
+                                       SparseVector svect = (SparseVector) 
vect;
+                                       int lnnz = svect.numNonzeros();
+                                       for( int k=0; k<lnnz; k++ )
+                                               mb.appendValue(0, 
svect.indices()[k], svect.values()[k]);
+                               }
+                               else { //dense
+                                       for( int j=0; j<_clen; j++ )
+                                               mb.appendValue(0, j, 
vect.apply(j));    
+                               }
+                       }
+                       else { //row
+                               Row row = (Row) obj;
+                               for( int j=off; j<off+_clen; j++ )
+                                       mb.appendValue(0, j-off, 
UtilFunctions.getDouble(row.get(j)));
+                       }
+                       mb.examSparsity();
+                       
+                       return new Tuple2<Long, Writable>(rowix, 
+                                       new PairWritableBlock(new 
MatrixIndexes(1,1),mb));
+               }
+       }
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a61fe08/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/util/PairWritableBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/util/PairWritableBlock.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/util/PairWritableBlock.java
index 1e3a992..40907d5 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/util/PairWritableBlock.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/parfor/util/PairWritableBlock.java
@@ -36,12 +36,20 @@ import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
  */
 public class PairWritableBlock implements Writable, Serializable
 {
-
        private static final long serialVersionUID = -6022511967446089164L;
 
        public MatrixIndexes indexes;
        public MatrixBlock block;
        
+       public PairWritableBlock() {
+       
+       }
+       
+       public PairWritableBlock(MatrixIndexes ix, MatrixBlock mb) {
+               indexes = ix;
+               block = mb;
+       }
+       
        @Override
        public void readFields(DataInput in) throws IOException 
        {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a61fe08/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java
index 97b840d..4316b15 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/BroadcastObject.java
@@ -29,10 +29,9 @@ public class BroadcastObject<T extends CacheBlock> extends 
LineageObject
        //soft reference storage for graceful cleanup in case of memory pressure
        protected SoftReference<PartitionedBroadcast<T>> _bcHandle = null;
        
-       public BroadcastObject( PartitionedBroadcast<T> bvar, String varName )
-       {
+       public BroadcastObject( PartitionedBroadcast<T> bvar, String varName ) {
+               super(varName);
                _bcHandle = new SoftReference<PartitionedBroadcast<T>>(bvar);
-               _varName = varName;
        }
 
        @SuppressWarnings("rawtypes")

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a61fe08/src/main/java/org/apache/sysml/runtime/instructions/spark/data/DatasetObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/DatasetObject.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/DatasetObject.java
new file mode 100644
index 0000000..5030136
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/DatasetObject.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.spark.data;
+
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+
+public class DatasetObject extends LineageObject
+{
+       private final Dataset<Row> _dsHandle;
+       private final boolean _isVector;
+       private final boolean _containsID;
+       
+       public DatasetObject( Dataset<Row> dsvar, String varName) {
+               this(dsvar, varName, true, true);
+       }
+       
+       public DatasetObject( Dataset<Row> dsvar, String varName, boolean 
isVector, boolean containsID) {
+               super(varName);
+               _dsHandle = dsvar;
+               _isVector = isVector;
+               _containsID = containsID;
+       }
+
+       public Dataset<Row> getDataset() {
+               return _dsHandle;
+       }
+       
+       public boolean isVectorBased() {
+               return _isVector;
+       }
+       
+       public boolean containsID() {
+               return _containsID;
+       }
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a61fe08/src/main/java/org/apache/sysml/runtime/instructions/spark/data/LineageObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/LineageObject.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/LineageObject.java
index 4b550cc..58629e6 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/LineageObject.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/LineageObject.java
@@ -29,12 +29,13 @@ public abstract class LineageObject
        //basic lineage information
        protected int _numRef = -1;
        protected List<LineageObject> _childs = null;
-       protected String _varName = null;
+       protected final String _varName;
        
        //N:1 back reference to matrix/frame object
        protected CacheableData<?> _cd = null;
        
-       protected LineageObject() {
+       protected LineageObject(String varName) {
+               _varName = varName;
                _numRef = 0;
                _childs = new ArrayList<LineageObject>();
        }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a61fe08/src/main/java/org/apache/sysml/runtime/instructions/spark/data/RDDObject.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/RDDObject.java 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/RDDObject.java
index 5b896bc..fd5dad3 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/RDDObject.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/RDDObject.java
@@ -23,7 +23,6 @@ import org.apache.spark.api.java.JavaPairRDD;
 
 public class RDDObject extends LineageObject
 {
-
        private JavaPairRDD<?,?> _rddHandle = null;
        
        //meta data on origin of given rdd handle
@@ -31,10 +30,9 @@ public class RDDObject extends LineageObject
        private boolean _hdfsfile = false;     //created from hdfs file
        private String  _hdfsFname = null;     //hdfs filename, if created from 
hdfs.  
        
-       public RDDObject( JavaPairRDD<?,?> rddvar, String varName)
-       {
+       public RDDObject( JavaPairRDD<?,?> rddvar, String varName) {
+               super(varName);
                _rddHandle = rddvar;
-               _varName = varName;
        }
 
        public JavaPairRDD<?,?> getRDD()

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a61fe08/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtils.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtils.java
index f370890..134a071 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtils.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtils.java
@@ -1151,7 +1151,7 @@ public class RDDConverterUtils
                }
        }
 
-       protected static class DataFrameExtractIDFunction implements 
PairFunction<Row, Row,Long> 
+       public static class DataFrameExtractIDFunction implements 
PairFunction<Row, Row,Long> 
        {
                private static final long serialVersionUID = 
7438855241666363433L;
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/0a61fe08/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java
 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java
new file mode 100644
index 0000000..41b8d16
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.mlcontext;
+
+import static org.apache.sysml.api.mlcontext.ScriptFactory.dml;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import org.apache.sysml.api.mlcontext.MLContext;
+import org.apache.sysml.api.mlcontext.MLResults;
+import org.apache.sysml.api.mlcontext.MatrixFormat;
+import org.apache.sysml.api.mlcontext.MatrixMetadata;
+import org.apache.sysml.api.mlcontext.Script;
+import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysml.runtime.util.DataConverter;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+
+public class MLContextParforDatasetTest extends AutomatedTestBase 
+{
+       protected final static String TEST_DIR = 
"org/apache/sysml/api/mlcontext";
+       protected final static String TEST_NAME = "MLContext";
+
+       private final static int rows = 100;
+       private final static int cols = 1600;
+       private final static double sparsity = 0.7;
+       
+       private static SparkConf conf;
+       private static JavaSparkContext sc;
+       private static MLContext ml;
+
+       @BeforeClass
+       public static void setUpClass() {
+               if (conf == null)
+                       conf = SparkExecutionContext.createSystemMLSparkConf()
+                               .setAppName("MLContextTest").setMaster("local");
+               if (sc == null)
+                       sc = new JavaSparkContext(conf);
+               ml = new MLContext(sc);
+       }
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_DIR, TEST_NAME);
+               getAndLoadTestConfiguration(TEST_NAME);
+       }
+
+
+       @Test
+       public void testParforDatasetVector() {
+               runMLContextParforDatasetTest(true, false);
+       }
+       
+       @Test
+       public void testParforDatasetRow() {
+               runMLContextParforDatasetTest(false, false);
+       }
+       
+       @Test
+       public void testParforDatasetVectorUnkownDims() {
+               runMLContextParforDatasetTest(true, true);
+       }
+       
+       @Test
+       public void testParforDatasetRowUnknownDims() {
+               runMLContextParforDatasetTest(false, true);
+       }
+       
+       private void runMLContextParforDatasetTest(boolean vector, boolean 
unknownDims) 
+       {
+               //modify memory budget to trigger fused datapartition-execute
+               long oldmem = InfrastructureAnalyzer.getLocalMaxMemory();
+               InfrastructureAnalyzer.setLocalMaxMemory(1*1024*1024); //1MB
+               
+               try
+               {
+                       double[][] A = getRandomMatrix(rows, cols, -10, 10, 
sparsity, 76543); 
+                       MatrixBlock mbA = 
DataConverter.convertToMatrixBlock(A); 
+                       int blksz = ConfigurationManager.getBlocksize();
+                       MatrixCharacteristics mc1 = new 
MatrixCharacteristics(rows, cols, blksz, blksz, mbA.getNonZeros());
+                       MatrixCharacteristics mc2 = unknownDims ? new 
MatrixCharacteristics() : new MatrixCharacteristics(mc1);
+                       
+                       //create input dataset
+                       SparkSession sparkSession = 
SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
+                       JavaPairRDD<MatrixIndexes,MatrixBlock> in = 
SparkExecutionContext.toMatrixJavaPairRDD(sc, mbA, blksz, blksz);
+                       Dataset<Row> df = 
RDDConverterUtils.binaryBlockToDataFrame(sparkSession, in, mc1, vector);
+                       MatrixMetadata mm = new MatrixMetadata(vector ? 
MatrixFormat.DF_VECTOR_WITH_INDEX : MatrixFormat.DF_DOUBLES_WITH_INDEX);
+                       mm.setMatrixCharacteristics(mc2);
+                       
+                       String s = "v = matrix(0, rows=nrow(X), cols=1)"
+                                       + "parfor(i in 1:nrow(X), log=DEBUG) {"
+                                       + "   v[i, ] = sum(X[i, ]);"
+                                       + "}"
+                                       + "r = sum(v);";
+                       Script script = dml(s).in("X", df, mm).out("r");
+                       MLResults results = ml.execute(script);
+                       
+                       //compare aggregation results
+                       double sum1 = results.getDouble("r");
+                       double sum2 = mbA.sum();
+                       
+                       TestUtils.compareScalars(sum2, sum1, 0.000001);
+               }
+               catch(Exception ex) {
+                       throw new RuntimeException(ex);
+               }
+               finally {
+                       InfrastructureAnalyzer.setLocalMaxMemory(oldmem);       
+               }
+       }
+
+       @After
+       public void tearDown() {
+               super.tearDown();
+       }
+
+       @AfterClass
+       public static void tearDownClass() {
+               // stop spark context to allow single jvm tests (otherwise the
+               // next test that tries to create a SparkContext would fail)
+               sc.stop();
+               sc = null;
+               conf = null;
+
+               // clear status mlcontext and spark exec context
+               ml.close();
+               ml = null;
+       }
+}

Reply via email to