This is an automated email from the ASF dual-hosted git repository.
bhulette pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 2f2ffda [BEAM-10139][BEAM-10140] Add cross-language support for Java
SpannerIO with python wrapper (#12611)
2f2ffda is described below
commit 2f2ffdafb341458fd626d1135cd87923231b31b8
Author: Piotr Szuberski <[email protected]>
AuthorDate: Mon Nov 16 19:12:22 2020 +0100
[BEAM-10139][BEAM-10140] Add cross-language support for Java SpannerIO with
python wrapper (#12611)
* [BEAM-10139][BEAM-10140] Add Support for cross-language transforms with
python wrapper to Java SpannerIO
* Change docstrings to use named parameters
* Remove mutation row concept from write operations
* exactStaleness -> staleness
* Deal with nullness
* Remove deprecated javadoc
---
.../expansion-service/build.gradle | 37 ++
.../beam/sdk/io/gcp/spanner/MutationUtils.java | 320 +++++++++++
.../beam/sdk/io/gcp/spanner/SpannerAccessor.java | 6 +
.../beam/sdk/io/gcp/spanner/SpannerConfig.java | 8 +
.../apache/beam/sdk/io/gcp/spanner/SpannerIO.java | 112 +++-
.../io/gcp/spanner/SpannerTransformRegistrar.java | 374 ++++++++++++
.../beam/sdk/io/gcp/spanner/StructUtils.java | 387 +++++++++++++
.../beam/sdk/io/gcp/spanner/MutationUtilsTest.java | 285 +++++++++
.../beam/sdk/io/gcp/spanner/StructUtilsTest.java | 258 +++++++++
sdks/python/apache_beam/io/gcp/spanner.py | 635 +++++++++++++++++++++
.../io/gcp/tests/xlang_spannerio_it_test.py | 339 +++++++++++
sdks/python/test-suites/portable/common.gradle | 2 +
settings.gradle | 1 +
13 files changed, 2758 insertions(+), 6 deletions(-)
diff --git a/sdks/java/io/google-cloud-platform/expansion-service/build.gradle
b/sdks/java/io/google-cloud-platform/expansion-service/build.gradle
new file mode 100644
index 0000000..2d17997
--- /dev/null
+++ b/sdks/java/io/google-cloud-platform/expansion-service/build.gradle
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+apply plugin: 'org.apache.beam.module'
+apply plugin: 'application'
+mainClassName = "org.apache.beam.sdk.expansion.service.ExpansionService"
+
+applyJavaNature(
+ automaticModuleName: 'org.apache.beam.sdk.io.gcp.expansion.service',
+ exportJavadoc: false,
+ validateShadowJar: false,
+ shadowClosure: {},
+)
+
+description = "Apache Beam :: SDKs :: Java :: IO :: Google Cloud Platform ::
Expansion Service"
+ext.summary = "Expansion service serving GCP Java IOs"
+
+dependencies {
+ compile project(":sdks:java:expansion-service")
+ compile project(":sdks:java:io:google-cloud-platform")
+ runtime library.java.slf4j_jdk14
+}
diff --git
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/MutationUtils.java
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/MutationUtils.java
index 4edbfa6..a84527c 100644
---
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/MutationUtils.java
+++
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/MutationUtils.java
@@ -17,8 +17,24 @@
*/
package org.apache.beam.sdk.io.gcp.spanner;
+import static java.util.stream.Collectors.toList;
+import static org.apache.beam.sdk.io.gcp.spanner.StructUtils.beamRowToStruct;
+import static
org.apache.beam.sdk.io.gcp.spanner.StructUtils.beamTypeToSpannerType;
+import static
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull;
+
+import com.google.cloud.ByteArray;
+import com.google.cloud.Timestamp;
+import com.google.cloud.spanner.Key;
import com.google.cloud.spanner.Mutation;
+import java.math.BigDecimal;
+import java.util.List;
+import java.util.stream.StreamSupport;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.values.Row;
import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.ReadableDateTime;
final class MutationUtils {
private MutationUtils() {}
@@ -34,4 +50,308 @@ final class MutationUtils {
&& Iterables.isEmpty(m.getKeySet().getRanges())
&& Iterables.size(m.getKeySet().getKeys()) == 1;
}
+
+ /**
+ * Utility function to convert row to mutation for cross-language usage.
+ *
+ * @return function that can convert row to mutation
+ */
+ public static SerializableFunction<Row, Mutation> beamRowToMutationFn(
+ Mutation.Op operation, String table) {
+ return (row -> {
+ switch (operation) {
+ case INSERT:
+ return
MutationUtils.createMutationFromBeamRows(Mutation.newInsertBuilder(table), row);
+ case DELETE:
+ return Mutation.delete(table,
MutationUtils.createKeyFromBeamRow(row));
+ case UPDATE:
+ return
MutationUtils.createMutationFromBeamRows(Mutation.newUpdateBuilder(table), row);
+ case REPLACE:
+ return
MutationUtils.createMutationFromBeamRows(Mutation.newReplaceBuilder(table),
row);
+ case INSERT_OR_UPDATE:
+ return MutationUtils.createMutationFromBeamRows(
+ Mutation.newInsertOrUpdateBuilder(table), row);
+ default:
+ throw new IllegalArgumentException(
+ String.format("Unknown mutation operation type: %s", operation));
+ }
+ });
+ }
+
+ private static Key createKeyFromBeamRow(Row row) {
+ Key.Builder builder = Key.newBuilder();
+ Schema schema = row.getSchema();
+ List<String> columns = schema.getFieldNames();
+ columns.forEach(
+ columnName ->
+ setBeamValueToKey(builder, schema.getField(columnName).getType(),
columnName, row));
+ return builder.build();
+ }
+
+ private static Mutation createMutationFromBeamRows(
+ Mutation.WriteBuilder mutationBuilder, Row row) {
+ Schema schema = row.getSchema();
+ List<String> columns = schema.getFieldNames();
+ columns.forEach(
+ columnName ->
+ setBeamValueToMutation(
+ mutationBuilder, schema.getField(columnName).getType(),
columnName, row));
+ return mutationBuilder.build();
+ }
+
+ private static void setBeamValueToKey(
+ Key.Builder keyBuilder, Schema.FieldType field, String columnName, Row
row) {
+ switch (field.getTypeName()) {
+ case BYTE:
+ @Nullable Byte byteValue = row.getByte(columnName);
+ if (byteValue == null) {
+ keyBuilder.append((Long) null);
+ } else {
+ keyBuilder.append(byteValue);
+ }
+ break;
+ case INT16:
+ @Nullable Short int16 = row.getInt16(columnName);
+ if (int16 == null) {
+ keyBuilder.append((Long) null);
+ } else {
+ keyBuilder.append(int16);
+ }
+ break;
+ case INT32:
+ @Nullable Integer int32 = row.getInt32(columnName);
+ if (int32 == null) {
+ keyBuilder.append((Long) null);
+ } else {
+ keyBuilder.append(int32);
+ }
+ break;
+ case INT64:
+ keyBuilder.append(row.getInt64(columnName));
+ break;
+ case FLOAT:
+ @Nullable Float floatValue = row.getFloat(columnName);
+ if (floatValue == null) {
+ keyBuilder.append((Double) null);
+ } else {
+ keyBuilder.append(floatValue);
+ }
+ break;
+ case DOUBLE:
+ keyBuilder.append(row.getDouble(columnName));
+ break;
+ case DECIMAL:
+ keyBuilder.append(row.getDecimal(columnName));
+ break;
+ // TODO: Implement logical date and datetime
+ case DATETIME:
+ @Nullable ReadableDateTime dateTime = row.getDateTime(columnName);
+ if (dateTime == null) {
+ keyBuilder.append((Timestamp) null);
+ } else {
+ keyBuilder.append(
+ Timestamp.ofTimeMicroseconds(dateTime.toInstant().getMillis() *
1_000L));
+ }
+ break;
+ case BOOLEAN:
+ keyBuilder.append(row.getBoolean(columnName));
+ break;
+ case STRING:
+ keyBuilder.append(row.getString(columnName));
+ break;
+ case BYTES:
+ byte @Nullable [] bytes = row.getBytes(columnName);
+ if (bytes == null) {
+ keyBuilder.append((ByteArray) null);
+ } else {
+ keyBuilder.append(ByteArray.copyFrom(bytes));
+ }
+ break;
+ default:
+ throw new IllegalArgumentException(
+ String.format("Unsupported field type: %s", field.getTypeName()));
+ }
+ }
+
+ private static void setBeamValueToMutation(
+ Mutation.WriteBuilder mutationBuilder,
+ Schema.FieldType fieldType,
+ String columnName,
+ Row row) {
+ switch (fieldType.getTypeName()) {
+ case BYTE:
+ @Nullable Byte byteValue = row.getByte(columnName);
+ if (byteValue == null) {
+ mutationBuilder.set(columnName).to(((Long) null));
+ } else {
+ mutationBuilder.set(columnName).to(byteValue);
+ }
+ break;
+ case INT16:
+ @Nullable Short int16 = row.getInt16(columnName);
+ if (int16 == null) {
+ mutationBuilder.set(columnName).to(((Long) null));
+ } else {
+ mutationBuilder.set(columnName).to(int16);
+ }
+ break;
+ case INT32:
+ @Nullable Integer int32 = row.getInt32(columnName);
+ if (int32 == null) {
+ mutationBuilder.set(columnName).to(((Long) null));
+ } else {
+ mutationBuilder.set(columnName).to(int32);
+ }
+ break;
+ case INT64:
+ mutationBuilder.set(columnName).to(row.getInt64(columnName));
+ break;
+ case FLOAT:
+ @Nullable Float floatValue = row.getFloat(columnName);
+ if (floatValue == null) {
+ mutationBuilder.set(columnName).to(((Double) null));
+ } else {
+ mutationBuilder.set(columnName).to(floatValue);
+ }
+ break;
+ case DOUBLE:
+ mutationBuilder.set(columnName).to(row.getDouble(columnName));
+ break;
+ case DECIMAL:
+ @Nullable BigDecimal decimal = row.getDecimal(columnName);
+ // BigDecimal is not nullable
+ if (decimal == null) {
+ checkNotNull(decimal, "Null decimal at column " + columnName);
+ } else {
+ mutationBuilder.set(columnName).to(decimal);
+ }
+ break;
+ // TODO: Implement logical date and datetime
+ case DATETIME:
+ @Nullable ReadableDateTime dateTime = row.getDateTime(columnName);
+ if (dateTime == null) {
+ mutationBuilder.set(columnName).to(((Timestamp) null));
+ } else {
+ mutationBuilder
+ .set(columnName)
+
.to(Timestamp.ofTimeMicroseconds(dateTime.toInstant().getMillis() * 1000L));
+ }
+ break;
+ case BOOLEAN:
+ mutationBuilder.set(columnName).to(row.getBoolean(columnName));
+ break;
+ case STRING:
+ mutationBuilder.set(columnName).to(row.getString(columnName));
+ break;
+ case BYTES:
+ byte @Nullable [] bytes = row.getBytes(columnName);
+ if (bytes == null) {
+ mutationBuilder.set(columnName).to(((ByteArray) null));
+ } else {
+ mutationBuilder.set(columnName).to(ByteArray.copyFrom(bytes));
+ }
+ break;
+ case ROW:
+ @Nullable Row subRow = row.getRow(columnName);
+ if (subRow == null) {
+ mutationBuilder
+ .set(columnName)
+
.to(beamTypeToSpannerType(row.getSchema().getField(columnName).getType()),
null);
+ } else {
+ mutationBuilder
+ .set(columnName)
+ .to(
+
beamTypeToSpannerType(row.getSchema().getField(columnName).getType()),
+ beamRowToStruct(subRow));
+ }
+ break;
+ case ARRAY:
+ addIterableToMutationBuilder(
+ mutationBuilder, row.getArray(columnName),
row.getSchema().getField(columnName));
+ break;
+ case ITERABLE:
+ addIterableToMutationBuilder(
+ mutationBuilder, row.getIterable(columnName),
row.getSchema().getField(columnName));
+ break;
+ default:
+ throw new IllegalArgumentException(
+ String.format("Unsupported field type: %s",
fieldType.getTypeName()));
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ private static void addIterableToMutationBuilder(
+ Mutation.WriteBuilder mutationBuilder,
+ @Nullable Iterable<Object> iterable,
+ Schema.Field field) {
+ String column = field.getName();
+ Schema.FieldType beamIterableType =
field.getType().getCollectionElementType();
+ if (beamIterableType == null) {
+ throw new NullPointerException("Null collection element type at field "
+ field.getName());
+ }
+ Schema.TypeName beamIterableTypeName = beamIterableType.getTypeName();
+ switch (beamIterableTypeName) {
+ case ROW:
+ if (iterable == null) {
+
mutationBuilder.set(column).toStructArray(beamTypeToSpannerType(beamIterableType),
null);
+ } else {
+ mutationBuilder
+ .set(column)
+ .toStructArray(
+ beamTypeToSpannerType(beamIterableType),
+ StreamSupport.stream(iterable.spliterator(), false)
+ .map(row -> beamRowToStruct((Row) row))
+ .collect(toList()));
+ }
+ break;
+ case INT16:
+ case INT32:
+ case INT64:
+ case BYTE:
+ mutationBuilder.set(column).toInt64Array((Iterable<Long>) ((Object)
iterable));
+ break;
+ case FLOAT:
+ case DOUBLE:
+ mutationBuilder.set(column).toFloat64Array((Iterable<Double>)
((Object) iterable));
+ break;
+ case DECIMAL:
+ mutationBuilder.set(column).toNumericArray((Iterable<BigDecimal>)
((Object) iterable));
+ break;
+ case BOOLEAN:
+ mutationBuilder.set(column).toBoolArray((Iterable<Boolean>) ((Object)
iterable));
+ break;
+ case BYTES:
+ if (iterable == null) {
+ mutationBuilder.set(column).toBytesArray(null);
+ } else {
+ mutationBuilder
+ .set(column)
+ .toBytesArray(
+ StreamSupport.stream(iterable.spliterator(), false)
+ .map(object -> ByteArray.copyFrom((byte[]) object))
+ .collect(toList()));
+ }
+ break;
+ case STRING:
+ mutationBuilder.set(column).toStringArray((Iterable<String>) ((Object)
iterable));
+ break;
+ case DATETIME:
+ if (iterable == null) {
+ mutationBuilder.set(column).toDateArray(null);
+ } else {
+ mutationBuilder
+ .set(column)
+ .toTimestampArray(
+ StreamSupport.stream(iterable.spliterator(), false)
+ .map(datetime ->
Timestamp.parseTimestamp((datetime).toString()))
+ .collect(toList()));
+ }
+ break;
+ default:
+ throw new IllegalArgumentException(
+ String.format(
+ "Unsupported iterable type '%s' while translating row to
struct.",
+ beamIterableType.getTypeName()));
+ }
+ }
}
diff --git
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java
index 63e85ea..a22b5b9 100644
---
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java
+++
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerAccessor.java
@@ -20,6 +20,7 @@ package org.apache.beam.sdk.io.gcp.spanner;
import com.google.api.gax.retrying.RetrySettings;
import com.google.api.gax.rpc.FixedHeaderProvider;
import com.google.api.gax.rpc.UnaryCallSettings;
+import com.google.cloud.NoCredentials;
import com.google.cloud.ServiceFactory;
import com.google.cloud.spanner.BatchClient;
import com.google.cloud.spanner.DatabaseAdminClient;
@@ -134,6 +135,11 @@ class SpannerAccessor implements AutoCloseable {
if (host != null) {
builder.setHost(host.get());
}
+ ValueProvider<String> emulatorHost = spannerConfig.getEmulatorHost();
+ if (emulatorHost != null) {
+ builder.setEmulatorHost(emulatorHost.get());
+ builder.setCredentials(NoCredentials.getInstance());
+ }
String userAgentString = USER_AGENT_PREFIX + "/" +
ReleaseInfo.getReleaseInfo().getVersion();
builder.setHeaderProvider(FixedHeaderProvider.create("user-agent",
userAgentString));
SpannerOptions options = builder.build();
diff --git
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java
index 5384547..6ca0a1e 100644
---
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java
+++
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerConfig.java
@@ -51,6 +51,8 @@ public abstract class SpannerConfig implements Serializable {
public abstract @Nullable ValueProvider<String> getHost();
+ public abstract @Nullable ValueProvider<String> getEmulatorHost();
+
public abstract @Nullable ValueProvider<Duration> getCommitDeadline();
public abstract @Nullable ValueProvider<Duration> getMaxCumulativeBackoff();
@@ -107,6 +109,8 @@ public abstract class SpannerConfig implements Serializable
{
abstract Builder setHost(ValueProvider<String> host);
+ abstract Builder setEmulatorHost(ValueProvider<String> emulatorHost);
+
abstract Builder setCommitDeadline(ValueProvider<Duration> commitDeadline);
abstract Builder setMaxCumulativeBackoff(ValueProvider<Duration>
maxCumulativeBackoff);
@@ -144,6 +148,10 @@ public abstract class SpannerConfig implements
Serializable {
return toBuilder().setHost(host).build();
}
+ public SpannerConfig withEmulatorHost(ValueProvider<String> emulatorHost) {
+ return toBuilder().setEmulatorHost(emulatorHost).build();
+ }
+
public SpannerConfig withCommitDeadline(Duration commitDeadline) {
return
withCommitDeadline(ValueProvider.StaticValueProvider.of(commitDeadline));
}
diff --git
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
index b1b4c69..b31d716 100644
---
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
+++
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
@@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.io.gcp.spanner;
+import static java.util.stream.Collectors.toList;
import static org.apache.beam.sdk.io.gcp.spanner.MutationUtils.isPointDelete;
import static
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
import static
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull;
@@ -45,7 +46,6 @@ import java.util.Comparator;
import java.util.List;
import java.util.OptionalInt;
import java.util.concurrent.TimeUnit;
-import java.util.stream.Collectors;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.annotations.Experimental.Kind;
import org.apache.beam.sdk.coders.SerializableCoder;
@@ -53,6 +53,7 @@ import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Distribution;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.options.ValueProvider;
+import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Flatten;
@@ -60,6 +61,7 @@ import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Reshuffle;
+import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.Wait;
import org.apache.beam.sdk.transforms.display.DisplayData;
@@ -76,6 +78,8 @@ import org.apache.beam.sdk.values.PCollection.IsBounded;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PDone;
+import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.sdk.values.TypeDescriptor;
@@ -455,6 +459,16 @@ public class SpannerIO {
return withHost(ValueProvider.StaticValueProvider.of(host));
}
+ /** Specifies the Cloud Spanner emulator host. */
+ public ReadAll withEmulatorHost(ValueProvider<String> emulatorHost) {
+ SpannerConfig config = getSpannerConfig();
+ return withSpannerConfig(config.withEmulatorHost(emulatorHost));
+ }
+
+ public ReadAll withEmulatorHost(String emulatorHost) {
+ return
withEmulatorHost(ValueProvider.StaticValueProvider.of(emulatorHost));
+ }
+
/** Specifies the Cloud Spanner database. */
public ReadAll withDatabaseId(ValueProvider<String> databaseId) {
SpannerConfig config = getSpannerConfig();
@@ -589,6 +603,16 @@ public class SpannerIO {
return withHost(ValueProvider.StaticValueProvider.of(host));
}
+ /** Specifies the Cloud Spanner emulator host. */
+ public Read withEmulatorHost(ValueProvider<String> emulatorHost) {
+ SpannerConfig config = getSpannerConfig();
+ return withSpannerConfig(config.withEmulatorHost(emulatorHost));
+ }
+
+ public Read withEmulatorHost(String emulatorHost) {
+ return
withEmulatorHost(ValueProvider.StaticValueProvider.of(emulatorHost));
+ }
+
/** If true the uses Cloud Spanner batch API. */
public Read withBatching(boolean batching) {
return toBuilder().setBatching(batching).build();
@@ -666,7 +690,7 @@ public class SpannerIO {
+ "columns to set with withColumns method");
checkArgument(
!getReadOperation().getColumns().isEmpty(),
- "For a read operation SpannerIO.read() requires a"
+ "For a read operation SpannerIO.read() requires a non-empty"
+ " list of columns to set with withColumns method");
} else {
throw new IllegalArgumentException(
@@ -681,6 +705,37 @@ public class SpannerIO {
.withTransaction(getTransaction());
return input.apply(Create.of(getReadOperation())).apply("Execute query",
readAll);
}
+
+ SerializableFunction<Struct, Row> getFormatFn() {
+ return (SerializableFunction<Struct, Row>)
+ input ->
+ Row.withSchema(Schema.builder().addInt64Field("Key").build())
+ .withFieldValue("Key", 3L)
+ .build();
+ }
+ }
+
+ static class ReadRows extends PTransform<PBegin, PCollection<Row>> {
+ Read read;
+ Schema schema;
+
+ public ReadRows(Read read, Schema schema) {
+ super("Read rows");
+ this.read = read;
+ this.schema = schema;
+ }
+
+ @Override
+ public PCollection<Row> expand(PBegin input) {
+ return input
+ .apply(read)
+ .apply(
+ MapElements.into(TypeDescriptor.of(Row.class))
+ .via(
+ (SerializableFunction<Struct, Row>)
+ struct -> StructUtils.structToBeamRow(struct,
schema)))
+ .setRowSchema(schema);
+ }
}
/**
@@ -756,6 +811,16 @@ public class SpannerIO {
return withHost(ValueProvider.StaticValueProvider.of(host));
}
+ /** Specifies the Cloud Spanner emulator host. */
+ public CreateTransaction withEmulatorHost(ValueProvider<String>
emulatorHost) {
+ SpannerConfig config = getSpannerConfig();
+ return withSpannerConfig(config.withEmulatorHost(emulatorHost));
+ }
+
+ public CreateTransaction withEmulatorHost(String emulatorHost) {
+ return
withEmulatorHost(ValueProvider.StaticValueProvider.of(emulatorHost));
+ }
+
@VisibleForTesting
CreateTransaction withServiceFactory(ServiceFactory<Spanner,
SpannerOptions> serviceFactory) {
SpannerConfig config = getSpannerConfig();
@@ -879,10 +944,20 @@ public class SpannerIO {
return withHost(ValueProvider.StaticValueProvider.of(host));
}
+ /** Specifies the Cloud Spanner emulator host. */
+ public Write withEmulatorHost(ValueProvider<String> emulatorHost) {
+ SpannerConfig config = getSpannerConfig();
+ return withSpannerConfig(config.withEmulatorHost(emulatorHost));
+ }
+
+ public Write withEmulatorHost(String emulatorHost) {
+ return
withEmulatorHost(ValueProvider.StaticValueProvider.of(emulatorHost));
+ }
+
/**
* Specifies the deadline for the Commit API call. Default is 15 secs.
DEADLINE_EXCEEDED errors
* will prompt a backoff/retry until the value of {@link
#withMaxCumulativeBackoff(Duration)} is
- * reached. DEADLINE_EXCEEDED errors are are reported with logging and
counters.
+ * reached. DEADLINE_EXCEEDED errors are reported with logging and
counters.
*/
public Write withCommitDeadline(Duration commitDeadline) {
SpannerConfig config = getSpannerConfig();
@@ -1002,6 +1077,32 @@ public class SpannerIO {
}
}
+ static class WriteRows extends PTransform<PCollection<Row>, PDone> {
+ private final Write write;
+ private final Mutation.Op operation;
+ private final String table;
+
+ private WriteRows(Write write, Mutation.Op operation, String table) {
+ this.write = write;
+ this.operation = operation;
+ this.table = table;
+ }
+
+ public static WriteRows of(Write write, Mutation.Op operation, String
table) {
+ return new WriteRows(write, operation, table);
+ }
+
+ @Override
+ public PDone expand(PCollection<Row> input) {
+ input
+ .apply(
+ MapElements.into(TypeDescriptor.of(Mutation.class))
+ .via(MutationUtils.beamRowToMutationFn(operation, table)))
+ .apply(write);
+ return PDone.in(input.getPipeline());
+ }
+ }
+
/** Same as {@link Write} but supports grouped mutations. */
public static class WriteGrouped
extends PTransform<PCollection<MutationGroup>, SpannerWriteResult> {
@@ -1037,8 +1138,7 @@ public class SpannerIO {
LOG.info("Batching of mutationGroups is disabled");
TypeDescriptor<Iterable<MutationGroup>> descriptor =
new TypeDescriptor<Iterable<MutationGroup>>() {};
- batches =
- input.apply(MapElements.into(descriptor).via(element ->
ImmutableList.of(element)));
+ batches =
input.apply(MapElements.into(descriptor).via(ImmutableList::of));
} else {
// First, read the Cloud Spanner schema.
@@ -1272,7 +1372,7 @@ public class SpannerIO {
out.output(
mutationsToSort.subList(batchStart, batchEnd).stream()
.map(o -> o.mutationGroup)
- .collect(Collectors.toList()));
+ .collect(toList()));
}
@ProcessElement
diff --git
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerTransformRegistrar.java
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerTransformRegistrar.java
new file mode 100644
index 0000000..891dd37
--- /dev/null
+++
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerTransformRegistrar.java
@@ -0,0 +1,374 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.spanner;
+
+import static com.google.cloud.spanner.TimestampBound.Mode.MAX_STALENESS;
+import static com.google.cloud.spanner.TimestampBound.Mode.READ_TIMESTAMP;
+
+import com.google.auto.service.AutoService;
+import com.google.cloud.Timestamp;
+import com.google.cloud.spanner.Mutation;
+import com.google.cloud.spanner.TimestampBound;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+import org.apache.beam.model.pipeline.v1.SchemaApi;
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.annotations.Experimental.Kind;
+import org.apache.beam.sdk.expansion.ExternalTransformRegistrar;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.SchemaTranslation;
+import org.apache.beam.sdk.transforms.ExternalTransformBuilder;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.values.PBegin;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PDone;
+import org.apache.beam.sdk.values.Row;
+import
org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.InvalidProtocolBufferException;
+import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.Duration;
+
+/**
+ * Exposes {@link SpannerIO.WriteRows} and {@link SpannerIO.ReadRows} as an
external transform for
+ * cross-language usage.
+ */
+@Experimental(Kind.PORTABILITY)
+@AutoService(ExternalTransformRegistrar.class)
+public class SpannerTransformRegistrar implements ExternalTransformRegistrar {
+ public static final String INSERT_URN =
"beam:external:java:spanner:insert:v1";
+ public static final String UPDATE_URN =
"beam:external:java:spanner:update:v1";
+ public static final String REPLACE_URN =
"beam:external:java:spanner:replace:v1";
+ public static final String INSERT_OR_UPDATE_URN =
+ "beam:external:java:spanner:insert_or_update:v1";
+ public static final String DELETE_URN =
"beam:external:java:spanner:delete:v1";
+ public static final String READ_URN = "beam:external:java:spanner:read:v1";
+
+ @Override
+ @NonNull
+ public Map<String, ExternalTransformBuilder<?, ?, ?>>
knownBuilderInstances() {
+ return ImmutableMap.<String, ExternalTransformBuilder<?, ?, ?>>builder()
+ .put(INSERT_URN, new InsertBuilder())
+ .put(UPDATE_URN, new UpdateBuilder())
+ .put(REPLACE_URN, new ReplaceBuilder())
+ .put(INSERT_OR_UPDATE_URN, new InsertOrUpdateBuilder())
+ .put(DELETE_URN, new DeleteBuilder())
+ .put(READ_URN, new ReadBuilder())
+ .build();
+ }
+
+ public abstract static class CrossLanguageConfiguration {
+ String instanceId = "";
+ String databaseId = "";
+ String projectId = "";
+ @Nullable String host;
+ @Nullable String emulatorHost;
+
+ public void setInstanceId(String instanceId) {
+ this.instanceId = instanceId;
+ }
+
+ public void setDatabaseId(String databaseId) {
+ this.databaseId = databaseId;
+ }
+
+ public void setProjectId(String projectId) {
+ this.projectId = projectId;
+ }
+
+ public void setHost(@Nullable String host) {
+ this.host = host;
+ }
+
+ public void setEmulatorHost(@Nullable String emulatorHost) {
+ this.emulatorHost = emulatorHost;
+ }
+
+ void checkMandatoryFields() {
+ if (projectId.isEmpty()) {
+ throw new IllegalArgumentException("projectId can't be empty");
+ }
+ if (databaseId.isEmpty()) {
+ throw new IllegalArgumentException("databaseId can't be empty");
+ }
+ if (instanceId.isEmpty()) {
+ throw new IllegalArgumentException("instanceId can't be empty");
+ }
+ }
+ }
+
+ @Experimental(Kind.PORTABILITY)
+ public static class ReadBuilder
+ implements ExternalTransformBuilder<ReadBuilder.Configuration, PBegin,
PCollection<Row>> {
+
+ public static class Configuration extends CrossLanguageConfiguration {
+ // TODO: BEAM-10851 Come up with something to determine schema without
this explicit parameter
+ private Schema schema = Schema.builder().build();
+ private @Nullable String sql;
+ private @Nullable String table;
+ private @Nullable Boolean batching;
+ private @Nullable String timestampBoundMode;
+ private @Nullable String readTimestamp;
+ private @Nullable String timeUnit;
+ private @Nullable Long staleness;
+
+ public void setSql(@Nullable String sql) {
+ this.sql = sql;
+ }
+
+ public void setTable(@Nullable String table) {
+ this.table = table;
+ }
+
+ public void setBatching(@Nullable Boolean batching) {
+ this.batching = batching;
+ }
+
+ public void setTimestampBoundMode(@Nullable String timestampBoundMode) {
+ this.timestampBoundMode = timestampBoundMode;
+ }
+
+ public void setSchema(byte[] schema) throws
InvalidProtocolBufferException {
+ this.schema =
SchemaTranslation.schemaFromProto(SchemaApi.Schema.parseFrom(schema));
+ }
+
+ public void setReadTimestamp(@Nullable String readTimestamp) {
+ this.readTimestamp = readTimestamp;
+ }
+
+ public void setTimeUnit(@Nullable String timeUnit) {
+ this.timeUnit = timeUnit;
+ }
+
+ public void setStaleness(@Nullable Long staleness) {
+ this.staleness = staleness;
+ }
+
+ private @Nullable TimestampBound getTimestampBound() {
+ if (timestampBoundMode == null) {
+ return null;
+ }
+ TimestampBound.Mode mode =
TimestampBound.Mode.valueOf(timestampBoundMode);
+ switch (mode) {
+ case STRONG:
+ return TimestampBound.strong();
+ case MAX_STALENESS:
+ case EXACT_STALENESS:
+ if (staleness == null) {
+ throw new NullPointerException(
+ "Staleness value cannot be empty when MAX_STALENESS or
EXACT_STALENESS mode is selected");
+ }
+ if (timeUnit == null) {
+ throw new NullPointerException(
+ "Time unit cannot be null when MAX_STALENESS or
EXACT_STALENESS mode is selected");
+ }
+ return mode == MAX_STALENESS
+ ? TimestampBound.ofMaxStaleness(staleness,
TimeUnit.valueOf(timeUnit))
+ : TimestampBound.ofExactStaleness(staleness,
TimeUnit.valueOf(timeUnit));
+ case READ_TIMESTAMP:
+ case MIN_READ_TIMESTAMP:
+ if (readTimestamp == null) {
+ throw new NullPointerException(
+ "Timestamp cannot be null when READ_TIMESTAMP or
MIN_READ_TIMESTAMP mode is selected");
+ }
+ return mode == READ_TIMESTAMP
+ ?
TimestampBound.ofReadTimestamp(Timestamp.parseTimestamp(readTimestamp))
+ :
TimestampBound.ofMinReadTimestamp(Timestamp.parseTimestamp(readTimestamp));
+ default:
+ throw new IllegalArgumentException("Unknown timestamp bound mode:
" + mode);
+ }
+ }
+
+ public ReadOperation getReadOperation() {
+ if (sql != null && table != null) {
+ throw new IllegalStateException(
+ "Query and table params are mutually exclusive. Set just one of
them.");
+ }
+ ReadOperation readOperation = ReadOperation.create();
+ if (sql != null) {
+ return readOperation.withQuery(sql);
+ }
+ if (Schema.builder().build().equals(schema)) {
+ throw new IllegalArgumentException("Schema can't be empty");
+ }
+ if (table != null) {
+ return
readOperation.withTable(table).withColumns(schema.getFieldNames());
+ }
+ throw new IllegalStateException("Can't happen");
+ }
+ }
+
+ @Override
+ @NonNull
+ public PTransform<PBegin, PCollection<Row>> buildExternal(Configuration
configuration) {
+ configuration.checkMandatoryFields();
+
+ SpannerIO.Read readTransform =
+ SpannerIO.read()
+ .withProjectId(configuration.projectId)
+ .withDatabaseId(configuration.databaseId)
+ .withInstanceId(configuration.instanceId)
+ .withReadOperation(configuration.getReadOperation());
+
+ if (configuration.host != null) {
+ readTransform = readTransform.withHost(configuration.host);
+ }
+ if (configuration.emulatorHost != null) {
+ readTransform =
readTransform.withEmulatorHost(configuration.emulatorHost);
+ }
+ @Nullable TimestampBound timestampBound =
configuration.getTimestampBound();
+ if (timestampBound != null) {
+ readTransform = readTransform.withTimestampBound(timestampBound);
+ }
+ if (configuration.batching != null) {
+ readTransform = readTransform.withBatching(configuration.batching);
+ }
+
+ return new SpannerIO.ReadRows(readTransform, configuration.schema);
+ }
+ }
+
+ @Experimental(Kind.PORTABILITY)
+ public static class InsertBuilder extends WriteBuilder {
+ public InsertBuilder() {
+ super(Mutation.Op.INSERT);
+ }
+ }
+
+ @Experimental(Kind.PORTABILITY)
+ public static class UpdateBuilder extends WriteBuilder {
+ public UpdateBuilder() {
+ super(Mutation.Op.UPDATE);
+ }
+ }
+
+ @Experimental(Kind.PORTABILITY)
+ public static class InsertOrUpdateBuilder extends WriteBuilder {
+ public InsertOrUpdateBuilder() {
+ super(Mutation.Op.INSERT_OR_UPDATE);
+ }
+ }
+
+ @Experimental(Kind.PORTABILITY)
+ public static class ReplaceBuilder extends WriteBuilder {
+ public ReplaceBuilder() {
+ super(Mutation.Op.REPLACE);
+ }
+ }
+
+ @Experimental(Kind.PORTABILITY)
+ public static class DeleteBuilder extends WriteBuilder {
+ public DeleteBuilder() {
+ super(Mutation.Op.DELETE);
+ }
+ }
+
+ @Experimental(Kind.PORTABILITY)
+ private abstract static class WriteBuilder
+ implements ExternalTransformBuilder<WriteBuilder.Configuration,
PCollection<Row>, PDone> {
+
+ private final Mutation.Op operation;
+
+ WriteBuilder(Mutation.Op operation) {
+ this.operation = operation;
+ }
+
+ public static class Configuration extends CrossLanguageConfiguration {
+ private String table = "";
+ private @Nullable Long maxBatchSizeBytes;
+ private @Nullable Long maxNumberMutations;
+ private @Nullable Long maxNumberRows;
+ private @Nullable Integer groupingFactor;
+ private @Nullable Duration commitDeadline;
+ private @Nullable Duration maxCumulativeBackoff;
+
+ public void setTable(String table) {
+ this.table = table;
+ }
+
+ public void setMaxBatchSizeBytes(@Nullable Long maxBatchSizeBytes) {
+ this.maxBatchSizeBytes = maxBatchSizeBytes;
+ }
+
+ public void setMaxNumberMutations(@Nullable Long maxNumberMutations) {
+ this.maxNumberMutations = maxNumberMutations;
+ }
+
+ public void setMaxNumberRows(@Nullable Long maxNumberRows) {
+ this.maxNumberRows = maxNumberRows;
+ }
+
+ public void setGroupingFactor(@Nullable Long groupingFactor) {
+ if (groupingFactor != null) {
+ this.groupingFactor = groupingFactor.intValue();
+ }
+ }
+
+ public void setCommitDeadline(@Nullable Long commitDeadline) {
+ if (commitDeadline != null) {
+ this.commitDeadline = Duration.standardSeconds(commitDeadline);
+ }
+ }
+
+ public void setMaxCumulativeBackoff(@Nullable Long maxCumulativeBackoff)
{
+ if (maxCumulativeBackoff != null) {
+ this.maxCumulativeBackoff =
Duration.standardSeconds(maxCumulativeBackoff);
+ }
+ }
+ }
+
+ @Override
+ @NonNull
+ public PTransform<PCollection<Row>, PDone> buildExternal(Configuration
configuration) {
+ configuration.checkMandatoryFields();
+
+ SpannerIO.Write writeTransform =
+ SpannerIO.write()
+ .withProjectId(configuration.projectId)
+ .withDatabaseId(configuration.databaseId)
+ .withInstanceId(configuration.instanceId);
+
+ if (configuration.maxBatchSizeBytes != null) {
+ writeTransform =
writeTransform.withBatchSizeBytes(configuration.maxBatchSizeBytes);
+ }
+ if (configuration.maxNumberMutations != null) {
+ writeTransform =
writeTransform.withMaxNumMutations(configuration.maxNumberMutations);
+ }
+ if (configuration.maxNumberRows != null) {
+ writeTransform =
writeTransform.withMaxNumRows(configuration.maxNumberRows);
+ }
+ if (configuration.groupingFactor != null) {
+ writeTransform =
writeTransform.withGroupingFactor(configuration.groupingFactor);
+ }
+ if (configuration.host != null) {
+ writeTransform = writeTransform.withHost(configuration.host);
+ }
+ if (configuration.emulatorHost != null) {
+ writeTransform =
writeTransform.withEmulatorHost(configuration.emulatorHost);
+ }
+ if (configuration.commitDeadline != null) {
+ writeTransform =
writeTransform.withCommitDeadline(configuration.commitDeadline);
+ }
+ if (configuration.maxCumulativeBackoff != null) {
+ writeTransform =
+
writeTransform.withMaxCumulativeBackoff(configuration.maxCumulativeBackoff);
+ }
+ return SpannerIO.WriteRows.of(writeTransform, operation,
configuration.table);
+ }
+ }
+}
diff --git
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/StructUtils.java
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/StructUtils.java
new file mode 100644
index 0000000..c97fcec
--- /dev/null
+++
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/StructUtils.java
@@ -0,0 +1,387 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.spanner;
+
+import static java.util.stream.Collectors.toList;
+import static
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull;
+
+import com.google.cloud.ByteArray;
+import com.google.cloud.Timestamp;
+import com.google.cloud.spanner.Struct;
+import com.google.cloud.spanner.Type;
+import java.math.BigDecimal;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.StreamSupport;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.values.Row;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.DateTime;
+import org.joda.time.Instant;
+import org.joda.time.ReadableDateTime;
+
+final class StructUtils {
+
+ // It's not possible to pass nulls as values even with a field is nullable
+ @SuppressWarnings({
+ "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402)
+ })
+ public static Row structToBeamRow(Struct struct, Schema schema) {
+ Map<String, @Nullable Object> structValues =
+ schema.getFields().stream()
+ .collect(
+ HashMap::new,
+ (map, field) -> map.put(field.getName(),
getStructValue(struct, field)),
+ Map::putAll);
+ return Row.withSchema(schema).withFieldValues(structValues).build();
+ }
+
+ public static Struct beamRowToStruct(Row row) {
+ Struct.Builder structBuilder = Struct.newBuilder();
+ List<Schema.Field> fields = row.getSchema().getFields();
+ fields.forEach(
+ field -> {
+ String column = field.getName();
+ switch (field.getType().getTypeName()) {
+ case ROW:
+ @Nullable Row subRow = row.getRow(column);
+ if (subRow == null) {
+
structBuilder.set(column).to(beamTypeToSpannerType(field.getType()), null);
+ } else {
+ structBuilder
+ .set(column)
+ .to(beamTypeToSpannerType(field.getType()),
beamRowToStruct(subRow));
+ }
+ break;
+ case ARRAY:
+ addIterableToStructBuilder(structBuilder, row.getArray(column),
field);
+ break;
+ case ITERABLE:
+ addIterableToStructBuilder(structBuilder,
row.getIterable(column), field);
+ break;
+ case FLOAT:
+ @Nullable Float floatValue = row.getFloat(column);
+ if (floatValue == null) {
+ structBuilder.set(column).to((Double) null);
+ } else {
+ structBuilder.set(column).to(floatValue);
+ }
+ break;
+ case DOUBLE:
+ structBuilder.set(column).to(row.getDouble(column));
+ break;
+ case INT16:
+ @Nullable Short int16 = row.getInt16(column);
+ if (int16 == null) {
+ structBuilder.set(column).to((Long) null);
+ } else {
+ structBuilder.set(column).to(int16);
+ }
+ break;
+ case INT32:
+ @Nullable Integer int32 = row.getInt32(column);
+ if (int32 == null) {
+ structBuilder.set(column).to((Long) null);
+ } else {
+ structBuilder.set(column).to(int32);
+ }
+ break;
+ case INT64:
+ structBuilder.set(column).to(row.getInt64(column));
+ break;
+ case DECIMAL:
+ @Nullable BigDecimal decimal = row.getDecimal(column);
+ // BigDecimal is not nullable
+ if (decimal == null) {
+ checkNotNull(decimal, "Null decimal at column " + column);
+ } else {
+ structBuilder.set(column).to(decimal);
+ }
+ break;
+ // TODO: implement logical type date and timestamp
+ case DATETIME:
+ @Nullable ReadableDateTime dateTime = row.getDateTime(column);
+ if (dateTime == null) {
+ structBuilder.set(column).to((Timestamp) null);
+ } else {
+
structBuilder.set(column).to(Timestamp.parseTimestamp(dateTime.toString()));
+ }
+ break;
+ case STRING:
+ structBuilder.set(column).to(row.getString(column));
+ break;
+ case BYTE:
+ @Nullable Byte byteValue = row.getByte(column);
+ if (byteValue == null) {
+ structBuilder.set(column).to((Long) null);
+ } else {
+ structBuilder.set(column).to(byteValue);
+ }
+ break;
+ case BYTES:
+ byte @Nullable [] bytes = row.getBytes(column);
+ if (bytes == null) {
+ structBuilder.set(column).to((ByteArray) null);
+ } else {
+ structBuilder.set(column).to(ByteArray.copyFrom(bytes));
+ }
+ break;
+ case BOOLEAN:
+ structBuilder.set(column).to(row.getBoolean(column));
+ break;
+ default:
+ throw new IllegalArgumentException(
+ String.format(
+ "Unsupported beam type '%s' while translating row to
struct.",
+ field.getType().getTypeName()));
+ }
+ });
+ return structBuilder.build();
+ }
+
+ public static Type beamTypeToSpannerType(Schema.FieldType beamType) {
+ switch (beamType.getTypeName()) {
+ case ARRAY:
+ case ITERABLE:
+ Schema.@Nullable FieldType elementType =
beamType.getCollectionElementType();
+ if (elementType == null) {
+ throw new NullPointerException("Null element type");
+ } else {
+ return Type.array(simpleBeamTypeToSpannerType(elementType));
+ }
+ default:
+ return simpleBeamTypeToSpannerType(beamType);
+ }
+ }
+
+ private static Type simpleBeamTypeToSpannerType(Schema.FieldType beamType) {
+ switch (beamType.getTypeName()) {
+ case ROW:
+ @Nullable Schema schema = beamType.getRowSchema();
+ if (schema == null) {
+ throw new NullPointerException("Null schema");
+ } else {
+ return
Type.struct(translateRowFieldsToStructFields(schema.getFields()));
+ }
+ case BYTES:
+ return Type.bytes();
+ case BYTE:
+ case INT64:
+ case INT32:
+ case INT16:
+ return Type.int64();
+ case DOUBLE:
+ case FLOAT:
+ return Type.float64();
+ case DECIMAL:
+ return Type.numeric();
+ case STRING:
+ return Type.string();
+ case BOOLEAN:
+ return Type.bool();
+ // TODO: implement logical type date and timestamp
+ case DATETIME:
+ return Type.timestamp();
+ default:
+ throw new IllegalArgumentException(
+ String.format(
+ "Unable to translate beam type %s to Spanner type",
beamType.getTypeName()));
+ }
+ }
+
+ private static Iterable<Type.StructField> translateRowFieldsToStructFields(
+ List<Schema.Field> rowFields) {
+ return rowFields.stream()
+ .map(field -> Type.StructField.of(field.getName(),
beamTypeToSpannerType(field.getType())))
+ .collect(toList());
+ }
+
+ @SuppressWarnings("unchecked")
+ private static void addIterableToStructBuilder(
+ Struct.Builder structBuilder, @Nullable Iterable<Object> iterable,
Schema.Field field) {
+ String column = field.getName();
+ Schema.FieldType beamIterableType =
field.getType().getCollectionElementType();
+ if (beamIterableType == null) {
+ throw new NullPointerException("Null collection element type at field "
+ field.getName());
+ }
+ Schema.TypeName beamIterableTypeName = beamIterableType.getTypeName();
+ switch (beamIterableTypeName) {
+ case ROW:
+ if (iterable == null) {
+
structBuilder.set(column).toStructArray(beamTypeToSpannerType(beamIterableType),
null);
+ } else {
+ structBuilder
+ .set(column)
+ .toStructArray(
+ beamTypeToSpannerType(beamIterableType),
+ StreamSupport.stream(iterable.spliterator(), false)
+ .map(row -> beamRowToStruct((Row) row))
+ .collect(toList()));
+ }
+ break;
+ case INT16:
+ case INT32:
+ case INT64:
+ case BYTE:
+ structBuilder.set(column).toInt64Array((Iterable<Long>) ((Object)
iterable));
+ break;
+ case FLOAT:
+ case DOUBLE:
+ structBuilder.set(column).toFloat64Array((Iterable<Double>) ((Object)
iterable));
+ break;
+ case DECIMAL:
+ structBuilder.set(column).toNumericArray((Iterable<BigDecimal>)
((Object) iterable));
+ break;
+ case BOOLEAN:
+ structBuilder.set(column).toBoolArray((Iterable<Boolean>) ((Object)
iterable));
+ break;
+ case BYTES:
+ if (iterable == null) {
+ structBuilder.set(column).toBytesArray(null);
+ } else {
+ structBuilder
+ .set(column)
+ .toBytesArray(
+ StreamSupport.stream(iterable.spliterator(), false)
+ .map(bytes -> ByteArray.copyFrom((byte[]) bytes))
+ .collect(toList()));
+ }
+ break;
+ case STRING:
+ structBuilder.set(column).toStringArray((Iterable<String>) ((Object)
iterable));
+ break;
+ // TODO: implement logical date and datetime
+ case DATETIME:
+ if (iterable == null) {
+ structBuilder.set(column).toTimestampArray(null);
+ } else {
+ structBuilder
+ .set(column)
+ .toTimestampArray(
+ StreamSupport.stream(iterable.spliterator(), false)
+ .map(timestamp ->
Timestamp.parseTimestamp(timestamp.toString()))
+ .collect(toList()));
+ }
+ break;
+ default:
+ throw new IllegalArgumentException(
+ String.format(
+ "Unsupported iterable type '%s' while translating row to
struct.",
+ beamIterableType.getTypeName()));
+ }
+ }
+
+ private static @Nullable Object getStructValue(Struct struct, Schema.Field
field) {
+ String column = field.getName();
+ Type.Code typeCode = struct.getColumnType(column).getCode();
+ if (struct.isNull(column)) {
+ return null;
+ }
+ switch (typeCode) {
+ case BOOL:
+ return struct.getBoolean(column);
+ case BYTES:
+ return struct.getBytes(column).toByteArray();
+ // TODO: implement logical datetime
+ case TIMESTAMP:
+ return
Instant.ofEpochSecond(struct.getTimestamp(column).getSeconds()).toDateTime();
+ // TODO: implement logical date
+ case DATE:
+ return DateTime.parse(struct.getDate(column).toString());
+ case INT64:
+ return struct.getLong(column);
+ case FLOAT64:
+ return struct.getDouble(column);
+ case NUMERIC:
+ return struct.getBigDecimal(column);
+ case STRING:
+ return struct.getString(column);
+ case ARRAY:
+ return getStructArrayValue(
+ struct, struct.getColumnType(column).getArrayElementType(), field);
+ case STRUCT:
+ @Nullable Schema schema = field.getType().getRowSchema();
+ if (schema == null) {
+ throw new NullPointerException("Null schema at field " +
field.getName());
+ } else {
+ return structToBeamRow(struct.getStruct(column), schema);
+ }
+ default:
+ throw new RuntimeException(
+ String.format("Unsupported spanner type %s for column %s.",
typeCode, column));
+ }
+ }
+
+ private static @Nullable Object getStructArrayValue(
+ Struct struct, Type arrayType, Schema.Field field) {
+ Type.Code arrayCode = arrayType.getCode();
+ String column = field.getName();
+ if (struct.isNull(column)) {
+ return null;
+ }
+ switch (arrayCode) {
+ case BOOL:
+ return struct.getBooleanList(column);
+ case BYTES:
+ return struct.getBytesList(column);
+ // TODO: implement logical datetime
+ case TIMESTAMP:
+ return struct.getTimestampList(column).stream()
+ .map(timestamp ->
Instant.ofEpochSecond(timestamp.getSeconds()).toDateTime())
+ .collect(toList());
+ // TODO: implement logical date
+ case DATE:
+ return struct.getDateList(column).stream()
+ .map(date -> DateTime.parse(date.toString()))
+ .collect(toList());
+ case INT64:
+ return struct.getLongList(column);
+ case FLOAT64:
+ return struct.getDoubleList(column);
+ case STRING:
+ return struct.getStringList(column);
+ case NUMERIC:
+ return struct.getBigDecimal(column);
+ case ARRAY:
+ throw new IllegalStateException(
+ String.format("Column %s has array of arrays which is prohibited
in Spanner.", column));
+ case STRUCT:
+ return struct.getStructList(column).stream()
+ .map(
+ structElem -> {
+ Schema.@Nullable FieldType fieldType =
field.getType().getCollectionElementType();
+ if (fieldType == null) {
+ throw new NullPointerException(
+ "Null collection element type at field " +
field.getName());
+ }
+
+ @Nullable Schema elementSchema = fieldType.getRowSchema();
+ if (elementSchema == null) {
+ throw new NullPointerException(
+ "Null schema element type at field " +
field.getName());
+ }
+ return structToBeamRow(structElem, elementSchema);
+ })
+ .collect(toList());
+ default:
+ throw new RuntimeException(
+ String.format("Unsupported spanner array type %s for column %s.",
arrayCode, column));
+ }
+ }
+}
diff --git
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/MutationUtilsTest.java
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/MutationUtilsTest.java
new file mode 100644
index 0000000..b0bafa7
--- /dev/null
+++
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/MutationUtilsTest.java
@@ -0,0 +1,285 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.spanner;
+
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static
org.apache.beam.sdk.io.gcp.spanner.MutationUtils.beamRowToMutationFn;
+import static org.junit.Assert.assertEquals;
+
+import com.google.cloud.ByteArray;
+import com.google.cloud.Timestamp;
+import com.google.cloud.spanner.Key;
+import com.google.cloud.spanner.Mutation;
+import com.google.cloud.spanner.Struct;
+import com.google.cloud.spanner.Type;
+import java.util.List;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.values.Row;
+import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.joda.time.DateTime;
+import org.junit.Test;
+
+public class MutationUtilsTest {
+ private static final Schema EMPTY_SCHEMA = Schema.builder().build();
+ private static final Schema INT64_SCHEMA =
Schema.builder().addInt64Field("int64").build();
+ private static final Row EMPTY_ROW = Row.withSchema(EMPTY_SCHEMA).build();
+ private static final Row INT64_ROW =
+ Row.withSchema(INT64_SCHEMA).withFieldValue("int64", 3L).build();
+ private static final Struct EMPTY_STRUCT = Struct.newBuilder().build();
+ private static final Struct INT64_STRUCT =
Struct.newBuilder().set("int64").to(3L).build();
+ private static final String TABLE = "some_table";
+
+ private static final Schema WRITE_ROW_SCHEMA =
+ Schema.builder()
+ .addNullableField("f_int64", Schema.FieldType.INT64)
+ .addNullableField("f_float64", Schema.FieldType.DOUBLE)
+ .addNullableField("f_string", Schema.FieldType.STRING)
+ .addNullableField("f_bytes", Schema.FieldType.BYTES)
+ .addNullableField("f_date_time", Schema.FieldType.DATETIME)
+ .addNullableField("f_bool", Schema.FieldType.BOOLEAN)
+ .addNullableField("f_struct", Schema.FieldType.row(EMPTY_SCHEMA))
+ .addNullableField("f_struct_int64",
Schema.FieldType.row(INT64_SCHEMA))
+ .addNullableField("f_array",
Schema.FieldType.array(Schema.FieldType.INT64))
+ .addNullableField(
+ "f_struct_array",
Schema.FieldType.array(Schema.FieldType.row(INT64_SCHEMA)))
+ .build();
+
+ private static final Row WRITE_ROW =
+ Row.withSchema(WRITE_ROW_SCHEMA)
+ .withFieldValue("f_int64", 1L)
+ .withFieldValue("f_float64", 1.1)
+ .withFieldValue("f_string", "donald_duck")
+ .withFieldValue("f_bytes", "some_bytes".getBytes(UTF_8))
+ .withFieldValue("f_date_time",
DateTime.parse("2077-10-15T00:00:00+00:00"))
+ .withFieldValue("f_bool", false)
+ .withFieldValue("f_struct", EMPTY_ROW)
+ .withFieldValue("f_struct_int64", INT64_ROW)
+ .withFieldValue("f_array", ImmutableList.of(2L, 3L))
+ .withFieldValue("f_struct_array", ImmutableList.of(INT64_ROW,
INT64_ROW))
+ .build();
+
+ private static final Row WRITE_ROW_NULLS =
+ Row.withSchema(WRITE_ROW_SCHEMA)
+ .addValue(null)
+ .addValue(null)
+ .addValue(null)
+ .addValue(null)
+ .addValue(null)
+ .addValue(null)
+ .addValue(null)
+ .addValue(null)
+ .addValue(null)
+ .addValue(null)
+ .build();
+
+ private static final Schema KEY_SCHEMA =
+ Schema.builder()
+ .addNullableField("f_int64", Schema.FieldType.INT64)
+ .addNullableField("f_float64", Schema.FieldType.DOUBLE)
+ .addNullableField("f_string", Schema.FieldType.STRING)
+ .addNullableField("f_bytes", Schema.FieldType.BYTES)
+ .addNullableField("f_date_time", Schema.FieldType.DATETIME)
+ .addNullableField("f_bool", Schema.FieldType.BOOLEAN)
+ .build();
+
+ private static final Row KEY_ROW =
+ Row.withSchema(KEY_SCHEMA)
+ .withFieldValue("f_int64", 1L)
+ .withFieldValue("f_float64", 1.1)
+ .withFieldValue("f_string", "donald_duck")
+ .withFieldValue("f_bytes", "some_bytes".getBytes(UTF_8))
+ .withFieldValue("f_date_time",
DateTime.parse("2077-10-15T00:00:00+00:00"))
+ .withFieldValue("f_bool", false)
+ .build();
+
+ private static final Row KEY_ROW_NULLS =
+ Row.withSchema(KEY_SCHEMA)
+ .addValue(null)
+ .addValue(null)
+ .addValue(null)
+ .addValue(null)
+ .addValue(null)
+ .addValue(null)
+ .build();
+
+ @Test
+ public void testCreateInsertMutationFromRow() {
+ Mutation expectedMutation = createMutation(Mutation.Op.INSERT);
+ Mutation mutation = beamRowToMutationFn(Mutation.Op.INSERT,
TABLE).apply(WRITE_ROW);
+ assertEquals(expectedMutation, mutation);
+ }
+
+ @Test
+ public void testCreateUpdateMutationFromRow() {
+ Mutation expectedMutation = createMutation(Mutation.Op.UPDATE);
+ Mutation mutation = beamRowToMutationFn(Mutation.Op.UPDATE,
TABLE).apply(WRITE_ROW);
+ assertEquals(expectedMutation, mutation);
+ }
+
+ @Test
+ public void testCreateReplaceMutationFromRow() {
+ Mutation expectedMutation = createMutation(Mutation.Op.REPLACE);
+ Mutation mutation = beamRowToMutationFn(Mutation.Op.REPLACE,
TABLE).apply(WRITE_ROW);
+ assertEquals(expectedMutation, mutation);
+ }
+
+ @Test
+ public void testCreateInsertOrUpdateMutationFromRow() {
+ Mutation expectedMutation = createMutation(Mutation.Op.INSERT_OR_UPDATE);
+ Mutation mutation = beamRowToMutationFn(Mutation.Op.INSERT_OR_UPDATE,
TABLE).apply(WRITE_ROW);
+ assertEquals(expectedMutation, mutation);
+ }
+
+ @Test
+ public void testCreateDeleteMutationFromRow() {
+ Mutation expectedMutation = createDeleteMutation();
+ Mutation mutation = beamRowToMutationFn(Mutation.Op.DELETE,
TABLE).apply(KEY_ROW);
+ assertEquals(expectedMutation, mutation);
+ }
+
+ @Test
+ public void testCreateInsertMutationFromRowWithNulls() {
+ Mutation expectedMutation = createMutationNulls(Mutation.Op.INSERT);
+ Mutation mutation = beamRowToMutationFn(Mutation.Op.INSERT,
TABLE).apply(WRITE_ROW_NULLS);
+ assertEquals(expectedMutation, mutation);
+ }
+
+ @Test
+ public void testCreateInsertOrUpdateMutationFromRowWithNulls() {
+ Mutation expectedMutation =
createMutationNulls(Mutation.Op.INSERT_OR_UPDATE);
+ Mutation mutation =
+ beamRowToMutationFn(Mutation.Op.INSERT_OR_UPDATE,
TABLE).apply(WRITE_ROW_NULLS);
+ assertEquals(expectedMutation, mutation);
+ }
+
+ @Test
+ public void testCreateUpdateMutationFromRowWithNulls() {
+ Mutation expectedMutation = createMutationNulls(Mutation.Op.UPDATE);
+ Mutation mutation = beamRowToMutationFn(Mutation.Op.UPDATE,
TABLE).apply(WRITE_ROW_NULLS);
+ assertEquals(expectedMutation, mutation);
+ }
+
+ @Test
+ public void testCreateReplaceMutationFromRowWithNulls() {
+ Mutation expectedMutation = createMutationNulls(Mutation.Op.REPLACE);
+ Mutation mutation = beamRowToMutationFn(Mutation.Op.REPLACE,
TABLE).apply(WRITE_ROW_NULLS);
+ assertEquals(expectedMutation, mutation);
+ }
+
+ @Test
+ public void testCreateDeleteMutationFromRowWithNulls() {
+ Mutation expectedMutation = createDeleteMutationNulls();
+ Mutation mutation = beamRowToMutationFn(Mutation.Op.DELETE,
TABLE).apply(KEY_ROW_NULLS);
+ assertEquals(expectedMutation, mutation);
+ }
+
+ private static Mutation createDeleteMutation() {
+ Key key =
+ Key.newBuilder()
+ .append(1L)
+ .append(1.1)
+ .append("donald_duck")
+ .append(ByteArray.copyFrom("some_bytes".getBytes(UTF_8)))
+ .append(Timestamp.parseTimestamp("2077-10-15T00:00:00"))
+ .append(false)
+ .build();
+ return Mutation.delete(TABLE, key);
+ }
+
+ private static Mutation createDeleteMutationNulls() {
+ Key key =
+ Key.newBuilder()
+ .append((Long) null)
+ .append((Double) null)
+ .append((String) null)
+ .append((ByteArray) null)
+ .append((Timestamp) null)
+ .append((Boolean) null)
+ .build();
+ return Mutation.delete(TABLE, key);
+ }
+
+ private static Mutation createMutation(Mutation.Op operation) {
+ Mutation.WriteBuilder builder = chooseBuilder(operation);
+ return builder
+ .set("f_int64")
+ .to(1L)
+ .set("f_float64")
+ .to(1.1)
+ .set("f_string")
+ .to("donald_duck")
+ .set("f_bytes")
+ .to(ByteArray.copyFrom("some_bytes".getBytes(UTF_8)))
+ .set("f_date_time")
+ .to(Timestamp.parseTimestamp("2077-10-15T00:00:00"))
+ .set("f_bool")
+ .to(false)
+ .set("f_struct")
+ .to(EMPTY_STRUCT)
+ .set("f_struct_int64")
+ .to(Struct.newBuilder().set("int64").to(3L).build())
+ .set("f_array")
+ .toInt64Array(ImmutableList.of(2L, 3L))
+ .set("f_struct_array")
+ .toStructArray(
+ Type.struct(ImmutableList.of(Type.StructField.of("int64",
Type.int64()))),
+ ImmutableList.of(INT64_STRUCT, INT64_STRUCT))
+ .build();
+ }
+
+ private static Mutation createMutationNulls(Mutation.Op operation) {
+ Mutation.WriteBuilder builder = chooseBuilder(operation);
+ return builder
+ .set("f_int64")
+ .to((Long) null)
+ .set("f_float64")
+ .to((Double) null)
+ .set("f_string")
+ .to((String) null)
+ .set("f_bytes")
+ .to((ByteArray) null)
+ .set("f_date_time")
+ .to((Timestamp) null)
+ .set("f_bool")
+ .to((Boolean) null)
+ .set("f_struct")
+ .to(Type.struct(), null)
+ .set("f_struct_int64")
+ .to(Type.struct(Type.StructField.of("int64", Type.int64())), null)
+ .set("f_array")
+ .toInt64Array((List<Long>) null)
+ .set("f_struct_array")
+ .toStructArray(Type.struct(Type.StructField.of("int64",
Type.int64())), null)
+ .build();
+ }
+
+ private static Mutation.WriteBuilder chooseBuilder(Mutation.Op op) {
+ switch (op) {
+ case INSERT:
+ return Mutation.newInsertBuilder(TABLE);
+ case UPDATE:
+ return Mutation.newUpdateBuilder(TABLE);
+ case REPLACE:
+ return Mutation.newReplaceBuilder(TABLE);
+ case INSERT_OR_UPDATE:
+ return Mutation.newInsertOrUpdateBuilder(TABLE);
+ default:
+ throw new IllegalArgumentException("Operation '" + op + "' not
supported");
+ }
+ }
+}
diff --git
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/StructUtilsTest.java
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/StructUtilsTest.java
new file mode 100644
index 0000000..6b4a465
--- /dev/null
+++
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/StructUtilsTest.java
@@ -0,0 +1,258 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.spanner;
+
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static
org.apache.beam.sdk.io.gcp.spanner.StructUtils.beamTypeToSpannerType;
+import static org.hamcrest.Matchers.containsString;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.fail;
+
+import com.google.cloud.ByteArray;
+import com.google.cloud.Date;
+import com.google.cloud.Timestamp;
+import com.google.cloud.spanner.Struct;
+import com.google.cloud.spanner.Type;
+import java.math.BigDecimal;
+import java.util.List;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.values.Row;
+import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.DateTime;
+import org.junit.Test;
+
+public class StructUtilsTest {
+ private static final Schema EMPTY_SCHEMA = Schema.builder().build();
+ private static final Schema INT64_SCHEMA =
Schema.builder().addInt64Field("int64").build();
+
+ @Test
+ public void testStructToBeamRow() {
+ Schema schema = getSchemaTemplate().addDateTimeField("f_date").build();
+ Row row = getRowTemplate(schema).withFieldValue("f_date",
DateTime.parse("2077-10-24")).build();
+ Struct struct =
+ getStructTemplate().set("f_date").to(Date.fromYearMonthDay(2077, 10,
24)).build();
+ assertEquals(row, StructUtils.structToBeamRow(struct, schema));
+ }
+
+ @Test
+ public void testStructToBeamRowFailsColumnsDontMatch() {
+ Schema schema = Schema.builder().addInt64Field("f_int64").build();
+ Struct struct =
Struct.newBuilder().set("f_different_field").to(5L).build();
+ Exception exception =
+ assertThrows(
+ IllegalArgumentException.class, () ->
StructUtils.structToBeamRow(struct, schema));
+ checkMessage("Field not found: f_int64", exception.getMessage());
+ }
+
+ @Test
+ public void testStructToBeamRowFailsTypesDontMatch() {
+ Schema schema = Schema.builder().addInt64Field("f_int64").build();
+ Struct struct =
Struct.newBuilder().set("f_int64").to("string_value").build();
+ Exception exception =
+ assertThrows(ClassCastException.class, () ->
StructUtils.structToBeamRow(struct, schema));
+ checkMessage("java.lang.String cannot be cast to java.lang.Long",
exception.getMessage());
+ }
+
+ @Test
+ public void testBeamRowToStruct() {
+ Schema schema =
+ getSchemaTemplate()
+ .addIterableField("f_iterable", Schema.FieldType.INT64)
+ .addDecimalField("f_decimal")
+ .build();
+ Row row =
+ getRowTemplate(schema)
+ .withFieldValue("f_iterable", ImmutableList.of(20L))
+ .withFieldValue("f_decimal", BigDecimal.ONE)
+ .build();
+ Struct struct =
+ getStructTemplate()
+ .set("f_iterable")
+ .toInt64Array(ImmutableList.of(20L))
+ .set("f_decimal")
+ .to(BigDecimal.ONE)
+ .build();
+ assertEquals(struct, StructUtils.beamRowToStruct(row));
+ }
+
+ @Test
+ public void testBeamRowToStructNulls() {
+ Schema schema = getSchemaTemplate().build();
+ Row row = getRowBuilder(schema).build();
+ Struct struct = getStructTemplateNulls().build();
+ assertEquals(struct, StructUtils.beamRowToStruct(row));
+ }
+
+ @Test
+ public void testBeamRowToStructNullDecimalNullShouldFail() {
+ Schema schema =
+ getSchemaTemplate().addNullableField("f_decimal",
Schema.FieldType.DECIMAL).build();
+ Row row = getRowBuilder(schema).addValue(null).build();
+ NullPointerException npe =
+ assertThrows(NullPointerException.class, () ->
StructUtils.beamRowToStruct(row));
+ String message = npe.getMessage();
+ checkMessage("Null", message);
+ }
+
+ @Test
+ public void testBeamRowToStructFailsTypeNotSupported() {
+ Schema schema =
+ getSchemaTemplate()
+ .addMapField("f_map", Schema.FieldType.STRING,
Schema.FieldType.STRING)
+ .build();
+ Row row = getRowTemplate(schema).withFieldValue("f_map",
ImmutableMap.of("a", "b")).build();
+ Exception exception =
+ assertThrows(IllegalArgumentException.class, () ->
StructUtils.beamRowToStruct(row));
+ checkMessage(
+ "Unsupported beam type 'MAP' while translating row to struct.",
exception.getMessage());
+ }
+
+ @Test
+ public void testBeamTypeToSpannerTypeTranslation() {
+ assertEquals(Type.int64(), beamTypeToSpannerType(Schema.FieldType.INT64));
+ assertEquals(Type.int64(), beamTypeToSpannerType(Schema.FieldType.INT32));
+ assertEquals(Type.int64(), beamTypeToSpannerType(Schema.FieldType.INT16));
+ assertEquals(Type.int64(), beamTypeToSpannerType(Schema.FieldType.BYTE));
+ assertEquals(Type.bytes(), beamTypeToSpannerType(Schema.FieldType.BYTES));
+ assertEquals(Type.string(),
beamTypeToSpannerType(Schema.FieldType.STRING));
+ assertEquals(Type.float64(),
beamTypeToSpannerType(Schema.FieldType.FLOAT));
+ assertEquals(Type.float64(),
beamTypeToSpannerType(Schema.FieldType.DOUBLE));
+ assertEquals(Type.bool(), beamTypeToSpannerType(Schema.FieldType.BOOLEAN));
+ assertEquals(Type.numeric(),
beamTypeToSpannerType(Schema.FieldType.DECIMAL));
+ assertEquals(
+ Type.struct(ImmutableList.of(Type.StructField.of("int64",
Type.int64()))),
+ beamTypeToSpannerType(Schema.FieldType.row(INT64_SCHEMA)));
+ assertEquals(
+ Type.array(Type.int64()),
+ beamTypeToSpannerType(Schema.FieldType.array(Schema.FieldType.INT64)));
+ }
+
+ private Schema.Builder getSchemaTemplate() {
+ return Schema.builder()
+ .addNullableField("f_int64", Schema.FieldType.INT64)
+ .addNullableField("f_float64", Schema.FieldType.DOUBLE)
+ .addNullableField("f_string", Schema.FieldType.STRING)
+ .addNullableField("f_bytes", Schema.FieldType.BYTES)
+ .addNullableField("f_timestamp", Schema.FieldType.DATETIME)
+ .addNullableField("f_bool", Schema.FieldType.BOOLEAN)
+ .addNullableField("f_struct", Schema.FieldType.row(EMPTY_SCHEMA))
+ .addNullableField("f_struct_int64", Schema.FieldType.row(INT64_SCHEMA))
+ .addNullableField("f_array",
Schema.FieldType.array(Schema.FieldType.INT64))
+ .addNullableField(
+ "f_struct_array",
Schema.FieldType.array(Schema.FieldType.row(INT64_SCHEMA)));
+ }
+
+ private Row.FieldValueBuilder getRowTemplate(Schema schema) {
+ return Row.withSchema(schema)
+ .withFieldValue("f_int64", 1L)
+ .withFieldValue("f_float64", 5.5)
+ .withFieldValue("f_string", "ducky_doo")
+ .withFieldValue("f_bytes",
ByteArray.copyFrom("random_bytes".getBytes(UTF_8)).toByteArray())
+ .withFieldValue("f_timestamp", DateTime.parse("2077-01-10"))
+ .withFieldValue("f_bool", true)
+ .withFieldValue("f_struct", Row.withSchema(EMPTY_SCHEMA).build())
+ .withFieldValue(
+ "f_struct_int64",
Row.withSchema(INT64_SCHEMA).withFieldValue("int64", 10L).build())
+ .withFieldValue("f_array", ImmutableList.of(55L, 43L))
+ .withFieldValue(
+ "f_struct_array",
+ ImmutableList.of(
+ Row.withSchema(INT64_SCHEMA).withFieldValue("int64",
1L).build(),
+ Row.withSchema(INT64_SCHEMA).withFieldValue("int64",
2L).build()));
+ }
+
+ private Row.Builder getRowBuilder(Schema schema) {
+ return Row.withSchema(schema)
+ .addValue(null)
+ .addValue(null)
+ .addValue(null)
+ .addValue(null)
+ .addValue(null)
+ .addValue(null)
+ .addValue(null)
+ .addValue(null)
+ .addValue(null)
+ .addValue(null);
+ }
+
+ private Struct.Builder getStructTemplate() {
+ return Struct.newBuilder()
+ .set("f_int64")
+ .to(1L)
+ .set("f_float64")
+ .to(5.5)
+ .set("f_string")
+ .to("ducky_doo")
+ .set("f_bytes")
+ .to(ByteArray.copyFrom("random_bytes".getBytes(UTF_8)))
+ .set("f_timestamp")
+ .to(
+ Timestamp.ofTimeMicroseconds(
+ DateTime.parse("2077-01-10").toInstant().getMillis() * 1000L))
+ .set("f_bool")
+ .to(true)
+ .set("f_struct")
+ .to(Struct.newBuilder().build())
+ .set("f_struct_int64")
+ .to(Struct.newBuilder().set("int64").to(10L).build())
+ .set("f_array")
+ .toInt64Array(ImmutableList.of(55L, 43L))
+ .set("f_struct_array")
+ .toStructArray(
+ Type.struct(Type.StructField.of("int64", Type.int64())),
+ ImmutableList.of(
+ Struct.newBuilder().set("int64").to(1L).build(),
+ Struct.newBuilder().set("int64").to(2L).build()));
+ }
+
+ private Struct.Builder getStructTemplateNulls() {
+ return Struct.newBuilder()
+ .set("f_int64")
+ .to((Long) null)
+ .set("f_float64")
+ .to((Double) null)
+ .set("f_string")
+ .to((String) null)
+ .set("f_bytes")
+ .to((ByteArray) null)
+ .set("f_timestamp")
+ .to((Timestamp) null)
+ .set("f_bool")
+ .to((Boolean) null)
+ .set("f_struct")
+ .to(Type.struct(), null)
+ .set("f_struct_int64")
+ .to(Type.struct(Type.StructField.of("int64", Type.int64())), null)
+ .set("f_array")
+ .toInt64Array((List<Long>) null)
+ .set("f_struct_array")
+ .toStructArray(Type.struct(Type.StructField.of("int64",
Type.int64())), null);
+ }
+
+ private void checkMessage(String substring, @Nullable String message) {
+ if (message != null) {
+ assertThat(message, containsString(substring));
+ } else {
+ fail();
+ }
+ }
+}
diff --git a/sdks/python/apache_beam/io/gcp/spanner.py
b/sdks/python/apache_beam/io/gcp/spanner.py
new file mode 100644
index 0000000..dd1a0c4
--- /dev/null
+++ b/sdks/python/apache_beam/io/gcp/spanner.py
@@ -0,0 +1,635 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""PTransforms for supporting Spanner in Python pipelines.
+
+ These transforms are currently supported by Beam portable
+ Flink and Spark runners.
+
+ **Setup**
+
+ Transforms provided in this module are cross-language transforms
+ implemented in the Beam Java SDK. During the pipeline construction, Python
SDK
+ will connect to a Java expansion service to expand these transforms.
+ To facilitate this, a small amount of setup is needed before using these
+ transforms in a Beam Python pipeline.
+
+ There are several ways to setup cross-language Spanner transforms.
+
+ * Option 1: use the default expansion service
+ * Option 2: specify a custom expansion service
+
+ See below for details regarding each of these options.
+
+ *Option 1: Use the default expansion service*
+
+ This is the recommended and easiest setup option for using Python Spanner
+ transforms. This option is only available for Beam 2.26.0 and later.
+
+ This option requires following pre-requisites before running the Beam
+ pipeline.
+
+ * Install Java runtime in the computer from where the pipeline is constructed
+ and make sure that 'java' command is available.
+
+ In this option, Python SDK will either download (for released Beam version)
or
+ build (when running from a Beam Git clone) a expansion service jar and use
+ that to expand transforms. Currently Spanner transforms use the
+ 'beam-sdks-java-io-google-cloud-platform-expansion-service' jar for this
+ purpose.
+
+ *Option 2: specify a custom expansion service*
+
+ In this option, you startup your own expansion service and provide that as
+ a parameter when using the transforms provided in this module.
+
+ This option requires following pre-requisites before running the Beam
+ pipeline.
+
+ * Startup your own expansion service.
+ * Update your pipeline to provide the expansion service address when
+ initiating Spanner transforms provided in this module.
+
+ Flink Users can use the built-in Expansion Service of the Flink Runner's
+ Job Server. If you start Flink's Job Server, the expansion service will be
+ started on port 8097. For a different address, please set the
+ expansion_service parameter.
+
+ **More information**
+
+ For more information regarding cross-language transforms see:
+ - https://beam.apache.org/roadmap/portability/
+
+ For more information specific to Flink runner see:
+ - https://beam.apache.org/documentation/runners/flink/
+"""
+
+# pytype: skip-file
+
+from __future__ import absolute_import
+
+from enum import Enum
+from enum import auto
+from typing import NamedTuple
+from typing import Optional
+
+from past.builtins import unicode
+
+from apache_beam.transforms.external import BeamJarExpansionService
+from apache_beam.transforms.external import ExternalTransform
+from apache_beam.transforms.external import NamedTupleBasedPayloadBuilder
+from apache_beam.typehints.schemas import named_tuple_to_schema
+
+__all__ = [
+ 'ReadFromSpanner',
+ 'SpannerDelete',
+ 'SpannerInsert',
+ 'SpannerInsertOrUpdate',
+ 'SpannerReplace',
+ 'SpannerUpdate',
+ 'TimestampBoundMode',
+ 'TimeUnit',
+]
+
+
+def default_io_expansion_service():
+ return BeamJarExpansionService(
+ 'sdks:java:io:google-cloud-platform:expansion-service:shadowJar')
+
+
+class TimeUnit(Enum):
+ NANOSECONDS = auto()
+ MICROSECONDS = auto()
+ MILLISECONDS = auto()
+ SECONDS = auto()
+ HOURS = auto()
+ DAYS = auto()
+
+
+class TimestampBoundMode(Enum):
+ MAX_STALENESS = auto()
+ EXACT_STALENESS = auto()
+ READ_TIMESTAMP = auto()
+ MIN_READ_TIMESTAMP = auto()
+ STRONG = auto()
+
+
+class ReadFromSpannerSchema(NamedTuple):
+ instance_id: unicode
+ database_id: unicode
+ schema: bytes
+ sql: Optional[unicode]
+ table: Optional[unicode]
+ project_id: Optional[unicode]
+ host: Optional[unicode]
+ emulator_host: Optional[unicode]
+ batching: Optional[bool]
+ timestamp_bound_mode: Optional[unicode]
+ read_timestamp: Optional[unicode]
+ staleness: Optional[int]
+ time_unit: Optional[unicode]
+
+
+class ReadFromSpanner(ExternalTransform):
+ """
+ A PTransform which reads from the specified Spanner instance's database.
+
+ This transform required type of the row it has to return to provide the
+ schema. Example::
+
+ from typing import NamedTuple
+ from apache_beam import coders
+
+ class ExampleRow(NamedTuple):
+ id: int
+ name: unicode
+
+ coders.registry.register_coder(ExampleRow, coders.RowCoder)
+
+ with Pipeline() as p:
+ result = (
+ p
+ | ReadFromSpanner(
+ instance_id='your_instance_id',
+ database_id='your_database_id',
+ project_id='your_project_id',
+ row_type=ExampleRow,
+ query='SELECT * FROM some_table',
+ timestamp_bound_mode=TimestampBoundMode.MAX_STALENESS,
+ staleness=3,
+ time_unit=TimeUnit.HOURS,
+ ).with_output_types(ExampleRow))
+
+ Experimental; no backwards compatibility guarantees.
+ """
+
+ URN = 'beam:external:java:spanner:read:v1'
+
+ def __init__(
+ self,
+ project_id,
+ instance_id,
+ database_id,
+ row_type=None,
+ sql=None,
+ table=None,
+ host=None,
+ emulator_host=None,
+ batching=None,
+ timestamp_bound_mode=None,
+ read_timestamp=None,
+ staleness=None,
+ time_unit=None,
+ expansion_service=None,
+ ):
+ """
+ Initializes a read operation from Spanner.
+
+ :param project_id: Specifies the Cloud Spanner project.
+ :param instance_id: Specifies the Cloud Spanner instance.
+ :param database_id: Specifies the Cloud Spanner database.
+ :param row_type: Row type that fits the given query or table. Passed as
+ NamedTuple, e.g. NamedTuple('name', [('row_name', unicode)])
+ :param sql: An sql query to execute. It's results must fit the
+ provided row_type. Don't use when table is set.
+ :param table: A spanner table. When provided all columns from row_type
+ will be selected to query. Don't use when query is set.
+ :param batching: By default Batch API is used to read data from Cloud
+ Spanner. It is useful to disable batching when the underlying query
+ is not root-partitionable.
+ :param host: Specifies the Cloud Spanner host.
+ :param emulator_host: Specifies Spanner emulator host.
+ :param timestamp_bound_mode: Defines how Cloud Spanner will choose a
+ timestamp for a read-only transaction or a single read/query.
+ Passed as TimestampBoundMode enum. Possible values:
+ STRONG: A timestamp bound that will perform reads and queries at a
+ timestamp where all previously committed transactions are visible.
+ READ_TIMESTAMP: Returns a timestamp bound that will perform reads
+ and queries at the given timestamp.
+ MIN_READ_TIMESTAMP: Returns a timestamp bound that will perform reads
+ and queries at a timestamp chosen to be at least given timestamp value.
+ EXACT_STALENESS: Returns a timestamp bound that will perform reads and
+ queries at an exact staleness. The timestamp is chosen soon after the
+ read is started.
+ MAX_STALENESS: Returns a timestamp bound that will perform reads and
+ queries at a timestamp chosen to be at most time_unit stale.
+ :param read_timestamp: Timestamp in string. Use only when
+ timestamp_bound_mode is set to READ_TIMESTAMP or MIN_READ_TIMESTAMP.
+ :param staleness: Staleness value as int. Use only when
+ timestamp_bound_mode is set to EXACT_STALENESS or MAX_STALENESS.
+ time_unit has to be set along with this param.
+ :param time_unit: Time unit for staleness_value passed as TimeUnit enum.
+ Possible values: NANOSECONDS, MICROSECONDS, MILLISECONDS, SECONDS,
+ HOURS, DAYS.
+ :param expansion_service: The address (host:port) of the ExpansionService.
+ """
+ assert row_type
+ assert sql or table and not (sql and table)
+ staleness_value = int(staleness) if staleness else None
+
+ if staleness_value or time_unit:
+ assert staleness_value and time_unit and \
+ timestamp_bound_mode is TimestampBoundMode.MAX_STALENESS or \
+ timestamp_bound_mode is TimestampBoundMode.EXACT_STALENESS
+
+ if read_timestamp:
+ assert timestamp_bound_mode is TimestampBoundMode.MIN_READ_TIMESTAMP\
+ or timestamp_bound_mode is TimestampBoundMode.READ_TIMESTAMP
+
+ super(ReadFromSpanner, self).__init__(
+ self.URN,
+ NamedTupleBasedPayloadBuilder(
+ ReadFromSpannerSchema(
+ instance_id=instance_id,
+ database_id=database_id,
+ sql=sql,
+ table=table,
+ schema=named_tuple_to_schema(row_type).SerializeToString(),
+ project_id=project_id,
+ host=host,
+ emulator_host=emulator_host,
+ batching=batching,
+ timestamp_bound_mode=_get_enum_name(timestamp_bound_mode),
+ read_timestamp=read_timestamp,
+ staleness=staleness,
+ time_unit=_get_enum_name(time_unit),
+ ),
+ ),
+ expansion_service or default_io_expansion_service(),
+ )
+
+
+class WriteToSpannerSchema(NamedTuple):
+ project_id: unicode
+ instance_id: unicode
+ database_id: unicode
+ table: unicode
+ max_batch_size_bytes: Optional[int]
+ max_number_mutations: Optional[int]
+ max_number_rows: Optional[int]
+ grouping_factor: Optional[int]
+ host: Optional[unicode]
+ emulator_host: Optional[unicode]
+ commit_deadline: Optional[int]
+ max_cumulative_backoff: Optional[int]
+
+
+_CLASS_DOC = \
+ """
+ A PTransform which writes {operation} mutations to the specified Spanner
+ table.
+
+ This transform receives rows defined as NamedTuple. Example::
+
+ from typing import NamedTuple
+ from apache_beam import coders
+
+ class {row_type}(NamedTuple):
+ id: int
+ name: unicode
+
+ coders.registry.register_coder({row_type}, coders.RowCoder)
+
+ with Pipeline() as p:
+ _ = (
+ p
+ | 'Impulse' >> beam.Impulse()
+ | 'Generate' >> beam.FlatMap(lambda x: range(num_rows))
+ | 'To row' >> beam.Map(lambda n: {row_type}(n, str(n))
+ .with_output_types({row_type})
+ | 'Write to Spanner' >> Spanner{operation_suffix}(
+ instance_id='your_instance',
+ database_id='existing_database',
+ project_id='your_project_id',
+ table='your_table'))
+
+ Experimental; no backwards compatibility guarantees.
+ """
+
+_INIT_DOC = \
+ """
+ Initializes {operation} operation to a Spanner table.
+
+ :param project_id: Specifies the Cloud Spanner project.
+ :param instance_id: Specifies the Cloud Spanner instance.
+ :param database_id: Specifies the Cloud Spanner database.
+ :param table: Specifies the Cloud Spanner table.
+ :param max_batch_size_bytes: Specifies the batch size limit (max number of
+ bytes mutated per batch). Default value is 1048576 bytes = 1MB.
+ :param max_number_mutations: Specifies the cell mutation limit (maximum
+ number of mutated cells per batch). Default value is 5000.
+ :param max_number_rows: Specifies the row mutation limit (maximum number of
+ mutated rows per batch). Default value is 500.
+ :param grouping_factor: Specifies the multiple of max mutation (in terms
+ of both bytes per batch and cells per batch) that is used to select a
+ set of mutations to sort by key for batching. This sort uses local
+ memory on the workers, so using large values can cause out of memory
+ errors. Default value is 1000.
+ :param host: Specifies the Cloud Spanner host.
+ :param emulator_host: Specifies Spanner emulator host.
+ :param commit_deadline: Specifies the deadline for the Commit API call.
+ Default is 15 secs. DEADLINE_EXCEEDED errors will prompt a backoff/retry
+ until the value of commit_deadline is reached. DEADLINE_EXCEEDED errors
+ are ar reported with logging and counters. Pass seconds as value.
+ :param max_cumulative_backoff: Specifies the maximum cumulative backoff
+ time when retrying after DEADLINE_EXCEEDED errors. Default is 900s
+ (15min). If the mutations still have not been written after this time,
+ they are treated as a failure, and handled according to the setting of
+ failure_mode. Pass seconds as value.
+ :param expansion_service: The address (host:port) of the ExpansionService.
+ """
+
+
+def _add_doc(
+ value,
+ operation=None,
+ row_type=None,
+ operation_suffix=None,
+):
+ def _doc(obj):
+ obj.__doc__ = value.format(
+ operation=operation,
+ row_type=row_type,
+ operation_suffix=operation_suffix,
+ )
+ return obj
+
+ return _doc
+
+
+@_add_doc(
+ _CLASS_DOC,
+ operation='delete',
+ row_type='ExampleKey',
+ operation_suffix='Delete',
+)
+class SpannerDelete(ExternalTransform):
+
+ URN = 'beam:external:java:spanner:delete:v1'
+
+ @_add_doc(_INIT_DOC, operation='a delete')
+ def __init__(
+ self,
+ project_id,
+ instance_id,
+ database_id,
+ table,
+ max_batch_size_bytes=None,
+ max_number_mutations=None,
+ max_number_rows=None,
+ grouping_factor=None,
+ host=None,
+ emulator_host=None,
+ commit_deadline=None,
+ max_cumulative_backoff=None,
+ expansion_service=None,
+ ):
+ max_cumulative_backoff = int(
+ max_cumulative_backoff) if max_cumulative_backoff else None
+ commit_deadline = int(commit_deadline) if commit_deadline else None
+ super().__init__(
+ self.URN,
+ NamedTupleBasedPayloadBuilder(
+ WriteToSpannerSchema(
+ project_id=project_id,
+ instance_id=instance_id,
+ database_id=database_id,
+ table=table,
+ max_batch_size_bytes=max_batch_size_bytes,
+ max_number_mutations=max_number_mutations,
+ max_number_rows=max_number_rows,
+ grouping_factor=grouping_factor,
+ host=host,
+ emulator_host=emulator_host,
+ commit_deadline=commit_deadline,
+ max_cumulative_backoff=max_cumulative_backoff,
+ ),
+ ),
+ expansion_service=expansion_service or default_io_expansion_service(),
+ )
+
+
+@_add_doc(
+ _CLASS_DOC,
+ operation='insert',
+ row_type='ExampleRow',
+ operation_suffix='Insert',
+)
+class SpannerInsert(ExternalTransform):
+
+ URN = 'beam:external:java:spanner:insert:v1'
+
+ @_add_doc(_INIT_DOC, operation='an insert')
+ def __init__(
+ self,
+ project_id,
+ instance_id,
+ database_id,
+ table,
+ max_batch_size_bytes=None,
+ max_number_mutations=None,
+ max_number_rows=None,
+ grouping_factor=None,
+ host=None,
+ emulator_host=None,
+ commit_deadline=None,
+ max_cumulative_backoff=None,
+ expansion_service=None,
+ ):
+ max_cumulative_backoff = int(
+ max_cumulative_backoff) if max_cumulative_backoff else None
+ commit_deadline = int(commit_deadline) if commit_deadline else None
+ super().__init__(
+ self.URN,
+ NamedTupleBasedPayloadBuilder(
+ WriteToSpannerSchema(
+ project_id=project_id,
+ instance_id=instance_id,
+ database_id=database_id,
+ table=table,
+ max_batch_size_bytes=max_batch_size_bytes,
+ max_number_mutations=max_number_mutations,
+ max_number_rows=max_number_rows,
+ grouping_factor=grouping_factor,
+ host=host,
+ emulator_host=emulator_host,
+ commit_deadline=commit_deadline,
+ max_cumulative_backoff=max_cumulative_backoff,
+ ),
+ ),
+ expansion_service=expansion_service or default_io_expansion_service(),
+ )
+
+
+@_add_doc(
+ _CLASS_DOC,
+ operation='replace',
+ row_type='ExampleRow',
+ operation_suffix='Replace',
+)
+class SpannerReplace(ExternalTransform):
+
+ URN = 'beam:external:java:spanner:replace:v1'
+
+ @_add_doc(_INIT_DOC, operation='a replace')
+ def __init__(
+ self,
+ project_id,
+ instance_id,
+ database_id,
+ table,
+ max_batch_size_bytes=None,
+ max_number_mutations=None,
+ max_number_rows=None,
+ grouping_factor=None,
+ host=None,
+ emulator_host=None,
+ commit_deadline=None,
+ max_cumulative_backoff=None,
+ expansion_service=None,
+ ):
+ max_cumulative_backoff = int(
+ max_cumulative_backoff) if max_cumulative_backoff else None
+ commit_deadline = int(commit_deadline) if commit_deadline else None
+ super().__init__(
+ self.URN,
+ NamedTupleBasedPayloadBuilder(
+ WriteToSpannerSchema(
+ project_id=project_id,
+ instance_id=instance_id,
+ database_id=database_id,
+ table=table,
+ max_batch_size_bytes=max_batch_size_bytes,
+ max_number_mutations=max_number_mutations,
+ max_number_rows=max_number_rows,
+ grouping_factor=grouping_factor,
+ host=host,
+ emulator_host=emulator_host,
+ commit_deadline=commit_deadline,
+ max_cumulative_backoff=max_cumulative_backoff,
+ ),
+ ),
+ expansion_service=expansion_service or default_io_expansion_service(),
+ )
+
+
+@_add_doc(
+ _CLASS_DOC,
+ operation='insert-or-update',
+ row_type='ExampleRow',
+ operation_suffix='InsertOrUpdate',
+)
+class SpannerInsertOrUpdate(ExternalTransform):
+
+ URN = 'beam:external:java:spanner:insert_or_update:v1'
+
+ @_add_doc(_INIT_DOC, operation='an insert-or-update')
+ def __init__(
+ self,
+ project_id,
+ instance_id,
+ database_id,
+ table,
+ max_batch_size_bytes=None,
+ max_number_mutations=None,
+ max_number_rows=None,
+ grouping_factor=None,
+ host=None,
+ emulator_host=None,
+ commit_deadline=None,
+ max_cumulative_backoff=None,
+ expansion_service=None,
+ ):
+ max_cumulative_backoff = int(
+ max_cumulative_backoff) if max_cumulative_backoff else None
+ commit_deadline = int(commit_deadline) if commit_deadline else None
+ super().__init__(
+ self.URN,
+ NamedTupleBasedPayloadBuilder(
+ WriteToSpannerSchema(
+ project_id=project_id,
+ instance_id=instance_id,
+ database_id=database_id,
+ table=table,
+ max_batch_size_bytes=max_batch_size_bytes,
+ max_number_mutations=max_number_mutations,
+ max_number_rows=max_number_rows,
+ grouping_factor=grouping_factor,
+ host=host,
+ emulator_host=emulator_host,
+ commit_deadline=commit_deadline,
+ max_cumulative_backoff=max_cumulative_backoff,
+ ),
+ ),
+ expansion_service=expansion_service or default_io_expansion_service(),
+ )
+
+
+@_add_doc(
+ _CLASS_DOC,
+ operation='update',
+ row_type='ExampleRow',
+ operation_suffix='Update',
+)
+class SpannerUpdate(ExternalTransform):
+
+ URN = 'beam:external:java:spanner:update:v1'
+
+ @_add_doc(_INIT_DOC, operation='an update')
+ def __init__(
+ self,
+ project_id,
+ instance_id,
+ database_id,
+ table,
+ max_batch_size_bytes=None,
+ max_number_mutations=None,
+ max_number_rows=None,
+ grouping_factor=None,
+ host=None,
+ emulator_host=None,
+ commit_deadline=None,
+ max_cumulative_backoff=None,
+ expansion_service=None,
+ ):
+ max_cumulative_backoff = int(
+ max_cumulative_backoff) if max_cumulative_backoff else None
+ commit_deadline = int(commit_deadline) if commit_deadline else None
+ super().__init__(
+ self.URN,
+ NamedTupleBasedPayloadBuilder(
+ WriteToSpannerSchema(
+ project_id=project_id,
+ instance_id=instance_id,
+ database_id=database_id,
+ table=table,
+ max_batch_size_bytes=max_batch_size_bytes,
+ max_number_mutations=max_number_mutations,
+ max_number_rows=max_number_rows,
+ grouping_factor=grouping_factor,
+ host=host,
+ emulator_host=emulator_host,
+ commit_deadline=commit_deadline,
+ max_cumulative_backoff=max_cumulative_backoff,
+ ),
+ ),
+ expansion_service=expansion_service or default_io_expansion_service(),
+ )
+
+
+def _get_enum_name(enum):
+ return None if enum is None else enum.name
diff --git a/sdks/python/apache_beam/io/gcp/tests/xlang_spannerio_it_test.py
b/sdks/python/apache_beam/io/gcp/tests/xlang_spannerio_it_test.py
new file mode 100644
index 0000000..9fa0d4e
--- /dev/null
+++ b/sdks/python/apache_beam/io/gcp/tests/xlang_spannerio_it_test.py
@@ -0,0 +1,339 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+from __future__ import absolute_import
+
+import argparse
+import logging
+import os
+import time
+import unittest
+import uuid
+from typing import NamedTuple
+from typing import Optional
+
+from past.builtins import unicode
+
+import apache_beam as beam
+from apache_beam import coders
+from apache_beam.io.gcp.spanner import ReadFromSpanner
+from apache_beam.io.gcp.spanner import SpannerDelete
+from apache_beam.io.gcp.spanner import SpannerInsert
+from apache_beam.io.gcp.spanner import SpannerInsertOrUpdate
+from apache_beam.io.gcp.spanner import SpannerReplace
+from apache_beam.io.gcp.spanner import SpannerUpdate
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+
+# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
+try:
+ from google.cloud import spanner
+except ImportError:
+ spanner = None
+
+try:
+ from testcontainers.core.container import DockerContainer
+except ImportError:
+ DockerContainer = None
+# pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports
+
+
+class SpannerTestKey(NamedTuple):
+ f_string: unicode
+
+
+class SpannerTestRow(NamedTuple):
+ f_string: unicode
+ f_int64: Optional[int]
+ f_boolean: Optional[bool]
+
+
+class SpannerPartTestRow(NamedTuple):
+ f_string: unicode
+ f_int64: Optional[int]
+
+
[email protected](spanner is None, 'GCP dependencies are not installed.')
[email protected](
+ DockerContainer is None, 'testcontainers package is not installed.')
+class CrossLanguageSpannerIOTest(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--spanner_instance_id',
+ default='beam-test',
+ help='Spanner instance id',
+ )
+ parser.add_argument(
+ '--spanner_project_id',
+ default='beam-testing',
+ help='GCP project with spanner instance',
+ )
+ parser.add_argument(
+ '--use_real_spanner',
+ action='store_true',
+ default=False,
+ help='Whether to use emulator or real spanner instance',
+ )
+
+ pipeline = TestPipeline(is_integration_test=True)
+ argv = pipeline.get_full_options_as_args()
+
+ known_args, _ = parser.parse_known_args(argv)
+ cls.project_id = known_args.spanner_project_id
+ cls.instance_id = known_args.spanner_instance_id
+ use_spanner_emulator = not known_args.use_real_spanner
+ cls.table = 'xlang_beam_spanner'
+ cls.spanner_helper = SpannerHelper(
+ cls.project_id, cls.instance_id, cls.table, use_spanner_emulator)
+
+ coders.registry.register_coder(SpannerTestRow, coders.RowCoder)
+ coders.registry.register_coder(SpannerPartTestRow, coders.RowCoder)
+ coders.registry.register_coder(SpannerTestKey, coders.RowCoder)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.spanner_helper.shutdown()
+
+ def setUp(self):
+ self.database_id = f'xlang_beam{uuid.uuid4()}'.replace('-', '')[:30]
+ self.spanner_helper.create_database(self.database_id)
+
+ def tearDown(self):
+ self.spanner_helper.drop_database(self.database_id)
+
+ def test_spanner_insert_or_update(self):
+ self.spanner_helper.insert_values(
+ self.database_id, [('or_update0', 5, False), ('or_update1', 9, False)])
+
+ def to_row_fn(i):
+ return SpannerTestRow(
+ f_int64=i, f_string=f'or_update{i}', f_boolean=i % 2 == 0)
+
+ self.run_write_pipeline(3, to_row_fn, SpannerTestRow,
SpannerInsertOrUpdate)
+
+ self.assertEqual(
+ self.spanner_helper.read_data(self.database_id, prefix='or_update'),
+ [[f'or_update{i}', i, i % 2 == 0] for i in range(3)])
+
+ def test_spanner_insert(self):
+ def to_row_fn(num):
+ return SpannerTestRow(
+ f_string=f'insert{num}', f_int64=num, f_boolean=None)
+
+ self.run_write_pipeline(1000, to_row_fn, SpannerTestRow, SpannerInsert)
+
+ def compare_row(row):
+ return row[1]
+
+ self.assertEqual(
+ sorted(
+ self.spanner_helper.read_data(self.database_id, 'insert'),
+ key=compare_row), [[f'insert{i}', i, None] for i in range(1000)])
+
+ def test_spanner_replace(self):
+ self.spanner_helper.insert_values(
+ self.database_id, [('replace0', 0, True), ('replace1', 1, False)])
+
+ def to_row_fn(num):
+ return SpannerPartTestRow(f_string=f'replace{num}', f_int64=num + 10)
+
+ self.run_write_pipeline(2, to_row_fn, SpannerPartTestRow, SpannerReplace)
+
+ self.assertEqual(
+ self.spanner_helper.read_data(self.database_id, prefix='replace'),
+ [['replace0', 10, None], ['replace1', 11, None]])
+
+ def test_spanner_update(self):
+ self.spanner_helper.insert_values(
+ self.database_id, [('update0', 5, False), ('update1', 9, False)])
+
+ def to_row_fn(num):
+ return SpannerPartTestRow(f_string=f'update{num}', f_int64=num + 10)
+
+ self.run_write_pipeline(2, to_row_fn, SpannerPartTestRow, SpannerUpdate)
+
+ self.assertEqual(
+ self.spanner_helper.read_data(self.database_id, 'update'),
+ [['update0', 10, False], ['update1', 11, False]])
+
+ def test_spanner_delete(self):
+ self.spanner_helper.insert_values(
+ self.database_id,
+ values=[
+ ('delete0', 0, None),
+ ('delete6', 6, False),
+ ('delete20', 20, True),
+ ])
+
+ def to_row_fn(num):
+ return SpannerTestKey(f_string=f'delete{num}')
+
+ self.run_write_pipeline(10, to_row_fn, SpannerTestKey, SpannerDelete)
+
+ self.assertEqual(
+ self.spanner_helper.read_data(self.database_id, prefix='delete'),
+ [['delete20', 20, True]])
+
+ def test_spanner_read_query(self):
+ self.insert_read_values('query_read')
+ self.run_read_pipeline('query_read', query=f'SELECT * FROM {self.table}')
+
+ def test_spanner_read_table(self):
+ self.insert_read_values('table_read')
+ self.run_read_pipeline('table_read', table=self.table)
+
+ def run_read_pipeline(self, prefix, table=None, query=None):
+ with TestPipeline(is_integration_test=True) as p:
+ p.not_use_test_runner_api = True
+ result = (
+ p
+ | ReadFromSpanner(
+ instance_id=self.instance_id,
+ database_id=self.database_id,
+ project_id=self.project_id,
+ row_type=SpannerTestRow,
+ sql=query,
+ table=table,
+ emulator_host=self.spanner_helper.get_emulator_host(),
+ ))
+
+ assert_that(
+ result,
+ equal_to([
+ SpannerTestRow(f_int64=0, f_string=f'{prefix}0', f_boolean=None),
+ SpannerTestRow(f_int64=1, f_string=f'{prefix}1', f_boolean=True),
+ SpannerTestRow(f_int64=2, f_string=f'{prefix}2',
f_boolean=False),
+ ]))
+
+ def run_write_pipeline(
+ self, num_rows, to_row_fn, row_type, spanner_transform=None):
+ with TestPipeline(is_integration_test=True) as p:
+ p.not_use_test_runner_api = True
+ _ = (
+ p
+ | 'Impulse' >> beam.Impulse()
+ | 'Generate' >> beam.FlatMap(lambda x: range(num_rows)) # pylint:
disable=range-builtin-not-iterating
+ | 'Map to row' >> beam.Map(to_row_fn).with_output_types(row_type)
+ | 'Write to Spanner' >> spanner_transform(
+ instance_id=self.instance_id,
+ database_id=self.database_id,
+ project_id=self.project_id,
+ table=self.table,
+ emulator_host=self.spanner_helper.get_emulator_host(),
+ ))
+
+ def insert_read_values(self, prefix):
+ self.spanner_helper.insert_values(
+ self.database_id,
+ values=[
+ (f'{prefix}0', 0, None),
+ (f'{prefix}1', 1, True),
+ (f'{prefix}2', 2, False),
+ ])
+
+
+def retry(fn, retries, err_msg, *args, **kwargs):
+ for _ in range(retries):
+ try:
+ return fn(*args, **kwargs)
+ except: # pylint: disable=bare-except
+ pass
+ logging.error(err_msg)
+ raise RuntimeError(err_msg)
+
+
+class SpannerHelper(object):
+ def __init__(self, project_id, instance_id, table, use_emulator):
+ self.use_emulator = use_emulator
+ self.table = table
+ self.host = None
+ if use_emulator:
+ self.emulator = DockerContainer(
+ 'gcr.io/cloud-spanner-emulator/emulator:latest').with_exposed_ports(
+ 9010, 9020)
+ retry(self.emulator.start, 3, 'Could not start spanner emulator.')
+ time.sleep(3)
+ self.host = f'{self.emulator.get_container_host_ip()}:' \
+ f'{self.emulator.get_exposed_port(9010)}'
+ os.environ['SPANNER_EMULATOR_HOST'] = self.host
+ self.client = spanner.Client(project_id)
+ self.instance = self.client.instance(instance_id)
+ if use_emulator:
+ self.create_instance()
+
+ def create_instance(self):
+ self.instance.create().result(120)
+
+ def create_database(self, database_id):
+ database = self.instance.database(
+ database_id,
+ ddl_statements=[
+ f'''
+ CREATE TABLE {self.table} (
+ f_string STRING(1024) NOT NULL,
+ f_int64 INT64,
+ f_boolean BOOL
+ ) PRIMARY KEY (f_string)'''
+ ])
+ database.create().result(120)
+
+ def insert_values(self, database_id, values, columns=None):
+ values = values or []
+ columns = columns or ('f_string', 'f_int64', 'f_boolean')
+ with self.instance.database(database_id).batch() as batch:
+ batch.insert(
+ table=self.table,
+ columns=columns,
+ values=values,
+ )
+
+ def get_emulator_host(self):
+ return f'http://{self.host}'
+
+ def read_data(self, database_id, prefix):
+ database = self.instance.database(database_id)
+ with database.snapshot() as snapshot:
+ results = snapshot.execute_sql(
+ f'''SELECT * FROM {self.table}
+ WHERE f_string LIKE "{prefix}%"
+ ORDER BY f_int64''')
+ try:
+ rows = list(results) if results else None
+ except IndexError:
+ raise ValueError(f"Spanner results not found for {prefix}.")
+ return rows
+
+ def drop_database(self, database_id):
+ database = self.instance.database(database_id)
+ database.drop()
+
+ def shutdown(self):
+ if self.use_emulator:
+ try:
+ self.emulator.stop()
+ except: # pylint: disable=bare-except
+ logging.error('Could not stop Spanner Cloud emulator.')
+
+
+if __name__ == '__main__':
+ logging.getLogger().setLevel(logging.INFO)
+ unittest.main()
diff --git a/sdks/python/test-suites/portable/common.gradle
b/sdks/python/test-suites/portable/common.gradle
index 239a974..a70af1e 100644
--- a/sdks/python/test-suites/portable/common.gradle
+++ b/sdks/python/test-suites/portable/common.gradle
@@ -166,6 +166,7 @@ project.task("postCommitPy${pythonVersionSuffix}IT") {
':runners:flink:1.10:job-server:shadowJar',
':sdks:java:container:docker',
':sdks:java:testing:kafka-service:buildTestKafkaServiceJar',
+ ':sdks:java:io:google-cloud-platform:expansion-service:shadowJar',
':sdks:java:io:kinesis:expansion-service:shadowJar',
':sdks:java:extensions:schemaio-expansion-service:shadowJar',
]
@@ -176,6 +177,7 @@ project.task("postCommitPy${pythonVersionSuffix}IT") {
"apache_beam.io.external.xlang_jdbcio_it_test",
"apache_beam.io.external.xlang_kafkaio_it_test",
"apache_beam.io.external.xlang_kinesisio_it_test",
+ "apache_beam.io.gcp.tests.xlang_spannerio_it_test",
]
def testOpts = ["--tests=${tests.join(',')}"]
def cmdArgs = mapToArgString([
diff --git a/settings.gradle b/settings.gradle
index 281ab1a..1ad0b34 100644
--- a/settings.gradle
+++ b/settings.gradle
@@ -139,6 +139,7 @@ include ":sdks:java:io:expansion-service"
include ":sdks:java:io:file-based-io-tests"
include ':sdks:java:io:bigquery-io-perf-tests'
include ":sdks:java:io:google-cloud-platform"
+include ":sdks:java:io:google-cloud-platform:expansion-service"
include ":sdks:java:io:hadoop-common"
include ":sdks:java:io:hadoop-file-system"
include ":sdks:java:io:hadoop-format"