This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 4962a3a  [SYSTEMDS-3098] Asynchronous early broadcast for Spark 
Operations
4962a3a is described below

commit 4962a3acab6947ecbde521da66ba8713a600f679
Author: arnabp <[email protected]>
AuthorDate: Wed Aug 18 23:08:12 2021 +0200

    [SYSTEMDS-3098] Asynchronous early broadcast for Spark Operations
    
    This patch enables asynchronous broadcast of matrices to
    Spark whenever necessary. The new Broadcast instruction is
    placed immediately after the instruction that produces
    the intermediate to be broadcast later.
    
    Closes #1364
---
 src/main/java/org/apache/sysds/common/Types.java   |  2 +-
 src/main/java/org/apache/sysds/lops/AppendM.java   |  7 +++
 src/main/java/org/apache/sysds/lops/Binary.java    | 18 ++++++++
 src/main/java/org/apache/sysds/lops/LeftIndex.java |  7 +++
 src/main/java/org/apache/sysds/lops/Lop.java       |  9 ++++
 src/main/java/org/apache/sysds/lops/MapMult.java   |  9 ++++
 .../java/org/apache/sysds/lops/MapMultChain.java   |  7 +++
 src/main/java/org/apache/sysds/lops/PMMJ.java      |  7 +++
 .../java/org/apache/sysds/lops/compile/Dag.java    | 45 ++++++++++++++++--
 .../context/SparkExecutionContext.java             |  4 ++
 .../runtime/instructions/CPInstructionParser.java  |  5 ++
 .../instructions/cp/BroadcastCPInstruction.java    | 51 ++++++++++++++++++++
 .../runtime/instructions/cp/CPInstruction.java     |  2 +-
 .../instructions/cp/TriggerBroadcastTask.java      | 54 ++++++++++++++++++++++
 .../java/org/apache/sysds/utils/Statistics.java    | 14 +++++-
 ...refetchRDDTest.java => AsyncBroadcastTest.java} | 42 ++++++++---------
 .../test/functions/async/PrefetchRDDTest.java      |  3 +-
 src/test/scripts/functions/async/BroadcastVar1.dml | 37 +++++++++++++++
 src/test/scripts/functions/async/BroadcastVar2.dml | 36 +++++++++++++++
 19 files changed, 325 insertions(+), 34 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index df9e9ef..4e2cef7 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -226,7 +226,7 @@ public class Types
        
        // Operations that require 1 operand
        public enum OpOp1 {
-               ABS, ACOS, ASIN, ASSERT, ATAN, CAST_AS_SCALAR, CAST_AS_MATRIX,
+               ABS, ACOS, ASIN, ASSERT, ATAN, BROADCAST, CAST_AS_SCALAR, 
CAST_AS_MATRIX,
                CAST_AS_FRAME, CAST_AS_DOUBLE, CAST_AS_INT, CAST_AS_BOOLEAN,
                CEIL, CHOLESKY, COS, COSH, CUMMAX, CUMMIN, CUMPROD, CUMSUM,
                CUMSUMPROD, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR, 
INVERSE,
diff --git a/src/main/java/org/apache/sysds/lops/AppendM.java 
b/src/main/java/org/apache/sysds/lops/AppendM.java
index 6836cf8..de8e479 100644
--- a/src/main/java/org/apache/sysds/lops/AppendM.java
+++ b/src/main/java/org/apache/sysds/lops/AppendM.java
@@ -54,6 +54,13 @@ public class AppendM extends Lop
                input3.addOutput(this);
                lps.setProperties(inputs, ExecType.SPARK);
        }
