This is an automated email from the ASF dual-hosted git repository.

bharathkk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/samza.git


The following commit(s) were added to refs/heads/master by this push:
     new 66495b677 SAMZA-2796: Introduce config knob for framework thread sub 
DAG execution (#1691)
66495b677 is described below

commit 66495b677a728ff75a8674b217672cd51aece640
Author: Bharath Kumarasubramanian <[email protected]>
AuthorDate: Tue Nov 21 11:56:08 2023 -0800

    SAMZA-2796: Introduce config knob for framework thread sub DAG execution 
(#1691)
    
    Description
    As part of SAMZA-2781, we use framework thread pool to execute hand-offs 
and sub-DAG execution. We want to add a config knob to enable users opt-in to 
the feature as opposed to enable it by default.
    
    Changes
    Introduce config knob to use the framework executor
    
    Tests
    Added unit tests
    
    Usage Instructions
    Refer to the configuration documentation. To enable framework thread pool 
for sub-DAG execution and message hand off, set 
job.operator.framework.executor.enabled to true
---
 .../versioned/jobs/configuration-table.html        |  10 ++
 .../java/org/apache/samza/config/JobConfig.java    |   8 ++
 .../apache/samza/operators/impl/OperatorImpl.java  |  92 ++++++++++-------
 .../org/apache/samza/container/TaskInstance.scala  |   6 +-
 .../samza/operators/impl/TestOperatorImpl.java     | 112 ++++++++++++++++++---
 5 files changed, 175 insertions(+), 53 deletions(-)

diff --git a/docs/learn/documentation/versioned/jobs/configuration-table.html 
b/docs/learn/documentation/versioned/jobs/configuration-table.html
index e00c983d8..390be0376 100644
--- a/docs/learn/documentation/versioned/jobs/configuration-table.html
+++ b/docs/learn/documentation/versioned/jobs/configuration-table.html
@@ -494,6 +494,16 @@
                     </td>
                 </tr>
 
+                <tr>
+                    <td class="property" 
id="job.operator.framework.executor.enabled">job.operator.framework.executor.enabled</td>
+                    <td class="default">false</td>
+                    <td class="description">
+                        If enabled, framework thread pool will be used for 
message hand off and sub DAG execution. Otherwise, the
+                        execution will fall back to using caller thread or 
java fork join pool depending on the type of work
+                        chained as part of message hand off.
+                    </td>
+                </tr>
+
                 <tr>
                                               <!-- change link to StandAlone 
design/tutorial doc. SAMZA-1299 -->
                 <th colspan="3" class="section" id="ZkBasedJobCoordination"><a 
href="../index.html">Zookeeper-based job configuration</a></th>
diff --git a/samza-core/src/main/java/org/apache/samza/config/JobConfig.java 
b/samza-core/src/main/java/org/apache/samza/config/JobConfig.java
index 3d0b53262..17f527252 100644
--- a/samza-core/src/main/java/org/apache/samza/config/JobConfig.java
+++ b/samza-core/src/main/java/org/apache/samza/config/JobConfig.java
@@ -197,6 +197,10 @@ public class JobConfig extends MapConfig {
   public static final String JOB_ELASTICITY_FACTOR = "job.elasticity.factor";
   public static final int DEFAULT_JOB_ELASTICITY_FACTOR = 1;
 
+  public static final String JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED = 
"job.operator.framework.executor.enabled";
+
+  public static final boolean DEFAULT_JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED 
= false;
+
   public JobConfig(Config config) {
     super(config);
   }
@@ -528,4 +532,8 @@ public class JobConfig extends MapConfig {
   public String getCoordinatorExecuteCommand() {
     return get(COORDINATOR_EXECUTE_COMMAND, 
DEFAULT_COORDINATOR_EXECUTE_COMMAND);
   }
+
+  public boolean getOperatorFrameworkExecutorEnabled() {
+    return getBoolean(JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED, 
DEFAULT_JOB_OPERATOR_FRAMEWORK_EXECUTOR_ENABLED);
+  }
 }
\ No newline at end of file
diff --git 
a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java 
b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
index 8b477d42d..c870264e9 100644
--- a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
+++ b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
@@ -22,6 +22,8 @@ import com.google.common.annotations.VisibleForTesting;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionStage;
 import java.util.concurrent.ExecutorService;
+import java.util.function.Consumer;
+import java.util.function.Function;
 import org.apache.samza.SamzaException;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.JobConfig;
@@ -95,6 +97,7 @@ public abstract class OperatorImpl<M, RM> {
   private ControlMessageSender controlMessageSender;
   private int elasticityFactor;
   private ExecutorService operatorExecutor;
+  private boolean operatorExecutorEnabled;
 
   /**
    * Initialize this {@link OperatorImpl} and its user-defined functions.
@@ -136,7 +139,9 @@ public abstract class OperatorImpl<M, RM> {
     this.taskModel = taskContext.getTaskModel();
     this.callbackScheduler = taskContext.getCallbackScheduler();
     handleInit(context);
-    this.elasticityFactor = new JobConfig(config).getElasticityFactor();
+    JobConfig jobConfig = new JobConfig(config);
+    this.elasticityFactor = jobConfig.getElasticityFactor();
+    this.operatorExecutorEnabled = 
jobConfig.getOperatorFrameworkExecutorEnabled();
     this.operatorExecutor = context.getTaskContext().getOperatorExecutor();
 
     initialized = true;
@@ -192,21 +197,20 @@ public abstract class OperatorImpl<M, RM> {
               getOpImplId(), getOperatorSpec().getSourceLocation(), 
expectedType, actualType), e);
     }
 
-    CompletionStage<Void> result = 
completableResultsFuture.thenComposeAsync(results -> {
+    CompletionStage<Void> result = 
composeFutureWithExecutor(completableResultsFuture, results -> {
       long endNs = this.highResClock.nanoTime();
       this.handleMessageNs.update(endNs - startNs);
 
       return CompletableFuture.allOf(results.stream()
-          .flatMap(r -> this.registeredOperators.stream()
-            .map(op -> op.onMessageAsync(r, collector, coordinator)))
+          .flatMap(r -> this.registeredOperators.stream().map(op -> 
op.onMessageAsync(r, collector, coordinator)))
           .toArray(CompletableFuture[]::new));
-    }, operatorExecutor);
+    });
 
     WatermarkFunction watermarkFn = getOperatorSpec().getWatermarkFn();
     if (watermarkFn != null) {
       // check whether there is new watermark emitted from the user function
       Long outputWm = watermarkFn.getOutputWatermark();
-      return result.thenComposeAsync(ignored -> propagateWatermark(outputWm, 
collector, coordinator), operatorExecutor);
+      return composeFutureWithExecutor(result, ignored -> 
propagateWatermark(outputWm, collector, coordinator));
     }
 
     return result;
@@ -245,11 +249,9 @@ public abstract class OperatorImpl<M, RM> {
                 .map(op -> op.onMessageAsync(r, collector, coordinator)))
             .toArray(CompletableFuture[]::new));
 
-    return resultFuture.thenComposeAsync(x ->
-        CompletableFuture.allOf(this.registeredOperators
-            .stream()
-            .map(op -> op.onTimer(collector, coordinator))
-            .toArray(CompletableFuture[]::new)), operatorExecutor);
+    return composeFutureWithExecutor(resultFuture, x -> 
CompletableFuture.allOf(this.registeredOperators.stream()
+        .map(op -> op.onTimer(collector, coordinator))
+        .toArray(CompletableFuture[]::new)));
   }
 
   /**
@@ -315,15 +317,14 @@ public abstract class OperatorImpl<M, RM> {
       }
 
       // populate the end-of-stream through the dag
-      endOfStreamFuture = onEndOfStream(collector, coordinator)
-          .thenAcceptAsync(result -> {
-            if (eosStates.allEndOfStream()) {
-              // all inputs have been end-of-stream, shut down the task
-              LOG.info("All input streams have reached the end for task {}", 
taskName.getTaskName());
-              coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
-              coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
-            }
-          }, operatorExecutor);
+      endOfStreamFuture = acceptFutureWithExecutor(onEndOfStream(collector, 
coordinator), result -> {
+        if (eosStates.allEndOfStream()) {
+          // all inputs have been end-of-stream, shut down the task
+          LOG.info("All input streams have reached the end for task {}", 
taskName.getTaskName());
+          coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+          coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        }
+      });
     }
 
     return endOfStreamFuture;
@@ -347,10 +348,10 @@ public abstract class OperatorImpl<M, RM> {
                   .map(op -> op.onMessageAsync(r, collector, coordinator)))
               .toArray(CompletableFuture[]::new));
 
-      endOfStreamFuture = resultFuture.thenComposeAsync(x ->
-          CompletableFuture.allOf(this.registeredOperators.stream()
+      endOfStreamFuture = composeFutureWithExecutor(resultFuture, x -> 
CompletableFuture.allOf(
+          this.registeredOperators.stream()
               .map(op -> op.onEndOfStream(collector, coordinator))
-              .toArray(CompletableFuture[]::new)), operatorExecutor);
+              .toArray(CompletableFuture[]::new)));
     }
 
     return endOfStreamFuture;
@@ -406,15 +407,14 @@ public abstract class OperatorImpl<M, RM> {
         controlMessageSender.broadcastToOtherPartitions(new 
DrainMessage(drainMessage.getRunId()), ssp, collector);
       }
 
-      drainFuture = onDrainOfStream(collector, coordinator)
-          .thenAcceptAsync(result -> {
-            if (drainStates.areAllStreamsDrained()) {
-              // All input streams have been drained, shut down the task
-              LOG.info("All input streams have been drained for task {}. 
Requesting shutdown.", taskName.getTaskName());
-              coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
-              coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
-            }
-          }, operatorExecutor);
+      drainFuture = acceptFutureWithExecutor(onDrainOfStream(collector, 
coordinator), result -> {
+        if (drainStates.areAllStreamsDrained()) {
+          // All input streams have been drained, shut down the task
+          LOG.info("All input streams have been drained for task {}. 
Requesting shutdown.", taskName.getTaskName());
+          coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+          coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        }
+      });
     }
 
     return drainFuture;
@@ -439,10 +439,10 @@ public abstract class OperatorImpl<M, RM> {
               .toArray(CompletableFuture[]::new));
 
       // propagate DrainMessage to downstream operators
-      drainFuture = resultFuture.thenComposeAsync(x ->
-          CompletableFuture.allOf(this.registeredOperators.stream()
+      drainFuture = composeFutureWithExecutor(resultFuture, x -> 
CompletableFuture.allOf(
+          this.registeredOperators.stream()
               .map(op -> op.onDrainOfStream(collector, coordinator))
-              .toArray(CompletableFuture[]::new)), operatorExecutor);
+              .toArray(CompletableFuture[]::new)));
     }
 
     return drainFuture;
@@ -474,8 +474,8 @@ public abstract class OperatorImpl<M, RM> {
         controlMessageSender.broadcastToOtherPartitions(new 
WatermarkMessage(watermark), ssp, collector);
       }
       // populate the watermark through the dag
-      watermarkFuture = onWatermark(watermark, collector, coordinator)
-          .thenAcceptAsync(ignored -> 
watermarkStates.updateAggregateMetric(ssp, watermark), operatorExecutor);
+      watermarkFuture = acceptFutureWithExecutor(onWatermark(watermark, 
collector, coordinator),
+        ignored -> watermarkStates.updateAggregateMetric(ssp, watermark));
     }
 
     return watermarkFuture;
@@ -530,8 +530,8 @@ public abstract class OperatorImpl<M, RM> {
                 .toArray(CompletableFuture[]::new));
       }
 
-      watermarkFuture = watermarkFuture.thenComposeAsync(res -> 
propagateWatermark(outputWm, collector, coordinator),
-          operatorExecutor);
+      watermarkFuture =
+          composeFutureWithExecutor(watermarkFuture, res -> 
propagateWatermark(outputWm, collector, coordinator));
     }
 
     return watermarkFuture;
@@ -679,6 +679,20 @@ public abstract class OperatorImpl<M, RM> {
         .toCompletableFuture().join();
   }
 
+  @VisibleForTesting
+  final <T, U> CompletionStage<U> composeFutureWithExecutor(CompletionStage<T> 
futureToChain,
+      Function<? super T, ? extends CompletionStage<U>> fn) {
+    return operatorExecutorEnabled ? futureToChain.thenComposeAsync(fn, 
operatorExecutor)
+        : futureToChain.thenCompose(fn);
+  }
+
+  @VisibleForTesting
+  final <T> CompletionStage<Void> acceptFutureWithExecutor(CompletionStage<T> 
futureToChain,
+      Consumer<? super T> consumer) {
+    return operatorExecutorEnabled ? futureToChain.thenAcceptAsync(consumer, 
operatorExecutor)
+        : futureToChain.thenAccept(consumer);
+  }
+
   private HighResolutionClock createHighResClock(Config config) {
     MetricsConfig metricsConfig = new MetricsConfig(config);
     // The timer metrics calculation here is only enabled for debugging
diff --git 
a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala 
b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
index 285e7c877..89738e2de 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
@@ -93,9 +93,13 @@ class TaskInstance(
     val jobConfig = new JobConfig(jobContext.getConfig)
     val taskExecutorFactory = 
ReflectionUtil.getObj(jobConfig.getTaskExecutorFactory, 
classOf[TaskExecutorFactory])
 
+    var operatorExecutor = 
Option.empty[java.util.concurrent.ExecutorService].orNull
+    if (jobConfig.getOperatorFrameworkExecutorEnabled) {
+      operatorExecutor = taskExecutorFactory.getOperatorExecutor(taskName, 
jobContext.getConfig)
+    }
     new TaskContextImpl(taskModel, metrics.registry, kvStoreSupplier, 
tableManager,
       new CallbackSchedulerImpl(epochTimeScheduler), offsetManager, jobModel, 
streamMetadataCache,
-      systemStreamPartitions, 
taskExecutorFactory.getOperatorExecutor(taskName, jobContext.getConfig))
+      systemStreamPartitions, operatorExecutor)
   }
   // need separate field for this instead of using it through Context, since 
Context throws an exception if it is null
   private val applicationTaskContextOption = 
applicationTaskContextFactoryOption
diff --git 
a/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java
 
b/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java
index 9cb307d57..6709417b9 100644
--- 
a/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java
+++ 
b/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpl.java
@@ -18,16 +18,24 @@
  */
 package org.apache.samza.operators.impl;
 
+import com.google.common.collect.ImmutableMap;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionStage;
+import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
+import java.util.function.Consumer;
+import java.util.function.Function;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.context.ContainerContext;
 import org.apache.samza.context.Context;
 import org.apache.samza.context.InternalTaskContext;
-import org.apache.samza.context.MockContext;
+import org.apache.samza.context.JobContext;
+import org.apache.samza.context.TaskContext;
 import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metrics.Counter;
 import org.apache.samza.metrics.MetricsRegistryMap;
@@ -44,33 +52,111 @@ import org.junit.Test;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
-import static org.mockito.Matchers.anyLong;
-import static org.mockito.Matchers.anyObject;
-import static org.mockito.Matchers.anyString;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
+import static org.mockito.Matchers.*;
+import static org.mockito.Mockito.*;
 
 
 public class TestOperatorImpl {
   private Context context;
   private InternalTaskContext internalTaskContext;
 
+  private JobContext jobContext;
+
+  private TaskContext taskContext;
+
+  private ContainerContext containerContext;
+
   @Before
   public void setup() {
-    this.context = new MockContext();
+    this.context = mock(Context.class);
     this.internalTaskContext = mock(InternalTaskContext.class);
+    this.jobContext = mock(JobContext.class);
+    this.taskContext = mock(TaskContext.class);
+    this.containerContext = mock(ContainerContext.class);
     when(this.internalTaskContext.getContext()).thenReturn(this.context);
     // might be necessary in the future
     
when(this.internalTaskContext.fetchObject(EndOfStreamStates.class.getName())).thenReturn(mock(EndOfStreamStates.class));
     
when(this.internalTaskContext.fetchObject(WatermarkStates.class.getName())).thenReturn(mock(WatermarkStates.class));
-    
when(this.context.getTaskContext().getTaskMetricsRegistry()).thenReturn(new 
MetricsRegistryMap());
-    
when(this.context.getTaskContext().getTaskModel()).thenReturn(mock(TaskModel.class));
-    
when(this.context.getTaskContext().getOperatorExecutor()).thenReturn(Executors.newSingleThreadExecutor());
-    
when(this.context.getContainerContext().getContainerMetricsRegistry()).thenReturn(new
 MetricsRegistryMap());
+    when(this.context.getJobContext()).thenReturn(jobContext);
+    when(this.context.getTaskContext()).thenReturn(taskContext);
+    when(this.taskContext.getTaskMetricsRegistry()).thenReturn(new 
MetricsRegistryMap());
+    when(this.taskContext.getTaskModel()).thenReturn(mock(TaskModel.class));
+    
when(this.taskContext.getOperatorExecutor()).thenReturn(Executors.newSingleThreadExecutor());
+    when(this.context.getContainerContext()).thenReturn(containerContext);
+    when(containerContext.getContainerMetricsRegistry()).thenReturn(new 
MetricsRegistryMap());
+  }
+
+  @Test
+  public void testComposeFutureWithExecutorWithFrameworkExecutorEnabled() {
+    OperatorImpl<Object, Object> opImpl = new TestOpImpl(mock(Object.class));
+    ExecutorService mockExecutor = mock(ExecutorService.class);
+    CompletionStage<Object> mockFuture = mock(CompletionStage.class);
+    Function<Object, CompletionStage<Object>> mockFunction = 
mock(Function.class);
+
+    Config config = new 
MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "true"));
+
+    when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor);
+    when(this.jobContext.getConfig()).thenReturn(config);
+
+    opImpl.init(this.internalTaskContext);
+    opImpl.composeFutureWithExecutor(mockFuture, mockFunction);
+
+    verify(mockFuture).thenComposeAsync(eq(mockFunction), eq(mockExecutor));
+  }
+
+  @Test
+  public void testComposeFutureWithExecutorWithFrameworkExecutorDisabled() {
+    OperatorImpl<Object, Object> opImpl = new TestOpImpl(mock(Object.class));
+    ExecutorService mockExecutor = mock(ExecutorService.class);
+    CompletionStage<Object> mockFuture = mock(CompletionStage.class);
+    Function<Object, CompletionStage<Object>> mockFunction = 
mock(Function.class);
+
+    Config config = new 
MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "false"));
+
+    when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor);
+    when(this.jobContext.getConfig()).thenReturn(config);
+
+    opImpl.init(this.internalTaskContext);
+    opImpl.composeFutureWithExecutor(mockFuture, mockFunction);
+
+    verify(mockFuture).thenCompose(eq(mockFunction));
   }
 
