Repository: systemml Updated Branches: refs/heads/master e7d948f9c -> 1791c1a26
[SYSTEMML-2297] Fix spark left indexing w/ partitioned broadcasts This patch fixes the distributed spark left indexing operations which failed due to index-out-of--bound for partitioned broadcasts and unaligned blocks crossing multiple partitions, while other partitions are not touched at all. We resolve this by simply pulling the slice over 1k x 1k blocks from the broadcast partition into the partitioned broadcast because it anyway just works with the getBlock abstraction. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/1791c1a2 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/1791c1a2 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/1791c1a2 Branch: refs/heads/master Commit: 1791c1a26279445726416c0cf8ce55257da19840 Parents: e7d948f Author: Matthias Boehm <[email protected]> Authored: Thu May 3 17:43:07 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Thu May 3 17:43:07 2018 -0700 ---------------------------------------------------------------------- .../spark/data/PartitionedBlock.java | 56 ----------------- .../spark/data/PartitionedBroadcast.java | 66 ++++++++++++++++---- 2 files changed, 55 insertions(+), 67 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/1791c1a2/src/main/java/org/apache/sysml/runtime/instructions/spark/data/PartitionedBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/PartitionedBlock.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/PartitionedBlock.java index 870d2a2..8a4999b 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/PartitionedBlock.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/PartitionedBlock.java @@ -27,18 +27,13 @@ import java.io.ObjectInput; import java.io.ObjectInputStream; import java.io.ObjectOutput; import java.io.ObjectOutputStream; -import java.lang.reflect.Constructor; -import java.util.ArrayList; import java.util.Arrays; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.caching.CacheBlock; import org.apache.sysml.runtime.controlprogram.caching.CacheBlockFactory; -import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues; -import org.apache.sysml.runtime.matrix.data.Pair; import org.apache.sysml.runtime.util.FastBufferedDataInputStream; import org.apache.sysml.runtime.util.FastBufferedDataOutputStream; -import org.apache.sysml.runtime.util.IndexRange; /** * This class is for partitioned matrix/frame blocks, to be used as broadcasts. @@ -195,57 +190,6 @@ public class PartitionedBlock<T extends CacheBlock> implements Externalizable return ret; } - - /** - * Utility for slice operations over partitioned matrices, where the index range can cover - * multiple blocks. The result is always a single result matrix block. All semantics are - * equivalent to the core matrix block slice operations. - * - * @param rl row lower bound - * @param ru row upper bound - * @param cl column lower bound - * @param cu column upper bound - * @param block block object - * @return block object - */ - @SuppressWarnings("unchecked") - public T slice(long rl, long ru, long cl, long cu, T block) { - int lrl = (int) rl; - int lru = (int) ru; - int lcl = (int) cl; - int lcu = (int) cu; - - ArrayList<Pair<?, ?>> allBlks = (ArrayList<Pair<?, ?>>) CacheBlockFactory.getPairList(block); - int start_iix = (lrl-1)/_brlen+1; - int end_iix = (lru-1)/_brlen+1; - int start_jix = (lcl-1)/_bclen+1; - int end_jix = (lcu-1)/_bclen+1; - - for( int iix = start_iix; iix <= end_iix; iix++ ) - for(int jix = start_jix; jix <= end_jix; jix++) { - IndexRange ixrange = new IndexRange(rl, ru, cl, cu); - allBlks.addAll(OperationsOnMatrixValues.performSlice( - ixrange, _brlen, _bclen, iix, jix, getBlock(iix, jix))); - } - - if(allBlks.size() == 1) { - return (T) allBlks.get(0).getValue(); - } - else { - //allocate output matrix - Constructor<?> constr; - try { - constr = block.getClass().getConstructor(int.class, int.class, boolean.class); - T ret = (T) constr.newInstance(lru-lrl+1, lcu-lcl+1, false); - for(Pair<?, ?> kv : allBlks) { - ret.merge((T)kv.getValue(), false); - } - return ret; - } catch (Exception e) { - throw new DMLRuntimeException(e); - } - } - } public void clearBlocks() { _partBlocks = null; http://git-wip-us.apache.org/repos/asf/systemml/blob/1791c1a2/src/main/java/org/apache/sysml/runtime/instructions/spark/data/PartitionedBroadcast.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/PartitionedBroadcast.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/PartitionedBroadcast.java index 94fcee4..a4c5173 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/PartitionedBroadcast.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/PartitionedBroadcast.java @@ -20,11 +20,18 @@ package org.apache.sysml.runtime.instructions.spark.data; import java.io.Serializable; +import java.lang.reflect.Constructor; +import java.util.ArrayList; import org.apache.spark.broadcast.Broadcast; +import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.caching.CacheBlock; +import org.apache.sysml.runtime.controlprogram.caching.CacheBlockFactory; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; +import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues; +import org.apache.sysml.runtime.matrix.data.Pair; +import org.apache.sysml.runtime.util.IndexRange; /** * This class is a wrapper around an array of broadcasts of partitioned matrix/frame blocks, @@ -90,20 +97,57 @@ public class PartitionedBroadcast<T extends CacheBlock> implements Serializable return _pbc[pix].value().getBlock(rowIndex, colIndex); } + /** + * Utility for slice operations over partitioned matrices, where the index range can cover + * multiple blocks. The result is always a single result matrix block. All semantics are + * equivalent to the core matrix block slice operations. + * + * @param rl row lower bound + * @param ru row upper bound + * @param cl column lower bound + * @param cu column upper bound + * @param block block object + * @return block object + */ + @SuppressWarnings("unchecked") public T slice(long rl, long ru, long cl, long cu, T block) { - T ret = null; - for( Broadcast<PartitionedBlock<T>> bc : _pbc ) { - PartitionedBlock<T> pm = bc.value(); - T tmp = pm.slice(rl, ru, cl, cu, block); - if( ret != null ) - ret.merge(tmp, false); - else - ret = tmp; - } + int lrl = (int) rl; + int lru = (int) ru; + int lcl = (int) cl; + int lcu = (int) cu; + + ArrayList<Pair<?, ?>> allBlks = (ArrayList<Pair<?, ?>>) CacheBlockFactory.getPairList(block); + int start_iix = (lrl-1)/_mc.getRowsPerBlock()+1; + int end_iix = (lru-1)/_mc.getRowsPerBlock()+1; + int start_jix = (lcl-1)/_mc.getColsPerBlock()+1; + int end_jix = (lcu-1)/_mc.getColsPerBlock()+1; + + for( int iix = start_iix; iix <= end_iix; iix++ ) + for(int jix = start_jix; jix <= end_jix; jix++) { + IndexRange ixrange = new IndexRange(rl, ru, cl, cu); + allBlks.addAll(OperationsOnMatrixValues.performSlice( + ixrange, _mc.getRowsPerBlock(), _mc.getColsPerBlock(), iix, jix, getBlock(iix, jix))); + } - return ret; + if(allBlks.size() == 1) { + return (T) allBlks.get(0).getValue(); + } + else { + //allocate output matrix + Constructor<?> constr; + try { + constr = block.getClass().getConstructor(int.class, int.class, boolean.class); + T ret = (T) constr.newInstance(lru-lrl+1, lcu-lcl+1, false); + for(Pair<?, ?> kv : allBlks) { + ret.merge((T)kv.getValue(), false); + } + return ret; + } catch (Exception e) { + throw new DMLRuntimeException(e); + } + } } - + /** * This method cleanups all underlying broadcasts of a partitioned broadcast, * by forward the calls to SparkExecutionContext.cleanupBroadcastVariable.
