[SYSTEMML-2240] Fix min/max correctness over inputs w/ NaNs, incl tests

We generally have limited support for NaNs in SystemML because any
sparse operations would be invalid because NaN*0=NaN. However, while sum
and sumSq compute correct results this is not the case for min and max
due to the use of java's basic operators. Instead, we now consistently
use the NaN-aware Math library functions, which so far were only used by
a subset of codegen primitives.

Furthermore, this also includes a cleanup of the builtin function object
that removes the custom execute2 method which is no longer necessary
after we have simplified the exception hierarchy.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ca615ca1
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ca615ca1
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ca615ca1

Branch: refs/heads/master
Commit: ca615ca10afdd793487c28afb9d5d937a9935eb4
Parents: 6fa83d3
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Fri Apr 13 18:00:55 2018 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Fri Apr 13 18:00:55 2018 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/cplan/CNodeBinary.java   |   4 +-
 .../sysml/runtime/compress/ColGroupDDC.java     |   2 +-
 .../sysml/runtime/compress/ColGroupOLE.java     |   2 +-
 .../sysml/runtime/compress/ColGroupOffset.java  |   2 +-
 .../sysml/runtime/compress/ColGroupRLE.java     |   2 +-
 .../sysml/runtime/compress/ColGroupValue.java   |  12 +-
 .../runtime/compress/CompressedMatrixBlock.java |   2 +-
 .../sysml/runtime/functionobjects/Builtin.java  | 160 ++++++++-----------
 .../sysml/runtime/matrix/data/LibMatrixAgg.java |  26 +--
 .../functions/aggregate/AggregateNaNTest.java   | 104 ++++++++++++
 .../functions/aggregate/ZPackageSuite.java      |   1 +
 11 files changed, 194 insertions(+), 123 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/ca615ca1/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java 
b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
index d1343c2..4aee712 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
@@ -197,9 +197,9 @@ public class CNodeBinary extends CNode
                                        return "    double %TMP% = (%IN1% != 
%IN2%) ? 1 : 0;\n";
                                
                                case MIN:
-                                       return "    double %TMP% = (%IN1% <= 
%IN2%) ? %IN1% : %IN2%;\n";
+                                       return "    double %TMP% = 
Math.min(%IN1%, %IN2%);\n";
                                case MAX:
-                                       return "    double %TMP% = (%IN1% >= 
%IN2%) ? %IN1% : %IN2%;\n";
+                                       return "    double %TMP% = 
Math.max(%IN1%, %IN2%);\n";
                                case LOG:
                                        return "    double %TMP% = 
Math.log(%IN1%)/Math.log(%IN2%);\n";
                                case LOG_NZ:

http://git-wip-us.apache.org/repos/asf/systemml/blob/ca615ca1/src/main/java/org/apache/sysml/runtime/compress/ColGroupDDC.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/compress/ColGroupDDC.java 
b/src/main/java/org/apache/sysml/runtime/compress/ColGroupDDC.java
index 40eef8e..24105ab 100644
--- a/src/main/java/org/apache/sysml/runtime/compress/ColGroupDDC.java
+++ b/src/main/java/org/apache/sysml/runtime/compress/ColGroupDDC.java
@@ -197,7 +197,7 @@ public abstract class ColGroupDDC extends ColGroupValue
                
                for( int i=rl; i<ru; i++ )
                        for( int j=0; j<ncol; j++ )
-                               c[i] = builtin.execute2(c[i], getData(i, j));
+                               c[i] = builtin.execute(c[i], getData(i, j));
        }
        
        protected final void postScaling(double[] vals, double[] c) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/ca615ca1/src/main/java/org/apache/sysml/runtime/compress/ColGroupOLE.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/compress/ColGroupOLE.java 
