This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 9b940f7051 [SYSTEMDS-3797] Fix rewrite for trace on reorg operations
9b940f7051 is described below
commit 9b940f7051ad1b8b216be130cd785ae3165da0b3
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Nov 28 10:39:56 2024 +0100
[SYSTEMDS-3797] Fix rewrite for trace on reorg operations
This patch fixes the rewrite for removing unnecessary reorg operations
such as sum(t(X)) or sum(rev(X)) for trace aggregations which only
consume a subset of values. Furthermore, we generalize this rewrite
to now eliminate all reorg operations that are guaranteed to preserve
all values (e.g., transpose/reshape/rev/roll, but not for diagM2V and
sort with index return).
Thanks to Jannik Lindemann for catching this issue.
---
src/main/java/org/apache/sysds/common/Types.java | 4 ++++
.../hops/rewrite/RewriteAlgebraicSimplificationStatic.java | 12 +++++-------
.../rewrite/RewriteSimplifyTraceMatrixMultTest.java | 12 +++---------
.../functions/rewrite/RewriteSimplifyTraceMatrixMult.R | 5 +++++
.../functions/rewrite/RewriteSimplifyTraceMatrixMult.dml | 2 ++
5 files changed, 19 insertions(+), 16 deletions(-)
diff --git a/src/main/java/org/apache/sysds/common/Types.java
b/src/main/java/org/apache/sysds/common/Types.java
index e7274b25c4..ba264dea7f 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -751,6 +751,10 @@ public interface Types {
DIAG, //DIAG_V2M and DIAG_M2V could not be distinguished if
sizes unknown
RESHAPE, REV, ROLL, SORT, TRANS;
+ public boolean preservesValues() {
+ return this != DIAG && this != SORT;
+ }
+
@Override
public String toString() {
switch(this) {
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 056770dceb..8053ddc78a 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -980,23 +980,21 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
private static Hop simplifyUnaryAggReorgOperation( Hop parent, Hop hi,
int pos )
{
- if( hi instanceof AggUnaryOp &&
((AggUnaryOp)hi).getDirection()==Direction.RowCol //full uagg
- && hi.getInput().get(0) instanceof ReorgOp )
//reorg operation
+ if( hi instanceof AggUnaryOp &&
((AggUnaryOp)hi).getDirection()==Direction.RowCol
+ && ((AggUnaryOp)hi).getOp() != AggOp.TRACE //full
uagg
+ && hi.getInput().get(0) instanceof ReorgOp ) //reorg
operation
{
ReorgOp rop = (ReorgOp)hi.getInput().get(0);
- if( (rop.getOp()==ReOrgOp.TRANS ||
rop.getOp()==ReOrgOp.RESHAPE
- || rop.getOp() == ReOrgOp.REV )
//valid reorg
- && rop.getParent().size()==1 )
//uagg only reorg consumer
+ if( rop.getOp().preservesValues() //valid reorg
+ && rop.getParent().size()==1 ) //uagg only
reorg consumer
{
Hop input = rop.getInput().get(0);
HopRewriteUtils.removeAllChildReferences(hi);
HopRewriteUtils.removeAllChildReferences(rop);
HopRewriteUtils.addChildReference(hi, input);
-
LOG.debug("Applied
simplifyUnaryAggReorgOperation");
}
}
-
return hi;
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceMatrixMultTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceMatrixMultTest.java
index 4a81609d3f..c2ae90eec7 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceMatrixMultTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyTraceMatrixMultTest.java
@@ -85,18 +85,12 @@ public class RewriteSimplifyTraceMatrixMultTest extends
AutomatedTestBase {
TestUtils.compareMatrices(dmlfile, rfile, eps,
"Stat-DML", "Stat-R");
//check trace operator existence
- String uaktrace = "uaktrace";
- long numTrace =
Statistics.getCPHeavyHitterCount(uaktrace);
-
- if(rewrites)
- Assert.assertTrue(numTrace == 0);
- else
- Assert.assertTrue(numTrace == 1);
-
+ long numTrace =
Statistics.getCPHeavyHitterCount("uaktrace");
+ Assert.assertTrue(numTrace == (rewrites ? 1 : 2));
+ Assert.assertTrue(heavyHittersContainsString("rev"));
}
finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
}
-
}
}
diff --git
a/src/test/scripts/functions/rewrite/RewriteSimplifyTraceMatrixMult.R
b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceMatrixMult.R
index 2153b2dafd..3bb323986d 100644
--- a/src/test/scripts/functions/rewrite/RewriteSimplifyTraceMatrixMult.R
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceMatrixMult.R
@@ -36,6 +36,11 @@ B = as.matrix(readMM(paste(args[1], "B.mtx", sep="")))
# Perform the matrix operation
R = sum(diag(A %*% B))
+rA = A;
+for(i in 1:nrow(rA)) {
+ rA[,i] = rev(rA[,i])
+}
+R = R + sum(diag(rA))
# Write the result scalar R
write(R, paste(args[2], "R" ,sep=""))
diff --git
a/src/test/scripts/functions/rewrite/RewriteSimplifyTraceMatrixMult.dml
b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceMatrixMult.dml
index 315af97843..7189a7f4e6 100644
--- a/src/test/scripts/functions/rewrite/RewriteSimplifyTraceMatrixMult.dml
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyTraceMatrixMult.dml
@@ -26,6 +26,8 @@ B = read($2)
# Perform the operation
R = trace(A %*% B)
+R = R + trace(rev(A))
# Write the result R
write(R, $3)
+