This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 9f41108cc498e13f03095d4ebb1b903cde9010ec
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Oct 31 21:34:27 2020 +0100

    [SYSTEMDS-2709] Fix missing federated unary aggregate for scalar mean
    
    With the fixed missing size propagation for federated init statements,
    now rewrites trigger, which expose operations we don't support yet. This
    patch adds, besides the existing row means and columns means, also
    support for full mean aggregates.
---
 .../controlprogram/federated/FederationUtils.java  | 16 ++++++---
 .../fed/AggregateUnaryFEDInstruction.java          |  2 +-
 .../federated/FederatedTestObjectConstructor.java  | 40 +++++++++++-----------
 3 files changed, 32 insertions(+), 26 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index c8da781..37cb7d5 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -98,7 +98,9 @@ public class FederationUtils {
                        MatrixBlock ret = null;
                        long size = 0;
                        for(int i=0; i<ffr.length; i++) {
-                               MatrixBlock tmp = 
(MatrixBlock)ffr[i].get().getData()[0];
+                               Object input = ffr[i].get().getData()[0];
+                               MatrixBlock tmp = (input instanceof 
ScalarObject) ? 
+                                       new 
MatrixBlock(((ScalarObject)input).getDoubleValue()) : (MatrixBlock) input;
                                size += ranges[i].getSize(0);
                                sop1 = sop1.setConstant(ranges[i].getSize(0));
                                tmp = tmp.scalarOperations(sop1, new 
MatrixBlock());
@@ -167,10 +169,11 @@ public class FederationUtils {
                }
        }
 
-       public static ScalarObject aggScalar(AggregateUnaryOperator aop, 
Future<FederatedResponse>[] ffr) {
+       public static ScalarObject aggScalar(AggregateUnaryOperator aop, 
Future<FederatedResponse>[] ffr, FederationMap map) {
                if(!(aop.aggOp.increOp.fn instanceof KahanFunction || 
(aop.aggOp.increOp.fn instanceof Builtin &&
-                       (((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == 
BuiltinCode.MIN ||
-                               ((Builtin) 
aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX)))) {
+                       (((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == 
BuiltinCode.MIN
+                       || ((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == 
BuiltinCode.MAX)
+                       || aop.aggOp.increOp.fn instanceof Mean ))) {
                        throw new DMLRuntimeException("Unsupported aggregation 
operator: "
                                + aop.aggOp.increOp.getClass().getSimpleName());
                }
@@ -181,7 +184,10 @@ public class FederationUtils {
                                boolean isMin = ((Builtin) 
aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN;
                                return new DoubleObject(aggMinMax(ffr, isMin, 
true,  Optional.empty()).getValue(0,0));
                        }
-                       else {
+                       else if( aop.aggOp.increOp.fn instanceof Mean ) {
+                               return new DoubleObject(aggMean(ffr, 
map).getValue(0,0));
+                       }
+                       else { //if (aop.aggOp.increOp.fn instanceof 
KahanFunction)
                                double sum = 0; //uak+
                                for( Future<FederatedResponse> fr : ffr )
                                        sum += 
((ScalarObject)fr.get().getData()[0]).getDoubleValue();
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index 60fe40b..d06dfaa 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -66,7 +66,7 @@ public class AggregateUnaryFEDInstruction extends 
UnaryFEDInstruction {
                //execute federated commands and cleanups
                Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, 
fr2, fr3);
                if( output.isScalar() )
-                       ec.setVariable(output.getName(), 
FederationUtils.aggScalar(aop, tmp));
+                       ec.setVariable(output.getName(), 
FederationUtils.aggScalar(aop, tmp, map));
                else
                        ec.setMatrixOutput(output.getName(), 
FederationUtils.aggMatrix(aop, tmp, map));
        }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedTestObjectConstructor.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedTestObjectConstructor.java
index af55b95..a970479 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedTestObjectConstructor.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedTestObjectConstructor.java
@@ -37,26 +37,26 @@ import org.apache.sysds.runtime.meta.MetaData;
 import org.junit.Assert;
 
 public class FederatedTestObjectConstructor {
-    public static MatrixObject constructFederatedInput(int rows, int cols, int 
blocksize, String host, long[][] begin,
-        long[][] end, int[] ports, String[] inputs, String file) {
-        MatrixObject fed = new MatrixObject(ValueType.FP64, file);
-        try {
-            fed.setMetaData(new MetaData(new MatrixCharacteristics(rows, cols, 
blocksize, rows * cols)));
-            List<Pair<FederatedRange, FederatedData>> d = new ArrayList<>();
-            for(int i = 0; i < ports.length; i++) {
-                FederatedRange X1r = new FederatedRange(begin[i], end[i]);
-                FederatedData X1d = new FederatedData(Types.DataType.MATRIX,
-                    new InetSocketAddress(InetAddress.getByName(host), 
ports[i]), inputs[i]);
-                d.add(new ImmutablePair<>(X1r, X1d));
-            }
+       public static MatrixObject constructFederatedInput(int rows, int cols, 
int blocksize, String host, long[][] begin,
+               long[][] end, int[] ports, String[] inputs, String file) {
+               MatrixObject fed = new MatrixObject(ValueType.FP64, file);
+               try {
+                       fed.setMetaData(new MetaData(new 
MatrixCharacteristics(rows, cols, blocksize, rows * cols)));
+                       List<Pair<FederatedRange, FederatedData>> d = new 
ArrayList<>();
+                       for(int i = 0; i < ports.length; i++) {
+                               FederatedRange X1r = new 
FederatedRange(begin[i], end[i]);
+                               FederatedData X1d = new 
FederatedData(Types.DataType.MATRIX,
+                                       new 
InetSocketAddress(InetAddress.getByName(host), ports[i]), inputs[i]);
+                               d.add(new ImmutablePair<>(X1r, X1d));
+                       }
 
-            InitFEDInstruction.federateMatrix(fed, d);
-        }
-        catch(Exception e) {
-            e.printStackTrace();
-            Assert.assertTrue(false);
-        }
-        return fed;
+                       InitFEDInstruction.federateMatrix(fed, d);
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       Assert.assertTrue(false);
+               }
+               return fed;
 
-    }
+       }
 }

Reply via email to