This is an automated email from the ASF dual-hosted git repository.
yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 22097177411 Adding error handler for
SpannerReadSchemaTransformProvider and missi… (#35241)
22097177411 is described below
commit 2209717741195e9f7969265f5ee618c3d78c2a08
Author: Tanu Sharma <[email protected]>
AuthorDate: Tue Jun 17 03:16:15 2025 +0530
Adding error handler for SpannerReadSchemaTransformProvider and missi…
(#35241)
* Adding error handler for SpannerReadSchemaTransformProvider and missing
tests for SpannerSchemaTransformProvider
* Removed not used logging
* Spotless Apply
* Spotless Apply
* Spotless Apply
* Typo correction
---
.../SpannerReadSchemaTransformProvider.java | 92 +++++-
.../SpannerSchemaTransformProviderTest.java | 316 +++++++++++++++++++++
sdks/python/apache_beam/yaml/standard_io.yaml | 1 +
3 files changed, 402 insertions(+), 7 deletions(-)
diff --git
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java
index 0bcf6e0c4f7..73f4c2dfe30 100644
---
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java
+++
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerReadSchemaTransformProvider.java
@@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.io.gcp.spanner;
+import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;
import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
import static
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull;
@@ -24,9 +25,12 @@ import com.google.auto.service.AutoService;
import com.google.auto.value.AutoValue;
import com.google.cloud.spanner.Struct;
import java.io.Serializable;
+import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import javax.annotation.Nullable;
+import org.apache.beam.sdk.metrics.Counter;
+import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.schemas.AutoValueSchema;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
@@ -34,11 +38,19 @@ import
org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider;
-import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.DoFn.FinishBundle;
+import org.apache.beam.sdk.transforms.DoFn.FinishBundleContext;
+import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver;
+import org.apache.beam.sdk.transforms.DoFn.ProcessElement;
+import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
+import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.Row;
-import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings;
/** A provider for reading from Cloud Spanner using a Schema Transform
Provider. */
@@ -61,6 +73,11 @@ public class SpannerReadSchemaTransformProvider
extends TypedSchemaTransformProvider<
SpannerReadSchemaTransformProvider.SpannerReadSchemaTransformConfiguration> {
+ public static final TupleTag<Row> OUTPUT_TAG = new TupleTag<Row>() {};
+ public static final TupleTag<Row> ERROR_TAG = new TupleTag<Row>() {};
+ public static final Schema ERROR_SCHEMA =
+ Schema.builder().addStringField("error").addStringField("row").build();
+
@Override
public String identifier() {
return "beam:schematransform:org.apache.beam:spanner_read:v1";
@@ -133,6 +150,7 @@ public class SpannerReadSchemaTransformProvider
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
checkNotNull(input, "Input to SpannerReadSchemaTransform cannot be
null.");
+ boolean handleErrors =
ErrorHandling.hasOutput(configuration.getErrorHandling());
SpannerIO.Read read =
SpannerIO.readWithSchema()
.withProjectId(configuration.getProjectId())
@@ -152,12 +170,66 @@ public class SpannerReadSchemaTransformProvider
}
PCollection<Struct> spannerRows = input.getPipeline().apply(read);
Schema schema = spannerRows.getSchema();
- PCollection<Row> rows =
+
+ PCollectionTuple outputTuple =
spannerRows.apply(
- MapElements.into(TypeDescriptor.of(Row.class))
- .via((Struct struct) -> StructUtils.structToBeamRow(struct,
schema)));
+ ParDo.of(
+ new ErrorFn("spanner-read-error-counter", ERROR_SCHEMA,
schema, handleErrors))
+ .withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));
+
+ PCollectionRowTuple outputRows =
+ PCollectionRowTuple.of("output",
outputTuple.get(OUTPUT_TAG).setRowSchema(schema));
+
+ // Error handling
+ PCollection<Row> errorOutput =
outputTuple.get(ERROR_TAG).setRowSchema(ERROR_SCHEMA);
+ if (handleErrors) {
+ outputRows =
+ outputRows.and(
+
checkArgumentNotNull(configuration.getErrorHandling()).getOutput(),
errorOutput);
+ }
- return PCollectionRowTuple.of("output", rows.setRowSchema(schema));
+ return outputRows;
+ }
+ }
+
+ public static class ErrorFn extends DoFn<Struct, Row> {
+ private final Counter errorCounter;
+ private Long errorsInBundle = 0L;
+ private final boolean handleErrors;
+ private final Schema errorSchema;
+ private final Schema schema;
+
+ public ErrorFn(String name, Schema errorSchema, Schema schema, boolean
handleErrors) {
+ this.errorCounter =
Metrics.counter(SpannerReadSchemaTransformProvider.class, name);
+ this.handleErrors = handleErrors;
+ this.errorSchema = errorSchema;
+ this.schema = schema;
+ }
+
+ @ProcessElement
+ public void processElement(@DoFn.Element Struct struct,
MultiOutputReceiver receiver) {
+ Row mappedRow = null;
+ try {
+ mappedRow = StructUtils.structToBeamRow(struct, schema);
+ } catch (Exception e) {
+ if (!handleErrors) {
+ throw new RuntimeException(e);
+ }
+ errorsInBundle += 1;
+ receiver
+ .get(ERROR_TAG)
+ .output(
+ Row.withSchema(errorSchema).addValues(e.getMessage(),
struct.toString()).build());
+ }
+ if (mappedRow != null) {
+ receiver.get(OUTPUT_TAG).output(mappedRow);
+ }
+ }
+
+ @FinishBundle
+ public void finish(FinishBundleContext c) {
+ errorCounter.inc(errorsInBundle);
+ errorsInBundle = 0L;
}
}
@@ -168,7 +240,7 @@ public class SpannerReadSchemaTransformProvider
@Override
public List<String> outputCollectionNames() {
- return Collections.singletonList("output");
+ return Arrays.asList("output", "errors");
}
@DefaultSchema(AutoValueSchema.class)
@@ -193,6 +265,8 @@ public class SpannerReadSchemaTransformProvider
public abstract Builder setBatching(Boolean batching);
+ public abstract Builder setErrorHandling(ErrorHandling errorHandling);
+
public abstract SpannerReadSchemaTransformConfiguration build();
}
@@ -261,6 +335,10 @@ public class SpannerReadSchemaTransformProvider
"Set to false to disable batching. Useful when using a query that is
not compatible with the PartitionQuery API. Defaults to true.")
@Nullable
public abstract Boolean getBatching();
+
+ @SchemaFieldDescription("This option specifies whether and where to output
unwritable rows.")
+ @Nullable
+ public abstract ErrorHandling getErrorHandling();
}
@Override
diff --git
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerSchemaTransformProviderTest.java
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerSchemaTransformProviderTest.java
new file mode 100644
index 00000000000..00308ef214b
--- /dev/null
+++
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerSchemaTransformProviderTest.java
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.spanner;
+
+import static
org.apache.beam.sdk.io.gcp.spanner.SpannerReadSchemaTransformProvider.SpannerReadSchemaTransformConfiguration;
+import static
org.apache.beam.sdk.io.gcp.spanner.SpannerWriteSchemaTransformProvider.SpannerWriteSchemaTransformConfiguration;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThrows;
+
+import com.google.cloud.spanner.Struct;
+import com.google.cloud.spanner.Value;
+import java.util.Arrays;
+import java.util.List;
+import java.util.ServiceLoader;
+import java.util.stream.Collectors;
+import java.util.stream.StreamSupport;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.TupleTagList;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
+import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link
org.apache.beam.sdk.io.gcp.spanner.SpannerReadSchemaTransformProvider} and
+ * {@link
org.apache.beam.sdk.io.gcp.spanner.SpannerWriteSchemaTransformProvider}.
+ */
+@RunWith(JUnit4.class)
+public class SpannerSchemaTransformProviderTest {
+
+ @Rule
+ public final transient TestPipeline pipeline =
+ TestPipeline.create().enableAbandonedNodeEnforcement(false);
+
+ @Test
+ public void testSpannerWriteValidations() {
+
+ // Missing instanceId should throw IllegalArgumentException
+ assertThrows(
+ IllegalStateException.class,
+ () -> {
+ SpannerWriteSchemaTransformConfiguration.builder()
+ .setDatabaseId("validDatabase")
+ .setTableId("validTable")
+ .build()
+ .validate();
+ });
+
+ // Missing databaseId should throw IllegalArgumentException
+ assertThrows(
+ IllegalStateException.class,
+ () -> {
+ SpannerWriteSchemaTransformConfiguration.builder()
+ .setInstanceId("validInstance")
+ .setTableId("validTable")
+ .build()
+ .validate();
+ });
+
+ // Missing tableId should throw IllegalArgumentException
+ assertThrows(
+ IllegalStateException.class,
+ () -> {
+ SpannerWriteSchemaTransformConfiguration.builder()
+ .setInstanceId("validInstance")
+ .setDatabaseId("validDatabase")
+ .build()
+ .validate();
+ });
+
+ // Valid config should NOT throw any exceptions
+ SpannerWriteSchemaTransformConfiguration.builder()
+ .setInstanceId("validInstance")
+ .setDatabaseId("validDatabase")
+ .setTableId("validTable")
+ .build()
+ .validate();
+ }
+
+ @Test
+ public void testSpannerReadValidations() {
+
+ // Missing instanceId should throw
+ assertThrows(
+ IllegalStateException.class,
+ () -> {
+ SpannerReadSchemaTransformConfiguration.builder()
+ .setDatabaseId("db")
+ .setTableId("table")
+ .build()
+ .validate();
+ });
+
+ // Missing databaseId should throw
+ assertThrows(
+ IllegalStateException.class,
+ () -> {
+ SpannerReadSchemaTransformConfiguration.builder()
+ .setInstanceId("instance")
+ .setTableId("table")
+ .build()
+ .validate();
+ });
+
+ // Missing both tableId and query should throw
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> {
+ SpannerReadSchemaTransformConfiguration.builder()
+ .setInstanceId("instance")
+ .setDatabaseId("db")
+ .build()
+ .validate();
+ });
+
+ // TableId without columns should throw
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> {
+ SpannerReadSchemaTransformConfiguration.builder()
+ .setInstanceId("instance")
+ .setDatabaseId("db")
+ .setTableId("table")
+ .build()
+ .validate();
+ });
+
+ // Valid table-based read config
+ SpannerReadSchemaTransformConfiguration.builder()
+ .setInstanceId("instance")
+ .setDatabaseId("db")
+ .setTableId("table")
+ .setColumns(Arrays.asList("col1"))
+ .build()
+ .validate();
+
+ // Valid query-based config
+ SpannerReadSchemaTransformConfiguration.builder()
+ .setInstanceId("instance")
+ .setDatabaseId("db")
+ .setQuery("SELECT * FROM table")
+ .build()
+ .validate();
+
+ // Both query and tableId provided – should throw if mutually exclusive
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> {
+ SpannerReadSchemaTransformConfiguration.builder()
+ .setInstanceId("instance")
+ .setDatabaseId("db")
+ .setQuery("SELECT * FROM table")
+ .setTableId("table")
+ .build()
+ .validate();
+ });
+ }
+
+ @Test
+ public void testReadBuildTransform() {
+ SpannerReadSchemaTransformProvider provider = new
SpannerReadSchemaTransformProvider();
+ provider.from(
+ SpannerReadSchemaTransformConfiguration.builder()
+ .setProjectId("test-project")
+ .setInstanceId("test-instance")
+ .setDatabaseId("test-database")
+ .setQuery("SELECT * FROM users")
+ .build());
+ }
+
+ @Test
+ public void testWriteBuildTransform() {
+ SpannerWriteSchemaTransformProvider provider = new
SpannerWriteSchemaTransformProvider();
+ provider.from(
+ SpannerWriteSchemaTransformConfiguration.builder()
+ .setProjectId("test-project")
+ .setInstanceId("test-instance")
+ .setDatabaseId("test-database")
+ .setTableId("test-table")
+ .build());
+ }
+
+ @Test
+ public void testReadFindTransformAndMakeItWork() {
+ ServiceLoader<SchemaTransformProvider> serviceLoader =
+ ServiceLoader.load(SchemaTransformProvider.class);
+ List<SchemaTransformProvider> providers =
+ StreamSupport.stream(serviceLoader.spliterator(), false)
+ .filter(provider -> provider.getClass() ==
SpannerReadSchemaTransformProvider.class)
+ .collect(Collectors.toList());
+
+ assertFalse("SpannerReadSchemaTransformProvider not found",
providers.isEmpty());
+
+ SchemaTransformProvider spannerProvider = providers.get(0);
+
+ // Check expected output and input collection names (adjust if different)
+ assertEquals(Lists.newArrayList("output", "errors"),
spannerProvider.outputCollectionNames());
+ assertEquals(Lists.newArrayList(), spannerProvider.inputCollectionNames());
+
+ assertEquals(
+ Sets.newHashSet(
+ "project_id",
+ "instance_id",
+ "database_id",
+ "table_id",
+ "query",
+ "columns",
+ "index",
+ "batching",
+ "error_handling"),
+ spannerProvider.configurationSchema().getFields().stream()
+ .map(field -> field.getName())
+ .collect(Collectors.toSet()));
+ }
+
+ @Test
+ public void testWriteFindTransformAndMakeItWork() {
+ ServiceLoader<SchemaTransformProvider> serviceLoader =
+ ServiceLoader.load(SchemaTransformProvider.class);
+
+ List<SchemaTransformProvider> providers =
+ StreamSupport.stream(serviceLoader.spliterator(), false)
+ .filter(provider -> provider.getClass() ==
SpannerWriteSchemaTransformProvider.class)
+ .collect(Collectors.toList());
+
+ assertFalse("SpannerWriteSchemaTransformProvider not found",
providers.isEmpty());
+
+ SchemaTransformProvider spannerWriteProvider = providers.get(0);
+
+ // Typically write transforms output to 'output' and 'errors' collections
(adjust if needed)
+ assertEquals(
+ Lists.newArrayList("post-write", "errors"),
spannerWriteProvider.outputCollectionNames());
+
+ assertEquals(
+ Sets.newHashSet("instance_id", "database_id", "table_id",
"project_id", "error_handling"),
+ spannerWriteProvider.configurationSchema().getFields().stream()
+ .map(field -> field.getName())
+ .collect(Collectors.toSet()));
+ }
+
+ @Test
+ public void testErrorFnCapturesStructFailureAsRow() {
+
+ Struct badStruct =
+
Struct.newBuilder().set("non_existing_field").to(Value.string("bad_value")).build();
+
+ // Define a mismatched schema (does not match the struct)
+ Schema expectedSchema = Schema.builder().addStringField("id").build(); //
"id" not in struct
+
+ TupleTag<Row> outputTag = SpannerReadSchemaTransformProvider.OUTPUT_TAG;
+ TupleTag<Row> errorTag = SpannerReadSchemaTransformProvider.ERROR_TAG;
+
+ PCollection<Struct> spannerRows = pipeline.apply("CreateBadStruct",
Create.of(badStruct));
+
+ PCollectionTuple result =
+ spannerRows.apply(
+ ParDo.of(
+ new SpannerReadSchemaTransformProvider.ErrorFn(
+ "test-counter",
+ SpannerReadSchemaTransformProvider.ERROR_SCHEMA,
+ expectedSchema,
+ true))
+ .withOutputTags(outputTag, TupleTagList.of(errorTag)));
+
+ result.get(outputTag).setRowSchema(expectedSchema);
+ PCollection<Row> errorOutput =
+
result.get(errorTag).setRowSchema(SpannerReadSchemaTransformProvider.ERROR_SCHEMA);
+
+ PAssert.that(errorOutput)
+ .satisfies(
+ rows -> {
+ if (!rows.iterator().hasNext()) {
+ throw new AssertionError("Expected at least one error row but
got none.");
+ }
+
+ Row errorRow = rows.iterator().next();
+ String errorMsg = errorRow.getString("error");
+ String rowString = errorRow.getString("row");
+
+ assert errorMsg != null && errorMsg.contains("Field not found:
id")
+ : "Missing expected error message. Got: " + errorMsg;
+ assert rowString != null && rowString.contains("bad_value")
+ : "Row string does not contain expected content. Got: " +
rowString;
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+}
diff --git a/sdks/python/apache_beam/yaml/standard_io.yaml
b/sdks/python/apache_beam/yaml/standard_io.yaml
index 15d3ccd3dda..d3d6e8f8d02 100644
--- a/sdks/python/apache_beam/yaml/standard_io.yaml
+++ b/sdks/python/apache_beam/yaml/standard_io.yaml
@@ -332,6 +332,7 @@
columns: 'columns'
index: 'index'
batching: 'batching'
+ error_handling: 'error_handling'
'WriteToSpanner':
project: 'project_id'
instance: 'instance_id'