Repository: systemml Updated Branches: refs/heads/master 68a7b44b5 -> a0b0e80e9
[SYSTEMML-2236] Improved spark cpmm (partitioning-preserving case) This patch adds a special case to the spark cpmm matrix multiply operator for the special case of matrix-vector multiply and existing matrix partitioning. In this case, we use a different approach that retains the original matrix keys and thus partitioning, which avoids unnecessary shuffle and stages. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/41526805 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/41526805 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/41526805 Branch: refs/heads/master Commit: 41526805241eafa1c454df830f1512b20d98dd2a Parents: 68a7b44 Author: Matthias Boehm <mboe...@gmail.com> Authored: Fri Apr 6 22:35:26 2018 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Fri Apr 6 22:35:26 2018 -0700 ---------------------------------------------------------------------- .../instructions/spark/CpmmSPInstruction.java | 97 +++++++++++++++----- .../functions/FilterNonEmptyBlocksFunction.java | 8 +- .../FilterNonEmptyBlocksFunction2.java | 34 +++++++ 3 files changed, 107 insertions(+), 32 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/41526805/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java index 770f6fb..5c98964 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java @@ -20,6 +20,8 @@ package org.apache.sysml.runtime.instructions.spark; import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFunction; import scala.Tuple2; @@ -30,9 +32,12 @@ import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.functionobjects.Multiply; import org.apache.sysml.runtime.functionobjects.Plus; +import org.apache.sysml.runtime.functionobjects.SwapIndex; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction; +import org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction2; +import org.apache.sysml.runtime.instructions.spark.functions.ReorgMapFunction; import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; @@ -42,6 +47,7 @@ import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue; import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator; import org.apache.sysml.runtime.matrix.operators.AggregateOperator; import org.apache.sysml.runtime.matrix.operators.Operator; +import org.apache.sysml.runtime.matrix.operators.ReorgOperator; /** * Cpmm: cross-product matrix multiplication operation (distributed matrix multiply @@ -93,39 +99,57 @@ public class CpmmSPInstruction extends BinarySPInstruction { in2 = in2.filter(new FilterNonEmptyBlocksFunction()); } - //compute preferred join degree of parallelism - int numPreferred = getPreferredParJoin(mc1, mc2, in1.getNumPartitions(), in2.getNumPartitions()); - int numPartJoin = Math.min(getMaxParJoin(mc1, mc2), numPreferred); - - //process core cpmm matrix multiply - JavaPairRDD<Long, IndexedMatrixValue> tmp1 = in1.mapToPair(new CpmmIndexFunction(true)); - JavaPairRDD<Long, IndexedMatrixValue> tmp2 = in2.mapToPair(new CpmmIndexFunction(false)); - JavaPairRDD<MatrixIndexes,MatrixBlock> out = tmp1 - .join(tmp2, numPartJoin) // join over common dimension - .mapToPair(new CpmmMultiplyFunction()); // compute block multiplications - - //process cpmm aggregation and handle outputs - if( _aggtype == SparkAggType.SINGLE_BLOCK ) { - //prune empty blocks and aggregate all results - out = out.filter(new FilterNonEmptyBlocksFunction()); + if( SparkUtils.isHashPartitioned(in1) //ZIPMM-like CPMM + && mc1.getNumRowBlocks()==1 && mc2.getCols()==1 ) { + //note: if the major input is hash-partitioned and it's a matrix-vector + //multiply, avoid the index mapping to preserve the partitioning similar + //to a ZIPMM but with different transpose characteristics + JavaRDD<MatrixBlock> out = in1 + .join(in2.mapToPair(new ReorgMapFunction("r'"))) + .values().map(new Cpmm2MultiplyFunction()) + .filter(new FilterNonEmptyBlocksFunction2()); MatrixBlock out2 = RDDAggregateUtils.sumStable(out); //put output block into symbol table (no lineage because single block) //this also includes implicit maintenance of matrix characteristics sec.setMatrixOutput(output.getName(), out2, getExtendedOpcode()); } - else { //DEFAULT: MULTI_BLOCK - if( !_outputEmptyBlocks ) - out = out.filter(new FilterNonEmptyBlocksFunction()); - out = RDDAggregateUtils.sumByKeyStable(out, false); + else //GENERAL CPMM + { + //compute preferred join degree of parallelism + int numPreferred = getPreferredParJoin(mc1, mc2, in1.getNumPartitions(), in2.getNumPartitions()); + int numPartJoin = Math.min(getMaxParJoin(mc1, mc2), numPreferred); - //put output RDD handle into symbol table - sec.setRDDHandleForVariable(output.getName(), out); - sec.addLineageRDD(output.getName(), input1.getName()); - sec.addLineageRDD(output.getName(), input2.getName()); + //process core cpmm matrix multiply + JavaPairRDD<Long, IndexedMatrixValue> tmp1 = in1.mapToPair(new CpmmIndexFunction(true)); + JavaPairRDD<Long, IndexedMatrixValue> tmp2 = in2.mapToPair(new CpmmIndexFunction(false)); + JavaPairRDD<MatrixIndexes,MatrixBlock> out = tmp1 + .join(tmp2, numPartJoin) // join over common dimension + .mapToPair(new CpmmMultiplyFunction()); // compute block multiplications - //update output statistics if not inferred - updateBinaryMMOutputMatrixCharacteristics(sec, true); + //process cpmm aggregation and handle outputs + if( _aggtype == SparkAggType.SINGLE_BLOCK ) { + //prune empty blocks and aggregate all results + out = out.filter(new FilterNonEmptyBlocksFunction()); + MatrixBlock out2 = RDDAggregateUtils.sumStable(out); + + //put output block into symbol table (no lineage because single block) + //this also includes implicit maintenance of matrix characteristics + sec.setMatrixOutput(output.getName(), out2, getExtendedOpcode()); + } + else { //DEFAULT: MULTI_BLOCK + if( !_outputEmptyBlocks ) + out = out.filter(new FilterNonEmptyBlocksFunction()); + out = RDDAggregateUtils.sumByKeyStable(out, false); + + //put output RDD handle into symbol table + sec.setRDDHandleForVariable(output.getName(), out); + sec.addLineageRDD(output.getName(), input1.getName()); + sec.addLineageRDD(output.getName(), input2.getName()); + + //update output statistics if not inferred + updateBinaryMMOutputMatrixCharacteristics(sec, true); + } } } @@ -190,4 +214,27 @@ public class CpmmSPInstruction extends BinarySPInstruction { return new Tuple2<>( ixOut, blkOut ); } } + + private static class Cpmm2MultiplyFunction implements Function<Tuple2<MatrixBlock,MatrixBlock>, MatrixBlock> + { + private static final long serialVersionUID = -3718880362385713416L; + private AggregateBinaryOperator _op = null; + private ReorgOperator _rop = null; + + @Override + public MatrixBlock call(Tuple2<MatrixBlock, MatrixBlock> arg0) throws Exception { + //lazy operator construction + if( _op == null ) { + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + _op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg); + _rop = new ReorgOperator(SwapIndex.getSwapIndexFnObject()); + } + //prepare inputs, including transpose of right-hand-side + MatrixBlock in1 = arg0._1(); + MatrixBlock in2 = (MatrixBlock)arg0._2() + .reorgOperations(_rop, new MatrixBlock(), 0, 0, 0); + //core block matrix multiplication + return in1.aggregateBinaryOperations(in1, in2, new MatrixBlock(), _op); + } + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/41526805/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/FilterNonEmptyBlocksFunction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/FilterNonEmptyBlocksFunction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/FilterNonEmptyBlocksFunction.java index 49355c9..4f545c2 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/FilterNonEmptyBlocksFunction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/FilterNonEmptyBlocksFunction.java @@ -28,20 +28,14 @@ import org.apache.sysml.runtime.matrix.data.MatrixIndexes; public class FilterNonEmptyBlocksFunction implements Function<Tuple2<MatrixIndexes,MatrixBlock>, Boolean> { - private static final long serialVersionUID = -8856829325565589854L; @Override - public Boolean call(Tuple2<MatrixIndexes, MatrixBlock> arg0) - throws Exception - { + public Boolean call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception { //always keep 1-1 block in order to prevent empty rdds boolean ix1 = (arg0._1().getRowIndex()==1 && arg0._1().getColumnIndex()==1); - //returns true for non-empty matrix blocks return !arg0._2().isEmptyBlock(false) || ix1; } - - } http://git-wip-us.apache.org/repos/asf/systemml/blob/41526805/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/FilterNonEmptyBlocksFunction2.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/FilterNonEmptyBlocksFunction2.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/FilterNonEmptyBlocksFunction2.java new file mode 100644 index 0000000..531f7f6 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/FilterNonEmptyBlocksFunction2.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.spark.functions; + +import org.apache.spark.api.java.function.Function; + +import org.apache.sysml.runtime.matrix.data.MatrixBlock; + +public class FilterNonEmptyBlocksFunction2 implements Function<MatrixBlock, Boolean> +{ + private static final long serialVersionUID = -8435900761521598692L; + + @Override + public Boolean call(MatrixBlock arg0) throws Exception { + return !arg0.isEmptyBlock(false); + } +}