b/src/main/java/org/apache/sysml/runtime/compress/ColGroupOLE.java
index 011e3f0..4684a43 100644
--- a/src/main/java/org/apache/sysml/runtime/compress/ColGroupOLE.java
+++ b/src/main/java/org/apache/sysml/runtime/compress/ColGroupOLE.java
@@ -642,7 +642,7 @@ public class ColGroupOLE extends ColGroupOffset
                                slen = _data[boff+bix];
                                for (int i = 1; i <= slen; i++) {
                                        int rix = off + _data[boff+bix + i];
-                                       c[rix] = builtin.execute2(c[rix], val);
+                                       c[rix] = builtin.execute(c[rix], val);
                                }
                        }
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ca615ca1/src/main/java/org/apache/sysml/runtime/compress/ColGroupOffset.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/compress/ColGroupOffset.java 
b/src/main/java/org/apache/sysml/runtime/compress/ColGroupOffset.java
index 2aa23bf..606d8cc 100644
--- a/src/main/java/org/apache/sysml/runtime/compress/ColGroupOffset.java
+++ b/src/main/java/org/apache/sysml/runtime/compress/ColGroupOffset.java
@@ -244,7 +244,7 @@ public abstract class ColGroupOffset extends ColGroupValue
                double val = (builtin.getBuiltinCode()==BuiltinCode.MAX) ?
                        Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY;
                for( int i = 0; i < numCols; i++ )
-                       val = builtin.execute2(val, _values[valOff+i]);
+                       val = builtin.execute(val, _values[valOff+i]);
                
                return val;
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ca615ca1/src/main/java/org/apache/sysml/runtime/compress/ColGroupRLE.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/compress/ColGroupRLE.java 
b/src/main/java/org/apache/sysml/runtime/compress/ColGroupRLE.java
index 865e09e..dc2b8b7 100644
--- a/src/main/java/org/apache/sysml/runtime/compress/ColGroupRLE.java
+++ b/src/main/java/org/apache/sysml/runtime/compress/ColGroupRLE.java
@@ -667,7 +667,7 @@ public class ColGroupRLE extends ColGroupOffset
                                curRunStartOff = curRunEnd + _data[boff+bix];
                                curRunEnd = curRunStartOff + _data[boff+bix+1];
                                for (int rix=curRunStartOff; rix<curRunEnd && 
rix<ru; rix++)
-                                       c[rix] = builtin.execute2(c[rix], val);
+                                       c[rix] = builtin.execute(c[rix], val);
                        }
                }
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ca615ca1/src/main/java/org/apache/sysml/runtime/compress/ColGroupValue.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/compress/ColGroupValue.java 
b/src/main/java/org/apache/sysml/runtime/compress/ColGroupValue.java
index a09512c..e3a4184 100644
--- a/src/main/java/org/apache/sysml/runtime/compress/ColGroupValue.java
+++ b/src/main/java/org/apache/sysml/runtime/compress/ColGroupValue.java
@@ -278,17 +278,17 @@ public abstract class ColGroupValue extends ColGroup
                double val = (builtin.getBuiltinCode()==BuiltinCode.MAX) ?
                        Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY;
                if( zeros )
-                       val = builtin.execute2(val, 0);
+                       val = builtin.execute(val, 0);
                
                //iterate over all values only
                final int numVals = getNumValues();
-               final int numCols = getNumCols();               
+               final int numCols = getNumCols();
                for (int k = 0; k < numVals; k++)
                        for( int j=0, valOff = k*numCols; j<numCols; j++ )
-                               val = builtin.execute2(val, _values[ valOff+j 
]);
+                               val = builtin.execute(val, _values[ valOff+j ]);
                
                //compute new partial aggregate
-               val = builtin.execute2(val, result.quickGetValue(0, 0));
+               val = builtin.execute(val, result.quickGetValue(0, 0));
                result.quickSetValue(0, 0, val);
        }
        
@@ -310,13 +310,13 @@ public abstract class ColGroupValue extends ColGroup
                        Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY);
                if( zeros ) {
                        for( int j = 0; j < numCols; j++ )
-                               vals[j] = builtin.execute2(vals[j], 0);
+                               vals[j] = builtin.execute(vals[j], 0);
                }
                
                //iterate over all values only
                for (int k = 0; k < numVals; k++) 
                        for( int j=0, valOff=k*numCols; j<numCols; j++ )
-                               vals[j] = builtin.execute2(vals[j], _values[ 
valOff+j ]);
+                               vals[j] = builtin.execute(vals[j], _values[ 
valOff+j ]);
                
                //copy results to output
                for( int j=0; j<numCols; j++ )

