This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new e705f893f7 [MINOR] Code cleanups in rewrites and tests
e705f893f7 is described below
commit e705f893f719632ef4afd990a908f2c51fbe0a3d
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Dec 13 08:50:54 2024 +0100
[MINOR] Code cleanups in rewrites and tests
---
.../RewriteAlgebraicSimplificationDynamic.java | 68 +++++++++++-----------
...iteElementwiseMultChainOptimizationAllTest.java | 15 +----
...ewriteElementwiseMultChainOptimizationTest.java | 15 +----
3 files changed, 38 insertions(+), 60 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index c9a9745091..15207e87b5 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -243,7 +243,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
{
if( HopRewriteUtils.isUnnecessaryRightIndexing(hi) &&
!hi.isScalar() ) {
//remove unnecessary right indexing
- Hop input = hi.getInput().get(0);
+ Hop input = hi.getInput(0);
HopRewriteUtils.replaceChildReference(parent, hi,
input, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = input;
@@ -258,8 +258,8 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
{
if( hi instanceof LeftIndexingOp && hi.getDataType() ==
DataType.MATRIX ) //left indexing op
{
- Hop input1 = hi.getInput().get(0); //lhs matrix
- Hop input2 = hi.getInput().get(1); //rhs matrix
+ Hop input1 = hi.getInput(0); //lhs matrix
+ Hop input2 = hi.getInput(1); //rhs matrix
if( input1.getNnz()==0 //nnz original known and empty
&& input2.getNnz()==0 ) //nnz input known and empty
@@ -271,7 +271,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
hi = hnew;
LOG.debug("Applied removeEmptyLeftIndexing");
- }
+ }
}
return hi;
@@ -281,19 +281,19 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
{
if( hi instanceof LeftIndexingOp ) //left indexing op
{
- Hop input = hi.getInput().get(1); //rhs matrix/frame
+ Hop input = hi.getInput(1); //rhs matrix/frame
if( HopRewriteUtils.isEqualSize(hi, input) ) //equal
dims
{
//equal dims of left indexing input and output
-> no need for indexing
- //remove unnecessary right indexing
+ //remove unnecessary right indexing
HopRewriteUtils.replaceChildReference(parent,
hi, input, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = input;
LOG.debug("Applied
removeUnnecessaryLeftIndexing");
- }
+ }
}
return hi;
@@ -306,15 +306,15 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//pattern1: X[,1]=A; X[,2]=B -> X=cbind(A,B); matrix / frame
if( hi instanceof LeftIndexingOp //first
lix
&&
HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp)hi)
- && hi.getInput().get(0) instanceof LeftIndexingOp
//second lix
+ && hi.getInput(0) instanceof LeftIndexingOp //second
lix
&&
HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp)hi.getInput().get(0))
- && hi.getInput().get(0).getParent().size()==1
//first lix is single consumer
- && hi.getInput().get(0).getInput().get(0).getDim2() ==
2 ) //two column matrix
+ && hi.getInput(0).getParent().size()==1 //first lix
is single consumer
+ && hi.getInput(0).getInput(0).getDim2() == 2 ) //two
column matrix
{
- Hop input2 = hi.getInput().get(1); //rhs matrix
- Hop pred2 = hi.getInput().get(4); //cl=cu
- Hop input1 = hi.getInput().get(0).getInput().get(1);
//lhs matrix
- Hop pred1 = hi.getInput().get(0).getInput().get(4);
//cl=cu
+ Hop input2 = hi.getInput(1); //rhs matrix
+ Hop pred2 = hi.getInput(4); //cl=cu
+ Hop input1 = hi.getInput(0).getInput(1); //lhs matrix
+ Hop pred1 = hi.getInput(0).getInput(4); //cl=cu
if( pred1 instanceof LiteralOp &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred1)==1
&& pred2 instanceof LiteralOp &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred2)==2
@@ -332,15 +332,15 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//pattern1: X[1,]=A; X[2,]=B -> X=rbind(A,B)
if( !applied && hi instanceof LeftIndexingOp //first
lix
&& HopRewriteUtils.isFullRowIndexing((LeftIndexingOp)hi)
- && hi.getInput().get(0) instanceof LeftIndexingOp
//second lix
+ && hi.getInput(0) instanceof LeftIndexingOp //second
lix
&&
HopRewriteUtils.isFullRowIndexing((LeftIndexingOp)hi.getInput().get(0))
- && hi.getInput().get(0).getParent().size()==1
//first lix is single consumer
- && hi.getInput().get(0).getInput().get(0).getDim1() ==
2 ) //two column matrix
+ && hi.getInput(0).getParent().size()==1 //first lix
is single consumer
+ && hi.getInput(0).getInput(0).getDim1() == 2 ) //two
column matrix
{
- Hop input2 = hi.getInput().get(1); //rhs matrix
- Hop pred2 = hi.getInput().get(2); //rl=ru
- Hop input1 = hi.getInput().get(0).getInput().get(1);
//lhs matrix
- Hop pred1 = hi.getInput().get(0).getInput().get(2);
//rl=ru
+ Hop input2 = hi.getInput(1); //rhs matrix
+ Hop pred2 = hi.getInput(2); //rl=ru
+ Hop input1 = hi.getInput(0).getInput(1); //lhs matrix
+ Hop pred1 = hi.getInput(0).getInput(2); //rl=ru
if( pred1 instanceof LiteralOp &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred1)==1
&& pred2 instanceof LiteralOp &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred2)==2
@@ -364,19 +364,19 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
{
if( hi instanceof UnaryOp &&
((UnaryOp)hi).isCumulativeUnaryOperation() )
{
- Hop input = hi.getInput().get(0); //input matrix
+ Hop input = hi.getInput(0); //input matrix
if( HopRewriteUtils.isDimsKnown(input) //dims input
known
&& input.getDim1()==1 ) //1 row
{
OpOp1 op = ((UnaryOp)hi).getOp();
- //remove unnecessary unary cumsum operator
+ //remove unnecessary unary cumsum operator
HopRewriteUtils.replaceChildReference(parent,
hi, input, pos);
hi = input;
LOG.debug("Applied
removeUnnecessaryCumulativeOp: "+op);
- }
+ }
}
return hi;
@@ -413,27 +413,27 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
if( hi instanceof BinaryOp ) //binary cell operation
{
OpOp2 bop = ((BinaryOp)hi).getOp();
- Hop left = hi.getInput().get(0);
- Hop right = hi.getInput().get(1);
+ Hop left = hi.getInput(0);
+ Hop right = hi.getInput(1);
//check for matrix-vector column replication: (A + b
%*% ones) -> (A + b)
if( HopRewriteUtils.isMatrixMultiply(right) //matrix
mult with datagen
&&
HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(1), 1)
- && right.getInput().get(0).getDim2() == 1 )
//column vector for mv binary
+ && right.getInput(0).getDim2() == 1 ) //column
vector for mv binary
{
//remove unnecessary outer product
- HopRewriteUtils.replaceChildReference(hi,
right, right.getInput().get(0), 1 );
+ HopRewriteUtils.replaceChildReference(hi,
right, right.getInput(0), 1 );
HopRewriteUtils.cleanupUnreferenced(right);
LOG.debug("Applied
removeUnnecessaryOuterProduct1 (line "+right.getBeginLine()+")");
}
//check for matrix-vector row replication: (A + ones
%*% b) -> (A + b)
else if( HopRewriteUtils.isMatrixMultiply(right)
//matrix mult with datagen
- &&
HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(0), 1)
- && right.getInput().get(1).getDim1() == 1 )
//row vector for mv binary
+ &&
HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput(0), 1)
+ && right.getInput(1).getDim1() == 1 ) //row
vector for mv binary
{
//remove unnecessary outer product
- HopRewriteUtils.replaceChildReference(hi,
right, right.getInput().get(1), 1 );
+ HopRewriteUtils.replaceChildReference(hi,
right, right.getInput(1), 1 );
HopRewriteUtils.cleanupUnreferenced(right);
LOG.debug("Applied
removeUnnecessaryOuterProduct2 (line "+right.getBeginLine()+")");
@@ -442,11 +442,11 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
else if(HopRewriteUtils.isValidOuterBinaryOp(bop)
&& HopRewriteUtils.isMatrixMultiply(left)
&&
HopRewriteUtils.isDataGenOpWithConstantValue(left.getInput().get(1), 1)
- && (left.getInput().get(0).getDim2() == 1
//outer product
- || left.getInput().get(1).getDim1() ==
1)
+ && (left.getInput(0).getDim2() == 1 //outer
product
+ || left.getInput(1).getDim1() == 1)
&& left.getDim1() != 1 && right.getDim1() == 1
) //outer vector binary
{
- Hop hnew =
HopRewriteUtils.createBinary(left.getInput().get(0), right, bop, true);
+ Hop hnew =
HopRewriteUtils.createBinary(left.getInput(0), right, bop, true);
HopRewriteUtils.replaceChildReference(parent,
hi, hnew, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationAllTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationAllTest.java
index 78728d9a71..15b24534c1 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationAllTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationAllTest.java
@@ -23,7 +23,6 @@ import java.util.HashMap;
import org.junit.Assert;
import org.junit.Test;
-import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.common.Types.ExecType;
@@ -74,16 +73,7 @@ public class RewriteElementwiseMultChainOptimizationAllTest
extends AutomatedTes
private void testRewriteMatrixMultChainOp(String testname, boolean
rewrites, ExecType et)
{
- ExecMode platformOld = rtplatform;
- switch( et ){
- case SPARK: rtplatform = ExecMode.SPARK; break;
- default: rtplatform = ExecMode.HYBRID; break;
- }
-
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- if( rtplatform == ExecMode.SPARK || rtplatform ==
ExecMode.HYBRID )
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-
+ ExecMode platformOld = setExecMode(et);
boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
@@ -126,8 +116,7 @@ public class RewriteElementwiseMultChainOptimizationAllTest
extends AutomatedTes
}
finally {
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ resetExecMode(platformOld);
}
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationTest.java
index d60df3f665..6c6ede61d7 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteElementwiseMultChainOptimizationTest.java
@@ -23,7 +23,6 @@ import java.util.HashMap;
import org.junit.Assert;
import org.junit.Test;
-import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.common.Types.ExecType;
@@ -73,16 +72,7 @@ public class RewriteElementwiseMultChainOptimizationTest
extends AutomatedTestBa
private void testRewriteMatrixMultChainOp(String testname, boolean
rewrites, ExecType et)
{
- ExecMode platformOld = rtplatform;
- switch( et ){
- case SPARK: rtplatform = ExecMode.SPARK; break;
- default: rtplatform = ExecMode.HYBRID; break;
- }
-
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- if( rtplatform == ExecMode.SPARK || rtplatform ==
ExecMode.HYBRID )
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-
+ ExecMode platformOld = setExecMode(et);
boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
@@ -119,8 +109,7 @@ public class RewriteElementwiseMultChainOptimizationTest
extends AutomatedTestBa
}
finally {
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ resetExecMode(platformOld);
}
}
}