This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 58490d3acf [SYSTEMDS-3799] Fix parfor result merge (all combinations)
58490d3acf is described below
commit 58490d3acf70baab18db2c37b083f589a08288eb
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Dec 1 17:46:18 2024 +0100
[SYSTEMDS-3799] Fix parfor result merge (all combinations)
This patch fixes recently discovered (see code coverage) issues of
parfor result merge for combinations of different result merge
implementations, dense/sparse inputs, with compare dense/sparse blocks,
and most importantly += accumulation into the output.
---
.../runtime/controlprogram/parfor/ResultMerge.java | 1 +
.../parfor/ResultMergeLocalFile.java | 7 +-
.../parfor/ResultMergeLocalMemory.java | 17 +--
.../controlprogram/parfor/ResultMergeMatrix.java | 118 ++++++++-------------
.../parfor/ResultMergeRemoteSparkWCompare.java | 8 +-
.../test/component/parfor/ResultMergeTest.java | 94 +++++++++-------
6 files changed, 122 insertions(+), 123 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMerge.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMerge.java
index b69ba96514..d441d02cc0 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMerge.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMerge.java
@@ -36,6 +36,7 @@ public abstract class ResultMerge<T extends CacheableData<?>>
implements Seriali
protected static final Log LOG =
LogFactory.getLog(ResultMerge.class.getName());
protected static final String NAME_SUFFIX = "_rm";
protected static final BinaryOperator PLUS =
InstructionUtils.parseBinaryOperator("+");
+ protected static final BinaryOperator MINUS =
InstructionUtils.parseBinaryOperator("-");
//inputs to result merge
protected T _output = null;
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalFile.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalFile.java
index ce683e455b..0ba0178657 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalFile.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalFile.java
@@ -258,7 +258,10 @@ public class ResultMergeLocalFile extends ResultMergeMatrix
DenseBlock compare =
DataConverter.convertToDenseBlock(mb, true);
for( String lname :
dir.list() ) {
MatrixBlock tmp
= LocalFileUtils.readMatrixBlockFromLocal( dir+"/"+lname );
-
mergeWithComp(mb, tmp, compare);
+ if( _isAccum )
+
mergeWithoutComp(mb, tmp, compare, appendOnly);
+ else
+
mergeWithComp(mb, tmp, compare);
}
//sort sparse due to
append-only
@@ -279,7 +282,7 @@ public class ResultMergeLocalFile extends ResultMergeMatrix
}
else {
MatrixBlock tmp = LocalFileUtils.readMatrixBlockFromLocal( dir+"/"+lname );
-
mergeWithoutComp(mb, tmp, appendOnly);
+
mergeWithoutComp(mb, tmp, null, appendOnly);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalMemory.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalMemory.java
index a64fbc3492..a546c5edc2 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalMemory.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalMemory.java
@@ -73,7 +73,7 @@ public class ResultMergeLocalMemory extends ResultMergeMatrix
//create compare matrix if required (existing data in
result)
_compare = getCompareMatrix(outMB);
- if( _compare != null )
+ if( _compare != null || _isAccum )
outMBNew.copy(outMB);
//serial merge all inputs
@@ -90,7 +90,7 @@ public class ResultMergeLocalMemory extends ResultMergeMatrix
MatrixBlock inMB = in.acquireRead();
//core merge
- merge( outMBNew, inMB, appendOnly );
+ merge( outMBNew, inMB, _compare,
appendOnly );
//unpin and clear in-memory input_i
in.release();
@@ -169,7 +169,7 @@ public class ResultMergeLocalMemory extends
ResultMergeMatrix
//create compare matrix if required (existing
data in result)
_compare = getCompareMatrix(outMB);
- if( _compare != null )
+ if( _compare != null || _isAccum )
outMBNew.copy(outMB);
//parallel merge of all inputs
@@ -215,7 +215,7 @@ public class ResultMergeLocalMemory extends
ResultMergeMatrix
return moNew;
}
- private static DenseBlock getCompareMatrix( MatrixBlock output ) {
+ private DenseBlock getCompareMatrix( MatrixBlock output ) {
//create compare matrix only if required
if( !output.isEmptyBlock(false) )
return DataConverter.convertToDenseBlock(output, false);
@@ -253,11 +253,12 @@ public class ResultMergeLocalMemory extends
ResultMergeMatrix
*
* @param out output matrix block
* @param in input matrix block
+ * @param compare initialized output
* @param appendOnly ?
*/
- private void merge( MatrixBlock out, MatrixBlock in, boolean appendOnly
) {
- if( _compare == null )
- mergeWithoutComp(out, in, appendOnly, true);
+ private void merge( MatrixBlock out, MatrixBlock in, DenseBlock
compare, boolean appendOnly ) {
+ if( _compare == null || _isAccum )
+ mergeWithoutComp(out, in, _compare, appendOnly, true);
else
mergeWithComp(out, in, _compare);
}
@@ -304,7 +305,7 @@ public class ResultMergeLocalMemory extends
ResultMergeMatrix
LOG.trace("ResultMerge (local, in-memory):
Merge input "+_inMO.hashCode()+" (fname="+_inMO.getFileName()+")");
MatrixBlock inMB = _inMO.acquireRead(); //incl.
implicit read from HDFS
- merge( _outMB, inMB, false );
+ merge( _outMB, inMB, _compare, false );
_inMO.release();
_inMO.clearData();
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeMatrix.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeMatrix.java
index 181899d954..dadc172114 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeMatrix.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeMatrix.java
@@ -22,7 +22,6 @@ package org.apache.sysds.runtime.controlprogram.parfor;
import java.util.List;
import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.compress.utils.Util;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
@@ -58,17 +57,21 @@ public abstract class ResultMergeMatrix extends
ResultMerge<MatrixObject> {
super(out, in, outputFilename, accum);
}
- protected void mergeWithoutComp(MatrixBlock out, MatrixBlock in,
boolean appendOnly) {
- mergeWithoutComp(out, in, appendOnly, false);
+ protected void mergeWithoutComp(MatrixBlock out, MatrixBlock in,
DenseBlock compare, boolean appendOnly) {
+ mergeWithoutComp(out, in, compare, appendOnly, false);
}
- protected void mergeWithoutComp(MatrixBlock out, MatrixBlock in,
boolean appendOnly, boolean par) {
+ protected void mergeWithoutComp(MatrixBlock out, MatrixBlock in,
DenseBlock compare, boolean appendOnly, boolean par) {
// pass through to matrix block operations
- if(_isAccum)
+ if(_isAccum) {
out.binaryOperationsInPlace(PLUS, in);
+ //compare block used for compensation here
+ if( compare != null )
+ out.binaryOperationsInPlace(MINUS,
+ new
MatrixBlock(out.getNumRows(),out.getNumColumns(), compare));
+ }
else {
MatrixBlock out2 = out.merge(in, appendOnly, par);
-
if(out2 != out)
throw new DMLRuntimeException("Failed merge
need to allow returned MatrixBlock to be used");
}
@@ -90,18 +93,13 @@ public abstract class ResultMergeMatrix extends
ResultMerge<MatrixObject> {
// NaNs, since NaN != NaN, otherwise we would potentially
overwrite results
// * For the case of accumulation, we add out += (new-old) to
ensure correct results
// because all inputs have the old values replicated
- final int rows = in.getNumRows();
- final int cols = in.getNumColumns();
- if(in.isEmptyBlock(false)) {
- if(_isAccum)
- return; // nothing to do
+ int rows = in.getNumRows();
+ int cols = in.getNumColumns();
+ if(in.isEmptyBlock(false))
mergeWithCompEmpty(out, rows, cols, compare);
- }
- else if(in.isInSparseFormat() && _isAccum)
- mergeSparseAccumulative(out, in, rows, cols, compare);
else if(in.isInSparseFormat())
mergeSparse(out, in, rows, cols, compare);
- else // SPARSE/DENSE
+ else // DENSE
mergeGeneric(out, in, rows, cols, compare);
}
@@ -111,76 +109,44 @@ public abstract class ResultMergeMatrix extends
ResultMerge<MatrixObject> {
}
private void mergeWithCompEmptyRow(MatrixBlock out, int m, int n,
DenseBlock compare, int i) {
-
for(int j = 0; j < n; j++) {
final double valOld = compare.get(i, j);
- if(!Util.eq(0.0, valOld)) // NaN awareness
+ if(!equals(0.0, valOld)) // NaN awareness
out.set(i, j, 0);
}
}
- private void mergeSparseAccumulative(MatrixBlock out, MatrixBlock in,
int m, int n, DenseBlock compare) {
- final SparseBlock a = in.getSparseBlock();
- for(int i = 0; i < m; i++) {
- if(a.isEmpty(i))
- continue;
- final int apos = a.pos(i);
- final int alen = a.size(i) + apos;
- final int[] aix = a.indexes(i);
- final double[] aval = a.values(i);
- mergeSparseRowAccumulative(out, apos, alen, aix, aval,
compare, n, i);
- }
- }
-
- private void mergeSparseRowAccumulative(MatrixBlock out, int apos, int
alen, int[] aix, double[] aval,
- DenseBlock compare, int n, int i) {
- for(; apos < alen; apos++) { // inside
- final double valOld = compare.get(i, aix[apos]);
- final double valNew = aval[apos];
- if(!Util.eq(valNew, valOld)) { // NaN awareness
- double value = out.get(i, aix[apos]) + (valNew
- valOld);
- out.set(i, aix[apos], value);
- }
- }
- }
-
private void mergeSparse(MatrixBlock out, MatrixBlock in, int m, int n,
DenseBlock compare) {
final SparseBlock a = in.getSparseBlock();
for(int i = 0; i < m; i++) {
if(a.isEmpty(i))
mergeWithCompEmptyRow(out, m, n, compare, i);
else {
- final int apos = a.pos(i);
- final int alen = a.size(i) + apos;
- final int[] aix = a.indexes(i);
- final double[] aval = a.values(i);
- mergeSparseRow(out, apos, alen, aix, aval,
compare, n, i);
- }
- }
- }
-
- private void mergeSparseRow(MatrixBlock out, int apos, int alen, int[]
aix, double[] aval, DenseBlock compare, int n,
- int i) {
- int j = 0;
- for(; j < n && apos < alen; j++) { // inside
- final boolean aposValid = aix[apos] == j;
- final double valOld = compare.get(i, j);
- final double valNew = aix[apos] == j ? aval[apos] : 0.0;
- if(!Util.eq(valNew, valOld)) { // NaN awareness
- double value = !_isAccum ? valNew : (out.get(i,
j) + (valNew - valOld));
- out.set(i, j, value);
- }
- if(aposValid)
- apos++;
- }
- for(; j < n; j++) {
- final double valOld = compare.get(i, j);
- if(valOld != 0) {
- double value = (out.get(i, j) - valOld);
- out.set(i, j, value);
+ int apos = a.pos(i);
+ int alen = a.size(i) + apos;
+ int[] aix = a.indexes(i);
+ double[] avals = a.values(i);
+ int j = 0;
+ for(; j < n && apos < alen; j++) { // inside
+ boolean aposValid = (aix[apos] == j);
+ double valOld = compare.get(i, j);
+ double valNew = aposValid ? avals[apos]
: 0.0;
+ if(!equals(valNew, valOld)) { // NaN
awareness
+ double value = !_isAccum ?
valNew : (out.get(i, j) + (valNew - valOld));
+ out.set(i, j, value);
+ }
+ if(aposValid)
+ apos++;
+ }
+ for(; j < n; j++) {
+ double valOld = compare.get(i, j);
+ if(valOld != 0) {
+ double value = (out.get(i, j) -
valOld);
+ out.set(i, j, value);
+ }
+ }
}
}
-
}
private void mergeGeneric(MatrixBlock out, MatrixBlock in, int m, int
n, DenseBlock compare) {
@@ -188,13 +154,17 @@ public abstract class ResultMergeMatrix extends
ResultMerge<MatrixObject> {
for(int j = 0; j < n; j++) {
final double valOld = compare.get(i, j);
final double valNew = in.get(i, j); // input
value
- if(!Util.eq(valNew, valOld)) { // NaN awareness
- double value = !_isAccum ? valNew :
(out.get(i, j) + (valNew - valOld));
- out.set(i, j, value);
+ if(!equals(valNew, valOld)) { // NaN awareness
+ out.set(i, j, valNew);
}
}
}
}
+
+ private boolean equals(double valNew, double valOld) {
+ return (valNew == valOld && !Double.isNaN(valNew)) //for
changed values
+ || (Double.isNaN(valNew) && Double.isNaN(valOld));
//NaN awareness
+ }
protected long computeNonZeros(MatrixObject out, List<MatrixObject> in)
{
// sum of nnz of input (worker result) - output var existing nnz
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSparkWCompare.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSparkWCompare.java
index 6b8d424b05..da98115c82 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSparkWCompare.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSparkWCompare.java
@@ -52,8 +52,12 @@ public class ResultMergeRemoteSparkWCompare extends
ResultMergeMatrix implements
//merge all blocks into compare block
MatrixBlock out = new MatrixBlock(cin);
- while( din.hasNext() )
- mergeWithComp(out, din.next(), compare);
+ while( din.hasNext() ) {
+ if( _isAccum )
+ mergeWithoutComp(out, din.next(), compare,
false);
+ else
+ mergeWithComp(out, din.next(), compare);
+ }
//create output tuple
return new Tuple2<>(new MatrixIndexes(ixin), out);
diff --git
a/src/test/java/org/apache/sysds/test/component/parfor/ResultMergeTest.java
b/src/test/java/org/apache/sysds/test/component/parfor/ResultMergeTest.java
index 23e94e809c..e2ca770795 100644
--- a/src/test/java/org/apache/sysds/test/component/parfor/ResultMergeTest.java
+++ b/src/test/java/org/apache/sysds/test/component/parfor/ResultMergeTest.java
@@ -30,7 +30,9 @@ import
org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysds.runtime.controlprogram.parfor.ResultMerge;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
@@ -48,38 +50,43 @@ public class ResultMergeTest extends AutomatedTestBase{
}
@Test
- public void testLocalMem() {
- testResultMergeAll(PResultMerge.LOCAL_MEM);
+ public void testLocalMemDenseCompare() {
+ testResultMergeAll(PResultMerge.LOCAL_MEM, false);
}
@Test
- public void testLocalFile() {
- testResultMergeAll(PResultMerge.LOCAL_FILE);
+ public void testLocalFileDenseCompare() {
+ testResultMergeAll(PResultMerge.LOCAL_FILE, false);
}
@Test
- public void testLocalAutomatic() {
- testResultMergeAll(PResultMerge.LOCAL_AUTOMATIC);
+ public void testLocalAutomaticDenseCompare() {
+ testResultMergeAll(PResultMerge.LOCAL_AUTOMATIC, false);
}
- private void testResultMergeAll(PResultMerge mtype) {
- testResultMerge(false, false, false, false, mtype);
- testResultMerge(false, true, false, false, mtype);
- testResultMerge(true, false, false, false, mtype);
- testResultMerge(false, false, true, false, mtype);
- if( mtype != PResultMerge.LOCAL_FILE ) //FIXME
- testResultMerge(false, true, true, false, mtype);
- testResultMerge(true, false, true, false, mtype);
- //testResultMerge(true, true, false, false, mtype); invalid
+ @Test
+ public void testLocalMemSparseCompare() {
+ testResultMergeAll(PResultMerge.LOCAL_MEM, true);
+ }
+
+ @Test
+ public void testLocalFileSparseCompare() {
+ testResultMergeAll(PResultMerge.LOCAL_FILE, true);
+ }
+
+ @Test
+ public void testLocalAutomaticSparseCompare() {
+ testResultMergeAll(PResultMerge.LOCAL_AUTOMATIC, true);
+ }
- /* FIXME sparse compare
- testResultMerge(false, false, false, true, mtype);
- testResultMerge(false, true, false, true, mtype);
- testResultMerge(true, false, false, true, mtype);
- testResultMerge(false, false, true, true, mtype);
- testResultMerge(false, true, true, true, mtype);
- testResultMerge(true, false, true, true, mtype);
- */
+ private void testResultMergeAll(PResultMerge mtype, boolean
sparseCompare) {
+ testResultMerge(false, false, false, sparseCompare, mtype);
+ testResultMerge(false, true, false, sparseCompare, mtype);
+ testResultMerge(true, false, false, sparseCompare, mtype);
+ testResultMerge(false, false, true, sparseCompare, mtype);
+ testResultMerge(false, true, true, sparseCompare, mtype);
+ testResultMerge(true, false, true, sparseCompare, mtype);
+ //testResultMerge(true, true, false, false, mtype); invalid
}
private void testResultMerge(boolean par, boolean accum, boolean
compare, boolean sparseCompare, PResultMerge mtype) {
@@ -87,19 +94,29 @@ public class ResultMergeTest extends AutomatedTestBase{
loadTestConfiguration(getTestConfiguration(TEST_NAME));
//create input and output objects
- MatrixBlock A = MatrixBlock.randOperations(1200, 1100,
0.1);
- CacheableData<?> Cobj = compare ?
- toMatrixObject(new MatrixBlock(1200,1100,1d),
output("C")) :
- toMatrixObject(new MatrixBlock(1200,1100,true),
output("C"));
+ int rows = 1200, cols = 1100;
+ MatrixBlock A =
MatrixBlock.randOperations(rows,cols,0.1,0,1,"uniform",7);
+ A.checkSparseRows();
MatrixBlock rest = compare ?
- new MatrixBlock(400,1100,sparseCompare?0.2:1.0)
: //constant
- new MatrixBlock(400,1100,true); //empty (also
sparse)
- MatrixObject[] Bobj = new MatrixObject[4];
- Bobj[0] =
toMatrixObject(A.slice(0,399).rbind(rest).rbind(rest), output("B0"));
- Bobj[1] =
toMatrixObject(rest.rbind(A.slice(400,799)).rbind(rest), output("B1"));
- Bobj[2] =
toMatrixObject(rest.rbind(rest).rbind(A.slice(800,1199)), output("B2"));
- Bobj[3] = toMatrixObject(rest.rbind(rest).rbind(rest),
output("B3"));
-
+
MatrixBlock.randOperations(rows/3,cols,sparseCompare?0.2:1.0,1,1,"uniform",3):
//constant
+ new MatrixBlock(rows/3,cols,true); //empty
(also sparse)
+ CacheableData<?> Cobj = compare ?
+ toMatrixObject(rest.rbind(rest).rbind(rest),
output("C")) :
+ toMatrixObject(new MatrixBlock(rows,cols,true),
output("C"));
+ MatrixObject[] Bobj = new MatrixObject[3];
+ Bobj[0] =
toMatrixObject(A.slice(0,rows/3-1).rbind(rest).rbind(rest), output("B0"));
+ Bobj[1] =
toMatrixObject(rest.rbind(A.slice(rows/3,2*rows/3-1)).rbind(rest),
output("B1"));
+ Bobj[2] =
toMatrixObject(rest.rbind(rest).rbind(A.slice(2*rows/3,rows-1)), output("B2"));
+ BinaryOperator PLUS =
InstructionUtils.parseBinaryOperator("+");
+ BinaryOperator MINUS =
InstructionUtils.parseBinaryOperator("-");
+ MatrixBlock aggAll =
((MatrixBlock)Cobj.acquireReadAndRelease())
+ .binaryOperations(PLUS,
Bobj[0].acquireReadAndRelease())
+ .binaryOperations(PLUS,
Bobj[1].acquireReadAndRelease())
+ .binaryOperations(PLUS,
Bobj[2].acquireReadAndRelease())
+ .binaryOperations(MINUS,
(MatrixBlock)Cobj.acquireReadAndRelease())
+ .binaryOperations(MINUS,
(MatrixBlock)Cobj.acquireReadAndRelease())
+ .binaryOperations(MINUS,
(MatrixBlock)Cobj.acquireReadAndRelease());
+
//create result merge
ExecutionContext ec =
ExecutionContextFactory.createContext();
int numThreads = 4;
@@ -113,8 +130,11 @@ public class ResultMergeTest extends AutomatedTestBase{
Cobj = rm.executeSerialMerge();
//check results
- TestUtils.compareMatrices(A,
- (MatrixBlock)Cobj.acquireReadAndRelease(),
1e-14);
+ MatrixBlock out =
(MatrixBlock)Cobj.acquireReadAndRelease();
+ if(!accum)
+ TestUtils.compareMatrices(A, out, 1e-14);
+ else
+ TestUtils.compareMatrices(aggAll, out, 1e-14);
}
catch(Exception e){
e.printStackTrace();