This is an automated email from the ASF dual-hosted git repository.
janniklinde pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 1f2e2eab00 [MINOR] OOC Bugfix Cache Reference Management + Return
Right BlockKey on Externally Managed Grouped Callbacks
1f2e2eab00 is described below
commit 1f2e2eab00a835f83a71e07b18f6b2b369ba1b97
Author: Jannik Lindemann <[email protected]>
AuthorDate: Fri Mar 27 10:29:30 2026 +0100
[MINOR] OOC Bugfix Cache Reference Management + Return Right BlockKey on
Externally Managed Grouped Callbacks
Closes #2454.
---
.../runtime/instructions/ooc/CachingStream.java | 49 +++++++++++++++++++++-
.../apache/sysds/runtime/ooc/cache/BlockEntry.java | 10 +++++
.../sysds/runtime/ooc/cache/OOCCacheScheduler.java | 8 ++++
.../runtime/ooc/cache/OOCLRUCacheScheduler.java | 39 +++++++++++++++--
4 files changed, 101 insertions(+), 5 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
index b3f5e57aaf..b7c4e2aa64 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java
@@ -130,6 +130,7 @@ public class CachingStream implements
OOCStreamable<IndexedMatrixValue> {
boolean ownsEntry =
true;
if(tmp instanceof
OOCCacheManager.CachedGroupCallback<?> cachedGroup) {
baseKey =
cachedGroup.getBlockKey();
+
ensureReferencedOrRematerialize(baseKey, cachedGroup);
ownsEntry =
false;
if(mSubscribers
!= null && mSubscribers.length > 0)
mCallback = tmp.keepOpen();
@@ -183,12 +184,14 @@ public class CachingStream implements
OOCStreamable<IndexedMatrixValue> {
if(tmp instanceof
OOCCacheManager.CachedQueueCallback<?> cachedQueue) {
blockKey =
cachedQueue.getBlockKey();
+
ensureReferencedOrRematerialize(blockKey, task);
ownsEntry =
false;
if(mSubscribers
!= null && mSubscribers.length > 0)
mCallback = tmp.keepOpen();
}
else if(tmp instanceof
OOCCacheManager.CachedSubCallback<?> cachedSub) {
BlockKey parent
= cachedSub.getParent().getBlockKey();
+
ensureReferencedOrRematerialize(parent, cachedSub.getParent());
blockKey = new
GroupedBlockKey(parent.getStreamId(), (int) parent.getSequenceNumber(),
cachedSub.getGroupIndex());
ownsEntry =
false;
@@ -297,6 +300,49 @@ public class CachingStream implements
OOCStreamable<IndexedMatrixValue> {
});
}
+
+ private void ensureReferencedOrRematerialize(BlockKey key,
IndexedMatrixValue value) {
+ try {
+ OOCCacheManager.getCache().addReference(key);
+ }
+ catch(IllegalArgumentException ex) {
+ try {
+ OOCCacheManager.putRaw(key, value,
((MatrixBlock) value.getValue()).getExactSerializedSize());
+ }
+ catch(IllegalStateException putEx) {
+ // Another downstream stream may have
re-materialized the same entry first.
+ OOCCacheManager.getCache().addReference(key);
+ }
+ }
+ }
+
+ private void ensureReferencedOrRematerialize(BlockKey key,
OOCCacheManager.CachedGroupCallback<?> group) {
+ try {
+ OOCCacheManager.getCache().addReference(key);
+ }
+ catch(IllegalArgumentException ex) {
+ try {
+ List<IndexedMatrixValue> values = new
ArrayList<>(group.size());
+ long totalSize = 0;
+ for(int gi = 0; gi < group.size(); gi++) {
+ @SuppressWarnings("unchecked")
+
OOCStream.QueueCallback<IndexedMatrixValue> sub =
+
(OOCStream.QueueCallback<IndexedMatrixValue>) group.getCallback(gi);
+ try(sub) {
+ IndexedMatrixValue imv =
sub.get();
+ values.add(imv);
+ totalSize += ((MatrixBlock)
imv.getValue()).getExactSerializedSize();
+ }
+ }
+ OOCCacheManager.putRaw(key, values, totalSize);
+ }
+ catch(IllegalStateException putEx) {
+ // Another downstream stream may have
re-materialized the same entry first.
+ OOCCacheManager.getCache().addReference(key);
+ }
+ }
+ }
+
private String getCtxMsg() {
StackTraceElement[] st = new Exception().getStackTrace();
// Skip the first few frames (constructor,
createWritableStream, etc.)
@@ -687,7 +733,7 @@ public class CachingStream implements
OOCStreamable<IndexedMatrixValue> {
if(groupIdx > 0)
continue; // only replay grouped blocks once at
the base index
- BlockKey replayKey = (groupSize > 1 && groupIdx == 0) ?
new BlockKey(_streamId, idx) : getBlockKey(i);
+ BlockKey replayKey = (groupSize > 1 && groupIdx == 0) ?
getEntryBlockKey(idx) : getBlockKey(i);
OOCCacheManager.requestBlock(replayKey).whenComplete((cb, r) -> {
if(r != null) {
subscriber.accept(OOCStream.eos(DMLRuntimeException.of(r)));
@@ -697,7 +743,6 @@ public class CachingStream implements
OOCStreamable<IndexedMatrixValue> {
synchronized(CachingStream.this) {
if(_index != null) {
if(cb instanceof
OOCStream.GroupQueueCallback<?> && groupSize > 1) {
-
@SuppressWarnings("unchecked")
OOCStream.GroupQueueCallback<IndexedMatrixValue> group =
(OOCStream.GroupQueueCallback<IndexedMatrixValue>) cb;
for(int gi = 0;
gi < groupSize; gi++) {
diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java
b/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java
index fbc7d64223..c0604d017d 100644
--- a/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java
+++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java
@@ -30,6 +30,7 @@ public final class BlockEntry {
private volatile BlockState _state;
private Object _data;
private int _retainHintCount;
+ private int _referenceCount; // The number of references from different
managing instances (e.g. CachingStream)
BlockEntry(BlockKey key, long size, Object data) {
this._key = key;
@@ -38,6 +39,7 @@ public final class BlockEntry {
this._state = BlockState.HOT;
this._data = data;
this._retainHintCount = 0;
+ this._referenceCount = 1;
}
public BlockKey getKey() {
@@ -84,6 +86,14 @@ public final class BlockEntry {
return _pinCount > 0;
}
+ synchronized int addReference() {
+ return ++_referenceCount;
+ }
+
+ synchronized int forget() {
+ return --_referenceCount;
+ }
+
synchronized void setState(BlockState state) {
_state = state;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java
b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java
index 9cc108db5e..dbbd73d53a 100644
--- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java
+++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheScheduler.java
@@ -103,6 +103,14 @@ public interface OOCCacheScheduler {
BlockEntry putAndPinSourceBacked(BlockKey key, Object data, long size,
OOCIOHandler.SourceBlockDescriptor descriptor);
+ /**
+ * Notifies the cache that there is another reference to the same block
key.
+ * This will prevent forget(key) from removing the block from cache.
+ * A block will only be forgotten after all referencing instances
called forget(key).
+ * @param key
+ */
+ void addReference(BlockKey key);
+
/**
* Forgets a block from the cache.
* @param key the associated key of the block
diff --git
a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java
b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java
index a204cd16db..cc7aa7bcd1 100644
--- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java
+++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java
@@ -19,6 +19,7 @@
package org.apache.sysds.runtime.ooc.cache;
+import org.apache.commons.lang3.mutable.MutableObject;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
@@ -291,6 +292,18 @@ public class OOCLRUCacheScheduler implements
OOCCacheScheduler {
return put(key, data, size, true, descriptor);
}
+ @Override
+ public void addReference(BlockKey key) {
+ synchronized(this) {
+ BlockEntry entry = _cache.get(key);
+ if(entry == null)
+ entry = _evictionCache.get(key);
+ if(entry == null)
+ throw new IllegalArgumentException("Could not
find requested block with key " + key);
+ entry.addReference();
+ }
+ }
+
private BlockEntry put(BlockKey key, Object data, long size, boolean
pin, OOCIOHandler.SourceBlockDescriptor descriptor) {
if (!this._running)
throw new IllegalStateException();
@@ -324,14 +337,34 @@ public class OOCLRUCacheScheduler implements
OOCCacheScheduler {
public void forget(BlockKey key) {
if (!this._running)
return;
+ final MutableObject<BlockEntry> mEntry = new MutableObject<>();
BlockEntry entry;
boolean shouldScheduleDeletion = false;
long cacheSizeDelta = 0;
synchronized(this) {
- entry = _cache.remove(key);
+ _cache.compute(key, (k, e) -> {
+ if(e == null)
+ return null;
+ if(e.forget() == 0) {
+ mEntry.setValue(e);
+ return null;
+ }
+ return e;
+ });
- if (entry == null)
- entry = _evictionCache.remove(key);
+ if (mEntry.getValue() == null) {
+ _evictionCache.compute(key, (k, e) -> {
+ if(e == null)
+ return null;
+ if(e.forget() == 0) {
+ mEntry.setValue(e);
+ return null;
+ }
+ return e;
+ });
+ }
+
+ entry = mEntry.getValue();
if (entry != null) {
synchronized(entry) {