+
+       @Override
+       public Lop getBroadcastInput() {
+               if (getExecType() != ExecType.SPARK)
+                       return null;
+               return getInputs().get(1); //frame or matrix
+       }
        
        @Override
        public String toString() {
diff --git a/src/main/java/org/apache/sysds/lops/Binary.java 
b/src/main/java/org/apache/sysds/lops/Binary.java
index 79346ff..202dce7 100644
--- a/src/main/java/org/apache/sysds/lops/Binary.java
+++ b/src/main/java/org/apache/sysds/lops/Binary.java
@@ -21,6 +21,9 @@ package org.apache.sysds.lops;
 
 import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
+
+import java.util.ArrayList;
+
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.OpOp2;
 import org.apache.sysds.common.Types.ValueType;
@@ -71,6 +74,21 @@ public class Binary extends Lop
                return " Operation: " + operation;
        }
        
+       @Override
+       public Lop getBroadcastInput() {
+               if (getExecType() != ExecType.SPARK)
+                       return null;
+
+               ArrayList<Lop> inputs = getInputs();
+               if (operation == OpOp2.MAP && inputs.get(0).getDataType() == 
DataType.MATRIX 
+                               && inputs.get(1).getDataType() == 
DataType.MATRIX)
+                       return inputs.get(1);
+               else if (inputs.get(0).getDataType() == DataType.FRAME && 
inputs.get(1).getDataType() == DataType.MATRIX)
+                       return inputs.get(1);
+               else
+                       return null;
+       }
+       
        public OpOp2 getOperationType() {
                return operation;
        }
diff --git a/src/main/java/org/apache/sysds/lops/LeftIndex.java 
b/src/main/java/org/apache/sysds/lops/LeftIndex.java
index ae49f59..3c6bf0a 100644
--- a/src/main/java/org/apache/sysds/lops/LeftIndex.java
+++ b/src/main/java/org/apache/sysds/lops/LeftIndex.java
@@ -87,6 +87,13 @@ public class LeftIndex extends Lop
                colU.addOutput(this);
                lps.setProperties(inputs, et);
        }
+
+       @Override
+       public Lop getBroadcastInput() {
+               if (getExecType() != ExecType.SPARK || _type == 
LixCacheType.NONE)
+                       return null;
+               return _type == LixCacheType.LEFT ? getInputs().get(0) : null;
+       }
        
        private String getOpcode() {
                if( _type != LixCacheType.NONE )
diff --git a/src/main/java/org/apache/sysds/lops/Lop.java 
b/src/main/java/org/apache/sysds/lops/Lop.java
index febe07a..e014d3c 100644
--- a/src/main/java/org/apache/sysds/lops/Lop.java
+++ b/src/main/java/org/apache/sysds/lops/Lop.java
@@ -425,6 +425,15 @@ public abstract class Lop
                return SparkAggType.NONE;
        }
        
