hvanhovell commented on code in PR #55657: URL: https://github.com/apache/spark/pull/55657#discussion_r3260439677
########## udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/EchoProtocolSuite.scala: ########## @@ -0,0 +1,938 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core + +import java.util.concurrent.{CountDownLatch, LinkedBlockingQueue, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean + +import com.google.protobuf.ByteString +import io.grpc.stub.StreamObserver +import io.grpc.{ManagedChannel, Server, Status} +import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder} +import org.apache.spark.udf.worker.UdfWorkerGrpc + +import org.apache.spark.udf.worker.{ + Cancel, CancelResponse, DataRequest, DataResponse, + ErrorResponse, ExecutionError, UserError, WorkerError, ProtocolError, + Finish, FinishResponse, Heartbeat, HeartbeatResponse, + Init, InitResponse, PayloadChunk, ShutdownRequest, ShutdownResponse, + UDFWorkerDataFormat, UdfControlRequest, UdfControlResponse, + UdfPayload, UdfRequest, UdfResponse, WorkerRequest, WorkerResponse +} + +// scalastyle:off funsuite +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.BeforeAndAfterEach + +/** + * Protocol validation test for the UDF gRPC execution protocol. + * + * Implements a minimal echo worker (gRPC server) and engine client to verify + * the full Execute stream lifecycle: init, data streaming, finish, cancel, + * error handling, and the Manage RPC. The worker echoes each DataRequest + * batch back as a DataResponse; error paths are triggered by a sentinel + * payload value. + */ +class EchoProtocolSuite extends AnyFunSuite with BeforeAndAfterEach { +// scalastyle:on funsuite + + private val SUPPORTED_VERSION: Int = 1 + // A DataRequest whose payload equals this value triggers an ErrorResponse. + private val ERROR_TRIGGER: ByteString = ByteString.copyFromUtf8("ERROR") + // An init payload whose value equals this triggers an init failure + // (InitResponse with error set). + private val INIT_ERROR_TRIGGER: ByteString = ByteString.copyFromUtf8("INIT_ERROR") + + private var server: Server = _ + private var channel: ManagedChannel = _ + private var stub: UdfWorkerGrpc.UdfWorkerStub = _ + + override def beforeEach(): Unit = { + val serverName = InProcessServerBuilder.generateName() + server = InProcessServerBuilder.forName(serverName) + .directExecutor() + .addService(new EchoWorkerService) + .build() + .start() + channel = InProcessChannelBuilder.forName(serverName).directExecutor().build() + stub = UdfWorkerGrpc.newStub(channel) + } + + override def afterEach(): Unit = { + channel.shutdownNow() + server.shutdownNow() + } + + // =========================================================================== + // WORKER SIDE (gRPC server) + // =========================================================================== + + /** + * Worker state machine for one Execute stream. + * + * AwaitingInit --> AwaitingChunks? --> Data --> Draining --> Drained --> Done + * | + * +--> PostError --> Cancelling --> Cancelled --> Done + * + * `[process done]` marks an event (not a state): the asynchronous + * completion notification of in-flight work. + * + * AwaitingInit + * Init(inline) --> Data (send InitResponse) + * Init(chunked) --> AwaitingChunks + * Init(failed) --> PostError (send InitResponse with error) + * Cancel --> Cancelling --[process done]--> Cancelled --> Done (send CR) + * + * AwaitingChunks + * PayloadChunk(last=false) --> AwaitingChunks (accumulate) + * PayloadChunk(last=true) --> Data (send InitResponse) + * Cancel --> Cancelling --[process done]--> Cancelled --> Done + * Finish --> protocol error (engine must wait for + * InitResponse first) + * + * Data + * ErrorResponse sent --> PostError + * Finish --> Draining --[process done]--> Drained --> Done (send FR) + * Cancel --> Cancelling --[process done]--> Cancelled --> Done (send CR) + * + * PostError + * Finish --> PostError (no-op; engine MUST follow up with Cancel) + * Cancel --> Cancelling --[process done]--> Cancelled --> Done + * + * Draining (in-flight work running) + * [process done] --> Drained + * Cancel --> Cancelling (the pending [process done] then sees + * Cancelling and routes to CancelResponse) + * + * Drained (work done; optional post-work cleanup hook may run here; + * any error from the hook is reported via FinishResponse.error) + * --> Done (send FinishResponse) + * + * Cancelling (in-flight work being cancelled) + * [process done] --> Cancelled + * + * Cancelled (cleanup done; optional post-work cleanup hook may run here; + * any error from the hook is reported via CancelResponse.error) + * --> Done (send CancelResponse) + * + * Cross-cutting: + * - Protocol violation in any active state: send ErrorResponse(ProtocolError) + * followed by CancelResponse, transition to Done. + * - gRPC transport error (onError): transition to Done, no response sent. + */ + private sealed trait WorkerState + private case object AwaitingInit extends WorkerState + // Chunked init handshake in progress; `accumulated` holds the inline + // portion of Init.udf.payload plus all chunks received so far. + private case class AwaitingChunks(accumulated: ByteString) extends WorkerState + private case object Data extends WorkerState + private case object PostError extends WorkerState + // Finish received; in-flight finish-callback / drain work is running. + private case object Draining extends WorkerState + // Drain complete; FinishResponse not yet sent. Post-work cleanup hook + // (if any) runs in this state before the terminator is emitted. + private case object Drained extends WorkerState + // Cancel received; cancel-callback / cleanup work is running. + private case object Cancelling extends WorkerState + // Cleanup complete; CancelResponse not yet sent. Post-work cleanup hook + // (if any) runs in this state before the terminator is emitted. + private case object Cancelled extends WorkerState + private case object Done extends WorkerState + + private class EchoWorkerService extends UdfWorkerGrpc.UdfWorkerImplBase { + + override def execute( + responseObserver: StreamObserver[UdfResponse]): StreamObserver[UdfRequest] = + new ExecuteStreamHandler(responseObserver) + + override def manage( + request: WorkerRequest, + responseObserver: StreamObserver[WorkerResponse]): Unit = { + request.getManageCase match { + case WorkerRequest.ManageCase.HEARTBEAT => + responseObserver.onNext(WorkerResponse.newBuilder() + .setHeartbeat(HeartbeatResponse.getDefaultInstance) + .build()) + responseObserver.onCompleted() + + case WorkerRequest.ManageCase.SHUTDOWN => + responseObserver.onNext(WorkerResponse.newBuilder() + .setShutdown(ShutdownResponse.newBuilder().setSessionsSettled(true).build()) + .build()) + responseObserver.onCompleted() + + case _ => + responseObserver.onError( + Status.INVALID_ARGUMENT.withDescription("empty manage request") + .asRuntimeException()) + } + } + } + + private class ExecuteStreamHandler( + responseObserver: StreamObserver[UdfResponse]) extends StreamObserver[UdfRequest] { + + // State mutations go through `matchUpdateThen`: under stateLock, the + // caller-supplied function inspects the current state, returns the + // next state and a non-blocking follow-up callback; the helper writes + // the new state and releases the lock before invoking the callback, + // so I/O does not extend the critical section. + @volatile private var state: WorkerState = AwaitingInit + private val stateLock = new Object + + private def matchUpdateThen( + transition: WorkerState => (WorkerState, () => Unit)): Unit = { + val followUp = stateLock.synchronized { + val (next, callback) = transition(state) + state = next + callback + } + followUp() + } + + // gRPC does not permit concurrent calls to the response StreamObserver; + // all writes are serialized through this lock. + private val responseLock = new Object + + override def onNext(request: UdfRequest): Unit = { + request.getRequestCase match { + case UdfRequest.RequestCase.CONTROL => handleControl(request.getControl) + case UdfRequest.RequestCase.DATA => handleDataRequest(request.getData) + case _ => closeWithProtocolError("empty request oneof") + } + } + + private def handleControl(ctrl: UdfControlRequest): Unit = { + ctrl.getControlCase match { + case UdfControlRequest.ControlCase.INIT => handleInit(ctrl.getInit) + case UdfControlRequest.ControlCase.PAYLOAD => handleChunk(ctrl.getPayload) + case UdfControlRequest.ControlCase.FINISH => handleFinish() + case UdfControlRequest.ControlCase.CANCEL => handleCancel(ctrl.getCancel) + case _ => closeWithProtocolError("empty control oneof") + } + } + + private def handleInit(init: Init): Unit = matchUpdateThen { + case AwaitingInit => + if (init.hasProtocolVersion && + init.getProtocolVersion != SUPPORTED_VERSION) { + val err = ExecutionError.newBuilder() + .setProtocol(ProtocolError.newBuilder() + .setMessage(s"unsupported protocol version: ${init.getProtocolVersion}") + .build()) + .build() + (PostError, () => sendControl(UdfControlResponse.newBuilder() + .setInit(InitResponse.newBuilder().setError(err).build()) + .build())) + } else if (init.getIsChunkingPayload) { + // Payload arrives via PayloadChunk messages; defer init + // processing until the last chunk has been received. + (AwaitingChunks(init.getUdf.getPayload), () => ()) + } else { + // Payload is fully inline; process init outside the lock. + // finalizeInit performs its own CAS on entry. + val payload = init.getUdf.getPayload + (AwaitingInit, () => finalizeInit(payload)) + } + case other => + (other, () => closeWithProtocolError(s"Init received in state $other")) + } + + private def handleChunk(chunk: PayloadChunk): Unit = matchUpdateThen { + case AwaitingChunks(existing) => + val updated = existing.concat(chunk.getData) + if (chunk.hasLast && chunk.getLast) { + // Stay in AwaitingChunks until finalizeInit's CAS transitions us. + (AwaitingChunks(existing), () => finalizeInit(updated)) + } else { + (AwaitingChunks(updated), () => ()) + } + case other => + (other, () => closeWithProtocolError(s"PayloadChunk received in state $other")) + } + + // Init processing hook: invoked once with the complete assembled UDF + // payload (inline + all chunks, if any). A real worker would deserialize + // the UDF, run validation, set up runtime resources here. The echo worker + // succeeds for any payload other than INIT_ERROR_TRIGGER, which simulates + // an init-time failure (e.g. deserialization error, missing dependency). + private def finalizeInit(payload: ByteString): Unit = { + val initError: Option[ExecutionError] = if (payload == INIT_ERROR_TRIGGER) { + Some(ExecutionError.newBuilder() + .setWorker(WorkerError.newBuilder() + .setMessage("simulated init failure") + .build()) + .build()) + } else { + None + } + matchUpdateThen { + case AwaitingInit | AwaitingChunks(_) => + initError match { + case Some(err) => + (PostError, () => sendControl(UdfControlResponse.newBuilder() + .setInit(InitResponse.newBuilder().setError(err).build()) + .build())) + case None => + (Data, () => sendInitResponse()) + } + // Concurrent Cancel / transport error moved state past the init + // phase; the cancel path owns the terminator. + case other @ (Cancelling | Cancelled | Done) => + (other, () => ()) + case other => + (other, () => closeWithProtocolError(s"finalizeInit invoked in state $other")) + } + } + + private def handleDataRequest(data: DataRequest): Unit = state match { + case Data => processEcho(data) + + case _ => closeWithProtocolError(s"DataRequest received in state $state") + } + + // Echo "processing" runs inline on the gRPC callback thread for test + // simplicity. Workers that offload to a thread pool (the typical + // approach for non-trivial UDFs) must apply back-pressure via a + // bounded queue and serialize state mutations across threads. + private def processEcho(data: DataRequest): Unit = { + if (data.getData == ERROR_TRIGGER) { + // Data-phase error: emit ErrorResponse and enter PostError so the + // terminator becomes CancelResponse after the engine's Cancel. + // Only transition if we are still in Data: a concurrent Cancel + // may have moved us to Cancelling, in which case the cancel path + // owns the terminator. + val errEnvelope = UdfControlResponse.newBuilder() + .setError(ErrorResponse.newBuilder() + .setError(ExecutionError.newBuilder() + .setUser(UserError.newBuilder() + .setMessage("simulated user-code error") + .setErrorClass("SimulatedError") + .build()) + .build()) + .build()) + .build() + matchUpdateThen { + case Data => (PostError, () => sendControl(errEnvelope)) + // Concurrent Cancel / transport error already moved past data + // phase; the cancel path owns the terminator. + case other @ (Cancelling | Cancelled | Done) => (other, () => ()) + case other => + (other, () => closeWithProtocolError(s"processEcho invoked in state $other")) + } + } else { + responseLock.synchronized { + responseObserver.onNext(UdfResponse.newBuilder() + .setData(DataResponse.newBuilder().setData(data.getData).build()) + .build()) + } + } + } + + private def handleFinish(): Unit = matchUpdateThen { + case Data => + (Draining, () => onWorkComplete()) + case PostError => + // ErrorResponse already sent; this Finish was in flight before the + // engine learned about the error. The engine MUST follow up with + // Cancel; wait for it. + (PostError, () => ()) + // Finish in AwaitingInit or AwaitingChunks is a protocol error: + // the engine MUST wait for InitResponse before sending Finish. + case other => + (other, () => closeWithProtocolError(s"Finish received in state $other")) + } + + // Lazy-cancel: transition to Cancelling and let any in-flight work run + // to natural completion; the pending onWorkComplete (or this method's + // own follow-up call when no work is in flight) sees Cancelling and + // routes to CancelResponse. + private def handleCancel(cancel: Cancel): Unit = matchUpdateThen { + case AwaitingInit | AwaitingChunks(_) | Data | PostError | Draining | Drained => + (Cancelling, () => onWorkComplete()) + case other @ (Cancelling | Cancelled | Done) => + // Already cancelling or terminated; ignore duplicate Cancel. + (other, () => ()) + } + + // Called when in-flight work (finish callback, cancel cleanup, or + // batch processing) completes. The current state decides the + // terminator: + // Draining -> Drained -> send FinishResponse + // Cancelling -> Cancelled -> send CancelResponse + // + // An optional post-work cleanup hook (release file handles, flush + // metrics) belongs between the state transition and the terminator + // send. Any error from the hook is reported via FinishResponse.error + // or CancelResponse.error. + private def onWorkComplete(): Unit = matchUpdateThen { + case Draining => + (Drained, () => sendFinishResponseAndFinalize()) + case Cancelling => + (Cancelled, () => sendCancelResponseAndFinalize()) + // Stream already finalized (e.g. onError fired before this + // completion notification arrived) -- nothing to do. + case Done => (Done, () => ()) + case other => + (other, () => closeWithProtocolError(s"onWorkComplete invoked in state $other")) + } + + private def sendFinishResponseAndFinalize(): Unit = { + sendControl(UdfControlResponse.newBuilder() + .setFinish(FinishResponse.newBuilder() + .putMetrics("status", "ok") + .build()) + .build()) + matchUpdateThen { _ => + (Done, () => responseLock.synchronized { responseObserver.onCompleted() }) + } + } + + private def sendCancelResponseAndFinalize(): Unit = { + sendControl(UdfControlResponse.newBuilder() + .setCancel(CancelResponse.getDefaultInstance) + .build()) + matchUpdateThen { _ => + (Done, () => responseLock.synchronized { responseObserver.onCompleted() }) + } + } + + // gRPC transport error: the connection dropped. The stream is dead, + // so no response can be sent. The worker MUST still run the cleanup + // it would perform on an explicit Cancel (stop in-progress work, + // release resources, free buffers). The echo worker has nothing to + // release; only the state is updated. + override def onError(t: Throwable): Unit = matchUpdateThen { _ => + (Done, () => ()) + } + + override def onCompleted(): Unit = state match { + case Done => // normal: engine half-closed after session terminated + case _ => + closeWithProtocolError( + s"request stream closed by engine in unexpected state $state") + } + + private def sendInitResponse(): Unit = + sendControl(UdfControlResponse.newBuilder() + .setInit(InitResponse.getDefaultInstance) + .build()) + + private def sendControl(ctrl: UdfControlResponse): Unit = + responseLock.synchronized { + responseObserver.onNext( + UdfResponse.newBuilder().setControl(ctrl).build()) + } + + // Emit ErrorResponse(ProtocolError) followed immediately by + // CancelResponse. No in-flight work to drain, so the Cancelling / + // Cancelled intermediate states are bypassed. + private def closeWithProtocolError(msg: String): Unit = { + sendControl(UdfControlResponse.newBuilder() + .setError(ErrorResponse.newBuilder() + .setError(ExecutionError.newBuilder() + .setProtocol(ProtocolError.newBuilder().setMessage(msg).build()) + .build()) + .build()) + .build()) + sendCancelResponseAndFinalize() + } + } + + // =========================================================================== + // ENGINE SIDE (gRPC client) + // =========================================================================== + + /** + * Minimal engine client that drives the Execute stream and collects results. + * + * The request stream is half-closed (onCompleted) only after the session + * outcome is known from the server: on receiving FinishResponse, + * CancelResponse, or a gRPC error. This keeps the stream open long enough + * for Cancel to follow Finish when needed. + */ + private class EngineClient(stub: UdfWorkerGrpc.UdfWorkerStub) { + private val results = new LinkedBlockingQueue[Array[Byte]]() + private val done = new CountDownLatch(1) + @volatile var executionError: Option[ExecutionError] = None + @volatile var streamError: Option[Throwable] = None + private val requestCompleted = new AtomicBoolean(false) + // Counted down on InitResponse (success or failure) or on terminal error. + // The engine MUST wait for this before sending any DataRequest or Finish. + private val initResponseLatch = new CountDownLatch(1) + + private val responseObserver = new StreamObserver[UdfResponse] { + override def onNext(response: UdfResponse): Unit = { + response.getResponseCase match { + case UdfResponse.ResponseCase.DATA => + results.add(response.getData.getData.toByteArray) + + case UdfResponse.ResponseCase.CONTROL => + val ctrl = response.getControl + ctrl.getControlCase match { + case UdfControlResponse.ControlCase.INIT => + // InitResponse received. If error is set, init failed. + val resp = ctrl.getInit + if (resp.hasError) { + executionError = Some(resp.getError) + if (!requestCompleted.get()) sendCancel("aborting after init error") + } + initResponseLatch.countDown() + // Data phase begins only on success (no error). + + case UdfControlResponse.ControlCase.ERROR => + // Data-phase error. Send Cancel so the worker can abort cleanly; + // the error is surfaced after CancelResponse arrives. + executionError = Some(ctrl.getError.getError) + if (!requestCompleted.get()) { + sendCancel("aborting after ErrorResponse") + } + + case UdfControlResponse.ControlCase.FINISH => + completeRequestStream() + done.countDown() + + case UdfControlResponse.ControlCase.CANCEL => + completeRequestStream() + done.countDown() + + case _ => Review Comment: Given that this is reference implementation. Let's make sure we are very strict and throw exceptions whenever we encounter an unknown/unhandled situation. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
