Ganeshsivakumar commented on code in PR #36623:
URL: https://github.com/apache/beam/pull/36623#discussion_r2549842559


##########
sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java:
##########
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.remoteinference.openai;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.annotation.JsonPropertyDescription;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.openai.client.OpenAIClient;
+import com.openai.client.okhttp.OpenAIOkHttpClient;
+import com.openai.core.JsonSchemaLocalValidation;
+import com.openai.models.responses.ResponseCreateParams;
+import com.openai.models.responses.StructuredResponseCreateParams;
+import org.apache.beam.sdk.ml.remoteinference.base.BaseModelHandler;
+import org.apache.beam.sdk.ml.remoteinference.base.PredictionResult;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Model handler for OpenAI API inference requests.
+ *
+ * <p>This handler manages communication with OpenAI's API, including client 
initialization,
+ * request formatting, and response parsing. It uses OpenAI's structured 
output feature to
+ * ensure reliable input-output pairing.
+ *
+ * <h3>Usage</h3>
+ * <pre>{@code
+ * OpenAIModelParameters params = OpenAIModelParameters.builder()
+ *     .apiKey("sk-...")
+ *     .modelName("gpt-4")
+ *     .instructionPrompt("Classify the following text into one of the 
categories: {CATEGORIES}")
+ *     .build();
+ *
+ * PCollection<OpenAIModelInput> inputs = ...;
+ * PCollection<Iterable<PredictionResult<OpenAIModelInput, 
OpenAIModelResponse>>> results =
+ *     inputs.apply(
+ *         RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke()
+ *             .handler(OpenAIModelHandler.class)
+ *             .withParameters(params)
+ *     );
+ * }</pre>
+ *
+ */
+public class OpenAIModelHandler
+  implements BaseModelHandler<OpenAIModelParameters, OpenAIModelInput, 
OpenAIModelResponse> {
+
+  private transient OpenAIClient client;
+  private transient StructuredResponseCreateParams<StructuredInputOutput> 
clientParams;
+  private OpenAIModelParameters modelParameters;
+
+  /**
+   * Initializes the OpenAI client with the provided parameters.
+   *
+   * <p>This method is called once during setup. It creates an authenticated
+   * OpenAI client using the API key from the parameters.
+   *
+   * @param parameters the configuration parameters including API key and 
model name
+   */
+  @Override
+  public void createClient(OpenAIModelParameters parameters) {
+    this.modelParameters = parameters;
+    this.client = OpenAIOkHttpClient.builder()
+      .apiKey(this.modelParameters.getApiKey())
+      .build();
+  }
+
+  /**
+   * Performs inference on a batch of inputs using the OpenAI Client.
+   *
+   * <p>This method serializes the input batch to JSON string, sends it to 
OpenAI with structured
+   * output requirements, and parses the response into {@link 
PredictionResult} objects
+   * that pair each input with its corresponding output.
+   *
+   * @param input the list of inputs to process
+   * @return an iterable of model results and input pairs
+   */
+  @Override
+  public Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>> 
request(List<OpenAIModelInput> input) {
+
+    try {
+      // Convert input list to JSON string
+      String inputBatch = new ObjectMapper()
+        
.writeValueAsString(input.stream().map(OpenAIModelInput::getModelInput).toList());

Review Comment:
   committed



##########
sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java:
##########
@@ -0,0 +1,366 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.remoteinference.openai;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TypeDescriptor;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.junit.Assume.assumeNotNull;
+import static org.junit.Assume.assumeTrue;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.beam.sdk.ml.remoteinference.base.*;
+import org.apache.beam.sdk.ml.remoteinference.RemoteInference;
+import 
org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.StructuredInputOutput;
+import 
org.apache.beam.sdk.ml.remoteinference.openai.OpenAIModelHandler.Response;
+
+public class OpenAIModelHandlerIT {
+  private static final Logger LOG = 
LoggerFactory.getLogger(OpenAIModelHandlerIT.class);
+
+  @Rule
+  public final transient TestPipeline pipeline = TestPipeline.create();
+
+  private String apiKey;
+  private static final String API_KEY_ENV = "OPENAI_API_KEY";
+  private static final String DEFAULT_MODEL = "gpt-4o-mini";
+
+
+  @Before
+  public void setUp() {
+    // Get API key
+    apiKey = System.getenv(API_KEY_ENV);
+
+    // Skip tests if API key is not provided
+    assumeNotNull(
+      "OpenAI API key not found. Set " + API_KEY_ENV
+        + " environment variable to run integration tests.",
+      apiKey);
+    assumeTrue("OpenAI API key is empty. Set " + API_KEY_ENV
+        + " environment variable to run integration tests.",
+      !apiKey.trim().isEmpty());
+  }
+
+  @Test
+  public void testSentimentAnalysisWithSingleInput() {
+    String input = "This product is absolutely amazing! I love it!";
+
+    PCollection<OpenAIModelInput> inputs = pipeline
+      .apply("CreateSingleInput", Create.of(input))
+      .apply("MapToInput", MapElements
+        .into(TypeDescriptor.of(OpenAIModelInput.class))
+        .via(OpenAIModelInput::create));
+
+    PCollection<Iterable<PredictionResult<OpenAIModelInput, 
OpenAIModelResponse>>> results = inputs
+      .apply("SentimentInference",
+        RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke()
+          .handler(OpenAIModelHandler.class)
+          .withParameters(OpenAIModelParameters.builder()
+            .apiKey(apiKey)
+            .modelName(DEFAULT_MODEL)
+            .instructionPrompt(
+              "Analyze the sentiment as 'positive' or 'negative'. Return only 
one word.")
+            .build()));
+
+    // Verify results
+    PAssert.that(results).satisfies(batches -> {
+      int count = 0;
+      for (Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>> 
batch : batches) {
+        for (PredictionResult<OpenAIModelInput, OpenAIModelResponse> result : 
batch) {
+          count++;
+          assertNotNull("Input should not be null", result.getInput());
+          assertNotNull("Output should not be null", result.getOutput());
+          assertNotNull("Output text should not be null",
+            result.getOutput().getModelResponse());
+
+          String sentiment = 
result.getOutput().getModelResponse().toLowerCase();
+          assertTrue("Sentiment should be positive or negative, got: " + 
sentiment,
+            sentiment.contains("positive")
+              || sentiment.contains("negative"));
+        }
+      }
+      assertEquals("Should have exactly 1 result", 1, count);
+      return null;
+    });
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  @Test
+  public void testSentimentAnalysisWithMultipleInputs() {
+    List<String> inputs = Arrays.asList(
+      "An excellent B2B SaaS solution that streamlines business processes 
efficiently.",
+      "The customer support is terrible. I've been waiting for days without 
any response.",
+      "The application works as expected. Installation was straightforward.",
+      "Really impressed with the innovative features! The AI capabilities are 
groundbreaking!",
+      "Mediocre product with occasional glitches. Documentation could be 
better.");
+
+    PCollection<OpenAIModelInput> inputCollection = pipeline
+      .apply("CreateMultipleInputs", Create.of(inputs))
+      .apply("MapToInputs", MapElements
+        .into(TypeDescriptor.of(OpenAIModelInput.class))
+        .via(OpenAIModelInput::create));
+
+    PCollection<Iterable<PredictionResult<OpenAIModelInput, 
OpenAIModelResponse>>> results = inputCollection
+      .apply("SentimentInference",
+        RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke()
+          .handler(OpenAIModelHandler.class)
+          .withParameters(OpenAIModelParameters.builder()
+            .apiKey(apiKey)
+            .modelName(DEFAULT_MODEL)
+            .instructionPrompt(
+              "Analyze sentiment as positive or negative")
+            .build()));
+
+    // Verify we get results for all inputs
+    PAssert.that(results).satisfies(batches -> {
+      int totalCount = 0;
+      for (Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>> 
batch : batches) {
+        for (PredictionResult<OpenAIModelInput, OpenAIModelResponse> result : 
batch) {
+          totalCount++;
+          assertNotNull("Input should not be null", result.getInput());
+          assertNotNull("Output should not be null", result.getOutput());
+          assertFalse("Output should not be empty",
+            result.getOutput().getModelResponse().trim().isEmpty());
+        }
+      }
+      assertEquals("Should have results for all 5 inputs", 5, totalCount);
+      return null;
+    });
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  @Test
+  public void testTextClassification() {
+    List<String> inputs = Arrays.asList(
+      "How do I reset my password?",
+      "Your product is broken and I want a refund!",
+      "Thank you for the excellent service!");
+
+    PCollection<OpenAIModelInput> inputCollection = pipeline
+      .apply("CreateInputs", Create.of(inputs))
+      .apply("MapToInputs", MapElements
+        .into(TypeDescriptor.of(OpenAIModelInput.class))
+        .via(OpenAIModelInput::create));
+
+    PCollection<Iterable<PredictionResult<OpenAIModelInput, 
OpenAIModelResponse>>> results = inputCollection
+      .apply("ClassificationInference",
+        RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke()
+          .handler(OpenAIModelHandler.class)
+          .withParameters(OpenAIModelParameters.builder()
+            .apiKey(apiKey)
+            .modelName(DEFAULT_MODEL)
+            .instructionPrompt(
+              "Classify each text into one category: 'question', 'complaint', 
or 'praise'. Return only the category.")
+            .build()));
+
+    PAssert.that(results).satisfies(batches -> {
+      List<String> categories = new ArrayList<>();
+      for (Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>> 
batch : batches) {
+        for (PredictionResult<OpenAIModelInput, OpenAIModelResponse> result : 
batch) {
+          String category = 
result.getOutput().getModelResponse().toLowerCase();
+          categories.add(category);
+        }
+      }
+
+      assertEquals("Should have 3 categories", 3, categories.size());
+
+      // Verify expected categories
+      boolean hasQuestion = categories.stream().anyMatch(c -> 
c.contains("question"));
+      boolean hasComplaint = categories.stream().anyMatch(c -> 
c.contains("complaint"));
+      boolean hasPraise = categories.stream().anyMatch(c -> 
c.contains("praise"));
+
+      assertTrue("Should have at least one recognized category",
+        hasQuestion || hasComplaint || hasPraise);
+
+      return null;
+    });
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  @Test
+  public void testInputOutputMapping() {
+    List<String> inputs = Arrays.asList("apple", "banana", "cherry");
+
+    PCollection<OpenAIModelInput> inputCollection = pipeline
+      .apply("CreateInputs", Create.of(inputs))
+      .apply("MapToInputs", MapElements
+        .into(TypeDescriptor.of(OpenAIModelInput.class))
+        .via(OpenAIModelInput::create));
+
+    PCollection<Iterable<PredictionResult<OpenAIModelInput, 
OpenAIModelResponse>>> results = inputCollection
+      .apply("MappingInference",
+        RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke()
+          .handler(OpenAIModelHandler.class)
+          .withParameters(OpenAIModelParameters.builder()
+            .apiKey(apiKey)
+            .modelName(DEFAULT_MODEL)
+            .instructionPrompt(
+              "Return the input word in uppercase")
+            .build()));
+
+    // Verify input-output pairing is preserved
+    PAssert.that(results).satisfies(batches -> {
+      for (Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>> 
batch : batches) {
+        for (PredictionResult<OpenAIModelInput, OpenAIModelResponse> result : 
batch) {
+          String input = result.getInput().getModelInput();
+          String output = result.getOutput().getModelResponse().toLowerCase();
+
+          // Verify the output relates to the input
+          assertTrue("Output should relate to input '" + input + "', got: " + 
output,
+            output.contains(input.toLowerCase()));
+        }
+      }
+      return null;
+    });
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  @Test
+  public void testWithDifferentModel() {
+    // Test with a different model
+    String input = "Explain quantum computing in one sentence.";
+
+    PCollection<OpenAIModelInput> inputs = pipeline
+      .apply("CreateInput", Create.of(input))
+      .apply("MapToInput", MapElements
+        .into(TypeDescriptor.of(OpenAIModelInput.class))
+        .via(OpenAIModelInput::create));
+
+    PCollection<Iterable<PredictionResult<OpenAIModelInput, 
OpenAIModelResponse>>> results = inputs
+      .apply("DifferentModelInference",
+        RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke()
+          .handler(OpenAIModelHandler.class)
+          .withParameters(OpenAIModelParameters.builder()
+            .apiKey(apiKey)
+            .modelName("gpt-5")

Review Comment:
   valid model



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to