+  @Test
+  public void testAcceptFutureWithExecutorWithFrameworkExecutorDisabled() {
+    OperatorImpl<Object, Object> opImpl = new TestOpImpl(mock(Object.class));
+    ExecutorService mockExecutor = mock(ExecutorService.class);
+    CompletionStage<Object> mockFuture = mock(CompletionStage.class);
+    Consumer<Object> mockConsumer = mock(Consumer.class);
+
+    Config config = new 
MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "false"));
+
+    when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor);
+    when(this.jobContext.getConfig()).thenReturn(config);
+
+    opImpl.init(this.internalTaskContext);
+    opImpl.acceptFutureWithExecutor(mockFuture, mockConsumer);
+
+    verify(mockFuture).thenAccept(eq(mockConsumer));
+  }
+
+  @Test
+  public void testAcceptFutureWithExecutorWithFrameworkExecutorEnabled() {
+    OperatorImpl<Object, Object> opImpl = new TestOpImpl(mock(Object.class));
+    ExecutorService mockExecutor = mock(ExecutorService.class);
+    CompletionStage<Object> mockFuture = mock(CompletionStage.class);
+    Consumer<Object> mockConsumer = mock(Consumer.class);
+
+    Config config = new 
MapConfig(ImmutableMap.of("job.operator.framework.executor.enabled", "true"));
+
+    when(this.taskContext.getOperatorExecutor()).thenReturn(mockExecutor);
+    when(this.jobContext.getConfig()).thenReturn(config);
+
+    opImpl.init(this.internalTaskContext);
+    opImpl.acceptFutureWithExecutor(mockFuture, mockConsumer);
+
+    verify(mockFuture).thenAcceptAsync(eq(mockConsumer), eq(mockExecutor));
+  }
   @Test(expected = IllegalStateException.class)
   public void testMultipleInitShouldThrow() {
     OperatorImpl<Object, Object> opImpl = new TestOpImpl(mock(Object.class));

Reply via email to