[SYSTEMML-2254] Robustness transformdecode for ncol(data) < ncol(meta)

This patch fixes the robustness of transformdecode. So far, we created
the output frame in the number of columns in the given meta data frame.
However, this leads to index-out-of-bound exceptions if the given input
data has fewer columns than the meta data frame due to unnecessary
pass-through encoders that try to access non-existing columns.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/36a06ab6
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/36a06ab6
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/36a06ab6

Branch: refs/heads/master
Commit: 36a06ab6876d3d92e01b8c94a46c77841d654792
Parents: 2838cef
Author: Matthias Boehm <[email protected]>
Authored: Wed Apr 18 00:14:27 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Wed Apr 18 00:14:27 2018 -0700

----------------------------------------------------------------------
 .../cp/ParameterizedBuiltinCPInstruction.java   |  3 ++-
 .../transform/decode/DecoderFactory.java        | 25 ++++++++++++--------
 .../transform/decode/DecoderPassThrough.java    |  3 ++-
 3 files changed, 19 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/36a06ab6/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index fe18dbc..06aeef3 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -268,7 +268,8 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
                        String[] colnames = meta.getColumnNames();
                        
                        //compute transformdecode
-                       Decoder decoder = 
DecoderFactory.createDecoder(getParameterMap().get("spec"), colnames, null, 
meta);
+                       Decoder decoder = DecoderFactory.createDecoder(
+                               getParameterMap().get("spec"), colnames, null, 
meta, data.getNumColumns());
                        FrameBlock fbout = decoder.decode(data, new 
FrameBlock(decoder.getSchema()));
                        fbout.setColumnNames(Arrays.copyOfRange(colnames, 0, 
fbout.getNumColumns()));
                        

http://git-wip-us.apache.org/repos/asf/systemml/blob/36a06ab6/src/main/java/org/apache/sysml/runtime/transform/decode/DecoderFactory.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/transform/decode/DecoderFactory.java 
b/src/main/java/org/apache/sysml/runtime/transform/decode/DecoderFactory.java
index d8b9f3a..c2788cb 100644
--- 
a/src/main/java/org/apache/sysml/runtime/transform/decode/DecoderFactory.java
+++ 
b/src/main/java/org/apache/sysml/runtime/transform/decode/DecoderFactory.java
@@ -36,45 +36,50 @@ import org.apache.wink.json4j.JSONObject;
 
 public class DecoderFactory 
 {
+       public static Decoder createDecoder(String spec, String[] colnames, 
ValueType[] schema, FrameBlock meta) {
+               return createDecoder(spec, colnames, schema, meta, 
meta.getNumColumns());
+       }
+       
        @SuppressWarnings("unchecked")
-       public static Decoder createDecoder(String spec, String[] colnames, 
ValueType[] schema, FrameBlock meta) 
+       public static Decoder createDecoder(String spec, String[] colnames, 
ValueType[] schema, FrameBlock meta, int clen) 
        {
                Decoder decoder = null;
                
-               try 
+               try
                {
                        //parse transform specification
                        JSONObject jSpec = new JSONObject(spec);
                        List<Decoder> ldecoders = new ArrayList<>();
-               
+                       
                        //create decoders 'recode', 'dummy' and 'pass-through'
                        List<Integer> rcIDs = Arrays.asList(ArrayUtils.toObject(
                                        TfMetaUtils.parseJsonIDList(jSpec, 
colnames, TfUtils.TXMETHOD_RECODE)));
                        List<Integer> dcIDs = Arrays.asList(ArrayUtils.toObject(
                                        TfMetaUtils.parseJsonIDList(jSpec, 
colnames, TfUtils.TXMETHOD_DUMMYCODE))); 
                        rcIDs = new 
ArrayList<Integer>(CollectionUtils.union(rcIDs, dcIDs));
+                       int len = dcIDs.isEmpty() ? 
Math.min(meta.getNumColumns(), clen) : meta.getNumColumns();
                        List<Integer> ptIDs = new 
ArrayList<Integer>(CollectionUtils
-                                       .subtract(UtilFunctions.getSeqList(1, 
meta.getNumColumns(), 1), rcIDs)); 
-
+                               .subtract(UtilFunctions.getSeqList(1, len, 1), 
rcIDs));
+                       
                        //create default schema if unspecified (with double 
columns for pass-through)
                        if( schema == null ) {
-                               schema = 
UtilFunctions.nCopies(meta.getNumColumns(), ValueType.STRING);
+                               schema = UtilFunctions.nCopies(len, 
ValueType.STRING);
                                for( Integer col : ptIDs )
                                        schema[col-1] = ValueType.DOUBLE;
                        }
                        
                        if( !dcIDs.isEmpty() ) {
                                ldecoders.add(new DecoderDummycode(schema, 
-                                               
ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0]))));
+                                       
ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0]))));
                        }
                        if( !rcIDs.isEmpty() ) {
                                ldecoders.add(new DecoderRecode(schema, 
!dcIDs.isEmpty(),
-                                               
ArrayUtils.toPrimitive(rcIDs.toArray(new Integer[0]))));
+                                       
ArrayUtils.toPrimitive(rcIDs.toArray(new Integer[0]))));
                        }
                        if( !ptIDs.isEmpty() ) {
                                ldecoders.add(new DecoderPassThrough(schema, 
-                                               
ArrayUtils.toPrimitive(ptIDs.toArray(new Integer[0])),
-                                               
ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0]))));        
+                                       
ArrayUtils.toPrimitive(ptIDs.toArray(new Integer[0])),
+                                       
ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0]))));
                        }
                        
                        //create composite decoder of all created decoders

http://git-wip-us.apache.org/repos/asf/systemml/blob/36a06ab6/src/main/java/org/apache/sysml/runtime/transform/decode/DecoderPassThrough.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/transform/decode/DecoderPassThrough.java
 
b/src/main/java/org/apache/sysml/runtime/transform/decode/DecoderPassThrough.java
index 8224b5c..79cf3d4 100644
--- 
a/src/main/java/org/apache/sysml/runtime/transform/decode/DecoderPassThrough.java
+++ 
b/src/main/java/org/apache/sysml/runtime/transform/decode/DecoderPassThrough.java
@@ -44,8 +44,9 @@ public class DecoderPassThrough extends Decoder
        @Override
        public FrameBlock decode(MatrixBlock in, FrameBlock out) {
                out.ensureAllocatedColumns(in.getNumRows());
+               int clen = Math.min(_colList.length, out.getNumColumns());
                for( int i=0; i<in.getNumRows(); i++ ) {
-                       for( int j=0; j<_colList.length; j++ ) {
+                       for( int j=0; j<clen; j++ ) {
                                int srcColID = _srcCols[j];
                                int tgtColID = _colList[j];
                                double val = in.quickGetValue(i, srcColID-1);

Reply via email to