This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 4962a3a [SYSTEMDS-3098] Asynchronous early broadcast for Spark
Operations
4962a3a is described below
commit 4962a3acab6947ecbde521da66ba8713a600f679
Author: arnabp <[email protected]>
AuthorDate: Wed Aug 18 23:08:12 2021 +0200
[SYSTEMDS-3098] Asynchronous early broadcast for Spark Operations
This patch enables asynchronous broadcast of matrices to
Spark whenever necessary. The new Broadcast instruction is
placed immediately after the instruction that produces
the intermediate to be broadcast later.
Closes #1364
---
src/main/java/org/apache/sysds/common/Types.java | 2 +-
src/main/java/org/apache/sysds/lops/AppendM.java | 7 +++
src/main/java/org/apache/sysds/lops/Binary.java | 18 ++++++++
src/main/java/org/apache/sysds/lops/LeftIndex.java | 7 +++
src/main/java/org/apache/sysds/lops/Lop.java | 9 ++++
src/main/java/org/apache/sysds/lops/MapMult.java | 9 ++++
.../java/org/apache/sysds/lops/MapMultChain.java | 7 +++
src/main/java/org/apache/sysds/lops/PMMJ.java | 7 +++
.../java/org/apache/sysds/lops/compile/Dag.java | 45 ++++++++++++++++--
.../context/SparkExecutionContext.java | 4 ++
.../runtime/instructions/CPInstructionParser.java | 5 ++
.../instructions/cp/BroadcastCPInstruction.java | 51 ++++++++++++++++++++
.../runtime/instructions/cp/CPInstruction.java | 2 +-
.../instructions/cp/TriggerBroadcastTask.java | 54 ++++++++++++++++++++++
.../java/org/apache/sysds/utils/Statistics.java | 14 +++++-
...refetchRDDTest.java => AsyncBroadcastTest.java} | 42 ++++++++---------
.../test/functions/async/PrefetchRDDTest.java | 3 +-
src/test/scripts/functions/async/BroadcastVar1.dml | 37 +++++++++++++++
src/test/scripts/functions/async/BroadcastVar2.dml | 36 +++++++++++++++
19 files changed, 325 insertions(+), 34 deletions(-)
diff --git a/src/main/java/org/apache/sysds/common/Types.java
b/src/main/java/org/apache/sysds/common/Types.java
index df9e9ef..4e2cef7 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -226,7 +226,7 @@ public class Types
// Operations that require 1 operand
public enum OpOp1 {
- ABS, ACOS, ASIN, ASSERT, ATAN, CAST_AS_SCALAR, CAST_AS_MATRIX,
+ ABS, ACOS, ASIN, ASSERT, ATAN, BROADCAST, CAST_AS_SCALAR,
CAST_AS_MATRIX,
CAST_AS_FRAME, CAST_AS_DOUBLE, CAST_AS_INT, CAST_AS_BOOLEAN,
CEIL, CHOLESKY, COS, COSH, CUMMAX, CUMMIN, CUMPROD, CUMSUM,
CUMSUMPROD, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR,
INVERSE,
diff --git a/src/main/java/org/apache/sysds/lops/AppendM.java
b/src/main/java/org/apache/sysds/lops/AppendM.java
index 6836cf8..de8e479 100644
--- a/src/main/java/org/apache/sysds/lops/AppendM.java
+++ b/src/main/java/org/apache/sysds/lops/AppendM.java
@@ -54,6 +54,13 @@ public class AppendM extends Lop
input3.addOutput(this);
lps.setProperties(inputs, ExecType.SPARK);
}
+
+ @Override
+ public Lop getBroadcastInput() {
+ if (getExecType() != ExecType.SPARK)
+ return null;
+ return getInputs().get(1); //frame or matrix
+ }
@Override
public String toString() {
diff --git a/src/main/java/org/apache/sysds/lops/Binary.java
b/src/main/java/org/apache/sysds/lops/Binary.java
index 79346ff..202dce7 100644
--- a/src/main/java/org/apache/sysds/lops/Binary.java
+++ b/src/main/java/org/apache/sysds/lops/Binary.java
@@ -21,6 +21,9 @@ package org.apache.sysds.lops;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.runtime.instructions.InstructionUtils;
+
+import java.util.ArrayList;
+
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.ValueType;
@@ -71,6 +74,21 @@ public class Binary extends Lop
return " Operation: " + operation;
}
+ @Override
+ public Lop getBroadcastInput() {
+ if (getExecType() != ExecType.SPARK)
+ return null;
+
+ ArrayList<Lop> inputs = getInputs();
+ if (operation == OpOp2.MAP && inputs.get(0).getDataType() ==
DataType.MATRIX
+ && inputs.get(1).getDataType() ==
DataType.MATRIX)
+ return inputs.get(1);
+ else if (inputs.get(0).getDataType() == DataType.FRAME &&
inputs.get(1).getDataType() == DataType.MATRIX)
+ return inputs.get(1);
+ else
+ return null;
+ }
+
public OpOp2 getOperationType() {
return operation;
}
diff --git a/src/main/java/org/apache/sysds/lops/LeftIndex.java
b/src/main/java/org/apache/sysds/lops/LeftIndex.java
index ae49f59..3c6bf0a 100644
--- a/src/main/java/org/apache/sysds/lops/LeftIndex.java
+++ b/src/main/java/org/apache/sysds/lops/LeftIndex.java
@@ -87,6 +87,13 @@ public class LeftIndex extends Lop
colU.addOutput(this);
lps.setProperties(inputs, et);
}
+
+ @Override
+ public Lop getBroadcastInput() {
+ if (getExecType() != ExecType.SPARK || _type ==
LixCacheType.NONE)
+ return null;
+ return _type == LixCacheType.LEFT ? getInputs().get(0) : null;
+ }
private String getOpcode() {
if( _type != LixCacheType.NONE )
diff --git a/src/main/java/org/apache/sysds/lops/Lop.java
b/src/main/java/org/apache/sysds/lops/Lop.java
index febe07a..e014d3c 100644
--- a/src/main/java/org/apache/sysds/lops/Lop.java
+++ b/src/main/java/org/apache/sysds/lops/Lop.java
@@ -425,6 +425,15 @@ public abstract class Lop
return SparkAggType.NONE;
}
+ /**
+ * Method to get the input to be broadcast.
+ * This method is overridden by the Lops which require broadcasts (e.g.
AppendM)
+ * @return An input Lop or Null
+ */
+ public Lop getBroadcastInput() {
+ return null;
+ }
+
/** Method should be overridden if needed
*
diff --git a/src/main/java/org/apache/sysds/lops/MapMult.java
b/src/main/java/org/apache/sysds/lops/MapMult.java
index 57b76ef..2e7c72d 100644
--- a/src/main/java/org/apache/sysds/lops/MapMult.java
+++ b/src/main/java/org/apache/sysds/lops/MapMult.java
@@ -92,6 +92,15 @@ public class MapMult extends Lop
public SparkAggType getAggType() {
return _aggtype;
}
+
+ @Override
+ public Lop getBroadcastInput() {
+ if (getExecType() != ExecType.SPARK)
+ return null;
+
+ return _cacheType.isRight() ? getInputs().get(1) :
getInputs().get(0);
+ //Note: rdd and broadcast inputs can flip during runtime
+ }
@Override
public String toString() {
diff --git a/src/main/java/org/apache/sysds/lops/MapMultChain.java
b/src/main/java/org/apache/sysds/lops/MapMultChain.java
index 32284b9..9f68f3e 100644
--- a/src/main/java/org/apache/sysds/lops/MapMultChain.java
+++ b/src/main/java/org/apache/sysds/lops/MapMultChain.java
@@ -100,6 +100,13 @@ public class MapMultChain extends Lop
}
@Override
+ public Lop getBroadcastInput() {
+ if (getExecType() != ExecType.SPARK)
+ return null;
+ return getInputs().get(1);
+ }
+
+ @Override
public String getInstructions(String input1, String input2, String
output) {
return getInstructions(input1, input2, null, output);
}
diff --git a/src/main/java/org/apache/sysds/lops/PMMJ.java
b/src/main/java/org/apache/sysds/lops/PMMJ.java
index 9e44201..e0c8dcf 100644
--- a/src/main/java/org/apache/sysds/lops/PMMJ.java
+++ b/src/main/java/org/apache/sysds/lops/PMMJ.java
@@ -69,6 +69,13 @@ public class PMMJ extends Lop
public String toString() {
return "Operation = PMMJ";
}
+
+ @Override
+ public Lop getBroadcastInput() {
+ if (getExecType() != ExecType.SPARK)
+ return null;
+ return getInputs().get(1);
+ }
@Override
public String getInstructions(String input_index1, String input_index2,
String input_index3, String output_index)
diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java
b/src/main/java/org/apache/sysds/lops/compile/Dag.java
index 823c14d..9b7f1e5 100644
--- a/src/main/java/org/apache/sysds/lops/compile/Dag.java
+++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java
@@ -199,14 +199,14 @@ public class Dag<N extends Lop>
//doTopologicalSortStrictOrder(nodes) :
doTopologicalSortTwoLevelOrder(nodes);
- // add Prefetch lops to the list, if necessary
- //List<Lop> node_pf = addPrefetchLop(node_v);
+ // add Prefetch and broadcast lops, if necessary
List<Lop> node_pf = OptimizerUtils.ASYNC_TRIGGER_RDD_OPERATIONS
? addPrefetchLop(node_v) : node_v;
+ List<Lop> node_bc = OptimizerUtils.ASYNC_TRIGGER_RDD_OPERATIONS
? addBroadcastLop(node_pf) : node_pf;
+ // TODO: Merge via a single traversal of the nodes
// do greedy grouping of operations
ArrayList<Instruction> inst =
- //doGreedyGrouping(sb, node_v) :
- doPlainInstructionGen(sb, node_pf);
+ doPlainInstructionGen(sb, node_bc);
// cleanup instruction (e.g., create packed rmvar instructions)
return cleanupInstructions(inst);
@@ -234,6 +234,7 @@ public class Dag<N extends Lop>
nodesWithPrefetch.add(l);
if (isPrefetchNeeded(l)) {
//TODO: No prefetch if the parent is placed
right after the spark OP
+ //or push the parent further to increase
parallelism
List<Lop> oldOuts = new
ArrayList<>(l.getOutputs());
//Construct a Prefetch lop that takes this
Spark node as a input
UnaryCP prefetch = new UnaryCP(l,
OpOp1.PREFETCH, l.getDataType(), l.getValueType(), ExecType.CP);
@@ -251,6 +252,30 @@ public class Dag<N extends Lop>
return nodesWithPrefetch;
}
+ private static List<Lop> addBroadcastLop(List<Lop> nodes) {
+ List<Lop> nodesWithBroadcast = new ArrayList<>();
+
+ for (Lop l : nodes) {
+ nodesWithBroadcast.add(l);
+ if (isBroadcastNeeded(l)) {
+ List<Lop> oldOuts = new
ArrayList<>(l.getOutputs());
+ //Construct a Broadcast lop that takes this
Spark node as an input
+ UnaryCP bc = new UnaryCP(l, OpOp1.BROADCAST,
l.getDataType(), l.getValueType(), ExecType.CP);
+ //FIXME: Wire Broadcast only with the necessary
outputs
+ for (Lop outCP : oldOuts) {
+ //Rewire l -> outCP to l -> Broadcast
-> outCP
+ bc.addOutput(outCP);
+ outCP.replaceInput(l, bc);
+ l.removeOutput(outCP);
+ //FIXME: Rewire _inputParams when
needed (e.g. GroupedAggregate)
+ }
+ //Place it immediately after the Spark lop in
the node list
+ nodesWithBroadcast.add(bc);
+ }
+ }
+ return nodesWithBroadcast;
+ }
+
private ArrayList<Instruction> doPlainInstructionGen(StatementBlock sb,
List<Lop> nodes)
{
//prepare basic instruction sets
@@ -308,7 +333,17 @@ public class Dag<N extends Lop>
return transformOP && !hasParameterizedOut
&& lop.isAllOutputsCP() && lop.getDataType() ==
DataType.MATRIX;
}
-
+
+ private static boolean isBroadcastNeeded(Lop lop) {
+ // Asynchronously broadcast a matrix if that is produced by a
CP instruction,
+ // and at least one Spark parent needs to broadcast this
intermediate (eg. mapmm)
+ boolean isBc = lop.getOutputs().stream()
+ .anyMatch(out -> (out.getBroadcastInput() ==
lop));
+ //TODO: Early broadcast objects that are bigger than a single
block
+ //return isCP && isBc && lop.getDataTypes() == DataType.Matrix;
+ return isBc && lop.getDataType() == DataType.MATRIX;
+ }
+
private static List<Instruction>
deleteUpdatedTransientReadVariables(StatementBlock sb, List<Lop> nodeV) {
List<Instruction> insts = new ArrayList<>();
if ( sb == null ) //return modifiable list
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index 331e54f..bb95fe0 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -701,6 +701,10 @@ public class SparkExecutionContext extends ExecutionContext
return bret;
}
+
+ public void setBroadcastHandle(MatrixObject mo) {
+ getBroadcastForMatrixObject(mo);
+ }
@SuppressWarnings("unchecked")
public PartitionedBroadcast<TensorBlock>
getBroadcastForTensorObject(TensorObject to) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index 9018aa7..d207a6c 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -34,6 +34,7 @@ import
org.apache.sysds.runtime.instructions.cp.AggregateTernaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.AppendCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.BroadcastCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BuiltinNaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
@@ -317,6 +318,7 @@ public class CPInstructionParser extends InstructionParser
String2CPInstructionType.put( "decompress",
CPType.DeCompression);
String2CPInstructionType.put( "spoof", CPType.SpoofFused);
String2CPInstructionType.put( "prefetch", CPType.Prefetch);
+ String2CPInstructionType.put( "broadcast", CPType.Broadcast);
String2CPInstructionType.put( "sql", CPType.Sql);
}
@@ -459,6 +461,9 @@ public class CPInstructionParser extends InstructionParser
case Prefetch:
return
PrefetchCPInstruction.parseInstruction(str);
+
+ case Broadcast:
+ return
BroadcastCPInstruction.parseInstruction(str);
default:
throw new DMLRuntimeException("Invalid CP
Instruction Type: " + cptype );
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
new file mode 100644
index 0000000..51b8ba5
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.cp;
+
+import java.util.concurrent.Executors;
+
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class BroadcastCPInstruction extends UnaryCPInstruction {
+ private BroadcastCPInstruction(Operator op, CPOperand in, CPOperand
out, String opcode, String istr) {
+ super(CPType.Broadcast, op, in, out, opcode, istr);
+ }
+
+ public static BroadcastCPInstruction parseInstruction (String str) {
+ InstructionUtils.checkNumFields(str, 2);
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode = parts[0];
+ CPOperand in = new CPOperand(parts[1]);
+ CPOperand out = new CPOperand(parts[2]);
+ return new BroadcastCPInstruction(null, in, out, opcode, str);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ ec.setVariable(output.getName(), ec.getMatrixObject(input1));
+
+ if (SparkUtils.triggerRDDPool == null)
+ SparkUtils.triggerRDDPool =
Executors.newCachedThreadPool();
+ SparkUtils.triggerRDDPool.submit(new TriggerBroadcastTask(ec,
ec.getMatrixObject(output)));
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
index 3a3840f..47a14b1 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
@@ -46,7 +46,7 @@ public abstract class CPInstruction extends Instruction
MultiReturnParameterizedBuiltin, ParameterizedBuiltin,
MultiReturnBuiltin,
Builtin, Reorg, Variable, FCall, Append, Rand, QSort, QPick,
MatrixIndexing, MMTSJ, PMMJ, MMChain, Reshape, Partition,
Compression, DeCompression, SpoofFused,
- StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn,
Sql, Prefetch }
+ StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn,
Sql, Prefetch, Broadcast }
protected final CPType _cptype;
protected final boolean _requiresLabelUpdate;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerBroadcastTask.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerBroadcastTask.java
new file mode 100644
index 0000000..cc1187b
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerBroadcastTask.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.cp;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysds.utils.Statistics;
+
+public class TriggerBroadcastTask implements Runnable {
+ ExecutionContext _ec;
+ MatrixObject _broadcastMO;
+
+ public TriggerBroadcastTask(ExecutionContext ec, MatrixObject mo) {
+ _ec = ec;
+ _broadcastMO = mo;
+ }
+
+ @Override
+ public void run() {
+ // TODO: Synchronization. Although it is harmless if to threads
create separate
+ // broadcast handles as only one will stay with the
MatrixObject. However, redundant
+ // partitioning increases untraced memory usage.
+ try {
+ SparkExecutionContext sec = (SparkExecutionContext)_ec;
+ sec.setBroadcastHandle(_broadcastMO);
+ }
+ catch (Exception e) {
+ e.printStackTrace();
+ }
+
+ if (DMLScript.STATISTICS)
+ Statistics.incSparkAsyncBroadcastCount(1);
+
+ }
+}
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java
b/src/main/java/org/apache/sysds/utils/Statistics.java
index e7850e5..cacd2f7 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -120,6 +120,7 @@ public class Statistics
private static final LongAdder sparkBroadcast = new LongAdder();
private static final LongAdder sparkBroadcastCount = new LongAdder();
private static final LongAdder sparkAsyncPrefetchCount = new
LongAdder();
+ private static final LongAdder sparkAsyncBroadcastCount = new
LongAdder();
// Paramserv function stats (time is in milli sec)
private static final Timing psExecutionTimer = new Timing(false);
@@ -490,6 +491,7 @@ public class Statistics
sparkCtxCreateTime = 0;
sparkAsyncPrefetchCount.reset();
+ sparkAsyncBroadcastCount.reset();
lTotalLix.reset();
lTotalLixUIP.reset();
@@ -573,6 +575,10 @@ public class Statistics
sparkAsyncPrefetchCount.add(c);
}
+ public static void incSparkAsyncBroadcastCount(long c) {
+ sparkAsyncBroadcastCount.add(c);
+ }
+
public static void incWorkerNumber() {
psNumWorkers.increment();
}
@@ -972,6 +978,10 @@ public class Statistics
return sparkAsyncPrefetchCount.longValue();
}
+ public static long getAsyncBroadcastCount() {
+ return sparkAsyncBroadcastCount.longValue();
+ }
+
/**
* Returns statistics of the DML program that was recently completed as
a string
* @return statistics as a string
@@ -1078,8 +1088,8 @@ public class Statistics
sparkBroadcast.longValue()*1e-9,
sparkCollect.longValue()*1e-9));
if (OptimizerUtils.ASYNC_TRIGGER_RDD_OPERATIONS)
- sb.append("Spark async. prefetch count:
\t" +
- String.format("%d.\n",
sparkAsyncPrefetchCount.longValue()));
+ sb.append("Spark async. count (pf,bc):
\t" +
+
String.format("%d/%d.\n", getAsyncPrefetchCount(), getAsyncBroadcastCount()));
}
if (psNumWorkers.longValue() > 0) {
sb.append(String.format("Paramserv total
execution time:\t%.3f secs.\n", psExecutionTime.doubleValue() / 1000));
diff --git
a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
b/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
similarity index 73%
copy from
src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
copy to
src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
index db9c2f1..088773e 100644
--- a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
@@ -35,12 +35,12 @@ import org.apache.sysds.utils.Statistics;
import org.junit.Assert;
import org.junit.Test;
-public class PrefetchRDDTest extends AutomatedTestBase {
+public class AsyncBroadcastTest extends AutomatedTestBase {
protected static final String TEST_DIR = "functions/async/";
- protected static final String TEST_NAME = "PrefetchRDD";
- protected static final int TEST_VARIANTS = 3;
- protected static String TEST_CLASS_DIR = TEST_DIR +
PrefetchRDDTest.class.getSimpleName() + "/";
+ protected static final String TEST_NAME = "BroadcastVar";
+ protected static final int TEST_VARIANTS = 2;
+ protected static String TEST_CLASS_DIR = TEST_DIR +
AsyncBroadcastTest.class.getSimpleName() + "/";
@Override
public void setUp() {
@@ -50,23 +50,15 @@ public class PrefetchRDDTest extends AutomatedTestBase {
}
@Test
- public void testAsyncSparkOPs1() {
- //Single CP consumer. Prefetch Lop has one output.
+ public void testAsyncBroadcast1() {
runTest(TEST_NAME+"1");
}
@Test
- public void testAsyncSparkOPs2() {
- //Two CP consumers. Prefetch Lop has two outputs.
+ public void testAsyncBroadcast2() {
runTest(TEST_NAME+"2");
}
- @Test
- public void testAsyncSparkOPs3() {
- //SP action type consumer. No Prefetch.
- runTest(TEST_NAME+"3");
- }
-
public void runTest(String testname) {
boolean old_simplification =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
boolean old_sum_product =
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
@@ -86,6 +78,7 @@ public class PrefetchRDDTest extends AutomatedTestBase {
List<String> proArgs = new ArrayList<>();
+ //proArgs.add("-explain");
proArgs.add("-stats");
proArgs.add("-args");
proArgs.add(output("R"));
@@ -97,17 +90,18 @@ public class PrefetchRDDTest extends AutomatedTestBase {
OptimizerUtils.ASYNC_TRIGGER_RDD_OPERATIONS = true;
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
OptimizerUtils.ASYNC_TRIGGER_RDD_OPERATIONS = false;
- HashMap<MatrixValue.CellIndex, Double> R_pf =
readDMLScalarFromOutputDir("R");
+ HashMap<MatrixValue.CellIndex, Double> R_bc =
readDMLScalarFromOutputDir("R");
//compare matrices
- TestUtils.compareMatrices(R, R_pf, 1e-6, "Origin",
"Reused");
- //assert Prefetch instructions and number of success.
- long expected_numPF =
!testname.equalsIgnoreCase(TEST_NAME+"3") ? 1 : 0;
- long expected_successPF =
!testname.equalsIgnoreCase(TEST_NAME+"3") ? 1 : 0;
- long numPF =
Statistics.getCPHeavyHitterCount("prefetch");
- Assert.assertTrue("Violated Prefetch instruction count:
"+numPF, numPF == expected_numPF);
- long successPF = Statistics.getAsyncPrefetchCount();
- Assert.assertTrue("Violated successful Prefetch count:
"+successPF, successPF == expected_successPF);
+ TestUtils.compareMatrices(R, R_bc, 1e-6, "Origin",
"withBroadcast");
+
+ //assert called and successful early broadcast counts
+ long expected_numBC = 1;
+ long expected_successBC = 1;
+ long numBC =
Statistics.getCPHeavyHitterCount("broadcast");
+ Assert.assertTrue("Violated Broadcast instruction
count: "+numBC, numBC == expected_numBC);
+ long successBC = Statistics.getAsyncBroadcastCount();
+ Assert.assertTrue("Violated successful Broadcast count:
"+successBC, successBC == expected_successBC);
} finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
old_simplification;
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES =
old_sum_product;
@@ -117,4 +111,4 @@ public class PrefetchRDDTest extends AutomatedTestBase {
Recompiler.reinitRecompiler();
}
}
-}
\ No newline at end of file
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
index db9c2f1..b625139 100644
--- a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
@@ -86,6 +86,7 @@ public class PrefetchRDDTest extends AutomatedTestBase {
List<String> proArgs = new ArrayList<>();
+ proArgs.add("-explain");
proArgs.add("-stats");
proArgs.add("-args");
proArgs.add(output("R"));
@@ -100,7 +101,7 @@ public class PrefetchRDDTest extends AutomatedTestBase {
HashMap<MatrixValue.CellIndex, Double> R_pf =
readDMLScalarFromOutputDir("R");
//compare matrices
- TestUtils.compareMatrices(R, R_pf, 1e-6, "Origin",
"Reused");
+ TestUtils.compareMatrices(R, R_pf, 1e-6, "Origin",
"withPrefetch");
//assert Prefetch instructions and number of success.
long expected_numPF =
!testname.equalsIgnoreCase(TEST_NAME+"3") ? 1 : 0;
long expected_successPF =
!testname.equalsIgnoreCase(TEST_NAME+"3") ? 1 : 0;
diff --git a/src/test/scripts/functions/async/BroadcastVar1.dml
b/src/test/scripts/functions/async/BroadcastVar1.dml
new file mode 100644
index 0000000..777e332
--- /dev/null
+++ b/src/test/scripts/functions/async/BroadcastVar1.dml
@@ -0,0 +1,37 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+X = rand(rows=10000, cols=200, seed=42); #sp_rand
+v = rand(rows=200, cols=1, seed=42); #cp_rand
+
+# CP operations
+v = ((v + v) * 1 - v) / (1+1);
+v1 = t(v);
+
+# Spark transformation operations
+sp = X + ceil(X);
+sp = ((sp + sp) * 1 - sp) / (1+1);
+
+# mapmm - broadcast v
+sp2 = sp %*% v;
+
+while(FALSE){}
+R = sum(sp2);
+write(R, $1, format="text");
diff --git a/src/test/scripts/functions/async/BroadcastVar2.dml
b/src/test/scripts/functions/async/BroadcastVar2.dml
new file mode 100644
index 0000000..26939b1
--- /dev/null
+++ b/src/test/scripts/functions/async/BroadcastVar2.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+X = rand(rows=10000, cols=200, seed=42); #sp_rand
+v = rand(rows=200, cols=1, seed=42); #cp_rand
+
+# CP operations
+v = ((v + v) * 1 - v) / (1+1);
+
+# Spark transformation operations
+sp = X + ceil(X);
+sp = ((sp + sp) * 1 - sp) / (1+1);
+
+# mappend - broadcast v
+sp2 = rbind(sp, t(v));
+
+while(FALSE){}
+R = sum(sp2);
+write(R, $1, format="text");