This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 854b4e9  [SYSTEMDS-2549,2624] Fix federated binary matrix-vector, var 
cleanup
854b4e9 is described below

commit 854b4e94f0e8f4c8b8e0f2867558cb90e4e8e552
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Aug 16 17:30:03 2020 +0200

    [SYSTEMDS-2549,2624] Fix federated binary matrix-vector, var cleanup
    
    This patch fixes two correctness issues related to (1) cleanup of
    federated matrices, and (2) federated binary matrix-row vector
    operators. Furthermore, this also includes a new federated Kmeans test
    and some minor fixes for row aggregates, and improvements of federated
    matrix multiplications.
---
 scripts/builtin/kmeans.dml                         |  2 +-
 .../controlprogram/context/ExecutionContext.java   |  9 +---
 .../controlprogram/federated/FederatedRange.java   |  8 +++-
 .../federated/FederatedWorkerHandler.java          |  2 +-
 .../controlprogram/federated/FederationMap.java    |  8 ++++
 .../fed/AggregateBinaryFEDInstruction.java         | 23 +++++++---
 .../fed/BinaryMatrixMatrixFEDInstruction.java      | 24 ++++++++---
 .../sysds/runtime/meta/DataCharacteristics.java    |  2 +-
 ...eratedPCATest.java => FederatedKmeansTest.java} | 50 ++++++++++++----------
 .../test/functions/federated/FederatedPCATest.java |  5 +++
 .../functions/federated/FederatedKmeansTest.dml    | 25 +++++++++++
 .../federated/FederatedKmeansTestReference.dml     | 24 +++++++++++
 12 files changed, 132 insertions(+), 50 deletions(-)

diff --git a/scripts/builtin/kmeans.dml b/scripts/builtin/kmeans.dml
index f18466d..90a7222 100644
--- a/scripts/builtin/kmeans.dml
+++ b/scripts/builtin/kmeans.dml
@@ -160,7 +160,7 @@ m_kmeans = function(Matrix[Double] X, Integer k = 10, 
Integer runs = 10, Integer
         C_old = C; C = C_new;
     }
 
-    if(is_verbose == TRUE)
+    if(is_verbose)
       print ("Run " + run_index + ", Iteration " + iter_count + ":  Terminated 
with code = "
         + term_code + ",  Centroid WCSS = " + wcss);
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index 31a467f..fcb5db3 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -59,9 +59,7 @@ import org.apache.sysds.utils.Statistics;
 
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.HashSet;
 import java.util.List;
