This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 33de453e87523253c6a1e853804a2bbc40021f04 Author: Matthias Boehm <[email protected]> AuthorDate: Thu Aug 10 17:32:45 2023 +0200 [SYSTEMDS-3606] Performance shuffle-based spark quaternary operations This patch significantly improves the performance of shuffle-based spark quaternary operations, where more than one input is an RDD (too large to broadcast). Instead of replicating the factor blocks, we now use custom join keys enabling spark to perform more efficient 1:M joins. With appropriate function abstractions, the implementation also got simpler and thus, easier to maintain. On the scenario mentioned in the JIRA task, the original implementation did not finish any task of the first shuffle phase after >9000s, while with the new implementation the entire script (with two shuffle-based quaternary operators) finishes in 1276s. Here are the stats: SystemDS Statistics: Total elapsed time: 1276.917 sec. Total compilation time: 2.338 sec. Total execution time: 1274.578 sec. Number of compiled Spark inst: 4. Number of executed Spark inst: 4. Cache hits (Mem/Li/WB/FS/HDFS): 13/2/0/1/0. Cache writes (Li/WB/FS/HDFS): 4/6/4/1. Cache times (ACQr/m, RLS, EXP): 1209.517/0.001/10.926/8.589 sec. HOP DAGs recompiled (PRED, SB): 0/1. HOP DAGs recompile time: 0.006 sec. Functions recompiled: 1. Functions recompile time: 0.011 sec. Spark ctx create time (lazy): 19.302 sec. Spark trans counts (par,bc,col):0/3/1. Spark trans times (par,bc,col): 0.000/13.671/644.719 secs. Spark async. count (pf,bc,op): 0/0/0. Total JIT compile time: 73.677 sec. Total JVM GC count: 188. Total JVM GC time: 23.182 sec. Heavy hitter instructions: 1 m_pnmf 714.304 1 2 r' 653.012 5 3 uak+ 560.027 2 4 sp_redwdivmm 42.446 2 5 rand 9.414 4 6 * 3.544 1 7 / 3.491 1 8 uack+ 3.466 1 9 uark+ 2.146 1 10 rmvar 0.246 15 --- .../spark/QuaternarySPInstruction.java | 174 +++++++-------------- 1 file changed, 59 insertions(+), 115 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/QuaternarySPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/QuaternarySPInstruction.java index 9c9a063d31..dbb71c724f 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/QuaternarySPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/QuaternarySPInstruction.java @@ -19,6 +19,7 @@ package org.apache.sysds.runtime.instructions.spark; +import org.apache.commons.lang3.ArrayUtils; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; @@ -44,7 +45,6 @@ import org.apache.sysds.runtime.instructions.cp.DoubleObject; import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator; import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast; import org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction; -import org.apache.sysds.runtime.instructions.spark.functions.ReplicateBlockFunction; import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; @@ -198,11 +198,6 @@ public class QuaternarySPInstruction extends ComputationSPInstruction { JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec.getBinaryMatrixBlockRDDHandleForVariable( input1.getName() ); JavaPairRDD<MatrixIndexes, MatrixBlock> out = null; - DataCharacteristics inMc = sec.getDataCharacteristics( input1.getName() ); - long rlen = inMc.getRows(); - long clen = inMc.getCols(); - int blen = inMc.getBlocksize(); - //pre-filter empty blocks (ultra-sparse matrices) for full aggregates //(map/redwsloss, map/redwcemm); safe because theses ops produce a scalar if( qop.wtype1 != null || qop.wtype4 != null ) { @@ -237,42 +232,25 @@ public class QuaternarySPInstruction extends ComputationSPInstruction { JavaPairRDD<MatrixIndexes,MatrixBlock> inW = (qop.hasFourInputs() && !_input4.isLiteral()) ? sec.getBinaryMatrixBlockRDDHandleForVariable( _input4.getName() ) : null; - //preparation of transposed and replicated U - if( inU != null ) - inU = inU.flatMapToPair(new ReplicateBlockFunction(clen, blen, true)); - - //preparation of transposed and replicated V - if( inV != null ) - inV = inV.mapToPair(new TransposeFactorIndexesFunction()) - .flatMapToPair(new ReplicateBlockFunction(rlen, blen, false)); + //join X and W on original indexes if W existing + JavaPairRDD<MatrixIndexes,MatrixBlock[]> tmp = (inW != null) ? + in.join(inW).mapToPair(new ToArray()) : + in.mapValues(mb -> new MatrixBlock[]{mb, null}); - //functions calls w/ two rdd inputs - if( inU != null && inV == null && inW == null ) - out = in.join(inU) - .mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2)); - else if( inU == null && inV != null && inW == null ) - out = in.join(inV) - .mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2)); - else if( inU == null && inV == null && inW != null ) - out = in.join(inW) - .mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2)); - //function calls w/ three rdd inputs - else if( inU != null && inV != null && inW == null ) - out = in.join(inU).join(inV) - .mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2)); - else if( inU != null && inV == null && inW != null ) - out = in.join(inU).join(inW) - .mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2)); - else if( inU == null && inV != null && inW != null ) - out = in.join(inV).join(inW) - .mapToPair(new RDDQuaternaryFunction3(qop, bc1, bc2)); - else if( inU == null && inV == null && inW == null ) { - out = in.mapPartitionsToPair(new RDDQuaternaryFunction1(qop, bc1, bc2), false); - } - //function call w/ four rdd inputs - else //need keys in case of wdivmm - out = in.join(inU).join(inV).join(inW) - .mapToPair(new RDDQuaternaryFunction4(qop)); + //join lhs U on row-block indexes of X/W + tmp = ( inU != null ) ? + tmp.mapToPair(new ExtractIndexWith(true)) + .join(inU.mapToPair(new ExtractIndex(true))).mapToPair(new Unpack()) : + tmp.mapValues(mb -> ArrayUtils.add(mb, null)); + + //join rhs V on column-block indexes X/W (note V transposed input, so rows) + tmp = ( inV != null ) ? + tmp.mapToPair(new ExtractIndexWith(false)) + .join(inV.mapToPair(new ExtractIndex(true))).mapToPair(new Unpack()) : + tmp.mapValues(mb -> ArrayUtils.add(mb, null)); + + //execute quaternary block operations on joined inputs + out = tmp.mapToPair(new RDDQuaternaryFunction2(qop, bc1, bc2)); //keep variable names for lineage maintenance if( inU == null ) bcVars.add(input2.getName()); else rddVars.add(input2.getName()); @@ -374,12 +352,11 @@ public class QuaternarySPInstruction extends ComputationSPInstruction { protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) { MatrixIndexes ixIn = arg._1(); MatrixBlock blkIn = arg._2(); - MatrixBlock blkOut = new MatrixBlock(); MatrixBlock mbU = _pmU.getBlock((int)ixIn.getRowIndex(), 1); MatrixBlock mbV = _pmV.getBlock((int)ixIn.getColumnIndex(), 1); //execute core operation - blkIn.quaternaryOperations(_qop, mbU, mbV, null, blkOut); + MatrixBlock blkOut = blkIn.quaternaryOperations(_qop, mbU, mbV, null, new MatrixBlock()); //create return tuple MatrixIndexes ixOut = createOutputIndexes(ixIn); @@ -389,7 +366,7 @@ public class QuaternarySPInstruction extends ComputationSPInstruction { } private static class RDDQuaternaryFunction2 extends RDDQuaternaryBaseFunction //two rdd input - implements PairFunction<Tuple2<MatrixIndexes, Tuple2<MatrixBlock,MatrixBlock>>, MatrixIndexes, MatrixBlock> + implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock[]>, MatrixIndexes, MatrixBlock> { private static final long serialVersionUID = 7493974462943080693L; @@ -398,17 +375,15 @@ public class QuaternarySPInstruction extends ComputationSPInstruction { } @Override - public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> arg0) { + public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock[]> arg0) { MatrixIndexes ixIn = arg0._1(); - MatrixBlock blkIn1 = arg0._2()._1(); - MatrixBlock blkIn2 = arg0._2()._2(); - MatrixBlock blkOut = new MatrixBlock(); - MatrixBlock mbU = (_pmU!=null)?_pmU.getBlock((int)ixIn.getRowIndex(), 1) : blkIn2; - MatrixBlock mbV = (_pmV!=null)?_pmV.getBlock((int)ixIn.getColumnIndex(), 1) : blkIn2; - MatrixBlock mbW = (_qop.hasFourInputs()) ? blkIn2 : null; + MatrixBlock[] blks = arg0._2(); + MatrixBlock mbU = (_pmU!=null)?_pmU.getBlock((int)ixIn.getRowIndex(), 1) : blks[2]; + MatrixBlock mbV = (_pmV!=null)?_pmV.getBlock((int)ixIn.getColumnIndex(), 1) : blks[3]; + MatrixBlock mbW = (_qop.hasFourInputs()) ? blks[1] : null; //execute core operation - blkIn1.quaternaryOperations(_qop, mbU, mbV, mbW, blkOut); + MatrixBlock blkOut = blks[0].quaternaryOperations(_qop, mbU, mbV, mbW, new MatrixBlock()); //create return tuple MatrixIndexes ixOut = createOutputIndexes(ixIn); @@ -416,82 +391,51 @@ public class QuaternarySPInstruction extends ComputationSPInstruction { } } - private static class RDDQuaternaryFunction3 extends RDDQuaternaryBaseFunction //three rdd input - implements PairFunction<Tuple2<MatrixIndexes, Tuple2<Tuple2<MatrixBlock,MatrixBlock>,MatrixBlock>>, MatrixIndexes, MatrixBlock> - { - private static final long serialVersionUID = -2294086455843773095L; - - public RDDQuaternaryFunction3( QuaternaryOperator qop, PartitionedBroadcast<MatrixBlock> bcU, PartitionedBroadcast<MatrixBlock> bcV ) { - super(qop, bcU, bcV); + private static class ExtractIndex implements PairFunction<Tuple2<MatrixIndexes,MatrixBlock>, Long, MatrixBlock> { + private static final long serialVersionUID = -6542246824481788376L; + private final boolean _row; + public ExtractIndex(boolean row) { + _row = row; } - @Override - public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>> arg0) { - MatrixIndexes ixIn = arg0._1(); - MatrixBlock blkIn1 = arg0._2()._1()._1(); - MatrixBlock blkIn2 = arg0._2()._1()._2(); - MatrixBlock blkIn3 = arg0._2()._2(); - MatrixBlock blkOut = new MatrixBlock(); - MatrixBlock mbU = (_pmU!=null)?_pmU.getBlock((int)ixIn.getRowIndex(), 1) : blkIn2; - MatrixBlock mbV = (_pmV!=null)?_pmV.getBlock((int)ixIn.getColumnIndex(), 1) : - (_pmU!=null)? blkIn2 : blkIn3; - MatrixBlock mbW = (_qop.hasFourInputs())? blkIn3 : null; - - //execute core operation - blkIn1.quaternaryOperations(_qop, mbU, mbV, mbW, blkOut); - - //create return tuple - MatrixIndexes ixOut = createOutputIndexes(ixIn); - return new Tuple2<>(ixOut, blkOut); + public Tuple2<Long, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg) throws Exception { + return new Tuple2<>(_row?arg._1().getRowIndex():arg._1().getColumnIndex(), arg._2()); } } - /** - * Note: never called for wsigmoid/wdivmm (only wsloss) - */ - private static class RDDQuaternaryFunction4 extends RDDQuaternaryBaseFunction //four rdd input - implements PairFunction<Tuple2<MatrixIndexes,Tuple2<Tuple2<Tuple2<MatrixBlock,MatrixBlock>,MatrixBlock>,MatrixBlock>>,MatrixIndexes,MatrixBlock> - { - private static final long serialVersionUID = 7328911771600289250L; - - public RDDQuaternaryFunction4( QuaternaryOperator qop ) { - super(qop, null, null); + private static class ExtractIndexWith implements PairFunction<Tuple2<MatrixIndexes,MatrixBlock[]>, Long, Tuple2<MatrixIndexes,MatrixBlock[]>> { + private static final long serialVersionUID = -966212318512764461L; + private final boolean _row; + public ExtractIndexWith(boolean row) { + _row = row; } + @Override + public Tuple2<Long, Tuple2<MatrixIndexes, MatrixBlock[]>> call(Tuple2<MatrixIndexes, MatrixBlock[]> arg) + throws Exception + { + return new Tuple2<>(_row?arg._1().getRowIndex():arg._1().getColumnIndex(), arg); + } + } + + private static class ToArray implements PairFunction<Tuple2<MatrixIndexes,Tuple2<MatrixBlock,MatrixBlock>>, MatrixIndexes, MatrixBlock[]> { + private static final long serialVersionUID = -4856316007590144978L; @Override - public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, Tuple2<Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>, MatrixBlock>> arg0) + public Tuple2<MatrixIndexes, MatrixBlock[]> call(Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> arg) + throws Exception { - MatrixIndexes ixIn1 = arg0._1(); - MatrixBlock blkIn1 = arg0._2()._1()._1()._1(); - MatrixBlock mbU = arg0._2()._1()._1()._2(); - MatrixBlock mbV = arg0._2()._1()._2(); - MatrixBlock mbW = arg0._2()._2(); - MatrixBlock blkOut = new MatrixBlock(); - - //execute core operation - blkIn1.quaternaryOperations(_qop, mbU, mbV, mbW, blkOut); - - //create return tuple - MatrixIndexes ixOut = createOutputIndexes(ixIn1); - return new Tuple2<>(ixOut, blkOut); + return new Tuple2<>(arg._1(), new MatrixBlock[]{arg._2()._1(),arg._2()._2()}); } } - private static class TransposeFactorIndexesFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> - { - private static final long serialVersionUID = -2571724736131823708L; - + private static class Unpack implements PairFunction<Tuple2<Long, Tuple2<Tuple2<MatrixIndexes,MatrixBlock[]>,MatrixBlock>>, MatrixIndexes, MatrixBlock[]> { + private static final long serialVersionUID = 3812660351709830714L; @Override - public Tuple2<MatrixIndexes, MatrixBlock> call( Tuple2<MatrixIndexes, MatrixBlock> arg0 ) { - MatrixIndexes ixIn = arg0._1(); - MatrixBlock blkIn = arg0._2(); - - //swap the matrix indexes - MatrixIndexes ixOut = new MatrixIndexes(ixIn.getColumnIndex(), ixIn.getRowIndex()); - MatrixBlock blkOut = new MatrixBlock(blkIn); - - //output new tuple - return new Tuple2<>(ixOut,blkOut); + public Tuple2<MatrixIndexes, MatrixBlock[]> call( + Tuple2<Long, Tuple2<Tuple2<MatrixIndexes, MatrixBlock[]>, MatrixBlock>> arg) throws Exception + { + return new Tuple2<>(arg._2()._1()._1(), //matrix indexes + ArrayUtils.addAll(arg._2()._1()._2(), arg._2()._2())); //array of matrix blocks } } }
