Repository: systemml Updated Branches: refs/heads/master c10e509a7 -> 97f684d2b
[SYSTEMML-2383] Fix robustness rowIndexMin/Max for entire rows of NaNs This patch fixes the unary aggregate operations rowIndexMin and rowIndexMax for correct handling of rows of all NaNs which so far resulted in runtime exceptions due to resulting index of -1. We now define the result as 1 in these cases and also properly reject inputs with zero columns which would similarly fail. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/97f684d2 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/97f684d2 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/97f684d2 Branch: refs/heads/master Commit: 97f684d2b47030f2992853bf7b695b4c4e355d0f Parents: c10e509 Author: Matthias Boehm <[email protected]> Authored: Mon Jun 11 18:45:33 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Mon Jun 11 18:45:54 2018 -0700 ---------------------------------------------------------------------- .../sysml/runtime/matrix/data/LibMatrixAgg.java | 22 ++++++++---- .../functions/aggregate/AggregateNaNTest.java | 38 ++++++++++++++++++++ 2 files changed, 54 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/97f684d2/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java index db73ce2..174f2a5 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java @@ -1318,7 +1318,7 @@ public class LibMatrixAgg case MAX_INDEX: { double init = Double.NEGATIVE_INFINITY; if( ixFn instanceof ReduceCol ) //ROWINDEXMAX - d_uarimxx(a, c, n, init, (Builtin)vFn, rl, ru); + d_uarimax(a, c, n, init, (Builtin)vFn, rl, ru); break; } case MIN_INDEX: { @@ -1424,7 +1424,7 @@ public class LibMatrixAgg case MAX_INDEX: { double init = Double.NEGATIVE_INFINITY; if( ixFn instanceof ReduceCol ) //ROWINDEXMAX - s_uarimxx(a, c, n, init, (Builtin)vFn, rl, ru); + s_uarimax(a, c, n, init, (Builtin)vFn, rl, ru); break; } case MIN_INDEX: { @@ -1899,7 +1899,9 @@ public class LibMatrixAgg * @param rl row lower index * @param ru row upper index */ - private static void d_uarimxx( DenseBlock a, DenseBlock c, int n, double init, Builtin builtin, int rl, int ru ) { + private static void d_uarimax( DenseBlock a, DenseBlock c, int n, double init, Builtin builtin, int rl, int ru ) { + if( n <= 0 ) + throw new DMLRuntimeException("rowIndexMax undefined for ncol="+n); for( int i=rl; i<ru; i++ ) { int maxindex = indexmax(a.values(i), a.pos(i), init, n, builtin); c.set(i, 0, (double)maxindex + 1); @@ -1919,6 +1921,8 @@ public class LibMatrixAgg * @param ru row upper index */ private static void d_uarimin( DenseBlock a, DenseBlock c, int n, double init, Builtin builtin, int rl, int ru ) { + if( n <= 0 ) + throw new DMLRuntimeException("rowIndexMin undefined for ncol="+n); for( int i=rl; i<ru; i++ ) { int minindex = indexmin(a.values(i), a.pos(i), init, n, builtin); c.set(i, 0, (double)minindex + 1); @@ -2533,7 +2537,9 @@ public class LibMatrixAgg * @param rl row lower index * @param ru row upper index */ - private static void s_uarimxx( SparseBlock a, DenseBlock c, int n, double init, Builtin builtin, int rl, int ru ) { + private static void s_uarimax( SparseBlock a, DenseBlock c, int n, double init, Builtin builtin, int rl, int ru ) { + if( n <= 0 ) + throw new DMLRuntimeException("rowIndexMax is undefined for ncol="+n); for( int i=rl; i<ru; i++ ) { if( !a.isEmpty(i) ) { int apos = a.pos(i); @@ -2574,6 +2580,8 @@ public class LibMatrixAgg * @param ru row upper index */ private static void s_uarimin( SparseBlock a, DenseBlock c, int n, double init, Builtin builtin, int rl, int ru ) { + if( n <= 0 ) + throw new DMLRuntimeException("rowIndexMin is undefined for ncol="+n); for( int i=rl; i<ru; i++ ) { if( !a.isEmpty(i) ) { int apos = a.pos(i); @@ -3085,7 +3093,8 @@ public class LibMatrixAgg maxindex = (a[i]>=maxval) ? i-ai : maxindex; maxval = (a[i]>=maxval) ? a[i] : maxval; } - return maxindex; + //note: robustness for all-NaN rows + return Math.max(maxindex, 0); } private static int indexmin( double[] a, int ai, final double init, final int len, Builtin aggop ) { @@ -3095,7 +3104,8 @@ public class LibMatrixAgg minindex = (a[i]<=minval) ? i-ai : minindex; minval = (a[i]<=minval) ? a[i] : minval; } - return minindex; + //note: robustness for all-NaN rows + return Math.max(minindex, 0); } public static void countAgg( double[] a, int[] c, int[] aix, int ai, final int len ) { http://git-wip-us.apache.org/repos/asf/systemml/blob/97f684d2/src/test/java/org/apache/sysml/test/integration/functions/aggregate/AggregateNaNTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/aggregate/AggregateNaNTest.java b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/AggregateNaNTest.java index 4499214..cdfa3cb 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/aggregate/AggregateNaNTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/AggregateNaNTest.java @@ -22,7 +22,11 @@ package org.apache.sysml.test.integration.functions.aggregate; import org.junit.Assert; import org.junit.Test; +import java.util.Arrays; + +import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.MatrixIndexes; import org.apache.sysml.runtime.util.DataConverter; import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.TestConfiguration; @@ -84,6 +88,26 @@ public class AggregateNaNTest extends AutomatedTestBase runNaNAggregateTest(3, true); } + @Test + public void testRowIndexMaxDenseNaN() { + runNaNRowIndexMxxTest("uarimax", false); + } + + @Test + public void testRowIndexMaxSparseNaN() { + runNaNRowIndexMxxTest("uarimax", true); + } + + @Test + public void testRowIndexMinDenseNaN() { + runNaNRowIndexMxxTest("uarimin", false); + } + + @Test + public void testRowIndexMinSparseNaN() { + runNaNRowIndexMxxTest("uarimin", true); + } + private void runNaNAggregateTest(int type, boolean sparse) { //generate input double sparsity = sparse ? sparsity1 : sparsity2; @@ -101,4 +125,18 @@ public class AggregateNaNTest extends AutomatedTestBase Assert.assertTrue(Double.isNaN(ret)); } + + private void runNaNRowIndexMxxTest(String type, boolean sparse) { + //generate input + double sparsity = sparse ? sparsity1 : sparsity2; + double[][] A = getRandomMatrix(rows, cols, -0.05, 1, sparsity, 7); + Arrays.fill(A[7], Double.NaN); + MatrixBlock mb = DataConverter.convertToMatrixBlock(A); + + double ret = mb.aggregateUnaryOperations( + InstructionUtils.parseBasicAggregateUnaryOperator(type), + new MatrixBlock(), -1, -1, new MatrixIndexes(1, 1), true).getValue(7, 0); + + Assert.assertTrue(ret == 1); + } }
