This is an automated email from the ASF dual-hosted git repository.

He-Pin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/pekko.git


The following commit(s) were added to refs/heads/main by this push:
     new 52747cdfee fix: harden async DNS request id handling (#2960)
52747cdfee is described below

commit 52747cdfeed71304a32ad660008051a3e15473ff
Author: He-Pin(kerr) <[email protected]>
AuthorDate: Sun May 10 03:51:52 2026 +0800

    fix: harden async DNS request id handling (#2960)
    
    Motivation:
    The async DNS resolver should avoid reusing request ids that are still 
active, should only release failed ids after the DNS client confirms the drop, 
and should ignore UDP DNS responses from unexpected remotes.
    
    Modification:
    Track active request ids in the request-id injector, keep duplicate ids 
reserved, add an explicit Dropped acknowledgement for DNS client drops, match 
drops by the full DNS question including record type, and reject UDP responses 
from non-nameserver remotes with a security marker.
    
    Result:
    Concurrent async DNS requests avoid active id reuse, failed requests clean 
up ids only when confirmed safe, and spoofed UDP responses from unexpected 
remotes are ignored.
    
    Tests:
    - scalafmt --mode diff-ref=port-ef33e71-async-dns-id-injector / passed
    - scalafmt --list --mode diff-ref=port-ef33e71-async-dns-id-injector / 
passed
    - git diff --check / passed
    - sbt "actor-tests / Test / testOnly 
org.apache.pekko.io.dns.internal.AsyncDnsResolverSpec 
org.apache.pekko.io.dns.internal.DnsClientSpec 
org.apache.pekko.io.dns.DnsSettingsSpec" / passed
    
    References:
    Upstream commit: 
https://github.com/akka/akka-core/commit/7434794bab6fec95cb52d810d95126f90efb8c34,
 which is now Apache licensed
    Refs #31926
---
 .../io/dns/internal/AsyncDnsResolverSpec.scala     | 37 +++++++++++--
 .../pekko/io/dns/internal/DnsClientSpec.scala      | 55 +++++++++++++++++--
 .../pekko/io/dns/internal/AsyncDnsResolver.scala   | 64 ++++++++++++++++++++--
 .../apache/pekko/io/dns/internal/DnsClient.scala   | 39 +++++++++++--
 4 files changed, 175 insertions(+), 20 deletions(-)

diff --git 
a/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolverSpec.scala
 
b/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolverSpec.scala
index 13c40e872d..1f520a6d65 100644
--- 
a/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolverSpec.scala
+++ 
b/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolverSpec.scala
@@ -26,7 +26,7 @@ import pekko.io.dns.{ AAAARecord, ARecord, DnsSettings, 
IdGenerator, SRVRecord }
 import pekko.io.dns.CachePolicy.Ttl
 import pekko.io.dns.DnsProtocol._
 import pekko.io.dns.internal.AsyncDnsResolver.ResolveFailedException
-import pekko.io.dns.internal.DnsClient.{ Answer, DropRequest, DuplicateId, 
Question4, Question6, SrvQuestion }
+import pekko.io.dns.internal.DnsClient.{ Answer, DropRequest, Dropped, 
DuplicateId, Question4, Question6, SrvQuestion }
 import pekko.testkit.{ PekkoSpec, TestProbe, WithLogCapturing }
 
 import com.typesafe.config.{ Config, ConfigFactory, ConfigValueFactory }
@@ -322,7 +322,7 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
       override val r = resolver(
         List(dnsClient1.ref, dnsClient2.ref),
         defaultConfig,
-        deterministicIds((1 to 10).map(_.toShort): _*))
+        deterministicIds(1, 2, 1, 3, 4, 5, 6, 7, 8, 9, 10))
 
       // Send multiple resolves for different names so no in-flight 
deduplication applies
       val resolveCount = 10
@@ -347,8 +347,35 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
       }
     }
 
