This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new e8b65fe [MINOR] Performance spark NaN checks and ultra-sparse
aggregation
e8b65fe is described below
commit e8b65fe6e97e8d52dc1b0150592090b281dca507
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Feb 4 14:28:07 2021 +0100
[MINOR] Performance spark NaN checks and ultra-sparse aggregation
This patch adds two early-abort conditions to unary isNaN, isNA as well
as unary aggregate operations in order to avoid the unnecessary
temporary allocation of dense outputs and corrections.
On the ultra-sparse criteo dataset in one-hot encoded representation,
this patch reduces the runtime of the isNaN check from >1h to 2.6min.
---
.../org/apache/sysds/runtime/functionobjects/Builtin.java | 7 +++++--
.../runtime/instructions/spark/utils/RDDAggregateUtils.java | 4 ++++
.../org/apache/sysds/runtime/matrix/data/MatrixBlock.java | 5 +++++
.../apache/sysds/runtime/matrix/operators/UnaryOperator.java | 11 +++++++++++
4 files changed, 25 insertions(+), 2 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
index 33aeae0..46acba4 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
@@ -118,8 +118,11 @@ public class Builtin extends ValueFunction
return bFunc;
}
- public static boolean isBuiltinCode(ValueFunction fn, BuiltinCode code)
{
- return (fn instanceof Builtin && ((Builtin)fn).getBuiltinCode()
== code);
+ public static boolean isBuiltinCode(ValueFunction fn, BuiltinCode...
codes) {
+ for( BuiltinCode code : codes )
+ if (fn instanceof Builtin &&
((Builtin)fn).getBuiltinCode() == code)
+ return true;
+ return false;
}
public static boolean isBuiltinFnObject(String str) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/RDDAggregateUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/RDDAggregateUtils.java
index d94fcce..3df024d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/RDDAggregateUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/RDDAggregateUtils.java
@@ -631,6 +631,10 @@ public class RDDAggregateUtils
return arg0;
}
+ //early-abort (without dense correction allocation)
+ if( _op.sparseSafe && (arg0.isEmpty() | arg1.isEmpty())
)
+ return arg1.isEmpty() ? arg0 : arg1;
+
//create correction block (on demand)
if( _op.existsCorrection() && _corr == null ) {
_corr = new MatrixBlock(arg0.getNumRows(),
arg0.getNumColumns(), false);
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 42e65da..aa0bb45 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -2677,6 +2677,11 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock, Externalizab
else
ret.reset(rlen, n, sp);
+ //early abort for comparisons w/ special values
+ if( Builtin.isBuiltinCode(op.fn, BuiltinCode.ISNAN,
BuiltinCode.ISNA))
+ if( !containsValue(op.getPattern()) )
+ return ret; //avoid unnecessary allocation
+
//core execute
if( LibMatrixAgg.isSupportedUnaryOperator(op) ) {
//e.g., cumsum/cumprod/cummin/cumax/cumsumprod
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/UnaryOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/UnaryOperator.java
index b82e35a..69deeef 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/UnaryOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/UnaryOperator.java
@@ -20,6 +20,7 @@
package org.apache.sysds.runtime.matrix.operators;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
@@ -55,4 +56,14 @@ public class UnaryOperator extends Operator
public boolean isInplace() {
return inplace;
}
+
+ public double getPattern() {
+ switch( ((Builtin)fn).bFunc ) {
+ case ISNAN:
+ case ISNA: return Double.NaN;
+ default:
+ throw new DMLRuntimeException(
+ "No pattern existing for
"+((Builtin)fn).bFunc.name());
+ }
+ }
}