[SYSTEMML] Distributed spark nary min/max operations (compiler/runtime) This patch adds the necessary compiler and runtime support for distributed spark nary min/max operations. The basic idea is to join all matrix inputs similar to codegen spark operations, ship all scalars via task closures, and then run partitioning-preserving block operations.
Furthermore, this also includes a fix of sparse runtime operations (robustness for empty rows) and refactors some existing code to avoid duplications. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/5f580f02 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/5f580f02 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/5f580f02 Branch: refs/heads/master Commit: 5f580f02ec0181f09806140f54b613ca139151d7 Parents: 231a364 Author: Matthias Boehm <[email protected]> Authored: Thu Jun 7 21:34:02 2018 -0700 Committer: Matthias Boehm <[email protected]> Committed: Thu Jun 7 21:34:02 2018 -0700 ---------------------------------------------------------------------- .../context/ExecutionContext.java | 19 +++- .../instructions/SPInstructionParser.java | 4 + .../cp/MatrixBuiltinNaryCPInstruction.java | 23 +--- .../spark/BuiltinNarySPInstruction.java | 108 ++++++++++++++----- .../instructions/spark/SpoofSPInstruction.java | 24 +---- .../spark/functions/MapInputSignature.java | 33 ++++++ .../spark/functions/MapJoinSignature.java | 36 +++++++ .../sysml/runtime/matrix/data/MatrixBlock.java | 1 + .../functions/nary/NaryMinMaxTest.java | 20 ++++ 9 files changed, 201 insertions(+), 67 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/5f580f02/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java index 6807848..d87f9d9 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/ExecutionContext.java @@ -20,7 +20,9 @@ package org.apache.sysml.runtime.controlprogram.context; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import java.util.stream.Collectors; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -72,7 +74,7 @@ public class ExecutionContext { /** * List of {@link GPUContext}s owned by this {@link ExecutionContext} */ - protected List<GPUContext> _gpuContexts = new ArrayList<>(); + protected List<GPUContext> _gpuContexts = new ArrayList<>(); protected ExecutionContext() { @@ -494,6 +496,21 @@ public class ExecutionContext { setVariable(varName, fo); } + public List<MatrixBlock> getMatrixInputs(CPOperand[] inputs) { + return Arrays.stream(inputs).filter(in -> in.isMatrix()) + .map(in -> getMatrixInput(in.getName())).collect(Collectors.toList()); + } + + public List<ScalarObject> getScalarInputs(CPOperand[] inputs) { + return Arrays.stream(inputs).filter(in -> in.isScalar()) + .map(in -> getScalarInput(in)).collect(Collectors.toList()); + } + + public void releaseMatrixInputs(CPOperand[] inputs) { + Arrays.stream(inputs).filter(in -> in.isMatrix()) + .forEach(in -> releaseMatrixInput(in.getName())); + } + /** * Pin a given list of variables i.e., set the "clean up" state in * corresponding matrix objects, so that the cached data inside these http://git-wip-us.apache.org/repos/asf/systemml/blob/5f580f02/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java index 5a201cb..dd91b9f 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java @@ -262,8 +262,12 @@ public class SPInstructionParser extends InstructionParser String2SPInstructionType.put( "rappend", SPType.RAppend); String2SPInstructionType.put( "gappend", SPType.GAppend); String2SPInstructionType.put( "galignedappend", SPType.GAlignedAppend); + String2SPInstructionType.put( "cbind", SPType.BuiltinNary); String2SPInstructionType.put( "rbind", SPType.BuiltinNary); + String2SPInstructionType.put( "nmin", SPType.BuiltinNary); + String2SPInstructionType.put( "nmax", SPType.BuiltinNary); + String2SPInstructionType.put( DataGen.RAND_OPCODE , SPType.Rand); String2SPInstructionType.put( DataGen.SEQ_OPCODE , SPType.Rand); http://git-wip-us.apache.org/repos/asf/systemml/blob/5f580f02/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java index 50e3721..6f644ec 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java @@ -19,9 +19,7 @@ package org.apache.sysml.runtime.instructions.cp; -import java.util.Arrays; import java.util.List; -import java.util.stream.Collectors; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; @@ -37,8 +35,8 @@ public class MatrixBuiltinNaryCPInstruction extends BuiltinNaryCPInstruction { @Override public void processInstruction(ExecutionContext ec) { //separate scalars and matrices and pin all input matrices - List<MatrixBlock> matrices = getMatrices(ec); - List<ScalarObject> scalars = getScalars(ec); + List<MatrixBlock> matrices = ec.getMatrixInputs(inputs); + List<ScalarObject> scalars = ec.getScalarInputs(inputs); MatrixBlock outBlock = null; if( "cbind".equals(getOpcode()) || "rbind".equals(getOpcode()) ) { @@ -56,7 +54,7 @@ public class MatrixBuiltinNaryCPInstruction extends BuiltinNaryCPInstruction { } //release inputs and set output matrix or scalar - releaseInputs(ec); + ec.releaseMatrixInputs(inputs); if( output.getDataType().isMatrix() ) { ec.setMatrixOutput(output.getName(), outBlock); } @@ -65,19 +63,4 @@ public class MatrixBuiltinNaryCPInstruction extends BuiltinNaryCPInstruction { output.getValueType(), outBlock.quickGetValue(0, 0))); } } - - private List<MatrixBlock> getMatrices(ExecutionContext ec) { - return Arrays.stream(inputs).filter(in -> in.getDataType().isMatrix()) - .map(in -> ec.getMatrixInput(in.getName())).collect(Collectors.toList()); - } - - private List<ScalarObject> getScalars(ExecutionContext ec) { - return Arrays.stream(inputs).filter(in -> in.getDataType().isScalar()) - .map(in -> ec.getScalarInput(in)).collect(Collectors.toList()); - } - - private void releaseInputs(ExecutionContext ec) { - Arrays.stream(inputs).filter(in -> in.getDataType().isMatrix()) - .forEach(in -> ec.releaseMatrixInput(in.getName())); - } } http://git-wip-us.apache.org/repos/asf/systemml/blob/5f580f02/src/main/java/org/apache/sysml/runtime/instructions/spark/BuiltinNarySPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/BuiltinNarySPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/BuiltinNarySPInstruction.java index 583fe96..b0a77fb 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/BuiltinNarySPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/BuiltinNarySPInstruction.java @@ -19,18 +19,26 @@ package org.apache.sysml.runtime.instructions.spark; +import java.util.List; + import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFunction; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysml.runtime.functionobjects.Builtin; import org.apache.sysml.runtime.instructions.InstructionUtils; import org.apache.sysml.runtime.instructions.cp.CPOperand; +import org.apache.sysml.runtime.instructions.cp.ScalarObject; import org.apache.sysml.runtime.instructions.spark.AppendGSPInstruction.ShiftMatrix; +import org.apache.sysml.runtime.instructions.spark.functions.MapInputSignature; +import org.apache.sysml.runtime.instructions.spark.functions.MapJoinSignature; import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; +import org.apache.sysml.runtime.matrix.operators.SimpleOperator; import org.apache.sysml.runtime.util.UtilFunctions; import scala.Tuple2; @@ -60,53 +68,88 @@ public class BuiltinNarySPInstruction extends SPInstruction @Override public void processInstruction(ExecutionContext ec) { SparkExecutionContext sec = (SparkExecutionContext)ec; - boolean cbind = getOpcode().equals("cbind"); - - //compute output characteristics - MatrixCharacteristics mcOut = computeOutputMatrixCharacteristics(sec, inputs, cbind); - - //get consolidated input via union over shifted and padded inputs - MatrixCharacteristics off = new MatrixCharacteristics( - 0, 0, mcOut.getRowsPerBlock(), mcOut.getColsPerBlock(), 0); JavaPairRDD<MatrixIndexes,MatrixBlock> out = null; - for( CPOperand input : inputs ) { - MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(input.getName()); - JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec - .getBinaryBlockRDDHandleForVariable( input.getName() ) - .flatMapToPair(new ShiftMatrix(off, mcIn, cbind)) - .mapToPair(new PadBlocksFunction(mcOut)); //just padding - out = (out != null) ? out.union(in) : in; - updateMatrixCharacteristics(mcIn, off, cbind); - } + MatrixCharacteristics mcOut = null; - //aggregate partially overlapping blocks w/ single shuffle - int numPartOut = SparkUtils.getNumPreferredPartitions(mcOut); - out = RDDAggregateUtils.mergeByKey(out, numPartOut, false); + if( getOpcode().equals("cbind") || getOpcode().equals("rbind") ) { + //compute output characteristics + boolean cbind = getOpcode().equals("cbind"); + mcOut = computeAppendOutputMatrixCharacteristics(sec, inputs, cbind); + + //get consolidated input via union over shifted and padded inputs + MatrixCharacteristics off = new MatrixCharacteristics( + 0, 0, mcOut.getRowsPerBlock(), mcOut.getColsPerBlock(), 0); + for( CPOperand input : inputs ) { + MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(input.getName()); + JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec + .getBinaryBlockRDDHandleForVariable( input.getName() ) + .flatMapToPair(new ShiftMatrix(off, mcIn, cbind)) + .mapToPair(new PadBlocksFunction(mcOut)); //just padding + out = (out != null) ? out.union(in) : in; + updateAppendMatrixCharacteristics(mcIn, off, cbind); + } + + //aggregate partially overlapping blocks w/ single shuffle + int numPartOut = SparkUtils.getNumPreferredPartitions(mcOut); + out = RDDAggregateUtils.mergeByKey(out, numPartOut, false); + } + else if( getOpcode().equals("nmin") || getOpcode().equals("nmax") ) { + //compute output characteristics + mcOut = computeMinMaxOutputMatrixCharacteristics(sec, inputs); + + //get scalars and consolidated input via join + List<ScalarObject> scalars = sec.getScalarInputs(inputs); + JavaPairRDD<MatrixIndexes, MatrixBlock[]> in = null; + for( CPOperand input : inputs ) { + if( !input.getDataType().isMatrix() ) continue; + JavaPairRDD<MatrixIndexes, MatrixBlock> tmp = sec + .getBinaryBlockRDDHandleForVariable(input.getName()); + in = (in == null) ? tmp.mapValues(new MapInputSignature()) : + in.join(tmp).mapValues(new MapJoinSignature()); + } + + //compute nary min/max (partitioning-preserving) + out = in.mapValues(new MinMaxFunction(getOpcode(), scalars)); + } //set output RDD and add lineage sec.getMatrixCharacteristics(output.getName()).set(mcOut); sec.setRDDHandleForVariable(output.getName(), out); for( CPOperand input : inputs ) - sec.addLineageRDD(output.getName(), input.getName()); + if( !input.isScalar() ) + sec.addLineageRDD(output.getName(), input.getName()); } - private static MatrixCharacteristics computeOutputMatrixCharacteristics(SparkExecutionContext sec, CPOperand[] inputs, boolean cbind) { + private static MatrixCharacteristics computeAppendOutputMatrixCharacteristics(SparkExecutionContext sec, CPOperand[] inputs, boolean cbind) { MatrixCharacteristics mcIn1 = sec.getMatrixCharacteristics(inputs[0].getName()); MatrixCharacteristics mcOut = new MatrixCharacteristics( 0, 0, mcIn1.getRowsPerBlock(), mcIn1.getColsPerBlock(), 0); for( CPOperand input : inputs ) { MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(input.getName()); - updateMatrixCharacteristics(mcIn, mcOut, cbind); + updateAppendMatrixCharacteristics(mcIn, mcOut, cbind); } return mcOut; } - private static void updateMatrixCharacteristics(MatrixCharacteristics in, MatrixCharacteristics out, boolean cbind) { + private static void updateAppendMatrixCharacteristics(MatrixCharacteristics in, MatrixCharacteristics out, boolean cbind) { out.setDimension(cbind ? Math.max(out.getRows(), in.getRows()) : out.getRows()+in.getRows(), cbind ? out.getCols()+in.getCols() : Math.max(out.getCols(), in.getCols())); out.setNonZeros((out.getNonZeros()!=-1 && in.dimsKnown(true)) ? out.getNonZeros()+in.getNonZeros() : -1); } + private static MatrixCharacteristics computeMinMaxOutputMatrixCharacteristics(SparkExecutionContext sec, CPOperand[] inputs) { + MatrixCharacteristics mcOut = new MatrixCharacteristics(); + for( CPOperand input : inputs ) { + if( !input.getDataType().isMatrix() ) continue; + MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(input.getName()); + mcOut.setRows(Math.max(mcOut.getRows(), mcIn.getRows())); + mcOut.setCols(Math.max(mcOut.getCols(), mcIn.getCols())); + mcOut.setRowsPerBlock(mcIn.getRowsPerBlock()); + mcOut.setColsPerBlock(mcIn.getColsPerBlock()); + } + return mcOut; + } + public static class PadBlocksFunction implements PairFunction<Tuple2<MatrixIndexes,MatrixBlock>,MatrixIndexes,MatrixBlock> { private static final long serialVersionUID = 1291358959908299855L; @@ -135,4 +178,21 @@ public class BuiltinNarySPInstruction extends SPInstruction return new Tuple2<>(ix, mb); } } + + private static class MinMaxFunction implements Function<MatrixBlock[], MatrixBlock> { + private static final long serialVersionUID = -4227447915387484397L; + + private final SimpleOperator _op; + private final ScalarObject[] _scalars; + + public MinMaxFunction(String opcode, List<ScalarObject> scalars) { + _scalars = scalars.toArray(new ScalarObject[0]); + _op = new SimpleOperator(Builtin.getBuiltinFnObject(opcode.substring(1))); + } + + @Override + public MatrixBlock call(MatrixBlock[] v1) throws Exception { + return MatrixBlock.naryOperations(_op, v1, _scalars, new MatrixBlock()); + } + } } http://git-wip-us.apache.org/repos/asf/systemml/blob/5f580f02/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java index 15d4de7..8f63427 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java @@ -27,9 +27,7 @@ import java.util.LinkedList; import java.util.List; import java.util.stream.IntStream; -import org.apache.commons.lang3.ArrayUtils; import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; @@ -59,6 +57,8 @@ import org.apache.sysml.runtime.instructions.cp.CPOperand; import org.apache.sysml.runtime.instructions.cp.DoubleObject; import org.apache.sysml.runtime.instructions.cp.ScalarObject; import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast; +import org.apache.sysml.runtime.instructions.spark.functions.MapInputSignature; +import org.apache.sysml.runtime.instructions.spark.functions.MapJoinSignature; import org.apache.sysml.runtime.instructions.spark.functions.ReplicateBlockFunction; import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; @@ -352,26 +352,6 @@ public class SpoofSPInstruction extends SPInstruction { } } - private static class MapInputSignature implements Function<MatrixBlock, MatrixBlock[]> - { - private static final long serialVersionUID = -816443970067626102L; - - @Override - public MatrixBlock[] call(MatrixBlock v1) throws Exception { - return new MatrixBlock[]{ v1 }; - } - } - - private static class MapJoinSignature implements Function<Tuple2<MatrixBlock[],MatrixBlock>, MatrixBlock[]> - { - private static final long serialVersionUID = -704403012606821854L; - - @Override - public MatrixBlock[] call(Tuple2<MatrixBlock[], MatrixBlock> v1) throws Exception { - return ArrayUtils.add(v1._1(), v1._2()); - } - } - private static class SpoofFunction implements Serializable { private static final long serialVersionUID = 2953479427746463003L; http://git-wip-us.apache.org/repos/asf/systemml/blob/5f580f02/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/MapInputSignature.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/MapInputSignature.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/MapInputSignature.java new file mode 100644 index 0000000..c7fdad6 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/MapInputSignature.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.spark.functions; + +import org.apache.spark.api.java.function.Function; + +import org.apache.sysml.runtime.matrix.data.MatrixBlock; + +public class MapInputSignature implements Function<MatrixBlock, MatrixBlock[]> { + private static final long serialVersionUID = -816443970067626102L; + + @Override + public MatrixBlock[] call(MatrixBlock v1) throws Exception { + return new MatrixBlock[]{ v1 }; + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/5f580f02/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/MapJoinSignature.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/MapJoinSignature.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/MapJoinSignature.java new file mode 100644 index 0000000..acfa687 --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/MapJoinSignature.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.spark.functions; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.spark.api.java.function.Function; + +import org.apache.sysml.runtime.matrix.data.MatrixBlock; + +import scala.Tuple2; + +public class MapJoinSignature implements Function<Tuple2<MatrixBlock[],MatrixBlock>, MatrixBlock[]> { + private static final long serialVersionUID = -704403012606821854L; + + @Override + public MatrixBlock[] call(Tuple2<MatrixBlock[], MatrixBlock> v1) throws Exception { + return ArrayUtils.add(v1._1(), v1._2()); + } +} http://git-wip-us.apache.org/repos/asf/systemml/blob/5f580f02/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java index 755bca1..5c285f8 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java @@ -3531,6 +3531,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab continue; if( in.isInSparseFormat() ) { SparseBlock a = in.sparseBlock; + if( a.isEmpty(i) ) continue; int alen = a.size(i); int apos = a.pos(i); int[] aix = a.indexes(i); http://git-wip-us.apache.org/repos/asf/systemml/blob/5f580f02/src/test/java/org/apache/sysml/test/integration/functions/nary/NaryMinMaxTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/nary/NaryMinMaxTest.java b/src/test/java/org/apache/sysml/test/integration/functions/nary/NaryMinMaxTest.java index c5a026a..3ef29c9 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/nary/NaryMinMaxTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/nary/NaryMinMaxTest.java @@ -71,6 +71,26 @@ public class NaryMinMaxTest extends AutomatedTestBase runMinMaxTest(false, true, ExecType.CP); } + @Test + public void testNaryMinDenseSP() { + runMinMaxTest(true, false, ExecType.SPARK); + } + + @Test + public void testNaryMinSparseSP() { + runMinMaxTest(true, true, ExecType.SPARK); + } + + @Test + public void testNaryMaxDenseSP() { + runMinMaxTest(false, false, ExecType.SPARK); + } + + @Test + public void testNaryMaxSparseSP() { + runMinMaxTest(false, true, ExecType.SPARK); + } + public void runMinMaxTest(boolean min, boolean sparse, ExecType et) { RUNTIME_PLATFORM platformOld = rtplatform;
