This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push: new 3743282 [SYSTEMDS-2554-7] Federated frame transformapply, incl binning, omit 3743282 is described below commit 37432820067175b8b49f21754bc4df7959971d7b Author: Kevin Innerebner <kevin.innereb...@yahoo.com> AuthorDate: Sat Aug 29 22:33:58 2020 +0200 [SYSTEMDS-2554-7] Federated frame transformapply, incl binning, omit Closes #1032. --- .../controlprogram/federated/FederatedRange.java | 6 + .../federated/FederatedWorkerHandler.java | 10 +- .../controlprogram/federated/FederationMap.java | 7 +- .../cp/ParameterizedBuiltinCPInstruction.java | 7 +- .../instructions/fed/FEDInstructionUtils.java | 3 +- ...tiReturnParameterizedBuiltinFEDInstruction.java | 58 +++-- .../fed/ParameterizedBuiltinFEDInstruction.java | 152 ++++++++++-- .../sysds/runtime/io/FileFormatPropertiesCSV.java | 8 +- .../sysds/runtime/transform/encode/Encoder.java | 27 +- .../sysds/runtime/transform/encode/EncoderBin.java | 120 ++++++++- .../runtime/transform/encode/EncoderComposite.java | 14 +- .../runtime/transform/encode/EncoderDummycode.java | 25 +- .../runtime/transform/encode/EncoderFactory.java | 10 +- .../transform/encode/EncoderFeatureHash.java | 38 ++- .../runtime/transform/encode/EncoderMVImpute.java | 2 +- .../runtime/transform/encode/EncoderOmit.java | 145 ++++++++--- .../transform/encode/EncoderPassThrough.java | 14 +- .../runtime/transform/encode/EncoderRecode.java | 15 +- .../sysds/runtime/transform/meta/TfMetaUtils.java | 79 +++--- .../org/apache/sysds/runtime/util/HDFSTool.java | 7 + .../org/apache/sysds/runtime/util/IndexRange.java | 18 +- .../TransformFederatedEncodeApplyTest.java | 273 +++++++++++++++++++++ .../TransformFederatedEncodeDecodeTest.java | 15 +- .../transform/TransformFederatedEncodeApply.dml | 36 +++ 24 files changed, 904 insertions(+), 185 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java index 23d0269..4289cfe 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java @@ -21,6 +21,8 @@ package org.apache.sysds.runtime.controlprogram.federated; import java.util.Arrays; +import org.apache.sysds.runtime.util.IndexRange; + public class FederatedRange implements Comparable<FederatedRange> { private long[] _beginDims; private long[] _endDims; @@ -119,4 +121,8 @@ public class FederatedRange implements Comparable<FederatedRange> { _endDims[1] = tmpEnd; return this; } + + public IndexRange asIndexRange() { + return new IndexRange(_beginDims[0], _endDims[0], _beginDims[1], _endDims[1]); + } } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java index a2e62fe..b5f0ec8 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java @@ -45,6 +45,7 @@ import org.apache.sysds.runtime.instructions.InstructionParser; import org.apache.sysds.runtime.instructions.cp.Data; import org.apache.sysds.runtime.instructions.cp.ListObject; import org.apache.sysds.runtime.instructions.cp.ScalarObject; +import org.apache.sysds.runtime.io.FileFormatPropertiesCSV; import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.meta.MetaDataFormat; @@ -184,10 +185,10 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { // read metadata FileFormat fmt = null; + boolean header = false; try { String mtdname = DataExpression.getMTDFileName(filename); Path path = new Path(mtdname); - FileSystem fs = IOUtilFunctions.getFileSystem(mtdname); //no auto-close try (BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(path)))) { JSONObject mtd = JSONHelper.parse(br); @@ -198,7 +199,8 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { mc.setCols(mtd.getLong(DataExpression.READCOLPARAM)); if(mtd.containsKey(DataExpression.READNNZPARAM)) mc.setNonZeros(mtd.getLong(DataExpression.READNNZPARAM)); - + if (mtd.has(DataExpression.DELIM_HAS_HEADER_ROW)) + header = mtd.getBoolean(DataExpression.DELIM_HAS_HEADER_ROW); cd = (CacheableData<?>) PrivacyPropagator.parseAndSetPrivacyConstraint(cd, mtd); fmt = FileFormat.safeValueOf(mtd.getString(DataExpression.FORMAT_TYPE)); } @@ -209,6 +211,10 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { //put meta data object in symbol table, read on first operation cd.setMetaData(new MetaDataFormat(mc, fmt)); + // TODO send FileFormatProperties with request and use them for CSV, this is currently a workaround so reading + // of CSV files works + cd.setFileFormatProperties(new FileFormatPropertiesCSV(header, DataExpression.DEFAULT_DELIM_DELIMITER, + DataExpression.DEFAULT_DELIM_SPARSE)); cd.enableCleanup(false); //guard against deletion _ecm.get(tid).setVariable(String.valueOf(id), cd); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java index ea8aa29..7d537c9 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java @@ -242,10 +242,9 @@ public class FederationMap public long getMaxIndexInRange(int dim) { - long maxIx = 0; - for(FederatedRange range : _fedMap.keySet()) - maxIx = Math.max(range.getEndDims()[dim], maxIx); - return maxIx; + return _fedMap.keySet().stream() + .mapToLong(range -> range.getEndDims()[dim]).max() + .orElse(-1L); } /** diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java index cfb20e3..5c71780 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java @@ -37,6 +37,7 @@ import org.apache.sysds.parser.ParameterizedBuiltinFunctionExpression; import org.apache.sysds.parser.Statement; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; import org.apache.sysds.runtime.controlprogram.caching.FrameObject; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.caching.TensorObject; @@ -304,7 +305,7 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction //get input spec and path String spec = getParameterMap().get("spec"); String path = getParameterMap().get(ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_MTD); - String delim = getParameterMap().containsKey("sep") ? getParameterMap().get("sep") : TfUtils.TXMTD_SEP; + String delim = getParameterMap().getOrDefault("sep", TfUtils.TXMTD_SEP); //execute transform meta data read FrameBlock meta = null; @@ -457,8 +458,8 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction } } - public MatrixObject getTarget(ExecutionContext ec) { - return ec.getMatrixObject(params.get("target")); + public CacheableData<?> getTarget(ExecutionContext ec) { + return ec.getCacheableData(params.get("target")); } private CPOperand getTargetOperand() { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java index a1b0a08..2e41aa5 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java @@ -79,7 +79,8 @@ public class FEDInstructionUtils { if(pinst.getOpcode().equals("replace") && pinst.getTarget(ec).isFederated()) { fedinst = ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString()); } - else if(pinst.getOpcode().equals("transformdecode") && pinst.getTarget(ec).isFederated()) { + else if((pinst.getOpcode().equals("transformdecode") || pinst.getOpcode().equals("transformapply")) && + pinst.getTarget(ec).isFederated()) { return ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString()); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java index b9b6203..0fe12b9 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java @@ -43,11 +43,15 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.privacy.PrivacyMonitor; import org.apache.sysds.runtime.transform.encode.Encoder; +import org.apache.sysds.runtime.transform.encode.EncoderBin; import org.apache.sysds.runtime.transform.encode.EncoderComposite; import org.apache.sysds.runtime.transform.encode.EncoderDummycode; import org.apache.sysds.runtime.transform.encode.EncoderFactory; +import org.apache.sysds.runtime.transform.encode.EncoderFeatureHash; +import org.apache.sysds.runtime.transform.encode.EncoderOmit; import org.apache.sysds.runtime.transform.encode.EncoderPassThrough; import org.apache.sysds.runtime.transform.encode.EncoderRecode; +import org.apache.sysds.runtime.util.IndexRange; public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFEDInstruction { protected final ArrayList<CPOperand> _outputs; @@ -86,10 +90,19 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE // obtain and pin input frame FrameObject fin = ec.getFrameObject(input1.getName()); String spec = ec.getScalarInput(input2).getStringValue(); + + String[] colNames = new String[(int) fin.getNumColumns()]; + Arrays.fill(colNames, ""); // the encoder in which the complete encoding information will be aggregated EncoderComposite globalEncoder = new EncoderComposite( - Arrays.asList(new EncoderRecode(), new EncoderPassThrough(), new EncoderDummycode())); + // IMPORTANT: Encoder order matters + Arrays.asList(new EncoderRecode(), + new EncoderFeatureHash(), + new EncoderPassThrough(), + new EncoderBin(), + new EncoderDummycode(), + new EncoderOmit(true))); // first create encoders at the federated workers, then collect them and aggregate them to a single large // encoder FederationMap fedMapping = fin.getFedMapping(); @@ -98,39 +111,55 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE // create an encoder with the given spec. The columnOffset (which is 1 based) has to be used to // tell the federated worker how much the indexes in the spec have to be offset. - Future<FederatedResponse> response = data.executeFederatedOperation( - new FederatedRequest(RequestType.EXEC_UDF, data.getVarID(), + Future<FederatedResponse> responseFuture = data.executeFederatedOperation( + new FederatedRequest(RequestType.EXEC_UDF, -1, new CreateFrameEncoder(data.getVarID(), spec, columnOffset))); // collect responses with encoders try { - Encoder encoder = (Encoder) response.get().getData()[0]; + FederatedResponse response = responseFuture.get(); + Encoder encoder = (Encoder) response.getData()[0]; // merge this encoder into a composite encoder synchronized(globalEncoder) { globalEncoder.mergeAt(encoder, columnOffset); } + // no synchronization necessary since names should anyway match + String[] subRangeColNames = (String[]) response.getData()[1]; + System.arraycopy(subRangeColNames, 0, colNames, (int) range.getBeginDims()[1], subRangeColNames.length); } catch(Exception e) { throw new DMLRuntimeException("Federated encoder creation failed: " + e.getMessage()); } return null; }); + FrameBlock meta = new FrameBlock((int) fin.getNumColumns(), Types.ValueType.STRING); + meta.setColumnNames(colNames); + globalEncoder.getMetaData(meta); + globalEncoder.initMetaData(meta); + + encodeFederatedFrames(fedMapping, globalEncoder, ec.getMatrixObject(getOutput(0))); + + // release input and outputs + ec.setFrameOutput(getOutput(1).getName(), meta); + } + + public static void encodeFederatedFrames(FederationMap fedMapping, Encoder globalEncoder, + MatrixObject transformedMat) { long varID = FederationUtils.getNextFedDataID(); FederationMap transformedFedMapping = fedMapping.mapParallel(varID, (range, data) -> { // copy because we reuse it long[] beginDims = range.getBeginDims(); long[] endDims = range.getEndDims(); - int colStart = (int) beginDims[1] + 1; - int colEnd = (int) endDims[1] + 1; + IndexRange ixRange = new IndexRange(beginDims[0], endDims[0], beginDims[1], endDims[1]).add(1);// make 1-based // update begin end dims (column part) considering columns added by dummycoding globalEncoder.updateIndexRanges(beginDims, endDims); // get the encoder segment that is relevant for this federated worker - Encoder encoder = globalEncoder.subRangeEncoder(colStart, colEnd); + Encoder encoder = globalEncoder.subRangeEncoder(ixRange); try { FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(RequestType.EXEC_UDF, - varID, new ExecuteFrameEncoder(data.getVarID(), varID, encoder))).get(); + -1, new ExecuteFrameEncoder(data.getVarID(), varID, encoder))).get(); if(!response.isSuccessful()) response.throwExceptionFromResponse(); } @@ -141,18 +170,11 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE }); // construct a federated matrix with the encoded data - MatrixObject transformedMat = ec.getMatrixObject(getOutput(0)); - transformedMat.getDataCharacteristics().setRows(transformedFedMapping.getMaxIndexInRange(0)); - transformedMat.getDataCharacteristics().setCols(transformedFedMapping.getMaxIndexInRange(1)); - // set the federated mapping for the matrix + transformedMat.getDataCharacteristics().setDimension( + transformedFedMapping.getMaxIndexInRange(0), transformedFedMapping.getMaxIndexInRange(1)); transformedMat.setFedMapping(transformedFedMapping); - - // release input and outputs - ec.setFrameOutput(getOutput(1).getName(), - globalEncoder.getMetaData(new FrameBlock((int) fin.getNumColumns(), Types.ValueType.STRING))); } - public static class CreateFrameEncoder extends FederatedUDF { private static final long serialVersionUID = 2376756757742169692L; private final String _spec; @@ -179,7 +201,7 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE fo.release(); // create federated response - return new FederatedResponse(ResponseType.SUCCESS, encoder); + return new FederatedResponse(ResponseType.SUCCESS, new Object[] {encoder, fb.getColumnNames()}); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java index 47f912d..204019f 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java @@ -23,17 +23,20 @@ import java.util.Arrays; import java.util.HashMap; import java.util.LinkedHashMap; +import java.util.List; import org.apache.sysds.common.Types; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.lops.Lop; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; import org.apache.sysds.runtime.controlprogram.caching.FrameObject; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; +import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse.ResponseType; import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; import org.apache.sysds.runtime.controlprogram.federated.FederationMap; import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; @@ -51,6 +54,10 @@ import org.apache.sysds.runtime.meta.MetaDataFormat; import org.apache.sysds.runtime.privacy.PrivacyMonitor; import org.apache.sysds.runtime.transform.decode.Decoder; import org.apache.sysds.runtime.transform.decode.DecoderFactory; +import org.apache.sysds.runtime.transform.encode.Encoder; +import org.apache.sysds.runtime.transform.encode.EncoderComposite; +import org.apache.sysds.runtime.transform.encode.EncoderFactory; +import org.apache.sysds.runtime.transform.encode.EncoderOmit; public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstruction { protected final LinkedHashMap<String, String> params; @@ -113,7 +120,7 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio if(opcode.equalsIgnoreCase("replace")) { // similar to unary federated instructions, get federated input // execute instruction, and derive federated output matrix - MatrixObject mo = getTarget(ec); + MatrixObject mo = (MatrixObject) getTarget(ec); FederatedRequest fr1 = FederationUtils.callInstruction(instString, output, new CPOperand[] {getTargetOperand()}, new long[] {mo.getFedMapping().getID()}); mo.getFedMapping().execute(getTID(), true, fr1); @@ -125,22 +132,24 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio } else if(opcode.equalsIgnoreCase("transformdecode")) transformDecode(ec); + else if(opcode.equalsIgnoreCase("transformapply")) + transformApply(ec); else { throw new DMLRuntimeException("Unknown opcode : " + opcode); } } - + private void transformDecode(ExecutionContext ec) { // acquire locks MatrixObject mo = ec.getMatrixObject(params.get("target")); FrameBlock meta = ec.getFrameInput(params.get("meta")); String spec = params.get("spec"); - + Decoder globalDecoder = DecoderFactory - .createDecoder(spec, meta.getColumnNames(), null, meta, (int) mo.getNumColumns()); - + .createDecoder(spec, meta.getColumnNames(), null, meta, (int) mo.getNumColumns()); + FederationMap fedMapping = mo.getFedMapping(); - + ValueType[] schema = new ValueType[(int) mo.getNumColumns()]; long varID = FederationUtils.getNextFedDataID(); FederationMap decodedMapping = fedMapping.mapParallel(varID, (range, data) -> { @@ -153,22 +162,21 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio // get the decoder segment that is relevant for this federated worker Decoder decoder = globalDecoder - .subRangeDecoder((int) beginDims[1] + 1, (int) endDims[1] + 1, colStartBefore); + .subRangeDecoder((int) beginDims[1] + 1, (int) endDims[1] + 1, colStartBefore); FrameBlock metaSlice = new FrameBlock(); synchronized(meta) { meta.slice(0, meta.getNumRows() - 1, (int) beginDims[1], (int) endDims[1] - 1, metaSlice); } - - + FederatedResponse response; try { response = data.executeFederatedOperation( - new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, varID, + new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, new DecodeMatrix(data.getVarID(), varID, metaSlice, decoder))).get(); if(!response.isSuccessful()) response.throwExceptionFromResponse(); - + ValueType[] subSchema = (ValueType[]) response.getData()[0]; synchronized(schema) { // It would be possible to assert that different federated workers don't give different value @@ -181,7 +189,7 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio } return null; }); - + // construct a federated matrix with the encoded data FrameObject decodedFrame = ec.getFrameObject(output); decodedFrame.setSchema(globalDecoder.getSchema()); @@ -189,19 +197,94 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio decodedFrame.getDataCharacteristics().setCols(globalDecoder.getSchema().length); // set the federated mapping for the matrix decodedFrame.setFedMapping(decodedMapping); - + // release locks ec.releaseFrameInput(params.get("meta")); } - public MatrixObject getTarget(ExecutionContext ec) { - return ec.getMatrixObject(params.get("target")); + private void transformApply(ExecutionContext ec) { + // acquire locks + FrameObject fo = ec.getFrameObject(params.get("target")); + FrameBlock meta = ec.getFrameInput(params.get("meta")); + String spec = params.get("spec"); + + FederationMap fedMapping = fo.getFedMapping(); + + // get column names for the EncoderFactory + String[] colNames = new String[(int) fo.getNumColumns()]; + Arrays.fill(colNames, ""); + + fedMapping.forEachParallel((range, data) -> { + try { + FederatedResponse response = data + .executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, + new GetColumnNames(data.getVarID()))).get(); + + // no synchronization necessary since names should anyway match + String[] subRangeColNames = (String[]) response.getData()[0]; + System.arraycopy(subRangeColNames, 0, colNames, (int) range.getBeginDims()[1], subRangeColNames.length); + } + catch(Exception e) { + throw new DMLRuntimeException(e); + } + return null; + }); + + Encoder globalEncoder = EncoderFactory.createEncoder(spec, colNames, colNames.length, meta); + + // check if EncoderOmit exists + List<Encoder> encoders = ((EncoderComposite) globalEncoder).getEncoders(); + int omitIx = -1; + for(int i = 0; i < encoders.size(); i++) { + if(encoders.get(i) instanceof EncoderOmit) { + omitIx = i; + break; + } + } + if(omitIx != -1) { + // extra step, build the omit encoder: we need information about all the rows to omit, if our federated + // ranges are split up row-wise we need to build the encoder separately and combine it + buildOmitEncoder(fedMapping, encoders, omitIx); + } + + MultiReturnParameterizedBuiltinFEDInstruction + .encodeFederatedFrames(fedMapping, globalEncoder, ec.getMatrixObject(getOutputVariableName())); + + // release locks + ec.releaseFrameInput(params.get("meta")); + } + + private static void buildOmitEncoder(FederationMap fedMapping, List<Encoder> encoders, int omitIx) { + Encoder omitEncoder = encoders.get(omitIx); + EncoderOmit newOmit = new EncoderOmit(true); + fedMapping.forEachParallel((range, data) -> { + try { + EncoderOmit subRangeEncoder = (EncoderOmit) omitEncoder.subRangeEncoder(range.asIndexRange().add(1)); + FederatedResponse response = data + .executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, + new InitRowsToRemoveOmit(data.getVarID(), subRangeEncoder))).get(); + + // no synchronization necessary since names should anyway match + Encoder builtEncoder = (Encoder) response.getData()[0]; + newOmit.mergeAt(builtEncoder, (int) (range.getBeginDims()[1] + 1)); + } + catch(Exception e) { + throw new DMLRuntimeException(e); + } + return null; + }); + encoders.remove(omitIx); + encoders.add(omitIx, newOmit); + } + + public CacheableData<?> getTarget(ExecutionContext ec) { + return ec.getCacheableData(params.get("target")); } private CPOperand getTargetOperand() { return new CPOperand(params.get("target"), ValueType.FP64, DataType.MATRIX); } - + public static class DecodeMatrix extends FederatedUDF { private static final long serialVersionUID = 2376756757742169692L; private final long _outputID; @@ -235,7 +318,42 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio // add it to the list of variables ec.setVariable(String.valueOf(_outputID), fo); // return schema - return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[] {fo.getSchema()}); + return new FederatedResponse(ResponseType.SUCCESS, new Object[] {fo.getSchema()}); + } + } + + private static class GetColumnNames extends FederatedUDF { + private static final long serialVersionUID = -7831469841164270004L; + + public GetColumnNames(long varID) { + super(new long[] {varID}); + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + FrameObject fo = (FrameObject) PrivacyMonitor.handlePrivacy(data[0]); + FrameBlock fb = fo.acquireReadAndRelease(); + // return column names + return new FederatedResponse(ResponseType.SUCCESS, new Object[] {fb.getColumnNames()}); + } + } + + private static class InitRowsToRemoveOmit extends FederatedUDF { + private static final long serialVersionUID = -8196730717390438411L; + + EncoderOmit _encoder; + + public InitRowsToRemoveOmit(long varID, EncoderOmit encoder) { + super(new long[] {varID}); + _encoder = encoder; + } + + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + FrameObject fo = (FrameObject) PrivacyMonitor.handlePrivacy(data[0]); + FrameBlock fb = fo.acquireReadAndRelease(); + _encoder.build(fb); + return new FederatedResponse(ResponseType.SUCCESS, new Object[] {_encoder}); } } } diff --git a/src/main/java/org/apache/sysds/runtime/io/FileFormatPropertiesCSV.java b/src/main/java/org/apache/sysds/runtime/io/FileFormatPropertiesCSV.java index 7049918..7b20e38 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FileFormatPropertiesCSV.java +++ b/src/main/java/org/apache/sysds/runtime/io/FileFormatPropertiesCSV.java @@ -54,6 +54,7 @@ public class FileFormatPropertiesCSV extends FileFormatProperties implements Ser } public FileFormatPropertiesCSV(boolean hasHeader, String delim, boolean fill, double fillValue, String naStrings) { + this(); this.header = hasHeader; this.delim = delim; this.fill = fill; @@ -68,6 +69,7 @@ public class FileFormatPropertiesCSV extends FileFormatProperties implements Ser } public FileFormatPropertiesCSV(boolean hasHeader, String delim, boolean sparse) { + this(); this.header = hasHeader; this.delim = delim; this.sparse = sparse; @@ -88,7 +90,11 @@ public class FileFormatPropertiesCSV extends FileFormatProperties implements Ser public String getDelim() { return delim; } - + + public void setNAStrings(HashSet<String> naStrings) { + this.naStrings = naStrings; + } + public HashSet<String> getNAStrings() { return naStrings; } diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java index 19271f8..7f47192 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java @@ -20,17 +20,20 @@ package org.apache.sysds.runtime.transform.encode; import java.io.Serializable; +import java.util.ArrayList; import java.util.Arrays; - import java.util.HashSet; +import java.util.List; import java.util.Set; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.wink.json4j.JSONArray; import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.IndexRange; import org.apache.sysds.runtime.util.UtilFunctions; +import org.apache.wink.json4j.JSONArray; /** * Base class for all transform encoders providing both a row and block @@ -125,14 +128,26 @@ public abstract class Encoder implements Serializable */ public abstract MatrixBlock apply(FrameBlock in, MatrixBlock out); + protected int[] subRangeColList(IndexRange ixRange) { + List<Integer> cols = new ArrayList<>(); + for(int col : _colList) { + if(ixRange.inColRange(col)) { + // add the correct column, removed columns before start + // colStart - 1 because colStart is 1-based + int corrColumn = (int) (col - (ixRange.colStart - 1)); + cols.add(corrColumn); + } + } + return cols.stream().mapToInt(i -> i).toArray(); + } + /** * Returns a new Encoder that only handles a sub range of columns. * - * @param colStart the start index of the sub-range (1-based, inclusive) - * @param colEnd the end index of the sub-range (1-based, exclusive) + * @param ixRange the range (1-based, begin inclusive, end exclusive) * @return an encoder of the same type, just for the sub-range */ - public Encoder subRangeEncoder(int colStart, int colEnd) { + public Encoder subRangeEncoder(IndexRange ixRange) { throw new DMLRuntimeException( this.getClass().getSimpleName() + " does not support the creation of a sub-range encoder"); } @@ -166,7 +181,7 @@ public abstract class Encoder implements Serializable */ public void mergeAt(Encoder other, int col) { throw new DMLRuntimeException( - this.getClass().getName() + " does not support merging with " + other.getClass().getName()); + this.getClass().getSimpleName() + " does not support merging with " + other.getClass().getSimpleName()); } /** diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java index 3be9ed9..351f68d 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java @@ -20,19 +20,24 @@ package org.apache.sysds.runtime.transform.encode; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.commons.lang.ArrayUtils; -import org.apache.wink.json4j.JSONArray; -import org.apache.wink.json4j.JSONException; -import org.apache.wink.json4j.JSONObject; +import org.apache.commons.lang3.tuple.MutableTriple; import org.apache.sysds.lops.Lop; import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.transform.TfUtils.TfMethod; import org.apache.sysds.runtime.transform.meta.TfMetaUtils; +import org.apache.sysds.runtime.util.IndexRange; import org.apache.sysds.runtime.util.UtilFunctions; +import org.apache.wink.json4j.JSONArray; +import org.apache.wink.json4j.JSONException; +import org.apache.wink.json4j.JSONObject; public class EncoderBin extends Encoder { @@ -49,7 +54,7 @@ public class EncoderBin extends Encoder private double[][] _binMins = null; private double[][] _binMaxs = null; - public EncoderBin(JSONObject parsedSpec, String[] colnames, int clen) + public EncoderBin(JSONObject parsedSpec, String[] colnames, int clen, int minCol, int maxCol) throws JSONException, IOException { super( null, clen ); @@ -57,22 +62,35 @@ public class EncoderBin extends Encoder return; //parse column names or column ids - List<Integer> collist = TfMetaUtils.parseBinningColIDs(parsedSpec, colnames); + List<Integer> collist = TfMetaUtils.parseBinningColIDs(parsedSpec, colnames, minCol, maxCol); initColList(ArrayUtils.toPrimitive(collist.toArray(new Integer[0]))); //parse number of bins per column boolean ids = parsedSpec.containsKey("ids") && parsedSpec.getBoolean("ids"); JSONArray group = (JSONArray) parsedSpec.get(TfMethod.BIN.toString()); _numBins = new int[collist.size()]; - for(int i=0; i < _numBins.length; i++) { - JSONObject colspec = (JSONObject) group.get(i); - int pos = collist.indexOf(ids ? colspec.getInt("id") : - ArrayUtils.indexOf(colnames, colspec.get("name"))+1); - _numBins[pos] = colspec.containsKey("numbins") ? - colspec.getInt("numbins"): 1; + for (Object o : group) { + JSONObject colspec = (JSONObject) o; + int ixOffset = minCol == -1 ? 0 : minCol - 1; + int pos = collist.indexOf(ids ? colspec.getInt("id") - ixOffset : + ArrayUtils.indexOf(colnames, colspec.get("name")) + 1); + if(pos >= 0) + _numBins[pos] = colspec.containsKey("numbins") ? colspec.getInt("numbins") : 1; } } + public EncoderBin() { + super(new int[0], 0); + _numBins = new int[0]; + } + + private EncoderBin(int[] colList, int clen, int[] numBins, double[][] binMins, double[][] binMaxs) { + super(colList, clen); + _numBins = numBins; + _binMins = binMins; + _binMaxs = binMaxs; + } + @Override public MatrixBlock encode(FrameBlock in, MatrixBlock out) { build(in); @@ -121,7 +139,87 @@ public class EncoderBin extends Encoder } return out; } + + @Override + public Encoder subRangeEncoder(IndexRange ixRange) { + List<Integer> colsList = new ArrayList<>(); + List<Integer> numBinsList = new ArrayList<>(); + List<double[]> binMinsList = new ArrayList<>(); + List<double[]> binMaxsList = new ArrayList<>(); + for(int i = 0; i < _colList.length; i++) { + int col = _colList[i]; + if(col >= ixRange.colStart && col < ixRange.colEnd) { + // add the correct column, removed columns before start + // colStart - 1 because colStart is 1-based + int corrColumn = (int) (col - (ixRange.colStart - 1)); + colsList.add(corrColumn); + numBinsList.add(_numBins[i]); + binMinsList.add(_binMins[i]); + binMaxsList.add(_binMaxs[i]); + } + } + if(colsList.isEmpty()) + // empty encoder -> sub range encoder does not exist + return null; + int[] colList = colsList.stream().mapToInt(i -> i).toArray(); + return new EncoderBin(colList, (int) (ixRange.colEnd - ixRange.colStart), + numBinsList.stream().mapToInt((i) -> i).toArray(), binMinsList.toArray(new double[0][0]), + binMaxsList.toArray(new double[0][0])); + } + + @Override + public void mergeAt(Encoder other, int col) { + if(other instanceof EncoderBin) { + EncoderBin otherBin = (EncoderBin) other; + + // save the min, max as well as the number of bins for the column indexes + Map<Integer, MutableTriple<Integer, Double, Double>> ixBinsMap = new HashMap<>(); + for(int i = 0; i < _colList.length; i++) { + ixBinsMap.put(_colList[i], + new MutableTriple<>(_numBins[i], _binMins[i][0], _binMaxs[i][_binMaxs[i].length - 1])); + } + for(int i = 0; i < otherBin._colList.length; i++) { + int column = otherBin._colList[i] + (col - 1); + MutableTriple<Integer, Double, Double> entry = ixBinsMap.get(column); + if(entry == null) { + ixBinsMap.put(column, + new MutableTriple<>(otherBin._numBins[i], otherBin._binMins[i][0], + otherBin._binMaxs[i][otherBin._binMaxs[i].length - 1])); + } + else { + // num bins will match + entry.middle = Math.min(entry.middle, otherBin._binMins[i][0]); + entry.right = Math.max(entry.right, otherBin._binMaxs[i][otherBin._binMaxs[i].length - 1]); + } + } + + mergeColumnInfo(other, col); + + // use the saved values to fill the arrays again + _numBins = new int[_colList.length]; + _binMins = new double[_colList.length][]; + _binMaxs = new double[_colList.length][]; + + for(int i = 0; i < _colList.length; i++) { + int column = _colList[i]; + MutableTriple<Integer, Double, Double> entry = ixBinsMap.get(column); + _numBins[i] = entry.left; + + double min = entry.middle; + double max = entry.right; + _binMins[i] = new double[_numBins[i]]; + _binMaxs[i] = new double[_numBins[i]]; + for(int j = 0; j < _numBins[i]; j++) { + _binMins[i][j] = min + j * (max - min) / _numBins[i]; + _binMaxs[i][j] = min + (j + 1) * (max - min) / _numBins[i]; + } + } + return; + } + super.mergeAt(other, col); + } + @Override public FrameBlock getMetaData(FrameBlock meta) { //allocate frame if necessary diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java index cd21f45..c494676 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java @@ -27,6 +27,7 @@ import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.IndexRange; /** * Simple composite encoder that applies a list of encoders @@ -104,10 +105,10 @@ public class EncoderComposite extends Encoder } @Override - public Encoder subRangeEncoder(int colStart, int colEnd) { + public Encoder subRangeEncoder(IndexRange ixRange) { List<Encoder> subRangeEncoders = new ArrayList<>(); for (Encoder encoder : _encoders) { - Encoder subEncoder = encoder.subRangeEncoder(colStart, colEnd); + Encoder subEncoder = encoder.subRangeEncoder(ixRange); if (subEncoder != null) { subRangeEncoders.add(subEncoder); } @@ -131,7 +132,7 @@ public class EncoderComposite extends Encoder } if(!mergedIn) { throw new DMLRuntimeException("Tried to merge in encoder of class that is not present in " - + "CompositeEncoder: " + otherEnc.getClass().getSimpleName()); + + "EncoderComposite: " + otherEnc.getClass().getSimpleName()); } } // update dummycode encoder domain sizes based on distinctness information from other encoders @@ -147,8 +148,11 @@ public class EncoderComposite extends Encoder if (encoder.getClass() == other.getClass()) { encoder.mergeAt(other, col); // update dummycode encoder domain sizes based on distinctness information from other encoders - if (encoder instanceof EncoderDummycode) { - ((EncoderDummycode) encoder).updateDomainSizes(_encoders); + for (Encoder encDummy : _encoders) { + if (encDummy instanceof EncoderDummycode) { + ((EncoderDummycode) encDummy).updateDomainSizes(_encoders); + return; + } } return; } diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java index 8ff5e57..19d41ea 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java @@ -29,6 +29,7 @@ import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.transform.TfUtils.TfMethod; import org.apache.sysds.runtime.transform.meta.TfMetaUtils; +import org.apache.sysds.runtime.util.IndexRange; import org.apache.wink.json4j.JSONException; import org.apache.wink.json4j.JSONObject; @@ -54,6 +55,12 @@ public class EncoderDummycode extends Encoder super(new int[0], 0); } + public EncoderDummycode(int[] colList, int clen, int[] domainSizes, long dummycodedLength) { + super(colList, clen); + _domainSizes = domainSizes; + _dummycodedLength = dummycodedLength; + } + @Override public int getNumCols() { return (int)_dummycodedLength; @@ -97,16 +104,16 @@ public class EncoderDummycode extends Encoder } @Override - public Encoder subRangeEncoder(int colStart, int colEnd) { + public Encoder subRangeEncoder(IndexRange ixRange) { List<Integer> cols = new ArrayList<>(); List<Integer> domainSizes = new ArrayList<>(); - int newDummycodedLength = colEnd - colStart; - for(int i = 0; i < _colList.length; i++){ + int newDummycodedLength = (int) ixRange.colSpan(); + for(int i = 0; i < _colList.length; i++) { int col = _colList[i]; - if(col >= colStart && col < colEnd) { + if(ixRange.inColRange(col)) { // add the correct column, removed columns before start // colStart - 1 because colStart is 1-based - int corrColumn = col - (colStart - 1); + int corrColumn = (int) (col - (ixRange.colStart - 1)); cols.add(corrColumn); domainSizes.add(_domainSizes[i]); newDummycodedLength += _domainSizes[i] - 1; @@ -116,12 +123,8 @@ public class EncoderDummycode extends Encoder // empty encoder -> sub range encoder does not exist return null; - EncoderDummycode subRangeEncoder = new EncoderDummycode(); - subRangeEncoder._clen = colEnd - colStart; - subRangeEncoder._colList = cols.stream().mapToInt(i -> i).toArray(); - subRangeEncoder._domainSizes = domainSizes.stream().mapToInt(i -> i).toArray(); - subRangeEncoder._dummycodedLength = newDummycodedLength; - return subRangeEncoder; + return new EncoderDummycode(cols.stream().mapToInt(i -> i).toArray(), (int) ixRange.colSpan(), + domainSizes.stream().mapToInt(i -> i).toArray(), newDummycodedLength); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java index 57f7102..313e5b2 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java @@ -73,7 +73,7 @@ public class EncoderFactory TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.HASH.toString(), minCol, maxCol))); List<Integer> dcIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol))); - List<Integer> binIDs = TfMetaUtils.parseBinningColIDs(jSpec, colnames); + List<Integer> binIDs = TfMetaUtils.parseBinningColIDs(jSpec, colnames, minCol, maxCol); //note: any dummycode column requires recode as preparation, unless it follows binning rcIDs = except(unionDistinct(rcIDs, except(dcIDs, binIDs)), haIDs); List<Integer> ptIDs = except(except(UtilFunctions.getSeqList(1, clen, 1), @@ -81,7 +81,7 @@ public class EncoderFactory List<Integer> oIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.OMIT.toString(), minCol, maxCol))); List<Integer> mvIDs = Arrays.asList(ArrayUtils.toObject( - TfMetaUtils.parseJsonObjectIDList(jSpec, colnames, TfMethod.IMPUTE.toString()))); + TfMetaUtils.parseJsonObjectIDList(jSpec, colnames, TfMethod.IMPUTE.toString(), minCol, maxCol))); //create individual encoders if( !rcIDs.isEmpty() ) { @@ -90,7 +90,7 @@ public class EncoderFactory lencoders.add(ra); } if( !haIDs.isEmpty() ) { - EncoderFeatureHash ha = new EncoderFeatureHash(jSpec, colnames, clen); + EncoderFeatureHash ha = new EncoderFeatureHash(jSpec, colnames, clen, minCol, maxCol); ha.setColList(ArrayUtils.toPrimitive(haIDs.toArray(new Integer[0]))); lencoders.add(ha); } @@ -98,11 +98,11 @@ public class EncoderFactory lencoders.add(new EncoderPassThrough( ArrayUtils.toPrimitive(ptIDs.toArray(new Integer[0])), clen)); if( !binIDs.isEmpty() ) - lencoders.add(new EncoderBin(jSpec, colnames, schema.length)); + lencoders.add(new EncoderBin(jSpec, colnames, schema.length, minCol, maxCol)); if( !dcIDs.isEmpty() ) lencoders.add(new EncoderDummycode(jSpec, colnames, schema.length, minCol, maxCol)); if( !oIDs.isEmpty() ) - lencoders.add(new EncoderOmit(jSpec, colnames, schema.length)); + lencoders.add(new EncoderOmit(jSpec, colnames, schema.length, minCol, maxCol)); if( !mvIDs.isEmpty() ) { EncoderMVImpute ma = new EncoderMVImpute(jSpec, colnames, schema.length); ma.initRecodeIDList(rcIDs); diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java index 85c408b..9317dfb 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java @@ -19,6 +19,7 @@ package org.apache.sysds.runtime.transform.encode; +import org.apache.sysds.runtime.util.IndexRange; import org.apache.wink.json4j.JSONException; import org.apache.wink.json4j.JSONObject; import org.apache.sysds.runtime.matrix.data.FrameBlock; @@ -35,10 +36,21 @@ public class EncoderFeatureHash extends Encoder private static final long serialVersionUID = 7435806042138687342L; private long _K; - public EncoderFeatureHash(JSONObject parsedSpec, String[] colnames, int clen) throws JSONException { + public EncoderFeatureHash(JSONObject parsedSpec, String[] colnames, int clen, int minCol, int maxCol) + throws JSONException { super(null, clen); - _colList = TfMetaUtils.parseJsonIDList(parsedSpec, colnames, TfMethod.HASH.toString()); - _K = getK(parsedSpec); + _colList = TfMetaUtils.parseJsonIDList(parsedSpec, colnames, TfMethod.HASH.toString(), minCol, maxCol); + _K = getK(parsedSpec); + } + + public EncoderFeatureHash(int[] colList, int clen, long K) { + super(colList, clen); + _K = K; + } + + public EncoderFeatureHash() { + super(new int[0], 0); + _K = 0; } /** @@ -89,6 +101,26 @@ public class EncoderFeatureHash extends Encoder } @Override + public Encoder subRangeEncoder(IndexRange ixRange) { + int[] colList = subRangeColList(ixRange); + if(colList.length == 0) + // empty encoder -> sub range encoder does not exist + return null; + return new EncoderFeatureHash(colList, (int) ixRange.colSpan(), _K); + } + + @Override + public void mergeAt(Encoder other, int col) { + if(other instanceof EncoderFeatureHash) { + mergeColumnInfo(other, col); + if (((EncoderFeatureHash) other)._K != 0 && _K == 0) + _K = ((EncoderFeatureHash) other)._K; + return; + } + super.mergeAt(other, col); + } + + @Override public FrameBlock getMetaData(FrameBlock meta) { if( !isApplicable() ) return meta; diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java index deba22f..56749a2 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java @@ -82,7 +82,7 @@ public class EncoderMVImpute extends Encoder super(null, clen); //handle column list - int[] collist = TfMetaUtils.parseJsonObjectIDList(parsedSpec, colnames, TfMethod.IMPUTE.toString()); + int[] collist = TfMetaUtils.parseJsonObjectIDList(parsedSpec, colnames, TfMethod.IMPUTE.toString(), -1, -1); initColList(collist); //handle method list diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java index 283c196..26ba4e4 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java @@ -19,33 +19,57 @@ package org.apache.sysds.runtime.transform.encode; -import org.apache.wink.json4j.JSONException; -import org.apache.wink.json4j.JSONObject; +import java.util.TreeSet; +import java.util.stream.Collectors; + +import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.transform.TfUtils; import org.apache.sysds.runtime.transform.TfUtils.TfMethod; import org.apache.sysds.runtime.transform.meta.TfMetaUtils; +import org.apache.sysds.runtime.util.IndexRange; import org.apache.sysds.runtime.util.UtilFunctions; +import org.apache.wink.json4j.JSONException; +import org.apache.wink.json4j.JSONObject; public class EncoderOmit extends Encoder { private static final long serialVersionUID = 1978852120416654195L; - private int _rmRows = 0; + private boolean _federated = false; + //TODO perf replace with boolean[rlen] similar to removeEmpty + private TreeSet<Integer> _rmRows = new TreeSet<>(); - public EncoderOmit(JSONObject parsedSpec, String[] colnames, int clen) + public EncoderOmit(JSONObject parsedSpec, String[] colnames, int clen, int minCol, int maxCol) throws JSONException { super(null, clen); if (!parsedSpec.containsKey(TfMethod.OMIT.toString())) return; - int[] collist = TfMetaUtils.parseJsonIDList(parsedSpec, colnames, TfMethod.OMIT.toString()); + int[] collist = TfMetaUtils.parseJsonIDList(parsedSpec, colnames, TfMethod.OMIT.toString(), minCol, maxCol); initColList(collist); + _federated = minCol != -1 || maxCol != -1; + } + + public EncoderOmit() { + super(new int[0], 0); + } + + public EncoderOmit(boolean federated) { + this(); + _federated = federated; + } + + + private EncoderOmit(int[] colList, int clen, TreeSet<Integer> rmRows) { + super(colList, clen); + _rmRows = rmRows; + _federated = true; } public int getNumRemovedRows() { - return _rmRows; + return _rmRows.size(); } public boolean omit(String[] words, TfUtils agents) @@ -67,45 +91,97 @@ public class EncoderOmit extends Encoder } @Override - public void build(FrameBlock in) { - //do nothing + public void build(FrameBlock in) { + if(_federated) + _rmRows = computeRmRows(in); } - + @Override - public MatrixBlock apply(FrameBlock in, MatrixBlock out) - { - //determine output size - int numRows = 0; - for(int i=0; i<out.getNumRows(); i++) { - boolean valid = true; - for(int j=0; j<_colList.length; j++) - valid &= !Double.isNaN(out.quickGetValue(i, _colList[j]-1)); - numRows += valid ? 1 : 0; - } - - //copy over valid rows into the output + public MatrixBlock apply(FrameBlock in, MatrixBlock out) { + // local rmRows for broadcasting encoder in spark + TreeSet<Integer> rmRows; + if(_federated) + rmRows = _rmRows; + else + rmRows = computeRmRows(in); + + // determine output size + int numRows = out.getNumRows() - rmRows.size(); + + // copy over valid rows into the output MatrixBlock ret = new MatrixBlock(numRows, out.getNumColumns(), false); int pos = 0; - for(int i=0; i<in.getNumRows(); i++) { - //determine if valid row or omit - boolean valid = true; - for(int j=0; j<_colList.length; j++) - valid &= !Double.isNaN(out.quickGetValue(i, _colList[j]-1)); - //copy row if necessary - if( valid ) { - for(int j=0; j<out.getNumColumns(); j++) + for(int i = 0; i < in.getNumRows(); i++) { + // copy row if necessary + if(!rmRows.contains(i)) { + for(int j = 0; j < out.getNumColumns(); j++) ret.quickSetValue(pos, j, out.quickGetValue(i, j)); pos++; } } - - //keep info an remove rows - _rmRows = out.getNumRows() - pos; - - return ret; + + _rmRows = rmRows; + + return ret; + } + + private TreeSet<Integer> computeRmRows(FrameBlock in) { + TreeSet<Integer> rmRows = new TreeSet<>(); + ValueType[] schema = in.getSchema(); + for(int i = 0; i < in.getNumRows(); i++) { + boolean valid = true; + for(int colID : _colList) { + Object val = in.get(i, colID - 1); + valid &= !(val == null || (schema[colID - 1] == ValueType.STRING && val.toString().isEmpty())); + } + if(!valid) + rmRows.add(i); + } + return rmRows; } @Override + public Encoder subRangeEncoder(IndexRange ixRange) { + int[] colList = subRangeColList(ixRange); + if(colList.length == 0) + // empty encoder -> sub range encoder does not exist + return null; + + TreeSet<Integer> rmRows = _rmRows.stream().filter((row) -> ixRange.inRowRange(row + 1)) + .map((row) -> (int) (row - (ixRange.rowStart - 1))).collect(Collectors.toCollection(TreeSet::new)); + + return new EncoderOmit(colList, (int) (ixRange.colSpan()), rmRows); + } + + @Override + public void mergeAt(Encoder other, int col) { + if(other instanceof EncoderOmit) { + mergeColumnInfo(other, col); + _rmRows.addAll(((EncoderOmit) other)._rmRows); + return; + } + super.mergeAt(other, col); + } + + @Override + public void updateIndexRanges(long[] beginDims, long[] endDims) { + // first update begin dims + int numRowsToRemove = 0; + Integer removedRow = _rmRows.ceiling(0); + while(removedRow != null && removedRow < beginDims[0]) { + numRowsToRemove++; + removedRow = _rmRows.ceiling(removedRow + 1); + } + beginDims[0] -= numRowsToRemove; + // update end dims + while(removedRow != null && removedRow < endDims[0]) { + numRowsToRemove++; + removedRow = _rmRows.ceiling(removedRow + 1); + } + endDims[0] -= numRowsToRemove; + } + + @Override public FrameBlock getMetaData(FrameBlock out) { //do nothing return out; @@ -116,4 +192,3 @@ public class EncoderOmit extends Encoder //do nothing } } - \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java index d6ceb15..ccd235d 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java @@ -25,6 +25,7 @@ import java.util.List; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.IndexRange; import org.apache.sysds.runtime.util.UtilFunctions; /** @@ -72,17 +73,18 @@ public class EncoderPassThrough extends Encoder } @Override - public Encoder subRangeEncoder(int colStart, int colEnd) { + public Encoder subRangeEncoder(IndexRange ixRange) { List<Integer> colList = new ArrayList<>(); - for (int col : _colList) { - if (col >= colStart && col < colEnd) + for(int col : _colList) { + if(col >= ixRange.colStart && col < ixRange.colEnd) // add the correct column, removed columns before start - colList.add(col - (colStart - 1)); + colList.add((int) (col - (ixRange.colStart - 1))); } - if (colList.isEmpty()) + if(colList.isEmpty()) // empty encoder -> return null return null; - return new EncoderPassThrough(colList.stream().mapToInt(i -> i).toArray(), colEnd - colStart); + return new EncoderPassThrough(colList.stream().mapToInt(i -> i).toArray(), + (int) (ixRange.colEnd - ixRange.colStart)); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java index be1cba9..e195835 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java @@ -27,6 +27,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import org.apache.sysds.runtime.util.IndexRange; import org.apache.wink.json4j.JSONException; import org.apache.wink.json4j.JSONObject; import org.apache.sysds.lops.Lop; @@ -164,25 +165,25 @@ public class EncoderRecode extends Encoder } @Override - public Encoder subRangeEncoder(int colStart, int colEnd) { + public Encoder subRangeEncoder(IndexRange ixRange) { List<Integer> cols = new ArrayList<>(); HashMap<Integer, HashMap<String, Long>> rcdMaps = new HashMap<>(); - for (int col : _colList) { - if (col >= colStart && col < colEnd) { + for(int col : _colList) { + if(ixRange.inColRange(col)) { // add the correct column, removed columns before start // colStart - 1 because colStart is 1-based - int corrColumn = col - (colStart - 1); + int corrColumn = (int) (col - (ixRange.colStart - 1)); cols.add(corrColumn); // copy rcdMap for column rcdMaps.put(corrColumn, new HashMap<>(_rcdMaps.get(col))); } } - if (cols.isEmpty()) + if(cols.isEmpty()) // empty encoder -> sub range encoder does not exist return null; - + int[] colList = cols.stream().mapToInt(i -> i).toArray(); - return new EncoderRecode(colList, colEnd - colStart, rcdMaps); + return new EncoderRecode(colList, (int) ixRange.colSpan(), rcdMaps); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java index 72fab7a..3f1a37b 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java +++ b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java @@ -133,16 +133,11 @@ public class TfMetaUtils else { ix = ArrayUtils.indexOf(colnames, attrs.get(i)) + 1; } - if(ix <= 0) { - if (minCol == -1 && maxCol == -1) { - // only if we remove some columns, ix -1 is expected - throw new RuntimeException("Specified column '" - + attrs.get(i)+"' does not exist."); - } - else // ignore column - continue; - } - colList.add(ix); + if(ix > 0) + colList.add(ix); + else if(minCol == -1 && maxCol == -1) + // only if we remove some columns, ix -1 is expected + throw new RuntimeException("Specified column '" + attrs.get(i) + "' does not exist."); } //ensure ascending order of column IDs @@ -152,33 +147,41 @@ public class TfMetaUtils return arr; } - public static int[] parseJsonObjectIDList(JSONObject spec, String[] colnames, String group) - throws JSONException - { - int[] colList = new int[0]; + public static int[] parseJsonObjectIDList(JSONObject spec, String[] colnames, String group, int minCol, int maxCol) + throws JSONException { + List<Integer> colList = new ArrayList<>(); + int[] arr = new int[0]; boolean ids = spec.containsKey("ids") && spec.getBoolean("ids"); - - if( spec.containsKey(group) && spec.get(group) instanceof JSONArray ) - { - JSONArray colspecs = (JSONArray)spec.get(group); - colList = new int[colspecs.size()]; - for(int j=0; j<colspecs.size(); j++) { - JSONObject colspec = (JSONObject) colspecs.get(j); - colList[j] = ids ? colspec.getInt("id") : - (ArrayUtils.indexOf(colnames, colspec.get("name")) + 1); - if( colList[j] <= 0 ) { - throw new RuntimeException("Specified column '" + - colspec.get(ids?"id":"name")+"' does not exist."); + + if(spec.containsKey(group) && spec.get(group) instanceof JSONArray) { + JSONArray colspecs = (JSONArray) spec.get(group); + for(Object o : colspecs) { + JSONObject colspec = (JSONObject) o; + int ix; + if(ids) { + ix = colspec.getInt("id"); + if(maxCol != -1 && ix >= maxCol) + ix = -1; + if(minCol != -1 && ix >= 0) + ix -= minCol - 1; + } + else { + ix = ArrayUtils.indexOf(colnames, colspec.get("name")) + 1; } + if(ix > 0) + colList.add(ix); + else if(minCol == -1 && maxCol == -1) + throw new RuntimeException( + "Specified column '" + colspec.get(ids ? "id" : "name") + "' does not exist."); } - - //ensure ascending order of column IDs - Arrays.sort(colList); + + // ensure ascending order of column IDs + arr = colList.stream().mapToInt((i) -> i).sorted().toArray(); } - - return colList; + + return arr; } - + /** * Reads transform meta data from an HDFS file path and converts it into an in-memory * FrameBlock object. @@ -227,7 +230,7 @@ public class TfMetaUtils //get list of recode ids List<Integer> recodeIDs = parseRecodeColIDs(spec, colnames); - List<Integer> binIDs = parseBinningColIDs(spec, colnames); + List<Integer> binIDs = parseBinningColIDs(spec, colnames, -1, -1); //create frame block from in-memory strings return convertToTransformMetaDataFrame(rows, colnames, recodeIDs, binIDs, meta, mvmeta); @@ -282,7 +285,7 @@ public class TfMetaUtils //get list of recode ids List<Integer> recodeIDs = parseRecodeColIDs(spec, colnames); - List<Integer> binIDs = parseBinningColIDs(spec, colnames); + List<Integer> binIDs = parseBinningColIDs(spec, colnames, -1, -1); //create frame block from in-memory strings return convertToTransformMetaDataFrame(rows, colnames, recodeIDs, binIDs, meta, mvmeta); @@ -390,26 +393,26 @@ public class TfMetaUtils return specRecodeIDs; } - public static List<Integer> parseBinningColIDs(String spec, String[] colnames) + public static List<Integer> parseBinningColIDs(String spec, String[] colnames, int minCol, int maxCol) throws IOException { try { JSONObject jSpec = new JSONObject(spec); - return parseBinningColIDs(jSpec, colnames); + return parseBinningColIDs(jSpec, colnames, minCol, maxCol); } catch(JSONException ex) { throw new IOException(ex); } } - public static List<Integer> parseBinningColIDs(JSONObject jSpec, String[] colnames) + public static List<Integer> parseBinningColIDs(JSONObject jSpec, String[] colnames, int minCol, int maxCol) throws IOException { try { String binKey = TfMethod.BIN.toString(); if( jSpec.containsKey(binKey) && jSpec.get(binKey) instanceof JSONArray ) { return Arrays.asList(ArrayUtils.toObject( - TfMetaUtils.parseJsonObjectIDList(jSpec, colnames, binKey))); + TfMetaUtils.parseJsonObjectIDList(jSpec, colnames, binKey, minCol, maxCol))); } else { //internally generates return Arrays.asList(ArrayUtils.toObject( diff --git a/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java b/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java index 8b1e42e..af7471c 100644 --- a/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java +++ b/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java @@ -31,6 +31,7 @@ import org.apache.hadoop.fs.FileUtil; import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.permission.FsPermission; import org.apache.hadoop.mapred.JobConf; +import org.apache.sysds.runtime.io.FileFormatPropertiesCSV; import org.apache.wink.json4j.JSONException; import org.apache.wink.json4j.OrderedJSONObject; import org.apache.sysds.common.Types.DataType; @@ -467,6 +468,12 @@ public class HDFSTool } else { mtd.put(DataExpression.AUTHORPARAM, "SystemDS"); } + + if (formatProperties instanceof FileFormatPropertiesCSV) { + FileFormatPropertiesCSV csvProps = (FileFormatPropertiesCSV) formatProperties; + mtd.put(DataExpression.DELIM_HAS_HEADER_ROW, csvProps.hasHeader()); + mtd.put(DataExpression.DELIM_DELIMITER, csvProps.getDelim()); + } SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss z"); mtd.put(DataExpression.CREATEDPARAM, sdf.format(new Date())); diff --git a/src/main/java/org/apache/sysds/runtime/util/IndexRange.java b/src/main/java/org/apache/sysds/runtime/util/IndexRange.java index 69ada3b..4a8d999 100644 --- a/src/main/java/org/apache/sysds/runtime/util/IndexRange.java +++ b/src/main/java/org/apache/sysds/runtime/util/IndexRange.java @@ -51,7 +51,23 @@ public class IndexRange implements Serializable rowStart + delta, rowEnd + delta, colStart + delta, colEnd + delta); } - + + public boolean inColRange(long col) { + return col >= colStart && col < colEnd; + } + + public boolean inRowRange(long row) { + return row >= rowStart && row < rowEnd; + } + + public long colSpan() { + return colEnd - colStart; + } + + public long rowSpan() { + return rowEnd - rowStart; + } + @Override public String toString() { return "["+rowStart+":"+rowEnd+","+colStart+":"+colEnd+"]"; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java new file mode 100644 index 0000000..622e6e0 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated.transform; + +import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.parser.DataExpression; +import org.apache.sysds.runtime.io.FileFormatPropertiesCSV; +import org.apache.sysds.runtime.io.FrameReaderFactory; +import org.apache.sysds.runtime.io.FrameWriter; +import org.apache.sysds.runtime.io.FrameWriterFactory; +import org.apache.sysds.runtime.io.MatrixReaderFactory; +import org.apache.sysds.runtime.matrix.data.FrameBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.runtime.util.UtilFunctions; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +public class TransformFederatedEncodeApplyTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "TransformFederatedEncodeApply"; + private final static String TEST_DIR = "functions/transform/"; + private final static String TEST_CLASS_DIR = TEST_DIR + TransformFederatedEncodeApplyTest.class.getSimpleName() + + "/"; + + // dataset and transform tasks without missing values + private final static String DATASET1 = "homes3/homes.csv"; + private final static String SPEC1 = "homes3/homes.tfspec_recode.json"; + private final static String SPEC1b = "homes3/homes.tfspec_recode2.json"; + private final static String SPEC2 = "homes3/homes.tfspec_dummy.json"; + private final static String SPEC2b = "homes3/homes.tfspec_dummy2.json"; + private final static String SPEC3 = "homes3/homes.tfspec_bin.json"; // recode + private final static String SPEC3b = "homes3/homes.tfspec_bin2.json"; // recode + private final static String SPEC6 = "homes3/homes.tfspec_recode_dummy.json"; + private final static String SPEC6b = "homes3/homes.tfspec_recode_dummy2.json"; + private final static String SPEC7 = "homes3/homes.tfspec_binDummy.json"; // recode+dummy + private final static String SPEC7b = "homes3/homes.tfspec_binDummy2.json"; // recode+dummy + private final static String SPEC8 = "homes3/homes.tfspec_hash.json"; + private final static String SPEC8b = "homes3/homes.tfspec_hash2.json"; + private final static String SPEC9 = "homes3/homes.tfspec_hash_recode.json"; + private final static String SPEC9b = "homes3/homes.tfspec_hash_recode2.json"; + + // dataset and transform tasks with missing values + private final static String DATASET2 = "homes/homes.csv"; + // private final static String SPEC4 = "homes3/homes.tfspec_impute.json"; + // private final static String SPEC4b = "homes3/homes.tfspec_impute2.json"; + private final static String SPEC5 = "homes3/homes.tfspec_omit.json"; + private final static String SPEC5b = "homes3/homes.tfspec_omit2.json"; + + private static final int[] BIN_col3 = new int[] {1, 4, 2, 3, 3, 2, 4}; + private static final int[] BIN_col8 = new int[] {1, 2, 2, 2, 2, 2, 3}; + + public enum TransformType { + RECODE, DUMMY, RECODE_DUMMY, BIN, BIN_DUMMY, + // IMPUTE, + OMIT, + HASH, + HASH_RECODE, + } + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"y"})); + } + + @Test + public void testHomesRecodeIDsCSV() { + runTransformTest(TransformType.RECODE, false); + } + + @Test + public void testHomesDummycodeIDsCSV() { + runTransformTest(TransformType.DUMMY, false); + } + + @Test + public void testHomesRecodeDummycodeIDsCSV() { + runTransformTest(TransformType.RECODE_DUMMY, false); + } + + @Test + public void testHomesBinningIDsCSV() { + runTransformTest(TransformType.BIN, false); + } + + @Test + public void testHomesBinningDummyIDsCSV() { + runTransformTest(TransformType.BIN_DUMMY, false); + } + + @Test + public void testHomesOmitIDsCSV() { + runTransformTest(TransformType.OMIT, false); + } + + // @Test + // public void testHomesImputeIDsCSV() { + // runTransformTest(TransformType.IMPUTE, false); + // } + + @Test + public void testHomesRecodeColnamesCSV() { + runTransformTest(TransformType.RECODE, true); + } + + @Test + public void testHomesDummycodeColnamesCSV() { + runTransformTest(TransformType.DUMMY, true); + } + + @Test + public void testHomesRecodeDummycodeColnamesCSV() { + runTransformTest(TransformType.RECODE_DUMMY, true); + } + + @Test + public void testHomesBinningColnamesCSV() { + runTransformTest(TransformType.BIN, true); + } + + @Test + public void testHomesBinningDummyColnamesCSV() { + runTransformTest(TransformType.BIN_DUMMY, true); + } + + @Test + public void testHomesOmitColnamesCSV() { + runTransformTest(TransformType.OMIT, true); + } + + // @Test + // public void testHomesImputeColnamesCSV() { + // runTransformTest(TransformType.IMPUTE, true); + // } + + @Test + public void testHomesHashColnamesCSV() { + runTransformTest(TransformType.HASH, true); + } + + @Test + public void testHomesHashIDsCSV() { + runTransformTest(TransformType.HASH, false); + } + + @Test + public void testHomesHashRecodeColnamesCSV() { + runTransformTest(TransformType.HASH_RECODE, true); + } + + @Test + public void testHomesHashRecodeIDsCSV() { + runTransformTest(TransformType.HASH_RECODE, false); + } + + private void runTransformTest(TransformType type, boolean colnames) { + ExecMode rtold = setExecMode(ExecMode.SINGLE_NODE); + + // set transform specification + String SPEC = null; + String DATASET = null; + switch(type) { + case RECODE: SPEC = colnames ? SPEC1b : SPEC1; DATASET = DATASET1; break; + case DUMMY: SPEC = colnames ? SPEC2b : SPEC2; DATASET = DATASET1; break; + case BIN: SPEC = colnames ? SPEC3b : SPEC3; DATASET = DATASET1; break; + // case IMPUTE: SPEC = colnames ? SPEC4b : SPEC4; DATASET = DATASET2; break; + case OMIT: SPEC = colnames ? SPEC5b : SPEC5; DATASET = DATASET2; break; + case RECODE_DUMMY: SPEC = colnames ? SPEC6b : SPEC6; DATASET = DATASET1; break; + case BIN_DUMMY: SPEC = colnames ? SPEC7b : SPEC7; DATASET = DATASET1; break; + case HASH: SPEC = colnames ? SPEC8b : SPEC8; DATASET = DATASET1; break; + case HASH_RECODE: SPEC = colnames ? SPEC9b : SPEC9; DATASET = DATASET1; break; + } + + Thread t1 = null, t2 = null; + try { + getAndLoadTestConfiguration(TEST_NAME1); + + int port1 = getRandomAvailablePort(); + t1 = startLocalFedWorkerThread(port1); + int port2 = getRandomAvailablePort(); + t2 = startLocalFedWorkerThread(port2); + + FileFormatPropertiesCSV ffpCSV = new FileFormatPropertiesCSV(true, DataExpression.DEFAULT_DELIM_DELIMITER, + DataExpression.DEFAULT_DELIM_FILL, DataExpression.DEFAULT_DELIM_FILL_VALUE, + DATASET.equals(DATASET1) ? DataExpression.DEFAULT_NA_STRINGS : "NA" + DataExpression.DELIM_NA_STRING_SEP + + ""); + String HOME = SCRIPT_DIR + TEST_DIR; + // split up dataset + FrameBlock dataset = FrameReaderFactory.createFrameReader(FileFormat.CSV, ffpCSV) + .readFrameFromHDFS(HOME + "input/" + DATASET, -1, -1); + + // default for write + ffpCSV.setNAStrings(UtilFunctions.defaultNaString); + FrameWriter fw = FrameWriterFactory.createFrameWriter(FileFormat.CSV, ffpCSV); + + FrameBlock A = new FrameBlock(); + dataset.slice(0, dataset.getNumRows() - 1, 0, dataset.getNumColumns() / 2 - 1, A); + fw.writeFrameToHDFS(A, input("A"), A.getNumRows(), A.getNumColumns()); + HDFSTool.writeMetaDataFile(input("A.mtd"), null, A.getSchema(), Types.DataType.FRAME, + new MatrixCharacteristics(A.getNumRows(), A.getNumColumns()), FileFormat.CSV, ffpCSV); + + FrameBlock B = new FrameBlock(); + dataset.slice(0, dataset.getNumRows() - 1, dataset.getNumColumns() / 2, dataset.getNumColumns() - 1, B); + fw.writeFrameToHDFS(B, input("B"), B.getNumRows(), B.getNumColumns()); + HDFSTool.writeMetaDataFile(input("B.mtd"), null, B.getSchema(), Types.DataType.FRAME, + new MatrixCharacteristics(B.getNumRows(), B.getNumColumns()), FileFormat.CSV, ffpCSV); + + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-nvargs", "in_A=" + TestUtils.federatedAddress(port1, input("A")), + "in_B=" + TestUtils.federatedAddress(port2, input("B")), "rows=" + dataset.getNumRows(), + "cols_A=" + A.getNumColumns(), "cols_B=" + B.getNumColumns(), "TFSPEC=" + HOME + "input/" + SPEC, + "TFDATA1=" + output("tfout1"), "TFDATA2=" + output("tfout2"), "OFMT=csv"}; + + runTest(true, false, null, -1); + + // read input/output and compare + double[][] R1 = DataConverter.convertToDoubleMatrix(MatrixReaderFactory.createMatrixReader(FileFormat.CSV) + .readMatrixFromHDFS(output("tfout1"), -1L, -1L, 1000, -1)); + double[][] R2 = DataConverter.convertToDoubleMatrix(MatrixReaderFactory.createMatrixReader(FileFormat.CSV) + .readMatrixFromHDFS(output("tfout2"), -1L, -1L, 1000, -1)); + TestUtils.compareMatrices(R1, R2, R1.length, R1[0].length, 0); + + // additional checks for binning as encode-decode impossible + if(type == TransformType.BIN) { + for(int i = 0; i < 7; i++) { + Assert.assertEquals(BIN_col3[i], R1[i][2], 1e-8); + Assert.assertEquals(BIN_col8[i], R1[i][7], 1e-8); + } + } + else if(type == TransformType.BIN_DUMMY) { + Assert.assertEquals(14, R1[0].length); + for(int i = 0; i < 7; i++) { + for(int j = 0; j < 4; j++) { // check dummy coded + Assert.assertEquals((j == BIN_col3[i] - 1) ? 1 : 0, R1[i][2 + j], 1e-8); + } + for(int j = 0; j < 3; j++) { // check dummy coded + Assert.assertEquals((j == BIN_col8[i] - 1) ? 1 : 0, R1[i][10 + j], 1e-8); + } + } + } + } + catch(Exception ex) { + throw new RuntimeException(ex); + } + finally { + TestUtils.shutdownThreads(t1, t2); + resetExecMode(rtold); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java index 29afa5b..c45be72 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java @@ -115,11 +115,9 @@ public class TransformFederatedEncodeDecodeTest extends AutomatedTestBase { runTransformEncodeDecodeTest(false, true, Types.FileFormat.BINARY); } - private void runTransformEncodeDecodeTest(boolean recode, boolean sparse, - Types.FileFormat format) { - ExecMode platformOld = rtplatform; - rtplatform = ExecMode.SINGLE_NODE; - + private void runTransformEncodeDecodeTest(boolean recode, boolean sparse, Types.FileFormat format) { + ExecMode rtold = setExecMode(ExecMode.SINGLE_NODE); + Thread t1 = null, t2 = null, t3 = null, t4 = null; try { getAndLoadTestConfiguration(TEST_NAME_RECODE); @@ -197,11 +195,8 @@ public class TransformFederatedEncodeDecodeTest extends AutomatedTestBase { Assert.fail(ex.getMessage()); } finally { - TestUtils.shutdownThread(t1); - TestUtils.shutdownThread(t2); - TestUtils.shutdownThread(t3); - TestUtils.shutdownThread(t4); - rtplatform = platformOld; + TestUtils.shutdownThreads(t1, t2, t3, t4); + resetExecMode(rtold); } } diff --git a/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml b/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml new file mode 100644 index 0000000..921242b --- /dev/null +++ b/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml @@ -0,0 +1,36 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +F1 = federated(type="frame", addresses=list($in_A, $in_B), ranges= + list(list(0,0), list($rows, $cols_A), # A range + list(0, $cols_A), list($rows, $cols_A + $cols_B))); # B range + +jspec = read($TFSPEC, data_type="scalar", value_type="string"); + +[X, M] = transformencode(target=F1, spec=jspec); + +while(FALSE){} + +X2 = transformapply(target=F1, spec=jspec, meta=M); + +write(X, $TFDATA1, format="csv"); +write(X2, $TFDATA2, format="csv"); +