http://git-wip-us.apache.org/repos/asf/systemml/blob/ca615ca1/src/main/java/org/apache/sysml/runtime/compress/CompressedMatrixBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/compress/CompressedMatrixBlock.java 
b/src/main/java/org/apache/sysml/runtime/compress/CompressedMatrixBlock.java
index 2051f34..d1df033 100644
--- a/src/main/java/org/apache/sysml/runtime/compress/CompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/compress/CompressedMatrixBlock.java
@@ -1301,7 +1301,7 @@ public class CompressedMatrixBlock extends MatrixBlock 
implements Externalizable
                        Builtin builtin = (Builtin)op.aggOp.increOp.fn;
                        for( int i=0; i<rlen; i++ )
                                if( rnnz[i] < clen )
-                                       ret.quickSetValue(i, 0, 
builtin.execute2(ret.quickGetValue(i, 0), 0));
+                                       ret.quickSetValue(i, 0, 
builtin.execute(ret.quickGetValue(i, 0), 0));
                }
                
                //drop correction if necessary

http://git-wip-us.apache.org/repos/asf/systemml/blob/ca615ca1/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java 
b/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java
index e96a195..6a04a87 100644
--- a/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/Builtin.java
@@ -341,117 +341,83 @@ public class Builtin extends ValueFunction
 
        /*
         * Builtin functions with two inputs
-        */     
+        */
        @Override
        public double execute (double in1, double in2) {
                switch(bFunc) {
-               
-               /*
-                * Arithmetic relational operators (==, !=, <=, >=) must be 
instead of
-                * <code>Double.compare()</code> due to the inconsistencies in 
the way
-                * NaN and -0.0 are handled. The behavior of methods in
-                * <code>Double</code> class are designed mainly to make Java
-                * collections work properly. For more details, see the help for
-                * <code>Double.equals()</code> and 
<code>Double.comapreTo()</code>.
-                */
-               case MAX:
-               case CUMMAX:
-                       //return (Double.compare(in1, in2) >= 0 ? in1 : in2);
-                       return (in1 >= in2 ? in1 : in2);
-               case MIN:
-               case CUMMIN:
-                       //return (Double.compare(in1, in2) <= 0 ? in1 : in2);
-                       return (in1 <= in2 ? in1 : in2);
-                       
-                       // *** HACK ALERT *** HACK ALERT *** HACK ALERT ***
-                       // rowIndexMax() and its siblings require comparing 
four values, but
-                       // the aggregation API only allows two values. So the 
execute()
-                       // method receives as its argument the two cell values 
to be
-                       // compared and performs just the value part of the 
comparison. We
-                       // return an integer cast down to a double, since the 
aggregation
-                       // API doesn't have any way to return anything but a 
double. The
-                       // integer returned takes on three posssible values: //
-                       // .     0 => keep the index associated with in1 //
-                       // .     1 => use the index associated with in2 //
-                       // .     2 => use whichever index is higher (tie in 
value) //
-               case MAXINDEX:
-                       if (in1 == in2) {
-                               return 2;
-                       } else if (in1 > in2) {
-                               return 1;
-                       } else { // in1 < in2
-                               return 0;
-                       }
-               case MININDEX:
-                       if (in1 == in2) {
-                               return 2;
-                       } else if (in1 < in2) {
-                               return 1;
-                       } else { // in1 > in2
-                               return 0;
-                       }
-                       // *** END HACK ***
-               case LOG:
-                       //faster in Math
-                       return (Math.log(in1)/Math.log(in2)); 
-               case LOG_NZ:
-                       //faster in Math
-                       return (in1==0) ? 0 : (Math.log(in1)/Math.log(in2));
-               default:
-                       throw new DMLRuntimeException("Builtin.execute(): 
Unknown operation: " + bFunc);
-               }
-       }
-       
-       /**
-        * Simplified version without exception handling
-        * 
-        * @param in1 double 1
-        * @param in2 double 2
-        * @return result
-        */
-       public double execute2(double in1, double in2) 
-       {
-               switch(bFunc) {
+                       /*
+                        * Arithmetic relational operators (==, !=, <=, >=) 
must be instead of
+                        * <code>Double.compare()</code> due to the 
inconsistencies in the way
+                        * NaN and -0.0 are handled. The behavior of methods in
+                        * <code>Double</code> class are designed mainly to 
make Java
+                        * collections work properly. For more details, see the 
help for
+                        * <code>Double.equals()</code> and 
<code>Double.comapreTo()</code>.
+                        */
                        case MAX:
                        case CUMMAX:
-                               //return (Double.compare(in1, in2) >= 0 ? in1 : 
in2); 
-                               return (in1 >= in2 ? in1 : in2);
+                               return Math.max(in1, in2);
                        case MIN:
                        case CUMMIN:
-                               //return (Double.compare(in1, in2) <= 0 ? in1 : 
in2); 
-                               return (in1 <= in2 ? in1 : in2);
-                       case MAXINDEX: 
-                               return (in1 >= in2) ? 1 : 0;
-                       case MININDEX: 
-                               return (in1 <= in2) ? 1 : 0;
+                               return Math.min(in1, in2);
+                               
+                               // *** HACK ALERT *** HACK ALERT *** HACK ALERT 
***
+                               // rowIndexMax() and its siblings require 
comparing four values, but
+                               // the aggregation API only allows two values. 
So the execute()
+                               // method receives as its argument the two cell 
values to be
+                               // compared and performs just the value part of 
the comparison. We
+                               // return an integer cast down to a double, 
since the aggregation
+                               // API doesn't have any way to return anything 
but a double. The
+                               // integer returned takes on three posssible 
values: //
+                               // .     0 => keep the index associated with 
in1 //
+                               // .     1 => use the index associated with in2 
//
+                               // .     2 => use whichever index is higher 
(tie in value) //
+                       case MAXINDEX:
+                               if (in1 == in2) {
+                                       return 2;
+                               } else if (in1 > in2) {
+                                       return 1;
+                               } else { // in1 < in2
+                                       return 0;
+                               }
+                       case MININDEX:
+                               if (in1 == in2) {
+                                       return 2;
+                               } else if (in1 < in2) {
+                                       return 1;
+                               } else { // in1 > in2
+                                       return 0;
+                               }
+                               // *** END HACK ***
+                       case LOG://faster in Math
+                               return (Math.log(in1)/Math.log(in2)); 
+                       case LOG_NZ: //faster in Math
+                               return (in1==0) ? 0 : 
(Math.log(in1)/Math.log(in2));
                        default:
-                               // For performance reasons, avoid throwing an 
exception 
-                               return -1;
+                               throw new 
DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
                }
        }
        
        @Override
        public double execute (long in1, long in2) {
                switch(bFunc) {
-               
-               case MAX:    
-               case CUMMAX:   return (in1 >= in2 ? in1 : in2); 
-               
-               case MIN:    
-               case CUMMIN:   return (in1 <= in2 ? in1 : in2); 
-               
-               case MAXINDEX: return (in1 >= in2) ? 1 : 0;
-               case MININDEX: return (in1 <= in2) ? 1 : 0;
-               
-               case LOG:
-                       //faster in Math
-                       return Math.log(in1)/Math.log(in2);
-               case LOG_NZ:
-                       //faster in Math
-                       return (in1==0) ? 0 : Math.log(in1)/Math.log(in2);
-
-               default:
-                       throw new DMLRuntimeException("Builtin.execute(): 
Unknown operation: " + bFunc);
+                       case MAX:
+                       case CUMMAX:   return Math.max(in1, in2);
+                       
+                       case MIN:
+                       case CUMMIN:   return Math.min(in1, in2);
+                       
+                       case MAXINDEX: return (in1 >= in2) ? 1 : 0;
+                       case MININDEX: return (in1 <= in2) ? 1 : 0;
+                       
+                       case LOG:
+                               //faster in Math
+                               return Math.log(in1)/Math.log(in2);
+                       case LOG_NZ:
+                               //faster in Math
+                               return (in1==0) ? 0 : 
Math.log(in1)/Math.log(in2);
+       
+                       default:
+                               throw new 
DMLRuntimeException("Builtin.execute(): Unknown operation: " + bFunc);
                }
        }
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/ca615ca1/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
index 72b90fa..5d66ff8 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
@@ -2388,7 +2388,7 @@ public class LibMatrixAgg
                        //note: we need to determine if there are only nnz in a 
column
                        for( int j=0; j<n; j++ )
                                if( cnt[j] < i+1 ) //no dense column
-                                       cmxx[j] = builtin.execute2(cmxx[j], 0);
+                                       cmxx[j] = builtin.execute(cmxx[j], 0);
                        
                        //always copy current sum (not sparse-safe)
                        System.arraycopy(cmxx, 0, c, ix, n);
@@ -2431,19 +2431,19 @@ public class LibMatrixAgg
                if( a.isContiguous() ) {
                        int alen = (int) a.size(rl, ru);
                        double val = builtin(a.values(rl), a.pos(rl), init, 
alen, builtin);
-                       ret = builtin.execute2(ret, val);
+                       ret = builtin.execute(ret, val);
                        //correction (not sparse-safe)
-                       ret = (alen<(ru-rl)*n) ? builtin.execute2(ret, 0) : ret;
+                       ret = (alen<(ru-rl)*n) ? builtin.execute(ret, 0) : ret;
                }
                else {
                        for( int i=rl; i<ru; i++ ) {
                                if( !a.isEmpty(i) ) {
                                        double lval = builtin(a.values(i), 
a.pos(i), init, a.size(i), builtin);
-                                       ret = builtin.execute2(ret, lval);
+                                       ret = builtin.execute(ret, lval);
                                }               
                                //correction (not sparse-safe)
                                if( a.size(i) < n )
-                                       ret = builtin.execute2(ret, 0); 
+                                       ret = builtin.execute(ret, 0); 
                        }
                }
                c.set(0, 0, ret);
@@ -2469,7 +2469,7 @@ public class LibMatrixAgg
                                c.set(i, 0, builtin(a.values(i), a.pos(i), 
init, a.size(i), builtin));
                        //correction (not sparse-safe)
                        if( a.size(i) < n )
-                               c.set(i, 0, builtin.execute2(c.get(i, 0), 0));
+                               c.set(i, 0, builtin.execute(c.get(i, 0), 0));
                }
        }
        
