derrickaw commented on code in PR #38772:
URL: https://github.com/apache/beam/pull/38772#discussion_r3494667103


##########
sdks/java/io/mongodb/src/test/java/org/apache/beam/sdk/io/mongodb/MongoDbReadSchemaTransformProviderTest.java:
##########
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.mongodb;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertThrows;
+
+import java.nio.charset.StandardCharsets;
+import java.util.Collections;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.SchemaRegistry;
+import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionRowTuple;
+import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.sdk.values.TupleTagList;
+import org.bson.Document;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Tests for {@link MongoDbReadSchemaTransformProvider}. */
+@RunWith(JUnit4.class)
+public class MongoDbReadSchemaTransformProviderTest {
+
+  @Rule public transient TestPipeline p = TestPipeline.create();
+
+  @Test
+  public void testInvalidConfigMissingUri() {
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          MongoDbReadSchemaTransformConfiguration.builder()
+              .setDatabase("db")
+              .setCollection("col")
+              .setSchema("{}")
+              .build()
+              .validate();
+        });
+  }
+
+  @Test
+  public void testInvalidConfigMissingDatabase() {
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          MongoDbReadSchemaTransformConfiguration.builder()
+              .setUri("mongodb://localhost:27017")
+              .setCollection("col")
+              .setSchema("{}")
+              .build()
+              .validate();
+        });
+  }
+
+  @Test
+  public void testInvalidConfigMissingCollection() {
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          MongoDbReadSchemaTransformConfiguration.builder()
+              .setUri("mongodb://localhost:27017")
+              .setDatabase("db")
+              .setSchema("{}")
+              .build()
+              .validate();
+        });
+  }
+
+  @Test
+  public void testInvalidConfigMissingSchema() {
+    assertThrows(
+        IllegalStateException.class,
+        () -> {
+          MongoDbReadSchemaTransformConfiguration.builder()
+              .setUri("mongodb://localhost:27017")
+              .setDatabase("db")
+              .setCollection("col")
+              .build()
+              .validate();
+        });
+  }
+
+  @Test
+  public void testConfigurationSchema() throws Exception {
+    Schema schema =
+        
SchemaRegistry.createDefault().getSchema(MongoDbReadSchemaTransformConfiguration.class);
+
+    // We expect 6 fields: uri, database, collection, schema, filter, 
errorHandling
+    assertEquals(6, schema.getFieldCount());
+    assertNotNull(schema.getField("uri"));
+    assertNotNull(schema.getField("database"));
+    assertNotNull(schema.getField("collection"));
+    assertNotNull(schema.getField("schema"));
+    assertNotNull(schema.getField("filter"));
+    assertNotNull(schema.getField("errorHandling"));
+  }
+
+  @Test
+  public void testExpandWithFilter() {
+    MongoDbReadSchemaTransformConfiguration config =
+        MongoDbReadSchemaTransformConfiguration.builder()
+            .setUri("mongodb://localhost:27017")
+            .setDatabase("db")
+            .setCollection("col")
+            .setSchema("{\"type\": \"object\", \"properties\": {\"name\": 
{\"type\": \"string\"}}}")
+            .setFilter("{\"name\": \"John\"}")
+            .build();
+
+    MongoDbReadSchemaTransformProvider provider = new 
MongoDbReadSchemaTransformProvider();
+    PCollectionRowTuple output =
+        
provider.from(config).expand(PCollectionRowTuple.empty(Pipeline.create()));
+
+    assertNotNull(output.get("output"));
+  }
+
+  @Test
+  public void testDocumentToRowFn() {
+    Schema beamSchema = 
Schema.builder().addStringField("name").addInt32Field("age").build();
+
+    Document doc = new Document().append("name", "John").append("age", 30);
+
+    PCollection<Document> inputDocs =
+        p.apply(
+            Create.of(Collections.singletonList(doc))
+                
.withCoder(MongoDbWriteSchemaTransformProvider.DocumentCoder.of()));
+
+    Schema errorSchema = ErrorHandling.errorSchemaBytes();
+    PCollectionTuple outputTuple =
+        inputDocs.apply(
+            "ConvertToRows",
+            ParDo.of(
+                    new MongoDbReadSchemaTransformProvider.DocumentToRowFn(
+                        beamSchema, false, errorSchema))
+                .withOutputTags(
+                    MongoDbReadSchemaTransformProvider.OUTPUT_TAG,
+                    
TupleTagList.of(MongoDbReadSchemaTransformProvider.ERROR_TAG)));
+
+    PCollection<Row> outputRows =
+        
outputTuple.get(MongoDbReadSchemaTransformProvider.OUTPUT_TAG).setRowSchema(beamSchema);
+    
outputTuple.get(MongoDbReadSchemaTransformProvider.ERROR_TAG).setRowSchema(errorSchema);
+
+    PAssert.that(outputRows)
+        .satisfies(
+            rows -> {
+              Row row = rows.iterator().next();
+              assertEquals("John", row.getString("name"));
+              assertEquals(Integer.valueOf(30), row.getInt32("age"));
+              return null;
+            });
+
+    p.run().waitUntilFinish();
+  }
+
+  @Test
+  public void testDocumentToRowFnWithErrors() {
+    Schema beamSchema = Schema.builder().addInt32Field("age").build();
+
+    // Invalid document: age value is a string "not_an_int" which cannot be 
converted to INT32
+    Document invalidDoc = new Document().append("age", "not_an_int");
+
+    PCollection<Document> inputDocs =
+        p.apply(
+            Create.of(Collections.singletonList(invalidDoc))
+                
.withCoder(MongoDbWriteSchemaTransformProvider.DocumentCoder.of()));
+
+    Schema errorSchema = ErrorHandling.errorSchemaBytes();
+    PCollectionTuple outputTuple =
+        inputDocs.apply(
+            "ConvertToRowsWithErrors",
+            ParDo.of(
+                    new MongoDbReadSchemaTransformProvider.DocumentToRowFn(
+                        beamSchema, true, errorSchema))
+                .withOutputTags(
+                    MongoDbReadSchemaTransformProvider.OUTPUT_TAG,
+                    
TupleTagList.of(MongoDbReadSchemaTransformProvider.ERROR_TAG)));
+
+    PCollection<Row> errorRows =
+        
outputTuple.get(MongoDbReadSchemaTransformProvider.ERROR_TAG).setRowSchema(errorSchema);
+    
outputTuple.get(MongoDbReadSchemaTransformProvider.OUTPUT_TAG).setRowSchema(beamSchema);
+
+    PAssert.that(errorRows)

Review Comment:
   Added.



##########
sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbUtils.java:
##########
@@ -71,4 +76,127 @@ public static Document toDocument(Row row) {
     }
     return value;
   }
+
+  /**
+   * Converts a BSON {@link Document} (or any Map representing fields) to a 
Beam {@link Row}
+   * matching the given {@link Schema}.
+   */
+  public static Row toRow(Map<?, ?> doc, Schema schema) {
+    Row.Builder rowBuilder = Row.withSchema(schema);
+    for (Field field : schema.getFields()) {
+      Object value = doc.get(field.getName());
+      rowBuilder.addValue(convertFromBsonValue(value, field.getType()));
+    }
+    return rowBuilder.build();
+  }
+
+  @SuppressWarnings("JavaUtilDate")
+  private static @Nullable Object convertFromBsonValue(
+      @Nullable Object value, FieldType fieldType) {
+    if (value == null || value instanceof BsonNull) {
+      return null;
+    }
+
+    switch (fieldType.getTypeName()) {
+      case BYTE:
+        return (value instanceof Number)
+            ? ((Number) value).byteValue()
+            : Byte.parseByte(value.toString());
+      case INT16:
+        return (value instanceof Number)
+            ? ((Number) value).shortValue()
+            : Short.parseShort(value.toString());
+      case INT32:
+        return (value instanceof Number)
+            ? ((Number) value).intValue()
+            : Integer.parseInt(value.toString());
+      case INT64:
+        return (value instanceof Number)
+            ? ((Number) value).longValue()
+            : Long.parseLong(value.toString());
+      case FLOAT:
+        return (value instanceof Number)
+            ? ((Number) value).floatValue()
+            : Float.parseFloat(value.toString());
+      case DOUBLE:
+        return (value instanceof Number)
+            ? ((Number) value).doubleValue()
+            : Double.parseDouble(value.toString());
+      case DECIMAL:
+        return (value instanceof Number)
+            ? java.math.BigDecimal.valueOf(((Number) value).doubleValue())
+            : new java.math.BigDecimal(value.toString());
+      case STRING:
+        return value.toString();
+      case BOOLEAN:
+        return (value instanceof Boolean)
+            ? (Boolean) value
+            : Boolean.parseBoolean(value.toString());
+      case DATETIME:
+        if (value instanceof java.util.Date) {
+          return new Instant(((java.util.Date) value).getTime());
+        } else if (value instanceof Number) {
+          return new Instant(((Number) value).longValue());
+        } else {
+          return Instant.parse(value.toString());
+        }
+      case BYTES:
+        if (value instanceof Binary) {
+          return ((Binary) value).getData();
+        } else if (value instanceof byte[]) {
+          return (byte[]) value;
+        } else {
+          return 
value.toString().getBytes(java.nio.charset.StandardCharsets.UTF_8);
+        }
+      case ARRAY:
+      case ITERABLE:
+        if (!(value instanceof Iterable)) {
+          throw new IllegalArgumentException(
+              "Expected Iterable for type "
+                  + fieldType
+                  + ", but got: "
+                  + value.getClass().getName());
+        }
+        Iterable<?> iterable = (Iterable<?>) value;
+        List<@Nullable Object> rowList = new ArrayList<>();
+        FieldType elementType = fieldType.getCollectionElementType();
+        if (elementType == null) {
+          throw new IllegalArgumentException(
+              "Collection element type cannot be null for type: " + fieldType);
+        }
+        for (Object item : iterable) {
+          rowList.add(convertFromBsonValue(item, elementType));
+        }
+        return rowList;
+      case MAP:
+        if (!(value instanceof Map)) {
+          throw new IllegalArgumentException(
+              "Expected Map for type " + fieldType + ", but got: " + 
value.getClass().getName());
+        }
+        Map<?, ?> map = (Map<?, ?>) value;
+        Map<String, @Nullable Object> rowMap = new HashMap<>();
+        FieldType valueType = fieldType.getMapValueType();
+        if (valueType == null) {
+          throw new IllegalArgumentException(
+              "Map value type cannot be null for type: " + fieldType);
+        }
+        for (Map.Entry<?, ?> entry : map.entrySet()) {
+          rowMap.put(
+              String.valueOf(entry.getKey()), 
convertFromBsonValue(entry.getValue(), valueType));
+        }
+        return rowMap;
+      case ROW:
+        Schema rowSchema = fieldType.getRowSchema();
+        if (rowSchema == null) {
+          throw new IllegalArgumentException("Row schema cannot be null for 
type: " + fieldType);
+        }
+        if (value instanceof Map) {
+          return toRow((Map<?, ?>) value, rowSchema);
+        } else {
+          throw new IllegalArgumentException("Cannot convert value to Row: " + 
value);

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to