This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 8e832ac085 [SYSTEMDS-3185] Caching of serialized federated responses
8e832ac085 is described below

commit 8e832ac085b14aa63ecd8a5baee463ac9dfa53bc
Author: ywcb00 <[email protected]>
AuthorDate: Sun May 15 17:05:53 2022 +0200

    [SYSTEMDS-3185] Caching of serialized federated responses
    
    Closes #1611.
---
 .../federated/FederatedResponse.java               | 21 ++++-
 .../federated/FederatedStatistics.java             | 72 ++++++++++++++---
 .../controlprogram/federated/FederatedWorker.java  | 38 +++++++++
 .../federated/FederatedWorkerHandler.java          |  3 +-
 .../runtime/instructions/FEDInstructionParser.java |  1 -
 .../instructions/fed/TsmmFEDInstruction.java       |  2 -
 .../apache/sysds/runtime/lineage/LineageCache.java | 64 +++++++++++++++
 .../sysds/runtime/lineage/LineageCacheEntry.java   | 35 +++++++-
 .../sysds/runtime/lineage/LineageItemUtils.java    |  7 ++
 .../java/org/apache/sysds/utils/Statistics.java    |  3 +-
 .../FederatedLineageTraceReuseTest.java            |  6 +-
 ...t.java => FederatedSerializationReuseTest.java} | 93 +++++++++++-----------
 .../FederatedSerializationReuseTest.dml            | 57 +++++++++++++
 13 files changed, 329 insertions(+), 73 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
index b8cb55851c..ff059a460a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
@@ -27,6 +27,7 @@ import java.util.concurrent.atomic.LongAdder;
 import org.apache.commons.lang.exception.ExceptionUtils;
 import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.privacy.CheckedConstraintsLog;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
 
