Repository: systemml
Updated Branches:
  refs/heads/master 68a7b44b5 -> a0b0e80e9


[SYSTEMML-2236] Improved spark cpmm (partitioning-preserving case)

This patch adds a special case to the spark cpmm matrix multiply
operator for the special case of matrix-vector multiply and existing
matrix partitioning. In this case, we use a different approach that
retains the original matrix keys and thus partitioning, which avoids
unnecessary shuffle and stages.

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/41526805
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/41526805
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/41526805

Branch: refs/heads/master
Commit: 41526805241eafa1c454df830f1512b20d98dd2a
Parents: 68a7b44
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Fri Apr 6 22:35:26 2018 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Fri Apr 6 22:35:26 2018 -0700

----------------------------------------------------------------------
 .../instructions/spark/CpmmSPInstruction.java   | 97 +++++++++++++++-----
 .../functions/FilterNonEmptyBlocksFunction.java |  8 +-
 .../FilterNonEmptyBlocksFunction2.java          | 34 +++++++
 3 files changed, 107 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/41526805/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java
index 770f6fb..5c98964 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java
@@ -20,6 +20,8 @@
 package org.apache.sysml.runtime.instructions.spark;
 
 import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.function.PairFunction;
 
 import scala.Tuple2;
@@ -30,9 +32,12 @@ import 
org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.runtime.functionobjects.Multiply;
 import org.apache.sysml.runtime.functionobjects.Plus;
+import org.apache.sysml.runtime.functionobjects.SwapIndex;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
 import 
org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
+import 
org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction2;
+import org.apache.sysml.runtime.instructions.spark.functions.ReorgMapFunction;
 import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
 import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
@@ -42,6 +47,7 @@ import 
org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
 import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator;
 import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysml.runtime.matrix.operators.Operator;