-import java.util.Set;
 import java.util.stream.Collectors;
 
 public class ExecutionContext {
@@ -73,7 +71,6 @@ public class ExecutionContext {
        //symbol table
        protected LocalVariableMap _variables;
        protected boolean _autoCreateVars;
-       protected Set<String> _guardedFiles = new HashSet<>();
        
        //lineage map, cache, prepared dedup blocks
        protected Lineage _lineage;
@@ -134,10 +131,6 @@ public class ExecutionContext {
        public void setAutoCreateVars(boolean flag) {
                _autoCreateVars = flag;
        }
-       
-       public void addGuardedFilename(String fname) {
-               _guardedFiles.add(fname);
-       }
 
        /**
         * Get the i-th GPUContext
@@ -758,7 +751,7 @@ public class ExecutionContext {
                        //compute ref count only if matrix cleanup actually 
necessary
                        if ( mo.isCleanupEnabled() && 
!getVariables().hasReferences(mo) )  {
                                mo.clearData(); //clean cached data
-                               if( fileExists && 
!_guardedFiles.contains(mo.getFileName()) ) {
+                               if( fileExists ) {
                                        
HDFSTool.deleteFileIfExistOnHDFS(mo.getFileName());
                                        
HDFSTool.deleteFileIfExistOnHDFS(mo.getFileName()+".mtd");
                                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
index 6571666..46ebce2 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
@@ -41,8 +41,12 @@ public class FederatedRange implements 
Comparable<FederatedRange> {
         * @param other the <code>FederatedRange</code> to copy
         */
        public FederatedRange(FederatedRange other) {
-               _beginDims = other._beginDims.clone();
-               _endDims = other._endDims.clone();
+               this(other._beginDims.clone(), other._endDims.clone());
+       }
+       
+       public FederatedRange(FederatedRange other, long clen) {
+               this(other._beginDims.clone(), other._endDims.clone());
+               _endDims[1] = clen;
        }
        
        public void setBeginDim(int dim, long value) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 47ca43c..1afbfb1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -181,7 +181,7 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                
                //TODO spawn async load of data, otherwise on first access
                _ec.setVariable(String.valueOf(id), cd);
-               _ec.addGuardedFilename(filename);
+               cd.enableCleanup(false); //guard against deletion
                
                if (dataType == Types.DataType.FRAME) {
                        FrameObject frameObject = (FrameObject) cd;
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 04532fd..d323bad 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -149,6 +149,14 @@ public class FederationMap
                        map.put(new FederatedRange(e.getKey()), new 
FederatedData(e.getValue(), id));
                return new FederationMap(id, map);
        }
+       
+       public FederationMap copyWithNewID(long id, long clen) {
+               Map<FederatedRange, FederatedData> map = new TreeMap<>();
+               //TODO handling of file path, but no danger as never written
+               for( Entry<FederatedRange, FederatedData> e : 
_fedMap.entrySet() )
+                       map.put(new FederatedRange(e.getKey(), clen), new 
FederatedData(e.getValue(), id));
+               return new FederationMap(id, map);
+       }
 
        public FederationMap rbind(long offset, FederationMap that) {
                for( Entry<FederatedRange, FederatedData> e : 
that._fedMap.entrySet() ) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 3fe1004..14f81bf 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -66,13 +66,22 @@ public class AggregateBinaryFEDInstruction extends 
BinaryFEDInstruction {
                        FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(mo2);
                        FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output,
                                new CPOperand[]{input1, input2}, new 
long[]{mo1.getFedMapping().getID(), fr1.getID()});
-                       FederatedRequest fr3 = new 
FederatedRequest(RequestType.GET_VAR, fr2.getID());
-                       //execute federated operations and aggregate
-                       Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(fr1, fr2, fr3);
-                       MatrixBlock ret = FederationUtils.rbind(tmp);
-                       mo1.getFedMapping().cleanup(fr1.getID(), fr2.getID());
-                       ec.setMatrixOutput(output.getName(), ret);
-                       //TODO should remain federated matrix (no need for agg)
+                       if( mo2.getNumColumns() == 1 ) { //MV
+                               FederatedRequest fr3 = new 
FederatedRequest(RequestType.GET_VAR, fr2.getID());
+                               //execute federated operations and aggregate
+                               Future<FederatedResponse>[] tmp = 
mo1.getFedMapping().execute(fr1, fr2, fr3);
+                               MatrixBlock ret = FederationUtils.rbind(tmp);
+                               mo1.getFedMapping().cleanup(fr1.getID(), 
fr2.getID());
+                               ec.setMatrixOutput(output.getName(), ret);
+                       }
+                       else { //MM
+                               //execute federated operations and aggregate
+                               mo1.getFedMapping().execute(fr1, fr2);
+                               mo1.getFedMapping().cleanup(fr1.getID());
+                               MatrixObject out = ec.getMatrixObject(output);
+                               
out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), 
(int)mo1.getBlocksize());
+                               
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID(), 
mo2.getNumColumns()));
+                       }
                }
                //#2 vector - federated matrix multiplication
                else if (mo2.isFederated()) {// VM + MM
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index d124c76..7813f6a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -45,13 +45,23 @@ public class BinaryMatrixMatrixFEDInstruction extends 
BinaryFEDInstruction
                }
                
                //matrix-matrix binary operations -> lhs fed input -> fed output
