This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 45de519829 [MINOR] Fix index bounds checks on recompilation literal
replacement
45de519829 is described below
commit 45de519829860f515f6f40871e9f4a4ce92d5cec
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Apr 19 10:46:41 2024 +0200
[MINOR] Fix index bounds checks on recompilation literal replacement
---
.../org/apache/sysds/hops/recompile/LiteralReplacement.java | 4 ++++
.../instructions/spark/MatrixIndexingSPInstruction.java | 10 +++++++---
.../functions/indexing/UnboundedScalarRightIndexingTest.java | 8 ++++----
3 files changed, 15 insertions(+), 7 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/hops/recompile/LiteralReplacement.java
b/src/main/java/org/apache/sysds/hops/recompile/LiteralReplacement.java
index 1a40ada94e..441a80ceb6 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/LiteralReplacement.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/LiteralReplacement.java
@@ -253,6 +253,10 @@ public class LiteralReplacement
if( mo.getNumRows()*mo.getNumColumns() <
REPLACE_LITERALS_MAX_MATRIX_SIZE )
{
MatrixBlock mBlock = mo.acquireRead();
+ if( rlval>mo.getNumRows() ||
clval>mo.getNumColumns() ) {
+ throw new
DMLRuntimeException("Scalar indexing out-of-bounds:"
+ + " ["+rlval+",
"+clval+"] in "+mo.getDataCharacteristics());
+ }
double value =
mBlock.get((int)rlval-1,(int)clval-1);
mo.release();
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java
index 3c8583d34c..e97336a8a6 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MatrixIndexingSPInstruction.java
@@ -90,11 +90,15 @@ public class MatrixIndexingSPInstruction extends
IndexingSPInstruction {
long cu = ec.getScalarInput(colUpper).getLongValue();
IndexRange ixrange = new IndexRange(rl, ru, cl, cu);
+ //check bounds
+ DataCharacteristics mcIn =
sec.getDataCharacteristics(input1.getName());
+ if( mcIn.dimsKnown() && (ru>mcIn.getRows() ||
cu>mcIn.getCols()) )
+ throw new DMLRuntimeException("Index range out of
bounds: "+ixrange+" "+mcIn);
+
//right indexing
if( opcode.equalsIgnoreCase(RightIndex.OPCODE) )
{
//update and check output dimensions
- DataCharacteristics mcIn =
sec.getDataCharacteristics(input1.getName());
DataCharacteristics mcOut =
sec.getDataCharacteristics(output.getName());
mcOut.set(ru-rl+1, cu-cl+1, mcIn.getBlocksize(),
mcIn.getBlocksize());
mcOut.setNonZerosBound(Math.min(mcOut.getLength(),
mcIn.getNonZerosBound()));
@@ -114,7 +118,7 @@ public class MatrixIndexingSPInstruction extends
IndexingSPInstruction {
//put output RDD handle into symbol table
sec.setRDDHandleForVariable(output.getName(),
out);
- sec.addLineageRDD(output.getName(),
input1.getName());
+ sec.addLineageRDD(output.getName(),
input1.getName());
}
}
//left indexing
@@ -129,7 +133,7 @@ public class MatrixIndexingSPInstruction extends
IndexingSPInstruction {
//update and check output dimensions
DataCharacteristics mcOut =
sec.getDataCharacteristics(output.getName());
- DataCharacteristics mcLeft =
ec.getDataCharacteristics(input1.getName());
+ DataCharacteristics mcLeft = mcIn;
mcOut.set(mcLeft.getRows(), mcLeft.getCols(),
mcLeft.getBlocksize(), mcLeft.getBlocksize());
checkValidOutputDimensions(mcOut);
diff --git
a/src/test/java/org/apache/sysds/test/functions/indexing/UnboundedScalarRightIndexingTest.java
b/src/test/java/org/apache/sysds/test/functions/indexing/UnboundedScalarRightIndexingTest.java
index 8fc9bf6d9e..d32e7865ec 100644
---
a/src/test/java/org/apache/sysds/test/functions/indexing/UnboundedScalarRightIndexingTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/indexing/UnboundedScalarRightIndexingTest.java
@@ -73,10 +73,10 @@ public class UnboundedScalarRightIndexingTest extends
AutomatedTestBase
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
try {
- TestConfiguration config = getTestConfiguration(TEST_NAME);
- loadTestConfiguration(config);
-
- String RI_HOME = SCRIPT_DIR + TEST_DIR;
+ TestConfiguration config =
getTestConfiguration(TEST_NAME);
+ loadTestConfiguration(config);
+
+ String RI_HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = RI_HOME + TEST_NAME + ".dml";
programArgs = new String[]{ "-args",
String.valueOf(val) };
fullRScriptName = RI_HOME + TEST_NAME + ".R";