Repository: systemml
Updated Branches:
  refs/heads/master c10e509a7 -> 97f684d2b


[SYSTEMML-2383] Fix robustness rowIndexMin/Max for entire rows of NaNs

This patch fixes the unary aggregate operations rowIndexMin and
rowIndexMax for correct handling of rows of all NaNs which so far
resulted in runtime exceptions due to resulting index of -1. We now
define the result as 1 in these cases and also properly reject inputs
with zero columns which would similarly fail.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/97f684d2
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/97f684d2
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/97f684d2

Branch: refs/heads/master
Commit: 97f684d2b47030f2992853bf7b695b4c4e355d0f
Parents: c10e509
Author: Matthias Boehm <[email protected]>
Authored: Mon Jun 11 18:45:33 2018 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Mon Jun 11 18:45:54 2018 -0700

----------------------------------------------------------------------
 .../sysml/runtime/matrix/data/LibMatrixAgg.java | 22 ++++++++----
 .../functions/aggregate/AggregateNaNTest.java   | 38 ++++++++++++++++++++
 2 files changed, 54 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/97f684d2/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
index db73ce2..174f2a5 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
@@ -1318,7 +1318,7 @@ public class LibMatrixAgg
                        case MAX_INDEX: {
                                double init = Double.NEGATIVE_INFINITY;
                                if( ixFn instanceof ReduceCol ) //ROWINDEXMAX
-                                       d_uarimxx(a, c, n, init, (Builtin)vFn, 
rl, ru);
+                                       d_uarimax(a, c, n, init, (Builtin)vFn, 
rl, ru);
                                break;
                        }
                        case MIN_INDEX: {
@@ -1424,7 +1424,7 @@ public class LibMatrixAgg
                        case MAX_INDEX: {
                                double init = Double.NEGATIVE_INFINITY;
                                if( ixFn instanceof ReduceCol ) //ROWINDEXMAX
-                                       s_uarimxx(a, c, n, init, (Builtin)vFn, 
rl, ru);
+                                       s_uarimax(a, c, n, init, (Builtin)vFn, 
rl, ru);
                                break;
                        }
                        case MIN_INDEX: {
@@ -1899,7 +1899,9 @@ public class LibMatrixAgg
         * @param rl row lower index
         * @param ru row upper index
         */
-       private static void d_uarimxx( DenseBlock a, DenseBlock c, int n, 
double init, Builtin builtin, int rl, int ru ) {
+       private static void d_uarimax( DenseBlock a, DenseBlock c, int n, 
double init, Builtin builtin, int rl, int ru ) {
+               if( n <= 0 )
+                       throw new DMLRuntimeException("rowIndexMax undefined 
for ncol="+n);
                for( int i=rl; i<ru; i++ ) {
                        int maxindex = indexmax(a.values(i), a.pos(i), init, n, 
builtin);
                        c.set(i, 0, (double)maxindex + 1);
@@ -1919,6 +1921,8 @@ public class LibMatrixAgg
         * @param ru row upper index
         */
        private static void d_uarimin( DenseBlock a, DenseBlock c, int n, 
double init, Builtin builtin, int rl, int ru ) {
+               if( n <= 0 )
+                       throw new DMLRuntimeException("rowIndexMin undefined 
for ncol="+n);
                for( int i=rl; i<ru; i++ ) {
                        int minindex = indexmin(a.values(i), a.pos(i), init, n, 
builtin);
                        c.set(i, 0, (double)minindex + 1);
@@ -2533,7 +2537,9 @@ public class LibMatrixAgg
         * @param rl row lower index
         * @param ru row upper index
         */
-       private static void s_uarimxx( SparseBlock a, DenseBlock c, int n, 
double init, Builtin builtin, int rl, int ru ) {
+       private static void s_uarimax( SparseBlock a, DenseBlock c, int n, 
double init, Builtin builtin, int rl, int ru ) {
+               if( n <= 0 )
+                       throw new DMLRuntimeException("rowIndexMax is undefined 
for ncol="+n);
                for( int i=rl; i<ru; i++ ) {
                        if( !a.isEmpty(i) ) {
                                int apos = a.pos(i);
@@ -2574,6 +2580,8 @@ public class LibMatrixAgg
         * @param ru row upper index
         */
        private static void s_uarimin( SparseBlock a, DenseBlock c, int n, 
double init, Builtin builtin, int rl, int ru ) {
+               if( n <= 0 )
+                       throw new DMLRuntimeException("rowIndexMin is undefined 
for ncol="+n);
                for( int i=rl; i<ru; i++ ) {
                        if( !a.isEmpty(i) ) {
                                int apos = a.pos(i);
@@ -3085,7 +3093,8 @@ public class LibMatrixAgg
                        maxindex = (a[i]>=maxval) ? i-ai : maxindex;
                        maxval = (a[i]>=maxval) ? a[i] : maxval;
                }
-               return maxindex;
+               //note: robustness for all-NaN rows
+               return Math.max(maxindex, 0);
        }
 
        private static int indexmin( double[] a, int ai, final double init, 
final int len, Builtin aggop ) {
@@ -3095,7 +3104,8 @@ public class LibMatrixAgg
                        minindex = (a[i]<=minval) ? i-ai : minindex;
                        minval = (a[i]<=minval) ? a[i] : minval;
                }
-               return minindex;
+               //note: robustness for all-NaN rows
+               return Math.max(minindex, 0);
        }
 
        public static void countAgg( double[] a, int[] c, int[] aix, int ai, 
final int len ) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/97f684d2/src/test/java/org/apache/sysml/test/integration/functions/aggregate/AggregateNaNTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/aggregate/AggregateNaNTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/AggregateNaNTest.java
index 4499214..cdfa3cb 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/aggregate/AggregateNaNTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/AggregateNaNTest.java
@@ -22,7 +22,11 @@ package 
org.apache.sysml.test.integration.functions.aggregate;
 import org.junit.Assert;
 import org.junit.Test;
 
+import java.util.Arrays;
+
+import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysml.runtime.util.DataConverter;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
@@ -84,6 +88,26 @@ public class AggregateNaNTest extends AutomatedTestBase
                runNaNAggregateTest(3, true);
        }
        
+       @Test
+       public void testRowIndexMaxDenseNaN() {
+               runNaNRowIndexMxxTest("uarimax", false);
+       }
+       
+       @Test
+       public void testRowIndexMaxSparseNaN() {
+               runNaNRowIndexMxxTest("uarimax", true);
+       }
+       
+       @Test
+       public void testRowIndexMinDenseNaN() {
+               runNaNRowIndexMxxTest("uarimin", false);
+       }
+       
+       @Test
+       public void testRowIndexMinSparseNaN() {
+               runNaNRowIndexMxxTest("uarimin", true);
+       }
+       
        private void runNaNAggregateTest(int type, boolean sparse) {
                //generate input
                double sparsity = sparse ? sparsity1 : sparsity2;
@@ -101,4 +125,18 @@ public class AggregateNaNTest extends AutomatedTestBase
                
                Assert.assertTrue(Double.isNaN(ret));
        }
+       
+       private void runNaNRowIndexMxxTest(String type, boolean sparse) {
+               //generate input
+               double sparsity = sparse ? sparsity1 : sparsity2;
+               double[][] A = getRandomMatrix(rows, cols, -0.05, 1, sparsity, 
7);
+               Arrays.fill(A[7], Double.NaN);
+               MatrixBlock mb = DataConverter.convertToMatrixBlock(A);
+               
+               double ret = mb.aggregateUnaryOperations(
+               InstructionUtils.parseBasicAggregateUnaryOperator(type),
+               new MatrixBlock(), -1, -1, new MatrixIndexes(1, 1), 
true).getValue(7, 0);
+
+               Assert.assertTrue(ret == 1);
+       }
 }

Reply via email to