This is an automated email from the ASF dual-hosted git repository.
He-Pin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/pekko-grpc.git
The following commit(s) were added to refs/heads/main by this push:
new 0dabeb39 Optimize unary gRPC handler fast paths (#686)
0dabeb39 is described below
commit 0dabeb39291df116f814a8d5d623faa4d16440f9
Author: He-Pin(kerr) <[email protected]>
AuthorDate: Sat May 16 17:01:56 2026 +0800
Optimize unary gRPC handler fast paths (#686)
* perf: optimize unary gRPC fast paths
Motivation:
The grpc_bench Scala Pekko server uses the Scala generated unary handler
path. For strict unary requests and already-completed service futures, the old
generated handler still went through Future flatMap/map/recoverWith scheduling
and the apply path went through partial/orElse dispatch.
Modification:
Add a Scala unary handler helper that synchronously continues
already-completed strict unmarshal and service futures, while preserving the
asynchronous path for incomplete futures. Generate single-service Scala
handlers through the direct handler path and use the unary helper for unary
methods. Add a Scala unary handler JMH benchmark comparing the generated
handler with the old-style Future chain. Keep the Java DSL strict unmarshal
parity benchmark and allow JMH to run on JDKs where U [...]
Result:
JDK 25 JMH with gc profiler: generatedUnaryStrictRequestProcessing 243,468
ops/s and 2,088 B/op; oldStyleUnaryStrictRequestProcessing 83,123 ops/s and
2,625 B/op. The generated strict unary path is 2.9x faster with about 20
percent lower allocation in this focused benchmark.
* perf: reduce gRPC protobuf frame allocations
* refactor: inline Scala unary response handling
* perf: reduce unary success-path allocations
---
.../org/apache/pekko/grpc/CommonBenchmark.scala | 1 +
.../pekko/grpc/GrpcMarshallingBenchmark.scala | 50 ++++++-
.../pekko/grpc/JavaUnaryHandlerBenchmark.scala | 162 +++++++++++++++++++++
.../pekko/grpc/ScalaUnaryHandlerBenchmark.scala | 155 ++++++++++++++++++++
build.sbt | 1 +
.../twirl/templates/JavaServer/Handler.scala.txt | 31 +++-
.../twirl/templates/ScalaServer/Handler.scala.txt | 72 +++++++--
.../org/apache/pekko/grpc/ProtobufSerializer.scala | 4 +
.../pekko/grpc/internal/AbstractGrpcProtocol.scala | 36 +++--
.../pekko/grpc/internal/GrpcEntityHelpers.scala | 18 ++-
.../pekko/grpc/internal/GrpcResponseHelpers.scala | 49 ++++++-
.../grpc/javadsl/GoogleProtobufSerializer.scala | 17 ++-
.../pekko/grpc/javadsl/GrpcMarshalling.scala | 58 +++++++-
.../pekko/grpc/scaladsl/GrpcMarshalling.scala | 82 ++++++++++-
.../grpc/scaladsl/ScalapbProtobufSerializer.scala | 17 ++-
.../pekko/grpc/javadsl/GrpcMarshallingSpec.scala | 86 +++++++++++
16 files changed, 792 insertions(+), 47 deletions(-)
diff --git
a/benchmarks/src/main/scala/org/apache/pekko/grpc/CommonBenchmark.scala
b/benchmarks/src/main/scala/org/apache/pekko/grpc/CommonBenchmark.scala
index 4babe8e0..fc829141 100644
--- a/benchmarks/src/main/scala/org/apache/pekko/grpc/CommonBenchmark.scala
+++ b/benchmarks/src/main/scala/org/apache/pekko/grpc/CommonBenchmark.scala
@@ -38,6 +38,7 @@ import org.openjdk.jmh.annotations.Warmup
"-XX:InitialCodeCacheSize=512m",
"-XX:ReservedCodeCacheSize=512m",
"-XX:+UseParallelGC",
+ "-XX:+IgnoreUnrecognizedVMOptions",
"-XX:-UseBiasedLocking",
"-XX:+AlwaysPreTouch"))
@BenchmarkMode(Array(Mode.Throughput))
diff --git
a/benchmarks/src/main/scala/org/apache/pekko/grpc/GrpcMarshallingBenchmark.scala
b/benchmarks/src/main/scala/org/apache/pekko/grpc/GrpcMarshallingBenchmark.scala
index c81c9c28..4304cc13 100644
---
a/benchmarks/src/main/scala/org/apache/pekko/grpc/GrpcMarshallingBenchmark.scala
+++
b/benchmarks/src/main/scala/org/apache/pekko/grpc/GrpcMarshallingBenchmark.scala
@@ -13,17 +13,22 @@
package org.apache.pekko.grpc
+import scala.concurrent.Await
+import scala.concurrent.duration.Duration
+
+import com.google.protobuf.{ Any => JavaAny, ByteString => JavaByteString }
import org.apache.pekko
import pekko.actor.ActorSystem
-import pekko.grpc.internal.{ GrpcProtocolNative, Identity }
+import pekko.grpc.internal.{ AbstractGrpcProtocol, GrpcProtocolNative,
Identity }
import pekko.grpc.scaladsl.{ GrpcMarshalling, ScalapbProtobufSerializer }
-import pekko.http.scaladsl.model.HttpResponse
+import pekko.http.scaladsl.model.{ HttpEntity, HttpResponse }
+import pekko.stream.SystemMaterializer
import pekko.stream.scaladsl.Source
import io.grpc.reflection.v1.reflection._
import org.openjdk.jmh.annotations._
// Microbenchmarks for GrpcMarshalling.
-// Does not actually benchmarks the actual marshalling because we dont consume
the HttpResponse
+// Does not actually benchmark response marshalling because we don't consume
the HttpResponse.
class GrpcMarshallingBenchmark extends CommonBenchmark {
implicit val system: ActorSystem = ActorSystem("bench")
implicit val writer: GrpcProtocol.GrpcProtocolWriter =
GrpcProtocolNative.newWriter(Identity)
@@ -31,14 +36,49 @@ class GrpcMarshallingBenchmark extends CommonBenchmark {
implicit val serializer: ScalapbProtobufSerializer[ServerReflectionRequest] =
ServerReflection.Serializers.ServerReflectionRequestSerializer
+ val request = ServerReflectionRequest()
+ val entity: HttpEntity.Strict =
+ HttpEntity.Strict(
+ GrpcProtocolNative.contentType,
+ AbstractGrpcProtocol.encodeFrameData(serializer.serialize(request),
isCompressed = false, isTrailer = false))
+
+ val javaSerializer = new
pekko.grpc.javadsl.GoogleProtobufSerializer(JavaAny.parser())
+ val javaRequest: JavaAny =
+
JavaAny.newBuilder().setTypeUrl("benchmark").setValue(JavaByteString.copyFromUtf8("payload")).build()
+ val javaEntity: pekko.http.javadsl.model.HttpEntity =
+ HttpEntity.Strict(
+ GrpcProtocolNative.contentType,
+
AbstractGrpcProtocol.encodeFrameData(javaSerializer.serialize(javaRequest),
isCompressed = false,
+ isTrailer = false))
+
+ val mat = SystemMaterializer(system).materializer
+
@Benchmark
def marshall(): HttpResponse = {
- GrpcMarshalling.marshal(ServerReflectionRequest())
+ GrpcMarshalling.marshal(request)
}
@Benchmark
def marshallStream(): HttpResponse = {
-
GrpcMarshalling.marshalStream(Source.repeat(ServerReflectionRequest()).take(10000))
+ GrpcMarshalling.marshalStream(Source.repeat(request).take(10000))
+ }
+
+ @Benchmark
+ def unmarshallStrict(): ServerReflectionRequest = {
+ Await.result(GrpcMarshalling.unmarshal(entity), Duration.Inf)
+ }
+
+ @Benchmark
+ def unmarshallJavaStrict(): JavaAny = {
+ pekko.grpc.javadsl.GrpcMarshalling.unmarshal(javaEntity, javaSerializer,
mat, reader).toCompletableFuture.get()
+ }
+
+ @Benchmark
+ def unmarshallJavaStrictStreamed(): JavaAny = {
+ pekko.grpc.javadsl.GrpcMarshalling
+ .unmarshal(javaEntity.getDataBytes, javaSerializer, mat, reader)
+ .toCompletableFuture
+ .get()
}
@TearDown
diff --git
a/benchmarks/src/main/scala/org/apache/pekko/grpc/JavaUnaryHandlerBenchmark.scala
b/benchmarks/src/main/scala/org/apache/pekko/grpc/JavaUnaryHandlerBenchmark.scala
new file mode 100644
index 00000000..d94f93f3
--- /dev/null
+++
b/benchmarks/src/main/scala/org/apache/pekko/grpc/JavaUnaryHandlerBenchmark.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.pekko.grpc
+
+import java.util.concurrent.CompletableFuture
+import java.util.concurrent.CompletionStage
+
+import scala.annotation.nowarn
+import scala.collection.immutable
+import scala.concurrent.Await
+import scala.concurrent.duration.Duration
+
+import com.google.protobuf.{ Any => JavaAny }
+import com.google.protobuf.{ ByteString => JavaByteString }
+import org.openjdk.jmh.annotations.Benchmark
+import org.openjdk.jmh.annotations.TearDown
+import org.openjdk.jmh.infra.Blackhole
+
+import org.apache.pekko
+import pekko.actor.ActorSystem
+import pekko.grpc.internal.AbstractGrpcProtocol
+import pekko.grpc.internal.Codecs
+import pekko.grpc.internal.GrpcProtocolNative
+import pekko.grpc.internal.Identity
+import pekko.grpc.javadsl.{ GoogleProtobufSerializer, GrpcExceptionHandler =>
JGrpcExceptionHandler }
+import pekko.grpc.javadsl.{ GrpcMarshalling => JGrpcMarshalling }
+import pekko.grpc.scaladsl.headers.`Message-Accept-Encoding`
+import pekko.grpc.scaladsl.headers.`Message-Encoding`
+import pekko.http.javadsl.model.{ HttpRequest => JHttpRequest }
+import pekko.http.javadsl.model.{ HttpResponse => JHttpResponse }
+import pekko.http.scaladsl.model.HttpEntity
+import pekko.http.scaladsl.model.HttpMethods
+import pekko.http.scaladsl.model.HttpRequest
+import pekko.http.scaladsl.model.HttpResponse
+import pekko.http.scaladsl.model.TransferEncodings
+import pekko.http.scaladsl.model.Uri
+import pekko.stream.Materializer
+import pekko.stream.SystemMaterializer
+import pekko.stream.scaladsl.Sink
+
+class JavaUnaryHandlerBenchmark extends CommonBenchmark {
+ private implicit val system: ActorSystem = ActorSystem("bench")
+ private val mat: Materializer = SystemMaterializer(system).materializer
+
+ private val writer = GrpcProtocolNative.newWriter(Identity)
+ private val serializer = new GoogleProtobufSerializer(JavaAny.parser())
+ private val requestMessage: JavaAny =
+
JavaAny.newBuilder().setTypeUrl("benchmark").setValue(JavaByteString.copyFromUtf8("request")).build()
+ private val responseMessage: JavaAny =
+
JavaAny.newBuilder().setTypeUrl("benchmark").setValue(JavaByteString.copyFromUtf8("response")).build()
+ private val eHandler = JGrpcExceptionHandler.defaultMapper
+
+ private val request: JHttpRequest = {
+ val data =
+ AbstractGrpcProtocol.encodeFrameData(
+ serializer.serialize(requestMessage),
+ isCompressed = false,
+ isTrailer = false)
+
+ HttpRequest(
+ method = HttpMethods.POST,
+ uri = Uri("https://unused.example/benchmark/Unary"),
+ headers = immutable.Seq(
+ `Message-Encoding`(writer.messageEncoding.name),
+
`Message-Accept-Encoding`(Codecs.supportedCodecs.map(_.name).mkString(",")),
+ pekko.http.scaladsl.model.headers.TE(TransferEncodings.trailers)),
+ entity = HttpEntity.Strict(writer.contentType, data))
+ }
+
+ private val unsupportedMediaType: CompletionStage[JHttpResponse] =
+ CompletableFuture.completedFuture(
+ pekko.http.javadsl.model.HttpResponse
+ .create()
+
.withStatus(pekko.http.javadsl.model.StatusCodes.UNSUPPORTED_MEDIA_TYPE))
+
+ private val generatedStyleHandler: JHttpRequest =>
CompletionStage[JHttpResponse] =
+ request =>
+ JGrpcMarshalling
+ .negotiated[JHttpResponse](
+ request,
+ (reader, writer) => {
+ request.entity() match {
+ case strict: pekko.http.scaladsl.model.HttpEntity.Strict =>
+ try {
+ JGrpcMarshalling.handleUnaryResponse(
+
invoke(serializer.deserialize(reader.decodeSingleFrame(strict.data))),
+ serializer,
+ writer,
+ system,
+ eHandler)
+ } catch {
+ case error: Throwable =>
JGrpcMarshalling.handleUnaryFailure(error, writer, system, eHandler)
+ }
+ case _ =>
+ JGrpcMarshalling.handleUnaryResponse(
+ JGrpcMarshalling
+ .unmarshal(request.entity(), serializer, mat, reader)
+ .thenCompose(in => invoke(in)),
+ serializer,
+ writer,
+ system,
+ eHandler)
+ }
+ })
+ .orElseGet(() => unsupportedMediaType)
+
+ private val oldStyleHandler: JHttpRequest => CompletionStage[JHttpResponse] =
+ request =>
+ JGrpcMarshalling
+ .negotiated[JHttpResponse](
+ request,
+ (reader, writer) =>
+ JGrpcMarshalling
+ .unmarshal(request.entity(), serializer, mat, reader)
+ .thenCompose(in => invoke(in))
+ .thenApply(out => JGrpcMarshalling.marshal(out, serializer,
writer, system, eHandler))
+ .exceptionally(error => JGrpcExceptionHandler.standard(error,
eHandler, writer, system)))
+ .orElseGet(() => unsupportedMediaType)
+
+ @Benchmark
+ def generatedStyleUnaryStrictRequestProcessing(blackhole: Blackhole): Unit =
+ consumeResponse(generatedStyleHandler(request).toCompletableFuture.get(),
blackhole)
+
+ @Benchmark
+ def oldStyleUnaryStrictRequestProcessing(blackhole: Blackhole): Unit =
+ consumeResponse(oldStyleHandler(request).toCompletableFuture.get(),
blackhole)
+
+ private def invoke(@nowarn("msg=never used") in: JavaAny):
CompletionStage[JavaAny] =
+ CompletableFuture.completedFuture(responseMessage)
+
+ private def consumeResponse(response: JHttpResponse, blackhole: Blackhole):
Unit = {
+ val scalaResponse = response.asInstanceOf[HttpResponse]
+ blackhole.consume(scalaResponse.status)
+ scalaResponse.entity match {
+ case HttpEntity.Strict(_, data) =>
+ blackhole.consume(data)
+ case _ =>
+ Await.result(scalaResponse.entity.dataBytes.runWith(Sink.ignore)(mat),
Duration.Inf)
+ }
+ }
+
+ @TearDown
+ def tearDown(): Unit =
+ system.terminate()
+}
diff --git
a/benchmarks/src/main/scala/org/apache/pekko/grpc/ScalaUnaryHandlerBenchmark.scala
b/benchmarks/src/main/scala/org/apache/pekko/grpc/ScalaUnaryHandlerBenchmark.scala
new file mode 100644
index 00000000..1384ecce
--- /dev/null
+++
b/benchmarks/src/main/scala/org/apache/pekko/grpc/ScalaUnaryHandlerBenchmark.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.pekko.grpc
+
+import scala.collection.immutable
+import scala.concurrent.Await
+import scala.concurrent.ExecutionContext
+import scala.concurrent.Future
+import scala.concurrent.duration.Duration
+
+import org.openjdk.jmh.annotations.Benchmark
+import org.openjdk.jmh.annotations.TearDown
+import org.openjdk.jmh.infra.Blackhole
+
+import org.apache.pekko
+import pekko.NotUsed
+import pekko.actor.ActorSystem
+import pekko.grpc.internal.AbstractGrpcProtocol
+import pekko.grpc.internal.Codecs
+import pekko.grpc.internal.GrpcProtocolNative
+import pekko.grpc.internal.Identity
+import pekko.grpc.internal.TelemetryExtension
+import pekko.grpc.scaladsl.GrpcExceptionHandler
+import pekko.grpc.scaladsl.GrpcMarshalling
+import pekko.grpc.scaladsl.headers.`Message-Accept-Encoding`
+import pekko.grpc.scaladsl.headers.`Message-Encoding`
+import pekko.http.scaladsl.model.HttpEntity
+import pekko.http.scaladsl.model.HttpMethods
+import pekko.http.scaladsl.model.HttpRequest
+import pekko.http.scaladsl.model.HttpResponse
+import pekko.http.scaladsl.model.StatusCodes
+import pekko.http.scaladsl.model.TransferEncodings
+import pekko.http.scaladsl.model.Uri
+import pekko.stream.Materializer
+import pekko.stream.SystemMaterializer
+import pekko.stream.scaladsl.Sink
+import pekko.stream.scaladsl.Source
+
+import example.myapp.helloworld.grpc.GreeterService
+import example.myapp.helloworld.grpc.GreeterServiceHandler
+import example.myapp.helloworld.grpc.HelloReply
+import example.myapp.helloworld.grpc.HelloRequest
+
+class ScalaUnaryHandlerBenchmark extends CommonBenchmark {
+ implicit val system: ActorSystem = ActorSystem("bench")
+ private implicit val mat: Materializer =
SystemMaterializer(system).materializer
+ private implicit val ec: ExecutionContext = mat.executionContext
+
+ private val writer = GrpcProtocolNative.newWriter(Identity)
+ private val requestMessage = HelloRequest("Alice")
+ private val responseMessage = HelloReply("Hello, Alice")
+ private val implementation = new BenchmarkGreeterService(responseMessage)
+
+ private val request: HttpRequest = {
+ val data =
+ AbstractGrpcProtocol.encodeFrameData(
+
GreeterService.Serializers.HelloRequestSerializer.serialize(requestMessage),
+ isCompressed = false,
+ isTrailer = false)
+
+ HttpRequest(
+ method = HttpMethods.POST,
+ uri = Uri("https://unused.example/" + GreeterService.name + "/SayHello"),
+ headers = immutable.Seq(
+ `Message-Encoding`(writer.messageEncoding.name),
+
`Message-Accept-Encoding`(Codecs.supportedCodecs.map(_.name).mkString(",")),
+ pekko.http.scaladsl.model.headers.TE(TransferEncodings.trailers)),
+ entity = HttpEntity.Strict(writer.contentType, data))
+ }
+
+ private val generatedHandler: HttpRequest => Future[HttpResponse] =
+ GreeterServiceHandler(implementation)
+
+ private val oldStyleHandler: HttpRequest => Future[HttpResponse] = {
+ val notFound = Future.successful(HttpResponse(StatusCodes.NotFound))
+ val unsupportedMediaType =
Future.successful(HttpResponse(StatusCodes.UnsupportedMediaType))
+ val spi = TelemetryExtension(system).spi
+ val eHandler = GrpcExceptionHandler.defaultMapper _
+
+ import GreeterService.Serializers._
+
+ request =>
+ request.uri.path match {
+ case Uri.Path.Slash(
+ Uri.Path.Segment(
+ GreeterService.name,
+ Uri.Path.Slash(Uri.Path.Segment("SayHello", Uri.Path.Empty))))
=>
+ val requestWithTelemetry = spi.onRequest(GreeterService.name,
"SayHello", request)
+ GrpcMarshalling
+ .negotiated(requestWithTelemetry,
+ (reader, writer) =>
+ GrpcMarshalling
+
.unmarshal(requestWithTelemetry.entity)(HelloRequestSerializer, mat, reader)
+ .flatMap(implementation.sayHello)
+ .map(e => GrpcMarshalling.marshal(e)(HelloReplySerializer,
writer, system))
+
.recoverWith(GrpcExceptionHandler.from(eHandler(system))(system, writer)))
+ .getOrElse(unsupportedMediaType)
+ case _ =>
+ notFound
+ }
+ }
+
+ @Benchmark
+ def generatedUnaryStrictRequestProcessing(blackhole: Blackhole): Unit =
+ consumeResponse(Await.result(generatedHandler(request), Duration.Inf),
blackhole)
+
+ @Benchmark
+ def oldStyleUnaryStrictRequestProcessing(blackhole: Blackhole): Unit =
+ consumeResponse(Await.result(oldStyleHandler(request), Duration.Inf),
blackhole)
+
+ private def consumeResponse(response: HttpResponse, blackhole: Blackhole):
Unit = {
+ blackhole.consume(response.status)
+ response.entity match {
+ case HttpEntity.Strict(_, data) =>
+ blackhole.consume(data)
+ case _ =>
+ Await.result(response.entity.dataBytes.runWith(Sink.ignore),
Duration.Inf)
+ }
+ }
+
+ @TearDown
+ def tearDown(): Unit =
+ system.terminate()
+
+ private final class BenchmarkGreeterService(response: HelloReply) extends
GreeterService {
+ override def sayHello(in: HelloRequest): Future[HelloReply] =
+ Future.successful(response)
+
+ override def itKeepsTalking(in: Source[HelloRequest, NotUsed]):
Future[HelloReply] =
+ throw new UnsupportedOperationException("itKeepsTalking")
+
+ override def itKeepsReplying(in: HelloRequest): Source[HelloReply,
NotUsed] =
+ throw new UnsupportedOperationException("itKeepsReplying")
+
+ override def streamHellos(in: Source[HelloRequest, NotUsed]):
Source[HelloReply, NotUsed] =
+ throw new UnsupportedOperationException("streamHellos")
+ }
+}
diff --git a/build.sbt b/build.sbt
index df92c294..d89bbbdd 100644
--- a/build.sbt
+++ b/build.sbt
@@ -232,6 +232,7 @@ lazy val interopTests = Project(id = "interop-tests", base
= file("interop-tests
lazy val benchmarks = Project(id = "benchmarks", base = file("benchmarks"))
.dependsOn(runtime)
+ .dependsOn(pluginTesterScala)
.enablePlugins(JmhPlugin)
.disablePlugins(MimaPlugin)
.settings(
diff --git a/codegen/src/main/twirl/templates/JavaServer/Handler.scala.txt
b/codegen/src/main/twirl/templates/JavaServer/Handler.scala.txt
index d9fe956b..d8796868 100644
--- a/codegen/src/main/twirl/templates/JavaServer/Handler.scala.txt
+++ b/codegen/src/main/twirl/templates/JavaServer/Handler.scala.txt
@@ -164,15 +164,42 @@ public class @{serviceName}HandlerFactory {
private static
CompletionStage<org.apache.pekko.http.javadsl.model.HttpResponse>
handle(org.apache.pekko.http.javadsl.model.HttpRequest request, String method,
@serviceName implementation, Materializer mat,
org.apache.pekko.japi.function.Function<ActorSystem,
org.apache.pekko.japi.function.Function<Throwable, Trailers>> eHandler,
ClassicActorSystemProvider system) {
return GrpcMarshalling.negotiated(request, (reader, writer) -> {
- final
CompletionStage<org.apache.pekko.http.javadsl.model.HttpResponse> response;
+ CompletionStage<org.apache.pekko.http.javadsl.model.HttpResponse>
response;
@{if(powerApis) { "Metadata metadata =
MetadataBuilder.fromHeaders(request.getHeaders());" } else { "" }}
switch(method) {
@for(method <- service.methods) {
- case "@method.grpcName":
+ case "@method.grpcName": {
+ @if(method.methodType == org.apache.pekko.grpc.gen.Unary) {
+ final org.apache.pekko.http.javadsl.model.HttpEntity entity =
request.entity();
+ if (entity instanceof
org.apache.pekko.http.scaladsl.model.HttpEntity.Strict) {
+ final org.apache.pekko.http.scaladsl.model.HttpEntity.Strict
strictEntity =
+ (org.apache.pekko.http.scaladsl.model.HttpEntity.Strict)
entity;
+ try {
+ return GrpcMarshalling.handleUnaryResponse(
+
implementation.@{method.name}(@{method.deserializer.name}.deserialize(reader.decodeSingleFrame().apply(strictEntity.data()))@{if(powerApis)
{ ", metadata" } else { "" }}),
+ @method.serializer.name,
+ writer,
+ system,
+ eHandler);
+ } catch (Throwable e) {
+ return GrpcMarshalling.handleUnaryFailure(e, writer, system,
eHandler);
+ }
+ } else {
+ return GrpcMarshalling.handleUnaryResponse(
+ GrpcMarshalling.unmarshal(request.entity(),
@method.deserializer.name, mat, reader)
+ .thenCompose(e ->
implementation.@{method.name}(e@{if(powerApis) { ", metadata" } else { "" }})),
+ @method.serializer.name,
+ writer,
+ system,
+ eHandler);
+ }
+ } else {
response = @{method.unmarshal}(request.entity(),
@method.deserializer.name, mat, reader)
.@{if(method.outputStreaming) { "thenApply" } else {
"thenCompose" }}(e -> implementation.@{method.name}(e@{if(powerApis) { ",
metadata" } else { "" }}))
.thenApply(e -> @{method.marshal}(e, @method.serializer.name,
writer, system, eHandler));
break;
+ }
+ }
}
default:
CompletableFuture<org.apache.pekko.http.javadsl.model.HttpResponse> result =
new CompletableFuture<>();
diff --git a/codegen/src/main/twirl/templates/ScalaServer/Handler.scala.txt
b/codegen/src/main/twirl/templates/ScalaServer/Handler.scala.txt
index 42c46584..e8c6af53 100644
--- a/codegen/src/main/twirl/templates/ScalaServer/Handler.scala.txt
+++ b/codegen/src/main/twirl/templates/ScalaServer/Handler.scala.txt
@@ -55,7 +55,7 @@ object @{serviceName}Handler {
* several services.
*/
def apply(implementation: @serviceName)(implicit system:
ClassicActorSystemProvider): model.HttpRequest =>
scala.concurrent.Future[model.HttpResponse] =
- partial(implementation).orElse { case _ => notFound }
+ handler(implementation, @{service.name}.name,
GrpcExceptionHandler.defaultMapper)
/**
* Creates a `HttpRequest` to `HttpResponse` handler that can be used in
for example `Http().bindAndHandleAsync`
@@ -65,7 +65,7 @@ object @{serviceName}Handler {
* several services.
*/
def apply(implementation: @serviceName, eHandler: ActorSystem =>
PartialFunction[Throwable, Trailers])(implicit system:
ClassicActorSystemProvider): model.HttpRequest =>
scala.concurrent.Future[model.HttpResponse] =
- partial(implementation, @{service.name}.name, eHandler).orElse { case _
=> notFound }
+ handler(implementation, @{service.name}.name, eHandler)
/**
* Creates a `HttpRequest` to `HttpResponse` handler that can be used in
for example `Http().bindAndHandleAsync`
@@ -77,7 +77,7 @@ object @{serviceName}Handler {
* Registering a gRPC service under a custom prefix is not widely
supported and strongly discouraged by the specification.
*/
def apply(implementation: @serviceName, prefix: String)(implicit system:
ClassicActorSystemProvider): model.HttpRequest =>
scala.concurrent.Future[model.HttpResponse] =
- partial(implementation, prefix).orElse { case _ => notFound }
+ handler(implementation, prefix, GrpcExceptionHandler.defaultMapper)
/**
* Creates a `HttpRequest` to `HttpResponse` handler that can be used in
for example `Http().bindAndHandleAsync`
@@ -89,7 +89,7 @@ object @{serviceName}Handler {
* Registering a gRPC service under a custom prefix is not widely
supported and strongly discouraged by the specification.
*/
def apply(implementation: @serviceName, prefix: String, eHandler:
ActorSystem => PartialFunction[Throwable, Trailers])(implicit system:
ClassicActorSystemProvider): model.HttpRequest =>
scala.concurrent.Future[model.HttpResponse] =
- partial(implementation, prefix, eHandler).orElse { case _ => notFound }
+ handler(implementation, prefix, eHandler)
@if(serviceName != "ServerReflection") {
@@ -107,6 +107,48 @@ object @{serviceName}Handler {
pekko.grpc.scaladsl.ServerReflection.partial(List(@{service.name})))
}
+ private def methodName(request: model.HttpRequest, prefix: String): String
=
+ request.uri.path match {
+ case model.Uri.Path.Slash(model.Uri.Path.Segment(`prefix`,
model.Uri.Path.Slash(model.Uri.Path.Segment(method, model.Uri.Path.Empty)))) =>
+ method
+ case _ =>
+ null
+ }
+
+ private def handler(implementation: @serviceName, prefix: String,
eHandler: ActorSystem => PartialFunction[Throwable, Trailers])(implicit system:
ClassicActorSystemProvider): model.HttpRequest =>
scala.concurrent.Future[model.HttpResponse] = {
+ implicit val mat: Materializer = SystemMaterializer(system).materializer
+ implicit val ec: ExecutionContext = mat.executionContext
+ val spi = TelemetryExtension(system).spi
+
+ import
@{service.name}.Serializers.@{service.scalaCompatConstants.WildcardImport}
+
+ def handle(request: model.HttpRequest, method: String):
scala.concurrent.Future[model.HttpResponse] =
+ GrpcMarshalling.negotiated(request, (reader, writer) =>
+ method match {
+ @for(method <- service.methods) {
+ case "@method.grpcName" =>
+ @{if(powerApis) { "val metadata =
MetadataBuilder.fromHeaders(request.headers)" } else { "" }}
+ @if(method.methodType == org.apache.pekko.grpc.gen.Unary) {
+ GrpcMarshalling.handleUnary(request.entity, (e:
@method.parameterType) => implementation.@{method.nameSafe}(e@{if(powerApis) {
", metadata" } else { "" }}), eHandler)(
+
@{service.scalaCompatConstants.ImplicitUsing}@method.deserializer.name,
@method.serializer.name, mat, reader, writer, system, ec)
+ } else {
+
@{method.unmarshal}(request.entity)(@{service.scalaCompatConstants.ImplicitUsing}@method.deserializer.name,
mat, reader)
+ .@{if(method.outputStreaming) { "map" } else { "flatMap"
}}(implementation.@{method.nameSafe}(_@{if(powerApis) { ", metadata" } else {
"" }}))
+ .map(e => @{method.marshal}(e,
eHandler)(@{service.scalaCompatConstants.ImplicitUsing}@method.serializer.name,
writer, system))
+
.recoverWith(GrpcExceptionHandler.from(eHandler(system.classicSystem))(system,
writer))
+ }
+ }
+ case m =>
GrpcExceptionHandler.from(eHandler(system.classicSystem))(system, writer)(new
NotImplementedError(s"Not implemented: $m"))
+ }
+ ).getOrElse(unsupportedMediaType)
+
+ request => {
+ val method = methodName(request, prefix)
+ if (method eq null) notFound
+ else handle(spi.onRequest(prefix, method, request), method)
+ }
+ }
+
/**
* Creates a partial `HttpRequest` to `HttpResponse` handler that can be
combined with handlers of other
* services with
`org.apache.pekko.grpc.scaladsl.ServiceHandler.concatOrNotFound` and then used
in for example
@@ -125,24 +167,28 @@ object @{serviceName}Handler {
def handle(request: model.HttpRequest, method: String):
scala.concurrent.Future[model.HttpResponse] =
GrpcMarshalling.negotiated(request, (reader, writer) =>
- (method match {
+ method match {
@for(method <- service.methods) {
case "@method.grpcName" =>
@{if(powerApis) { "val metadata =
MetadataBuilder.fromHeaders(request.headers)" } else { "" }}
+ @if(method.methodType == org.apache.pekko.grpc.gen.Unary) {
+ GrpcMarshalling.handleUnary(request.entity, (e:
@method.parameterType) => implementation.@{method.nameSafe}(e@{if(powerApis) {
", metadata" } else { "" }}), eHandler)(
+
@{service.scalaCompatConstants.ImplicitUsing}@method.deserializer.name,
@method.serializer.name, mat, reader, writer, system, ec)
+ } else {
@{method.unmarshal}(request.entity)(@{service.scalaCompatConstants.ImplicitUsing}@method.deserializer.name,
mat, reader)
.@{if(method.outputStreaming) { "map" } else { "flatMap"
}}(implementation.@{method.nameSafe}(_@{if(powerApis) { ", metadata" } else {
"" }}))
.map(e => @{method.marshal}(e,
eHandler)(@{service.scalaCompatConstants.ImplicitUsing}@method.serializer.name,
writer, system))
+
.recoverWith(GrpcExceptionHandler.from(eHandler(system.classicSystem))(system,
writer))
+ }
}
- case m => scala.concurrent.Future.failed(new
NotImplementedError(s"Not implemented: $m"))
- })
-
.recoverWith(GrpcExceptionHandler.from(eHandler(system.classicSystem))(system,
writer))
+ case m =>
GrpcExceptionHandler.from(eHandler(system.classicSystem))(system, writer)(new
NotImplementedError(s"Not implemented: $m"))
+ }
).getOrElse(unsupportedMediaType)
- Function.unlift((req: model.HttpRequest) => req.uri.path match {
- case model.Uri.Path.Slash(model.Uri.Path.Segment(`prefix`,
model.Uri.Path.Slash(model.Uri.Path.Segment(method, model.Uri.Path.Empty)))) =>
- Some(handle(spi.onRequest(prefix, method, req), method))
- case _ =>
- None
+ Function.unlift((req: model.HttpRequest) => {
+ val method = methodName(req, prefix)
+ if (method eq null) None
+ else Some(handle(spi.onRequest(prefix, method, req), method))
})
}
}
diff --git
a/runtime/src/main/scala/org/apache/pekko/grpc/ProtobufSerializer.scala
b/runtime/src/main/scala/org/apache/pekko/grpc/ProtobufSerializer.scala
index 64538749..e5392960 100644
--- a/runtime/src/main/scala/org/apache/pekko/grpc/ProtobufSerializer.scala
+++ b/runtime/src/main/scala/org/apache/pekko/grpc/ProtobufSerializer.scala
@@ -24,3 +24,7 @@ trait ProtobufSerializer[T] {
def deserialize(bytes: ByteString): T
def deserialize(stream: InputStream): T =
deserialize(ByteStringUtils.fromInputStream(stream))
}
+
+private[grpc] trait ProtobufFrameSerializer[T] extends ProtobufSerializer[T] {
+ private[grpc] def serializeDataFrame(t: T): ByteString
+}
diff --git
a/runtime/src/main/scala/org/apache/pekko/grpc/internal/AbstractGrpcProtocol.scala
b/runtime/src/main/scala/org/apache/pekko/grpc/internal/AbstractGrpcProtocol.scala
index 16ef234c..8cd4653c 100644
---
a/runtime/src/main/scala/org/apache/pekko/grpc/internal/AbstractGrpcProtocol.scala
+++
b/runtime/src/main/scala/org/apache/pekko/grpc/internal/AbstractGrpcProtocol.scala
@@ -24,10 +24,9 @@ import pekko.stream.impl.io.ByteStringParser
import pekko.stream.impl.io.ByteStringParser.{ ByteReader, ParseResult,
ParseStep }
import pekko.stream.scaladsl.Flow
import pekko.stream.stage.GraphStageLogic
-import pekko.util.{ ByteString, ByteStringBuilder }
+import pekko.util.ByteString
import io.grpc.StatusException
-import java.nio.ByteOrder
import scala.collection.immutable
abstract class AbstractGrpcProtocol(subType: String) extends GrpcProtocol {
@@ -69,6 +68,24 @@ object AbstractGrpcProtocol {
def fieldType(codec: Codec) = if (codec == Identity) notCompressed else
compressed
+ private[grpc] final val FrameHeaderSize = 5
+
+ private[grpc] def writeFrameHeader(
+ frame: Array[Byte],
+ offset: Int,
+ dataLength: Int,
+ isCompressed: Boolean,
+ isTrailer: Boolean): Unit = {
+ val flags =
+ (if (isCompressed) 1 else 0) | (if (isTrailer) 0x80 else 0)
+
+ frame(offset) = flags.toByte
+ frame(offset + 1) = (dataLength >>> 24).toByte
+ frame(offset + 2) = (dataLength >>> 16).toByte
+ frame(offset + 3) = (dataLength >>> 8).toByte
+ frame(offset + 4) = dataLength.toByte
+ }
+
/**
* Adjusts the compressibility of a content type to suit a message encoding.
* @param contentType the content type for the gRPC protocol.
@@ -84,18 +101,9 @@ object AbstractGrpcProtocol {
.toContentType
def encodeFrameData(data: ByteString, isCompressed: Boolean, isTrailer:
Boolean): ByteString = {
- implicit val byteOrder = ByteOrder.BIG_ENDIAN
- val length = data.length
- val builder = new ByteStringBuilder()
- builder.sizeHint(5)
- val flags =
- (if (isCompressed) 1 else 0) | (if (isTrailer) 0x80 else 0)
-
- builder // ...
- .putByte(flags.toByte)
- .putInt(length)
- .++=(data)
- .result()
+ val header = new Array[Byte](FrameHeaderSize)
+ writeFrameHeader(header, 0, data.length, isCompressed, isTrailer)
+ ByteString.fromArrayUnsafe(header, 0, FrameHeaderSize) ++ data
}
def writer(
diff --git
a/runtime/src/main/scala/org/apache/pekko/grpc/internal/GrpcEntityHelpers.scala
b/runtime/src/main/scala/org/apache/pekko/grpc/internal/GrpcEntityHelpers.scala
index 128b53f8..d39f0364 100644
---
a/runtime/src/main/scala/org/apache/pekko/grpc/internal/GrpcEntityHelpers.scala
+++
b/runtime/src/main/scala/org/apache/pekko/grpc/internal/GrpcEntityHelpers.scala
@@ -17,10 +17,10 @@ import org.apache.pekko
import pekko.NotUsed
import pekko.actor.{ ActorSystem, ClassicActorSystemProvider }
import pekko.annotation.InternalApi
-import pekko.grpc.{ GrpcServiceException, ProtobufSerializer, Trailers }
+import pekko.grpc.{ GrpcServiceException, ProtobufFrameSerializer,
ProtobufSerializer, Trailers }
import pekko.grpc.GrpcProtocol.{ DataFrame, Frame, GrpcProtocolWriter,
TrailerFrame }
import pekko.grpc.scaladsl.{ headers, BytesEntry, Metadata, MetadataEntry,
StringEntry }
-import pekko.http.scaladsl.model.HttpEntity.ChunkStreamPart
+import pekko.http.scaladsl.model.HttpEntity.{ Chunk, ChunkStreamPart }
import pekko.http.scaladsl.model.HttpHeader
import pekko.http.scaladsl.model.headers.RawHeader
import pekko.stream.scaladsl.Source
@@ -68,6 +68,20 @@ object GrpcEntityHelpers {
private def chunks[T](e: Source[T, NotUsed], trail: Source[Frame, NotUsed])(
implicit m: ProtobufSerializer[T],
writer: GrpcProtocolWriter): Source[ChunkStreamPart, NotUsed] =
+ if ((writer.messageEncoding eq Identity) && writer.contentType ==
GrpcProtocolNative.contentType) {
+ m match {
+ case frameSerializer: ProtobufFrameSerializer[T @unchecked] =>
+ e.map { msg => Chunk(frameSerializer.serializeDataFrame(msg)):
ChunkStreamPart }
+ .via(concatCheap(trail.map(writer.encodeFrame)))
+ case _ => frameEncodedChunks(e, trail)
+ }
+ } else {
+ frameEncodedChunks(e, trail)
+ }
+
+ private def frameEncodedChunks[T](e: Source[T, NotUsed], trail:
Source[Frame, NotUsed])(
+ implicit m: ProtobufSerializer[T],
+ writer: GrpcProtocolWriter): Source[ChunkStreamPart, NotUsed] =
e.map { msg => DataFrame(m.serialize(msg))
}.via(concatCheap(trail)).via(writer.frameEncoder)
def trailer(status: Status): TrailerFrame =
diff --git
a/runtime/src/main/scala/org/apache/pekko/grpc/internal/GrpcResponseHelpers.scala
b/runtime/src/main/scala/org/apache/pekko/grpc/internal/GrpcResponseHelpers.scala
index 58fef5a6..f3440aa2 100644
---
a/runtime/src/main/scala/org/apache/pekko/grpc/internal/GrpcResponseHelpers.scala
+++
b/runtime/src/main/scala/org/apache/pekko/grpc/internal/GrpcResponseHelpers.scala
@@ -19,11 +19,21 @@ import pekko.actor.{ ActorSystem,
ClassicActorSystemProvider }
import pekko.annotation.InternalApi
import pekko.grpc.GrpcProtocol.{ GrpcProtocolWriter, TrailerFrame }
import pekko.grpc.scaladsl.{ headers, GrpcExceptionHandler }
-import pekko.grpc.{ ProtobufSerializer, Trailers }
+import pekko.grpc.{ ProtobufFrameSerializer, ProtobufSerializer, Trailers }
import pekko.http.scaladsl.model.HttpEntity.ChunkStreamPart
-import pekko.http.scaladsl.model.{ HttpEntity, HttpResponse, Trailer }
+import pekko.http.scaladsl.model.{
+ AttributeKey,
+ AttributeKeys,
+ HttpEntity,
+ HttpHeader,
+ HttpProtocols,
+ HttpResponse,
+ StatusCodes,
+ Trailer
+}
import pekko.stream.Materializer
import pekko.stream.scaladsl.Source
+import pekko.util.ByteString
import io.grpc.Status
import scala.collection.immutable
@@ -39,6 +49,10 @@ import scala.util.control.NonFatal
object GrpcResponseHelpers {
private val TrailerOk = GrpcEntityHelpers.trailer(Status.OK)
private val TrailerOkAttribute = Trailer(TrailerOk.trailers)
+ private val TrailerOkAttributes =
+ Map.empty[AttributeKey[_], Any].updated(AttributeKeys.trailer,
TrailerOkAttribute)
+ private val IdentityResponseHeaders: immutable.Seq[HttpHeader] =
+ headers.`Message-Encoding`(Identity.name) :: Nil
def apply[T](e: Source[T, NotUsed])(
implicit m: ProtobufSerializer[T],
@@ -56,13 +70,38 @@ object GrpcResponseHelpers {
implicit m: ProtobufSerializer[T],
writer: GrpcProtocolWriter,
system: ClassicActorSystemProvider): HttpResponse = {
- val responseHeaders =
headers.`Message-Encoding`(writer.messageEncoding.name) :: Nil
- try writer.encodeDataToResponse(m.serialize(e), responseHeaders,
TrailerOkAttribute)
- catch {
+ val responseHeaders = responseHeadersFor(writer)
+ try {
+ if ((writer.messageEncoding eq Identity) && writer.contentType ==
GrpcProtocolNative.contentType) {
+ m match {
+ case frameSerializer: ProtobufFrameSerializer[T @unchecked] =>
+ nativeResponse(writer, frameSerializer.serializeDataFrame(e),
responseHeaders)
+ case _ =>
+ writer.encodeDataToResponse(m.serialize(e), responseHeaders,
TrailerOkAttribute)
+ }
+ } else {
+ writer.encodeDataToResponse(m.serialize(e), responseHeaders,
TrailerOkAttribute)
+ }
+ } catch {
case NonFatal(ex) => status(GrpcEntityHelpers.handleException(ex,
eHandler))
}
}
+ private def responseHeadersFor(writer: GrpcProtocolWriter):
immutable.Seq[HttpHeader] =
+ if (writer.messageEncoding eq Identity) IdentityResponseHeaders
+ else headers.`Message-Encoding`(writer.messageEncoding.name) :: Nil
+
+ private def nativeResponse(
+ writer: GrpcProtocolWriter,
+ encodedData: ByteString,
+ responseHeaders: immutable.Seq[HttpHeader]): HttpResponse =
+ new HttpResponse(
+ status = StatusCodes.OK,
+ headers = responseHeaders,
+ entity = HttpEntity(writer.contentType, encodedData),
+ protocol = HttpProtocols.`HTTP/1.1`,
+ attributes = TrailerOkAttributes)
+
def apply[T](e: Source[T, NotUsed], status: Future[Status])(
implicit m: ProtobufSerializer[T],
mat: Materializer,
diff --git
a/runtime/src/main/scala/org/apache/pekko/grpc/javadsl/GoogleProtobufSerializer.scala
b/runtime/src/main/scala/org/apache/pekko/grpc/javadsl/GoogleProtobufSerializer.scala
index 45d28860..c71b6ed1 100644
---
a/runtime/src/main/scala/org/apache/pekko/grpc/javadsl/GoogleProtobufSerializer.scala
+++
b/runtime/src/main/scala/org/apache/pekko/grpc/javadsl/GoogleProtobufSerializer.scala
@@ -15,17 +15,30 @@ package org.apache.pekko.grpc.javadsl
import org.apache.pekko
import pekko.annotation.ApiMayChange
-import pekko.grpc.ProtobufSerializer
+import pekko.grpc.ProtobufFrameSerializer
+import pekko.grpc.internal.AbstractGrpcProtocol
import pekko.util.ByteString
+import com.google.protobuf.CodedOutputStream
import com.google.protobuf.Parser
import java.io.InputStream
@ApiMayChange
-class GoogleProtobufSerializer[T <: com.google.protobuf.Message](parser:
Parser[T]) extends ProtobufSerializer[T] {
+class GoogleProtobufSerializer[T <: com.google.protobuf.Message](parser:
Parser[T]) extends ProtobufFrameSerializer[T] {
override def serialize(t: T): ByteString =
ByteString.fromArrayUnsafe(t.toByteArray)
+ override private[grpc] def serializeDataFrame(t: T): ByteString = {
+ val dataLength = t.getSerializedSize
+ val frame = new Array[Byte](AbstractGrpcProtocol.FrameHeaderSize +
dataLength)
+ AbstractGrpcProtocol.writeFrameHeader(frame, 0, dataLength, isCompressed =
false, isTrailer = false)
+
+ val output = CodedOutputStream.newInstance(frame,
AbstractGrpcProtocol.FrameHeaderSize, dataLength)
+ t.writeTo(output)
+ output.checkNoSpaceLeft()
+
+ ByteString.fromArrayUnsafe(frame)
+ }
override def deserialize(bytes: ByteString): T = {
val inputStream = bytes.asInputStream
try parser.parseFrom(inputStream)
diff --git
a/runtime/src/main/scala/org/apache/pekko/grpc/javadsl/GrpcMarshalling.scala
b/runtime/src/main/scala/org/apache/pekko/grpc/javadsl/GrpcMarshalling.scala
index 07fdee0f..b4317eee 100644
--- a/runtime/src/main/scala/org/apache/pekko/grpc/javadsl/GrpcMarshalling.scala
+++ b/runtime/src/main/scala/org/apache/pekko/grpc/javadsl/GrpcMarshalling.scala
@@ -20,6 +20,7 @@ import org.apache.pekko
import pekko.NotUsed
import pekko.actor.ActorSystem
import pekko.actor.ClassicActorSystemProvider
+import pekko.annotation.InternalApi
import pekko.grpc._
import pekko.grpc.internal._
import pekko.grpc.GrpcProtocol.{ GrpcProtocolReader, GrpcProtocolWriter }
@@ -30,6 +31,7 @@ import pekko.stream.javadsl.Source
import pekko.util.ByteString
import scala.annotation.nowarn
+import scala.util.control.NonFatal
object GrpcMarshalling {
@@ -56,7 +58,12 @@ object GrpcMarshalling {
u: ProtobufSerializer[T],
mat: Materializer,
reader: GrpcProtocolReader): CompletionStage[T] =
- unmarshal(entity.getDataBytes, u, mat, reader)
+ entity match {
+ case strict: pekko.http.scaladsl.model.HttpEntity.Strict =>
+ completedOrFailed(u.deserialize(reader.decodeSingleFrame(strict.data)))
+ case _ =>
+ unmarshal(entity.getDataBytes, u, mat, reader)
+ }
def unmarshalStream[T](
data: Source[ByteString, AnyRef],
@@ -98,9 +105,58 @@ object GrpcMarshalling {
: HttpResponse =
GrpcResponseHelpers(e.asScala, scalaAnonymousPartialFunction(eHandler))(m,
writer, system)
+ @InternalApi
+ def handleUnaryResponse[Out](
+ response: CompletionStage[Out],
+ m: ProtobufSerializer[Out],
+ writer: GrpcProtocolWriter,
+ system: ClassicActorSystemProvider,
+ eHandler: JFunction[ActorSystem, JFunction[Throwable, Trailers]]):
CompletionStage[HttpResponse] =
+ try {
+ response match {
+ case future: CompletableFuture[_] if future.isDone =>
+ try completedResponse(marshal(completedValue[Out](future), m,
writer, system, eHandler))
+ catch {
+ case NonFatal(error) => handleUnaryFailure(error, writer, system,
eHandler)
+ }
+ case _ =>
+ response
+ .thenApply(out => marshal(out, m, writer, system, eHandler))
+ .exceptionally(error => GrpcExceptionHandler.standard(error,
eHandler, writer, system))
+ }
+ } catch {
+ case NonFatal(error) => handleUnaryFailure(error, writer, system,
eHandler)
+ }
+
+ @InternalApi
+ def handleUnaryFailure(error: Throwable): CompletionStage[HttpResponse] =
+ if (NonFatal(error)) failure(error)
+ else throw error
+
+ @InternalApi
+ def handleUnaryFailure(
+ error: Throwable,
+ writer: GrpcProtocolWriter,
+ system: ClassicActorSystemProvider,
+ eHandler: JFunction[ActorSystem, JFunction[Throwable, Trailers]]):
CompletionStage[HttpResponse] =
+ if (NonFatal(error))
completedResponse(GrpcExceptionHandler.standard(error, eHandler, writer,
system))
+ else throw error
+
+ private def completedResponse(response: HttpResponse):
CompletableFuture[HttpResponse] =
+ CompletableFuture.completedFuture(response)
+
private def failure[R](error: Throwable): CompletableFuture[R] = {
val future: CompletableFuture[R] = new CompletableFuture()
future.completeExceptionally(error)
future
}
+
+ private def completedOrFailed[R](value: => R): CompletionStage[R] =
+ try CompletableFuture.completedFuture(value)
+ catch {
+ case NonFatal(error) => failure(error)
+ }
+
+ private def completedValue[T](future: CompletableFuture[_]): T =
+ future.asInstanceOf[CompletableFuture[T]].getNow(null.asInstanceOf[T])
}
diff --git
a/runtime/src/main/scala/org/apache/pekko/grpc/scaladsl/GrpcMarshalling.scala
b/runtime/src/main/scala/org/apache/pekko/grpc/scaladsl/GrpcMarshalling.scala
index e99285e3..42b248fe 100644
---
a/runtime/src/main/scala/org/apache/pekko/grpc/scaladsl/GrpcMarshalling.scala
+++
b/runtime/src/main/scala/org/apache/pekko/grpc/scaladsl/GrpcMarshalling.scala
@@ -16,8 +16,9 @@ package org.apache.pekko.grpc.scaladsl
import io.grpc.Status
import scala.annotation.nowarn
-import scala.concurrent.Future
+import scala.concurrent.{ ExecutionContext, Future }
import scala.util.{ Failure, Success, Try }
+import scala.util.control.NonFatal
import org.apache.pekko
import pekko.NotUsed
@@ -28,6 +29,7 @@ import pekko.grpc._
import pekko.grpc.GrpcProtocol.{ GrpcProtocolReader, GrpcProtocolWriter }
import pekko.grpc.internal._
import pekko.http.scaladsl.model.{ HttpEntity, HttpRequest, HttpResponse, Uri }
+import pekko.http.scaladsl.util.FastFuture
import pekko.stream.Materializer
import pekko.stream.scaladsl.Source
import pekko.util.ByteString
@@ -108,6 +110,84 @@ object GrpcMarshalling {
GrpcResponseHelpers(e, eHandler)
}
+ @InternalApi
+ def handleUnary[In, Out](
+ entity: HttpEntity,
+ implementation: In => Future[Out],
+ eHandler: ActorSystem => PartialFunction[Throwable, Trailers])(
+ implicit u: ProtobufSerializer[In],
+ m: ProtobufSerializer[Out],
+ mat: Materializer,
+ reader: GrpcProtocolReader,
+ writer: GrpcProtocolWriter,
+ system: ClassicActorSystemProvider,
+ ec: ExecutionContext): Future[HttpResponse] = {
+ entity match {
+ case HttpEntity.Strict(_, data) =>
+ try {
+ val in = u.deserialize(reader.decodeSingleFrame(data))
+ invokeUnary(in, implementation, eHandler)
+ } catch {
+ case NonFatal(ex) => unaryExceptionHandler(eHandler)(system,
writer)(ex)
+ }
+ case _ =>
+ val requestFuture = unmarshal[In](entity)(u, mat, reader)
+ requestFuture.value match {
+ case Some(Success(in)) => invokeUnary(in, implementation, eHandler)
+ case Some(Failure(ex)) => unaryExceptionHandler(eHandler)(system,
writer)(ex)
+ case None =>
+ val exceptionHandler = unaryExceptionHandler(eHandler)
+ requestFuture
+ .flatMap(in => invokeUnary(in, implementation, eHandler))
+ .recoverWith(exceptionHandler)
+ }
+ }
+ }
+
+ @inline private def invokeUnary[In, Out](
+ in: In,
+ implementation: In => Future[Out],
+ eHandler: ActorSystem => PartialFunction[Throwable, Trailers])(
+ implicit m: ProtobufSerializer[Out],
+ writer: GrpcProtocolWriter,
+ system: ClassicActorSystemProvider,
+ ec: ExecutionContext): Future[HttpResponse] =
+ try handleUnaryResponse(implementation(in), eHandler)
+ catch {
+ case NonFatal(ex) => unaryExceptionHandler(eHandler)(system, writer)(ex)
+ }
+
+ @inline private def handleUnaryResponse[Out](
+ responseFuture: Future[Out],
+ eHandler: ActorSystem => PartialFunction[Throwable, Trailers])(
+ implicit m: ProtobufSerializer[Out],
+ writer: GrpcProtocolWriter,
+ system: ClassicActorSystemProvider,
+ ec: ExecutionContext): Future[HttpResponse] =
+ responseFuture.value match {
+ case Some(Success(out)) => marshalUnaryResponse(out, eHandler)
+ case Some(Failure(ex)) => unaryExceptionHandler(eHandler)(system,
writer)(ex)
+ case None =>
+ val exceptionHandler = unaryExceptionHandler(eHandler)
+ responseFuture.map(out => marshal[Out](out, eHandler)(m, writer,
system)).recoverWith(exceptionHandler)
+ }
+
+ @inline private def marshalUnaryResponse[Out](
+ out: Out,
+ eHandler: ActorSystem => PartialFunction[Throwable, Trailers])(
+ implicit m: ProtobufSerializer[Out],
+ writer: GrpcProtocolWriter,
+ system: ClassicActorSystemProvider): Future[HttpResponse] =
+ try FastFuture.successful(marshal[Out](out, eHandler)(m, writer, system))
+ catch {
+ case NonFatal(ex) => unaryExceptionHandler(eHandler)(system, writer)(ex)
+ }
+
+ @inline private def unaryExceptionHandler(eHandler: ActorSystem =>
PartialFunction[Throwable, Trailers])(
+ implicit system: ClassicActorSystemProvider,
+ writer: GrpcProtocolWriter): PartialFunction[Throwable,
Future[HttpResponse]] =
+ GrpcExceptionHandler.from(eHandler(system.classicSystem))
+
@InternalApi
def marshalRequest[T](
uri: Uri,
diff --git
a/runtime/src/main/scala/org/apache/pekko/grpc/scaladsl/ScalapbProtobufSerializer.scala
b/runtime/src/main/scala/org/apache/pekko/grpc/scaladsl/ScalapbProtobufSerializer.scala
index 1e997c66..a9884852 100644
---
a/runtime/src/main/scala/org/apache/pekko/grpc/scaladsl/ScalapbProtobufSerializer.scala
+++
b/runtime/src/main/scala/org/apache/pekko/grpc/scaladsl/ScalapbProtobufSerializer.scala
@@ -15,8 +15,10 @@ package org.apache.pekko.grpc.scaladsl
import org.apache.pekko
import pekko.annotation.ApiMayChange
-import pekko.grpc.ProtobufSerializer
+import pekko.grpc.ProtobufFrameSerializer
+import pekko.grpc.internal.AbstractGrpcProtocol
import pekko.util.ByteString
+import com.google.protobuf.CodedOutputStream
import com.google.protobuf.CodedInputStream
import scalapb.{ GeneratedMessage, GeneratedMessageCompanion }
@@ -24,9 +26,20 @@ import java.io.InputStream
@ApiMayChange
class ScalapbProtobufSerializer[T <: GeneratedMessage](companion:
GeneratedMessageCompanion[T])
- extends ProtobufSerializer[T] {
+ extends ProtobufFrameSerializer[T] {
override def serialize(t: T): ByteString =
ByteString.fromArrayUnsafe(t.toByteArray)
+ override private[grpc] def serializeDataFrame(t: T): ByteString = {
+ val dataLength = t.serializedSize
+ val frame = new Array[Byte](AbstractGrpcProtocol.FrameHeaderSize +
dataLength)
+ AbstractGrpcProtocol.writeFrameHeader(frame, 0, dataLength, isCompressed =
false, isTrailer = false)
+
+ val output = CodedOutputStream.newInstance(frame,
AbstractGrpcProtocol.FrameHeaderSize, dataLength)
+ t.writeTo(output)
+ output.checkNoSpaceLeft()
+
+ ByteString.fromArrayUnsafe(frame)
+ }
override def deserialize(bytes: ByteString): T =
companion.parseFrom(CodedInputStream.newInstance(bytes.asByteBuffer))
override def deserialize(data: InputStream): T =
diff --git
a/runtime/src/test/scala/org/apache/pekko/grpc/javadsl/GrpcMarshallingSpec.scala
b/runtime/src/test/scala/org/apache/pekko/grpc/javadsl/GrpcMarshallingSpec.scala
new file mode 100644
index 00000000..c35c50f4
--- /dev/null
+++
b/runtime/src/test/scala/org/apache/pekko/grpc/javadsl/GrpcMarshallingSpec.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.pekko.grpc.javadsl
+
+import java.util.concurrent.{ CompletableFuture, TimeUnit }
+
+import scala.concurrent.Await
+import scala.concurrent.duration._
+
+import com.google.protobuf.{ Any => ProtobufAny, ByteString =>
ProtobufByteString }
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.wordspec.AnyWordSpec
+
+import org.apache.pekko
+import pekko.actor.ActorSystem
+import pekko.grpc.internal.{ AbstractGrpcProtocol, GrpcProtocolNative,
Identity }
+import pekko.http.scaladsl.model.HttpEntity
+import pekko.stream.SystemMaterializer
+
+class GrpcMarshallingSpec extends AnyWordSpec with Matchers {
+ "The javadsl GrpcMarshalling" should {
+ "unmarshal a strict unary entity" in {
+ val system = ActorSystem("GrpcMarshallingSpec")
+ try {
+ val mat = SystemMaterializer(system).materializer
+ val serializer = new GoogleProtobufSerializer(ProtobufAny.parser())
+ val message =
+
ProtobufAny.newBuilder().setTypeUrl("benchmark").setValue(ProtobufByteString.copyFromUtf8("payload")).build()
+ val entity =
+ HttpEntity.Strict(
+ GrpcProtocolNative.contentType,
+
AbstractGrpcProtocol.encodeFrameData(serializer.serialize(message),
isCompressed = false,
+ isTrailer = false))
+
+ val result =
+ GrpcMarshalling
+ .unmarshal(entity, serializer, mat,
GrpcProtocolNative.newReader(Identity))
+ .toCompletableFuture
+ .get(10, TimeUnit.SECONDS)
+
+ result should be(message)
+ } finally {
+ Await.result(system.terminate(), 10.seconds)
+ }
+ }
+
+ "recover a failed unary response stage" in {
+ val system = ActorSystem("GrpcMarshallingSpec")
+ try {
+ val serializer = new GoogleProtobufSerializer(ProtobufAny.parser())
+ val responseFuture = new CompletableFuture[ProtobufAny]()
+ responseFuture.completeExceptionally(new RuntimeException("boom"))
+
+ val response =
+ GrpcMarshalling
+ .handleUnaryResponse(
+ responseFuture,
+ serializer,
+ GrpcProtocolNative.newWriter(Identity),
+ system,
+ GrpcExceptionHandler.defaultMapper)
+ .toCompletableFuture
+ .get(10, TimeUnit.SECONDS)
+
+ response.getHeader("grpc-status").get().value() should be("13")
+ } finally {
+ Await.result(system.terminate(), 10.seconds)
+ }
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]