This is an automated email from the ASF dual-hosted git repository.
He-Pin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/pekko.git
The following commit(s) were added to refs/heads/main by this push:
new 52747cdfee fix: harden async DNS request id handling (#2960)
52747cdfee is described below
commit 52747cdfeed71304a32ad660008051a3e15473ff
Author: He-Pin(kerr) <[email protected]>
AuthorDate: Sun May 10 03:51:52 2026 +0800
fix: harden async DNS request id handling (#2960)
Motivation:
The async DNS resolver should avoid reusing request ids that are still
active, should only release failed ids after the DNS client confirms the drop,
and should ignore UDP DNS responses from unexpected remotes.
Modification:
Track active request ids in the request-id injector, keep duplicate ids
reserved, add an explicit Dropped acknowledgement for DNS client drops, match
drops by the full DNS question including record type, and reject UDP responses
from non-nameserver remotes with a security marker.
Result:
Concurrent async DNS requests avoid active id reuse, failed requests clean
up ids only when confirmed safe, and spoofed UDP responses from unexpected
remotes are ignored.
Tests:
- scalafmt --mode diff-ref=port-ef33e71-async-dns-id-injector / passed
- scalafmt --list --mode diff-ref=port-ef33e71-async-dns-id-injector /
passed
- git diff --check / passed
- sbt "actor-tests / Test / testOnly
org.apache.pekko.io.dns.internal.AsyncDnsResolverSpec
org.apache.pekko.io.dns.internal.DnsClientSpec
org.apache.pekko.io.dns.DnsSettingsSpec" / passed
References:
Upstream commit:
https://github.com/akka/akka-core/commit/7434794bab6fec95cb52d810d95126f90efb8c34,
which is now Apache licensed
Refs #31926
---
.../io/dns/internal/AsyncDnsResolverSpec.scala | 37 +++++++++++--
.../pekko/io/dns/internal/DnsClientSpec.scala | 55 +++++++++++++++++--
.../pekko/io/dns/internal/AsyncDnsResolver.scala | 64 ++++++++++++++++++++--
.../apache/pekko/io/dns/internal/DnsClient.scala | 39 +++++++++++--
4 files changed, 175 insertions(+), 20 deletions(-)
diff --git
a/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolverSpec.scala
b/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolverSpec.scala
index 13c40e872d..1f520a6d65 100644
---
a/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolverSpec.scala
+++
b/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolverSpec.scala
@@ -26,7 +26,7 @@ import pekko.io.dns.{ AAAARecord, ARecord, DnsSettings,
IdGenerator, SRVRecord }
import pekko.io.dns.CachePolicy.Ttl
import pekko.io.dns.DnsProtocol._
import pekko.io.dns.internal.AsyncDnsResolver.ResolveFailedException
-import pekko.io.dns.internal.DnsClient.{ Answer, DropRequest, DuplicateId,
Question4, Question6, SrvQuestion }
+import pekko.io.dns.internal.DnsClient.{ Answer, DropRequest, Dropped,
DuplicateId, Question4, Question6, SrvQuestion }
import pekko.testkit.{ PekkoSpec, TestProbe, WithLogCapturing }
import com.typesafe.config.{ Config, ConfigFactory, ConfigValueFactory }
@@ -322,7 +322,7 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
override val r = resolver(
List(dnsClient1.ref, dnsClient2.ref),
defaultConfig,
- deterministicIds((1 to 10).map(_.toShort): _*))
+ deterministicIds(1, 2, 1, 3, 4, 5, 6, 7, 8, 9, 10))
// Send multiple resolves for different names so no in-flight
deduplication applies
val resolveCount = 10
@@ -347,8 +347,35 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
}
}
- "retry request ids that duplicate an in-flight request" in new Setup {
- override val r = resolver(List(dnsClient1.ref), defaultConfig,
deterministicIds(1, 1, 2))
+ "reuse confirmed dropped request ids" in new Setup {
+ val config = defaultConfig.withValue("resolve-timeout",
ConfigValueFactory.fromAnyRef("200 ms"))
+ override val r = resolver(List(dnsClient1.ref), config,
deterministicIds(1, 1, 2))
+
+ r ! Resolve("host1.cats.com", Ip(ipv4 = true, ipv6 = false))
+ val timedOutQuestion = dnsClient1.expectMsgPF() {
+ case q: Question4 if q.name == "host1.cats.com" => q
+ }
+
+ val dropped = dnsClient1.expectMsgPF(remainingOrDefault) {
+ case DropRequest(question) if question == timedOutQuestion => question
+ }
+ dnsClient1.reply(Dropped(dropped.id))
+ senderProbe.expectMsgPF(remainingOrDefault) {
+ case Failure(ResolveFailedException(_)) =>
+ }
+
+ r ! Resolve("host2.cats.com", Ip(ipv4 = true, ipv6 = false))
+ val reusedQuestion = dnsClient1.expectMsgPF() {
+ case q: Question4 if q.name == "host2.cats.com" => q
+ }
+ reusedQuestion.id shouldBe timedOutQuestion.id
+ dnsClient1.reply(Answer(reusedQuestion.id, im.Seq.empty))
+
+ senderProbe.expectMsg(Resolved("host2.cats.com", im.Seq.empty))
+ }
+
+ "retry request ids that duplicate an untracked in-flight request" in new
Setup {
+ override val r = resolver(List(dnsClient1.ref), defaultConfig,
deterministicIds(1, 1, 2, 2, 3))
val asker1 = TestProbe()
val asker2 = TestProbe()
@@ -361,7 +388,7 @@ class AsyncDnsResolverSpec extends PekkoSpec("""
r.tell(Resolve("host2.cats.com", Ip(ipv4 = true, ipv6 = false)),
asker2.ref)
val duplicatedQuestion = dnsClient1.expectMsgPF() {
- case q: Question4 if q.name == "host2.cats.com" && q.id ==
firstQuestion.id => q
+ case q: Question4 if q.name == "host2.cats.com" && q.id !=
firstQuestion.id => q
}
dnsClient1.reply(DuplicateId(duplicatedQuestion.id))
diff --git
a/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/DnsClientSpec.scala
b/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/DnsClientSpec.scala
index 8c8e9bb4de..c84dc96375 100644
---
a/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/DnsClientSpec.scala
+++
b/actor-tests/src/test/scala/org/apache/pekko/io/dns/internal/DnsClientSpec.scala
@@ -23,7 +23,7 @@ import org.apache.pekko.actor.Status.Failure
import pekko.actor.Props
import pekko.io.Udp
import pekko.io.dns.{ ARecord, CachePolicy, RecordClass, RecordType }
-import pekko.io.dns.internal.DnsClient.{ Answer, DropRequest, DuplicateId,
Question4 }
+import pekko.io.dns.internal.DnsClient.{ Answer, DropRequest, Dropped,
DuplicateId, Question4, Question6 }
import pekko.testkit.{ ImplicitSender, PekkoSpec, TestProbe }
class DnsClientSpec extends PekkoSpec with ImplicitSender {
@@ -150,9 +150,9 @@ class DnsClientSpec extends PekkoSpec with ImplicitSender {
answerRecs = im.Seq(goodRecord))
val goodAnswer = exampleRequestMessage.copy(flags = flags, answerRecs =
im.Seq(goodRecord))
- udpSocketProbe.reply(Udp.Received(internal.ByteResponse(badId), socket))
- udpSocketProbe.reply(Udp.Received(internal.ByteResponse(badQuestion),
socket))
- udpSocketProbe.reply(Udp.Received(internal.ByteResponse(goodAnswer),
socket))
+ udpSocketProbe.reply(Udp.Received(internal.ByteResponse(badId),
dnsServerAddress))
+ udpSocketProbe.reply(Udp.Received(internal.ByteResponse(badQuestion),
dnsServerAddress))
+ udpSocketProbe.reply(Udp.Received(internal.ByteResponse(goodAnswer),
dnsServerAddress))
val answer = Answer(exampleRequest.id, im.Seq(goodRecord), im.Seq())
goodSenderProbe.expectMsg(answer)
@@ -182,6 +182,7 @@ class DnsClientSpec extends PekkoSpec with ImplicitSender {
udpSend.payload shouldBe exampleRequestMessage.write()
goodSenderProbe.send(client, DropRequest(exampleRequest.copy(id = 999)))
+ goodSenderProbe.expectMsg(Dropped(999))
// duplicate shows inflight message not deleted
goodSenderProbe.send(client, exampleRequest)
@@ -193,13 +194,59 @@ class DnsClientSpec extends PekkoSpec with ImplicitSender
{
goodSenderProbe.send(client, exampleRequest)
goodSenderProbe.expectMsg(DuplicateId(exampleRequest.id))
+ goodSenderProbe.send(client, DropRequest(Question6(exampleRequest.id,
exampleRequest.name)))
+
+ // duplicate shows inflight message not deleted
+ goodSenderProbe.send(client, exampleRequest)
+ goodSenderProbe.expectMsg(DuplicateId(exampleRequest.id))
+
goodSenderProbe.send(client, DropRequest(exampleRequest))
+ goodSenderProbe.expectMsg(Dropped(exampleRequest.id))
// no duplicate shows inflight message was deleted
goodSenderProbe.send(client, exampleRequest)
goodSenderProbe.expectNoMessage()
}
+ "Ignore DNS responses from unexpected remotes" in {
+ val udpExtensionProbe = TestProbe()
+ val udpSocketProbe = TestProbe()
+ val tcpClientProbe = TestProbe()
+ val goodSenderProbe = TestProbe()
+
+ val socket = InetSocketAddress.createUnresolved("localhost", 41325)
+ val unexpectedAddress = InetSocketAddress.createUnresolved("bar", 53)
+ val goodRecord = ARecord(exampleRequest.name, CachePolicy.Ttl.never,
InetAddress.getLocalHost())
+
+ val client = system.actorOf(Props(new DnsClient(dnsServerAddress) {
+ override val udp = udpExtensionProbe.ref
+
+ override def createTcpClient() = tcpClientProbe.ref
+ }))
+
+ udpExtensionProbe.expectMsgType[Udp.Bind]
+ udpSocketProbe.send(udpExtensionProbe.lastSender, Udp.Bound(socket))
+
+ goodSenderProbe.send(client, exampleRequest)
+
+ val udpSend = udpSocketProbe.expectMsgType[Udp.Send]
+ udpSend.payload shouldBe exampleRequestMessage.write()
+
+ val flags = MessageFlags(true, authoritativeAnswer = true)
+ val goodAnswer = exampleRequestMessage.copy(flags = flags, answerRecs =
im.Seq(goodRecord))
+
+ udpSocketProbe.reply(Udp.Received(internal.ByteResponse(goodAnswer),
unexpectedAddress))
+
+ // duplicate shows unexpected remote did not complete or delete the
inflight request
+ goodSenderProbe.send(client, exampleRequest)
+ goodSenderProbe.expectMsg(DuplicateId(exampleRequest.id))
+
+ udpSocketProbe.reply(Udp.Received(internal.ByteResponse(goodAnswer),
dnsServerAddress))
+
+ val answer = Answer(exampleRequest.id, im.Seq(goodRecord), im.Seq())
+ goodSenderProbe.expectMsg(answer)
+ }
+
"Verify original question when processing UDP Failures" in {
val udpExtensionProbe = TestProbe()
val udpSocketProbe = TestProbe()
diff --git
a/actor/src/main/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolver.scala
b/actor/src/main/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolver.scala
index b3ce2f2cfe..89b2059bf5 100644
---
a/actor/src/main/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolver.scala
+++
b/actor/src/main/scala/org/apache/pekko/io/dns/internal/AsyncDnsResolver.scala
@@ -15,10 +15,12 @@ package org.apache.pekko.io.dns.internal
import java.net.{ Inet4Address, Inet6Address, InetAddress, InetSocketAddress }
+import scala.annotation.tailrec
import scala.collection.immutable
import scala.concurrent.{ ExecutionContextExecutor, Future, Promise }
import scala.concurrent.ExecutionContext.parasitic
import scala.util.{ Failure, Success, Try }
+import scala.util.control.NonFatal
import org.apache.pekko
import pekko.actor.{ Actor, ActorLogging, ActorRef, ActorRefFactory,
NoSerializationVerificationNeeded, Props, Status }
@@ -218,41 +220,91 @@ private[pekko] object AsyncDnsResolver {
private final case class InjectedDnsQuestionAnswer(requestId: Long, result:
Try[Answer])
extends NoSerializationVerificationNeeded
+ private final case class DidntDrop(id: Short) extends
NoSerializationVerificationNeeded
+
private object RequestIdInjector {
+ private val MaxIdGenerationAttempts = 1 << 16
+
def props(idGenerator: IdGenerator): Props = Props(new
RequestIdInjector(idGenerator))
}
- private class RequestIdInjector(idGenerator: IdGenerator) extends Actor {
+ private class RequestIdInjector(idGenerator: IdGenerator) extends Actor with
ActorLogging {
+ import RequestIdInjector._
+
private implicit val ec: ExecutionContextExecutor = context.dispatcher
+ private var activeRequestIds = Set.empty[Short]
override def receive: Receive = {
case question: DnsQuestionPreInjection =>
- sendQuestion(sender(), question, question.withId(idGenerator.nextId()))
+ sendQuestionWithNewId(sender(), question)
- case DnsQuestionAnswer(replyTo, request, _, Success(result: Answer)) =>
+ case DnsQuestionAnswer(replyTo, request, question, Success(result:
Answer)) =>
+ activeRequestIds -= question.id
replyTo ! InjectedDnsQuestionAnswer(request.requestId, Success(result))
case DnsQuestionAnswer(replyTo, request, question,
Success(DuplicateId(_))) =>
- sendQuestion(replyTo, request, question.withId(idGenerator.nextId()))
+ sendQuestionWithNewId(replyTo, request)
case DnsQuestionAnswer(replyTo, request, question, Failure(t)) =>
- request.resolver ! DropRequest(question)
replyTo ! InjectedDnsQuestionAnswer(request.requestId, Failure(t))
+ dropQuestion(request.resolver, request.timeout, question)
case DnsQuestionAnswer(replyTo, request, question, Success(a)) =>
- request.resolver ! DropRequest(question)
replyTo ! InjectedDnsQuestionAnswer(
request.requestId,
Failure(
new IllegalArgumentException("Unexpected response " + a.toString +
" of type " + a.getClass.toString)))
+ dropQuestion(request.resolver, request.timeout, question)
+
+ case Dropped(id) =>
+ activeRequestIds -= id
+
+ case DidntDrop(id) =>
+ log.warning("DNS request id [{}] could not be confirmed dropped,
keeping it reserved", id)
}
+ private def sendQuestionWithNewId(replyTo: ActorRef, request:
DnsQuestionPreInjection): Unit =
+ nextAvailableRequestId() match {
+ case Success(id) =>
+ sendQuestion(replyTo, request, request.withId(id))
+ case Failure(t) =>
+ replyTo ! InjectedDnsQuestionAnswer(request.requestId, Failure(t))
+ }
+
private def sendQuestion(replyTo: ActorRef, request:
DnsQuestionPreInjection, question: DnsQuestion): Unit = {
+ activeRequestIds += question.id
implicit val askTimeout: Timeout = request.timeout
(request.resolver ? question).onComplete { result =>
self ! DnsQuestionAnswer(replyTo, request, question, result)
}
}
+
+ @tailrec
+ private def nextAvailableRequestId(attemptsLeft: Int =
MaxIdGenerationAttempts): Try[Short] =
+ if (attemptsLeft == 0) {
+ Failure(new IllegalStateException("No non-active DNS request id could
be generated"))
+ } else {
+ Try(idGenerator.nextId()) match {
+ case Failure(t) => Failure(t)
+ case Success(id) =>
+ if (activeRequestIds.contains(id))
nextAvailableRequestId(attemptsLeft - 1)
+ else Success(id)
+ }
+ }
+
+ private def dropQuestion(resolver: ActorRef, timeout: Timeout, question:
DnsQuestion): Unit = {
+ implicit val askTimeout: Timeout = timeout
+ (resolver ? DropRequest(question)).map {
+ case dropped: Dropped => dropped
+ case other =>
+ log.warning("Unexpected response [{}] when dropping DNS request id
[{}]", other, question.id)
+ DidntDrop(question.id)
+ }.recover {
+ case NonFatal(t) =>
+ log.warning("Drop request for DNS request id [{}] failed: {}",
question.id, t.getMessage)
+ DidntDrop(question.id)
+ }.foreach(self ! _)
+ }
}
private object DnsResolutionActor {
diff --git
a/actor/src/main/scala/org/apache/pekko/io/dns/internal/DnsClient.scala
b/actor/src/main/scala/org/apache/pekko/io/dns/internal/DnsClient.scala
index 07290c8d6a..72b6619c53 100644
--- a/actor/src/main/scala/org/apache/pekko/io/dns/internal/DnsClient.scala
+++ b/actor/src/main/scala/org/apache/pekko/io/dns/internal/DnsClient.scala
@@ -24,6 +24,7 @@ import org.apache.pekko
import pekko.actor.{ Actor, ActorLogging, ActorRef,
NoSerializationVerificationNeeded, Props, Stash }
import pekko.actor.Status.Failure
import pekko.annotation.InternalApi
+import pekko.event.{ LogMarker, Logging }
import pekko.io.{ IO, Tcp, Udp }
import pekko.io.dns.{ RecordClass, RecordType, ResourceRecord }
import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
@@ -51,6 +52,7 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
final case class DuplicateId(id: Short) extends
NoSerializationVerificationNeeded
final case class DropRequest(question: DnsQuestion)
+ final case class Dropped(id: Short) extends NoSerializationVerificationNeeded
}
/**
@@ -64,6 +66,8 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
val udp = IO(Udp)
val tcp = IO(Tcp)
+ private val securityLog = Logging.withMarker(this)
+
private[internal] var inflightRequests: Map[Short, (ActorRef, Message)] =
Map.empty
lazy val tcpDnsClient: ActorRef = createTcpClient()
@@ -95,14 +99,17 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
@nowarn()
def ready(socket: ActorRef): Receive = {
case DropRequest(msg) =>
- inflightRequests.get(msg.id).foreach {
- case (_, orig) if Seq(msg.name) == orig.questions.map(_.name) =>
+ inflightRequests.get(msg.id) match {
+ case Some((_, orig)) if isSameQuestion(Seq(question(msg)),
orig.questions) =>
log.debug("Dropping request [{}]", msg.id)
inflightRequests -= msg.id
- case (_, orig) =>
+ sender() ! Dropped(msg.id)
+ case Some((_, orig)) =>
log.warning("Cannot drop inflight DNS request the question [{}] does
not match [{}]",
- msg.name,
- orig.questions.map(_.name).mkString(","))
+ question(msg),
+ orig.questions.mkString(","))
+ case None =>
+ sender() ! Dropped(msg.id)
}
case Question4(id, name) =>
@@ -149,6 +156,13 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
case _ =>
log.warning("Dns client failed to send {}", cmd)
}
+ case Udp.Received(_, remote) if !isExpectedRemote(remote) =>
+ securityLog.warning(
+ LogMarker.Security,
+ "Ignoring DNS response from [{}], expected [{}]",
+ remote,
+ ns)
+
case Udp.Received(data, remote) =>
log.debug("Received message from [{}]: [{}]", remote, data)
val msg = Message.parse(data)
@@ -219,6 +233,21 @@ import pekko.pattern.{ BackoffOpts, BackoffSupervisor }
impl(q1s.sortBy(_.name).toList, q2s.sortBy(_.name).toList)
}
+ private def question(msg: DnsQuestion): Question =
+ msg match {
+ case Question4(_, name) => Question(name, RecordType.A, RecordClass.IN)
+ case Question6(_, name) => Question(name, RecordType.AAAA,
RecordClass.IN)
+ case SrvQuestion(_, name) => Question(name, RecordType.SRV,
RecordClass.IN)
+ }
+
+ private def isExpectedRemote(remote: InetSocketAddress): Boolean =
+ remote == ns || {
+ remote.getPort == ns.getPort &&
+ remote.getAddress != null &&
+ ns.getAddress != null &&
+ remote.getAddress == ns.getAddress
+ }
+
def createTcpClient() = {
context.actorOf(
BackoffSupervisor.props(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]