@@ -2520,7 +2520,7 @@ public class LibMatrixAgg
                // to be replaced with a 0 because there was a missing nonzero. 
                for( int i=0; i<n; i++ )
                        if( cnt[i] < m ) //no dense column
-                               c[i] = builtin.execute2(c[i], 0);
+                               c[i] = builtin.execute(c[i], 0);
        }
 
        /**
@@ -2546,7 +2546,7 @@ public class LibMatrixAgg
                                c.set(i, 0, (double)aix[apos+maxindex] + 1);
                                c.set(i, 1, maxvalue);
                                //correction (not sparse-safe)
-                               if( alen < n && builtin.execute2(0, maxvalue) 
== 1 ) {
+                               if( alen < n && builtin.execute(0, maxvalue) == 
1 ) {
                                        int ix = n-1; //find last 0 value
                                        for( int j=apos+alen-1; j>=apos; j--, 
ix-- )
                                                if( aix[j]!=ix )
@@ -2585,8 +2585,8 @@ public class LibMatrixAgg
                                double minvalue = avals[apos+minindex];
                                c.set(i, 0, (double)aix[apos+minindex] + 1);
                                c.set(i, 1, minvalue); //min value among 
non-zeros
-                               //correction (not sparse-safe)  
-                               if(alen < n && builtin.execute2(0, minvalue) == 
1) {
+                               //correction (not sparse-safe)
+                               if(alen < n && builtin.execute(0, minvalue) == 
1) {
                                        int ix = n-1; //find last 0 value
                                        for( int j=alen-1; j>=0; j--, ix-- )
                                                if( aix[apos+j]!=ix )
@@ -3065,18 +3065,18 @@ public class LibMatrixAgg
        private static double builtin( double[] a, int ai, final double init, 
final int len, Builtin aggop ) {
                double val = init;
                for( int i=0; i<len; i++, ai++ )
-                       val = aggop.execute2( val, a[ ai ] );
+                       val = aggop.execute( val, a[ ai ] );
                return val;
        }
 
        private static void builtinAgg( double[] a, double[] c, int ai, final 
int len, Builtin aggop ) {
                for( int i=0; i<len; i++ )
-                       c[ i ] = aggop.execute2( c[ i ], a[ ai+i ] );
+                       c[ i ] = aggop.execute( c[ i ], a[ ai+i ] );
        }
 
        private static void builtinAgg( double[] a, double[] c, int[] aix, int 
ai, final int len, Builtin aggop ) {
                for( int i=ai; i<ai+len; i++ )
-                       c[ aix[i] ] = aggop.execute2( c[ aix[i] ], a[ i ] );
+                       c[ aix[i] ] = aggop.execute( c[ aix[i] ], a[ i ] );
        }
 
        private static int indexmax( double[] a, int ai, final double init, 
final int len, Builtin aggop ) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/ca615ca1/src/test/java/org/apache/sysml/test/integration/functions/aggregate/AggregateNaNTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/aggregate/AggregateNaNTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/AggregateNaNTest.java
new file mode 100644
index 0000000..4499214
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/AggregateNaNTest.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.aggregate;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.util.DataConverter;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+
+public class AggregateNaNTest extends AutomatedTestBase 
+{
+       private final static String TEST_NAME = "NaNTest";
+       private final static String TEST_DIR = "functions/aggregate/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
AggregateNaNTest.class.getSimpleName() + "/";
+       private final static int rows = 120;
+       private final static int cols = 117;
+       private final static double sparsity1 = 0.1;
+       private final static double sparsity2 = 0.7;
+       
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME,
+                       new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new 
String[]{"B"})); 
+       }
+
+       
+       @Test
+       public void testSumDenseNaN() {
+               runNaNAggregateTest(0, false);
+       }
+       
+       @Test
+       public void testSumSparseNaN() {
+               runNaNAggregateTest(0, true);
+       }
+       
+       @Test
+       public void testSumSqDenseNaN() {
+               runNaNAggregateTest(1, false);
+       }
+       
+       @Test
+       public void testSumSqSparseNaN() {
+               runNaNAggregateTest(1, true);
+       }
+       
+       @Test
+       public void testMinDenseNaN() {
+               runNaNAggregateTest(2, false);
+       }
+       
+       @Test
+       public void testMinSparseNaN() {
+               runNaNAggregateTest(2, true);
+       }
+       
+       @Test
+       public void testMaxDenseNaN() {
+               runNaNAggregateTest(3, false);
+       }
+       
+       @Test
+       public void testMaxSparseNaN() {
+               runNaNAggregateTest(3, true);
+       }
+       
+       private void runNaNAggregateTest(int type, boolean sparse) {
+               //generate input
+               double sparsity = sparse ? sparsity1 : sparsity2;
+               double[][] A = getRandomMatrix(rows, cols, -0.05, 1, sparsity, 
7); 
+               A[7][7] = Double.NaN;
+               MatrixBlock mb = DataConverter.convertToMatrixBlock(A);
+               
+               double ret = -1;
+               switch(type) {
+                       case 0: ret = mb.sum();
+                       case 1: ret = mb.sumSq();
+                       case 2: ret = mb.min();
+                       case 3: ret = mb.max();
+               }
+               
+               Assert.assertTrue(Double.isNaN(ret));
+       }
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/ca615ca1/src/test_suites/java/org/apache/sysml/test/integration/functions/aggregate/ZPackageSuite.java
----------------------------------------------------------------------
diff --git 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/aggregate/ZPackageSuite.java
 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/aggregate/ZPackageSuite.java
index 4aedf4d..00629fb 100644
--- 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/aggregate/ZPackageSuite.java
+++ 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/aggregate/ZPackageSuite.java
@@ -27,6 +27,7 @@ import org.junit.runners.Suite;
 @RunWith(Suite.class)
 @Suite.SuiteClasses({
        AggregateInfTest.class,
+       AggregateNaNTest.class,
        ColStdDevsTest.class,
        ColSumsSqTest.class,
        ColSumTest.class,

Reply via email to