-               FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
-               FederatedRequest fr2 = 
FederationUtils.callInstruction(instString, output,
-                       new CPOperand[]{input1, input2}, new 
long[]{mo1.getFedMapping().getID(), fr1.getID()});
-               
-               //execute federated instruction and cleanup intermediates
-               mo1.getFedMapping().execute(fr1, fr2);
-               mo1.getFedMapping().cleanup(fr1.getID());
+               FederatedRequest fr2 = null;
+               if(mo2.getNumRows() > 1 && mo2.getNumColumns() == 1 ) { //MV 
row vector
+                       FederatedRequest[] fr1 = 
mo1.getFedMapping().broadcastSliced(mo2, false);
+                       fr2 = FederationUtils.callInstruction(instString, 
output, new CPOperand[]{input1, input2},
+                               new long[]{mo1.getFedMapping().getID(), 
fr1[0].getID()});
+                       //execute federated instruction and cleanup 
intermediates
+                       mo1.getFedMapping().execute(fr1, fr2);
+                       mo1.getFedMapping().cleanup(fr1[0].getID());
+               }
+               else { //MM or MV col vector
+                       FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(mo2);
+                       fr2 = FederationUtils.callInstruction(instString, 
output, new CPOperand[]{input1, input2},
+                               new long[]{mo1.getFedMapping().getID(), 
fr1.getID()});
+                       //execute federated instruction and cleanup 
intermediates
+                       mo1.getFedMapping().execute(fr1, fr2);
+                       mo1.getFedMapping().cleanup(fr1.getID());
+               }
                
                //derive new fed mapping for output
                MatrixObject out = ec.getMatrixObject(output);
