This is an automated email from the ASF dual-hosted git repository.

hepin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/pekko.git


The following commit(s) were added to refs/heads/main by this push:
     new 2939de02db fix: Fix recoverWith with add missing attempt back. (#2712)
2939de02db is described below

commit 2939de02db2a1af7b82e2763534024979bcbd25d
Author: He-Pin(kerr) <[email protected]>
AuthorDate: Thu Mar 5 04:12:23 2026 +0800

    fix: Fix recoverWith with add missing attempt back. (#2712)
    
    * Revert "Revert recent changes to recoverWith (#2674)"
    
    This reverts commit c9906c60fcfac696cf307ca0edc8a80da202dde1.
    
    * fix: Fix recoverWith with add missing attempt back.
    
    * chore: add attempt first in RecoverWith
---
 .../stream/scaladsl/FlowRecoverWithSpec.scala      | 396 ++++++++++++++++++++-
 .../org/apache/pekko/stream/impl/fusing/Ops.scala  |  42 ++-
 2 files changed, 430 insertions(+), 8 deletions(-)

diff --git 
a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowRecoverWithSpec.scala
 
b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowRecoverWithSpec.scala
index 47f129a329..a0cafc337c 100644
--- 
a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowRecoverWithSpec.scala
+++ 
b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/FlowRecoverWithSpec.scala
@@ -14,7 +14,8 @@
 package org.apache.pekko.stream.scaladsl
 
 import scala.annotation.nowarn
-import scala.concurrent.Future
+import scala.concurrent.{ Future, Promise }
+import scala.concurrent.duration._
 import scala.util.control.NoStackTrace
 
 import org.apache.pekko
@@ -251,6 +252,102 @@ class FlowRecoverWithSpec extends StreamSpec {
         .expectError(ex)
     }
 
+    "terminate with exception after set number of retries with failed source" 
in {
+      Source
+        .failed[Int](ex)
+        .recoverWithRetries(3,
+          {
+            case _: Throwable => Source.failed(ex)
+          })
+        .runWith(TestSink[Int]())
+        .request(1)
+        .expectError(ex)
+    }
+
+    "terminate with exception after set number of retries with failed future 
source" in {
+      Source
+        .failed[Int](ex)
+        .recoverWithRetries(3,
+          {
+            case _: Throwable => Source.future(Future.failed(ex))
+          })
+        .runWith(TestSink[Int]())
+        .request(1)
+        .expectError(ex)
+    }
+
+    "count retries correctly with failed source" in {
+      val counter = new java.util.concurrent.atomic.AtomicInteger(0)
+      Source
+        .failed[Int](ex)
+        .recoverWithRetries(5,
+          {
+            case _: Throwable =>
+              if (counter.incrementAndGet() < 5) {
+                Source.failed(ex)
+              } else {
+                Source.single(42)
+              }
+          })
+        .runWith(TestSink[Int]())
+        .request(1)
+        .expectNext(42)
+        .expectComplete()
+      counter.get() shouldBe 5
+    }
+
+    "count retries correctly with failed future source" in {
+      val counter = new java.util.concurrent.atomic.AtomicInteger(0)
+      Source
+        .failed[Int](ex)
+        .recoverWithRetries(5,
+          {
+            case _: Throwable =>
+              if (counter.incrementAndGet() < 5) {
+                Source.future(Future.failed(ex))
+              } else {
+                Source.single(42)
+              }
+          })
+        .runWith(TestSink[Int]())
+        .request(1)
+        .expectNext(42)
+        .expectComplete()
+      counter.get() shouldBe 5
+    }
+
+    "exhaust retries with failed source and then fail" in {
+      val counter = new java.util.concurrent.atomic.AtomicInteger(0)
+      Source
+        .failed[Int](ex)
+        .recoverWithRetries(3,
+          {
+            case _: Throwable =>
+              counter.incrementAndGet()
+              Source.failed(ex)
+          })
+        .runWith(TestSink[Int]())
+        .request(1)
+        .expectError(ex)
+      counter.get() shouldBe 3
+    }
+
+    "exhaust retries with failed future source and then fail" in {
+      val counter = new java.util.concurrent.atomic.AtomicInteger(0)
+      Source
+        .failed[Int](ex)
+        .recoverWithRetries(3,
+          {
+            case _: Throwable =>
+              counter.incrementAndGet()
+              Source.future(Future.failed(ex))
+          })
+        .runWith(TestSink[Int]())
+        .request(1)
+        .expectError(ex)
+      counter.get() shouldBe 3
+    }
+
     "not attempt recovering when attempts is zero" in {
       Source(1 to 3)
         .map { a =>
@@ -310,5 +407,302 @@ class FlowRecoverWithSpec extends StreamSpec {
       result.failed.futureValue should ===(matFail)
 
     }
+
+    "fail when failed source carries an exception not matched by the partial 
function" in {
+      val ex2 = new IllegalArgumentException("ex2") with NoStackTrace
+      Source
+        .failed[Int](ex)
+        .recoverWithRetries(3,
+          {
+            case _: RuntimeException => Source.failed(ex2)
+          })
+        .runWith(TestSink[Int]())
+        .request(1)
+        .expectError(ex2)
+    }
+
+    "allow exactly one retry with recoverWithRetries(1, ...)" in {
+      val counter = new java.util.concurrent.atomic.AtomicInteger(0)
+      Source
+        .failed[Int](ex)
+        .recoverWithRetries(1,
+          {
+            case _: Throwable =>
+              counter.incrementAndGet()
+              Source.single(42)
+          })
+        .runWith(TestSink[Int]())
+        .request(1)
+        .expectNext(42)
+        .expectComplete()
+      counter.get() shouldBe 1
+    }
+
+    "fail after exactly one retry with failed source and recoverWithRetries(1, 
...)" in {
+      val counter = new java.util.concurrent.atomic.AtomicInteger(0)
+      Source
+        .failed[Int](ex)
+        .recoverWithRetries(1,
+          {
+            case _: Throwable =>
+              counter.incrementAndGet()
+              Source.failed(ex)
+          })
+        .runWith(TestSink[Int]())
+        .request(1)
+        .expectError(ex)
+      counter.get() shouldBe 1
+    }
+
+    "recover with iterable source from failed source" in {
+      Source
+        .failed[Int](ex)
+        .recoverWith { case _: Throwable => Source(List(1, 2, 3)) }
+        .runWith(TestSink[Int]())
+        .request(4)
+        .expectNextN(List(1, 2, 3))
+        .expectComplete()
+    }
+
+    "terminate after set number of retries with iterable source" in {
+      Source(1 to 3)
+        .map { a =>
+          if (a == 3) throw ex else a
+        }
+        .recoverWithRetries(2, { case _: Throwable => Source(List(11, 22, 
33)).map(m => if (m == 33) throw ex else m) })
+        .runWith(TestSink[Int]())
+        .request(100)
+        .expectNextN(List(1, 2))
+        .expectNextN(List(11, 22))
+        .expectNextN(List(11, 22))
+        .expectError(ex)
+    }
+
+    "recover with a pending future source" in {
+      val promise = Promise[Int]()
+      val probe = Source
+        .failed[Int](ex)
+        .recoverWith { case _: Throwable => Source.future(promise.future) }
+        .runWith(TestSink[Int]())
+      probe.request(1)
+      probe.expectNoMessage(200.millis)
+      promise.success(42)
+      probe
+        .expectNext(42)
+        .expectComplete()
+    }
+
+    "recover infinitely with failed source when negative (-1) number of 
attempts given" in {
+      val counter = new java.util.concurrent.atomic.AtomicInteger(0)
+      Source
+        .failed[Int](ex)
+        .recoverWithRetries(-1,
+          {
+            case _: Throwable =>
+              if (counter.incrementAndGet() < 100) {
+                Source.failed(ex)
+              } else {
+                Source.single(42)
+              }
+          })
+        .runWith(TestSink[Int]())
+        .request(1)
+        .expectNext(42)
+        .expectComplete()
+      counter.get() shouldBe 100
+    }
+
+    "recover infinitely with failed future source when negative (-1) number of 
attempts given" in {
+      val counter = new java.util.concurrent.atomic.AtomicInteger(0)
+      Source
+        .failed[Int](ex)
+        .recoverWithRetries(-1,
+          {
+            case _: Throwable =>
+              if (counter.incrementAndGet() < 100) {
+                Source.future(Future.failed(ex))
+              } else {
+                Source.single(42)
+              }
+          })
+        .runWith(TestSink[Int]())
+        .request(1)
+        .expectNext(42)
+        .expectComplete()
+      counter.get() shouldBe 100
+    }
+
+    "count retries correctly with mixed failed source and failed future 
source" in {
+      val counter = new java.util.concurrent.atomic.AtomicInteger(0)
+      Source
+        .failed[Int](ex)
+        .recoverWithRetries(5,
+          {
+            case _: Throwable =>
+              val count = counter.incrementAndGet()
+              if (count < 5) {
+                if (count % 2 == 1) Source.failed(ex)
+                else Source.future(Future.failed(ex))
+              } else {
+                Source.single(42)
+              }
+          })
+        .runWith(TestSink[Int]())
+        .request(1)
+        .expectNext(42)
+        .expectComplete()
+      counter.get() shouldBe 5
+    }
+
+    "recover with failed source then iterable source" in {
+      val counter = new java.util.concurrent.atomic.AtomicInteger(0)
+      Source
+        .failed[Int](ex)
+        .recoverWith {
+          case _: Throwable =>
+            if (counter.incrementAndGet() < 3) {
+              Source.failed(ex)
+            } else {
+              Source(List(10, 20, 30))
+            }
+        }
+        .runWith(TestSink[Int]())
+        .request(4)
+        .expectNextN(List(10, 20, 30))
+        .expectComplete()
+      counter.get() shouldBe 3
+    }
+
+    "recover with failed source then java stream source" in {
+      val counter = new java.util.concurrent.atomic.AtomicInteger(0)
+      Source
+        .failed[Int](ex)
+        .recoverWith {
+          case _: Throwable =>
+            if (counter.incrementAndGet() < 3) {
+              Source.failed(ex)
+            } else {
+              Source.fromJavaStream(() => java.util.stream.Stream.of(10, 20))
+            }
+        }
+        .runWith(TestSink[Int]())
+        .request(3)
+        .expectNextN(List(10, 20))
+        .expectComplete()
+      counter.get() shouldBe 3
+    }
+
+    "recover with failed source then empty source" in {
+      val counter = new java.util.concurrent.atomic.AtomicInteger(0)
+      Source
+        .failed[Int](ex)
+        .recoverWith {
+          case _: Throwable =>
+            if (counter.incrementAndGet() < 3) {
+              Source.failed(ex)
+            } else {
+              Source.empty
+            }
+        }
+        .runWith(TestSink[Int]())
+        .request(1)
+        .expectComplete()
+      counter.get() shouldBe 3
+    }
+
+    "recover with failed source then single source" in {
+      val counter = new java.util.concurrent.atomic.AtomicInteger(0)
+      Source
+        .failed[Int](ex)
+        .recoverWith {
+          case _: Throwable =>
+            if (counter.incrementAndGet() < 3) {
+              Source.failed(ex)
+            } else {
+              Source.single(99)
+            }
+        }
+        .runWith(TestSink[Int]())
+        .request(1)
+        .expectNext(99)
+        .expectComplete()
+      counter.get() shouldBe 3
+    }
+
+    "recover with failed source then completed future source" in {
+      val counter = new java.util.concurrent.atomic.AtomicInteger(0)
+      Source
+        .failed[Int](ex)
+        .recoverWith {
+          case _: Throwable =>
+            if (counter.incrementAndGet() < 3) {
+              Source.failed(ex)
+            } else {
+              Source.future(Future.successful(77))
+            }
+        }
+        .runWith(TestSink[Int]())
+        .request(1)
+        .expectNext(77)
+        .expectComplete()
+      counter.get() shouldBe 3
+    }
+
+    "recover with failed source carrying different exception types" in {
+      val ex2 = new IllegalStateException("ex2") with NoStackTrace
+      val counter = new java.util.concurrent.atomic.AtomicInteger(0)
+      Source
+        .failed[Int](ex)
+        .recoverWithRetries(5,
+          {
+            case _: RuntimeException =>
+              if (counter.incrementAndGet() < 3) {
+                Source.failed(ex2)
+              } else {
+                Source.single(42)
+              }
+          })
+        .runWith(TestSink[Int]())
+        .request(1)
+        .expectNext(42)
+        .expectComplete()
+      counter.get() shouldBe 3
+    }
+
+    "recover on a Flow" in {
+      Source(1 to 4)
+        .via(
+          Flow[Int]
+            .map { a =>
+              if (a == 3) throw ex else a
+            }
+            .recoverWith { case _: Throwable => Source.single(99) })
+        .runWith(TestSink[Int]())
+        .request(3)
+        .expectNextN(List(1, 2, 99))
+        .expectComplete()
+    }
+
+    "recover on a Flow with failed source retries" in {
+      val counter = new java.util.concurrent.atomic.AtomicInteger(0)
+      Source
+        .failed[Int](ex)
+        .via(
+          Flow[Int]
+            .recoverWithRetries(5,
+              {
+                case _: Throwable =>
+                  if (counter.incrementAndGet() < 3) {
+                    Source.failed(ex)
+                  } else {
+                    Source.single(42)
+                  }
+              }))
+        .runWith(TestSink[Int]())
+        .request(1)
+        .expectNext(42)
+        .expectComplete()
+      counter.get() shouldBe 3
+    }
   }
 }
