This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 998d82e27b8add5a0ca55ac687f0bfd9abe54c8b
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Oct 31 20:25:59 2020 +0100

    [SYSTEMDS-2549] Extended federated binary element-wise operations
    
    This patch generalizes the existing federated binary element-wise
    operations to avoid unsupported scenarios. Specifically, if the
    right-hand-side matrix (instead of left-hand-side) matrix is federated
    and the operation is commutative (e.g., mult/add) we canonicalize the
    inputs accordingly.
---
 .../fed/BinaryMatrixMatrixFEDInstruction.java           | 17 +++++++++++++----
 .../sysds/runtime/matrix/operators/BinaryOperator.java  |  7 +++++++
 .../apache/sysds/runtime/meta/DataCharacteristics.java  |  4 +++-
 .../sysds/runtime/meta/MatrixCharacteristics.java       | 12 +++++++++++-
 .../sysds/runtime/meta/TensorCharacteristics.java       |  9 +++++++++
 .../federated/algorithms/FederatedGLMTest.java          |  2 +-
 .../federated/algorithms/FederatedKmeansTest.java       |  4 +++-
 7 files changed, 47 insertions(+), 8 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index bceb6ae..ea34df1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -25,6 +25,7 @@ import 
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
 public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
@@ -39,8 +40,16 @@ public class BinaryMatrixMatrixFEDInstruction extends 
BinaryFEDInstruction
                MatrixObject mo1 = ec.getMatrixObject(input1);
                MatrixObject mo2 = ec.getMatrixObject(input2);
                
+               //canonicalization for federated lhs
+               if( !mo1.isFederated() && mo2.isFederated() 
+                       && 
mo1.getDataCharacteristics().equalDims(mo2.getDataCharacteristics()) 
+                       && ((BinaryOperator)_optr).isCommutative() ) {
+                       mo1 = ec.getMatrixObject(input2);
+                       mo2 = ec.getMatrixObject(input1);
+               }
+               
+               //execute federated operation on mo1 or mo2
                FederatedRequest fr2 = null;
