This is an automated email from the ASF dual-hosted git repository. dblevins pushed a commit to branch TOMEE-4050 in repository https://gitbox.apache.org/repos/asf/tomee.git
commit 876bf3b9e123e485a1a3488e8de5e1bed0e000dc Author: David Blevins <dblev...@tomitribe.com> AuthorDate: Thu Sep 22 18:50:16 2022 -0700 Initialization and retry tests for CachingSupplier Required for TOMEE-4050: Retry and Refresh for MP JWT keys supplied via HTTP --- .../org/apache/openejb/util/CachedSupplier.java | 6 +- .../apache/openejb/util/CachedSupplierTest.java | 410 ++++++++++++++++++++- 2 files changed, 411 insertions(+), 5 deletions(-) diff --git a/container/openejb-core/src/main/java/org/apache/openejb/util/CachedSupplier.java b/container/openejb-core/src/main/java/org/apache/openejb/util/CachedSupplier.java index 8eec3d6444..45a47e8a35 100644 --- a/container/openejb-core/src/main/java/org/apache/openejb/util/CachedSupplier.java +++ b/container/openejb-core/src/main/java/org/apache/openejb/util/CachedSupplier.java @@ -80,8 +80,10 @@ public class CachedSupplier<T> implements Supplier<T> { @Override public T get() { try { - initialized.await(accessTimeout.getTime(), accessTimeout.getUnit()); - return value.get(); + if (initialized.await(accessTimeout.getTime(), accessTimeout.getUnit())){ + return value.get(); + } + throw new TimeoutException(); } catch (InterruptedException e) { throw new TimeoutException(); } diff --git a/container/openejb-core/src/test/java/org/apache/openejb/util/CachedSupplierTest.java b/container/openejb-core/src/test/java/org/apache/openejb/util/CachedSupplierTest.java index e179bc3acc..ecac615047 100644 --- a/container/openejb-core/src/test/java/org/apache/openejb/util/CachedSupplierTest.java +++ b/container/openejb-core/src/test/java/org/apache/openejb/util/CachedSupplierTest.java @@ -18,9 +18,23 @@ package org.apache.openejb.util; import org.junit.Test; +import java.util.Objects; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; +import java.util.stream.Stream; -import static org.junit.Assert.*; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; public class CachedSupplierTest { @@ -34,6 +48,17 @@ public class CachedSupplierTest { */ @Test public void happyPath() { + final AtomicInteger count = new AtomicInteger(); + final Supplier<Integer> supplier = count::incrementAndGet; + final CachedSupplier<Integer> cached = CachedSupplier.of(supplier); + + Runner.threads(100) + .run(() -> assertEquals(1, (int) cached.get())) + .assertNoExceptions() + .assertTimesLessThan(5, MILLISECONDS); + + // Assert the supplier was not called more than once + assertEquals(1, count.get()); } /** @@ -43,6 +68,32 @@ public class CachedSupplierTest { */ @Test public void delayedInitialization() { + final CountDownLatch causeSomeDelays = new CountDownLatch(1); + final AtomicInteger count = new AtomicInteger(); + final Supplier<Integer> supplier = () -> { + await(causeSomeDelays); + sleep(111); + return count.incrementAndGet(); + }; + final CachedSupplier<Integer> cached = CachedSupplier.of(supplier); + + final Runner runner = Runner.threads(100); + + // Run and expect at least 100 ms of delays + runner.pre(causeSomeDelays::countDown) + .run(() -> assertEquals(1, (int) cached.get())) + .assertNoExceptions() + .assertTimesGreaterThan(100, MILLISECONDS) + .assertTimesLessThan(200, MILLISECONDS); + + // Everything is cached now, so runs should be quick + runner.pre(null) + .run(() -> assertEquals(1, (int) cached.get())) + .assertNoExceptions() + .assertTimesLessThan(5, MILLISECONDS); + + // Assert the supplier was not called more than once + assertEquals(1, count.get()); } /** @@ -53,7 +104,51 @@ public class CachedSupplierTest { * or any blocking. */ @Test - public void delayedInitializationTimeout() { + public void delayedInitializationTimeout() throws InterruptedException { + final CountDownLatch causeSomeDelays = new CountDownLatch(1); + final CountDownLatch nearlyThere = new CountDownLatch(1); + final AtomicInteger count = new AtomicInteger(); + final Supplier<Integer> supplier = () -> { + await(causeSomeDelays); + sleep(150); + nearlyThere.countDown(); + sleep(50); + try { + return count.incrementAndGet(); + } finally { + nearlyThere.countDown(); + } + }; + + final CachedSupplier<Integer> cached = CachedSupplier.builder(supplier) + .accessTimeout(100, MILLISECONDS) + .build(); + + final Runner runner = Runner.threads(100); + + runner.pre(causeSomeDelays::countDown) + .run(() -> assertEquals(1, (int) cached.get())) + .assertExceptions(CachedSupplier.TimeoutException.class) + .assertTimesGreaterThan(99, MILLISECONDS) + .assertTimesLessThan(150, MILLISECONDS); + + // Wait for the supplier to get near completion + assertTrue(nearlyThere.await(1, MINUTES)); + + runner.pre(null); + + // Calls should now block a bit, but ultimately succeed with no issues + runner.run(() -> assertEquals(1, (int) cached.get())) + .assertNoExceptions() + .assertTimesGreaterThan(30, MILLISECONDS); + + // Calls should now succeed with no delay + runner.run(() -> assertEquals(1, (int) cached.get())) + .assertNoExceptions() + .assertTimesLessThan(5, MILLISECONDS); + + // Assert the supplier was not called more than once + assertEquals(1, count.get()); } /** @@ -64,6 +159,43 @@ public class CachedSupplierTest { */ @Test public void initializationRetry() { + final Long[] calls = new Long[10]; + final AtomicInteger count = new AtomicInteger(); + final Supplier<Integer> supplier = () -> { + final int i = count.incrementAndGet(); + if (i < calls.length) { + calls[i] = System.nanoTime(); + } + + // Return null for the first three calls + // Then return the actual value + return i < 4 ? null : i; + }; + + final CachedSupplier<Integer> cached = CachedSupplier.builder(supplier) + .initialRetryDelay(500, MILLISECONDS) + .accessTimeout(1, MINUTES) + .build(); + + final Runner runner = Runner.threads(100); + + runner.run(() -> assertEquals(4, (int) cached.get())) + .assertNoExceptions(); + + final Long[] tries = Stream.of(calls) + .filter(Objects::nonNull) + .toArray(Long[]::new); + + assertEquals(4, tries.length); + assertEquals(4, count.get()); + + long first = NANOSECONDS.toSeconds(tries[1] - tries[0]); + long second = NANOSECONDS.toSeconds(tries[2] - tries[1]); + long third = NANOSECONDS.toSeconds(tries[3] - tries[2]); + + assertEquals(1, first); + assertEquals(2, second); + assertEquals(4, third); } /** @@ -74,6 +206,50 @@ public class CachedSupplierTest { */ @Test public void initializationRetryTillMax() { + final Long[] calls = new Long[10]; + final AtomicInteger count = new AtomicInteger(); + final Supplier<Integer> supplier = () -> { + final int i = count.incrementAndGet(); + if (i < calls.length) { + calls[i] = System.nanoTime(); + } + + // Return null for the first three calls + // Then return the actual value + return i < 7 ? null : i; + }; + + final CachedSupplier<Integer> cached = CachedSupplier.builder(supplier) + .initialRetryDelay(500, MILLISECONDS) + .maxRetryDelay(10, SECONDS) + .accessTimeout(1, MINUTES) + .build(); + + final Runner runner = Runner.threads(100); + + runner.run(() -> assertEquals(7, (int) cached.get())) + .assertNoExceptions(); + + final Long[] tries = Stream.of(calls) + .filter(Objects::nonNull) + .toArray(Long[]::new); + + assertEquals(7, tries.length); + assertEquals(7, count.get()); + + long first = NANOSECONDS.toSeconds(tries[1] - tries[0]); + long second = NANOSECONDS.toSeconds(tries[2] - tries[1]); + long third = NANOSECONDS.toSeconds(tries[3] - tries[2]); + long fourth = NANOSECONDS.toSeconds(tries[4] - tries[3]); + long fifth = NANOSECONDS.toSeconds(tries[5] - tries[4]); + long sixth = NANOSECONDS.toSeconds(tries[6] - tries[5]); + + assertEquals(1, first); + assertEquals(2, second); + assertEquals(4, third); + assertEquals(8, fourth); + assertEquals(10, fifth); + assertEquals(10, sixth); } /** @@ -90,7 +266,6 @@ public class CachedSupplierTest { public void refreshReliablyCalled() { } - /** * On the first refresh the Supplier returns null indicating there is * no valid replacement. We assert that the previous valid value is @@ -109,4 +284,233 @@ public class CachedSupplierTest { public void refreshFailedWithException() { } + private void sleep(final int millis) { + try { + Thread.sleep(millis); + } catch (InterruptedException e) { + throw new IllegalStateException(e); + } + } + + private void await(final CountDownLatch latch) { + try { + latch.await(); + } catch (InterruptedException e) { + throw new IllegalStateException(e); + } + } + + static class Timer { + private final long start = System.nanoTime(); + + public static Timer start() { + return new Timer(); + } + + public Time time() { + return new Time(System.nanoTime() - start); + } + + public static class Time { + private final long time; + private final String description; + + public Time(final long timeInNanoseconds) { + this.time = timeInNanoseconds; + final long seconds = NANOSECONDS.toSeconds(this.time); + final long milliseconds = NANOSECONDS.toMillis(this.time) - SECONDS.toMillis(seconds); + final long nanoseconds = this.time - SECONDS.toNanos(seconds) - MILLISECONDS.toNanos(milliseconds); + this.description = String.format("%ss, %sms and %sns", seconds, milliseconds, nanoseconds); + } + + public long getTime() { + return time; + } + + public Time assertLessThan(final long time, final TimeUnit unit) { + final long expected = unit.toNanos(time); + final long actual = this.time; + assertTrue("Actual time: " + description, actual < expected); + return this; + } + + public Time assertGreaterThan(final long time, final TimeUnit unit) { + final long expected = unit.toNanos(time); + final long actual = this.time; + assertTrue("Actual time: " + description, actual > expected); + return this; + } + } + } + + public static class Runner { + private final int threads; + private final Executor executor; + private final Duration timeout = new Duration(1, MINUTES); + private Runnable before = null; + + public Runner(final int threads) { + this.threads = threads; + this.executor = Executors.newFixedThreadPool(threads, new DaemonThreadFactory(Runner.class)); + } + + public static Runner threads(final int threads) { + return new Runner(threads); + } + + public Runner pre(final Runnable runnable) { + this.before = runnable; + return this; + } + + public Run run(final Runnable runnable) { + final Throwable[] failures = new Throwable[threads]; + final Timer.Time[] times = new Timer.Time[threads]; + + /* + * You won't immediately understand these CountDownLatches (look down). + * + * Here's the deal: when you launch 100+ threads in a loop as we're + * about to do it can take 25+ milliseconds. By the time you get to + * your 99th thread, the previous 50 are all probably gone. The + * thread-creation overhead messes with all your timings and threads + * are executing somewhat serially with very little parallelism. + * + * The latches fix this by forcing all the threads to truly run + * at the same time. + * + * Imagine each thread is a runner in a race. What we want is + * each runner to get on the racetrack, into the starting + * position (ready.countDown) and wait diligently for the sound + * of the starting pistol (start.await) before they start running. + * + * When all runners are in position (ready.await), we fire the starting + * pistol (start.countDown). Awesome, they're all truly running at + * once and competing. + * + * As each runner finishes the race we have them call completed.countDown + * When all runners are finished the completed.await call will unblock + * and we exit this method with all results in hand. + * + * Seems like overkill, but after you've been burned by poor testing + * covering up thread safety issues you learn to do it right. + */ + final CountDownLatch ready = new CountDownLatch(threads); + final CountDownLatch start = new CountDownLatch(1); + final CountDownLatch completed = new CountDownLatch(threads); + + for (int submitted = 0; submitted < threads; submitted++) { + final int id = submitted; + executor.execute(new Runnable() { + @Override + public void run() { + ready.countDown(); + try { + start.await(); + } catch (InterruptedException e) { + return; + } + + /* + * If there's anything we'd like to execute + * that shouldn't be included in the timings, + * do it now. + */ + if (before != null) before.run(); + + /* + * Run, Forrest! Run!! + */ + final Timer timer = Timer.start(); + try { + runnable.run(); + } catch (Throwable t) { + failures[id] = t; + } finally { + times[id] = timer.time(); + completed.countDown(); + } + } + }); + } + + // wait for the above threads to be ready + await(ready, "ready"); + + // fire the starting pistol + start.countDown(); + + // wait for them to finish the race + await(completed, "completed"); + + return new Run(threads, failures, times); + } + + private void await(final CountDownLatch latch, final String state) { + try { + if (!latch.await(timeout.getTime(), timeout.getUnit())) { + fail(String.format("%s of %s threads not %s after %s", + state, + threads - latch.getCount(), + threads, + timeout + )); + } + } catch (InterruptedException e) { + fail(String.format("Interrupted while waiting %s state", "ready")); + } + } + + public static class Run { + final int threads; + final Throwable[] exceptions; + final Timer.Time[] times; + + + public Run(final int threads, final Throwable[] exceptions, final Timer.Time[] times) { + this.threads = threads; + this.exceptions = exceptions; + this.times = times; + } + + public Run assertNoExceptions() { + final long failed = Stream.of(exceptions) + .filter(Objects::nonNull) + .peek(Throwable::printStackTrace) + .count(); + if (failed > 0) { + final long succeeded = threads - failed; + fail(String.format("Succeeded: %s, Failed: %s", succeeded, failed)); + } + return this; + } + + public Run assertExceptions(final Class<? extends Throwable> expected) { + for (final Throwable actual : exceptions) { + assertNotNull(actual); + try { + assertEquals(expected, actual.getClass()); + } catch (AssertionError e) { + actual.printStackTrace(); + throw e; + } + } + return this; + } + + public Run assertTimesLessThan(final long time, final TimeUnit unit) { + for (final Timer.Time t : times) { + t.assertLessThan(time, unit); + } + return this; + } + + public Run assertTimesGreaterThan(final long time, final TimeUnit unit) { + for (final Timer.Time t : times) { + t.assertGreaterThan(time, unit); + } + return this; + } + } + } }