This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 61a385fc9d [SYSTEMDS-2985] Fix nested list cache management
61a385fc9d is described below
commit 61a385fc9d82f74642bc0fe2392b05cf556537ee
Author: MaximilianTUB <[email protected]>
AuthorDate: Wed Dec 6 17:09:21 2023 +0100
[SYSTEMDS-2985] Fix nested list cache management
SystemDS was previously not supporting nested lists correctly
since the data of CacheableData objects within nested loops
were always deleted after a function call.
Normally, there are rmvar statements after function calls to
emove all variables used within the function. To protect
CacheableData objects (e.g. matrices) from having their data
removed by the rmvar statements we use a cleanup-enabled flag.
This flag was not correctly set for variables that were within
a nested list. These commits fix this problem by flagging all
elements, also within nested lists.
Automated tests have been added to test the changes.
Closes #1956
---
.../runtime/controlprogram/ParForProgramBlock.java | 7 +-
.../controlprogram/context/ExecutionContext.java | 69 ++++------
.../instructions/cp/FunctionCallCPInstruction.java | 3 +-
.../sysds/runtime/instructions/cp/ListObject.java | 58 +++++++-
.../test/functions/caching/PinVariablesTest.java | 153 +++++++++++++++++++++
5 files changed, 242 insertions(+), 48 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
index 790a92de58..06a548a753 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
@@ -31,6 +31,7 @@ import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
+import java.util.Queue;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -652,7 +653,7 @@ public class ParForProgramBlock extends ForProgramBlock {
//preserve shared input/result variables of cleanup
ArrayList<String> varList = ec.getVarList();
- boolean[] varState = ec.pinVariables(varList);
+ Queue<Boolean> varState = ec.pinVariables(varList);
try
{
@@ -677,7 +678,7 @@ public class ParForProgramBlock extends ForProgramBlock {
catch(Exception ex) {
throw new DMLRuntimeException("PARFOR: Failed to
execute loop in parallel.",ex);
}
-
+
//reset state of shared input/result variables
ec.unpinVariables(varList, varState);
@@ -1198,7 +1199,7 @@ public class ParForProgramBlock extends ForProgramBlock {
}
}
- private void cleanupSharedVariables( ExecutionContext ec, boolean[]
varState ) {
+ private void cleanupSharedVariables( ExecutionContext ec,
Queue<Boolean> varState ) {
//TODO needs as precondition a systematic treatment of
persistent read information.
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index d98827a24e..0903b5abca 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -65,12 +65,14 @@ import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.utils.Statistics;
import java.util.ArrayList;
+import java.util.LinkedList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
+import java.util.Queue;
public class ExecutionContext {
protected static final Log LOG =
LogFactory.getLog(ExecutionContext.class.getName());
@@ -753,45 +755,28 @@ public class ExecutionContext {
* @param varList variable list
* @return indicator vector of old cleanup state of matrix objects
*/
- public boolean[] pinVariables(List<String> varList)
+ public Queue<Boolean> pinVariables(List<String> varList)
{
- //analyze list variables
- int nlist = 0;
- int nlistItems = 0;
- for( int i=0; i<varList.size(); i++ ) {
- Data dat = _variables.get(varList.get(i));
- if( dat instanceof ListObject ) {
- nlistItems +=
((ListObject)dat).getNumCacheableData();
- nlist++;
- }
- }
-
- //2-pass approach since multiple vars might refer to same
matrix object
- boolean[] varsState = new
boolean[varList.size()-nlist+nlistItems];
-
- //step 1) get current information
- for( int i=0, pos=0; i<varList.size(); i++ ) {
- Data dat = _variables.get(varList.get(i));
- if( dat instanceof CacheableData<?> )
- varsState[pos++] =
((CacheableData<?>)dat).isCleanupEnabled();
- else if( dat instanceof ListObject )
- for( Data dat2 : ((ListObject)dat).getData() )
- if( dat2 instanceof CacheableData<?> )
- varsState[pos++] =
((CacheableData<?>)dat2).isCleanupEnabled();
+ // step 1) get current cleanupFlag status information
+ Queue<Boolean> varsStates = new LinkedList<>();
+ for (String varName : varList) {
+ Data dat = _variables.get(varName);
+ if (dat instanceof CacheableData<?>)
+
varsStates.add(((CacheableData<?>)dat).isCleanupEnabled());
+ else if (dat instanceof ListObject)
+
varsStates.addAll(((ListObject)dat).getCleanupStates());
}
-
- //step 2) pin variables
- for( int i=0; i<varList.size(); i++ ) {
- Data dat = _variables.get(varList.get(i));
- if( dat instanceof CacheableData<?> )
+
+ // step 2) pin variables
+ for (String varName : varList) {
+ Data dat = _variables.get(varName);
+ if (dat instanceof CacheableData<?>)
((CacheableData<?>)dat).enableCleanup(false);
- else if( dat instanceof ListObject )
- for( Data dat2 : ((ListObject)dat).getData() )
- if( dat2 instanceof CacheableData<?> )
-
((CacheableData<?>)dat2).enableCleanup(false);
+ else if (dat instanceof ListObject)
+ ((ListObject)dat).enableCleanup(false);
}
- return varsState;
+ return varsStates;
}
/**
@@ -810,15 +795,13 @@ public class ExecutionContext {
* @param varList variable list
* @param varsState variable state
*/
- public void unpinVariables(List<String> varList, boolean[] varsState) {
- for( int i=0, pos=0; i<varList.size(); i++ ) {
- Data dat = _variables.get(varList.get(i));
- if( dat instanceof CacheableData<?> )
-
((CacheableData<?>)dat).enableCleanup(varsState[pos++]);
- else if( dat instanceof ListObject )
- for( Data dat2 : ((ListObject)dat).getData() )
- if( dat2 instanceof CacheableData<?> )
-
((CacheableData<?>)dat2).enableCleanup(varsState[pos++]);
+ public void unpinVariables(List<String> varList, Queue<Boolean>
varsState) {
+ for (String varName : varList) {
+ Data dat = _variables.get(varName);
+ if (dat instanceof CacheableData<?>)
+
((CacheableData<?>)dat).enableCleanup(varsState.poll());
+ else if (dat instanceof ListObject)
+ ((ListObject)dat).enableCleanup(varsState);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
index 6aa4552a97..4d0649e1e9 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
@@ -24,6 +24,7 @@ import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
+import java.util.Queue;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -182,7 +183,7 @@ public class FunctionCallCPInstruction extends
CPInstruction {
// Pin the input variables so that they do not get deleted
// from pb's symbol table at the end of execution of function
- boolean[] pinStatus = ec.pinVariables(_boundInputNames);
+ Queue<Boolean> pinStatus = ec.pinVariables(_boundInputNames);
// Create a symbol table under a new execution context for the
function invocation,
// and copy the function arguments into the created table.
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
index 5c302fe80a..bf33a7e298 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
@@ -25,7 +25,9 @@ import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.LinkedList;
import java.util.List;
+import java.util.Queue;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.DataType;
@@ -96,7 +98,7 @@ public class ListObject extends Data implements
Externalizable {
for(int i=0; i<_data.size(); i++) {
Data dat = _data.get(i);
if( dat instanceof CacheableData<?> )
- _dataState[i] = ((CacheableData<?>)
dat).isCleanupEnabled();
+ _dataState[i] = ((CacheableData<?>)
dat).isCleanupEnabled();
}
}
@@ -497,4 +499,58 @@ public class ListObject extends Data implements
Externalizable {
_data.add(d);
}
}
+
+ /**
+ * Gets list of current cleanupFlag values recursively for every element
+ * in the list and in its sublists of type CacheableData. The order is
+ * as CacheableData elements are discovered during DFS. Elements that
+ * are not of type CacheableData are skipped.
+ *
+ * @return list of booleans containing the _cleanupFlag values.
+ */
+ public List<Boolean> getCleanupStates() {
+ List<Boolean> varsState = new LinkedList<>();
+ for (Data dat : this.getData()) {
+ if (dat instanceof CacheableData<?>)
+
varsState.add(((CacheableData<?>)dat).isCleanupEnabled());
+ else if (dat instanceof ListObject)
+
varsState.addAll(((ListObject)dat).getCleanupStates());
+ }
+ return varsState;
+ }
+
+ /**
+ * Sets the cleanupFlag values recursively for every element of type
+ * CacheableData in the list and in its sublists to the provided flag
+ * value.
+ *
+ * @param flag New value for every CacheableData element.
+ */
+ public void enableCleanup(boolean flag) {
+ for (Data dat : this.getData()) {
+ if (dat instanceof CacheableData<?>)
+ ((CacheableData<?>)dat).enableCleanup(flag);
+ if (dat instanceof ListObject)
+ ((ListObject)dat).enableCleanup(flag);
+ }
+ }
+
+ /**
+ * Sets the cleanupFlag values recursively for every element of type
+ * CacheableData in the list and in its sublists to the provided values
+ * in flags. The cleanupFlag value of the i-th CacheableData element
+ * in the list (counted in the order of DFS) is set to the i-th value
+ * in flags.
+ *
+ * @param flags Queue of values in the same order as its corresponding
+ * elements occur in DFS.
+ */
+ public void enableCleanup(Queue<Boolean> flags) {
+ for (Data dat : this.getData()) {
+ if (dat instanceof CacheableData<?>)
+
((CacheableData<?>)dat).enableCleanup(flags.poll());
+ else if (dat instanceof ListObject)
+ ((ListObject)dat).enableCleanup(flags);
+ }
+ }
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/caching/PinVariablesTest.java
b/src/test/java/org/apache/sysds/test/functions/caching/PinVariablesTest.java
new file mode 100644
index 0000000000..b88b036615
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/caching/PinVariablesTest.java
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.caching;
+
+import org.apache.oro.util.Cache;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.instructions.cp.ListObject;
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.instructions.cp.Data;
+
+import java.util.LinkedList;
+import java.util.Queue;
+import java.util.List;
+
+public class PinVariablesTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "PinVariables";
+ private final static String TEST_DIR = "functions/caching/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
PinVariablesTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR,
TEST_NAME));
+ }
+
+ @Test
+ public void testPinNoLists() {
+ createMockDataAndCall(true, false, false);
+ }
+
+ @Test
+ public void testPinShallowLists() {
+ createMockDataAndCall(true, true, false);
+ }
+
+ @Test
+ public void testPinNestedLists() {
+ createMockDataAndCall(true, true, true);
+ }
+
+ private void createMockDataAndCall(boolean matrices, boolean list, boolean
nestedList) {
+ LocalVariableMap vars = new LocalVariableMap();
+ List<String> varList = new LinkedList<>();
+ Queue<Boolean> varStates = new LinkedList<>();
+
+ if (matrices) {
+ MatrixObject mat1 = new MatrixObject(Types.ValueType.FP64,
"SomeFile1");
+ mat1.enableCleanup(true);
+ MatrixObject mat2 = new MatrixObject(Types.ValueType.FP64,
"SomeFile2");
+ mat2.enableCleanup(true);
+ MatrixObject mat3 = new MatrixObject(Types.ValueType.FP64,
"SomeFile3");
+ mat3.enableCleanup(false);
+ vars.put("mat1", mat1);
+ vars.put("mat2", mat2);
+ vars.put("mat3", mat3);
+
+ varList.add("mat2");
+ varList.add("mat3");
+
+ varStates.add(true);
+ varStates.add(false);
+ }
+ if (list) {
+ MatrixObject mat4 = new MatrixObject(Types.ValueType.FP64,
"SomeFile4");
+ mat4.enableCleanup(true);
+ MatrixObject mat5 = new MatrixObject(Types.ValueType.FP64,
"SomeFile5");
+ mat5.enableCleanup(false);
+ List<Data> l1_data = new LinkedList<>();
+ l1_data.add(mat4);
+ l1_data.add(mat5);
+
+ if (nestedList) {
+ MatrixObject mat6 = new MatrixObject(Types.ValueType.FP64,
"SomeFile6");
+ mat4.enableCleanup(true);
+ List<Data> l2_data = new LinkedList<>();
+ l2_data.add(mat6);
+ ListObject l2 = new ListObject(l2_data);
+ l1_data.add(l2);
+ }
+
+ ListObject l1 = new ListObject(l1_data);
+ vars.put("l1", l1);
+
+ varList.add("l1");
+
+ // cleanup flag of inner matrix (m4)
+ varStates.add(true);
+ varStates.add(false);
+ if (nestedList)
+ varStates.add(true);
+ }
+
+ ExecutionContext ec = new ExecutionContext(vars);
+
+ commonPinVariablesTest(ec, varList, varStates);
+ }
+
+ private void commonPinVariablesTest(ExecutionContext ec, List<String>
varList, Queue<Boolean> varStatesExp) {
+ Queue<Boolean> varStates = ec.pinVariables(varList);
+
+ // check returned cleanupEnabled flags
+ Assert.assertEquals(varStatesExp, varStates);
+
+ // assert updated cleanupEnabled flag to false
+ for (String varName : varList) {
+ Data dat = ec.getVariable(varName);
+
+ if (dat instanceof CacheableData<?>)
+ Assert.assertFalse(((CacheableData<?>)dat).isCleanupEnabled());
+ else if (dat instanceof ListObject) {
+ assertListFlagsDisabled((ListObject)dat);
+ }
+ }
+
+ ec.unpinVariables(varList, varStates);
+
+ // check returned flags after unpinVariables()
+ Queue<Boolean> varStates2 = ec.pinVariables(varList);
+ Assert.assertEquals(varStatesExp, varStates2);
+ }
+
+ private void assertListFlagsDisabled(ListObject l) {
+ for (Data dat : l.getData()) {
+ if (dat instanceof CacheableData<?>)
+ Assert.assertFalse(((CacheableData<?>)dat).isCleanupEnabled());
+ else if (dat instanceof ListObject)
+ assertListFlagsDisabled((ListObject)dat);
+ }
+ }
+}