sebwrede commented on a change in pull request #1075:
URL: https://github.com/apache/systemds/pull/1075#discussion_r498664209
##########
File path:
src/test/java/org/apache/sysds/test/component/paramserv/SerializationTest.java
##########
@@ -29,40 +35,90 @@
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.ProgramConverter;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+@RunWith(Parameterized.class)
public class SerializationTest {
+ private int _named;
+
+ @Parameterized.Parameters
+ public static Collection named() {
+ return Arrays.asList(new Object[][] {
+ { 0 },
+ { 1 }
+ });
+ }
+
+ public SerializationTest(Integer named) {
+ this._named = named;
+ }
@Test
- public void serializeUnnamedListObject() {
+ public void serializeListObject() {
MatrixObject mo1 = generateDummyMatrix(10);
MatrixObject mo2 = generateDummyMatrix(20);
IntObject io = new IntObject(30);
- ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io));
- String serial = ProgramConverter.serializeDataObject("key", lo);
- Object[] obj = ProgramConverter.parseDataObject(serial);
- ListObject actualLO = (ListObject) obj[1];
- MatrixObject actualMO1 = (MatrixObject) actualLO.slice(0);
- MatrixObject actualMO2 = (MatrixObject) actualLO.slice(1);
- IntObject actualIO = (IntObject) actualLO.slice(2);
-
Assert.assertArrayEquals(mo1.acquireRead().getDenseBlockValues(),
actualMO1.acquireRead().getDenseBlockValues(), 0);
-
Assert.assertArrayEquals(mo2.acquireRead().getDenseBlockValues(),
actualMO2.acquireRead().getDenseBlockValues(), 0);
- Assert.assertEquals(io.getLongValue(), actualIO.getLongValue());
+ ListObject lot = new ListObject(Arrays.asList(mo2));
+ ListObject lo;
+
+ if (_named == 1)
+ lo = new ListObject(Arrays.asList(mo1, lot, io),
Arrays.asList("e1", "e2", "e3"));
+ else
+ lo = new ListObject(Arrays.asList(mo1, lot, io));
+
+ ListObject loDeserialized = null;
+
+ // serialize and back
+ try {
+ ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ ObjectOutputStream out = new ObjectOutputStream(bos);
+ out.writeObject(lo);
+ out.flush();
+ byte[] loBytes = bos.toByteArray();
+
+ ByteArrayInputStream bis = new
ByteArrayInputStream(loBytes);
+ ObjectInput in = new ObjectInputStream(bis);
+ loDeserialized = (ListObject) in.readObject();
Review comment:
Nice. This is a different approach than I used in
Privacy/ReadWriteTest.java, but it does the job well.
##########
File path:
src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
##########
@@ -280,4 +301,124 @@ public String toString() {
sb.append(")");
return sb.toString();
}
+
+ /**
+ * Redirects the default java serialization via externalizable to our
default
+ * hadoop writable serialization for efficient broadcast/rdd
serialization.
+ *
+ * @param out object output
+ * @throws IOException if IOException occurs
+ */
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ // write out length
+ out.writeInt(getLength());
+ // write out num cacheable
+ out.writeInt(_nCacheable);
+
+ // write out names for named list
+ out.writeBoolean(getNames() != null);
+ if(getNames() != null) {
+ for (int i = 0; i < getLength(); i++) {
+ out.writeObject(_names.get(i));
+ }
+ }
+
+ // write out data
+ for(int i = 0; i < getLength(); i++) {
+ Data d = getData(i);
+ out.writeObject(d.getDataType());
+ out.writeObject(d.getValueType());
+ switch(d.getDataType()) {
+ case LIST:
+ ListObject lo = (ListObject) d;
+ out.writeObject(lo);
+ break;
+ case MATRIX:
+ MatrixObject mo = (MatrixObject) d;
+ MetaDataFormat md = (MetaDataFormat)
mo.getMetaData();
+ DataCharacteristics dc =
md.getDataCharacteristics();
+
+ out.writeObject(dc.getRows());
+ out.writeObject(dc.getCols());
+ out.writeObject(dc.getBlocksize());
+ out.writeObject(dc.getNonZeros());
+ out.writeObject(md.getFileFormat());
+
out.writeObject(mo.acquireReadAndRelease());
+ break;
+ case SCALAR:
+ ScalarObject so = (ScalarObject) d;
+ out.writeObject(so.getStringValue());
Review comment:
By only writing the string value, you write the most essential
information, but you risk throwing away other relevant information. At the
moment, it is only the privacy constraints that are not included. This is okay
for now, but you can easily add it by getting the privacy constraint from the
ScalarObject:
`so.getPrivacyConstraint()`
This can then be written as an object.
##########
File path:
src/test/java/org/apache/sysds/test/component/paramserv/SerializationTest.java
##########
@@ -29,40 +35,90 @@
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.ProgramConverter;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+@RunWith(Parameterized.class)
public class SerializationTest {
+ private int _named;
+
+ @Parameterized.Parameters
+ public static Collection named() {
+ return Arrays.asList(new Object[][] {
+ { 0 },
+ { 1 }
+ });
+ }
+
+ public SerializationTest(Integer named) {
+ this._named = named;
+ }
@Test
- public void serializeUnnamedListObject() {
+ public void serializeListObject() {
MatrixObject mo1 = generateDummyMatrix(10);
MatrixObject mo2 = generateDummyMatrix(20);
IntObject io = new IntObject(30);
- ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io));
- String serial = ProgramConverter.serializeDataObject("key", lo);
- Object[] obj = ProgramConverter.parseDataObject(serial);
- ListObject actualLO = (ListObject) obj[1];
- MatrixObject actualMO1 = (MatrixObject) actualLO.slice(0);
- MatrixObject actualMO2 = (MatrixObject) actualLO.slice(1);
- IntObject actualIO = (IntObject) actualLO.slice(2);
-
Assert.assertArrayEquals(mo1.acquireRead().getDenseBlockValues(),
actualMO1.acquireRead().getDenseBlockValues(), 0);
-
Assert.assertArrayEquals(mo2.acquireRead().getDenseBlockValues(),
actualMO2.acquireRead().getDenseBlockValues(), 0);
- Assert.assertEquals(io.getLongValue(), actualIO.getLongValue());
+ ListObject lot = new ListObject(Arrays.asList(mo2));
+ ListObject lo;
+
+ if (_named == 1)
+ lo = new ListObject(Arrays.asList(mo1, lot, io),
Arrays.asList("e1", "e2", "e3"));
+ else
+ lo = new ListObject(Arrays.asList(mo1, lot, io));
+
+ ListObject loDeserialized = null;
+
+ // serialize and back
+ try {
+ ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ ObjectOutputStream out = new ObjectOutputStream(bos);
+ out.writeObject(lo);
+ out.flush();
+ byte[] loBytes = bos.toByteArray();
+
+ ByteArrayInputStream bis = new
ByteArrayInputStream(loBytes);
+ ObjectInput in = new ObjectInputStream(bis);
+ loDeserialized = (ListObject) in.readObject();
+ }
+ catch(Exception e){
+ assert(false);
Review comment:
You may want to print the exception in some way so that it is visible in
the log what went wrong in the test. Otherwise it could happen that it just
says AssertionError without explaining which error.
##########
File path:
src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
##########
@@ -280,4 +301,124 @@ public String toString() {
sb.append(")");
return sb.toString();
}
+
+ /**
+ * Redirects the default java serialization via externalizable to our
default
+ * hadoop writable serialization for efficient broadcast/rdd
serialization.
+ *
+ * @param out object output
+ * @throws IOException if IOException occurs
+ */
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ // write out length
+ out.writeInt(getLength());
+ // write out num cacheable
+ out.writeInt(_nCacheable);
+
+ // write out names for named list
+ out.writeBoolean(getNames() != null);
+ if(getNames() != null) {
+ for (int i = 0; i < getLength(); i++) {
+ out.writeObject(_names.get(i));
+ }
+ }
+
+ // write out data
+ for(int i = 0; i < getLength(); i++) {
+ Data d = getData(i);
+ out.writeObject(d.getDataType());
+ out.writeObject(d.getValueType());
+ switch(d.getDataType()) {
+ case LIST:
+ ListObject lo = (ListObject) d;
+ out.writeObject(lo);
+ break;
+ case MATRIX:
+ MatrixObject mo = (MatrixObject) d;
+ MetaDataFormat md = (MetaDataFormat)
mo.getMetaData();
+ DataCharacteristics dc =
md.getDataCharacteristics();
+
+ out.writeObject(dc.getRows());
+ out.writeObject(dc.getCols());
+ out.writeObject(dc.getBlocksize());
+ out.writeObject(dc.getNonZeros());
+ out.writeObject(md.getFileFormat());
+
out.writeObject(mo.acquireReadAndRelease());
+ break;
+ case SCALAR:
+ ScalarObject so = (ScalarObject) d;
+ out.writeObject(so.getStringValue());
+ break;
+ default:
+ throw new DMLRuntimeException("Unable
to serialize datatype " + dataType);
+ }
+ }
+ }
+
+ /**
+ * Redirects the default java serialization via externalizable to our
default
+ * hadoop writable serialization for efficient broadcast/rdd
deserialization.
+ *
+ * @param in object input
+ * @throws IOException if IOException occurs
+ */
+ @Override
+ public void readExternal(ObjectInput in) throws IOException,
ClassNotFoundException {
+ // read in length
+ int length = in.readInt();
+ // read in num cacheable
+ _nCacheable = in.readInt();
+
+ // read in names
+ Boolean names = in.readBoolean();
+ if(names) {
+ _names = new ArrayList<>();
+ for (int i = 0; i < length; i++) {
+ _names.add((String) in.readObject());
+ }
+ }
+
+ // read in data
+ for(int i = 0; i < length; i++) {
+ DataType dataType = (DataType) in.readObject();
+ ValueType valueType = (ValueType) in.readObject();
+ Data d;
+ switch(dataType) {
+ case LIST:
+ d = (ListObject) in.readObject();
+ break;
+ case MATRIX:
+ long rows = (long) in.readObject();
+ long cols = (long) in.readObject();
+ int blockSize = (int) in.readObject();
+ long nonZeros = (long) in.readObject();
+ Types.FileFormat fileFormat =
(Types.FileFormat) in.readObject();
+
+ // construct objects and set meta data
+ MatrixCharacteristics
matrixCharacteristics = new MatrixCharacteristics(rows, cols, blockSize,
nonZeros);
+ MetaDataFormat metaDataFormat = new
MetaDataFormat(matrixCharacteristics, fileFormat);
+ MatrixBlock matrixBlock = (MatrixBlock)
in.readObject();
+
+ d = new MatrixObject(valueType,
Dag.getNextUniqueVarname(Types.DataType.MATRIX), metaDataFormat, matrixBlock);
Review comment:
I think that this will also ignore the privacy constraints. I have a
different PR where the privacy constraints are set in the
MatrixObject(MatrixObject mo) constructor.
##########
File path:
src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
##########
@@ -91,16 +95,87 @@ public ParamservBuiltinCPInstruction(Operator op,
LinkedHashMap<String, String>
@Override
public void processInstruction(ExecutionContext ec) {
- PSModeType mode = getPSMode();
- switch (mode) {
- case LOCAL:
- runLocally(ec, mode);
- break;
- case REMOTE_SPARK:
- runOnSpark((SparkExecutionContext) ec, mode);
- break;
- default:
- throw new
DMLRuntimeException(String.format("Paramserv func: not support mode %s", mode));
+ // check if the input is federated
+ if(ec.getMatrixObject(getParam(PS_FEATURES)).isFederated() ||
+
ec.getMatrixObject(getParam(PS_LABELS)).isFederated()) {
+ runFederated(ec);
+ }
+ // if not federated check mode
+ else {
+ PSModeType mode = getPSMode();
+ switch (mode) {
+ case LOCAL:
+ runLocally(ec, mode);
+ break;
+ case REMOTE_SPARK:
+ runOnSpark((SparkExecutionContext) ec,
mode);
+ break;
+ default:
+ throw new
DMLRuntimeException(String.format("Paramserv func: not support mode %s", mode));
+ }
+ }
+ }
+
+ private void runFederated(ExecutionContext ec) {
+ System.out.println("PARAMETER SERVER");
+ System.out.println("[+] Running in federated mode");
+
+ // get inputs
+ PSFrequency freq = getFrequency();
+ PSUpdateType updateType = getUpdateType();
+ String updFunc = getParam(PS_UPDATE_FUN);
+ String aggFunc = getParam(PS_AGGREGATION_FUN);
+
+ // partition federated data
+ DataPartitionFederatedScheme.Result result = new
FederatedDataPartitioner(Statement.FederatedPSScheme.KEEP_DATA_ON_WORKER)
+
.doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)),
ec.getMatrixObject(getParam(PS_LABELS)));
+ List<MatrixObject> pFeatures = result.pFeatures;
+ List<MatrixObject> pLabels = result.pLabels;
+ int workerNum = result.workerNum;
+
+ // setup threading
+ BasicThreadFactory factory = new BasicThreadFactory.Builder()
+
.namingPattern("workers-pool-thread-%d").build();
+ ExecutorService es = Executors.newFixedThreadPool(workerNum,
factory);
+
+ // Get the compiled execution context
+ LocalVariableMap newVarsMap = createVarsMap(ec);
+ // Level of par is 1 because one worker will be launched per
task
+ // TODO: Fix recompilation
+ ExecutionContext newEC =
ParamservUtils.createExecutionContext(ec, newVarsMap, updFunc, aggFunc, 1,
true);
+ // Create workers' execution context
+ List<ExecutionContext> federatedWorkerECs =
ParamservUtils.copyExecutionContext(newEC, workerNum);
Review comment:
Will this copy the entire execution context or only the part that is
required by the workers to run the instruction on the federated data?
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
##########
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.lops.compile.Dag;
+import org.apache.sysds.runtime.meta.MetaData;
+import org.apache.sysds.runtime.meta.MetaDataFormat;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+
+public class KeepDataOnWorkerFederatedScheme extends
DataPartitionFederatedScheme {
+ @Override
+ public Result doPartitioning(MatrixObject features, MatrixObject
labels) {
+ List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
+ List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+ return new Result(pFeatures, pLabels, pFeatures.size());
+ }
+
+ /**
+ * Takes a row federated Matrix and slices it into a matrix for each
worker
+ *
+ * @param fedMatrix the federated input matrix
+ */
+ private List<MatrixObject> sliceFederatedMatrix(MatrixObject fedMatrix)
{
+ if (fedMatrix.isFederated(FederationMap.FType.ROW)
+ ||
fedMatrix.isFederated(FederationMap.FType.ROW)) {
Review comment:
Why is this done twice?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]