This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 641949da67 [SYSTEMDS-3374] Federation primitive for local to federated 
data
641949da67 is described below

commit 641949da67a2abfdbbdab0164359f9b6e387622a
Author: OlgaOvcharenko <[email protected]>
AuthorDate: Sun Jun 5 17:11:55 2022 +0200

    [SYSTEMDS-3374] Federation primitive for local to federated data
    
    Closes #1609.
---
 src/main/java/org/apache/sysds/lops/Federated.java |  27 ++-
 .../org/apache/sysds/parser/DataExpression.java    |  24 ++-
 .../controlprogram/federated/FederatedData.java    |  14 ++
 .../federated/FederatedWorkerHandler.java          |  35 ++--
 .../instructions/fed/InitFEDInstruction.java       | 212 +++++++++++++++++++--
 .../sysds/runtime/matrix/data/MatrixBlock.java     |   9 +
 src/test/java/org/apache/sysds/test/TestUtils.java |   5 +-
 .../primitives/FederatedTransferLocalDataTest.java | 125 ++++++++++++
 .../federated/FederatedTransferLocalDataTest.dml   |  34 ++++
 .../FederatedTransferLocalDataTestReference.dml    |  23 +++
 10 files changed, 467 insertions(+), 41 deletions(-)

diff --git a/src/main/java/org/apache/sysds/lops/Federated.java 
b/src/main/java/org/apache/sysds/lops/Federated.java
index 52b52be544..2ed1de2fdb 100644
--- a/src/main/java/org/apache/sysds/lops/Federated.java
+++ b/src/main/java/org/apache/sysds/lops/Federated.java
@@ -25,11 +25,12 @@ import java.util.HashMap;
 import static org.apache.sysds.common.Types.DataType;
 import static org.apache.sysds.common.Types.ValueType;
 import static org.apache.sysds.parser.DataExpression.FED_ADDRESSES;
