Make JAXBCoder Thread Safe JAXB Marshaller and Unmarshaller are not thread safe, but coders are required to be.
Create context only once, and create Marshaller and Unmarshaller locally on-demand. Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/071d9c7e Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/071d9c7e Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/071d9c7e Branch: refs/heads/master Commit: 071d9c7ee8ed3a9628924de624d4e3cb03c5f20b Parents: 6f52ff9 Author: Thomas Groh <[email protected]> Authored: Mon Jun 20 09:40:31 2016 -0700 Committer: Dan Halperin <[email protected]> Committed: Mon Jun 20 11:26:32 2016 -0700 ---------------------------------------------------------------------- .../org/apache/beam/sdk/coders/JAXBCoder.java | 28 +++++--- .../apache/beam/sdk/coders/JAXBCoderTest.java | 69 +++++++++++++++++++- 2 files changed, 84 insertions(+), 13 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071d9c7e/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/JAXBCoder.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/JAXBCoder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/JAXBCoder.java index 6fc8fcf..f90eb54 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/JAXBCoder.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/JAXBCoder.java @@ -46,8 +46,7 @@ import javax.xml.bind.Unmarshaller; public class JAXBCoder<T> extends AtomicCoder<T> { private final Class<T> jaxbClass; - private transient Marshaller jaxbMarshaller = null; - private transient Unmarshaller jaxbUnmarshaller = null; + private transient JAXBContext jaxbContext; public Class<T> getJAXBClass() { return jaxbClass; @@ -70,10 +69,9 @@ public class JAXBCoder<T> extends AtomicCoder<T> { public void encode(T value, OutputStream outStream, Context context) throws CoderException, IOException { try { - if (jaxbMarshaller == null) { - JAXBContext jaxbContext = JAXBContext.newInstance(jaxbClass); - jaxbMarshaller = jaxbContext.createMarshaller(); - } + JAXBContext jaxbContext = getContext(); + // TODO: Consider caching in a ThreadLocal if this impacts performance + Marshaller jaxbMarshaller = jaxbContext.createMarshaller(); if (!context.isWholeStream) { try { long size = getEncodedElementByteSize(value, Context.OUTER); @@ -95,10 +93,9 @@ public class JAXBCoder<T> extends AtomicCoder<T> { @Override public T decode(InputStream inStream, Context context) throws CoderException, IOException { try { - if (jaxbUnmarshaller == null) { - JAXBContext jaxbContext = JAXBContext.newInstance(jaxbClass); - jaxbUnmarshaller = jaxbContext.createUnmarshaller(); - } + JAXBContext jaxbContext = getContext(); + // TODO: Consider caching in a ThreadLocal if this impacts performance + Unmarshaller jaxbUnmarshaller = jaxbContext.createUnmarshaller(); InputStream stream = inStream; if (!context.isWholeStream) { @@ -113,6 +110,17 @@ public class JAXBCoder<T> extends AtomicCoder<T> { } } + private final JAXBContext getContext() throws JAXBException { + if (jaxbContext == null) { + synchronized (this) { + if (jaxbContext == null) { + jaxbContext = JAXBContext.newInstance(jaxbClass); + } + } + } + return jaxbContext; + } + @Override public String getEncodingId() { return getJAXBClass().getName(); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/071d9c7e/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/JAXBCoderTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/JAXBCoderTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/JAXBCoderTest.java index 1a00417..6b59e52 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/JAXBCoderTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/JAXBCoderTest.java @@ -17,12 +17,15 @@ */ package org.apache.beam.sdk.coders; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + import org.apache.beam.sdk.testing.CoderProperties; import org.apache.beam.sdk.util.CoderUtils; +import org.apache.beam.sdk.util.SerializableUtils; import com.google.common.collect.ImmutableList; -import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -31,6 +34,11 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import javax.xml.bind.annotation.XmlRootElement; @@ -91,7 +99,15 @@ public class JAXBCoderTest { JAXBCoder<TestType> coder = JAXBCoder.of(TestType.class); byte[] encoded = CoderUtils.encodeToByteArray(coder, new TestType("abc", 9999)); - Assert.assertEquals(new TestType("abc", 9999), CoderUtils.decodeFromByteArray(coder, encoded)); + assertEquals(new TestType("abc", 9999), CoderUtils.decodeFromByteArray(coder, encoded)); + } + + @Test + public void testEncodeDecodeAfterClone() throws Exception { + JAXBCoder<TestType> coder = SerializableUtils.clone(JAXBCoder.of(TestType.class)); + + byte[] encoded = CoderUtils.encodeToByteArray(coder, new TestType("abc", 9999)); + assertEquals(new TestType("abc", 9999), CoderUtils.decodeFromByteArray(coder, encoded)); } @Test @@ -100,10 +116,57 @@ public class JAXBCoderTest { TestCoder nesting = new TestCoder(jaxbCoder); byte[] encoded = CoderUtils.encodeToByteArray(nesting, new TestType("abc", 9999)); - Assert.assertEquals( + assertEquals( new TestType("abc", 9999), CoderUtils.decodeFromByteArray(nesting, encoded)); } + @Test + public void testEncodeDecodeMultithreaded() throws Throwable { + final JAXBCoder<TestType> coder = JAXBCoder.of(TestType.class); + int numThreads = 1000; + + final CountDownLatch ready = new CountDownLatch(numThreads); + final CountDownLatch start = new CountDownLatch(1); + final CountDownLatch done = new CountDownLatch(numThreads); + + final AtomicReference<Throwable> thrown = new AtomicReference<>(); + + Executor executor = Executors.newCachedThreadPool(); + for (int i = 0; i < numThreads; i++) { + final TestType elem = new TestType("abc", i); + final int index = i; + executor.execute( + new Runnable() { + @Override + public void run() { + ready.countDown(); + try { + start.await(); + } catch (InterruptedException e) { + } + + try { + byte[] encoded = CoderUtils.encodeToByteArray(coder, elem); + assertEquals( + new TestType("abc", index), CoderUtils.decodeFromByteArray(coder, encoded)); + } catch (Throwable e) { + thrown.compareAndSet(null, e); + } + done.countDown(); + } + }); + } + ready.await(); + start.countDown(); + + if (!done.await(10L, TimeUnit.SECONDS)) { + fail("Should be able to clone " + numThreads + " elements in 10 seconds"); + } + if (thrown.get() != null) { + throw thrown.get(); + } + } + /** * A coder that surrounds the value with two values, to demonstrate nesting. */
