This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new c83d2fc  [SYSTEMDS-2857] Add missing federated unary matrix operations
c83d2fc is described below

commit c83d2fcaed0c858b450c0fd730c494a2db2bd723
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Feb 11 17:49:02 2021 +0100

    [SYSTEMDS-2857] Add missing federated unary matrix operations
    
    This patch adds support for unary matrix operations such as isNaN,
    round, ceil, floor, etc. So far, we only supported specific sub classes
    of unary instructions but not the general case. The mapping into
    federated operations is simple and only executes the given operation on
    all partitions (which assumes that the entire federated matrix is
    covered).
---
 .../fed/AggregateUnaryFEDInstruction.java          | 15 +++--
 .../runtime/instructions/fed/FEDInstruction.java   |  3 +-
 .../instructions/fed/FEDInstructionUtils.java      | 15 +++--
 .../fed/UnaryMatrixFEDInstruction.java             | 65 ++++++++++++++++++++++
 .../org/apache/sysds/test/AutomatedTestBase.java   |  2 +-
 .../federated/algorithms/FederatedLogRegTest.java  | 11 ++--
 6 files changed, 92 insertions(+), 19 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index 097b678..5745ccd 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -36,18 +36,21 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
 
 public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
        
-       private AggregateUnaryFEDInstruction(AggregateUnaryOperator auop, 
CPOperand in,
-                       CPOperand out, String opcode, String istr) {
+       private AggregateUnaryFEDInstruction(AggregateUnaryOperator auop,
+               CPOperand in, CPOperand out, String opcode, String istr)
+       {
                super(FEDType.AggregateUnary, auop, in, out, opcode, istr);
        }
 
-       protected AggregateUnaryFEDInstruction(Operator op, CPOperand in1, 
CPOperand in2, CPOperand out,
-                                                                               
   String opcode, String istr) {
+       protected AggregateUnaryFEDInstruction(Operator op,
+               CPOperand in1, CPOperand in2, CPOperand out, String opcode, 
String istr)
+       {
                super(FEDType.AggregateUnary, op, in1, in2, out, opcode, istr);
        }
 
-       protected AggregateUnaryFEDInstruction(Operator op, CPOperand in1, 
CPOperand in2, CPOperand in3, CPOperand out,
-                                                                               
   String opcode, String istr) {
+       protected AggregateUnaryFEDInstruction(Operator op, CPOperand in1,
+               CPOperand in2, CPOperand in3, CPOperand out, String opcode, 
String istr)
+       {
                super(FEDType.AggregateUnary, op, in1, in2, in3, out, opcode, 
istr);
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 0cf1fac..dafd723 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -42,7 +42,8 @@ public abstract class FEDInstruction extends Instruction {
                MatrixIndexing,
                Quaternary,
                QSort,
-               QPick
+               QPick,
+               Unary
        }
 
        protected final FEDType _fedType;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 6c0e3ba..2a608f3 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -39,6 +39,7 @@ import 
org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstructio
 import org.apache.sysds.runtime.instructions.cp.QuaternaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.UnaryMatrixCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
 import 
org.apache.sysds.runtime.instructions.cp.VariableCPInstruction.VariableOperationCode;
 import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
@@ -109,16 +110,20 @@ public class FEDInstructionUtils {
                                && ec.containsVariable(instruction.input1)) {
 
                                MatrixObject mo1 = 
ec.getMatrixObject(instruction.input1);
-                               
if(instruction.getOpcode().equalsIgnoreCase("cm") && mo1.isFederated()) {
+                               
if(instruction.getOpcode().equalsIgnoreCase("cm") && mo1.isFederated())
                                        fedinst = 
CentralMomentFEDInstruction.parseInstruction(inst.getInstructionString());
-                               } else 
if(inst.getOpcode().equalsIgnoreCase("qsort") && mo1.isFederated()) {
+                               else 
if(inst.getOpcode().equalsIgnoreCase("qsort") && mo1.isFederated()) {
                                        
if(mo1.getFedMapping().getFederatedRanges().length == 1)
                                                fedinst = 
QuantileSortFEDInstruction.parseInstruction(inst.getInstructionString());
-                               } else 
if(inst.getOpcode().equalsIgnoreCase("rshape") && mo1.isFederated()) {
+                               }
+                               else 
if(inst.getOpcode().equalsIgnoreCase("rshape") && mo1.isFederated())
                                        fedinst = 
ReshapeFEDInstruction.parseInstruction(inst.getInstructionString());
-                               } else if(inst instanceof 
AggregateUnaryCPInstruction  && mo1.isFederated() &&
-                                       ((AggregateUnaryCPInstruction) 
instruction).getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT) {
+                               else if(inst instanceof 
AggregateUnaryCPInstruction  && mo1.isFederated() &&
+                                       ((AggregateUnaryCPInstruction) 
instruction).getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT)
                                        fedinst = 
AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
+                               else if(inst instanceof 
UnaryMatrixCPInstruction && mo1.isFederated()) {
+                                       
if(UnaryMatrixFEDInstruction.isValidOpcode(inst.getOpcode()))
+                                               fedinst = 
UnaryMatrixFEDInstruction.parseInstruction(inst.getInstructionString());
                                }
                        }
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
new file mode 100644
index 0000000..8e5104c
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class UnaryMatrixFEDInstruction extends UnaryFEDInstruction {
+       protected UnaryMatrixFEDInstruction(Operator op, CPOperand in, 
CPOperand out, String opcode, String instr) {
+               super(FEDType.Unary, op, in, out, opcode, instr);
+       }
+       
+       public static boolean isValidOpcode(String opcode) {
+               return !LibCommonsMath.isSupportedUnaryOperation(opcode)
+                       && !opcode.startsWith("ucum"); //ucumk+ ucum* ucumk+* 
ucummin ucummax
+       }
+
+       public static UnaryMatrixFEDInstruction parseInstruction(String str) {
+               CPOperand in = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
+               CPOperand out = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
+               String opcode = parseUnaryInstruction(str, in, out);
+               return new 
UnaryMatrixFEDInstruction(InstructionUtils.parseUnaryOperator(opcode), in, out, 
opcode, str);
+       }
+       
+       @Override 
+       public void processInstruction(ExecutionContext ec) {
+               MatrixObject mo1 = ec.getMatrixObject(input1);
+               
+               //federated execution on arbitrary row/column partitions
+               //(only assumption for sparse-unsafe: fed mapping covers entire 
matrix)
+               FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output,
+                       new CPOperand[]{input1}, new 
long[]{mo1.getFedMapping().getID()});
+               mo1.getFedMapping().execute(getTID(), true, fr1);
+               
+               //set characteristics and fed mapp
+               MatrixObject out = ec.getMatrixObject(output);
+               out.getDataCharacteristics().set(mo1.getNumRows(), 
mo1.getNumColumns(), (int)mo1.getBlocksize());
+               
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+       }
+}
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java 
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 5b7127b..74bf15a 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -116,7 +116,7 @@ public abstract class AutomatedTestBase {
        public static final double GPU_TOLERANCE = 1e-9;
 
        public static final int FED_WORKER_WAIT = 1000; // in ms
-       public static final int FED_WORKER_WAIT_S = 40; // in ms
+       public static final int FED_WORKER_WAIT_S = 50; // in ms
        
 
        // With OpenJDK 8u242 on Windows, the new changes in JDK are not 
allowing
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
index fe67bc2..7abb1a8 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
@@ -120,12 +120,11 @@ public class FederatedLogRegTest extends 
AutomatedTestBase {
                TestUtils.shutdownThreads(t1, t2);
 
                // check for federated operations
-               Assert.assertTrue("contains federated matrix mult", 
heavyHittersContainsString("fed_ba+*"));
-               Assert.assertTrue("contains federated row unary aggregate",
-                       heavyHittersContainsString("fed_uark+", "fed_uarsqk+"));
-               Assert.assertTrue("contains federated matrix mult chain or 
transpose",
-                       heavyHittersContainsString("fed_mmchain", "fed_r'"));
-
+               Assert.assertTrue("contains fed_ba+*", 
heavyHittersContainsString("fed_ba+*"));
+               Assert.assertTrue("contains fed_uar", 
heavyHittersContainsString("fed_uark+", "fed_uarsqk+"));
+               Assert.assertTrue("contains fed_mmchain & r'", 
heavyHittersContainsString("fed_mmchain", "fed_r'"));
+               Assert.assertTrue("contains fed_isnan", 
heavyHittersContainsString("fed_isnan"));
+               
                // check that federated input files are still existing
                Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
                Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));

Reply via email to