+import static org.apache.sysds.parser.DataExpression.FED_LOCAL_OBJECT;
 import static org.apache.sysds.parser.DataExpression.FED_RANGES;
 import static org.apache.sysds.parser.DataExpression.FED_TYPE;
 
 public class Federated extends Lop {
-       private Lop _type, _addresses, _ranges;
+       private Lop _type, _addresses, _ranges, _localObject;
        
        public Federated(HashMap<String, Lop> inputLops, DataType dataType, 
ValueType valueType) {
                super(Type.Federated, dataType, valueType);
@@ -43,6 +44,12 @@ public class Federated extends Lop {
                _addresses.addOutput(this);
                addInput(_ranges);
                _ranges.addOutput(this);
+
+               if(inputLops.size() == 4) {
+                       _localObject = inputLops.get(FED_LOCAL_OBJECT);
+                       addInput(_localObject);
+                       _localObject.addOutput(this);
+               }
        }
        
        @Override
@@ -60,6 +67,24 @@ public class Federated extends Lop {
                sb.append(prepOutputOperand(output));
                return sb.toString();
        }
+
+       @Override
+       public String getInstructions(String type, String addresses, String 
ranges, String object, String output) {
+               StringBuilder sb = new StringBuilder("FED");
+               sb.append(OPERAND_DELIMITOR);
+               sb.append("fedinit");
+               sb.append(OPERAND_DELIMITOR);
+               sb.append(_type.prepScalarInputOperand(type));
+               sb.append(OPERAND_DELIMITOR);
+               sb.append(_addresses.prepScalarInputOperand(addresses));
+               sb.append(OPERAND_DELIMITOR);
+               sb.append(_ranges.prepScalarInputOperand(ranges));
+               sb.append(OPERAND_DELIMITOR);
+               sb.append(_localObject.prepScalarInputOperand(object));
+               sb.append(OPERAND_DELIMITOR);
+               sb.append(prepOutputOperand(output));
+               return sb.toString();
+       }
        
        @Override
        public String toString() {
diff --git a/src/main/java/org/apache/sysds/parser/DataExpression.java 
b/src/main/java/org/apache/sysds/parser/DataExpression.java
index 2f25809762..e2e3996cea 100644
--- a/src/main/java/org/apache/sysds/parser/DataExpression.java
+++ b/src/main/java/org/apache/sysds/parser/DataExpression.java
@@ -87,6 +87,7 @@ public class DataExpression extends DataIdentifier
        public static final String FED_ADDRESSES = "addresses";
        public static final String FED_RANGES = "ranges";
        public static final String FED_TYPE = "type";
+       public static final String FED_LOCAL_OBJECT = "local_matrix";
        
        public static final String FORMAT_TYPE = "format";
        
@@ -132,7 +133,7 @@ public class DataExpression extends DataIdentifier
                Arrays.asList(SQL_CONN, SQL_USER, SQL_PASS, SQL_QUERY));
        
        public static final Set<String> FEDERATED_VALID_PARAM_NAMES = new 
HashSet<>(
-               Arrays.asList(FED_ADDRESSES, FED_RANGES, FED_TYPE));
+               Arrays.asList(FED_ADDRESSES, FED_RANGES, FED_TYPE, 
FED_LOCAL_OBJECT));
 
        /** Valid parameter names in metadata file */
        public static final Set<String> READ_VALID_MTD_PARAM_NAMES =new 
HashSet<>(
@@ -540,6 +541,16 @@ public class DataExpression extends DataIdentifier
                                param = passedParamExprs.get(2);
                                
dataExpr.addFederatedExprParam(DataExpression.FED_TYPE, param.getExpr());
                        }
+                       else if(unnamedParamCount == 4) {
+                               ParameterExpression param = 
passedParamExprs.get(0);
+                               
dataExpr.addFederatedExprParam(DataExpression.FED_LOCAL_OBJECT, 
param.getExpr());
+                               param = passedParamExprs.get(1);
+                               
dataExpr.addFederatedExprParam(DataExpression.FED_ADDRESSES, param.getExpr());
+                               param = passedParamExprs.get(2);
+                               
dataExpr.addFederatedExprParam(DataExpression.FED_RANGES, param.getExpr());
+                               param = passedParamExprs.get(3);
+                               
dataExpr.addFederatedExprParam(DataExpression.FED_TYPE, param.getExpr());
+                       }
                        else {
                                errorListener.validationError(parseInfo,
                                        "for federated statement, at most 3 
arguments are supported: addresses, ranges, type");
@@ -888,7 +899,7 @@ public class DataExpression extends DataIdentifier
                                raiseValidateError("UDF function call not 
supported as parameter to built-in function call", 
false,LanguageErrorCodes.INVALID_PARAMETERS);
                        }
                        inputParamExpr.validateExpression(ids, currConstVars, 
conditional);
-                       if (s != null && !s.equals(RAND_DATA) && 
!s.equals(RAND_DIMS) && !s.equals(FED_ADDRESSES) && !s.equals(FED_RANGES)
+                       if (s != null && !s.equals(RAND_DATA) && 
!s.equals(RAND_DIMS) && !s.equals(FED_ADDRESSES) && !s.equals(FED_RANGES) && 
!s.equals(FED_LOCAL_OBJECT)
                                        && !s.equals(DELIM_NA_STRINGS) && 
!s.equals(SCHEMAPARAM) && getVarParam(s).getOutput().getDataType() != 
DataType.SCALAR ) {
                                raiseValidateError("Non-scalar data types are 
not supported for data expression.", 
conditional,LanguageErrorCodes.INVALID_PARAMETERS);
                        }
@@ -2195,7 +2206,16 @@ public class DataExpression extends DataIdentifier
                        else 
if(fedType.getValue().equalsIgnoreCase(FED_FRAME_IDENTIFIER)) {
                                getOutput().setDataType(DataType.FRAME);
                        }
+
+                       if(_varParams.size() == 4) {
+                               exp = getVarParam(FED_LOCAL_OBJECT);
+                               if( !(exp instanceof DataIdentifier) ) {
+                                       raiseValidateError("for federated 
statement " + FED_LOCAL_OBJECT + " has incorrect value type", conditional);
+                               }
+                               
getVarParam(FED_LOCAL_OBJECT).validateExpression(ids, currConstVars, 
conditional);
+                       }
                        getOutput().setDimensions(-1, -1);
+
                        break;
                        
                default:
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
index 1fb1e8b1ec..370163aaf2 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
@@ -34,6 +34,7 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
 import org.apache.sysds.runtime.controlprogram.paramserv.NetworkTrafficCounter;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -150,6 +151,19 @@ public class FederatedData {
                return executeFederatedOperation(request);
        }
 
+       public synchronized Future<FederatedResponse> 
initFederatedDataFromLocal(long id, CacheBlock block) {
+               if(isInitialized())
+                       throw new DMLRuntimeException("Tried to init already 
initialized data");
+               if(!_dataType.isMatrix() && !_dataType.isFrame())
+                       throw new DMLRuntimeException("Federated datatype \"" + 
_dataType.toString() + "\" is not supported.");
+               _varID = id;
+               FederatedRequest request = new 
FederatedRequest(RequestType.READ_VAR, id);
+               request.appendParam(_filepath);
+               request.appendParam(_dataType.name());
+               request.appendParam(block);
+               return executeFederatedOperation(request);
+       }
+
        public Future<FederatedResponse> 
executeFederatedOperation(FederatedRequest... request) {
                return executeFederatedOperation(_address, request);
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index d0865df120..592f77ccce 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -89,15 +89,14 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
        private static final Log LOG = 
LogFactory.getLog(FederatedWorkerHandler.class.getName());
 
        /** The Federated Lookup Table of the current Federated Worker. */
-       private FederatedLookupTable _flt;
+       private final FederatedLookupTable _flt;
 
        /** Read cache shared by all worker handlers */
-       private FederatedReadCache _frc;
+       private final FederatedReadCache _frc;
        private Timing _timing = null;
-
-
+       
        /** Federated workload analyzer */
-       private FederatedWorkloadAnalyzer _fan;
+       private final FederatedWorkloadAnalyzer _fan;
 
        /**
         * Create a Federated Worker Handler.
@@ -114,12 +113,12 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                _frc = frc;
                _fan = fan;
        }
-
+       
        public FederatedWorkerHandler(FederatedLookupTable flt, 
FederatedReadCache frc, FederatedWorkloadAnalyzer fan, Timing timing) {
                this(flt, frc, fan);
                _timing = timing;
        }
-
+       
        @Override
        public void channelRead(ChannelHandlerContext ctx, Object msg) {
                ctx.writeAndFlush(createResponse(msg, 
ctx.channel().remoteAddress()))
@@ -138,7 +137,7 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                } catch (RuntimeException ignored) {
                        // ignore timing if it wasn't started yet
                }
-
+               
                String host;
                if(remoteAddress instanceof InetSocketAddress) {
                        host = ((InetSocketAddress) 
remoteAddress).getHostString();
@@ -183,7 +182,7 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
 
        private FederatedResponse createResponse(FederatedRequest[] requests, 
String remoteHost)
                throws DMLPrivacyException, FederatedWorkerHandlerException, 
Exception {
-
+                       
                FederatedResponse response = null; // last response
                boolean containsCLEAR = false;
                for(int i = 0; i < requests.length; i++) {
@@ -272,14 +271,15 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
        }
 
        private FederatedResponse readData(FederatedRequest request, 
ExecutionContextMap ecm) {
-               checkNumParams(request.getNumParams(), 2);
+               checkNumParams(request.getNumParams(), 2, 3);
                String filename = (String) request.getParam(0);
                DataType dt = DataType.valueOf((String) request.getParam(1));
-               return readData(filename, dt, request.getID(), 
request.getTID(), ecm);
+               return readData(filename, dt, request.getID(), 
request.getTID(), ecm,
+                       request.getNumParams() == 2 ? null : 
(CacheBlock)request.getParam(2));
        }
 
        private FederatedResponse readData(String filename, DataType dataType,
-               long id, long tid, ExecutionContextMap ecm) {
+               long id, long tid, ExecutionContextMap ecm, CacheBlock 
localBlock) {
                MatrixCharacteristics mc = new MatrixCharacteristics();
                mc.setBlocksize(ConfigurationManager.getBlocksize());
 
@@ -299,7 +299,7 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                        cd = _frc.get(filename, !linReuse);
                        try {
                                if(cd == null) { // data is neither in lineage 
cache nor in read cache
-                                       cd = readDataNoReuse(filename, 
dataType, mc); // actual read of the data
+                                       cd = localBlock == null ? 
readDataNoReuse(filename, dataType, mc) : 
ExecutionContext.createCacheableData(localBlock); // actual read of the data
                                        if(linReuse) // put the object into the 
lineage cache
                                                
LineageCache.putFedReadObject(cd, linItem, ec);
                                        else
@@ -315,7 +315,7 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                                throw ex;
                        }
                }
-
+               
                if(shouldTryAsyncCompress()) // TODO: replace the reused object
                        CompressedMatrixBlockFactory.compressAsync(ec, sId);
 
@@ -426,7 +426,7 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                        throw new FederatedWorkerHandlerException(
                                "Unsupported object type, has to be of type 
CacheBlock or ScalarObject");
 
-
+                               
                // set variable and construct empty response
                ec.setVariable(varName, data);
 
@@ -450,12 +450,13 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
 
        private FederatedResponse getVariable(FederatedRequest request, 
ExecutionContextMap ecm) {
                try{
+
                        checkNumParams(request.getNumParams(), 0);
                        ExecutionContext ec = ecm.get(request.getTID());
                        
if(!ec.containsVariable(String.valueOf(request.getID())))
                                throw new FederatedWorkerHandlerException(
                                        "Variable " + request.getID() + " does 
not exist at federated worker.");
-
+       
                        // get variable and construct response
                        Data dataObject = 
ec.getVariable(String.valueOf(request.getID()));
                        dataObject = PrivacyMonitor.handlePrivacy(dataObject);
@@ -487,7 +488,7 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                adaptToWorkload(ec, _fan, tid, ins);
                return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
        }
-
+       
        private static ExecutionContext getContextForInstruction(long id, 
Instruction ins, ExecutionContextMap ecm){
                final ExecutionContext ec = ecm.get(id);
                //handle missing spark execution context
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index 6e18115835..3e648bbe3b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -33,24 +33,25 @@ import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 
-import org.apache.sysds.api.DMLScript;
 import org.apache.commons.lang3.tuple.ImmutablePair;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.conf.DMLConfig;
 import org.apache.sysds.hops.fedplanner.FTypes.FType;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
 import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
 import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
-import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -60,6 +61,8 @@ import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.cp.StringObject;
 import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.lineage.LineageTraceable;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 
 public class InitFEDInstruction extends FEDInstruction implements 
LineageTraceable {
@@ -69,7 +72,7 @@ public class InitFEDInstruction extends FEDInstruction 
implements LineageTraceab
        public static final String FED_MATRIX_IDENTIFIER = "matrix";
        public static final String FED_FRAME_IDENTIFIER = "frame";
 
-       private CPOperand _type, _addresses, _ranges, _output;
+       private CPOperand _type, _addresses, _ranges, _localObject, _output;
 
        public InitFEDInstruction(CPOperand type, CPOperand addresses, 
CPOperand ranges, CPOperand out, String opcode,
                String instr) {
@@ -80,44 +83,66 @@ public class InitFEDInstruction extends FEDInstruction 
implements LineageTraceab
                _output = out;
        }
 
+       public InitFEDInstruction(CPOperand type, CPOperand addresses, 
CPOperand ranges, CPOperand object, CPOperand out, String opcode,
+               String instr) {
+               this(type, addresses, ranges, out, opcode, instr);
+               _localObject = object;
+       }
+
        public static InitFEDInstruction parseInstruction(String str) {
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
                // We need 5 parts: Opcode, Type (Frame/Matrix), Addresses 
(list of Strings with
                // url/ip:port/filepath), ranges and the output Operand
-               if(parts.length != 5)
+               if(parts.length != 5 && parts.length != 6)
                        throw new DMLRuntimeException("Invalid number of 
operands in federated instruction: " + str);
                String opcode = parts[0];
 
-               CPOperand type, addresses, ranges, out;
-               type = new CPOperand(parts[1]);
-               addresses = new CPOperand(parts[2]);
-               ranges = new CPOperand(parts[3]);
-               out = new CPOperand(parts[4]);
-               return new InitFEDInstruction(type, addresses, ranges, out, 
opcode, str);
+               if(parts.length == 5) {
+                       CPOperand type, addresses, ranges, out;
+                       type = new CPOperand(parts[1]);
+                       addresses = new CPOperand(parts[2]);
+                       ranges = new CPOperand(parts[3]);
+                       out = new CPOperand(parts[4]);
+                       return new InitFEDInstruction(type, addresses, ranges, 
out, opcode, str);
+               } else {
+                       CPOperand type, addresses, object, ranges, out;
+                       type = new CPOperand(parts[1]);
+                       addresses = new CPOperand(parts[2]);
+                       ranges = new CPOperand(parts[3]);
+                       object = new CPOperand(parts[4]);
+                       out = new CPOperand(parts[5]);
+                       return new InitFEDInstruction(type, addresses, ranges, 
object, out, opcode, str);
+               }
        }
 
        @Override
        public void processInstruction(ExecutionContext ec) {
+               if(_localObject == null)
+                       processFedInit(ec);
+               else
+                       processFromLocalFedInit(ec);
+       }
+
+       private void processFedInit(ExecutionContext ec){
                String type = ec.getScalarInput(_type).getStringValue();
                ListObject addresses = ec.getListObject(_addresses.getName());
                ListObject ranges = ec.getListObject(_ranges.getName());
                List<Pair<FederatedRange, FederatedData>> feds = new 
ArrayList<>();
 
                if(addresses.getLength() * 2 != ranges.getLength())
-                       throw new DMLRuntimeException("Federated read needs 
twice the amount of addresses as ranges "
-                               + "(begin and end): addresses=" + 
addresses.getLength() + " ranges=" + ranges.getLength());
+                       throw new DMLRuntimeException("Federated read needs 
twice the amount of addresses as ranges " + "(begin and end): addresses=" + 
addresses.getLength() + " ranges=" + ranges.getLength());
 
                //check for duplicate addresses (would lead to overwrite with 
common variable names)
                // TODO relax requirement by using different execution contexts 
per federated data?
                Set<String> addCheck = new HashSet<>();
                for( Data dat : addresses.getData() )
                        if( dat instanceof StringObject ) {
-                               String address = 
((StringObject)dat).getStringValue();
+                               String address = ((StringObject) 
dat).getStringValue();
                                if(addCheck.contains(address))
                                        LOG.warn("Federated data contains 
address duplicates: " + addresses);
                                addCheck.add(address);
                        }
-               
+
                Types.DataType fedDataType;
                if(type.equalsIgnoreCase(FED_MATRIX_IDENTIFIER))
                        fedDataType = Types.DataType.MATRIX;
@@ -136,6 +161,103 @@ public class InitFEDInstruction extends FEDInstruction 
implements LineageTraceab
                                int port = Integer.parseInt(parsedValues[1]);
                                String filePath = parsedValues[2];
 
+                               if(DMLScript.FED_STATISTICS)
+                                       // register the federated worker for 
federated statistics creation
+                                       
FederatedStatistics.registerFedWorker(host, port);
+
+                               // get beginning and end of data ranges
+                               List<Data> rangesData = ranges.getData();
+                               Data beginData = rangesData.get(i * 2);
+                               Data endData = rangesData.get(i * 2 + 1);
+                               if(beginData.getDataType() != 
Types.DataType.LIST || endData.getDataType() != Types.DataType.LIST)
+                                       throw new 
DMLRuntimeException("Federated read ranges (lower, upper) have to be lists of 
dimensions");
+                               List<Data> beginDimsData = ((ListObject) 
beginData).getData();
+                               List<Data> endDimsData = ((ListObject) 
endData).getData();
+
+                               // fill begin and end dims
+                               long[] beginDims = new 
long[beginDimsData.size()];
+                               long[] endDims = new long[beginDims.length];
+                               for(int d = 0; d < beginDims.length; d++) {
+                                       beginDims[d] = ((ScalarObject) 
beginDimsData.get(d)).getLongValue();
+                                       endDims[d] = ((ScalarObject) 
endDimsData.get(d)).getLongValue();
+                               }
+
+                               usedDims[0] = Math.max(usedDims[0], endDims[0]);
+                               usedDims[1] = Math.max(usedDims[1], endDims[1]);
+                               try {
+                                       FederatedData federatedData = new 
FederatedData(fedDataType,
+                                               new 
InetSocketAddress(InetAddress.getByName(host), port), filePath);
+                                       feds.add(new ImmutablePair<>(new 
FederatedRange(beginDims, endDims), federatedData));
+                               }
+                               catch(UnknownHostException e) {
+                                       throw new 
DMLRuntimeException("federated host was unknown: " + host);
+                               }
+                       }
+                       else {
+                               throw new DMLRuntimeException("federated 
instruction only takes strings as addresses");
+                       }
+               }
+
+               if(type.equalsIgnoreCase(FED_MATRIX_IDENTIFIER)) {
+                       CacheableData<?> output = ec.getCacheableData(_output);
+                       
output.getDataCharacteristics().setRows(usedDims[0]).setCols(usedDims[1]);
+                       federateMatrix(output, feds, null);
+               }
+               else if(type.equalsIgnoreCase(FED_FRAME_IDENTIFIER)) {
+                       if(usedDims[1] > Integer.MAX_VALUE)
+                               throw new DMLRuntimeException("federated Frame 
can not have more than max int columns, because the "
+                                       + "schema can only be max int length");
+                       FrameObject output = ec.getFrameObject(_output);
+                       
output.getDataCharacteristics().setRows(usedDims[0]).setCols(usedDims[1]);
+                       federateFrame(output, feds, null);
+               }
+               else {
+                       throw new DMLRuntimeException("type \"" + type + "\" 
non valid federated type");
+               }
+       }
+
+       public void processFromLocalFedInit(ExecutionContext ec) {
+               String type = ec.getScalarInput(_type).getStringValue();
+               ListObject addresses = ec.getListObject(_addresses.getName());
+               ListObject ranges = ec.getListObject(_ranges.getName());
+               List<Pair<FederatedRange, FederatedData>> feds = new 
ArrayList<>();
+
+               CacheableData<?> co = ec.getCacheableData(_localObject);
+               CacheBlock cb =  co.acquireReadAndRelease();
+
+               if(addresses.getLength() * 2 != ranges.getLength())
+                       throw new DMLRuntimeException("Federated read needs 
twice the amount of addresses as ranges "
+                               + "(begin and end): addresses=" + 
addresses.getLength() + " ranges=" + ranges.getLength());
+
+               //check for duplicate addresses (would lead to overwrite with 
common variable names)
+               Set<String> addCheck = new HashSet<>();
+               for(Data dat : addresses.getData())
+                       if(dat instanceof StringObject) {
+                               String address = ((StringObject) 
dat).getStringValue();
+                               if(addCheck.contains(address))
+                                       LOG.warn("Federated data contains 
address duplicates: " + addresses);
+                               addCheck.add(address);
+                       }
+
+               Types.DataType fedDataType;
+               if(type.equalsIgnoreCase(FED_MATRIX_IDENTIFIER))
+                       fedDataType = Types.DataType.MATRIX;
+               else if(type.equalsIgnoreCase(FED_FRAME_IDENTIFIER))
+                       fedDataType = Types.DataType.FRAME;
+               else
+                       throw new DMLRuntimeException("type \"" + type + "\" 
non valid federated type");
+
+               long[] usedDims = new long[] {0, 0};
+               CacheBlock[] cbs = new CacheBlock[addresses.getLength()];
+               for(int i = 0; i < addresses.getLength(); i++) {
+                       Data addressData = addresses.getData().get(i);
+                       if(addressData instanceof StringObject) {
+                               // We split address into url/ip, the port and 
file path of file to read
+                               String[] parsedValues = 
parseURLNoFilePath(((StringObject) addressData).getStringValue());
+                               String host = parsedValues[0];
+                               int port = Integer.parseInt(parsedValues[1]);
+                               String filePath = co.getFileName();
+
                                if(DMLScript.FED_STATISTICS)
                                        // register the federated worker for 
federated statistics creation
                                        
FederatedStatistics.registerFedWorker(host, port);
@@ -159,6 +281,11 @@ public class InitFEDInstruction extends FEDInstruction 
implements LineageTraceab
                                }
                                usedDims[0] = Math.max(usedDims[0], endDims[0]);
                                usedDims[1] = Math.max(usedDims[1], endDims[1]);
+
+                               CacheBlock slice = cb instanceof MatrixBlock ? 
((MatrixBlock)cb).slice((int) beginDims[0], (int) endDims[0]-1, (int) 
beginDims[1], (int) endDims[1]-1, true) :
+                                       ((FrameBlock)cb).slice((int) 
beginDims[0], (int) endDims[0]-1, (int) beginDims[1], (int) endDims[1]-1, true, 
new FrameBlock());
+                               cbs[i] = slice;
+
                                try {
                                        FederatedData federatedData = new 
FederatedData(fedDataType,
                                                new 
InetSocketAddress(InetAddress.getByName(host), port), filePath);
@@ -172,10 +299,11 @@ public class InitFEDInstruction extends FEDInstruction 
implements LineageTraceab
                                throw new DMLRuntimeException("federated 
instruction only takes strings as addresses");
                        }
                }
+
                if(type.equalsIgnoreCase(FED_MATRIX_IDENTIFIER)) {
                        CacheableData<?> output = ec.getCacheableData(_output);
                        
output.getDataCharacteristics().setRows(usedDims[0]).setCols(usedDims[1]);
-                       federateMatrix(output, feds);
+                       federateMatrix(output, feds, cbs);
                }
                else if(type.equalsIgnoreCase(FED_FRAME_IDENTIFIER)) {
                        if(usedDims[1] > Integer.MAX_VALUE)
@@ -183,13 +311,44 @@ public class InitFEDInstruction extends FEDInstruction 
implements LineageTraceab
                                        + "schema can only be max int length");
                        FrameObject output = ec.getFrameObject(_output);
                        
output.getDataCharacteristics().setRows(usedDims[0]).setCols(usedDims[1]);
-                       federateFrame(output, feds);
+                       federateFrame(output, feds, cbs);
                }
                else {
                        throw new DMLRuntimeException("type \"" + type + "\" 
non valid federated type");
                }
        }
 
+       public static String[] parseURLNoFilePath(String input) {
+               try {
+                       // Artificially making it http protocol.
+                       // This is to avoid malformed address error in the URL 
passing.
+                       // TODO: Construct new protocol name for Federated 
communication
+                       URL address = new URL("http://"; + input);
+                       String host = address.getHost();
+                       if(host.length() == 0)
+                               throw new IllegalArgumentException("Missing 
Host name for federated address");
+                       // The current system does not support ipv6, only ipv4.
+                       // TODO: Support IPV6 address for Federated 
communication
+                       String ipRegex = 
"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$";
+                       if(host.matches("^\\d+\\.\\d+\\.\\d+\\.\\d+$") && 
!host.matches(ipRegex))
+                               throw new IllegalArgumentException("Input Host 
address looks like an IP address but is outside range");
+                       int port = address.getPort();
+                       if(port == -1)
+                               port = DMLConfig.DEFAULT_FEDERATED_PORT;
+                       if(address.getQuery() != null)
+                               throw new IllegalArgumentException("Query is 
not supported");
+
+                       if(address.getRef() != null)
+                               throw new IllegalArgumentException("Reference 
is not supported");
+
+                       return new String[] {host, String.valueOf(port)};
+               }
+               catch(MalformedURLException e) {
+                       throw new IllegalArgumentException(
+                               "federated address `" + input + "` does not fit 
required URL pattern of \"host:port/directory\"", e);
+               }
+       }
+
        public static String[] parseURL(String input) {
                try {
                        // Artificially making it http protocol.
@@ -231,6 +390,10 @@ public class InitFEDInstruction extends FEDInstruction 
implements LineageTraceab
        }
 
        public static void federateMatrix(CacheableData<?> output, 
List<Pair<FederatedRange, FederatedData>> workers) {
+               federateMatrix(output, workers, null);
+       }
+
+       public static void federateMatrix(CacheableData<?> output, 
List<Pair<FederatedRange, FederatedData>> workers, CacheBlock[] blocks) {
 
                List<Pair<FederatedRange, FederatedData>> fedMapping = new 
ArrayList<>();
                for(Pair<FederatedRange, FederatedData> e : workers)
@@ -239,6 +402,7 @@ public class InitFEDInstruction extends FEDInstruction 
implements LineageTraceab
                long id = FederationUtils.getNextFedDataID();
                boolean rowPartitioned = true;
                boolean colPartitioned = true;
+               int k = 0;
                for(Pair<FederatedRange, FederatedData> entry : fedMapping) {
                        FederatedRange range = entry.getKey();
                        FederatedData value = entry.getValue();
@@ -248,7 +412,10 @@ public class InitFEDInstruction extends FEDInstruction 
implements LineageTraceab
                                long[] dims = 
output.getDataCharacteristics().getDims();
                                for(int i = 0; i < dims.length; i++)
                                        dims[i] = endDims[i] - beginDims[i];
-                               idResponses.add(new ImmutablePair<>(value, 
value.initFederatedData(id)));
+                               if(blocks == null || blocks.length == 0)
+                                       idResponses.add(new 
ImmutablePair<>(value, value.initFederatedData(id)));
+                               else
+                                       idResponses.add(new 
ImmutablePair<>(value, value.initFederatedDataFromLocal(id, blocks[k++])));
                        }
                        rowPartitioned &= (range.getSize(1) == 
output.getNumColumns());
                        colPartitioned &= (range.getSize(0) == 
output.getNumRows());
@@ -284,7 +451,7 @@ public class InitFEDInstruction extends FEDInstruction 
implements LineageTraceab
                        LOG.debug("Fed map Inited:" + output.getFedMapping());
        }
 
-       public static void federateFrame(FrameObject output, 
List<Pair<FederatedRange, FederatedData>> workers) {
+       public static void federateFrame(FrameObject output, 
List<Pair<FederatedRange, FederatedData>> workers, CacheBlock[] blocks) {
                List<Pair<FederatedRange, FederatedData>> fedMapping = new 
ArrayList<>();
                for(Pair<FederatedRange, FederatedData> e : workers)
                        fedMapping.add(e);
@@ -295,6 +462,7 @@ public class InitFEDInstruction extends FEDInstruction 
implements LineageTraceab
                long id = FederationUtils.getNextFedDataID();
                boolean rowPartitioned = true;
                boolean colPartitioned = true;
+               int k = 0;
                for(Pair<FederatedRange, FederatedData> entry : fedMapping) {
                        FederatedRange range = entry.getKey();
                        FederatedData value = entry.getValue();
@@ -305,8 +473,12 @@ public class InitFEDInstruction extends FEDInstruction 
implements LineageTraceab
                                for(int i = 0; i < dims.length; i++) {
                                        dims[i] = endDims[i] - beginDims[i];
                                }
-                               idResponses.add(
-                                       new ImmutablePair<>(value, new 
ImmutablePair<>((int) beginDims[1], value.initFederatedData(id))));
+                               if(blocks == null || blocks.length == 0)
+                                       idResponses.add(
+                                               new ImmutablePair<>(value, new 
ImmutablePair<>((int) beginDims[1], value.initFederatedData(id))));
+                               else
+                                       idResponses.add(
+                                               new ImmutablePair<>(value, new 
ImmutablePair<>((int) beginDims[1], value.initFederatedDataFromLocal(id, 
blocks[k++]))));
                        }
                        rowPartitioned &= (range.getSize(1) == 
output.getNumColumns());
                        colPartitioned &= (range.getSize(0) == 
output.getNumRows());
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 315871ef50..ee4e668c55 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -4147,6 +4147,15 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                return slice(rl, ru, cl, cu, true, ret);
        }
 
+       /**
+        * Slice out a row block
+        * @param rl The row lower to start from
+        * @param ru The row lower to end at
+        * @param cl The col lower to start from
+        * @param cu The col lower to end at
+        * @param deep Deep copy or not
+        * @return The sliced out matrix block.
+        */
        public final MatrixBlock slice(int rl, int ru, int cl, int cu, boolean 
deep){
                return slice(rl, ru, cl, cu, deep, null);
        }
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java 
b/src/test/java/org/apache/sysds/test/TestUtils.java
index 64785c4ed8..1096a5147e 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -1839,7 +1839,6 @@ public class TestUtils
        public static double[][] generateTestMatrix(int rows, int cols, double 
min, double max, double sparsity, long seed) {
                double[][] matrix = new double[rows][cols];
                Random random = (seed == -1) ? TestUtils.random : new 
Random(seed);
-
                for (int i = 0; i < rows; i++) {
                        for (int j = 0; j < cols; j++) {
                                if (random.nextDouble() > sparsity)
@@ -3022,6 +3021,10 @@ public class TestUtils
                return host + ':' + port + '/' + input;
        }
 
+       public static String federatedAddressNoInput(String host, int port) {
+               return host + ':' + port;
+       }
+
        public static double gaussian_probability (double point)
        //  "Handbook of Mathematical Functions", ed. by M. Abramowitz and I.A. 
Stegun,
        //  U.S. Nat-l Bureau of Standards, 10th print (Dec 1972), Sec. 7.1.26, 
p. 299
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTransferLocalDataTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTransferLocalDataTest.java
new file mode 100644
index 0000000000..f612ca14e4
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTransferLocalDataTest.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedTransferLocalDataTest extends AutomatedTestBase {
+       private final static String TEST_DIR = "functions/federated/";
+       private final static String TEST_NAME1 = 
"FederatedTransferLocalDataTest";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedTransferLocalDataTest.class.getSimpleName() + "/";
+
+       private final static int blocksize = 1024;
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+       @Parameterized.Parameter(2)
+       public boolean rowPartitioned;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"}));
+       }
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               return Arrays.asList(new Object[][] {
+                       {12, 4, true}, {12, 4, false},
+               });
+       }
+
+       @Test
+       public void federatedTransferCP() { 
runTransferTest(Types.ExecMode.SINGLE_NODE); }
+
+       @Test
+       public void federatedTransferSP() { 
runTransferTest(Types.ExecMode.SPARK); }
+
+       private void runTransferTest(Types.ExecMode execMode) {
+               String TEST_NAME = TEST_NAME1;
+               ExecMode platformOld = setExecMode(execMode);
+
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               // write input matrices
+               double[][] X = getRandomMatrix(rows, cols, 1, 5, 1, 3);
+
+               MatrixCharacteristics mc = new MatrixCharacteristics(rows, 
cols, blocksize, (long) rows * cols);
+               writeInputMatrixWithMTD("X", X, false, mc);
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               int port3 = getRandomAvailablePort();
+               int port4 = getRandomAvailablePort();
+               Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+               Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+               Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
+               Thread t4 = startLocalFedWorkerThread(port4);
+
+               rtplatform = execMode;
+               if(rtplatform == Types.ExecMode.SPARK) {
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               }
+               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+               loadTestConfiguration(config);
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] {"-stats", "100", "-args", 
input("X"), expected("S")};
+
+               runTest(null);
+
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats", "100", "-nvargs",
+                       "in_X=" + input("X"),
+                       "in_X1=" + 
TestUtils.federatedAddressNoInput("localhost", port1),
+                       "in_X2=" + 
TestUtils.federatedAddressNoInput("localhost", port2),
+                       "in_X3=" + 
TestUtils.federatedAddressNoInput("localhost", port3),
+                       "in_X4=" + 
TestUtils.federatedAddressNoInput("localhost", port4), "rows=" + rows, "cols=" 
+ cols,
+                       "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), 
"out_S=" + output("S")};
+
+               runTest(null);
+
+               // compare via files
+               compareResults(1e-9, "Stat-DML1", "Stat-DML2");
+
+               TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+               resetExecMode(platformOld);
+       }
+}
diff --git 
a/src/test/scripts/functions/federated/FederatedTransferLocalDataTest.dml 
b/src/test/scripts/functions/federated/FederatedTransferLocalDataTest.dml
new file mode 100644
index 0000000000..e4dd5180db
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedTransferLocalDataTest.dml
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A1 = read($in_X);
+
+if ($rP) {
+  A = federated(local_matrix=A1, addresses=list($in_X1, $in_X2, $in_X3, 
$in_X4),
+    ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+    list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+} else {
+  A = federated(local_matrix=A1, addresses=list($in_X1, $in_X2, $in_X3, 
$in_X4),
+    ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, 
$cols/2),
+    list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), 
list($rows, $cols)));
+}
+print(toString(A))
+write(A, $out_S);
diff --git 
a/src/test/scripts/functions/federated/FederatedTransferLocalDataTestReference.dml
 
b/src/test/scripts/functions/federated/FederatedTransferLocalDataTestReference.dml
new file mode 100644
index 0000000000..dcce9747e8
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/FederatedTransferLocalDataTestReference.dml
@@ -0,0 +1,23 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+write(A, $2);

Reply via email to