http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/KafkaStreamingTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/KafkaStreamingTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/KafkaStreamingTest.java deleted file mode 100644 index 05340d6..0000000 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/KafkaStreamingTest.java +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.streaming; - -import com.google.cloud.dataflow.sdk.Pipeline; -import com.google.cloud.dataflow.sdk.coders.KvCoder; -import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; -import com.google.cloud.dataflow.sdk.testing.DataflowAssert; -import com.google.cloud.dataflow.sdk.transforms.DoFn; -import com.google.cloud.dataflow.sdk.transforms.ParDo; -import com.google.cloud.dataflow.sdk.transforms.View; -import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; -import com.google.cloud.dataflow.sdk.transforms.windowing.Window; -import com.google.cloud.dataflow.sdk.values.KV; -import com.google.cloud.dataflow.sdk.values.PCollection; - -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; -import org.apache.beam.runners.spark.io.KafkaIO; -import org.apache.beam.runners.spark.EvaluationResult; -import org.apache.beam.runners.spark.SparkPipelineRunner; -import org.apache.beam.runners.spark.streaming.utils.DataflowAssertStreaming; -import org.apache.beam.runners.spark.streaming.utils.EmbeddedKafkaCluster; - -import org.apache.kafka.clients.producer.KafkaProducer; -import org.apache.kafka.clients.producer.ProducerRecord; -import org.apache.kafka.common.serialization.Serializer; -import org.apache.kafka.common.serialization.StringSerializer; -import org.joda.time.Duration; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; - -import java.io.IOException; -import java.util.Collections; -import java.util.Map; -import java.util.Properties; -import java.util.Set; - -import kafka.serializer.StringDecoder; - -/** - * Test Kafka as input. - */ -public class KafkaStreamingTest { - private static final EmbeddedKafkaCluster.EmbeddedZookeeper EMBEDDED_ZOOKEEPER = - new EmbeddedKafkaCluster.EmbeddedZookeeper(17001); - private static final EmbeddedKafkaCluster EMBEDDED_KAFKA_CLUSTER = - new EmbeddedKafkaCluster(EMBEDDED_ZOOKEEPER.getConnection(), - new Properties(), Collections.singletonList(6667)); - private static final String TOPIC = "kafka_dataflow_test_topic"; - private static final Map<String, String> KAFKA_MESSAGES = ImmutableMap.of( - "k1", "v1", "k2", "v2", "k3", "v3", "k4", "v4" - ); - private static final Set<String> EXPECTED = ImmutableSet.of( - "k1,v1", "k2,v2", "k3,v3", "k4,v4" - ); - private static final long TEST_TIMEOUT_MSEC = 1000L; - - @BeforeClass - public static void init() throws IOException { - EMBEDDED_ZOOKEEPER.startup(); - EMBEDDED_KAFKA_CLUSTER.startup(); - - // write to Kafka - Properties producerProps = new Properties(); - producerProps.putAll(EMBEDDED_KAFKA_CLUSTER.getProps()); - producerProps.put("request.required.acks", 1); - producerProps.put("bootstrap.servers", EMBEDDED_KAFKA_CLUSTER.getBrokerList()); - Serializer<String> stringSerializer = new StringSerializer(); - try (@SuppressWarnings("unchecked") KafkaProducer<String, String> kafkaProducer = - new KafkaProducer(producerProps, stringSerializer, stringSerializer)) { - for (Map.Entry<String, String> en : KAFKA_MESSAGES.entrySet()) { - kafkaProducer.send(new ProducerRecord<>(TOPIC, en.getKey(), en.getValue())); - } - } - } - - @Test - public void testRun() throws Exception { - // test read from Kafka - SparkStreamingPipelineOptions options = SparkStreamingPipelineOptionsFactory.create(); - options.setAppName(this.getClass().getSimpleName()); - options.setRunner(SparkPipelineRunner.class); - options.setTimeout(TEST_TIMEOUT_MSEC);// run for one interval - Pipeline p = Pipeline.create(options); - - Map<String, String> kafkaParams = ImmutableMap.of( - "metadata.broker.list", EMBEDDED_KAFKA_CLUSTER.getBrokerList(), - "auto.offset.reset", "smallest" - ); - - PCollection<KV<String, String>> kafkaInput = p.apply(KafkaIO.Read.from(StringDecoder.class, - StringDecoder.class, String.class, String.class, Collections.singleton(TOPIC), - kafkaParams)) - .setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())); - PCollection<KV<String, String>> windowedWords = kafkaInput - .apply(Window.<KV<String, String>>into(FixedWindows.of(Duration.standardSeconds(1)))); - - PCollection<String> formattedKV = windowedWords.apply(ParDo.of(new FormatKVFn())); - - DataflowAssert.thatIterable(formattedKV.apply(View.<String>asIterable())) - .containsInAnyOrder(EXPECTED); - - EvaluationResult res = SparkPipelineRunner.create(options).run(p); - res.close(); - - DataflowAssertStreaming.assertNoFailures(res); - } - - @AfterClass - public static void tearDown() { - EMBEDDED_KAFKA_CLUSTER.shutdown(); - EMBEDDED_ZOOKEEPER.shutdown(); - } - - private static class FormatKVFn extends DoFn<KV<String, String>, String> { - @Override - public void processElement(ProcessContext c) { - c.output(c.element().getKey() + "," + c.element().getValue()); - } - } - -}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/SimpleStreamingWordCountTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/SimpleStreamingWordCountTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/SimpleStreamingWordCountTest.java deleted file mode 100644 index 16b145a..0000000 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/SimpleStreamingWordCountTest.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.streaming; - -import com.google.cloud.dataflow.sdk.Pipeline; -import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; -import com.google.cloud.dataflow.sdk.testing.DataflowAssert; -import com.google.cloud.dataflow.sdk.transforms.View; -import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; -import com.google.cloud.dataflow.sdk.transforms.windowing.Window; -import com.google.cloud.dataflow.sdk.values.PCollection; -import com.google.common.collect.ImmutableSet; - -import org.apache.beam.runners.spark.io.CreateStream; -import org.apache.beam.runners.spark.EvaluationResult; -import org.apache.beam.runners.spark.SimpleWordCountTest; -import org.apache.beam.runners.spark.SparkPipelineRunner; -import org.apache.beam.runners.spark.streaming.utils.DataflowAssertStreaming; - -import org.joda.time.Duration; -import org.junit.Test; - -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Set; - -public class SimpleStreamingWordCountTest { - - private static final String[] WORDS_ARRAY = { - "hi there", "hi", "hi sue bob", "hi sue", "", "bob hi"}; - private static final List<Iterable<String>> WORDS_QUEUE = - Collections.<Iterable<String>>singletonList(Arrays.asList(WORDS_ARRAY)); - private static final Set<String> EXPECTED_COUNT_SET = - ImmutableSet.of("hi: 5", "there: 1", "sue: 2", "bob: 2"); - private static final long TEST_TIMEOUT_MSEC = 1000L; - - @Test - public void testRun() throws Exception { - SparkStreamingPipelineOptions options = SparkStreamingPipelineOptionsFactory.create(); - options.setAppName(this.getClass().getSimpleName()); - options.setRunner(SparkPipelineRunner.class); - options.setTimeout(TEST_TIMEOUT_MSEC);// run for one interval - Pipeline p = Pipeline.create(options); - - PCollection<String> inputWords = - p.apply(CreateStream.fromQueue(WORDS_QUEUE)).setCoder(StringUtf8Coder.of()); - PCollection<String> windowedWords = inputWords - .apply(Window.<String>into(FixedWindows.of(Duration.standardSeconds(1)))); - - PCollection<String> output = windowedWords.apply(new SimpleWordCountTest.CountWords()); - - DataflowAssert.thatIterable(output.apply(View.<String>asIterable())) - .containsInAnyOrder(EXPECTED_COUNT_SET); - - EvaluationResult res = SparkPipelineRunner.create(options).run(p); - res.close(); - - DataflowAssertStreaming.assertNoFailures(res); - } -} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/utils/DataflowAssertStreaming.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/utils/DataflowAssertStreaming.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/utils/DataflowAssertStreaming.java deleted file mode 100644 index 367a062..0000000 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/utils/DataflowAssertStreaming.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.streaming.utils; - -import org.apache.beam.runners.spark.EvaluationResult; - -import org.junit.Assert; - -/** - * Since DataflowAssert doesn't propagate assert exceptions, use Aggregators to assert streaming - * success/failure counters. - */ -public final class DataflowAssertStreaming { - /** - * Copied aggregator names from {@link com.google.cloud.dataflow.sdk.testing.DataflowAssert} - */ - static final String SUCCESS_COUNTER = "DataflowAssertSuccess"; - static final String FAILURE_COUNTER = "DataflowAssertFailure"; - - private DataflowAssertStreaming() { - } - - public static void assertNoFailures(EvaluationResult res) { - int failures = res.getAggregatorValue(FAILURE_COUNTER, Integer.class); - Assert.assertEquals("Found " + failures + " failures, see the log for details", 0, failures); - } -} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/utils/EmbeddedKafkaCluster.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/utils/EmbeddedKafkaCluster.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/utils/EmbeddedKafkaCluster.java deleted file mode 100644 index 8273684..0000000 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/streaming/utils/EmbeddedKafkaCluster.java +++ /dev/null @@ -1,317 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.spark.streaming.utils; - -import org.apache.zookeeper.server.NIOServerCnxnFactory; -import org.apache.zookeeper.server.ServerCnxnFactory; -import org.apache.zookeeper.server.ZooKeeperServer; - -import java.io.File; -import java.io.FileNotFoundException; -import java.io.IOException; -import java.net.InetSocketAddress; -import java.net.ServerSocket; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Properties; -import java.util.Random; - -import kafka.server.KafkaConfig; -import kafka.server.KafkaServer; -import kafka.utils.Time; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * https://gist.github.com/fjavieralba/7930018 - */ -public class EmbeddedKafkaCluster { - - private static final Logger LOG = LoggerFactory.getLogger(EmbeddedKafkaCluster.class); - - private final List<Integer> ports; - private final String zkConnection; - private final Properties baseProperties; - - private final String brokerList; - - private final List<KafkaServer> brokers; - private final List<File> logDirs; - - public EmbeddedKafkaCluster(String zkConnection) { - this(zkConnection, new Properties()); - } - - public EmbeddedKafkaCluster(String zkConnection, Properties baseProperties) { - this(zkConnection, baseProperties, Collections.singletonList(-1)); - } - - public EmbeddedKafkaCluster(String zkConnection, Properties baseProperties, List<Integer> ports) { - this.zkConnection = zkConnection; - this.ports = resolvePorts(ports); - this.baseProperties = baseProperties; - - this.brokers = new ArrayList<>(); - this.logDirs = new ArrayList<>(); - - this.brokerList = constructBrokerList(this.ports); - } - - private static List<Integer> resolvePorts(List<Integer> ports) { - List<Integer> resolvedPorts = new ArrayList<>(); - for (Integer port : ports) { - resolvedPorts.add(resolvePort(port)); - } - return resolvedPorts; - } - - private static int resolvePort(int port) { - if (port == -1) { - return TestUtils.getAvailablePort(); - } - return port; - } - - private static String constructBrokerList(List<Integer> ports) { - StringBuilder sb = new StringBuilder(); - for (Integer port : ports) { - if (sb.length() > 0) { - sb.append(","); - } - sb.append("localhost:").append(port); - } - return sb.toString(); - } - - public void startup() { - for (int i = 0; i < ports.size(); i++) { - Integer port = ports.get(i); - File logDir = TestUtils.constructTempDir("kafka-local"); - - Properties properties = new Properties(); - properties.putAll(baseProperties); - properties.setProperty("zookeeper.connect", zkConnection); - properties.setProperty("broker.id", String.valueOf(i + 1)); - properties.setProperty("host.name", "localhost"); - properties.setProperty("port", Integer.toString(port)); - properties.setProperty("log.dir", logDir.getAbsolutePath()); - properties.setProperty("log.flush.interval.messages", String.valueOf(1)); - - KafkaServer broker = startBroker(properties); - - brokers.add(broker); - logDirs.add(logDir); - } - } - - - private static KafkaServer startBroker(Properties props) { - KafkaServer server = new KafkaServer(new KafkaConfig(props), new SystemTime()); - server.startup(); - return server; - } - - public Properties getProps() { - Properties props = new Properties(); - props.putAll(baseProperties); - props.put("metadata.broker.list", brokerList); - props.put("zookeeper.connect", zkConnection); - return props; - } - - public String getBrokerList() { - return brokerList; - } - - public List<Integer> getPorts() { - return ports; - } - - public String getZkConnection() { - return zkConnection; - } - - public void shutdown() { - for (KafkaServer broker : brokers) { - try { - broker.shutdown(); - } catch (Exception e) { - LOG.warn("{}", e.getMessage(), e); - } - } - for (File logDir : logDirs) { - try { - TestUtils.deleteFile(logDir); - } catch (FileNotFoundException e) { - LOG.warn("{}", e.getMessage(), e); - } - } - } - - @Override - public String toString() { - return "EmbeddedKafkaCluster{" + "brokerList='" + brokerList + "'}"; - } - - public static class EmbeddedZookeeper { - private int port = -1; - private int tickTime = 500; - - private ServerCnxnFactory factory; - private File snapshotDir; - private File logDir; - - public EmbeddedZookeeper() { - this(-1); - } - - public EmbeddedZookeeper(int port) { - this(port, 500); - } - - public EmbeddedZookeeper(int port, int tickTime) { - this.port = resolvePort(port); - this.tickTime = tickTime; - } - - private static int resolvePort(int port) { - if (port == -1) { - return TestUtils.getAvailablePort(); - } - return port; - } - - public void startup() throws IOException { - if (this.port == -1) { - this.port = TestUtils.getAvailablePort(); - } - this.factory = NIOServerCnxnFactory.createFactory(new InetSocketAddress("localhost", port), - 1024); - this.snapshotDir = TestUtils.constructTempDir("embedded-zk/snapshot"); - this.logDir = TestUtils.constructTempDir("embedded-zk/log"); - - try { - factory.startup(new ZooKeeperServer(snapshotDir, logDir, tickTime)); - } catch (InterruptedException e) { - throw new IOException(e); - } - } - - - public void shutdown() { - factory.shutdown(); - try { - TestUtils.deleteFile(snapshotDir); - } catch (FileNotFoundException e) { - // ignore - } - try { - TestUtils.deleteFile(logDir); - } catch (FileNotFoundException e) { - // ignore - } - } - - public String getConnection() { - return "localhost:" + port; - } - - public void setPort(int port) { - this.port = port; - } - - public void setTickTime(int tickTime) { - this.tickTime = tickTime; - } - - public int getPort() { - return port; - } - - public int getTickTime() { - return tickTime; - } - - @Override - public String toString() { - return "EmbeddedZookeeper{" + "connection=" + getConnection() + "}"; - } - } - - static class SystemTime implements Time { - @Override - public long milliseconds() { - return System.currentTimeMillis(); - } - - @Override - public long nanoseconds() { - return System.nanoTime(); - } - - @Override - public void sleep(long ms) { - try { - Thread.sleep(ms); - } catch (InterruptedException e) { - // Ignore - } - } - } - - static final class TestUtils { - private static final Random RANDOM = new Random(); - - private TestUtils() { - } - - static File constructTempDir(String dirPrefix) { - File file = new File(System.getProperty("java.io.tmpdir"), dirPrefix + RANDOM.nextInt - (10000000)); - if (!file.mkdirs()) { - throw new RuntimeException("could not create temp directory: " + file.getAbsolutePath()); - } - file.deleteOnExit(); - return file; - } - - static int getAvailablePort() { - try { - try (ServerSocket socket = new ServerSocket(0)) { - return socket.getLocalPort(); - } - } catch (IOException e) { - throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e); - } - } - - static boolean deleteFile(File path) throws FileNotFoundException { - if (!path.exists()) { - throw new FileNotFoundException(path.getAbsolutePath()); - } - boolean ret = true; - if (path.isDirectory()) { - for (File f : path.listFiles()) { - ret = ret && deleteFile(f); - } - } - return ret && path.delete(); - } - } -} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/CombineGloballyTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/CombineGloballyTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/CombineGloballyTest.java new file mode 100644 index 0000000..6945d68 --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/CombineGloballyTest.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.Iterables; +import org.apache.beam.runners.spark.EvaluationResult; +import org.apache.beam.runners.spark.SparkPipelineOptions; +import org.apache.beam.runners.spark.SparkPipelineRunner; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class CombineGloballyTest { + + private static final String[] WORDS_ARRAY = { + "hi there", "hi", "hi sue bob", + "hi sue", "", "bob hi"}; + private static final List<String> WORDS = Arrays.asList(WORDS_ARRAY); + + @Test + public void test() throws Exception { + SparkPipelineOptions options = SparkPipelineOptionsFactory.create(); + Pipeline p = Pipeline.create(options); + PCollection<String> inputWords = p.apply(Create.of(WORDS)).setCoder(StringUtf8Coder.of()); + PCollection<String> output = inputWords.apply(Combine.globally(new WordMerger())); + + EvaluationResult res = SparkPipelineRunner.create().run(p); + assertEquals("hi there,hi,hi sue bob,hi sue,,bob hi", Iterables.getOnlyElement(res.get(output))); + res.close(); + } + + public static class WordMerger extends Combine.CombineFn<String, StringBuilder, String> { + + @Override + public StringBuilder createAccumulator() { + // return null to differentiate from an empty string + return null; + } + + @Override + public StringBuilder addInput(StringBuilder accumulator, String input) { + return combine(accumulator, input); + } + + @Override + public StringBuilder mergeAccumulators(Iterable<StringBuilder> accumulators) { + StringBuilder sb = new StringBuilder(); + for (StringBuilder accum : accumulators) { + if (accum != null) { + sb.append(accum); + } + } + return sb; + } + + @Override + public String extractOutput(StringBuilder accumulator) { + return accumulator != null ? accumulator.toString(): ""; + } + + private static StringBuilder combine(StringBuilder accum, String datum) { + if (accum == null) { + return new StringBuilder(datum); + } else { + accum.append(",").append(datum); + return accum; + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/CombinePerKeyTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/CombinePerKeyTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/CombinePerKeyTest.java new file mode 100644 index 0000000..0373968 --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/CombinePerKeyTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.coders.VarLongCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.transforms.*; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.ImmutableList; +import org.apache.beam.runners.spark.EvaluationResult; +import org.apache.beam.runners.spark.SparkPipelineRunner; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class CombinePerKeyTest { + + private static final List<String> WORDS = + ImmutableList.of("the", "quick", "brown", "fox", "jumped", "over", "the", "lazy", "dog"); + @Test + public void testRun() { + Pipeline p = Pipeline.create(PipelineOptionsFactory.create()); + PCollection<String> inputWords = p.apply(Create.of(WORDS)).setCoder(StringUtf8Coder.of()); + PCollection<KV<String, Long>> cnts = inputWords.apply(new SumPerKey<String>()); + EvaluationResult res = SparkPipelineRunner.create().run(p); + Map<String, Long> actualCnts = new HashMap<>(); + for (KV<String, Long> kv : res.get(cnts)) { + actualCnts.put(kv.getKey(), kv.getValue()); + } + res.close(); + Assert.assertEquals(8, actualCnts.size()); + Assert.assertEquals(Long.valueOf(2L), actualCnts.get("the")); + } + + private static class SumPerKey<T> extends PTransform<PCollection<T>, PCollection<KV<T, Long>>> { + @Override + public PCollection<KV<T, Long>> apply(PCollection<T> pcol) { + PCollection<KV<T, Long>> withLongs = pcol.apply(ParDo.of(new DoFn<T, KV<T, Long>>() { + @Override + public void processElement(ProcessContext processContext) throws Exception { + processContext.output(KV.of(processContext.element(), 1L)); + } + })).setCoder(KvCoder.of(pcol.getCoder(), VarLongCoder.of())); + return withLongs.apply(Sum.<T>longsPerKey()); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/DoFnOutputTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/DoFnOutputTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/DoFnOutputTest.java new file mode 100644 index 0000000..a9779e6 --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/DoFnOutputTest.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.values.PCollection; +import org.apache.beam.runners.spark.EvaluationResult; +import org.apache.beam.runners.spark.SparkPipelineOptions; +import org.apache.beam.runners.spark.SparkPipelineRunner; +import org.junit.Test; + +import java.io.Serializable; + +public class DoFnOutputTest implements Serializable { + @Test + public void test() throws Exception { + SparkPipelineOptions options = SparkPipelineOptionsFactory.create(); + options.setRunner(SparkPipelineRunner.class); + Pipeline pipeline = Pipeline.create(options); + + PCollection<String> strings = pipeline.apply(Create.of("a")); + // Test that values written from startBundle() and finishBundle() are written to + // the output + PCollection<String> output = strings.apply(ParDo.of(new DoFn<String, String>() { + @Override + public void startBundle(Context c) throws Exception { + c.output("start"); + } + @Override + public void processElement(ProcessContext c) throws Exception { + c.output(c.element()); + } + @Override + public void finishBundle(Context c) throws Exception { + c.output("finish"); + } + })); + + DataflowAssert.that(output).containsInAnyOrder("start", "a", "finish"); + + EvaluationResult res = SparkPipelineRunner.create().run(pipeline); + res.close(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/MultiOutputWordCountTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/MultiOutputWordCountTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/MultiOutputWordCountTest.java new file mode 100644 index 0000000..8ab3798 --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/MultiOutputWordCountTest.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.AggregatorValues; +import com.google.cloud.dataflow.sdk.transforms.*; +import com.google.cloud.dataflow.sdk.values.*; +import com.google.common.collect.Iterables; +import org.apache.beam.runners.spark.EvaluationResult; +import org.apache.beam.runners.spark.SparkPipelineRunner; +import org.junit.Assert; +import org.junit.Test; + +public class MultiOutputWordCountTest { + + private static final TupleTag<String> upper = new TupleTag<>(); + private static final TupleTag<String> lower = new TupleTag<>(); + private static final TupleTag<KV<String, Long>> lowerCnts = new TupleTag<>(); + private static final TupleTag<KV<String, Long>> upperCnts = new TupleTag<>(); + + @Test + public void testRun() throws Exception { + Pipeline p = Pipeline.create(PipelineOptionsFactory.create()); + PCollection<String> regex = p.apply(Create.of("[^a-zA-Z']+")); + PCollection<String> w1 = p.apply(Create.of("Here are some words to count", "and some others")); + PCollection<String> w2 = p.apply(Create.of("Here are some more words", "and even more words")); + PCollectionList<String> list = PCollectionList.of(w1).and(w2); + + PCollection<String> union = list.apply(Flatten.<String>pCollections()); + PCollectionView<String> regexView = regex.apply(View.<String>asSingleton()); + CountWords countWords = new CountWords(regexView); + PCollectionTuple luc = union.apply(countWords); + PCollection<Long> unique = luc.get(lowerCnts).apply( + ApproximateUnique.<KV<String, Long>>globally(16)); + + EvaluationResult res = SparkPipelineRunner.create().run(p); + Iterable<KV<String, Long>> actualLower = res.get(luc.get(lowerCnts)); + Assert.assertEquals("are", actualLower.iterator().next().getKey()); + Iterable<KV<String, Long>> actualUpper = res.get(luc.get(upperCnts)); + Assert.assertEquals("Here", actualUpper.iterator().next().getKey()); + Iterable<Long> actualUniqCount = res.get(unique); + Assert.assertEquals(9, (long) actualUniqCount.iterator().next()); + int actualTotalWords = res.getAggregatorValue("totalWords", Integer.class); + Assert.assertEquals(18, actualTotalWords); + int actualMaxWordLength = res.getAggregatorValue("maxWordLength", Integer.class); + Assert.assertEquals(6, actualMaxWordLength); + AggregatorValues<Integer> aggregatorValues = res.getAggregatorValues(countWords + .getTotalWordsAggregator()); + Assert.assertEquals(18, Iterables.getOnlyElement(aggregatorValues.getValues()).intValue()); + + res.close(); + } + + /** + * A DoFn that tokenizes lines of text into individual words. + */ + static class ExtractWordsFn extends DoFn<String, String> { + + private final Aggregator<Integer, Integer> totalWords = createAggregator("totalWords", + new Sum.SumIntegerFn()); + private final Aggregator<Integer, Integer> maxWordLength = createAggregator("maxWordLength", + new Max.MaxIntegerFn()); + private final PCollectionView<String> regex; + + ExtractWordsFn(PCollectionView<String> regex) { + this.regex = regex; + } + + @Override + public void processElement(ProcessContext c) { + String[] words = c.element().split(c.sideInput(regex)); + for (String word : words) { + totalWords.addValue(1); + if (!word.isEmpty()) { + maxWordLength.addValue(word.length()); + if (Character.isLowerCase(word.charAt(0))) { + c.output(word); + } else { + c.sideOutput(upper, word); + } + } + } + } + } + + public static class CountWords extends PTransform<PCollection<String>, PCollectionTuple> { + + private final PCollectionView<String> regex; + private final ExtractWordsFn extractWordsFn; + + public CountWords(PCollectionView<String> regex) { + this.regex = regex; + this.extractWordsFn = new ExtractWordsFn(regex); + } + + @Override + public PCollectionTuple apply(PCollection<String> lines) { + // Convert lines of text into individual words. + PCollectionTuple lowerUpper = lines + .apply(ParDo.of(extractWordsFn) + .withSideInputs(regex) + .withOutputTags(lower, TupleTagList.of(upper))); + lowerUpper.get(lower).setCoder(StringUtf8Coder.of()); + lowerUpper.get(upper).setCoder(StringUtf8Coder.of()); + PCollection<KV<String, Long>> lowerCounts = lowerUpper.get(lower).apply(Count + .<String>perElement()); + PCollection<KV<String, Long>> upperCounts = lowerUpper.get(upper).apply(Count + .<String>perElement()); + return PCollectionTuple + .of(lowerCnts, lowerCounts) + .and(upperCnts, upperCounts); + } + + Aggregator<Integer, Integer> getTotalWordsAggregator() { + return extractWordsFn.totalWords; + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SerializationTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SerializationTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SerializationTest.java new file mode 100644 index 0000000..b378795 --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SerializationTest.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.AtomicCoder; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.transforms.*; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Function; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; +import org.apache.beam.runners.spark.EvaluationResult; +import org.apache.beam.runners.spark.SparkPipelineOptions; +import org.apache.beam.runners.spark.SparkPipelineRunner; +import org.junit.Test; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.Arrays; +import java.util.List; +import java.util.Set; +import java.util.regex.Pattern; + +public class SerializationTest { + + public static class StringHolder { // not serializable + private final String string; + + public StringHolder(String string) { + this.string = string; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + StringHolder that = (StringHolder) o; + return string.equals(that.string); + } + + @Override + public int hashCode() { + return string.hashCode(); + } + + @Override + public String toString() { + return string; + } + } + + public static class StringHolderUtf8Coder extends AtomicCoder<StringHolder> { + + private final StringUtf8Coder stringUtf8Coder = StringUtf8Coder.of(); + + @Override + public void encode(StringHolder value, OutputStream outStream, Context context) throws IOException { + stringUtf8Coder.encode(value.toString(), outStream, context); + } + + @Override + public StringHolder decode(InputStream inStream, Context context) throws IOException { + return new StringHolder(stringUtf8Coder.decode(inStream, context)); + } + + public static Coder<StringHolder> of() { + return new StringHolderUtf8Coder(); + } + } + + private static final String[] WORDS_ARRAY = { + "hi there", "hi", "hi sue bob", + "hi sue", "", "bob hi"}; + private static final List<StringHolder> WORDS = Lists.transform( + Arrays.asList(WORDS_ARRAY), new Function<String, StringHolder>() { + @Override public StringHolder apply(String s) { + return new StringHolder(s); + } + }); + private static final Set<StringHolder> EXPECTED_COUNT_SET = + ImmutableSet.copyOf(Lists.transform( + Arrays.asList("hi: 5", "there: 1", "sue: 2", "bob: 2"), + new Function<String, StringHolder>() { + @Override + public StringHolder apply(String s) { + return new StringHolder(s); + } + })); + + @Test + public void testRun() throws Exception { + SparkPipelineOptions options = SparkPipelineOptionsFactory.create(); + options.setRunner(SparkPipelineRunner.class); + Pipeline p = Pipeline.create(options); + PCollection<StringHolder> inputWords = + p.apply(Create.of(WORDS).withCoder(StringHolderUtf8Coder.of())); + PCollection<StringHolder> output = inputWords.apply(new CountWords()); + + DataflowAssert.that(output).containsInAnyOrder(EXPECTED_COUNT_SET); + + EvaluationResult res = SparkPipelineRunner.create().run(p); + res.close(); + } + + /** + * A DoFn that tokenizes lines of text into individual words. + */ + static class ExtractWordsFn extends DoFn<StringHolder, StringHolder> { + private static final Pattern WORD_BOUNDARY = Pattern.compile("[^a-zA-Z']+"); + private final Aggregator<Long, Long> emptyLines = + createAggregator("emptyLines", new Sum.SumLongFn()); + + @Override + public void processElement(ProcessContext c) { + // Split the line into words. + String[] words = WORD_BOUNDARY.split(c.element().toString()); + + // Keep track of the number of lines without any words encountered while tokenizing. + // This aggregator is visible in the monitoring UI when run using DataflowPipelineRunner. + if (words.length == 0) { + emptyLines.addValue(1L); + } + + // Output each word encountered into the output PCollection. + for (String word : words) { + if (!word.isEmpty()) { + c.output(new StringHolder(word)); + } + } + } + } + + /** + * A DoFn that converts a Word and Count into a printable string. + */ + private static class FormatCountsFn extends DoFn<KV<StringHolder, Long>, StringHolder> { + @Override + public void processElement(ProcessContext c) { + c.output(new StringHolder(c.element().getKey() + ": " + c.element().getValue())); + } + } + + private static class CountWords extends PTransform<PCollection<StringHolder>, PCollection<StringHolder>> { + @Override + public PCollection<StringHolder> apply(PCollection<StringHolder> lines) { + + // Convert lines of text into individual words. + PCollection<StringHolder> words = lines.apply( + ParDo.of(new ExtractWordsFn())); + + // Count the number of times each word occurs. + PCollection<KV<StringHolder, Long>> wordCounts = + words.apply(Count.<StringHolder>perElement()); + + // Format each word and count into a printable string. + + return wordCounts.apply(ParDo.of(new FormatCountsFn())); + } + + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SideEffectsTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SideEffectsTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SideEffectsTest.java new file mode 100644 index 0000000..fc14fc7 --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/SideEffectsTest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringDelegateCoder; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import org.apache.beam.runners.spark.SparkPipelineOptions; +import org.apache.beam.runners.spark.SparkPipelineRunner; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.Serializable; +import java.net.URI; + +import static org.junit.Assert.*; + +public class SideEffectsTest implements Serializable { + + static class UserException extends RuntimeException { + } + + @Test + public void test() throws Exception { + SparkPipelineOptions options = SparkPipelineOptionsFactory.create(); + options.setRunner(SparkPipelineRunner.class); + Pipeline pipeline = Pipeline.create(options); + + pipeline.getCoderRegistry().registerCoder(URI.class, StringDelegateCoder.of(URI.class)); + + pipeline.apply(Create.of("a")).apply(ParDo.of(new DoFn<String, String>() { + @Override + public void processElement(ProcessContext c) throws Exception { + throw new UserException(); + } + })); + + try { + pipeline.run(); + fail("Run should thrown an exception"); + } catch (RuntimeException e) { + assertNotNull(e.getCause()); + + // TODO: remove the version check (and the setup and teardown methods) when we no + // longer support Spark 1.3 or 1.4 + String version = SparkContextFactory.getSparkContext(options.getSparkMaster(), options.getAppName()).version(); + if (!version.startsWith("1.3.") && !version.startsWith("1.4.")) { + assertTrue(e.getCause() instanceof UserException); + } + } + } + + @Before + public void setup() { + System.setProperty(SparkContextFactory.TEST_REUSE_SPARK_CONTEXT, "true"); + } + + @After + public void teardown() { + System.setProperty(SparkContextFactory.TEST_REUSE_SPARK_CONTEXT, "false"); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TestSparkPipelineOptionsFactory.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TestSparkPipelineOptionsFactory.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TestSparkPipelineOptionsFactory.java new file mode 100644 index 0000000..9cace83 --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TestSparkPipelineOptionsFactory.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import org.apache.beam.runners.spark.SparkPipelineOptions; +import org.junit.Assert; +import org.junit.Test; + +public class TestSparkPipelineOptionsFactory { + @Test + public void testDefaultCreateMethod() { + SparkPipelineOptions actualOptions = SparkPipelineOptionsFactory.create(); + Assert.assertEquals("local[1]", actualOptions.getSparkMaster()); + } + + @Test + public void testSettingCustomOptions() { + SparkPipelineOptions actualOptions = SparkPipelineOptionsFactory.create(); + actualOptions.setSparkMaster("spark://207.184.161.138:7077"); + Assert.assertEquals("spark://207.184.161.138:7077", actualOptions.getSparkMaster()); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TransformTranslatorTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TransformTranslatorTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TransformTranslatorTest.java new file mode 100644 index 0000000..da30321 --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/TransformTranslatorTest.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import com.google.api.client.repackaged.com.google.common.base.Joiner; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.base.Charsets; +import org.apache.beam.runners.spark.SparkPipelineRunner; +import org.apache.commons.io.FileUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Collections; +import java.util.List; + +/** + * A test for the transforms registered in TransformTranslator. + * Builds a regular Dataflow pipeline with each of the mapped + * transforms, and makes sure that they work when the pipeline is + * executed in Spark. + */ +public class TransformTranslatorTest { + + @Rule + public TestName name = new TestName(); + + private DirectPipelineRunner directRunner; + private SparkPipelineRunner sparkRunner; + private String testDataDirName; + + @Before public void init() throws IOException { + sparkRunner = SparkPipelineRunner.create(); + directRunner = DirectPipelineRunner.createForTest(); + testDataDirName = Joiner.on(File.separator).join("target", "test-data", name.getMethodName()) + + File.separator; + FileUtils.deleteDirectory(new File(testDataDirName)); + new File(testDataDirName).mkdirs(); + } + + /** + * Builds a simple pipeline with TextIO.Read and TextIO.Write, runs the pipeline + * in DirectPipelineRunner and on SparkPipelineRunner, with the mapped dataflow-to-spark + * transforms. Finally it makes sure that the results are the same for both runs. + */ + @Test + public void testTextIOReadAndWriteTransforms() throws IOException { + String directOut = runPipeline("direct", directRunner); + String sparkOut = runPipeline("spark", sparkRunner); + + List<String> directOutput = + Files.readAllLines(Paths.get(directOut + "-00000-of-00001"), Charsets.UTF_8); + + List<String> sparkOutput = + Files.readAllLines(Paths.get(sparkOut + "-00000-of-00001"), Charsets.UTF_8); + + // sort output to get a stable result (PCollections are not ordered) + Collections.sort(directOutput); + Collections.sort(sparkOutput); + + Assert.assertArrayEquals(directOutput.toArray(), sparkOutput.toArray()); + } + + private String runPipeline(String name, PipelineRunner<?> runner) { + Pipeline p = Pipeline.create(PipelineOptionsFactory.create()); + String outFile = Joiner.on(File.separator).join(testDataDirName, "test_text_out_" + name); + PCollection<String> lines = p.apply(TextIO.Read.from("src/test/resources/test_text.txt")); + lines.apply(TextIO.Write.to(outFile)); + runner.run(p); + return outFile; + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/WindowedWordCountTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/WindowedWordCountTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/WindowedWordCountTest.java new file mode 100644 index 0000000..9f29a37 --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/WindowedWordCountTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.spark.translation; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.options.PipelineOptionsFactory; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.ImmutableList; +import java.util.Arrays; +import java.util.List; + +import org.apache.beam.runners.spark.EvaluationResult; +import org.apache.beam.runners.spark.SimpleWordCountTest; +import org.apache.beam.runners.spark.SparkPipelineOptions; +import org.apache.beam.runners.spark.SparkPipelineRunner; +import org.joda.time.Duration; +import org.junit.Test; + +public class WindowedWordCountTest { + private static final String[] WORDS_ARRAY = { + "hi there", "hi", "hi sue bob", + "hi sue", "", "bob hi"}; + private static final Long[] TIMESTAMPS_ARRAY = { + 60000L, 60000L, 60000L, + 120000L, 120000L, 120000L}; + private static final List<String> WORDS = Arrays.asList(WORDS_ARRAY); + private static final List<Long> TIMESTAMPS = Arrays.asList(TIMESTAMPS_ARRAY); + private static final List<String> EXPECTED_COUNT_SET = + ImmutableList.of("hi: 3", "there: 1", "sue: 1", "bob: 1", + "hi: 2", "sue: 1", "bob: 1"); + + @Test + public void testRun() throws Exception { + SparkPipelineOptions options = SparkPipelineOptionsFactory.create(); + options.setRunner(SparkPipelineRunner.class); + Pipeline p = Pipeline.create(PipelineOptionsFactory.create()); + PCollection<String> inputWords = p.apply(Create.timestamped(WORDS, TIMESTAMPS)) + .setCoder(StringUtf8Coder.of()); + PCollection<String> windowedWords = inputWords + .apply(Window.<String>into(FixedWindows.of(Duration.standardMinutes(1)))); + + PCollection<String> output = windowedWords.apply(new SimpleWordCountTest.CountWords()); + + DataflowAssert.that(output).containsInAnyOrder(EXPECTED_COUNT_SET); + + EvaluationResult res = SparkPipelineRunner.create().run(p); + res.close(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/FlattenStreamingTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/FlattenStreamingTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/FlattenStreamingTest.java new file mode 100644 index 0000000..a3eb301 --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/FlattenStreamingTest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.translation.streaming; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionList; + +import org.apache.beam.runners.spark.SparkStreamingPipelineOptions; +import org.apache.beam.runners.spark.io.CreateStream; +import org.apache.beam.runners.spark.EvaluationResult; +import org.apache.beam.runners.spark.SparkPipelineRunner; +import org.apache.beam.runners.spark.translation.streaming.utils.DataflowAssertStreaming; + +import org.joda.time.Duration; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Test Flatten (union) implementation for streaming. + */ +public class FlattenStreamingTest { + + private static final String[] WORDS_ARRAY_1 = { + "one", "two", "three", "four"}; + private static final List<Iterable<String>> WORDS_QUEUE_1 = + Collections.<Iterable<String>>singletonList(Arrays.asList(WORDS_ARRAY_1)); + private static final String[] WORDS_ARRAY_2 = { + "five", "six", "seven", "eight"}; + private static final List<Iterable<String>> WORDS_QUEUE_2 = + Collections.<Iterable<String>>singletonList(Arrays.asList(WORDS_ARRAY_2)); + private static final String[] EXPECTED_UNION = { + "one", "two", "three", "four", "five", "six", "seven", "eight"}; + private static final long TEST_TIMEOUT_MSEC = 1000L; + + @Test + public void testRun() throws Exception { + SparkStreamingPipelineOptions options = SparkStreamingPipelineOptionsFactory.create(); + options.setAppName(this.getClass().getSimpleName()); + options.setRunner(SparkPipelineRunner.class); + options.setTimeout(TEST_TIMEOUT_MSEC);// run for one interval + Pipeline p = Pipeline.create(options); + + PCollection<String> w1 = + p.apply(CreateStream.fromQueue(WORDS_QUEUE_1)).setCoder(StringUtf8Coder.of()); + PCollection<String> windowedW1 = + w1.apply(Window.<String>into(FixedWindows.of(Duration.standardSeconds(1)))); + PCollection<String> w2 = + p.apply(CreateStream.fromQueue(WORDS_QUEUE_2)).setCoder(StringUtf8Coder.of()); + PCollection<String> windowedW2 = + w2.apply(Window.<String>into(FixedWindows.of(Duration.standardSeconds(1)))); + PCollectionList<String> list = PCollectionList.of(windowedW1).and(windowedW2); + PCollection<String> union = list.apply(Flatten.<String>pCollections()); + + DataflowAssert.thatIterable(union.apply(View.<String>asIterable())) + .containsInAnyOrder(EXPECTED_UNION); + + EvaluationResult res = SparkPipelineRunner.create(options).run(p); + res.close(); + + DataflowAssertStreaming.assertNoFailures(res); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/KafkaStreamingTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/KafkaStreamingTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/KafkaStreamingTest.java new file mode 100644 index 0000000..628fe86 --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/KafkaStreamingTest.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.translation.streaming; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.apache.beam.runners.spark.SparkStreamingPipelineOptions; +import org.apache.beam.runners.spark.io.KafkaIO; +import org.apache.beam.runners.spark.EvaluationResult; +import org.apache.beam.runners.spark.SparkPipelineRunner; +import org.apache.beam.runners.spark.translation.streaming.utils.DataflowAssertStreaming; +import org.apache.beam.runners.spark.translation.streaming.utils.EmbeddedKafkaCluster; + +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.joda.time.Duration; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.Properties; +import java.util.Set; + +import kafka.serializer.StringDecoder; + +/** + * Test Kafka as input. + */ +public class KafkaStreamingTest { + private static final EmbeddedKafkaCluster.EmbeddedZookeeper EMBEDDED_ZOOKEEPER = + new EmbeddedKafkaCluster.EmbeddedZookeeper(17001); + private static final EmbeddedKafkaCluster EMBEDDED_KAFKA_CLUSTER = + new EmbeddedKafkaCluster(EMBEDDED_ZOOKEEPER.getConnection(), + new Properties(), Collections.singletonList(6667)); + private static final String TOPIC = "kafka_dataflow_test_topic"; + private static final Map<String, String> KAFKA_MESSAGES = ImmutableMap.of( + "k1", "v1", "k2", "v2", "k3", "v3", "k4", "v4" + ); + private static final Set<String> EXPECTED = ImmutableSet.of( + "k1,v1", "k2,v2", "k3,v3", "k4,v4" + ); + private static final long TEST_TIMEOUT_MSEC = 1000L; + + @BeforeClass + public static void init() throws IOException { + EMBEDDED_ZOOKEEPER.startup(); + EMBEDDED_KAFKA_CLUSTER.startup(); + + // write to Kafka + Properties producerProps = new Properties(); + producerProps.putAll(EMBEDDED_KAFKA_CLUSTER.getProps()); + producerProps.put("request.required.acks", 1); + producerProps.put("bootstrap.servers", EMBEDDED_KAFKA_CLUSTER.getBrokerList()); + Serializer<String> stringSerializer = new StringSerializer(); + try (@SuppressWarnings("unchecked") KafkaProducer<String, String> kafkaProducer = + new KafkaProducer(producerProps, stringSerializer, stringSerializer)) { + for (Map.Entry<String, String> en : KAFKA_MESSAGES.entrySet()) { + kafkaProducer.send(new ProducerRecord<>(TOPIC, en.getKey(), en.getValue())); + } + } + } + + @Test + public void testRun() throws Exception { + // test read from Kafka + SparkStreamingPipelineOptions options = SparkStreamingPipelineOptionsFactory.create(); + options.setAppName(this.getClass().getSimpleName()); + options.setRunner(SparkPipelineRunner.class); + options.setTimeout(TEST_TIMEOUT_MSEC);// run for one interval + Pipeline p = Pipeline.create(options); + + Map<String, String> kafkaParams = ImmutableMap.of( + "metadata.broker.list", EMBEDDED_KAFKA_CLUSTER.getBrokerList(), + "auto.offset.reset", "smallest" + ); + + PCollection<KV<String, String>> kafkaInput = p.apply(KafkaIO.Read.from(StringDecoder.class, + StringDecoder.class, String.class, String.class, Collections.singleton(TOPIC), + kafkaParams)) + .setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of())); + PCollection<KV<String, String>> windowedWords = kafkaInput + .apply(Window.<KV<String, String>>into(FixedWindows.of(Duration.standardSeconds(1)))); + + PCollection<String> formattedKV = windowedWords.apply(ParDo.of(new FormatKVFn())); + + DataflowAssert.thatIterable(formattedKV.apply(View.<String>asIterable())) + .containsInAnyOrder(EXPECTED); + + EvaluationResult res = SparkPipelineRunner.create(options).run(p); + res.close(); + + DataflowAssertStreaming.assertNoFailures(res); + } + + @AfterClass + public static void tearDown() { + EMBEDDED_KAFKA_CLUSTER.shutdown(); + EMBEDDED_ZOOKEEPER.shutdown(); + } + + private static class FormatKVFn extends DoFn<KV<String, String>, String> { + @Override + public void processElement(ProcessContext c) { + c.output(c.element().getKey() + "," + c.element().getValue()); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SimpleStreamingWordCountTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SimpleStreamingWordCountTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SimpleStreamingWordCountTest.java new file mode 100644 index 0000000..b591510 --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/SimpleStreamingWordCountTest.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.translation.streaming; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder; +import com.google.cloud.dataflow.sdk.testing.DataflowAssert; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.windowing.FixedWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.common.collect.ImmutableSet; + +import org.apache.beam.runners.spark.SparkStreamingPipelineOptions; +import org.apache.beam.runners.spark.io.CreateStream; +import org.apache.beam.runners.spark.EvaluationResult; +import org.apache.beam.runners.spark.SimpleWordCountTest; +import org.apache.beam.runners.spark.SparkPipelineRunner; +import org.apache.beam.runners.spark.translation.streaming.utils.DataflowAssertStreaming; + +import org.joda.time.Duration; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Set; + +public class SimpleStreamingWordCountTest { + + private static final String[] WORDS_ARRAY = { + "hi there", "hi", "hi sue bob", "hi sue", "", "bob hi"}; + private static final List<Iterable<String>> WORDS_QUEUE = + Collections.<Iterable<String>>singletonList(Arrays.asList(WORDS_ARRAY)); + private static final Set<String> EXPECTED_COUNT_SET = + ImmutableSet.of("hi: 5", "there: 1", "sue: 2", "bob: 2"); + private static final long TEST_TIMEOUT_MSEC = 1000L; + + @Test + public void testRun() throws Exception { + SparkStreamingPipelineOptions options = SparkStreamingPipelineOptionsFactory.create(); + options.setAppName(this.getClass().getSimpleName()); + options.setRunner(SparkPipelineRunner.class); + options.setTimeout(TEST_TIMEOUT_MSEC);// run for one interval + Pipeline p = Pipeline.create(options); + + PCollection<String> inputWords = + p.apply(CreateStream.fromQueue(WORDS_QUEUE)).setCoder(StringUtf8Coder.of()); + PCollection<String> windowedWords = inputWords + .apply(Window.<String>into(FixedWindows.of(Duration.standardSeconds(1)))); + + PCollection<String> output = windowedWords.apply(new SimpleWordCountTest.CountWords()); + + DataflowAssert.thatIterable(output.apply(View.<String>asIterable())) + .containsInAnyOrder(EXPECTED_COUNT_SET); + + EvaluationResult res = SparkPipelineRunner.create(options).run(p); + res.close(); + + DataflowAssertStreaming.assertNoFailures(res); + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/DataflowAssertStreaming.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/DataflowAssertStreaming.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/DataflowAssertStreaming.java new file mode 100644 index 0000000..30673dd --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/DataflowAssertStreaming.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.translation.streaming.utils; + +import org.apache.beam.runners.spark.EvaluationResult; + +import org.junit.Assert; + +/** + * Since DataflowAssert doesn't propagate assert exceptions, use Aggregators to assert streaming + * success/failure counters. + */ +public final class DataflowAssertStreaming { + /** + * Copied aggregator names from {@link com.google.cloud.dataflow.sdk.testing.DataflowAssert} + */ + static final String SUCCESS_COUNTER = "DataflowAssertSuccess"; + static final String FAILURE_COUNTER = "DataflowAssertFailure"; + + private DataflowAssertStreaming() { + } + + public static void assertNoFailures(EvaluationResult res) { + int failures = res.getAggregatorValue(FAILURE_COUNTER, Integer.class); + Assert.assertEquals("Found " + failures + " failures, see the log for details", 0, failures); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/eb0341d4/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/EmbeddedKafkaCluster.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/EmbeddedKafkaCluster.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/EmbeddedKafkaCluster.java new file mode 100644 index 0000000..e967cdb --- /dev/null +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/streaming/utils/EmbeddedKafkaCluster.java @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.spark.translation.streaming.utils; + +import org.apache.zookeeper.server.NIOServerCnxnFactory; +import org.apache.zookeeper.server.ServerCnxnFactory; +import org.apache.zookeeper.server.ZooKeeperServer; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Properties; +import java.util.Random; + +import kafka.server.KafkaConfig; +import kafka.server.KafkaServer; +import kafka.utils.Time; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * https://gist.github.com/fjavieralba/7930018 + */ +public class EmbeddedKafkaCluster { + + private static final Logger LOG = LoggerFactory.getLogger(EmbeddedKafkaCluster.class); + + private final List<Integer> ports; + private final String zkConnection; + private final Properties baseProperties; + + private final String brokerList; + + private final List<KafkaServer> brokers; + private final List<File> logDirs; + + public EmbeddedKafkaCluster(String zkConnection) { + this(zkConnection, new Properties()); + } + + public EmbeddedKafkaCluster(String zkConnection, Properties baseProperties) { + this(zkConnection, baseProperties, Collections.singletonList(-1)); + } + + public EmbeddedKafkaCluster(String zkConnection, Properties baseProperties, List<Integer> ports) { + this.zkConnection = zkConnection; + this.ports = resolvePorts(ports); + this.baseProperties = baseProperties; + + this.brokers = new ArrayList<>(); + this.logDirs = new ArrayList<>(); + + this.brokerList = constructBrokerList(this.ports); + } + + private static List<Integer> resolvePorts(List<Integer> ports) { + List<Integer> resolvedPorts = new ArrayList<>(); + for (Integer port : ports) { + resolvedPorts.add(resolvePort(port)); + } + return resolvedPorts; + } + + private static int resolvePort(int port) { + if (port == -1) { + return TestUtils.getAvailablePort(); + } + return port; + } + + private static String constructBrokerList(List<Integer> ports) { + StringBuilder sb = new StringBuilder(); + for (Integer port : ports) { + if (sb.length() > 0) { + sb.append(","); + } + sb.append("localhost:").append(port); + } + return sb.toString(); + } + + public void startup() { + for (int i = 0; i < ports.size(); i++) { + Integer port = ports.get(i); + File logDir = TestUtils.constructTempDir("kafka-local"); + + Properties properties = new Properties(); + properties.putAll(baseProperties); + properties.setProperty("zookeeper.connect", zkConnection); + properties.setProperty("broker.id", String.valueOf(i + 1)); + properties.setProperty("host.name", "localhost"); + properties.setProperty("port", Integer.toString(port)); + properties.setProperty("log.dir", logDir.getAbsolutePath()); + properties.setProperty("log.flush.interval.messages", String.valueOf(1)); + + KafkaServer broker = startBroker(properties); + + brokers.add(broker); + logDirs.add(logDir); + } + } + + + private static KafkaServer startBroker(Properties props) { + KafkaServer server = new KafkaServer(new KafkaConfig(props), new SystemTime()); + server.startup(); + return server; + } + + public Properties getProps() { + Properties props = new Properties(); + props.putAll(baseProperties); + props.put("metadata.broker.list", brokerList); + props.put("zookeeper.connect", zkConnection); + return props; + } + + public String getBrokerList() { + return brokerList; + } + + public List<Integer> getPorts() { + return ports; + } + + public String getZkConnection() { + return zkConnection; + } + + public void shutdown() { + for (KafkaServer broker : brokers) { + try { + broker.shutdown(); + } catch (Exception e) { + LOG.warn("{}", e.getMessage(), e); + } + } + for (File logDir : logDirs) { + try { + TestUtils.deleteFile(logDir); + } catch (FileNotFoundException e) { + LOG.warn("{}", e.getMessage(), e); + } + } + } + + @Override + public String toString() { + return "EmbeddedKafkaCluster{" + "brokerList='" + brokerList + "'}"; + } + + public static class EmbeddedZookeeper { + private int port = -1; + private int tickTime = 500; + + private ServerCnxnFactory factory; + private File snapshotDir; + private File logDir; + + public EmbeddedZookeeper() { + this(-1); + } + + public EmbeddedZookeeper(int port) { + this(port, 500); + } + + public EmbeddedZookeeper(int port, int tickTime) { + this.port = resolvePort(port); + this.tickTime = tickTime; + } + + private static int resolvePort(int port) { + if (port == -1) { + return TestUtils.getAvailablePort(); + } + return port; + } + + public void startup() throws IOException { + if (this.port == -1) { + this.port = TestUtils.getAvailablePort(); + } + this.factory = NIOServerCnxnFactory.createFactory(new InetSocketAddress("localhost", port), + 1024); + this.snapshotDir = TestUtils.constructTempDir("embedded-zk/snapshot"); + this.logDir = TestUtils.constructTempDir("embedded-zk/log"); + + try { + factory.startup(new ZooKeeperServer(snapshotDir, logDir, tickTime)); + } catch (InterruptedException e) { + throw new IOException(e); + } + } + + + public void shutdown() { + factory.shutdown(); + try { + TestUtils.deleteFile(snapshotDir); + } catch (FileNotFoundException e) { + // ignore + } + try { + TestUtils.deleteFile(logDir); + } catch (FileNotFoundException e) { + // ignore + } + } + + public String getConnection() { + return "localhost:" + port; + } + + public void setPort(int port) { + this.port = port; + } + + public void setTickTime(int tickTime) { + this.tickTime = tickTime; + } + + public int getPort() { + return port; + } + + public int getTickTime() { + return tickTime; + } + + @Override + public String toString() { + return "EmbeddedZookeeper{" + "connection=" + getConnection() + "}"; + } + } + + static class SystemTime implements Time { + @Override + public long milliseconds() { + return System.currentTimeMillis(); + } + + @Override + public long nanoseconds() { + return System.nanoTime(); + } + + @Override + public void sleep(long ms) { + try { + Thread.sleep(ms); + } catch (InterruptedException e) { + // Ignore + } + } + } + + static final class TestUtils { + private static final Random RANDOM = new Random(); + + private TestUtils() { + } + + static File constructTempDir(String dirPrefix) { + File file = new File(System.getProperty("java.io.tmpdir"), dirPrefix + RANDOM.nextInt + (10000000)); + if (!file.mkdirs()) { + throw new RuntimeException("could not create temp directory: " + file.getAbsolutePath()); + } + file.deleteOnExit(); + return file; + } + + static int getAvailablePort() { + try { + try (ServerSocket socket = new ServerSocket(0)) { + return socket.getLocalPort(); + } + } catch (IOException e) { + throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e); + } + } + + static boolean deleteFile(File path) throws FileNotFoundException { + if (!path.exists()) { + throw new FileNotFoundException(path.getAbsolutePath()); + } + boolean ret = true; + if (path.isDirectory()) { + for (File f : path.listFiles()) { + ret = ret && deleteFile(f); + } + } + return ret && path.delete(); + } + } +}
