[SYSTEMML-2222] Empty block filtering for spark cpmm, rshape, nnz

This patch improves the performance of distributed operations over
ultra-sparse matrices by compiling safe filtering conditions for empty
blocks into spark cpmm, reshape and nnz compute operations. On an
end-to-end application over ultra-sparse graph datasets, this patch
improved the total application runtime from 1139s to 451s.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/696e7921
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/696e7921
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/696e7921

Branch: refs/heads/master
Commit: 696e79218db8e20a8a2c7ea62cb3558037c9fc0e
Parents: 022e046
Author: Matthias Boehm <[email protected]>
Authored: Sat Mar 31 00:55:09 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Sat Mar 31 00:59:20 2018 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/AggBinaryOp.java |  6 ++-
 .../java/org/apache/sysml/hops/ReorgOp.java     |  6 ++-
 src/main/java/org/apache/sysml/lops/MMCJ.java   | 10 +++--
 .../java/org/apache/sysml/lops/Transform.java   | 11 +++++-
 .../mr/MatrixReshapeMRInstruction.java          |  2 +-
 .../instructions/spark/CpmmSPInstruction.java   | 17 ++++++---
 .../spark/MatrixReshapeSPInstruction.java       | 40 ++++++++++++--------
 .../spark/functions/RecomputeNnzFunction.java   | 39 +++++++++++++++++++
 .../instructions/spark/utils/SparkUtils.java    |  6 ++-
 .../runtime/matrix/data/LibMatrixReorg.java     | 12 +++---
 10 files changed, 111 insertions(+), 38 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/696e7921/src/main/java/org/apache/sysml/hops/AggBinaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggBinaryOp.java 
b/src/main/java/org/apache/sysml/hops/AggBinaryOp.java
index 3c5f157..fbbb2dd 100644
--- a/src/main/java/org/apache/sysml/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggBinaryOp.java
@@ -813,9 +813,10 @@ public class AggBinaryOp extends Hop implements 
MultiThreadedHop
                else
                {
                        SparkAggType aggtype = getSparkMMAggregationType(true);
+                       _outputEmptyBlocks = 
!OptimizerUtils.allowsToFilterEmptyBlockOutputs(this); 
                        
                        Lop cpmm = new MMCJ(getInput().get(0).constructLops(), 
getInput().get(1).constructLops(), 
-                                                               getDataType(), 
getValueType(), aggtype, ExecType.SPARK);
+                               getDataType(), getValueType(), 
_outputEmptyBlocks, aggtype, ExecType.SPARK);
                        setOutputDimensions( cpmm );
                        setLineNumbers( cpmm );
                        setLops( cpmm );
@@ -834,7 +835,8 @@ public class AggBinaryOp extends Hop implements 
MultiThreadedHop
                setLineNumbers(tY);
                
                //matrix multiply
-               MMCJ mmcj = new MMCJ(tY, X.constructLops(), getDataType(), 
getValueType(), aggtype, ExecType.SPARK);
+               _outputEmptyBlocks = 
!OptimizerUtils.allowsToFilterEmptyBlockOutputs(this); 
+               MMCJ mmcj = new MMCJ(tY, X.constructLops(), getDataType(), 
getValueType(), _outputEmptyBlocks, aggtype, ExecType.SPARK);
                mmcj.getOutputParameters().setDimensions(getDim1(), getDim2(), 
getRowsInBlock(), getColsInBlock(), getNnz());
                setLineNumbers(mmcj);
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/696e7921/src/main/java/org/apache/sysml/hops/ReorgOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ReorgOp.java 
b/src/main/java/org/apache/sysml/hops/ReorgOp.java
index b8b0139..0d3863b 100644
--- a/src/main/java/org/apache/sysml/hops/ReorgOp.java
+++ b/src/main/java/org/apache/sysml/hops/ReorgOp.java
@@ -232,7 +232,7 @@ public class ReorgOp extends Hop implements MultiThreadedHop
                                if( et==ExecType.MR )
                                {
                                        Transform transform1 = new Transform( 
linputs,
-                                               HopsTransf2Lops.get(op), 
getDataType(), getValueType(), et);
+                                               HopsTransf2Lops.get(op), 
getDataType(), getValueType(), true, et);
                                        setOutputDimensions(transform1);
                                        setLineNumbers(transform1);
                                        
