This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 9f41108cc498e13f03095d4ebb1b903cde9010ec Author: Matthias Boehm <[email protected]> AuthorDate: Sat Oct 31 21:34:27 2020 +0100 [SYSTEMDS-2709] Fix missing federated unary aggregate for scalar mean With the fixed missing size propagation for federated init statements, now rewrites trigger, which expose operations we don't support yet. This patch adds, besides the existing row means and columns means, also support for full mean aggregates. --- .../controlprogram/federated/FederationUtils.java | 16 ++++++--- .../fed/AggregateUnaryFEDInstruction.java | 2 +- .../federated/FederatedTestObjectConstructor.java | 40 +++++++++++----------- 3 files changed, 32 insertions(+), 26 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java index c8da781..37cb7d5 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java @@ -98,7 +98,9 @@ public class FederationUtils { MatrixBlock ret = null; long size = 0; for(int i=0; i<ffr.length; i++) { - MatrixBlock tmp = (MatrixBlock)ffr[i].get().getData()[0]; + Object input = ffr[i].get().getData()[0]; + MatrixBlock tmp = (input instanceof ScalarObject) ? + new MatrixBlock(((ScalarObject)input).getDoubleValue()) : (MatrixBlock) input; size += ranges[i].getSize(0); sop1 = sop1.setConstant(ranges[i].getSize(0)); tmp = tmp.scalarOperations(sop1, new MatrixBlock()); @@ -167,10 +169,11 @@ public class FederationUtils { } } - public static ScalarObject aggScalar(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr) { + public static ScalarObject aggScalar(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr, FederationMap map) { if(!(aop.aggOp.increOp.fn instanceof KahanFunction || (aop.aggOp.increOp.fn instanceof Builtin && - (((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN || - ((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX)))) { + (((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN + || ((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX) + || aop.aggOp.increOp.fn instanceof Mean ))) { throw new DMLRuntimeException("Unsupported aggregation operator: " + aop.aggOp.increOp.getClass().getSimpleName()); } @@ -181,7 +184,10 @@ public class FederationUtils { boolean isMin = ((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN; return new DoubleObject(aggMinMax(ffr, isMin, true, Optional.empty()).getValue(0,0)); } - else { + else if( aop.aggOp.increOp.fn instanceof Mean ) { + return new DoubleObject(aggMean(ffr, map).getValue(0,0)); + } + else { //if (aop.aggOp.increOp.fn instanceof KahanFunction) double sum = 0; //uak+ for( Future<FederatedResponse> fr : ffr ) sum += ((ScalarObject)fr.get().getData()[0]).getDoubleValue(); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java index 60fe40b..d06dfaa 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java @@ -66,7 +66,7 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction { //execute federated commands and cleanups Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, fr2, fr3); if( output.isScalar() ) - ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, tmp)); + ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, tmp, map)); else ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, tmp, map)); } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedTestObjectConstructor.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedTestObjectConstructor.java index af55b95..a970479 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/FederatedTestObjectConstructor.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedTestObjectConstructor.java @@ -37,26 +37,26 @@ import org.apache.sysds.runtime.meta.MetaData; import org.junit.Assert; public class FederatedTestObjectConstructor { - public static MatrixObject constructFederatedInput(int rows, int cols, int blocksize, String host, long[][] begin, - long[][] end, int[] ports, String[] inputs, String file) { - MatrixObject fed = new MatrixObject(ValueType.FP64, file); - try { - fed.setMetaData(new MetaData(new MatrixCharacteristics(rows, cols, blocksize, rows * cols))); - List<Pair<FederatedRange, FederatedData>> d = new ArrayList<>(); - for(int i = 0; i < ports.length; i++) { - FederatedRange X1r = new FederatedRange(begin[i], end[i]); - FederatedData X1d = new FederatedData(Types.DataType.MATRIX, - new InetSocketAddress(InetAddress.getByName(host), ports[i]), inputs[i]); - d.add(new ImmutablePair<>(X1r, X1d)); - } + public static MatrixObject constructFederatedInput(int rows, int cols, int blocksize, String host, long[][] begin, + long[][] end, int[] ports, String[] inputs, String file) { + MatrixObject fed = new MatrixObject(ValueType.FP64, file); + try { + fed.setMetaData(new MetaData(new MatrixCharacteristics(rows, cols, blocksize, rows * cols))); + List<Pair<FederatedRange, FederatedData>> d = new ArrayList<>(); + for(int i = 0; i < ports.length; i++) { + FederatedRange X1r = new FederatedRange(begin[i], end[i]); + FederatedData X1d = new FederatedData(Types.DataType.MATRIX, + new InetSocketAddress(InetAddress.getByName(host), ports[i]), inputs[i]); + d.add(new ImmutablePair<>(X1r, X1d)); + } - InitFEDInstruction.federateMatrix(fed, d); - } - catch(Exception e) { - e.printStackTrace(); - Assert.assertTrue(false); - } - return fed; + InitFEDInstruction.federateMatrix(fed, d); + } + catch(Exception e) { + e.printStackTrace(); + Assert.assertTrue(false); + } + return fed; - } + } }
