Repository: incubator-systemml
Updated Branches:
  refs/heads/master 97b136601 -> 873bae76b


[SYSTEMML-562] Fix serialization/partitioning of partitioned broadcasts

This patch fixes an issue with incorrect class references on
deserialization as well as various smaller issues related to
incompatible serialization/deserialization (long vs int rlen/clen) and
partitioning (destroyed block references).

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/42906ba1
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/42906ba1
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/42906ba1

Branch: refs/heads/master
Commit: 42906ba126ed25b3dfee3ab254c9afbb6caf421e
Parents: 97b1366
Author: Matthias Boehm <[email protected]>
Authored: Wed Jul 27 20:59:15 2016 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Wed Jul 27 20:59:15 2016 -0700

----------------------------------------------------------------------
 .../caching/CacheBlockFactory.java              |  57 ++++++
 .../context/SparkExecutionContext.java          |   2 +-
 .../spark/data/PartitionedBlock.java            | 180 +++++++++----------
 3 files changed, 140 insertions(+), 99 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/42906ba1/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheBlockFactory.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheBlockFactory.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheBlockFactory.java
new file mode 100644
index 0000000..3bf86b5
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheBlockFactory.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.caching;
+
+import org.apache.sysml.runtime.matrix.data.FrameBlock;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+
+/**
+ * Factory to create instances of matrix/frame blocks given
+ * internal codes.
+ * 
+ */
+public class CacheBlockFactory 
+{
+       /**
+        * 
+        * @param code
+        * @return
+        */
+       public static CacheBlock newInstance(int code) {
+               switch( code ) {
+                       case 0: return new MatrixBlock();
+                       case 1: return new FrameBlock();
+               }
+               throw new RuntimeException("Unsupported cache block type: 
"+code);
+       }
+       
+       /**
+        * 
+        * @param block
+        * @return
+        */
+       public static int getCode(CacheBlock block) {
+               if( block instanceof MatrixBlock )
+                       return 0;
+               else if( block instanceof FrameBlock )
+                       return 1;
+               throw new RuntimeException("Unsupported cache block type: 
"+block.getClass().getName());
+       }
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/42906ba1/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
index 99614f2..0eea221 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
@@ -910,7 +910,7 @@ public class SparkExecutionContext extends ExecutionContext
                
                long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
 
-               PartitionedBlock<MatrixBlock> out = new 
PartitionedBlock<MatrixBlock>(rlen, clen, brlen, bclen, new MatrixBlock());
+               PartitionedBlock<MatrixBlock> out = new 
PartitionedBlock<MatrixBlock>(rlen, clen, brlen, bclen);
                List<Tuple2<MatrixIndexes,MatrixBlock>> list = rdd.collect();
                
                //copy blocks one-at-a-time into output matrix block

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/42906ba1/src/main/java/org/apache/sysml/runtime/instructions/spark/data/PartitionedBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/PartitionedBlock.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/PartitionedBlock.java
index 20fcd0b..465fdd5 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/data/PartitionedBlock.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/data/PartitionedBlock.java
@@ -27,28 +27,27 @@ import java.io.ObjectInput;
 import java.io.ObjectInputStream;
 import java.io.ObjectOutput;
 import java.io.ObjectOutputStream;
-import java.lang.reflect.Array;
 import java.lang.reflect.Constructor;
 import java.util.ArrayList;
 
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.CacheBlock;
+import org.apache.sysml.runtime.controlprogram.caching.CacheBlockFactory;
 import org.apache.sysml.runtime.matrix.data.Pair;
 import org.apache.sysml.runtime.util.FastBufferedDataInputStream;
 import org.apache.sysml.runtime.util.FastBufferedDataOutputStream;
 import org.apache.sysml.runtime.util.IndexRange;
 
 /**
- * This class is for partitioned matrix/frame blocks, to be used
- * as broadcasts. Distributed tasks require block-partitioned broadcasts but a 
lazy partitioning per
- * task would create instance-local copies and hence replicate broadcast 
variables which are shared
- * by all tasks within an executor.  
+ * This class is for partitioned matrix/frame blocks, to be used as 
broadcasts. 
+ * Distributed tasks require block-partitioned broadcasts but a lazy 
partitioning 
+ * per task would create instance-local copies and hence replicate broadcast 
+ * variables which are shared by all tasks within an executor.  
  * 
  */
 public class PartitionedBlock<T extends CacheBlock> implements Externalizable
 {
-
-       protected T[] _partBlocks = null; 
+       protected CacheBlock[] _partBlocks = null; 
        protected long _rlen = -1;
        protected long _clen = -1;
        protected int _brlen = -1;
@@ -60,40 +59,6 @@ public class PartitionedBlock<T extends CacheBlock> 
implements Externalizable
        }
        
        
-       public long getNumRows() {
-               return _rlen;
-       }
-       
-       public long getNumCols() {
-               return _clen;
-       }
-       
-       public long getNumRowsPerBlock() {
-               return _brlen;
-       }
-       
-       public long getNumColumnsPerBlock() {
-               return _bclen;
-       }
-       
-       /**
-        * 
-        * @return
-        */
-       public int getNumRowBlocks() 
-       {
-               return (int)Math.ceil((double)_rlen/_brlen);
-       }
-       
-       /**
-        * 
-        * @return
-        */
-       public int getNumColumnBlocks() 
-       {
-               return (int)Math.ceil((double)_clen/_bclen);
-       }
-       
        @SuppressWarnings("unchecked")
        public PartitionedBlock(T block, int brlen, int bclen) 
        {
@@ -106,17 +71,16 @@ public class PartitionedBlock<T extends CacheBlock> 
implements Externalizable
                _clen = clen;
                _brlen = brlen;
                _bclen = bclen;
-
                int nrblks = getNumRowBlocks();
                int ncblks = getNumColumnBlocks();
+               int code = CacheBlockFactory.getCode(block);
                
                try
                {
-                       _partBlocks = 
(T[])Array.newInstance((block.getClass()), nrblks * ncblks);
+                       _partBlocks = new CacheBlock[nrblks * ncblks];
                        for( int i=0, ix=0; i<nrblks; i++ )
-                               for( int j=0; j<ncblks; j++, ix++ )
-                               {
-                                       T tmp = (T) 
block.getClass().newInstance();
+                               for( int j=0; j<ncblks; j++, ix++ ) {
+                                       T tmp = (T) 
CacheBlockFactory.newInstance(code);
                                        block.sliceOperations(i*_brlen, 
Math.min((i+1)*_brlen, rlen)-1, 
                                                                   j*_bclen, 
Math.min((j+1)*_bclen, clen)-1, tmp);
                                        _partBlocks[ix] = tmp;
@@ -129,8 +93,7 @@ public class PartitionedBlock<T extends CacheBlock> 
implements Externalizable
                _offset = 0;
        }
 
-       @SuppressWarnings("unchecked")
-       public PartitionedBlock(int rlen, int clen, int brlen, int bclen, T 
block) 
+       public PartitionedBlock(int rlen, int clen, int brlen, int bclen) 
        {
                //partitioning input broadcast
                _rlen = rlen;
@@ -140,8 +103,60 @@ public class PartitionedBlock<T extends CacheBlock> 
implements Externalizable
                
                int nrblks = getNumRowBlocks();
                int ncblks = getNumColumnBlocks();
-               _partBlocks = (T[])Array.newInstance((block.getClass()), nrblks 
* ncblks);
-
+               _partBlocks = new CacheBlock[nrblks * ncblks];
+       }
+       
+       
+       /**
+        * 
+        * @param offset
+        * @param numBlks
+        * @return
+        */
+       public PartitionedBlock<T> createPartition( int offset, int numBlks, T 
block )
+       {
+               PartitionedBlock<T> ret = new PartitionedBlock<T>();
+               ret._rlen = _rlen;
+               ret._clen = _clen;
+               ret._brlen = _brlen;
+               ret._bclen = _bclen;
+               ret._partBlocks = new CacheBlock[numBlks];
+               ret._offset = offset;
+               System.arraycopy(_partBlocks, offset, ret._partBlocks, 0, 
numBlks);
+               
+               return ret;
+       }
+       
+       public long getNumRows() {
+               return _rlen;
+       }
+       
+       public long getNumCols() {
+               return _clen;
+       }
+       
+       public long getNumRowsPerBlock() {
+               return _brlen;
+       }
+       
+       public long getNumColumnsPerBlock() {
+               return _bclen;
+       }
+       
+       /**
+        * 
+        * @return
+        */
+       public int getNumRowBlocks() {
+               return (int)Math.ceil((double)_rlen/_brlen);
+       }
+       
+       /**
+        * 
+        * @return
+        */
+       public int getNumColumnBlocks() {
+               return (int)Math.ceil((double)_clen/_bclen);
        }
        
        /**
@@ -151,6 +166,7 @@ public class PartitionedBlock<T extends CacheBlock> 
implements Externalizable
         * @return
         * @throws DMLRuntimeException 
         */
+       @SuppressWarnings("unchecked")
        public T getBlock(int rowIndex, int colIndex) 
                throws DMLRuntimeException 
        {
@@ -165,7 +181,7 @@ public class PartitionedBlock<T extends CacheBlock> 
implements Externalizable
                int rix = rowIndex - 1;
                int cix = colIndex - 1;
                int ix = rix*ncblks+cix - _offset;
-               return _partBlocks[ix];
+               return (T)_partBlocks[ix];
        }
        
        /**
@@ -189,43 +205,19 @@ public class PartitionedBlock<T extends CacheBlock> 
implements Externalizable
                int rix = rowIndex - 1;
                int cix = colIndex - 1;
                int ix = rix*ncblks+cix - _offset;
-               _partBlocks[ ix ] = block;
-               
-       }
-       
-       /**
-        * 
-        * @param offset
-        * @param numBlks
-        * @return
-        */
-       @SuppressWarnings("unchecked")
-       public PartitionedBlock<T> createPartition( int offset, int numBlks, T 
block )
-       {
-               PartitionedBlock<T> ret = new PartitionedBlock<T>();
-               ret._rlen = _rlen;
-               ret._clen = _clen;
-               ret._brlen = _brlen;
-               ret._bclen = _bclen;
-
-               _partBlocks = (T[])Array.newInstance(block.getClass(), numBlks);
-               ret._offset = offset;
-               System.arraycopy(_partBlocks, offset, ret._partBlocks, 0, 
numBlks);
-               
-               return ret;
+               _partBlocks[ ix ] = block;      
        }
 
        /**
         * 
         * @return
         */     
-       public long getInMemorySize()
-       {
+       public long getInMemorySize() {
                long ret = 24; //header
                ret += 32;    //block array
                
                if( _partBlocks != null )
-                       for( T block : _partBlocks )
+                       for( CacheBlock block : _partBlocks )
                                ret += block.getInMemorySize();
                
                return ret;
@@ -236,12 +228,11 @@ public class PartitionedBlock<T extends CacheBlock> 
implements Externalizable
         * @return
         */
        
-       public long getExactSerializedSize()
-       {
+       public long getExactSerializedSize() {
                long ret = 24; //header
                
                if( _partBlocks != null )
-                       for( T block :  _partBlocks )
+                       for( CacheBlock block : _partBlocks )
                                ret += block.getExactSerializedSize();
                
                return ret;
@@ -361,8 +352,9 @@ public class PartitionedBlock<T extends CacheBlock> 
implements Externalizable
                dos.writeInt(_bclen);
                dos.writeInt(_offset);
                dos.writeInt(_partBlocks.length);
+               dos.writeByte(CacheBlockFactory.getCode(_partBlocks[0]));
                
-               for( T block : _partBlocks )
+               for( CacheBlock block : _partBlocks )
                        block.write(dos);
        }
 
@@ -371,29 +363,21 @@ public class PartitionedBlock<T extends CacheBlock> 
implements Externalizable
         * @param din
         * @throws IOException 
         */
-       @SuppressWarnings("unchecked")
        private void readHeaderAndPayload(DataInput dis) 
                throws IOException
        {
-               _rlen = dis.readInt();
-               _clen = dis.readInt();
+               _rlen = dis.readLong();
+               _clen = dis.readLong();
                _brlen = dis.readInt();
                _bclen = dis.readInt();
-               _offset = dis.readInt();
-               
+               _offset = dis.readInt();                
                int len = dis.readInt();
+               int code = dis.readByte();
                
-               try
-               {
-                       _partBlocks = (T[])Array.newInstance(getClass(), len);
-                       for( int i=0; i<len; i++ ) {
-                               _partBlocks[i].readFields(dis);
-                       }
+               _partBlocks = new CacheBlock[len];
+               for( int i=0; i<len; i++ ) {
+                       _partBlocks[i] = CacheBlockFactory.newInstance(code);
+                       _partBlocks[i].readFields(dis);
                }
-               catch(Exception ex) {
-                       throw new RuntimeException("Failed partitioning of 
broadcast variable input.", ex);
-               }
-               
        }
-       
 }

Reply via email to