This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 2972d6df5f2453e091343b59708343a4c562f185 Author: Matthias Boehm <[email protected]> AuthorDate: Mon May 20 19:22:42 2024 +0200 [MINOR] Fix simplification rewrite binary ops (robustness for strings) --- .../rewrite/RewriteAlgebraicSimplificationStatic.java | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index f1065ea832..8fed2481ed 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -279,7 +279,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule Hop right = bop.getInput().get(1); //X/1 or X*1 -> X if( left.getDataType()==DataType.MATRIX - && right instanceof LiteralOp && ((LiteralOp)right).getDoubleValue()==1.0 ) + && right instanceof LiteralOp && right.getValueType().isNumeric() + && ((LiteralOp)right).getDoubleValue()==1.0 ) { if( bop.getOp()==OpOp2.DIV || bop.getOp()==OpOp2.MULT ) { @@ -291,7 +292,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule } //X-0 -> X else if( left.getDataType()==DataType.MATRIX - && right instanceof LiteralOp && ((LiteralOp)right).getDoubleValue()==0.0 ) + && right instanceof LiteralOp && right.getValueType().isNumeric() + && ((LiteralOp)right).getDoubleValue()==0.0 ) { if( bop.getOp()==OpOp2.MINUS ) { @@ -303,7 +305,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule } //1*X -> X else if( right.getDataType()==DataType.MATRIX - && left instanceof LiteralOp && ((LiteralOp)left).getDoubleValue()==1.0 ) + && left instanceof LiteralOp && left.getValueType().isNumeric() + && ((LiteralOp)left).getDoubleValue()==1.0 ) { if( bop.getOp()==OpOp2.MULT ) { @@ -317,7 +320,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule //note: this rewrite is necessary since the new antlr parser always converts //-X to -1*X due to mechanical reasons else if( right.getDataType()==DataType.MATRIX - && left instanceof LiteralOp && ((LiteralOp)left).getDoubleValue()==-1.0 ) + && left instanceof LiteralOp && left.getValueType().isNumeric() + && ((LiteralOp)left).getDoubleValue()==-1.0 ) { if( bop.getOp()==OpOp2.MULT ) { @@ -330,7 +334,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule } //X*-1 -> -X (see comment above) else if( left.getDataType()==DataType.MATRIX - && right instanceof LiteralOp && ((LiteralOp)right).getDoubleValue()==-1.0 ) + && right instanceof LiteralOp && right.getValueType().isNumeric() + && ((LiteralOp)right).getDoubleValue()==-1.0 ) { if( bop.getOp()==OpOp2.MULT ) {
