This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new f33be5b1d9 [MINOR] Support Non-literals in Federated Reshape
Instructions
f33be5b1d9 is described below
commit f33be5b1d9100e781dfa8f0ebf63390817b606bb
Author: ywcb00 <[email protected]>
AuthorDate: Sat Jul 15 15:32:56 2023 +0200
[MINOR] Support Non-literals in Federated Reshape Instructions
AMLS project SoSe'23, part I
Closes #1862.
---
.../instructions/fed/ReshapeFEDInstruction.java | 18 ++++++++++--------
.../functions/federated/io/FederatedReaderTest.java | 12 +++++-------
.../federated/primitives/FederatedMisAlignedTest.java | 8 ++++----
.../functions/federated/FederatedReshapeTest.dml | 7 ++++++-
4 files changed, 25 insertions(+), 20 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java
index 3d355cd1dd..521dbe8e51 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.Arrays;
import java.util.stream.Collectors;
+import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.lops.Lop;
@@ -119,7 +120,7 @@ public class ReshapeFEDInstruction extends
UnaryFEDInstruction {
mo1.getFedMapping().execute(getTID(), true, fr1, new
FederatedRequest[0]);
// set new fed map
- FederationMap reshapedFedMap = mo1.getFedMapping();
+ FederationMap reshapedFedMap =
mo1.getFedMapping().copyWithNewID(fr1[0].getID());
for(int i = 0; i <
reshapedFedMap.getFederatedRanges().length; i++) {
long cells =
reshapedFedMap.getFederatedRanges()[i].getSize();
long row = byRow.getBooleanValue() ? cells /
cols : rows;
@@ -140,7 +141,7 @@ public class ReshapeFEDInstruction extends
UnaryFEDInstruction {
//derive output federated mapping
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(rows, cols, (int)
mo1.getBlocksize(), mo1.getNnz());
-
out.setFedMapping(reshapedFedMap.copyWithNewID(fr1[0].getID()));
+ out.setFedMapping(reshapedFedMap);
}
else {
// TODO support tensor out, frame and list
@@ -156,14 +157,15 @@ public class ReshapeFEDInstruction extends
UnaryFEDInstruction {
.collect(Collectors.toSet()).size();
sameFedSize = sameFedSize == 1 ? 1 :
mo1.getFedMapping().getSize();
+ String execTypeName =
InstructionUtils.getExecType(instString).name();
+ String[] instParts =
InstructionUtils.getInstructionPartsWithValueType(instString);
for(int i = 0; i < sameFedSize; i++) {
- String[] instParts =
instString.split(Lop.OPERAND_DELIMITOR);
long size =
mo1.getFedMapping().getFederatedRanges()[i].getSize();
- String oldInstStringPart = byRow ? instParts[3] :
instParts[4];
- String newInstStringPart = byRow ?
- oldInstStringPart.replace(String.valueOf(rows),
String.valueOf(size/cols)) :
- oldInstStringPart.replace(String.valueOf(cols),
String.valueOf(size/rows));
- instStrings[i] = instString.replace(oldInstStringPart,
newInstStringPart);
+ instParts[2] = InstructionUtils.createLiteralOperand(
+ String.valueOf((int)(byRow ? size/cols :
rows)), Types.ValueType.INT64);
+ instParts[3] = InstructionUtils.createLiteralOperand(
+ String.valueOf((int)(byRow ? cols :
size/rows)), Types.ValueType.INT64);
+ instStrings[i] =
InstructionUtils.concatOperands(ArrayUtils.addFirst(instParts, execTypeName));
}
if(sameFedSize == 1)
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
index ff68c8328e..295fe54770 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
@@ -40,7 +40,7 @@ import org.junit.runners.Parameterized;
public class FederatedReaderTest extends AutomatedTestBase {
private static final Log LOG =
LogFactory.getLog(FederatedReaderTest.class.getName());
- private final static String TEST_DIR = "functions/federated/ioR/";
+ private final static String TEST_DIR = "functions/federated/io/";
private final static String TEST_NAME = "FederatedReaderTest";
private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedReaderTest.class.getSimpleName() + "/";
private final static int blocksize = 1024;
@@ -50,8 +50,6 @@ public class FederatedReaderTest extends AutomatedTestBase {
public int cols;
@Parameterized.Parameter(2)
public boolean rowPartitioned;
- @Parameterized.Parameter(3)
- public int fedCount;
@Override
public void setUp() {
@@ -62,7 +60,7 @@ public class FederatedReaderTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
// number of rows or cols has to be >= number of federated
locations.
- return Arrays.asList(new Object[][] {{10, 13, true, 2}});
+ return Arrays.asList(new Object[][] {{10, 13, true}});
}
@Test
@@ -111,11 +109,11 @@ public class FederatedReaderTest extends
AutomatedTestBase {
// Run reference dml script with normal matrix
if(workerCount == 1) {
- fullDMLScriptName = SCRIPT_DIR +
"functions/federated/io/" + TEST_NAME + "1Reference.dml";
+ fullDMLScriptName = SCRIPT_DIR + TEST_DIR +
TEST_NAME + "1Reference.dml";
programArgs = new String[] {"-stats", "-args",
input("X1")};
}
else {
- fullDMLScriptName = SCRIPT_DIR +
"functions/federated/io/" + TEST_NAME
+ fullDMLScriptName = SCRIPT_DIR + TEST_DIR +
TEST_NAME
+ (rowPartitioned ? "Row" : "Col") +
"2Reference.dml";
programArgs = new String[] {"-stats", "-args",
input("X1"), input("X2")};
}
@@ -125,7 +123,7 @@ public class FederatedReaderTest extends AutomatedTestBase {
LOG.debug(refOut);
// Run federated
- fullDMLScriptName = SCRIPT_DIR +
"functions/federated/io/" + TEST_NAME + ".dml";
+ fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME +
".dml";
programArgs = new String[] {"-stats", "-args",
input("X.json")};
String out = runTest(null).toString();
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMisAlignedTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMisAlignedTest.java
index ecc8a7b90f..5b4b350b08 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMisAlignedTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMisAlignedTest.java
@@ -205,10 +205,10 @@ public class FederatedMisAlignedTest extends
AutomatedTestBase {
c = cols;
}
- double[][] X1 = getRandomMatrix(r, c, 3, 3, 1, 3);
- double[][] X2 = getRandomMatrix(r, c, 3, 3, 1, 7);
- double[][] X3 = getRandomMatrix(r, c, 3, 3, 1, 8);
- double[][] X4 = getRandomMatrix(r, c, 3, 3, 1, 9);
+ double[][] X1 = getRandomMatrix(r, c, 3, 4, 1, 3);
+ double[][] X2 = getRandomMatrix(r, c, 3, 4, 1, 7);
+ double[][] X3 = getRandomMatrix(r, c, 3, 4, 1, 8);
+ double[][] X4 = getRandomMatrix(r, c, 3, 4, 1, 9);
MatrixCharacteristics mc = new MatrixCharacteristics(r, c,
blocksize, r * c);
writeInputMatrixWithMTD("X1", X1, false, mc);
diff --git a/src/test/scripts/functions/federated/FederatedReshapeTest.dml
b/src/test/scripts/functions/federated/FederatedReshapeTest.dml
index 6aa8a165b5..f133bcff17 100644
--- a/src/test/scripts/functions/federated/FederatedReshapeTest.dml
+++ b/src/test/scripts/functions/federated/FederatedReshapeTest.dml
@@ -27,5 +27,10 @@ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
ranges=list(list(0, 0), list(2, 12), list(2, 0), list(4, $cols),
list(4, 0), list(10, $cols), list(10, 0), list(12, $cols)));
-s = matrix(A, rows=$r_rows, cols=$r_cols);
+# materialize the scalar input (non-literal)
+reshape_cols = $r_cols;
+while(FALSE) {}
+reshape_cols = reshape_cols;
+
+s = matrix(A, rows=$r_rows, cols=reshape_cols);
write(s, $out_S);