This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new a46189ce7b [SYSTEMDS-3805] Rewrite and runtime for scalar right
indexing
a46189ce7b is described below
commit a46189ce7b72992b0597c9cea819abdb390c7b66
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Dec 11 18:07:02 2024 +0100
[SYSTEMDS-3805] Rewrite and runtime for scalar right indexing
This patch adds a new rewrite, as well as modifies existing rewrites
and runtime instructions in order to perform scalar right indexing
for operations like as.scalar(X[i,1]) which avoids unnecessary
createvar and cast instructions. On a scenario of running the baseline
(non-vectorized) exponential smoothing on 10M data points, the patch
improved end-to-end performance from from 22.3s to 12.2s (6.7s without
statistics time measurements).
alpha = 0.05
r = as.scalar(X[1, 1])
for(i in 2:nrow(X)) {
r = alpha * as.scalar(X[i, 1]) + (1-alpha) * r
}
Total elapsed time: 22.348 sec.
Total compilation time: 0.516 sec.
Total execution time: 21.832 sec.
Cache hits (Mem/Li/WB/FS/HDFS): 20000000/0/0/0/0.
Cache writes (Li/WB/FS/HDFS): 1/0/0/0.
Cache times (ACQr/m, RLS, EXP): 0.777/0.432/1.124/0.000 sec.
HOP DAGs recompiled (PRED, SB): 0/0.
HOP DAGs recompile time: 0.300 sec.
Functions recompiled: 1.
Functions recompile time: 0.002 sec.
Total JIT compile time: 2.608 sec.
Total JVM GC count: 1.
Total JVM GC time: 0.018 sec.
Heavy hitter instructions:
1 rightIndex 4.894 10000000
2 createvar 3.585 10000001
3 rmvar 2.848 30000000
4 castdts 2.242 10000000
5 * 1.742 19999998
6 + 0.898 9999999
7 mvvar 0.751 10000002
8 rand 0.213 1
9 - 0.016 1
10 print 0.000 1
11 assignvar 0.000 2
Total elapsed time: 12.589 sec.
Total compilation time: 0.520 sec.
Total execution time: 12.069 sec.
Cache hits (Mem/Li/WB/FS/HDFS): 10000000/0/0/0/0.
Cache writes (Li/WB/FS/HDFS): 1/0/0/0.
Cache times (ACQr/m, RLS, EXP): 0.455/0.000/0.463/0.000 sec.
HOP DAGs recompiled (PRED, SB): 0/0.
HOP DAGs recompile time: 0.313 sec.
Functions recompiled: 1.
Functions recompile time: 0.002 sec.
Total JIT compile time: 1.923 sec.
Total JVM GC count: 1.
Total JVM GC time: 0.011 sec.
Heavy hitter instructions:
1 rightIndex 3.046 10000000
2 * 1.876 19999998
3 rmvar 1.450 20000000
4 + 0.954 9999999
5 mvvar 0.801 10000002
6 rand 0.213 1
7 - 0.018 1
8 print 0.000 1
9 createvar 0.000 1
10 assignvar 0.000 2
---
.../java/org/apache/sysds/hops/IndexingOp.java | 4 +
.../apache/sysds/hops/rewrite/HopRewriteUtils.java | 2 +-
.../RewriteAlgebraicSimplificationDynamic.java | 2 +-
.../RewriteAlgebraicSimplificationStatic.java | 22 ++++++
.../hops/rewrite/RewriteIndexingVectorization.java | 6 +-
.../cp/MatrixIndexingCPInstruction.java | 53 +++++++------
.../instructions/cp/VariableCPInstruction.java | 5 ++
.../spark/MatrixIndexingSPInstruction.java | 40 ++++++----
.../rewrite/RewriteLoopVectorization.java | 2 +
.../rewrite/RewriteScalarRightIndexingTest.java | 92 ++++++++++++++++++++++
.../rewrite/RewriteScalarRightIndexing.dml | 34 ++++++++
11 files changed, 220 insertions(+), 42 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/IndexingOp.java
b/src/main/java/org/apache/sysds/hops/IndexingOp.java
index 35215fa843..1756724e74 100644
--- a/src/main/java/org/apache/sysds/hops/IndexingOp.java
+++ b/src/main/java/org/apache/sysds/hops/IndexingOp.java
@@ -73,6 +73,10 @@ public class IndexingOp extends Hop
setRowLowerEqualsUpper(passedRowsLEU);
setColLowerEqualsUpper(passedColsLEU);
}
+
+ public boolean isScalarOutput() {
+ return isRowLowerEqualsUpper() && isColLowerEqualsUpper();
+ }
public boolean isRowLowerEqualsUpper(){
return _rowLowerEqualsUpper;
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index 68167ac3ae..aae2787cd3 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -1332,7 +1332,7 @@ public class HopRewriteUtils {
}
public static boolean isUnnecessaryRightIndexing(Hop hop) {
- if( !(hop instanceof IndexingOp) )
+ if( !(hop instanceof IndexingOp) || hop.isScalar() )
return false;
//note: in addition to equal sizes, we also check a valid
//starting row and column ranges of 1 in order to guard against
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 396c40d114..9c1f2174d0 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -241,7 +241,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
private static Hop removeUnnecessaryRightIndexing(Hop parent, Hop hi,
int pos)
{
- if( HopRewriteUtils.isUnnecessaryRightIndexing(hi) ) {
+ if( HopRewriteUtils.isUnnecessaryRightIndexing(hi) &&
!hi.isScalar() ) {
//remove unnecessary right indexing
Hop input = hi.getInput().get(0);
HopRewriteUtils.replaceChildReference(parent, hi,
input, pos);
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 5a79bdee33..d06f89d72e 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -174,6 +174,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
hi = simplifyTraceMatrixMult(hop, hi, i);
//e.g., trace(X%*%Y)->sum(X*t(Y));
hi = simplifySlicedMatrixMult(hop, hi, i);
//e.g., (X%*%Y)[1,1] -> X[1,] %*% Y[,1];
hi = simplifyListIndexing(hi);
//e.g., L[i:i, 1:ncol(L)] -> L[i:i, 1:1]
+ hi = simplifyScalarIndexing(hop, hi, i);
//e.g., as.scalar(X[i,1])->X[i,1] w/ scalar output
hi = simplifyConstantSort(hop, hi, i);
//e.g., order(matrix())->matrix/seq;
hi = simplifyOrderedSort(hop, hi, i);
//e.g., order(matrix())->seq;
hi = fuseOrderOperationChain(hi);
//e.g., order(order(X,2),1) -> order(X,(12))
@@ -1508,6 +1509,27 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
return hi;
}
+ private static Hop simplifyScalarIndexing(Hop parent, Hop hi, int pos)
+ {
+ //as.scalar(X[i,1]) -> X[i,1] w/ scalar output
+ if( HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR)
+ && hi.getInput(0).getParent().size() == 1 // only
consumer
+ && hi.getInput(0) instanceof IndexingOp
+ && ((IndexingOp)hi.getInput(0)).isScalarOutput()
+ && hi.getInput(0).isMatrix() //no frame support yet
+ && !HopRewriteUtils.isData(parent,
OpOpData.TRANSIENTWRITE))
+ {
+ Hop hi2 = hi.getInput().get(0);
+ hi2.setDataType(DataType.SCALAR);
+ hi2.setDim1(0); hi2.setDim2(0);
+ HopRewriteUtils.replaceChildReference(parent, hi, hi2,
pos);
+ HopRewriteUtils.cleanupUnreferenced(hi);
+ hi = hi2;
+ LOG.debug("Applied simplifyScalarIndexing (line
"+hi.getBeginLine()+").");
+ }
+ return hi;
+ }
+
private static Hop simplifyConstantSort(Hop parent, Hop hi, int pos)
{
//order(matrix(7), indexreturn=FALSE) -> matrix(7)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteIndexingVectorization.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteIndexingVectorization.java
index 9c04959ed5..6da9e52132 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteIndexingVectorization.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteIndexingVectorization.java
@@ -186,8 +186,8 @@ public class RewriteIndexingVectorization extends
HopRewriteRule
ihops.add(ihop0);
for( Hop c : input.getParent() ){
if( c != ihop0 && c instanceof
IndexingOp && c.getInput().get(0) == input
- && ((IndexingOp)
c).isRowLowerEqualsUpper()
- &&
c.getInput().get(1)==ihop0.getInput().get(1) )
+ && ((IndexingOp)
c).isRowLowerEqualsUpper() && !c.isScalar()
+ &&
c.getInput().get(1)==ihop0.getInput().get(1) )
{
ihops.add( c );
}
@@ -225,7 +225,7 @@ public class RewriteIndexingVectorization extends
HopRewriteRule
ihops.add(ihop0);
for( Hop c : input.getParent() ){
if( c != ihop0 && c instanceof
IndexingOp && c.getInput().get(0) == input
- && ((IndexingOp)
c).isColLowerEqualsUpper()
+ && ((IndexingOp)
c).isColLowerEqualsUpper() && !c.isScalar()
&&
c.getInput().get(3)==ihop0.getInput().get(3) )
{
ihops.add( c );
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixIndexingCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixIndexingCPInstruction.java
index afbf7724ab..99473b7a49 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixIndexingCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixIndexingCPInstruction.java
@@ -52,39 +52,46 @@ public final class MatrixIndexingCPInstruction extends
IndexingCPInstruction {
String opcode = getOpcode();
IndexRange ix = getIndexRange(ec);
- //get original matrix
MatrixObject mo = ec.getMatrixObject(input1.getName());
+ boolean inRange = ix.rowStart < mo.getNumRows() && ix.colStart
< mo.getNumColumns();
//right indexing
if( opcode.equalsIgnoreCase(RightIndex.OPCODE) )
{
- MatrixBlock resultBlock = null;
-
- if( mo.isPartitioned() ) //via data partitioning
- resultBlock = mo.readMatrixPartition(ix.add(1));
- else if( ix.isScalar() && ix.rowStart < mo.getNumRows()
&& ix.colStart < mo.getNumColumns() ) {
+ if( output.isScalar() && inRange ) { //SCALAR out
MatrixBlock matBlock =
mo.acquireReadAndRelease();
- resultBlock = new MatrixBlock(
- matBlock.get((int)ix.rowStart,
(int)ix.colStart));
+ ec.setScalarOutput(output.getName(),
+ new
DoubleObject(matBlock.get((int)ix.rowStart, (int)ix.colStart)));
}
- else //via slicing the in-memory matrix
- {
- //execute right indexing operation (with
shallow row copies for range
- //of entire sparse rows, which is safe due to
copy on update)
- MatrixBlock matBlock = mo.acquireRead();
- resultBlock = matBlock.slice((int)ix.rowStart,
(int)ix.rowEnd,
- (int)ix.colStart, (int)ix.colEnd,
false, new MatrixBlock());
+ else { //MATRIX out
+ MatrixBlock resultBlock = null;
- //unpin rhs input
- ec.releaseMatrixInput(input1.getName());
+ if( mo.isPartitioned() ) //via data partitioning
+ resultBlock =
mo.readMatrixPartition(ix.add(1));
+ else if( ix.isScalar() && inRange ) {
+ MatrixBlock matBlock =
mo.acquireReadAndRelease();
+ resultBlock = new MatrixBlock(
+ matBlock.get((int)ix.rowStart,
(int)ix.colStart));
+ }
+ else //via slicing the in-memory matrix
+ {
+ //execute right indexing operation
(with shallow row copies for range
+ //of entire sparse rows, which is safe
due to copy on update)
+ MatrixBlock matBlock = mo.acquireRead();
+ resultBlock =
matBlock.slice((int)ix.rowStart, (int)ix.rowEnd,
+ (int)ix.colStart,
(int)ix.colEnd, false, new MatrixBlock());
+
+ //unpin rhs input
+ ec.releaseMatrixInput(input1.getName());
+
+ //ensure correct sparse/dense output
representation
+ if(
checkGuardedRepresentationChange(matBlock, resultBlock) )
+ resultBlock.examSparsity();
+ }
- //ensure correct sparse/dense output
representation
- if( checkGuardedRepresentationChange(matBlock,
resultBlock) )
- resultBlock.examSparsity();
+ //unpin output
+ ec.setMatrixOutput(output.getName(),
resultBlock);
}
-
- //unpin output
- ec.setMatrixOutput(output.getName(), resultBlock);
}
//left indexing
else if ( opcode.equalsIgnoreCase(LeftIndex.OPCODE))
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
index 8826c41b80..3ae0a96f0b 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
@@ -897,6 +897,11 @@ public class VariableCPInstruction extends CPInstruction
implements LineageTrace
ec.setVariable(output.getName(), list.slice(0));
break;
}
+ case SCALAR: {
+ //for robustness in case rewrites added
unnecessary as.scalars
+ ec.setScalarOutput(output.getName(),
ec.getScalarInput(getInput1()));
+ break;
+ }
default:
throw new DMLRuntimeException("Unsupported data
type "
+ "in as.scalar():
"+getInput1().getDataType().name());
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java
index ac2d8f4f22..ceaaea2ded 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java
@@ -35,6 +35,7 @@ import
org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
@@ -47,6 +48,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;
import scala.Function1;
@@ -103,26 +105,35 @@ public class MatrixIndexingSPInstruction extends
IndexingSPInstruction {
if( opcode.equalsIgnoreCase(RightIndex.OPCODE) )
{
//update and check output dimensions
- DataCharacteristics mcOut =
sec.getDataCharacteristics(output.getName());
+ DataCharacteristics mcOut = output.isScalar() ?
+ new MatrixCharacteristics(1,1) :
+ ec.getDataCharacteristics(output.getName());
mcOut.set(ru-rl+1, cu-cl+1, mcIn.getBlocksize(),
mcIn.getBlocksize());
mcOut.setNonZerosBound(Math.min(mcOut.getLength(),
mcIn.getNonZerosBound()));
checkValidOutputDimensions(mcOut);
//execute right indexing operation
(partitioning-preserving if possible)
JavaPairRDD<MatrixIndexes,MatrixBlock> in1 =
sec.getBinaryMatrixBlockRDDHandleForVariable( input1.getName() );
-
- if( isSingleBlockLookup(mcIn, ixrange) ) {
- sec.setMatrixOutput(output.getName(),
singleBlockIndexing(in1, mcIn, mcOut, ixrange));
- }
- else if( isMultiBlockLookup(in1, mcIn, mcOut, ixrange)
) {
- sec.setMatrixOutput(output.getName(),
multiBlockIndexing(in1, mcIn, mcOut, ixrange));
+
+ if( output.isScalar() ) { //SCALAR output
+ MatrixBlock ret = singleBlockIndexing(in1,
mcIn, mcOut, ixrange);
+ sec.setScalarOutput(output.getName(), new
DoubleObject(ret.get(0, 0)));
}
- else { //rdd output for general case
- JavaPairRDD<MatrixIndexes,MatrixBlock> out =
generalCaseRightIndexing(in1, mcIn, mcOut, ixrange, _aggType);
+ else { //MATRIX output
- //put output RDD handle into symbol table
- sec.setRDDHandleForVariable(output.getName(),
out);
- sec.addLineageRDD(output.getName(),
input1.getName());
+ if( isSingleBlockLookup(mcIn, ixrange) ) {
+ sec.setMatrixOutput(output.getName(),
singleBlockIndexing(in1, mcIn, mcOut, ixrange));
+ }
+ else if( isMultiBlockLookup(in1, mcIn, mcOut,
ixrange) ) {
+ sec.setMatrixOutput(output.getName(),
multiBlockIndexing(in1, mcIn, mcOut, ixrange));
+ }
+ else { //rdd output for general case
+ JavaPairRDD<MatrixIndexes,MatrixBlock>
out = generalCaseRightIndexing(in1, mcIn, mcOut, ixrange, _aggType);
+
+ //put output RDD handle into symbol
table
+
sec.setRDDHandleForVariable(output.getName(), out);
+ sec.addLineageRDD(output.getName(),
input1.getName());
+ }
}
}
//left indexing
@@ -178,12 +189,13 @@ public class MatrixIndexingSPInstruction extends
IndexingSPInstruction {
sec.addLineageRDD(output.getName(),
input2.getName());
}
else
- throw new DMLRuntimeException("Invalid opcode (" +
opcode +") encountered in MatrixIndexingSPInstruction.");
+ throw new DMLRuntimeException("Invalid opcode (" +
opcode +") encountered in MatrixIndexingSPInstruction.");
}
public static MatrixBlock
inmemoryIndexing(JavaPairRDD<MatrixIndexes,MatrixBlock> in1,
- DataCharacteristics mcIn,
DataCharacteristics mcOut, IndexRange ixrange) {
+ DataCharacteristics mcIn, DataCharacteristics mcOut, IndexRange
ixrange)
+ {
if( isSingleBlockLookup(mcIn, ixrange) ) {
return singleBlockIndexing(in1, mcIn, mcOut, ixrange);
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteLoopVectorization.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteLoopVectorization.java
index d9358fef30..927b0fd666 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteLoopVectorization.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteLoopVectorization.java
@@ -22,6 +22,7 @@ package org.apache.sysds.test.functions.rewrite;
import java.util.HashMap;
import org.junit.Assert;
+import org.junit.Ignore;
import org.junit.Test;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
@@ -57,6 +58,7 @@ public class RewriteLoopVectorization extends
AutomatedTestBase
}
@Test
+ @Ignore //FIXME: extend loop vectorization rewrite
public void testLoopVectorizationSumRewrite() {
testRewriteLoopVectorizationSum( TEST_NAME1, true );
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteScalarRightIndexingTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteScalarRightIndexingTest.java
new file mode 100644
index 0000000000..9a3792d29d
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteScalarRightIndexingTest.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.utils.Statistics;
+
+public class RewriteScalarRightIndexingTest extends AutomatedTestBase
+{
+ private final static String TEST_DIR = "functions/rewrite/";
+ private final static String TEST_NAME = "RewriteScalarRightIndexing";
+
+ private final static String TEST_CLASS_DIR = TEST_DIR +
RewriteScalarRightIndexingTest.class.getSimpleName() + "/";
+
+ private final static int rows = 122;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"A"}));
+ }
+
+ @Test
+ public void testScalarRightIndexingCP() {
+ runScalarRightIndexing(true, ExecType.CP);
+ }
+
+ @Test
+ public void testScalarRightIndexingNoRewriteCP() {
+ runScalarRightIndexing(false, ExecType.CP);
+ }
+
+ @Test
+ public void testScalarRightIndexingSpark() {
+ runScalarRightIndexing(true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testScalarRightIndexingNoRewriteSpark() {
+ runScalarRightIndexing(false, ExecType.SPARK);
+ }
+
+ private void runScalarRightIndexing(boolean rewrite, ExecType instType)
{
+ ExecMode platformOld = setExecMode(instType);
+ boolean flagOld = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ try {
+ TestConfiguration config =
getTestConfiguration(TEST_NAME);
+ loadTestConfiguration(config);
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrite;
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{"-explain", "-stats",
"-args",
+ Long.toString(rows), output("A")};
+ runTest(true, false, null, -1);
+
+ Double ret = readDMLScalarFromOutputDir("A").get(new
CellIndex(1,1));
+ Assert.assertEquals(Double.valueOf(103.0383), ret,
1e-4);
+ if(rewrite) //w/o rewrite 122 casts
+
Assert.assertTrue(Statistics.getCPHeavyHitterCount("castdts")<=1);
+ }
+ finally {
+ resetExecMode(platformOld);
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = flagOld;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/rewrite/RewriteScalarRightIndexing.dml
b/src/test/scripts/functions/rewrite/RewriteScalarRightIndexing.dml
new file mode 100644
index 0000000000..d0b76cd2d8
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteScalarRightIndexing.dml
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+nrow = $1;
+X = seq(1, nrow);
+
+alpha = 0.05
+
+r = as.scalar(X[1, 1])
+for(i in 2:nrow(X)) {
+ r = alpha * as.scalar(X[i, 1]) + (1-alpha) * r
+}
+
+write(r, $2);
+