This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
commit d61c3bffc677443c28d0fca27364267c1ca41111 Author: baunsgaard <[email protected]> AuthorDate: Thu Nov 12 18:13:40 2020 +0100 [SYSTEMDS-2724] Cast to matrix Federated Closes #1100 --- .../instructions/fed/FEDInstructionUtils.java | 26 ++-- .../instructions/fed/VariableFEDInstruction.java | 65 ++++++--- .../primitives/FederetedCastToFrameTest.java | 4 +- .../primitives/FederetedCastToMatrixTest.java | 160 +++++++++++++++++++++ .../test/functions/frame/DetectSchemaTest.java | 6 +- .../test/functions/lineage/CacheEvictionTest.java | 1 + .../primitives/FederatedCastToMatrixTest.dml | 26 ++++ .../FederatedCastToMatrixTestReference.dml | 25 ++++ 8 files changed, 274 insertions(+), 39 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java index 68e1cee..ef66b66 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java @@ -36,8 +36,6 @@ import org.apache.sysds.runtime.instructions.cp.MatrixIndexingCPInstruction; import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction; import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction; import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction; -import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction; -import org.apache.sysds.runtime.instructions.cp.UnaryMatrixCPInstruction; import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction; import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction.VariableOperationCode; import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction; @@ -84,15 +82,13 @@ public class FEDInstructionUtils { if( mo.isFederated() ) fedinst = TsmmFEDInstruction.parseInstruction(linst.getInstructionString()); } - else if(inst instanceof UnaryCPInstruction){ - if (inst instanceof AggregateUnaryCPInstruction) { - AggregateUnaryCPInstruction instruction = (AggregateUnaryCPInstruction) inst; - if( instruction.input1.isMatrix() && ec.containsVariable(instruction.input1) ) { - MatrixObject mo1 = ec.getMatrixObject(instruction.input1); - if (mo1.isFederated() && instruction.getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT){ - LOG.debug("Federated UnaryAggregate"); - fedinst = AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString()); - } + else if (inst instanceof AggregateUnaryCPInstruction) { + AggregateUnaryCPInstruction instruction = (AggregateUnaryCPInstruction) inst; + if( instruction.input1.isMatrix() && ec.containsVariable(instruction.input1) ) { + MatrixObject mo1 = ec.getMatrixObject(instruction.input1); + if (mo1.isFederated() && instruction.getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT){ + LOG.debug("Federated UnaryAggregate"); + fedinst = AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString()); } } } @@ -154,11 +150,13 @@ public class FEDInstructionUtils { && ec.getCacheableData(ins.getInput1()).isFederated()){ fedinst = VariableFEDInstruction.parseInstruction(ins); } - + else if(ins.getVariableOpcode() == VariableOperationCode.CastAsMatrixVariable + && ins.getInput1().isFrame() + && ec.getCacheableData(ins.getInput1()).isFederated()){ + fedinst = VariableFEDInstruction.parseInstruction(ins); + } } - - //set thread id for federated context management if( fedinst != null ) { fedinst.setTID(ec.getTID()); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java index 7d39e9d..134a2e3 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java @@ -87,37 +87,62 @@ public class VariableFEDInstruction extends FEDInstruction implements LineageTra _in.processInstruction(ec); } - private void processCastAsMatrixVariableInstruction(ExecutionContext ec){ - LOG.error("Not Implemented"); - throw new DMLRuntimeException("Not Implemented Cast as Matrix"); + private void processCastAsMatrixVariableInstruction(ExecutionContext ec) { + FrameObject mo1 = ec.getFrameObject(_in.getInput1()); + + if(!mo1.isFederated()) + throw new DMLRuntimeException( + "Federated Reorg: " + "Federated input expected, but invoked w/ " + mo1.isFederated()); + + // execute function at federated site. + FederatedRequest fr1 = FederationUtils.callInstruction(_in.getInstructionString(), + _in.getOutput(), + new CPOperand[] {_in.getInput1()}, + new long[] {mo1.getFedMapping().getID()}); + mo1.getFedMapping().execute(getTID(), true, fr1); + + // Construct output local. + + MatrixObject out = ec.getMatrixObject(_in.getOutput()); + FederationMap outMap = mo1.getFedMapping().copyWithNewID(fr1.getID()); + Map<FederatedRange, FederatedData> newMap = new HashMap<>(); + for(Map.Entry<FederatedRange, FederatedData> pair : outMap.getFedMapping().entrySet()) { + FederatedData om = pair.getValue(); + FederatedData nf = new FederatedData(Types.DataType.MATRIX, om.getAddress(), om.getFilepath(), + om.getVarID()); + newMap.put(pair.getKey(), nf); + } + out.setFedMapping(outMap); } - private void processCastAsFrameVariableInstruction(ExecutionContext ec){ + private void processCastAsFrameVariableInstruction(ExecutionContext ec) { MatrixObject mo1 = ec.getMatrixObject(_in.getInput1()); - - if( !mo1.isFederated() ) - throw new DMLRuntimeException("Federated Reorg: " - + "Federated input expected, but invoked w/ "+mo1.isFederated()); - - //execute transpose at federated site - FederatedRequest fr1 = FederationUtils.callInstruction(_in.getInstructionString(), _in.getOutput(), - new CPOperand[]{_in.getInput1()}, new long[]{mo1.getFedMapping().getID()}); + + if(!mo1.isFederated()) + throw new DMLRuntimeException( + "Federated Reorg: " + "Federated input expected, but invoked w/ " + mo1.isFederated()); + + // execute function at federated site. + FederatedRequest fr1 = FederationUtils.callInstruction(_in.getInstructionString(), + _in.getOutput(), + new CPOperand[] {_in.getInput1()}, + new long[] {mo1.getFedMapping().getID()}); mo1.getFedMapping().execute(getTID(), true, fr1); - - //drive output federated mapping + + // Construct output local. FrameObject out = ec.getFrameObject(_in.getOutput()); - out.getDataCharacteristics().set(mo1.getNumColumns(), - mo1.getNumRows(), (int)mo1.getBlocksize(), mo1.getNnz()); - FederationMap outMap = mo1.getFedMapping().copyWithNewID(fr1.getID()); + out.getDataCharacteristics().set(mo1.getNumColumns(), mo1.getNumRows(), (int) mo1.getBlocksize(), mo1.getNnz()); + FederationMap outMap = mo1.getFedMapping().copyWithNewID(fr1.getID()); Map<FederatedRange, FederatedData> newMap = new HashMap<>(); - for(Map.Entry<FederatedRange, FederatedData> pair : outMap.getFedMapping().entrySet()){ + for(Map.Entry<FederatedRange, FederatedData> pair : outMap.getFedMapping().entrySet()) { FederatedData om = pair.getValue(); - FederatedData nf = new FederatedData(Types.DataType.FRAME, om.getAddress(),om.getFilepath(),om.getVarID()); + FederatedData nf = new FederatedData(Types.DataType.FRAME, om.getAddress(), om.getFilepath(), + om.getVarID()); newMap.put(pair.getKey(), nf); } - ValueType[] schema = new ValueType[(int)mo1.getDataCharacteristics().getCols()]; + ValueType[] schema = new ValueType[(int) mo1.getDataCharacteristics().getCols()]; Arrays.fill(schema, ValueType.FP64); out.setSchema(schema); out.setFedMapping(outMap); diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToFrameTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToFrameTest.java index bbef96e..5e05bf5 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToFrameTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToFrameTest.java @@ -114,12 +114,12 @@ public class FederetedCastToFrameTest extends AutomatedTestBase { "X2=" + TestUtils.federatedAddress(port2, input("X2")), "r=" + rows, "c=" + cols}; String fedOut = runTest(null).toString(); - LOG.error(fedOut); + LOG.debug(fedOut); fedOut = fedOut.split("SystemDS Statistics:")[0]; Assert.assertTrue("Equal Printed Output", out.equals(fedOut)); Assert.assertTrue("Contains federated Cast to frame", heavyHittersContainsString("fed_castdtf")); TestUtils.shutdownThreads(t1, t2); - + rtplatform = platformOld; DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToMatrixTest.java new file mode 100644 index 0000000..b075e47 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToMatrixTest.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.federated.primitives; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.DataType; +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.io.FrameWriter; +import org.apache.sysds.runtime.io.FrameWriterFactory; +import org.apache.sysds.runtime.matrix.data.FrameBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.test.functions.frame.DetectSchemaTest; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(value = Parameterized.class) [email protected] +public class FederetedCastToMatrixTest extends AutomatedTestBase { + private static final Log LOG = LogFactory.getLog(FederetedCastToMatrixTest.class.getName()); + + private final static String TEST_DIR = "functions/federated/primitives/"; + private final static String TEST_NAME = "FederatedCastToMatrixTest"; + private final static String TEST_CLASS_DIR = TEST_DIR + FederetedCastToMatrixTest.class.getSimpleName() + "/"; + + @Parameterized.Parameter() + public int rows; + @Parameterized.Parameter(1) + public int cols; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + } + + @Parameterized.Parameters + public static Collection<Object[]> data() { + // rows have to be even and > 1 + return Arrays.asList(new Object[][] {{10, 32}}); + } + + @Test + public void federatedMultiplyCP() { + federatedMultiply(Types.ExecMode.SINGLE_NODE); + } + + @Test + @Ignore + public void federatedMultiplySP() { + // TODO Fix me Spark execution error + federatedMultiply(Types.ExecMode.SPARK); + } + + public void federatedMultiply(Types.ExecMode execMode) { + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + Types.ExecMode platformOld = rtplatform; + rtplatform = execMode; + if(rtplatform == Types.ExecMode.SPARK) { + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + } + try { + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + + ValueType[] schema = new ValueType[cols]; + Arrays.fill(schema, ValueType.FP64); + FrameBlock frame1 = new FrameBlock(schema); + FrameBlock frame2 = new FrameBlock(schema); + FrameWriter writer = FrameWriterFactory.createFrameWriter(FileFormat.BINARY); + + // write input matrices + int halfRows = rows / 2; + // We have two matrices handled by a single federated worker + double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42); + double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340); + + DetectSchemaTest.initFrameDataString(frame1, X1, schema, halfRows, cols); + writer.writeFrameToHDFS(frame1.slice(0, halfRows - 1, 0, schema.length - 1, new FrameBlock()), + input("X1"), + halfRows, + schema.length); + + DetectSchemaTest.initFrameDataString(frame2, X2, schema, halfRows, cols); + writer.writeFrameToHDFS(frame2.slice(0, halfRows - 1, 0, schema.length - 1, new FrameBlock()), + input("X2"), + halfRows, + schema.length); + + MatrixCharacteristics mc = new MatrixCharacteristics(X1.length, X1[0].length, + OptimizerUtils.DEFAULT_BLOCKSIZE, -1); + HDFSTool.writeMetaDataFile(input("X1") + ".mtd", null, schema, DataType.FRAME, mc, FileFormat.BINARY); + HDFSTool.writeMetaDataFile(input("X2") + ".mtd", null, schema, DataType.FRAME, mc, FileFormat.BINARY); + + int port1 = getRandomAvailablePort(); + int port2 = getRandomAvailablePort(); + Thread t1 = startLocalFedWorkerThread(port1); + Thread t2 = startLocalFedWorkerThread(port2); + + TestConfiguration config = availableTestConfigurations.get(TEST_NAME); + loadTestConfiguration(config); + + // Run reference dml script with normal matrix + fullDMLScriptName = HOME + TEST_NAME + "Reference.dml"; + programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2")}; + String out = runTest(null).toString().split("SystemDS Statistics:")[0]; + + // Run actual dml script with federated matrix + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-stats", "100", "-nvargs", + "X1=" + TestUtils.federatedAddress(port1, input("X1")), + "X2=" + TestUtils.federatedAddress(port2, input("X2")), "r=" + rows, "c=" + cols}; + String fedOut = runTest(null).toString(); + + LOG.debug(fedOut); + fedOut = fedOut.split("SystemDS Statistics:")[0]; + Assert.assertTrue("Equal Printed Output", out.equals(fedOut)); + Assert.assertTrue("Contains federated Cast to frame", heavyHittersContainsString("fed_castdtm")); + TestUtils.shutdownThreads(t1, t2); + + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + catch(IOException e) { + Assert.fail("Error writing input frame."); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/frame/DetectSchemaTest.java b/src/test/java/org/apache/sysds/test/functions/frame/DetectSchemaTest.java index 67d5626..69d3dc5 100644 --- a/src/test/java/org/apache/sysds/test/functions/frame/DetectSchemaTest.java +++ b/src/test/java/org/apache/sysds/test/functions/frame/DetectSchemaTest.java @@ -117,7 +117,7 @@ public class DetectSchemaTest extends AutomatedTestBase { } else { double[][] A = getRandomMatrix(rows, 3, -Float.MAX_VALUE, Float.MAX_VALUE, 0.7, 2373); - initFrameDataString(frame1, A, schema); + initFrameDataString(frame1, A, schema, rows, 3); writer.writeFrameToHDFS(frame1.slice(0, rows-1, 0, schema.length-1, new FrameBlock()), input("A"), rows, schema.length); schema[schema.length-2] = Types.ValueType.FP64; } @@ -143,8 +143,8 @@ public class DetectSchemaTest extends AutomatedTestBase { } } - private static void initFrameDataString(FrameBlock frame1, double[][] data, Types.ValueType[] lschema) { - for (int j = 0; j < 3; j++) { + public static void initFrameDataString(FrameBlock frame1, double[][] data, Types.ValueType[] lschema, int rows, int cols) { + for (int j = 0; j < cols; j++) { Types.ValueType vt = lschema[j]; switch (vt) { case STRING: diff --git a/src/test/java/org/apache/sysds/test/functions/lineage/CacheEvictionTest.java b/src/test/java/org/apache/sysds/test/functions/lineage/CacheEvictionTest.java index 4f4d4a7..ac23d2e 100644 --- a/src/test/java/org/apache/sysds/test/functions/lineage/CacheEvictionTest.java +++ b/src/test/java/org/apache/sysds/test/functions/lineage/CacheEvictionTest.java @@ -17,6 +17,7 @@ * under the License. */ + package org.apache.sysds.test.functions.lineage; import java.util.ArrayList; diff --git a/src/test/scripts/functions/federated/primitives/FederatedCastToMatrixTest.dml b/src/test/scripts/functions/federated/primitives/FederatedCastToMatrixTest.dml new file mode 100644 index 0000000..52b9889 --- /dev/null +++ b/src/test/scripts/functions/federated/primitives/FederatedCastToMatrixTest.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = federated(type="frame", addresses=list($X1, $X2), + ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c))) + +Z = as.matrix(X) +print(toString(Z[1])) diff --git a/src/test/scripts/functions/federated/primitives/FederatedCastToMatrixTestReference.dml b/src/test/scripts/functions/federated/primitives/FederatedCastToMatrixTestReference.dml new file mode 100644 index 0000000..a0db27d --- /dev/null +++ b/src/test/scripts/functions/federated/primitives/FederatedCastToMatrixTestReference.dml @@ -0,0 +1,25 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rbind(read($X1), read($X2)) + +Z = as.matrix(X) +print(toString(Z[1]))
