This is an automated email from the ASF dual-hosted git repository.
He-Pin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/pekko.git
The following commit(s) were added to refs/heads/main by this push:
new 5a4f236db0 feat(stream): add TLS GraphStage engine (#2878)
5a4f236db0 is described below
commit 5a4f236db07af095f9fe82cc1a6f23e9648ae0b6
Author: He-Pin(kerr) <[email protected]>
AuthorDate: Fri May 8 19:57:42 2026 +0800
feat(stream): add TLS GraphStage engine (#2878)
* feat(stream): add TLS GraphStage engine
Motivation:
The stream TLS path still depends on the legacy actor/FanoutProcessor
infrastructure. A GraphStage engine is needed for the stream internals while
preserving the existing Pekko TLSActor SSLEngine state machine semantics
without changing the legacy actor implementation.
Modification:
- Add TlsGraphStage as a GraphStage adapter for the existing Pekko TLS pump
phases.
- Reuse the Pekko TCP direct BufferPool for TLS transport buffers and
allocate application buffers from SSLEngine session sizes.
- Add a pekko.stream.materializer.tls.engine selector with legacy-actor as
the default and graph-stage as the opt-in engine.
- Run the shared TLS regression matrix against both legacy and GraphStage
paths and add focused GraphStage edge-case coverage.
- Add TLS JMH benchmarks for cold handshake and warm round-trip scenarios.
Result:
The GraphStage path is opt-in, the legacy TLSActor remains untouched, and
TLS close, truncation, renegotiation, failure-alert, and TLS 1.3 behavior are
covered by regression tests.
Tests:
- stream / scalafmtCheck
- stream-tests / scalafmtCheck
- bench-jmh / scalafmtCheck
- stream / Test / compile
- stream-tests / Test / testOnly org.apache.pekko.stream.io.TlsSpec
org.apache.pekko.stream.io.TlsGraphStageSpec
org.apache.pekko.stream.io.TlsGraphStageEdgeCasesSpec
org.apache.pekko.stream.io.TlsGraphStageIsolatedSpec
- git diff --check
- red-flag rg scan for prior suspicious port markers
References:
- https://github.com/apache/pekko/pull/2878
- https://github.com/apache/pekko/issues/2860
* fix(stream): address TLS GraphStage review feedback
---
.../scala/org/apache/pekko/util/ByteString.scala | 45 +-
bench-jmh/src/main/resources/keystore | Bin 0 -> 2397 bytes
bench-jmh/src/main/resources/truststore | Bin 0 -> 857 bytes
.../org/apache/pekko/stream/io/TlsBenchmark.scala | 216 ++++
.../stream/io/TlsGraphStageEdgeCasesSpec.scala | 331 ++++++
.../stream/io/TlsGraphStageIsolatedSpec.scala | 238 +++++
.../scala/org/apache/pekko/stream/io/TlsSpec.scala | 39 +-
.../stream/scaladsl/TlsEngineSelectionSpec.scala | 36 +
stream/src/main/resources/reference.conf | 12 +
.../org/apache/pekko/stream/impl/Stages.scala | 1 +
.../org/apache/pekko/stream/impl/io/TLSActor.scala | 36 +-
.../pekko/stream/impl/io/TlsEngineHelpers.scala | 75 ++
.../pekko/stream/impl/io/TlsGraphStage.scala | 1103 ++++++++++++++++++++
.../org/apache/pekko/stream/scaladsl/TLS.scala | 50 +-
14 files changed, 2142 insertions(+), 40 deletions(-)
diff --git a/actor/src/main/scala/org/apache/pekko/util/ByteString.scala
b/actor/src/main/scala/org/apache/pekko/util/ByteString.scala
index ab9b8f4a5f..bf91ab24fd 100644
--- a/actor/src/main/scala/org/apache/pekko/util/ByteString.scala
+++ b/actor/src/main/scala/org/apache/pekko/util/ByteString.scala
@@ -415,6 +415,10 @@ object ByteString {
override def copyToBuffer(buffer: ByteBuffer): Int =
writeToBuffer(buffer, offset = 0)
+ /** INTERNAL API: Specialized for internal use, copying from an offset
without slicing. */
+ private[pekko] override def copyToBuffer(buffer: ByteBuffer, offset: Int):
Int =
+ writeToBuffer(buffer, offset)
+
/** INTERNAL API: Specialized for internal use, writing multiple
ByteString1C into the same ByteBuffer. */
private[pekko] def writeToBuffer(buffer: ByteBuffer, offset: Int): Int = {
val copyLength = Math.max(0, Math.min(buffer.remaining, length - offset))
@@ -552,13 +556,17 @@ object ByteString {
}
override def copyToBuffer(buffer: ByteBuffer): Int =
- writeToBuffer(buffer)
+ writeToBuffer(buffer, offset = 0)
+
+ /** INTERNAL API: Specialized for internal use, copying from an offset
without slicing. */
+ private[pekko] override def copyToBuffer(buffer: ByteBuffer, offset: Int):
Int =
+ writeToBuffer(buffer, offset)
/** INTERNAL API: Specialized for internal use, writing multiple
ByteString1C into the same ByteBuffer. */
- private[pekko] def writeToBuffer(buffer: ByteBuffer): Int = {
- val copyLength = Math.min(buffer.remaining, length)
+ private[pekko] def writeToBuffer(buffer: ByteBuffer, offset: Int): Int = {
+ val copyLength = Math.max(0, Math.min(buffer.remaining, length - offset))
if (copyLength > 0) {
- buffer.put(bytes, startIndex, copyLength)
+ buffer.put(bytes, startIndex + offset, copyLength)
}
copyLength
}
@@ -946,12 +954,28 @@ object ByteString {
def isCompact: Boolean = if (bytestrings.length == 1)
bytestrings.head.isCompact else false
- override def copyToBuffer(buffer: ByteBuffer): Int = {
- val it = bytestrings.iterator
+ override def copyToBuffer(buffer: ByteBuffer): Int =
+ copyToBuffer(buffer, offset = 0)
+
+ /** INTERNAL API: Specialized for internal use, copying from an offset
without slicing. */
+ private[pekko] override def copyToBuffer(buffer: ByteBuffer, offset: Int):
Int = {
+ var remainingOffset = offset
var written = 0
- while (it.hasNext && buffer.hasRemaining) {
- written += it.next().writeToBuffer(buffer)
+ var i = 0
+ val count = bytestrings.length
+
+ while (i < count && buffer.hasRemaining) {
+ val fragment = bytestrings(i)
+ val fragmentLength = fragment.length
+ if (remainingOffset >= fragmentLength) {
+ remainingOffset -= fragmentLength
+ } else {
+ written += fragment.writeToBuffer(buffer, remainingOffset)
+ remainingOffset = 0
+ }
+ i += 1
}
+
written
}
@@ -1772,6 +1796,11 @@ sealed abstract class ByteString
*/
def copyToBuffer(@nowarn("msg=never used") buffer: ByteBuffer): Int
+ /** INTERNAL API: Copy bytes to a ByteBuffer from a ByteString offset
without allocating a slice. */
+ private[pekko] def copyToBuffer(buffer: ByteBuffer, offset: Int): Int =
+ if (offset <= 0) copyToBuffer(buffer)
+ else drop(offset).copyToBuffer(buffer)
+
/**
* Create a new ByteString with all contents compacted into a single,
* full byte array.
diff --git a/bench-jmh/src/main/resources/keystore
b/bench-jmh/src/main/resources/keystore
new file mode 100644
index 0000000000..2b0237562b
Binary files /dev/null and b/bench-jmh/src/main/resources/keystore differ
diff --git a/bench-jmh/src/main/resources/truststore
b/bench-jmh/src/main/resources/truststore
new file mode 100644
index 0000000000..3cc1983600
Binary files /dev/null and b/bench-jmh/src/main/resources/truststore differ
diff --git
a/bench-jmh/src/main/scala/org/apache/pekko/stream/io/TlsBenchmark.scala
b/bench-jmh/src/main/scala/org/apache/pekko/stream/io/TlsBenchmark.scala
new file mode 100644
index 0000000000..e02ee3ada9
--- /dev/null
+++ b/bench-jmh/src/main/scala/org/apache/pekko/stream/io/TlsBenchmark.scala
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.pekko.stream.io
+
+import java.security.{ KeyStore, SecureRandom }
+import java.util.concurrent.TimeUnit
+import javax.net.ssl.{ KeyManagerFactory, SSLContext, SSLEngine, SSLSession,
TrustManagerFactory }
+
+import scala.concurrent.Await
+import scala.concurrent.duration._
+import scala.util.{ Success, Try }
+
+import com.typesafe.config.{ Config, ConfigFactory }
+import org.openjdk.jmh.annotations._
+
+import org.apache.pekko
+import pekko.NotUsed
+import pekko.actor.ActorSystem
+import pekko.stream._
+import pekko.stream.TLSProtocol._
+import pekko.stream.impl.io.{ TlsGraphStage, TlsModule }
+import pekko.stream.scaladsl._
+import pekko.util.ByteString
+
+/**
+ * JMH benchmark comparing the legacy actor-based TLS path (`TlsModule`) to the
+ * GraphStage path (`TlsGraphStage`).
+ *
+ * - `warmRoundTrip` drives a fixed payload through a client+server echo
loop, with
+ * the SSL engines reused across invocations (one materialization per
@Setup).
+ * This isolates per-record encrypt/decrypt overhead — handshake cost is
amortized
+ * away by the iteration count.
+ * - `coldHandshake` measures the cost of materializing a fresh client+server
pair
+ * and completing the TLS handshake before transferring a tiny payload. This
+ * represents short-lived connections (e.g. HTTPS request/response).
+ *
+ * Run with:
+ * {{{
+ * sbt "bench-jmh/Jmh/run -i 5 -wi 3 -f1 -t1 .*TlsBenchmark.*"
+ * }}}
+ */
+@State(Scope.Benchmark)
+@OutputTimeUnit(TimeUnit.MILLISECONDS)
+@BenchmarkMode(Array(Mode.Throughput))
+@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
+@Measurement(iterations = 10, time = 1, timeUnit = TimeUnit.SECONDS)
+@Fork(1)
+class TlsBenchmark {
+
+ private val config: Config = ConfigFactory.parseString("""
+ pekko {
+ log-config-on-start = off
+ log-dead-letters-during-shutdown = off
+ stdout-loglevel = "OFF"
+ loglevel = "OFF"
+ actor.default-dispatcher {
+ throughput = 1024
+ }
+ actor.default-mailbox {
+ mailbox-type =
"org.apache.pekko.dispatch.SingleConsumerOnlyUnboundedMailbox"
+ }
+ }""".stripMargin).withFallback(ConfigFactory.load())
+
+ implicit var system: ActorSystem = _
+ private var sslContext: SSLContext = _
+ private var ciphers: Array[String] = _
+
+ @Param(Array("legacy", "graphstage"))
+ var implementation: String = _
+
+ // 256 B = control message; 4 KiB = typical HTTP request; 64 KiB = streaming
chunk
+ @Param(Array("256", "4096", "65536"))
+ var payloadSize: Int = _
+
+ private var payload: ByteString = _
+ private var payloads: scala.collection.immutable.IndexedSeq[SslTlsOutbound]
= _
+
+ @Setup
+ def setup(): Unit = {
+ system = ActorSystem("TlsBenchmark", config)
+ SystemMaterializer(system).materializer
+
+ sslContext = TlsBenchmark.initSslContext("TLSv1.2")
+ ciphers = TlsBenchmark.TLS12Ciphers.toArray
+
+ payload = ByteString(Array.fill[Byte](payloadSize)('a'.toByte))
+ payloads = (0 until TlsBenchmark.WarmRoundTripRecords).map(_ =>
SendBytes(payload))
+ }
+
+ @TearDown
+ def shutdown(): Unit = {
+ Await.result(system.terminate(), 10.seconds)
+ }
+
+ private def engine(role: TLSRole): SSLEngine = {
+ val e = sslContext.createSSLEngine()
+ e.setUseClientMode(role == Client)
+ e.setEnabledCipherSuites(ciphers)
+ e.setEnabledProtocols(Array("TLSv1.2"))
+ e
+ }
+
+ private def makeBidi(
+ role: TLSRole,
+ closing: TLSClosing,
+ verifySession: SSLSession => Try[Unit] = _ => Success(()))
+ : BidiFlow[SslTlsOutbound, ByteString, ByteString, SslTlsInbound,
NotUsed] =
+ implementation match {
+ case "legacy" =>
+ BidiFlow.fromGraph(
+ TlsModule(Attributes.none, () => engine(role), verifySession,
closing))
+ case "graphstage" =>
+ graphStageBidi(role, closing, verifySession)
+ }
+
+ private def graphStageBidi(
+ role: TLSRole,
+ closing: TLSClosing,
+ verifySession: SSLSession => Try[Unit])
+ : BidiFlow[SslTlsOutbound, ByteString, ByteString, SslTlsInbound,
NotUsed] =
+ BidiFlow
+ .fromGraph(new TlsGraphStage(() => engine(role), verifySession, closing))
+ .withAttributes(TlsGraphStage.StreamTlsAttributes)
+
+ /**
+ * Warm round-trip: 1000 payloads through a fresh client+server pair. The
+ * handshake is amortized over the records and the sink counts bytes instead
+ * of concatenating payloads, keeping the measurement focused on TLS work.
+ */
+ @Benchmark
+ @OperationsPerInvocation(1000)
+ def warmRoundTrip(): Unit = {
+ val records = TlsBenchmark.WarmRoundTripRecords
+ val expected = payload.size * records
+ val client = makeBidi(Client, IgnoreComplete)
+ val server = makeBidi(Server, IgnoreComplete)
+ val echo = Flow[SslTlsInbound].collect { case SessionBytes(_, b) =>
SendBytes(b) }
+
+ val done = Source(payloads)
+ .via(client.atop(server.reversed).join(echo))
+ .collect { case SessionBytes(_, b) => b }
+ .scan(0)((acc, b) => acc + b.size)
+ .dropWhile(_ < expected)
+ .runWith(Sink.headOption)
+
+ Await.result(done, 30.seconds)
+ }
+
+ /**
+ * Cold handshake: each invocation builds a fresh client+server pair and
+ * completes the handshake by exchanging one configured payload. The sink
+ * counts bytes only, which avoids charging ByteString concatenation to the
+ * TLS implementation being tested.
+ */
+ @Benchmark
+ def coldHandshake(): Unit = {
+ val client = makeBidi(Client, IgnoreComplete)
+ val server = makeBidi(Server, IgnoreComplete)
+ val expected = payload.size
+ val echo = Flow[SslTlsInbound].collect { case SessionBytes(_, b) =>
SendBytes(b) }
+
+ val done = Source
+ .single[SslTlsOutbound](SendBytes(payload))
+ .via(client.atop(server.reversed).join(echo))
+ .collect { case SessionBytes(_, b) => b }
+ .scan(0)((acc, b) => acc + b.size)
+ .dropWhile(_ < expected)
+ .runWith(Sink.headOption)
+
+ Await.result(done, 30.seconds)
+ }
+}
+
+object TlsBenchmark {
+
+ final val WarmRoundTripRecords = 1000
+
+ val TLS12Ciphers: Set[String] = Set(
+ "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
+ "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384")
+
+ def initSslContext(protocol: String): SSLContext = {
+ val password = "changeme"
+
+ val keyStore = KeyStore.getInstance(KeyStore.getDefaultType)
+ keyStore.load(getClass.getResourceAsStream("/keystore"),
password.toCharArray)
+
+ val trustStore = KeyStore.getInstance(KeyStore.getDefaultType)
+ trustStore.load(getClass.getResourceAsStream("/truststore"),
password.toCharArray)
+
+ val keyManagerFactory =
KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm)
+ keyManagerFactory.init(keyStore, password.toCharArray)
+
+ val trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
+ trustManagerFactory.init(trustStore)
+
+ val context = SSLContext.getInstance(protocol)
+ context.init(keyManagerFactory.getKeyManagers,
trustManagerFactory.getTrustManagers, new SecureRandom)
+ context
+ }
+}
diff --git
a/stream-tests/src/test/scala/org/apache/pekko/stream/io/TlsGraphStageEdgeCasesSpec.scala
b/stream-tests/src/test/scala/org/apache/pekko/stream/io/TlsGraphStageEdgeCasesSpec.scala
new file mode 100644
index 0000000000..2e9e3216bb
--- /dev/null
+++
b/stream-tests/src/test/scala/org/apache/pekko/stream/io/TlsGraphStageEdgeCasesSpec.scala
@@ -0,0 +1,331 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.pekko.stream.io
+
+import javax.net.ssl.{ SSLContext, SSLEngine, SSLSession }
+
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.concurrent.Await
+import scala.concurrent.duration._
+import scala.util.{ Success, Try }
+
+import com.typesafe.config.ConfigFactory
+
+import org.apache.pekko
+import pekko.NotUsed
+import pekko.stream._
+import pekko.stream.TLSProtocol._
+import pekko.stream.impl.io.TlsGraphStage
+import pekko.stream.scaladsl._
+import pekko.stream.testkit.StreamSpec
+import pekko.testkit.TestDuration
+import pekko.util.ByteString
+
+/**
+ * Edge cases for the [[TlsGraphStage]] path that are awkward to express in the
+ * shared [[TlsGraphStageSpec]] matrix: fragmented TLS records, user-side
+ * backpressure, materialization isolation, and empty-source shutdown.
+ */
+class TlsGraphStageEdgeCasesSpec extends
StreamSpec(TlsGraphStageEdgeCasesSpec.config) {
+
+ private val sslContext: SSLContext = TlsSpec.initSslContext("TLSv1.2")
+ private val ciphers = TlsSpec.TLS12Ciphers.toArray
+
+ private def engine(role: TLSRole): SSLEngine = {
+ val e = sslContext.createSSLEngine()
+ e.setUseClientMode(role == Client)
+ e.setEnabledCipherSuites(ciphers)
+ e.setEnabledProtocols(Array("TLSv1.2"))
+ e
+ }
+
+ private def graphStageBidi(
+ role: TLSRole,
+ closing: TLSClosing = IgnoreComplete,
+ verifySession: SSLSession => Try[Unit] = _ => Success(()))
+ : BidiFlow[SslTlsOutbound, ByteString, ByteString, SslTlsInbound,
NotUsed] =
+ BidiFlow
+ .fromGraph(new TlsGraphStage(() => engine(role), verifySession, closing))
+ .withAttributes(TlsGraphStage.StreamTlsAttributes)
+
+ /** Drive a roundtrip and stop once the expected number of plain bytes
arrives. */
+ private def collectExactly(
+ stream: Source[SslTlsInbound, NotUsed],
+ expectedBytes: Int,
+ timeout: FiniteDuration = 30.seconds): ByteString =
+ Await.result(
+ stream
+ .collect { case SessionBytes(_, b) => b }
+ .scan(ByteString.empty)(_ ++ _)
+ .dropWhile(_.size < expectedBytes)
+ .runWith(Sink.headOption),
+ timeout.dilated).getOrElse(ByteString.empty)
+
+ "TlsGraphStage" must {
+
+ "expose async boundary and input buffer attributes for the production
path" in {
+ val attributes = TlsGraphStage.StreamTlsAttributes
+
+ attributes.contains(Attributes.AsyncBoundary) shouldBe true
+ attributes.get[Attributes.InputBuffer] shouldBe Some(
+ Attributes.InputBuffer(TlsGraphStage.InputBufferInitialSize,
TlsGraphStage.InputBufferMaxSize))
+ }
+
+ "name ports according to the plaintext and cipher sides" in {
+ val graph = new TlsGraphStage(() => engine(Client), _ => Success(()),
IgnoreComplete)
+
+ graph.shape.inlets.map(_.s) shouldBe Seq("TlsGraphStage.plainIn",
"TlsGraphStage.cipherIn")
+ graph.shape.outlets.map(_.s) shouldBe Seq("TlsGraphStage.cipherOut",
"TlsGraphStage.plainOut")
+ }
+
+ "be wrapped in its own async island via Attributes.asyncBoundary" in {
+ val client = graphStageBidi(Client)
+ val server = graphStageBidi(Server)
+ val echo = Flow[SslTlsInbound].collect { case SessionBytes(_, b) =>
SendBytes(b) }
+ val payload = ByteString("ping")
+ val received = collectExactly(
+ Source.single[SslTlsOutbound](SendBytes(payload))
+ .via(client.atop(server.reversed).join(echo)),
+ payload.size)
+ received shouldBe payload
+ }
+
+ "decode cipher payloads delivered as one-byte fragments (BUFFER_UNDERFLOW
recovery)" in {
+ // Fragmenting cipher chunks to one byte each is the most aggressive way
to force
+ // BUFFER_UNDERFLOW: every TLS record header (5 bytes) and the entire
payload arrive
+ // across many separate `onPush` events. The stage must keep pulling
more bytes
+ // instead of deadlocking once it sees its first underflow.
+ val client = graphStageBidi(Client)
+ val server = graphStageBidi(Server)
+ val fragmenter = Flow[ByteString].mapConcat(_.toIndexedSeq.map(b =>
ByteString(b)))
+ val transport = BidiFlow.fromFlows(fragmenter, fragmenter)
+ val echo = Flow[SslTlsInbound].collect { case SessionBytes(_, b) =>
SendBytes(b) }
+ val payload = ByteString("0123456789" * 200) // 2000 bytes of varied
content
+ val received = collectExactly(
+ Source.single[SslTlsOutbound](SendBytes(payload))
+ .via(client.atop(transport).atop(server.reversed).join(echo)),
+ payload.size,
+ timeout = 60.seconds)
+ received shouldBe payload
+ }
+
+ "batch many small plaintext writes into fewer transport chunks" in {
+ val client = graphStageBidi(Client)
+ val server = graphStageBidi(Server)
+ val clientToServerChunks = new AtomicInteger
+ val serverToClientChunks = new AtomicInteger
+ val transport = BidiFlow.fromFlows(
+ Flow[ByteString].map { bytes =>
+ clientToServerChunks.incrementAndGet()
+ bytes
+ },
+ Flow[ByteString].map { bytes =>
+ serverToClientChunks.incrementAndGet()
+ bytes
+ })
+ val echo = Flow[SslTlsInbound].collect { case SessionBytes(_, b) =>
SendBytes(b) }
+ val payloads = List.fill(160)(ByteString("a" * 64))
+ val expected = payloads.foldLeft(ByteString.empty)(_ ++ _)
+
+ val received = collectExactly(
+ Source(payloads).map[SslTlsOutbound](SendBytes.apply)
+ .via(client.atop(transport).atop(server.reversed).join(echo)),
+ expected.size,
+ timeout = 60.seconds)
+
+ received shouldBe expected
+ clientToServerChunks.get should be < (payloads.size / 4)
+ serverToClientChunks.get should be < (payloads.size / 4)
+ }
+
+ "flush batched small plaintext writes without waiting for upstream
completion" in {
+ val client = graphStageBidi(Client)
+ val server = graphStageBidi(Server)
+ val echo = Flow[SslTlsInbound].collect { case SessionBytes(_, b) =>
SendBytes(b) }
+ val payloads = Vector.fill(64)(ByteString("b" * 64))
+ val expected = payloads.foldLeft(ByteString.empty)(_ ++ _)
+ val (tail, done) =
+ Source(payloads)
+ .map[SslTlsOutbound](SendBytes.apply)
+ .concatMat(Source.maybe[SslTlsOutbound])(Keep.right)
+ .via(client.atop(server.reversed).join(echo))
+ .collect { case SessionBytes(_, b) => b }
+ .scan(ByteString.empty)(_ ++ _)
+ .dropWhile(_.size < expected.size)
+ .toMat(Sink.headOption)(Keep.both)
+ .run()
+
+ try Await.result(done, 10.seconds.dilated).getOrElse(ByteString.empty)
shouldBe expected
+ finally tail.trySuccess(None)
+ }
+
+ "flush large plaintext writes without waiting for upstream completion" in {
+ val client = graphStageBidi(Client)
+ val server = graphStageBidi(Server)
+ val echo = Flow[SslTlsInbound].collect { case SessionBytes(_, b) =>
SendBytes(b) }
+ val payload = ByteString(Array.fill[Byte](64 * 1024)('c'.toByte))
+ val (tail, done) =
+ Source
+ .single[SslTlsOutbound](SendBytes(payload))
+ .concatMat(Source.maybe[SslTlsOutbound])(Keep.right)
+ .via(client.atop(server.reversed).join(echo))
+ .collect { case SessionBytes(_, b) => b }
+ .scan(ByteString.empty)(_ ++ _)
+ .dropWhile(_.size < payload.size)
+ .toMat(Sink.headOption)(Keep.both)
+ .run()
+
+ try Await.result(done, 10.seconds.dilated).getOrElse(ByteString.empty)
shouldBe payload
+ finally tail.trySuccess(None)
+ }
+
+ "deliver the benchmark burst without losing bytes" in {
+ val client = graphStageBidi(Client)
+ val server = graphStageBidi(Server)
+ val echo = Flow[SslTlsInbound].collect { case SessionBytes(_, b) =>
SendBytes(b) }
+ val payload = ByteString(Array.fill[Byte](4096)('a'.toByte))
+ val payloads = Vector.fill(100)(SendBytes(payload))
+ val expectedBytes = payload.size * payloads.size
+
+ val received = collectExactly(
+ Source(payloads)
+ .via(client.atop(server.reversed).join(echo)),
+ expectedBytes,
+ timeout = 60.seconds)
+
+ received.size shouldBe expectedBytes
+ }
+
+ "deliver the cold-handshake benchmark payload without waiting for stream
completion" in {
+ val client = graphStageBidi(Client)
+ val server = graphStageBidi(Server)
+ val echo = Flow[SslTlsInbound].collect { case SessionBytes(_, b) =>
SendBytes(b) }
+ val payload = ByteString(Array.fill[Byte](4096)('a'.toByte))
+
+ val received = collectExactly(
+ Source.single[SslTlsOutbound](SendBytes(payload))
+ .via(client.atop(server.reversed).join(echo)),
+ payload.size,
+ timeout = 10.seconds)
+
+ received shouldBe payload
+ }
+
+ "complete repeated cold-handshake benchmark payloads after downstream
early cancellation" in {
+ val payload = ByteString(Array.fill[Byte](4096)('a'.toByte))
+
+ (1 to 1000).foreach { _ =>
+ val client = graphStageBidi(Client)
+ val server = graphStageBidi(Server)
+ val echo = Flow[SslTlsInbound].collect { case SessionBytes(_, b) =>
SendBytes(b) }
+ val received = collectExactly(
+ Source.single[SslTlsOutbound](SendBytes(payload))
+ .via(client.atop(server.reversed).join(echo)),
+ payload.size,
+ timeout = 10.seconds)
+
+ received shouldBe payload
+ }
+ }
+
+ "complete repeated benchmark bursts after downstream early cancellation"
in {
+ val payload = ByteString(Array.fill[Byte](4096)('a'.toByte))
+ val payloads = Vector.fill(100)(SendBytes(payload))
+ val expectedBytes = payload.size * payloads.size
+
+ (1 to 50).foreach { _ =>
+ val client = graphStageBidi(Client)
+ val server = graphStageBidi(Server)
+ val echo = Flow[SslTlsInbound].collect { case SessionBytes(_, b) =>
SendBytes(b) }
+ val received = collectExactly(
+ Source(payloads)
+ .via(client.atop(server.reversed).join(echo)),
+ expectedBytes,
+ timeout = 10.seconds)
+
+ received.size shouldBe expectedBytes
+ }
+ }
+
+ "deliver all bytes when plainOut downstream is slow (backpressure)" in {
+ val client = graphStageBidi(Client)
+ val server = graphStageBidi(Server)
+ // 10x ~16 KiB blocks crosses several TLS records and sustained plainOut
+ // backpressure forces the stage to buffer/wait without losing data.
+ val blocks = (0 until 10).map(_ => ByteString("a" * 16384))
+ val totalSize = blocks.map(_.length).sum
+ val echo = Flow[SslTlsInbound].collect { case SessionBytes(_, b) =>
SendBytes(b) }
+ val received = Await.result(
+ Source(blocks)
+ .map[SslTlsOutbound](b => SendBytes(b))
+ .via(client.atop(server.reversed).join(echo))
+ .collect { case SessionBytes(_, b) => b }
+ .throttle(1, 10.millis)
+ .scan(ByteString.empty)(_ ++ _)
+ .dropWhile(_.size < totalSize)
+ .runWith(Sink.headOption),
+ 60.seconds.dilated).getOrElse(ByteString.empty)
+ received.length shouldBe totalSize
+ received shouldBe blocks.foldLeft(ByteString.empty)(_ ++ _)
+ }
+
+ "give each materialization its own SSLEngine (no shared state)" in {
+ // Reusing a closed SSLEngine would make the second or third
materialization fail.
+ def round(payload: ByteString): ByteString = {
+ val client = graphStageBidi(Client)
+ val server = graphStageBidi(Server)
+ val echo = Flow[SslTlsInbound].collect { case SessionBytes(_, b) =>
SendBytes(b) }
+ collectExactly(
+ Source.single[SslTlsOutbound](SendBytes(payload))
+ .via(client.atop(server.reversed).join(echo)),
+ payload.size)
+ }
+
+ round(ByteString("first")) shouldBe ByteString("first")
+ round(ByteString("second")) shouldBe ByteString("second")
+ round(ByteString("third")) shouldBe ByteString("third")
+ }
+
+ "complete cleanly when source is empty (CompletedImmediately variant)" in {
+ // Both sides use EagerClose: an empty source completes immediately, the
+ // close cascade fires before any payload travels, and both stages must
+ // finalize without errors. No bytes should be delivered.
+ val client = graphStageBidi(Client, EagerClose)
+ val server = graphStageBidi(Server, EagerClose)
+ val echo = Flow[SslTlsInbound].collect { case SessionBytes(_, b) =>
SendBytes(b) }
+ val received = Await.result(
+ Source.empty[SslTlsOutbound]
+ .via(client.atop(server.reversed).join(echo))
+ .collect { case SessionBytes(_, b) => b }
+ .runFold(ByteString.empty)(_ ++ _),
+ 30.seconds.dilated)
+ received shouldBe ByteString.empty
+ }
+ }
+}
+
+object TlsGraphStageEdgeCasesSpec {
+ val config =
+ ConfigFactory
+ .parseString("""
+ pekko.actor.default-dispatcher.throughput = 1024
+ pekko.actor.default-mailbox.mailbox-type =
"org.apache.pekko.dispatch.SingleConsumerOnlyUnboundedMailbox"
+ """)
+ .withFallback(ConfigFactory.parseString(TlsSpec.configOverrides))
+}
diff --git
a/stream-tests/src/test/scala/org/apache/pekko/stream/io/TlsGraphStageIsolatedSpec.scala
b/stream-tests/src/test/scala/org/apache/pekko/stream/io/TlsGraphStageIsolatedSpec.scala
new file mode 100644
index 0000000000..53390bef9f
--- /dev/null
+++
b/stream-tests/src/test/scala/org/apache/pekko/stream/io/TlsGraphStageIsolatedSpec.scala
@@ -0,0 +1,238 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.pekko.stream.io
+
+import javax.net.ssl._
+
+import scala.concurrent.Await
+import scala.concurrent.duration._
+import scala.util.Success
+
+import org.apache.pekko
+import pekko.NotUsed
+import pekko.stream._
+import pekko.stream.TLSProtocol._
+import pekko.stream.impl.io.TlsGraphStage
+import pekko.stream.scaladsl._
+import pekko.stream.scaladsl.GraphDSL.Implicits._
+import pekko.stream.testkit.{ StreamSpec, TestPublisher }
+import pekko.testkit.WithLogCapturing
+import pekko.util.ByteString
+
+/**
+ * Focused [[TlsGraphStage]] tests for early failures, empty inputs, large
+ * fragmented inputs, and TLS 1.2 renegotiation.
+ */
+class TlsGraphStageIsolatedSpec extends StreamSpec(TlsSpec.configOverrides)
with WithLogCapturing {
+
+ import TlsSpec._
+
+ /** Constructs a [[BidiFlow]] backed by a single [[TlsGraphStage]] instance.
*/
+ private def stageFlow(
+ ctx: SSLContext,
+ ciphers: Set[String],
+ clientMode: Boolean,
+ closing: TLSClosing,
+ asyncBoundary: Boolean = true): BidiFlow[SslTlsOutbound, ByteString,
ByteString, SslTlsInbound, NotUsed] = {
+ val stage = new TlsGraphStage(
+ () => {
+ val engine = ctx.createSSLEngine()
+ engine.setUseClientMode(clientMode)
+ if (ciphers.nonEmpty) engine.setEnabledCipherSuites(ciphers.toArray)
+ engine
+ },
+ _ => Success(()),
+ closing)
+ val flow = BidiFlow.fromGraph(stage)
+ if (asyncBoundary) flow.addAttributes(TlsGraphStage.StreamTlsAttributes)
else flow
+ }
+
+ /** Connects a client [[TlsGraphStage]] to a server [[TlsGraphStage]] in
memory. */
+ private def loopbackFlow(
+ ctx: SSLContext,
+ ciphers: Set[String],
+ clientClosing: TLSClosing,
+ serverClosing: TLSClosing,
+ flow: Flow[SslTlsInbound, SslTlsOutbound, NotUsed]):
Flow[SslTlsOutbound, SslTlsInbound, NotUsed] = {
+ val client = stageFlow(ctx, ciphers, clientMode = true, clientClosing)
+ val server = stageFlow(ctx, ciphers, clientMode = false, serverClosing)
+ client.atop(server.reversed).join(flow)
+ }
+
+ private val echoApplicationFlow = Flow[SslTlsInbound].collect {
+ case SessionBytes(_, b) => SendBytes(b)
+ }
+
+ private def roundTrip(
+ ctx: SSLContext,
+ inputs: Seq[ByteString],
+ clientClosing: TLSClosing = IgnoreComplete,
+ serverClosing: TLSClosing = IgnoreComplete,
+ timeout: FiniteDuration = 20.seconds): ByteString = {
+ val expectedBytes = inputs.foldLeft(0)(_ + _.size)
+ val received =
+ Source(inputs.map(SendBytes.apply).toList)
+ .via(loopbackFlow(ctx, TLS12Ciphers, clientClosing, serverClosing,
echoApplicationFlow))
+ .collect { case SessionBytes(_, b) => b }
+
+ if (expectedBytes == 0) {
+ val outputs = Await.result(received.runWith(Sink.seq), timeout)
+ outputs.foldLeft(ByteString.empty)(_ ++ _)
+ } else {
+ Await.result(
+ received
+ .scan(ByteString.empty)(_ ++ _)
+ .drop(1)
+ .filter(_.size >= expectedBytes)
+ .runWith(Sink.head),
+ timeout)
+ }
+ }
+
+ "TlsGraphStage isolated cases" must {
+
+ "reliably cancel subscriptions when cipherIn (TransportIn) fails early" in
{
+ val ex = new Exception("transport-in-failure")
+ // asyncBoundary = false: this test exercises stage-level error
propagation in
+ // isolation and does not need a separate async island.
+ val client = stageFlow(initSslContext("TLSv1.2"), TLS12Ciphers,
clientMode = true, EagerClose,
+ asyncBoundary = false)
+
+ val (sub, out1, out2) =
+ RunnableGraph
+ .fromGraph(
+ GraphDSL.createGraph(
+ Source.asSubscriber[SslTlsOutbound],
+ Sink.head[ByteString],
+ Sink.head[SslTlsInbound])((_, _, _)) { implicit b => (s, o1, o2)
=>
+ val tls = b.add(client)
+ s ~> tls.in1
+ tls.out1 ~> o1
+ o2 <~ tls.out2
+ tls.in2 <~ Source.failed(ex)
+ ClosedShape
+ })
+ .run()
+
+ the[Exception] thrownBy Await.result(out1, 3.seconds) should be(ex)
+ the[Exception] thrownBy Await.result(out2, 3.seconds) should be(ex)
+ val pub = TestPublisher.probe()
+ pub.subscribe(sub)
+ pub.expectSubscription().expectCancellation()
+ }
+
+ "reliably cancel subscriptions when plainIn (UserIn) fails early" in {
+ val ex = new Exception("user-in-failure")
+ // asyncBoundary = false: this test exercises stage-level error
propagation in
+ // isolation and does not need a separate async island.
+ val client = stageFlow(initSslContext("TLSv1.2"), TLS12Ciphers,
clientMode = true, EagerClose,
+ asyncBoundary = false)
+
+ val (sub, out1, out2) =
+ RunnableGraph
+ .fromGraph(
+ GraphDSL.createGraph(
+ Source.asSubscriber[ByteString],
+ Sink.head[ByteString],
+ Sink.head[SslTlsInbound])((_, _, _)) { implicit b => (s, o1, o2)
=>
+ val tls = b.add(client)
+ Source.failed[SslTlsOutbound](ex) ~> tls.in1
+ tls.out1 ~> o1
+ o2 <~ tls.out2
+ tls.in2 <~ s
+ ClosedShape
+ })
+ .run()
+
+ the[Exception] thrownBy Await.result(out1, 3.seconds) should be(ex)
+ the[Exception] thrownBy Await.result(out2, 3.seconds) should be(ex)
+ val pub = TestPublisher.probe()
+ pub.subscribe(sub)
+ pub.expectSubscription().expectCancellation()
+ }
+
+ "round-trip alternating empty and non-empty ByteString inputs exactly" in {
+ val input = List(
+ ByteString.empty,
+ ByteString("A"),
+ ByteString.empty,
+ ByteString("BC"),
+ ByteString.empty,
+ ByteString("DEF"),
+ ByteString.empty)
+ val expected = input.foldLeft(ByteString.empty)(_ ++ _)
+
+ roundTrip(initSslContext("TLSv1.2"), input) shouldEqual expected
+ }
+
+ "round-trip a fragmented large payload exactly" in {
+ val payloadSize = (64 * 1024) + 1
+ val payload = ByteString(Array.tabulate[Byte](payloadSize)(i => (i %
251).toByte))
+
+ roundTrip(initSslContext("TLSv1.2"), List(payload), timeout =
30.seconds) shouldEqual payload
+ }
+
+ "complete without payload when both sides are EagerClose" in {
+ roundTrip(
+ initSslContext("TLSv1.2"),
+ Nil,
+ clientClosing = EagerClose,
+ serverClosing = EagerClose,
+ timeout = 10.seconds) shouldEqual ByteString.empty
+ }
+
+ "complete when plainIn finishes immediately under
EagerClose/IgnoreComplete" in {
+ roundTrip(
+ initSslContext("TLSv1.2"),
+ Nil,
+ clientClosing = EagerClose,
+ serverClosing = IgnoreComplete,
+ timeout = 10.seconds) shouldEqual ByteString.empty
+ }
+
+ "pass data before and after NegotiateNewSession (TLS 1.2 renegotiation)"
in {
+ val renegotiationContext = initSslContext("TLSv1.2")
+ val markNewSession = Flow[SslTlsInbound].map {
+ var session: SSLSession = null
+
+ {
+ case SessionTruncated =>
SendBytes(ByteString("TRUNCATED"))
+ case SessionBytes(s, b) if session eq null =>
+ session = s
+ SendBytes(b)
+ case SessionBytes(s, b) if s != session =>
+ session = s
+ SendBytes(ByteString("NEWSESSION") ++ b)
+ case SessionBytes(_, b) => SendBytes(b)
+ }
+ }
+
+ val expected = ByteString("helloNEWSESSIONworld")
+ val outputs = Await.result(
+ Source(List[SslTlsOutbound](SendBytes(ByteString("hello")),
NegotiateNewSession,
+ SendBytes(ByteString("world"))))
+ .via(loopbackFlow(renegotiationContext, TLS12Ciphers,
IgnoreComplete, IgnoreComplete, markNewSession))
+ .collect { case SessionBytes(_, b) => b }
+ .takeWithin(10.seconds)
+ .runWith(Sink.seq),
+ 20.seconds)
+
+ outputs.foldLeft(ByteString.empty)(_ ++ _) shouldEqual expected
+ }
+ }
+}
diff --git
a/stream-tests/src/test/scala/org/apache/pekko/stream/io/TlsSpec.scala
b/stream-tests/src/test/scala/org/apache/pekko/stream/io/TlsSpec.scala
index 10f3655f81..32c152bf13 100644
--- a/stream-tests/src/test/scala/org/apache/pekko/stream/io/TlsSpec.scala
+++ b/stream-tests/src/test/scala/org/apache/pekko/stream/io/TlsSpec.scala
@@ -23,7 +23,7 @@ import scala.collection.immutable
import scala.concurrent.Await
import scala.concurrent.Future
import scala.concurrent.duration._
-import scala.util.Random
+import scala.util.{ Random, Success }
import org.apache.pekko
import pekko.NotUsed
@@ -31,6 +31,7 @@ import pekko.pattern.{ after => later }
import pekko.stream._
import pekko.stream.TLSProtocol._
import pekko.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
+import pekko.stream.impl.io.{ TlsGraphStage, TlsModule }
import pekko.stream.scaladsl._
import pekko.stream.stage._
import pekko.stream.testkit._
@@ -111,12 +112,34 @@ object TlsSpec {
"""
}
-class TlsSpec extends StreamSpec(TlsSpec.configOverrides) with
WithLogCapturing {
+class TlsSpec extends AbstractTlsSpec(useLegacyActor = true)
+class TlsGraphStageSpec extends AbstractTlsSpec(useLegacyActor = false)
+
+abstract class AbstractTlsSpec(useLegacyActor: Boolean)
+ extends StreamSpec(TlsSpec.configOverrides)
+ with WithLogCapturing {
import GraphDSL.Implicits._
import TlsSpec._
import system.dispatcher
- "SslTls" must {
+ /**
+ * Build a TLS BidiFlow without going through the global TLS engine setting
so
+ * each subclass can independently exercise the legacy actor path or the
GraphStage path
+ * within the same JVM.
+ */
+ protected def tlsBidi(
+ createSSLEngine: () => SSLEngine,
+ verifySession: SSLSession => scala.util.Try[Unit] = _ => Success(()),
+ closing: TLSClosing): BidiFlow[SslTlsOutbound, ByteString, ByteString,
SslTlsInbound, NotUsed] =
+ if (useLegacyActor)
+ BidiFlow.fromGraph(
+ TlsModule(Attributes.none, () => createSSLEngine(), verifySession,
closing))
+ else
+ BidiFlow
+ .fromGraph(new TlsGraphStage(createSSLEngine, verifySession, closing))
+ .withAttributes(TlsGraphStage.StreamTlsAttributes)
+
+ s"SslTls (${if (useLegacyActor) "legacy actor path" else "GraphStage
path"})" must {
"work for TLSv1.2" must { workFor("TLSv1.2", TLS12Ciphers) }
"work for TLSv1.3" must { workFor("TLSv1.3", TLS13Ciphers) }
@@ -163,13 +186,13 @@ class TlsSpec extends StreamSpec(TlsSpec.configOverrides)
with WithLogCapturing
}
def clientTls(closing: TLSClosing) =
- TLS(() => createSSLEngine(sslContext, Client), closing)
+ tlsBidi(() => createSSLEngine(sslContext, Client), closing = closing)
def badClientTls(closing: TLSClosing) =
- TLS(() => createSSLEngine(initWithTrust("/badtruststore", protocol),
Client), closing)
+ tlsBidi(() => createSSLEngine(initWithTrust("/badtruststore",
protocol), Client), closing = closing)
def serverTls(closing: TLSClosing) =
- TLS(() => createSSLEngine(sslContext, Server), closing)
+ tlsBidi(() => createSSLEngine(sslContext, Server), closing = closing)
trait Named {
def name: String =
@@ -567,9 +590,9 @@ class TlsSpec extends StreamSpec(TlsSpec.configOverrides)
with WithLogCapturing
case SessionTruncated => SendBytes(ByteString.empty)
case SessionBytes(_, b) => SendBytes(b)
}
- val clientTls = TLS(
+ val clientTls = tlsBidi(
() => createSSLEngine2(sslContext, Client, hostnameVerification =
true, hostInfo = Some((hostName, 80))),
- EagerClose)
+ closing = EagerClose)
val flow = clientTls.atop(serverTls(EagerClose).reversed).join(rhs)
diff --git
a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/TlsEngineSelectionSpec.scala
b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/TlsEngineSelectionSpec.scala
new file mode 100644
index 0000000000..bf55ae6e6a
--- /dev/null
+++
b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/TlsEngineSelectionSpec.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.pekko.stream.scaladsl
+
+import org.apache.pekko.testkit.PekkoSpec
+
+class TlsEngineSelectionSpec extends PekkoSpec {
+
+ "TLS engine selection" should {
+
+ "ignore blank system properties and use the configured value" in {
+ TLS.configuredEngineName(Some(""), TLS.GraphStageEngineName) shouldBe
TLS.GraphStageEngineName
+ TLS.configuredEngineName(Some(" "), TLS.LegacyActorEngineName) shouldBe
TLS.LegacyActorEngineName
+ }
+
+ "let non-blank system properties override the configured value" in {
+ TLS.configuredEngineName(Some(TLS.GraphStageEngineName),
TLS.LegacyActorEngineName) shouldBe
+ TLS.GraphStageEngineName
+ }
+ }
+}
diff --git a/stream/src/main/resources/reference.conf
b/stream/src/main/resources/reference.conf
index bba71993ec..f4e4bd3193 100644
--- a/stream/src/main/resources/reference.conf
+++ b/stream/src/main/resources/reference.conf
@@ -140,6 +140,18 @@ pekko {
# Time to wait for async materializer creation before throwing an
exception
creation-timeout = 20 seconds
+ # Stream TLS engine.
+ #
+ # Valid values:
+ # - "legacy-actor": route TLS through the existing actor-backed engine.
+ # - "graph-stage": route TLS through the GraphStage engine that adapts
the
+ # same TLS pump state machine to GraphStage lifecycle and ports.
+ #
+ # This setting is read once when `org.apache.pekko.stream.scaladsl.TLS`
is
+ # initialized. Keep the GraphStage engine opt-in until it has enough
+ # production and benchmark evidence to become the default.
+ tls.engine = "legacy-actor"
+
//#stream-ref
# configure defaults for SourceRef and SinkRef
stream-ref {
diff --git a/stream/src/main/scala/org/apache/pekko/stream/impl/Stages.scala
b/stream/src/main/scala/org/apache/pekko/stream/impl/Stages.scala
index 6de53600bd..78584949e4 100755
--- a/stream/src/main/scala/org/apache/pekko/stream/impl/Stages.scala
+++ b/stream/src/main/scala/org/apache/pekko/stream/impl/Stages.scala
@@ -179,6 +179,7 @@ import pekko.stream.Attributes._
val inputBoundary = name("input-boundary")
val outputBoundary = name("output-boundary")
val dropRepeated = name("dropRepeated")
+ val tlsGraphStage = name("TlsGraphStage")
}
}
diff --git
a/stream/src/main/scala/org/apache/pekko/stream/impl/io/TLSActor.scala
b/stream/src/main/scala/org/apache/pekko/stream/impl/io/TLSActor.scala
index 02cdeda571..de8ffca080 100644
--- a/stream/src/main/scala/org/apache/pekko/stream/impl/io/TLSActor.scala
+++ b/stream/src/main/scala/org/apache/pekko/stream/impl/io/TLSActor.scala
@@ -107,7 +107,7 @@ import pekko.util.ByteString
* with these characteristics, use `prepare()`.
*/
def chopInto(b: ByteBuffer): Unit = {
- b.compact()
+ TlsEngineHelpers.prepareForAppend(b)
if (buffer.isEmpty) {
buffer = inputBunch.dequeue(idx) match {
// this class handles both UserIn and TransportIn
@@ -143,10 +143,7 @@ import pekko.util.ByteString
/**
* Prepare a fresh ByteBuffer for receiving a chop of data.
*/
- def prepare(b: ByteBuffer): Unit = {
- b.clear()
- b.limit(0)
- }
+ def prepare(b: ByteBuffer): Unit = TlsEngineHelpers.emptyReadBuffer(b)
}
// These are Netty's default values
@@ -283,7 +280,8 @@ import pekko.util.ByteString
}
def completeOrFlush(): Unit =
- if (engine.isOutboundDone || (engine.isInboundDone &&
userInChoppingBlock.isEmpty)) nextPhase(completedPhase)
+ if (engine.isOutboundDone || (engine.isInboundDone &&
userInChoppingBlock.isEmpty && !userInBuffer.hasRemaining))
+ nextPhase(completedPhase)
else nextPhase(flushingOutbound)
private def doInbound(isOutboundClosed: Boolean, inboundState:
TransferState): Boolean =
@@ -308,20 +306,22 @@ import pekko.util.ByteString
} else if (inboundState.isReady) {
transportInChoppingBlock.chopInto(transportInBuffer)
try {
- doUnwrap(ignoreOutput = false)
+ doUnwrap(ignoreOutput = (inboundState eq inboundHalfClosed) ||
outputBunch.isCancelled(UserOut))
true
} catch {
case ex: SSLException =>
if (tracing) log.debug(s"SSLException during doUnwrap: $ex")
fail(ex, closeTransport = false)
- engine.closeInbound() // we don't need to add lastHandshakeStatus
check here because
+ try engine.closeInbound()
+ catch { case _: SSLException => () }
completeOrFlush() // it doesn't make any sense to write anything to
the network anymore
false
}
} else true
private def doOutbound(isInboundClosed: Boolean): Unit =
- if (inputBunch.isDepleted(UserIn) && userInChoppingBlock.isEmpty &&
mayCloseOutbound) {
+ if (inputBunch.isDepleted(UserIn) && userInChoppingBlock.isEmpty &&
!userInBuffer.hasRemaining &&
+ mayCloseOutbound) {
if (!isInboundClosed && closing.ignoreComplete) {
if (tracing) log.debug("ignoring closeOutbound")
} else {
@@ -444,7 +444,7 @@ import pekko.util.ByteString
transportInBuffer.position() == oldInPosition =>
throw new IllegalStateException("SSLEngine trying to loop
NEED_UNWRAP without producing output")
case _ =>
- if (transportInBuffer.hasRemaining) doUnwrap(ignoreOutput = false)
+ if (transportInBuffer.hasRemaining) doUnwrap(ignoreOutput)
else flushToUser()
}
case CLOSED =>
@@ -459,18 +459,12 @@ import pekko.util.ByteString
}
}
- @tailrec
private def runDelegatedTasks(): Unit = {
- val task = engine.getDelegatedTask
- if (task ne null) {
- if (tracing) log.debug("running task")
- task.run()
- runDelegatedTasks()
- } else {
- val st = lastHandshakeStatus
- lastHandshakeStatus = engine.getHandshakeStatus
- if (tracing && st != lastHandshakeStatus) log.debug(s"handshake status
after tasks: $lastHandshakeStatus")
- }
+ val st = lastHandshakeStatus
+ val taskCount = TlsEngineHelpers.runDelegatedTasks(engine)
+ lastHandshakeStatus = engine.getHandshakeStatus
+ if (tracing && taskCount > 0) log.debug(s"ran $taskCount delegated TLS
task(s)")
+ if (tracing && st != lastHandshakeStatus) log.debug(s"handshake status
after tasks: $lastHandshakeStatus")
}
private def handshakeFinished(): Unit = {
diff --git
a/stream/src/main/scala/org/apache/pekko/stream/impl/io/TlsEngineHelpers.scala
b/stream/src/main/scala/org/apache/pekko/stream/impl/io/TlsEngineHelpers.scala
new file mode 100644
index 0000000000..c8fe38b6c1
--- /dev/null
+++
b/stream/src/main/scala/org/apache/pekko/stream/impl/io/TlsEngineHelpers.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.pekko.stream.impl.io
+
+import java.nio.ByteBuffer
+import javax.net.ssl.SSLEngine
+
+import org.apache.pekko.annotation.InternalApi
+
+/**
+ * INTERNAL API.
+ */
+@InternalApi private[stream] object TlsEngineHelpers {
+
+ @inline
+ def emptyReadBuffer(buffer: ByteBuffer): Unit = {
+ buffer.clear()
+ buffer.flip()
+ }
+
+ @inline
+ def prepareForAppend(buffer: ByteBuffer): Unit =
+ if (buffer.hasRemaining) buffer.compact()
+ else buffer.clear()
+
+ @inline
+ def runDelegatedTasks(engine: SSLEngine): Int = {
+ var count = 0
+ var task = engine.getDelegatedTask
+ while (task ne null) {
+ count += 1
+ task.run()
+ task = engine.getDelegatedTask
+ }
+ count
+ }
+
+ @inline
+ def hasCompleteTlsRecord(buffer: ByteBuffer): Boolean = {
+ val remaining = buffer.remaining()
+ if (remaining < TlsRecordHeaderSize) false
+ else {
+ val position = buffer.position()
+ val contentType = buffer.get(position) & 0xFF
+ val majorVersion = buffer.get(position + 1) & 0xFF
+
+ if (majorVersion != TlsMajorVersion || contentType < TlsMinContentType
|| contentType > TlsMaxContentType) true
+ else {
+ val packetLength = ((buffer.get(position + 3) & 0xFF) << 8) |
(buffer.get(position + 4) & 0xFF)
+ val frameLength = packetLength + TlsRecordHeaderSize
+ frameLength > buffer.capacity || remaining >= frameLength
+ }
+ }
+ }
+
+ private final val TlsRecordHeaderSize = 5
+ private final val TlsMajorVersion = 3
+ private final val TlsMinContentType = 20
+ private final val TlsMaxContentType = 23
+}
diff --git
a/stream/src/main/scala/org/apache/pekko/stream/impl/io/TlsGraphStage.scala
b/stream/src/main/scala/org/apache/pekko/stream/impl/io/TlsGraphStage.scala
new file mode 100644
index 0000000000..f700db110b
--- /dev/null
+++ b/stream/src/main/scala/org/apache/pekko/stream/impl/io/TlsGraphStage.scala
@@ -0,0 +1,1103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.pekko.stream.impl.io
+
+import java.nio.ByteBuffer
+import javax.net.ssl._
+import javax.net.ssl.SSLEngineResult.HandshakeStatus
+import javax.net.ssl.SSLEngineResult.HandshakeStatus._
+import javax.net.ssl.SSLEngineResult.Status._
+
+import scala.annotation.{ switch, tailrec }
+import scala.concurrent.duration.Duration
+import scala.util.{ Failure, Success, Try }
+import scala.util.control.NonFatal
+
+import org.apache.pekko
+import pekko.annotation.InternalApi
+import pekko.io.{ BufferPool, Tcp }
+import pekko.stream._
+import pekko.stream.TLSProtocol._
+import pekko.stream.impl.Stages.DefaultAttributes
+import pekko.stream.stage.{ GraphStage, InHandler, OutHandler,
TimerGraphStageLogic }
+import pekko.util.ByteString
+
+/**
+ * INTERNAL API.
+ *
+ * GraphStage adapter for Pekko's existing TLS pump state machine. The phases
+ * intentionally mirror [[TLSActor]] so the new implementation keeps the legacy
+ * close, renegotiation, and SSLEngine workaround semantics while avoiding the
+ * actor/FanoutProcessor substrate.
+ */
+@InternalApi private[stream] final class TlsGraphStage(
+ createSSLEngine: () => SSLEngine,
+ verifySession: SSLSession => Try[Unit],
+ closing: TLSClosing)
+ extends GraphStage[BidiShape[SslTlsOutbound, ByteString, ByteString,
SslTlsInbound]] {
+
+ private val plainIn = Inlet[SslTlsOutbound]("TlsGraphStage.plainIn")
+ private val plainOut = Outlet[SslTlsInbound]("TlsGraphStage.plainOut")
+ private val cipherIn = Inlet[ByteString]("TlsGraphStage.cipherIn")
+ private val cipherOut = Outlet[ByteString]("TlsGraphStage.cipherOut")
+
+ override val shape: BidiShape[SslTlsOutbound, ByteString, ByteString,
SslTlsInbound] =
+ BidiShape(plainIn, cipherOut, cipherIn, plainOut)
+
+ override def initialAttributes: Attributes = DefaultAttributes.tlsGraphStage
+
+ override def createLogic(inheritedAttributes: Attributes):
TimerGraphStageLogic =
+ new TimerGraphStageLogic(shape) { logic =>
+ import TlsGraphStage._
+
+ private var engine: SSLEngine = _
+ private var currentSession: SSLSession = _
+ private var lastHandshakeStatus: HandshakeStatus = NOT_HANDSHAKING
+ private var inlineHandshakeReentries = 0
+ private var phase: Byte = BidirectionalPhase.toByte
+ private var deferredTransportFlushes = 0
+ private var stateBits = 0
+
+ private var transportOutBuffer: ByteBuffer = _
+ private var userOutBuffer: ByteBuffer = _
+ private var transportInBuffer: ByteBuffer = _
+ private var transportUnreadBuffer: ByteBuffer = _
+ private var userInBuffer: ByteBuffer = _
+ private var bufferPool: BufferPool = _
+ private var pooledBufferSize = 0
+ private var transportPacketSize = 0
+ private var applicationRecordSize = 0
+ private var maxUserInBufferSize = 0
+ private var maxUserOutBufferSize = 0
+
+ private val userInput = new UserInputSlot(plainIn)
+ private val transportInput = new TransportInputSlot(cipherIn)
+ private val userOutput =
+ new OutputSlot[SslTlsInbound](plainOut, UserOutputCancelledFlag,
UserOutputCompletedFlag, UserOutputErroredFlag)
+ private val transportOutput =
+ new OutputSlot[ByteString](
+ cipherOut,
+ TransportOutputCancelledFlag,
+ TransportOutputCompletedFlag,
+ TransportOutputErroredFlag,
+ deferInitialPumpOutput = true)
+
+ setHandler(plainIn, userInput)
+ setHandler(cipherIn, transportInput)
+ setHandler(cipherOut, transportOutput)
+ setHandler(plainOut, userOutput)
+
+ override def preStart(): Unit =
+ try {
+ val tcp = Tcp(materializer.system)
+ bufferPool = tcp.bufferPool
+ pooledBufferSize = tcp.Settings.DirectBufferSize
+
+ engine = createSSLEngine()
+ currentSession = engine.getSession
+ allocateBuffers(currentSession)
+ engine.beginHandshake()
+ lastHandshakeStatus = engine.getHandshakeStatus
+ nextPhase(BidirectionalPhase)
+ pullInputs()
+ } catch {
+ case NonFatal(ex) => failTls(ex)
+ }
+
+ override protected def onTimer(timerKey: Any): Unit =
+ if (timerKey == InitialOutputFlushTimer && (stateBits &
StageClosedFlag) == 0) {
+ transportOutput.pushPending()
+ runPump()
+ } else if (timerKey == TransportFlushTimer && (stateBits &
TransportFlushPendingFlag) != 0) {
+ stateBits &= ~TransportFlushPendingFlag
+ flushToTransport()
+ runPump()
+ } else if (timerKey == UserInputBatchTimer && (stateBits &
UserInputBatchPendingFlag) != 0) {
+ stateBits &= ~UserInputBatchPendingFlag
+ runPump()
+ }
+
+ private def allocateBuffers(session: SSLSession): Unit = {
+ val packetSize = math.max(1, session.getPacketBufferSize)
+ val applicationSize = math.max(1, session.getApplicationBufferSize)
+ val applicationBatchSize = applicationBufferSize(applicationSize,
packetSize)
+ transportPacketSize = packetSize
+ applicationRecordSize = applicationSize
+ maxUserInBufferSize = applicationBatchSize
+ maxUserOutBufferSize = applicationBatchSize
+
+ transportOutBuffer = acquireTransportBuffer(packetSize)
+ transportInBuffer = acquireTransportBuffer(packetSize)
+ transportUnreadBuffer = acquireTransportBuffer(packetSize)
+ userInBuffer = ByteBuffer.allocate(applicationSize)
+ userOutBuffer = ByteBuffer.allocate(applicationSize)
+
+ TlsEngineHelpers.emptyReadBuffer(userInBuffer)
+ TlsEngineHelpers.emptyReadBuffer(transportInBuffer)
+ TlsEngineHelpers.emptyReadBuffer(transportUnreadBuffer)
+ }
+
+ private def applicationBufferSize(applicationSize: Int, packetSize:
Int): Int = {
+ val recordsThatFitTransportBuffer = math.max(1, pooledBufferSize /
packetSize)
+ val recordsPerEngineCall =
math.min(MaxApplicationRecordsPerEngineCall, recordsThatFitTransportBuffer)
+ applicationSize * recordsPerEngineCall
+ }
+
+ private def runPump(): Unit =
+ if ((stateBits & (StageClosedFlag | StageCompletionPendingFlag)) == 0
&&
+ (stateBits & InitialPumpEnabledFlag) != 0) {
+ pump()
+ if ((stateBits & (StageClosedFlag | StageCompletionPendingFlag)) ==
0) {
+ pullInputs()
+ schedulePendingTransportFlush()
+ }
+ }
+
+ private def pump(): Unit = {
+ try {
+ while (phaseExecutable)
+ runPhase()
+ } catch {
+ case NonFatal(ex) => failTls(ex)
+ }
+
+ if (phaseCompleted)
+ pumpFinished()
+ }
+
+ private def nextPhase(next: Int): Unit =
+ phase = next.toByte
+
+ private def phaseExecutable: Boolean =
+ (phase.toInt: @switch) match {
+ case BidirectionalPhase => bidirectionalReady &&
!bidirectionalCompleted
+ case FlushingOutboundPhase => outboundHalfClosedReady &&
!outboundHalfClosedCompleted
+ case AwaitingClosePhase => awaitingCloseReady &&
!awaitingCloseCompleted
+ case OutboundClosedPhase => (outboundHalfClosedReady ||
inboundReady) && !outboundClosedCompleted
+ case InboundClosedPhase => (outboundReady ||
inboundHalfClosedReady) && !inboundClosedCompleted
+ case _ => false
+ }
+
+ private def phaseCompleted: Boolean =
+ (phase.toInt: @switch) match {
+ case BidirectionalPhase => bidirectionalCompleted
+ case FlushingOutboundPhase => outboundHalfClosedCompleted
+ case AwaitingClosePhase => awaitingCloseCompleted
+ case OutboundClosedPhase => outboundClosedCompleted
+ case InboundClosedPhase => inboundClosedCompleted
+ case CompletedPhase => true
+ case _ => false
+ }
+
+ private def runPhase(): Unit =
+ (phase.toInt: @switch) match {
+ case BidirectionalPhase =>
+ val continue = doInbound(isOutboundClosed = false,
inboundHalfClosedMode = false)
+ if (continue) doOutbound(isInboundClosed = false)
+
+ case FlushingOutboundPhase =>
+ try doWrap()
+ catch { case _: SSLException => nextPhase(CompletedPhase) }
+
+ case AwaitingClosePhase =>
+ transportInput.chopInto(transportInBuffer)
+ if (TlsEngineHelpers.hasCompleteTlsRecord(transportInBuffer))
+ try doUnwrap(ignoreOutput = true)
+ catch { case _: SSLException => nextPhase(CompletedPhase) }
+
+ case OutboundClosedPhase =>
+ val continue = doInbound(isOutboundClosed = true,
inboundHalfClosedMode = false)
+ if (continue && outboundHalfClosedReady) {
+ try doWrap()
+ catch { case _: SSLException => nextPhase(CompletedPhase) }
+ }
+
+ case InboundClosedPhase =>
+ val continue = doInbound(isOutboundClosed = false,
inboundHalfClosedMode = true)
+ if (continue) doOutbound(isInboundClosed = true)
+
+ case _ =>
+ }
+
+ private def userHasDataReady: Boolean = (stateBits &
PlainDataAllowedFlag) != 0 &&
+ (userInBufferHasRemaining || userInput.isReady) &&
+ lastHandshakeStatus != NEED_UNWRAP
+
+ private def userHasDataCompleted: Boolean =
+ userInput.isCancelled || userInput.isDepleted
+
+ private def engineNeedsWrapReady: Boolean =
+ lastHandshakeStatus == NEED_WRAP
+
+ private def engineNeedsWrapCompleted: Boolean =
+ engine.isOutboundDone
+
+ private def userOrEngineOutboundReady: Boolean =
+ userHasDataReady || engineNeedsWrapReady
+
+ private def userOrEngineOutboundCompleted: Boolean =
+ userHasDataCompleted && engineNeedsWrapCompleted
+
+ private def bidirectionalReady: Boolean =
+ outboundReady || inboundReady
+
+ private def bidirectionalCompleted: Boolean =
+ outboundCompleted && inboundCompleted
+
+ private def outboundReady: Boolean =
+ userOrEngineOutboundReady && transportOutput.isDemandAvailable
+
+ private def outboundCompleted: Boolean =
+ userOrEngineOutboundCompleted || transportOutput.isClosed
+
+ private def inboundReady: Boolean = (transportInput.isReady &&
userOutput.isDemandAvailable) ||
+ userOutput.isCancelled
+
+ private def inboundCompleted: Boolean = (transportInput.isCompleted ||
userOutput.isClosed) &&
+ (engine.isInboundDone || userOutput.isErrored)
+
+ private def outboundHalfClosedReady: Boolean =
+ engineNeedsWrapReady && transportOutput.isDemandAvailable
+
+ private def outboundHalfClosedCompleted: Boolean =
+ engineNeedsWrapCompleted || transportOutput.isClosed
+
+ private def inboundHalfClosedReady: Boolean =
+ transportInput.isReady && !engine.isInboundDone
+
+ private def inboundHalfClosedCompleted: Boolean =
+ transportInput.isCompleted || engine.isInboundDone
+
+ private def outboundClosedCompleted: Boolean =
+ outboundHalfClosedCompleted && inboundCompleted
+
+ private def inboundClosedCompleted: Boolean =
+ outboundCompleted && inboundHalfClosedCompleted
+
+ private def awaitingCloseReady: Boolean =
+ transportInput.isPending && !engine.isInboundDone
+
+ private def awaitingCloseCompleted: Boolean =
+ transportInput.isDepleted || transportInput.isCancelled ||
engine.isInboundDone
+
+ private def pullInputs(): Unit = {
+ userInput.pullIfNeeded(!userInBufferHasRemaining)
+ transportInput.pullIfNeeded(transportInput.isEmpty)
+ }
+
+ private def completeOrFlush(): Unit =
+ if (engine.isOutboundDone || (engine.isInboundDone &&
userInput.isEmpty && !userInBufferHasRemaining))
+ nextPhase(CompletedPhase)
+ else nextPhase(FlushingOutboundPhase)
+
+ private def doInbound(isOutboundClosed: Boolean, inboundHalfClosedMode:
Boolean): Boolean =
+ if (transportInput.isDepleted && transportInput.isEmpty) {
+ try engine.closeInbound()
+ catch {
+ case _: SSLException => userOutput.enqueue(SessionTruncated)
+ }
+ lastHandshakeStatus = engine.getHandshakeStatus
+ completeOrFlush()
+ false
+ } else if (!inboundHalfClosedMode && userOutput.isCancelled) {
+ if (!isOutboundClosed && closing.ignoreCancel) {
+ nextPhase(InboundClosedPhase)
+ } else {
+ engine.closeOutbound()
+ lastHandshakeStatus = engine.getHandshakeStatus
+ nextPhase(FlushingOutboundPhase)
+ }
+ true
+ } else if (if (inboundHalfClosedMode) inboundHalfClosedReady else
inboundReady) {
+ transportInput.chopInto(transportInBuffer)
+ if (TlsEngineHelpers.hasCompleteTlsRecord(transportInBuffer))
+ try {
+ doUnwrap(ignoreOutput = inboundHalfClosedMode ||
userOutput.isCancelled)
+ true
+ } catch {
+ case ex: SSLException =>
+ failTls(ex, closeTransport = false)
+ refreshHandshakeStatus()
+ try engine.closeInbound()
+ catch { case _: SSLException => () }
+ completeOrFlush()
+ false
+ }
+ else true
+ } else true
+
+ private def doOutbound(isInboundClosed: Boolean): Unit =
+ if (userInput.isDepleted && userInput.isEmpty &&
!userInBufferHasRemaining && mayCloseOutbound) {
+ if (transportOutBuffer.position() > 0)
+ flushToTransport()
+ if (!isInboundClosed && closing.ignoreComplete) {
+ // keep the outbound TLS side open
+ } else {
+ engine.closeOutbound()
+ lastHandshakeStatus = engine.getHandshakeStatus
+ }
+ nextPhase(OutboundClosedPhase)
+ } else if (transportOutput.isCancelled) {
+ nextPhase(CompletedPhase)
+ } else if (outboundReady) {
+ if (userHasDataReady && !userInBufferHasRemaining)
+ userInput.chopInto(userInBuffer)
+ try doWrap()
+ catch {
+ case ex: SSLException =>
+ failTls(ex, closeTransport = false)
+ refreshHandshakeStatus()
+ completeOrFlush()
+ }
+ }
+
+ private def mayCloseOutbound: Boolean =
+ lastHandshakeStatus match {
+ case HandshakeStatus.NOT_HANDSHAKING | HandshakeStatus.FINISHED =>
true
+ case _ =>
false
+ }
+
+ private def doWrap(): Unit = {
+ val result = wrapIntoTransportBuffer()
+ lastHandshakeStatus = result.getHandshakeStatus
+
+ if (lastHandshakeStatus == FINISHED) handshakeFinished()
+ runDelegatedTasks()
+
+ result.getStatus match {
+ case OK =>
+ if (transportOutBuffer.position() == 0 && lastHandshakeStatus ==
NEED_WRAP)
+ throw new IllegalStateException("SSLEngine trying to loop
NEED_WRAP without producing output")
+
+ flushToTransportIfNeeded(result)
+
+ case CLOSED =>
+ flushToTransport()
+ if (engine.isInboundDone) nextPhase(CompletedPhase)
+ else nextPhase(AwaitingClosePhase)
+
+ case BUFFER_OVERFLOW =>
+ growTransportOutBuffer()
+
+ case status =>
+ failTls(new IllegalStateException(s"unexpected status $status in
doWrap()"))
+ }
+ }
+
+ private def wrapIntoTransportBuffer(): SSLEngineResult = {
+ var result = engine.wrap(userInBuffer, transportOutBuffer)
+
+ while (canContinueWrapping(result)) {
+ val userPosition = userInBuffer.position()
+ val transportPosition = transportOutBuffer.position()
+ result = engine.wrap(userInBuffer, transportOutBuffer)
+ if (userInBuffer.position() == userPosition &&
transportOutBuffer.position() == transportPosition)
+ return result
+ }
+
+ result
+ }
+
+ private def canContinueWrapping(result: SSLEngineResult): Boolean =
+ result.getStatus == OK &&
+ result.getHandshakeStatus == NOT_HANDSHAKING &&
+ userInBuffer.hasRemaining &&
+ transportOutBuffer.remaining >= transportPacketSize
+
+ @tailrec
+ private def doUnwrap(ignoreOutput: Boolean): Unit = {
+ if (!ignoreOutput && transportInBuffer.remaining > transportPacketSize
&&
+ userOutBuffer.capacity < maxUserOutBufferSize)
+ growUserOutBuffer()
+
+ val oldInPosition = transportInBuffer.position()
+ val result = engine.unwrap(transportInBuffer, userOutBuffer)
+ if (ignoreOutput) userOutBuffer.clear()
+ lastHandshakeStatus = result.getHandshakeStatus
+ runDelegatedTasks()
+
+ result.getStatus match {
+ case OK =>
+ result.getHandshakeStatus match {
+ case NEED_WRAP =>
+ inlineHandshakeReentries += 1
+ if (inlineHandshakeReentries > MaxInlineEngineProgressAttempts)
+ throw new IllegalStateException(
+ s"Stuck in unwrap loop, bailing out, last handshake status
[$lastHandshakeStatus], " +
+ s"remaining=${transportInBuffer.remaining},
out=${userOutBuffer.position()}, " +
+ "(https://github.com/apache/pekko/issues/442)")
+
+ transportInput.putBackUnreadBuffer(transportInBuffer)
+
+ case FINISHED =>
+ flushToUser()
+ handshakeFinished()
+ transportInput.putBackUnreadBuffer(transportInBuffer)
+
+ case NEED_UNWRAP
+ if transportInBuffer.hasRemaining &&
+ userOutBuffer.position() == 0 &&
+ transportInBuffer.position() == oldInPosition =>
+ throw new IllegalStateException("SSLEngine trying to loop
NEED_UNWRAP without producing output")
+
+ case _ =>
+ if (transportInBuffer.hasRemaining) doUnwrap(ignoreOutput)
+ else flushToUser()
+ }
+
+ case CLOSED =>
+ flushToUser()
+ completeOrFlush()
+
+ case BUFFER_UNDERFLOW =>
+ flushToUser()
+
+ case BUFFER_OVERFLOW =>
+ flushToUser()
+ transportInput.putBackUnreadBuffer(transportInBuffer)
+
+ case null =>
+ failTls(new IllegalStateException("unexpected status null in
doUnwrap()"))
+ }
+ }
+
+ private def runDelegatedTasks(): Unit =
+ if (TlsEngineHelpers.runDelegatedTasks(engine) > 0 ||
lastHandshakeStatus == NEED_TASK)
+ lastHandshakeStatus = engine.getHandshakeStatus
+
+ private def handshakeFinished(): Unit = {
+ val session = engine.getSession
+ verifySession(session) match {
+ case Success(()) =>
+ currentSession = session
+ stateBits |= PlainDataAllowedFlag
+ flushToUser()
+
+ case Failure(ex) =>
+ throw ex
+ }
+ }
+
+ private def setNewSessionParameters(params: NegotiateNewSession): Unit =
{
+ if (transportOutBuffer.position() > 0)
+ flushToTransport()
+ currentSession.invalidate()
+ TlsUtils.applySessionParameters(engine, params)
+ engine.beginHandshake()
+ lastHandshakeStatus = engine.getHandshakeStatus
+ stateBits &= ~PlainDataAllowedFlag
+ }
+
+ private def flushToTransport(): Unit = {
+ if ((stateBits & TransportFlushPendingFlag) != 0) {
+ stateBits &= ~TransportFlushPendingFlag
+ cancelTimer(TransportFlushTimer)
+ }
+ deferredTransportFlushes = 0
+ transportOutBuffer.flip()
+ if (transportOutBuffer.hasRemaining && !transportOutput.isClosed)
+ transportOutput.enqueue(ByteString(transportOutBuffer))
+ transportOutBuffer.clear()
+ }
+
+ private def flushToTransportIfNeeded(result: SSLEngineResult): Unit =
+ if (transportOutBuffer.position() == 0) {
+ if ((stateBits & TransportFlushPendingFlag) != 0) {
+ stateBits &= ~TransportFlushPendingFlag
+ cancelTimer(TransportFlushTimer)
+ }
+ deferredTransportFlushes = 0
+ } else if (shouldFlushTransportNow(result)) {
+ flushToTransport()
+ } else {
+ deferredTransportFlushes += 1
+ if (deferredTransportFlushes >= DeferredTransportFlushLimit) {
+ flushToTransport()
+ } else {
+ stateBits |= TransportFlushPendingFlag
+ }
+ }
+
+ private def schedulePendingTransportFlush(): Unit =
+ if ((stateBits & TransportFlushPendingFlag) != 0) {
+ if (transportOutput.isClosed) {
+ stateBits &= ~TransportFlushPendingFlag
+ deferredTransportFlushes = 0
+ transportOutBuffer.clear()
+ } else if (!isTimerActive(TransportFlushTimer)) {
+ scheduleOnce(TransportFlushTimer, Duration.Zero)
+ }
+ }
+
+ private def shouldFlushTransportNow(result: SSLEngineResult): Boolean =
+ if (result.getHandshakeStatus != NOT_HANDSHAKING ||
+ engine.isOutboundDone ||
+ (stateBits & PlainDataAllowedFlag) == 0 ||
+ transportOutput.isCancelled) {
+ true
+ } else if (transportOutBuffer.position() >= transportPacketSize) {
+ true
+ } else shouldFlushApplicationDataNow(result.bytesConsumed())
+
+ private def shouldFlushApplicationDataNow(bytesConsumed: Int): Boolean =
+ bytesConsumed <= 0 || bytesConsumed > MaxBatchedApplicationBytes ||
+ bytesConsumed >= applicationRecordSize
+
+ private def flushToUser(): Unit = {
+ if (inlineHandshakeReentries > 0) inlineHandshakeReentries = 0
+ userOutBuffer.flip()
+ if (userOutBuffer.hasRemaining)
+ userOutput.enqueue(SessionBytes(currentSession,
ByteString(userOutBuffer)))
+ userOutBuffer.clear()
+ }
+
+ private def growTransportOutBuffer(): Unit = {
+ val oldBuffer = transportOutBuffer
+ val oldCapacity = oldBuffer.capacity()
+ if (oldCapacity > Int.MaxValue / 2)
+ throw new IllegalStateException(s"Cannot grow TLS transport output
buffer beyond $oldCapacity bytes")
+
+ val bigger = acquireTransportBuffer(oldCapacity * 2)
+ transportOutBuffer.flip()
+ bigger.put(transportOutBuffer)
+ releaseTransportBuffer(oldBuffer)
+ transportOutBuffer = bigger
+ }
+
+ private def failTls(ex: Throwable): Unit = failTls(ex, closeTransport =
true)
+
+ private def failTls(ex: Throwable, closeTransport: Boolean): Unit =
+ if ((stateBits & StageClosedFlag) == 0) {
+ userInput.cancel()
+ transportInput.cancel()
+ if (closeTransport) {
+ stateBits |= StageClosedFlag
+ failStage(ex)
+ } else {
+ writeFailureAlert()
+ userOutput.failOutput(ex)
+ }
+ }
+
+ private val EmptyApplicationBuffer = ByteBuffer.allocate(0)
+
+ private def writeFailureAlert(): Unit =
+ if ((engine ne null) && !transportOutput.isCancelled)
+ try {
+ if (!engine.isOutboundDone)
+ engine.closeOutbound()
+ val result = engine.wrap(EmptyApplicationBuffer,
transportOutBuffer)
+ lastHandshakeStatus = result.getHandshakeStatus
+ flushToTransport()
+ EmptyApplicationBuffer.clear()
+ } catch {
+ case NonFatal(_) => EmptyApplicationBuffer.clear()
+ }
+
+ private def refreshHandshakeStatus(): Unit =
+ if (engine ne null)
+ try {
+ lastHandshakeStatus = engine.getHandshakeStatus
+ } catch {
+ case NonFatal(_) => ()
+ }
+
+ private def pumpFinished(): Unit =
+ if ((stateBits & StageClosedFlag) == 0) {
+ userInput.cancel()
+ transportInput.cancel()
+ transportOutput.completeOutput()
+ userOutput.completeOutput()
+ stateBits |= StageCompletionPendingFlag
+ completeStageIfReady()
+ }
+
+ private def completeStageIfReady(): Unit =
+ if ((stateBits & StageCompletionPendingFlag) != 0 &&
+ (stateBits & StageClosedFlag) == 0 &&
+ !userOutput.hasPending &&
+ !transportOutput.hasPending) {
+ stateBits |= StageClosedFlag
+ completeStage()
+ }
+
+ override def postStop(): Unit = {
+ releaseBuffers()
+ super.postStop()
+ }
+
+ private def userInBufferHasRemaining: Boolean = (userInBuffer ne null)
&& userInBuffer.hasRemaining
+
+ private def growUserInBuffer(availableBytes: Int): ByteBuffer = {
+ val requested = math.min(maxUserInBufferSize,
math.max(applicationRecordSize, availableBytes))
+ val records = (requested + applicationRecordSize - 1) /
applicationRecordSize
+ val capacity = math.min(maxUserInBufferSize, records *
applicationRecordSize)
+ val expanded = ByteBuffer.allocate(capacity)
+ userInBuffer = expanded
+ expanded
+ }
+
+ private def growUserOutBuffer(): Unit = {
+ val expanded = ByteBuffer.allocate(maxUserOutBufferSize)
+ userOutBuffer.flip()
+ expanded.put(userOutBuffer)
+ userOutBuffer = expanded
+ }
+
+ private def releaseBuffers(): Unit = {
+ releaseTransportBuffer(transportOutBuffer)
+ transportOutBuffer = null
+ userOutBuffer = null
+ releaseTransportBuffer(transportInBuffer)
+ transportInBuffer = null
+ releaseTransportBuffer(transportUnreadBuffer)
+ transportUnreadBuffer = null
+ userInBuffer = null
+ }
+
+ private def acquireTransportBuffer(requiredCapacity: Int): ByteBuffer = {
+ val buffer = bufferPool.acquire()
+ if (buffer.capacity >= requiredCapacity) buffer
+ else {
+ releaseTransportBuffer(buffer)
+ throw new IllegalStateException(
+ s"TLS packet buffer requires $requiredCapacity bytes but
pekko.io.tcp.direct-buffer-size is $pooledBufferSize")
+ }
+ }
+
+ private def releaseTransportBuffer(buffer: ByteBuffer): Unit =
+ if (buffer ne null) {
+ buffer.clear()
+ bufferPool.release(buffer)
+ }
+
+ private final class UserInputSlot(inlet: Inlet[SslTlsOutbound]) extends
InHandler {
+ private val pendingQueue = new Array[AnyRef](UserInputQueueSize)
+ private var pendingHead = 0
+ private var pendingCount = 0
+ private var bytes: ByteString = ByteString.empty
+ private var bytesOffset = 0
+ private var bytesRemaining = 0
+
+ def isReady: Boolean = bytesRemaining > 0 || pendingCount != 0 ||
isDepleted
+ def isCompleted: Boolean = (stateBits & UserInputCancelledFlag) != 0
+ def isEmpty: Boolean = bytesRemaining == 0 && pendingCount == 0
+ def isPending: Boolean = pendingCount != 0
+ def isDepleted: Boolean = (stateBits & UserInputFinishedFlag) != 0 &&
pendingCount == 0 && bytesRemaining == 0
+ def isCancelled: Boolean = (stateBits & UserInputCancelledFlag) != 0
+
+ override def onPush(): Unit = {
+ stateBits |= InitialPumpEnabledFlag
+ val elem = grab(inlet)
+ enqueue(elem)
+ pullIfNeeded(bytesDrained = true)
+
+ if (shouldDefer(elem) && pendingCount < UserInputQueueSize) {
+ stateBits |= UserInputBatchPendingFlag
+ if (!isTimerActive(UserInputBatchTimer))
+ scheduleOnce(UserInputBatchTimer, Duration.Zero)
+ } else {
+ cancelPendingBatchTimer()
+ runPump()
+ }
+ }
+
+ override def onUpstreamFinish(): Unit = {
+ stateBits |= InitialPumpEnabledFlag
+ stateBits |= UserInputFinishedFlag
+ cancelPendingBatchTimer()
+ runPump()
+ }
+
+ override def onUpstreamFailure(ex: Throwable): Unit =
+ failTls(ex)
+
+ private def shouldDefer(elem: SslTlsOutbound): Boolean =
+ elem match {
+ case SendBytes(bs) =>
+ bs.size <= MaxBatchedApplicationBytes &&
+ (stateBits & PlainDataAllowedFlag) != 0 &&
+ (lastHandshakeStatus == NOT_HANDSHAKING || lastHandshakeStatus
== FINISHED) &&
+ !userInBufferHasRemaining &&
+ !engine.isOutboundDone &&
+ !transportOutput.isClosed
+
+ case _ => false
+ }
+
+ private def cancelPendingBatchTimer(): Unit =
+ if ((stateBits & UserInputBatchPendingFlag) != 0) {
+ stateBits &= ~UserInputBatchPendingFlag
+ if (isTimerActive(UserInputBatchTimer))
+ cancelTimer(UserInputBatchTimer)
+ }
+
+ private def enqueue(elem: SslTlsOutbound): Unit = {
+ if (pendingCount >= UserInputQueueSize)
+ throw new IllegalStateException("TLS user input queue is full")
+
+ pendingQueue((pendingHead + pendingCount) & UserInputQueueMask) =
elem.asInstanceOf[AnyRef]
+ pendingCount += 1
+ stateBits |= UserInputHasPendingFlag
+ }
+
+ private def peek(): SslTlsOutbound =
+ pendingQueue(pendingHead).asInstanceOf[SslTlsOutbound]
+
+ private def dequeue(): SslTlsOutbound = {
+ val elem = pendingQueue(pendingHead).asInstanceOf[SslTlsOutbound]
+ pendingQueue(pendingHead) = null
+ pendingHead = (pendingHead + 1) & UserInputQueueMask
+ pendingCount -= 1
+ if (pendingCount == 0) {
+ pendingHead = 0
+ stateBits &= ~UserInputHasPendingFlag
+ }
+ elem
+ }
+
+ def pullIfNeeded(bytesDrained: Boolean): Unit =
+ if (bytesDrained &&
+ bytesRemaining == 0 &&
+ pendingCount < UserInputQueueSize &&
+ (stateBits & UserInputFinishedFlag) == 0 &&
+ (stateBits & UserInputCancelledFlag) == 0 &&
+ !hasBeenPulled(inlet))
+ pull(inlet)
+
+ def cancel(): Unit =
+ if ((stateBits & UserInputCancelledFlag) == 0) {
+ stateBits |= UserInputCancelledFlag
+ cancelPendingBatchTimer()
+ clearPending()
+ clearBytes()
+ if ((stateBits & UserInputFinishedFlag) == 0 &&
!logic.isClosed(inlet)) logic.cancel(inlet)
+ stateBits |= UserInputFinishedFlag
+ }
+
+ private def clearPending(): Unit = {
+ var i = 0
+ while (i < pendingCount) {
+ pendingQueue((pendingHead + i) & UserInputQueueMask) = null
+ i += 1
+ }
+ pendingHead = 0
+ pendingCount = 0
+ stateBits &= ~UserInputHasPendingFlag
+ }
+
+ private def setBytes(next: ByteString): Unit = {
+ bytes = next
+ bytesOffset = 0
+ bytesRemaining = next.size
+ }
+
+ private def clearBytes(): Unit = {
+ bytes = ByteString.empty
+ bytesOffset = 0
+ bytesRemaining = 0
+ }
+
+ def chopInto(buffer: ByteBuffer): Unit = {
+ var target = buffer
+ TlsEngineHelpers.prepareForAppend(target)
+
+ var continue = true
+ while (continue && target.hasRemaining && (bytesRemaining > 0 ||
pendingCount != 0)) {
+ if (bytesRemaining == 0 && pendingCount != 0) {
+ peek() match {
+ case SendBytes(bs) =>
+ if (bs.isEmpty) dequeue()
+ else if (bs.size > target.remaining && target.position() !=
0) continue = false
+ else {
+ if (bs.size > target.remaining && target.capacity <
maxUserInBufferSize)
+ target = growUserInBuffer(bs.size)
+ dequeue()
+ setBytes(bs)
+ }
+
+ case n: NegotiateNewSession =>
+ if (target.position() == 0) {
+ dequeue()
+ setNewSessionParameters(n)
+ }
+ continue = false
+
+ case other =>
+ throw new IllegalArgumentException(s"Unexpected TLS input
element: $other")
+ }
+ }
+
+ if (continue && bytesRemaining > 0) {
+ if (bytesRemaining > target.remaining && target.position() == 0
&& target.capacity < maxUserInBufferSize)
+ target = growUserInBuffer(bytesRemaining)
+
+ val copied = bytes.copyToBuffer(target, bytesOffset)
+ if (copied == bytesRemaining) clearBytes()
+ else if (copied > 0) {
+ bytesOffset += copied
+ bytesRemaining -= copied
+ continue = false
+ } else continue = false
+ }
+ }
+
+ target.flip()
+ }
+ }
+
+ private final class TransportInputSlot(inlet: Inlet[ByteString]) extends
InHandler {
+ private var pending: ByteString = _
+ private var bytes: ByteString = ByteString.empty
+ private var bytesOffset = 0
+ private var bytesRemaining = 0
+
+ private def unreadBufferHasRemaining: Boolean = (transportUnreadBuffer
ne null) &&
+ transportUnreadBuffer.hasRemaining
+
+ def isReady: Boolean =
+ unreadBufferHasRemaining || bytesRemaining > 0 || (stateBits &
TransportInputHasPendingFlag) != 0 ||
+ isDepleted
+ def isCompleted: Boolean = (stateBits & TransportInputCancelledFlag)
!= 0
+ def isEmpty: Boolean = bytesRemaining == 0 && !unreadBufferHasRemaining
+ def isPending: Boolean =
+ unreadBufferHasRemaining || bytesRemaining > 0 || (stateBits &
TransportInputHasPendingFlag) != 0
+ def isDepleted: Boolean = (stateBits & TransportInputFinishedFlag) !=
0 &&
+ (stateBits & TransportInputHasPendingFlag) == 0 &&
+ bytesRemaining == 0 &&
+ !unreadBufferHasRemaining
+ def isCancelled: Boolean = (stateBits & TransportInputCancelledFlag)
!= 0
+
+ override def onPush(): Unit = {
+ stateBits |= InitialPumpEnabledFlag
+ pending = grab(inlet)
+ stateBits |= TransportInputHasPendingFlag
+ runPump()
+ }
+
+ override def onUpstreamFinish(): Unit = {
+ stateBits |= InitialPumpEnabledFlag
+ stateBits |= TransportInputFinishedFlag
+ runPump()
+ }
+
+ override def onUpstreamFailure(ex: Throwable): Unit =
+ failTls(ex)
+
+ private def dequeue(): ByteString = {
+ val elem = pending
+ pending = null
+ stateBits &= ~TransportInputHasPendingFlag
+ elem
+ }
+
+ def pullIfNeeded(bytesDrained: Boolean): Unit =
+ if (bytesDrained &&
+ (stateBits & TransportInputHasPendingFlag) == 0 &&
+ (stateBits & TransportInputFinishedFlag) == 0 &&
+ (stateBits & TransportInputCancelledFlag) == 0 &&
+ !hasBeenPulled(inlet))
+ pull(inlet)
+
+ def cancel(): Unit =
+ if ((stateBits & TransportInputCancelledFlag) == 0) {
+ stateBits |= TransportInputCancelledFlag
+ pending = null
+ clearBytes()
+ stateBits &= ~TransportInputHasPendingFlag
+ if ((stateBits & TransportInputFinishedFlag) == 0 &&
!logic.isClosed(inlet)) logic.cancel(inlet)
+ stateBits |= TransportInputFinishedFlag
+ }
+
+ private def setBytes(next: ByteString): Unit = {
+ bytes = next
+ bytesOffset = 0
+ bytesRemaining = next.size
+ }
+
+ private def clearBytes(): Unit = {
+ bytes = ByteString.empty
+ bytesOffset = 0
+ bytesRemaining = 0
+ }
+
+ def chopInto(buffer: ByteBuffer): Unit = {
+ TlsEngineHelpers.prepareForAppend(buffer)
+ drainUnreadBufferInto(buffer)
+ if (bytesRemaining == 0 && (stateBits &
TransportInputHasPendingFlag) != 0)
+ setBytes(dequeue())
+
+ val copied = bytes.copyToBuffer(buffer, bytesOffset)
+ if (copied == bytesRemaining) clearBytes()
+ else if (copied > 0) {
+ bytesOffset += copied
+ bytesRemaining -= copied
+ }
+ buffer.flip()
+ }
+
+ private def drainUnreadBufferInto(buffer: ByteBuffer): Unit =
+ if (transportUnreadBuffer.hasRemaining) {
+ val unreadLimit = transportUnreadBuffer.limit()
+ if (transportUnreadBuffer.remaining > buffer.remaining)
+ transportUnreadBuffer.limit(transportUnreadBuffer.position() +
buffer.remaining)
+
+ buffer.put(transportUnreadBuffer)
+ transportUnreadBuffer.limit(unreadLimit)
+
+ if (!transportUnreadBuffer.hasRemaining)
+ TlsEngineHelpers.emptyReadBuffer(transportUnreadBuffer)
+ }
+
+ def putBackUnreadBuffer(buffer: ByteBuffer): Unit = {
+ if (buffer.hasRemaining) {
+ TlsEngineHelpers.prepareForAppend(transportUnreadBuffer)
+ if (buffer.remaining > transportUnreadBuffer.remaining)
+ throw new IllegalStateException(
+ s"TLS unread transport data requires ${buffer.remaining} bytes
but only " +
+ s"${transportUnreadBuffer.remaining} bytes are available")
+ transportUnreadBuffer.put(buffer)
+ transportUnreadBuffer.flip()
+ }
+ TlsEngineHelpers.emptyReadBuffer(buffer)
+ }
+ }
+
+ private final class OutputSlot[T](
+ outlet: Outlet[T],
+ cancelledFlag: Int,
+ completedFlag: Int,
+ erroredFlag: Int,
+ deferInitialPumpOutput: Boolean = false)
+ extends OutHandler {
+ private var pending: T = _
+
+ def isCancelled: Boolean = (stateBits & cancelledFlag) != 0
+ def isErrored: Boolean = (stateBits & erroredFlag) != 0
+ def hasPending: Boolean = pending != null
+ def hasDemand: Boolean = !isClosed && isAvailable(outlet)
+ def isClosed: Boolean = (stateBits & (cancelledFlag | completedFlag |
erroredFlag)) != 0 ||
+ logic.isClosed(outlet)
+ def isDemandAvailable: Boolean = pending == null && !isClosed &&
isAvailable(outlet)
+
+ override def onPull(): Unit = {
+ if (deferInitialPumpOutput && (stateBits & InitialPumpEnabledFlag)
== 0) {
+ stateBits |= InitialPumpEnabledFlag
+ stateBits |= InitialPumpDeferringOutputFlag
+ runPump()
+ stateBits &= ~InitialPumpDeferringOutputFlag
+ if (pending != null && isAvailable(outlet))
+ scheduleOnce(InitialOutputFlushTimer, Duration.Zero)
+ } else {
+ pushPending()
+ completeStageIfReady()
+ runPump()
+ }
+ }
+
+ def pushPending(): Unit =
+ if (pending != null &&
+ (stateBits & (cancelledFlag | erroredFlag)) == 0 &&
+ !logic.isClosed(outlet) &&
+ isAvailable(outlet)) {
+ val elem = pending
+ pending = null.asInstanceOf[T]
+ push(outlet, elem)
+ if ((stateBits & completedFlag) != 0 && !logic.isClosed(outlet))
+ logic.complete(outlet)
+ }
+
+ override def onDownstreamFinish(cause: Throwable): Unit = {
+ onCancel()
+ completeStageIfReady()
+ runPump()
+ }
+
+ def enqueue(elem: T): Unit =
+ if (!isClosed) {
+ if (pending == null && isAvailable(outlet) &&
!shouldDeferInitialOutput) push(outlet, elem)
+ else if (pending == null) pending = elem
+ else throw new IllegalStateException("TLS output slot already has
pending data")
+ }
+
+ private def shouldDeferInitialOutput: Boolean =
+ deferInitialPumpOutput && (stateBits &
InitialPumpDeferringOutputFlag) != 0
+
+ def onCancel(): Unit = {
+ stateBits |= cancelledFlag
+ pending = null.asInstanceOf[T]
+ }
+
+ def completeOutput(): Unit =
+ if (!isClosed) {
+ stateBits |= completedFlag
+ if (pending == null)
+ logic.complete(outlet)
+ }
+
+ def failOutput(ex: Throwable): Unit =
+ if (!isClosed) {
+ stateBits |= erroredFlag
+ logic.fail(outlet, ex)
+ }
+ }
+ }
+}
+
+@InternalApi private[pekko] object TlsGraphStage {
+ private[pekko] final val TransportFlushTimer = "TlsGraphStageTransportFlush"
+ private[pekko] final val UserInputBatchTimer = "TlsGraphStageUserInputBatch"
+ private[pekko] final val InitialOutputFlushTimer =
"TlsGraphStageInitialOutputFlush"
+ private[pekko] final val MaxInlineEngineProgressAttempts = 1 << 10
+ private[pekko] final val MaxApplicationRecordsPerEngineCall = 4
+ private[pekko] final val MaxBatchedApplicationBytes = 1024
+ private[pekko] final val DeferredTransportFlushLimit = 16
+ private[pekko] final val InputBufferInitialSize = 16
+ private[pekko] final val InputBufferMaxSize = 64
+ private[pekko] final val UserInputQueueSize = InputBufferMaxSize
+ private[pekko] final val UserInputQueueMask = UserInputQueueSize - 1
+ private[pekko] val StreamTlsAttributes: Attributes =
+ Attributes.asyncBoundary and
Attributes.inputBuffer(InputBufferInitialSize, InputBufferMaxSize)
+
+ private[pekko] final val PlainDataAllowedFlag = 1 << 0
+ private[pekko] final val StageClosedFlag = 1 << 1
+ private[pekko] final val InitialPumpEnabledFlag = 1 << 2
+ private[pekko] final val TransportFlushPendingFlag = 1 << 3
+ private[pekko] final val UserInputHasPendingFlag = 1 << 4
+ private[pekko] final val UserInputFinishedFlag = 1 << 5
+ private[pekko] final val UserInputCancelledFlag = 1 << 6
+ private[pekko] final val TransportInputHasPendingFlag = 1 << 7
+ private[pekko] final val TransportInputFinishedFlag = 1 << 8
+ private[pekko] final val TransportInputCancelledFlag = 1 << 9
+ private[pekko] final val UserOutputCancelledFlag = 1 << 10
+ private[pekko] final val UserOutputCompletedFlag = 1 << 11
+ private[pekko] final val UserOutputErroredFlag = 1 << 12
+ private[pekko] final val TransportOutputCancelledFlag = 1 << 13
+ private[pekko] final val TransportOutputCompletedFlag = 1 << 14
+ private[pekko] final val TransportOutputErroredFlag = 1 << 15
+ private[pekko] final val StageCompletionPendingFlag = 1 << 16
+ private[pekko] final val UserInputBatchPendingFlag = 1 << 17
+ private[pekko] final val InitialPumpDeferringOutputFlag = 1 << 18
+
+ private[pekko] final val BidirectionalPhase = 1
+ private[pekko] final val FlushingOutboundPhase = 2
+ private[pekko] final val AwaitingClosePhase = 3
+ private[pekko] final val OutboundClosedPhase = 4
+ private[pekko] final val InboundClosedPhase = 5
+ private[pekko] final val CompletedPhase = 6
+}
diff --git a/stream/src/main/scala/org/apache/pekko/stream/scaladsl/TLS.scala
b/stream/src/main/scala/org/apache/pekko/stream/scaladsl/TLS.scala
index 88c2585d3d..57e74f1f26 100644
--- a/stream/src/main/scala/org/apache/pekko/stream/scaladsl/TLS.scala
+++ b/stream/src/main/scala/org/apache/pekko/stream/scaladsl/TLS.scala
@@ -17,11 +17,13 @@ import javax.net.ssl.{ SSLContext, SSLEngine, SSLSession }
import scala.util.{ Success, Try }
+import com.typesafe.config.{ ConfigException, ConfigFactory }
+
import org.apache.pekko
import pekko.NotUsed
import pekko.stream._
import pekko.stream.TLSProtocol._
-import pekko.stream.impl.io.TlsModule
+import pekko.stream.impl.io.{ TlsGraphStage, TlsModule }
import pekko.util.ByteString
/**
@@ -61,6 +63,40 @@ import pekko.util.ByteString
*/
object TLS {
+ /**
+ * INTERNAL API.
+ *
+ * Selects the Stream TLS engine. This is read once when [[TLS]] is
initialized;
+ * the non-blank system property form of the key wins over configuration.
+ */
+ private sealed trait StreamTlsEngine
+ private case object LegacyActorEngine extends StreamTlsEngine
+ private case object GraphStageEngine extends StreamTlsEngine
+
+ private[scaladsl] final val LegacyActorEngineName = "legacy-actor"
+ private[scaladsl] final val GraphStageEngineName = "graph-stage"
+ private val TlsEngineKey = "pekko.stream.materializer.tls.engine"
+
+ private[scaladsl] def configuredEngineName(systemProperty: Option[String],
configValue: => String): String =
+ systemProperty.map(_.trim).filter(_.nonEmpty).getOrElse(configValue.trim)
+
+ private def configuredEngineNameFromConfig: String =
+ try ConfigFactory.load().getString(TlsEngineKey)
+ catch { case _: ConfigException.Missing => LegacyActorEngineName }
+
+ private val selectedEngine: StreamTlsEngine = {
+ val configured =
configuredEngineName(Option(System.getProperty(TlsEngineKey)),
configuredEngineNameFromConfig)
+
+ configured match {
+ case LegacyActorEngineName => LegacyActorEngine
+ case GraphStageEngineName => GraphStageEngine
+ case other =>
+ throw new IllegalArgumentException(
+ s"Unsupported TLS engine [$other]. Expected one of
[$LegacyActorEngineName, $GraphStageEngineName] " +
+ s"for [$TlsEngineKey].")
+ }
+ }
+
/**
* Create a StreamTls [[pekko.stream.scaladsl.BidiFlow]].
*
@@ -76,8 +112,16 @@ object TLS {
createSSLEngine: () => SSLEngine,
verifySession: SSLSession => Try[Unit],
closing: TLSClosing): scaladsl.BidiFlow[SslTlsOutbound, ByteString,
ByteString, SslTlsInbound, NotUsed] =
- scaladsl.BidiFlow.fromGraph(
- TlsModule(Attributes.none, () => createSSLEngine(), session =>
verifySession(session), closing))
+ selectedEngine match {
+ case LegacyActorEngine =>
+ scaladsl.BidiFlow.fromGraph(
+ TlsModule(Attributes.none, () => createSSLEngine(), session =>
verifySession(session), closing))
+
+ case GraphStageEngine =>
+ scaladsl.BidiFlow
+ .fromGraph(new TlsGraphStage(createSSLEngine, verifySession,
closing))
+ .withAttributes(TlsGraphStage.StreamTlsAttributes)
+ }
/**
* Create a StreamTls [[pekko.stream.scaladsl.BidiFlow]].
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]