diff --git 
a/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/Ops.scala 
b/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/Ops.scala
index 0635e52d51..d214479d4e 100644
--- a/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/Ops.scala
+++ b/stream/src/main/scala/org/apache/pekko/stream/impl/fusing/Ops.scala
@@ -36,9 +36,16 @@ import pekko.stream.Attributes.{ InputBuffer, LogLevels }
 import pekko.stream.Attributes.SourceLocation
 import pekko.stream.OverflowStrategies._
 import pekko.stream.Supervision.Decider
-import pekko.stream.impl.{ Buffer => BufferImpl, ContextPropagation, 
ReactiveStreamsCompliance, TraversalBuilder }
+import pekko.stream.impl.{
+  Buffer => BufferImpl,
+  ContextPropagation,
+  FailedSource,
+  JavaStreamSource,
+  ReactiveStreamsCompliance,
+  TraversalBuilder
+}
 import pekko.stream.impl.Stages.DefaultAttributes
-import pekko.stream.impl.fusing.GraphStages.SimpleLinearGraphStage
+import pekko.stream.impl.fusing.GraphStages.{ FutureSource, 
SimpleLinearGraphStage, SingleSource }
 import pekko.stream.scaladsl.{
   DelayStrategy,
   Source,
@@ -2162,6 +2169,7 @@ private[pekko] object TakeWithin {
       override def onPull(): Unit = pull(in)
 
       @nowarn("msg=Any")
+      @tailrec
       def onFailure(ex: Throwable): Unit = {
         import Collect.NotApplied
         if (maximumRetries < 0 || attempt < maximumRetries) {
@@ -2169,13 +2177,33 @@ private[pekko] object TakeWithin {
             case _: NotApplied.type                                            
                                   => failStage(ex)
             case source: Graph[SourceShape[T] @unchecked, M @unchecked] if 
TraversalBuilder.isEmptySource(source) =>
               completeStage()
-            case other: Graph[SourceShape[T] @unchecked, M @unchecked] =>
-              TraversalBuilder.getSingleSource(other) match {
-                case OptionVal.Some(singleSource) =>
-                  emit(out, singleSource.elem.asInstanceOf[T], () => 
completeStage())
+            case source: Graph[SourceShape[T] @unchecked, M @unchecked] =>
+              TraversalBuilder.getValuePresentedSource(source) match {
+                case OptionVal.Some(graph) => graph match {
+                    case singleSource: SingleSource[T @unchecked] => emit(out, 
singleSource.elem, () => completeStage())
+                    case failed: FailedSource[T @unchecked]       =>
+                      attempt += 1
+                      onFailure(failed.failure)
+                    case futureSource: FutureSource[T @unchecked] => 
futureSource.future.value match {
+                        case Some(Success(elem)) => emit(out, elem, () => 
completeStage())
+                        case Some(Failure(ex))   =>
+                          attempt += 1
+                          onFailure(ex)
+                        case None =>
+                          attempt += 1
+                          switchTo(source)
+                      }
+                    case iterableSource: IterableSource[T @unchecked] =>
+                      emitMultiple(out, iterableSource.elements, () => 
completeStage())
+                    case javaStreamSource: JavaStreamSource[T @unchecked, _] =>
+                      emitMultiple(out, javaStreamSource.open().spliterator(), 
() => completeStage())
+                    case _ =>
+                      attempt += 1
+                      switchTo(source)
+                  }
                 case _ =>
-                  switchTo(other)
                   attempt += 1
+                  switchTo(source)
               }
             case _ => throw new IllegalStateException() // won't happen, 
compiler exhaustiveness check pleaser
           }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to