This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push: new 9efd321 [SYSTEMDS-2555] Federated transform dummycode encode/decode 9efd321 is described below commit 9efd32125ac0142fb108f8b25f5b8cbecea1c06c Author: Kevin Innerebner <kevin.innereb...@yahoo.com> AuthorDate: Sat Aug 29 00:17:51 2020 +0200 [SYSTEMDS-2555] Federated transform dummycode encode/decode Closes #1031. --- .../controlprogram/federated/FederationMap.java | 7 ++ ...tiReturnParameterizedBuiltinFEDInstruction.java | 25 +++-- .../fed/ParameterizedBuiltinFEDInstruction.java | 62 ++++++----- .../sysds/runtime/transform/decode/Decoder.java | 27 ++++- .../runtime/transform/decode/DecoderComposite.java | 27 ++++- .../runtime/transform/decode/DecoderDummycode.java | 59 +++++++++- .../transform/decode/DecoderPassThrough.java | 40 ++++++- .../runtime/transform/decode/DecoderRecode.java | 29 +++++ .../sysds/runtime/transform/encode/Encoder.java | 10 ++ .../sysds/runtime/transform/encode/EncoderBin.java | 2 +- .../runtime/transform/encode/EncoderComposite.java | 20 +++- .../runtime/transform/encode/EncoderDummycode.java | 122 +++++++++++++++++++-- .../runtime/transform/encode/EncoderFactory.java | 2 +- .../transform/encode/EncoderPassThrough.java | 3 - .../runtime/transform/encode/EncoderRecode.java | 13 ++- .../TransformFederatedEncodeDecodeTest.java | 96 +++++++++++----- ...dml => TransformDummyFederatedEncodeDecode.dml} | 5 - .../transform/TransformEncodeDecodeDummySpec.json | 5 + ...ml => TransformRecodeFederatedEncodeDecode.dml} | 0 19 files changed, 457 insertions(+), 97 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java index 72d1196..ea8aa29 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java @@ -240,6 +240,13 @@ public class FederationMap return this; } + + public long getMaxIndexInRange(int dim) { + long maxIx = 0; + for(FederatedRange range : _fedMap.keySet()) + maxIx = Math.max(range.getEndDims()[dim], maxIx); + return maxIx; + } /** * Execute a function for each <code>FederatedRange</code> + <code>FederatedData</code> pair. The function should diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java index 5d25729..b9b6203 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java @@ -44,6 +44,7 @@ import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.privacy.PrivacyMonitor; import org.apache.sysds.runtime.transform.encode.Encoder; import org.apache.sysds.runtime.transform.encode.EncoderComposite; +import org.apache.sysds.runtime.transform.encode.EncoderDummycode; import org.apache.sysds.runtime.transform.encode.EncoderFactory; import org.apache.sysds.runtime.transform.encode.EncoderPassThrough; import org.apache.sysds.runtime.transform.encode.EncoderRecode; @@ -88,7 +89,7 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE // the encoder in which the complete encoding information will be aggregated EncoderComposite globalEncoder = new EncoderComposite( - Arrays.asList(new EncoderRecode(), new EncoderPassThrough())); + Arrays.asList(new EncoderRecode(), new EncoderPassThrough(), new EncoderDummycode())); // first create encoders at the federated workers, then collect them and aggregate them to a single large // encoder FederationMap fedMapping = fin.getFedMapping(); @@ -115,14 +116,21 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE }); long varID = FederationUtils.getNextFedDataID(); FederationMap transformedFedMapping = fedMapping.mapParallel(varID, (range, data) -> { - int colStart = (int) range.getBeginDims()[1] + 1; - int colEnd = (int) range.getEndDims()[1] + 1; + // copy because we reuse it + long[] beginDims = range.getBeginDims(); + long[] endDims = range.getEndDims(); + int colStart = (int) beginDims[1] + 1; + int colEnd = (int) endDims[1] + 1; + + // update begin end dims (column part) considering columns added by dummycoding + globalEncoder.updateIndexRanges(beginDims, endDims); + // get the encoder segment that is relevant for this federated worker Encoder encoder = globalEncoder.subRangeEncoder(colStart, colEnd); + try { - FederatedResponse response = data.executeFederatedOperation( - new FederatedRequest(RequestType.EXEC_UDF, varID, - new ExecuteFrameEncoder(data.getVarID(), varID, encoder))).get(); + FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(RequestType.EXEC_UDF, + varID, new ExecuteFrameEncoder(data.getVarID(), varID, encoder))).get(); if(!response.isSuccessful()) response.throwExceptionFromResponse(); } @@ -134,13 +142,14 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE // construct a federated matrix with the encoded data MatrixObject transformedMat = ec.getMatrixObject(getOutput(0)); - transformedMat.getDataCharacteristics().set(fin.getDataCharacteristics()); + transformedMat.getDataCharacteristics().setRows(transformedFedMapping.getMaxIndexInRange(0)); + transformedMat.getDataCharacteristics().setCols(transformedFedMapping.getMaxIndexInRange(1)); // set the federated mapping for the matrix transformedMat.setFedMapping(transformedFedMapping); // release input and outputs ec.setFrameOutput(getOutput(1).getName(), - globalEncoder.getMetaData(new FrameBlock(globalEncoder.getNumCols(), Types.ValueType.STRING))); + globalEncoder.getMetaData(new FrameBlock((int) fin.getNumColumns(), Types.ValueType.STRING))); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java index e3523ed..47f912d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java @@ -102,7 +102,8 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio return new ParameterizedBuiltinFEDInstruction(null, paramsMap, out, opcode, str); } else { - throw new DMLRuntimeException("Unsupported opcode (" + opcode + ") for ParameterizedBuiltinFEDInstruction."); + throw new DMLRuntimeException( + "Unsupported opcode (" + opcode + ") for ParameterizedBuiltinFEDInstruction."); } } @@ -135,22 +136,36 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio FrameBlock meta = ec.getFrameInput(params.get("meta")); String spec = params.get("spec"); + Decoder globalDecoder = DecoderFactory + .createDecoder(spec, meta.getColumnNames(), null, meta, (int) mo.getNumColumns()); + FederationMap fedMapping = mo.getFedMapping(); ValueType[] schema = new ValueType[(int) mo.getNumColumns()]; long varID = FederationUtils.getNextFedDataID(); FederationMap decodedMapping = fedMapping.mapParallel(varID, (range, data) -> { - int columnOffset = (int) range.getBeginDims()[1] + 1; + long[] beginDims = range.getBeginDims(); + long[] endDims = range.getEndDims(); + int colStartBefore = (int) beginDims[1]; + + // update begin end dims (column part) considering columns added by dummycoding + globalDecoder.updateIndexRanges(beginDims, endDims); - FrameBlock subMeta = new FrameBlock(); + // get the decoder segment that is relevant for this federated worker + Decoder decoder = globalDecoder + .subRangeDecoder((int) beginDims[1] + 1, (int) endDims[1] + 1, colStartBefore); + + FrameBlock metaSlice = new FrameBlock(); synchronized(meta) { - meta.slice(0, meta.getNumRows() - 1, columnOffset - 1, (int) range.getEndDims()[1] - 1, subMeta); + meta.slice(0, meta.getNumRows() - 1, (int) beginDims[1], (int) endDims[1] - 1, metaSlice); } + FederatedResponse response; try { - response = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, - varID, new DecodeMatrix(data.getVarID(), varID, subMeta, spec, columnOffset))).get(); + response = data.executeFederatedOperation( + new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, varID, + new DecodeMatrix(data.getVarID(), varID, metaSlice, decoder))).get(); if(!response.isSuccessful()) response.throwExceptionFromResponse(); @@ -158,7 +173,7 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio synchronized(schema) { // It would be possible to assert that different federated workers don't give different value // types for the same columns, but the performance impact is not worth the effort - System.arraycopy(subSchema, 0, schema, columnOffset - 1, subSchema.length); + System.arraycopy(subSchema, 0, schema, colStartBefore, subSchema.length); } } catch(Exception e) { @@ -169,8 +184,9 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio // construct a federated matrix with the encoded data FrameObject decodedFrame = ec.getFrameObject(output); - decodedFrame.setSchema(schema); + decodedFrame.setSchema(globalDecoder.getSchema()); decodedFrame.getDataCharacteristics().set(mo.getDataCharacteristics()); + decodedFrame.getDataCharacteristics().setCols(globalDecoder.getSchema().length); // set the federated mapping for the matrix decodedFrame.setFedMapping(decodedMapping); @@ -185,34 +201,28 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio private CPOperand getTargetOperand() { return new CPOperand(params.get("target"), ValueType.FP64, DataType.MATRIX); } - + public static class DecodeMatrix extends FederatedUDF { private static final long serialVersionUID = 2376756757742169692L; - private final long _outputID; + private final long _outputID; private final FrameBlock _meta; - private final String _spec; - private final int _globalOffset; - - public DecodeMatrix(long input, long outputID, FrameBlock meta, String spec, int globalOffset) { - super(new long[]{input}); + private final Decoder _decoder; + + public DecodeMatrix(long input, long outputID, FrameBlock meta, Decoder decoder) { + super(new long[] {input}); _outputID = outputID; _meta = meta; - _spec = spec; - _globalOffset = globalOffset; + _decoder = decoder; } - - @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { MatrixObject mo = (MatrixObject) PrivacyMonitor.handlePrivacy(data[0]); MatrixBlock mb = mo.acquireRead(); String[] colNames = _meta.getColumnNames(); - - // compute transformdecode - Decoder decoder = DecoderFactory.createDecoder(_spec, colNames, null, - _meta, mb.getNumColumns(), _globalOffset, _globalOffset + mb.getNumColumns()); - FrameBlock fbout = decoder.decode(mb, new FrameBlock(decoder.getSchema())); + + FrameBlock fbout = _decoder.decode(mb, new FrameBlock(_decoder.getSchema())); fbout.setColumnNames(Arrays.copyOfRange(colNames, 0, fbout.getNumColumns())); - + // copy characteristics MatrixCharacteristics mc = new MatrixCharacteristics(mo.getDataCharacteristics()); FrameObject fo = new FrameObject(OptimizerUtils.getUniqueTempFileName(), @@ -221,7 +231,7 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio fo.acquireModify(fbout); fo.release(); mo.release(); - + // add it to the list of variables ec.setVariable(String.valueOf(_outputID), fo); // return schema diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java index 2aeda0f..4417387 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java @@ -22,6 +22,7 @@ package org.apache.sysds.runtime.transform.decode; import java.io.Serializable; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -64,6 +65,30 @@ public abstract class Decoder implements Serializable * @return returns given output frame block for convenience */ public abstract FrameBlock decode(MatrixBlock in, FrameBlock out); - + + /** + * Returns a new Decoder that only handles a sub range of columns. The sub-range refers to the columns after + * decoding. + * + * @param colStart the start index of the sub-range (1-based, inclusive) + * @param colEnd the end index of the sub-range (1-based, exclusive) + * @param dummycodedOffset the offset of dummycoded segments before colStart + * @return a decoder of the same type, just for the sub-range + */ + public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) { + throw new DMLRuntimeException( + getClass().getSimpleName() + " does not support the creation of a sub-range decoder"); + } + + /** + * Update index-ranges to after decoding. Note that only Dummycoding changes the ranges. + * + * @param beginDims the begin indexes before encoding + * @param endDims the end indexes before encoding + */ + public void updateIndexRanges(long[] beginDims, long[] endDims) { + // do nothing - default + } + public abstract void initMetaData(FrameBlock meta); } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java index 69fcb41..263e064 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java @@ -19,6 +19,8 @@ package org.apache.sysds.runtime.transform.decode; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.apache.sysds.common.Types.ValueType; @@ -26,9 +28,9 @@ import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; /** - * Simple composite decoder that applies a list of decoders + * Simple composite decoder that applies a list of decoders * in specified order. By implementing the default decoder API - * it can be used as a drop-in replacement for any other decoder. + * it can be used as a drop-in replacement for any other decoder. * */ public class DecoderComposite extends Decoder @@ -45,13 +47,30 @@ public class DecoderComposite extends Decoder @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { for( Decoder decoder : _decoders ) - out = decoder.decode(in, out); + out = decoder.decode(in, out); return out; } @Override + public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) { + List<Decoder> subRangeDecoders = new ArrayList<>(); + for (Decoder decoder : _decoders) { + Decoder subDecoder = decoder.subRangeDecoder(colStart, colEnd, dummycodedOffset); + if (subDecoder != null) + subRangeDecoders.add(subDecoder); + } + return new DecoderComposite(Arrays.copyOfRange(_schema, colStart-1, colEnd-1), subRangeDecoders); + } + + @Override + public void updateIndexRanges(long[] beginDims, long[] endDims) { + for(Decoder dec : _decoders) + dec.updateIndexRanges(beginDims, endDims); + } + + @Override public void initMetaData(FrameBlock meta) { for( Decoder decoder : _decoders ) - decoder.initMetaData(meta); + decoder.initMetaData(meta); } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java index 0ad2187..ab1fbc8 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java @@ -19,6 +19,9 @@ package org.apache.sysds.runtime.transform.decode; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -43,6 +46,7 @@ public class DecoderDummycode extends Decoder @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { + //TODO perf (exploit sparse representation for better asymptotic behavior) out.ensureAllocatedColumns(in.getNumRows()); for( int i=0; i<in.getNumRows(); i++ ) for( int j=0; j<_colList.length; j++ ) @@ -50,11 +54,60 @@ public class DecoderDummycode extends Decoder if( in.quickGetValue(i, k-1) != 0 ) { int col = _colList[j] - 1; out.set(i, col, UtilFunctions.doubleToObject( - out.getSchema()[col], k-_clPos[j]+1)); - } + out.getSchema()[col], k-_clPos[j]+1)); + } return out; } - + + @Override + public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) { + List<Integer> dcList = new ArrayList<>(); + List<Integer> clPosList = new ArrayList<>(); + List<Integer> cuPosList = new ArrayList<>(); + + // get the column IDs for the sub range of the dummycode columns and their destination positions, + // where they will be decoded to + for( int j=0; j<_colList.length; j++ ) { + int colID = _colList[j]; + if (colID >= colStart && colID < colEnd) { + dcList.add(colID - (colStart - 1)); + clPosList.add(_clPos[j] - dummycodedOffset); + cuPosList.add(_cuPos[j] - dummycodedOffset); + } + } + if (dcList.isEmpty()) + return null; + // create sub-range decoder + int[] colList = dcList.stream().mapToInt(i -> i).toArray(); + DecoderDummycode subRangeDecoder = new DecoderDummycode( + Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1), colList); + subRangeDecoder._clPos = clPosList.stream().mapToInt(i -> i).toArray(); + subRangeDecoder._cuPos = cuPosList.stream().mapToInt(i -> i).toArray(); + return subRangeDecoder; + } + + @Override + public void updateIndexRanges(long[] beginDims, long[] endDims) { + if(_colList == null) + return; + + long lowerColDest = beginDims[1]; + long upperColDest = endDims[1]; + for(int i = 0; i < _colList.length; i++) { + long numDistinct = _cuPos[i] - _clPos[i]; + + if(_cuPos[i] <= beginDims[1] + 1) + if(numDistinct > 0) + lowerColDest -= numDistinct - 1; + + if(_cuPos[i] <= endDims[1] + 1) + if(numDistinct > 0) + upperColDest -= numDistinct - 1; + } + beginDims[1] = lowerColDest; + endDims[1] = upperColDest; + } + @Override public void initMetaData(FrameBlock meta) { _clPos = new int[_colList.length]; //col lower pos diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java index 206ac74..753c666 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java @@ -19,6 +19,10 @@ package org.apache.sysds.runtime.transform.decode; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -50,14 +54,43 @@ public class DecoderPassThrough extends Decoder int srcColID = _srcCols[j]; int tgtColID = _colList[j]; double val = in.quickGetValue(i, srcColID-1); - out.set(i, tgtColID-1, UtilFunctions.doubleToObject( - _schema[tgtColID-1], val)); + out.set(i, tgtColID-1, + UtilFunctions.doubleToObject(_schema[tgtColID-1], val)); } } return out; } @Override + public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) { + List<Integer> colList = new ArrayList<>(); + List<Integer> dcList = new ArrayList<>(); + List<Integer> srcList = new ArrayList<>(); + + for (int i = 0; i < _colList.length; i++) { + int colID = _colList[i]; + if (colID >= colStart && colID < colEnd) { + colList.add(colID - (colStart - 1)); + srcList.add(_srcCols[i] - dummycodedOffset); + } + } + + Arrays.stream(_dcCols) + .filter(c -> c >= colStart && c < colEnd) + .forEach(c -> dcList.add(c)); + + if (colList.isEmpty()) + // empty decoder -> return null + return null; + + DecoderPassThrough decoder = new DecoderPassThrough(Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1), + colList.stream().mapToInt(i -> i).toArray(), + dcList.stream().mapToInt(i -> i).toArray()); + decoder._srcCols = srcList.stream().mapToInt(i -> i).toArray(); + return decoder; + } + + @Override public void initMetaData(FrameBlock meta) { if( _dcCols.length > 0 ) { //prepare source column id mapping w/ dummy coding @@ -69,8 +102,7 @@ public class DecoderPassThrough extends Decoder ix1 ++; } else { //_colList[ix1] > _dcCols[ix2] - off += (int)meta.getColumnMetadata()[_dcCols[ix2]-1] - .getNumDistinct() - 1; + off += (int)meta.getColumnMetadata()[_dcCols[ix2]-1].getNumDistinct() - 1; ix2 ++; } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java index 5ebb8cc..9ae315f 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java @@ -19,8 +19,11 @@ package org.apache.sysds.runtime.transform.decode; +import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; +import java.util.List; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -74,6 +77,32 @@ public class DecoderRecode extends Decoder @Override @SuppressWarnings("unchecked") + public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) { + List<Integer> cols = new ArrayList<>(); + List<HashMap<Long, Object>> rcMaps = new ArrayList<>(); + for(int i = 0; i < _colList.length; i++) { + int col = _colList[i]; + if(col >= colStart && col < colEnd) { + // add the correct column, removed columns before start + // colStart - 1 because colStart is 1-based + int corrColumn = col - (colStart - 1); + cols.add(corrColumn); + rcMaps.add(new HashMap<>(_rcMaps[i])); + } + } + if(cols.isEmpty()) + // empty encoder -> sub range encoder does not exist + return null; + + int[] colList = cols.stream().mapToInt(i -> i).toArray(); + DecoderRecode subRangeDecoder = new DecoderRecode( + Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1), _onOut, colList); + subRangeDecoder._rcMaps = rcMaps.toArray(new HashMap[0]); + return subRangeDecoder; + } + + @Override + @SuppressWarnings("unchecked") public void initMetaData(FrameBlock meta) { //initialize recode maps according to schema _rcMaps = new HashMap[_colList.length]; diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java index 5945e27..19271f8 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java @@ -168,6 +168,16 @@ public abstract class Encoder implements Serializable throw new DMLRuntimeException( this.getClass().getName() + " does not support merging with " + other.getClass().getName()); } + + /** + * Update index-ranges to after encoding. Note that only Dummycoding changes the ranges. + * + * @param beginDims the begin indexes before encoding + * @param endDims the end indexes before encoding + */ + public void updateIndexRanges(long[] beginDims, long[] endDims) { + // do nothing - default + } /** * Construct a frame block out of the transform meta data. diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java index b170e22..3be9ed9 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java @@ -42,7 +42,7 @@ public class EncoderBin extends Encoder public static final String MAX_PREFIX = "max"; public static final String NBINS_PREFIX = "nbins"; - private int[] _numBins = null; + protected int[] _numBins = null; //frame transform-apply attributes //TODO binMins is redundant and could be removed diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java index e653307..cd21f45 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java @@ -134,17 +134,35 @@ public class EncoderComposite extends Encoder + "CompositeEncoder: " + otherEnc.getClass().getSimpleName()); } } + // update dummycode encoder domain sizes based on distinctness information from other encoders + for (Encoder encoder : _encoders) { + if (encoder instanceof EncoderDummycode) { + ((EncoderDummycode) encoder).updateDomainSizes(_encoders); + return; + } + } return; } for (Encoder encoder : _encoders) { if (encoder.getClass() == other.getClass()) { encoder.mergeAt(other, col); + // update dummycode encoder domain sizes based on distinctness information from other encoders + if (encoder instanceof EncoderDummycode) { + ((EncoderDummycode) encoder).updateDomainSizes(_encoders); + } return; } } super.mergeAt(other, col); } - + + @Override + public void updateIndexRanges(long[] beginDims, long[] endDims) { + for(Encoder enc : _encoders) { + enc.updateIndexRanges(beginDims, endDims); + } + } + @Override public FrameBlock getMetaData(FrameBlock out) { if( _meta != null ) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java index ea66479..8ff5e57 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java @@ -19,28 +19,40 @@ package org.apache.sysds.runtime.transform.encode; -import org.apache.wink.json4j.JSONException; -import org.apache.wink.json4j.JSONObject; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.transform.TfUtils.TfMethod; import org.apache.sysds.runtime.transform.meta.TfMetaUtils; +import org.apache.wink.json4j.JSONException; +import org.apache.wink.json4j.JSONObject; public class EncoderDummycode extends Encoder { private static final long serialVersionUID = 5832130477659116489L; - private int[] _domainSizes = null; // length = #of dummycoded columns + public int[] _domainSizes = null; // length = #of dummycoded columns private long _dummycodedLength = 0; // #of columns after dummycoded - public EncoderDummycode(JSONObject parsedSpec, String[] colnames, int clen) throws JSONException { + public EncoderDummycode(JSONObject parsedSpec, String[] colnames, int clen, int minCol, int maxCol) + throws JSONException { super(null, clen); - - if ( parsedSpec.containsKey(TfMethod.DUMMYCODE.toString()) ) { - int[] collist = TfMetaUtils.parseJsonIDList(parsedSpec, colnames, TfMethod.DUMMYCODE.toString()); + + if(parsedSpec.containsKey(TfMethod.DUMMYCODE.toString())) { + int[] collist = TfMetaUtils + .parseJsonIDList(parsedSpec, colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol); initColList(collist); } } + + public EncoderDummycode() { + super(new int[0], 0); + } @Override public int getNumCols() { @@ -85,6 +97,102 @@ public class EncoderDummycode extends Encoder } @Override + public Encoder subRangeEncoder(int colStart, int colEnd) { + List<Integer> cols = new ArrayList<>(); + List<Integer> domainSizes = new ArrayList<>(); + int newDummycodedLength = colEnd - colStart; + for(int i = 0; i < _colList.length; i++){ + int col = _colList[i]; + if(col >= colStart && col < colEnd) { + // add the correct column, removed columns before start + // colStart - 1 because colStart is 1-based + int corrColumn = col - (colStart - 1); + cols.add(corrColumn); + domainSizes.add(_domainSizes[i]); + newDummycodedLength += _domainSizes[i] - 1; + } + } + if(cols.isEmpty()) + // empty encoder -> sub range encoder does not exist + return null; + + EncoderDummycode subRangeEncoder = new EncoderDummycode(); + subRangeEncoder._clen = colEnd - colStart; + subRangeEncoder._colList = cols.stream().mapToInt(i -> i).toArray(); + subRangeEncoder._domainSizes = domainSizes.stream().mapToInt(i -> i).toArray(); + subRangeEncoder._dummycodedLength = newDummycodedLength; + return subRangeEncoder; + } + + @Override + public void mergeAt(Encoder other, int col) { + if(other instanceof EncoderDummycode) { + mergeColumnInfo(other, col); + + _domainSizes = new int[_colList.length]; + _dummycodedLength = _clen; + // temporary, will be updated later + Arrays.fill(_domainSizes, 0, _colList.length, 1); + return; + } + super.mergeAt(other, col); + } + + @Override + public void updateIndexRanges(long[] beginDims, long[] endDims) { + long[] initialBegin = Arrays.copyOf(beginDims, beginDims.length); + long[] initialEnd = Arrays.copyOf(endDims, endDims.length); + for(int i = 0; i < _colList.length; i++) { + // 1-based vs 0-based + if(_colList[i] < initialBegin[1] + 1) { + // new columns inserted left of the columns of this partial (federated) block + beginDims[1] += _domainSizes[i] - 1; + endDims[1] += _domainSizes[i] - 1; + } + else if(_colList[i] < initialEnd[1] + 1) { + // new columns inserted in this (federated) block + endDims[1] += _domainSizes[i] - 1; + } + } + } + + public void updateDomainSizes(List<Encoder> encoders) { + if(_colList == null) + return; + + // maps the column ids of the columns encoded by this Dummycode Encoder to their respective indexes + // in the _colList + Map<Integer, Integer> colIDToIxMap = new HashMap<>(); + for (int i = 0; i < _colList.length; i++) + colIDToIxMap.put(_colList[i], i); + + _dummycodedLength = _clen; + for (Encoder encoder : encoders) { + int[] distinct = null; + if (encoder instanceof EncoderRecode) { + EncoderRecode encoderRecode = (EncoderRecode) encoder; + distinct = encoderRecode.numDistinctValues(); + } + else if (encoder instanceof EncoderBin) { + distinct = ((EncoderBin) encoder)._numBins; + } + + if (distinct != null) { + // search for match of encoded columns + for (int i = 0; i < encoder._colList.length; i++) { + Integer ix = colIDToIxMap.get(encoder._colList[i]); + + if (ix != null) { + // set size + _domainSizes[ix] = distinct[i]; + _dummycodedLength += _domainSizes[ix] - 1; + } + } + } + } + } + + @Override public FrameBlock getMetaData(FrameBlock out) { return out; } diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java index 2070485..57f7102 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java @@ -100,7 +100,7 @@ public class EncoderFactory if( !binIDs.isEmpty() ) lencoders.add(new EncoderBin(jSpec, colnames, schema.length)); if( !dcIDs.isEmpty() ) - lencoders.add(new EncoderDummycode(jSpec, colnames, schema.length)); + lencoders.add(new EncoderDummycode(jSpec, colnames, schema.length, minCol, maxCol)); if( !oIDs.isEmpty() ) lencoders.add(new EncoderOmit(jSpec, colnames, schema.length)); if( !mvIDs.isEmpty() ) { diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java index 8b3d36a..d6ceb15 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java @@ -73,9 +73,6 @@ public class EncoderPassThrough extends Encoder @Override public Encoder subRangeEncoder(int colStart, int colEnd) { - if (colStart - 1 >= _clen) - return null; - List<Integer> colList = new ArrayList<>(); for (int col : _colList) { if (col >= colStart && col < colEnd) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java index d4b201e..be1cba9 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java @@ -165,9 +165,6 @@ public class EncoderRecode extends Encoder @Override public Encoder subRangeEncoder(int colStart, int colEnd) { - if (colStart - 1 >= _clen) - return null; - List<Integer> cols = new ArrayList<>(); HashMap<Integer, HashMap<String, Long>> rcdMaps = new HashMap<>(); for (int col : _colList) { @@ -216,6 +213,16 @@ public class EncoderRecode extends Encoder } super.mergeAt(other, col); } + + public int[] numDistinctValues() { + int[] numDistinct = new int[_colList.length]; + + for( int j=0; j<_colList.length; j++ ) { + int colID = _colList[j]; //1-based + numDistinct[j] = _rcdMaps.get(colID).size(); + } + return numDistinct; + } @Override public FrameBlock getMetaData(FrameBlock meta) { diff --git a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java index ceee7da..29afa5b 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java @@ -35,11 +35,13 @@ import org.junit.Assert; import org.junit.Test; public class TransformFederatedEncodeDecodeTest extends AutomatedTestBase { - private static final String TEST_NAME1 = "TransformFederatedEncodeDecode"; + private static final String TEST_NAME_RECODE = "TransformRecodeFederatedEncodeDecode"; + private static final String TEST_NAME_DUMMY = "TransformDummyFederatedEncodeDecode"; private static final String TEST_DIR = "functions/transform/"; private static final String TEST_CLASS_DIR = TEST_DIR+TransformFederatedEncodeDecodeTest.class.getSimpleName()+"/"; - private static final String SPEC = "TransformEncodeDecodeSpec.json"; + private static final String SPEC_RECODE = "TransformEncodeDecodeSpec.json"; + private static final String SPEC_DUMMYCODE = "TransformEncodeDecodeDummySpec.json"; private static final int rows = 1234; private static final int cols = 2; @@ -49,47 +51,78 @@ public class TransformFederatedEncodeDecodeTest extends AutomatedTestBase { @Override public void setUp() { TestUtils.clearAssertionInformation(); - addTestConfiguration(TEST_NAME1, - new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"FO1", "FO2"})); + addTestConfiguration(TEST_NAME_RECODE, + new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_RECODE, new String[] {"FO1", "FO2"})); } @Test - public void runTestCSVDenseCP() { - runTransformEncodeDecodeTest(false, Types.FileFormat.CSV); + public void runComplexRecodeTestCSVDenseCP() { + runTransformEncodeDecodeTest(true, false, Types.FileFormat.CSV); } @Test - public void runTestCSVSparseCP() { - runTransformEncodeDecodeTest(true, Types.FileFormat.CSV); + public void runComplexRecodeTestCSVSparseCP() { + runTransformEncodeDecodeTest(true, true, Types.FileFormat.CSV); } @Test - public void runTestTextcellDenseCP() { - runTransformEncodeDecodeTest(false, Types.FileFormat.TEXT); + public void runComplexRecodeTestTextcellDenseCP() { + runTransformEncodeDecodeTest(true, false, Types.FileFormat.TEXT); } @Test - public void runTestTextcellSparseCP() { - runTransformEncodeDecodeTest(true, Types.FileFormat.TEXT); + public void runComplexRecodeTestTextcellSparseCP() { + runTransformEncodeDecodeTest(true, true, Types.FileFormat.TEXT); } @Test - public void runTestBinaryDenseCP() { - runTransformEncodeDecodeTest(false, Types.FileFormat.BINARY); + public void runComplexRecodeTestBinaryDenseCP() { + runTransformEncodeDecodeTest(true, false, Types.FileFormat.BINARY); } @Test - public void runTestBinarySparseCP() { - runTransformEncodeDecodeTest(true, Types.FileFormat.BINARY); + public void runComplexRecodeTestBinarySparseCP() { + runTransformEncodeDecodeTest(true, true, Types.FileFormat.BINARY); + } + + @Test + public void runSimpleDummycodeTestCSVDenseCP() { + runTransformEncodeDecodeTest(false, false, Types.FileFormat.CSV); + } + + @Test + public void runSimpleDummycodeTestCSVSparseCP() { + runTransformEncodeDecodeTest(false, true, Types.FileFormat.CSV); + } + + @Test + public void runSimpleDummycodeTestTextDenseCP() { + runTransformEncodeDecodeTest(false, false, Types.FileFormat.TEXT); + } + + @Test + public void runSimpleDummycodeTestTextSparseCP() { + runTransformEncodeDecodeTest(false, true, Types.FileFormat.TEXT); + } + + @Test + public void runSimpleDummycodeTestBinaryDenseCP() { + runTransformEncodeDecodeTest(false, false, Types.FileFormat.BINARY); + } + + @Test + public void runSimpleDummycodeTestBinarySparseCP() { + runTransformEncodeDecodeTest(false, true, Types.FileFormat.BINARY); } - private void runTransformEncodeDecodeTest(boolean sparse, Types.FileFormat format) { + private void runTransformEncodeDecodeTest(boolean recode, boolean sparse, + Types.FileFormat format) { ExecMode platformOld = rtplatform; rtplatform = ExecMode.SINGLE_NODE; Thread t1 = null, t2 = null, t3 = null, t4 = null; try { - getAndLoadTestConfiguration(TEST_NAME1); + getAndLoadTestConfiguration(TEST_NAME_RECODE); int port1 = getRandomAvailablePort(); t1 = startLocalFedWorkerThread(port1); @@ -120,14 +153,15 @@ public class TransformFederatedEncodeDecodeTest extends AutomatedTestBase { writeInputFrameWithMTD("BU", BUpper, false, schema, format); writeInputFrameWithMTD("BL", BLower, false, schema, format); - fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME1 + ".dml"; + fullDMLScriptName = SCRIPT_DIR + TEST_DIR + (recode ? TEST_NAME_RECODE : TEST_NAME_DUMMY) + ".dml"; + String spec_file = recode ? SPEC_RECODE : SPEC_DUMMYCODE; programArgs = new String[] {"-nvargs", "in_AU=" + TestUtils.federatedAddress("localhost", port1, input("AU")), "in_AL=" + TestUtils.federatedAddress("localhost", port2, input("AL")), "in_BU=" + TestUtils.federatedAddress("localhost", port3, input("BU")), "in_BL=" + TestUtils.federatedAddress("localhost", port4, input("BL")), "rows=" + rows, "cols=" + cols, - "spec_file=" + SCRIPT_DIR + TEST_DIR + SPEC, "out1=" + output("FO1"), "out2=" + output("FO2"), + "spec_file=" + SCRIPT_DIR + TEST_DIR + spec_file, "out1=" + output("FO1"), "out2=" + output("FO2"), "format=" + format.toString()}; // run test @@ -144,16 +178,18 @@ public class TransformFederatedEncodeDecodeTest extends AutomatedTestBase { + val, expected, val); } } - // TODO federate the aggregated result so that the decode is applied in a federated environment - // compare matrices (values recoded to identical codes) - FrameBlock FO = reader.readFrameFromHDFS(output("FO1"), 15, 2); - HashMap<String, Long> cFA = getCounts(A, B); - Iterator<String[]> iterFO = FO.getStringRowIterator(); - while(iterFO.hasNext()) { - String[] row = iterFO.next(); - Double expected = (double) cFA.get(row[1]); - Double val = (row[0] != null) ? Double.parseDouble(row[0]) : 0; - Assert.assertEquals("Output aggregates don't match: " + expected + " vs " + val, expected, val); + if(recode) { + // TODO federate the aggregated result so that the decode is applied in a federated environment + // compare matrices (values recoded to identical codes) + FrameBlock FO = reader.readFrameFromHDFS(output("FO1"), 15, 2); + HashMap<String, Long> cFA = getCounts(A, B); + Iterator<String[]> iterFO = FO.getStringRowIterator(); + while(iterFO.hasNext()) { + String[] row = iterFO.next(); + Double expected = (double) cFA.get(row[1]); + Double val = (row[0] != null) ? Double.parseDouble(row[0]) : 0; + Assert.assertEquals("Output aggregates don't match: " + expected + " vs " + val, expected, val); + } } } catch(Exception ex) { diff --git a/src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml b/src/test/scripts/functions/transform/TransformDummyFederatedEncodeDecode.dml similarity index 89% copy from src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml copy to src/test/scripts/functions/transform/TransformDummyFederatedEncodeDecode.dml index 50174d7..f029719 100644 --- a/src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml +++ b/src/test/scripts/functions/transform/TransformDummyFederatedEncodeDecode.dml @@ -28,11 +28,6 @@ jspec = read($spec_file, data_type="scalar", value_type="string"); [X, M] = transformencode(target=F, spec=jspec); -A = aggregate(target=X[,1], groups=X[,2], fn="count"); -Ag = cbind(A, seq(1,nrow(A))); - -FO1 = transformdecode(target=Ag, spec=jspec, meta=M); FO2 = transformdecode(target=X, spec=jspec, meta=M); -write(FO1, $out1, format=$format); write(FO2, $out2, format=$format); diff --git a/src/test/scripts/functions/transform/TransformEncodeDecodeDummySpec.json b/src/test/scripts/functions/transform/TransformEncodeDecodeDummySpec.json new file mode 100644 index 0000000..5f4aa12 --- /dev/null +++ b/src/test/scripts/functions/transform/TransformEncodeDecodeDummySpec.json @@ -0,0 +1,5 @@ +{ + "ids": true + ,"dummycode": [ 2 ] + +} diff --git a/src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml b/src/test/scripts/functions/transform/TransformRecodeFederatedEncodeDecode.dml similarity index 100% rename from src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml rename to src/test/scripts/functions/transform/TransformRecodeFederatedEncodeDecode.dml