This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 0de6fab40f [MINOR] Code cleanups in rewrites and tests, part II
0de6fab40f is described below
commit 0de6fab40fcca5a53eff15ebe9014ae6bfe7facf
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Dec 13 10:02:08 2024 +0100
[MINOR] Code cleanups in rewrites and tests, part II
---
.github/workflows/javaTests.yml | 10 +-
.github/workflows/python.yml | 8 +-
.../RewriteAlgebraicSimplificationDynamic.java | 768 ++++++++++-----------
.../RewriteAlgebraicSimplificationStatic.java | 462 ++++++-------
.../rewrite/RewriteFuseBinaryOpChainTest.java | 17 +-
.../RewriteHoistingLoopInvariantOpsTest.java | 17 +-
.../test/functions/rewrite/RewriteIfElseTest.java | 16 +-
7 files changed, 635 insertions(+), 663 deletions(-)
diff --git a/.github/workflows/javaTests.yml b/.github/workflows/javaTests.yml
index f58b4975c2..c2cab87c22 100644
--- a/.github/workflows/javaTests.yml
+++ b/.github/workflows/javaTests.yml
@@ -116,12 +116,16 @@ jobs:
determine_test_coverage:
name: Determine Test Coverage
- needs: [java_tests]
+ runs-on: ${{ matrix.os }}
+ needs: [
+ java_tests
+ ]
strategy:
+ fail-fast: false
matrix:
os: [ubuntu-latest]
- java: [11]
-
+ java: ['11']
+ javadist: ['adopt']
steps:
- name: Checkout Repository
uses: actions/checkout@v4
diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
index d448645d94..9f39f07ecb 100644
--- a/.github/workflows/python.yml
+++ b/.github/workflows/python.yml
@@ -64,7 +64,7 @@ jobs:
distribution: ${{ matrix.javadist }}
java-version: ${{ matrix.java }}
cache: 'maven'
-
+
- name: Cache Pip Dependencies
uses: actions/cache@v4
with:
@@ -93,11 +93,11 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
architecture: 'x64'
-
+
- name: Install pip Dependencies
run: |
# Install pip twice to update past the versions.
- pip install --upgrade pip
+ pip install --upgrade pip
pip install --upgrade pip
pip install wheel
pip install \
@@ -133,7 +133,7 @@ jobs:
unittest-parallel -t . -s tests -v
# python -m unittest discover -s tests -p 'test_*.py'
echo "Exit Status: " $?
-
+
- name: Run all python tests no environment
run: |
export LOG4JPROP=$(pwd)/src/test/resources/log4j.properties
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 15207e87b5..d73f8489b6 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -307,7 +307,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
if( hi instanceof LeftIndexingOp //first
lix
&&
HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp)hi)
&& hi.getInput(0) instanceof LeftIndexingOp //second
lix
- &&
HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp)hi.getInput().get(0))
+ &&
HopRewriteUtils.isFullColumnIndexing((LeftIndexingOp)hi.getInput(0))
&& hi.getInput(0).getParent().size()==1 //first lix
is single consumer
&& hi.getInput(0).getInput(0).getDim2() == 2 ) //two
column matrix
{
@@ -333,7 +333,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
if( !applied && hi instanceof LeftIndexingOp //first
lix
&& HopRewriteUtils.isFullRowIndexing((LeftIndexingOp)hi)
&& hi.getInput(0) instanceof LeftIndexingOp //second
lix
- &&
HopRewriteUtils.isFullRowIndexing((LeftIndexingOp)hi.getInput().get(0))
+ &&
HopRewriteUtils.isFullRowIndexing((LeftIndexingOp)hi.getInput(0))
&& hi.getInput(0).getParent().size()==1 //first lix
is single consumer
&& hi.getInput(0).getInput(0).getDim1() == 2 ) //two
column matrix
{
@@ -387,7 +387,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
if( hi instanceof ReorgOp )
{
ReorgOp rop = (ReorgOp) hi;
- Hop input = hi.getInput().get(0);
+ Hop input = hi.getInput(0);
boolean apply = false;
//equal dims of reshape input and output -> no need for
reshape because
@@ -418,7 +418,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
//check for matrix-vector column replication: (A + b
%*% ones) -> (A + b)
if( HopRewriteUtils.isMatrixMultiply(right) //matrix
mult with datagen
- &&
HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(1), 1)
+ &&
HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput(1), 1)
&& right.getInput(0).getDim2() == 1 ) //column
vector for mv binary
{
//remove unnecessary outer product
@@ -441,7 +441,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
//check for vector-vector column replication: (a %*%
ones) == b) -> outer(a, b, "==")
else if(HopRewriteUtils.isValidOuterBinaryOp(bop)
&& HopRewriteUtils.isMatrixMultiply(left)
- &&
HopRewriteUtils.isDataGenOpWithConstantValue(left.getInput().get(1), 1)
+ &&
HopRewriteUtils.isDataGenOpWithConstantValue(left.getInput(1), 1)
&& (left.getInput(0).getDim2() == 1 //outer
product
|| left.getInput(1).getDim1() == 1)
&& left.getDim1() != 1 && right.getDim1() == 1
) //outer vector binary
@@ -463,9 +463,9 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
if( !HopRewriteUtils.isTernary(hi, OpOp3.IFELSE) )
return hi;
- Hop expr = hi.getInput().get(0);
- Hop first = hi.getInput().get(1);
- Hop second = hi.getInput().get(2);
+ Hop expr = hi.getInput(0);
+ Hop first = hi.getInput(1);
+ Hop second = hi.getInput(2);
boolean applied = false;
//pattern 1: ifelse(TRUE/FALSE, A, B) -> A/B (constant scalar
predicate)
@@ -506,27 +506,27 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//pattern 1: X = t(rbind(A,B,C)) %*% rbind(A,B,C) -> t(A)%*%A +
t(B)%*%B + t(C)%*%C
int branch = -1;
if( HopRewriteUtils.isTsmm(hi)
- &&
HopRewriteUtils.isTransposeOperation(hi.getInput().get(0))
- && HopRewriteUtils.isNary(hi.getInput().get(1),
OpOpN.RBIND) )
+ && HopRewriteUtils.isTransposeOperation(hi.getInput(0))
+ && HopRewriteUtils.isNary(hi.getInput(1), OpOpN.RBIND) )
{
- List<Hop> inputs = hi.getInput().get(1).getInput();
+ List<Hop> inputs = hi.getInput(1).getInput();
if( HopRewriteUtils.checkAvgRowsGteCols(inputs) ) {
Hop[] tsmms = inputs.stream()
.map(h -> HopRewriteUtils.createTsmm(h,
true)).toArray(Hop[]::new);
hnew = HopRewriteUtils.createNary(OpOpN.PLUS,
tsmms);
//cleanup parent references from rbind
-
//HopRewriteUtils.removeAllChildReferences(hi.getInput().get(1));
+
//HopRewriteUtils.removeAllChildReferences(hi.getInput(1));
branch = 1;
}
}
//pattern 2: X = t(rbind(A,B,C)) %*% rbind(D,E,F) -> t(A)%*%D
+ t(B)%*%E + t(C)%*%F
else if( HopRewriteUtils.isMatrixMultiply(hi)
- &&
HopRewriteUtils.isTransposeOperation(hi.getInput().get(0))
- &&
HopRewriteUtils.isNary(hi.getInput().get(0).getInput().get(0), OpOpN.RBIND)
- && HopRewriteUtils.isNary(hi.getInput().get(1),
OpOpN.RBIND) )
+ && HopRewriteUtils.isTransposeOperation(hi.getInput(0))
+ && HopRewriteUtils.isNary(hi.getInput(0).getInput(0),
OpOpN.RBIND)
+ && HopRewriteUtils.isNary(hi.getInput(1), OpOpN.RBIND) )
{
- List<Hop> inputs1 =
hi.getInput().get(0).getInput().get(0).getInput();
- List<Hop> inputs2 = hi.getInput().get(1).getInput();
+ List<Hop> inputs1 =
hi.getInput(0).getInput(0).getInput();
+ List<Hop> inputs2 = hi.getInput(1).getInput();
if( HopRewriteUtils.checkAvgRowsGteCols(inputs1)
&& HopRewriteUtils.checkAvgRowsGteCols(inputs2)
&& HopRewriteUtils.checkConsistentRows(inputs1,
inputs2) )
@@ -537,18 +537,18 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
HopRewriteUtils.createTranspose(inputs1.get(i)), inputs2.get(i));
hnew = HopRewriteUtils.createNary(OpOpN.PLUS,
mms);
//cleanup parent references from rbind
left/right
-
//HopRewriteUtils.removeAllChildReferences(hi.getInput().get(0).getInput().get(0));
-
//HopRewriteUtils.removeAllChildReferences(hi.getInput().get(1));
+
//HopRewriteUtils.removeAllChildReferences(hi.getInput(0).getInput(0));
+
//HopRewriteUtils.removeAllChildReferences(hi.getInput(1));
branch = 2;
}
}
//pattern 3: X = t(cbind(A, B)) %*% cbind(A, B), w/ one cbind
consumer (twice in tsmm)
- else if( HopRewriteUtils.isTsmm(hi) &&
hi.getInput().get(1).getParent().size()==2
- &&
HopRewriteUtils.isTransposeOperation(hi.getInput().get(0))
- && HopRewriteUtils.isBinary(hi.getInput().get(1),
OpOp2.CBIND) )
+ else if( HopRewriteUtils.isTsmm(hi) &&
hi.getInput(1).getParent().size()==2
+ && HopRewriteUtils.isTransposeOperation(hi.getInput(0))
+ && HopRewriteUtils.isBinary(hi.getInput(1),
OpOp2.CBIND) )
{
- Hop input1 = hi.getInput().get(1).getInput().get(0);
- Hop input2 = hi.getInput().get(1).getInput().get(1);
+ Hop input1 = hi.getInput(1).getInput(0);
+ Hop input2 = hi.getInput(1).getInput(1);
if( input1.getDim1() > input1.getDim2() &&
input2.getDim2() == 1 ) {
hnew = HopRewriteUtils.createPartialTsmmCbind(
input1, input2,
HopRewriteUtils.createTsmm(input1, true));
@@ -571,10 +571,10 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
private static Hop fuseDatagenAndReorgOperation(Hop parent, Hop hi, int
pos)
{
if( HopRewriteUtils.isTransposeOperation(hi)
- && hi.getInput().get(0) instanceof DataGenOp
//datagen
- && hi.getInput().get(0).getParent().size()==1 )
//transpose only consumer
+ && hi.getInput(0) instanceof DataGenOp //datagen
+ && hi.getInput(0).getParent().size()==1 ) //transpose
only consumer
{
- DataGenOp dop = (DataGenOp)hi.getInput().get(0);
+ DataGenOp dop = (DataGenOp)hi.getInput(0);
if( (dop.getOp() == OpOpDG.RAND || dop.getOp() ==
OpOpDG.SINIT)
&& (dop.getDim1()==1 || dop.getDim2()==1 ))
{
@@ -609,7 +609,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
if( hi instanceof AggUnaryOp )
{
AggUnaryOp uhi = (AggUnaryOp)hi;
- Hop input = uhi.getInput().get(0);
+ Hop input = uhi.getInput(0);
if( HopRewriteUtils.isValidOp(uhi.getOp(),
LOOKUP_VALID_ROW_COL_AGGREGATE) ) {
if( uhi.getDirection() == Direction.Col )
@@ -670,7 +670,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
if( hi instanceof AggUnaryOp )
{
AggUnaryOp uhi = (AggUnaryOp)hi;
- Hop input = uhi.getInput().get(0);
+ Hop input = uhi.getInput(0);
if( HopRewriteUtils.isValidOp(uhi.getOp(),
LOOKUP_VALID_ROW_COL_AGGREGATE) ) {
if( uhi.getDirection() == Direction.Row )
@@ -761,13 +761,13 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
if( hi instanceof AggUnaryOp )
{
AggUnaryOp uhi = (AggUnaryOp)hi;
- Hop input = uhi.getInput().get(0);
+ Hop input = uhi.getInput(0);
if( uhi.getOp() == AggOp.SUM && uhi.getDirection() ==
Direction.Col //colsums
&& HopRewriteUtils.isBinary(input, OpOp2.MULT) )
//b(*)
{
- Hop left = input.getInput().get(0);
- Hop right = input.getInput().get(1);
+ Hop left = input.getInput(0);
+ Hop right = input.getInput(1);
if( left.getDim1()>1 && left.getDim2()>1
&& right.getDim1()>1 &&
right.getDim2()==1 ) // MV (col vector)
@@ -796,13 +796,13 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
if( hi instanceof AggUnaryOp )
{
AggUnaryOp uhi = (AggUnaryOp)hi;
- Hop input = uhi.getInput().get(0);
+ Hop input = uhi.getInput(0);
if( uhi.getOp() == AggOp.SUM && uhi.getDirection() ==
Direction.Row //rowsums
&& HopRewriteUtils.isBinary(input, OpOp2.MULT)
) //b(*)
{
- Hop left = input.getInput().get(0);
- Hop right = input.getInput().get(1);
+ Hop left = input.getInput(0);
+ Hop right = input.getInput(1);
if( left.getDim1()>1 && left.getDim2()>1
&& right.getDim1()==1 &&
right.getDim2()>1 ) // MV (row vector)
@@ -831,7 +831,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
if( hi instanceof AggUnaryOp &&
((AggUnaryOp)hi).getDirection()==Direction.RowCol )
{
AggUnaryOp uhi = (AggUnaryOp)hi;
- Hop input = uhi.getInput().get(0);
+ Hop input = uhi.getInput(0);
if( HopRewriteUtils.isValidOp(uhi.getOp(),
LOOKUP_VALID_UNNECESSARY_AGGREGATE) ){
@@ -856,7 +856,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
if( hi instanceof AggUnaryOp )
{
AggUnaryOp uhi = (AggUnaryOp)hi;
- Hop input = uhi.getInput().get(0);
+ Hop input = uhi.getInput(0);
//check for valid empty aggregates, except for matrices
with zero rows/cols
if( HopRewriteUtils.isValidOp(uhi.getOp(),
LOOKUP_VALID_EMPTY_AGGREGATE)
@@ -909,7 +909,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
if( hi instanceof UnaryOp )
{
UnaryOp uhi = (UnaryOp)hi;
- Hop input = uhi.getInput().get(0);
+ Hop input = uhi.getInput(0);
if( HopRewriteUtils.isValidOp(uhi.getOp(),
LOOKUP_VALID_EMPTY_UNARY) ){
@@ -933,7 +933,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
if( hi instanceof ReorgOp )
{
ReorgOp rhi = (ReorgOp)hi;
- Hop input = rhi.getInput().get(0);
+ Hop input = rhi.getInput(0);
if( HopRewriteUtils.isEmpty(input) ) //empty input
{
@@ -954,8 +954,8 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
}
}
else if( rhi.getOp() == ReOrgOp.RESHAPE )
- hnew =
HopRewriteUtils.createDataGenOpByVal(rhi.getInput().get(1),
rhi.getInput().get(2),
- rhi.getInput().get(3),
rhi.getDataType(), rhi.getValueType(), 0);
+ hnew =
HopRewriteUtils.createDataGenOpByVal(rhi.getInput(1), rhi.getInput(2),
+ rhi.getInput(3),
rhi.getDataType(), rhi.getValueType(), 0);
//modify dag if one of the above rules applied
if( hnew != null ){
@@ -978,7 +978,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
if( hi instanceof ReorgOp &&
((ReorgOp)hi).getOp()==ReOrgOp.SORT )
{
ReorgOp rhi = (ReorgOp)hi;
- Hop input = rhi.getInput().get(0);
+ Hop input = rhi.getInput(0);
if( HopRewriteUtils.isEmpty(input) ) //empty input
{
@@ -986,9 +986,9 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
Hop hnew = null;
boolean ixret = false;
- if( rhi.getInput().get(3) instanceof LiteralOp
) //index return known
+ if( rhi.getInput(3) instanceof LiteralOp )
//index return known
{
- ixret =
HopRewriteUtils.getBooleanValue((LiteralOp)rhi.getInput().get(3));
+ ixret =
HopRewriteUtils.getBooleanValue((LiteralOp)rhi.getInput(3));
if( ixret )
hnew =
HopRewriteUtils.createSeqDataGenOp(input);
else
@@ -1012,8 +1012,8 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
private static Hop simplifyEmptyMatrixMult(Hop parent, Hop hi, int pos)
{
if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y -> matrix(0,
)
{
- Hop left = hi.getInput().get(0);
- Hop right = hi.getInput().get(1);
+ Hop left = hi.getInput(0);
+ Hop right = hi.getInput(1);
if( HopRewriteUtils.isEmpty(left) //one input empty
|| HopRewriteUtils.isEmpty(right) )
@@ -1034,8 +1034,8 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
{
if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y -> X, if y
is matrix(1,1,1)
{
- Hop left = hi.getInput().get(0);
- Hop right = hi.getInput().get(1);
+ Hop left = hi.getInput(0);
+ Hop right = hi.getInput(1);
// X %*% y -> X
if( HopRewriteUtils.isDimsKnown(right) &&
right.getDim1()==1 && right.getDim2()==1 && //scalar right
@@ -1056,8 +1056,8 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
{
if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y
{
- Hop left = hi.getInput().get(0);
- Hop right = hi.getInput().get(1);
+ Hop left = hi.getInput(0);
+ Hop right = hi.getInput(1);
// y %*% X -> as.scalar(y) * X
if( HopRewriteUtils.isDimsKnown(left) &&
left.getDim1()==1 && left.getDim2()==1 ) //scalar left
@@ -1097,8 +1097,8 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y
{
- Hop left = hi.getInput().get(0);
- Hop right = hi.getInput().get(1);
+ Hop left = hi.getInput(0);
+ Hop right = hi.getInput(1);
// diag(X) %*% Y -> X * Y / diag(X) %*% Y -> Y * X
// previously rep required for the second case: diag(X)
%*% Y -> (X%*%ones) * Y
@@ -1108,7 +1108,7 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
if( right.getDim2()==1 ) //right column vector
{
//create binary operation over input
and right
- Hop input = left.getInput().get(0);
//diag input
+ Hop input = left.getInput(0); //diag
input
hnew =
HopRewriteUtils.createBinary(input, right, OpOp2.MULT);
LOG.debug("Applied
simplifyMatrixMultDiag1");
@@ -1117,7 +1117,7 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
{
//create binary operation over input
and right; in contrast to above rewrite,
//we need to switch the order because
MV binary cell operations require vector on the right
- Hop input = left.getInput().get(0);
//diag input
+ Hop input = left.getInput(0); //diag
input
hnew =
HopRewriteUtils.createBinary(right, input, OpOp2.MULT);
//NOTE: previously to MV binary cell
operations we replicated the left
@@ -1145,11 +1145,11 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
private static Hop simplifyDiagMatrixMult(Hop parent, Hop hi, int pos) {
if(hi instanceof ReorgOp && ((ReorgOp) hi).getOp() ==
ReOrgOp.DIAG && hi.getDim2() == 1) //diagM2V
{
- Hop hi2 = hi.getInput().get(0);
+ Hop hi2 = hi.getInput(0);
if(HopRewriteUtils.isMatrixMultiply(hi2)) //X%*%Y
{
- Hop left = hi2.getInput().get(0);
- Hop right = hi2.getInput().get(1);
+ Hop left = hi2.getInput(0);
+ Hop right = hi2.getInput(1);
//create new operators (incl refresh size
inside for transpose)
ReorgOp trans =
HopRewriteUtils.createTranspose(right);
@@ -1209,10 +1209,10 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
AggUnaryOp au = (AggUnaryOp) hi;
if(au.getOp() == AggOp.SUM && au.getDirection() ==
Direction.RowCol) //sum
{
- Hop hi2 = au.getInput().get(0);
+ Hop hi2 = au.getInput(0);
if(hi2 instanceof ReorgOp && ((ReorgOp)
hi2).getOp() == ReOrgOp.DIAG && hi2.getDim2() == 1) //diagM2V
{
- Hop hi3 = hi2.getInput().get(0);
+ Hop hi3 = hi2.getInput(0);
//remove diag operator
HopRewriteUtils.replaceChildReference(au, hi2, hi3, 0);
@@ -1233,12 +1233,12 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//pattern: X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri
(only right)
if( HopRewriteUtils.isBinary(hi, OpOp2.MULT)
&& hi.getDim1() == hi.getDim2() && hi.getDim1() > 1 ) {
- Hop left = hi.getInput().get(0);
- Hop right = hi.getInput().get(1);
+ Hop left = hi.getInput(0);
+ Hop right = hi.getInput(1);
if( HopRewriteUtils.isUnary(right, OpOp1.CUMSUM) &&
right.getParent().size()==1
- &&
HopRewriteUtils.isReorg(right.getInput().get(0), ReOrgOp.DIAG)
- &&
HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput().get(0).getInput().get(0),
1d))
+ && HopRewriteUtils.isReorg(right.getInput(0),
ReOrgOp.DIAG)
+ &&
HopRewriteUtils.isDataGenOpWithConstantValue(right.getInput(0).getInput(0), 1d))
{
LinkedHashMap<String,Hop> args = new
LinkedHashMap<>();
args.put("target", left);
@@ -1279,8 +1279,8 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//(2) in order to make the binary operation more efficient
(dense vector vs sparse matrix)
if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) )
{
- Hop left = hi.getInput().get(0);
- Hop right = hi.getInput().get(1);
+ Hop left = hi.getInput(0);
+ Hop right = hi.getInput(1);
boolean applyLeft = false;
boolean applyRight = false;
@@ -1288,14 +1288,14 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//left input is diag
if( left instanceof ReorgOp &&
((ReorgOp)left).getOp()==ReOrgOp.DIAG
&& left.getParent().size()==1 //binary op only
parent
- && left.getInput().get(0).getDim2()==1 //col
vector
+ && left.getInput(0).getDim2()==1 //col vector
&& right.getDataType() == DataType.SCALAR )
{
applyLeft = true;
}
else if( right instanceof ReorgOp &&
((ReorgOp)right).getOp()==ReOrgOp.DIAG
&& right.getParent().size()==1 //binary
op only parent
- && right.getInput().get(0).getDim2()==1
//col vector
+ && right.getInput(0).getDim2()==1 //col
vector
&& left.getDataType() ==
DataType.SCALAR )
{
applyRight = true;
@@ -1316,7 +1316,7 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//rewire binop-diag-input into diag-binop-input
if( applyLeft ) {
- Hop input = left.getInput().get(0);
+ Hop input = left.getInput(0);
HopRewriteUtils.removeChildReferenceByPos(hi, left, 0);
HopRewriteUtils.removeChildReferenceByPos(left, input, 0);
HopRewriteUtils.addChildReference(left,
hi, 0);
@@ -1325,7 +1325,7 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
hi = left;
}
else if ( applyRight ) {
- Hop input = right.getInput().get(0);
+ Hop input = right.getInput(0);
HopRewriteUtils.removeChildReferenceByPos(hi, right, 1);
HopRewriteUtils.removeChildReferenceByPos(right, input, 0);
HopRewriteUtils.addChildReference(right, hi, 0);
@@ -1362,12 +1362,12 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
if( hi instanceof AggUnaryOp //full sum root over binaryop
&& ((AggUnaryOp)hi).getDirection()==Direction.RowCol
&& ((AggUnaryOp)hi).getOp() == AggOp.SUM
- && hi.getInput().get(0) instanceof BinaryOp
- && hi.getInput().get(0).getParent().size()==1 )
//single parent
+ && hi.getInput(0) instanceof BinaryOp
+ && hi.getInput(0).getParent().size()==1 ) //single
parent
{
- BinaryOp bop = (BinaryOp) hi.getInput().get(0);
- Hop left = bop.getInput().get(0);
- Hop right = bop.getInput().get(1);
+ BinaryOp bop = (BinaryOp) hi.getInput(0);
+ Hop left = bop.getInput(0);
+ Hop right = bop.getInput(1);
if( HopRewriteUtils.isEqualSize(left, right) //dims(A)
== dims(B)
&& left.getDataType() == DataType.MATRIX
@@ -1427,43 +1427,43 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
boolean appliedPattern = false;
if( HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM,
Direction.RowCol) //all patterns rooted by sum()
- && hi.getInput().get(0) instanceof BinaryOp //all
patterns subrooted by binary op
- && hi.getInput().get(0).getDim2() > 1 ) //not
applied for vector-vector mult
+ && hi.getInput(0) instanceof BinaryOp //all patterns
subrooted by binary op
+ && hi.getInput(0).getDim2() > 1 ) //not applied
for vector-vector mult
{
- BinaryOp bop = (BinaryOp) hi.getInput().get(0);
+ BinaryOp bop = (BinaryOp) hi.getInput(0);
//Pattern 1) sum (W * (X - U %*% t(V)) ^ 2) (post
weighting)
//alternative pattern: sum (W * (U %*% t(V) - X) ^ 2)
- if( bop.getOp()==OpOp2.MULT &&
HopRewriteUtils.isBinary(bop.getInput().get(1), OpOp2.POW)
- &&
bop.getInput().get(0).getDataType()==DataType.MATRIX
- &&
HopRewriteUtils.isEqualSize(bop.getInput().get(0), bop.getInput().get(1))
//prevent mv
- &&
HopRewriteUtils.isLiteralOfValue(bop.getInput().get(1).getInput().get(1), 2) )
+ if( bop.getOp()==OpOp2.MULT &&
HopRewriteUtils.isBinary(bop.getInput(1), OpOp2.POW)
+ &&
bop.getInput(0).getDataType()==DataType.MATRIX
+ && HopRewriteUtils.isEqualSize(bop.getInput(0),
bop.getInput(1)) //prevent mv
+ &&
HopRewriteUtils.isLiteralOfValue(bop.getInput(1).getInput(1), 2) )
{
- Hop W = bop.getInput().get(0);
- Hop tmp =
bop.getInput().get(1).getInput().get(0); //(X - U %*% t(V))
+ Hop W = bop.getInput(0);
+ Hop tmp = bop.getInput(1).getInput(0); //(X - U
%*% t(V))
if( HopRewriteUtils.isBinary(tmp, OpOp2.MINUS)
- &&
HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1))
//prevent mv
- && tmp.getInput().get(0).getDataType()
== DataType.MATRIX )
+ &&
HopRewriteUtils.isEqualSize(tmp.getInput(0), tmp.getInput(1)) //prevent mv
+ && tmp.getInput(0).getDataType() ==
DataType.MATRIX )
{
//a) sum (W * (X - U %*% t(V)) ^ 2)
int uvIndex = -1;
- if( tmp.getInput().get(1) instanceof
AggBinaryOp //ba gurantees matrices
- &&
HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true)) {
//BLOCKSIZE CONSTRAINT
+ if( tmp.getInput(1) instanceof
AggBinaryOp //ba gurantees matrices
+ &&
HopRewriteUtils.isSingleBlock(tmp.getInput(1).getInput(0),true)) { //BLOCKSIZE
CONSTRAINT
uvIndex = 1;
}
//b) sum (W * (U %*% t(V) - X) ^ 2)
- else if(tmp.getInput().get(0)
instanceof AggBinaryOp //ba gurantees matrices
- &&
HopRewriteUtils.isSingleBlock(tmp.getInput().get(0).getInput().get(0),true)) {
//BLOCKSIZE CONSTRAINT
+ else if(tmp.getInput(0) instanceof
AggBinaryOp //ba gurantees matrices
+ &&
HopRewriteUtils.isSingleBlock(tmp.getInput(0).getInput(0),true)) { //BLOCKSIZE
CONSTRAINT
uvIndex = 0;
}
if( uvIndex >= 0 ) { //rewrite match
Hop X =
tmp.getInput().get((uvIndex==0)?1:0);
- Hop U =
tmp.getInput().get(uvIndex).getInput().get(0);
- Hop V =
tmp.getInput().get(uvIndex).getInput().get(1);
+ Hop U =
tmp.getInput().get(uvIndex).getInput(0);
+ Hop V =
tmp.getInput().get(uvIndex).getInput(1);
V =
!HopRewriteUtils.isTransposeOperation(V) ?
-
HopRewriteUtils.createTranspose(V) : V.getInput().get(0);
+
HopRewriteUtils.createTranspose(V) : V.getInput(0);
//handle special case of post_nz
if(
HopRewriteUtils.isNonZeroIndicator(W, X) ){
@@ -1484,38 +1484,38 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//Pattern 2) sum ((X - W * (U %*% t(V))) ^ 2) (pre
weighting)
//alternative pattern: sum ((W * (U %*% t(V)) - X) ^ 2)
if( !appliedPattern
- && bop.getOp()==OpOp2.POW &&
HopRewriteUtils.isLiteralOfValue(bop.getInput().get(1), 2)
- &&
HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS)
- &&
HopRewriteUtils.isEqualMatrixSize((BinaryOp)bop.getInput().get(0)))
+ && bop.getOp()==OpOp2.POW &&
HopRewriteUtils.isLiteralOfValue(bop.getInput(1), 2)
+ && HopRewriteUtils.isBinary(bop.getInput(0),
OpOp2.MINUS)
+ &&
HopRewriteUtils.isEqualMatrixSize((BinaryOp)bop.getInput(0)))
{
- Hop lleft =
bop.getInput().get(0).getInput().get(0);
- Hop lright =
bop.getInput().get(0).getInput().get(1);
+ Hop lleft = bop.getInput(0).getInput(0);
+ Hop lright = bop.getInput(0).getInput(1);
//a) sum ((X - W * (U %*% t(V))) ^ 2)
int wuvIndex = -1;
- if( lright instanceof BinaryOp &&
lright.getInput().get(1) instanceof AggBinaryOp ){
+ if( lright instanceof BinaryOp &&
lright.getInput(1) instanceof AggBinaryOp ){
wuvIndex = 1;
}
//b) sum ((W * (U %*% t(V)) - X) ^ 2)
- else if( lleft instanceof BinaryOp &&
lleft.getInput().get(1) instanceof AggBinaryOp ){
+ else if( lleft instanceof BinaryOp &&
lleft.getInput(1) instanceof AggBinaryOp ){
wuvIndex = 0;
}
if( wuvIndex >= 0 ) //rewrite match
{
- Hop X =
bop.getInput().get(0).getInput().get((wuvIndex==0)?1:0);
- Hop tmp =
bop.getInput().get(0).getInput().get(wuvIndex); //(W * (U %*% t(V)))
+ Hop X =
bop.getInput(0).getInput().get((wuvIndex==0)?1:0);
+ Hop tmp =
bop.getInput(0).getInput().get(wuvIndex); //(W * (U %*% t(V)))
if( ((BinaryOp)tmp).getOp()==OpOp2.MULT
- &&
tmp.getInput().get(0).getDataType() == DataType.MATRIX
- &&
HopRewriteUtils.isEqualSize(tmp.getInput().get(0), tmp.getInput().get(1))
//prevent mv
- &&
HopRewriteUtils.isSingleBlock(tmp.getInput().get(1).getInput().get(0),true))
//BLOCKSIZE CONSTRAINT
+ &&
tmp.getInput(0).getDataType() == DataType.MATRIX
+ &&
HopRewriteUtils.isEqualSize(tmp.getInput(0), tmp.getInput(1)) //prevent mv
+ &&
HopRewriteUtils.isSingleBlock(tmp.getInput(1).getInput(0),true)) //BLOCKSIZE
CONSTRAINT
{
- Hop W = tmp.getInput().get(0);
- Hop U =
tmp.getInput().get(1).getInput().get(0);
- Hop V =
tmp.getInput().get(1).getInput().get(1);
+ Hop W = tmp.getInput(0);
+ Hop U =
tmp.getInput(1).getInput(0);
+ Hop V =
tmp.getInput(1).getInput(1);
V =
!HopRewriteUtils.isTransposeOperation(V) ?
-
HopRewriteUtils.createTranspose(V) : V.getInput().get(0);
+
HopRewriteUtils.createTranspose(V) : V.getInput(0);
hnew = new
QuaternaryOp(hi.getName(), DataType.SCALAR,
ValueType.FP64,
OpOp4.WSLOSS, X, U, V, W, false);
HopRewriteUtils.setOutputParametersForScalar(hnew);
@@ -1528,33 +1528,33 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//Pattern 3) sum ((X - (U %*% t(V))) ^ 2) (no weighting)
//alternative pattern: sum (((U %*% t(V)) - X) ^ 2)
if( !appliedPattern
- && bop.getOp()==OpOp2.POW &&
HopRewriteUtils.isLiteralOfValue(bop.getInput().get(1), 2)
- &&
HopRewriteUtils.isBinary(bop.getInput().get(0), OpOp2.MINUS)
- &&
HopRewriteUtils.isEqualMatrixSize((BinaryOp)bop.getInput().get(0))) //prevent mv
+ && bop.getOp()==OpOp2.POW &&
HopRewriteUtils.isLiteralOfValue(bop.getInput(1), 2)
+ && HopRewriteUtils.isBinary(bop.getInput(0),
OpOp2.MINUS)
+ &&
HopRewriteUtils.isEqualMatrixSize((BinaryOp)bop.getInput(0))) //prevent mv
{
- Hop lleft =
bop.getInput().get(0).getInput().get(0);
- Hop lright =
bop.getInput().get(0).getInput().get(1);
+ Hop lleft = bop.getInput(0).getInput(0);
+ Hop lright = bop.getInput(0).getInput(1);
//a) sum ((X - (U %*% t(V))) ^ 2)
int uvIndex = -1;
if( lright instanceof AggBinaryOp //ba
guarantees matrices
- &&
HopRewriteUtils.isSingleBlock(lright.getInput().get(0),true) ) { //BLOCKSIZE
CONSTRAINT
+ &&
HopRewriteUtils.isSingleBlock(lright.getInput(0),true) ) { //BLOCKSIZE
CONSTRAINT
uvIndex = 1;
}
//b) sum (((U %*% t(V)) - X) ^ 2)
else if( lleft instanceof AggBinaryOp //ba
guarantees matrices
- &&
HopRewriteUtils.isSingleBlock(lleft.getInput().get(0),true) ) { //BLOCKSIZE
CONSTRAINT
+ &&
HopRewriteUtils.isSingleBlock(lleft.getInput(0),true) ) { //BLOCKSIZE CONSTRAINT
uvIndex = 0;
}
if( uvIndex >= 0 ) { //rewrite match
- Hop X =
bop.getInput().get(0).getInput().get((uvIndex==0)?1:0);
- Hop tmp =
bop.getInput().get(0).getInput().get(uvIndex); //(U %*% t(V))
+ Hop X =
bop.getInput(0).getInput().get((uvIndex==0)?1:0);
+ Hop tmp =
bop.getInput(0).getInput().get(uvIndex); //(U %*% t(V))
Hop W = new LiteralOp(1); //no
weighting
- Hop U = tmp.getInput().get(0);
- Hop V = tmp.getInput().get(1);
+ Hop U = tmp.getInput(0);
+ Hop V = tmp.getInput(1);
V =
!HopRewriteUtils.isTransposeOperation(V) ?
-
HopRewriteUtils.createTranspose(V) : V.getInput().get(0);
+
HopRewriteUtils.createTranspose(V) : V.getInput(0);
hnew = new QuaternaryOp(hi.getName(),
DataType.SCALAR,
ValueType.FP64, OpOp4.WSLOSS,
X, U, V, W, false);
HopRewriteUtils.setOutputParametersForScalar(hnew);
@@ -1569,32 +1569,32 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//alternative pattern: sumSq (U %*% t(V) - X)
if( !appliedPattern
&& HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM_SQ,
Direction.RowCol)
- && HopRewriteUtils.isBinary(hi.getInput().get(0),
OpOp2.MINUS)
- &&
HopRewriteUtils.isEqualMatrixSize((BinaryOp)hi.getInput().get(0))) //prevent mv
+ && HopRewriteUtils.isBinary(hi.getInput(0),
OpOp2.MINUS)
+ &&
HopRewriteUtils.isEqualMatrixSize((BinaryOp)hi.getInput(0))) //prevent mv
{
- Hop lleft = hi.getInput().get(0).getInput().get(0);
- Hop lright = hi.getInput().get(0).getInput().get(1);
+ Hop lleft = hi.getInput(0).getInput(0);
+ Hop lright = hi.getInput(0).getInput(1);
//a) sumSq (X - U %*% t(V))
int uvIndex = -1;
if( lright instanceof AggBinaryOp //ba guarantees
matrices
- &&
HopRewriteUtils.isSingleBlock(lright.getInput().get(0),true) ) { //BLOCKSIZE
CONSTRAINT
+ &&
HopRewriteUtils.isSingleBlock(lright.getInput(0),true) ) { //BLOCKSIZE
CONSTRAINT
uvIndex = 1;
}
//b) sumSq (U %*% t(V) - X)
else if( lleft instanceof AggBinaryOp //ba guarantees
matrices
- &&
HopRewriteUtils.isSingleBlock(lleft.getInput().get(0),true) ) { //BLOCKSIZE
CONSTRAINT
+ &&
HopRewriteUtils.isSingleBlock(lleft.getInput(0),true) ) { //BLOCKSIZE CONSTRAINT
uvIndex = 0;
}
if( uvIndex >= 0 ) { //rewrite match
- Hop X =
hi.getInput().get(0).getInput().get((uvIndex==0)?1:0);
- Hop tmp =
hi.getInput().get(0).getInput().get(uvIndex); //(U %*% t(V))
+ Hop X =
hi.getInput(0).getInput().get((uvIndex==0)?1:0);
+ Hop tmp =
hi.getInput(0).getInput().get(uvIndex); //(U %*% t(V))
Hop W = new LiteralOp(1); //no weighting
- Hop U = tmp.getInput().get(0);
- Hop V = tmp.getInput().get(1);
+ Hop U = tmp.getInput(0);
+ Hop V = tmp.getInput(1);
V = !HopRewriteUtils.isTransposeOperation(V) ?
- HopRewriteUtils.createTranspose(V) :
V.getInput().get(0);
+ HopRewriteUtils.createTranspose(V) :
V.getInput(0);
hnew = new QuaternaryOp(hi.getName(),
DataType.SCALAR,
ValueType.FP64, OpOp4.WSLOSS, X, U, V,
W, false);
HopRewriteUtils.setOutputParametersForScalar(hnew);
@@ -1619,27 +1619,27 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) //all patterns
subrooted by W *
&& hi.getDim2() > 1 //not applied for
vector-vector mult
- && HopRewriteUtils.isEqualSize(hi.getInput().get(0),
hi.getInput().get(1)) //prevent mv
- && hi.getInput().get(0).getDataType()==DataType.MATRIX
- && hi.getInput().get(1) instanceof UnaryOp )
//sigmoid/log
+ && HopRewriteUtils.isEqualSize(hi.getInput(0),
hi.getInput(1)) //prevent mv
+ && hi.getInput(0).getDataType()==DataType.MATRIX
+ && hi.getInput(1) instanceof UnaryOp ) //sigmoid/log
{
- UnaryOp uop = (UnaryOp) hi.getInput().get(1);
+ UnaryOp uop = (UnaryOp) hi.getInput(1);
boolean appliedPattern = false;
//Pattern 1) W * sigmoid(Y%*%t(X)) (basic)
if( uop.getOp() == OpOp1.SIGMOID
- && uop.getInput().get(0) instanceof AggBinaryOp
- &&
HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0),true) )
+ && uop.getInput(0) instanceof AggBinaryOp
+ &&
HopRewriteUtils.isSingleBlock(uop.getInput(0).getInput(0),true) )
{
- Hop W = hi.getInput().get(0);
- Hop Y = uop.getInput().get(0).getInput().get(0);
- Hop tX =
uop.getInput().get(0).getInput().get(1);
+ Hop W = hi.getInput(0);
+ Hop Y = uop.getInput(0).getInput(0);
+ Hop tX = uop.getInput(0).getInput(1);
if( !HopRewriteUtils.isTransposeOperation(tX) )
{
tX =
HopRewriteUtils.createTranspose(tX);
}
else
- tX = tX.getInput().get(0);
+ tX = tX.getInput(0);
hnew = new QuaternaryOp(hi.getName(),
DataType.MATRIX, ValueType.FP64,
OpOp4.WSIGMOID, W, Y, tX,
false, false);
@@ -1653,22 +1653,22 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//Pattern 2) W * sigmoid(-(Y%*%t(X))) (minus)
if( !appliedPattern
&& uop.getOp() == OpOp1.SIGMOID
- &&
HopRewriteUtils.isBinary(uop.getInput().get(0), OpOp2.MINUS)
- && uop.getInput().get(0).getInput().get(0)
instanceof LiteralOp
+ && HopRewriteUtils.isBinary(uop.getInput(0),
OpOp2.MINUS)
+ && uop.getInput(0).getInput(0) instanceof
LiteralOp
&& HopRewriteUtils.getDoubleValueSafe(
-
(LiteralOp)uop.getInput().get(0).getInput().get(0))==0
- && uop.getInput().get(0).getInput().get(1)
instanceof AggBinaryOp
- &&
HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(1).getInput().get(0),true))
+ (LiteralOp)uop.getInput(0).getInput(0))==0
+ && uop.getInput(0).getInput(1) instanceof
AggBinaryOp
+ &&
HopRewriteUtils.isSingleBlock(uop.getInput(0).getInput(1).getInput(0),true))
{
- Hop W = hi.getInput().get(0);
- Hop Y =
uop.getInput().get(0).getInput().get(1).getInput().get(0);
- Hop tX =
uop.getInput().get(0).getInput().get(1).getInput().get(1);
+ Hop W = hi.getInput(0);
+ Hop Y = uop.getInput(0).getInput(1).getInput(0);
+ Hop tX =
uop.getInput(0).getInput(1).getInput(1);
if( !HopRewriteUtils.isTransposeOperation(tX) )
{
tX =
HopRewriteUtils.createTranspose(tX);
}
else
- tX = tX.getInput().get(0);
+ tX = tX.getInput(0);
hnew = new QuaternaryOp(hi.getName(),
DataType.MATRIX, ValueType.FP64,
OpOp4.WSIGMOID, W, Y, tX,
false, true);
@@ -1682,19 +1682,19 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//Pattern 3) W * log(sigmoid(Y%*%t(X))) (log)
if( !appliedPattern
&& uop.getOp() == OpOp1.LOG
- &&
HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID)
- && uop.getInput().get(0).getInput().get(0)
instanceof AggBinaryOp
- &&
HopRewriteUtils.isSingleBlock(uop.getInput().get(0).getInput().get(0).getInput().get(0),true)
)
+ && HopRewriteUtils.isUnary(uop.getInput(0),
OpOp1.SIGMOID)
+ && uop.getInput(0).getInput(0) instanceof
AggBinaryOp
+ &&
HopRewriteUtils.isSingleBlock(uop.getInput(0).getInput(0).getInput(0),true) )
{
- Hop W = hi.getInput().get(0);
- Hop Y =
uop.getInput().get(0).getInput().get(0).getInput().get(0);
- Hop tX =
uop.getInput().get(0).getInput().get(0).getInput().get(1);
+ Hop W = hi.getInput(0);
+ Hop Y = uop.getInput(0).getInput(0).getInput(0);
+ Hop tX =
uop.getInput(0).getInput(0).getInput(1);
if( !HopRewriteUtils.isTransposeOperation(tX) )
{
tX =
HopRewriteUtils.createTranspose(tX);
}
else
- tX = tX.getInput().get(0);
+ tX = tX.getInput(0);
hnew = new QuaternaryOp(hi.getName(),
DataType.MATRIX, ValueType.FP64,
OpOp4.WSIGMOID, W, Y, tX,
true, false);
@@ -1708,25 +1708,25 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//Pattern 4) W * log(sigmoid(-(Y%*%t(X)))) (log_minus)
if( !appliedPattern
&& uop.getOp() == OpOp1.LOG
- &&
HopRewriteUtils.isUnary(uop.getInput().get(0), OpOp1.SIGMOID)
- &&
HopRewriteUtils.isBinary(uop.getInput().get(0).getInput().get(0), OpOp2.MINUS) )
+ && HopRewriteUtils.isUnary(uop.getInput(0),
OpOp1.SIGMOID)
+ &&
HopRewriteUtils.isBinary(uop.getInput(0).getInput(0), OpOp2.MINUS) )
{
- BinaryOp bop = (BinaryOp)
uop.getInput().get(0).getInput().get(0);
+ BinaryOp bop = (BinaryOp)
uop.getInput(0).getInput(0);
- if( bop.getInput().get(0) instanceof
LiteralOp
- &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)bop.getInput().get(0))==0
- && bop.getInput().get(1) instanceof
AggBinaryOp
- &&
HopRewriteUtils.isSingleBlock(bop.getInput().get(1).getInput().get(0),true))
+ if( bop.getInput(0) instanceof LiteralOp
+ &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)bop.getInput(0))==0
+ && bop.getInput(1) instanceof
AggBinaryOp
+ &&
HopRewriteUtils.isSingleBlock(bop.getInput(1).getInput(0),true))
{
- Hop W = hi.getInput().get(0);
- Hop Y =
bop.getInput().get(1).getInput().get(0);
- Hop tX =
bop.getInput().get(1).getInput().get(1);
+ Hop W = hi.getInput(0);
+ Hop Y = bop.getInput(1).getInput(0);
+ Hop tX = bop.getInput(1).getInput(1);
if(
!HopRewriteUtils.isTransposeOperation(tX) ) {
tX =
HopRewriteUtils.createTranspose(tX);
}
else
- tX = tX.getInput().get(0);
+ tX = tX.getInput(0);
hnew = new QuaternaryOp(hi.getName(),
DataType.MATRIX, ValueType.FP64,
OpOp4.WSIGMOID, W, Y,
tX, true, true);
@@ -1755,32 +1755,32 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//left/right patterns rooted by 'ab - b(div)' or 'ab - b(mult)'
//note: we do not rewrite t(X)%*%(w*(X%*%v)) where w and v are
vectors (see mmchain ops)
if( HopRewriteUtils.isMatrixMultiply(hi)
- && (hi.getInput().get(0) instanceof BinaryOp
- &&
HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput().get(0)).getOp(),
LOOKUP_VALID_WDIVMM_BINARY)
- || hi.getInput().get(1) instanceof BinaryOp
+ && (hi.getInput(0) instanceof BinaryOp
+ &&
HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput(0)).getOp(),
LOOKUP_VALID_WDIVMM_BINARY)
+ || hi.getInput(1) instanceof BinaryOp
&& hi.getDim2() > 1 //not applied for vector-vector mult
- &&
HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput().get(1)).getOp(),
LOOKUP_VALID_WDIVMM_BINARY)) )
+ &&
HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput(1)).getOp(),
LOOKUP_VALID_WDIVMM_BINARY)) )
{
- Hop left = hi.getInput().get(0);
- Hop right = hi.getInput().get(1);
+ Hop left = hi.getInput(0);
+ Hop right = hi.getInput(1);
//Pattern 1) t(U) %*% (W/(U%*%t(V)))
//alternative pattern: t(U) %*% (W*(U%*%t(V)))
if( right instanceof BinaryOp &&
HopRewriteUtils.isValidOp(((BinaryOp)right).getOp(),LOOKUP_VALID_WDIVMM_BINARY)
- &&
HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1))
//prevent mv
- &&
HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1))
- &&
HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0),true) )
//BLOCKSIZE CONSTRAINT
+ &&
HopRewriteUtils.isEqualSize(right.getInput(0), right.getInput(1)) //prevent mv
+ &&
HopRewriteUtils.isOuterProductLikeMM(right.getInput(1))
+ &&
HopRewriteUtils.isSingleBlock(right.getInput(1).getInput(0),true) ) //BLOCKSIZE
CONSTRAINT
{
- Hop W = right.getInput().get(0);
- Hop U =
right.getInput().get(1).getInput().get(0);
- Hop V =
right.getInput().get(1).getInput().get(1);
+ Hop W = right.getInput(0);
+ Hop U = right.getInput(1).getInput(0);
+ Hop V = right.getInput(1).getInput(1);
if( HopRewriteUtils.isTransposeOfItself(left,
U) )
{
if(
!HopRewriteUtils.isTransposeOperation(V) )
V =
HopRewriteUtils.createTranspose(V);
else
- V = V.getInput().get(0);
+ V = V.getInput(0);
boolean mult =
((BinaryOp)right).getOp() == OpOp2.MULT;
hnew = new QuaternaryOp(hi.getName(),
DataType.MATRIX, ValueType.FP64,
@@ -1799,23 +1799,23 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//Pattern 1e) t(U) %*% (W/(U%*%t(V) + x))
if( !appliedPattern
&& HopRewriteUtils.isBinary(right,
LOOKUP_VALID_WDIVMM_BINARY[1]) //DIV
- &&
HopRewriteUtils.isEqualSize(right.getInput().get(0), right.getInput().get(1))
//prevent mv
- &&
HopRewriteUtils.isBinary(right.getInput().get(1), OpOp2.PLUS)
- &&
right.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR
- &&
HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0))
- &&
HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true)
) //BLOCKSIZE CONSTRAINT
+ &&
HopRewriteUtils.isEqualSize(right.getInput(0), right.getInput(1)) //prevent mv
+ && HopRewriteUtils.isBinary(right.getInput(1),
OpOp2.PLUS)
+ && right.getInput(1).getInput(1).getDataType()
== DataType.SCALAR
+ &&
HopRewriteUtils.isOuterProductLikeMM(right.getInput(1).getInput(0))
+ &&
HopRewriteUtils.isSingleBlock(right.getInput(1).getInput(0).getInput(0),true) )
//BLOCKSIZE CONSTRAINT
{
- Hop W = right.getInput().get(0);
- Hop U =
right.getInput().get(1).getInput().get(0).getInput().get(0);
- Hop V =
right.getInput().get(1).getInput().get(0).getInput().get(1);
- Hop X =
right.getInput().get(1).getInput().get(1);
+ Hop W = right.getInput(0);
+ Hop U =
right.getInput(1).getInput(0).getInput(0);
+ Hop V =
right.getInput(1).getInput(0).getInput(1);
+ Hop X = right.getInput(1).getInput(1);
if( HopRewriteUtils.isTransposeOfItself(left,
U) )
{
if(
!HopRewriteUtils.isTransposeOperation(V) )
V =
HopRewriteUtils.createTranspose(V);
else
- V = V.getInput().get(0);
+ V = V.getInput(0);
hnew = new QuaternaryOp(hi.getName(),
DataType.MATRIX, ValueType.FP64,
OpOp4.WDIVMM, W, U,
V, X, 3, false, false); // 3=>DIV_LEFT_EPS
@@ -1834,20 +1834,20 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//alternative pattern: (W*(U%*%t(V))) %*% V
if( !appliedPattern
&& left instanceof BinaryOp &&
HopRewriteUtils.isValidOp(((BinaryOp)left).getOp(), LOOKUP_VALID_WDIVMM_BINARY)
- &&
HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1))
//prevent mv
- &&
HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1))
- &&
HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0),true) )
//BLOCKSIZE CONSTRAINT
+ &&
HopRewriteUtils.isEqualSize(left.getInput(0), left.getInput(1)) //prevent mv
+ &&
HopRewriteUtils.isOuterProductLikeMM(left.getInput(1))
+ &&
HopRewriteUtils.isSingleBlock(left.getInput(1).getInput(0),true) ) //BLOCKSIZE
CONSTRAINT
{
- Hop W = left.getInput().get(0);
- Hop U =
left.getInput().get(1).getInput().get(0);
- Hop V =
left.getInput().get(1).getInput().get(1);
+ Hop W = left.getInput(0);
+ Hop U = left.getInput(1).getInput(0);
+ Hop V = left.getInput(1).getInput(1);
if( HopRewriteUtils.isTransposeOfItself(right,
V) )
{
if(
!HopRewriteUtils.isTransposeOperation(V) )
V = right;
else
- V = V.getInput().get(0);
+ V = V.getInput(0);
boolean mult = ((BinaryOp)left).getOp()
== OpOp2.MULT;
hnew = new QuaternaryOp(hi.getName(),
DataType.MATRIX, ValueType.FP64,
@@ -1863,23 +1863,23 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//Pattern 2e) (W/(U%*%t(V) + x)) %*% V
if( !appliedPattern
&& HopRewriteUtils.isBinary(left,
LOOKUP_VALID_WDIVMM_BINARY[1]) //DIV
- &&
HopRewriteUtils.isEqualSize(left.getInput().get(0), left.getInput().get(1))
//prevent mv
- &&
HopRewriteUtils.isBinary(left.getInput().get(1), OpOp2.PLUS)
- &&
left.getInput().get(1).getInput().get(1).getDataType() == DataType.SCALAR
- &&
HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0))
- &&
HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true)
) //BLOCKSIZE CONSTRAINT
+ &&
HopRewriteUtils.isEqualSize(left.getInput(0), left.getInput(1)) //prevent mv
+ && HopRewriteUtils.isBinary(left.getInput(1),
OpOp2.PLUS)
+ && left.getInput(1).getInput(1).getDataType()
== DataType.SCALAR
+ &&
HopRewriteUtils.isOuterProductLikeMM(left.getInput(1).getInput(0))
+ &&
HopRewriteUtils.isSingleBlock(left.getInput(1).getInput(0).getInput(0),true) )
//BLOCKSIZE CONSTRAINT
{
- Hop W = left.getInput().get(0);
- Hop U =
left.getInput().get(1).getInput().get(0).getInput().get(0);
- Hop V =
left.getInput().get(1).getInput().get(0).getInput().get(1);
- Hop X =
left.getInput().get(1).getInput().get(1);
+ Hop W = left.getInput(0);
+ Hop U =
left.getInput(1).getInput(0).getInput(0);
+ Hop V =
left.getInput(1).getInput(0).getInput(1);
+ Hop X = left.getInput(1).getInput(1);
if( HopRewriteUtils.isTransposeOfItself(right,
V) )
{
if(
!HopRewriteUtils.isTransposeOperation(V) )
V = right;
else
- V = V.getInput().get(0);
+ V = V.getInput(0);
hnew = new QuaternaryOp(hi.getName(),
DataType.MATRIX, ValueType.FP64,
OpOp4.WDIVMM, W, U,
V, X, 4, false, false); // 4=>DIV_RIGHT_EPS
@@ -1894,15 +1894,15 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//Pattern 3) t(U) %*% ((X!=0)*(U%*%t(V)-X))
if( !appliedPattern
&& HopRewriteUtils.isBinary(right,
LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT
- &&
HopRewriteUtils.isBinary(right.getInput().get(1), OpOp2.MINUS)
- &&
HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0))
- &&
right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
- &&
HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true)
) //BLOCKSIZE CONSTRAINT
+ && HopRewriteUtils.isBinary(right.getInput(1),
OpOp2.MINUS)
+ &&
HopRewriteUtils.isOuterProductLikeMM(right.getInput(1).getInput(0))
+ && right.getInput(1).getInput(1).getDataType()
== DataType.MATRIX
+ &&
HopRewriteUtils.isSingleBlock(right.getInput(1).getInput(0).getInput(0),true) )
//BLOCKSIZE CONSTRAINT
{
- Hop W = right.getInput().get(0);
- Hop U =
right.getInput().get(1).getInput().get(0).getInput().get(0);
- Hop V =
right.getInput().get(1).getInput().get(0).getInput().get(1);
- Hop X =
right.getInput().get(1).getInput().get(1);
+ Hop W = right.getInput(0);
+ Hop U =
right.getInput(1).getInput(0).getInput(0);
+ Hop V =
right.getInput(1).getInput(0).getInput(1);
+ Hop X = right.getInput(1).getInput(1);
if( HopRewriteUtils.isNonZeroIndicator(W, X)
//W-X constraint
&&
HopRewriteUtils.isTransposeOfItself(left, U) ) //t(U)-U constraint
@@ -1910,7 +1910,7 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
if(
!HopRewriteUtils.isTransposeOperation(V) )
V =
HopRewriteUtils.createTranspose(V);
else
- V = V.getInput().get(0);
+ V = V.getInput(0);
hnew = new QuaternaryOp(hi.getName(),
DataType.MATRIX, ValueType.FP64,
OpOp4.WDIVMM, X, U,
V, new LiteralOp(-1), 1, true, true);
@@ -1928,15 +1928,15 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//Pattern 4) ((X!=0)*(U%*%t(V)-X)) %*% V
if( !appliedPattern
&& HopRewriteUtils.isBinary(left,
LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT
- &&
HopRewriteUtils.isBinary(left.getInput().get(1), OpOp2.MINUS)
- &&
HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0))
- &&
left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
- &&
HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true)
) //BLOCKSIZE CONSTRAINT
+ && HopRewriteUtils.isBinary(left.getInput(1),
OpOp2.MINUS)
+ &&
HopRewriteUtils.isOuterProductLikeMM(left.getInput(1).getInput(0))
+ && left.getInput(1).getInput(1).getDataType()
== DataType.MATRIX
+ &&
HopRewriteUtils.isSingleBlock(left.getInput(1).getInput(0).getInput(0),true) )
//BLOCKSIZE CONSTRAINT
{
- Hop W = left.getInput().get(0);
- Hop U =
left.getInput().get(1).getInput().get(0).getInput().get(0);
- Hop V =
left.getInput().get(1).getInput().get(0).getInput().get(1);
- Hop X =
left.getInput().get(1).getInput().get(1);
+ Hop W = left.getInput(0);
+ Hop U =
left.getInput(1).getInput(0).getInput(0);
+ Hop V =
left.getInput(1).getInput(0).getInput(1);
+ Hop X = left.getInput(1).getInput(1);
if( HopRewriteUtils.isNonZeroIndicator(W, X)
//W-X constraint
&&
HopRewriteUtils.isTransposeOfItself(right, V) ) //V-t(V) constraint
@@ -1944,7 +1944,7 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
if(
!HopRewriteUtils.isTransposeOperation(V) )
V = right;
else
- V = V.getInput().get(0);
+ V = V.getInput(0);
hnew = new QuaternaryOp(hi.getName(),
DataType.MATRIX, ValueType.FP64,
OpOp4.WDIVMM, X, U,
V, new LiteralOp(-1), 2, true, true);
@@ -1959,22 +1959,22 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//Pattern 5) t(U) %*% (W*(U%*%t(V)-X))
if( !appliedPattern
&& HopRewriteUtils.isBinary(right,
LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT
- &&
HopRewriteUtils.isBinary(right.getInput().get(1), OpOp2.MINUS)
- &&
HopRewriteUtils.isOuterProductLikeMM(right.getInput().get(1).getInput().get(0))
- &&
right.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
- &&
HopRewriteUtils.isSingleBlock(right.getInput().get(1).getInput().get(0).getInput().get(0),true)
) //BLOCKSIZE CONSTRAINT
+ && HopRewriteUtils.isBinary(right.getInput(1),
OpOp2.MINUS)
+ &&
HopRewriteUtils.isOuterProductLikeMM(right.getInput(1).getInput(0))
+ && right.getInput(1).getInput(1).getDataType()
== DataType.MATRIX
+ &&
HopRewriteUtils.isSingleBlock(right.getInput(1).getInput(0).getInput(0),true) )
//BLOCKSIZE CONSTRAINT
{
- Hop W = right.getInput().get(0);
- Hop U =
right.getInput().get(1).getInput().get(0).getInput().get(0);
- Hop V =
right.getInput().get(1).getInput().get(0).getInput().get(1);
- Hop X =
right.getInput().get(1).getInput().get(1);
+ Hop W = right.getInput(0);
+ Hop U =
right.getInput(1).getInput(0).getInput(0);
+ Hop V =
right.getInput(1).getInput(0).getInput(1);
+ Hop X = right.getInput(1).getInput(1);
if( HopRewriteUtils.isTransposeOfItself(left,
U) ) //t(U)-U constraint
{
if(
!HopRewriteUtils.isTransposeOperation(V) )
V =
HopRewriteUtils.createTranspose(V);
else
- V = V.getInput().get(0);
+ V = V.getInput(0);
//note: x and w exchanged compared to
patterns 1-4, 7
hnew = new QuaternaryOp(hi.getName(),
DataType.MATRIX, ValueType.FP64,
@@ -1993,22 +1993,22 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//Pattern 6) (W*(U%*%t(V)-X)) %*% V
if( !appliedPattern
&& HopRewriteUtils.isBinary(left,
LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT
- &&
HopRewriteUtils.isBinary(left.getInput().get(1), OpOp2.MINUS)
- &&
HopRewriteUtils.isOuterProductLikeMM(left.getInput().get(1).getInput().get(0))
- &&
left.getInput().get(1).getInput().get(1).getDataType() == DataType.MATRIX
- &&
HopRewriteUtils.isSingleBlock(left.getInput().get(1).getInput().get(0).getInput().get(0),true)
) //BLOCKSIZE CONSTRAINT
+ && HopRewriteUtils.isBinary(left.getInput(1),
OpOp2.MINUS)
+ &&
HopRewriteUtils.isOuterProductLikeMM(left.getInput(1).getInput(0))
+ && left.getInput(1).getInput(1).getDataType()
== DataType.MATRIX
+ &&
HopRewriteUtils.isSingleBlock(left.getInput(1).getInput(0).getInput(0),true) )
//BLOCKSIZE CONSTRAINT
{
- Hop W = left.getInput().get(0);
- Hop U =
left.getInput().get(1).getInput().get(0).getInput().get(0);
- Hop V =
left.getInput().get(1).getInput().get(0).getInput().get(1);
- Hop X =
left.getInput().get(1).getInput().get(1);
+ Hop W = left.getInput(0);
+ Hop U =
left.getInput(1).getInput(0).getInput(0);
+ Hop V =
left.getInput(1).getInput(0).getInput(1);
+ Hop X = left.getInput(1).getInput(1);
if( HopRewriteUtils.isTransposeOfItself(right,
V) ) //V-t(V) constraint
{
if(
!HopRewriteUtils.isTransposeOperation(V) )
V = right;
else
- V = V.getInput().get(0);
+ V = V.getInput(0);
//note: x and w exchanged compared to
patterns 1-4, 7
hnew = new QuaternaryOp(hi.getName(),
DataType.MATRIX, ValueType.FP64,
@@ -2025,24 +2025,24 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//Pattern 7) (W*(U%*%t(V)))
if( !appliedPattern
&& HopRewriteUtils.isBinary(hi,
LOOKUP_VALID_WDIVMM_BINARY[0]) //MULT
- && HopRewriteUtils.isEqualSize(hi.getInput().get(0),
hi.getInput().get(1)) //prevent mv
+ && HopRewriteUtils.isEqualSize(hi.getInput(0),
hi.getInput(1)) //prevent mv
&& hi.getDim2() > 1 //not applied for vector-vector mult
- && hi.getInput().get(0).getDataType() ==
DataType.MATRIX
- && hi.getInput().get(0).getDim2() >
hi.getInput().get(0).getBlocksize()
- &&
HopRewriteUtils.isOuterProductLikeMM(hi.getInput().get(1))
- && (((AggBinaryOp)
hi.getInput().get(1)).checkMapMultChain() == ChainType.NONE ||
hi.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain
- &&
HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0),true) )
//BLOCKSIZE CONSTRAINT
+ && hi.getInput(0).getDataType() == DataType.MATRIX
+ && hi.getInput(0).getDim2() >
hi.getInput(0).getBlocksize()
+ && HopRewriteUtils.isOuterProductLikeMM(hi.getInput(1))
+ && (((AggBinaryOp) hi.getInput(1)).checkMapMultChain()
== ChainType.NONE || hi.getInput(1).getInput(1).getDim2() > 1) //no mmchain
+ &&
HopRewriteUtils.isSingleBlock(hi.getInput(1).getInput(0),true) ) //BLOCKSIZE
CONSTRAINT
{
- Hop W = hi.getInput().get(0);
- Hop U = hi.getInput().get(1).getInput().get(0);
- Hop V = hi.getInput().get(1).getInput().get(1);
+ Hop W = hi.getInput(0);
+ Hop U = hi.getInput(1).getInput(0);
+ Hop V = hi.getInput(1).getInput(1);
//for this basic pattern, we're more conservative and
only apply wdivmm if
//W is sparse and U/V unknown or dense; or if U/V are
dense
if( (HopRewriteUtils.isSparse(W) &&
!HopRewriteUtils.isSparse(U) && !HopRewriteUtils.isSparse(V))
|| (HopRewriteUtils.isDense(U) &&
HopRewriteUtils.isDense(V)) ) {
V = !HopRewriteUtils.isTransposeOperation(V) ?
- HopRewriteUtils.createTranspose(V) :
V.getInput().get(0);
+ HopRewriteUtils.createTranspose(V) :
V.getInput(0);
hnew = new QuaternaryOp(hi.getName(),
DataType.MATRIX, ValueType.FP64,
OpOp4.WDIVMM, W, U, V, new
LiteralOp(-1), 0, true, false);
hnew.setBlocksize(W.getBlocksize());
@@ -2068,28 +2068,28 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
if( hi instanceof AggUnaryOp &&
((AggUnaryOp)hi).getDirection()==Direction.RowCol
&& ((AggUnaryOp)hi).getOp() == AggOp.SUM //pattern
rooted by sum()
- && hi.getInput().get(0) instanceof BinaryOp //pattern
subrooted by binary op
- && hi.getInput().get(0).getDim2() > 1 ) //not
applied for vector-vector mult
+ && hi.getInput(0) instanceof BinaryOp //pattern
subrooted by binary op
+ && hi.getInput(0).getDim2() > 1 ) //not applied
for vector-vector mult
{
- BinaryOp bop = (BinaryOp) hi.getInput().get(0);
- Hop left = bop.getInput().get(0);
- Hop right = bop.getInput().get(1);
+ BinaryOp bop = (BinaryOp) hi.getInput(0);
+ Hop left = bop.getInput(0);
+ Hop right = bop.getInput(1);
//Pattern 1) sum( X * log(U %*% t(V)))
if( bop.getOp()==OpOp2.MULT &&
left.getDataType()==DataType.MATRIX
&& HopRewriteUtils.isEqualSize(left, right)
//prevent mb
&& HopRewriteUtils.isUnary(right, OpOp1.LOG)
- && right.getInput().get(0) instanceof
AggBinaryOp //ba gurantees matrices
- &&
HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0),true))
//BLOCKSIZE CONSTRAINT
+ && right.getInput(0) instanceof AggBinaryOp
//ba gurantees matrices
+ &&
HopRewriteUtils.isSingleBlock(right.getInput(0).getInput(0),true)) //BLOCKSIZE
CONSTRAINT
{
Hop X = left;
- Hop U =
right.getInput().get(0).getInput().get(0);
- Hop V =
right.getInput().get(0).getInput().get(1);
+ Hop U = right.getInput(0).getInput(0);
+ Hop V = right.getInput(0).getInput(1);
if( !HopRewriteUtils.isTransposeOperation(V) )
V = HopRewriteUtils.createTranspose(V);
else
- V = V.getInput().get(0);
+ V = V.getInput(0);
hnew = new QuaternaryOp(hi.getName(),
DataType.SCALAR, ValueType.FP64, OpOp4.WCEMM, X, U, V,
new LiteralOp(0.0), 0, false,
false);
@@ -2104,21 +2104,21 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
&& bop.getOp()==OpOp2.MULT &&
left.getDataType()==DataType.MATRIX
&& HopRewriteUtils.isEqualSize(left, right)
&& HopRewriteUtils.isUnary(right, OpOp1.LOG)
- &&
HopRewriteUtils.isBinary(right.getInput().get(0), OpOp2.PLUS)
- && right.getInput().get(0).getInput().get(0)
instanceof AggBinaryOp
- && right.getInput().get(0).getInput().get(1)
instanceof LiteralOp
- &&
right.getInput().get(0).getInput().get(1).getDataType() == DataType.SCALAR
- &&
HopRewriteUtils.isSingleBlock(right.getInput().get(0).getInput().get(0).getInput().get(0),true))
+ && HopRewriteUtils.isBinary(right.getInput(0),
OpOp2.PLUS)
+ && right.getInput(0).getInput(0) instanceof
AggBinaryOp
+ && right.getInput(0).getInput(1) instanceof
LiteralOp
+ && right.getInput(0).getInput(1).getDataType()
== DataType.SCALAR
+ &&
HopRewriteUtils.isSingleBlock(right.getInput(0).getInput(0).getInput(0),true))
{
Hop X = left;
- Hop U =
right.getInput().get(0).getInput().get(0).getInput().get(0);
- Hop V =
right.getInput().get(0).getInput().get(0).getInput().get(1);
- Hop eps =
right.getInput().get(0).getInput().get(1);
+ Hop U =
right.getInput(0).getInput(0).getInput(0);
+ Hop V =
right.getInput(0).getInput(0).getInput(1);
+ Hop eps = right.getInput(0).getInput(1);
if( !HopRewriteUtils.isTransposeOperation(V) )
V = HopRewriteUtils.createTranspose(V);
else
- V = V.getInput().get(0);
+ V = V.getInput(0);
hnew = new QuaternaryOp(hi.getName(),
DataType.SCALAR, ValueType.FP64,
OpOp4.WCEMM, X, U, V, eps, 1,
false, false); // 1 => BASIC_EPS
@@ -2143,25 +2143,25 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//Pattern 1) (W*uop(U%*%t(V)))
if( hi instanceof BinaryOp &&
HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(),LOOKUP_VALID_WDIVMM_BINARY)
- && HopRewriteUtils.isEqualSize(hi.getInput().get(0),
hi.getInput().get(1)) //prevent mv
+ && HopRewriteUtils.isEqualSize(hi.getInput(0),
hi.getInput(1)) //prevent mv
&& hi.getDim2() > 1 //not applied for vector-vector mult
- && hi.getInput().get(0).getDataType() ==
DataType.MATRIX
- && hi.getInput().get(0).getDim2() >
hi.getInput().get(0).getBlocksize()
- && hi.getInput().get(1) instanceof UnaryOp
- &&
HopRewriteUtils.isValidOp(((UnaryOp)hi.getInput().get(1)).getOp(),
LOOKUP_VALID_WUMM_UNARY)
- && hi.getInput().get(1).getInput().get(0) instanceof
AggBinaryOp
- &&
HopRewriteUtils.isSingleBlock(hi.getInput().get(1).getInput().get(0).getInput().get(0),true)
) //BLOCKSIZE CONSTRAINT
+ && hi.getInput(0).getDataType() == DataType.MATRIX
+ && hi.getInput(0).getDim2() >
hi.getInput(0).getBlocksize()
+ && hi.getInput(1) instanceof UnaryOp
+ &&
HopRewriteUtils.isValidOp(((UnaryOp)hi.getInput(1)).getOp(),
LOOKUP_VALID_WUMM_UNARY)
+ && hi.getInput(1).getInput(0) instanceof AggBinaryOp
+ &&
HopRewriteUtils.isSingleBlock(hi.getInput(1).getInput(0).getInput(0),true) )
//BLOCKSIZE CONSTRAINT
{
- Hop W = hi.getInput().get(0);
- Hop U =
hi.getInput().get(1).getInput().get(0).getInput().get(0);
- Hop V =
hi.getInput().get(1).getInput().get(0).getInput().get(1);
+ Hop W = hi.getInput(0);
+ Hop U = hi.getInput(1).getInput(0).getInput(0);
+ Hop V = hi.getInput(1).getInput(0).getInput(1);
boolean mult = ((BinaryOp)hi).getOp()==OpOp2.MULT;
- OpOp1 op = ((UnaryOp)hi.getInput().get(1)).getOp();
+ OpOp1 op = ((UnaryOp)hi.getInput(1)).getOp();
if( !HopRewriteUtils.isTransposeOperation(V) )
V = HopRewriteUtils.createTranspose(V);
else
- V = V.getInput().get(0);
+ V = V.getInput(0);
hnew = new QuaternaryOp(hi.getName(), DataType.MATRIX,
ValueType.FP64,
OpOp4.WUMM, W, U, V, mult, op, null);
@@ -2175,33 +2175,33 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//Pattern 2.7) (W*(U%*%t(V))*2 or 2*(W*(U%*%t(V))
if( !appliedPattern
&& hi instanceof BinaryOp &&
HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(), OpOp2.MULT)
- &&
(HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0), 2)
- ||
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 2)))
+ &&
(HopRewriteUtils.isLiteralOfValue(hi.getInput(0), 2)
+ ||
HopRewriteUtils.isLiteralOfValue(hi.getInput(1), 2)))
{
final Hop nl; // non-literal
- if( hi.getInput().get(0) instanceof LiteralOp ) {
- nl = hi.getInput().get(1);
+ if( hi.getInput(0) instanceof LiteralOp ) {
+ nl = hi.getInput(1);
} else {
- nl = hi.getInput().get(0);
+ nl = hi.getInput(0);
}
if ( HopRewriteUtils.isBinary(nl, OpOp2.MULT)
&& nl.getParent().size()==1 // ensure
no foreign parents
- &&
HopRewriteUtils.isEqualSize(nl.getInput().get(0), nl.getInput().get(1))
//prevent mv
+ &&
HopRewriteUtils.isEqualSize(nl.getInput(0), nl.getInput(1)) //prevent mv
&& nl.getDim2() > 1 //not applied for
vector-vector mult
- && nl.getInput().get(0).getDataType()
== DataType.MATRIX
- && nl.getInput().get(0).getDim2() >
nl.getInput().get(0).getBlocksize()
- &&
HopRewriteUtils.isOuterProductLikeMM(nl.getInput().get(1))
- && (((AggBinaryOp)
nl.getInput().get(1)).checkMapMultChain() == ChainType.NONE ||
nl.getInput().get(1).getInput().get(1).getDim2() > 1) //no mmchain
- &&
HopRewriteUtils.isSingleBlock(nl.getInput().get(1).getInput().get(0),true) )
+ && nl.getInput(0).getDataType() ==
DataType.MATRIX
+ && nl.getInput(0).getDim2() >
nl.getInput(0).getBlocksize()
+ &&
HopRewriteUtils.isOuterProductLikeMM(nl.getInput(1))
+ && (((AggBinaryOp)
nl.getInput(1)).checkMapMultChain() == ChainType.NONE ||
nl.getInput(1).getInput(1).getDim2() > 1) //no mmchain
+ &&
HopRewriteUtils.isSingleBlock(nl.getInput(1).getInput(0),true) )
{
- final Hop W = nl.getInput().get(0);
- final Hop U =
nl.getInput().get(1).getInput().get(0);
- Hop V = nl.getInput().get(1).getInput().get(1);
+ final Hop W = nl.getInput(0);
+ final Hop U = nl.getInput(1).getInput(0);
+ Hop V = nl.getInput(1).getInput(1);
if( !HopRewriteUtils.isTransposeOperation(V) )
V = HopRewriteUtils.createTranspose(V);
else
- V = V.getInput().get(0);
+ V = V.getInput(0);
hnew = new QuaternaryOp(hi.getName(),
DataType.MATRIX, ValueType.FP64,
OpOp4.WUMM, W, U, V, true,
null, OpOp2.MULT);
@@ -2216,46 +2216,46 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//Pattern 2) (W*sop(U%*%t(V),c)) for known sop translating to
unary ops
if( !appliedPattern
&& hi instanceof BinaryOp &&
HopRewriteUtils.isValidOp(((BinaryOp)hi).getOp(),LOOKUP_VALID_WDIVMM_BINARY)
- && HopRewriteUtils.isEqualSize(hi.getInput().get(0),
hi.getInput().get(1)) //prevent mv
+ && HopRewriteUtils.isEqualSize(hi.getInput(0),
hi.getInput(1)) //prevent mv
&& hi.getDim2() > 1 //not applied for vector-vector mult
- && hi.getInput().get(0).getDataType() == DataType.MATRIX
- && hi.getInput().get(0).getDim2() >
hi.getInput().get(0).getBlocksize()
- && hi.getInput().get(1) instanceof BinaryOp
- &&
HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput().get(1)).getOp(),
LOOKUP_VALID_WUMM_BINARY) )
+ && hi.getInput(0).getDataType() == DataType.MATRIX
+ && hi.getInput(0).getDim2() >
hi.getInput(0).getBlocksize()
+ && hi.getInput(1) instanceof BinaryOp
+ &&
HopRewriteUtils.isValidOp(((BinaryOp)hi.getInput(1)).getOp(),
LOOKUP_VALID_WUMM_BINARY) )
{
- Hop left = hi.getInput().get(1).getInput().get(0);
- Hop right = hi.getInput().get(1).getInput().get(1);
+ Hop left = hi.getInput(1).getInput(0);
+ Hop right = hi.getInput(1).getInput(1);
Hop abop = null;
//pattern 2a) matrix-scalar operations
if( right.getDataType()==DataType.SCALAR && right
instanceof LiteralOp
&&
HopRewriteUtils.getDoubleValue((LiteralOp)right)==2 //pow2, mult2
&& left instanceof AggBinaryOp
- &&
HopRewriteUtils.isSingleBlock(left.getInput().get(0),true) ) //BLOCKSIZE
CONSTRAINT
+ &&
HopRewriteUtils.isSingleBlock(left.getInput(0),true) ) //BLOCKSIZE CONSTRAINT
{
abop = left;
}
//pattern 2b) scalar-matrix operations
else if( left.getDataType()==DataType.SCALAR && left
instanceof LiteralOp
&&
HopRewriteUtils.getDoubleValue((LiteralOp)left)==2 //mult2
- && ((BinaryOp)hi.getInput().get(1)).getOp() ==
OpOp2.MULT
+ && ((BinaryOp)hi.getInput(1)).getOp() ==
OpOp2.MULT
&& right instanceof AggBinaryOp
- &&
HopRewriteUtils.isSingleBlock(right.getInput().get(0),true) ) //BLOCKSIZE
CONSTRAINT
+ &&
HopRewriteUtils.isSingleBlock(right.getInput(0),true) ) //BLOCKSIZE CONSTRAINT
{
abop = right;
}
if( abop != null ) {
- Hop W = hi.getInput().get(0);
- Hop U = abop.getInput().get(0);
- Hop V = abop.getInput().get(1);
+ Hop W = hi.getInput(0);
+ Hop U = abop.getInput(0);
+ Hop V = abop.getInput(1);
boolean mult =
((BinaryOp)hi).getOp()==OpOp2.MULT;
- OpOp2 op =
((BinaryOp)hi.getInput().get(1)).getOp();
+ OpOp2 op = ((BinaryOp)hi.getInput(1)).getOp();
if( !HopRewriteUtils.isTransposeOperation(V) )
V = HopRewriteUtils.createTranspose(V);
else
- V = V.getInput().get(0);
+ V = V.getInput(0);
hnew = new QuaternaryOp(hi.getName(),
DataType.MATRIX, ValueType.FP64,
OpOp4.WUMM, W, U, V, mult,
null, op);
@@ -2293,40 +2293,40 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//w/o materialization of intermediates
if( hi instanceof AggUnaryOp &&
((AggUnaryOp)hi).getOp()==AggOp.SUM //sum
&& ((AggUnaryOp)hi).getDirection()==Direction.RowCol
//full aggregate
- && hi.getInput().get(0).getDim2() == 1 ) //vector (for
correctness)
+ && hi.getInput(0).getDim2() == 1 ) //vector (for
correctness)
{
Hop baLeft = null;
Hop baRight = null;
- Hop hi2 = hi.getInput().get(0); //check for ^2 w/o
multiple consumers
+ Hop hi2 = hi.getInput(0); //check for ^2 w/o multiple
consumers
//check for sum(v^2), might have been rewritten from
sum(v*v)
if( HopRewriteUtils.isBinary(hi2, OpOp2.POW)
- && hi2.getInput().get(1) instanceof LiteralOp
- &&
HopRewriteUtils.getDoubleValue((LiteralOp)hi2.getInput().get(1))==2
+ && hi2.getInput(1) instanceof LiteralOp
+ &&
HopRewriteUtils.getDoubleValue((LiteralOp)hi2.getInput(1))==2
&& hi2.getParent().size() == 1 ) //no other
consumer than sum
{
- Hop input = hi2.getInput().get(0);
+ Hop input = hi2.getInput(0);
baLeft = input;
baRight = input;
}
//check for sum(v1*v2), but prevent to rewrite
sum(v1*v2*v3) which is later compiled into a ta+* lop
else if( HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1)
//no other consumer than sum
- && hi2.getInput().get(0).getDim2()==1
&& hi2.getInput().get(1).getDim2()==1
- && hi2.getInput().get(0).isMatrix() &&
hi2.getInput().get(1).isMatrix()
- &&
!HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT)
- &&
!HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT)
+ && hi2.getInput(0).getDim2()==1 &&
hi2.getInput(1).getDim2()==1
+ && hi2.getInput(0).isMatrix() &&
hi2.getInput(1).isMatrix()
+ &&
!HopRewriteUtils.isBinary(hi2.getInput(0), OpOp2.MULT)
+ &&
!HopRewriteUtils.isBinary(hi2.getInput(1), OpOp2.MULT)
&& ( !ALLOW_SUM_PRODUCT_REWRITES
- || !(
HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW) // do not
rewrite (A^2)*B
- &&
hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp // let tak+*
handle it
- &&
((LiteralOp)hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2 ))
+ || !(
HopRewriteUtils.isBinary(hi2.getInput(0), OpOp2.POW) // do not rewrite
(A^2)*B
+ &&
hi2.getInput(0).getInput(1) instanceof LiteralOp // let tak+* handle it
+ &&
((LiteralOp)hi2.getInput(0).getInput(1)).getLongValue() == 2 ))
&& ( !ALLOW_SUM_PRODUCT_REWRITES
- || !(
HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW) // do not
rewrite B*(A^2)
- &&
hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp // let tak+*
handle it
- &&
((LiteralOp)hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2 ))
+ || !(
HopRewriteUtils.isBinary(hi2.getInput(1), OpOp2.POW) // do not rewrite
B*(A^2)
+ &&
hi2.getInput(1).getInput(1) instanceof LiteralOp // let tak+* handle it
+ &&
((LiteralOp)hi2.getInput(1).getInput(1)).getLongValue() == 2 ))
)
{
- baLeft = hi2.getInput().get(0);
- baRight = hi2.getInput().get(1);
+ baLeft = hi2.getInput(0);
+ baRight = hi2.getInput(1);
}
//perform actual rewrite (if necessary)
@@ -2362,14 +2362,14 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
private static Hop fuseSumSquared(Hop parent, Hop hi, int pos) {
// if SUM
if (hi instanceof AggUnaryOp && ((AggUnaryOp) hi).getOp() ==
AggOp.SUM) {
- Hop sumInput = hi.getInput().get(0);
+ Hop sumInput = hi.getInput(0);
// if input to SUM is POW(X,2), and no other consumers
of the POW(X,2) HOP
if( HopRewriteUtils.isBinary(sumInput, OpOp2.POW)
- && sumInput.getInput().get(1)
instanceof LiteralOp
- &&
HopRewriteUtils.getDoubleValue((LiteralOp) sumInput.getInput().get(1)) == 2
+ && sumInput.getInput(1) instanceof
LiteralOp
+ &&
HopRewriteUtils.getDoubleValue((LiteralOp) sumInput.getInput(1)) == 2
&& sumInput.getParent().size() == 1) {
- Hop x = sumInput.getInput().get(0);
+ Hop x = sumInput.getInput(0);
// if X is NOT a column vector
if (x.getDim2() > 1) {
@@ -2394,8 +2394,8 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
&& (((BinaryOp)hi).getOp()==OpOp2.PLUS ||
((BinaryOp)hi).getOp()==OpOp2.MINUS) )
{
BinaryOp bop = (BinaryOp) hi;
- Hop left = bop.getInput().get(0);
- Hop right = bop.getInput().get(1);
+ Hop left = bop.getInput(0);
+ Hop right = bop.getInput(1);
Hop ternop = null;
//pattern (a) X + s*Y -> X +* sY
@@ -2404,8 +2404,8 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
&& HopRewriteUtils.isEqualSize(left, right)
&& right.getParent().size() == 1 )
//single consumer s*Y
{
- Hop smid = right.getInput().get(
(right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1);
- Hop mright = right.getInput().get(
(right.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0);
+ Hop smid = right.getInput().get(
(right.getInput(0).getDataType()==DataType.SCALAR) ? 0 : 1);
+ Hop mright = right.getInput().get(
(right.getInput(0).getDataType()==DataType.SCALAR) ? 1 : 0);
ternop = (smid instanceof LiteralOp &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)smid)==0) ?
left :
HopRewriteUtils.createTernary(left, smid, mright, OpOp3.PLUS_MULT);
LOG.debug("Applied
fuseAxpyBinaryOperationChain1. (line " +hi.getBeginLine()+")");
@@ -2416,8 +2416,8 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
&& HopRewriteUtils.isEqualSize(left, right)
&& left.getParent().size() == 1 )
//single consumer s*Y
{
- Hop smid = left.getInput().get(
(left.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1);
- Hop mright = left.getInput().get(
(left.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0);
+ Hop smid = left.getInput().get(
(left.getInput(0).getDataType()==DataType.SCALAR) ? 0 : 1);
+ Hop mright = left.getInput().get(
(left.getInput(0).getDataType()==DataType.SCALAR) ? 1 : 0);
ternop = (smid instanceof LiteralOp &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)smid)==0) ?
right :
HopRewriteUtils.createTernary(right, smid, mright, OpOp3.PLUS_MULT);
LOG.debug("Applied
fuseAxpyBinaryOperationChain2. (line " +hi.getBeginLine()+")");
@@ -2428,8 +2428,8 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
&& HopRewriteUtils.isEqualSize(left, right)
&& right.getParent().size() == 1 )
//single consumer s*Y
{
- Hop smid = right.getInput().get(
(right.getInput().get(0).getDataType()==DataType.SCALAR) ? 0 : 1);
- Hop mright = right.getInput().get(
(right.getInput().get(0).getDataType()==DataType.SCALAR) ? 1 : 0);
+ Hop smid = right.getInput().get(
(right.getInput(0).getDataType()==DataType.SCALAR) ? 0 : 1);
+ Hop mright = right.getInput().get(
(right.getInput(0).getDataType()==DataType.SCALAR) ? 1 : 0);
ternop = (smid instanceof LiteralOp &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)smid)==0) ?
left :
HopRewriteUtils.createTernary(left, smid, mright, OpOp3.MINUS_MULT);
LOG.debug("Applied
fuseAxpyBinaryOperationChain3. (line " +hi.getBeginLine()+")");
@@ -2452,8 +2452,8 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
if( hi instanceof BinaryOp ) //b(?) X Y
{
BinaryOp bop = (BinaryOp) hi;
- Hop left = hi.getInput().get(0);
- Hop right = hi.getInput().get(1);
+ Hop left = hi.getInput(0);
+ Hop right = hi.getInput(1);
if( left.getDataType()==DataType.MATRIX &&
right.getDataType()==DataType.MATRIX )
{
@@ -2532,16 +2532,16 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
{
if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y
{
- Hop hileft = hi.getInput().get(0);
- Hop hiright = hi.getInput().get(1);
+ Hop hileft = hi.getInput(0);
+ Hop hiright = hi.getInput(1);
if( HopRewriteUtils.isBinary(hileft, OpOp2.MINUS)
//X=-Z
- && hileft.getInput().get(0) instanceof
LiteralOp
- &&
HopRewriteUtils.getDoubleValue((LiteralOp)hileft.getInput().get(0))==0.0
- && hi.dimsKnown() &&
hileft.getInput().get(1).dimsKnown() //size comparison
- && HopRewriteUtils.compareSize(hi,
hileft.getInput().get(1)) < 0 )
+ && hileft.getInput(0) instanceof LiteralOp
+ &&
HopRewriteUtils.getDoubleValue((LiteralOp)hileft.getInput(0))==0.0
+ && hi.dimsKnown() &&
hileft.getInput(1).dimsKnown() //size comparison
+ && HopRewriteUtils.compareSize(hi,
hileft.getInput(1)) < 0 )
{
- Hop hi2 = hileft.getInput().get(1);
+ Hop hi2 = hileft.getInput(1);
//remove link from matrixmult to minus
HopRewriteUtils.removeChildReference(hi,
hileft);
@@ -2570,12 +2570,12 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
LOG.debug("Applied reorderMinusMatrixMult (line
"+hi.getBeginLine()+").");
}
else if( HopRewriteUtils.isBinary(hiright, OpOp2.MINUS)
//X=-Z
- && hiright.getInput().get(0) instanceof
LiteralOp
- &&
HopRewriteUtils.getDoubleValue((LiteralOp)hiright.getInput().get(0))==0.0
- && hi.dimsKnown() &&
hiright.getInput().get(1).dimsKnown() //size comparison
- && HopRewriteUtils.compareSize(hi,
hiright.getInput().get(1)) < 0 )
+ && hiright.getInput(0) instanceof
LiteralOp
+ &&
HopRewriteUtils.getDoubleValue((LiteralOp)hiright.getInput(0))==0.0
+ && hi.dimsKnown() &&
hiright.getInput(1).dimsKnown() //size comparison
+ && HopRewriteUtils.compareSize(hi,
hiright.getInput(1)) < 0 )
{
- Hop hi2 = hiright.getInput().get(1);
+ Hop hi2 = hiright.getInput(1);
//remove link from matrixmult to minus
HopRewriteUtils.removeChildReference(hi,
hiright);
@@ -2617,13 +2617,13 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//-- if not dot product, not applied since aggregate removed
//-- if sum not the only consumer, not applied to prevent
redundancy
if( hi instanceof AggUnaryOp &&
((AggUnaryOp)hi).getOp()==AggOp.SUM //sum
- && hi.getInput().get(0) instanceof AggBinaryOp
//A%*%B
- && (hi.getInput().get(0).getDim1()>1 ||
hi.getInput().get(0).getDim2()>1) //not dot product
- && hi.getInput().get(0).getParent().size()==1 )
//not multiple consumers of matrix mult
+ && hi.getInput(0) instanceof AggBinaryOp
//A%*%B
+ && (hi.getInput(0).getDim1()>1 ||
hi.getInput(0).getDim2()>1) //not dot product
+ && hi.getInput(0).getParent().size()==1 ) //not
multiple consumers of matrix mult
{
- Hop hi2 = hi.getInput().get(0);
- Hop left = hi2.getInput().get(0);
- Hop right = hi2.getInput().get(1);
+ Hop hi2 = hi.getInput(0);
+ Hop left = hi2.getInput(0);
+ Hop right = hi2.getInput(1);
//remove link from parent to matrix mult
HopRewriteUtils.removeChildReference(hi, hi2);
@@ -2665,10 +2665,10 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
private static Hop simplifyScalarMVBinaryOperation(Hop hi)
{
if( hi instanceof BinaryOp &&
((BinaryOp)hi).supportsMatrixScalarOperations() //e.g., X * s
- && hi.getInput().get(0).getDataType()==DataType.MATRIX
- && hi.getInput().get(1).getDataType()==DataType.MATRIX
)
+ && hi.getInput(0).getDataType()==DataType.MATRIX
+ && hi.getInput(1).getDataType()==DataType.MATRIX )
{
- Hop right = hi.getInput().get(1);
+ Hop right = hi.getInput(1);
//X * s -> X * as.scalar(s)
if( HopRewriteUtils.isDimsKnown(right) &&
right.getDim1()==1 && right.getDim2()==1 ) //scalar right
@@ -2689,19 +2689,19 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known
if( hi instanceof AggUnaryOp &&
((AggUnaryOp)hi).getOp()==AggOp.SUM //sum
&& ((AggUnaryOp)hi).getDirection() == Direction.RowCol
//full aggregate
- && HopRewriteUtils.isBinary(hi.getInput().get(0),
OpOp2.NOTEQUAL) )
+ && HopRewriteUtils.isBinary(hi.getInput(0),
OpOp2.NOTEQUAL) )
{
- Hop ppred = hi.getInput().get(0);
+ Hop ppred = hi.getInput(0);
Hop X = null;
- if( ppred.getInput().get(0) instanceof LiteralOp
- &&
HopRewriteUtils.getDoubleValue((LiteralOp)ppred.getInput().get(0))==0 )
+ if( ppred.getInput(0) instanceof LiteralOp
+ &&
HopRewriteUtils.getDoubleValue((LiteralOp)ppred.getInput(0))==0 )
{
- X = ppred.getInput().get(1);
+ X = ppred.getInput(1);
}
- else if( ppred.getInput().get(1) instanceof LiteralOp
- &&
HopRewriteUtils.getDoubleValue((LiteralOp)ppred.getInput().get(1))==0 )
+ else if( ppred.getInput(1) instanceof LiteralOp
+ &&
HopRewriteUtils.getDoubleValue((LiteralOp)ppred.getInput(1))==0 )
{
- X = ppred.getInput().get(0);
+ X = ppred.getInput(0);
}
//apply rewrite if known nnz
@@ -2725,16 +2725,16 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//even if the intermediate is otherwise not required, e.g.,
when part of a fused operator)
if( hi instanceof UnaryOp )
{
- if( ((UnaryOp)hi).getOp()==OpOp1.NROW &&
hi.getInput().get(0).rowsKnown() ) {
- Hop hnew = new
LiteralOp(hi.getInput().get(0).getDim1());
+ if( ((UnaryOp)hi).getOp()==OpOp1.NROW &&
hi.getInput(0).rowsKnown() ) {
+ Hop hnew = new
LiteralOp(hi.getInput(0).getDim1());
HopRewriteUtils.replaceChildReference(parent,
hi, hnew, pos, false);
HopRewriteUtils.cleanupUnreferenced(hi);
LOG.debug("Applied simplifyNrowComputation
nrow("+hi.getHopID()+") -> "
+ hnew.getName()+" (line
"+hi.getBeginLine()+").");
hi = hnew;
}
- else if( ((UnaryOp)hi).getOp()==OpOp1.NCOL &&
hi.getInput().get(0).colsKnown() ) {
- Hop hnew = new
LiteralOp(hi.getInput().get(0).getDim2());
+ else if( ((UnaryOp)hi).getOp()==OpOp1.NCOL &&
hi.getInput(0).colsKnown() ) {
+ Hop hnew = new
LiteralOp(hi.getInput(0).getDim2());
HopRewriteUtils.replaceChildReference(parent,
hi, hnew, pos, false);
HopRewriteUtils.cleanupUnreferenced(hi);
LOG.debug("Applied simplifyNcolComputation
ncol("+hi.getHopID()+") -> "
@@ -2751,14 +2751,14 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//pattern: table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v,
max=m, dir=row, ignore=false, cast=true)
//note: this rewrite supports both left/right sequence
if( hi instanceof TernaryOp && hi.getInput().size()==6
//table without weights
- &&
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(2), 1) ) //i.e., weight of 1
+ && HopRewriteUtils.isLiteralOfValue(hi.getInput(2), 1)
) //i.e., weight of 1
{
- Hop first = hi.getInput().get(0);
- Hop second = hi.getInput().get(1);
+ Hop first = hi.getInput(0);
+ Hop second = hi.getInput(1);
//pattern a: table(seq(1,nrow(v)), v, nrow(v), m, 1)
if( HopRewriteUtils.isBasic1NSequence(first, second,
true)
- &&
HopRewriteUtils.isSizeExpressionOf(hi.getInput().get(3), second, true) )
+ &&
HopRewriteUtils.isSizeExpressionOf(hi.getInput(3), second, true) )
{
//setup input parameter hops
LinkedHashMap<String,Hop> args = new
LinkedHashMap<>();
@@ -2784,7 +2784,7 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
//setup input parameter hops
LinkedHashMap<String,Hop> args = new
LinkedHashMap<>();
args.put("target", first);
- args.put("max", hi.getInput().get(3));
+ args.put("max", hi.getInput(3));
args.put("dir", new LiteralOp("rows"));
args.put("ignore", new LiteralOp(false));
args.put("cast", new LiteralOp(true));
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index d3c8757c13..72aa05ad16 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -214,8 +214,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
&&
((BinaryOp)hi).supportsMatrixScalarOperations() )
{
BinaryOp bop = (BinaryOp)hi;
- Hop left = bop.getInput().get(0);
- Hop right = bop.getInput().get(1);
+ Hop left = bop.getInput(0);
+ Hop right = bop.getInput(1);
//NOTE: these rewrites of binary cell operations need
to be aware that right is
//potentially a vector but the result is of the size of
left
@@ -279,8 +279,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
- Hop left = bop.getInput().get(0);
- Hop right = bop.getInput().get(1);
+ Hop left = bop.getInput(0);
+ Hop right = bop.getInput(1);
//X/1 or X*1 -> X
if( left.getDataType()==DataType.MATRIX
&& right instanceof LiteralOp &&
right.getValueType().isNumeric()
@@ -359,13 +359,13 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
public static Hop simplifyConstantConjunction(Hop parent, Hop hi, int
pos) {
if (hi instanceof BinaryOp) {
BinaryOp bop = (BinaryOp) hi;
- Hop left = hi.getInput().get(0);
- Hop right = hi.getInput().get(1);
+ Hop left = hi.getInput(0);
+ Hop right = hi.getInput(1);
// Patterns: a & !a --> FALSE / !a & a --> FALSE
if (bop.getOp() == OpOp2.AND
- && ((HopRewriteUtils.isUnary(right, OpOp1.NOT)
&& left == right.getInput().get(0))
- || (HopRewriteUtils.isUnary(left, OpOp1.NOT) &&
left.getInput().get(0) == right)))
+ && ((HopRewriteUtils.isUnary(right, OpOp1.NOT)
&& left == right.getInput(0))
+ || (HopRewriteUtils.isUnary(left, OpOp1.NOT) &&
left.getInput(0) == right)))
{
LiteralOp falseOp = new LiteralOp(false);
@@ -380,8 +380,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
}
// Pattern: a | !a --> TRUE
else if (bop.getOp() == OpOp2.OR
- && ((HopRewriteUtils.isUnary(right, OpOp1.NOT)
&& left == right.getInput().get(0))
- || (HopRewriteUtils.isUnary(left, OpOp1.NOT) &&
left.getInput().get(0) == right)))
+ && ((HopRewriteUtils.isUnary(right, OpOp1.NOT)
&& left == right.getInput(0))
+ || (HopRewriteUtils.isUnary(left, OpOp1.NOT) &&
left.getInput(0) == right)))
{
LiteralOp trueOp = new LiteralOp(true);
@@ -415,8 +415,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
- Hop left = bop.getInput().get(0);
- Hop right = bop.getInput().get(1);
+ Hop left = bop.getInput(0);
+ Hop right = bop.getInput(1);
//NOTE: rewrite not applied if more than one datagen
consumer because this would lead to
//the creation of multiple datagen ops and thus
potentially different results if seed not specified)
@@ -542,8 +542,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
- Hop left = bop.getInput().get(0);
- Hop right = bop.getInput().get(1);
+ Hop left = bop.getInput(0);
+ Hop right = bop.getInput(1);
if( right instanceof DataGenOp &&
((DataGenOp)right).getOp()==OpOpDG.RAND &&
left instanceof LiteralOp &&
((LiteralOp)left).getDoubleValue()==0.0 )
@@ -649,8 +649,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
- Hop left = hi.getInput().get(0);
- Hop right = hi.getInput().get(1);
+ Hop left = hi.getInput(0);
+ Hop right = hi.getInput(1);
//patterns: X+X -> X*2, X*X -> X^2,
if( left == right &&
left.getDataType()==DataType.MATRIX )
@@ -676,13 +676,13 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
else if( bop.getOp() == OpOp2.MINUS
&& HopRewriteUtils.isBinary(left,
OpOp2.GREATER)
&& HopRewriteUtils.isBinary(right,
OpOp2.LESS)
- && left.getInput().get(0) ==
right.getInput().get(0)
- && left.getInput().get(1) instanceof
LiteralOp
- &&
HopRewriteUtils.getDoubleValue((LiteralOp)left.getInput().get(1))==0
- && right.getInput().get(1) instanceof
LiteralOp
- &&
HopRewriteUtils.getDoubleValue((LiteralOp)right.getInput().get(1))==0 )
+ && left.getInput(0) == right.getInput(0)
+ && left.getInput(1) instanceof LiteralOp
+ &&
HopRewriteUtils.getDoubleValue((LiteralOp)left.getInput(1))==0
+ && right.getInput(1) instanceof
LiteralOp
+ &&
HopRewriteUtils.getDoubleValue((LiteralOp)right.getInput(1))==0 )
{
- UnaryOp uop =
HopRewriteUtils.createUnary(left.getInput().get(0), OpOp1.SIGN);
+ UnaryOp uop =
HopRewriteUtils.createUnary(left.getInput(0), OpOp1.SIGN);
HopRewriteUtils.replaceChildReference(parent,
hi, uop, pos);
HopRewriteUtils.cleanupUnreferenced(hi, left,
right);
hi = uop;
@@ -708,8 +708,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
- Hop left = hi.getInput().get(0);
- Hop right = hi.getInput().get(1);
+ Hop left = hi.getInput(0);
+ Hop right = hi.getInput(1);
//pattern: (eps + U%*%V) -> (U%*%V+eps)
if( left.getDataType().isScalar() && right instanceof
AggBinaryOp
@@ -754,14 +754,14 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
private static Hop removeUnnecessaryCTable( Hop parent, Hop hi, int pos
) {
if ( HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM,
Direction.RowCol)
- &&
HopRewriteUtils.isTernary(hi.getInput().get(0), OpOp3.CTABLE)
- &&
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(0).getInput().get(2), 1.0))
+ && HopRewriteUtils.isTernary(hi.getInput(0),
OpOp3.CTABLE)
+ &&
HopRewriteUtils.isLiteralOfValue(hi.getInput(0).getInput(2), 1.0))
{
- Hop matrixInput =
hi.getInput().get(0).getInput().get(0);
+ Hop matrixInput = hi.getInput(0).getInput(0);
OpOp1 opcode = matrixInput.getDim2() == 1 ? OpOp1.NROW
: OpOp1.LENGTH;
Hop newOpLength = new UnaryOp("tmp", DataType.SCALAR,
ValueType.INT64, opcode, matrixInput);
HopRewriteUtils.replaceChildReference(parent, hi,
newOpLength, pos);
- HopRewriteUtils.cleanupUnreferenced(hi,
hi.getInput().get(0));
+ HopRewriteUtils.cleanupUnreferenced(hi, hi.getInput(0));
hi = newOpLength;
}
return hi;
@@ -780,16 +780,16 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
private static Hop simplifyReverseOperation( Hop parent, Hop hi, int
pos )
{
if( hi instanceof AggBinaryOp
- && hi.getInput().get(0) instanceof TernaryOp )
+ && hi.getInput(0) instanceof TernaryOp )
{
- TernaryOp top = (TernaryOp) hi.getInput().get(0);
+ TernaryOp top = (TernaryOp) hi.getInput(0);
if( top.getOp()==OpOp3.CTABLE
- &&
HopRewriteUtils.isBasic1NSequence(top.getInput().get(0))
- &&
HopRewriteUtils.isBasicN1Sequence(top.getInput().get(1))
- &&
top.getInput().get(0).getDim1()==top.getInput().get(1).getDim1())
+ &&
HopRewriteUtils.isBasic1NSequence(top.getInput(0))
+ &&
HopRewriteUtils.isBasicN1Sequence(top.getInput(1))
+ &&
top.getInput(0).getDim1()==top.getInput(1).getDim1())
{
- ReorgOp rop =
HopRewriteUtils.createReorg(hi.getInput().get(1), ReOrgOp.REV);
+ ReorgOp rop =
HopRewriteUtils.createReorg(hi.getInput(1), ReOrgOp.REV);
HopRewriteUtils.replaceChildReference(parent,
hi, rop, pos);
HopRewriteUtils.cleanupUnreferenced(hi, top);
hi = rop;
@@ -828,14 +828,14 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
//pattern: 1-(X*Y) --> X 1-* Y (avoid intermediate)
if( HopRewriteUtils.isBinary(hi, OpOp2.MINUS)
&& hi.getDataType() == DataType.MATRIX
- && hi.getInput().get(0) instanceof LiteralOp
- &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)hi.getInput().get(0))==1
- &&
HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT)
- && hi.getInput().get(1).getParent().size() == 1
) //single consumer
+ && hi.getInput(0) instanceof LiteralOp
+ &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)hi.getInput(0))==1
+ && HopRewriteUtils.isBinary(hi.getInput(1),
OpOp2.MULT)
+ && hi.getInput(1).getParent().size() == 1 )
//single consumer
{
BinaryOp bop = (BinaryOp)hi;
- Hop left = hi.getInput().get(1).getInput().get(0);
- Hop right = hi.getInput().get(1).getInput().get(1);
+ Hop left = hi.getInput(1).getInput(0);
+ Hop right = hi.getInput(1).getInput(1);
//set new binaryop type and rewire inputs
bop.setOp(OpOp2.MINUS1_MULT);
@@ -864,8 +864,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
- Hop left = bop.getInput().get(0);
- Hop right = bop.getInput().get(1);
+ Hop left = bop.getInput(0);
+ Hop right = bop.getInput(1);
//(X+Y*X) -> (1+Y)*X, (Y*X+X) -> (Y+1)*X
//(X-Y*X) -> (1-Y)*X, (Y*X-X) -> (Y-1)*X
@@ -876,8 +876,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
Hop X = null; Hop Y = null;
if( HopRewriteUtils.isBinary(left, OpOp2.MULT)
) //(Y*X-X) -> (Y-1)*X
{
- Hop leftC1 = left.getInput().get(0);
- Hop leftC2 = left.getInput().get(1);
+ Hop leftC1 = left.getInput(0);
+ Hop leftC2 = left.getInput(1);
if(
leftC1.getDataType()==DataType.MATRIX && leftC2.getDataType()==DataType.MATRIX
&&
(right == leftC1 ||
right == leftC2) && leftC1 !=leftC2 ){ //any mult order
@@ -899,8 +899,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
if( !applied && HopRewriteUtils.isBinary(right,
OpOp2.MULT) ) //(X-Y*X) -> (1-Y)*X
{
- Hop rightC1 = right.getInput().get(0);
- Hop rightC2 = right.getInput().get(1);
+ Hop rightC1 = right.getInput(0);
+ Hop rightC2 = right.getInput(1);
if(
rightC1.getDataType()==DataType.MATRIX &&
rightC2.getDataType()==DataType.MATRIX &&
(left == rightC1 ||
left == rightC2) && rightC1 !=rightC2 ){ //any mult order
X = left;
@@ -941,8 +941,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
if( hi instanceof BinaryOp && parent instanceof AggBinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
- Hop left = bop.getInput().get(0);
- Hop right = bop.getInput().get(1);
+ Hop left = bop.getInput(0);
+ Hop right = bop.getInput(1);
OpOp2 op = bop.getOp();
if( left.getDataType()==DataType.MATRIX &&
right.getDataType()==DataType.MATRIX &&
@@ -953,8 +953,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
if( right instanceof BinaryOp )
{
BinaryOp bop2 = (BinaryOp)right;
- Hop left2 = bop2.getInput().get(0);
- Hop right2 = bop2.getInput().get(1);
+ Hop left2 = bop2.getInput(0);
+ Hop right2 = bop2.getInput(1);
OpOp2 op2 = bop2.getOp();
if( op==op2 &&
right2.getDataType()==DataType.MATRIX
@@ -976,8 +976,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
if( !applied && left instanceof BinaryOp )
{
BinaryOp bop2 = (BinaryOp)left;
- Hop left2 = bop2.getInput().get(0);
- Hop right2 = bop2.getInput().get(1);
+ Hop left2 = bop2.getInput(0);
+ Hop right2 = bop2.getInput(1);
OpOp2 op2 = bop2.getOp();
if( op==op2 &&
left2.getDataType()==DataType.MATRIX
@@ -1006,13 +1006,13 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
{
if( hi instanceof AggUnaryOp &&
((AggUnaryOp)hi).getDirection()==Direction.RowCol
&& ((AggUnaryOp)hi).getOp() != AggOp.TRACE //full
uagg
- && hi.getInput().get(0) instanceof ReorgOp ) //reorg
operation
+ && hi.getInput(0) instanceof ReorgOp ) //reorg
operation
{
- ReorgOp rop = (ReorgOp)hi.getInput().get(0);
+ ReorgOp rop = (ReorgOp)hi.getInput(0);
if( rop.getOp().preservesValues() //valid reorg
&& rop.getParent().size()==1 ) //uagg only
reorg consumer
{
- Hop input = rop.getInput().get(0);
+ Hop input = rop.getInput(0);
HopRewriteUtils.removeAllChildReferences(hi);
HopRewriteUtils.removeAllChildReferences(rop);
HopRewriteUtils.addChildReference(hi, input);
@@ -1028,17 +1028,17 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
//min(rowMins(X)) -> min(X), min(colMins(X)) -> min(X)
//max(rowMaxs(X)) -> max(X), max(colMaxs(X)) -> max(X)
//sum(rowSums(X^2)) -> sum(X), sum(colSums(X^2)) -> sum(X)
- if( hi instanceof AggUnaryOp && hi.getInput().get(0) instanceof
AggUnaryOp
+ if( hi instanceof AggUnaryOp && hi.getInput(0) instanceof
AggUnaryOp
&&
((AggUnaryOp)hi).getDirection()==Direction.RowCol
- && hi.getInput().get(0).getParent().size()==1 )
+ && hi.getInput(0).getParent().size()==1 )
{
AggUnaryOp au1 = (AggUnaryOp) hi;
- AggUnaryOp au2 = (AggUnaryOp) hi.getInput().get(0);
+ AggUnaryOp au2 = (AggUnaryOp) hi.getInput(0);
if( (au1.getOp()==AggOp.SUM && (au2.getOp()==AggOp.SUM
|| au2.getOp()==AggOp.SUM_SQ))
|| (au1.getOp()==AggOp.MIN &&
au2.getOp()==AggOp.MIN)
|| (au1.getOp()==AggOp.MAX &&
au2.getOp()==AggOp.MAX) )
{
- Hop input = au2.getInput().get(0);
+ Hop input = au2.getInput(0);
HopRewriteUtils.removeAllChildReferences(au2);
HopRewriteUtils.replaceChildReference(au1, au2,
input);
if( au2.getOp() == AggOp.SUM_SQ )
@@ -1058,28 +1058,28 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
// operations; other operations like cbind/rbind will never
occur as matrix-scalar operations.
if( HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR)
- && hi.getInput().get(0) instanceof BinaryOp
- &&
HopRewriteUtils.isBinary(hi.getInput().get(0), LOOKUP_VALID_SCALAR_BINARY))
+ && hi.getInput(0) instanceof BinaryOp
+ && HopRewriteUtils.isBinary(hi.getInput(0),
LOOKUP_VALID_SCALAR_BINARY))
{
- BinaryOp bin = (BinaryOp) hi.getInput().get(0);
+ BinaryOp bin = (BinaryOp) hi.getInput(0);
BinaryOp bout = null;
//as.scalar(X*Y) -> as.scalar(X) * as.scalar(Y)
- if( bin.getInput().get(0).getDataType()==DataType.MATRIX
- &&
bin.getInput().get(1).getDataType()==DataType.MATRIX ) {
- UnaryOp cast1 =
HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR);
- UnaryOp cast2 =
HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR);
+ if( bin.getInput(0).getDataType()==DataType.MATRIX
+ &&
bin.getInput(1).getDataType()==DataType.MATRIX ) {
+ UnaryOp cast1 =
HopRewriteUtils.createUnary(bin.getInput(0), OpOp1.CAST_AS_SCALAR);
+ UnaryOp cast2 =
HopRewriteUtils.createUnary(bin.getInput(1), OpOp1.CAST_AS_SCALAR);
bout = HopRewriteUtils.createBinary(cast1,
cast2, bin.getOp());
}
//as.scalar(X*s) -> as.scalar(X) * s
- else if(
bin.getInput().get(0).getDataType()==DataType.MATRIX ) {
- UnaryOp cast =
HopRewriteUtils.createUnary(bin.getInput().get(0), OpOp1.CAST_AS_SCALAR);
- bout = HopRewriteUtils.createBinary(cast,
bin.getInput().get(1), bin.getOp());
+ else if( bin.getInput(0).getDataType()==DataType.MATRIX
) {
+ UnaryOp cast =
HopRewriteUtils.createUnary(bin.getInput(0), OpOp1.CAST_AS_SCALAR);
+ bout = HopRewriteUtils.createBinary(cast,
bin.getInput(1), bin.getOp());
}
//as.scalar(s*X) -> s * as.scalar(X)
- else if (
bin.getInput().get(1).getDataType()==DataType.MATRIX ) {
- UnaryOp cast =
HopRewriteUtils.createUnary(bin.getInput().get(1), OpOp1.CAST_AS_SCALAR);
- bout =
HopRewriteUtils.createBinary(bin.getInput().get(0), cast, bin.getOp());
+ else if (
bin.getInput(1).getDataType()==DataType.MATRIX ) {
+ UnaryOp cast =
HopRewriteUtils.createUnary(bin.getInput(1), OpOp1.CAST_AS_SCALAR);
+ bout =
HopRewriteUtils.createBinary(bin.getInput(0), cast, bin.getOp());
}
if( bout != null ) {
@@ -1096,14 +1096,14 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
{
if( hi instanceof AggUnaryOp && hi.getParent().size()==1
&& (((AggUnaryOp)
hi).getDirection()==Direction.Row || ((AggUnaryOp)
hi).getDirection()==Direction.Col)
- &&
HopRewriteUtils.isTransposeOperation(hi.getInput().get(0), 1)
+ &&
HopRewriteUtils.isTransposeOperation(hi.getInput(0), 1)
&& HopRewriteUtils.isValidOp(((AggUnaryOp)
hi).getOp(), LOOKUP_VALID_ROW_COL_AGGREGATE) )
{
AggUnaryOp uagg = (AggUnaryOp) hi;
//get input rewire existing operators (remove inner
transpose)
- Hop input = uagg.getInput().get(0).getInput().get(0);
-
HopRewriteUtils.removeAllChildReferences(hi.getInput().get(0));
+ Hop input = uagg.getInput(0).getInput(0);
+
HopRewriteUtils.removeAllChildReferences(hi.getInput(0));
HopRewriteUtils.removeAllChildReferences(hi);
HopRewriteUtils.removeChildReferenceByPos(parent, hi,
pos);
@@ -1135,12 +1135,12 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
// probed at root node of b in above example
// (with support for left or right scalar operations)
if( HopRewriteUtils.isTransposeOperation(hi, 1)
- &&
HopRewriteUtils.isBinaryMatrixScalarOperation(hi.getInput().get(0))
- && hi.getInput().get(0).getParent().size()==1)
+ &&
HopRewriteUtils.isBinaryMatrixScalarOperation(hi.getInput(0))
+ && hi.getInput(0).getParent().size()==1)
{
- int Xpos =
hi.getInput().get(0).getInput().get(0).getDataType().isMatrix() ? 0 : 1;
- Hop X = hi.getInput().get(0).getInput().get(Xpos);
- BinaryOp binary = (BinaryOp) hi.getInput().get(0);
+ int Xpos =
hi.getInput(0).getInput(0).getDataType().isMatrix() ? 0 : 1;
+ Hop X = hi.getInput(0).getInput().get(Xpos);
+ BinaryOp binary = (BinaryOp) hi.getInput(0);
if(
HopRewriteUtils.containsTransposeOperation(X.getParent())
&&
!HopRewriteUtils.isValidOp(binary.getOp(), new OpOp2[]{OpOp2.MOMENT,
OpOp2.QUANTILE}))
@@ -1168,12 +1168,12 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
//pattern: sum(lamda*X) -> lamda*sum(X)
if( hi instanceof AggUnaryOp &&
((AggUnaryOp)hi).getDirection()==Direction.RowCol
&& ((AggUnaryOp)hi).getOp()==AggOp.SUM // only
one parent which is the sum
- &&
HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.MULT, 1)
- &&
((hi.getInput().get(0).getInput().get(0).getDataType()==DataType.SCALAR &&
hi.getInput().get(0).getInput().get(1).getDataType()==DataType.MATRIX)
-
||(hi.getInput().get(0).getInput().get(0).getDataType()==DataType.MATRIX &&
hi.getInput().get(0).getInput().get(1).getDataType()==DataType.SCALAR)))
+ && HopRewriteUtils.isBinary(hi.getInput(0),
OpOp2.MULT, 1)
+ &&
((hi.getInput(0).getInput(0).getDataType()==DataType.SCALAR &&
hi.getInput(0).getInput(1).getDataType()==DataType.MATRIX)
+
||(hi.getInput(0).getInput(0).getDataType()==DataType.MATRIX &&
hi.getInput(0).getInput(1).getDataType()==DataType.SCALAR)))
{
- Hop operand1 = hi.getInput().get(0).getInput().get(0);
- Hop operand2 = hi.getInput().get(0).getInput().get(1);
+ Hop operand1 = hi.getInput(0).getInput(0);
+ Hop operand2 = hi.getInput(0).getInput(1);
//check which operand is the Scalar and which is the
matrix
Hop lamda = (operand1.getDataType()==DataType.SCALAR) ?
operand1 : operand2;
@@ -1212,15 +1212,15 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
private static Hop simplifyUnaryPPredOperation( Hop parent, Hop hi, int
pos )
{
if( hi instanceof UnaryOp && hi.getDataType()==DataType.MATRIX
//unaryop
- && hi.getInput().get(0) instanceof BinaryOp
//binaryop - ppred
- &&
((BinaryOp)hi.getInput().get(0)).isPPredOperation() )
+ && hi.getInput(0) instanceof BinaryOp
//binaryop - ppred
+ &&
((BinaryOp)hi.getInput(0)).isPPredOperation() )
{
UnaryOp uop = (UnaryOp) hi; //valid unary op
if( uop.getOp()==OpOp1.ABS || uop.getOp()==OpOp1.SIGN
|| uop.getOp()==OpOp1.CEIL ||
uop.getOp()==OpOp1.FLOOR || uop.getOp()==OpOp1.ROUND )
{
//clear link unary-binary
- Hop input = uop.getInput().get(0);
+ Hop input = uop.getInput(0);
HopRewriteUtils.replaceChildReference(parent,
hi, input, pos);
HopRewriteUtils.cleanupUnreferenced(hi);
hi = input;
@@ -1236,18 +1236,18 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
{
//e.g., t(cbind(t(A),t(B))) --> rbind(A,B), t(rbind(t(A),t(B)))
--> cbind(A,B)
if( HopRewriteUtils.isTransposeOperation(hi) //t() rooted
- && hi.getInput().get(0) instanceof BinaryOp
- &&
(((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.CBIND //append (cbind/rbind)
- ||
((BinaryOp)hi.getInput().get(0)).getOp()==OpOp2.RBIND)
- && hi.getInput().get(0).getParent().size() == 1
) //single consumer of append
+ && hi.getInput(0) instanceof BinaryOp
+ &&
(((BinaryOp)hi.getInput(0)).getOp()==OpOp2.CBIND //append (cbind/rbind)
+ ||
((BinaryOp)hi.getInput(0)).getOp()==OpOp2.RBIND)
+ && hi.getInput(0).getParent().size() == 1 )
//single consumer of append
{
- BinaryOp bop = (BinaryOp)hi.getInput().get(0);
+ BinaryOp bop = (BinaryOp)hi.getInput(0);
//both inputs transpose ops, where transpose is single
consumer
- if(
HopRewriteUtils.isTransposeOperation(bop.getInput().get(0), 1)
- &&
HopRewriteUtils.isTransposeOperation(bop.getInput().get(1), 1) )
+ if(
HopRewriteUtils.isTransposeOperation(bop.getInput(0), 1)
+ &&
HopRewriteUtils.isTransposeOperation(bop.getInput(1), 1) )
{
- Hop left =
bop.getInput().get(0).getInput().get(0);
- Hop right =
bop.getInput().get(1).getInput().get(0);
+ Hop left = bop.getInput(0).getInput(0);
+ Hop right = bop.getInput(1).getInput(0);
//create new subdag (no in-place dag update to
prevent anomalies with
//multiple consumers during rewrite process)
@@ -1279,8 +1279,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
- Hop left = hi.getInput().get(0);
- Hop right = hi.getInput().get(1);
+ Hop left = hi.getInput(0);
+ Hop right = hi.getInput(1);
boolean applied = false;
//sample proportion (sprop) operator
@@ -1294,8 +1294,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
if( left instanceof BinaryOp ) //(1-X)*X
{
BinaryOp bleft = (BinaryOp)left;
- Hop left1 = bleft.getInput().get(0);
- Hop left2 = bleft.getInput().get(1);
+ Hop left1 = bleft.getInput(0);
+ Hop left2 = bleft.getInput(1);
if( left1 instanceof LiteralOp &&
HopRewriteUtils.getDoubleValue((LiteralOp)left1)==1 &&
@@ -1313,8 +1313,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
if( !applied && right instanceof BinaryOp )
//X*(1-X)
{
BinaryOp bright = (BinaryOp)right;
- Hop right1 = bright.getInput().get(0);
- Hop right2 = bright.getInput().get(1);
+ Hop right1 = bright.getInput(0);
+ Hop right2 = bright.getInput(1);
if( right1 instanceof LiteralOp &&
HopRewriteUtils.getDoubleValue((LiteralOp)right1)==1 &&
@@ -1340,14 +1340,14 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
//i.e., we still fuse but leave the
intermediate for the other consumers
BinaryOp bop2 = (BinaryOp)right;
- Hop left2 = bop2.getInput().get(0);
- Hop right2 = bop2.getInput().get(1);
+ Hop left2 = bop2.getInput(0);
+ Hop right2 = bop2.getInput(1);
if( bop2.getOp() == OpOp2.PLUS &&
left2.getDataType()==DataType.SCALAR && right2.getDataType()==DataType.MATRIX
&& left2 instanceof LiteralOp
&& HopRewriteUtils.getDoubleValue((LiteralOp)left2)==1 && right2 instanceof
UnaryOp)
{
UnaryOp uop = (UnaryOp) right2;
- Hop uopin = uop.getInput().get(0);
+ Hop uopin = uop.getInput(0);
if( uop.getOp()==OpOp1.EXP )
{
@@ -1356,8 +1356,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
//Pattern 1: (1/(1 + exp(-X))
if(
HopRewriteUtils.isBinary(uopin, OpOp2.MINUS) ) {
BinaryOp bop3 =
(BinaryOp) uopin;
- Hop left3 =
bop3.getInput().get(0);
- Hop right3 =
bop3.getInput().get(1);
+ Hop left3 =
bop3.getInput(0);
+ Hop right3 =
bop3.getInput(1);
if( left3 instanceof
LiteralOp && HopRewriteUtils.getDoubleValue((LiteralOp)left3)==0 )
unary =
HopRewriteUtils.createUnary(right3, OpOp1.SIGMOID);
@@ -1390,8 +1390,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
if( left instanceof BinaryOp ) //(X>0)*X
{
BinaryOp bleft = (BinaryOp)left;
- Hop left1 = bleft.getInput().get(0);
- Hop left2 = bleft.getInput().get(1);
+ Hop left1 = bleft.getInput(0);
+ Hop left2 = bleft.getInput(1);
if( left2 instanceof LiteralOp &&
HopRewriteUtils.getDoubleValue((LiteralOp)left2)==0 &&
@@ -1409,8 +1409,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
if( !applied && right instanceof BinaryOp )
//X*(X>0)
{
BinaryOp bright = (BinaryOp)right;
- Hop right1 = bright.getInput().get(0);
- Hop right2 = bright.getInput().get(1);
+ Hop right1 = bright.getInput(0);
+ Hop right2 = bright.getInput(1);
if( right2 instanceof LiteralOp &&
HopRewriteUtils.getDoubleValue((LiteralOp)right2)==0 &&
@@ -1435,11 +1435,11 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
{
if( hi instanceof AggUnaryOp &&
((AggUnaryOp)hi).getOp()==AggOp.TRACE ) //trace()
{
- Hop hi2 = hi.getInput().get(0);
+ Hop hi2 = hi.getInput(0);
if( HopRewriteUtils.isMatrixMultiply(hi2) ) //X%*%Y
{
- Hop left = hi2.getInput().get(0);
- Hop right = hi2.getInput().get(1);
+ Hop left = hi2.getInput(0);
+ Hop right = hi2.getInput(1);
//create new operators (incl refresh size
inside for transpose)
ReorgOp trans =
HopRewriteUtils.createTranspose(right);
@@ -1464,14 +1464,14 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
if( hi instanceof IndexingOp
&& ((IndexingOp)hi).isRowLowerEqualsUpper()
&& ((IndexingOp)hi).isColLowerEqualsUpper()
- && hi.getInput().get(0).getParent().size()==1
//rix is single mm consumer
- &&
HopRewriteUtils.isMatrixMultiply(hi.getInput().get(0)) )
+ && hi.getInput(0).getParent().size()==1 //rix
is single mm consumer
+ &&
HopRewriteUtils.isMatrixMultiply(hi.getInput(0)) )
{
- Hop mm = hi.getInput().get(0);
- Hop X = mm.getInput().get(0);
- Hop Y = mm.getInput().get(1);
- Hop rowExpr = hi.getInput().get(1); //rl==ru
- Hop colExpr = hi.getInput().get(3); //cl==cu
+ Hop mm = hi.getInput(0);
+ Hop X = mm.getInput(0);
+ Hop Y = mm.getInput(1);
+ Hop rowExpr = hi.getInput(1); //rl==ru
+ Hop colExpr = hi.getInput(3); //cl==cu
HopRewriteUtils.removeAllChildReferences(mm);
@@ -1520,7 +1520,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
&& hi.getInput(0).isMatrix() //no frame support yet
&& !HopRewriteUtils.isData(parent,
OpOpData.TRANSIENTWRITE))
{
- Hop hi2 = hi.getInput().get(0);
+ Hop hi2 = hi.getInput(0);
hi2.setDataType(DataType.SCALAR);
hi2.setDim1(0); hi2.setDim2(0);
HopRewriteUtils.replaceChildReference(parent, hi, hi2,
pos);
@@ -1537,13 +1537,13 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
//order(matrix(7), indexreturn=TRUE) -> seq(1,nrow(X),1)
if( hi instanceof ReorgOp &&
((ReorgOp)hi).getOp()==ReOrgOp.SORT ) //order
{
- Hop hi2 = hi.getInput().get(0);
+ Hop hi2 = hi.getInput(0);
if( hi2 instanceof DataGenOp &&
((DataGenOp)hi2).getOp()==OpOpDG.RAND
&& ((DataGenOp)hi2).hasConstantValue()
- && hi.getInput().get(3) instanceof
LiteralOp ) //known indexreturn
+ && hi.getInput(3) instanceof LiteralOp
) //known indexreturn
{
- if(
HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(3)) )
+ if(
HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput(3)) )
{
//order(matrix(7), indexreturn=TRUE) ->
seq(1,nrow(X),1)
Hop seq =
HopRewriteUtils.createSeqDataGenOp(hi2);
@@ -1575,20 +1575,20 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
//order(seq(2,N+1,1), indexreturn=TRUE) ->
seq(1,N,1)/seq(N,1,-1)
if( hi instanceof ReorgOp &&
((ReorgOp)hi).getOp()==ReOrgOp.SORT ) //order
{
- Hop hi2 = hi.getInput().get(0);
+ Hop hi2 = hi.getInput(0);
if( hi2 instanceof DataGenOp &&
((DataGenOp)hi2).getOp()==OpOpDG.SEQ )
{
Hop incr =
hi2.getInput().get(((DataGenOp)hi2).getParamIndex(Statement.SEQ_INCR));
//check for known ascending ordering and known
indexreturn
if( incr instanceof LiteralOp &&
HopRewriteUtils.getDoubleValue((LiteralOp)incr)==1
- && hi.getInput().get(2)
instanceof LiteralOp //decreasing
- && hi.getInput().get(3)
instanceof LiteralOp ) //indexreturn
+ && hi.getInput(2) instanceof
LiteralOp //decreasing
+ && hi.getInput(3) instanceof
LiteralOp ) //indexreturn
{
- if(
HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(3)) ) //IXRET,
ASC/DESC
+ if(
HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput(3)) ) //IXRET, ASC/DESC
{
//order(seq(2,N+1,1),
indexreturn=TRUE) -> seq(1,N,1)/seq(N,1,-1)
- boolean desc =
HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2));
+ boolean desc =
HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput(2));
Hop seq =
HopRewriteUtils.createSeqDataGenOp(hi2, !desc);
seq.refreshSizeInformation();
HopRewriteUtils.replaceChildReference(parent, hi, seq, pos);
@@ -1597,7 +1597,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
LOG.debug("Applied
simplifyOrderedSort1.");
}
- else if(
!HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2)) ) //DATA, ASC
+ else if(
!HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput(2)) ) //DATA, ASC
{
//order(seq(2,N+1,1),
indexreturn=FALSE) -> seq(2,N+1,1)
HopRewriteUtils.replaceChildReference(parent, hi, hi2, pos);
@@ -1617,28 +1617,28 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
{
//order(order(X,2),1) -> order(X, (12)),
if( HopRewriteUtils.isReorg(hi, ReOrgOp.SORT)
- && hi.getInput().get(1) instanceof LiteralOp
//scalar by
- && hi.getInput().get(2) instanceof LiteralOp
//scalar desc
- &&
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false) ) //not ixret
+ && hi.getInput(1) instanceof LiteralOp //scalar
by
+ && hi.getInput(2) instanceof LiteralOp //scalar
desc
+ &&
HopRewriteUtils.isLiteralOfValue(hi.getInput(3), false) ) //not ixret
{
- LiteralOp by = (LiteralOp) hi.getInput().get(1);
- boolean desc =
HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput().get(2));
+ LiteralOp by = (LiteralOp) hi.getInput(1);
+ boolean desc =
HopRewriteUtils.getBooleanValue((LiteralOp)hi.getInput(2));
//find chain of order operations with same desc/ixret
configuration and single consumers
Set<String> probe = new HashSet<>();
ArrayList<LiteralOp> byList = new ArrayList<>();
byList.add(by); probe.add(by.getStringValue());
- Hop input = hi.getInput().get(0);
+ Hop input = hi.getInput(0);
while( HopRewriteUtils.isReorg(input, ReOrgOp.SORT)
- && input.getInput().get(1) instanceof
LiteralOp //scalar by
- &&
!probe.contains(input.getInput().get(1).getName())
- &&
HopRewriteUtils.isLiteralOfValue(input.getInput().get(2), desc)
- &&
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), false)
+ && input.getInput(1) instanceof
LiteralOp //scalar by
+ &&
!probe.contains(input.getInput(1).getName())
+ &&
HopRewriteUtils.isLiteralOfValue(input.getInput(2), desc)
+ &&
HopRewriteUtils.isLiteralOfValue(hi.getInput(3), false)
&& input.getParent().size() == 1 )
{
- byList.add((LiteralOp)input.getInput().get(1));
- probe.add(input.getInput().get(1).getName());
- input = input.getInput().get(0);
+ byList.add((LiteralOp)input.getInput(1));
+ probe.add(input.getInput(1).getName());
+ input = input.getInput(0);
}
//merge order chain if at least two instances
@@ -1654,7 +1654,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
//cleanup references recursively
Hop current = hi;
while(current != input ) {
- Hop tmp = current.getInput().get(0);
+ Hop tmp = current.getInput(0);
HopRewriteUtils.removeAllChildReferences(current);
current = tmp;
}
@@ -1683,21 +1683,21 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
private static Hop simplifyTransposeAggBinBinaryChains(Hop parent, Hop
hi, int pos)
{
if( HopRewriteUtils.isTransposeOperation(hi)
- && hi.getInput().get(0) instanceof BinaryOp
//basic binary
- &&
((BinaryOp)hi.getInput().get(0)).supportsMatrixScalarOperations())
+ && hi.getInput(0) instanceof BinaryOp
//basic binary
+ &&
((BinaryOp)hi.getInput(0)).supportsMatrixScalarOperations())
{
- Hop left = hi.getInput().get(0).getInput().get(0);
- Hop C = hi.getInput().get(0).getInput().get(1);
+ Hop left = hi.getInput(0).getInput(0);
+ Hop C = hi.getInput(0).getInput(1);
//check matrix mult and both inputs transposes w/
single consumer
if( left instanceof AggBinaryOp &&
C.getDataType().isMatrix()
- &&
HopRewriteUtils.isTransposeOperation(left.getInput().get(0))
- &&
left.getInput().get(0).getParent().size()==1
- &&
HopRewriteUtils.isTransposeOperation(left.getInput().get(1))
- &&
left.getInput().get(1).getParent().size()==1 )
+ &&
HopRewriteUtils.isTransposeOperation(left.getInput(0))
+ &&
left.getInput(0).getParent().size()==1
+ &&
HopRewriteUtils.isTransposeOperation(left.getInput(1))
+ &&
left.getInput(1).getParent().size()==1 )
{
- Hop A =
left.getInput().get(0).getInput().get(0);
- Hop B =
left.getInput().get(1).getInput().get(0);
+ Hop A = left.getInput(0).getInput(0);
+ Hop B = left.getInput(1).getInput(0);
AggBinaryOp abop =
HopRewriteUtils.createMatrixMultiply(B, A);
ReorgOp rop =
HopRewriteUtils.createTranspose(C);
@@ -1716,18 +1716,18 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
// Patterns: X + (X==0) * s -> replace(X, 0, s)
private static Hop simplifyReplaceZeroOperation(Hop parent, Hop hi, int
pos)
{
- if( HopRewriteUtils.isBinary(hi, OpOp2.PLUS) &&
hi.getInput().get(0).isMatrix()
- &&
HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT)
- &&
hi.getInput().get(1).getInput().get(1).isScalar()
- &&
HopRewriteUtils.isBinaryMatrixScalar(hi.getInput().get(1).getInput().get(0),
OpOp2.EQUAL, 0)
- &&
hi.getInput().get(1).getInput().get(0).getInput().contains(hi.getInput().get(0))
)
+ if( HopRewriteUtils.isBinary(hi, OpOp2.PLUS) &&
hi.getInput(0).isMatrix()
+ && HopRewriteUtils.isBinary(hi.getInput(1),
OpOp2.MULT)
+ && hi.getInput(1).getInput(1).isScalar()
+ &&
HopRewriteUtils.isBinaryMatrixScalar(hi.getInput(1).getInput(0), OpOp2.EQUAL, 0)
+ &&
hi.getInput(1).getInput(0).getInput().contains(hi.getInput(0)) )
{
LinkedHashMap<String, Hop> args = new LinkedHashMap<>();
- args.put("target", hi.getInput().get(0));
+ args.put("target", hi.getInput(0));
args.put("pattern", new LiteralOp(0));
- args.put("replacement",
hi.getInput().get(1).getInput().get(1));
+ args.put("replacement", hi.getInput(1).getInput(1));
Hop replace =
HopRewriteUtils.createParameterizedBuiltinOp(
- hi.getInput().get(0), args,
ParamBuiltinOp.REPLACE);
+ hi.getInput(0), args,
ParamBuiltinOp.REPLACE);
HopRewriteUtils.replaceChildReference(parent, hi,
replace, pos);
hi = replace;
LOG.debug("Applied simplifyReplaceZeroOperation (line
"+hi.getBeginLine()+").");
@@ -1750,10 +1750,10 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
if( hi instanceof ReorgOp &&
HopRewriteUtils.isValidOp(((ReorgOp)hi).getOp(), lookup) ) //first reorg
{
ReOrgOp firstOp = ((ReorgOp)hi).getOp();
- Hop hi2 = hi.getInput().get(0);
+ Hop hi2 = hi.getInput(0);
if( hi2 instanceof ReorgOp &&
((ReorgOp)hi2).getOp()==firstOp ) //second reorg w/ same type
{
- Hop hi3 = hi2.getInput().get(0);
+ Hop hi3 = hi2.getInput(0);
//remove unnecessary chain of t(t())
HopRewriteUtils.replaceChildReference(parent,
hi, hi3, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
@@ -1777,11 +1777,11 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
//rowSums(removeEmpty(target=X,margin="cols")) -> rowSums(X)
//colSums(removeEmpty(target=X,margin="rows")) -> colSums(X)
if( (HopRewriteUtils.isSum(hi) || HopRewriteUtils.isSumSq(hi))
- &&
HopRewriteUtils.isRemoveEmpty(hi.getInput().get(0))
- && hi.getInput().get(0).getParent().size() == 1
)
+ && HopRewriteUtils.isRemoveEmpty(hi.getInput(0))
+ && hi.getInput(0).getParent().size() == 1 )
{
AggUnaryOp agg = (AggUnaryOp)hi;
- ParameterizedBuiltinOp rmEmpty =
(ParameterizedBuiltinOp) hi.getInput().get(0);
+ ParameterizedBuiltinOp rmEmpty =
(ParameterizedBuiltinOp) hi.getInput(0);
boolean needRmEmpty = (agg.getDirection() ==
Direction.Row && HopRewriteUtils.isRemoveEmpty(rmEmpty, true))
|| (agg.getDirection() == Direction.Col
&& HopRewriteUtils.isRemoveEmpty(rmEmpty, false));
@@ -1796,10 +1796,10 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
//check if nrow is called on the output of removeEmpty
if( HopRewriteUtils.isUnary(hi, OpOp1.NROW)
- &&
HopRewriteUtils.isRemoveEmpty(hi.getInput().get(0), true)
- && hi.getInput().get(0).getParent().size() == 1
)
+ &&
HopRewriteUtils.isRemoveEmpty(hi.getInput(0), true)
+ && hi.getInput(0).getParent().size() == 1 )
{
- ParameterizedBuiltinOp rm = (ParameterizedBuiltinOp)
hi.getInput().get(0);
+ ParameterizedBuiltinOp rm = (ParameterizedBuiltinOp)
hi.getInput(0);
//obtain optional select vector or input if col vector
//(nnz will be the same as the select vector if
// the select vector is provided and it will be the same
@@ -1831,15 +1831,15 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
{
if( hi.getDataType() == DataType.MATRIX && hi instanceof
BinaryOp
&& ((BinaryOp)hi).getOp()==OpOp2.MINUS
//first minus
- && hi.getInput().get(0) instanceof LiteralOp &&
((LiteralOp)hi.getInput().get(0)).getDoubleValue()==0 )
+ && hi.getInput(0) instanceof LiteralOp &&
((LiteralOp)hi.getInput(0)).getDoubleValue()==0 )
{
- Hop hi2 = hi.getInput().get(1);
+ Hop hi2 = hi.getInput(1);
if( hi2.getDataType() == DataType.MATRIX && hi2
instanceof BinaryOp
&& ((BinaryOp)hi2).getOp()==OpOp2.MINUS
//second minus
- && hi2.getInput().get(0) instanceof
LiteralOp && ((LiteralOp)hi2.getInput().get(0)).getDoubleValue()==0 )
+ && hi2.getInput(0) instanceof LiteralOp
&& ((LiteralOp)hi2.getInput(0)).getDoubleValue()==0 )
{
- Hop hi3 = hi2.getInput().get(1);
+ Hop hi3 = hi2.getInput(1);
//remove unnecessary chain of -(-())
HopRewriteUtils.replaceChildReference(parent,
hi, hi3, pos);
HopRewriteUtils.cleanupUnreferenced(hi, hi2);
@@ -1887,19 +1887,19 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
//note: this is done as a hop rewrite in order to significantly
reduce the
//memory estimate for X - tmp if X is sparse
if( HopRewriteUtils.isBinary(hi, OpOp2.MINUS)
- &&
hi.getInput().get(0).getDataType()==DataType.MATRIX
- &&
hi.getInput().get(1).getDataType()==DataType.MATRIX
- &&
HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.MULT) )
+ && hi.getInput(0).getDataType()==DataType.MATRIX
+ && hi.getInput(1).getDataType()==DataType.MATRIX
+ && HopRewriteUtils.isBinary(hi.getInput(1),
OpOp2.MULT) )
{
- Hop X = hi.getInput().get(0);
- Hop s = hi.getInput().get(1).getInput().get(0);
- Hop pred = hi.getInput().get(1).getInput().get(1);
+ Hop X = hi.getInput(0);
+ Hop s = hi.getInput(1).getInput(0);
+ Hop pred = hi.getInput(1).getInput(1);
if( s.getDataType()==DataType.SCALAR &&
pred.getDataType()==DataType.MATRIX
&& HopRewriteUtils.isBinary(pred,
OpOp2.NOTEQUAL)
- && pred.getInput().get(0) == X //depend
on common subexpression elimination
- && pred.getInput().get(1) instanceof
LiteralOp
- &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 )
+ && pred.getInput(0) == X //depend on
common subexpression elimination
+ && pred.getInput(1) instanceof LiteralOp
+ &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput(1))==0 )
{
Hop hnew = HopRewriteUtils.createBinary(X, s,
OpOp2.MINUS_NZ);
@@ -1920,17 +1920,17 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
//note: this is done as a hop rewrite in order to significantly
reduce the
//memory estimate and to prevent dense intermediates if X is
ultra sparse
if( HopRewriteUtils.isBinary(hi, OpOp2.MULT)
- &&
hi.getInput().get(0).getDataType()==DataType.MATRIX
- &&
hi.getInput().get(1).getDataType()==DataType.MATRIX
- &&
HopRewriteUtils.isUnary(hi.getInput().get(1), OpOp1.LOG) )
+ && hi.getInput(0).getDataType()==DataType.MATRIX
+ && hi.getInput(1).getDataType()==DataType.MATRIX
+ && HopRewriteUtils.isUnary(hi.getInput(1),
OpOp1.LOG) )
{
- Hop pred = hi.getInput().get(0);
- Hop X = hi.getInput().get(1).getInput().get(0);
+ Hop pred = hi.getInput(0);
+ Hop X = hi.getInput(1).getInput(0);
if( HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL)
- && pred.getInput().get(0) == X //depend
on common subexpression elimination
- && pred.getInput().get(1) instanceof
LiteralOp
- &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 )
+ && pred.getInput(0) == X //depend on
common subexpression elimination
+ && pred.getInput(1) instanceof LiteralOp
+ &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput(1))==0 )
{
Hop hnew = HopRewriteUtils.createUnary(X,
OpOp1.LOG_NZ);
@@ -1951,18 +1951,18 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
//note: this is done as a hop rewrite in order to significantly
reduce the
//memory estimate and to prevent dense intermediates if X is
ultra sparse
if( HopRewriteUtils.isBinary(hi, OpOp2.MULT)
- &&
hi.getInput().get(0).getDataType()==DataType.MATRIX
- &&
hi.getInput().get(1).getDataType()==DataType.MATRIX
- &&
HopRewriteUtils.isBinary(hi.getInput().get(1), OpOp2.LOG) )
+ && hi.getInput(0).getDataType()==DataType.MATRIX
+ && hi.getInput(1).getDataType()==DataType.MATRIX
+ && HopRewriteUtils.isBinary(hi.getInput(1),
OpOp2.LOG) )
{
- Hop pred = hi.getInput().get(0);
- Hop X = hi.getInput().get(1).getInput().get(0);
- Hop log = hi.getInput().get(1).getInput().get(1);
+ Hop pred = hi.getInput(0);
+ Hop X = hi.getInput(1).getInput(0);
+ Hop log = hi.getInput(1).getInput(1);
if( HopRewriteUtils.isBinary(pred, OpOp2.NOTEQUAL)
- && pred.getInput().get(0) == X //depend
on common subexpression elimination
- && pred.getInput().get(1) instanceof
LiteralOp
- &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput().get(1))==0 )
+ && pred.getInput(0) == X //depend on
common subexpression elimination
+ && pred.getInput(1) instanceof LiteralOp
+ &&
HopRewriteUtils.getDoubleValueSafe((LiteralOp)pred.getInput(1))==0 )
{
Hop hnew = HopRewriteUtils.createBinary(X, log,
OpOp2.LOG_NZ);
@@ -1984,20 +1984,20 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
if( HopRewriteUtils.isBinary(hi, OpOp2.EQUAL) &&
((BinaryOp)hi).isOuter() )
{
- if( (
HopRewriteUtils.isTransposeOperation(hi.getInput().get(1)) //pattern a:
outer(v, t(seq(1,m)), "==")
- &&
HopRewriteUtils.isBasic1NSequence(hi.getInput().get(1).getInput().get(0)))
- ||
HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0))) //pattern b:
outer(seq(1,m), t(v) "==")
+ if( (
HopRewriteUtils.isTransposeOperation(hi.getInput(1)) //pattern a: outer(v,
t(seq(1,m)), "==")
+ &&
HopRewriteUtils.isBasic1NSequence(hi.getInput(1).getInput(0)))
+ ||
HopRewriteUtils.isBasic1NSequence(hi.getInput(0))) //pattern b: outer(seq(1,m),
t(v) "==")
{
//determine variable parameters for pattern a/b
- boolean isPatternB =
HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0));
- boolean isTransposeRight =
HopRewriteUtils.isTransposeOperation(hi.getInput().get(1));
+ boolean isPatternB =
HopRewriteUtils.isBasic1NSequence(hi.getInput(0));
+ boolean isTransposeRight =
HopRewriteUtils.isTransposeOperation(hi.getInput(1));
Hop trgt = isPatternB ? (isTransposeRight ?
-
hi.getInput().get(1).getInput().get(0) : //get v from t(v)
-
HopRewriteUtils.createTranspose(hi.getInput().get(1)) ) : //create v via t(v')
- hi.getInput().get(0);
//get v directly
+ hi.getInput(1).getInput(0) :
//get v from t(v)
+
HopRewriteUtils.createTranspose(hi.getInput(1)) ) : //create v via t(v')
+ hi.getInput(0);
//get v directly
Hop seq = isPatternB ?
- hi.getInput().get(0) :
hi.getInput().get(1).getInput().get(0);
- String direction =
HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)) ? "rows" : "cols";
+ hi.getInput(0) :
hi.getInput(1).getInput(0);
+ String direction =
HopRewriteUtils.isBasic1NSequence(hi.getInput(0)) ? "rows" : "cols";
//setup input parameter hops
LinkedHashMap<String,Hop> inputargs = new
LinkedHashMap<>();
@@ -2024,12 +2024,12 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
private static Hop simplifyBinaryComparisonChain(Hop parent, Hop hi,
int pos) {
if( HopRewriteUtils.isBinaryPPred(hi)
- &&
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 0d, 1d)
- &&
HopRewriteUtils.isBinaryPPred(hi.getInput().get(0)) )
+ &&
HopRewriteUtils.isLiteralOfValue(hi.getInput(1), 0d, 1d)
+ &&
HopRewriteUtils.isBinaryPPred(hi.getInput(0)) )
{
BinaryOp bop = (BinaryOp) hi;
- BinaryOp bop2 = (BinaryOp) hi.getInput().get(0);
- boolean one =
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(1), 1);
+ BinaryOp bop2 = (BinaryOp) hi.getInput(0);
+ boolean one =
HopRewriteUtils.isLiteralOfValue(hi.getInput(1), 1);
//pattern: outer(v1,v2,"!=") == 1 -> outer(v1,v2,"!=")
if( (one && bop.getOp() == OpOp2.EQUAL)
@@ -2043,8 +2043,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
//pattern: outer(v1,v2,"!=") == 0 -> outer(v1,v2,"==")
else if( !one && bop.getOp() == OpOp2.EQUAL ) {
OpOp2 optr = bop2.getComplementPPredOperation();
- BinaryOp tmp =
HopRewriteUtils.createBinary(bop2.getInput().get(0),
- bop2.getInput().get(1), optr,
bop2.isOuter());
+ BinaryOp tmp =
HopRewriteUtils.createBinary(bop2.getInput(0),
+ bop2.getInput(1), optr,
bop2.isOuter());
HopRewriteUtils.replaceChildReference(parent,
bop, tmp, pos);
HopRewriteUtils.cleanupUnreferenced(bop, bop2);
hi = tmp;
@@ -2059,11 +2059,11 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
//pattern: colSums(cumsum(X)) -> cumSums(X*seq(nrow(X),1))
if( (HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, Direction.Col)
|| HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM,
Direction.RowCol))
- &&
HopRewriteUtils.isUnary(hi.getInput().get(0), OpOp1.CUMSUM)
- && hi.getInput().get(0).getParent().size()==1)
+ && HopRewriteUtils.isUnary(hi.getInput(0),
OpOp1.CUMSUM)
+ && hi.getInput(0).getParent().size()==1)
{
- Hop cumsumX = hi.getInput().get(0);
- Hop X = cumsumX.getInput().get(0);
+ Hop cumsumX = hi.getInput(0);
+ Hop X = cumsumX.getInput(0);
Hop mult = HopRewriteUtils.createBinary(X,
HopRewriteUtils.createSeqDataGenOp(X,
false), OpOp2.MULT);
HopRewriteUtils.replaceChildReference(hi, cumsumX,
mult);
@@ -2076,14 +2076,14 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
private static Hop simplifyCumsumReverse(Hop parent, Hop hi, int pos) {
//pattern: rev(cumsum(rev(X))) -> X + colSums(X) - cumsum(X)
if( HopRewriteUtils.isReorg(hi, ReOrgOp.REV)
- &&
HopRewriteUtils.isUnary(hi.getInput().get(0), OpOp1.CUMSUM)
- && hi.getInput().get(0).getParent().size()==1
- &&
HopRewriteUtils.isReorg(hi.getInput().get(0).getInput().get(0), ReOrgOp.REV)
- &&
hi.getInput().get(0).getInput().get(0).getParent().size()==1)
+ && HopRewriteUtils.isUnary(hi.getInput(0),
OpOp1.CUMSUM)
+ && hi.getInput(0).getParent().size()==1
+ &&
HopRewriteUtils.isReorg(hi.getInput(0).getInput(0), ReOrgOp.REV)
+ &&
hi.getInput(0).getInput(0).getParent().size()==1)
{
- Hop cumsumX = hi.getInput().get(0);
- Hop revX = cumsumX.getInput().get(0);
- Hop X = revX.getInput().get(0);
+ Hop cumsumX = hi.getInput(0);
+ Hop revX = cumsumX.getInput(0);
+ Hop X = revX.getInput(0);
Hop plus = HopRewriteUtils.createBinary(X,
HopRewriteUtils
.createAggUnaryOp(X, AggOp.SUM,
Direction.Col), OpOp2.PLUS);
Hop minus = HopRewriteUtils.createBinary(plus,
@@ -2160,8 +2160,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
if( hi instanceof BinaryOp )
{
BinaryOp bop = (BinaryOp)hi;
- Hop left = bop.getInput().get(0);
- Hop right = bop.getInput().get(1);
+ Hop left = bop.getInput(0);
+ Hop right = bop.getInput(1);
Hop datagen = null;
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFuseBinaryOpChainTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFuseBinaryOpChainTest.java
index 722ac0dc1a..8b697e28cf 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFuseBinaryOpChainTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFuseBinaryOpChainTest.java
@@ -23,7 +23,6 @@ import java.util.HashMap;
import org.junit.Assert;
import org.junit.Test;
-import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.common.Types.ExecType;
@@ -143,17 +142,8 @@ public class RewriteFuseBinaryOpChainTest extends
AutomatedTestBase
}
private void testFuseBinaryChain( String testname, boolean rewrites,
ExecType instType )
- {
- ExecMode platformOld = rtplatform;
- switch( instType ){
- case SPARK: rtplatform = ExecMode.SPARK; break;
- default: rtplatform = ExecMode.HYBRID; break;
- }
-
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- if( rtplatform == ExecMode.SPARK )
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-
+ {
+ ExecMode platformOld = setExecMode(instType);
boolean rewritesOld =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
@@ -184,8 +174,7 @@ public class RewriteFuseBinaryOpChainTest extends
AutomatedTestBase
}
finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
rewritesOld;
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ resetExecMode(platformOld);
}
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteHoistingLoopInvariantOpsTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteHoistingLoopInvariantOpsTest.java
index 6a6f5105e4..f52a0c7cde 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteHoistingLoopInvariantOpsTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteHoistingLoopInvariantOpsTest.java
@@ -23,7 +23,6 @@ import java.util.HashMap;
import org.junit.Assert;
import org.junit.Test;
-import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.common.Types.ExecType;
@@ -75,17 +74,8 @@ public class RewriteHoistingLoopInvariantOpsTest extends
AutomatedTestBase
}
private void testRewriteCodeMotion(String testname, boolean rewrites,
ExecType et)
- {
- ExecMode platformOld = rtplatform;
- switch( et ){
- case SPARK: rtplatform = ExecMode.SPARK; break;
- default: rtplatform = ExecMode.HYBRID; break;
- }
-
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- if( rtplatform == ExecMode.SPARK || rtplatform ==
ExecMode.HYBRID )
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-
+ {
+ ExecMode platformOld = setExecMode(et);
boolean rewritesOld = OptimizerUtils.ALLOW_CODE_MOTION;
OptimizerUtils.ALLOW_CODE_MOTION = rewrites;
@@ -119,8 +109,7 @@ public class RewriteHoistingLoopInvariantOpsTest extends
AutomatedTestBase
}
finally {
OptimizerUtils.ALLOW_CODE_MOTION = rewritesOld;
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ resetExecMode(platformOld);
}
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteIfElseTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteIfElseTest.java
index f074f4f7d2..087fc49d98 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteIfElseTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteIfElseTest.java
@@ -134,17 +134,8 @@ public class RewriteIfElseTest extends AutomatedTestBase
}
private void testRewriteIfElse(String testname, boolean pred, boolean
rewrites, ExecType et)
- {
- ExecMode platformOld = rtplatform;
- switch( et ){
- case SPARK: rtplatform = ExecMode.SPARK; break;
- default: rtplatform = ExecMode.HYBRID; break;
- }
-
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- if( rtplatform == ExecMode.SPARK || rtplatform ==
ExecMode.HYBRID )
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-
+ {
+ ExecMode platformOld = setExecMode(et);
boolean rewritesOld =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
@@ -184,8 +175,7 @@ public class RewriteIfElseTest extends AutomatedTestBase
}
finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
rewritesOld;
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ resetExecMode(platformOld);
}
}
}