This is an automated email from the ASF dual-hosted git repository.

chamikara pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 99c87a2c68b Updates the Transform Service to accept Python extra 
packages through the Java API (#28783)
99c87a2c68b is described below

commit 99c87a2c68b09020d9d5fa40da18d432b501c39a
Author: Chamikara Jayalath <[email protected]>
AuthorDate: Tue Oct 10 21:09:58 2023 -0700

    Updates the Transform Service to accept Python extra packages through the 
Java API (#28783)
    
    * Updates the Transform Service to accept Python extra packages through the 
Java API
    
    * Addressing reviewer comments
    
    * Addressing reviewer comments
---
 build.gradle.kts                                   |   2 +
 .../core/construction/TransformUpgrader.java       |   2 +-
 .../extensions/python/PythonExternalTransform.java |  39 ++---
 sdks/java/transform-service/docker-compose/.env    |   8 +
 .../docker-compose/docker-compose.yml              |   3 +-
 sdks/java/transform-service/launcher/build.gradle  |   3 +
 .../launcher/TransformServiceLauncher.java         | 135 +++++++++++++--
 .../launcher/TransformServiceLauncherTest.java     | 185 +++++++++++++++++++++
 .../sdk/transformservice/ExpansionService.java     | 114 ++++++++++++-
 .../sdk/transformservice/ExpansionServiceTest.java |   7 +-
 .../utils/transform_service_launcher.py            |  14 ++
 sdks/python/expansion-service-container/boot.go    | 102 +++++++++++-
 12 files changed, 566 insertions(+), 48 deletions(-)

diff --git a/build.gradle.kts b/build.gradle.kts
index fbea1a59b28..ea1b4e6784e 100644
--- a/build.gradle.kts
+++ b/build.gradle.kts
@@ -310,6 +310,8 @@ tasks.register("javaPreCommit") {
   dependsOn(":sdks:java:testing:test-utils:build")
   dependsOn(":sdks:java:testing:tpcds:build")
   dependsOn(":sdks:java:testing:watermarks:build")
+  dependsOn(":sdks:java:transform-service:build")
+  dependsOn(":sdks:java:transform-service:launcher:build")
 
   dependsOn(":examples:java:preCommit")
   dependsOn(":examples:java:twitter:preCommit")
diff --git 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java
 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java
index d657bb31b18..5e1609f27a3 100644
--- 
a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java
+++ 
b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/TransformUpgrader.java
@@ -108,7 +108,7 @@ public class TransformUpgrader implements AutoCloseable {
     } else if (options.getTransformServiceBeamVersion() != null) {
       String projectName = UUID.randomUUID().toString();
       int port = findAvailablePort();
-      service = TransformServiceLauncher.forProject(projectName, port);
+      service = TransformServiceLauncher.forProject(projectName, port, null);
       service.setBeamVersion(options.getTransformServiceBeamVersion());
 
       // Starting the transform service.
diff --git 
a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java
 
b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java
index 4a5f4f12a07..5ba3484964c 100644
--- 
a/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java
+++ 
b/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java
@@ -495,6 +495,20 @@ public class PythonExternalTransform<InputT extends 
PInput, OutputT extends POut
         boolean pythonAvailable = isPythonAvailable();
         boolean dockerAvailable = isDockerAvailable();
 
+        File requirementsFile = null;
+        if (!extraPackages.isEmpty()) {
+          requirementsFile = File.createTempFile("requirements", ".txt");
+          requirementsFile.deleteOnExit();
+          try (Writer fout =
+              new OutputStreamWriter(
+                  new FileOutputStream(requirementsFile.getAbsolutePath()), 
Charsets.UTF_8)) {
+            for (String pkg : extraPackages) {
+              fout.write(pkg);
+              fout.write('\n');
+            }
+          }
+        }
+
         // We use the transform service if either of the following is true.
         // * It was explicitly requested.
         // * Python executable is not available in the system but Docker is 
available.
@@ -514,19 +528,16 @@ public class PythonExternalTransform<InputT extends 
PInput, OutputT extends POut
               projectName,
               port);
 
-          TransformServiceLauncher service = 
TransformServiceLauncher.forProject(projectName, port);
+          String pythonRequirementsFile =
+              requirementsFile != null ? requirementsFile.getAbsolutePath() : 
null;
+          TransformServiceLauncher service =
+              TransformServiceLauncher.forProject(projectName, port, 
pythonRequirementsFile);
           service.setBeamVersion(ReleaseInfo.getReleaseInfo().getSdkVersion());
-          // TODO(https://github.com/apache/beam/issues/26833): add support 
for installing extra
-          // packages.
-          if (!extraPackages.isEmpty()) {
-            throw new RuntimeException(
-                "Transform Service does not support installing extra packages 
yet");
-          }
           try {
             // Starting the transform service.
             service.start();
             // Waiting the service to be ready.
-            service.waitTillUp(15000);
+            service.waitTillUp(-1);
             // Expanding the transform.
             output = apply(input, String.format("localhost:%s", port), 
payload);
           } finally {
@@ -539,17 +550,7 @@ public class PythonExternalTransform<InputT extends 
PInput, OutputT extends POut
           ImmutableList.Builder<String> args = ImmutableList.builder();
           args.add(
               "--port=" + port, "--fully_qualified_name_glob=*", 
"--pickle_library=cloudpickle");
-          if (!extraPackages.isEmpty()) {
-            File requirementsFile = File.createTempFile("requirements", 
".txt");
-            requirementsFile.deleteOnExit();
-            try (Writer fout =
-                new OutputStreamWriter(
-                    new FileOutputStream(requirementsFile.getAbsolutePath()), 
Charsets.UTF_8)) {
-              for (String pkg : extraPackages) {
-                fout.write(pkg);
-                fout.write('\n');
-              }
-            }
+          if (requirementsFile != null) {
             args.add("--requirements_file=" + 
requirementsFile.getAbsolutePath());
           }
           PythonService service =
diff --git a/sdks/java/transform-service/docker-compose/.env 
b/sdks/java/transform-service/docker-compose/.env
index 5de5982cfa3..ed27b267fed 100644
--- a/sdks/java/transform-service/docker-compose/.env
+++ b/sdks/java/transform-service/docker-compose/.env
@@ -12,6 +12,14 @@
 
 BEAM_VERSION=$BEAM_VERSION
 CREDENTIALS_VOLUME=$CREDENTIALS_VOLUME
+DEPENDENCIES_VOLUME=$DEPENDENCIES_VOLUME
+
+# A requirements file with either of the following
+# *  PyPi packages
+# * Locally available packages relative to the directory provided to
+#   DEPENDENCIES_VOLUME.
+PYTHON_REQUIREMENTS_FILE_NAME=$PYTHON_REQUIREMENTS_FILE_NAME
+
 GOOGLE_APPLICATION_CREDENTIALS_FILE_NAME=application_default_credentials.json
 COMPOSE_PROJECT_NAME=apache.beam.transform.service
 TRANSFORM_SERVICE_PORT=$TRANSFORM_SERVICE_PORT
diff --git a/sdks/java/transform-service/docker-compose/docker-compose.yml 
b/sdks/java/transform-service/docker-compose/docker-compose.yml
index b685be10a32..39235533b9a 100644
--- a/sdks/java/transform-service/docker-compose/docker-compose.yml
+++ b/sdks/java/transform-service/docker-compose/docker-compose.yml
@@ -32,8 +32,9 @@ services:
   expansion-service-2:
     image: "apache/beam_python_expansion_service:${BEAM_VERSION}"
     restart: on-failure
-    command: -id expansion-service-2 -port 5001
+    command: -id expansion-service-2 -port 5001 -requirements_file 
${PYTHON_REQUIREMENTS_FILE_NAME} -dependencies_dir '/dependencies_volume'
     volumes:
       - ${CREDENTIALS_VOLUME}:/credentials_volume
+      - ${DEPENDENCIES_VOLUME}:/dependencies_volume
     environment:
       - 
GOOGLE_APPLICATION_CREDENTIALS=/credentials_volume/${GOOGLE_APPLICATION_CREDENTIALS_FILE_NAME}
diff --git a/sdks/java/transform-service/launcher/build.gradle 
b/sdks/java/transform-service/launcher/build.gradle
index 83c5d60a1ef..0952f37109e 100644
--- a/sdks/java/transform-service/launcher/build.gradle
+++ b/sdks/java/transform-service/launcher/build.gradle
@@ -45,6 +45,9 @@ dependencies {
     shadow library.java.args4j
     shadow library.java.error_prone_annotations
     permitUnusedDeclared(library.java.error_prone_annotations)
+    testImplementation library.java.junit
+    testImplementation library.java.mockito_core
+    testImplementation project(path: ":sdks:java:core")
 }
 
 sourceSets {
diff --git 
a/sdks/java/transform-service/launcher/src/main/java/org/apache/beam/sdk/transformservice/launcher/TransformServiceLauncher.java
 
b/sdks/java/transform-service/launcher/src/main/java/org/apache/beam/sdk/transformservice/launcher/TransformServiceLauncher.java
index f52fdfed710..c0a9097a762 100644
--- 
a/sdks/java/transform-service/launcher/src/main/java/org/apache/beam/sdk/transformservice/launcher/TransformServiceLauncher.java
+++ 
b/sdks/java/transform-service/launcher/src/main/java/org/apache/beam/sdk/transformservice/launcher/TransformServiceLauncher.java
@@ -17,9 +17,11 @@
  */
 package org.apache.beam.sdk.transformservice.launcher;
 
+import java.io.BufferedWriter;
 import java.io.File;
 import java.io.FileOutputStream;
 import java.io.IOException;
+import java.nio.charset.StandardCharsets;
 import java.nio.file.Path;
 import java.nio.file.Paths;
 import java.util.ArrayList;
@@ -28,6 +30,7 @@ import java.util.List;
 import java.util.Locale;
 import java.util.Map;
 import java.util.concurrent.TimeoutException;
+import java.util.stream.Stream;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.Files;
@@ -62,9 +65,9 @@ public class TransformServiceLauncher {
   private static final int STATUS_LOGGER_WAIT_TIME = 3000;
 
   @SuppressWarnings("argument")
-  private TransformServiceLauncher(@Nullable String projectName, int port) 
throws IOException {
-    LOG.info("Initializing the Beam Transform Service {}.", projectName);
-
+  private TransformServiceLauncher(
+      @Nullable String projectName, int port, @Nullable String 
pythonRequirementsFile)
+      throws IOException {
     String tmpDirLocation = System.getProperty("java.io.tmpdir");
     // We use Docker Compose project name as the name of the temporary 
directory to isolate
     // different transform service instances that may be running in the same 
machine.
@@ -83,14 +86,14 @@ public class TransformServiceLauncher {
       ByteStreams.copy(getClass().getResourceAsStream("/.env"), fout);
     }
 
+    // Setting up the credentials directory.
     File credentialsDir = Paths.get(tmpDir, "credentials_dir").toFile();
-    LOG.info(
-        "Creating a temporary directory for storing credentials: "
-            + credentialsDir.getAbsolutePath());
-
     if (credentialsDir.exists()) {
       LOG.info("Reusing the existing credentials directory " + 
credentialsDir.getAbsolutePath());
     } else {
+      LOG.info(
+          "Creating a temporary directory for storing credentials: "
+              + credentialsDir.getAbsolutePath());
       if (!credentialsDir.mkdir()) {
         throw new IOException(
             "Could not create a temporary directory for storing credentials: "
@@ -124,10 +127,84 @@ public class TransformServiceLauncher {
       }
     }
 
+    // Setting up the dependencies directory.
+    File dependenciesDir = Paths.get(tmpDir, "dependencies_dir").toFile();
+    Path updatedRequirementsFilePath = Paths.get(dependenciesDir.toString(), 
"requirements.txt");
+    if (dependenciesDir.exists()) {
+      LOG.info("Reusing the existing dependencies directory " + 
dependenciesDir.getAbsolutePath());
+    } else {
+      LOG.info(
+          "Creating a temporary directory for storing dependencies: "
+              + dependenciesDir.getAbsolutePath());
+      if (!dependenciesDir.mkdir()) {
+        throw new IOException(
+            "Could not create a temporary directory for storing dependencies: "
+                + dependenciesDir.getAbsolutePath());
+      }
+
+      // We create a requirements file with extra dependencies.
+      // If there are no extra dependencies, we just provide an empty 
requirements file.
+      File file = updatedRequirementsFilePath.toFile();
+      if (!file.createNewFile()) {
+        throw new IOException(
+            "Could not create the new requirements file " + 
updatedRequirementsFilePath);
+      }
+
+      // Updating dependencies.
+      if (pythonRequirementsFile != null) {
+        Path requirementsFilePath = Paths.get(pythonRequirementsFile);
+        List<String> updatedLines = new ArrayList<>();
+
+        try (Stream<String> lines = 
java.nio.file.Files.lines(requirementsFilePath)) {
+          lines.forEachOrdered(
+              line -> {
+                Path dependencyFilePath = Paths.get(line);
+                if (java.nio.file.Files.exists(dependencyFilePath)) {
+                  Path fileName = dependencyFilePath.getFileName();
+                  if (fileName == null) {
+                    throw new IllegalArgumentException(
+                        "Could not determine the filename of the local 
artifact "
+                            + dependencyFilePath);
+                  }
+                  try {
+                    java.nio.file.Files.copy(
+                        dependencyFilePath,
+                        Paths.get(dependenciesDir.toString(), 
fileName.toString()));
+                  } catch (IOException e) {
+                    throw new RuntimeException(e);
+                  }
+                  updatedLines.add(fileName.toString());
+                } else {
+                  updatedLines.add(line);
+                }
+              });
+        }
+
+        try (BufferedWriter writer =
+            java.nio.file.Files.newBufferedWriter(file.toPath(), 
StandardCharsets.UTF_8)) {
+          for (String line : updatedLines) {
+            writer.write(line);
+            writer.newLine();
+          }
+          writer.flush();
+        }
+      }
+    }
+
     // Setting environment variables used by the docker-compose.yml file.
     environmentVariables.put("CREDENTIALS_VOLUME", 
credentialsDir.getAbsolutePath());
+    environmentVariables.put("DEPENDENCIES_VOLUME", 
dependenciesDir.getAbsolutePath());
     environmentVariables.put("TRANSFORM_SERVICE_PORT", String.valueOf(port));
 
+    Path updatedRequirementsFileName = 
updatedRequirementsFilePath.getFileName();
+    if (updatedRequirementsFileName == null) {
+      throw new IllegalArgumentException(
+          "Could not determine the file name of the updated requirements file "
+              + updatedRequirementsFilePath);
+    }
+    environmentVariables.put(
+        "PYTHON_REQUIREMENTS_FILE_NAME", 
updatedRequirementsFileName.toString());
+
     // Building the Docker Compose command.
     dockerComposeStartCommandPrefix.add("docker-compose");
     dockerComposeStartCommandPrefix.add("-p");
@@ -136,21 +213,37 @@ public class TransformServiceLauncher {
     dockerComposeStartCommandPrefix.add(dockerComposeFile.getAbsolutePath());
   }
 
+  /**
+   * Specifies the Beam version to get containers for the transform service.
+   *
+   * <p>Could be a release Beam version with containers in Docker Hub or an 
unreleased Beam version
+   * for which containers are available locally.
+   *
+   * @param beamVersion a Beam version to get containers from.
+   */
   public void setBeamVersion(String beamVersion) {
     environmentVariables.put("BEAM_VERSION", beamVersion);
   }
 
-  public void setPythonExtraPackages(String pythonExtraPackages) {
-    environmentVariables.put("$PYTHON_EXTRA_PACKAGES", pythonExtraPackages);
-  }
-
+  /**
+   * Initializes a client for managing transform service instances.
+   *
+   * @param projectName project name for the transform service.
+   * @param port port exposed by the transform service.
+   * @param pythonRequirementsFile a requirements file with extra dependencies 
for the Python
+   *     expansion services.
+   * @return an initialized client for managing the transform service.
+   * @throws IOException
+   */
   public static synchronized TransformServiceLauncher forProject(
-      @Nullable String projectName, int port) throws IOException {
+      @Nullable String projectName, int port, @Nullable String 
pythonRequirementsFile)
+      throws IOException {
     if (projectName == null || projectName.isEmpty()) {
       projectName = DEFAULT_PROJECT_NAME;
     }
     if (!launchers.containsKey(projectName)) {
-      launchers.put(projectName, new TransformServiceLauncher(projectName, 
port));
+      launchers.put(
+          projectName, new TransformServiceLauncher(projectName, port, 
pythonRequirementsFile));
     }
     return launchers.get(projectName);
   }
@@ -200,10 +293,10 @@ public class TransformServiceLauncher {
 
   public synchronized void waitTillUp(int timeout) throws IOException, 
TimeoutException {
     timeout = timeout <= 0 ? DEFAULT_START_WAIT_TIME : timeout;
-    String statusFileName = getStatus();
 
     long startTime = System.currentTimeMillis();
     while (System.currentTimeMillis() - startTime < timeout) {
+      String statusFileName = getStatus();
       try {
         // We are just waiting for a local process. No need for exponential 
backoff.
         this.wait(1000);
@@ -226,6 +319,7 @@ public class TransformServiceLauncher {
 
   private synchronized String getStatus() throws IOException {
     File outputOverride = File.createTempFile("output_override", null);
+    outputOverride.deleteOnExit();
     runDockerComposeCommand(ImmutableList.of("ps"), outputOverride);
 
     return outputOverride.getAbsolutePath();
@@ -238,6 +332,8 @@ public class TransformServiceLauncher {
     static final String PORT_ARG_NAME = "port";
     static final String BEAM_VERSION_ARG_NAME = "beam_version";
 
+    static final String PYTHON_REQUIREMENTS_FILE_ARG_NAME = 
"python_requirements_file";
+
     @Option(name = "--" + PROJECT_NAME_ARG_NAME, usage = "Docker compose 
project name")
     private String projectName = "";
 
@@ -249,6 +345,11 @@ public class TransformServiceLauncher {
 
     @Option(name = "--" + BEAM_VERSION_ARG_NAME, usage = "Beam version to 
use.")
     private String beamVersion = "";
+
+    @Option(
+        name = "--" + PYTHON_REQUIREMENTS_FILE_ARG_NAME,
+        usage = "Extra Python packages in the form of an requirements file.")
+    private String pythonRequirementsFile = "";
   }
 
   public static void main(String[] args) throws IOException, TimeoutException {
@@ -288,8 +389,12 @@ public class TransformServiceLauncher {
                 : ("port " + Integer.toString(config.port) + ".")));
     System.out.println("===================================================");
 
+    String pythonRequirementsFile =
+        !config.pythonRequirementsFile.isEmpty() ? 
config.pythonRequirementsFile : null;
+
     TransformServiceLauncher service =
-        TransformServiceLauncher.forProject(config.projectName, config.port);
+        TransformServiceLauncher.forProject(
+            config.projectName, config.port, pythonRequirementsFile);
     if (!config.beamVersion.isEmpty()) {
       service.setBeamVersion(config.beamVersion);
     }
diff --git 
a/sdks/java/transform-service/launcher/src/test/java/org/apache/beam/sdk/transformservice/launcher/TransformServiceLauncherTest.java
 
b/sdks/java/transform-service/launcher/src/test/java/org/apache/beam/sdk/transformservice/launcher/TransformServiceLauncherTest.java
new file mode 100644
index 00000000000..4ef84b02061
--- /dev/null
+++ 
b/sdks/java/transform-service/launcher/src/test/java/org/apache/beam/sdk/transformservice/launcher/TransformServiceLauncherTest.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.transformservice.launcher;
+
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.Writer;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.UUID;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Charsets;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class TransformServiceLauncherTest {
+
+  @Test
+  public void testLauncherCreatesCredentialsDir() throws IOException {
+    String projectName = UUID.randomUUID().toString();
+    Path expectedTempDir = Paths.get(System.getProperty("java.io.tmpdir"), 
projectName);
+    File file = expectedTempDir.toFile();
+    file.deleteOnExit();
+    TransformServiceLauncher.forProject(projectName, 12345, null);
+    Path expectedCredentialsDir = Paths.get(expectedTempDir.toString(), 
"credentials_dir");
+    assertTrue(expectedCredentialsDir.toFile().exists());
+  }
+
+  @Test
+  public void testLauncherCreatesDependenciesDir() throws IOException {
+    String projectName = UUID.randomUUID().toString();
+    Path expectedTempDir = Paths.get(System.getProperty("java.io.tmpdir"), 
projectName);
+    File file = expectedTempDir.toFile();
+    file.deleteOnExit();
+    TransformServiceLauncher.forProject(projectName, 12345, null);
+    Path expectedCredentialsDir = Paths.get(expectedTempDir.toString(), 
"dependencies_dir");
+    assertTrue(expectedCredentialsDir.toFile().exists());
+  }
+
+  @Test
+  public void testLauncherInstallsDependencies() throws IOException {
+    String projectName = UUID.randomUUID().toString();
+    Path expectedTempDir = Paths.get(System.getProperty("java.io.tmpdir"), 
projectName);
+    File file = expectedTempDir.toFile();
+    file.deleteOnExit();
+
+    File requirementsFile =
+        Paths.get(
+                System.getProperty("java.io.tmpdir"),
+                ("requirements" + UUID.randomUUID().toString() + ".txt"))
+            .toFile();
+    requirementsFile.deleteOnExit();
+
+    try (Writer fout =
+        new OutputStreamWriter(
+            new FileOutputStream(requirementsFile.getAbsolutePath()), 
Charsets.UTF_8)) {
+      fout.write("pypipackage1\n");
+      fout.write("pypipackage2\n");
+    }
+
+    TransformServiceLauncher.forProject(projectName, 12345, 
requirementsFile.getAbsolutePath());
+
+    // Confirming that the Transform Service launcher created a temporary 
requirements file with the
+    // specified set of packages.
+    Path expectedUpdatedRequirementsFile =
+        Paths.get(expectedTempDir.toString(), "dependencies_dir", 
"requirements.txt");
+    assertTrue(expectedUpdatedRequirementsFile.toFile().exists());
+
+    ArrayList<String> expectedUpdatedRequirementsFileLines = new ArrayList<>();
+    try (BufferedReader bufReader =
+        Files.newBufferedReader(expectedUpdatedRequirementsFile, UTF_8)) {
+      String line = bufReader.readLine();
+      while (line != null) {
+        expectedUpdatedRequirementsFileLines.add(line);
+        line = bufReader.readLine();
+      }
+    }
+
+    assertEquals(2, expectedUpdatedRequirementsFileLines.size());
+    assertTrue(expectedUpdatedRequirementsFileLines.contains("pypipackage1"));
+    assertTrue(expectedUpdatedRequirementsFileLines.contains("pypipackage2"));
+  }
+
+  @Test
+  public void testLauncherInstallsLocalDependencies() throws IOException {
+    String projectName = UUID.randomUUID().toString();
+    Path expectedTempDir = Paths.get(System.getProperty("java.io.tmpdir"), 
projectName);
+    File file = expectedTempDir.toFile();
+    file.deleteOnExit();
+
+    String dependency1FileName = "dep_" + UUID.randomUUID().toString();
+    File dependency1 =
+        Paths.get(System.getProperty("java.io.tmpdir"), 
dependency1FileName).toFile();
+    dependency1.deleteOnExit();
+    try (Writer fout =
+        new OutputStreamWriter(
+            new FileOutputStream(dependency1.getAbsolutePath()), 
Charsets.UTF_8)) {
+      fout.write("tempdata\n");
+    }
+
+    String dependency2FileName = "dep_" + UUID.randomUUID().toString();
+    File dependency2 =
+        Paths.get(System.getProperty("java.io.tmpdir"), 
dependency2FileName).toFile();
+    dependency2.deleteOnExit();
+    try (Writer fout =
+        new OutputStreamWriter(
+            new FileOutputStream(dependency2.getAbsolutePath()), 
Charsets.UTF_8)) {
+      fout.write("tempdata\n");
+    }
+
+    File requirementsFile =
+        Paths.get(
+                System.getProperty("java.io.tmpdir"),
+                ("requirements" + UUID.randomUUID().toString() + ".txt"))
+            .toFile();
+    requirementsFile.deleteOnExit();
+    try (Writer fout =
+        new OutputStreamWriter(
+            new FileOutputStream(requirementsFile.getAbsolutePath()), 
Charsets.UTF_8)) {
+      fout.write(dependency1.getAbsolutePath() + "\n");
+      fout.write(dependency2.getAbsolutePath() + "\n");
+      fout.write("pypipackage" + "\n");
+    }
+
+    TransformServiceLauncher.forProject(projectName, 12345, 
requirementsFile.getAbsolutePath());
+
+    // Confirming that the Transform Service launcher created a temporary 
requirements file with the
+    // specified set of packages.
+    Path expectedUpdatedRequirementsFile =
+        Paths.get(expectedTempDir.toString(), "dependencies_dir", 
"requirements.txt");
+    assertTrue(expectedUpdatedRequirementsFile.toFile().exists());
+
+    ArrayList<String> expectedUpdatedRequirementsFileLines = new ArrayList<>();
+    try (BufferedReader bufReader =
+        Files.newBufferedReader(expectedUpdatedRequirementsFile, UTF_8)) {
+      String line = bufReader.readLine();
+      while (line != null) {
+        expectedUpdatedRequirementsFileLines.add(line);
+        line = bufReader.readLine();
+      }
+    }
+
+    // To make local packages available to the expansion service Docker 
containers, the temporary
+    // requirements file should contain names of the local packages relative 
to the dependencies
+    // volume and local packages should have been copied to the dependencies 
volume.
+    assertEquals(3, expectedUpdatedRequirementsFileLines.size());
+    
assertTrue(expectedUpdatedRequirementsFileLines.contains(dependency1FileName));
+    
assertTrue(expectedUpdatedRequirementsFileLines.contains(dependency2FileName));
+    assertTrue(expectedUpdatedRequirementsFileLines.contains("pypipackage"));
+
+    assertTrue(
+        Paths.get(expectedTempDir.toString(), "dependencies_dir", 
dependency1FileName)
+            .toFile()
+            .exists());
+    assertTrue(
+        Paths.get(expectedTempDir.toString(), "dependencies_dir", 
dependency2FileName)
+            .toFile()
+            .exists());
+  }
+}
diff --git 
a/sdks/java/transform-service/src/main/java/org/apache/beam/sdk/transformservice/ExpansionService.java
 
b/sdks/java/transform-service/src/main/java/org/apache/beam/sdk/transformservice/ExpansionService.java
index 17fe5472f9f..0a2e65099e7 100644
--- 
a/sdks/java/transform-service/src/main/java/org/apache/beam/sdk/transformservice/ExpansionService.java
+++ 
b/sdks/java/transform-service/src/main/java/org/apache/beam/sdk/transformservice/ExpansionService.java
@@ -17,15 +17,22 @@
  */
 package org.apache.beam.sdk.transformservice;
 
+import java.io.IOException;
+import java.net.Socket;
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeoutException;
 import org.apache.beam.model.expansion.v1.ExpansionApi;
+import org.apache.beam.model.expansion.v1.ExpansionApi.ExpansionResponse;
 import org.apache.beam.model.expansion.v1.ExpansionServiceGrpc;
 import org.apache.beam.model.pipeline.v1.Endpoints;
 import 
org.apache.beam.runners.core.construction.DefaultExpansionServiceClientFactory;
 import org.apache.beam.runners.core.construction.ExpansionServiceClientFactory;
 import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.ManagedChannelBuilder;
 import org.apache.beam.vendor.grpc.v1p54p0.io.grpc.stub.StreamObserver;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Throwables;
 import org.checkerframework.checker.nullness.qual.Nullable;
 
@@ -40,6 +47,12 @@ public class ExpansionService extends 
ExpansionServiceGrpc.ExpansionServiceImplB
 
   final List<Endpoints.ApiServiceDescriptor> endpoints;
 
+  private boolean checkedAllServices = false;
+
+  private static final long SERVICE_CHECK_TIMEOUT_MILLIS = 60000;
+
+  private boolean disableServiceCheck = false;
+
   ExpansionService(
       List<Endpoints.ApiServiceDescriptor> endpoints,
       @Nullable ExpansionServiceClientFactory clientFactory) {
@@ -48,10 +61,65 @@ public class ExpansionService extends 
ExpansionServiceGrpc.ExpansionServiceImplB
         clientFactory != null ? clientFactory : 
DEFAULT_EXPANSION_SERVICE_CLIENT_FACTORY;
   }
 
+  // Waits till all expansion services are ready.
+  private void waitForAllServicesToBeReady() throws TimeoutException {
+    if (disableServiceCheck) {
+      // Service check disabled. Just returning.
+      return;
+    }
+
+    outer:
+    for (Endpoints.ApiServiceDescriptor endpoint : endpoints) {
+      long start = System.currentTimeMillis();
+      long duration = 10;
+      while (System.currentTimeMillis() - start < 
SERVICE_CHECK_TIMEOUT_MILLIS) {
+        try {
+          String url = endpoint.getUrl();
+          int portIndex = url.lastIndexOf(":");
+          if (portIndex <= 0) {
+            throw new RuntimeException(
+                "Expected the endpoint to be of the form <host>:<port> but 
received " + url);
+          }
+          int port = Integer.parseInt(url.substring(portIndex + 1));
+          String host = url.substring(0, portIndex);
+          new Socket(host, port).close();
+          // Current service is up. Checking the next one.
+          continue outer;
+        } catch (IOException exn) {
+          try {
+            Thread.sleep(duration);
+          } catch (InterruptedException e) {
+            // Ignore
+          }
+          duration = (long) (duration * 1.2);
+        }
+      }
+      throw new TimeoutException(
+          "Timeout waiting for the service "
+              + endpoint.getUrl()
+              + " to startup after "
+              + (System.currentTimeMillis() - start)
+              + " milliseconds.");
+    }
+  }
+
+  @VisibleForTesting
+  void disableServiceCheck() {
+    disableServiceCheck = true;
+  }
+
   @Override
   public void expand(
       ExpansionApi.ExpansionRequest request,
       StreamObserver<ExpansionApi.ExpansionResponse> responseObserver) {
+    if (!checkedAllServices) {
+      try {
+        waitForAllServicesToBeReady();
+      } catch (TimeoutException e) {
+        throw new RuntimeException(e);
+      }
+      checkedAllServices = true;
+    }
     try {
       responseObserver.onNext(processExpand(request));
       responseObserver.onCompleted();
@@ -68,6 +136,14 @@ public class ExpansionService extends 
ExpansionServiceGrpc.ExpansionServiceImplB
   public void discoverSchemaTransform(
       ExpansionApi.DiscoverSchemaTransformRequest request,
       StreamObserver<ExpansionApi.DiscoverSchemaTransformResponse> 
responseObserver) {
+    if (!checkedAllServices) {
+      try {
+        waitForAllServicesToBeReady();
+      } catch (TimeoutException e) {
+        throw new RuntimeException(e);
+      }
+      checkedAllServices = true;
+    }
     try {
       responseObserver.onNext(processDiscover(request));
       responseObserver.onCompleted();
@@ -80,18 +156,41 @@ public class ExpansionService extends 
ExpansionServiceGrpc.ExpansionServiceImplB
     }
   }
 
-  /*package*/ ExpansionApi.ExpansionResponse 
processExpand(ExpansionApi.ExpansionRequest request) {
+  private ExpansionApi.ExpansionResponse getAggregatedErrorResponse(
+      Map<String, ExpansionApi.ExpansionResponse> errorResponses) {
+    StringBuilder errorMessageBuilder = new StringBuilder();
+
+    errorMessageBuilder.append(
+        "Aggregated errors from " + errorResponses.size() + " expansion 
services." + "\n");
+    for (Map.Entry<String, ExpansionApi.ExpansionResponse> entry : 
errorResponses.entrySet()) {
+      errorMessageBuilder.append(
+          "Error from expansion service "
+              + entry.getKey()
+              + ": "
+              + entry.getValue().getError()
+              + "\n");
+    }
+
+    return errorResponses
+        .values()
+        .iterator()
+        .next()
+        .toBuilder()
+        .setError(errorMessageBuilder.toString())
+        .build();
+  }
+
+  ExpansionApi.ExpansionResponse processExpand(ExpansionApi.ExpansionRequest 
request) {
     // Trying out expansion services in order till one succeeds.
     // If all services fail, re-raises the last error.
-    // TODO: when all services fail, return an aggregated error with errors 
from all services.
-    ExpansionApi.ExpansionResponse lastErrorResponse = null;
+    Map<String, ExpansionResponse> errorResponses = new HashMap<>();
     RuntimeException lastException = null;
     for (Endpoints.ApiServiceDescriptor endpoint : endpoints) {
       try {
         ExpansionApi.ExpansionResponse response =
             
expansionServiceClientFactory.getExpansionServiceClient(endpoint).expand(request);
         if (!response.getError().isEmpty()) {
-          lastErrorResponse = response;
+          errorResponses.put(endpoint.getUrl(), response);
           continue;
         }
         return response;
@@ -99,8 +198,11 @@ public class ExpansionService extends 
ExpansionServiceGrpc.ExpansionServiceImplB
         lastException = e;
       }
     }
-    if (lastErrorResponse != null) {
-      return lastErrorResponse;
+    if (lastException != null) {
+      throw new RuntimeException("Expansion request to transform service 
failed.", lastException);
+    }
+    if (!errorResponses.isEmpty()) {
+      return getAggregatedErrorResponse(errorResponses);
     } else if (lastException != null) {
       throw new RuntimeException("Expansion request to transform service 
failed.", lastException);
     } else {
diff --git 
a/sdks/java/transform-service/src/test/java/org/apache/beam/sdk/transformservice/ExpansionServiceTest.java
 
b/sdks/java/transform-service/src/test/java/org/apache/beam/sdk/transformservice/ExpansionServiceTest.java
index 298bce87f90..9905abd1d9b 100644
--- 
a/sdks/java/transform-service/src/test/java/org/apache/beam/sdk/transformservice/ExpansionServiceTest.java
+++ 
b/sdks/java/transform-service/src/test/java/org/apache/beam/sdk/transformservice/ExpansionServiceTest.java
@@ -60,6 +60,8 @@ public class ExpansionServiceTest {
     endpoints.add(endpoint2);
     clientFactory = Mockito.mock(ExpansionServiceClientFactory.class);
     expansionService = new ExpansionService(endpoints, clientFactory);
+    // We do not run actual services in unit tests.
+    expansionService.disableServiceCheck();
   }
 
   @Test
@@ -131,7 +133,10 @@ public class ExpansionServiceTest {
     ArgumentCaptor<ExpansionResponse> expansionResponseCapture =
         ArgumentCaptor.forClass(ExpansionResponse.class);
     
Mockito.verify(responseObserver).onNext(expansionResponseCapture.capture());
-    assertEquals("expansion error 2", 
expansionResponseCapture.getValue().getError());
+
+    // Error response should contain errors from both expansion services.
+    
assertTrue(expansionResponseCapture.getValue().getError().contains("expansion 
error 1"));
+    
assertTrue(expansionResponseCapture.getValue().getError().contains("expansion 
error 2"));
   }
 
   @Test
diff --git a/sdks/python/apache_beam/utils/transform_service_launcher.py 
b/sdks/python/apache_beam/utils/transform_service_launcher.py
index 33feab9bf29..ac492513aba 100644
--- a/sdks/python/apache_beam/utils/transform_service_launcher.py
+++ b/sdks/python/apache_beam/utils/transform_service_launcher.py
@@ -86,6 +86,7 @@ class TransformServiceLauncher(object):
 
     compose_file = os.path.join(temp_dir, 'docker-compose.yml')
 
+    # Creating the credentials volume.
     credentials_dir = os.path.join(temp_dir, 'credentials_dir')
     if not os.path.exists(credentials_dir):
       os.mkdir(credentials_dir)
@@ -111,11 +112,24 @@ class TransformServiceLauncher(object):
           'credentials file at the expected location %s.' %
           application_default_path_file)
 
+    # Creating the dependencies volume.
+    dependencies_dir = os.path.join(temp_dir, 'dependencies_dir')
+    if not os.path.exists(dependencies_dir):
+      os.mkdir(dependencies_dir)
+
     self._environmental_variables = {}
     self._environmental_variables['CREDENTIALS_VOLUME'] = credentials_dir
+    self._environmental_variables['DEPENDENCIES_VOLUME'] = dependencies_dir
     self._environmental_variables['TRANSFORM_SERVICE_PORT'] = str(port)
     self._environmental_variables['BEAM_VERSION'] = beam_version
 
+    # Setting an empty requirements file
+    requirements_file_name = os.path.join(dependencies_dir, 'requirements.txt')
+    with open(requirements_file_name, 'w') as _:
+      pass
+    self._environmental_variables['PYTHON_REQUIREMENTS_FILE_NAME'] = (
+        'requirements.txt')
+
     self._docker_compose_start_command_prefix = []
     self._docker_compose_start_command_prefix.append('docker-compose')
     self._docker_compose_start_command_prefix.append('-p')
diff --git a/sdks/python/expansion-service-container/boot.go 
b/sdks/python/expansion-service-container/boot.go
index 90a97c35425..ba56b349c4e 100644
--- a/sdks/python/expansion-service-container/boot.go
+++ b/sdks/python/expansion-service-container/boot.go
@@ -18,8 +18,10 @@
 package main
 
 import (
+       "bufio"
        "flag"
        "fmt"
+       "io/ioutil"
        "log"
        "os"
        "path/filepath"
@@ -31,16 +33,15 @@ import (
 )
 
 var (
-       id   = flag.String("id", "", "Local identifier (required)")
-       port = flag.Int("port", 0, "Port for the expansion service (required)")
+       id                = flag.String("id", "", "Local identifier (required)")
+       port              = flag.Int("port", 0, "Port for the expansion service 
(required)")
+       requirements_file = flag.String("requirements_file", "", "A requirement 
file with extra packages to be made available to the transforms being expanded. 
Path should be relative to the 'dependencies_dir'")
+       dependencies_dir  = flag.String("dependencies_dir", "", "A directory 
that stores locally available extra packages.")
 )
 
 const (
        expansionServiceEntrypoint = 
"apache_beam.runners.portability.expansion_service_main"
        venvDirectory              = "beam_venv" // This should match the venv 
directory name used in the Dockerfile.
-       requirementsFile           = "requirements.txt"
-       beamSDKArtifact            = "apache-beam-sdk.tar.gz"
-       beamSDKOptions             = "[gcp,dataframe]"
 )
 
 func main() {
@@ -58,6 +59,79 @@ func main() {
        }
 }
 
+func getLines(fileNameToRead string) ([]string, error) {
+       fileToRead, err := os.Open(fileNameToRead)
+       if err != nil {
+               return nil, err
+       }
+       defer fileToRead.Close()
+
+       sc := bufio.NewScanner(fileToRead)
+       lines := make([]string, 0)
+
+       // Read through 'tokens' until an EOF is encountered.
+       for sc.Scan() {
+               lines = append(lines, sc.Text())
+       }
+
+       if err := sc.Err(); err != nil {
+               return nil, err
+       }
+       return lines, nil
+}
+
+func installExtraPackages(requirementsFile string) error {
+       extraPackages, err := getLines(requirementsFile)
+       if err != nil {
+               return err
+       }
+
+       for _, extraPackage := range extraPackages {
+               log.Printf("Installing extra package %v", extraPackage)
+               // We expect 'pip' command in virtual env to be already 
available at the top of the PATH.
+               args := []string{"install", extraPackage}
+               if err := execx.Execute("pip", args...); err != nil {
+                       return fmt.Errorf("Could not install the package %s: 
%s", extraPackage, err)
+               }
+       }
+       return nil
+}
+
+func getUpdatedRequirementsFile(oldRequirementsFileName string, 
dependenciesDir string) (string, error) {
+       oldExtraPackages, err := getLines(filepath.Join(dependenciesDir, 
oldRequirementsFileName))
+       if err != nil {
+               return "", err
+       }
+       var updatedExtraPackages = make([]string, 0)
+       for _, extraPackage := range oldExtraPackages {
+               // TODO update
+               potentialLocalFilePath := filepath.Join(dependenciesDir, 
extraPackage)
+               _, err := os.Stat(potentialLocalFilePath)
+               if err == nil {
+                       // Package exists locally so using that.
+                       extraPackage = potentialLocalFilePath
+                       log.Printf("Using locally available extra package %v", 
extraPackage)
+               }
+               updatedExtraPackages = append(updatedExtraPackages, 
extraPackage)
+       }
+
+       updatedRequirementsFile, err := ioutil.TempFile("/opt/apache/beam", 
"requirements*.txt")
+       if err != nil {
+               return "", err
+       }
+
+       updatedRequirementsFileName := updatedRequirementsFile.Name()
+
+       datawriter := bufio.NewWriter(updatedRequirementsFile)
+       for _, extraPackage := range updatedExtraPackages {
+               _, _ = datawriter.WriteString(extraPackage + "\n")
+       }
+       datawriter.Flush()
+       updatedRequirementsFile.Close()
+
+       return updatedRequirementsFileName, nil
+}
+
 func launchExpansionServiceProcess() error {
        pythonVersion, err := expansionx.GetPythonVersion()
        if err != nil {
@@ -70,6 +144,24 @@ func launchExpansionServiceProcess() error {
        os.Setenv("PATH", strings.Join([]string{filepath.Join(dir, "bin"), 
os.Getenv("PATH")}, ":"))
 
        args := []string{"-m", expansionServiceEntrypoint, "-p", 
strconv.Itoa(*port), "--fully_qualified_name_glob", "*"}
+
+       if *requirements_file != "" {
+               log.Printf("Received the requirements file %v", 
*requirements_file)
+               updatedRequirementsFileName, err := 
getUpdatedRequirementsFile(*requirements_file, *dependencies_dir)
+               if err != nil {
+                       return err
+               }
+               defer os.Remove(updatedRequirementsFileName)
+               log.Printf("Updated requirements file is %v", 
updatedRequirementsFileName)
+               // Provide the requirements file to the expansion service so 
that packages get staged by runners.
+               args = append(args, "--requirements_file", 
updatedRequirementsFileName)
+               // Install packages locally so that they can be used by the 
expansion service during transform
+               // expansion if needed.
+               err = installExtraPackages(updatedRequirementsFileName)
+               if err != nil {
+                       return err
+               }
+       }
        if err := execx.Execute(pythonVersion, args...); err != nil {
                return fmt.Errorf("could not start the expansion service: %s", 
err)
        }


Reply via email to