This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 60d16c474b76ecb4d45d3cd6e36580672fc6f1da Author: Matthias Boehm <[email protected]> AuthorDate: Fri Sep 3 00:12:53 2021 +0200 [SYSTEMDS-3118] Extended parfor parser/runtime (frame result variables) This patch extends parfor by support for frame results variables during dependency analysis and merge of worker result variables. So far, this captures only in-memory frame result merge. --- .../apache/sysds/parser/ParForStatementBlock.java | 3 +- .../runtime/controlprogram/ParForProgramBlock.java | 92 +++++++++-------- .../controlprogram/caching/FrameObject.java | 7 ++ .../runtime/controlprogram/parfor/ResultMerge.java | 90 ++-------------- .../parfor/ResultMergeFrameLocalMemory.java | 114 +++++++++++++++++++++ .../parfor/ResultMergeLocalAutomatic.java | 4 +- .../parfor/ResultMergeLocalFile.java | 2 +- .../parfor/ResultMergeLocalMemory.java | 2 +- .../{ResultMerge.java => ResultMergeMatrix.java} | 50 ++------- .../parfor/ResultMergeRemoteSpark.java | 2 +- .../parfor/ResultMergeRemoteSparkWCompare.java | 2 +- .../parfor/ParForDependencyAnalysisTest.java | 10 +- ...est.java => ParForListFrameResultVarsTest.java} | 22 +++- src/test/scripts/component/parfor/parfor54e.dml | 26 +++++ src/test/scripts/component/parfor/parfor54f.dml | 26 +++++ .../functions/parfor/parfor_frameResults.dml | 32 ++++++ 16 files changed, 306 insertions(+), 178 deletions(-) diff --git a/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java b/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java index 74c55c5..607641c 100644 --- a/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java +++ b/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java @@ -677,7 +677,7 @@ public class ParForStatementBlock extends ForStatementBlock for(DataIdentifier write : datsUpdated) { if( !c._var.equals( write.getName() ) ) continue; - if( cdt != DataType.MATRIX && cdt != DataType.LIST ) { + if( cdt != DataType.MATRIX && cdt != DataType.FRAME && cdt != DataType.LIST ) { //cannot infer type, need to exit (conservative approach) throw new LanguageException("PARFOR loop dependency analysis: cannot check " + "for dependencies due to unknown datatype of var '"+c._var+"': "+cdt.name()+"."); @@ -716,6 +716,7 @@ public class ParForStatementBlock extends ForStatementBlock return; } else if( (cdt == DataType.MATRIX && dat2dt == DataType.MATRIX) + || (cdt == DataType.FRAME && dat2dt == DataType.FRAME ) || (cdt == DataType.LIST && dat2dt == DataType.LIST ) ) { boolean invalid = false; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java index 25d49bb..42ab8bc 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java @@ -38,6 +38,7 @@ import org.apache.sysds.parser.StatementBlock; import org.apache.sysds.parser.VariableSet; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.CacheableData; +import org.apache.sysds.runtime.controlprogram.caching.FrameObject; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; @@ -51,6 +52,7 @@ import org.apache.sysds.runtime.controlprogram.parfor.RemoteDPParForSpark; import org.apache.sysds.runtime.controlprogram.parfor.RemoteParForJobReturn; import org.apache.sysds.runtime.controlprogram.parfor.RemoteParForSpark; import org.apache.sysds.runtime.controlprogram.parfor.ResultMerge; +import org.apache.sysds.runtime.controlprogram.parfor.ResultMergeFrameLocalMemory; import org.apache.sysds.runtime.controlprogram.parfor.ResultMergeLocalAutomatic; import org.apache.sysds.runtime.controlprogram.parfor.ResultMergeLocalFile; import org.apache.sysds.runtime.controlprogram.parfor.ResultMergeLocalMemory; @@ -1056,9 +1058,9 @@ public class ParForProgramBlock extends ForProgramBlock * @param out output matrix * @param in array of input matrix objects */ - private static void cleanWorkerResultVariables(ExecutionContext ec, MatrixObject out, MatrixObject[] in, boolean parallel) { + private static void cleanWorkerResultVariables(ExecutionContext ec, CacheableData<?> out, CacheableData<?>[] in, boolean parallel) { //check for empty inputs (no iterations executed) - Stream<MatrixObject> results = Arrays.stream(in).filter(m -> m!=null && m!=out); + Stream<CacheableData<?>> results = Arrays.stream(in).filter(m -> m!=null && m!=out); //perform cleanup (parallel to mitigate file deletion bottlenecks) (parallel ? results.parallel() : results) .forEach(m -> ec.cleanupCacheableData(m)); @@ -1307,33 +1309,41 @@ public class ParForProgramBlock extends ForProgramBlock return dp; } - private ResultMerge createResultMerge( PResultMerge prm, MatrixObject out, MatrixObject[] in, String fname, boolean accum, ExecutionContext ec ) + private ResultMerge<?> createResultMerge( PResultMerge prm, + CacheableData<?> out, CacheableData<?>[] in, String fname, boolean accum, ExecutionContext ec ) { - ResultMerge rm = null; + ResultMerge<?> rm = null; - //create result merge implementation (determine degree of parallelism - //only for spark to avoid unnecessary spark context creation) - switch( prm ) - { - case LOCAL_MEM: - rm = new ResultMergeLocalMemory( out, in, fname, accum ); - break; - case LOCAL_FILE: - rm = new ResultMergeLocalFile( out, in, fname, accum ); - break; - case LOCAL_AUTOMATIC: - rm = new ResultMergeLocalAutomatic( out, in, fname, accum ); - break; - case REMOTE_SPARK: - int numMap = Math.max(_numThreads, - SparkExecutionContext.getDefaultParallelism(true)); - int numRed = numMap; //equal map/reduce - rm = new ResultMergeRemoteSpark( out, in, - fname, accum, ec, numMap, numRed ); - break; - - default: - throw new DMLRuntimeException("Undefined result merge: '" +prm.toString()+"'."); + if( out instanceof FrameObject ) { + rm = new ResultMergeFrameLocalMemory((FrameObject)out, (FrameObject[])in, fname, accum); + } + else if(out instanceof MatrixObject) { + //create result merge implementation (determine degree of parallelism + //only for spark to avoid unnecessary spark context creation) + switch( prm ) + { + case LOCAL_MEM: + rm = new ResultMergeLocalMemory( (MatrixObject)out, (MatrixObject[])in, fname, accum ); + break; + case LOCAL_FILE: + rm = new ResultMergeLocalFile( (MatrixObject)out, (MatrixObject[])in, fname, accum ); + break; + case LOCAL_AUTOMATIC: + rm = new ResultMergeLocalAutomatic( (MatrixObject)out, (MatrixObject[])in, fname, accum ); + break; + case REMOTE_SPARK: + int numMap = Math.max(_numThreads, + SparkExecutionContext.getDefaultParallelism(true)); + int numRed = numMap; //equal map/reduce + rm = new ResultMergeRemoteSpark( (MatrixObject)out, + (MatrixObject[])in, fname, accum, ec, numMap, numRed ); + break; + default: + throw new DMLRuntimeException("Undefined result merge: '" +prm.toString()+"'."); + } + } + else { + throw new DMLRuntimeException("Unsupported result merge data: "+out.getClass().getSimpleName()); } return rm; @@ -1437,14 +1447,15 @@ public class ParForProgramBlock extends ForProgramBlock { Data dat = ec.getVariable(var._name); - if( dat instanceof MatrixObject ) //robustness scalars + if( dat instanceof MatrixObject | dat instanceof FrameObject ) { - MatrixObject out = (MatrixObject) dat; - MatrixObject[] in = Arrays.stream(results).map(vars -> - vars.get(var._name)).toArray(MatrixObject[]::new); + CacheableData<?> out = (CacheableData<?>) dat; + Stream<Object> tmp = Arrays.stream(results).map(vars -> vars.get(var._name)); + CacheableData<?>[] in = (dat instanceof MatrixObject) ? + tmp.toArray(MatrixObject[]::new) : tmp.toArray(FrameObject[]::new); String fname = constructResultMergeFileName(); - ResultMerge rm = createResultMerge(_resultMerge, out, in, fname, var._isAccum, ec); - MatrixObject outNew = USE_PARALLEL_RESULT_MERGE ? + ResultMerge<?> rm = createResultMerge(_resultMerge, out, in, fname, var._isAccum, ec); + CacheableData<?> outNew = USE_PARALLEL_RESULT_MERGE ? rm.executeParallelMerge(_numThreads) : rm.executeSerialMerge(); @@ -1653,18 +1664,19 @@ public class ParForProgramBlock extends ForProgramBlock if( var == LocalTaskQueue.NO_MORE_TASKS ) // task queue closed (no more tasks) break; - MatrixObject out = null; + CacheableData<?> out = null; synchronized( _ec.getVariables() ){ - out = _ec.getMatrixObject(var._name); + out = _ec.getCacheableData(var._name); } - MatrixObject[] in = new MatrixObject[ _refVars.length ]; - for( int i=0; i< _refVars.length; i++ ) - in[i] = (MatrixObject) _refVars[i].get( var._name ); + Stream<Object> tmp = Arrays.stream(_refVars).map(vars -> vars.get(var._name)); + CacheableData<?>[] in = (out instanceof MatrixObject) ? + tmp.toArray(MatrixObject[]::new) : tmp.toArray(FrameObject[]::new); + String fname = constructResultMergeFileName(); - ResultMerge rm = createResultMerge(_resultMerge, out, in, fname, var._isAccum, _ec); - MatrixObject outNew = null; + ResultMerge<?> rm = createResultMerge(_resultMerge, out, in, fname, var._isAccum, _ec); + CacheableData<?> outNew = null; if( USE_PARALLEL_RESULT_MERGE ) outNew = rm.executeParallelMerge( _numThreads ); else diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java index 5eae986..4485388 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java @@ -41,6 +41,7 @@ import org.apache.sysds.runtime.lineage.LineageItem; import org.apache.sysds.runtime.lineage.LineageRecomputeUtils; import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.meta.DataCharacteristics; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.meta.MetaData; import org.apache.sysds.runtime.meta.MetaDataFormat; import org.apache.sysds.runtime.util.UtilFunctions; @@ -86,6 +87,12 @@ public class FrameObject extends CacheableData<FrameBlock> */ public FrameObject(FrameObject fo) { super(fo); + + MetaDataFormat metaOld = (MetaDataFormat) fo.getMetaData(); + _metaData = new MetaDataFormat( + new MatrixCharacteristics(metaOld.getDataCharacteristics()), + metaOld.getFileFormat()); + _schema = fo._schema.clone(); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMerge.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMerge.java index 18b09a1..b69ba96 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMerge.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMerge.java @@ -21,42 +21,33 @@ package org.apache.sysds.runtime.controlprogram.parfor; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; -import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; import org.apache.sysds.runtime.instructions.InstructionUtils; -import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import java.io.Serializable; -import java.util.List; -/** - * Due to independence of all iterations, any result has the following properties: - * (1) non local var, (2) matrix object, and (3) completely independent. - * These properties allow us to realize result merging in parallel without any synchronization. - * - */ -public abstract class ResultMerge implements Serializable +public abstract class ResultMerge<T extends CacheableData<?>> implements Serializable { //note: this class needs to be serializable to ensure that all attributes of //ResultMergeRemoteSparkWCompare are included in the task closure - private static final long serialVersionUID = 2620430969346516677L; + private static final long serialVersionUID = -6756689640511059030L; protected static final Log LOG = LogFactory.getLog(ResultMerge.class.getName()); protected static final String NAME_SUFFIX = "_rm"; protected static final BinaryOperator PLUS = InstructionUtils.parseBinaryOperator("+"); //inputs to result merge - protected MatrixObject _output = null; - protected MatrixObject[] _inputs = null; - protected String _outputFName = null; - protected boolean _isAccum = false; + protected T _output = null; + protected T[] _inputs = null; + protected String _outputFName = null; + protected boolean _isAccum = false; protected ResultMerge( ) { //do nothing } - public ResultMerge( MatrixObject out, MatrixObject[] in, String outputFilename, boolean accum ) { + public ResultMerge( T out, T[] in, String outputFilename, boolean accum ) { _output = out; _inputs = in; _outputFName = outputFilename; @@ -70,7 +61,7 @@ public abstract class ResultMerge implements Serializable * * @return output (merged) matrix */ - public abstract MatrixObject executeSerialMerge(); + public abstract T executeSerialMerge(); /** * Merge all given input matrices in parallel into the given output matrix. @@ -80,67 +71,6 @@ public abstract class ResultMerge implements Serializable * @param par degree of parallelism * @return output (merged) matrix */ - public abstract MatrixObject executeParallelMerge( int par ); - - protected void mergeWithoutComp( MatrixBlock out, MatrixBlock in, boolean appendOnly ) { - mergeWithoutComp(out, in, appendOnly, false); - } + public abstract T executeParallelMerge(int par); - protected void mergeWithoutComp( MatrixBlock out, MatrixBlock in, boolean appendOnly, boolean par ) { - //pass through to matrix block operations - if( _isAccum ) - out.binaryOperationsInPlace(PLUS, in); - else - out.merge(in, appendOnly, par); - } - - /** - * NOTE: append only not applicable for wiht compare because output must be populated with - * initial state of matrix - with append, this would result in duplicates. - * - * @param out output matrix block - * @param in input matrix block - * @param compare ? - */ - protected void mergeWithComp( MatrixBlock out, MatrixBlock in, DenseBlock compare ) - { - //Notes for result correctness: - // * Always iterate over entire block in order to compare all values - // (using sparse iterator would miss values set to 0) - // * Explicit NaN awareness because for cases were original matrix contains - // NaNs, since NaN != NaN, otherwise we would potentially overwrite results - // * For the case of accumulation, we add out += (new-old) to ensure correct results - // because all inputs have the old values replicated - - if( in.isEmptyBlock(false) ) { - if( _isAccum ) return; //nothing to do - for( int i=0; i<in.getNumRows(); i++ ) - for( int j=0; j<in.getNumColumns(); j++ ) - if( compare.get(i, j) != 0 ) - out.quickSetValue(i, j, 0); - } - else { //SPARSE/DENSE - int rows = in.getNumRows(); - int cols = in.getNumColumns(); - for( int i=0; i<rows; i++ ) - for( int j=0; j<cols; j++ ) { - double valOld = compare.get(i,j); - double valNew = in.quickGetValue(i,j); //input value - if( (valNew != valOld && !Double.isNaN(valNew) ) //for changed values - || Double.isNaN(valNew) != Double.isNaN(valOld) ) //NaN awareness - { - double value = !_isAccum ? valNew : - (out.quickGetValue(i, j) + (valNew - valOld)); - out.quickSetValue(i, j, value); - } - } - } - } - - protected long computeNonZeros( MatrixObject out, List<MatrixObject> in ) { - //sum of nnz of input (worker result) - output var existing nnz - long outNNZ = out.getDataCharacteristics().getNonZeros(); - return outNNZ - in.size() * outNNZ + in.stream() - .mapToLong(m -> m.getDataCharacteristics().getNonZeros()).sum(); - } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeFrameLocalMemory.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeFrameLocalMemory.java new file mode 100644 index 0000000..cd2d99f --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeFrameLocalMemory.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.controlprogram.parfor; + +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.FrameObject; +import org.apache.sysds.runtime.matrix.data.FrameBlock; +import org.apache.sysds.runtime.util.UtilFunctions; + +public class ResultMergeFrameLocalMemory extends ResultMerge<FrameObject> +{ + private static final long serialVersionUID = 549739254879310540L; + + public ResultMergeFrameLocalMemory(FrameObject out, FrameObject[] in, String outputFilename, boolean accum) { + super( out, in, outputFilename, accum ); + } + + @Override + public FrameObject executeSerialMerge() + { + FrameObject foNew = null; //always create new matrix object (required for nested parallelism) + + if( LOG.isTraceEnabled() ) + LOG.trace("ResultMerge (local, in-memory): Execute serial merge for output " + +_output.hashCode()+" (fname="+_output.getFileName()+")"); + + try + { + //get old and new output frame blocks + FrameBlock outFB = _output.acquireRead(); + FrameBlock outFBNew = new FrameBlock(outFB); + + //create compare matrix if required (existing data in result) + FrameBlock compare = outFB; + int rlen = compare.getNumRows(); + int clen = compare.getNumColumns(); + + //serial merge all inputs + boolean flagMerged = false; + for( FrameObject in : _inputs ) + { + //check for empty inputs (no iterations executed) + if( in != null && in != _output ) + { + if( LOG.isTraceEnabled() ) + LOG.trace("ResultMergeFrame (local, in-memory): Merge input "+in.hashCode()+" (fname="+in.getFileName()+")"); + + //read/pin input_i + FrameBlock inMB = in.acquireRead(); + + //core merge + for(int j=0; j<clen; j++) { + ValueType vt = compare.getSchema()[j]; + for(int i=0; i<rlen; i++) { + Object val1 = compare.get(i, j); + Object val2 = inMB.get(i, j); + if( UtilFunctions.compareTo(vt, val1, val2) != 0 ) + outFBNew.set(i, j, val2); + } + } + + //unpin and clear in-memory input_i + in.release(); + in.clearData(); + flagMerged = true; + } + } + + //create output and release old output + foNew = flagMerged ? createNewFrameObject(_output, outFBNew) : _output; + _output.release(); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + + //LOG.trace("ResultMerge (local, in-memory): Executed serial merge for output "+_output.getVarName()+" (fname="+_output.getFileName()+") in "+time.stop()+"ms"); + + return foNew; + } + + @Override + public FrameObject executeParallelMerge( int par ) { + if( LOG.isTraceEnabled() ) + LOG.trace("ResultMerge (local, in-memory): Execute parallel (par="+par+") " + + "merge for output "+_output.hashCode()+" (fname="+_output.getFileName()+")"); + return executeSerialMerge(); + } + + private static FrameObject createNewFrameObject( FrameObject foOld, FrameBlock dataNew ) { + FrameObject ret = new FrameObject(foOld); + ret.acquireModify(dataNew); + ret.release(); + return ret; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalAutomatic.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalAutomatic.java index 92ec8f9..ea5195d 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalAutomatic.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalAutomatic.java @@ -26,11 +26,11 @@ import org.apache.sysds.runtime.controlprogram.parfor.opt.OptimizerRuleBased; import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing; import org.apache.sysds.runtime.meta.DataCharacteristics; -public class ResultMergeLocalAutomatic extends ResultMerge +public class ResultMergeLocalAutomatic extends ResultMergeMatrix { private static final long serialVersionUID = 1600893100602101732L; - private ResultMerge _rm = null; + private ResultMergeMatrix _rm = null; public ResultMergeLocalAutomatic( MatrixObject out, MatrixObject[] in, String outputFilename, boolean accum ) { super( out, in, outputFilename, accum ); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalFile.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalFile.java index db3d741..441ba3e 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalFile.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalFile.java @@ -67,7 +67,7 @@ import java.util.Map.Entry; * NOTE: file merge typically used due to memory constraints - parallel merge would increase the memory * consumption again. */ -public class ResultMergeLocalFile extends ResultMerge +public class ResultMergeLocalFile extends ResultMergeMatrix { private static final long serialVersionUID = -6905893742840020489L; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalMemory.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalMemory.java index 5c604dd..f422423 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalMemory.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalMemory.java @@ -39,7 +39,7 @@ import java.util.ArrayList; * * */ -public class ResultMergeLocalMemory extends ResultMerge +public class ResultMergeLocalMemory extends ResultMergeMatrix { private static final long serialVersionUID = -3543612508601511701L; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMerge.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeMatrix.java similarity index 67% copy from src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMerge.java copy to src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeMatrix.java index 18b09a1..7d0776c 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMerge.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeMatrix.java @@ -19,13 +19,9 @@ package org.apache.sysds.runtime.controlprogram.parfor; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.data.DenseBlock; -import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import java.io.Serializable; import java.util.List; @@ -36,52 +32,18 @@ import java.util.List; * These properties allow us to realize result merging in parallel without any synchronization. * */ -public abstract class ResultMerge implements Serializable +public abstract class ResultMergeMatrix extends ResultMerge<MatrixObject> implements Serializable { - //note: this class needs to be serializable to ensure that all attributes of - //ResultMergeRemoteSparkWCompare are included in the task closure - private static final long serialVersionUID = 2620430969346516677L; + private static final long serialVersionUID = 5319002218804570071L; - protected static final Log LOG = LogFactory.getLog(ResultMerge.class.getName()); - protected static final String NAME_SUFFIX = "_rm"; - protected static final BinaryOperator PLUS = InstructionUtils.parseBinaryOperator("+"); - - //inputs to result merge - protected MatrixObject _output = null; - protected MatrixObject[] _inputs = null; - protected String _outputFName = null; - protected boolean _isAccum = false; - - protected ResultMerge( ) { - //do nothing + public ResultMergeMatrix() { + super(); } - public ResultMerge( MatrixObject out, MatrixObject[] in, String outputFilename, boolean accum ) { - _output = out; - _inputs = in; - _outputFName = outputFilename; - _isAccum = accum; + public ResultMergeMatrix(MatrixObject out, MatrixObject[] in, String outputFilename, boolean accum) { + super(out, in, outputFilename, accum); } - /** - * Merge all given input matrices sequentially into the given output matrix. - * The required space in-memory is the size of the output matrix plus the size - * of one input matrix at a time. - * - * @return output (merged) matrix - */ - public abstract MatrixObject executeSerialMerge(); - - /** - * Merge all given input matrices in parallel into the given output matrix. - * The required space in-memory is the size of the output matrix plus the size - * of all input matrices. - * - * @param par degree of parallelism - * @return output (merged) matrix - */ - public abstract MatrixObject executeParallelMerge( int par ); - protected void mergeWithoutComp( MatrixBlock out, MatrixBlock in, boolean appendOnly ) { mergeWithoutComp(out, in, appendOnly, false); } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSpark.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSpark.java index 8a70ecf..6f33225 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSpark.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSpark.java @@ -44,7 +44,7 @@ import org.apache.sysds.utils.Statistics; import java.util.Arrays; -public class ResultMergeRemoteSpark extends ResultMerge +public class ResultMergeRemoteSpark extends ResultMergeMatrix { private static final long serialVersionUID = -6924566953903424820L; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSparkWCompare.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSparkWCompare.java index a152c52..6b8d424 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSparkWCompare.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSparkWCompare.java @@ -31,7 +31,7 @@ import org.apache.sysds.runtime.util.DataConverter; import scala.Tuple2; -public class ResultMergeRemoteSparkWCompare extends ResultMerge implements PairFunction<Tuple2<MatrixIndexes,Tuple2<Iterable<MatrixBlock>,MatrixBlock>>, MatrixIndexes, MatrixBlock> +public class ResultMergeRemoteSparkWCompare extends ResultMergeMatrix implements PairFunction<Tuple2<MatrixIndexes,Tuple2<Iterable<MatrixBlock>,MatrixBlock>>, MatrixIndexes, MatrixBlock> { private static final long serialVersionUID = -5970805069405942836L; diff --git a/src/test/java/org/apache/sysds/test/component/parfor/ParForDependencyAnalysisTest.java b/src/test/java/org/apache/sysds/test/component/parfor/ParForDependencyAnalysisTest.java index 04f575a..cf7c71a 100644 --- a/src/test/java/org/apache/sysds/test/component/parfor/ParForDependencyAnalysisTest.java +++ b/src/test/java/org/apache/sysds/test/component/parfor/ParForDependencyAnalysisTest.java @@ -66,8 +66,8 @@ import org.apache.sysds.test.TestConfiguration; * 49a: dep, 49b: dep * * accumulators * 53a: no, 53b dep, 53c dep, 53d dep, 53e dep - * * lists - * 54a: no, 54b: no, 54c: dep, 54d: dep + * * lists/frames + * 54a: no, 54b: no, 54c: dep, 54d: dep, 54e: no-dep, 54f: dep * * negative loop increment * 55a: no, 55b: yes */ @@ -328,6 +328,12 @@ public class ParForDependencyAnalysisTest extends AutomatedTestBase public void testDependencyAnalysis54d() { runTest("parfor54d.dml", true); } @Test + public void testDependencyAnalysis54e() { runTest("parfor54e.dml", false); } + + @Test + public void testDependencyAnalysis54f() { runTest("parfor54f.dml", true); } + + @Test public void testDependencyAnalysis55a() { runTest("parfor55a.dml", false); } @Test diff --git a/src/test/java/org/apache/sysds/test/functions/parfor/misc/ParForListResultVarsTest.java b/src/test/java/org/apache/sysds/test/functions/parfor/misc/ParForListFrameResultVarsTest.java similarity index 75% rename from src/test/java/org/apache/sysds/test/functions/parfor/misc/ParForListResultVarsTest.java rename to src/test/java/org/apache/sysds/test/functions/parfor/misc/ParForListFrameResultVarsTest.java index fc952e1..a206781 100644 --- a/src/test/java/org/apache/sysds/test/functions/parfor/misc/ParForListResultVarsTest.java +++ b/src/test/java/org/apache/sysds/test/functions/parfor/misc/ParForListFrameResultVarsTest.java @@ -25,16 +25,18 @@ import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; -public class ParForListResultVarsTest extends AutomatedTestBase +public class ParForListFrameResultVarsTest extends AutomatedTestBase { private final static String TEST_DIR = "functions/parfor/"; private final static String TEST_NAME1 = "parfor_listResults"; - private final static String TEST_CLASS_DIR = TEST_DIR + ParForListResultVarsTest.class.getSimpleName() + "/"; + private final static String TEST_NAME2 = "parfor_frameResults"; + + private final static String TEST_CLASS_DIR = TEST_DIR + ParForListFrameResultVarsTest.class.getSimpleName() + "/"; @Override public void setUp() { - addTestConfiguration(TEST_NAME1, - new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"})); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"R"})); } @Test @@ -47,11 +49,21 @@ public class ParForListResultVarsTest extends AutomatedTestBase runListResultVarTest(TEST_NAME1, 35, 10); } + @Test + public void testParForFrameResult1a() { + runListResultVarTest(TEST_NAME2, 2, 1); + } + + @Test + public void testParForFrameResult1b() { + runListResultVarTest(TEST_NAME2, 35, 10); + } + private void runListResultVarTest(String testName, int rows, int cols) { loadTestConfiguration(getTestConfiguration(testName)); String HOME = SCRIPT_DIR + TEST_DIR; - fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + fullDMLScriptName = HOME + testName + ".dml"; programArgs = new String[]{"-explain","-args", String.valueOf(rows), String.valueOf(cols), output("R") }; diff --git a/src/test/scripts/component/parfor/parfor54e.dml b/src/test/scripts/component/parfor/parfor54e.dml new file mode 100644 index 0000000..70837e9 --- /dev/null +++ b/src/test/scripts/component/parfor/parfor54e.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +A = rbind(as.frame("a"), as.frame("b"), as.frame("c")); +parfor( i in 1:nrow(A) ) + A[i,1] = as.frame(as.scalar(A[i,1])+"-"+i); +print(toString(A)); diff --git a/src/test/scripts/component/parfor/parfor54f.dml b/src/test/scripts/component/parfor/parfor54f.dml new file mode 100644 index 0000000..23bcf44 --- /dev/null +++ b/src/test/scripts/component/parfor/parfor54f.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +A = rbind(as.frame("a"), as.frame("b"), as.frame("c")); +parfor( i in 1:nrow(A) ) + A[i,1] = as.frame(as.scalar(A[1,1])+"-"+i); +print(toString(A)); diff --git a/src/test/scripts/functions/parfor/parfor_frameResults.dml b/src/test/scripts/functions/parfor/parfor_frameResults.dml new file mode 100644 index 0000000..b1a54be --- /dev/null +++ b/src/test/scripts/functions/parfor/parfor_frameResults.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +F = as.frame(matrix(0,7,1)); + +parfor(i in 1:nrow(F)) + F[i,1] = as.frame(rowMeans(as.matrix(F[i]))+i); + +R1 = matrix(0,0,1) +for(i in 1:length(F)) + R1 = rbind(R1, as.matrix(F[i,1])); + +R = as.matrix(sum(R1==seq(1,7))); +write(R, $3);
