This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new d15fa6e  [SYSTEMDS-2556,2560] Federated transform impute, improved omit
d15fa6e is described below

commit d15fa6e55682b8c8a1e85740b8f7cbffab4dd151
Author: Kevin Innerebner <[email protected]>
AuthorDate: Sun Sep 6 19:21:57 2020 +0200

    [SYSTEMDS-2556,2560] Federated transform impute, improved omit
    
    Closes #1046.
---
 .../federated/FederatedWorkerHandler.java          |   4 +-
 ...tiReturnParameterizedBuiltinFEDInstruction.java |   6 +-
 .../fed/ParameterizedBuiltinFEDInstruction.java    |   2 +-
 .../sysds/runtime/transform/encode/Encoder.java    |   9 +-
 .../sysds/runtime/transform/encode/EncoderBin.java |   4 +-
 .../runtime/transform/encode/EncoderComposite.java |   8 +-
 .../runtime/transform/encode/EncoderDummycode.java |   4 +-
 .../runtime/transform/encode/EncoderFactory.java   |   2 +-
 .../transform/encode/EncoderFeatureHash.java       |   4 +-
 .../runtime/transform/encode/EncoderMVImpute.java  | 388 ++++++++++-----------
 .../runtime/transform/encode/EncoderOmit.java      |  74 ++--
 .../transform/encode/EncoderPassThrough.java       |   4 +-
 .../runtime/transform/encode/EncoderRecode.java    |   4 +-
 .../TransformFederatedEncodeApplyTest.java         | 100 ++++--
 .../transform/TransformFederatedEncodeApply.dml    |   8 +-
 15 files changed, 320 insertions(+), 301 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index b5f0ec8..2690ee6 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -100,9 +100,9 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                        
                        //select the response for the entire batch of requests
                        if (!tmp.isSuccessful()) {
-                               log.error("Command " + request.getType() + " 
failed: " 
+                               log.error("Command " + request.getType() + " 
failed: "
                                        + tmp.getErrorMessage() + "full 
command: \n" + request.toString());
-                               response = (response == null || 
response.isSuccessful()) 
+                               response = (response == null || 
response.isSuccessful())
                                        ? tmp : response; //return first error
                        }
                        else if( request.getType() == RequestType.GET_VAR ) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
index 0fe12b9..047aff3 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
@@ -48,6 +48,7 @@ import 
org.apache.sysds.runtime.transform.encode.EncoderComposite;
 import org.apache.sysds.runtime.transform.encode.EncoderDummycode;
 import org.apache.sysds.runtime.transform.encode.EncoderFactory;
 import org.apache.sysds.runtime.transform.encode.EncoderFeatureHash;
+import org.apache.sysds.runtime.transform.encode.EncoderMVImpute;
 import org.apache.sysds.runtime.transform.encode.EncoderOmit;
 import org.apache.sysds.runtime.transform.encode.EncoderPassThrough;
 import org.apache.sysds.runtime.transform.encode.EncoderRecode;
@@ -102,7 +103,8 @@ public class MultiReturnParameterizedBuiltinFEDInstruction 
extends ComputationFE
                                new EncoderPassThrough(),
                                new EncoderBin(),
                                new EncoderDummycode(),
-                               new EncoderOmit(true)));
+                               new EncoderOmit(true),
+                               new EncoderMVImpute()));
                // first create encoders at the federated workers, then collect 
them and aggregate them to a single large
                // encoder
                FederationMap fedMapping = fin.getFedMapping();
@@ -120,7 +122,7 @@ public class MultiReturnParameterizedBuiltinFEDInstruction 
extends ComputationFE
                                Encoder encoder = (Encoder) 
response.getData()[0];
                                // merge this encoder into a composite encoder
                                synchronized(globalEncoder) {
-                                       globalEncoder.mergeAt(encoder, 
columnOffset);
+                                       globalEncoder.mergeAt(encoder, (int) 
(range.getBeginDims()[0] + 1), columnOffset);
                                }
                                // no synchronization necessary since names 
should anyway match
                                String[] subRangeColNames = (String[]) 
response.getData()[1];
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index 204019f..4f31d4f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -266,7 +266,7 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
 
                                // no synchronization necessary since names 
should anyway match
                                Encoder builtEncoder = (Encoder) 
