This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 7eccbfe  [SYSTEMDS-2554,2558,2561] Federated transform decode 
(recoding)
7eccbfe is described below

commit 7eccbfeb047dd029f8e110e65d208543c3d60ee5
Author: Kevin Innerebner <[email protected]>
AuthorDate: Thu Aug 20 22:53:56 2020 +0200

    [SYSTEMDS-2554,2558,2561] Federated transform decode (recoding)
    
    Closes #1027.
---
 .../controlprogram/caching/CacheableData.java      |  12 +-
 .../controlprogram/federated/FederatedData.java    |   1 +
 .../federated/FederatedWorkerHandler.java          |  14 +-
 .../cp/ParameterizedBuiltinCPInstruction.java      |  10 +-
 .../instructions/fed/FEDInstructionUtils.java      |   5 +-
 .../fed/ParameterizedBuiltinFEDInstruction.java    | 174 +++++++++++++++++----
 .../runtime/transform/decode/DecoderFactory.java   |  21 ++-
 .../sysds/runtime/transform/encode/Encoder.java    |   4 +-
 .../runtime/transform/encode/EncoderFactory.java   |   2 +-
 .../runtime/transform/encode/EncoderRecode.java    |  13 +-
 .../sysds/runtime/transform/meta/TfMetaUtils.java  |   2 +-
 .../functions/federated/FederatedNegativeTest.java |  32 +---
 .../TransformFederatedEncodeDecodeTest.java        |  21 ++-
 .../transform/TransformFederatedEncodeDecode.dml   |  11 +-
 14 files changed, 228 insertions(+), 94 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index f7e893f..720534a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -768,7 +768,7 @@ public abstract class CacheableData<T extends CacheBlock> 
extends Data
                        new Path(_hdfsFileName), new Path(fName));
                
                //actual export (note: no direct transfer of local copy in 
