This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 1e86da37df [SYSTEMDS-3806] Robustness simplifyDotProductSum rewrite
1e86da37df is described below
commit 1e86da37dfed87b94a1484c5a41612ad0787270b
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Dec 13 07:41:27 2024 +0100
[SYSTEMDS-3806] Robustness simplifyDotProductSum rewrite
This patch fixes an issue of incorrect application of the
simplifyDotProductSum rewrite. Specifically, sum(s*V) was rewritten to
t(s) %*% V because s was assumed to be a vector of equal size than V
but was a scalar. The root cause of an incorrect size propagation for
the new scalar right indexing, but for robustness we now also check
that both inputs are actually matrices.
---
src/main/java/org/apache/sysds/hops/IndexingOp.java | 8 ++++++++
.../sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java | 1 +
src/test/scripts/functions/unary/matrix/eigen.dml | 2 --
3 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/IndexingOp.java
b/src/main/java/org/apache/sysds/hops/IndexingOp.java
index 1756724e74..457c5b44e0 100644
--- a/src/main/java/org/apache/sysds/hops/IndexingOp.java
+++ b/src/main/java/org/apache/sysds/hops/IndexingOp.java
@@ -370,6 +370,14 @@ public class IndexingOp extends Hop
@Override
public void refreshSizeInformation()
{
+ // early abort for scalar right indexing
+ // (important to prevent incorrect dynamic rewrites)
+ if( isScalar() ) {
+ setDim1(0);
+ setDim2(0);
+ return;
+ }
+
Hop input1 = getInput().get(0); //matrix
Hop input2 = getInput().get(1); //inpRowL
Hop input3 = getInput().get(2); //inpRowU
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 787b87716b..c9a9745091 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -2312,6 +2312,7 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//check for sum(v1*v2), but prevent to rewrite
sum(v1*v2*v3) which is later compiled into a ta+* lop
else if( HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1)
//no other consumer than sum
&& hi2.getInput().get(0).getDim2()==1
&& hi2.getInput().get(1).getDim2()==1
+ && hi2.getInput().get(0).isMatrix() &&
hi2.getInput().get(1).isMatrix()
&&
!HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT)
&&
!HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT)
&& ( !ALLOW_SUM_PRODUCT_REWRITES
diff --git a/src/test/scripts/functions/unary/matrix/eigen.dml
b/src/test/scripts/functions/unary/matrix/eigen.dml
index f863a0a8e6..a2a262f4cb 100644
--- a/src/test/scripts/functions/unary/matrix/eigen.dml
+++ b/src/test/scripts/functions/unary/matrix/eigen.dml
@@ -33,9 +33,7 @@ numEval = $2;
D = matrix(1, numEval, 1);
for ( i in 1:numEval ) {
Av = A %*% evec[,i];
- while(FALSE){} #fix incorrect rewrite sequence
rhs = as.scalar(eval[i,1]) * evec[,i];
- while(FALSE){} #fix incorrect rewrite sequence
diff = sum(Av-rhs);
D[i,1] = diff;
}