This is an automated email from the ASF dual-hosted git repository. glauesppen pushed a commit to branch develop in repository https://gitbox.apache.org/repos/asf/incubator-wayang.git
commit c8a2e260d254cfbe9a69bc7e170637954fa77e66 Author: composer <[email protected]> AuthorDate: Tue Oct 24 09:51:54 2023 +0800 feat: naive kmeans operator on sparkml --- .../wayang/basic/operators/KMeansOperator.java | 38 +++++++ .../main/java/org/apache/wayang/spark/Spark.java | 12 +++ .../org/apache/wayang/spark/mapping/Mappings.java | 5 + .../wayang/spark/mapping/ml/KMeansMapping.java | 38 +++++++ .../spark/operators/ml/SparkKMeansOperator.java | 119 +++++++++++++++++++++ .../apache/wayang/spark/plugin/SparkMLPlugin.java | 35 ++++++ .../spark/operators/SparkKMeansOperatorTest.java | 40 +++++++ .../wayang-spark/wayang-spark_2.12/pom.xml | 5 + .../apache/wayang/tests/SparkIntegrationIT.java | 53 +++++++-- 9 files changed, 335 insertions(+), 10 deletions(-) diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/KMeansOperator.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/KMeansOperator.java new file mode 100644 index 00000000..91a048d1 --- /dev/null +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/operators/KMeansOperator.java @@ -0,0 +1,38 @@ +package org.apache.wayang.basic.operators; + +import org.apache.wayang.basic.data.Tuple2; +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.optimizer.cardinality.CardinalityEstimator; +import org.apache.wayang.core.plan.wayangplan.UnaryToUnaryOperator; +import org.apache.wayang.core.types.DataSetType; + +import java.util.Optional; + +public class KMeansOperator extends UnaryToUnaryOperator<double[], Tuple2<double[], Integer>> { + // TODO other parameters + protected int k; + + public KMeansOperator(int k) { + super(DataSetType.createDefaultUnchecked(double[].class), + DataSetType.createDefaultUnchecked(Tuple2.class), + false); + this.k = k; + } + + public KMeansOperator(KMeansOperator that) { + super(that); + this.k = that.k; + } + + public int getK() { + return k; + } + + // TODO support fit and transform + + @Override + public Optional<CardinalityEstimator> createCardinalityEstimator(int outputIndex, Configuration configuration) { + // TODO + return super.createCardinalityEstimator(outputIndex, configuration); + } +} diff --git a/wayang-platforms/wayang-spark/code/main/java/org/apache/wayang/spark/Spark.java b/wayang-platforms/wayang-spark/code/main/java/org/apache/wayang/spark/Spark.java index 3bd3d911..aefffed9 100644 --- a/wayang-platforms/wayang-spark/code/main/java/org/apache/wayang/spark/Spark.java +++ b/wayang-platforms/wayang-spark/code/main/java/org/apache/wayang/spark/Spark.java @@ -22,6 +22,7 @@ import org.apache.wayang.spark.platform.SparkPlatform; import org.apache.wayang.spark.plugin.SparkBasicPlugin; import org.apache.wayang.spark.plugin.SparkConversionPlugin; import org.apache.wayang.spark.plugin.SparkGraphPlugin; +import org.apache.wayang.spark.plugin.SparkMLPlugin; /** * Register for relevant components of this module. @@ -34,6 +35,8 @@ public class Spark { private final static SparkConversionPlugin CONVERSION_PLUGIN = new SparkConversionPlugin(); + private final static SparkMLPlugin ML_PLUGIN = new SparkMLPlugin(); + /** * Retrieve the {@link SparkBasicPlugin}. * @@ -61,6 +64,15 @@ public class Spark { return CONVERSION_PLUGIN; } + /** + * Retrieve the {@link SparkMLPlugin}. + * + * @return the {@link SparkMLPlugin} + */ + public static SparkMLPlugin mlPlugin() { + return ML_PLUGIN; + } + /** * Retrieve the {@link SparkPlatform}. * diff --git a/wayang-platforms/wayang-spark/code/main/java/org/apache/wayang/spark/mapping/Mappings.java b/wayang-platforms/wayang-spark/code/main/java/org/apache/wayang/spark/mapping/Mappings.java index 046fb280..484b6c30 100644 --- a/wayang-platforms/wayang-spark/code/main/java/org/apache/wayang/spark/mapping/Mappings.java +++ b/wayang-platforms/wayang-spark/code/main/java/org/apache/wayang/spark/mapping/Mappings.java @@ -20,6 +20,7 @@ package org.apache.wayang.spark.mapping; import org.apache.wayang.core.mapping.Mapping; import org.apache.wayang.spark.mapping.graph.PageRankMapping; +import org.apache.wayang.spark.mapping.ml.KMeansMapping; import java.util.Arrays; import java.util.Collection; @@ -63,4 +64,8 @@ public class Mappings { new PageRankMapping() ); + public static Collection<Mapping> ML_MAPPINGS = Arrays.asList( + new KMeansMapping() + ); + } diff --git a/wayang-platforms/wayang-spark/code/main/java/org/apache/wayang/spark/mapping/ml/KMeansMapping.java b/wayang-platforms/wayang-spark/code/main/java/org/apache/wayang/spark/mapping/ml/KMeansMapping.java new file mode 100644 index 00000000..3da37d89 --- /dev/null +++ b/wayang-platforms/wayang-spark/code/main/java/org/apache/wayang/spark/mapping/ml/KMeansMapping.java @@ -0,0 +1,38 @@ +package org.apache.wayang.spark.mapping.ml; + +import org.apache.wayang.basic.operators.KMeansOperator; +import org.apache.wayang.core.mapping.*; +import org.apache.wayang.spark.operators.ml.SparkKMeansOperator; +import org.apache.wayang.spark.platform.SparkPlatform; + +import java.util.Collection; +import java.util.Collections; + +/** + * Mapping from {@link KMeansOperator} to {@link SparkKMeansOperator}. + */ +@SuppressWarnings("unchecked") +public class KMeansMapping implements Mapping { + + @Override + public Collection<PlanTransformation> getTransformations() { + return Collections.singleton(new PlanTransformation( + this.createSubplanPattern(), + this.createReplacementSubplanFactory(), + SparkPlatform.getInstance() + )); + } + + private SubplanPattern createSubplanPattern() { + final OperatorPattern operatorPattern = new OperatorPattern( + "kMeans", new KMeansOperator(0), false + ); + return SubplanPattern.createSingleton(operatorPattern); + } + + private ReplacementSubplanFactory createReplacementSubplanFactory() { + return new ReplacementSubplanFactory.OfSingleOperators<KMeansOperator>( + (matchedOperator, epoch) -> new SparkKMeansOperator(matchedOperator).at(epoch) + ); + } +} diff --git a/wayang-platforms/wayang-spark/code/main/java/org/apache/wayang/spark/operators/ml/SparkKMeansOperator.java b/wayang-platforms/wayang-spark/code/main/java/org/apache/wayang/spark/operators/ml/SparkKMeansOperator.java new file mode 100644 index 00000000..f4084c03 --- /dev/null +++ b/wayang-platforms/wayang-spark/code/main/java/org/apache/wayang/spark/operators/ml/SparkKMeansOperator.java @@ -0,0 +1,119 @@ +package org.apache.wayang.spark.operators.ml; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.ml.clustering.KMeans; +import org.apache.spark.ml.clustering.KMeansModel; +import org.apache.spark.ml.linalg.Vector; +import org.apache.spark.ml.linalg.Vectors; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.wayang.basic.data.Tuple2; +import org.apache.wayang.basic.operators.KMeansOperator; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; +import org.apache.wayang.core.platform.ChannelDescriptor; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.core.platform.lineage.ExecutionLineageNode; +import org.apache.wayang.core.util.Tuple; +import org.apache.wayang.spark.channels.RddChannel; +import org.apache.wayang.spark.execution.SparkExecutor; +import org.apache.wayang.spark.operators.SparkExecutionOperator; + +import java.util.*; + +public class SparkKMeansOperator extends KMeansOperator implements SparkExecutionOperator { + + public SparkKMeansOperator(int k) { + super(k); + } + + public SparkKMeansOperator(KMeansOperator that) { + super(that); + } + + @Override + public List<ChannelDescriptor> getSupportedInputChannels(int index) { + // TODO need DataFrameChannel? + return Arrays.asList(RddChannel.UNCACHED_DESCRIPTOR, RddChannel.CACHED_DESCRIPTOR); + } + + @Override + public List<ChannelDescriptor> getSupportedOutputChannels(int index) { + // TODO need DataFrameChannel? + return Collections.singletonList(RddChannel.UNCACHED_DESCRIPTOR); + } + + @Override + public Tuple<Collection<ExecutionLineageNode>, Collection<ChannelInstance>> evaluate( + ChannelInstance[] inputs, + ChannelInstance[] outputs, + SparkExecutor sparkExecutor, + OptimizationContext.OperatorContext operatorContext) { + assert inputs.length == this.getNumInputs(); + assert outputs.length == this.getNumInputs(); + + final RddChannel.Instance input = (RddChannel.Instance) inputs[0]; + final RddChannel.Instance output = (RddChannel.Instance) outputs[0]; + + final JavaRDD<double[]> inputRdd = input.provideRdd(); + final JavaRDD<Data> dataRdd = inputRdd.map(Data::new); + final Dataset<Row> df = SparkSession.builder().getOrCreate().createDataFrame(dataRdd, Data.class); + final KMeansModel model = new KMeans() + .setK(this.k) + .fit(df); + + final Dataset<Row> transform = model.transform(df); + final JavaRDD<Tuple2<double[], Integer>> outputRdd = transform.toJavaRDD() + .map(row -> new Tuple2<>(((Vector) row.get(0)).toArray(), (Integer) row.get(1))); + + this.name(outputRdd); + output.accept(outputRdd, sparkExecutor); + + return ExecutionOperator.modelLazyExecution(inputs, outputs, operatorContext); + } + + // TODO support fit and transform + + @Override + public boolean containsAction() { + return false; + } + + public static class Data { + private final Vector features; + + + public Data(Vector features) { + this.features = features; + } + + public Data(double[] features) { + this.features = Vectors.dense(features); + } + + public Vector getFeatures() { + return features; + } + + @Override + public String toString() { + return "Data{" + + "features=" + features + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof Data)) return false; + Data data = (Data) o; + return Objects.equals(features, data.features); + } + + @Override + public int hashCode() { + return Objects.hash(features); + } + } +} diff --git a/wayang-platforms/wayang-spark/code/main/java/org/apache/wayang/spark/plugin/SparkMLPlugin.java b/wayang-platforms/wayang-spark/code/main/java/org/apache/wayang/spark/plugin/SparkMLPlugin.java new file mode 100644 index 00000000..55923040 --- /dev/null +++ b/wayang-platforms/wayang-spark/code/main/java/org/apache/wayang/spark/plugin/SparkMLPlugin.java @@ -0,0 +1,35 @@ +package org.apache.wayang.spark.plugin; + +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.mapping.Mapping; +import org.apache.wayang.core.optimizer.channels.ChannelConversion; +import org.apache.wayang.core.platform.Platform; +import org.apache.wayang.core.plugin.Plugin; +import org.apache.wayang.spark.mapping.Mappings; +import org.apache.wayang.spark.platform.SparkPlatform; + +import java.util.Collection; +import java.util.Collections; + +public class SparkMLPlugin implements Plugin { + + @Override + public Collection<Mapping> getMappings() { + return Mappings.ML_MAPPINGS; + } + + @Override + public Collection<ChannelConversion> getChannelConversions() { + return Collections.emptyList(); + } + + @Override + public Collection<Platform> getRequiredPlatforms() { + return Collections.singletonList(SparkPlatform.getInstance()); + } + + @Override + public void setProperties(Configuration configuration) { + // Nothing to do, because we already configured the properties in #configureDefaults(...). + } +} diff --git a/wayang-platforms/wayang-spark/code/test/java/org/apache/wayang/spark/operators/SparkKMeansOperatorTest.java b/wayang-platforms/wayang-spark/code/test/java/org/apache/wayang/spark/operators/SparkKMeansOperatorTest.java new file mode 100644 index 00000000..71c51129 --- /dev/null +++ b/wayang-platforms/wayang-spark/code/test/java/org/apache/wayang/spark/operators/SparkKMeansOperatorTest.java @@ -0,0 +1,40 @@ +package org.apache.wayang.spark.operators; + +import org.apache.wayang.basic.data.Tuple2; +import org.apache.wayang.core.platform.ChannelInstance; +import org.apache.wayang.spark.channels.RddChannel; +import org.apache.wayang.spark.operators.ml.SparkKMeansOperator; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; + +public class SparkKMeansOperatorTest extends SparkOperatorTestBase { + @Test + public void testExecution() { + // Prepare test data. + RddChannel.Instance input = this.createRddChannelInstance(Arrays.asList( + new double[]{1, 2, 3}, + new double[]{-1, -2, -3}, + new double[]{2, 4, 6})); + RddChannel.Instance output = this.createRddChannelInstance(); + + SparkKMeansOperator kMeansOperator = new SparkKMeansOperator(2); + + // Set up the ChannelInstances. + ChannelInstance[] inputs = new ChannelInstance[]{input}; + ChannelInstance[] outputs = new ChannelInstance[]{output}; + + // Execute. + this.evaluate(kMeansOperator, inputs, outputs); + + // Verify the outcome. + final List<Tuple2<double[], Integer>> results = output.<Tuple2<double[], Integer>>provideRdd().collect(); + Assert.assertEquals(3, results.size()); + Assert.assertEquals( + results.get(0).field1, + results.get(2).field1 + ); + } +} diff --git a/wayang-platforms/wayang-spark/wayang-spark_2.12/pom.xml b/wayang-platforms/wayang-spark/wayang-spark_2.12/pom.xml index 28731aa2..a706f28f 100644 --- a/wayang-platforms/wayang-spark/wayang-spark_2.12/pom.xml +++ b/wayang-platforms/wayang-spark/wayang-spark_2.12/pom.xml @@ -48,6 +48,11 @@ <artifactId>spark-graphx_2.12</artifactId> <version>${spark.version}</version> </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-mllib_2.12</artifactId> + <version>${spark.version}</version> + </dependency> <!--Error of ArrayIndexOutOfBoundsException--> <dependency> <groupId>com.thoughtworks.paranamer</groupId> diff --git a/wayang-tests-integration/code/test/java/org/apache/wayang/tests/SparkIntegrationIT.java b/wayang-tests-integration/code/test/java/org/apache/wayang/tests/SparkIntegrationIT.java index 7323f8da..77a6ff45 100644 --- a/wayang-tests-integration/code/test/java/org/apache/wayang/tests/SparkIntegrationIT.java +++ b/wayang-tests-integration/code/test/java/org/apache/wayang/tests/SparkIntegrationIT.java @@ -18,12 +18,11 @@ package org.apache.wayang.tests; -import org.junit.Assert; -import org.junit.Test; import org.apache.wayang.basic.WayangBasics; import org.apache.wayang.basic.data.Tuple2; import org.apache.wayang.basic.operators.CollectionSource; import org.apache.wayang.basic.operators.FilterOperator; +import org.apache.wayang.basic.operators.KMeansOperator; import org.apache.wayang.basic.operators.LocalCallbackSink; import org.apache.wayang.core.api.Configuration; import org.apache.wayang.core.api.Job; @@ -34,21 +33,17 @@ import org.apache.wayang.core.function.PredicateDescriptor; import org.apache.wayang.core.plan.wayangplan.WayangPlan; import org.apache.wayang.core.types.DataSetType; import org.apache.wayang.core.util.WayangCollections; +import org.apache.wayang.java.Java; import org.apache.wayang.spark.Spark; import org.apache.wayang.tests.platform.MyMadeUpPlatform; +import org.junit.Assert; +import org.junit.Test; import java.io.IOException; import java.net.URISyntaxException; import java.nio.file.Files; import java.nio.file.Paths; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Set; +import java.util.*; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -457,6 +452,44 @@ public class SparkIntegrationIT { Assert.assertEquals(expectedValues, collectedValues); } + @Test + public void testKMeans() { + CollectionSource<double[]> collectionSource = new CollectionSource<>( + Arrays.asList( + new double[]{1, 2, 3}, + new double[]{-1, -2, -3}, + new double[]{2, 4, 6}), + double[].class + ); + collectionSource.addTargetPlatform(Java.platform()); + collectionSource.addTargetPlatform(Spark.platform()); + + KMeansOperator kMeansOperator = new KMeansOperator(2); + + // write results to a sink + List<Tuple2> results = new ArrayList<>(); + LocalCallbackSink<Tuple2> sink = LocalCallbackSink.createCollectingSink(results, DataSetType.createDefault(Tuple2.class)); + + // Build Wayang plan by connecting operators + collectionSource.connectTo(0, kMeansOperator, 0); + kMeansOperator.connectTo(0, sink, 0); + WayangPlan wayangPlan = new WayangPlan(sink); + + // Have Wayang execute the plan. + WayangContext wayangContext = new WayangContext(); + wayangContext.register(Java.basicPlugin()); + wayangContext.register(Spark.basicPlugin()); + wayangContext.register(Spark.mlPlugin()); + wayangContext.execute(wayangPlan); + + // Verify the outcome. + Assert.assertEquals(3, results.size()); + Assert.assertEquals( + ((Tuple2<double[], Integer>) results.get(0)).field1, + ((Tuple2<double[], Integer>) results.get(2)).field1 + ); + } + private static class SemijoinFunction implements PredicateDescriptor.ExtendedSerializablePredicate<Integer> { private Set<Integer> allowedInts;