order to ensure blocking (and hence, parallelism))
-               if( isDirty() || !eqScheme ||
+               if( isDirty() || !eqScheme || isFederated() ||
                        (pWrite && !isEqualOutputFormat(outputFormat)) ) 
                {
                        // CASE 1: dirty in-mem matrix or pWrite w/ different 
format (write matrix to fname; load into memory if evicted)
@@ -781,13 +781,15 @@ public abstract class CacheableData<T extends CacheBlock> 
extends Data
                                {
                                        if( getRDDHandle()==null || 
getRDDHandle().allowsShortCircuitRead() )
                                                _data = readBlobFromHDFS( 
_hdfsFileName );
-                                       else
+                                       else if( getRDDHandle() != null )
                                                _data = readBlobFromRDD( 
getRDDHandle(), new MutableBoolean() );
+                                       else 
+                                               _data = readBlobFromFederated( 
getFedMapping() );
+                                       
                                        setDirty(false);
                                }
-                               catch (IOException e)
-                               {
-                                   throw new DMLRuntimeException("Reading of " 
+ _hdfsFileName + " ("+hashCode()+") failed.", e);
+                               catch (IOException e) {
+                                       throw new DMLRuntimeException("Reading 
of " + _hdfsFileName + " ("+hashCode()+") failed.", e);
                                }
                        }
                        //get object from cache
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
index d161522..9f5f942 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
@@ -119,6 +119,7 @@ public class FederatedData {
        /**
         * Executes an federated operation on a federated worker.
         *
+        * @param address socket address (incl host and port)
         * @param request the requested operation
         * @return the response
         */
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 6dd0abc..f4af303 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -256,7 +256,8 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                        pb.execute(ec); //execute single instruction
                }
                catch(Exception ex) {
-                       return new FederatedResponse(ResponseType.ERROR, 
ex.getMessage());
+                       return new FederatedResponse(ResponseType.ERROR, new 
FederatedWorkerHandlerException(
+                               "Exception of type " + ex.getClass() + " thrown 
when processing EXEC_INST request", ex));
                }
                return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
        }
@@ -276,12 +277,19 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                        return udf.execute(ec, inputs);
                }
                catch(Exception ex) {
-                       return new FederatedResponse(ResponseType.ERROR, 
ex.getMessage());
+                       return new FederatedResponse(ResponseType.ERROR, new 
FederatedWorkerHandlerException(
+                               "Exception of type " + ex.getClass() + " thrown 
when processing EXEC_UDF request", ex));
                }
        }
 
        private FederatedResponse execClear() {
-               _ecm.clear();
+               try {
+                       _ecm.clear();
+               }
+               catch(Exception ex) {
+                       return new FederatedResponse(ResponseType.ERROR, new 
FederatedWorkerHandlerException(
+                               "Exception of type " + ex.getClass() + " thrown 
when processing CLEAR request", ex));
+               }
                return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
        }
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index 5e62475..cfb20e3 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -445,7 +445,7 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
                }
                else if (opcode.equalsIgnoreCase("transformdecode")) {
                        CPOperand target = getTargetOperand();
-                       CPOperand meta = getLiteral(params.get("meta"), 
DataType.FRAME);
+                       CPOperand meta = getLiteral("meta", ValueType.UNKNOWN, 
DataType.FRAME);
                        CPOperand spec = getStringLiteral("spec");
                        return Pair.of(output.getName(), new 
LineageItem(getOpcode(),
                                LineageItemUtils.getLineage(ec, target, meta, 
spec)));
@@ -476,12 +476,12 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
        private CPOperand getBoolLiteral(String name) {
                return getLiteral(name, ValueType.BOOLEAN);
        }
-
-       private CPOperand getLiteral(String name, DataType dt) {
-               return new CPOperand(name, ValueType.UNKNOWN, DataType.FRAME);
-       }
        
        private CPOperand getLiteral(String name, ValueType vt) {
                return new CPOperand(params.get(name), vt, DataType.SCALAR, 
true);
        }
+       
+       private CPOperand getLiteral(String name, ValueType vt, DataType dt) {
+               return new CPOperand(params.get(name), vt, dt);
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index c8cf729..a1b0a08 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -75,10 +75,13 @@ public class FEDInstructionUtils {
                        }
                }
                else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
-                       ParameterizedBuiltinCPInstruction pinst = 
(ParameterizedBuiltinCPInstruction)inst;
+                       ParameterizedBuiltinCPInstruction pinst = 
(ParameterizedBuiltinCPInstruction) inst;
                        if(pinst.getOpcode().equals("replace") && 
pinst.getTarget(ec).isFederated()) {
                                fedinst = 
ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString());
                        }
+                       else if(pinst.getOpcode().equals("transformdecode") && 
pinst.getTarget(ec).isFederated()) {
+                               return 
ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString());
+                       }
                }
                else if (inst instanceof 
MultiReturnParameterizedBuiltinCPInstruction) {
                        MultiReturnParameterizedBuiltinCPInstruction minst = 
(MultiReturnParameterizedBuiltinCPInstruction) inst;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index ec28965..e3523ed 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -19,103 +19,213 @@
 
 package org.apache.sysds.runtime.instructions.fed;
 
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.LinkedHashMap;
+
+import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.lops.Lop;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
-import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.meta.MetaDataFormat;
+import org.apache.sysds.runtime.privacy.PrivacyMonitor;
+import org.apache.sysds.runtime.transform.decode.Decoder;
+import org.apache.sysds.runtime.transform.decode.DecoderFactory;
 
 public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstruction {
-
        protected final LinkedHashMap<String, String> params;
-       
-       protected ParameterizedBuiltinFEDInstruction(Operator op,
-               LinkedHashMap<String, String> paramsMap, CPOperand out, String 
opcode, String istr)
-       {
+
+       protected ParameterizedBuiltinFEDInstruction(Operator op, 
LinkedHashMap<String, String> paramsMap, CPOperand out,
+               String opcode, String istr) {
                super(FEDType.ParameterizedBuiltin, op, null, null, out, 
opcode, istr);
                params = paramsMap;
        }
-       
-       public HashMap<String,String> getParameterMap() { 
-               return params; 
+
+       public HashMap<String, String> getParameterMap() {
+               return params;
        }
-       
+
        public String getParam(String key) {
                return getParameterMap().get(key);
        }
-       
+
        public static LinkedHashMap<String, String> 
constructParameterMap(String[] params) {
                // process all elements in "params" except first(opcode) and 
last(output)
-               LinkedHashMap<String,String> paramMap = new LinkedHashMap<>();
-               
+               LinkedHashMap<String, String> paramMap = new LinkedHashMap<>();
+
                // all parameters are of form <name=value>
                String[] parts;
-               for ( int i=1; i <= params.length-2; i++ ) {
+               for(int i = 1; i <= params.length - 2; i++) {
                        parts = params[i].split(Lop.NAME_VALUE_SEPARATOR);
                        paramMap.put(parts[0], parts[1]);
                }
-               
+
                return paramMap;
        }
-       
-       public static ParameterizedBuiltinFEDInstruction parseInstruction ( 
String str ) {
+
+       public static ParameterizedBuiltinFEDInstruction 
parseInstruction(String str) {
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
                // first part is always the opcode
                String opcode = parts[0];
                // last part is always the output
-               CPOperand out = new CPOperand( parts[parts.length-1] ); 
-       
+               CPOperand out = new CPOperand(parts[parts.length - 1]);
+
                // process remaining parts and build a hash map
-               LinkedHashMap<String,String> paramsMap = 
constructParameterMap(parts);
-       
+               LinkedHashMap<String, String> paramsMap = 
constructParameterMap(parts);
+
                // determine the appropriate value function
-               ValueFunction func = null;
                if( opcode.equalsIgnoreCase("replace") ) {
-                       func = 
ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
+                       ValueFunction func = 
ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
                        return new ParameterizedBuiltinFEDInstruction(new 
SimpleOperator(func), paramsMap, out, opcode, str);
                }
+               else if(opcode.equals("transformapply") || 
opcode.equals("transformdecode")) {
+                       return new ParameterizedBuiltinFEDInstruction(null, 
paramsMap, out, opcode, str);
+               }
                else {
                        throw new DMLRuntimeException("Unsupported opcode (" + 
opcode + ") for ParameterizedBuiltinFEDInstruction.");
                }
        }
-       
-       @Override 
+
+       @Override
        public void processInstruction(ExecutionContext ec) {
                String opcode = getOpcode();
-               if ( opcode.equalsIgnoreCase("replace") ) {
-                       //similar to unary federated instructions, get 
federated input
-                       //execute instruction, and derive federated output 
matrix
+               if(opcode.equalsIgnoreCase("replace")) {
+                       // similar to unary federated instructions, get 
federated input
+                       // execute instruction, and derive federated output 
matrix
                        MatrixObject mo = getTarget(ec);
                        FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output,
-                               new CPOperand[]{getTargetOperand()}, new 
long[]{mo.getFedMapping().getID()});
+                               new CPOperand[] {getTargetOperand()}, new 
long[] {mo.getFedMapping().getID()});
                        mo.getFedMapping().execute(getTID(), true, fr1);
-                       
-                       //derive new fed mapping for output
+
+                       // derive new fed mapping for output
                        MatrixObject out = ec.getMatrixObject(output);
                        
out.getDataCharacteristics().set(mo.getDataCharacteristics());
                        
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
                }
+               else if(opcode.equalsIgnoreCase("transformdecode"))
+                       transformDecode(ec);
                else {
                        throw new DMLRuntimeException("Unknown opcode : " + 
opcode);
                }
        }
        
+       private void transformDecode(ExecutionContext ec) {
+               // acquire locks
+               MatrixObject mo = ec.getMatrixObject(params.get("target"));
+               FrameBlock meta = ec.getFrameInput(params.get("meta"));
+               String spec = params.get("spec");
+               
+               FederationMap fedMapping = mo.getFedMapping();
+               
+               ValueType[] schema = new ValueType[(int) mo.getNumColumns()];
+               long varID = FederationUtils.getNextFedDataID();
+               FederationMap decodedMapping = fedMapping.mapParallel(varID, 
(range, data) -> {
+                       int columnOffset = (int) range.getBeginDims()[1] + 1;
+                       
+                       FrameBlock subMeta = new FrameBlock();
+                       synchronized(meta) {
+                               meta.slice(0, meta.getNumRows() - 1, 
columnOffset - 1, (int) range.getEndDims()[1] - 1, subMeta);
+                       }
+                       
+                       FederatedResponse response;
+                       try {
+                               response = data.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+                                       varID, new 
DecodeMatrix(data.getVarID(), varID, subMeta, spec, columnOffset))).get();
+                               if(!response.isSuccessful())
+                                       response.throwExceptionFromResponse();
+                               
+                               ValueType[] subSchema = (ValueType[]) 
response.getData()[0];
+                               synchronized(schema) {
+                                       // It would be possible to assert that 
different federated workers don't give different value
+                                       // types for the same columns, but the 
performance impact is not worth the effort
+                                       System.arraycopy(subSchema, 0, schema, 
columnOffset - 1, subSchema.length);
+                               }
+                       }
+                       catch(Exception e) {
+                               throw new DMLRuntimeException(e);
+                       }
+                       return null;
+               });
+               
+               // construct a federated matrix with the encoded data
+               FrameObject decodedFrame = ec.getFrameObject(output);
+               decodedFrame.setSchema(schema);
+               
decodedFrame.getDataCharacteristics().set(mo.getDataCharacteristics());
+               // set the federated mapping for the matrix
+               decodedFrame.setFedMapping(decodedMapping);
+               
+               // release locks
+               ec.releaseFrameInput(params.get("meta"));
+       }
+
        public MatrixObject getTarget(ExecutionContext ec) {
                return ec.getMatrixObject(params.get("target"));
        }
-       
+
        private CPOperand getTargetOperand() {
                return new CPOperand(params.get("target"), ValueType.FP64, 
DataType.MATRIX);
        }
+       
+       public static class DecodeMatrix extends FederatedUDF {
+               private static final long serialVersionUID = 
2376756757742169692L;
+               private final long _outputID;
+               private final FrameBlock _meta;
+               private final String _spec;
+               private final int _globalOffset;
+               
+               public DecodeMatrix(long input, long outputID, FrameBlock meta, 
String spec, int globalOffset) {
+                       super(new long[]{input});
+                       _outputID = outputID;
+                       _meta = meta;
+                       _spec = spec;
+                       _globalOffset = globalOffset;
+               }
+               
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixObject mo = (MatrixObject) 
PrivacyMonitor.handlePrivacy(data[0]);
+                       MatrixBlock mb = mo.acquireRead();
+                       String[] colNames = _meta.getColumnNames();
+                       
+                       // compute transformdecode
+                       Decoder decoder = DecoderFactory.createDecoder(_spec, 
colNames, null,
+                               _meta, mb.getNumColumns(), _globalOffset, 
_globalOffset + mb.getNumColumns());
+                       FrameBlock fbout = decoder.decode(mb, new 
FrameBlock(decoder.getSchema()));
+                       fbout.setColumnNames(Arrays.copyOfRange(colNames, 0, 
fbout.getNumColumns()));
+                       
+                       // copy characteristics
+                       MatrixCharacteristics mc = new 
MatrixCharacteristics(mo.getDataCharacteristics());
+                       FrameObject fo = new 
FrameObject(OptimizerUtils.getUniqueTempFileName(),
+                               new MetaDataFormat(mc, 
Types.FileFormat.BINARY));
+                       // set the encoded data
+                       fo.acquireModify(fbout);
+                       fo.release();
+                       mo.release();
+                       
+                       // add it to the list of variables
+                       ec.setVariable(String.valueOf(_outputID), fo);
+                       // return schema
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[] 
{fo.getSchema()});
+               }
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java 
b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java
index 977d494..b51547d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java
@@ -37,24 +37,33 @@ import static 
org.apache.sysds.runtime.util.CollectionUtils.unionDistinct;
 public class DecoderFactory 
 {
        public static Decoder createDecoder(String spec, String[] colnames, 
ValueType[] schema, FrameBlock meta) {
-               return createDecoder(spec, colnames, schema, meta, 
meta.getNumColumns());
+               return createDecoder(spec, colnames, schema, meta, 
meta.getNumColumns(), -1, -1);
        }
        
-       public static Decoder createDecoder(String spec, String[] colnames, 
ValueType[] schema, FrameBlock meta, int clen) 
+       public static Decoder createDecoder(String spec, String[] colnames, 
ValueType[] schema, FrameBlock meta, int clen) {
+               return createDecoder(spec, colnames, schema, meta, clen, -1, 
-1);
+       }
+
+       public static Decoder createDecoder(String spec, String[] colnames, 
ValueType[] schema, FrameBlock meta, int minCol,
+               int maxCol) {
+               return createDecoder(spec, colnames, schema, meta, 
meta.getNumColumns(), minCol, maxCol);
+       }
+
+       public static Decoder createDecoder(String spec, String[] colnames, 
ValueType[] schema,
+               FrameBlock meta, int clen, int minCol, int maxCol)
        {
                Decoder decoder = null;
                
-               try
-               {
+               try {
                        //parse transform specification
                        JSONObject jSpec = new JSONObject(spec);
                        List<Decoder> ldecoders = new ArrayList<>();
                        
                        //create decoders 'recode', 'dummy' and 'pass-through'
                        List<Integer> rcIDs = Arrays.asList(ArrayUtils.toObject(
-                                       TfMetaUtils.parseJsonIDList(jSpec, 
colnames, TfMethod.RECODE.toString())));
+                                       TfMetaUtils.parseJsonIDList(jSpec, 
colnames, TfMethod.RECODE.toString(), minCol, maxCol)));
                        List<Integer> dcIDs = Arrays.asList(ArrayUtils.toObject(
-                                       TfMetaUtils.parseJsonIDList(jSpec, 
colnames, TfMethod.DUMMYCODE.toString()))); 
+                                       TfMetaUtils.parseJsonIDList(jSpec, 
colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol)));
                        rcIDs = unionDistinct(rcIDs, dcIDs);
                        int len = dcIDs.isEmpty() ? 
Math.min(meta.getNumColumns(), clen) : meta.getNumColumns();
                        List<Integer> ptIDs = 
except(UtilFunctions.getSeqList(1, len, 1), rcIDs);
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
index 4c0ad9e..5945e27 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
@@ -134,7 +134,7 @@ public abstract class Encoder implements Serializable
         */
        public Encoder subRangeEncoder(int colStart, int colEnd) {
                throw new DMLRuntimeException(
-                       this.getClass().getName() + " does not support the 
creation of a sub-range encoder");
+                       this.getClass().getSimpleName() + " does not support 
the creation of a sub-range encoder");
        }
 
        /**
@@ -145,7 +145,7 @@ public abstract class Encoder implements Serializable
         */
        protected void mergeColumnInfo(Encoder other, int col) {
                // update number of columns
-               _clen = Math.max(_colList.length, col - 1 + other.getNumCols());
+               _clen = Math.max(_clen, col - 1 + other._clen);
 
                // update the new columns that this encoder operates on
                Set<Integer> colListAgg = new HashSet<>(); // for dedup
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index dcd2b1c..2070485 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -85,7 +85,7 @@ public class EncoderFactory
                        
                        //create individual encoders
                        if( !rcIDs.isEmpty() ) {
-                               EncoderRecode ra = new EncoderRecode(jSpec, 
colnames, clen);
+                               EncoderRecode ra = new EncoderRecode(jSpec, 
colnames, clen, minCol, maxCol);
                                
ra.setColList(ArrayUtils.toPrimitive(rcIDs.toArray(new Integer[0])));
                                lencoders.add(ra);
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
index fe3f5b1..d4b201e 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
@@ -43,11 +43,11 @@ public class EncoderRecode extends Encoder
        private HashMap<Integer, HashMap<String, Long>> _rcdMaps  = new 
HashMap<>();
        private HashMap<Integer, HashSet<Object>> _rcdMapsPart = null;
        
-       public EncoderRecode(JSONObject parsedSpec, String[] colnames, int clen)
+       public EncoderRecode(JSONObject parsedSpec, String[] colnames, int 
clen, int minCol, int maxCol)
                throws JSONException 
        {
                super(null, clen);
-               _colList = TfMetaUtils.parseJsonIDList(parsedSpec, colnames, 
TfMethod.RECODE.toString());
+               _colList = TfMetaUtils.parseJsonIDList(parsedSpec, colnames, 
TfMethod.RECODE.toString(), minCol, maxCol);
        }
        
        private EncoderRecode(int[] colList, int clen) {
@@ -58,6 +58,11 @@ public class EncoderRecode extends Encoder
                this(new int[0], 0);
        }
        
+       private EncoderRecode(int[] colList, int clen, HashMap<Integer, 
HashMap<String, Long>> rcdMaps) {
+               super(colList, clen);
+               _rcdMaps = rcdMaps;
+       }
+       
        public HashMap<Integer, HashMap<String,Long>> getCPRecodeMaps() { 
                return _rcdMaps; 
        }
@@ -180,9 +185,7 @@ public class EncoderRecode extends Encoder
                        return null;
                
                int[] colList = cols.stream().mapToInt(i -> i).toArray();
-               EncoderRecode subRangeEncoder = new EncoderRecode(colList, 
colEnd - colStart);
-               subRangeEncoder._rcdMaps = rcdMaps;
-               return subRangeEncoder;
+               return new EncoderRecode(colList, colEnd - colStart, rcdMaps);
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java 
b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
index 39f5650..72fab7a 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
@@ -135,7 +135,7 @@ public class TfMetaUtils
                                }
                                if(ix <= 0) {
                                        if (minCol == -1 && maxCol == -1) {
-                                               // only if we cut of some 
columns, ix -1 is expected
+                                               // only if we remove some 
columns, ix -1 is expected
                                                throw new 
RuntimeException("Specified column '"
                                                        + attrs.get(i)+"' does 
not exist.");
                                        }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedNegativeTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedNegativeTest.java
index a355275..8c60cec 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedNegativeTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedNegativeTest.java
@@ -19,8 +19,6 @@
 
 package org.apache.sysds.test.functions.federated;
 
-import org.apache.log4j.Level;
-import org.apache.log4j.Logger;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.controlprogram.federated.*;
 import org.apache.sysds.test.AutomatedTestBase;
@@ -30,7 +28,6 @@ import org.junit.Test;
 import java.net.InetSocketAddress;
 import java.util.HashMap;
 import java.util.Map;
-import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Future;
 
 import static org.junit.Assert.assertFalse;
@@ -38,23 +35,16 @@ import static org.junit.Assert.assertTrue;
 
 @net.jcip.annotations.NotThreadSafe
 public class FederatedNegativeTest {
-       protected static Logger log = 
Logger.getLogger(FederatedNegativeTest.class);
-
-       static {
-               Logger.getLogger("org.apache.sysds").setLevel(Level.OFF);
-       }
-
        @Test
        public void NegativeTest1() {
-        int port = AutomatedTestBase.getRandomAvailablePort();
+               int port = AutomatedTestBase.getRandomAvailablePort();
                String[] args = {"-w", Integer.toString(port)};
-        Thread t = AutomatedTestBase.startLocalFedWorkerWithArgs(args);
+               Thread t = AutomatedTestBase.startLocalFedWorkerWithArgs(args);
+               FederationUtils.resetFedDataID(); //ensure expected ID when 
tests run in single JVM
                Map<FederatedRange, FederatedData> fedMap = new HashMap<>();
                FederatedRange r = new FederatedRange(new long[]{0,0}, new 
long[]{1,1});
-               FederatedData d = new FederatedData(
-                               Types.DataType.SCALAR,
-                               new InetSocketAddress("localhost", port),
-                               "Nowhere");
+               FederatedData d = new FederatedData(Types.DataType.SCALAR,
+                       new InetSocketAddress("localhost", port), "Nowhere");
                fedMap.put(r,d);
                FederationMap fedM = new FederationMap(fedMap);
                FederatedRequest fr = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR);
@@ -62,17 +52,11 @@ public class FederatedNegativeTest {
                try {
                        FederatedResponse fres = res[0].get();
                        assertFalse(fres.isSuccessful());
-                       assertTrue(fres.getErrorMessage().contains("Variable 0 
does not exist at federated worker"));
-
-               } catch (InterruptedException e) {
-                       e.printStackTrace();
-               } catch (ExecutionException e) {
-                       e.printStackTrace();
-               } catch (Exception e) {
+                       assertTrue(fres.getErrorMessage().contains("Variable 1 
does not exist at federated worker"));
+               }
+               catch (Exception e) {
                        e.printStackTrace();
                }
-
                TestUtils.shutdownThread(t);
        }
-
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFederatedEncodeDecodeTest.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFederatedEncodeDecodeTest.java
index df9def8..1f9c87d 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFederatedEncodeDecodeTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFederatedEncodeDecodeTest.java
@@ -49,7 +49,8 @@ public class TransformFederatedEncodeDecodeTest extends 
AutomatedTestBase {
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"FO"}));
+               addTestConfiguration(TEST_NAME1,
+                       new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new 
String[] {"FO1", "FO2"}));
        }
 
        @Test
@@ -126,14 +127,26 @@ public class TransformFederatedEncodeDecodeTest extends 
AutomatedTestBase {
                                "in_AL=" + 
TestUtils.federatedAddress("localhost", port2, input("AL")),
                                "in_BU=" + 
TestUtils.federatedAddress("localhost", port3, input("BU")),
                                "in_BL=" + 
TestUtils.federatedAddress("localhost", port4, input("BL")), "rows=" + rows, 
"cols=" + cols,
-                               "spec_file=" + SCRIPT_DIR + TEST_DIR + SPEC, 
"out=" + output("FO"), "format=" + format.toString()};
+                               "spec_file=" + SCRIPT_DIR + TEST_DIR + SPEC, 
"out1=" + output("FO1"), "out2=" + output("FO2"),
+                               "format=" + format.toString()};
 
                        // run test
                        runTest(true, false, null, -1);
 
-                       // compare matrices (values recoded to identical codes)
+                       // compare frame before and after encode and decode
                        FrameReader reader = 
FrameReaderFactory.createFrameReader(format);
-                       FrameBlock FO = reader.readFrameFromHDFS(output("FO"), 
15, 2);
+                       FrameBlock OUT = 
reader.readFrameFromHDFS(output("FO2"), rows, cols);
+                       for(int r = 0; r < rows; r++) {
+                               for(int c = 0; c < cols; c++) {
+                                       String expected = c < cols / 2 ? 
Double.toString(A[r][c]) : "Str" + B[r][c - cols / 2];
+                                       String val = (String) OUT.get(r, c);
+                                       Assert.assertEquals("Enc- and Decoded 
frame does not match the source frame: " + expected + " vs "
+                                               + val, expected, val);
+                               }
+                       }
+                       // TODO federate the aggregated result so that the 
decode is applied in a federated environment
+                       // compare matrices (values recoded to identical codes)
+                       FrameBlock FO = reader.readFrameFromHDFS(output("FO1"), 
15, 2);
                        HashMap<String, Long> cFA = getCounts(A, B);
                        Iterator<String[]> iterFO = FO.getStringRowIterator();
                        while(iterFO.hasNext()) {
diff --git 
a/src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml 
b/src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml
index 1ff5446..50174d7 100644
--- a/src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml
+++ b/src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml
@@ -19,19 +19,20 @@
 #
 #-------------------------------------------------------------
 
-F1 = federated(type="frame", addresses=list($in_AU, $in_AL, $in_BU, $in_BL), 
ranges=
+F = federated(type="frame", addresses=list($in_AU, $in_AL, $in_BU, $in_BL), 
ranges=
   list(list(0,0), list($rows / 2, $cols / 2), # AUpper range
     list($rows / 2, 0), list($rows, $cols / 2), # ALower range
     list(0, $cols / 2), list($rows / 2, $cols), # BUpper range
     list($rows / 2, $cols / 2), list($rows, $cols))); # BLower range
 jspec = read($spec_file, data_type="scalar", value_type="string");
 
-[X, M] = transformencode(target=F1, spec=jspec);
+[X, M] = transformencode(target=F, spec=jspec);
 
 A = aggregate(target=X[,1], groups=X[,2], fn="count");
 Ag = cbind(A, seq(1,nrow(A)));
 
-F2 = transformdecode(target=Ag, spec=jspec, meta=M);
-
-write(F2, $out, format=$format);
+FO1 = transformdecode(target=Ag, spec=jspec, meta=M);
+FO2 = transformdecode(target=X, spec=jspec, meta=M);
 
+write(FO1, $out1, format=$format);
+write(FO2, $out2, format=$format);

Reply via email to