This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push:
new 833750f [SYSTEMDS-224] Support for federated frames (construction,
collect)
833750f is described below
commit 833750fa5c09e32518ba19196dcee2d321836d74
Author: Kevin Innerebner <[email protected]>
AuthorDate: Sat May 2 22:45:17 2020 +0200
[SYSTEMDS-224] Support for federated frames (construction, collect)
This PR adds federated frames. It also uses the federated builtin,
which now has an additional parameter type. type can either be set to
"matrix" or "frame" (case-insensitive). SPARK execution mode is not
supported (as with many other federated commands), once we have a
better integration of federation during instruction building (always
knowing if hop/lop is federated) both CP and SPARK should be supported.
Closes #893.
---
docs/Tasks.txt | 5 +-
src/main/java/org/apache/sysds/lops/Federated.java | 10 +-
.../org/apache/sysds/parser/DataExpression.java | 77 ++++++++----
.../controlprogram/caching/CacheableData.java | 29 +++++
.../controlprogram/caching/FrameObject.java | 60 +++++++--
.../controlprogram/caching/MatrixObject.java | 103 ++++++----------
.../controlprogram/federated/FederatedData.java | 24 +++-
.../controlprogram/federated/FederatedRequest.java | 2 +-
.../federated/FederatedResponse.java | 15 ++-
.../controlprogram/federated/FederatedWorker.java | 6 +-
.../federated/FederatedWorkerHandler.java | 54 ++++++---
.../controlprogram/federated/LibFederatedAgg.java | 2 +-
.../fed/AggregateBinaryFEDInstruction.java | 10 +-
.../fed/BinaryMatrixScalarFEDInstruction.java | 8 +-
.../instructions/fed/InitFEDInstruction.java | 125 ++++++++++++++++---
.../apache/sysds/runtime/util/UtilFunctions.java | 46 +++++--
.../org/apache/sysds/test/AutomatedTestBase.java | 47 +++++---
src/test/java/org/apache/sysds/test/TestUtils.java | 134 ++++++++++++++++-----
.../federated/FederatedConstructionTest.java | 88 ++++++++++----
...ence.dml => FederatedFrameConstructionTest.dml} | 4 +-
...=> FederatedFrameConstructionTestReference.dml} | 0
...est.dml => FederatedMatrixConstructionTest.dml} | 0
...> FederatedMatrixConstructionTestReference.dml} | 0
23 files changed, 609 insertions(+), 240 deletions(-)
diff --git a/docs/Tasks.txt b/docs/Tasks.txt
index 86ab8fe..209a01c 100644
--- a/docs/Tasks.txt
+++ b/docs/Tasks.txt
@@ -178,13 +178,14 @@ SYSTEMDS-210 Extended lists Operations
SYSTEMDS-220 Federated Tensors and Instructions
* 221 Initial infrastructure federated operations OK
- * 222 Federated matrix-vector multiplication
+ * 222 Federated matrix-vector multiplication OK
* 223 Federated unary aggregates OK
- * 224 Federated transform functionality
+ * 224 Federated frames OK
* 225 Federated elementwise operations OK
* 226 Federated rbind and cbind OK
* 227 Federated worker setup and infrastructure OK
* 228 Federated matrix-matrix multiplication OK
+ * 229 Federated transform functionality
SYSTEMDS-230 Lineage Integration
* 231 Use lineage in buffer pool
diff --git a/src/main/java/org/apache/sysds/lops/Federated.java
b/src/main/java/org/apache/sysds/lops/Federated.java
index 3d508e3..8aacbd7 100644
--- a/src/main/java/org/apache/sysds/lops/Federated.java
+++ b/src/main/java/org/apache/sysds/lops/Federated.java
@@ -26,15 +26,19 @@ import static org.apache.sysds.common.Types.DataType;
import static org.apache.sysds.common.Types.ValueType;
import static org.apache.sysds.parser.DataExpression.FED_ADDRESSES;
import static org.apache.sysds.parser.DataExpression.FED_RANGES;
+import static org.apache.sysds.parser.DataExpression.FED_TYPE;
public class Federated extends Lop {
- private Lop _addresses, _ranges;
+ private Lop _type, _addresses, _ranges;
public Federated(HashMap<String, Lop> inputLops, DataType dataType,
ValueType valueType) {
super(Type.Federated, dataType, valueType);
+ _type = inputLops.get(FED_TYPE);
_addresses = inputLops.get(FED_ADDRESSES);
_ranges = inputLops.get(FED_RANGES);
+ addInput(_type);
+ _type.addOutput(this);
addInput(_addresses);
_addresses.addOutput(this);
addInput(_ranges);
@@ -42,11 +46,13 @@ public class Federated extends Lop {
}
@Override
- public String getInstructions(String addresses, String ranges, String
output) {
+ public String getInstructions(String type, String addresses, String
ranges, String output) {
StringBuilder sb = new StringBuilder("FED");
sb.append(OPERAND_DELIMITOR);
sb.append("fedinit");
sb.append(OPERAND_DELIMITOR);
+ sb.append(_type.prepScalarInputOperand(type));
+ sb.append(OPERAND_DELIMITOR);
sb.append(_addresses.prepScalarInputOperand(addresses));
sb.append(OPERAND_DELIMITOR);
sb.append(_ranges.prepScalarInputOperand(ranges));
diff --git a/src/main/java/org/apache/sysds/parser/DataExpression.java
b/src/main/java/org/apache/sysds/parser/DataExpression.java
index baa2b48..202947b 100644
--- a/src/main/java/org/apache/sysds/parser/DataExpression.java
+++ b/src/main/java/org/apache/sysds/parser/DataExpression.java
@@ -49,8 +49,10 @@ import java.util.HashSet;
import java.util.Set;
import java.util.Map.Entry;
+import static
org.apache.sysds.runtime.instructions.fed.InitFEDInstruction.FED_FRAME_IDENTIFIER;
+import static
org.apache.sysds.runtime.instructions.fed.InitFEDInstruction.FED_MATRIX_IDENTIFIER;
-public class DataExpression extends DataIdentifier
+public class DataExpression extends DataIdentifier
{
public static final String RAND_DIMS = "dims";
@@ -81,6 +83,7 @@ public class DataExpression extends DataIdentifier
public static final String FED_ADDRESSES = "addresses";
public static final String FED_RANGES = "ranges";
+ public static final String FED_TYPE = "type";
public static final String FORMAT_TYPE = "format";
public static final String FORMAT_TYPE_VALUE_TEXT = "text";
@@ -122,7 +125,7 @@ public class DataExpression extends DataIdentifier
Arrays.asList(SQL_CONN, SQL_USER, SQL_PASS, SQL_QUERY));
public static final Set<String> FEDERATED_VALID_PARAM_NAMES = new
HashSet<>(
- Arrays.asList(FED_ADDRESSES, FED_RANGES));
+ Arrays.asList(FED_ADDRESSES, FED_RANGES, FED_TYPE));
// Valid parameter names in a metadata file
public static final Set<String> READ_VALID_MTD_PARAM_NAMES =new
HashSet<>(
@@ -432,30 +435,43 @@ public class DataExpression extends DataIdentifier
else
namedParamCount++;
}
- if (passedParamExprs.size() != 2){
- errorListener.validationError(parseInfo, "for
federated statement, must specify exactly 2 argument: addresses, ranges");
+ if(passedParamExprs.size() < 2) {
+ errorListener.validationError(parseInfo,
+ "for federated statement, must specify
at least 2 arguments: addresses, ranges");
return null;
}
- if (unnamedParamCount > 0) {
- if (namedParamCount > 0) {
-
errorListener.validationError(parseInfo, "for federated statement, cannot mix
named and unnamed parameters");
+ if(unnamedParamCount > 0) {
+ if(namedParamCount > 0) {
+ errorListener.validationError(parseInfo,
+ "for federated statement,
cannot mix named and unnamed parameters");
return null;
}
- ParameterExpression param =
passedParamExprs.get(0);
-
dataExpr.addFederatedExprParam(DataExpression.FED_ADDRESSES, param.getExpr());
- param = passedParamExprs.get(1);
-
dataExpr.addFederatedExprParam(DataExpression.FED_RANGES, param.getExpr());
+ if(unnamedParamCount == 2) {
+ // first parameter addresses second are
the ranges (type defaults to Matrix)
+ ParameterExpression param =
passedParamExprs.get(0);
+
dataExpr.addFederatedExprParam(DataExpression.FED_ADDRESSES, param.getExpr());
+ param = passedParamExprs.get(1);
+
dataExpr.addFederatedExprParam(DataExpression.FED_RANGES, param.getExpr());
+ }
+ else if(unnamedParamCount == 3) {
+ ParameterExpression param =
passedParamExprs.get(0);
+
dataExpr.addFederatedExprParam(DataExpression.FED_ADDRESSES, param.getExpr());
+ param = passedParamExprs.get(1);
+
dataExpr.addFederatedExprParam(DataExpression.FED_RANGES, param.getExpr());
+ param = passedParamExprs.get(2);
+
dataExpr.addFederatedExprParam(DataExpression.FED_TYPE, param.getExpr());
+ }
+ else {
+ errorListener.validationError(parseInfo,
+ "for federated statement, at
most 3 arguments are supported: addresses, ranges, type");
+ }
}
else {
- ParameterExpression firstParam =
passedParamExprs.get(0);
- if (firstParam.getName() != null &&
!firstParam.getName().equals(DataExpression.FED_ADDRESSES)){
-
errorListener.validationError(parseInfo, "federated method must have addresses
parameter as first parameter or unnamed parameter");
- return null;
- }
for (ParameterExpression passedParamExpr :
passedParamExprs) {
dataExpr.addFederatedExprParam(passedParamExpr.getName(),
passedParamExpr.getExpr());
}
}
+ dataExpr.setFederatedDefault();
}
if (dataExpr != null) {
@@ -569,7 +585,7 @@ public class DataExpression extends DataIdentifier
if (!found)
raiseValidateError("unexpected parameter \"" +
paramName + "\". Legal parameters for federated statement are "
- + "(capitalization-sensitive): " +
FED_ADDRESSES + ", " + FED_RANGES);
+ + "(capitalization-sensitive): " +
FED_ADDRESSES + ", " + FED_RANGES + ", " + FED_TYPE);
if (getVarParam(paramName) != null)
raiseValidateError("attempted to add federated
statement parameter " + paramValue + " more than once");
@@ -620,6 +636,11 @@ public class DataExpression extends DataIdentifier
addVarParam(RAND_BY_ROW, new BooleanIdentifier(true,
this));
}
+ public void setFederatedDefault(){
+ if (getVarParam(FED_TYPE) == null)
+ addVarParam(FED_TYPE, new
StringIdentifier(FED_MATRIX_IDENTIFIER, this));
+ }
+
private void setSqlDefault() {
if (getVarParam(SQL_USER) == null)
addVarParam(SQL_USER, new StringIdentifier("", this));
@@ -1881,8 +1902,8 @@ public class DataExpression extends DataIdentifier
case FEDERATED:
validateParams(conditional, FEDERATED_VALID_PARAM_NAMES,
- "Legal parameters for federated
statement are (case-sensitive): " + FED_ADDRESSES + ", " +
- FED_RANGES);
+ "Legal parameters for federated statement are
(case-sensitive): "
+ + FED_TYPE + ", " + FED_ADDRESSES + ", " +
FED_RANGES);
exp = getVarParam(FED_ADDRESSES);
if( !(exp instanceof DataIdentifier) ) {
raiseValidateError("for federated statement " +
FED_ADDRESSES + " has incorrect value type", conditional);
@@ -1893,14 +1914,26 @@ public class DataExpression extends DataIdentifier
raiseValidateError("for federated statement " +
FED_RANGES + " has incorrect value type", conditional);
}
getVarParam(FED_RANGES).validateExpression(ids,
currConstVars, conditional);
+ exp = getVarParam(FED_TYPE);
+ if( !(exp instanceof StringIdentifier) ) {
+ raiseValidateError("for federated statement " +
FED_TYPE + " has incorrect value type", conditional);
+ }
+ getVarParam(FED_TYPE).validateExpression(ids,
currConstVars, conditional);
// TODO format type?
getOutput().setFormatType(FormatType.BINARY);
- getOutput().setDataType(DataType.MATRIX);
- // TODO value type for federated object
- getOutput().setValueType(ValueType.FP64);
+ StringIdentifier fedType = (StringIdentifier) exp;
+
if(fedType.getValue().equalsIgnoreCase(FED_MATRIX_IDENTIFIER)) {
+ getOutput().setDataType(DataType.MATRIX);
+ // TODO value type for federated object
+ getOutput().setValueType(ValueType.FP64);
+ }
+ else
if(fedType.getValue().equalsIgnoreCase(FED_FRAME_IDENTIFIER)) {
+ getOutput().setDataType(DataType.FRAME);
+ }
getOutput().setDimensions(-1, -1);
break;
+
default:
raiseValidateError("Unsupported Data expression "+
this.getOpCode(), false, LanguageErrorCodes.INVALID_PARAMETERS); //always
unconditional
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index 32b6162..ef1967c 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -31,6 +31,8 @@ import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer.RPolicy;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.instructions.cp.Data;
@@ -168,6 +170,9 @@ public abstract class CacheableData<T extends CacheBlock>
extends Data
*/
protected PrivacyConstraint _privacyConstraint = null;
+ protected Map<FederatedRange, FederatedData> _fedMapping = null;
+
+
/** The name of HDFS file in which the data is backed up. */
protected String _hdfsFileName = null; // file name and path
@@ -326,6 +331,30 @@ public abstract class CacheableData<T extends CacheBlock>
extends Data
public abstract void refreshMetaData();
+ /**
+ * Check if object is federated.
+ * @return true if federated else false
+ */
+ public boolean isFederated() {
+ return _fedMapping != null;
+ }
+
+ /**
+ * Gets the mapping of indices ranges to federated objects.
+ * @return fedMapping mapping
+ */
+ public Map<FederatedRange, FederatedData> getFedMapping() {
+ return _fedMapping;
+ }
+
+ /**
+ * Sets the mapping of indices ranges to federated objects.
+ * @param fedMapping mapping
+ */
+ public void setFedMapping(Map<FederatedRange, FederatedData>
fedMapping) {
+ _fedMapping = fedMapping;
+ }
+
public RDDObject getRDDHandle() {
return _rddHandle;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
index 7cdd998..3808289 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
@@ -22,14 +22,16 @@ package org.apache.sysds.runtime.controlprogram.caching;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang.mutable.MutableBoolean;
+import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import org.apache.sysds.runtime.io.FileFormatProperties;
-import org.apache.sysds.runtime.io.FrameReader;
import org.apache.sysds.runtime.io.FrameReaderFactory;
import org.apache.sysds.runtime.io.FrameWriter;
import org.apache.sysds.runtime.io.FrameWriterFactory;
@@ -43,13 +45,17 @@ import org.apache.sysds.runtime.util.UtilFunctions;
import java.io.IOException;
import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.Future;
+
+import static org.apache.sysds.runtime.util.UtilFunctions.requestFederatedData;
public class FrameObject extends CacheableData<FrameBlock>
{
private static final long serialVersionUID = 1755082174281927785L;
private ValueType[] _schema = null;
-
+
protected FrameObject() {
super(DataType.FRAME, ValueType.STRING);
}
@@ -97,7 +103,7 @@ public class FrameObject extends CacheableData<FrameBlock>
return (_schema!=null && _schema.length>cu) ?
Arrays.copyOfRange(_schema, cl, cu+1) :
UtilFunctions.nCopies(cu-cl+1, ValueType.STRING);
}
-
+
/**
* Creates a new collection which contains the schema of the current
* frame object concatenated with the schema of the passed frame object.
@@ -156,13 +162,50 @@ public class FrameObject extends CacheableData<FrameBlock>
}
@Override
+ public FrameBlock acquireRead() {
+ // forward call for non-federated objects
+ if( !isFederated() )
+ return super.acquireRead();
+
+ FrameBlock result = new FrameBlock(_schema);
+ // provide long support?
+ result.ensureAllocatedColumns((int)
_metaData.getDataCharacteristics().getRows());
+ List<Pair<FederatedRange, Future<FederatedResponse>>>
readResponses = requestFederatedData(_fedMapping);
+ try {
+ for(Pair<FederatedRange, Future<FederatedResponse>>
readResponse : readResponses) {
+ FederatedRange range = readResponse.getLeft();
+ FederatedResponse response =
readResponse.getRight().get();
+ // add result
+ if(!response.isSuccessful())
+ throw new
DMLRuntimeException("Federated matrix read failed: " +
response.getErrorMessage());
+ FrameBlock multRes = (FrameBlock)
response.getData()[0];
+ for (int r = 0; r < multRes.getNumRows(); r++) {
+ for (int c = 0; c <
multRes.getNumColumns(); c++) {
+ int destRow =
range.getBeginDimsInt()[0] + r;
+ int destCol =
range.getBeginDimsInt()[1] + c;
+ result.set(destRow, destCol,
multRes.get(r, c));
+ }
+ }
+ }
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException("Federated Frame read
failed.", e);
+ }
+
+ //keep returned object for future use
+ acquireModify(result);
+
+ return result;
+ }
+
+ @Override
protected FrameBlock readBlobFromCache(String fname) throws IOException
{
return (FrameBlock)LazyWriteBuffer.readBlock(fname, false);
}
@Override
protected FrameBlock readBlobFromHDFS(String fname, long[] dims)
- throws IOException
+ throws IOException
{
long clen = dims[1];
MetaDataFormat iimd = (MetaDataFormat) _metaData;
@@ -175,17 +218,18 @@ public class FrameObject extends CacheableData<FrameBlock>
//read the frame block
FrameBlock data = null;
try {
- FrameReader reader =
FrameReaderFactory.createFrameReader(iimd.getInputInfo(),
getFileFormatProperties());
- data = reader.readFrameFromHDFS(fname, lschema,
dc.getRows(), dc.getCols());
+ data = isFederated() ? acquireReadAndRelease() :
+
FrameReaderFactory.createFrameReader(iimd.getInputInfo(),
getFileFormatProperties())
+ .readFrameFromHDFS(fname, lschema,
dc.getRows(), dc.getCols());
}
catch( DMLRuntimeException ex ) {
throw new IOException(ex);
}
-
+
//sanity check correct output
if( data == null )
throw new IOException("Unable to load frame from file:
"+fname);
-
+
return data;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
index 2e3c55f..9ca4c66 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
@@ -20,7 +20,6 @@
package org.apache.sysds.runtime.controlprogram.caching;
import org.apache.commons.lang.mutable.MutableBoolean;
-import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
@@ -32,9 +31,7 @@ import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
import
org.apache.sysds.runtime.controlprogram.ParForProgramBlock.PDataPartitionFormat;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.instructions.spark.data.RDDObject;
import org.apache.sysds.runtime.io.FileFormatProperties;
@@ -51,11 +48,10 @@ import org.apache.sysds.runtime.util.IndexRange;
import java.io.IOException;
import java.lang.ref.SoftReference;
-import java.util.ArrayList;
import java.util.List;
-import java.util.Map;
import java.util.concurrent.Future;
+import static org.apache.sysds.runtime.util.UtilFunctions.requestFederatedData;
/**
* Represents a matrix in control program. This class contains method to read
@@ -92,8 +88,6 @@ public class MatrixObject extends CacheableData<MatrixBlock>
private String _partitionCacheName = null; //name of cache block
private MatrixBlock _partitionInMemory = null;
- private Map<FederatedRange, FederatedData> _fedMapping = null; //
mapping for federated matrixobject
-
/**
* Constructor that takes the value type and the HDFS filename.
*
@@ -142,66 +136,6 @@ public class MatrixObject extends
CacheableData<MatrixBlock>
_markForLinCache = mo._markForLinCache;
}
- public boolean isFederated() {
- return _fedMapping != null;
- }
-
- public Map<FederatedRange, FederatedData> getFedMapping() {
- return _fedMapping;
- }
-
- public void setFedMapping(Map<FederatedRange, FederatedData>
fedMapping) {
- _fedMapping = fedMapping;
- }
-
- @Override
- public MatrixBlock acquireRead() {
- // forward call for non-federated objects
- if( !isFederated() )
- return super.acquireRead();
-
- long[] dims = getDataCharacteristics().getDims();
- // TODO sparse optimization
- MatrixBlock result = new MatrixBlock((int) dims[0], (int)
dims[1], false);
- List<Pair<FederatedRange, Future<FederatedResponse>>>
readResponses = new ArrayList<>();
- for (Map.Entry<FederatedRange, FederatedData> entry :
_fedMapping.entrySet()) {
- FederatedRange range = entry.getKey();
- FederatedData fd = entry.getValue();
-
- if( fd.isInitialized() ) {
- FederatedRequest request = new
FederatedRequest(FederatedRequest.FedMethod.TRANSFER);
- Future<FederatedResponse> readResponse =
fd.executeFederatedOperation(request, true);
- readResponses.add(new ImmutablePair<>(range,
readResponse));
- }
- else {
- throw new DMLRuntimeException("Federated matrix
read only supported on initialized FederatedData");
- }
- }
- try {
- for (Pair<FederatedRange, Future<FederatedResponse>>
readResponse : readResponses) {
- FederatedRange range = readResponse.getLeft();
- FederatedResponse response =
readResponse.getRight().get();
- // add result
- int[] beginDimsInt = range.getBeginDimsInt();
- int[] endDimsInt = range.getEndDimsInt();
- if( !response.isSuccessful() )
- throw new
DMLRuntimeException("Federated matrix read failed: " +
response.getErrorMessage());
- MatrixBlock multRes = (MatrixBlock)
response.getData();
- result.copy(beginDimsInt[0], endDimsInt[0] - 1,
- beginDimsInt[1], endDimsInt[1] - 1,
multRes, false);
- result.setNonZeros(result.getNonZeros() +
multRes.getNonZeros());
- }
- }
- catch (Exception e) {
- throw new DMLRuntimeException("Federated matrix read
failed.", e);
- }
-
- //keep returned object for future use
- acquireModify(result);
-
- return result;
- }
-
public void setUpdateType(UpdateType flag) {
_updateType = flag;
}
@@ -472,8 +406,41 @@ public class MatrixObject extends
CacheableData<MatrixBlock>
return sb.toString();
}
+ @Override
+ public MatrixBlock acquireRead() {
+ // forward call for non-federated objects
+ if( !isFederated() )
+ return super.acquireRead();
+
+ long[] dims = getDataCharacteristics().getDims();
+ // TODO sparse optimization
+ MatrixBlock result = new MatrixBlock((int) dims[0], (int)
dims[1], false);
+ List<Pair<FederatedRange, Future<FederatedResponse>>>
readResponses = requestFederatedData(_fedMapping);
+ try {
+ for (Pair<FederatedRange, Future<FederatedResponse>>
readResponse : readResponses) {
+ FederatedRange range = readResponse.getLeft();
+ FederatedResponse response =
readResponse.getRight().get();
+ // add result
+ int[] beginDimsInt = range.getBeginDimsInt();
+ int[] endDimsInt = range.getEndDimsInt();
+ if( !response.isSuccessful() )
+ throw new
DMLRuntimeException("Federated matrix read failed: " +
response.getErrorMessage());
+ MatrixBlock multRes = (MatrixBlock)
response.getData()[0];
+ result.copy(beginDimsInt[0], endDimsInt[0] - 1,
+ beginDimsInt[1], endDimsInt[1] - 1,
multRes, false);
+ result.setNonZeros(result.getNonZeros() +
multRes.getNonZeros());
+ }
+ }
+ catch (Exception e) {
+ throw new DMLRuntimeException("Federated matrix read
failed.", e);
+ }
+
+ //keep returned object for future use
+ acquireModify(result);
+
+ return result;
+ }
-
// *********************************************
// *** ***
// *** LOW-LEVEL PROTECTED METHODS ***
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
index 9adf089..32a3457 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
@@ -32,7 +32,7 @@ import io.netty.handler.codec.serialization.ClassResolvers;
import io.netty.handler.codec.serialization.ObjectDecoder;
import io.netty.handler.codec.serialization.ObjectEncoder;
import io.netty.util.concurrent.Promise;
-
+import org.apache.sysds.common.Types;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -41,6 +41,7 @@ import java.util.concurrent.Future;
public class FederatedData {
+ private Types.DataType _dataType;
private InetSocketAddress _address;
private String _filepath;
/**
@@ -50,7 +51,8 @@ public class FederatedData {
private int _nrThreads =
Integer.parseInt(DMLConfig.DEFAULT_NUMBER_OF_FEDERATED_WORKER_THREADS);
- public FederatedData(InetSocketAddress address, String filepath) {
+ public FederatedData(Types.DataType dataType, InetSocketAddress
address, String filepath) {
+ _dataType = dataType;
_address = address;
_filepath = filepath;
}
@@ -61,7 +63,7 @@ public class FederatedData {
* @param varID the varID of the variable we refer to
*/
public FederatedData(FederatedData other, long varID) {
- this(other._address, other._filepath);
+ this(other._dataType, other._address, other._filepath);
_varID = varID;
}
@@ -82,12 +84,22 @@ public class FederatedData {
}
public synchronized Future<FederatedResponse> initFederatedData() {
- if( isInitialized() )
+ if(isInitialized())
throw new DMLRuntimeException("Tried to init already
initialized data");
- FederatedRequest request = new
FederatedRequest(FederatedRequest.FedMethod.READ);
+ FederatedRequest.FedMethod fedMethod;
+ switch(_dataType) {
+ case MATRIX:
+ fedMethod =
FederatedRequest.FedMethod.READ_MATRIX;
+ break;
+ case FRAME:
+ fedMethod =
FederatedRequest.FedMethod.READ_FRAME;
+ break;
+ default:
+ throw new DMLRuntimeException("Federated
datatype \"" + _dataType.toString() + "\" is not supported.");
+ }
+ FederatedRequest request = new FederatedRequest(fedMethod);
request.appendParam(_filepath);
return executeFederatedOperation(request);
-
}
/**
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
index 17b9180..8e59431 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
@@ -28,7 +28,7 @@ public class FederatedRequest implements Serializable {
private static final long serialVersionUID = 5946781306963870394L;
public enum FedMethod {
- READ, MATVECMULT, TRANSFER, AGGREGATE, SCALAR
+ READ_MATRIX, READ_FRAME, MATVECMULT, TRANSFER, AGGREGATE, SCALAR
}
private FedMethod _method;
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
index acd730b..6032984 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
@@ -31,28 +31,35 @@ public class FederatedResponse implements Serializable {
}
private FederatedResponse.Type _status;
- private Object _data;
+ private Object[] _data;
public FederatedResponse(FederatedResponse.Type status) {
this(status, null);
}
- public FederatedResponse(FederatedResponse.Type status, Object data) {
+ public FederatedResponse(FederatedResponse.Type status, Object[] data) {
_status = status;
_data = data;
if( _status == FederatedResponse.Type.SUCCESS && data == null )
_status = FederatedResponse.Type.SUCCESS_EMPTY;
}
+ public FederatedResponse(FederatedResponse.Type status, Object data) {
+ _status = status;
+ _data = new Object[] {data};
+ if(_status == FederatedResponse.Type.SUCCESS && data == null)
+ _status = FederatedResponse.Type.SUCCESS_EMPTY;
+ }
+
public boolean isSuccessful() {
return _status != FederatedResponse.Type.ERROR;
}
public String getErrorMessage() {
- return (String) _data;
+ return (String) _data[0];
}
- public Object getData() {
+ public Object[] getData() {
return _data;
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
index 13a1a91..afed54b 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
@@ -31,9 +31,9 @@ import io.netty.handler.codec.serialization.ClassResolvers;
import io.netty.handler.codec.serialization.ObjectDecoder;
import io.netty.handler.codec.serialization.ObjectEncoder;
import org.apache.log4j.Logger;
-import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
-import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
+import org.apache.sysds.runtime.instructions.cp.Data;
import java.util.HashMap;
import java.util.Map;
@@ -44,7 +44,7 @@ public class FederatedWorker {
private int _port;
private int _nrThreads =
Integer.parseInt(DMLConfig.DEFAULT_NUMBER_OF_FEDERATED_WORKER_THREADS);
private IDSequence _seq = new IDSequence();
- private Map<Long, CacheableData<?>> _vars = new HashMap<>();
+ private Map<Long, Data> _vars = new HashMap<>();
public FederatedWorker(int port) {
_port = (port == -1) ?
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 36bd258..81c383e 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -27,13 +27,13 @@ import org.apache.commons.lang.exception.ExceptionUtils;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.log4j.Logger;
-import org.apache.wink.json4j.JSONObject;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.caching.TensorObject;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
@@ -53,6 +53,7 @@ import
org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.utils.JSONHelper;
+import org.apache.wink.json4j.JSONObject;
import java.io.BufferedReader;
import java.io.InputStreamReader;
@@ -63,9 +64,9 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
protected static Logger log =
Logger.getLogger(FederatedWorkerHandler.class);
private final IDSequence _seq;
- private Map<Long, CacheableData<?>> _vars;
+ private Map<Long, Data> _vars;
- public FederatedWorkerHandler(IDSequence seq, Map<Long,
CacheableData<?>> _vars2) {
+ public FederatedWorkerHandler(IDSequence seq, Map<Long, Data> _vars2) {
_seq = seq;
_vars = _vars2;
}
@@ -93,8 +94,10 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
FederatedRequest.FedMethod method = request.getMethod();
try {
switch (method) {
- case READ:
- return readMatrix(request);
+ case READ_MATRIX:
+ return readData(request,
Types.DataType.MATRIX);
+ case READ_FRAME:
+ return readData(request,
Types.DataType.FRAME);
case MATVECMULT:
return executeMatVecMult(request);
case TRANSFER:
@@ -113,18 +116,30 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
}
}
- private FederatedResponse readMatrix(FederatedRequest request) {
+ private FederatedResponse readData(FederatedRequest request,
Types.DataType dataType) {
checkNumParams(request.getNumParams(), 1);
String filename = (String) request.getParam(0);
- return readMatrix(filename);
+ return readData(filename, dataType);
}
- private FederatedResponse readMatrix(String filename) {
+ private FederatedResponse readData(String filename, Types.DataType
dataType) {
MatrixCharacteristics mc = new MatrixCharacteristics();
mc.setBlocksize(ConfigurationManager.getBlocksize());
- MatrixObject mo = new MatrixObject(Types.ValueType.FP64,
filename);
- OutputInfo oi = null;
- InputInfo ii = null;
+ CacheableData<?> cd;
+ switch (dataType) {
+ case MATRIX:
+ cd = new MatrixObject(Types.ValueType.FP64,
filename);
+ break;
+ case FRAME:
+ cd = new FrameObject(filename);
+ break;
+ default:
+ // should NEVER happen (if we keep request
codes in sync with actual behaviour)
+ return new
FederatedResponse(FederatedResponse.Type.ERROR, "Could not recognize datatype");
+ }
+
+ OutputInfo oi;
+ InputInfo ii;
// read metadata
try {
String mtdname =
DataExpression.getMTDFileName(filename);
@@ -146,13 +161,17 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
throw new DMLRuntimeException(ex);
}
MetaDataFormat mdf = new MetaDataFormat(mc, oi, ii);
- mo.setMetaData(mdf);
- mo.acquireRead();
- mo.refreshMetaData();
- mo.release();
+ cd.setMetaData(mdf);
+ cd.acquireRead();
+ cd.refreshMetaData();
+ cd.release();
long id = _seq.getNextID();
- _vars.put(id, mo);
+ _vars.put(id, cd);
+ if (dataType == Types.DataType.FRAME) {
+ FrameObject frameObject = (FrameObject) cd;
+ return new
FederatedResponse(FederatedResponse.Type.SUCCESS, new Object[] {id,
frameObject.getSchema()});
+ }
return new FederatedResponse(FederatedResponse.Type.SUCCESS,
id);
}
@@ -192,6 +211,9 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
case MATRIX:
return new
FederatedResponse(FederatedResponse.Type.SUCCESS,
((MatrixObject)
dataObject).acquireReadAndRelease());
+ case FRAME:
+ return new
FederatedResponse(FederatedResponse.Type.SUCCESS,
+ ((FrameObject)
dataObject).acquireReadAndRelease());
case LIST:
return new
FederatedResponse(FederatedResponse.Type.SUCCESS, ((ListObject)
dataObject).getData());
// TODO rest of the possible datatypes
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/LibFederatedAgg.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/LibFederatedAgg.java
index 7e72eee..1f40221 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/LibFederatedAgg.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/LibFederatedAgg.java
@@ -68,7 +68,7 @@ public class LibFederatedAgg
int[] beginDims = range.getBeginDimsInt();
if (!federatedResponse.isSuccessful())
throw new
DMLRuntimeException("Federated aggregation failed: " +
federatedResponse.getErrorMessage());
- MatrixBlock mb = (MatrixBlock)
federatedResponse.getData();
+ MatrixBlock mb = (MatrixBlock)
federatedResponse.getData()[0];
// TODO performance optimizations
MatrixValue.CellIndex cellIndex = new
MatrixValue.CellIndex(0, 0);
ValueFunction valueFn =
operator.aggOp.increOp.fn;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 632e507..caed372 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -211,7 +211,7 @@ public class AggregateBinaryFEDInstruction extends
BinaryFEDInstruction {
FederatedResponse federatedResponse, MatrixBlock resultBlock,
boolean matrixVectorOp)
{
int[] beginDims = range.getBeginDimsInt();
- MatrixBlock mb = (MatrixBlock) federatedResponse.getData();
+ MatrixBlock mb = (MatrixBlock) federatedResponse.getData()[0];
// TODO performance optimizations
// TODO Improve Vector Matrix multiplication accuracy: An idea
would be to make use of kahan plus here,
// this should improve accuracy a bit, although we still lose
out on the small error lost on the worker
@@ -323,8 +323,8 @@ public class AggregateBinaryFEDInstruction extends
BinaryFEDInstruction {
// TODO experiment if sending multiple requests
at the same time to the same worker increases
// performance (remove get and do
multithreaded?)
FederatedResponse response =
executeMVMultiply(_range, _data, vec, _distributeCols).get();
- if (response.isSuccessful()) {
- result.copy(r, r, 0, endDims[1] -
beginDims[1] - 1, (MatrixBlock) response.getData(), true);
+ if(response.isSuccessful()) {
+ result.copy(r, r, 0, endDims[1] -
beginDims[1] - 1, (MatrixBlock) response.getData()[0], true);
}
else
throw new DMLRuntimeException(
@@ -366,8 +366,8 @@ public class AggregateBinaryFEDInstruction extends
BinaryFEDInstruction {
// TODO experiment if sending multiple requests
at the same time to the same worker increases
// performance
FederatedResponse response =
executeMVMultiply(_range, _data, vec, _distributeCols).get();
- if (response.isSuccessful()) {
- result.copy(0, endDims[0] -
beginDims[0] - 1, c, c, (MatrixBlock) response.getData(), true);
+ if(response.isSuccessful()) {
+ result.copy(0, endDims[0] -
beginDims[0] - 1, c, c, (MatrixBlock) response.getData()[0], true);
}
else
throw new DMLRuntimeException(
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
index aaca557..74fd1ab 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
@@ -24,10 +24,12 @@ import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.federated.*;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
-import org.apache.sysds.runtime.instructions.fed.BinaryFEDInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
@@ -77,7 +79,7 @@ public class BinaryMatrixScalarFEDInstruction extends
BinaryFEDInstruction
if (!federatedResponse.isSuccessful())
throw new
DMLRuntimeException("Federated binary operation failed: " +
federatedResponse.getErrorMessage());
- MatrixBlock shard = (MatrixBlock)
federatedResponse.getData();
+ MatrixBlock shard = (MatrixBlock)
federatedResponse.getData()[0];
ret.copy(range.getBeginDimsInt()[0],
range.getEndDimsInt()[0]-1,
range.getBeginDimsInt()[1],
range.getEndDimsInt()[1]-1, shard, false);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index 024a554..5e0eb9c 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -24,6 +24,7 @@ import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
@@ -42,16 +43,22 @@ import java.net.MalformedURLException;
import java.net.URL;
import java.net.UnknownHostException;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.Future;
public class InitFEDInstruction extends FEDInstruction {
- private CPOperand _addresses, _ranges, _output;
+
+ public static final String FED_MATRIX_IDENTIFIER = "matrix";
+ public static final String FED_FRAME_IDENTIFIER = "frame";
- public InitFEDInstruction(CPOperand addresses, CPOperand ranges,
CPOperand out, String opcode, String instr) {
+ private CPOperand _type, _addresses, _ranges, _output;
+
+ public InitFEDInstruction(CPOperand type, CPOperand addresses,
CPOperand ranges, CPOperand out, String opcode, String instr) {
super(FEDType.Init, opcode, instr);
+ _type = type;
_addresses = addresses;
_ranges = ranges;
_output = out;
@@ -59,21 +66,23 @@ public class InitFEDInstruction extends FEDInstruction {
public static InitFEDInstruction parseInstruction(String str) {
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
- // We need 4 parts: Opcode, Addresses (list of Strings with
+ // We need 5 parts: Opcode, Type (Frame/Matrix), Addresses
(list of Strings with
// url/ip:port/filepath), ranges and the output Operand
- if (parts.length != 4)
+ if (parts.length != 5)
throw new DMLRuntimeException("Invalid number of
operands in federated instruction: " + str);
String opcode = parts[0];
- CPOperand addresses, ranges, out;
- addresses = new CPOperand(parts[1]);
- ranges = new CPOperand(parts[2]);
- out = new CPOperand(parts[3]);
- return new InitFEDInstruction(addresses, ranges, out, opcode,
str);
+ CPOperand type, addresses, ranges, out;
+ type = new CPOperand(parts[1]);
+ addresses = new CPOperand(parts[2]);
+ ranges = new CPOperand(parts[3]);
+ out = new CPOperand(parts[4]);
+ return new InitFEDInstruction(type, addresses, ranges, out,
opcode, str);
}
@Override
public void processInstruction(ExecutionContext ec) {
+ String type = ec.getScalarInput(_type).getStringValue();
ListObject addresses = ec.getListObject(_addresses.getName());
ListObject ranges = ec.getListObject(_ranges.getName());
List<Pair<FederatedRange, FederatedData>> feds = new
ArrayList<>();
@@ -81,7 +90,15 @@ public class InitFEDInstruction extends FEDInstruction {
if (addresses.getLength() * 2 != ranges.getLength())
throw new DMLRuntimeException("Federated read needs
twice the amount of addresses as ranges "
+ "(begin and end): addresses=" +
addresses.getLength() + " ranges=" + ranges.getLength());
-
+
+ Types.DataType fedDataType;
+ if (type.equalsIgnoreCase(FED_MATRIX_IDENTIFIER))
+ fedDataType = Types.DataType.MATRIX;
+ else if (type.equalsIgnoreCase(FED_FRAME_IDENTIFIER))
+ fedDataType = Types.DataType.FRAME;
+ else
+ throw new DMLRuntimeException("type \"" + type + "\"
non valid federated type");
+
long[] usedDims = new long[] { 0, 0 };
for (int i = 0; i < addresses.getLength(); i++) {
Data addressData = addresses.getData().get(i);
@@ -112,7 +129,7 @@ public class InitFEDInstruction extends FEDInstruction {
usedDims[0] = Math.max(usedDims[0], endDims[0]);
usedDims[1] = Math.max(usedDims[1], endDims[1]);
try {
- FederatedData federatedData = new
FederatedData(
+ FederatedData federatedData = new
FederatedData(fedDataType,
new
InetSocketAddress(InetAddress.getByName(host), port), filePath);
feds.add(new ImmutablePair<>(new
FederatedRange(beginDims, endDims), federatedData));
}
@@ -125,9 +142,22 @@ public class InitFEDInstruction extends FEDInstruction {
throw new DMLRuntimeException("federated
instruction only takes strings as addresses");
}
}
- MatrixObject output = ec.getMatrixObject(_output);
-
output.getDataCharacteristics().setRows(usedDims[0]).setCols(usedDims[1]);
- federate(output, feds);
+ if (type.equalsIgnoreCase(FED_MATRIX_IDENTIFIER)) {
+ MatrixObject output = ec.getMatrixObject(_output);
+
output.getDataCharacteristics().setRows(usedDims[0]).setCols(usedDims[1]);
+ federateMatrix(output, feds);
+ }
+ else if (type.equalsIgnoreCase(FED_FRAME_IDENTIFIER)) {
+ if (usedDims[1] > Integer.MAX_VALUE)
+ throw new DMLRuntimeException("federated Frame
can not have more than max int columns, because the " +
+ "schema can only be max int
length");
+ FrameObject output = ec.getFrameObject(_output);
+
output.getDataCharacteristics().setRows(usedDims[0]).setCols(usedDims[1]);
+ federateFrame(output, feds);
+ }
+ else {
+ throw new DMLRuntimeException("type \"" + type + "\"
non valid federated type");
+ }
}
public static String[] parseURL(String input) {
@@ -170,10 +200,9 @@ public class InitFEDInstruction extends FEDInstruction {
}
}
- public void federate(MatrixObject output, List<Pair<FederatedRange,
FederatedData>> workers) {
+ public void federateMatrix(MatrixObject output,
List<Pair<FederatedRange, FederatedData>> workers) {
Map<FederatedRange, FederatedData> fedMapping = new TreeMap<>();
for (Pair<FederatedRange, FederatedData> t : workers) {
- // TODO support all value types
fedMapping.put(t.getLeft(), t.getRight());
}
List<Pair<FederatedData, Future<FederatedResponse>>>
idResponses = new ArrayList<>();
@@ -187,6 +216,7 @@ public class InitFEDInstruction extends FEDInstruction {
for (int i = 0; i < dims.length; i++) {
dims[i] = endDims[i] - beginDims[i];
}
+ // TODO check if all matrices have the same
DataType (currently only double is supported)
idResponses.add(new ImmutablePair<>(value,
value.initFederatedData()));
}
}
@@ -194,7 +224,7 @@ public class InitFEDInstruction extends FEDInstruction {
for (Pair<FederatedData, Future<FederatedResponse>>
idResponse : idResponses) {
FederatedResponse response =
idResponse.getRight().get();
if (response.isSuccessful())
- idResponse.getLeft().setVarID((Long)
response.getData());
+ idResponse.getLeft().setVarID((Long)
response.getData()[0]);
else
throw new
DMLRuntimeException(response.getErrorMessage());
}
@@ -205,4 +235,65 @@ public class InitFEDInstruction extends FEDInstruction {
output.getDataCharacteristics().setNonZeros(output.getNumColumns() *
output.getNumRows());
output.setFedMapping(fedMapping);
}
+
+ public void federateFrame(FrameObject output, List<Pair<FederatedRange,
FederatedData>> workers) {
+ Map<FederatedRange, FederatedData> fedMapping = new TreeMap<>();
+ for (Pair<FederatedRange, FederatedData> t : workers) {
+ fedMapping.put(t.getLeft(), t.getRight());
+ }
+ // we want to wait for the futures with the response containing
varIDs and the schemas of the frames
+ // on the distributed workers. We need the FederatedData, the
starting column of the sub frame (for the schema)
+ // and the future for the response
+ List<Pair<FederatedData, Pair<Integer,
Future<FederatedResponse>>>> idResponses = new ArrayList<>();
+ for (Map.Entry<FederatedRange, FederatedData> entry :
fedMapping.entrySet()) {
+ FederatedRange range = entry.getKey();
+ FederatedData value = entry.getValue();
+ if (!value.isInitialized()) {
+ long[] beginDims = range.getBeginDims();
+ long[] endDims = range.getEndDims();
+ long[] dims =
output.getDataCharacteristics().getDims();
+ for (int i = 0; i < dims.length; i++) {
+ dims[i] = endDims[i] - beginDims[i];
+ }
+ idResponses.add(new ImmutablePair<>(value, new
ImmutablePair<>((int) beginDims[1], value.initFederatedData())));
+ }
+ }
+ // columns are definitely in int range, because we throw an
DMLRuntime Exception in `processInstruction` else
+ Types.ValueType[] schema = new Types.ValueType[(int)
output.getNumColumns()];
+ Arrays.fill(schema, Types.ValueType.UNKNOWN);
+ try {
+ for (Pair<FederatedData, Pair<Integer,
Future<FederatedResponse>>> idResponse : idResponses) {
+ FederatedData fedData = idResponse.getLeft();
+ FederatedResponse response =
idResponse.getRight().getRight().get();
+ int startCol = idResponse.getRight().getLeft();
+ handleFedFrameResponse(schema, fedData,
response, startCol);
+ }
+ }
+ catch (Exception e) {
+ throw new DMLRuntimeException("Federation
initialization failed", e);
+ }
+
output.getDataCharacteristics().setNonZeros(output.getNumColumns() *
output.getNumRows());
+ output.setSchema(schema);
+ output.setFedMapping(fedMapping);
+ }
+
+ private static void handleFedFrameResponse(Types.ValueType[] schema,
FederatedData federatedData,
+ FederatedResponse response, int startColumn) {
+ if(response.isSuccessful()) {
+ // Index 0 is the varID, Index 1 is the schema of the
frame
+ federatedData.setVarID((Long) response.getData()[0]);
+ // copy the
+ Types.ValueType[] range_schema = (Types.ValueType[])
response.getData()[1];
+ for(int i = 0; i < range_schema.length; i++) {
+ Types.ValueType vType = range_schema[i];
+ int schema_index = startColumn + i;
+ if(schema[schema_index] !=
Types.ValueType.UNKNOWN && schema[schema_index] != vType)
+ throw new
DMLRuntimeException("federated Frame schemas mismatch");
+ else
+ schema[schema_index] = vType;
+ }
+ }
+ else
+ throw new
DMLRuntimeException(response.getErrorMessage());
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
index 1185fd2..8b803dc 100644
--- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
@@ -19,20 +19,15 @@
package org.apache.sysds.runtime.util;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.BitSet;
-import java.util.HashSet;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Set;
-import java.util.stream.Stream;
-import java.util.stream.StreamSupport;
-
import org.apache.commons.lang.ArrayUtils;
+import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.TensorIndexes;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
@@ -41,6 +36,18 @@ import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.runtime.meta.TensorCharacteristics;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.BitSet;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.Future;
+import java.util.stream.Stream;
+import java.util.stream.StreamSupport;
+
public class UtilFunctions
{
//for accurate cast of double values to int and long
@@ -819,4 +826,23 @@ public class UtilFunctions
break;
}
}
+
+ public static List<org.apache.commons.lang3.tuple.Pair<FederatedRange,
Future<FederatedResponse>>> requestFederatedData(
+ Map<FederatedRange, FederatedData> fedMapping) {
+ List<org.apache.commons.lang3.tuple.Pair<FederatedRange,
Future<FederatedResponse>>> readResponses = new ArrayList<>();
+ for(Map.Entry<FederatedRange, FederatedData> entry :
fedMapping.entrySet()) {
+ FederatedRange range = entry.getKey();
+ FederatedData fd = entry.getValue();
+
+ if(fd.isInitialized()) {
+ FederatedRequest request = new
FederatedRequest(FederatedRequest.FedMethod.TRANSFER);
+ Future<FederatedResponse> readResponse =
fd.executeFederatedOperation(request, true);
+ readResponses.add(new ImmutablePair<>(range,
readResponse));
+ }
+ else {
+ throw new DMLRuntimeException("Federated matrix
read only supported on initialized FederatedData");
+ }
+ }
+ return readResponses;
+ }
}
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 6d1b396..88b6c0f 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -19,20 +19,6 @@
package org.apache.sysds.test;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
-
-import java.io.File;
-import java.io.IOException;
-import java.io.OutputStream;
-import java.io.PrintStream;
-import java.net.ServerSocket;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashMap;
-
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.spark.sql.SparkSession;
@@ -43,14 +29,14 @@ import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.parser.DataExpression;
-import org.apache.sysds.common.Types.DataType;
-import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
@@ -59,8 +45,8 @@ import org.apache.sysds.runtime.io.FrameReaderFactory;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.InputInfo;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.data.OutputInfo;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.matrix.data.OutputInfo;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.privacy.PrivacyConstraint;
import org.apache.sysds.runtime.util.DataConverter;
@@ -68,6 +54,20 @@ import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.utils.ParameterBuilder;
import org.apache.sysds.utils.Statistics;
+import java.io.File;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.PrintStream;
+import java.net.ServerSocket;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
/**
* <p>
* Extend this class to easily
@@ -1335,6 +1335,19 @@ public abstract class AutomatedTestBase {
}
}
}
+
+ /**
+ * <p>
+ * Compares the results of the computation of the frame with the
expected ones.
+ * </p>
+ *
+ * @param schema the frame schema
+ */
+ protected void compareResults(ValueType[] schema) {
+ for (int i = 0; i < comparisonFiles.length; i++) {
+ TestUtils.compareDMLFrameWithJavaFrame(schema,
comparisonFiles[i], outputDirectories[i]);
+ }
+ }
/**
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java
b/src/test/java/org/apache/sysds/test/TestUtils.java
index 80e0328..5383d6e 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -238,7 +238,7 @@ public class TestUtils
fail("unable to read file: " + e.getMessage());
}
}
-
+
/**
* Read doubles from the input stream and put them into the given
hashmap of values.
* @param inputStream input stream of doubles with related indices
@@ -273,6 +273,60 @@ public class TestUtils
/**
* <p>
+ * Read the cell values of the expected file and actual files. Schema
is used for correct parsing if the file is a
+ * frame and if it is null FP64 will be used for all values (useful for
Matrices).
+ * </p>
+ *
+ * @param schema the schema of the frame, can be null (for FP64)
+ * @param expectedFile the file with expected values
+ * @param actualDir the directory where the actual values were
written
+ * @param expectedValues the HashMap where the expected values will be
written to
+ * @param actualValues the HashMap where the actual values will be
written to
+ */
+ private static void readActualAndExpectedFile(ValueType[] schema,
String expectedFile, String actualDir,
+ HashMap<CellIndex, Object> expectedValues, HashMap<CellIndex,
Object> actualValues) {
+ try {
+ Path outDirectory = new Path(actualDir);
+ Path compareFile = new Path(expectedFile);
+ FileSystem fs =
IOUtilFunctions.getFileSystem(outDirectory, conf);
+ FSDataInputStream fsin = fs.open(compareFile);
+
+ try(BufferedReader compareIn = new BufferedReader(new
InputStreamReader(fsin))) {
+ String line;
+ while((line = compareIn.readLine()) != null) {
+ StringTokenizer st = new
StringTokenizer(line, " ");
+ int i =
Integer.parseInt(st.nextToken());
+ int j =
Integer.parseInt(st.nextToken());
+ ValueType vt = (schema != null) ?
schema[j - 1] : ValueType.FP64;
+ Object obj =
UtilFunctions.stringToObject(vt, st.nextToken());
+ expectedValues.put(new CellIndex(i, j),
obj);
+ }
+ }
+
+ FileStatus[] outFiles = fs.listStatus(outDirectory);
+
+ for(FileStatus file : outFiles) {
+ FSDataInputStream fsout =
fs.open(file.getPath());
+ try(BufferedReader outIn = new
BufferedReader(new InputStreamReader(fsout))) {
+ String line;
+ while((line = outIn.readLine()) !=
null) {
+ StringTokenizer st = new
StringTokenizer(line, " ");
+ int i =
Integer.parseInt(st.nextToken());
+ int j =
Integer.parseInt(st.nextToken());
+ ValueType vt = (schema != null)
? schema[j - 1] : ValueType.FP64;
+ Object obj =
UtilFunctions.stringToObject(vt, st.nextToken());
+ actualValues.put(new
CellIndex(i, j), obj);
+ }
+ }
+ }
+ }
+ catch(IOException e) {
+ fail("unable to read file: " + e.getMessage());
+ }
+ }
+
+ /**
+ * <p>
* Compares the expected values calculated in Java by testcase and
which are
* in the normal filesystem, with those calculated by SystemDS located
in
* HDFS
@@ -287,41 +341,61 @@ public class TestUtils
*/
@SuppressWarnings("resource")
public static void compareDMLMatrixWithJavaMatrix(String expectedFile,
String actualDir, double epsilon) {
- try {
- Path outDirectory = new Path(actualDir);
- Path compareFile = new Path(expectedFile);
- FileSystem fs =
IOUtilFunctions.getFileSystem(outDirectory, conf);
-
- FSDataInputStream fsin = fs.open(compareFile);
- HashMap<CellIndex, Double> expectedValues = new
HashMap<>();
- readValuesFromFileStream(fsin, expectedValues);
-
- HashMap<CellIndex, Double> actualValues = new
HashMap<>();
- FileStatus[] outFiles = fs.listStatus(outDirectory);
+ HashMap<CellIndex, Object> expectedValues = new HashMap<>();
+ HashMap<CellIndex, Object> actualValues = new HashMap<>();
- for (FileStatus file : outFiles) {
- FSDataInputStream fsout =
fs.open(file.getPath());
- readValuesFromFileStream(fsout, actualValues);
+ readActualAndExpectedFile(null, expectedFile, actualDir,
expectedValues, actualValues);
+
+ int countErrors = 0;
+ for(CellIndex index : expectedValues.keySet()) {
+ Double expectedValue = (Double)
expectedValues.get(index);
+ Double actualValue = (Double) actualValues.get(index);
+ if(expectedValue == null)
+ expectedValue = 0.0;
+ if(actualValue == null)
+ actualValue = 0.0;
+
+ if(!compareCellValue(expectedValue, actualValue,
epsilon, false)) {
+ System.out.println(
+ expectedFile + ": " + index + "
mismatch: expected " + expectedValue + ", actual " + actualValue);
+ countErrors++;
}
+ }
+ assertEquals("for file " + actualDir + " " + countErrors + "
values are not equal", 0, countErrors);
+ }
+
+ /**
+ * <p>
+ * Compares the expected values calculated in Java by testcase and
which are
+ * in the normal filesystem, with those calculated by SystemDS located
in
+ * HDFS
+ * </p>
+ *
+ * @param expectedFile
+ * file with expected values, which is located in OS
filesystem
+ * @param actualDir
+ * file with actual values, which is located in HDFS
+ */
+ @SuppressWarnings("resource")
+ public static void compareDMLFrameWithJavaFrame(ValueType[] schema,
String expectedFile, String actualDir) {
+ HashMap<CellIndex, Object> expectedValues = new HashMap<>();
+ HashMap<CellIndex, Object> actualValues = new HashMap<>();
- int countErrors = 0;
- for (CellIndex index : expectedValues.keySet()) {
- Double expectedValue =
expectedValues.get(index);
- Double actualValue = actualValues.get(index);
- if (expectedValue == null)
- expectedValue = 0.0;
- if (actualValue == null)
- actualValue = 0.0;
+ readActualAndExpectedFile(schema, expectedFile, actualDir,
expectedValues, actualValues);
- if (!compareCellValue(expectedValue,
actualValue, epsilon, false)) {
- System.out.println(expectedFile+":
"+index+" mismatch: expected " + expectedValue + ", actual " + actualValue);
- countErrors++;
- }
+ int countErrors = 0;
+ for(CellIndex index : expectedValues.keySet()) {
+ Object expectedValue = expectedValues.get(index);
+ Object actualValue = actualValues.get(index);
+
+ int j = index.column;
+ if(UtilFunctions.compareTo(schema[j - 1],
expectedValue, actualValue) != 0) {
+ System.out.println(
+ expectedFile + ": " + index + "
mismatch: expected " + expectedValue + ", actual " + actualValue);
+ countErrors++;
}
- assertTrue("for file " + actualDir + " " + countErrors
+ " values are not equal", countErrors == 0);
- } catch (IOException e) {
- fail("unable to read file: " + e.getMessage());
}
+ assertEquals("for file " + actualDir + " " + countErrors + "
values are not equal", 0, countErrors);
}
public static void compareTensorBlocks(TensorBlock tb1, TensorBlock
tb2) {
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedConstructionTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedConstructionTest.java
index 878e2d8..75a41f5 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedConstructionTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedConstructionTest.java
@@ -19,27 +19,34 @@
package org.apache.sysds.test.functions.federated;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.matrix.data.OutputInfo;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import java.io.IOException;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
public class FederatedConstructionTest extends AutomatedTestBase {
-
+
private final static String TEST_DIR = "functions/federated/";
private final static String TEST_NAME = "FederatedConstructionTest";
private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedConstructionTest.class.getSimpleName() + "/";
-
+ public static final String MATRIX_TEST_FILE_NAME =
"FederatedMatrixConstructionTest";
+ public static final String FRAME_TEST_FILE_NAME =
"FederatedFrameConstructionTest";
+
private int blocksize = 1024;
private int rows, cols;
@@ -51,7 +58,8 @@ public class FederatedConstructionTest extends
AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
- Object[][] data = new Object[][] {{1, 1000}, {10, 100}, {100,
10}, {1000, 1}, {10, 2000}, {2000, 10}};
+ // cols have to be dividable by 4 for Frame tests
+ Object[][] data = new Object[][] {{1, 1024}, {8, 256}, {256,
8}, {1024, 4}, {16, 2048}, {2048, 32}};
return Arrays.asList(data);
}
@@ -62,28 +70,59 @@ public class FederatedConstructionTest extends
AutomatedTestBase {
}
@Test
- public void federatedConstructionCP() {
- federatedConstruction(Types.ExecMode.SINGLE_NODE);
+ public void federatedMatrixConstructionCP() {
+ federatedMatrixConstruction(Types.ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedMatrixConstructionSP() {
+ federatedMatrixConstruction(Types.ExecMode.SPARK);
}
+ public void federatedMatrixConstruction(Types.ExecMode execMode) {
+ getAndLoadTestConfiguration(TEST_NAME);
+ // write input matrix
+ double[][] A = getRandomMatrix(rows, cols, -1, 1, 1, 1234);
+ writeInputMatrixWithMTD("A", A, false, new
MatrixCharacteristics(rows, cols, blocksize, rows * cols));
+ federatedConstruction(execMode, MATRIX_TEST_FILE_NAME, "A",
null);
+ }
+
@Test
- public void federatedConstructionSP() {
- federatedConstruction(Types.ExecMode.SPARK);
+ public void federatedFrameConstructionCP() throws IOException {
+ federatedFrameConstruction(Types.ExecMode.SINGLE_NODE);
+ }
+
+ /* like other federated functionality, SPARK execution mode is not yet
working (waiting for better integration
+ of federated instruction building, like propagating information that
object is federated)
+ @Test
+ public void federatedFrameConstructionSP() throws IOException {
+ federatedFrameConstruction(Types.ExecMode.SPARK);
+ }*/
+
+ public void federatedFrameConstruction(Types.ExecMode execMode) throws
IOException {
+ getAndLoadTestConfiguration(TEST_NAME);
+ // write input matrix
+ double[][] A = getRandomMatrix(rows, cols, -1, 1, 1, 1234);
+
+ List<Types.ValueType> schemaList = new
ArrayList<>(Collections.nCopies(cols/4, Types.ValueType.STRING));
+ schemaList.addAll(Collections.nCopies(cols/4,
Types.ValueType.FP64));
+ schemaList.addAll(Collections.nCopies(cols/4,
Types.ValueType.INT64));
+ schemaList.addAll(Collections.nCopies(cols/4,
Types.ValueType.BOOLEAN));
+
+ Types.ValueType[] schema = new Types.ValueType[cols];
+ schemaList.toArray(schema);
+ writeInputFrameWithMTD("A", A, false, schema,
OutputInfo.BinaryBlockOutputInfo);
+ federatedConstruction(execMode, FRAME_TEST_FILE_NAME, "A",
schema);
}
- public void federatedConstruction(Types.ExecMode execMode) {
+ public void federatedConstruction(Types.ExecMode execMode, String
testFile, String inputIdentifier, Types.ValueType[] schema) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
- Thread t = null;
+ Thread t;
- getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
- // write input matrix
- double[][] A = getRandomMatrix(rows, cols, -1, 1, 1, 1234);
- writeInputMatrixWithMTD("A", A, false, new
MatrixCharacteristics(rows, cols, blocksize, rows * cols));
-
int port = getRandomAvailablePort();
t = startLocalFedWorker(port);
@@ -93,8 +132,8 @@ public class FederatedConstructionTest extends
AutomatedTestBase {
// we need the reference file to not be written to hdfs, so we
get the correct format
rtplatform = Types.ExecMode.SINGLE_NODE;
// Run reference dml script with normal matrix
- fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-args", input("A"), expected("B")};
+ fullDMLScriptName = HOME + testFile + "Reference.dml";
+ programArgs = new String[] {"-args", input(inputIdentifier),
expected("B")};
runTest(true, false, null, -1);
// reference file should not be written to hdfs
@@ -102,13 +141,16 @@ public class FederatedConstructionTest extends
AutomatedTestBase {
if(rtplatform == Types.ExecMode.SPARK) {
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
}
- fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-args", "\"localhost:" + port +
"/" + input("A") + "\"", Integer.toString(rows),
- Integer.toString(cols), Integer.toString(rows * 2),
output("B")};
+ fullDMLScriptName = HOME + testFile + ".dml";
+ programArgs = new String[] {"-args", "\"localhost:" + port +
"/" + input(inputIdentifier) + "\"",
+ Integer.toString(rows), Integer.toString(cols),
Integer.toString(rows * 2), output("B")};
runTest(true, false, null, -1);
// compare via files
- compareResults(1e-12);
+ if (schema != null)
+ compareResults(schema);
+ else
+ compareResults(1e-12);
TestUtils.shutdownThread(t);
rtplatform = platformOld;
diff --git
a/src/test/scripts/functions/federated/FederatedConstructionTestReference.dml
b/src/test/scripts/functions/federated/FederatedFrameConstructionTest.dml
similarity index 87%
copy from
src/test/scripts/functions/federated/FederatedConstructionTestReference.dml
copy to src/test/scripts/functions/federated/FederatedFrameConstructionTest.dml
index ba8a201..df3078d 100644
---
a/src/test/scripts/functions/federated/FederatedConstructionTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedFrameConstructionTest.dml
@@ -19,5 +19,5 @@
#
#-------------------------------------------------------------
-A = rbind(read($1), read($1))
-write(A, $2)
+A = federated(type="Frame", addresses=list($1, $1), ranges=list(list(0, 0),
list($2, $3), list($2, 0), list($4, $3)))
+write(A, $5)
diff --git
a/src/test/scripts/functions/federated/FederatedConstructionTestReference.dml
b/src/test/scripts/functions/federated/FederatedFrameConstructionTestReference.dml
similarity index 100%
copy from
src/test/scripts/functions/federated/FederatedConstructionTestReference.dml
copy to
src/test/scripts/functions/federated/FederatedFrameConstructionTestReference.dml
diff --git a/src/test/scripts/functions/federated/FederatedConstructionTest.dml
b/src/test/scripts/functions/federated/FederatedMatrixConstructionTest.dml
similarity index 100%
rename from src/test/scripts/functions/federated/FederatedConstructionTest.dml
rename to
src/test/scripts/functions/federated/FederatedMatrixConstructionTest.dml
diff --git
a/src/test/scripts/functions/federated/FederatedConstructionTestReference.dml
b/src/test/scripts/functions/federated/FederatedMatrixConstructionTestReference.dml
similarity index 100%
rename from
src/test/scripts/functions/federated/FederatedConstructionTestReference.dml
rename to
src/test/scripts/functions/federated/FederatedMatrixConstructionTestReference.dml