diff --git 
a/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java 
b/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
index 58bdcd0..d71ce9d 100644
--- a/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
+++ b/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
@@ -31,7 +31,7 @@ public abstract class DataCharacteristics implements 
Serializable {
 
        protected int _blocksize;
        
-       public DataCharacteristics set(long nr, long nc, int len) {
+       public DataCharacteristics set(long nr, long nc, int blen) {
                throw new DMLRuntimeException("DataCharacteristics.set(long, 
long, int): should never get called in the base class");
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
similarity index 73%
copy from 
src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
copy to 
src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
index bf674a8..1ef2384 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
@@ -27,6 +27,7 @@ import org.junit.runners.Parameterized;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
@@ -36,11 +37,11 @@ import java.util.Collection;
 
 @RunWith(value = Parameterized.class)
 @net.jcip.annotations.NotThreadSafe
-public class FederatedPCATest extends AutomatedTestBase {
+public class FederatedKmeansTest extends AutomatedTestBase {
 
        private final static String TEST_DIR = "functions/federated/";
-       private final static String TEST_NAME = "FederatedPCATest";
-       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedPCATest.class.getSimpleName() + "/";
+       private final static String TEST_NAME = "FederatedKMeansTest";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedKmeansTest.class.getSimpleName() + "/";
 
        private final static int blocksize = 1024;
        @Parameterized.Parameter()
@@ -48,7 +49,7 @@ public class FederatedPCATest extends AutomatedTestBase {
        @Parameterized.Parameter(1)
        public int cols;
        @Parameterized.Parameter(2)
-       public boolean scaleAndShift;
+       public int runs;
 
        @Override
        public void setUp() {
@@ -60,22 +61,23 @@ public class FederatedPCATest extends AutomatedTestBase {
        public static Collection<Object[]> data() {
                // rows have to be even and > 1
                return Arrays.asList(new Object[][] {
-                       {10000, 10, false}, {2000, 50, false}, {1000, 100, 
false},
-                       {10000, 10, true}, {2000, 50, true}, {1000, 100, true}
+                       {10000, 10, 1}, {2000, 50, 1}, {1000, 100, 1},
+                       //TODO support for multi-threaded federated interactions
+                       //{10000, 10, 16}, {2000, 50, 16}, {1000, 100, 16}, 
//concurrent requests
                });
        }
 
        @Test
-       public void federatedPCASinglenode() {
-               federatedL2SVM(Types.ExecMode.SINGLE_NODE);
+       public void federatedKmeansSinglenode() {
+               federatedKmeans(Types.ExecMode.SINGLE_NODE);
        }
        
        @Test
-       public void federatedPCAHybrid() {
-               federatedL2SVM(Types.ExecMode.HYBRID);
+       public void federatedKmeansHybrid() {
+               federatedKmeans(Types.ExecMode.HYBRID);
        }
 
-       public void federatedL2SVM(Types.ExecMode execMode) {
+       public void federatedKmeans(Types.ExecMode execMode) {
                ExecMode platformOld = setExecMode(execMode);
 
                getAndLoadTestConfiguration(TEST_NAME);
@@ -98,11 +100,12 @@ public class FederatedPCATest extends AutomatedTestBase {
 
                TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
                loadTestConfiguration(config);
+               setOutputBuffering(false);
                
                // Run reference dml script with normal matrix
                fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
                programArgs = new String[] {"-args", input("X1"), input("X2"),
-                       String.valueOf(scaleAndShift).toUpperCase(), 
expected("Z")};
+                       String.valueOf(runs), expected("Z")};
                runTest(true, false, null, -1);
 
                // Run actual dml script with federated matrix
@@ -110,24 +113,25 @@ public class FederatedPCATest extends AutomatedTestBase {
                programArgs = new String[] {"-stats",
                        "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
                        "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")), "rows=" + rows, "cols=" + cols,
-                       "scaleAndShift=" + 
String.valueOf(scaleAndShift).toUpperCase(), "out=" + output("Z")};
+                       "runs=" + String.valueOf(runs), "out=" + output("Z")};
                runTest(true, false, null, -1);
 
                // compare via files
-               compareResults(1e-9);
+               //compareResults(1e-9); --> randomized
                TestUtils.shutdownThreads(t1, t2);
                
                // check for federated operations
                Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
-               Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
-               Assert.assertTrue(heavyHittersContainsString("fed_tsmm"));
-               if( scaleAndShift ) {
-                       
Assert.assertTrue(heavyHittersContainsString("fed_uacsqk+"));
-                       
Assert.assertTrue(heavyHittersContainsString("fed_uacmean"));
-                       Assert.assertTrue(heavyHittersContainsString("fed_-"));
-                       Assert.assertTrue(heavyHittersContainsString("fed_/"));
-                       
Assert.assertTrue(heavyHittersContainsString("fed_replace"));
-               }
+               Assert.assertTrue(heavyHittersContainsString("fed_uasqk+"));
+               Assert.assertTrue(heavyHittersContainsString("fed_uarmin"));
+               Assert.assertTrue(heavyHittersContainsString("fed_*"));
+               Assert.assertTrue(heavyHittersContainsString("fed_+"));
+               Assert.assertTrue(heavyHittersContainsString("fed_<="));
+               Assert.assertTrue(heavyHittersContainsString("fed_/"));
+               
+               //check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
                
                resetExecMode(platformOld);
        }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
index bf674a8..53eac1e 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
@@ -27,6 +27,7 @@ import org.junit.runners.Parameterized;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
@@ -129,6 +130,10 @@ public class FederatedPCATest extends AutomatedTestBase {
                        
Assert.assertTrue(heavyHittersContainsString("fed_replace"));
                }
                
+               //check that federated input files are still existing
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+               
                resetExecMode(platformOld);
        }
 }
diff --git a/src/test/scripts/functions/federated/FederatedKmeansTest.dml 
b/src/test/scripts/functions/federated/FederatedKmeansTest.dml
new file mode 100644
index 0000000..95f136c
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedKmeansTest.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+    ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), 
list($rows, $cols)))
+[C,Y] = kmeans(X=X, k=4, runs=$runs)
+write(C, $out)
diff --git 
a/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml 
b/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml
new file mode 100644
index 0000000..da32c8b
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedKmeansTestReference.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($1), read($2))
+[C,Y] = kmeans(X=X, k=4, runs=$3)
+write(C, $4)

Reply via email to