@@ -250,8 +250,10 @@ public class ReorgOp extends Hop implements 
MultiThreadedHop
                                }
                                else //CP/SPARK
                                {
+                                       _outputEmptyBlocks = 
(et==ExecType.SPARK &&
+                                               
!OptimizerUtils.allowsToFilterEmptyBlockOutputs(this)); 
                                        Transform transform1 = new Transform( 
linputs,
-                                               HopsTransf2Lops.get(op), 
getDataType(), getValueType(), et);
+                                               HopsTransf2Lops.get(op), 
getDataType(), getValueType(), _outputEmptyBlocks, et);
                                        setOutputDimensions(transform1);
                                        setLineNumbers(transform1);
                                        

http://git-wip-us.apache.org/repos/asf/systemml/blob/696e7921/src/main/java/org/apache/sysml/lops/MMCJ.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/MMCJ.java 
b/src/main/java/org/apache/sysml/lops/MMCJ.java
index 429520e..bd0dbc3 100644
--- a/src/main/java/org/apache/sysml/lops/MMCJ.java
+++ b/src/main/java/org/apache/sysml/lops/MMCJ.java
@@ -37,6 +37,7 @@ public class MMCJ extends Lop
        }
        
        //optional attribute for mr exec type
+       private boolean _outputEmptyBlocks = true;
        private MMCJType _type = MMCJType.AGG;
        
        //optional attribute for spark exec type
@@ -81,12 +82,12 @@ public class MMCJ extends Lop
                }
        }
 