@@ -42,23 +43,35 @@ public class FederatedResponse implements Serializable {
        private ResponseType _status;
        private Object[] _data;
        private Map<PrivacyLevel,LongAdder> checkedConstraints;
+
+       private transient LineageItem _linItem = null; // not included in 
serialized object
        
        public FederatedResponse(ResponseType status) {
-               this(status, null);
+               this(status, null, null);
        }
        
        public FederatedResponse(ResponseType status, Object[] data) {
+               this(status, data, null);
+       }
+
+       public FederatedResponse(ResponseType status, Object[] data, 
LineageItem linItem) {
                _status = status;
                _data = data;
                if( _status == ResponseType.SUCCESS && data == null )
                        _status = ResponseType.SUCCESS_EMPTY;
+               _linItem = linItem;
        }
-       
+
        public FederatedResponse(FederatedResponse.ResponseType status, Object 
data) {
+               this(status, data, null);
+       }
+
+       public FederatedResponse(FederatedResponse.ResponseType status, Object 
data, LineageItem linItem) {
                _status = status;
                _data = new Object[] {data};
                if(_status == ResponseType.SUCCESS && data == null)
                        _status = ResponseType.SUCCESS_EMPTY;
+               _linItem = linItem;
        }
        
        public boolean isSuccessful() {
@@ -126,4 +139,8 @@ public class FederatedResponse implements Serializable {
                if ( checkedConstraints != null && 
!checkedConstraints.isEmpty() )
                        
CheckedConstraintsLog.addCheckedConstraints(checkedConstraints);
        }
+
+       public LineageItem getLineageItem() {
+               return _linItem;
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
index 20dfe27e66..58a9480266 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
@@ -46,6 +46,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.Fed
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.GCStatsCollection;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.LineageCacheStatsCollection;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.FedStatsCollection.MultiTenantStatsCollection;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
@@ -81,6 +82,8 @@ public class FederatedStatistics {
        private static final LongAdder fedReuseReadBytesCount = new LongAdder();
        private static final LongAdder fedPutLineageCount = new LongAdder();
        private static final LongAdder fedPutLineageItems = new LongAdder();
+       private static final LongAdder fedSerializationReuseCount = new 
LongAdder();
+       private static final LongAdder fedSerializationReuseBytes = new 
LongAdder();
 
        public static synchronized void incFederated(RequestType rqt, 
List<Object> data){
                switch (rqt) {
@@ -159,6 +162,8 @@ public class FederatedStatistics {
                fedReuseReadBytesCount.reset();
                fedPutLineageCount.reset();
                fedPutLineageItems.reset();
+               fedSerializationReuseCount.reset();
+               fedSerializationReuseBytes.reset();
        }
 
        public static String displayFedIOExecStatistics() {
@@ -204,6 +209,15 @@ public class FederatedStatistics {
                return sb.toString();
        }
 
+       public static String displayFedWorkerStats() {
+               StringBuilder sb = new StringBuilder();
+               sb.append(displayFedLookupTableStats());
+               sb.append(displayFedReuseReadStats());
+               sb.append(displayFedPutLineageStats());
+               sb.append(displayFedSerializationReuseStats());
+               return sb.toString();
+       }
+
        public static String displayStatistics(int numHeavyHitters) {
                StringBuilder sb = new StringBuilder();
                FedStatsCollection fedStats = collectFedStats();
@@ -251,6 +265,7 @@ public class FederatedStatistics {
                sb.append(displayFedLookupTableStats(mtsc.fLTGetCount, 
mtsc.fLTEntryCount, mtsc.fLTGetTime));
                sb.append(displayFedReuseReadStats(mtsc.reuseReadHits, 
mtsc.reuseReadBytes));
                sb.append(displayFedPutLineageStats(mtsc.putLineageCount, 
mtsc.putLineageItems));
+               
sb.append(displayFedSerializationReuseStats(mtsc.serializationReuseCount, 
mtsc.serializationReuseBytes));
                return sb.toString();
        }
 
@@ -385,6 +400,14 @@ public class FederatedStatistics {
                return fedPutLineageItems.longValue();
        }
 
+       public static long getFedSerializationReuseCount() {
+               return fedSerializationReuseCount.longValue();
+       }
+
+       public static long getFedSerializationReuseBytes() {
+               return fedSerializationReuseBytes.longValue();
+       }
+
        public static void incFedLookupTableGetCount() {
                fedLookupTableGetCount.increment();
        }
@@ -414,6 +437,11 @@ public class FederatedStatistics {
                fedPutLineageItems.add(serializedLineage.lines().count());
        }
 
+       public static void aggFedSerializationReuse(long bytes) {
+               fedSerializationReuseCount.increment();
+               fedSerializationReuseBytes.add(bytes);
+       }
+
        public static String displayFedLookupTableStats() {
                return 
displayFedLookupTableStats(fedLookupTableGetCount.longValue(),
                        fedLookupTableEntryCount.longValue(), 
fedLookupTableGetTime.doubleValue() / 1000000000);
@@ -421,25 +449,24 @@ public class FederatedStatistics {
 
        public static String displayFedLookupTableStats(long fltGetCount, long 
fltEntryCount, double fltGetTime) {
                if(fltGetCount > 0) {
-                       StringBuilder sb = new StringBuilder();
-                       sb.append("Fed LookupTable (Get, Entries):\t" +
-                               fltGetCount + "/" + fltEntryCount + ".\n");
-                       return sb.toString();
+                       return InstructionUtils.concatStrings(
+                               "Fed LookupTable (Get, Entries):\t",
+                               String.valueOf(fltGetCount), "/", 
String.valueOf(fltEntryCount),".\n");
                }
                return "";
        }
 
        public static String displayFedReuseReadStats() {
-               return 
displayFedReuseReadStats(fedReuseReadHitCount.longValue(),
+               return displayFedReuseReadStats(
+                       fedReuseReadHitCount.longValue(),
                        fedReuseReadBytesCount.longValue());
        }
 
        public static String displayFedReuseReadStats(long rrHits, long 
rrBytes) {
                if(rrHits > 0) {
-                       StringBuilder sb = new StringBuilder();
-                       sb.append("Fed ReuseRead (Hits, Bytes):\t" +
-                               rrHits + "/" + rrBytes + ".\n");
-                       return sb.toString();
+                       return InstructionUtils.concatStrings(
+                               "Fed ReuseRead (Hits, Bytes):\t",
+                               String.valueOf(rrHits), "/", 
String.valueOf(rrBytes), ".\n");
                }
                return "";
        }
@@ -451,10 +478,23 @@ public class FederatedStatistics {
 
        public static String displayFedPutLineageStats(long plCount, long 
plItems) {
                if(plCount > 0) {
-                       StringBuilder sb = new StringBuilder();
-                       sb.append("Fed PutLineage (Count, Items):\t" +
-                               plCount + "/" + plItems + ".\n");
-                       return sb.toString();
+                       return InstructionUtils.concatStrings(
+                               "Fed PutLineage (Count, Items):\t",
+                               String.valueOf(plCount), "/", 
String.valueOf(plItems), ".\n");
+               }
+               return "";
+       }
+
+       public static String displayFedSerializationReuseStats() {
+               return 
displayFedSerializationReuseStats(fedSerializationReuseCount.longValue(),
+                       fedSerializationReuseBytes.longValue());
+       }
+
+       public static String displayFedSerializationReuseStats(long srCount, 
long srBytes) {
+               if(srCount > 0) {
+                       return InstructionUtils.concatStrings(
+                               "Fed SerialReuse (Count, Bytes):\t",
+                               String.valueOf(srCount), "/", 
String.valueOf(srBytes), ".\n");
                }
                return "";
        }
@@ -619,6 +659,8 @@ public class FederatedStatistics {
                                reuseReadBytes = getFedReuseReadBytesCount();
                                putLineageCount = getFedPutLineageCount();
                                putLineageItems = getFedPutLineageItems();
+                               serializationReuseCount = 
getFedSerializationReuseCount();
+                               serializationReuseBytes = 
getFedSerializationReuseBytes();
                        }
 
                        private void aggregate(MultiTenantStatsCollection that) 
{
@@ -629,6 +671,8 @@ public class FederatedStatistics {
                                reuseReadBytes += that.reuseReadBytes;
                                putLineageCount += that.putLineageCount;
                                putLineageItems += that.putLineageItems;
+                               serializationReuseCount += 
that.serializationReuseCount;
+                               serializationReuseBytes += 
that.serializationReuseBytes;
                        }
 
                        private long fLTGetCount = 0;
@@ -638,6 +682,8 @@ public class FederatedStatistics {
                        private long reuseReadBytes = 0;
                        private long putLineageCount = 0;
                        private long putLineageItems = 0;
+                       private long serializationReuseCount = 0;
+                       private long serializationReuseBytes = 0;
                }
 
                private CacheStatsCollection cacheStats = new 
CacheStatsCollection();
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
index d090d0553c..a41f656524 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
@@ -32,8 +32,12 @@ import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.conf.DMLConfig;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
 import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysds.runtime.lineage.LineageCache;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
+import org.apache.sysds.runtime.lineage.LineageItem;
 
 import io.netty.bootstrap.ServerBootstrap;
 import io.netty.buffer.ByteBuf;
@@ -133,6 +137,40 @@ public class FederatedWorker {
                        else
                                return ctx.alloc().heapBuffer(initCapacity);
                }
+
+               @Override
+               protected void encode(ChannelHandlerContext ctx, Serializable 
msg, ByteBuf out) throws Exception {
+                       LineageItem objLI = null;
+                       boolean linReusePossible = (!ReuseCacheType.isNone() && 
msg instanceof FederatedResponse);
+                       if(linReusePossible) {
+                               FederatedResponse response = 
(FederatedResponse)msg;
+                               if(response.getData() != null && 
response.getData().length != 0
+                                       && response.getData()[0] instanceof 
CacheBlock) {
+                                       objLI = response.getLineageItem();
+
+                                       byte[] cachedBytes = 
LineageCache.reuseSerialization(objLI);
+                                       if(cachedBytes != null) {
+                                               out.writeBytes(cachedBytes);
+                                               return;
+                                       }
+                               }
+                       }
+
+                       linReusePossible &= (objLI != null);
+
+                       int startIdx = linReusePossible ? out.writerIndex() : 0;
+                       long t0 = linReusePossible ? System.nanoTime() : 0;
+                       super.encode(ctx, msg, out);
+                       long t1 = linReusePossible ? System.nanoTime() : 0;
+
+                       if(linReusePossible) {
+                               out.readerIndex(startIdx);
+                               byte[] dst = new byte[out.readableBytes()];
+                               out.readBytes(dst);
+                               LineageCache.putSerializedObject(dst, objLI, 
(t1 - t0));
+                               out.resetReaderIndex();
+                       }
+               }
        }
 
        private ChannelInitializer<SocketChannel> createChannel(boolean ssl) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 769dabc173..4c90c74b1b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -443,7 +443,8 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                                case TENSOR:
                                case MATRIX:
                                case FRAME:
-                                       return new 
FederatedResponse(ResponseType.SUCCESS, ((CacheableData<?>) 
dataObject).acquireReadAndRelease());
+                                       return new 
FederatedResponse(ResponseType.SUCCESS, ((CacheableData<?>) 
dataObject).acquireReadAndRelease(),
+                                               ReuseCacheType.isNone() ? null 
: ec.getLineage().get(String.valueOf(request.getID())));
                                case LIST:
                                        return new 
FederatedResponse(ResponseType.SUCCESS, ((ListObject) dataObject).getData());
                                case SCALAR:
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 58ab43daba..8e5e673e1d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -21,7 +21,6 @@ package org.apache.sysds.runtime.instructions;
 
 import org.apache.sysds.lops.Append;
 import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.instructions.cp.CPInstruction;
 import org.apache.sysds.runtime.instructions.fed.AggregateBinaryFEDInstruction;
 import 
org.apache.sysds.runtime.instructions.fed.AggregateTernaryFEDInstruction;
 import org.apache.sysds.runtime.instructions.fed.AggregateUnaryFEDInstruction;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
index 7ad48cfa25..5ebe7f6295 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
@@ -31,9 +31,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.Reques
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
-import org.apache.sysds.runtime.instructions.CPInstructionParser;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
-import org.apache.sysds.runtime.instructions.cp.CPInstruction;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index 3ea7d3d143..c1135cdb54 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -398,6 +398,38 @@ public class LineageCache
                return false;
        }
 
+       public static byte[] reuseSerialization(LineageItem objLI) {
+               if (ReuseCacheType.isNone() || objLI == null)
+                       return null;
+
+               LineageItem li = 
LineageItemUtils.getSerializedFedResponseLineageItem(objLI);
+
+               LineageCacheEntry e = null;
+               synchronized(_cache) {
+                       if(LineageCache.probe(li)) {
+                               e = LineageCache.getIntern(li);
+                       }
+                       else {
+                               putIntern(li, DataType.UNKNOWN, null, null, 0);
+                               return null; // direct return after placing the 
placeholder
+                       }
+               }
+
+               if(e != null && e.isSerializedBytes()) {
+                       byte[] sBytes = e.getSerializedBytes(); // waiting if 
the value is not set yet
+                       if (sBytes == null && e.getCacheStatus() == 
LineageCacheStatus.NOTCACHED)
+                               return null;  // the executing thread removed 
this entry from cache
+
+                       if (DMLScript.STATISTICS) { // increment statistics
+                               
LineageCacheStatistics.incrementSavedComputeTime(e._computeTime);
+                               
FederatedStatistics.aggFedSerializationReuse(sBytes.length);
+                       }
+
+                       return sBytes;
+               }
+               return null;
+       }
+
        public static boolean probe(LineageItem key) {
                //TODO problematic as after probe the matrix might be kicked 
out of cache
                boolean p = _cache.containsKey(key);  // in cache or in disk
@@ -695,6 +727,38 @@ public class LineageCache
                }
        }
 
+       public static void putSerializedObject(byte[] serialBytes, LineageItem 
objLI, long computetime) {
+               if(ReuseCacheType.isNone())
+                       return;
+
+               LineageItem li = 
LineageItemUtils.getSerializedFedResponseLineageItem(objLI);
+
+               LineageCacheEntry entry = getIntern(li);
+
+               if(entry != null && serialBytes != null) {
+                       synchronized(_cache) {
+                               long size = serialBytes.length;
+
+                               // remove the placeholder if the entry is 
bigger than the cache.
+                               if (size > 
LineageCacheEviction.getCacheLimit()) {
+                                       removePlaceholder(li);
+                               }
+
+                               // make space for the data
+                               if 
(!LineageCacheEviction.isBelowThreshold(size))
+                                       LineageCacheEviction.makeSpace(_cache, 
size);
+                               LineageCacheEviction.updateSize(size, true);
+
+                               entry.setValue(serialBytes, computetime);
+                       }
+               }
+               else {
+                       synchronized(_cache) {
+                               removePlaceholder(li);
+                       }
+               }
+       }
+
        public static void resetCache() {
                synchronized (_cache) {
                        _cache.clear();
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
index 962c7d5307..b8e30cb4c3 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheEntry.java
@@ -33,6 +33,7 @@ public class LineageCacheEntry {
        protected final DataType _dt;
        protected MatrixBlock _MBval;
        protected ScalarObject _SOval;
+       protected byte[] _serialBytes; // serialized bytes of a federated 
response
        protected long _computeTime;
        protected long _timestamp = 0;
        protected LineageCacheStatus _status;
@@ -88,7 +89,22 @@ public class LineageCacheEntry {
                        throw new DMLRuntimeException(ex);
                }
        }
-       
+
+       public synchronized byte[] getSerializedBytes() {
+               try {
+                       // wait until other thread completes operation
+                       // in order to avoid redundant computation
+                       while(_status == LineageCacheStatus.EMPTY) {
+                               wait();
+                       }
+                       // comes here if data is placed or the entry is removed 
by the running thread
+                       return _serialBytes;
+               }
+               catch( InterruptedException ex ) {
+                       throw new DMLRuntimeException(ex);
+               }
+       }
+
        public synchronized LineageCacheStatus getCacheStatus() {
                return _status;
        }
@@ -113,7 +129,7 @@ public class LineageCacheEntry {
        }
        
        public boolean isNullVal() {
-               return(_MBval == null && _SOval == null && _gpuObject == null);
+               return(_MBval == null && _SOval == null && _gpuObject == null 
&& _serialBytes == null);
        }
        
        public boolean isMatrixValue() {
@@ -123,7 +139,11 @@ public class LineageCacheEntry {
        public boolean isScalarValue() {
                return _dt.isScalar();
        }
-       
+
+       public boolean isSerializedBytes() {
+               return _dt.isUnknown() && 
_key.getOpcode().equals(LineageItemUtils.SERIALIZATION_OPCODE);
+       }
+
        public synchronized void setValue(MatrixBlock val, long computetime) {
                _MBval = val;
                _gpuObject = null;  //Matrix block and gpu object cannot coexist
@@ -154,6 +174,14 @@ public class LineageCacheEntry {
                //resume all threads waiting for val
                notifyAll();
        }
+
+       public synchronized void setValue(byte[] serialBytes, long computetime) 
{
+               _serialBytes = serialBytes;
+               _computeTime = computetime;
+               _status = isNullVal() ? LineageCacheStatus.EMPTY : 
LineageCacheStatus.CACHED;
+               // resume all threads waiting for val
+               notifyAll();
+       }
        
        public synchronized GPUObject getGPUObject() {
                return _gpuObject;
@@ -162,6 +190,7 @@ public class LineageCacheEntry {
        protected synchronized void setNullValues() {
                _MBval = null;
                _SOval = null;
+               _serialBytes = null;
                _status = LineageCacheStatus.EMPTY;
        }
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
index c5e60ffa01..8372d711f2 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
@@ -76,6 +76,9 @@ import java.util.stream.Collectors;
 public class LineageItemUtils {
        
        public static final String LPLACEHOLDER = "IN#";
+
+       // opcode to represent the serialized bytes of a federated response in 
lineage cache
+       public static final String SERIALIZATION_OPCODE = "serialize";
        
        public static LineageItemType getType(String str) {
                if (str.length() == 1) {
@@ -541,4 +544,8 @@ public class LineageItemUtils {
                sb.append(true); //isLiteral = true
                return new LineageItem(sb.toString());
        }
+
+       public static LineageItem 
getSerializedFedResponseLineageItem(LineageItem li) {
+               return new LineageItem(SERIALIZATION_OPCODE, new 
LineageItem[]{li});
+       }
 }
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java 
b/src/main/java/org/apache/sysds/utils/Statistics.java
index b105ffaeeb..77b63a9921 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -654,8 +654,7 @@ public class Statistics
                        sb.append(ParForStatistics.displayStatistics());
 
                        
sb.append(FederatedStatistics.displayFedIOExecStatistics());
-                       
sb.append(FederatedStatistics.displayFedLookupTableStats());
-                       
sb.append(FederatedStatistics.displayFedReuseReadStats());
+                       sb.append(FederatedStatistics.displayFedWorkerStats());
 
                        sb.append(TransformStatistics.displayStatistics());
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedLineageTraceReuseTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedLineageTraceReuseTest.java
index ac4c861941..f01c196a63 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedLineageTraceReuseTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedLineageTraceReuseTest.java
@@ -235,12 +235,14 @@ public class FederatedLineageTraceReuseTest extends 
MultiTenantTestBase {
                boolean retVal = false;
                int multiplier = 1;
                int numInst = -1;
+               int serializationWrites = 0;
                switch(opType) {
                        case EW_PLUS:
                                numInst = (execMode == ExecMode.SPARK) ? 1 : 2;
                                break;
                        case MM:
                                numInst = rowPartitioned ? 2 : 3;
+                               serializationWrites = rowPartitioned ? 1 : 0;
                                break;
                        case PARFOR_ADD: // number of instructions times number 
of iterations of the parfor loop
                                multiplier = 3;
@@ -249,8 +251,8 @@ public class FederatedLineageTraceReuseTest extends 
MultiTenantTestBase {
                }
                retVal = outputLog.contains(LINCACHE_MULTILVL
                        + Integer.toString(numInst * 
(coordinatorProcesses.size()-1) * workerProcesses.size()) + "/");
-               retVal &= outputLog.contains(LINCACHE_WRITES
-                       + Integer.toString((1 + numInst) * 
workerProcesses.size()) + "/"); // read + instructions
+               retVal &= outputLog.contains(LINCACHE_WRITES // read + 
instructions + serializations
+                       + Integer.toString((1 + numInst + serializationWrites) 
* workerProcesses.size()) + "/");
                retVal &= outputLog.contains(FED_LINEAGEPUT
                        + Integer.toString(coordinatorProcesses.size() * 
workerProcesses.size() * multiplier) + "/");
                return retVal;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedLineageTraceReuseTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedSerializationReuseTest.java
similarity index 75%
copy from 
src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedLineageTraceReuseTest.java
copy to 
src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedSerializationReuseTest.java
index ac4c861941..7fef68f000 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedLineageTraceReuseTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/multitenant/FederatedSerializationReuseTest.java
@@ -40,11 +40,11 @@ import org.junit.runners.Parameterized;
 
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
-public class FederatedLineageTraceReuseTest extends MultiTenantTestBase {
-       private final static String TEST_NAME = 
"FederatedLineageTraceReuseTest";
+public class FederatedSerializationReuseTest extends MultiTenantTestBase {
+       private final static String TEST_NAME = 
"FederatedSerializationReuseTest";
 
        private final static String TEST_DIR = 
"functions/federated/multitenant/";
-       private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedLineageTraceReuseTest.class.getSimpleName() + "/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedSerializationReuseTest.class.getSimpleName() + "/";
 
        private final static double TOLERANCE = 0;
 
@@ -70,9 +70,9 @@ public class FederatedLineageTraceReuseTest extends 
MultiTenantTestBase {
        }
 
        private enum OpType {
-               EW_PLUS,
-               MM,
-               PARFOR_ADD,
+               EW_DIV,
+               ROWSUMS,
+               PARFOR_MULT,
        }
 
        @Override
@@ -82,39 +82,39 @@ public class FederatedLineageTraceReuseTest extends 
MultiTenantTestBase {
        }
 
        @Test
-       public void testElementWisePlusCP() {
-               runLineageTraceReuseTest(OpType.EW_PLUS, 4, 
ExecMode.SINGLE_NODE);
+       @Ignore
+       public void testElementWiseDivCP() {
+               runSerializationReuseTest(OpType.EW_DIV, 4, 
ExecMode.SINGLE_NODE);
        }
 
        @Test
-       @Ignore
-       public void testElementWisePlusSP() {
-               runLineageTraceReuseTest(OpType.EW_PLUS, 4, ExecMode.SPARK);
+       public void testElementWiseDivSP() {
+               runSerializationReuseTest(OpType.EW_DIV, 4, ExecMode.SPARK);
        }
 
        @Test
-       public void testMatrixMultCP() {
-               runLineageTraceReuseTest(OpType.MM, 4, ExecMode.SINGLE_NODE);
+       @Ignore
+       public void testRowSumsCP() {
+               runSerializationReuseTest(OpType.ROWSUMS, 4, 
ExecMode.SINGLE_NODE);
        }
 
        @Test
-       @Ignore // TODO: allow for reuse of respective spark instructions
-       public void testMatrixMultSP() {
-               runLineageTraceReuseTest(OpType.MM, 4, ExecMode.SPARK);
+       public void testRowSumsSP() {
+               runSerializationReuseTest(OpType.ROWSUMS, 4, ExecMode.SPARK);
        }
 
        @Test
-       public void testParforAddCP() {
-               runLineageTraceReuseTest(OpType.PARFOR_ADD, 3, 
ExecMode.SINGLE_NODE);
+       public void testParforMultCP() {
+               runSerializationReuseTest(OpType.PARFOR_MULT, 3, 
ExecMode.SINGLE_NODE);
        }
 
        @Test
        @Ignore
-       public void testParforAddSP() {
-               runLineageTraceReuseTest(OpType.PARFOR_ADD, 3, ExecMode.SPARK);
+       public void testParforMultSP() {
+               runSerializationReuseTest(OpType.PARFOR_MULT, 3, 
ExecMode.SPARK);
        }
 
-       private void runLineageTraceReuseTest(OpType opType, int 
numCoordinators, ExecMode execMode) {
+       private void runSerializationReuseTest(OpType opType, int 
numCoordinators, ExecMode execMode) {
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                ExecMode platformOld = rtplatform;
 
@@ -203,21 +203,15 @@ public class FederatedLineageTraceReuseTest extends 
MultiTenantTestBase {
        private boolean checkForHeavyHitter(OpType opType, String outputLog, 
ExecMode execMode) {
                boolean retVal = false;
                switch(opType) {
-                       case EW_PLUS:
-                               retVal = checkForHeavyHitter(outputLog, 
"fed_+");
-                               if(execMode == ExecMode.SINGLE_NODE)
-                                       retVal &= 
checkForHeavyHitter(outputLog, "fed_uak+");
+                       case EW_DIV:
+                               retVal = checkForHeavyHitter(outputLog, 
"fed_/");
                                break;
-                       case MM:
-                               retVal = checkForHeavyHitter(outputLog, 
(execMode == ExecMode.SPARK) ? "fed_mapmm" : "fed_ba+*");
-                               retVal &= checkForHeavyHitter(outputLog, 
"fed_r'");
-                               if(!rowPartitioned)
-                                       retVal &= 
checkForHeavyHitter(outputLog, (execMode == ExecMode.SPARK) ? "fed_rblk" : 
"fed_uak+");
+                       case ROWSUMS:
+                               retVal = checkForHeavyHitter(outputLog, 
"fed_uark+");
                                break;
-                       case PARFOR_ADD:
-                               retVal = checkForHeavyHitter(outputLog, 
"fed_-");
-                               retVal &= checkForHeavyHitter(outputLog, 
"fed_+");
-                               retVal &= checkForHeavyHitter(outputLog, 
(execMode == ExecMode.SPARK) ? "fed_rblk" : "fed_uak+");
+                       case PARFOR_MULT:
+                               retVal = checkForHeavyHitter(outputLog, 
"fed_*");
+                               retVal &= checkForHeavyHitter(outputLog, 
"fed_uack+");
                                break;
                }
                return retVal;
@@ -231,28 +225,33 @@ public class FederatedLineageTraceReuseTest extends 
MultiTenantTestBase {
        private boolean checkForReuses(OpType opType, String outputLog, 
ExecMode execMode) {
                final String LINCACHE_MULTILVL = "LinCache MultiLvl 
(Ins/SB/Fn):\t";
                final String LINCACHE_WRITES = "LinCache writes 
(Mem/FS/Del):\t";
-               final String FED_LINEAGEPUT = "Fed PutLineage (Count, 
Items):\t";
+               final String SERIAL_REUSE = "Fed SerialReuse (Count, Bytes):\t";
                boolean retVal = false;
-               int multiplier = 1;
                int numInst = -1;
+               int multiplier = 1;
+               int serializationWrites = 0;
                switch(opType) {
-                       case EW_PLUS:
-                               numInst = (execMode == ExecMode.SPARK) ? 1 : 2;
+                       case EW_DIV:
+                               numInst = 1;
+                               serializationWrites = 1;
                                break;
-                       case MM:
-                               numInst = rowPartitioned ? 2 : 3;
+                       case ROWSUMS:
+                               numInst = (execMode == ExecMode.SPARK) ? 0 : 1;
+                               serializationWrites = 1;
                                break;
-                       case PARFOR_ADD: // number of instructions times number 
of iterations of the parfor loop
-                               multiplier = 3;
-                               numInst = ((execMode == ExecMode.SPARK) ? 2 : 
3) * multiplier;
+                       case PARFOR_MULT: // number of instructions times 
number of iterations of the parfor loop
+                               multiplier = 3; // number of parfor iterations
+                               numInst = (execMode == ExecMode.SPARK) ? 1 * 
multiplier : 2 * multiplier;
+                               serializationWrites = multiplier;
                                break;
                }
                retVal = outputLog.contains(LINCACHE_MULTILVL
                        + Integer.toString(numInst * 
(coordinatorProcesses.size()-1) * workerProcesses.size()) + "/");
-               retVal &= outputLog.contains(LINCACHE_WRITES
-                       + Integer.toString((1 + numInst) * 
workerProcesses.size()) + "/"); // read + instructions
-               retVal &= outputLog.contains(FED_LINEAGEPUT
-                       + Integer.toString(coordinatorProcesses.size() * 
workerProcesses.size() * multiplier) + "/");
+               retVal &= outputLog.contains(LINCACHE_WRITES // read + 
instructions + serializations
+                       + Integer.toString((1 + numInst + serializationWrites) 
* workerProcesses.size()) + "/");
+               retVal &= outputLog.contains(SERIAL_REUSE
+                       + Integer.toString(serializationWrites * 
(coordinatorProcesses.size()-1)
+                               * workerProcesses.size()) + "/");
                return retVal;
        }
 }
diff --git 
a/src/test/scripts/functions/federated/multitenant/FederatedSerializationReuseTest.dml
 
b/src/test/scripts/functions/federated/multitenant/FederatedSerializationReuseTest.dml
new file mode 100644
index 0000000000..b38da937b1
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/multitenant/FederatedSerializationReuseTest.dml
@@ -0,0 +1,57 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+rowPart = $rP;
+
+if (rowPart) {
+  X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+                 list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 
0), list($rows, $cols)));
+} else {
+  X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), 
list($rows, $cols/2),
+               list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 
3*($cols/4)), list($rows, $cols)));
+}
+
+testnum = $testnum;
+
+if(testnum == 0) { # EW_DIV
+  S = X / 2;
+}
+else if(testnum == 1) { # ROWSUMS
+  S = rowSums(X);
+}
+else if(testnum == 2) { # PARFOR_MULT
+  Y = rand(rows=$rows, cols=$cols, seed=1234);
+  while(FALSE) { }
+  numiter = 3;
+  Z = matrix(0, rows=numiter, cols=ncol(X));
+  parfor(i in 1:numiter) {
+    Y_vec = rowMeans(Y + i);
+    while(FALSE) { }
+    Z_tmp = X * Y_vec;
+    while(FALSE) { }
+    Z[i, ] = colSums(Z_tmp);
+  }
+  S = Z;
+}
+
+write(S, $out_S);

Reply via email to