+import org.apache.sysml.runtime.matrix.operators.ReorgOperator;
 
 /**
  * Cpmm: cross-product matrix multiplication operation (distributed matrix 
multiply
@@ -93,39 +99,57 @@ public class CpmmSPInstruction extends BinarySPInstruction {
                        in2 = in2.filter(new FilterNonEmptyBlocksFunction());
                }
                
-               //compute preferred join degree of parallelism
-               int numPreferred = getPreferredParJoin(mc1, mc2, 
in1.getNumPartitions(), in2.getNumPartitions());
-               int numPartJoin = Math.min(getMaxParJoin(mc1, mc2), 
numPreferred);
-               
-               //process core cpmm matrix multiply 
-               JavaPairRDD<Long, IndexedMatrixValue> tmp1 = in1.mapToPair(new 
CpmmIndexFunction(true));
-               JavaPairRDD<Long, IndexedMatrixValue> tmp2 = in2.mapToPair(new 
CpmmIndexFunction(false));
-               JavaPairRDD<MatrixIndexes,MatrixBlock> out = tmp1
-                       .join(tmp2, numPartJoin)                // join over 
common dimension
-                       .mapToPair(new CpmmMultiplyFunction()); // compute 
block multiplications
-               
-               //process cpmm aggregation and handle outputs
-               if( _aggtype == SparkAggType.SINGLE_BLOCK ) {
-                       //prune empty blocks and aggregate all results
-                       out = out.filter(new FilterNonEmptyBlocksFunction());
+               if( SparkUtils.isHashPartitioned(in1) //ZIPMM-like CPMM
+                       && mc1.getNumRowBlocks()==1 && mc2.getCols()==1 ) {
+                       //note: if the major input is hash-partitioned and it's 
a matrix-vector
+                       //multiply, avoid the index mapping to preserve the 
partitioning similar
+                       //to a ZIPMM but with different transpose 
characteristics
+                       JavaRDD<MatrixBlock> out = in1
+                               .join(in2.mapToPair(new ReorgMapFunction("r'")))
+                               .values().map(new Cpmm2MultiplyFunction())
+                               .filter(new FilterNonEmptyBlocksFunction2());
                        MatrixBlock out2 = RDDAggregateUtils.sumStable(out);
                        
                        //put output block into symbol table (no lineage 
because single block)
                        //this also includes implicit maintenance of matrix 
characteristics
                        sec.setMatrixOutput(output.getName(), out2, 
getExtendedOpcode());
                }
-               else { //DEFAULT: MULTI_BLOCK
-                       if( !_outputEmptyBlocks )
-                               out = out.filter(new 
FilterNonEmptyBlocksFunction());
-                       out = RDDAggregateUtils.sumByKeyStable(out, false);
+               else //GENERAL CPMM
+               {
+                       //compute preferred join degree of parallelism
+                       int numPreferred = getPreferredParJoin(mc1, mc2, 
in1.getNumPartitions(), in2.getNumPartitions());
+                       int numPartJoin = Math.min(getMaxParJoin(mc1, mc2), 
numPreferred);
                        
-                       //put output RDD handle into symbol table
-                       sec.setRDDHandleForVariable(output.getName(), out);
-                       sec.addLineageRDD(output.getName(), input1.getName());
-                       sec.addLineageRDD(output.getName(), input2.getName());
+                       //process core cpmm matrix multiply 
+                       JavaPairRDD<Long, IndexedMatrixValue> tmp1 = 
in1.mapToPair(new CpmmIndexFunction(true));
+                       JavaPairRDD<Long, IndexedMatrixValue> tmp2 = 
in2.mapToPair(new CpmmIndexFunction(false));
+                       JavaPairRDD<MatrixIndexes,MatrixBlock> out = tmp1
+                               .join(tmp2, numPartJoin)                // join 
over common dimension
+                               .mapToPair(new CpmmMultiplyFunction()); // 
compute block multiplications
                        
-                       //update output statistics if not inferred
-                       updateBinaryMMOutputMatrixCharacteristics(sec, true);
+                       //process cpmm aggregation and handle outputs
+                       if( _aggtype == SparkAggType.SINGLE_BLOCK ) {
+                               //prune empty blocks and aggregate all results
+                               out = out.filter(new 
FilterNonEmptyBlocksFunction());
+                               MatrixBlock out2 = 
RDDAggregateUtils.sumStable(out);
+                               
+                               //put output block into symbol table (no 
lineage because single block)
+                               //this also includes implicit maintenance of 
matrix characteristics
+                               sec.setMatrixOutput(output.getName(), out2, 
getExtendedOpcode());
+                       }
+                       else { //DEFAULT: MULTI_BLOCK
+                               if( !_outputEmptyBlocks )
+                                       out = out.filter(new 
FilterNonEmptyBlocksFunction());
+                               out = RDDAggregateUtils.sumByKeyStable(out, 
false);
+                               
+                               //put output RDD handle into symbol table
+                               sec.setRDDHandleForVariable(output.getName(), 
out);
+                               sec.addLineageRDD(output.getName(), 
input1.getName());
+                               sec.addLineageRDD(output.getName(), 
input2.getName());
+                               
+                               //update output statistics if not inferred
+                               updateBinaryMMOutputMatrixCharacteristics(sec, 
true);
+                       }
                }
        }
        
@@ -190,4 +214,27 @@ public class CpmmSPInstruction extends BinarySPInstruction 
{
                        return new Tuple2<>( ixOut, blkOut );
                }
        }
+       
+       private static class Cpmm2MultiplyFunction implements 
Function<Tuple2<MatrixBlock,MatrixBlock>, MatrixBlock>
+       {
+               private static final long serialVersionUID = 
-3718880362385713416L;
+               private AggregateBinaryOperator _op = null;
+               private ReorgOperator _rop = null;
+               
+               @Override
+               public MatrixBlock call(Tuple2<MatrixBlock, MatrixBlock> arg0) 
throws Exception {
+                        //lazy operator construction
+                       if( _op == null ) {
+                               AggregateOperator agg = new 
AggregateOperator(0, Plus.getPlusFnObject());
+                               _op = new 
AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
+                               _rop = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject());
+                       }
+                       //prepare inputs, including transpose of right-hand-side
+                       MatrixBlock in1 = arg0._1();
+                       MatrixBlock in2 = (MatrixBlock)arg0._2()
+                               .reorgOperations(_rop, new MatrixBlock(), 0, 0, 
0);
+                       //core block matrix multiplication
+                       return in1.aggregateBinaryOperations(in1, in2, new 
MatrixBlock(), _op);
+               }
+       }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/41526805/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/FilterNonEmptyBlocksFunction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/FilterNonEmptyBlocksFunction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/FilterNonEmptyBlocksFunction.java
index 49355c9..4f545c2 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/FilterNonEmptyBlocksFunction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/FilterNonEmptyBlocksFunction.java
@@ -28,20 +28,14 @@ import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
 
 public class FilterNonEmptyBlocksFunction implements 
Function<Tuple2<MatrixIndexes,MatrixBlock>, Boolean> 
 {
-       
        private static final long serialVersionUID = -8856829325565589854L;
 
        @Override
-       public Boolean call(Tuple2<MatrixIndexes, MatrixBlock> arg0)
-               throws Exception 
-       {
+       public Boolean call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws 
Exception {
                //always keep 1-1 block in order to prevent empty rdds
                boolean ix1 = (arg0._1().getRowIndex()==1 
                                && arg0._1().getColumnIndex()==1);
-               
                //returns true for non-empty matrix blocks
                return !arg0._2().isEmptyBlock(false) || ix1;
        }
-       
-
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/41526805/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/FilterNonEmptyBlocksFunction2.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/FilterNonEmptyBlocksFunction2.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/FilterNonEmptyBlocksFunction2.java
new file mode 100644
index 0000000..531f7f6
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/FilterNonEmptyBlocksFunction2.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.spark.functions;
+
+import org.apache.spark.api.java.function.Function;
+
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+
+public class FilterNonEmptyBlocksFunction2 implements Function<MatrixBlock, 
Boolean> 
+{
+       private static final long serialVersionUID = -8435900761521598692L;
+
+       @Override
+       public Boolean call(MatrixBlock arg0) throws Exception {
+               return !arg0.isEmptyBlock(false);
+       }
+}

Reply via email to