-
                if( mo2.isFederated() ) {
                        if(mo1.isFederated() && 
mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)) {
                                fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, 
input2},
@@ -48,12 +57,12 @@ public class BinaryMatrixMatrixFEDInstruction extends 
BinaryFEDInstruction
                                mo1.getFedMapping().execute(getTID(), true, 
fr2);
                        }
                        else {
-                               throw new DMLRuntimeException("Matrix-matrix 
binary operations "
-                                       + " with a federated right input are 
not supported yet.");
+                               throw new DMLRuntimeException("Matrix-matrix 
binary operations with a "
+                                       + "federated right input are only 
supported for special cases yet.");
                        }
                }
                else {
-                       //matrix-matrix binary oFederatedRequest fr2 = 
null;perations -> lhs fed input -> fed output
+                       //matrix-matrix binary operations -> lhs fed input -> 
fed output
                        if(mo2.getNumRows() > 1 && mo2.getNumColumns() == 1 ) { 
//MV row vector
                                FederatedRequest[] fr1 = 
mo1.getFedMapping().broadcastSliced(mo2, false);
                                fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, 
input2},
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
index beca629..bc4cdd0 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
@@ -56,6 +56,7 @@ public class BinaryOperator  extends Operator implements 
Serializable
        private static final long serialVersionUID = -2547950181558989209L;
 
        public final ValueFunction fn;
+       public final boolean commutative;
        
        public BinaryOperator(ValueFunction p) {
                //binaryop is sparse-safe iff (0 op 0) == 0
@@ -65,6 +66,8 @@ public class BinaryOperator  extends Operator implements 
Serializable
                        || p instanceof BitwAnd || p instanceof BitwOr || p 
instanceof BitwXor
                        || p instanceof BitwShiftL || p instanceof BitwShiftR);
                fn = p;
+               commutative = p instanceof Plus || p instanceof Multiply 
+                       || p instanceof And || p instanceof Or || p instanceof 
Xor;
        }
        
        /**
@@ -111,6 +114,10 @@ public class BinaryOperator  extends Operator implements 
Serializable
                return null;
        }
        
+       public boolean isCommutative() {
+               return commutative;
+       }
+       
        @Override
        public String toString() {
                return "BinaryOperator("+fn.getClass().getSimpleName()+")";
diff --git 
a/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java 
b/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
index d71ce9d..a28d98d 100644
--- a/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
+++ b/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
@@ -188,9 +188,11 @@ public abstract class DataCharacteristics implements 
Serializable {
                dimOut.set(dim1.getRows(), dim2.getCols(), dim1.getBlocksize());
        }
 
+       public abstract boolean equalDims(Object anObject);
+
        @Override
        public abstract boolean equals(Object anObject);
-
+       
        @Override
        public abstract int hashCode();
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java 
b/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java
index 0b29cce..bdc4b21 100644
--- a/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java
+++ b/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java
@@ -229,7 +229,17 @@ public class MatrixCharacteristics extends 
DataCharacteristics
                return !nnzKnown() || numRows==0 || numColumns==0
                        || (nonZero < numRows*numColumns - singleBlk);
        }
-
+       
+       @Override
+       public boolean equalDims(Object anObject) {
+               if( !(anObject instanceof MatrixCharacteristics) )
+                       return false;
+               MatrixCharacteristics mc = (MatrixCharacteristics) anObject;
+               return dimsKnown() && mc.dimsKnown()
+                       && numRows == mc.numRows
+                       && numColumns == mc.numColumns;
+       }
+       
        @Override
        public boolean equals (Object anObject) {
                if( !(anObject instanceof MatrixCharacteristics) )
diff --git 
a/src/main/java/org/apache/sysds/runtime/meta/TensorCharacteristics.java 
b/src/main/java/org/apache/sysds/runtime/meta/TensorCharacteristics.java
index 449cc2d..2b554a2 100644
--- a/src/main/java/org/apache/sysds/runtime/meta/TensorCharacteristics.java
+++ b/src/main/java/org/apache/sysds/runtime/meta/TensorCharacteristics.java
@@ -157,6 +157,15 @@ public class TensorCharacteristics extends 
DataCharacteristics
        }
        
        @Override
+       public boolean equalDims(Object anObject) {
+               if( !(anObject instanceof TensorCharacteristics) )
+                       return false;
+               TensorCharacteristics tc = (TensorCharacteristics) anObject;
+               return dimsKnown() && tc.dimsKnown()
+                       && Arrays.equals(_dims, tc._dims);
+       }
+       
+       @Override
        public boolean equals (Object anObject) {
                if( !(anObject instanceof TensorCharacteristics) )
                        return false;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
index 2b9d287..44de28f 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
@@ -123,7 +123,7 @@ public class FederatedGLMTest extends AutomatedTestBase {
                Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
                
Assert.assertTrue(heavyHittersContainsString("fed_uark+","fed_uarsqk+"));
                Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
-               Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
+               //Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
                Assert.assertTrue(heavyHittersContainsString("fed_mmchain"));
                
                //check that federated input files are still existing
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
index eb70a4b..0dd339f 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
@@ -128,8 +128,10 @@ public class FederatedKmeansTest extends AutomatedTestBase 
{
                
                        // check for federated operations
                        
Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
-                       
Assert.assertTrue(heavyHittersContainsString("fed_uasqk+"));
+                       
//Assert.assertTrue(heavyHittersContainsString("fed_uasqk+"));
                        
Assert.assertTrue(heavyHittersContainsString("fed_uarmin"));
+                       
Assert.assertTrue(heavyHittersContainsString("fed_uark+"));
+                       
Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
                        Assert.assertTrue(heavyHittersContainsString("fed_*"));
                        Assert.assertTrue(heavyHittersContainsString("fed_+"));
                        Assert.assertTrue(heavyHittersContainsString("fed_<="));

Reply via email to