This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 8e832ac085 [SYSTEMDS-3185] Caching of serialized federated responses
8e832ac085 is described below
commit 8e832ac085b14aa63ecd8a5baee463ac9dfa53bc
Author: ywcb00 <[email protected]>
AuthorDate: Sun May 15 17:05:53 2022 +0200
[SYSTEMDS-3185] Caching of serialized federated responses
Closes #1611.
---
.../federated/FederatedResponse.java | 21 ++++-
.../federated/FederatedStatistics.java | 72 ++++++++++++++---
.../controlprogram/federated/FederatedWorker.java | 38 +++++++++
.../federated/FederatedWorkerHandler.java | 3 +-
.../runtime/instructions/FEDInstructionParser.java | 1 -
.../instructions/fed/TsmmFEDInstruction.java | 2 -
.../apache/sysds/runtime/lineage/LineageCache.java | 64 +++++++++++++++
.../sysds/runtime/lineage/LineageCacheEntry.java | 35 +++++++-
.../sysds/runtime/lineage/LineageItemUtils.java | 7 ++
.../java/org/apache/sysds/utils/Statistics.java | 3 +-
.../FederatedLineageTraceReuseTest.java | 6 +-
...t.java => FederatedSerializationReuseTest.java} | 93 +++++++++++-----------
.../FederatedSerializationReuseTest.dml | 57 +++++++++++++
13 files changed, 329 insertions(+), 73 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
index b8cb55851c..ff059a460a 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
@@ -27,6 +27,7 @@ import java.util.concurrent.atomic.LongAdder;
import org.apache.commons.lang.exception.ExceptionUtils;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.privacy.CheckedConstraintsLog;
import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
@@ -42,23 +43,35 @@ public class FederatedResponse implements Serializable {
private ResponseType _status;
private Object[] _data;
private Map<PrivacyLevel,LongAdder> checkedConstraints;
+
+ private transient LineageItem _linItem = null; // not included in
serialized object
public FederatedResponse(ResponseType status) {
- this(status, null);
+ this(status, null, null);
}
public FederatedResponse(ResponseType status, Object[] data) {
+ this(status, data, null);
+ }
+
+ public FederatedResponse(ResponseType status, Object[] data,
LineageItem linItem) {
_status = status;
_data = data;
if( _status == ResponseType.SUCCESS && data == null )
_status = ResponseType.SUCCESS_EMPTY;
+ _linItem = linItem;
}
-
+
public FederatedResponse(FederatedResponse.ResponseType status, Object
data) {
+ this(status, data, null);
+ }
+
+ public FederatedResponse(FederatedResponse.ResponseType status, Object
data, LineageItem linItem) {
_status = status;
_data = new Object[] {data};
if(_status == ResponseType.SUCCESS && data == null)
_status = ResponseType.SUCCESS_EMPTY;
+ _linItem = linItem;
}
public boolean isSuccessful() {
@@ -126,4 +139,8 @@ public class FederatedResponse implements Serializable {
if ( checkedConstraints != null &&
!checkedConstraints.isEmpty() )
CheckedConstraintsLog.addCheckedConstraints(checkedConstraints);
}
+
+ public LineageItem getLineageItem() {
+ return _linItem;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
index 20dfe27e66..58a9480266 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
@@ -46,6 +46,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.Fed
import
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.GCStatsCollection;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.LineageCacheStatsCollection;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.MultiTenantStatsCollection;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
@@ -81,6 +82,8 @@ public class FederatedStatistics {
private static final LongAdder fedReuseReadBytesCount = new LongAdder();
private static final LongAdder fedPutLineageCount = new LongAdder();
private static final LongAdder fedPutLineageItems = new LongAdder();
+ private static final LongAdder fedSerializationReuseCount = new
LongAdder();
+ private static final LongAdder fedSerializationReuseBytes = new
LongAdder();
public static synchronized void incFederated(RequestType rqt,
List<Object> data){
switch (rqt) {
@@ -159,6 +162,8 @@ public class FederatedStatistics {
fedReuseReadBytesCount.reset();
fedPutLineageCount.reset();
fedPutLineageItems.reset();
+ fedSerializationReuseCount.reset();
+ fedSerializationReuseBytes.reset();
}
public static String displayFedIOExecStatistics() {
@@ -204,6 +209,15 @@ public class FederatedStatistics {
return sb.toString();
}
+ public static String displayFedWorkerStats() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(displayFedLookupTableStats());
+ sb.append(displayFedReuseReadStats());
+ sb.append(displayFedPutLineageStats());
+ sb.append(displayFedSerializationReuseStats());
+ return sb.toString();
+ }
+
public static String displayStatistics(int numHeavyHitters) {
StringBuilder sb = new StringBuilder();
FedStatsCollection fedStats = collectFedStats();
@@ -251,6 +265,7 @@ public class FederatedStatistics {
sb.append(displayFedLookupTableStats(mtsc.fLTGetCount,
mtsc.fLTEntryCount, mtsc.fLTGetTime));
sb.append(displayFedReuseReadStats(mtsc.reuseReadHits,
mtsc.reuseReadBytes));
sb.append(displayFedPutLineageStats(mtsc.putLineageCount,
mtsc.putLineageItems));
+
sb.append(displayFedSerializationReuseStats(mtsc.serializationReuseCount,
mtsc.serializationReuseBytes));
return sb.toString();
}
@@ -385,6 +400,14 @@ public class FederatedStatistics {
return fedPutLineageItems.longValue();
}
+ public static long getFedSerializationReuseCount() {
+ return fedSerializationReuseCount.longValue();
+ }
+
+ public static long getFedSerializationReuseBytes() {
+ return fedSerializationReuseBytes.longValue();
+ }
+
public static void incFedLookupTableGetCount() {
fedLookupTableGetCount.increment();
}
@@ -414,6 +437,11 @@ public class FederatedStatistics {
fedPutLineageItems.add(serializedLineage.lines().count());
}
+ public static void aggFedSerializationReuse(long bytes) {
+ fedSerializationReuseCount.increment();
+ fedSerializationReuseBytes.add(bytes);
+ }
+
public static String displayFedLookupTableStats() {
return
displayFedLookupTableStats(fedLookupTableGetCount.longValue(),
fedLookupTableEntryCount.longValue(),
fedLookupTableGetTime.doubleValue() / 1000000000);
@@ -421,25 +449,24 @@ public class FederatedStatistics {
public static String displayFedLookupTableStats(long fltGetCount, long
fltEntryCount, double fltGetTime) {
if(fltGetCount > 0) {
- StringBuilder sb = new StringBuilder();
- sb.append("Fed LookupTable (Get, Entries):\t" +
- fltGetCount + "/" + fltEntryCount + ".\n");
- return sb.toString();
+ return InstructionUtils.concatStrings(
+ "Fed LookupTable (Get, Entries):\t",
+ String.valueOf(fltGetCount), "/",
String.valueOf(fltEntryCount),".\n");
}
return "";
}
public static String displayFedReuseReadStats() {
- return
displayFedReuseReadStats(fedReuseReadHitCount.longValue(),
+ return displayFedReuseReadStats(
+ fedReuseReadHitCount.longValue(),
fedReuseReadBytesCount.longValue());
}
public static String displayFedReuseReadStats(long rrHits, long
rrBytes) {
if(rrHits > 0) {
- StringBuilder sb = new StringBuilder();
- sb.append("Fed ReuseRead (Hits, Bytes):\t" +
- rrHits + "/" + rrBytes + ".\n");
- return sb.toString();
+ return InstructionUtils.concatStrings(
+ "Fed ReuseRead (Hits, Bytes):\t",
+ String.valueOf(rrHits), "/",
String.valueOf(rrBytes), ".\n");
}
return "";
}
@@ -451,10 +478,23 @@ public class FederatedStatistics {
public static String displayFedPutLineageStats(long plCount, long
plItems) {
if(plCount > 0) {
- StringBuilder sb = new StringBuilder();
- sb.append("Fed PutLineage (Count, Items):\t" +
- plCount + "/" + plItems + ".\n");
- return sb.toString();
+ return InstructionUtils.concatStrings(
+ "Fed PutLineage (Count, Items):\t",
+ String.valueOf(plCount), "/",
String.valueOf(plItems), ".\n");
+ }
+ return "";
+ }
+
+ public static String displayFedSerializationReuseStats() {
+ return
displayFedSerializationReuseStats(fedSerializationReuseCount.longValue(),
+ fedSerializationReuseBytes.longValue());
+ }
+
+ public static String displayFedSerializationReuseStats(long srCount,
long srBytes) {
+ if(srCount > 0) {
+ return InstructionUtils.concatStrings(
+ "Fed SerialReuse (Count, Bytes):\t",
+ String.valueOf(srCount), "/",
String.valueOf(srBytes), ".\n");
}
return "";
}
@@ -619,6 +659,8 @@ public class FederatedStatistics {
reuseReadBytes = getFedReuseReadBytesCount();
putLineageCount = getFedPutLineageCount();
putLineageItems = getFedPutLineageItems();
+ serializationReuseCount =
getFedSerializationReuseCount();
+ serializationReuseBytes =
getFedSerializationReuseBytes();
}
private void aggregate(MultiTenantStatsCollection that)
{
@@ -629,6 +671,8 @@ public class FederatedStatistics {
reuseReadBytes += that.reuseReadBytes;
putLineageCount += that.putLineageCount;
putLineageItems += that.putLineageItems;
+ serializationReuseCount +=
that.serializationReuseCount;
+ serializationReuseBytes +=
that.serializationReuseBytes;
}
private long fLTGetCount = 0;
@@ -638,6 +682,8 @@ public class FederatedStatistics {
private long reuseReadBytes = 0;
private long putLineageCount = 0;
private long putLineageItems = 0;
+ private long serializationReuseCount = 0;
+ private long serializationReuseBytes = 0;
}
private CacheStatsCollection cacheStats = new
CacheStatsCollection();
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
index d090d0553c..a41f656524 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
@@ -32,8 +32,12 @@ import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysds.runtime.lineage.LineageCache;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
+import org.apache.sysds.runtime.lineage.LineageItem;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
@@ -133,6 +137,40 @@ public class FederatedWorker {
else
return ctx.alloc().heapBuffer(initCapacity);
}
+
+ @Override
+ protected void encode(ChannelHandlerContext ctx, Serializable
msg, ByteBuf out) throws Exception {
+ LineageItem objLI = null;
+ boolean linReusePossible = (!ReuseCacheType.isNone() &&
msg instanceof FederatedResponse);
+ if(linReusePossible) {
+ FederatedResponse response =
(FederatedResponse)msg;
+ if(response.getData() != null &&
response.getData().length != 0
+ && response.getData()[0] instanceof
CacheBlock) {
+ objLI = response.getLineageItem();
+
+ byte[] cachedBytes =
LineageCache.reuseSerialization(objLI);
+ if(cachedBytes != null) {
+ out.writeBytes(cachedBytes);
+ return;
+ }
+ }
+ }
+
+ linReusePossible &= (objLI != null);
+
+ int startIdx = linReusePossible ? out.writerIndex() : 0;
+ long t0 = linReusePossible ? System.nanoTime() : 0;
+ super.encode(ctx, msg, out);
+ long t1 = linReusePossible ? System.nanoTime() : 0;
+
+ if(linReusePossible) {
+ out.readerIndex(startIdx);
+ byte[] dst = new byte[out.readableBytes()];
+ out.readBytes(dst);
+ LineageCache.putSerializedObject(dst, objLI,
(t1 - t0));
+ out.resetReaderIndex();
+ }
+ }
}
private ChannelInitializer<SocketChannel> createChannel(boolean ssl) {
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 769dabc173..4c90c74b1b 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -443,7 +443,8 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
case TENSOR:
case MATRIX:
case FRAME:
- return new
FederatedResponse(ResponseType.SUCCESS, ((CacheableData<?>)
dataObject).acquireReadAndRelease());
+ return new
FederatedResponse(ResponseType.SUCCESS, ((CacheableData<?>)
dataObject).acquireReadAndRelease(),
+ ReuseCacheType.isNone() ? null
: ec.getLineage().get(String.valueOf(request.getID())));
case LIST:
return new
FederatedResponse(ResponseType.SUCCESS, ((ListObject) dataObject).getData());
case SCALAR:
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 58ab43daba..8e5e673e1d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -21,7 +21,6 @@ package org.apache.sysds.runtime.instructions;
import org.apache.sysds.lops.Append;
import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.fed.AggregateBinaryFEDInstruction;
import
org.apache.sysds.runtime.instructions.fed.AggregateTernaryFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.AggregateUnaryFEDInstruction;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
index 7ad48cfa25..5ebe7f6295 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
@@ -31,9 +31,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.Reques
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
-import org.apache.sysds.runtime.instructions.CPInstructionParser;
import org.apache.sysds.runtime.instructions.InstructionUtils;
-import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index 3ea7d3d143..c1135cdb54 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -398,6 +398,38 @@ public class LineageCache
return false;
}
+ public static byte[] reuseSerialization(LineageItem objLI) {
+ if (ReuseCacheType.isNone() || objLI == null)
+ return null;
+
+ LineageItem li =
LineageItemUtils.getSerializedFedResponseLineageItem(objLI);
+
+ LineageCacheEntry e = null;
+ synchronized(_cache) {
+ if(LineageCache.probe(li)) {
+ e = LineageCache.getIntern(li);
+ }
+ else {
+ putIntern(li, DataType.UNKNOWN, null, null, 0);
+ return null; // direct return after placing the
placeholder
+ }
+ }
+
+ if(e != null && e.isSerializedBytes()) {
+ byte[] sBytes = e.getSerializedBytes(); // waiting if
the value is not set yet
+ if (sBytes == null && e.getCacheStatus() ==
LineageCacheStatus.NOTCACHED)
+ return null; // the executing thread removed
this entry from cache
+
+ if (DMLScript.STATISTICS) { // increment statistics
+
LineageCacheStatistics.incrementSavedComputeTime(e._computeTime);
+
FederatedStatistics.aggFedSerializationReuse(sBytes.length);
+ }
+
+ return sBytes;
+ }
+ return null;
+ }
+
public static boolean probe(LineageItem key) {
//TODO problematic as after probe the matrix might be kicked
out of cache
boolean p = _cache.containsKey(key); // in cache or in disk
@@ -695,6 +727,38 @@ public class LineageCache
}
}
+ public static void putSerializedObject(byte[] serialBytes, LineageItem
objLI, long computetime) {
+ if(ReuseCacheType.isNone())
+ return;
+
+ LineageItem li =
LineageItemUtils.getSerializedFedResponseLineageItem(objLI);
+
+ LineageCacheEntry entry = getIntern(li);
+
+ if(entry != null && serialBytes != null) {
+ synchronized(_cache) {
+ long size = serialBytes.length;
+
+ // remove the placeholder if the entry is
bigger than the cache.
+ if (size >
LineageCacheEviction.getCacheLimit()) {
+ removePlaceholder(li);
+ }
+
+ // make space for the data
+ if
(!LineageCacheEviction.isBelowThreshold(size))
+ LineageCacheEviction.makeSpace(_cache,
size);
+ LineageCacheEviction.updateSize(size, true);
+
+ entry.setValue(serialBytes, computetime);
+ }
+ }
+ else {
+ synchronized(_cache) {
+ removePlaceholder(li);
+ }
+ }
+ }
+
public static void resetCache() {
synchronized (_cache) {
_cache.clear();
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
index 962c7d5307..b8e30cb4c3 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
@@ -33,6 +33,7 @@ public class LineageCacheEntry {
protected final DataType _dt;
protected MatrixBlock _MBval;
protected ScalarObject _SOval;
+ protected byte[] _serialBytes; // serialized bytes of a federated
response
protected long _computeTime;
protected long _timestamp = 0;
protected LineageCacheStatus _status;
@@ -88,7 +89,22 @@ public class LineageCacheEntry {
throw new DMLRuntimeException(ex);
}
}
-
+
+ public synchronized byte[] getSerializedBytes() {
+ try {
+ // wait until other thread completes operation
+ // in order to avoid redundant computation
+ while(_status == LineageCacheStatus.EMPTY) {
+ wait();
+ }
+ // comes here if data is placed or the entry is removed
by the running thread
+ return _serialBytes;
+ }
+ catch( InterruptedException ex ) {
+ throw new DMLRuntimeException(ex);
+ }
+ }
+
public synchronized LineageCacheStatus getCacheStatus() {
return _status;
}
@@ -113,7 +129,7 @@ public class LineageCacheEntry {
}
public boolean isNullVal() {
- return(_MBval == null && _SOval == null && _gpuObject == null);
+ return(_MBval == null && _SOval == null && _gpuObject == null
&& _serialBytes == null);
}
public boolean isMatrixValue() {
@@ -123,7 +139,11 @@ public class LineageCacheEntry {
public boolean isScalarValue() {
return _dt.isScalar();
}
-
+
+ public boolean isSerializedBytes() {
+ return _dt.isUnknown() &&
_key.getOpcode().equals(LineageItemUtils.SERIALIZATION_OPCODE);
+ }
+
public synchronized void setValue(MatrixBlock val, long computetime) {
_MBval = val;
_gpuObject = null; //Matrix block and gpu object cannot coexist
@@ -154,6 +174,14 @@ public class LineageCacheEntry {
//resume all threads waiting for val
notifyAll();
}
+
+ public synchronized void setValue(byte[] serialBytes, long computetime)
{
+ _serialBytes = serialBytes;
+ _computeTime = computetime;
+ _status = isNullVal() ? LineageCacheStatus.EMPTY :
LineageCacheStatus.CACHED;
+ // resume all threads waiting for val
+ notifyAll();
+ }
public synchronized GPUObject getGPUObject() {
return _gpuObject;
@@ -162,6 +190,7 @@ public class LineageCacheEntry {
protected synchronized void setNullValues() {
_MBval = null;
_SOval = null;
+ _serialBytes = null;
_status = LineageCacheStatus.EMPTY;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
index c5e60ffa01..8372d711f2 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
@@ -76,6 +76,9 @@ import java.util.stream.Collectors;
public class LineageItemUtils {
public static final String LPLACEHOLDER = "IN#";
+
+ // opcode to represent the serialized bytes of a federated response in
lineage cache
+ public static final String SERIALIZATION_OPCODE = "serialize";
public static LineageItemType getType(String str) {
if (str.length() == 1) {
@@ -541,4 +544,8 @@ public class LineageItemUtils {
sb.append(true); //isLiteral = true
return new LineageItem(sb.toString());
}
+
+ public static LineageItem
getSerializedFedResponseLineageItem(LineageItem li) {
+ return new LineageItem(SERIALIZATION_OPCODE, new
LineageItem[]{li});
+ }
}
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java
b/src/main/java/org/apache/sysds/utils/Statistics.java
index b105ffaeeb..77b63a9921 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -654,8 +654,7 @@ public class Statistics
sb.append(ParForStatistics.displayStatistics());
sb.append(FederatedStatistics.displayFedIOExecStatistics());
-
sb.append(FederatedStatistics.displayFedLookupTableStats());
-
sb.append(FederatedStatistics.displayFedReuseReadStats());
+ sb.append(FederatedStatistics.displayFedWorkerStats());
sb.append(TransformStatistics.displayStatistics());
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedLineageTraceReuseTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedLineageTraceReuseTest.java
index ac4c861941..f01c196a63 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedLineageTraceReuseTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedLineageTraceReuseTest.java
@@ -235,12 +235,14 @@ public class FederatedLineageTraceReuseTest extends
MultiTenantTestBase {
boolean retVal = false;
int multiplier = 1;
int numInst = -1;
+ int serializationWrites = 0;
switch(opType) {
case EW_PLUS:
numInst = (execMode == ExecMode.SPARK) ? 1 : 2;
break;
case MM:
numInst = rowPartitioned ? 2 : 3;
+ serializationWrites = rowPartitioned ? 1 : 0;
break;
case PARFOR_ADD: // number of instructions times number
of iterations of the parfor loop
multiplier = 3;
@@ -249,8 +251,8 @@ public class FederatedLineageTraceReuseTest extends
MultiTenantTestBase {
}
retVal = outputLog.contains(LINCACHE_MULTILVL
+ Integer.toString(numInst *
(coordinatorProcesses.size()-1) * workerProcesses.size()) + "/");
- retVal &= outputLog.contains(LINCACHE_WRITES
- + Integer.toString((1 + numInst) *
workerProcesses.size()) + "/"); // read + instructions
+ retVal &= outputLog.contains(LINCACHE_WRITES // read +
instructions + serializations
+ + Integer.toString((1 + numInst + serializationWrites)
* workerProcesses.size()) + "/");
retVal &= outputLog.contains(FED_LINEAGEPUT
+ Integer.toString(coordinatorProcesses.size() *
workerProcesses.size() * multiplier) + "/");
return retVal;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedLineageTraceReuseTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedSerializationReuseTest.java
similarity index 75%
copy from
src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedLineageTraceReuseTest.java
copy to
src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedSerializationReuseTest.java
index ac4c861941..7fef68f000 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedLineageTraceReuseTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedSerializationReuseTest.java
@@ -40,11 +40,11 @@ import org.junit.runners.Parameterized;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
-public class FederatedLineageTraceReuseTest extends MultiTenantTestBase {
- private final static String TEST_NAME =
"FederatedLineageTraceReuseTest";
+public class FederatedSerializationReuseTest extends MultiTenantTestBase {
+ private final static String TEST_NAME =
"FederatedSerializationReuseTest";
private final static String TEST_DIR =
"functions/federated/multitenant/";
- private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedLineageTraceReuseTest.class.getSimpleName() + "/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
FederatedSerializationReuseTest.class.getSimpleName() + "/";
private final static double TOLERANCE = 0;
@@ -70,9 +70,9 @@ public class FederatedLineageTraceReuseTest extends
MultiTenantTestBase {
}
private enum OpType {
- EW_PLUS,
- MM,
- PARFOR_ADD,
+ EW_DIV,
+ ROWSUMS,
+ PARFOR_MULT,
}
@Override
@@ -82,39 +82,39 @@ public class FederatedLineageTraceReuseTest extends
MultiTenantTestBase {
}
@Test
- public void testElementWisePlusCP() {
- runLineageTraceReuseTest(OpType.EW_PLUS, 4,
ExecMode.SINGLE_NODE);
+ @Ignore
+ public void testElementWiseDivCP() {
+ runSerializationReuseTest(OpType.EW_DIV, 4,
ExecMode.SINGLE_NODE);
}
@Test
- @Ignore
- public void testElementWisePlusSP() {
- runLineageTraceReuseTest(OpType.EW_PLUS, 4, ExecMode.SPARK);
+ public void testElementWiseDivSP() {
+ runSerializationReuseTest(OpType.EW_DIV, 4, ExecMode.SPARK);
}
@Test
- public void testMatrixMultCP() {
- runLineageTraceReuseTest(OpType.MM, 4, ExecMode.SINGLE_NODE);
+ @Ignore
+ public void testRowSumsCP() {
+ runSerializationReuseTest(OpType.ROWSUMS, 4,
ExecMode.SINGLE_NODE);
}
@Test
- @Ignore // TODO: allow for reuse of respective spark instructions
- public void testMatrixMultSP() {
- runLineageTraceReuseTest(OpType.MM, 4, ExecMode.SPARK);
+ public void testRowSumsSP() {
+ runSerializationReuseTest(OpType.ROWSUMS, 4, ExecMode.SPARK);
}
@Test
- public void testParforAddCP() {
- runLineageTraceReuseTest(OpType.PARFOR_ADD, 3,
ExecMode.SINGLE_NODE);
+ public void testParforMultCP() {
+ runSerializationReuseTest(OpType.PARFOR_MULT, 3,
ExecMode.SINGLE_NODE);
}
@Test
@Ignore
- public void testParforAddSP() {
- runLineageTraceReuseTest(OpType.PARFOR_ADD, 3, ExecMode.SPARK);
+ public void testParforMultSP() {
+ runSerializationReuseTest(OpType.PARFOR_MULT, 3,
ExecMode.SPARK);
}
- private void runLineageTraceReuseTest(OpType opType, int
numCoordinators, ExecMode execMode) {
+ private void runSerializationReuseTest(OpType opType, int
numCoordinators, ExecMode execMode) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
@@ -203,21 +203,15 @@ public class FederatedLineageTraceReuseTest extends
MultiTenantTestBase {
private boolean checkForHeavyHitter(OpType opType, String outputLog,
ExecMode execMode) {
boolean retVal = false;
switch(opType) {
- case EW_PLUS:
- retVal = checkForHeavyHitter(outputLog,
"fed_+");
- if(execMode == ExecMode.SINGLE_NODE)
- retVal &=
checkForHeavyHitter(outputLog, "fed_uak+");
+ case EW_DIV:
+ retVal = checkForHeavyHitter(outputLog,
"fed_/");
break;
- case MM:
- retVal = checkForHeavyHitter(outputLog,
(execMode == ExecMode.SPARK) ? "fed_mapmm" : "fed_ba+*");
- retVal &= checkForHeavyHitter(outputLog,
"fed_r'");
- if(!rowPartitioned)
- retVal &=
checkForHeavyHitter(outputLog, (execMode == ExecMode.SPARK) ? "fed_rblk" :
"fed_uak+");
+ case ROWSUMS:
+ retVal = checkForHeavyHitter(outputLog,
"fed_uark+");
break;
- case PARFOR_ADD:
- retVal = checkForHeavyHitter(outputLog,
"fed_-");
- retVal &= checkForHeavyHitter(outputLog,
"fed_+");
- retVal &= checkForHeavyHitter(outputLog,
(execMode == ExecMode.SPARK) ? "fed_rblk" : "fed_uak+");
+ case PARFOR_MULT:
+ retVal = checkForHeavyHitter(outputLog,
"fed_*");
+ retVal &= checkForHeavyHitter(outputLog,
"fed_uack+");
break;
}
return retVal;
@@ -231,28 +225,33 @@ public class FederatedLineageTraceReuseTest extends
MultiTenantTestBase {
private boolean checkForReuses(OpType opType, String outputLog,
ExecMode execMode) {
final String LINCACHE_MULTILVL = "LinCache MultiLvl
(Ins/SB/Fn):\t";
final String LINCACHE_WRITES = "LinCache writes
(Mem/FS/Del):\t";
- final String FED_LINEAGEPUT = "Fed PutLineage (Count,
Items):\t";
+ final String SERIAL_REUSE = "Fed SerialReuse (Count, Bytes):\t";
boolean retVal = false;
- int multiplier = 1;
int numInst = -1;
+ int multiplier = 1;
+ int serializationWrites = 0;
switch(opType) {
- case EW_PLUS:
- numInst = (execMode == ExecMode.SPARK) ? 1 : 2;
+ case EW_DIV:
+ numInst = 1;
+ serializationWrites = 1;
break;
- case MM:
- numInst = rowPartitioned ? 2 : 3;
+ case ROWSUMS:
+ numInst = (execMode == ExecMode.SPARK) ? 0 : 1;
+ serializationWrites = 1;
break;
- case PARFOR_ADD: // number of instructions times number
of iterations of the parfor loop
- multiplier = 3;
- numInst = ((execMode == ExecMode.SPARK) ? 2 :
3) * multiplier;
+ case PARFOR_MULT: // number of instructions times
number of iterations of the parfor loop
+ multiplier = 3; // number of parfor iterations
+ numInst = (execMode == ExecMode.SPARK) ? 1 *
multiplier : 2 * multiplier;
+ serializationWrites = multiplier;
break;
}
retVal = outputLog.contains(LINCACHE_MULTILVL
+ Integer.toString(numInst *
(coordinatorProcesses.size()-1) * workerProcesses.size()) + "/");
- retVal &= outputLog.contains(LINCACHE_WRITES
- + Integer.toString((1 + numInst) *
workerProcesses.size()) + "/"); // read + instructions
- retVal &= outputLog.contains(FED_LINEAGEPUT
- + Integer.toString(coordinatorProcesses.size() *
workerProcesses.size() * multiplier) + "/");
+ retVal &= outputLog.contains(LINCACHE_WRITES // read +
instructions + serializations
+ + Integer.toString((1 + numInst + serializationWrites)
* workerProcesses.size()) + "/");
+ retVal &= outputLog.contains(SERIAL_REUSE
+ + Integer.toString(serializationWrites *
(coordinatorProcesses.size()-1)
+ * workerProcesses.size()) + "/");
return retVal;
}
}
diff --git
a/src/test/scripts/functions/federated/multitenant/FederatedSerializationReuseTest.dml
b/src/test/scripts/functions/federated/multitenant/FederatedSerializationReuseTest.dml
new file mode 100644
index 0000000000..b38da937b1
--- /dev/null
+++
b/src/test/scripts/functions/federated/multitenant/FederatedSerializationReuseTest.dml
@@ -0,0 +1,57 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+rowPart = $rP;
+
+if (rowPart) {
+ X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4,
0), list($rows, $cols)));
+} else {
+ X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4),
list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0,
3*($cols/4)), list($rows, $cols)));
+}
+
+testnum = $testnum;
+
+if(testnum == 0) { # EW_DIV
+ S = X / 2;
+}
+else if(testnum == 1) { # ROWSUMS
+ S = rowSums(X);
+}
+else if(testnum == 2) { # PARFOR_MULT
+ Y = rand(rows=$rows, cols=$cols, seed=1234);
+ while(FALSE) { }
+ numiter = 3;
+ Z = matrix(0, rows=numiter, cols=ncol(X));
+ parfor(i in 1:numiter) {
+ Y_vec = rowMeans(Y + i);
+ while(FALSE) { }
+ Z_tmp = X * Y_vec;
+ while(FALSE) { }
+ Z[i, ] = colSums(Z_tmp);
+ }
+ S = Z;
+}
+
+write(S, $out_S);