[SYSTEMML-2254] Robustness transformdecode for ncol(data) < ncol(meta) This patch fixes the robustness of transformdecode. So far, we created the output frame in the number of columns in the given meta data frame. However, this leads to index-out-of-bound exceptions if the given input data has fewer columns than the meta data frame due to unnecessary pass-through encoders that try to access non-existing columns.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/36a06ab6 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/36a06ab6 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/36a06ab6 Branch: refs/heads/master Commit: 36a06ab6876d3d92e01b8c94a46c77841d654792 Parents: 2838cef Author: Matthias Boehm <[email protected]> Authored: Wed Apr 18 00:14:27 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Wed Apr 18 00:14:27 2018 -0700 ---------------------------------------------------------------------- .../cp/ParameterizedBuiltinCPInstruction.java | 3 ++- .../transform/decode/DecoderFactory.java | 25 ++++++++++++-------- .../transform/decode/DecoderPassThrough.java | 3 ++- 3 files changed, 19 insertions(+), 12 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/36a06ab6/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java index fe18dbc..06aeef3 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java @@ -268,7 +268,8 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction String[] colnames = meta.getColumnNames(); //compute transformdecode - Decoder decoder = DecoderFactory.createDecoder(getParameterMap().get("spec"), colnames, null, meta); + Decoder decoder = DecoderFactory.createDecoder( + getParameterMap().get("spec"), colnames, null, meta, data.getNumColumns()); FrameBlock fbout = decoder.decode(data, new FrameBlock(decoder.getSchema())); fbout.setColumnNames(Arrays.copyOfRange(colnames, 0, fbout.getNumColumns())); http://git-wip-us.apache.org/repos/asf/systemml/blob/36a06ab6/src/main/java/org/apache/sysml/runtime/transform/decode/DecoderFactory.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/transform/decode/DecoderFactory.java b/src/main/java/org/apache/sysml/runtime/transform/decode/DecoderFactory.java index d8b9f3a..c2788cb 100644 --- a/src/main/java/org/apache/sysml/runtime/transform/decode/DecoderFactory.java +++ b/src/main/java/org/apache/sysml/runtime/transform/decode/DecoderFactory.java @@ -36,45 +36,50 @@ import org.apache.wink.json4j.JSONObject; public class DecoderFactory { + public static Decoder createDecoder(String spec, String[] colnames, ValueType[] schema, FrameBlock meta) { + return createDecoder(spec, colnames, schema, meta, meta.getNumColumns()); + } + @SuppressWarnings("unchecked") - public static Decoder createDecoder(String spec, String[] colnames, ValueType[] schema, FrameBlock meta) + public static Decoder createDecoder(String spec, String[] colnames, ValueType[] schema, FrameBlock meta, int clen) { Decoder decoder = null; - try + try { //parse transform specification JSONObject jSpec = new JSONObject(spec); List<Decoder> ldecoders = new ArrayList<>(); - + //create decoders 'recode', 'dummy' and 'pass-through' List<Integer> rcIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(jSpec, colnames, TfUtils.TXMETHOD_RECODE))); List<Integer> dcIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(jSpec, colnames, TfUtils.TXMETHOD_DUMMYCODE))); rcIDs = new ArrayList<Integer>(CollectionUtils.union(rcIDs, dcIDs)); + int len = dcIDs.isEmpty() ? Math.min(meta.getNumColumns(), clen) : meta.getNumColumns(); List<Integer> ptIDs = new ArrayList<Integer>(CollectionUtils - .subtract(UtilFunctions.getSeqList(1, meta.getNumColumns(), 1), rcIDs)); - + .subtract(UtilFunctions.getSeqList(1, len, 1), rcIDs)); + //create default schema if unspecified (with double columns for pass-through) if( schema == null ) { - schema = UtilFunctions.nCopies(meta.getNumColumns(), ValueType.STRING); + schema = UtilFunctions.nCopies(len, ValueType.STRING); for( Integer col : ptIDs ) schema[col-1] = ValueType.DOUBLE; } if( !dcIDs.isEmpty() ) { ldecoders.add(new DecoderDummycode(schema, - ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); + ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); } if( !rcIDs.isEmpty() ) { ldecoders.add(new DecoderRecode(schema, !dcIDs.isEmpty(), - ArrayUtils.toPrimitive(rcIDs.toArray(new Integer[0])))); + ArrayUtils.toPrimitive(rcIDs.toArray(new Integer[0])))); } if( !ptIDs.isEmpty() ) { ldecoders.add(new DecoderPassThrough(schema, - ArrayUtils.toPrimitive(ptIDs.toArray(new Integer[0])), - ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); + ArrayUtils.toPrimitive(ptIDs.toArray(new Integer[0])), + ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); } //create composite decoder of all created decoders http://git-wip-us.apache.org/repos/asf/systemml/blob/36a06ab6/src/main/java/org/apache/sysml/runtime/transform/decode/DecoderPassThrough.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/transform/decode/DecoderPassThrough.java b/src/main/java/org/apache/sysml/runtime/transform/decode/DecoderPassThrough.java index 8224b5c..79cf3d4 100644 --- a/src/main/java/org/apache/sysml/runtime/transform/decode/DecoderPassThrough.java +++ b/src/main/java/org/apache/sysml/runtime/transform/decode/DecoderPassThrough.java @@ -44,8 +44,9 @@ public class DecoderPassThrough extends Decoder @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { out.ensureAllocatedColumns(in.getNumRows()); + int clen = Math.min(_colList.length, out.getNumColumns()); for( int i=0; i<in.getNumRows(); i++ ) { - for( int j=0; j<_colList.length; j++ ) { + for( int j=0; j<clen; j++ ) { int srcColID = _srcCols[j]; int tgtColID = _colList[j]; double val = in.quickGetValue(i, srcColID-1);