+       /**
+        * Method to get the input to be broadcast.
+        * This method is overridden by the Lops which require broadcasts (e.g. 
AppendM)
+        * @return An input Lop or Null
+        */
+       public Lop getBroadcastInput() {
+               return null;
+       }
+       
 
        /** Method should be overridden if needed
         * 
diff --git a/src/main/java/org/apache/sysds/lops/MapMult.java 
b/src/main/java/org/apache/sysds/lops/MapMult.java
index 57b76ef..2e7c72d 100644
--- a/src/main/java/org/apache/sysds/lops/MapMult.java
+++ b/src/main/java/org/apache/sysds/lops/MapMult.java
@@ -92,6 +92,15 @@ public class MapMult extends Lop
        public SparkAggType getAggType() {
                return _aggtype;
        }
+       
+       @Override
+       public Lop getBroadcastInput() {
+               if (getExecType() != ExecType.SPARK)
+                       return null;
+               
+               return _cacheType.isRight() ? getInputs().get(1) : 
getInputs().get(0);
+               //Note: rdd and broadcast inputs can flip during runtime
+       }
 
        @Override
        public String toString() {
diff --git a/src/main/java/org/apache/sysds/lops/MapMultChain.java 
b/src/main/java/org/apache/sysds/lops/MapMultChain.java
index 32284b9..9f68f3e 100644
--- a/src/main/java/org/apache/sysds/lops/MapMultChain.java
+++ b/src/main/java/org/apache/sysds/lops/MapMultChain.java
@@ -100,6 +100,13 @@ public class MapMultChain extends Lop
        }
 
        @Override
+       public Lop getBroadcastInput() {
+               if (getExecType() != ExecType.SPARK)
+                       return null;
+               return getInputs().get(1);
+       }
+
+       @Override
        public String getInstructions(String input1, String input2, String 
output) {
                return getInstructions(input1, input2, null, output);
        }
diff --git a/src/main/java/org/apache/sysds/lops/PMMJ.java 
b/src/main/java/org/apache/sysds/lops/PMMJ.java
index 9e44201..e0c8dcf 100644
--- a/src/main/java/org/apache/sysds/lops/PMMJ.java
+++ b/src/main/java/org/apache/sysds/lops/PMMJ.java
@@ -69,6 +69,13 @@ public class PMMJ extends Lop
        public String toString() {
                return "Operation = PMMJ";
        }
+
+       @Override
+       public Lop getBroadcastInput() {
+               if (getExecType() != ExecType.SPARK)
+                       return null;
+               return getInputs().get(1);
+       }
        
        @Override
        public String getInstructions(String input_index1, String input_index2, 
String input_index3, String output_index) 
diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java 
b/src/main/java/org/apache/sysds/lops/compile/Dag.java
index 823c14d..9b7f1e5 100644
--- a/src/main/java/org/apache/sysds/lops/compile/Dag.java
+++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java
@@ -199,14 +199,14 @@ public class Dag<N extends Lop>
                        //doTopologicalSortStrictOrder(nodes) :
                        doTopologicalSortTwoLevelOrder(nodes);
                
-               // add Prefetch lops to the list, if necessary
-               //List<Lop> node_pf = addPrefetchLop(node_v);
+               // add Prefetch and broadcast lops, if necessary
                List<Lop> node_pf = OptimizerUtils.ASYNC_TRIGGER_RDD_OPERATIONS 
? addPrefetchLop(node_v) : node_v;
+               List<Lop> node_bc = OptimizerUtils.ASYNC_TRIGGER_RDD_OPERATIONS 
? addBroadcastLop(node_pf) : node_pf;
+               // TODO: Merge via a single traversal of the nodes
                
                // do greedy grouping of operations
                ArrayList<Instruction> inst =
-                       //doGreedyGrouping(sb, node_v) :
-                       doPlainInstructionGen(sb, node_pf);
+                       doPlainInstructionGen(sb, node_bc);
                
                // cleanup instruction (e.g., create packed rmvar instructions)
                return cleanupInstructions(inst);
@@ -234,6 +234,7 @@ public class Dag<N extends Lop>
                        nodesWithPrefetch.add(l);
                        if (isPrefetchNeeded(l)) {
                                //TODO: No prefetch if the parent is placed 
right after the spark OP
+                               //or push the parent further to increase 
parallelism
                                List<Lop> oldOuts = new 
ArrayList<>(l.getOutputs());
                                //Construct a Prefetch lop that takes this 
Spark node as a input
                                UnaryCP prefetch = new UnaryCP(l, 
OpOp1.PREFETCH, l.getDataType(), l.getValueType(), ExecType.CP);
@@ -251,6 +252,30 @@ public class Dag<N extends Lop>
                return nodesWithPrefetch;
        }
        
+       private static List<Lop> addBroadcastLop(List<Lop> nodes) {
+               List<Lop> nodesWithBroadcast = new ArrayList<>();
+               
+               for (Lop l : nodes) {
+                       nodesWithBroadcast.add(l);
+                       if (isBroadcastNeeded(l)) {
+                               List<Lop> oldOuts = new 
ArrayList<>(l.getOutputs());
+                               //Construct a Broadcast lop that takes this 
Spark node as an input
+                               UnaryCP bc = new UnaryCP(l, OpOp1.BROADCAST, 
l.getDataType(), l.getValueType(), ExecType.CP);
+                               //FIXME: Wire Broadcast only with the necessary 
outputs
+                               for (Lop outCP : oldOuts) {
+                                       //Rewire l -> outCP to l -> Broadcast 
-> outCP
+                                       bc.addOutput(outCP);
+                                       outCP.replaceInput(l, bc);
+                                       l.removeOutput(outCP);
+                                       //FIXME: Rewire _inputParams when 
needed (e.g. GroupedAggregate)
+                               }
+                               //Place it immediately after the Spark lop in 
the node list
+                               nodesWithBroadcast.add(bc);
+                       }
+               }
+               return nodesWithBroadcast;
+       }
+       
        private ArrayList<Instruction> doPlainInstructionGen(StatementBlock sb, 
List<Lop> nodes)
        {
                //prepare basic instruction sets
@@ -308,7 +333,17 @@ public class Dag<N extends Lop>
                return transformOP && !hasParameterizedOut 
                                && lop.isAllOutputsCP() && lop.getDataType() == 
DataType.MATRIX;
        }
-
+       
+       private static boolean isBroadcastNeeded(Lop lop) {
+               // Asynchronously broadcast a matrix if that is produced by a 
CP instruction,
+               // and at least one Spark parent needs to broadcast this 
intermediate (eg. mapmm)
+               boolean isBc = lop.getOutputs().stream()
+                               .anyMatch(out -> (out.getBroadcastInput() == 
lop));
+               //TODO: Early broadcast objects that are bigger than a single 
block
+               //return isCP && isBc && lop.getDataTypes() == DataType.Matrix;
+               return isBc && lop.getDataType() == DataType.MATRIX;
+       }
+       
        private static List<Instruction> 
deleteUpdatedTransientReadVariables(StatementBlock sb, List<Lop> nodeV) {
                List<Instruction> insts = new ArrayList<>();
                if ( sb == null ) //return modifiable list
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index 331e54f..bb95fe0 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -701,6 +701,10 @@ public class SparkExecutionContext extends ExecutionContext
 
                return bret;
        }
+       
+       public void setBroadcastHandle(MatrixObject mo) {
+               getBroadcastForMatrixObject(mo);
+       }
 
        @SuppressWarnings("unchecked")
        public PartitionedBroadcast<TensorBlock> 
getBroadcastForTensorObject(TensorObject to) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index 9018aa7..d207a6c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -34,6 +34,7 @@ import 
org.apache.sysds.runtime.instructions.cp.AggregateTernaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.AppendCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.BroadcastCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.BuiltinNaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.CPInstruction;
 import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
@@ -317,6 +318,7 @@ public class CPInstructionParser extends InstructionParser
                String2CPInstructionType.put( "decompress", 
CPType.DeCompression);
                String2CPInstructionType.put( "spoof",     CPType.SpoofFused);
                String2CPInstructionType.put( "prefetch",  CPType.Prefetch);
+               String2CPInstructionType.put( "broadcast",  CPType.Broadcast);
                
                String2CPInstructionType.put( "sql", CPType.Sql);
        }
@@ -459,6 +461,9 @@ public class CPInstructionParser extends InstructionParser
                                
                        case Prefetch:
                                return 
PrefetchCPInstruction.parseInstruction(str);
+                               
+                       case Broadcast:
+                               return 
BroadcastCPInstruction.parseInstruction(str);
                        
                        default:
                                throw new DMLRuntimeException("Invalid CP 
Instruction Type: " + cptype );
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
new file mode 100644
index 0000000..51b8ba5
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.cp;
+
+import java.util.concurrent.Executors;
+
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class BroadcastCPInstruction extends UnaryCPInstruction {
+       private BroadcastCPInstruction(Operator op, CPOperand in, CPOperand 
out, String opcode, String istr) {
+               super(CPType.Broadcast, op, in, out, opcode, istr);
+       }
+       
+       public static BroadcastCPInstruction parseInstruction (String str) {
+               InstructionUtils.checkNumFields(str, 2);
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               String opcode = parts[0];
+               CPOperand in = new CPOperand(parts[1]);
+               CPOperand out = new CPOperand(parts[2]);
+               return new BroadcastCPInstruction(null, in, out, opcode, str);
+       }
+
+       @Override
+       public void processInstruction(ExecutionContext ec) {
+               ec.setVariable(output.getName(), ec.getMatrixObject(input1));
+
+               if (SparkUtils.triggerRDDPool == null)
+                       SparkUtils.triggerRDDPool = 
Executors.newCachedThreadPool();
+               SparkUtils.triggerRDDPool.submit(new TriggerBroadcastTask(ec, 
ec.getMatrixObject(output)));
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
index 3a3840f..47a14b1 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
@@ -46,7 +46,7 @@ public abstract class CPInstruction extends Instruction
                MultiReturnParameterizedBuiltin, ParameterizedBuiltin, 
MultiReturnBuiltin,
                Builtin, Reorg, Variable, FCall, Append, Rand, QSort, QPick,
                MatrixIndexing, MMTSJ, PMMJ, MMChain, Reshape, Partition, 
Compression, DeCompression, SpoofFused,
-               StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn, 
Sql, Prefetch }
+               StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn, 
Sql, Prefetch, Broadcast }
 
        protected final CPType _cptype;
        protected final boolean _requiresLabelUpdate;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerBroadcastTask.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerBroadcastTask.java
new file mode 100644
index 0000000..cc1187b
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerBroadcastTask.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.cp;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysds.utils.Statistics;
+
+public class TriggerBroadcastTask implements Runnable {
+       ExecutionContext _ec;
+       MatrixObject _broadcastMO;
+       
+       public TriggerBroadcastTask(ExecutionContext ec, MatrixObject mo) {
+               _ec = ec;
+               _broadcastMO = mo;
+       }
+
+       @Override
+       public void run() {
+               // TODO: Synchronization. Although it is harmless if to threads 
create separate
+               // broadcast handles as only one will stay with the 
MatrixObject. However, redundant
+               // partitioning increases untraced memory usage.
+               try {
+                       SparkExecutionContext sec = (SparkExecutionContext)_ec;
+                       sec.setBroadcastHandle(_broadcastMO);
+               }
+               catch (Exception e) {
+                       e.printStackTrace();
+               }
+
+               if (DMLScript.STATISTICS)
+                       Statistics.incSparkAsyncBroadcastCount(1);
+               
+       }
+}
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java 
b/src/main/java/org/apache/sysds/utils/Statistics.java
index e7850e5..cacd2f7 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -120,6 +120,7 @@ public class Statistics
        private static final LongAdder sparkBroadcast = new LongAdder();
        private static final LongAdder sparkBroadcastCount = new LongAdder();
        private static final LongAdder sparkAsyncPrefetchCount = new 
LongAdder();
+       private static final LongAdder sparkAsyncBroadcastCount = new 
LongAdder();
 
        // Paramserv function stats (time is in milli sec)
        private static final Timing psExecutionTimer = new Timing(false);
@@ -490,6 +491,7 @@ public class Statistics
                
                sparkCtxCreateTime = 0;
                sparkAsyncPrefetchCount.reset();
+               sparkAsyncBroadcastCount.reset();
                
                lTotalLix.reset();
                lTotalLixUIP.reset();
@@ -573,6 +575,10 @@ public class Statistics
                sparkAsyncPrefetchCount.add(c);
        }
 
+       public static void incSparkAsyncBroadcastCount(long c) {
+               sparkAsyncBroadcastCount.add(c);
+       }
+
        public static void incWorkerNumber() {
                psNumWorkers.increment();
        }
@@ -972,6 +978,10 @@ public class Statistics
                return sparkAsyncPrefetchCount.longValue();
        }
 
+       public static long getAsyncBroadcastCount() {
+               return sparkAsyncBroadcastCount.longValue();
+       }
+
        /**
         * Returns statistics of the DML program that was recently completed as 
a string
         * @return statistics as a string
@@ -1078,8 +1088,8 @@ public class Statistics
                                                                
sparkBroadcast.longValue()*1e-9,
                                                                
sparkCollect.longValue()*1e-9));
                                if (OptimizerUtils.ASYNC_TRIGGER_RDD_OPERATIONS)
-                                       sb.append("Spark async. prefetch count: 
\t" + 
-                                                       String.format("%d.\n", 
sparkAsyncPrefetchCount.longValue()));
+                                       sb.append("Spark async. count (pf,bc): 
\t" + 
+                                                       
String.format("%d/%d.\n", getAsyncPrefetchCount(), getAsyncBroadcastCount()));
                        }
                        if (psNumWorkers.longValue() > 0) {
                                sb.append(String.format("Paramserv total 
execution time:\t%.3f secs.\n", psExecutionTime.doubleValue() / 1000));
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java 
b/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
similarity index 73%
copy from 
src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
copy to 
src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
index db9c2f1..088773e 100644
--- a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
@@ -35,12 +35,12 @@ import org.apache.sysds.utils.Statistics;
 import org.junit.Assert;
 import org.junit.Test;
 
-public class PrefetchRDDTest extends AutomatedTestBase {
+public class AsyncBroadcastTest extends AutomatedTestBase {
        
        protected static final String TEST_DIR = "functions/async/";
-       protected static final String TEST_NAME = "PrefetchRDD";
-       protected static final int TEST_VARIANTS = 3;
-       protected static String TEST_CLASS_DIR = TEST_DIR + 
PrefetchRDDTest.class.getSimpleName() + "/";
+       protected static final String TEST_NAME = "BroadcastVar";
+       protected static final int TEST_VARIANTS = 2;
+       protected static String TEST_CLASS_DIR = TEST_DIR + 
AsyncBroadcastTest.class.getSimpleName() + "/";
        
        @Override
        public void setUp() {
@@ -50,23 +50,15 @@ public class PrefetchRDDTest extends AutomatedTestBase {
        }
        
        @Test
-       public void testAsyncSparkOPs1() {
-               //Single CP consumer. Prefetch Lop has one output.
+       public void testAsyncBroadcast1() {
                runTest(TEST_NAME+"1");
        }
 
        @Test
-       public void testAsyncSparkOPs2() {
-               //Two CP consumers. Prefetch Lop has two outputs.
+       public void testAsyncBroadcast2() {
                runTest(TEST_NAME+"2");
        }
 
-       @Test
-       public void testAsyncSparkOPs3() {
-               //SP action type consumer. No Prefetch.
-               runTest(TEST_NAME+"3");
-       }
-       
        public void runTest(String testname) {
                boolean old_simplification = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
                boolean old_sum_product = 
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
@@ -86,6 +78,7 @@ public class PrefetchRDDTest extends AutomatedTestBase {
                        
                        List<String> proArgs = new ArrayList<>();
                        
+                       //proArgs.add("-explain");
                        proArgs.add("-stats");
                        proArgs.add("-args");
                        proArgs.add(output("R"));
@@ -97,17 +90,18 @@ public class PrefetchRDDTest extends AutomatedTestBase {
                        OptimizerUtils.ASYNC_TRIGGER_RDD_OPERATIONS = true;
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
                        OptimizerUtils.ASYNC_TRIGGER_RDD_OPERATIONS = false;
-                       HashMap<MatrixValue.CellIndex, Double> R_pf = 
readDMLScalarFromOutputDir("R");
+                       HashMap<MatrixValue.CellIndex, Double> R_bc = 
readDMLScalarFromOutputDir("R");
 
                        //compare matrices
-                       TestUtils.compareMatrices(R, R_pf, 1e-6, "Origin", 
"Reused");
-                       //assert Prefetch instructions and number of success.
-                       long expected_numPF = 
!testname.equalsIgnoreCase(TEST_NAME+"3") ? 1 : 0;
-                       long expected_successPF = 
!testname.equalsIgnoreCase(TEST_NAME+"3") ? 1 : 0;
-                       long numPF = 
Statistics.getCPHeavyHitterCount("prefetch");
-                       Assert.assertTrue("Violated Prefetch instruction count: 
"+numPF, numPF == expected_numPF);
-                       long successPF = Statistics.getAsyncPrefetchCount();
-                       Assert.assertTrue("Violated successful Prefetch count: 
"+successPF, successPF == expected_successPF);
+                       TestUtils.compareMatrices(R, R_bc, 1e-6, "Origin", 
"withBroadcast");
+
+                       //assert called and successful early broadcast counts
+                       long expected_numBC = 1;
+                       long expected_successBC = 1;
+                       long numBC = 
Statistics.getCPHeavyHitterCount("broadcast");
+                       Assert.assertTrue("Violated Broadcast instruction 
count: "+numBC, numBC == expected_numBC);
+                       long successBC = Statistics.getAsyncBroadcastCount();
+                       Assert.assertTrue("Violated successful Broadcast count: 
"+successBC, successBC == expected_successBC);
                } finally {
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
old_simplification;
                        OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = 
old_sum_product;
@@ -117,4 +111,4 @@ public class PrefetchRDDTest extends AutomatedTestBase {
                        Recompiler.reinitRecompiler();
                }
        }
-}
\ No newline at end of file
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java 
b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
index db9c2f1..b625139 100644
--- a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
@@ -86,6 +86,7 @@ public class PrefetchRDDTest extends AutomatedTestBase {
                        
                        List<String> proArgs = new ArrayList<>();
                        
+                       proArgs.add("-explain");
                        proArgs.add("-stats");
                        proArgs.add("-args");
                        proArgs.add(output("R"));
@@ -100,7 +101,7 @@ public class PrefetchRDDTest extends AutomatedTestBase {
                        HashMap<MatrixValue.CellIndex, Double> R_pf = 
readDMLScalarFromOutputDir("R");
 
                        //compare matrices
-                       TestUtils.compareMatrices(R, R_pf, 1e-6, "Origin", 
"Reused");
+                       TestUtils.compareMatrices(R, R_pf, 1e-6, "Origin", 
"withPrefetch");
                        //assert Prefetch instructions and number of success.
                        long expected_numPF = 
!testname.equalsIgnoreCase(TEST_NAME+"3") ? 1 : 0;
                        long expected_successPF = 
!testname.equalsIgnoreCase(TEST_NAME+"3") ? 1 : 0;
diff --git a/src/test/scripts/functions/async/BroadcastVar1.dml 
b/src/test/scripts/functions/async/BroadcastVar1.dml
new file mode 100644
index 0000000..777e332
--- /dev/null
+++ b/src/test/scripts/functions/async/BroadcastVar1.dml
@@ -0,0 +1,37 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+X = rand(rows=10000, cols=200, seed=42); #sp_rand
+v = rand(rows=200, cols=1, seed=42); #cp_rand
+
+# CP operations
+v = ((v + v) * 1 - v) / (1+1);
+v1 = t(v);
+
+# Spark transformation operations 
+sp = X + ceil(X);
+sp = ((sp + sp) * 1 - sp) / (1+1);
+
+# mapmm - broadcast v
+sp2 = sp %*% v; 
+
+while(FALSE){}
+R = sum(sp2);
+write(R, $1, format="text");
diff --git a/src/test/scripts/functions/async/BroadcastVar2.dml 
b/src/test/scripts/functions/async/BroadcastVar2.dml
new file mode 100644
index 0000000..26939b1
--- /dev/null
+++ b/src/test/scripts/functions/async/BroadcastVar2.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+X = rand(rows=10000, cols=200, seed=42); #sp_rand
+v = rand(rows=200, cols=1, seed=42); #cp_rand
+
+# CP operations
+v = ((v + v) * 1 - v) / (1+1);
+
+# Spark transformation operations 
+sp = X + ceil(X);
+sp = ((sp + sp) * 1 - sp) / (1+1);
+
+# mappend - broadcast v
+sp2 = rbind(sp, t(v));
+
+while(FALSE){}
+R = sum(sp2);
+write(R, $1, format="text");

Reply via email to