This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 1e86da37df [SYSTEMDS-3806] Robustness simplifyDotProductSum rewrite
1e86da37df is described below

commit 1e86da37dfed87b94a1484c5a41612ad0787270b
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Dec 13 07:41:27 2024 +0100

    [SYSTEMDS-3806] Robustness simplifyDotProductSum rewrite
    
    This patch fixes an issue of incorrect application of the
    simplifyDotProductSum rewrite. Specifically, sum(s*V) was rewritten to
    t(s) %*% V because s was assumed to be a vector of equal size than V
    but was a scalar. The root cause of an incorrect size propagation for
    the new scalar right indexing, but for robustness we now also check
    that both inputs are actually matrices.
---
 src/main/java/org/apache/sysds/hops/IndexingOp.java               | 8 ++++++++
 .../sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java | 1 +
 src/test/scripts/functions/unary/matrix/eigen.dml                 | 2 --
 3 files changed, 9 insertions(+), 2 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/IndexingOp.java 
b/src/main/java/org/apache/sysds/hops/IndexingOp.java
index 1756724e74..457c5b44e0 100644
--- a/src/main/java/org/apache/sysds/hops/IndexingOp.java
+++ b/src/main/java/org/apache/sysds/hops/IndexingOp.java
@@ -370,6 +370,14 @@ public class IndexingOp extends Hop
        @Override
        public void refreshSizeInformation()
        {
+               // early abort for scalar right indexing
+               // (important to prevent incorrect dynamic rewrites)
+               if( isScalar() ) {
+                       setDim1(0); 
+                       setDim2(0);
+                       return;
+               }
+               
                Hop input1 = getInput().get(0); //matrix
                Hop input2 = getInput().get(1); //inpRowL
                Hop input3 = getInput().get(2); //inpRowU
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 787b87716b..c9a9745091 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -2312,6 +2312,7 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                        //check for sum(v1*v2), but prevent to rewrite 
sum(v1*v2*v3) which is later compiled into a ta+* lop
                        else if( HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1) 
//no other consumer than sum
                                        && hi2.getInput().get(0).getDim2()==1 
&& hi2.getInput().get(1).getDim2()==1
+                                       && hi2.getInput().get(0).isMatrix() && 
hi2.getInput().get(1).isMatrix()
                                        && 
!HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT)
                                        && 
!HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT)
                                        && ( !ALLOW_SUM_PRODUCT_REWRITES
diff --git a/src/test/scripts/functions/unary/matrix/eigen.dml 
b/src/test/scripts/functions/unary/matrix/eigen.dml
index f863a0a8e6..a2a262f4cb 100644
--- a/src/test/scripts/functions/unary/matrix/eigen.dml
+++ b/src/test/scripts/functions/unary/matrix/eigen.dml
@@ -33,9 +33,7 @@ numEval = $2;
 D = matrix(1, numEval, 1);
 for ( i in 1:numEval ) {
     Av = A %*% evec[,i];
-    while(FALSE){} #fix incorrect rewrite sequence
     rhs = as.scalar(eval[i,1]) * evec[,i];
-    while(FALSE){} #fix incorrect rewrite sequence
     diff = sum(Av-rhs);
     D[i,1] = diff;
 }

Reply via email to