-    "retry request ids that duplicate an in-flight request" in new Setup {
-      override val r = resolver(List(dnsClient1.ref), defaultConfig, 
deterministicIds(1, 1, 2))
+    "reuse confirmed dropped request ids" in new Setup {
+      val config = defaultConfig.withValue("resolve-timeout", 
ConfigValueFactory.fromAnyRef("200 ms"))
+      override val r = resolver(List(dnsClient1.ref), config, 
deterministicIds(1, 1, 2))
+
+      r ! Resolve("host1.cats.com", Ip(ipv4 = true, ipv6 = false))
+      val timedOutQuestion = dnsClient1.expectMsgPF() {
+        case q: Question4 if q.name == "host1.cats.com" => q
+      }
+
+      val dropped = dnsClient1.expectMsgPF(remainingOrDefault) {
+        case DropRequest(question) if question == timedOutQuestion => question
+      }
+      dnsClient1.reply(Dropped(dropped.id))
+      senderProbe.expectMsgPF(remainingOrDefault) {
+        case Failure(ResolveFailedException(_)) =>
+      }
+
+      r ! Resolve("host2.cats.com", Ip(ipv4 = true, ipv6 = false))
+      val reusedQuestion = dnsClient1.expectMsgPF() {
+        case q: Question4 if q.name == "host2.cats.com" => q
+      }
+      reusedQuestion.id shouldBe timedOutQuestion.id
+      dnsClient1.reply(Answer(reusedQuestion.id, im.Seq.empty))
+
+      senderProbe.expectMsg(Resolved("host2.cats.com", im.Seq.empty))
+    }
+
+    "retry request ids that duplicate an untracked in-flight request" in new 
Setup {
+      override val r = resolver(List(dnsClient1.ref), defaultConfig, 
deterministicIds(1, 1, 2, 2, 3))
 
       val asker1 = TestProbe()
       val asker2 = TestProbe()
@@ -361,7 +388,7 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
 
       r.tell(Resolve("host2.cats.com", Ip(ipv4 = true, ipv6 = false)), 
asker2.ref)
       val duplicatedQuestion = dnsClient1.expectMsgPF() {
-        case q: Question4 if q.name == "host2.cats.com" && q.id == 
firstQuestion.id => q
+        case q: Question4 if q.name == "host2.cats.com" && q.id != 
firstQuestion.id => q
       }
       dnsClient1.reply(DuplicateId(duplicatedQuestion.id))
 
diff --git 
a/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/DnsClientSpec.scala
 
b/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/DnsClientSpec.scala
index 8c8e9bb4de..c84dc96375 100644
--- 
a/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/DnsClientSpec.scala
+++ 
b/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/DnsClientSpec.scala
@@ -23,7 +23,7 @@ import org.apache.pekko.actor.Status.Failure
 import pekko.actor.Props
 import pekko.io.Udp
 import pekko.io.dns.{ ARecord, CachePolicy, RecordClass, RecordType }
-import pekko.io.dns.internal.DnsClient.{ Answer, DropRequest, DuplicateId, 
Question4 }
+import pekko.io.dns.internal.DnsClient.{ Answer, DropRequest, Dropped, 
DuplicateId, Question4, Question6 }
 import pekko.testkit.{ ImplicitSender, PekkoSpec, TestProbe }
 
 class DnsClientSpec extends PekkoSpec with ImplicitSender {
@@ -150,9 +150,9 @@ class DnsClientSpec extends PekkoSpec with ImplicitSender {
         answerRecs = im.Seq(goodRecord))
       val goodAnswer = exampleRequestMessage.copy(flags = flags, answerRecs = 
im.Seq(goodRecord))
 
-      udpSocketProbe.reply(Udp.Received(internal.ByteResponse(badId), socket))
-      udpSocketProbe.reply(Udp.Received(internal.ByteResponse(badQuestion), 
socket))
-      udpSocketProbe.reply(Udp.Received(internal.ByteResponse(goodAnswer), 
socket))
+      udpSocketProbe.reply(Udp.Received(internal.ByteResponse(badId), 
dnsServerAddress))
+      udpSocketProbe.reply(Udp.Received(internal.ByteResponse(badQuestion), 
dnsServerAddress))
+      udpSocketProbe.reply(Udp.Received(internal.ByteResponse(goodAnswer), 
dnsServerAddress))
 
       val answer = Answer(exampleRequest.id, im.Seq(goodRecord), im.Seq())
       goodSenderProbe.expectMsg(answer)
@@ -182,6 +182,7 @@ class DnsClientSpec extends PekkoSpec with ImplicitSender {
       udpSend.payload shouldBe exampleRequestMessage.write()
 
       goodSenderProbe.send(client, DropRequest(exampleRequest.copy(id = 999)))
+      goodSenderProbe.expectMsg(Dropped(999))
 
       // duplicate shows inflight message not deleted
       goodSenderProbe.send(client, exampleRequest)
@@ -193,13 +194,59 @@ class DnsClientSpec extends PekkoSpec with ImplicitSender 
{
       goodSenderProbe.send(client, exampleRequest)
       goodSenderProbe.expectMsg(DuplicateId(exampleRequest.id))
 
+      goodSenderProbe.send(client, DropRequest(Question6(exampleRequest.id, 
exampleRequest.name)))
+
+      // duplicate shows inflight message not deleted
+      goodSenderProbe.send(client, exampleRequest)
+      goodSenderProbe.expectMsg(DuplicateId(exampleRequest.id))
+
       goodSenderProbe.send(client, DropRequest(exampleRequest))
+      goodSenderProbe.expectMsg(Dropped(exampleRequest.id))
 
       // no duplicate shows inflight message was deleted
       goodSenderProbe.send(client, exampleRequest)
       goodSenderProbe.expectNoMessage()
     }
 
+    "Ignore DNS responses from unexpected remotes" in {
+      val udpExtensionProbe = TestProbe()
+      val udpSocketProbe = TestProbe()
+      val tcpClientProbe = TestProbe()
+      val goodSenderProbe = TestProbe()
+
+      val socket = InetSocketAddress.createUnresolved("localhost", 41325)
+      val unexpectedAddress = InetSocketAddress.createUnresolved("bar", 53)
+      val goodRecord = ARecord(exampleRequest.name, CachePolicy.Ttl.never, 
InetAddress.getLocalHost())
+
+      val client = system.actorOf(Props(new DnsClient(dnsServerAddress) {
+        override val udp = udpExtensionProbe.ref
+
+        override def createTcpClient() = tcpClientProbe.ref
+      }))
+
+      udpExtensionProbe.expectMsgType[Udp.Bind]
+      udpSocketProbe.send(udpExtensionProbe.lastSender, Udp.Bound(socket))
+
+      goodSenderProbe.send(client, exampleRequest)
+
+      val udpSend = udpSocketProbe.expectMsgType[Udp.Send]
+      udpSend.payload shouldBe exampleRequestMessage.write()
+
+      val flags = MessageFlags(true, authoritativeAnswer = true)
+      val goodAnswer = exampleRequestMessage.copy(flags = flags, answerRecs = 
im.Seq(goodRecord))
+
+      udpSocketProbe.reply(Udp.Received(internal.ByteResponse(goodAnswer), 
unexpectedAddress))
+
+      // duplicate shows unexpected remote did not complete or delete the 
inflight request
+      goodSenderProbe.send(client, exampleRequest)
+      goodSenderProbe.expectMsg(DuplicateId(exampleRequest.id))
+
+      udpSocketProbe.reply(Udp.Received(internal.ByteResponse(goodAnswer), 
dnsServerAddress))
+
+      val answer = Answer(exampleRequest.id, im.Seq(goodRecord), im.Seq())
+      goodSenderProbe.expectMsg(answer)
+    }
+
     "Verify original question when processing UDP Failures" in {
       val udpExtensionProbe = TestProbe()
       val udpSocketProbe = TestProbe()
diff --git 
a/actor/src/main/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolver.scala 
b/actor/src/main/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolver.scala
index b3ce2f2cfe..89b2059bf5 100644
--- 
a/actor/src/main/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolver.scala
+++ 
b/actor/src/main/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolver.scala
@@ -15,10 +15,12 @@ package org.apache.pekko.io.dns.internal
 
 import java.net.{ Inet4Address, Inet6Address, InetAddress, InetSocketAddress }
 
+import scala.annotation.tailrec
 import scala.collection.immutable
 import scala.concurrent.{ ExecutionContextExecutor, Future, Promise }
 import scala.concurrent.ExecutionContext.parasitic
 import scala.util.{ Failure, Success, Try }
+import scala.util.control.NonFatal
 
 import org.apache.pekko
 import pekko.actor.{ Actor, ActorLogging, ActorRef, ActorRefFactory, 
NoSerializationVerificationNeeded, Props, Status }
@@ -218,41 +220,91 @@ private[pekko] object AsyncDnsResolver {
   private final case class InjectedDnsQuestionAnswer(requestId: Long, result: 
Try[Answer])
       extends NoSerializationVerificationNeeded
 
+  private final case class DidntDrop(id: Short) extends 
NoSerializationVerificationNeeded
+
   private object RequestIdInjector {
+    private val MaxIdGenerationAttempts = 1 << 16
+
     def props(idGenerator: IdGenerator): Props = Props(new 
RequestIdInjector(idGenerator))
   }
 
-  private class RequestIdInjector(idGenerator: IdGenerator) extends Actor {
+  private class RequestIdInjector(idGenerator: IdGenerator) extends Actor with 
ActorLogging {
+    import RequestIdInjector._
+
     private implicit val ec: ExecutionContextExecutor = context.dispatcher
+    private var activeRequestIds = Set.empty[Short]
 
     override def receive: Receive = {
       case question: DnsQuestionPreInjection =>
-        sendQuestion(sender(), question, question.withId(idGenerator.nextId()))
+        sendQuestionWithNewId(sender(), question)
 
-      case DnsQuestionAnswer(replyTo, request, _, Success(result: Answer)) =>
+      case DnsQuestionAnswer(replyTo, request, question, Success(result: 
Answer)) =>
+        activeRequestIds -= question.id
         replyTo ! InjectedDnsQuestionAnswer(request.requestId, Success(result))
 
       case DnsQuestionAnswer(replyTo, request, question, 
Success(DuplicateId(_))) =>
-        sendQuestion(replyTo, request, question.withId(idGenerator.nextId()))
+        sendQuestionWithNewId(replyTo, request)
 
       case DnsQuestionAnswer(replyTo, request, question, Failure(t)) =>
-        request.resolver ! DropRequest(question)
         replyTo ! InjectedDnsQuestionAnswer(request.requestId, Failure(t))
+        dropQuestion(request.resolver, request.timeout, question)
 
       case DnsQuestionAnswer(replyTo, request, question, Success(a)) =>
-        request.resolver ! DropRequest(question)
         replyTo ! InjectedDnsQuestionAnswer(
           request.requestId,
           Failure(
             new IllegalArgumentException("Unexpected response " + a.toString + 
" of type " + a.getClass.toString)))
+        dropQuestion(request.resolver, request.timeout, question)
+
+      case Dropped(id) =>
+        activeRequestIds -= id
+
+      case DidntDrop(id) =>
+        log.warning("DNS request id [{}] could not be confirmed dropped, 
keeping it reserved", id)
     }
 
+    private def sendQuestionWithNewId(replyTo: ActorRef, request: 
DnsQuestionPreInjection): Unit =
+      nextAvailableRequestId() match {
+        case Success(id) =>
+          sendQuestion(replyTo, request, request.withId(id))
+        case Failure(t) =>
+          replyTo ! InjectedDnsQuestionAnswer(request.requestId, Failure(t))
+      }
+
     private def sendQuestion(replyTo: ActorRef, request: 
DnsQuestionPreInjection, question: DnsQuestion): Unit = {
+      activeRequestIds += question.id
       implicit val askTimeout: Timeout = request.timeout
       (request.resolver ? question).onComplete { result =>
         self ! DnsQuestionAnswer(replyTo, request, question, result)
       }
     }
+
+    @tailrec
+    private def nextAvailableRequestId(attemptsLeft: Int = 
MaxIdGenerationAttempts): Try[Short] =
+      if (attemptsLeft == 0) {
+        Failure(new IllegalStateException("No non-active DNS request id could 
be generated"))
+      } else {
+        Try(idGenerator.nextId()) match {
+          case Failure(t)  => Failure(t)
+          case Success(id) =>
+            if (activeRequestIds.contains(id)) 
nextAvailableRequestId(attemptsLeft - 1)
+            else Success(id)
+        }
+      }
+
+    private def dropQuestion(resolver: ActorRef, timeout: Timeout, question: 
DnsQuestion): Unit = {
+      implicit val askTimeout: Timeout = timeout
+      (resolver ? DropRequest(question)).map {
+        case dropped: Dropped => dropped
+        case other            =>
+          log.warning("Unexpected response [{}] when dropping DNS request id 
[{}]", other, question.id)
+          DidntDrop(question.id)
+      }.recover {
+        case NonFatal(t) =>
+          log.warning("Drop request for DNS request id [{}] failed: {}", 
question.id, t.getMessage)
+          DidntDrop(question.id)
+      }.foreach(self ! _)
+    }
   }
 
   private object DnsResolutionActor {
diff --git 
a/actor/src/main/scala/org/apache/pekko/io/dns/internal/DnsClient.scala 
b/actor/src/main/scala/org/apache/pekko/io/dns/internal/DnsClient.scala
index 07290c8d6a..72b6619c53 100644
--- a/actor/src/main/scala/org/apache/pekko/io/dns/internal/DnsClient.scala
+++ b/actor/src/main/scala/org/apache/pekko/io/dns/internal/DnsClient.scala
@@ -24,6 +24,7 @@ import org.apache.pekko
 import pekko.actor.{ Actor, ActorLogging, ActorRef, 
NoSerializationVerificationNeeded, Props, Stash }
 import pekko.actor.Status.Failure
 import pekko.annotation.InternalApi
+import pekko.event.{ LogMarker, Logging }
 import pekko.io.{ IO, Tcp, Udp }
 import pekko.io.dns.{ RecordClass, RecordType, ResourceRecord }
 import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
@@ -51,6 +52,7 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
 
   final case class DuplicateId(id: Short) extends 
NoSerializationVerificationNeeded
   final case class DropRequest(question: DnsQuestion)
+  final case class Dropped(id: Short) extends NoSerializationVerificationNeeded
 }
 
 /**
@@ -64,6 +66,8 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
   val udp = IO(Udp)
   val tcp = IO(Tcp)
 
+  private val securityLog = Logging.withMarker(this)
+
   private[internal] var inflightRequests: Map[Short, (ActorRef, Message)] = 
Map.empty
 
   lazy val tcpDnsClient: ActorRef = createTcpClient()
@@ -95,14 +99,17 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
   @nowarn()
   def ready(socket: ActorRef): Receive = {
     case DropRequest(msg) =>
-      inflightRequests.get(msg.id).foreach {
-        case (_, orig) if Seq(msg.name) == orig.questions.map(_.name) =>
+      inflightRequests.get(msg.id) match {
+        case Some((_, orig)) if isSameQuestion(Seq(question(msg)), 
orig.questions) =>
           log.debug("Dropping request [{}]", msg.id)
           inflightRequests -= msg.id
-        case (_, orig) =>
+          sender() ! Dropped(msg.id)
+        case Some((_, orig)) =>
           log.warning("Cannot drop inflight DNS request the question [{}] does 
not match [{}]",
-            msg.name,
-            orig.questions.map(_.name).mkString(","))
+            question(msg),
+            orig.questions.mkString(","))
+        case None =>
+          sender() ! Dropped(msg.id)
       }
 
     case Question4(id, name) =>
@@ -149,6 +156,13 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
         case _ =>
           log.warning("Dns client failed to send {}", cmd)
       }
+    case Udp.Received(_, remote) if !isExpectedRemote(remote) =>
+      securityLog.warning(
+        LogMarker.Security,
+        "Ignoring DNS response from [{}], expected [{}]",
+        remote,
+        ns)
+
     case Udp.Received(data, remote) =>
       log.debug("Received message from [{}]: [{}]", remote, data)
       val msg = Message.parse(data)
@@ -219,6 +233,21 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
     impl(q1s.sortBy(_.name).toList, q2s.sortBy(_.name).toList)
   }
 
+  private def question(msg: DnsQuestion): Question =
+    msg match {
+      case Question4(_, name)   => Question(name, RecordType.A, RecordClass.IN)
+      case Question6(_, name)   => Question(name, RecordType.AAAA, 
RecordClass.IN)
+      case SrvQuestion(_, name) => Question(name, RecordType.SRV, 
RecordClass.IN)
+    }
+
+  private def isExpectedRemote(remote: InetSocketAddress): Boolean =
+    remote == ns || {
+      remote.getPort == ns.getPort &&
+      remote.getAddress != null &&
+      ns.getAddress != null &&
+      remote.getAddress == ns.getAddress
+    }
+
   def createTcpClient() = {
     context.actorOf(
       BackoffSupervisor.props(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to