gemini-code-assist[bot] commented on code in PR #36623:
URL: https://github.com/apache/beam/pull/36623#discussion_r2523782294


##########
sdks/java/ml/remoteinference/src/test/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandlerIT.java:
##########
@@ -0,0 +1,366 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.remoteinference.openai;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TypeDescriptor;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.junit.Assume.assumeNotNull;
+import static org.junit.Assume.assumeTrue;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.beam.sdk.ml.remoteinference.base.*;
+import org.apache.beam.sdk.ml.remoteinference.RemoteInference;
+import 
org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.StructuredInputOutput;
+import 
org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.Response;
+
+public class OpenAIModelHandlerIT {
+  private static final Logger LOG = 
LoggerFactory.getLogger(OpenAIModelHandlerIT.class);
+
+  @Rule
+  public final transient TestPipeline pipeline = TestPipeline.create();
+
+  private String apiKey;
+  private static final String API_KEY_ENV = "OPENAI_API_KEY";
+  private static final String DEFAULT_MODEL = "gpt-4o-mini";
+
+
+  @Before
+  public void setUp() {
+    // Get API key
+    apiKey = System.getenv(API_KEY_ENV);
+
+    // Skip tests if API key is not provided
+    assumeNotNull(
+      "OpenAI API key not found. Set " + API_KEY_ENV
+        + " environment variable to run integration tests.",
+      apiKey);
+    assumeTrue("OpenAI API key is empty. Set " + API_KEY_ENV
+        + " environment variable to run integration tests.",
+      !apiKey.trim().isEmpty());
+  }
+
+  @Test
+  public void testSentimentAnalysisWithSingleInput() {
+    String input = "This product is absolutely amazing! I love it!";
+
+    PCollection<OpenAIModelInput> inputs = pipeline
+      .apply("CreateSingleInput", Create.of(input))
+      .apply("MapToInput", MapElements
+        .into(TypeDescriptor.of(OpenAIModelInput.class))
+        .via(OpenAIModelInput::create));
+
+    PCollection<Iterable<PredictionResult<OpenAIModelInput, 
OpenAIModelResponse>>> results = inputs
+      .apply("SentimentInference",
+        RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke()
+          .handler(OpenAIModelHandler.class)
+          .withParameters(OpenAIModelParameters.builder()
+            .apiKey(apiKey)
+            .modelName(DEFAULT_MODEL)
+            .instructionPrompt(
+              "Analyze the sentiment as 'positive' or 'negative'. Return only 
one word.")
+            .build()));
+
+    // Verify results
+    PAssert.that(results).satisfies(batches -> {
+      int count = 0;
+      for (Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>> 
batch : batches) {
+        for (PredictionResult<OpenAIModelInput, OpenAIModelResponse> result : 
batch) {
+          count++;
+          assertNotNull("Input should not be null", result.getInput());
+          assertNotNull("Output should not be null", result.getOutput());
+          assertNotNull("Output text should not be null",
+            result.getOutput().getModelResponse());
+
+          String sentiment = 
result.getOutput().getModelResponse().toLowerCase();
+          assertTrue("Sentiment should be positive or negative, got: " + 
sentiment,
+            sentiment.contains("positive")
+              || sentiment.contains("negative"));
+        }
+      }
+      assertEquals("Should have exactly 1 result", 1, count);
+      return null;
+    });
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  @Test
+  public void testSentimentAnalysisWithMultipleInputs() {
+    List<String> inputs = Arrays.asList(
+      "An excellent B2B SaaS solution that streamlines business processes 
efficiently.",
+      "The customer support is terrible. I've been waiting for days without 
any response.",
+      "The application works as expected. Installation was straightforward.",
+      "Really impressed with the innovative features! The AI capabilities are 
groundbreaking!",
+      "Mediocre product with occasional glitches. Documentation could be 
better.");
+
+    PCollection<OpenAIModelInput> inputCollection = pipeline
+      .apply("CreateMultipleInputs", Create.of(inputs))
+      .apply("MapToInputs", MapElements
+        .into(TypeDescriptor.of(OpenAIModelInput.class))
+        .via(OpenAIModelInput::create));
+
+    PCollection<Iterable<PredictionResult<OpenAIModelInput, 
OpenAIModelResponse>>> results = inputCollection
+      .apply("SentimentInference",
+        RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke()
+          .handler(OpenAIModelHandler.class)
+          .withParameters(OpenAIModelParameters.builder()
+            .apiKey(apiKey)
+            .modelName(DEFAULT_MODEL)
+            .instructionPrompt(
+              "Analyze sentiment as positive or negative")
+            .build()));
+
+    // Verify we get results for all inputs
+    PAssert.that(results).satisfies(batches -> {
+      int totalCount = 0;
+      for (Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>> 
batch : batches) {
+        for (PredictionResult<OpenAIModelInput, OpenAIModelResponse> result : 
batch) {
+          totalCount++;
+          assertNotNull("Input should not be null", result.getInput());
+          assertNotNull("Output should not be null", result.getOutput());
+          assertFalse("Output should not be empty",
+            result.getOutput().getModelResponse().trim().isEmpty());
+        }
+      }
+      assertEquals("Should have results for all 5 inputs", 5, totalCount);
+      return null;
+    });
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  @Test
+  public void testTextClassification() {
+    List<String> inputs = Arrays.asList(
+      "How do I reset my password?",
+      "Your product is broken and I want a refund!",
+      "Thank you for the excellent service!");
+
+    PCollection<OpenAIModelInput> inputCollection = pipeline
+      .apply("CreateInputs", Create.of(inputs))
+      .apply("MapToInputs", MapElements
+        .into(TypeDescriptor.of(OpenAIModelInput.class))
+        .via(OpenAIModelInput::create));
+
+    PCollection<Iterable<PredictionResult<OpenAIModelInput, 
OpenAIModelResponse>>> results = inputCollection
+      .apply("ClassificationInference",
+        RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke()
+          .handler(OpenAIModelHandler.class)
+          .withParameters(OpenAIModelParameters.builder()
+            .apiKey(apiKey)
+            .modelName(DEFAULT_MODEL)
+            .instructionPrompt(
+              "Classify each text into one category: 'question', 'complaint', 
or 'praise'. Return only the category.")
+            .build()));
+
+    PAssert.that(results).satisfies(batches -> {
+      List<String> categories = new ArrayList<>();
+      for (Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>> 
batch : batches) {
+        for (PredictionResult<OpenAIModelInput, OpenAIModelResponse> result : 
batch) {
+          String category = 
result.getOutput().getModelResponse().toLowerCase();
+          categories.add(category);
+        }
+      }
+
+      assertEquals("Should have 3 categories", 3, categories.size());
+
+      // Verify expected categories
+      boolean hasQuestion = categories.stream().anyMatch(c -> 
c.contains("question"));
+      boolean hasComplaint = categories.stream().anyMatch(c -> 
c.contains("complaint"));
+      boolean hasPraise = categories.stream().anyMatch(c -> 
c.contains("praise"));
+
+      assertTrue("Should have at least one recognized category",
+        hasQuestion || hasComplaint || hasPraise);
+
+      return null;
+    });
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  @Test
+  public void testInputOutputMapping() {
+    List<String> inputs = Arrays.asList("apple", "banana", "cherry");
+
+    PCollection<OpenAIModelInput> inputCollection = pipeline
+      .apply("CreateInputs", Create.of(inputs))
+      .apply("MapToInputs", MapElements
+        .into(TypeDescriptor.of(OpenAIModelInput.class))
+        .via(OpenAIModelInput::create));
+
+    PCollection<Iterable<PredictionResult<OpenAIModelInput, 
OpenAIModelResponse>>> results = inputCollection
+      .apply("MappingInference",
+        RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke()
+          .handler(OpenAIModelHandler.class)
+          .withParameters(OpenAIModelParameters.builder()
+            .apiKey(apiKey)
+            .modelName(DEFAULT_MODEL)
+            .instructionPrompt(
+              "Return the input word in uppercase")
+            .build()));
+
+    // Verify input-output pairing is preserved
+    PAssert.that(results).satisfies(batches -> {
+      for (Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>> 
batch : batches) {
+        for (PredictionResult<OpenAIModelInput, OpenAIModelResponse> result : 
batch) {
+          String input = result.getInput().getModelInput();
+          String output = result.getOutput().getModelResponse().toLowerCase();
+
+          // Verify the output relates to the input
+          assertTrue("Output should relate to input '" + input + "', got: " + 
output,
+            output.contains(input.toLowerCase()));
+        }
+      }
+      return null;
+    });
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  @Test
+  public void testWithDifferentModel() {
+    // Test with a different model
+    String input = "Explain quantum computing in one sentence.";
+
+    PCollection<OpenAIModelInput> inputs = pipeline
+      .apply("CreateInput", Create.of(input))
+      .apply("MapToInput", MapElements
+        .into(TypeDescriptor.of(OpenAIModelInput.class))
+        .via(OpenAIModelInput::create));
+
+    PCollection<Iterable<PredictionResult<OpenAIModelInput, 
OpenAIModelResponse>>> results = inputs
+      .apply("DifferentModelInference",
+        RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke()
+          .handler(OpenAIModelHandler.class)
+          .withParameters(OpenAIModelParameters.builder()
+            .apiKey(apiKey)
+            .modelName("gpt-5")

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The integration test `testWithDifferentModel` uses the model name "gpt-5". 
This model is not currently available, and using it will cause the integration 
test to fail with an "invalid model" error from the OpenAI API. Please use a 
valid and available model name to ensure the test can run successfully. For 
example, you could use another real model like `gpt-4-turbo`.
   
   ```suggestion
               .modelName("gpt-4-turbo")
   ```



##########
sdks/java/ml/remoteinference/build.gradle.kts:
##########
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+plugins {
+  id("org.apache.beam.module")
+  id("java-library")
+}
+
+description = "Apache Beam :: SDKs :: Java :: ML :: RemoteInference"
+
+dependencies {
+  // Core Beam SDK
+  implementation(project(":sdks:java:core"))
+
+  implementation("com.openai:openai-java:4.3.0")
+  compileOnly("com.google.auto.value:auto-value-annotations:1.11.0")
+  compileOnly("org.checkerframework:checker-qual:3.42.0")
+  annotationProcessor("com.google.auto.value:auto-value:1.11.0")
+  implementation("com.fasterxml.jackson.core:jackson-core:2.20.0")

Review Comment:
   ![critical](https://www.gstatic.com/codereviewagent/critical.svg)
   
   The specified version `2.20.0` for `com.fasterxml.jackson.core:jackson-core` 
does not appear to be a valid published version and will likely cause build 
failures. Please use a valid version. For example, `2.17.1` is a recent stable 
version.
   
   ```kotlin
     implementation("com.fasterxml.jackson.core:jackson-core:2.17.1")
   ```



##########
sdks/java/ml/remoteinference/build.gradle.kts:
##########
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+plugins {
+  id("org.apache.beam.module")
+  id("java-library")
+}
+
+description = "Apache Beam :: SDKs :: Java :: ML :: RemoteInference"
+
+dependencies {
+  // Core Beam SDK
+  implementation(project(":sdks:java:core"))
+
+  implementation("com.openai:openai-java:4.3.0")
+  compileOnly("com.google.auto.value:auto-value-annotations:1.11.0")
+  compileOnly("org.checkerframework:checker-qual:3.42.0")
+  annotationProcessor("com.google.auto.value:auto-value:1.11.0")
+  implementation("com.fasterxml.jackson.core:jackson-core:2.20.0")
+  implementation("org.apache.beam:beam-vendor-guava-32_1_2-jre:0.1")
+  implementation("org.slf4j:slf4j-api:2.0.9")
+  implementation("org.slf4j:slf4j-simple:2.0.9")

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The `slf4j-simple` dependency should be scoped to `testImplementation` 
instead of `implementation`. Including a concrete logging implementation like 
`slf4j-simple` in the `implementation` scope can lead to classpath conflicts in 
downstream projects that use this module, as they may have their own preferred 
logging framework. It's a best practice for libraries to only depend on 
`slf4j-api`.
   
   ```kotlin
     testImplementation("org.slf4j:slf4j-simple:2.0.9")
   ```



##########
sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java:
##########
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.remoteinference.openai;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.annotation.JsonPropertyDescription;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.openai.client.OpenAIClient;
+import com.openai.client.okhttp.OpenAIOkHttpClient;
+import com.openai.core.JsonSchemaLocalValidation;
+import com.openai.models.responses.ResponseCreateParams;
+import com.openai.models.responses.StructuredResponseCreateParams;
+import org.apache.beam.sdk.ml.remoteinference.base.BaseModelHandler;
+import org.apache.beam.sdk.ml.remoteinference.base.PredictionResult;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Model handler for OpenAI API inference requests.
+ *
+ * <p>This handler manages communication with OpenAI's API, including client 
initialization,
+ * request formatting, and response parsing. It uses OpenAI's structured 
output feature to
+ * ensure reliable input-output pairing.
+ *
+ * <h3>Usage</h3>
+ * <pre>{@code
+ * OpenAIModelParameters params = OpenAIModelParameters.builder()
+ *     .apiKey("sk-...")
+ *     .modelName("gpt-4")
+ *     .instructionPrompt("Classify the following text into one of the 
categories: {CATEGORIES}")
+ *     .build();
+ *
+ * PCollection<OpenAIModelInput> inputs = ...;
+ * PCollection<Iterable<PredictionResult<OpenAIModelInput, 
OpenAIModelResponse>>> results =
+ *     inputs.apply(
+ *         RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke()
+ *             .handler(OpenAIModelHandler.class)
+ *             .withParameters(params)
+ *     );
+ * }</pre>
+ *
+ */
+public class OpenAIModelHandler
+  implements BaseModelHandler<OpenAIModelParameters, OpenAIModelInput, 
OpenAIModelResponse> {
+
+  private transient OpenAIClient client;
+  private transient StructuredResponseCreateParams<StructuredInputOutput> 
clientParams;

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The field `clientParams` is declared as a transient instance variable but it 
is only assigned and used within the `request` method. This makes the code 
harder to reason about, as it suggests `clientParams` holds state across method 
calls, which it doesn't. It should be a local variable within the `request` 
method.



##########
sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java:
##########
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.remoteinference.openai;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.annotation.JsonPropertyDescription;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.openai.client.OpenAIClient;
+import com.openai.client.okhttp.OpenAIOkHttpClient;
+import com.openai.core.JsonSchemaLocalValidation;
+import com.openai.models.responses.ResponseCreateParams;
+import com.openai.models.responses.StructuredResponseCreateParams;
+import org.apache.beam.sdk.ml.remoteinference.base.BaseModelHandler;
+import org.apache.beam.sdk.ml.remoteinference.base.PredictionResult;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Model handler for OpenAI API inference requests.
+ *
+ * <p>This handler manages communication with OpenAI's API, including client 
initialization,
+ * request formatting, and response parsing. It uses OpenAI's structured 
output feature to
+ * ensure reliable input-output pairing.
+ *
+ * <h3>Usage</h3>
+ * <pre>{@code
+ * OpenAIModelParameters params = OpenAIModelParameters.builder()
+ *     .apiKey("sk-...")
+ *     .modelName("gpt-4")
+ *     .instructionPrompt("Classify the following text into one of the 
categories: {CATEGORIES}")
+ *     .build();
+ *
+ * PCollection<OpenAIModelInput> inputs = ...;
+ * PCollection<Iterable<PredictionResult<OpenAIModelInput, 
OpenAIModelResponse>>> results =
+ *     inputs.apply(
+ *         RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke()
+ *             .handler(OpenAIModelHandler.class)
+ *             .withParameters(params)
+ *     );
+ * }</pre>
+ *
+ */
+public class OpenAIModelHandler
+  implements BaseModelHandler<OpenAIModelParameters, OpenAIModelInput, 
OpenAIModelResponse> {
+
+  private transient OpenAIClient client;
+  private transient StructuredResponseCreateParams<StructuredInputOutput> 
clientParams;
+  private OpenAIModelParameters modelParameters;
+
+  /**
+   * Initializes the OpenAI client with the provided parameters.
+   *
+   * <p>This method is called once during setup. It creates an authenticated
+   * OpenAI client using the API key from the parameters.
+   *
+   * @param parameters the configuration parameters including API key and 
model name
+   */
+  @Override
+  public void createClient(OpenAIModelParameters parameters) {
+    this.modelParameters = parameters;
+    this.client = OpenAIOkHttpClient.builder()
+      .apiKey(this.modelParameters.getApiKey())
+      .build();
+  }
+
+  /**
+   * Performs inference on a batch of inputs using the OpenAI Client.
+   *
+   * <p>This method serializes the input batch to JSON string, sends it to 
OpenAI with structured
+   * output requirements, and parses the response into {@link 
PredictionResult} objects
+   * that pair each input with its corresponding output.
+   *
+   * @param input the list of inputs to process
+   * @return an iterable of model results and input pairs
+   */
+  @Override
+  public Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>> 
request(List<OpenAIModelInput> input) {
+
+    try {
+      // Convert input list to JSON string
+      String inputBatch = new ObjectMapper()
+        
.writeValueAsString(input.stream().map(OpenAIModelInput::getModelInput).toList());
+
+      // Build structured response parameters
+      this.clientParams = ResponseCreateParams.builder()
+        .model(modelParameters.getModelName())
+        .input(inputBatch)
+        .text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO)
+        .instructions(modelParameters.getInstructionPrompt())
+        .build();

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The `clientParams` is being assigned to an instance field, but it's only 
used locally within the `request` method. It should be a local variable. This 
also requires removing the `clientParams` field from the class.
   
   ```suggestion
         StructuredResponseCreateParams<StructuredInputOutput> clientParams = 
ResponseCreateParams.builder()
           .model(modelParameters.getModelName())
           .input(inputBatch)
           .text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO)
           .instructions(modelParameters.getInstructionPrompt())
           .build();
   ```



##########
sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelParameters.java:
##########
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.remoteinference.openai;
+
+import org.apache.beam.sdk.ml.remoteinference.base.BaseModelParameters;
+
+/**
+ * Configuration parameters required for OpenAI model inference.
+ *
+ * <p>This class encapsulates all configuration needed to initialize and 
communicate with
+ * OpenAI's API, including authentication credentials, model selection, and 
inference instructions.
+ *
+ * <h3>Example Usage</h3>
+ * <pre>{@code
+ * OpenAIModelParameters params = OpenAIModelParameters.builder()
+ *     .apiKey("sk-...")
+ *     .modelName("gpt-4")
+ *     .instructionPrompt("Translate the following text to French:")
+ *     .build();
+ * }</pre>
+ *
+ * @see OpenAIModelHandler
+ */
+public class OpenAIModelParameters implements BaseModelParameters {

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The classes `OpenAIModelParameters`, `OpenAIModelInput`, and 
`OpenAIModelResponse` do not override `equals()` and `hashCode()`. This can 
lead to unexpected behavior when these objects are used in collections (like 
`Set` or as keys in a `Map`) or in tests that rely on object equality. The test 
classes you've written for `RemoteInferenceTest` correctly implement these 
methods, and the production classes should as well.
   
   For `OpenAIModelParameters`, you can add the following:
   ```java
   @Override
   public boolean equals(Object o) {
     if (this == o) return true;
     if (o == null || getClass() != o.getClass()) return false;
     OpenAIModelParameters that = (OpenAIModelParameters) o;
     return java.util.Objects.equals(apiKey, that.apiKey) &&
            java.util.Objects.equals(modelName, that.modelName) &&
            java.util.Objects.equals(instructionPrompt, that.instructionPrompt);
   }
   
   @Override
   public int hashCode() {
     return java.util.Objects.hash(apiKey, modelName, instructionPrompt);
   }
   ```
   Similar implementations should be added to `OpenAIModelInput` and 
`OpenAIModelResponse`.



##########
sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java:
##########
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.remoteinference;
+
+import org.apache.beam.sdk.ml.remoteinference.base.*;
+import org.apache.beam.sdk.transforms.*;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
+import org.apache.beam.sdk.values.PCollection;
+import com.google.auto.value.AutoValue;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * A {@link PTransform} for making remote inference calls to external machine 
learning services.
+ *
+ * <p>{@code RemoteInference} provides a framework for integrating remote ML 
model
+ * inference into Apache Beam pipelines and handles the communication between 
pipelines
+ * and external inference APIs.
+ *
+ * <h3>Example: OpenAI Model Inference</h3>
+ *
+ * <pre>{@code
+ * // Create model parameters
+ * OpenAIModelParameters params = OpenAIModelParameters.builder()
+ *     .apiKey("your-api-key")
+ *     .modelName("gpt-4")
+ *     .instructionPrompt("Analyse sentiment as positive or negative")
+ *     .build();
+ *
+ * // Apply remote inference transform
+ * PCollection<OpenAIModelInput> inputs = pipeline.apply(Create.of(
+ *     OpenAIModelInput.create("An excellent B2B SaaS solution that 
streamlines business processes efficiently."),
+ *     OpenAIModelInput.create("Really impressed with the innovative 
features!")
+ * ));
+ *
+ * PCollection<Iterable<PredictionResult<OpenAIModelInput, 
OpenAIModelResponse>>> results =
+ *     inputs.apply(
+ *         RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke()
+ *             .handler(OpenAIModelHandler.class)
+ *             .withParameters(params)
+ *     );
+ * }</pre>
+ *
+ */
+@SuppressWarnings({ "rawtypes", "unchecked" })

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The class is suppressed with `@SuppressWarnings({ "rawtypes", "unchecked" 
})` because `BaseModelHandler` is used as a raw type. This reduces type safety 
and can hide potential class cast exceptions at runtime.
   
   For example, `Invoke.handler()` returns `Class<? extends BaseModelHandler>`, 
and `RemoteInferenceFn.handler` is a raw `BaseModelHandler`. This means the 
compiler cannot verify that the `BaseModelParameters` and 
`BaseInput`/`BaseResponse` types are compatible between the `Invoke` transform 
and the `BaseModelHandler` implementation.
   
   While fixing this might require some significant refactoring of the 
generics, it would make the framework more robust. A potential direction could 
be to include the `BaseModelParameters` type in the `Invoke` transform's 
generics, like `Invoke<InputT, OutputT, ParamT extends BaseModelParameters>`. 
This would allow for stronger type checking throughout the implementation.



##########
sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/RemoteInference.java:
##########
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.remoteinference;
+
+import org.apache.beam.sdk.ml.remoteinference.base.*;
+import org.apache.beam.sdk.transforms.*;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import static 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
+import org.apache.beam.sdk.values.PCollection;
+import com.google.auto.value.AutoValue;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * A {@link PTransform} for making remote inference calls to external machine 
learning services.
+ *
+ * <p>{@code RemoteInference} provides a framework for integrating remote ML 
model
+ * inference into Apache Beam pipelines and handles the communication between 
pipelines
+ * and external inference APIs.
+ *
+ * <h3>Example: OpenAI Model Inference</h3>
+ *
+ * <pre>{@code
+ * // Create model parameters
+ * OpenAIModelParameters params = OpenAIModelParameters.builder()
+ *     .apiKey("your-api-key")
+ *     .modelName("gpt-4")
+ *     .instructionPrompt("Analyse sentiment as positive or negative")
+ *     .build();
+ *
+ * // Apply remote inference transform
+ * PCollection<OpenAIModelInput> inputs = pipeline.apply(Create.of(
+ *     OpenAIModelInput.create("An excellent B2B SaaS solution that 
streamlines business processes efficiently."),
+ *     OpenAIModelInput.create("Really impressed with the innovative 
features!")
+ * ));
+ *
+ * PCollection<Iterable<PredictionResult<OpenAIModelInput, 
OpenAIModelResponse>>> results =
+ *     inputs.apply(
+ *         RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke()
+ *             .handler(OpenAIModelHandler.class)
+ *             .withParameters(params)
+ *     );
+ * }</pre>
+ *
+ */
+@SuppressWarnings({ "rawtypes", "unchecked" })
+public class RemoteInference {
+
+  /** Invoke the model handler with model parameters */
+  public static <InputT extends BaseInput, OutputT extends BaseResponse> 
Invoke<InputT, OutputT> invoke() {
+    return new AutoValue_RemoteInference_Invoke.Builder<InputT, 
OutputT>().setParameters(null)
+      .build();
+  }
+
+  private RemoteInference() {
+  }
+
+  @AutoValue
+  public abstract static class Invoke<InputT extends BaseInput, OutputT 
extends BaseResponse>
+    extends PTransform<PCollection<InputT>, 
PCollection<Iterable<PredictionResult<InputT, OutputT>>>> {
+
+    abstract @Nullable Class<? extends BaseModelHandler> handler();
+
+    abstract @Nullable BaseModelParameters parameters();
+
+
+    abstract Builder<InputT, OutputT> builder();
+
+    @AutoValue.Builder
+    abstract static class Builder<InputT extends BaseInput, OutputT extends 
BaseResponse> {
+
+      abstract Builder<InputT, OutputT> setHandler(Class<? extends 
BaseModelHandler> modelHandler);
+
+      abstract Builder<InputT, OutputT> setParameters(BaseModelParameters 
modelParameters);
+
+
+      abstract Invoke<InputT, OutputT> build();
+    }
+
+    /**
+     * Model handler class for inference.
+     */
+    public Invoke<InputT, OutputT> handler(Class<? extends BaseModelHandler> 
modelHandler) {
+      return builder().setHandler(modelHandler).build();
+    }
+
+    /**
+     * Configures the parameters for model initialization.
+     */
+    public Invoke<InputT, OutputT> withParameters(BaseModelParameters 
modelParameters) {
+      return builder().setParameters(modelParameters).build();
+    }
+
+
+    @Override
+    public PCollection<Iterable<PredictionResult<InputT, OutputT>>> 
expand(PCollection<InputT> input) {
+      checkArgument(handler() != null, "handler() is required");
+      checkArgument(parameters() != null, "withParameters() is required");
+      return input
+        .apply("WrapInputInList", MapElements.via(new SimpleFunction<InputT, 
List<InputT>>() {
+          @Override
+          public List<InputT> apply(InputT element) {
+            return Collections.singletonList(element);
+          }
+        }))
+        // Pass the list to the inference function
+        .apply("RemoteInference", ParDo.of(new RemoteInferenceFn<InputT, 
OutputT>(this)));

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The current implementation processes each input element individually by 
wrapping it in a singleton list. This results in a separate remote inference 
call for every element, which is highly inefficient and will lead to poor 
performance, especially for large datasets. The `BaseModelHandler#request` 
method already accepts a `List<InputT>`, suggesting that batching is intended.
   
   To improve performance, you should introduce batching before the `ParDo` 
transform. You can use `org.apache.beam.sdk.transforms.GroupIntoBatches` to 
group elements into batches of a configurable size. This will significantly 
reduce the number of remote calls.
   
   For example, you could add a `batchSize` parameter to the `Invoke` transform 
and then use it like this:
   ```java
   // First, add a key to each element
   input.apply(WithKeys.of("key"))
        // Then, group elements into batches
        .apply(GroupIntoBatches.ofSize(batchSize))
        // Then, get the values (the batches)
        .apply(Values.create())
        // Finally, perform the remote inference on each batch
        .apply("RemoteInference", ParDo.of(new RemoteInferenceFn<InputT, 
OutputT>(this)));
   ```



##########
sdks/java/ml/remoteinference/src/main/java/org/apache/beam/sdk/ml/remoteinference/openai/OpenAIModelHandler.java:
##########
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.remoteinference.openai;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.annotation.JsonPropertyDescription;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.openai.client.OpenAIClient;
+import com.openai.client.okhttp.OpenAIOkHttpClient;
+import com.openai.core.JsonSchemaLocalValidation;
+import com.openai.models.responses.ResponseCreateParams;
+import com.openai.models.responses.StructuredResponseCreateParams;
+import org.apache.beam.sdk.ml.remoteinference.base.BaseModelHandler;
+import org.apache.beam.sdk.ml.remoteinference.base.PredictionResult;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Model handler for OpenAI API inference requests.
+ *
+ * <p>This handler manages communication with OpenAI's API, including client 
initialization,
+ * request formatting, and response parsing. It uses OpenAI's structured 
output feature to
+ * ensure reliable input-output pairing.
+ *
+ * <h3>Usage</h3>
+ * <pre>{@code
+ * OpenAIModelParameters params = OpenAIModelParameters.builder()
+ *     .apiKey("sk-...")
+ *     .modelName("gpt-4")
+ *     .instructionPrompt("Classify the following text into one of the 
categories: {CATEGORIES}")
+ *     .build();
+ *
+ * PCollection<OpenAIModelInput> inputs = ...;
+ * PCollection<Iterable<PredictionResult<OpenAIModelInput, 
OpenAIModelResponse>>> results =
+ *     inputs.apply(
+ *         RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke()
+ *             .handler(OpenAIModelHandler.class)
+ *             .withParameters(params)
+ *     );
+ * }</pre>
+ *
+ */
+public class OpenAIModelHandler
+  implements BaseModelHandler<OpenAIModelParameters, OpenAIModelInput, 
OpenAIModelResponse> {
+
+  private transient OpenAIClient client;
+  private transient StructuredResponseCreateParams<StructuredInputOutput> 
clientParams;
+  private OpenAIModelParameters modelParameters;
+
+  /**
+   * Initializes the OpenAI client with the provided parameters.
+   *
+   * <p>This method is called once during setup. It creates an authenticated
+   * OpenAI client using the API key from the parameters.
+   *
+   * @param parameters the configuration parameters including API key and 
model name
+   */
+  @Override
+  public void createClient(OpenAIModelParameters parameters) {
+    this.modelParameters = parameters;
+    this.client = OpenAIOkHttpClient.builder()
+      .apiKey(this.modelParameters.getApiKey())
+      .build();
+  }
+
+  /**
+   * Performs inference on a batch of inputs using the OpenAI Client.
+   *
+   * <p>This method serializes the input batch to JSON string, sends it to 
OpenAI with structured
+   * output requirements, and parses the response into {@link 
PredictionResult} objects
+   * that pair each input with its corresponding output.
+   *
+   * @param input the list of inputs to process
+   * @return an iterable of model results and input pairs
+   */
+  @Override
+  public Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>> 
request(List<OpenAIModelInput> input) {
+
+    try {
+      // Convert input list to JSON string
+      String inputBatch = new ObjectMapper()
+        
.writeValueAsString(input.stream().map(OpenAIModelInput::getModelInput).toList());

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   A new `ObjectMapper` instance is created for every call to the `request` 
method. `ObjectMapper` is a heavy object to create, and it is thread-safe. 
Creating it repeatedly in a hot path like this can negatively impact 
performance.
   
   You should create the `ObjectMapper` instance once and reuse it. A good 
place to initialize it would be in the `createClient` method and store it in a 
`transient` field.
   
   1. Add a field to `OpenAIModelHandler`:
   ```java
   private transient ObjectMapper objectMapper;
   ```
   
   2. Initialize it in `createClient`:
   ```java
   @Override
   public void createClient(OpenAIModelParameters parameters) {
     this.modelParameters = parameters;
     this.client = OpenAIOkHttpClient.builder()
       .apiKey(this.modelParameters.getApiKey())
       .build();
     this.objectMapper = new ObjectMapper();
   }
   ```
   
   3. Use the field in `request`:
   
   ```suggestion
         String inputBatch = objectMapper
           
.writeValueAsString(input.stream().map(OpenAIModelInput::getModelInput).toList());
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to