response.getData()[0];
-                               newOmit.mergeAt(builtEncoder, (int) 
(range.getBeginDims()[1] + 1));
+                               newOmit.mergeAt(builtEncoder, (int) 
(range.getBeginDims()[0] + 1), (int) (range.getBeginDims()[1] + 1));
                        }
                        catch(Exception e) {
                                throw new DMLRuntimeException(e);
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
index 7f47192..0758620 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
@@ -177,9 +177,10 @@ public abstract class Encoder implements Serializable
         * other <code>Encoder</code>.
         * 
         * @param other the encoder that should be merged in
-        * @param col   the position where it should be placed (1-based)
+        * @param row   the row where it should be placed (1-based)
+        * @param col   the col where it should be placed (1-based)
         */
-       public void mergeAt(Encoder other, int col) {
+       public void mergeAt(Encoder other, int row, int col) {
                throw new DMLRuntimeException(
                        this.getClass().getSimpleName() + " does not support 
merging with " + other.getClass().getSimpleName());
        }
@@ -187,8 +188,8 @@ public abstract class Encoder implements Serializable
        /**
         * Update index-ranges to after encoding. Note that only Dummycoding 
changes the ranges.
         *
-        * @param beginDims the begin indexes before encoding
-        * @param endDims   the end indexes before encoding
+        * @param beginDims begin dimensions of range
+        * @param endDims end dimensions of range
         */
        public void updateIndexRanges(long[] beginDims, long[] endDims) {
                // do nothing - default
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
index 351f68d..4caee9b 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
@@ -169,7 +169,7 @@ public class EncoderBin extends Encoder
        }
        
        @Override
-       public void mergeAt(Encoder other, int col) {
+       public void mergeAt(Encoder other, int row, int col) {
                if(other instanceof EncoderBin) {
                        EncoderBin otherBin = (EncoderBin) other;
 
@@ -217,7 +217,7 @@ public class EncoderBin extends Encoder
                        }
                        return;
                }
-               super.mergeAt(other, col);
+               super.mergeAt(other, row, col);
        }
        
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
index c494676..cc59932 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
@@ -117,7 +117,7 @@ public class EncoderComposite extends Encoder
        }
 
        @Override
-       public void mergeAt(Encoder other, int col) {
+       public void mergeAt(Encoder other, int row, int col) {
                if (other instanceof EncoderComposite) {
                        EncoderComposite otherComposite = (EncoderComposite) 
other;
                        // TODO maybe assert that the _encoders never have the 
same type of encoder twice or more
@@ -125,7 +125,7 @@ public class EncoderComposite extends Encoder
                                boolean mergedIn = false;
                                for (Encoder encoder : _encoders) {
                                        if (encoder.getClass() == 
otherEnc.getClass()) {
-                                               encoder.mergeAt(otherEnc, col);
+                                               encoder.mergeAt(otherEnc, row, 
col);
                                                mergedIn = true;
                                                break;
                                        }
@@ -146,7 +146,7 @@ public class EncoderComposite extends Encoder
                }
                for (Encoder encoder : _encoders) {
                        if (encoder.getClass() == other.getClass()) {
-                               encoder.mergeAt(other, col);
+                               encoder.mergeAt(other, row, col);
                                // update dummycode encoder domain sizes based 
on distinctness information from other encoders
                                for (Encoder encDummy : _encoders) {
                                        if (encDummy instanceof 
EncoderDummycode) {
@@ -157,7 +157,7 @@ public class EncoderComposite extends Encoder
                                return;
                        }
                }
-               super.mergeAt(other, col);
+               super.mergeAt(other, row, col);
        }
        
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java
index 19d41ea..f590a04 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java
@@ -128,7 +128,7 @@ public class EncoderDummycode extends Encoder
        }
 
        @Override
-       public void mergeAt(Encoder other, int col) {
+       public void mergeAt(Encoder other, int row, int col) {
                if(other instanceof EncoderDummycode) {
                        mergeColumnInfo(other, col);
 
@@ -138,7 +138,7 @@ public class EncoderDummycode extends Encoder
                        Arrays.fill(_domainSizes, 0, _colList.length, 1);
                        return;
                }
-               super.mergeAt(other, col);
+               super.mergeAt(other, row, col);
        }
        
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index 313e5b2..af929be 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -104,7 +104,7 @@ public class EncoderFactory
                        if( !oIDs.isEmpty() )
                                lencoders.add(new EncoderOmit(jSpec, colnames, 
schema.length, minCol, maxCol));
                        if( !mvIDs.isEmpty() ) {
-                               EncoderMVImpute ma = new EncoderMVImpute(jSpec, 
colnames, schema.length);
+                               EncoderMVImpute ma = new EncoderMVImpute(jSpec, 
colnames, schema.length, minCol, maxCol);
                                ma.initRecodeIDList(rcIDs);
                                lencoders.add(ma);
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java
index 9317dfb..3b6503b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java
@@ -110,14 +110,14 @@ public class EncoderFeatureHash extends Encoder
        }
        
        @Override
-       public void mergeAt(Encoder other, int col) {
+       public void mergeAt(Encoder other, int row, int col) {
                if(other instanceof EncoderFeatureHash) {
                        mergeColumnInfo(other, col);
                        if (((EncoderFeatureHash) other)._K != 0 && _K == 0)
                                _K = ((EncoderFeatureHash) other)._K;
                        return;
                }
-               super.mergeAt(other, col);
+               super.mergeAt(other, row, col);
        }
        
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
index 56749a2..534d16c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
@@ -19,263 +19,119 @@
 
 package org.apache.sysds.runtime.transform.encode;
 
-import java.io.IOException;
-import java.util.BitSet;
+import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
 import java.util.Map.Entry;
+import java.util.Set;
+import java.util.stream.Collectors;
 
 import org.apache.wink.json4j.JSONArray;
 import org.apache.wink.json4j.JSONException;
 import org.apache.wink.json4j.JSONObject;
-import org.apache.sysds.runtime.functionobjects.CM;
+import org.apache.commons.lang.ArrayUtils;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.functionobjects.KahanPlus;
 import org.apache.sysds.runtime.functionobjects.Mean;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
 import org.apache.sysds.runtime.instructions.cp.KahanObject;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import 
org.apache.sysds.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
-import org.apache.sysds.runtime.transform.TfUtils;
 import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
 import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
+import org.apache.sysds.runtime.util.IndexRange;
 import org.apache.sysds.runtime.util.UtilFunctions;
 
-public class EncoderMVImpute extends Encoder 
+public class EncoderMVImpute extends Encoder
 {
        private static final long serialVersionUID = 9057868620144662194L;
 
        public enum MVMethod { INVALID, GLOBAL_MEAN, GLOBAL_MODE, CONSTANT }
        
        private MVMethod[] _mvMethodList = null;
-       private MVMethod[] _mvscMethodList = null; // scaling methods for 
attributes that are imputed and also scaled
-       
-       private BitSet _isMVScaled = null;
-       private CM _varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE); 
        // function object that understands variance computation
        
        // objects required to compute mean and variance of all non-missing 
entries 
-       private Mean _meanFn = Mean.getMeanFnObject();  // function object that 
understands mean computation
+       private final Mean _meanFn = Mean.getMeanFnObject();  // function 
object that understands mean computation
        private KahanObject[] _meanList = null;         // column-level means, 
computed so far
        private long[] _countList = null;               // #of non-missing 
values
        
-       private CM_COV_Object[] _varList = null;        // column-level 
variances, computed so far (for scaling)
-
-       private int[]           _scnomvList = null;        // List of 
attributes that are scaled but not imputed
-       private MVMethod[]      _scnomvMethodList = null;  // scaling methods: 
0 for invalid; 1 for mean-subtraction; 2 for z-scoring
-       private KahanObject[]   _scnomvMeanList = null;    // column-level 
means, for attributes scaled but not imputed
-       private long[]          _scnomvCountList = null;   // #of non-missing 
values, for attributes scaled but not imputed
-       private CM_COV_Object[] _scnomvVarList = null;     // column-level 
variances, computed so far
-       
        private String[] _replacementList = null; // replacements: for 
global_mean, mean; and for global_mode, recode id of mode category
-       private String[] _NAstrings = null;
        private List<Integer> _rcList = null;
        private HashMap<Integer,HashMap<String,Long>> _hist = null;
        
        public String[] getReplacements() { return _replacementList; }
        public KahanObject[] getMeans()   { return _meanList; }
-       public CM_COV_Object[] getVars()  { return _varList; }
-       public KahanObject[] getMeans_scnomv()   { return _scnomvMeanList; }
-       public CM_COV_Object[] getVars_scnomv()  { return _scnomvVarList; }
        
-       public EncoderMVImpute(JSONObject parsedSpec, String[] colnames, int 
clen) 
+       public EncoderMVImpute(JSONObject parsedSpec, String[] colnames, int 
clen, int minCol, int maxCol)
                throws JSONException
        {
                super(null, clen);
                
                //handle column list
-               int[] collist = TfMetaUtils.parseJsonObjectIDList(parsedSpec, 
colnames, TfMethod.IMPUTE.toString(), -1, -1);
+               int[] collist = TfMetaUtils
+                       .parseJsonObjectIDList(parsedSpec, colnames, 
TfMethod.IMPUTE.toString(), minCol, maxCol);
                initColList(collist);
        
                //handle method list
-               parseMethodsAndReplacments(parsedSpec);
+               parseMethodsAndReplacements(parsedSpec, colnames, minCol);
                
                //create reuse histograms
                _hist = new HashMap<>();
        }
        
-       public EncoderMVImpute(JSONObject parsedSpec, String[] colnames, 
String[] NAstrings, int clen)
-               throws JSONException 
-       {
-               super(null, clen);
-               boolean isMV = 
parsedSpec.containsKey(TfMethod.IMPUTE.toString());
-               boolean isSC = 
parsedSpec.containsKey(TfMethod.SCALE.toString());
-               _NAstrings = NAstrings;
-               
-               if(!isMV) {
-                       // MV Impute is not applicable
-                       _colList = null;
-                       _mvMethodList = null;
-                       _meanList = null;
-                       _countList = null;
-                       _replacementList = null;
-               }
-               else {
-                       JSONObject mvobj = (JSONObject) 
parsedSpec.get(TfMethod.IMPUTE.toString());
-                       JSONArray mvattrs = (JSONArray) 
mvobj.get(TfUtils.JSON_ATTRS);
-                       JSONArray mvmthds = (JSONArray) 
mvobj.get(TfUtils.JSON_MTHD);
-                       int mvLength = mvattrs.size();
-                       
-                       _colList = new int[mvLength];
-                       _mvMethodList = new MVMethod[mvLength];
-                       
-                       _meanList = new KahanObject[mvLength];
-                       _countList = new long[mvLength];
-                       _varList = new CM_COV_Object[mvLength];
-                       
-                       _isMVScaled = new BitSet(_colList.length);
-                       _isMVScaled.clear();
-                       
-                       for(int i=0; i < _colList.length; i++) {
-                               _colList[i] = 
UtilFunctions.toInt(mvattrs.get(i));
-                               _mvMethodList[i] = 
MVMethod.values()[UtilFunctions.toInt(mvmthds.get(i))]; 
-                               _meanList[i] = new KahanObject(0, 0);
-                       }
-                       
-                       _replacementList = new String[mvLength];        // 
contains replacements for all columns (scale and categorical)
-                       
-                       JSONArray constants = 
(JSONArray)mvobj.get(TfUtils.JSON_CONSTS);
-                       for(int i=0; i < constants.size(); i++) {
-                               if ( constants.get(i) == null )
-                                       _replacementList[i] = "NaN";
-                               else
-                                       _replacementList[i] = 
constants.get(i).toString();
-                       }
-               }
-               
-               // Handle scaled attributes
-               if ( !isSC )
-               {
-                       // scaling is not applicable
-                       _scnomvCountList = null;
-                       _scnomvMeanList = null;
-                       _scnomvVarList = null;
-               }
-               else
-               {
-                       if ( _colList != null ) 
-                               _mvscMethodList = new MVMethod[_colList.length];
-                       
-                       JSONObject scobj = (JSONObject) 
parsedSpec.get(TfMethod.SCALE.toString());
-                       JSONArray scattrs = (JSONArray) 
scobj.get(TfUtils.JSON_ATTRS);
-                       JSONArray scmthds = (JSONArray) 
scobj.get(TfUtils.JSON_MTHD);
-                       int scLength = scattrs.size();
-                       
-                       int[] _allscaled = new int[scLength];
-                       int scnomv = 0, colID;
-                       byte mthd;
-                       for(int i=0; i < scLength; i++)
-                       {
-                               colID = UtilFunctions.toInt(scattrs.get(i));
-                               mthd = (byte) 
UtilFunctions.toInt(scmthds.get(i));
-                               
-                               _allscaled[i] = colID;
-                               
-                               // check if the attribute is also MV imputed
-                               int mvidx = isApplicable(colID);
-                               if(mvidx != -1)
-                               {
-                                       _isMVScaled.set(mvidx);
-                                       _mvscMethodList[mvidx] = 
MVMethod.values()[mthd];
-                                       _varList[mvidx] = new CM_COV_Object();
-                               }
-                               else
-                                       scnomv++;       // count of scaled but 
not imputed 
-                       }
-                       
-                       if(scnomv > 0)
-                       {
-                               _scnomvList = new int[scnomv];
-                               _scnomvMethodList = new MVMethod[scnomv];
+       public EncoderMVImpute() {
+               super(new int[0], 0);
+       }
        
-                               _scnomvMeanList = new KahanObject[scnomv];
-                               _scnomvCountList = new long[scnomv];
-                               _scnomvVarList = new CM_COV_Object[scnomv];
-                               
-                               for(int i=0, idx=0; i < scLength; i++)
-                               {
-                                       colID = 
UtilFunctions.toInt(scattrs.get(i));
-                                       mthd = 
(byte)UtilFunctions.toInt(scmthds.get(i));
-                                                       
-                                       if(isApplicable(colID) == -1)
-                                       {       // scaled but not imputed
-                                               _scnomvList[idx] = colID;
-                                               _scnomvMethodList[idx] = 
MVMethod.values()[mthd];
-                                               _scnomvMeanList[idx] = new 
KahanObject(0, 0);
-                                               _scnomvVarList[idx] = new 
CM_COV_Object();
-                                               idx++;
-                                       }
-                               }
-                       }
-               }
+       
+       public EncoderMVImpute(int[] colList, MVMethod[] mvMethodList, String[] 
replacementList, KahanObject[] meanList,
+                       long[] countList, List<Integer> rcList, int clen) {
+               super(colList, clen);
+               _mvMethodList = mvMethodList;
+               _replacementList = replacementList;
+               _meanList = meanList;
+               _countList = countList;
+               _rcList = rcList;
        }
-
-       private void parseMethodsAndReplacments(JSONObject parsedSpec) throws 
JSONException {
+       
+       private void parseMethodsAndReplacements(JSONObject parsedSpec, 
String[] colnames, int offset) throws JSONException {
                JSONArray mvspec = (JSONArray) 
parsedSpec.get(TfMethod.IMPUTE.toString());
+               boolean ids = parsedSpec.containsKey("ids") && 
parsedSpec.getBoolean("ids");
+               // make space for all elements
                _mvMethodList = new MVMethod[mvspec.size()];
                _replacementList = new String[mvspec.size()];
                _meanList = new KahanObject[mvspec.size()];
                _countList = new long[mvspec.size()];
-               for(int i=0; i < mvspec.size(); i++) {
-                       JSONObject mvobj = (JSONObject)mvspec.get(i);
-                       _mvMethodList[i] = 
MVMethod.valueOf(mvobj.get("method").toString().toUpperCase()); 
-                       if( _mvMethodList[i] == MVMethod.CONSTANT ) {
-                               _replacementList[i] = 
mvobj.getString("value").toString();
-                       }
-                       _meanList[i] = new KahanObject(0, 0);
-               }
-       }
-               
-       public void prepare(String[] words) throws IOException {
+               // sort for binary search
+               Arrays.sort(_colList);
                
-               try {
-                       String w = null;
-                       if(_colList != null)
-                       for(int i=0; i <_colList.length; i++) {
-                               int colID = _colList[i];
-                               w = 
UtilFunctions.unquote(words[colID-1].trim());
-                               
-                               try {
-                               if(!TfUtils.isNA(_NAstrings, w)) {
-                                       _countList[i]++;
-                                       
-                                       boolean computeMean = (_mvMethodList[i] 
== MVMethod.GLOBAL_MEAN || _isMVScaled.get(i) );
-                                       if(computeMean) {
-                                               // global_mean
-                                               double d = 
UtilFunctions.parseToDouble(w, UtilFunctions.defaultNaString);
-                                               _meanFn.execute2(_meanList[i], 
d, _countList[i]);
-                                               
-                                               if (_isMVScaled.get(i) && 
_mvscMethodList[i] == MVMethod.GLOBAL_MODE)
-                                                       
_varFn.execute(_varList[i], d);
-                                       }
-                                       else {
-                                               // global_mode or constant
-                                               // Nothing to do here. Mode is 
computed using recode maps.
-                                       }
-                               }
-                               } catch (NumberFormatException e) 
-                               {
-                                       throw new RuntimeException("Encountered 
\"" + w + "\" in column ID \"" + colID + "\", when expecting a numeric value. 
Consider adding \"" + w + "\" to na.strings, along with an appropriate 
imputation method.");
+               int listIx = 0;
+               for(Object o : mvspec) {
+                       JSONObject mvobj = (JSONObject) o;
+                       int ixOffset = offset == -1 ? 0 : offset - 1;
+                       // check for position -> -1 if not present
+                       int pos = Arrays.binarySearch(_colList,
+                               ids ? mvobj.getInt("id") - ixOffset : 
ArrayUtils.indexOf(colnames, mvobj.get("name")) + 1);
+                       if(pos >= 0) {
+                               // add to arrays
+                               _mvMethodList[listIx] = 
MVMethod.valueOf(mvobj.get("method").toString().toUpperCase());
+                               if(_mvMethodList[listIx] == MVMethod.CONSTANT) {
+                                       _replacementList[listIx] = 
mvobj.getString("value");
                                }
+                               _meanList[listIx++] = new KahanObject(0, 0);
                        }
-                       
-                       // Compute mean and variance for attributes that are 
scaled but not imputed
-                       if(_scnomvList != null)
-                       for(int i=0; i < _scnomvList.length; i++) 
-                       {
-                               int colID = _scnomvList[i];
-                               w = 
UtilFunctions.unquote(words[colID-1].trim());
-                               double d = UtilFunctions.parseToDouble(w, 
UtilFunctions.defaultNaString);
-                               _scnomvCountList[i]++;          // not 
required, this is always equal to total #records processed
-                               _meanFn.execute2(_scnomvMeanList[i], d, 
_scnomvCountList[i]);
-                               if(_scnomvMethodList[i] == MVMethod.GLOBAL_MODE)
-                                       _varFn.execute(_scnomvVarList[i], d);
-                       }
-               } catch(Exception e) {
-                       throw new IOException(e);
                }
+               // make arrays required size
+               _mvMethodList = Arrays.copyOf(_mvMethodList, listIx);
+               _replacementList = Arrays.copyOf(_replacementList, listIx);
+               _meanList = Arrays.copyOf(_meanList, listIx);
+               _countList = Arrays.copyOf(_countList, listIx);
        }
        
        public MVMethod getMethod(int colID) {
-               int idx = isApplicable(colID);          
+               int idx = isApplicable(colID);
                if(idx == -1)
                        return MVMethod.INVALID;
                else
@@ -287,8 +143,8 @@ public class EncoderMVImpute extends Encoder
                return (idx == -1) ? 0 : _countList[idx];
        }
        
-       public String getReplacement(int colID)  {
-               int idx = isApplicable(colID);          
+       public String getReplacement(int colID) {
+               int idx = isApplicable(colID);
                return (idx == -1) ? null : _replacementList[idx];
        }
        
@@ -321,7 +177,7 @@ public class EncoderMVImpute extends Encoder
                                                if( key != null && 
!key.isEmpty() ) {
                                                        Long val = 
hist.get(key);
                                                        hist.put(key, 
(val!=null) ? val+1 : 1);
-                                               }       
+                                               }
                                        }
                                        _hist.put(colID, hist);
                                        long max = Long.MIN_VALUE; 
@@ -349,12 +205,98 @@ public class EncoderMVImpute extends Encoder
                }
                return out;
        }
+
+       @Override
+       public Encoder subRangeEncoder(IndexRange ixRange) {
+               Map<Integer, ColInfo> map = new HashMap<>();
+               for(int i = 0; i < _colList.length; i++) {
+                       int col = _colList[i];
+                       if(ixRange.inColRange(col))
+                               map.put((int) (_colList[i] - (ixRange.colStart 
- 1)),
+                                       new ColInfo(_mvMethodList[i], 
_replacementList[i], _meanList[i], _countList[i], _hist.get(i)));
+               }
+               if(map.size() == 0)
+                       // empty encoder -> sub range encoder does not exist
+                       return null;
+
+               int[] colList = new int[map.size()];
+               MVMethod[] mvMethodList = new MVMethod[map.size()];
+               String[] replacementList = new String[map.size()];
+               KahanObject[] meanList = new KahanObject[map.size()];
+               long[] countList = new long[map.size()];
+
+               fillListsFromMap(map, colList, mvMethodList, replacementList, 
meanList, countList, _hist);
+
+               if(_rcList == null)
+                       _rcList = new ArrayList<>();
+               List<Integer> rcList = 
_rcList.stream().filter(ixRange::inColRange).map(i -> (int) (i - 
(ixRange.colStart - 1)))
+                       .collect(Collectors.toList());
+
+               return new EncoderMVImpute(colList, mvMethodList, 
replacementList, meanList, countList, rcList,
+                       (int) ixRange.colSpan());
+       }
+
+       private static void fillListsFromMap(Map<Integer, ColInfo> map, int[] 
colList, MVMethod[] mvMethodList,
+               String[] replacementList, KahanObject[] meanList, long[] 
countList,
+               HashMap<Integer, HashMap<String, Long>> hist) {
+               int i = 0;
+               for(Entry<Integer, ColInfo> entry : map.entrySet()) {
+                       colList[i] = entry.getKey();
+                       mvMethodList[i] = entry.getValue()._method;
+                       replacementList[i] = entry.getValue()._replacement;
+                       meanList[i] = entry.getValue()._mean;
+                       countList[i++] = entry.getValue()._count;
+
+                       hist.put(entry.getKey(), entry.getValue()._hist);
+               }
+       }
+
+       @Override
+       public void mergeAt(Encoder other, int row, int col) {
+               if(other instanceof EncoderMVImpute) {
+                       EncoderMVImpute otherImpute = (EncoderMVImpute) other;
+                       Map<Integer, ColInfo> map = new HashMap<>();
+                       for(int i = 0; i < _colList.length; i++) {
+                               map.put(_colList[i],
+                                       new ColInfo(_mvMethodList[i], 
_replacementList[i], _meanList[i], _countList[i], _hist.get(i + 1)));
+                       }
+                       for(int i = 0; i < other._colList.length; i++) {
+                               int column = other._colList[i] + (col - 1);
+                               ColInfo otherColInfo = new 
ColInfo(otherImpute._mvMethodList[i], otherImpute._replacementList[i],
+                                       otherImpute._meanList[i], 
otherImpute._countList[i], otherImpute._hist.get(i + 1));
+                               ColInfo colInfo = map.get(column);
+                               if(colInfo == null)
+                                       map.put(column, otherColInfo);
+                               else
+                                       colInfo.merge(otherColInfo);
+                       }
+
+                       _colList = new int[map.size()];
+                       _mvMethodList = new MVMethod[map.size()];
+                       _replacementList = new String[map.size()];
+                       _meanList = new KahanObject[map.size()];
+                       _countList = new long[map.size()];
+                       _hist = new HashMap<>();
+
+                       fillListsFromMap(map, _colList, _mvMethodList, 
_replacementList, _meanList, _countList, _hist);
+                       // update number of columns
+                       _clen = Math.max(_clen, col - 1 + other._clen);
+
+                       if(_rcList == null)
+                               _rcList = new ArrayList<>();
+                       Set<Integer> rcSet = new HashSet<>(_rcList);
+                       rcSet.addAll(otherImpute._rcList.stream().map(i -> i + 
(col - 1)).collect(Collectors.toSet()));
+                       _rcList = new ArrayList<>(rcSet);
+                       return;
+               }
+               super.mergeAt(other, row, col);
+       }
        
        @Override
        public FrameBlock getMetaData(FrameBlock out) {
                for( int j=0; j<_colList.length; j++ ) {
                        out.getColumnMetadata(_colList[j]-1)
-                          .setMvValue(_replacementList[j]);
+                               .setMvValue(_replacementList[j]);
                }
                return out;
        }
@@ -391,4 +333,48 @@ public class EncoderMVImpute extends Encoder
        public HashMap<String,Long> getHistogram( int colID ) {
                return _hist.get(colID);
        }
+       
+       private static class ColInfo {
+               MVMethod _method;
+               String _replacement;
+               KahanObject _mean;
+               long _count;
+               HashMap<String, Long> _hist;
+
+               ColInfo(MVMethod method, String replacement, KahanObject mean, 
long count, HashMap<String, Long> hist) {
+                       _method = method;
+                       _replacement = replacement;
+                       _mean = mean;
+                       _count = count;
+                       _hist = hist;
+               }
+
+               public void merge(ColInfo otherColInfo) {
+                       if(_method != otherColInfo._method)
+                               throw new DMLRuntimeException("Tried to merge 
two different impute methods: " + _method.name() + " vs. "
+                                       + otherColInfo._method.name());
+                       switch(_method) {
+                               case CONSTANT:
+                                       assert 
_replacement.equals(otherColInfo._replacement);
+                                       break;
+                               case GLOBAL_MEAN:
+                                       _mean._sum *= _count;
+                                       _mean._correction *= _count;
+                                       
KahanPlus.getKahanPlusFnObject().execute(_mean, otherColInfo._mean._sum * 
otherColInfo._count);
+                                       
KahanPlus.getKahanPlusFnObject().execute(_mean,
+                                               otherColInfo._mean._correction 
* otherColInfo._count);
+                                       _count += otherColInfo._count;
+                                       break;
+                               case GLOBAL_MODE:
+                                       if (_hist == null)
+                                               _hist = new 
HashMap<>(otherColInfo._hist);
+                                       else
+                                               // add counts
+                                               _hist.replaceAll((key, count) 
-> count + otherColInfo._hist.getOrDefault(key, 0L));
+                                       break;
+                               default:
+                                       throw new DMLRuntimeException("Method 
`" + _method.name() + "` not supported for federated impute");
+                       }
+               }
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java
index 26ba4e4..bbc83e4 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java
@@ -19,9 +19,7 @@
 
 package org.apache.sysds.runtime.transform.encode;
 
-import java.util.TreeSet;
-import java.util.stream.Collectors;
-
+import java.util.Arrays;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -38,8 +36,7 @@ public class EncoderOmit extends Encoder
        private static final long serialVersionUID = 1978852120416654195L;
 
        private boolean _federated = false;
-       //TODO perf replace with boolean[rlen] similar to removeEmpty
-       private TreeSet<Integer> _rmRows = new TreeSet<>();
+       private boolean[] _rmRows = new boolean[0];
 
        public EncoderOmit(JSONObject parsedSpec, String[] colnames, int clen, 
int minCol, int maxCol)
                throws JSONException 
@@ -61,19 +58,24 @@ public class EncoderOmit extends Encoder
                _federated = federated;
        }
        
-       
-       private EncoderOmit(int[] colList, int clen, TreeSet<Integer> rmRows) {
+       private EncoderOmit(int[] colList, int clen, boolean[] rmRows) {
                super(colList, clen);
                _rmRows = rmRows;
                _federated = true;
        }
-       
+
+       public int getNumRemovedRows(boolean[] rmRows) {
+               int cnt = 0;
+               for(boolean v : rmRows)
+                       cnt += v ? 1 : 0;
+               return cnt;
+       }
+
        public int getNumRemovedRows() {
-               return _rmRows.size();
+               return getNumRemovedRows(_rmRows);
        }
        
-       public boolean omit(String[] words, TfUtils agents) 
-       {
+       public boolean omit(String[] words, TfUtils agents) {
                if( !isApplicable() )
                        return false;
                
@@ -99,21 +101,21 @@ public class EncoderOmit extends Encoder
        @Override
        public MatrixBlock apply(FrameBlock in, MatrixBlock out) {
                // local rmRows for broadcasting encoder in spark
-               TreeSet<Integer> rmRows;
+               boolean[] rmRows;
                if(_federated)
                        rmRows = _rmRows;
                else
                        rmRows = computeRmRows(in);
 
                // determine output size
-               int numRows = out.getNumRows() - rmRows.size();
+               int numRows = out.getNumRows() - getNumRemovedRows(rmRows);
 
                // copy over valid rows into the output
                MatrixBlock ret = new MatrixBlock(numRows, out.getNumColumns(), 
false);
                int pos = 0;
                for(int i = 0; i < in.getNumRows(); i++) {
                        // copy row if necessary
-                       if(!rmRows.contains(i)) {
+                       if(!rmRows[i]) {
                                for(int j = 0; j < out.getNumColumns(); j++)
                                        ret.quickSetValue(pos, j, 
out.quickGetValue(i, j));
                                pos++;
@@ -125,17 +127,19 @@ public class EncoderOmit extends Encoder
                return ret;
        }
 
-       private TreeSet<Integer> computeRmRows(FrameBlock in) {
-               TreeSet<Integer> rmRows = new TreeSet<>();
+       private boolean[] computeRmRows(FrameBlock in) {
+               boolean[] rmRows = new boolean[in.getNumRows()];
                ValueType[] schema = in.getSchema();
+               //TODO perf evaluate if column-wise scan more efficient
+               //  (sequential but less impact of early abort)
                for(int i = 0; i < in.getNumRows(); i++) {
-                       boolean valid = true;
                        for(int colID : _colList) {
                                Object val = in.get(i, colID - 1);
-                               valid &= !(val == null || (schema[colID - 1] == 
ValueType.STRING && val.toString().isEmpty()));
+                               if (val == null || (schema[colID - 1] == 
ValueType.STRING && val.toString().isEmpty())) {
+                                       rmRows[i] = true;
+                                       break; // early abort
+                               }
                        }
-                       if(!valid)
-                               rmRows.add(i);
                }
                return rmRows;
        }
@@ -146,38 +150,38 @@ public class EncoderOmit extends Encoder
                if(colList.length == 0)
                        // empty encoder -> sub range encoder does not exist
                        return null;
-
-               TreeSet<Integer> rmRows = _rmRows.stream().filter((row) -> 
ixRange.inRowRange(row + 1))
-                       .map((row) -> (int) (row - (ixRange.rowStart - 
1))).collect(Collectors.toCollection(TreeSet::new));
+               boolean[] rmRows = _rmRows;
+               if (_rmRows.length > 0)
+                       rmRows = Arrays.copyOfRange(rmRows, (int) 
ixRange.rowStart - 1, (int) ixRange.rowEnd - 1);
 
                return new EncoderOmit(colList, (int) (ixRange.colSpan()), 
rmRows);
        }
 
        @Override
-       public void mergeAt(Encoder other, int col) {
+       public void mergeAt(Encoder other, int row, int col) {
                if(other instanceof EncoderOmit) {
                        mergeColumnInfo(other, col);
-                       _rmRows.addAll(((EncoderOmit) other)._rmRows);
+                       EncoderOmit otherOmit = (EncoderOmit) other;
+                       _rmRows = Arrays.copyOf(_rmRows, 
Math.max(_rmRows.length, (row - 1) + otherOmit._rmRows.length));
+                       for (int i = 0; i < otherOmit._rmRows.length; i++)
+                               _rmRows[(row - 1) + 1] |= otherOmit._rmRows[i];
                        return;
                }
-               super.mergeAt(other, col);
+               super.mergeAt(other, row, col);
        }
        
        @Override
        public void updateIndexRanges(long[] beginDims, long[] endDims) {
                // first update begin dims
                int numRowsToRemove = 0;
-               Integer removedRow = _rmRows.ceiling(0);
-               while(removedRow != null && removedRow < beginDims[0]) {
-                       numRowsToRemove++;
-                       removedRow = _rmRows.ceiling(removedRow + 1);
-               }
+               for (int i = 0; i < beginDims[0] - 1 && i < _rmRows.length; i++)
+                       if (_rmRows[i])
+                               numRowsToRemove++;
                beginDims[0] -= numRowsToRemove;
                // update end dims
-               while(removedRow != null && removedRow < endDims[0]) {
-                       numRowsToRemove++;
-                       removedRow = _rmRows.ceiling(removedRow + 1);
-               }
+               for (int i = 0; i < endDims[0] - 1 && i < _rmRows.length; i++)
+                       if (_rmRows[i])
+                               numRowsToRemove++;
                endDims[0] -= numRowsToRemove;
        }
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
index ccd235d..ac414e9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
@@ -88,12 +88,12 @@ public class EncoderPassThrough extends Encoder
        }
        
        @Override
-       public void mergeAt(Encoder other, int col) {
+       public void mergeAt(Encoder other, int row, int col) {
                if(other instanceof EncoderPassThrough) {
                        mergeColumnInfo(other, col);
                        return;
                }
-               super.mergeAt(other, col);
+               super.mergeAt(other, row, col);
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
index d8d524a..6a1ea0b 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
@@ -189,7 +189,7 @@ public class EncoderRecode extends Encoder
        }
 
        @Override
-       public void mergeAt(Encoder other, int col) {
+       public void mergeAt(Encoder other, int row, int col) {
                if(other instanceof EncoderRecode) {
                        mergeColumnInfo(other, col);
                        
@@ -214,7 +214,7 @@ public class EncoderRecode extends Encoder
                        }
                        return;
                }
-               super.mergeAt(other, col);
+               super.mergeAt(other, row, col);
        }
        
        public int[] numDistinctValues() {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
index 622e6e0..3aa0981 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.test.functions.federated.transform;
 
+import java.io.IOException;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.common.Types.FileFormat;
@@ -42,8 +43,8 @@ import org.junit.Test;
 public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
        private final static String TEST_NAME1 = 
"TransformFederatedEncodeApply";
        private final static String TEST_DIR = "functions/transform/";
-       private final static String TEST_CLASS_DIR = TEST_DIR + 
TransformFederatedEncodeApplyTest.class.getSimpleName()
-               + "/";
+       private final static String TEST_CLASS_DIR = TEST_DIR 
+               + TransformFederatedEncodeApplyTest.class.getSimpleName() + "/";
 
        // dataset and transform tasks without missing values
        private final static String DATASET1 = "homes3/homes.csv";
@@ -64,8 +65,8 @@ public class TransformFederatedEncodeApplyTest extends 
AutomatedTestBase {
 
        // dataset and transform tasks with missing values
        private final static String DATASET2 = "homes/homes.csv";
-       // private final static String SPEC4 = 
"homes3/homes.tfspec_impute.json";
-       // private final static String SPEC4b = 
"homes3/homes.tfspec_impute2.json";
+       private final static String SPEC4 = "homes3/homes.tfspec_impute.json";
+       private final static String SPEC4b = "homes3/homes.tfspec_impute2.json";
        private final static String SPEC5 = "homes3/homes.tfspec_omit.json";
        private final static String SPEC5b = "homes3/homes.tfspec_omit2.json";
 
@@ -73,11 +74,7 @@ public class TransformFederatedEncodeApplyTest extends 
AutomatedTestBase {
        private static final int[] BIN_col8 = new int[] {1, 2, 2, 2, 2, 2, 3};
 
        public enum TransformType {
-               RECODE, DUMMY, RECODE_DUMMY, BIN, BIN_DUMMY,
-               // IMPUTE,
-               OMIT,
-               HASH,
-               HASH_RECODE,
+               RECODE, DUMMY, RECODE_DUMMY, BIN, BIN_DUMMY, IMPUTE, OMIT, 
HASH, HASH_RECODE,
        }
 
        @Override
@@ -116,10 +113,10 @@ public class TransformFederatedEncodeApplyTest extends 
AutomatedTestBase {
                runTransformTest(TransformType.OMIT, false);
        }
 
-       // @Test
-       // public void testHomesImputeIDsCSV() {
-       // runTransformTest(TransformType.IMPUTE, false);
-       // }
+       @Test
+       public void testHomesImputeIDsCSV() {
+               runTransformTest(TransformType.IMPUTE, false);
+       }
 
        @Test
        public void testHomesRecodeColnamesCSV() {
@@ -151,10 +148,10 @@ public class TransformFederatedEncodeApplyTest extends 
AutomatedTestBase {
                runTransformTest(TransformType.OMIT, true);
        }
 
-       // @Test
-       // public void testHomesImputeColnamesCSV() {
-       // runTransformTest(TransformType.IMPUTE, true);
-       // }
+       @Test
+       public void testHomesImputeColnamesCSV() {
+               runTransformTest(TransformType.IMPUTE, true);
+       }
 
        @Test
        public void testHomesHashColnamesCSV() {
@@ -186,7 +183,7 @@ public class TransformFederatedEncodeApplyTest extends 
AutomatedTestBase {
                        case RECODE: SPEC = colnames ? SPEC1b : SPEC1; DATASET 
= DATASET1; break;
                        case DUMMY: SPEC = colnames ? SPEC2b : SPEC2; DATASET = 
DATASET1; break;
                        case BIN: SPEC = colnames ? SPEC3b : SPEC3; DATASET = 
DATASET1; break;
-                       // case IMPUTE: SPEC = colnames ? SPEC4b : SPEC4; 
DATASET = DATASET2; break;
+                       case IMPUTE: SPEC = colnames ? SPEC4b : SPEC4; DATASET 
= DATASET2; break;
                        case OMIT: SPEC = colnames ? SPEC5b : SPEC5; DATASET = 
DATASET2; break;
                        case RECODE_DUMMY: SPEC = colnames ? SPEC6b : SPEC6; 
DATASET = DATASET1; break;
                        case BIN_DUMMY: SPEC = colnames ? SPEC7b : SPEC7; 
DATASET = DATASET1; break;
@@ -194,7 +191,7 @@ public class TransformFederatedEncodeApplyTest extends 
AutomatedTestBase {
                        case HASH_RECODE: SPEC = colnames ? SPEC9b : SPEC9; 
DATASET = DATASET1; break;
                }
 
-               Thread t1 = null, t2 = null;
+               Thread t1 = null, t2 = null, t3 = null, t4 = null;
                try {
                        getAndLoadTestConfiguration(TEST_NAME1);
 
@@ -202,11 +199,14 @@ public class TransformFederatedEncodeApplyTest extends 
AutomatedTestBase {
                        t1 = startLocalFedWorkerThread(port1);
                        int port2 = getRandomAvailablePort();
                        t2 = startLocalFedWorkerThread(port2);
+                       int port3 = getRandomAvailablePort();
+                       t3 = startLocalFedWorkerThread(port3);
+                       int port4 = getRandomAvailablePort();
+                       t4 = startLocalFedWorkerThread(port4);
 
                        FileFormatPropertiesCSV ffpCSV = new 
FileFormatPropertiesCSV(true, DataExpression.DEFAULT_DELIM_DELIMITER,
-                               DataExpression.DEFAULT_DELIM_FILL, 
DataExpression.DEFAULT_DELIM_FILL_VALUE,
-                               DATASET.equals(DATASET1) ? 
DataExpression.DEFAULT_NA_STRINGS : "NA" + DataExpression.DELIM_NA_STRING_SEP
-                                       + "");
+                               DataExpression.DEFAULT_DELIM_FILL, 
DataExpression.DEFAULT_DELIM_FILL_VALUE, DATASET.equals(DATASET1) ?
+                               DataExpression.DEFAULT_NA_STRINGS : "NA" + 
DataExpression.DELIM_NA_STRING_SEP + "");
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        // split up dataset
                        FrameBlock dataset = 
FrameReaderFactory.createFrameReader(FileFormat.CSV, ffpCSV)
@@ -216,23 +216,37 @@ public class TransformFederatedEncodeApplyTest extends 
AutomatedTestBase {
                        ffpCSV.setNAStrings(UtilFunctions.defaultNaString);
                        FrameWriter fw = 
FrameWriterFactory.createFrameWriter(FileFormat.CSV, ffpCSV);
 
-                       FrameBlock A = new FrameBlock();
-                       dataset.slice(0, dataset.getNumRows() - 1, 0, 
dataset.getNumColumns() / 2 - 1, A);
-                       fw.writeFrameToHDFS(A, input("A"), A.getNumRows(), 
A.getNumColumns());
-                       HDFSTool.writeMetaDataFile(input("A.mtd"), null, 
A.getSchema(), Types.DataType.FRAME,
-                               new MatrixCharacteristics(A.getNumRows(), 
A.getNumColumns()), FileFormat.CSV, ffpCSV);
-
-                       FrameBlock B = new FrameBlock();
-                       dataset.slice(0, dataset.getNumRows() - 1, 
dataset.getNumColumns() / 2, dataset.getNumColumns() - 1, B);
-                       fw.writeFrameToHDFS(B, input("B"), B.getNumRows(), 
B.getNumColumns());
-                       HDFSTool.writeMetaDataFile(input("B.mtd"), null, 
B.getSchema(), Types.DataType.FRAME,
-                               new MatrixCharacteristics(B.getNumRows(), 
B.getNumColumns()), FileFormat.CSV, ffpCSV);
+                       writeDatasetSlice(dataset, fw, ffpCSV, "AH",
+                               0,
+                               dataset.getNumRows() / 2 - 1,
+                               0,
+                               dataset.getNumColumns() / 2 - 1);
+
+                       writeDatasetSlice(dataset, fw, ffpCSV, "AL",
+                               dataset.getNumRows() / 2,
+                               dataset.getNumRows() - 1,
+                               0,
+                               dataset.getNumColumns() / 2 - 1);
+
+                       writeDatasetSlice(dataset, fw, ffpCSV, "BH",
+                               0,
+                               dataset.getNumRows() / 2 - 1,
+                               dataset.getNumColumns() / 2,
+                               dataset.getNumColumns() - 1);
+
+                       writeDatasetSlice(dataset, fw, ffpCSV, "BL",
+                               dataset.getNumRows() / 2,
+                               dataset.getNumRows() - 1,
+                               dataset.getNumColumns() / 2,
+                               dataset.getNumColumns() - 1);
 
                        fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
-                       programArgs = new String[] {"-nvargs", "in_A=" + 
TestUtils.federatedAddress(port1, input("A")),
-                               "in_B=" + TestUtils.federatedAddress(port2, 
input("B")), "rows=" + dataset.getNumRows(),
-                               "cols_A=" + A.getNumColumns(), "cols_B=" + 
B.getNumColumns(), "TFSPEC=" + HOME + "input/" + SPEC,
-                               "TFDATA1=" + output("tfout1"), "TFDATA2=" + 
output("tfout2"), "OFMT=csv"};
+                       programArgs = new String[] {"-nvargs", "in_AH=" + 
TestUtils.federatedAddress(port1, input("AH")),
+                               "in_AL=" + TestUtils.federatedAddress(port2, 
input("AL")),
+                               "in_BH=" + TestUtils.federatedAddress(port3, 
input("BH")),
+                               "in_BL=" + TestUtils.federatedAddress(port4, 
input("BL")), "rows=" + dataset.getNumRows(),
+                               "cols=" + dataset.getNumColumns(), "TFSPEC=" + 
HOME + "input/" + SPEC, "TFDATA1=" + output("tfout1"),
+                               "TFDATA2=" + output("tfout2"), "OFMT=csv"};
 
                        runTest(true, false, null, -1);
 
@@ -266,8 +280,18 @@ public class TransformFederatedEncodeApplyTest extends 
AutomatedTestBase {
                        throw new RuntimeException(ex);
                }
                finally {
-                       TestUtils.shutdownThreads(t1, t2);
+                       TestUtils.shutdownThreads(t1, t2, t3, t4);
                        resetExecMode(rtold);
                }
        }
+
+       private void writeDatasetSlice(FrameBlock dataset, FrameWriter fw, 
FileFormatPropertiesCSV ffpCSV, String name,
+               int rl, int ru, int cl, int cu) throws IOException {
+               FrameBlock AH = new FrameBlock();
+               dataset.slice(rl, ru, cl, cu, AH);
+               fw.writeFrameToHDFS(AH, input(name), AH.getNumRows(), 
AH.getNumColumns());
+               
HDFSTool.writeMetaDataFile(input(DataExpression.getMTDFileName(name)), null, 
AH.getSchema(),
+                       Types.DataType.FRAME, new 
MatrixCharacteristics(AH.getNumRows(), AH.getNumColumns()),
+                       FileFormat.CSV, ffpCSV);
+       }
 }
diff --git 
a/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml 
b/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml
index 921242b..28cdcda 100644
--- a/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml
+++ b/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml
@@ -19,9 +19,11 @@
 #
 #-------------------------------------------------------------
 
-F1 = federated(type="frame", addresses=list($in_A, $in_B), ranges=
-    list(list(0,0), list($rows, $cols_A), # A range
-    list(0, $cols_A), list($rows, $cols_A + $cols_B))); # B range
+F1 = federated(type="frame", addresses=list($in_AH, $in_AL, $in_BH, $in_BL), 
ranges=list(
+    list(0,0), list($rows / 2, $cols / 2), # AH range
+    list($rows / 2,0), list($rows, $cols / 2), # AL range
+    list(0,$cols / 2), list($rows / 2, $cols), # BH range
+    list($rows / 2,$cols / 2), list($rows, $cols))); # BL range
 
 jspec = read($TFSPEC, data_type="scalar", value_type="string");
 

Reply via email to