-       public MMCJ(Lop input1, Lop input2, DataType dt, ValueType vt, 
SparkAggType aggtype, ExecType et) {
+       public MMCJ(Lop input1, Lop input2, DataType dt, ValueType vt, boolean 
outputEmptyBlocks, SparkAggType aggtype, ExecType et) {
                this(input1, input2, dt, vt, MMCJType.NO_AGG, et);
+               _outputEmptyBlocks = outputEmptyBlocks;
                _aggtype = aggtype;
        }
        
-       
        @Override
        public String toString() {
                return "Operation = MMCJ";
@@ -118,8 +119,11 @@ public class MMCJ extends Lop
                sb.append( prepOutputOperand(output) );
                
                sb.append( OPERAND_DELIMITOR );
-               if( getExecType() == ExecType.SPARK )
+               if( getExecType() == ExecType.SPARK ) {
+                       sb.append(_outputEmptyBlocks);
+                       sb.append(Lop.OPERAND_DELIMITOR);
                        sb.append(_aggtype.name());
+               }
                else
                        sb.append(_type.name());
                

http://git-wip-us.apache.org/repos/asf/systemml/blob/696e7921/src/main/java/org/apache/sysml/lops/Transform.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/Transform.java 
b/src/main/java/org/apache/sysml/lops/Transform.java
index 1de2a9c..511e525 100644
--- a/src/main/java/org/apache/sysml/lops/Transform.java
+++ b/src/main/java/org/apache/sysml/lops/Transform.java
@@ -42,14 +42,16 @@ public class Transform extends Lop
        
        private OperationTypes operation = null;
        private boolean _bSortIndInMem = false;
+       private boolean _outputEmptyBlock = true;
        private int _numThreads = 1;
        
        public Transform(Lop input, Transform.OperationTypes op, DataType dt, 
ValueType vt, ExecType et) {
                this(input, op, dt, vt, et, 1);
        }
        
-       public Transform(Lop[] inputs, Transform.OperationTypes op, DataType 
dt, ValueType vt, ExecType et) {
+       public Transform(Lop[] inputs, Transform.OperationTypes op, DataType 
dt, ValueType vt, boolean outputEmptyBlock, ExecType et) {
                this(inputs, op, dt, vt, et, 1);
+               _outputEmptyBlock = outputEmptyBlock;
        }
        
        public Transform(Lop input, Transform.OperationTypes op, DataType dt, 
ValueType vt, ExecType et, int k)  {
@@ -176,7 +178,7 @@ public class Transform extends Lop
                sb.append( getInputs().get(0).prepInputOperand(input1));
                sb.append( OPERAND_DELIMITOR );
                sb.append( this.prepOutputOperand(output));
-
+               
                if( getExecType()==ExecType.CP && operation == 
OperationTypes.Transpose ) {
                        sb.append( OPERAND_DELIMITOR );
                        sb.append( _numThreads );
@@ -209,6 +211,11 @@ public class Transform extends Lop
                sb.append( OPERAND_DELIMITOR );
                sb.append( this.prepOutputOperand(output));
                
+               if( getExecType()==ExecType.SPARK && operation == 
OperationTypes.Reshape ) {
+                       sb.append( OPERAND_DELIMITOR );
+                       sb.append( _outputEmptyBlock );
+               }
+               
                if( getExecType()==ExecType.SPARK && operation == 
OperationTypes.Sort ){
                        sb.append( OPERAND_DELIMITOR );
                        sb.append( _bSortIndInMem );

http://git-wip-us.apache.org/repos/asf/systemml/blob/696e7921/src/main/java/org/apache/sysml/runtime/instructions/mr/MatrixReshapeMRInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/mr/MatrixReshapeMRInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/mr/MatrixReshapeMRInstruction.java
index c252cc5..fc1e05f 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/mr/MatrixReshapeMRInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/mr/MatrixReshapeMRInstruction.java
@@ -83,7 +83,7 @@ public class MatrixReshapeMRInstruction extends 
UnaryInstruction {
        
                                //process instruction
                                _mcOut.setBlockSize(brlen, bclen);
-                               out = LibMatrixReorg.reshape(imv, _mcIn, out, 
_mcOut, _byrow);
+                               out = LibMatrixReorg.reshape(imv, _mcIn, out, 
_mcOut, _byrow, true);
                                
                                //put the output values in the output cache
                                for( IndexedMatrixValue outBlk : out )

http://git-wip-us.apache.org/repos/asf/systemml/blob/696e7921/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java
index 376cf5b..770f6fb 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/CpmmSPInstruction.java
@@ -53,10 +53,12 @@ import org.apache.sysml.runtime.matrix.operators.Operator;
  * 
  */
 public class CpmmSPInstruction extends BinarySPInstruction {
-       private SparkAggType _aggtype;
-
-       private CpmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, 
CPOperand out, SparkAggType aggtype, String opcode, String istr) {
+       private final boolean _outputEmptyBlocks;
+       private final SparkAggType _aggtype;
+       
+       private CpmmSPInstruction(Operator op, CPOperand in1, CPOperand in2, 
CPOperand out, boolean outputEmptyBlocks, SparkAggType aggtype, String opcode, 
String istr) {
                super(SPType.CPMM, op, in1, in2, out, opcode, istr);
+               _outputEmptyBlocks = outputEmptyBlocks;
                _aggtype = aggtype;
        }
 
@@ -70,8 +72,9 @@ public class CpmmSPInstruction extends BinarySPInstruction {
                CPOperand out = new CPOperand(parts[3]);
                AggregateOperator agg = new AggregateOperator(0, 
Plus.getPlusFnObject());
                AggregateBinaryOperator aggbin = new 
AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
-               SparkAggType aggtype = SparkAggType.valueOf(parts[4]);
-               return new CpmmSPInstruction(aggbin, in1, in2, out, aggtype, 
opcode, str);
+               boolean outputEmptyBlocks = Boolean.parseBoolean(parts[4]);
+               SparkAggType aggtype = SparkAggType.valueOf(parts[5]);
+               return new CpmmSPInstruction(aggbin, in1, in2, out, 
outputEmptyBlocks, aggtype, opcode, str);
        }
        
        @Override
@@ -84,7 +87,7 @@ public class CpmmSPInstruction extends BinarySPInstruction {
                MatrixCharacteristics mc1 = 
sec.getMatrixCharacteristics(input1.getName());
                MatrixCharacteristics mc2 = 
sec.getMatrixCharacteristics(input2.getName());
                
-               if( _aggtype == SparkAggType.SINGLE_BLOCK ) {
+               if( !_outputEmptyBlocks || _aggtype == 
SparkAggType.SINGLE_BLOCK ) {
                        //prune empty blocks of ultra-sparse matrices
                        in1 = in1.filter(new FilterNonEmptyBlocksFunction());
                        in2 = in2.filter(new FilterNonEmptyBlocksFunction());
@@ -112,6 +115,8 @@ public class CpmmSPInstruction extends BinarySPInstruction {
                        sec.setMatrixOutput(output.getName(), out2, 
getExtendedOpcode());
                }
                else { //DEFAULT: MULTI_BLOCK
+                       if( !_outputEmptyBlocks )
+                               out = out.filter(new 
FilterNonEmptyBlocksFunction());
                        out = RDDAggregateUtils.sumByKeyStable(out, false);
                        
                        //put output RDD handle into symbol table

http://git-wip-us.apache.org/repos/asf/systemml/blob/696e7921/src/main/java/org/apache/sysml/runtime/instructions/spark/MatrixReshapeSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/MatrixReshapeSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/MatrixReshapeSPInstruction.java
index 42b4dd9..352f400 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/MatrixReshapeSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/MatrixReshapeSPInstruction.java
@@ -33,6 +33,7 @@ import 
org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import 
org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
 import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
 import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
@@ -42,23 +43,25 @@ import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
 import org.apache.sysml.runtime.matrix.operators.Operator;
 
-public class MatrixReshapeSPInstruction extends UnarySPInstruction {
-
-       private CPOperand _opRows = null;
-       private CPOperand _opCols = null;
-       private CPOperand _opByRow = null;
+public class MatrixReshapeSPInstruction extends UnarySPInstruction
+{
+       private final CPOperand _opRows;
+       private final CPOperand _opCols;
+       private final CPOperand _opByRow;
+       private final boolean _outputEmptyBlocks;
 
        private MatrixReshapeSPInstruction(Operator op, CPOperand in1, 
CPOperand in2, CPOperand in3, CPOperand in4,
-                       CPOperand out, String opcode, String istr) {
+                       CPOperand out, boolean outputEmptyBlocks, String 
opcode, String istr) {
                super(SPType.MatrixReshape, op, in1, out, opcode, istr);
                _opRows = in2;
                _opCols = in3;
                _opByRow = in4;
+               _outputEmptyBlocks = outputEmptyBlocks;
        }
 
        public static MatrixReshapeSPInstruction parseInstruction ( String str 
) {
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
-               InstructionUtils.checkNumFields( parts, 5 );
+               InstructionUtils.checkNumFields( parts, 6 );
                
                String opcode = parts[0];
                CPOperand in1 = new CPOperand(parts[1]);
@@ -66,11 +69,12 @@ public class MatrixReshapeSPInstruction extends 
UnarySPInstruction {
                CPOperand in3 = new CPOperand(parts[3]);
                CPOperand in4 = new CPOperand(parts[4]);
                CPOperand out = new CPOperand(parts[5]);
+               boolean outputEmptyBlocks = Boolean.parseBoolean(parts[6]);
                 
                if(!opcode.equalsIgnoreCase("rshape"))
                        throw new DMLRuntimeException("Unknown opcode while 
parsing an MatrixReshapeInstruction: " + str);
                else
-                       return new MatrixReshapeSPInstruction(new 
Operator(true), in1, in2, in3, in4, out, opcode, str);
+                       return new MatrixReshapeSPInstruction(new 
Operator(true), in1, in2, in3, in4, out, outputEmptyBlocks, opcode, str);
        }
        
        @Override
@@ -96,9 +100,12 @@ public class MatrixReshapeSPInstruction extends 
UnarySPInstruction {
                                + mcIn.getRows()+"x"+mcIn.getCols()+" vs 
"+mcOut.getRows()+"x"+mcOut.getCols());
                }
                
+               if( !_outputEmptyBlocks )
+                       in1 = in1.filter(new FilterNonEmptyBlocksFunction());
+               
                //execute reshape instruction
                JavaPairRDD<MatrixIndexes,MatrixBlock> out = 
-                               in1.flatMapToPair(new RDDReshapeFunction(mcIn, 
mcOut, byRow));
+                       in1.flatMapToPair(new RDDReshapeFunction(mcIn, mcOut, 
byRow, _outputEmptyBlocks));
                out = RDDAggregateUtils.mergeByKey(out);
                
                //put output RDD handle into symbol table
@@ -110,15 +117,16 @@ public class MatrixReshapeSPInstruction extends 
UnarySPInstruction {
        {
                private static final long serialVersionUID = 
2819309412002224478L;
                
-               private MatrixCharacteristics _mcIn = null;
-               private MatrixCharacteristics _mcOut = null;
-               private boolean _byrow = true;
+               private final MatrixCharacteristics _mcIn;
+               private final MatrixCharacteristics _mcOut;
+               private final boolean _byrow;
+               private final boolean _outputEmptyBlocks;
                
-               public RDDReshapeFunction( MatrixCharacteristics mcIn, 
MatrixCharacteristics mcOut, boolean byrow)
-               {
+               public RDDReshapeFunction( MatrixCharacteristics mcIn, 
MatrixCharacteristics mcOut, boolean byrow, boolean outputEmptyBlocks) {
                        _mcIn = mcIn;
                        _mcOut = mcOut;
                        _byrow = byrow;
+                       _outputEmptyBlocks = outputEmptyBlocks;
                }
                
                @Override
@@ -129,8 +137,8 @@ public class MatrixReshapeSPInstruction extends 
UnarySPInstruction {
                        IndexedMatrixValue in = 
SparkUtils.toIndexedMatrixBlock(arg0);
                        
                        //execute actual reshape operation
-                       ArrayList<IndexedMatrixValue> out = new ArrayList<>();
-                       out = LibMatrixReorg.reshape(in, _mcIn, out, _mcOut, 
_byrow);
+                       ArrayList<IndexedMatrixValue> out = LibMatrixReorg
+                               .reshape(in, _mcIn, new ArrayList<>(), _mcOut, 
_byrow, _outputEmptyBlocks);
 
                        //output conversion (for compatibility w/ rdd schema)
                        return 
SparkUtils.fromIndexedMatrixBlock(out).iterator();

http://git-wip-us.apache.org/repos/asf/systemml/blob/696e7921/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/RecomputeNnzFunction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/RecomputeNnzFunction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/RecomputeNnzFunction.java
new file mode 100644
index 0000000..e591f83
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/RecomputeNnzFunction.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.spark.functions;
+
+import java.util.Arrays;
+import java.util.Iterator;
+
+import org.apache.spark.api.java.function.FlatMapFunction;
+
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+
+public class RecomputeNnzFunction implements 
FlatMapFunction<Iterator<MatrixBlock>, Long>
+{
+       private static final long serialVersionUID = -973429193604040011L;
+
+       public Iterator<Long> call(Iterator<MatrixBlock> iter) throws Exception 
{
+               long nnz = 0;
+               while( iter.hasNext() )
+                       nnz += iter.next().getNonZeros();
+               return Arrays.asList(nnz).iterator();
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/696e7921/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.java
index 546266b..cdd64f0 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/SparkUtils.java
@@ -40,6 +40,8 @@ import 
org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyze
 import 
org.apache.sysml.runtime.instructions.spark.functions.CopyBinaryCellFunction;
 import org.apache.sysml.runtime.instructions.spark.functions.CopyBlockFunction;
 import 
org.apache.sysml.runtime.instructions.spark.functions.CopyBlockPairFunction;
+import 
org.apache.sysml.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
+import 
org.apache.sysml.runtime.instructions.spark.functions.RecomputeNnzFunction;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.FrameBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
@@ -233,7 +235,9 @@ public class SparkUtils
        }
        
        public static long getNonZeros(JavaPairRDD<MatrixIndexes, MatrixBlock> 
input) {
-               return input.values().map(b -> 
b.getNonZeros()).reduce((a,b)->a+b);
+               //note: avoid direct lambda expression due reduce unnecessary 
GC overhead
+               return input.filter(new FilterNonEmptyBlocksFunction())
+                       .values().mapPartitions(new 
RecomputeNnzFunction()).reduce((a,b)->a+b);
        }
 
        private static class AnalyzeCellMatrixCharacteristics implements 
Function<Tuple2<MatrixIndexes,MatrixCell>, MatrixCharacteristics> 

http://git-wip-us.apache.org/repos/asf/systemml/blob/696e7921/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixReorg.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixReorg.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixReorg.java
index b6fe8d7..8f7bda9 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixReorg.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixReorg.java
@@ -463,10 +463,11 @@ public class LibMatrixReorg
         * @param out list of indexed matrix values
         * @param mcOut output matrix characteristics
         * @param rowwise if true, reshape by row
+        * @param outputEmptyBlocks output blocks with nnz=0
         * @return list of indexed matrix values
         */
        public static ArrayList<IndexedMatrixValue> reshape( IndexedMatrixValue 
in, MatrixCharacteristics mcIn, 
-                       ArrayList<IndexedMatrixValue> out, 
MatrixCharacteristics mcOut, boolean rowwise ) {
+                       ArrayList<IndexedMatrixValue> out, 
MatrixCharacteristics mcOut, boolean rowwise, boolean outputEmptyBlocks ) {
                //prepare inputs
                MatrixIndexes ixIn = in.getIndexes();
                MatrixBlock mbIn = (MatrixBlock) in.getValue();
@@ -485,10 +486,11 @@ public class LibMatrixReorg
                
                //prepare output
                out = new ArrayList<>();
-               for( Entry<MatrixIndexes, MatrixBlock> e : rblk.entrySet() ) {
-                       e.getValue().examSparsity(); //ensure correct format
-                       out.add(new 
IndexedMatrixValue(e.getKey(),e.getValue()));
-               }
+               for( Entry<MatrixIndexes, MatrixBlock> e : rblk.entrySet() )
+                       if( outputEmptyBlocks || 
!e.getValue().isEmptyBlock(false) ) {
+                               e.getValue().examSparsity(); //ensure correct 
format
+                               out.add(new 
IndexedMatrixValue(e.getKey(),e.getValue()));
+                       }
                
                return out;
        }

Reply via email to