[
https://issues.apache.org/jira/browse/BEAM-5918?focusedWorklogId=162736&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-162736
]
ASF GitHub Bot logged work on BEAM-5918:
----------------------------------------
Author: ASF GitHub Bot
Created on: 05/Nov/18 20:11
Start Date: 05/Nov/18 20:11
Worklog Time Spent: 10m
Work Description: kennknowles closed pull request #6888: [BEAM-5918] Add
Cast transform for Rows
URL: https://github.com/apache/beam/pull/6888
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java
index 86a0f4653d5..1587a6bbee7 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java
@@ -292,7 +292,7 @@ public int hashCode() {
INT16, // two-byte signed integer.
INT32, // four-byte signed integer.
INT64, // eight-byte signed integer.
- DECIMAL, // Decimal integer
+ DECIMAL, // Arbitrary-precision decimal number
FLOAT,
DOUBLE,
STRING, // String.
@@ -338,6 +338,47 @@ public boolean isMapType() {
public boolean isCompositeType() {
return COMPOSITE_TYPES.contains(this);
}
+
+ public boolean isSubtypeOf(TypeName other) {
+ return other.isSupertypeOf(this);
+ }
+
+ public boolean isSupertypeOf(TypeName other) {
+ if (this == other) {
+ return true;
+ }
+
+ // defined only for numeric types
+ if (!isNumericType() || !other.isNumericType()) {
+ return false;
+ }
+
+ switch (this) {
+ case BYTE:
+ return false;
+
+ case INT16:
+ return other == BYTE;
+
+ case INT32:
+ return other == BYTE || other == INT16;
+
+ case INT64:
+ return other == BYTE || other == INT16 || other == INT32;
+
+ case FLOAT:
+ return false;
+
+ case DOUBLE:
+ return other == FLOAT;
+
+ case DECIMAL:
+ return other == FLOAT || other == DOUBLE;
+
+ default:
+ throw new AssertionError("Unexpected numeric type: " + this);
+ }
+ }
}
/**
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Cast.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Cast.java
new file mode 100644
index 00000000000..3048806edf0
--- /dev/null
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Cast.java
@@ -0,0 +1,440 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.schemas.transforms;
+
+import com.google.auto.value.AutoValue;
+import com.google.common.base.Joiner;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Maps;
+import java.io.Serializable;
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.stream.Collectors;
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.Field;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
+import org.apache.beam.sdk.schemas.Schema.TypeName;
+import org.apache.beam.sdk.schemas.utils.SchemaZipFold;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.Row;
+
+/** Set of utilities for casting rows between schemas. */
+@Experimental(Experimental.Kind.SCHEMAS)
+@AutoValue
+public abstract class Cast<T> extends PTransform<PCollection<T>,
PCollection<Row>> {
+
+ public abstract Schema outputSchema();
+
+ public abstract Validator validator();
+
+ public static <T> Cast<T> of(Schema outputSchema, Validator validator) {
+ return new AutoValue_Cast<>(outputSchema, validator);
+ }
+
+ public static <T> Cast<T> widening(Schema outputSchema) {
+ return new AutoValue_Cast<>(outputSchema, Widening.of());
+ }
+
+ public static <T> Cast<T> narrowing(Schema outputSchema) {
+ return new AutoValue_Cast<>(outputSchema, Narrowing.of());
+ }
+
+ /** Describes compatibility errors during casting. */
+ @AutoValue
+ public abstract static class CompatibilityError implements Serializable {
+
+ public abstract List<String> path();
+
+ public abstract String message();
+
+ public static CompatibilityError create(List<String> path, String message)
{
+ return new AutoValue_Cast_CompatibilityError(path, message);
+ }
+ }
+
+ /** Interface for statically validating casts. */
+ public interface Validator extends Serializable {
+ List<CompatibilityError> apply(Schema input, Schema output);
+ }
+
+ /**
+ * Widening changes to type that can represent any possible value of the
original type.
+ *
+ * <p>Standard widening conversions:
+ *
+ * <ul>
+ * <li>BYTE to INT16, INT32, INT64, FLOAT, DOUBLE, DECIMAL
+ * <li>INT16 to INT32, INT64, FLOAT, DOUBLE, DECIMAL
+ * <li>INT32 to INT64, FLOAT, DOUBLE, DECIMAL
+ * <li>INT64 to FLOAT, DOUBLE, DECIMAL
+ * <li>FLOAT to DOUBLE, DECIMAL
+ * <li>DOUBLE to DECIMAL
+ * </ul>
+ *
+ * <p>Row widening:
+ *
+ * <ul>
+ * <li>wider schema to schema with a subset of fields
+ * <li>non-nullable fields to nullable fields
+ * </ul>
+ *
+ * <p>Widening doesn't lose information about the overall magnitude in
following cases:
+ *
+ * <ul>
+ * <li>integral type to another integral type
+ * <li>BYTE or INT16 to FLOAT, DOUBLE or DECIMAL
+ * <li>INT32 to DOUBLE
+ * </ul>
+ *
+ * <p>Other conversions to may cause loss of precision.
+ */
+ public static class Widening implements Validator {
+ private final Fold fold = new Fold();
+
+ public static Widening of() {
+ return new Widening();
+ }
+
+ @Override
+ public String toString() {
+ return "Cast.Widening";
+ }
+
+ @Override
+ public List<CompatibilityError> apply(final Schema input, final Schema
output) {
+ return fold.apply(input, output);
+ }
+
+ private static class Fold extends SchemaZipFold<List<CompatibilityError>> {
+
+ @Override
+ public List<CompatibilityError> accumulate(
+ List<CompatibilityError> left, List<CompatibilityError> right) {
+ return
ImmutableList.<CompatibilityError>builder().addAll(left).addAll(right).build();
+ }
+
+ @Override
+ public List<CompatibilityError> accept(
+ Context context, Optional<Field> left, Optional<Field> right) {
+ if (!left.isPresent() && !right.isPresent()) {
+ return Collections.emptyList();
+ } else if (left.isPresent() && !right.isPresent()) {
+ return Collections.emptyList();
+ } else if (!left.isPresent() && right.isPresent()) {
+ return Collections.singletonList(
+ CompatibilityError.create(context.path(), "Field is missing in
output schema"));
+ } else {
+ if (left.get().getNullable() && !right.get().getNullable()) {
+ return Collections.singletonList(
+ CompatibilityError.create(
+ context.path(), "Can't cast nullable field to non-nullable
field"));
+ }
+ }
+
+ return Collections.emptyList();
+ }
+
+ @Override
+ public List<CompatibilityError> accept(Context context, FieldType input,
FieldType output) {
+ TypeName inputType = input.getTypeName();
+ TypeName outputType = output.getTypeName();
+
+ boolean supertype = outputType.isSupertypeOf(inputType);
+
+ if (isIntegral(inputType) && isDecimal(outputType)) {
+ return Collections.emptyList();
+ } else if (!supertype) {
+ return Collections.singletonList(
+ CompatibilityError.create(
+ context.path(), "Can't cast '" + inputType + "' to '" +
outputType + "'"));
+ }
+
+ return Collections.emptyList();
+ }
+ }
+ }
+
+ /**
+ * Narrowing changes type without guarantee to preserve data.
+ *
+ * <p>Standard narrowing conversions:
+ *
+ * <ul>
+ * <li>any conversions of {@link Widening}
+ * <li>conversions the opposite to {@link Widening}
+ * </ul>
+ *
+ * <p>Row narrowing
+ *
+ * <ul>
+ * <li>wider schema to schema with a subset of fields
+ * <li>non-nullable fields to nullable fields
+ * <li>nullable fields to non-nullable fields
+ * </ul>
+ */
+ public static class Narrowing implements Validator {
+ private final Fold fold = new Fold();
+
+ public static Narrowing of() {
+ return new Narrowing();
+ }
+
+ @Override
+ public String toString() {
+ return "Cast.Narrowing";
+ }
+
+ @Override
+ public List<CompatibilityError> apply(final Schema input, final Schema
output) {
+ return fold.apply(input, output);
+ }
+
+ private static class Fold extends SchemaZipFold<List<CompatibilityError>> {
+
+ @Override
+ public List<CompatibilityError> accumulate(
+ List<CompatibilityError> left, List<CompatibilityError> right) {
+ return
ImmutableList.<CompatibilityError>builder().addAll(left).addAll(right).build();
+ }
+
+ @Override
+ public List<CompatibilityError> accept(
+ Context context, Optional<Field> left, Optional<Field> right) {
+
+ if (!left.isPresent() && right.isPresent()) {
+ return Collections.singletonList(
+ CompatibilityError.create(context.path(), "Field is missing in
output schema"));
+ }
+
+ return Collections.emptyList();
+ }
+
+ @Override
+ public List<CompatibilityError> accept(Context context, FieldType input,
FieldType output) {
+ TypeName inputType = input.getTypeName();
+ TypeName outputType = output.getTypeName();
+
+ boolean supertype = outputType.isSupertypeOf(inputType);
+ boolean subtype = outputType.isSubtypeOf(inputType);
+
+ if (isDecimal(inputType) && isIntegral(outputType)) {
+ return Collections.emptyList();
+ } else if (!supertype && !subtype) {
+ return Collections.singletonList(
+ CompatibilityError.create(
+ context.path(), "Can't cast '" + inputType + "' to '" +
outputType + "'"));
+ }
+
+ return Collections.emptyList();
+ }
+ }
+ }
+
+ /** Checks if type is integral. */
+ public static boolean isIntegral(TypeName type) {
+ return type == TypeName.BYTE
+ || type == TypeName.INT16
+ || type == TypeName.INT32
+ || type == TypeName.INT64;
+ }
+
+ /** Checks if type is decimal. */
+ public static boolean isDecimal(TypeName type) {
+ return type == TypeName.FLOAT || type == TypeName.DOUBLE || type ==
TypeName.DECIMAL;
+ }
+
+ public void verifyCompatibility(Schema inputSchema) {
+ List<CompatibilityError> errors = validator().apply(inputSchema,
outputSchema());
+
+ if (!errors.isEmpty()) {
+ String reason =
+ errors
+ .stream()
+ .map(x -> Joiner.on('.').join(x.path()) + ": " + x.message())
+ .collect(Collectors.joining("\n\t"));
+
+ throw new IllegalArgumentException(
+ "Cast isn't compatible using " + validator() + ":\n\t" + reason);
+ }
+ }
+
+ @Override
+ public PCollection<Row> expand(PCollection<T> input) {
+ Schema inputSchema = input.getSchema();
+
+ verifyCompatibility(inputSchema);
+
+ return input
+ .apply(
+ ParDo.of(
+ new DoFn<T, Row>() {
+ // TODO: This should be the same as resolved so that Beam
knows which fields
+ // are being accessed. Currently Beam only supports wildcard
descriptors.
+ // Once BEAM-4457 is fixed, fix this.
+ @FieldAccess("filterFields")
+ final FieldAccessDescriptor fieldAccessDescriptor =
+ FieldAccessDescriptor.withAllFields();
+
+ @ProcessElement
+ public void process(
+ @FieldAccess("filterFields") Row input,
OutputReceiver<Row> r) {
+ Row output = castRow(input, inputSchema, outputSchema());
+ r.output(output);
+ }
+ }))
+ .setRowSchema(outputSchema());
+ }
+
+ public static Row castRow(Row input, Schema inputSchema, Schema
outputSchema) {
+ if (input == null) {
+ return null;
+ }
+
+ Row.Builder output = Row.withSchema(outputSchema);
+ for (int i = 0; i < outputSchema.getFieldCount(); i++) {
+ Schema.Field outputField = outputSchema.getField(i);
+
+ int fromFieldIdx = inputSchema.indexOf(outputField.getName());
+ Schema.Field inputField = inputSchema.getField(fromFieldIdx);
+
+ Object inputValue = input.getValue(fromFieldIdx);
+ Object outputValue = castValue(inputValue, inputField.getType(),
outputField.getType());
+
+ output.addValue(outputValue);
+ }
+
+ return output.build();
+ }
+
+ public static Number castNumber(Number value, TypeName input, TypeName
output) {
+ if (!input.isNumericType()) {
+ throw new RuntimeException("Can't cast non-numeric types: " + input);
+ }
+
+ if (!output.isNumericType()) {
+ throw new RuntimeException("Can't cast numbers to non-numeric type: " +
output);
+ }
+
+ if (value == null) {
+ return null;
+ }
+
+ if (input == output) {
+ return value;
+ }
+
+ switch (output) {
+ case BYTE:
+ return value.byteValue();
+
+ case INT16:
+ return value.shortValue();
+
+ case INT32:
+ return value.intValue();
+
+ case INT64:
+ return value.longValue();
+
+ case FLOAT:
+ return value.floatValue();
+
+ case DOUBLE:
+ return value.doubleValue();
+
+ case DECIMAL:
+ switch (input) {
+ case BYTE:
+ case INT16:
+ case INT32:
+ return new BigDecimal(value.intValue());
+
+ case INT64:
+ return new BigDecimal(value.longValue());
+
+ case FLOAT:
+ case DOUBLE:
+ return new BigDecimal(value.doubleValue());
+
+ default:
+ throw new AssertionError("Unexpected numeric type: " + output);
+ }
+
+ default:
+ throw new AssertionError("Unexpected numeric type: " + output);
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ public static Object castValue(Object inputValue, FieldType input, FieldType
output) {
+
+ TypeName inputType = input.getTypeName();
+ TypeName outputType = output.getTypeName();
+
+ if (inputValue == null) {
+ return null;
+ }
+
+ switch (inputType) {
+ case ROW:
+ return castRow((Row) inputValue, input.getRowSchema(),
output.getRowSchema());
+
+ case ARRAY:
+ List<Object> inputValues = (List<Object>) inputValue;
+ List<Object> outputValues = new ArrayList<>(inputValues.size());
+
+ for (Object elem : inputValues) {
+ outputValues.add(
+ castValue(elem, input.getCollectionElementType(),
output.getCollectionElementType()));
+ }
+
+ return outputValues;
+
+ case MAP:
+ Map<Object, Object> inputMap = (Map<Object, Object>) inputValue;
+ Map<Object, Object> outputMap =
Maps.newHashMapWithExpectedSize(inputMap.size());
+
+ for (Map.Entry<Object, Object> entry : inputMap.entrySet()) {
+ Object outputKey =
+ castValue(entry.getKey(), input.getMapKeyType(),
output.getMapKeyType());
+ Object outputValue =
+ castValue(entry.getValue(), input.getMapValueType(),
output.getMapValueType());
+
+ outputMap.put(outputKey, outputValue);
+ }
+
+ return outputMap;
+
+ default:
+ if (inputType.isNumericType()) {
+ return castNumber((Number) inputValue, inputType, outputType);
+ } else {
+ throw new IllegalArgumentException("input should be array, map,
numeric or row");
+ }
+ }
+ }
+}
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/SchemaZipFold.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/SchemaZipFold.java
new file mode 100644
index 00000000000..4d24aeda107
--- /dev/null
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/utils/SchemaZipFold.java
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.schemas.utils;
+
+import com.google.auto.value.AutoValue;
+import com.google.common.collect.ImmutableList;
+import java.io.Serializable;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Stream;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.Field;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
+import org.apache.beam.sdk.schemas.Schema.TypeName;
+
+/**
+ * Visitor that zips schemas, and accepts pairs of fields and their types.
+ *
+ * <p>Values returned by `accept` are accumulated.
+ */
+public abstract class SchemaZipFold<T> implements Serializable {
+
+ public final T apply(Schema left, Schema right) {
+ return visit(this, Context.EMPTY, FieldType.row(left),
FieldType.row(right));
+ }
+
+ /** Accumulate two results together. */
+ public abstract T accumulate(T left, T right);
+
+ /** Accepts two components, context.parent() is always ROW, MAP, ARRAY or
absent. */
+ public abstract T accept(Context context, FieldType left, FieldType right);
+
+ /** Accepts two fields, context.parent() is always ROW. */
+ public abstract T accept(Context context, Optional<Field> left,
Optional<Field> right);
+
+ /** Context referring to a current position in a schema. */
+ @AutoValue
+ public abstract static class Context {
+ /** Field path from a root of a schema. */
+ public abstract List<String> path();
+
+ /** Type of parent node in a tree. */
+ public abstract Optional<TypeName> parent();
+
+ public static final Context EMPTY =
Context.create(Collections.emptyList(), Optional.empty());
+
+ public Context withPathPart(String part) {
+ return
create(ImmutableList.<String>builder().addAll(path()).add(part).build(),
parent());
+ }
+
+ public Context withParent(TypeName parent) {
+ return create(path(), Optional.of(parent));
+ }
+
+ public static Context create(List<String> path, Optional<TypeName> parent)
{
+ return new AutoValue_SchemaZipFold_Context(path, parent);
+ }
+ }
+
+ static <T> T visit(SchemaZipFold<T> zipFold, Context context, FieldType
left, FieldType right) {
+ if (left.getTypeName() != right.getTypeName()) {
+ return zipFold.accept(context, left, right);
+ }
+
+ Context newContext = context.withParent(left.getTypeName());
+
+ switch (left.getTypeName()) {
+ case ARRAY:
+ return zipFold.accumulate(
+ zipFold.accept(context, left, right),
+ visit(
+ zipFold,
+ newContext,
+ left.getCollectionElementType(),
+ right.getCollectionElementType()));
+
+ case ROW:
+ return visitRow(zipFold, newContext, left.getRowSchema(),
right.getRowSchema());
+
+ case MAP:
+ return zipFold.accumulate(
+ zipFold.accept(context, left, right),
+ visit(
+ zipFold,
+ newContext,
+ left.getCollectionElementType(),
+ right.getCollectionElementType()));
+
+ default:
+ return zipFold.accept(context, left, right);
+ }
+ }
+
+ static <T> T visitRow(SchemaZipFold<T> zipFold, Context context, Schema
left, Schema right) {
+ T node = zipFold.accept(context, FieldType.row(left),
FieldType.row(right));
+
+ Stream<String> union =
+ Stream.concat(
+ left.getFields().stream().map(Schema.Field::getName),
+ right.getFields().stream().map(Schema.Field::getName))
+ .distinct();
+
+ Stream<String> intersection =
+
left.getFields().stream().map(Schema.Field::getName).filter(right::hasField);
+
+ T inner0 =
+ intersection
+ .map(
+ name ->
+ visit(
+ zipFold,
+ context.withPathPart(name).withParent(TypeName.ROW),
+ left.getField(name).getType(),
+ right.getField(name).getType()))
+ .reduce(node, zipFold::accumulate);
+
+ T inner1 =
+ union
+ .map(
+ name -> {
+ Optional<Field> field0 = Optional.empty();
+ Optional<Field> field1 = Optional.empty();
+
+ if (left.hasField(name)) {
+ field0 = Optional.of(left.getField(name));
+ }
+
+ if (right.hasField(name)) {
+ field1 = Optional.of(right.getField(name));
+ }
+
+ Context newContext =
context.withPathPart(name).withParent(TypeName.ROW);
+ return zipFold.accept(newContext, field0, field1);
+ })
+ .reduce(node, zipFold::accumulate);
+
+ return zipFold.accumulate(zipFold.accumulate(node, inner0), inner1);
+ }
+}
diff --git
a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CastTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CastTest.java
new file mode 100644
index 00000000000..05312d91a29
--- /dev/null
+++
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CastTest.java
@@ -0,0 +1,489 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.schemas.transforms;
+
+import static org.junit.Assert.assertEquals;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableMap;
+import java.util.Arrays;
+import java.util.Objects;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.schemas.DefaultSchema;
+import org.apache.beam.sdk.schemas.JavaFieldSchema;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.Row;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+
+/** Tests for {@link Cast}. */
+public class CastTest {
+
+ @Rule public final transient TestPipeline pipeline = TestPipeline.create();
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testProjection() throws Exception {
+ Schema outputSchema =
pipeline.getSchemaRegistry().getSchema(Projection2.class);
+ PCollection<Projection2> pojos =
+ pipeline
+ .apply(Create.of(new Projection1()))
+ .apply(Cast.widening(outputSchema))
+ .apply(Convert.to(Projection2.class));
+
+ PAssert.that(pojos).containsInAnyOrder(new Projection2());
+ pipeline.run();
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testTypeWiden() throws Exception {
+ Schema outputSchema =
pipeline.getSchemaRegistry().getSchema(TypeWiden2.class);
+
+ PCollection<TypeWiden2> pojos =
+ pipeline
+ .apply(Create.of(new TypeWiden1()))
+ .apply(Cast.widening(outputSchema))
+ .apply(Convert.to(TypeWiden2.class));
+
+ PAssert.that(pojos).containsInAnyOrder(new TypeWiden2());
+ pipeline.run();
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testTypeNarrow() throws Exception {
+ // narrowing is the opposite of widening
+ Schema outputSchema =
pipeline.getSchemaRegistry().getSchema(TypeWiden1.class);
+
+ PCollection<TypeWiden1> pojos =
+ pipeline
+ .apply(Create.of(new TypeWiden2()))
+ .apply(Cast.narrowing(outputSchema))
+ .apply(Convert.to(TypeWiden1.class));
+
+ PAssert.that(pojos).containsInAnyOrder(new TypeWiden1());
+ pipeline.run();
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ @Category(NeedsRunner.class)
+ public void testTypeNarrowFail() throws Exception {
+ // narrowing is the opposite of widening
+ Schema inputSchema =
pipeline.getSchemaRegistry().getSchema(TypeWiden2.class);
+ Schema outputSchema =
pipeline.getSchemaRegistry().getSchema(TypeWiden1.class);
+
+ Cast.narrowing(outputSchema).verifyCompatibility(inputSchema);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testWeakedNullable() throws Exception {
+ Schema outputSchema =
pipeline.getSchemaRegistry().getSchema(Nullable2.class);
+
+ PCollection<Nullable2> pojos =
+ pipeline
+ .apply(Create.of(new Nullable1()))
+ .apply(Cast.narrowing(outputSchema))
+ .apply(Convert.to(Nullable2.class));
+
+ PAssert.that(pojos).containsInAnyOrder(new Nullable2());
+ pipeline.run();
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ @Category(NeedsRunner.class)
+ public void testWeakedNullableFail() throws Exception {
+ Schema inputSchema =
pipeline.getSchemaRegistry().getSchema(Nullable1.class);
+ Schema outputSchema =
pipeline.getSchemaRegistry().getSchema(Nullable2.class);
+
+ Cast.widening(outputSchema).verifyCompatibility(inputSchema);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testIgnoreNullable() throws Exception {
+ // ignoring nullable is opposite of weakening
+ Schema outputSchema =
pipeline.getSchemaRegistry().getSchema(Nullable1.class);
+
+ PCollection<Nullable1> pojos =
+ pipeline
+ .apply(Create.of(new Nullable2()))
+ .apply(Cast.narrowing(outputSchema))
+ .apply(Convert.to(Nullable1.class));
+
+ PAssert.that(pojos).containsInAnyOrder(new Nullable1());
+ pipeline.run();
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ @Category(NeedsRunner.class)
+ public void testIgnoreNullableFail() throws Exception {
+ // ignoring nullable is opposite of weakening
+ Schema inputSchema =
pipeline.getSchemaRegistry().getSchema(Nullable2.class);
+ Schema outputSchema =
pipeline.getSchemaRegistry().getSchema(Nullable1.class);
+
+ Cast.widening(outputSchema).verifyCompatibility(inputSchema);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testComplexCast() throws Exception {
+ Schema outputSchema = pipeline.getSchemaRegistry().getSchema(All2.class);
+
+ PCollection<All2> pojos =
+ pipeline
+ .apply(Create.of(new All1()))
+ .apply(Cast.narrowing(outputSchema))
+ .apply(Convert.to(All2.class));
+
+ PAssert.that(pojos).containsInAnyOrder(new All2());
+ pipeline.run();
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ @Category(NeedsRunner.class)
+ public void testComplexCastFail() throws Exception {
+ Schema inputSchema = pipeline.getSchemaRegistry().getSchema(All1.class);
+ Schema outputSchema = pipeline.getSchemaRegistry().getSchema(All2.class);
+
+ Cast.widening(outputSchema).verifyCompatibility(inputSchema);
+ }
+
+ @Test
+ public void testCastArray() {
+ Object output =
+ Cast.castValue(
+ Arrays.asList((short) 1, (short) 2, (short) 3),
+ Schema.FieldType.array(Schema.FieldType.INT16),
+ Schema.FieldType.array(Schema.FieldType.INT32));
+
+ assertEquals(Arrays.asList(1, 2, 3), output);
+ }
+
+ @Test
+ public void testCastMap() {
+ Object output =
+ Cast.castValue(
+ ImmutableMap.of((short) 1, 1, (short) 2, 2, (short) 3, 3),
+ Schema.FieldType.map(Schema.FieldType.INT16,
Schema.FieldType.INT32),
+ Schema.FieldType.map(Schema.FieldType.INT32,
Schema.FieldType.INT64));
+
+ assertEquals(ImmutableMap.of(1, 1L, 2, 2L, 3, 3L), output);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testIgnoreNullFail() {
+ Schema inputSchema = Schema.of(Schema.Field.nullable("f0",
Schema.FieldType.INT32));
+ Schema outputSchema = Schema.of(Schema.Field.of("f0",
Schema.FieldType.INT32));
+
+ Cast.castRow(Row.withSchema(inputSchema).addValue(null).build(),
inputSchema, outputSchema);
+ }
+
+ /** POJO for {@link CastTest#testProjection()}. */
+ @DefaultSchema(JavaFieldSchema.class)
+ @VisibleForTesting
+ public static class Projection1 {
+
+ public Short field1 = 42;
+ public Integer field2 = 1337;
+ public String field3 = "field";
+
+ @Override
+ public boolean equals(final Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ final Projection1 pojo1 = (Projection1) o;
+ return Objects.equals(field1, pojo1.field1) && Objects.equals(field2,
pojo1.field2);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(field1, field2);
+ }
+
+ @Override
+ public String toString() {
+ return "Projection1{"
+ + "field1="
+ + field1
+ + ", field2="
+ + field2
+ + ", field3='"
+ + field3
+ + '\''
+ + '}';
+ }
+ }
+
+ /** POJO for {@link CastTest#testProjection()}. */
+ @DefaultSchema(JavaFieldSchema.class)
+ @VisibleForTesting
+ public static class Projection2 {
+ public Integer field2 = 1337;
+
+ @Override
+ public boolean equals(final Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ final Projection2 that = (Projection2) o;
+ return Objects.equals(field2, that.field2);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(field2);
+ }
+
+ @Override
+ public String toString() {
+ return "Projection2{" + "field2=" + field2 + '}';
+ }
+ }
+
+ /** POJO for {@link CastTest#testTypeWiden()}. */
+ @DefaultSchema(JavaFieldSchema.class)
+ public static class TypeWiden1 {
+
+ public Short field1 = 42;
+ public Integer field2 = 1337;
+
+ @Override
+ public boolean equals(final Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ final TypeWiden1 typeWiden1 = (TypeWiden1) o;
+ return Objects.equals(field1, typeWiden1.field1) &&
Objects.equals(field2, typeWiden1.field2);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(field1, field2);
+ }
+
+ @Override
+ public String toString() {
+ return "TypeWiden1{" + "field1=" + field1 + ", field2=" + field2 + '}';
+ }
+ }
+
+ /** POJO for {@link CastTest#testTypeWiden()}. */
+ @DefaultSchema(JavaFieldSchema.class)
+ @VisibleForTesting
+ public static class TypeWiden2 {
+
+ public Integer field1 = 42;
+ public Long field2 = 1337L;
+
+ @Override
+ public boolean equals(final Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ final TypeWiden2 typeWiden2 = (TypeWiden2) o;
+ return Objects.equals(field1, typeWiden2.field1) &&
Objects.equals(field2, typeWiden2.field2);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(field1, field2);
+ }
+
+ @Override
+ public String toString() {
+ return "TypeWiden2{" + "field1=" + field1 + ", field2=" + field2 + '}';
+ }
+ }
+
+ /** POJO for {@link CastTest#testWeakedNullable()}. */
+ @DefaultSchema(JavaFieldSchema.class)
+ @VisibleForTesting
+ public static class Nullable1 {
+ public Integer field1 = 42;
+ public @Nullable Long field2 = null;
+
+ @Override
+ public boolean equals(final Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ final Nullable1 nullable1 = (Nullable1) o;
+ return Objects.equals(field1, nullable1.field1) &&
Objects.equals(field2, nullable1.field2);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(field1, field2);
+ }
+
+ @Override
+ public String toString() {
+ return "Nullable1{" + "field1=" + field1 + ", field2=" + field2 + '}';
+ }
+ }
+
+ /** POJO for {@link CastTest#testWeakedNullable()}. */
+ @DefaultSchema(JavaFieldSchema.class)
+ @VisibleForTesting
+ public static class Nullable2 {
+ public @Nullable Integer field1 = 42;
+ public @Nullable Long field2 = null;
+
+ @Override
+ public boolean equals(final Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ final Nullable2 nullable2 = (Nullable2) o;
+ return Objects.equals(field1, nullable2.field1) &&
Objects.equals(field2, nullable2.field2);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(field1, field2);
+ }
+
+ @Override
+ public String toString() {
+ return "Nullable2{" + "field1=" + field1 + ", field2=" + field2 + '}';
+ }
+ }
+
+ /** POJO for {@link CastTest#testComplexCast()}. */
+ @DefaultSchema(JavaFieldSchema.class)
+ @VisibleForTesting
+ public static class All1 {
+ public Projection1 field1 = new Projection1();
+ public TypeWiden1 field2 = new TypeWiden1();
+ public TypeWiden2 field3 = new TypeWiden2();
+ public Nullable1 field4 = new Nullable1();
+ public @Nullable Nullable2 field5 = new Nullable2();
+
+ @Override
+ public boolean equals(final Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ final All1 all1 = (All1) o;
+ return Objects.equals(field1, all1.field1)
+ && Objects.equals(field2, all1.field2)
+ && Objects.equals(field3, all1.field3)
+ && Objects.equals(field4, all1.field4)
+ && Objects.equals(field5, all1.field5);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(field1, field2, field3, field4, field5);
+ }
+
+ @Override
+ public String toString() {
+ return "All1{"
+ + "field1="
+ + field1
+ + ", field2="
+ + field2
+ + ", field3="
+ + field3
+ + ", field4="
+ + field4
+ + ", field5="
+ + field5
+ + '}';
+ }
+ }
+
+ /** POJO for {@link CastTest#testComplexCast()}. */
+ @DefaultSchema(JavaFieldSchema.class)
+ @VisibleForTesting
+ public static class All2 {
+ public Projection2 field1 = new Projection2();
+ public TypeWiden2 field2 = new TypeWiden2();
+ public TypeWiden1 field3 = new TypeWiden1();
+ public Nullable2 field4 = new Nullable2();
+ public @Nullable Nullable1 field5 = new Nullable1();
+
+ @Override
+ public boolean equals(final Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ final All2 all2 = (All2) o;
+ return Objects.equals(field1, all2.field1)
+ && Objects.equals(field2, all2.field2)
+ && Objects.equals(field3, all2.field3)
+ && Objects.equals(field4, all2.field4)
+ && Objects.equals(field5, all2.field5);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(field1, field2, field3, field4, field5);
+ }
+
+ @Override
+ public String toString() {
+ return "All2{"
+ + "field1="
+ + field1
+ + ", field2="
+ + field2
+ + ", field3="
+ + field3
+ + ", field4="
+ + field4
+ + ", field5="
+ + field5
+ + '}';
+ }
+ }
+}
diff --git
a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CastValidatorTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CastValidatorTest.java
new file mode 100644
index 00000000000..af09d40e619
--- /dev/null
+++
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/CastValidatorTest.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.schemas.transforms;
+
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.not;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import java.math.BigDecimal;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.FieldType;
+import org.apache.beam.sdk.schemas.Schema.TypeName;
+import org.junit.Test;
+
+/** Tests for {@link Cast.Widening}, {@link Cast.Narrowing}. */
+public class CastValidatorTest {
+
+ public static final Map<TypeName, Number> NUMERICS =
+ ImmutableMap.<TypeName, Number>builder()
+ .put(TypeName.BYTE, Byte.valueOf((byte) 42))
+ .put(TypeName.INT16, Short.valueOf((short) 42))
+ .put(TypeName.INT32, Integer.valueOf(42))
+ .put(TypeName.INT64, Long.valueOf(42))
+ .put(TypeName.FLOAT, Float.valueOf(42))
+ .put(TypeName.DOUBLE, Double.valueOf(42))
+ .put(TypeName.DECIMAL, BigDecimal.valueOf(42))
+ .build();
+
+ public static final List<TypeName> NUMERIC_ORDER =
+ ImmutableList.of(
+ TypeName.BYTE,
+ TypeName.INT16,
+ TypeName.INT32,
+ TypeName.INT64,
+ TypeName.FLOAT,
+ TypeName.DOUBLE,
+ TypeName.DECIMAL);
+
+ @Test
+ public void testWideningOrder() {
+ NUMERICS
+ .keySet()
+ .forEach(input -> NUMERICS.keySet().forEach(output ->
testWideningOrder(input, output)));
+ }
+
+ @Test
+ public void testCasting() {
+ NUMERICS
+ .keySet()
+ .forEach(input -> NUMERICS.keySet().forEach(output ->
testCasting(input, output)));
+ }
+
+ public void testCasting(TypeName inputType, TypeName outputType) {
+ Object output =
+ Cast.castValue(NUMERICS.get(inputType), FieldType.of(inputType),
FieldType.of(outputType));
+
+ assertEquals(NUMERICS.get(outputType), output);
+ }
+
+ @Test
+ public void testCastingCompleteness() {
+ boolean all =
+
NUMERIC_ORDER.stream().filter(TypeName::isNumericType).allMatch(NUMERIC_ORDER::contains);
+
+ assertTrue(all);
+ }
+
+ public void testWideningOrder(TypeName input, TypeName output) {
+ Schema inputSchema = Schema.of(Schema.Field.of("f0", FieldType.of(input)));
+ Schema outputSchema = Schema.of(Schema.Field.of("f0",
FieldType.of(output)));
+
+ List<Cast.CompatibilityError> errors =
Cast.Widening.of().apply(inputSchema, outputSchema);
+
+ if (NUMERIC_ORDER.indexOf(input) <= NUMERIC_ORDER.indexOf(output)) {
+ assertThat(input + " is before " + output, errors, empty());
+ } else {
+ assertThat(input + " is after " + output, errors, not(empty()));
+ }
+ }
+
+ @Test
+ public void testWideningNullableToNotNullable() {
+ Schema input = Schema.of(Schema.Field.nullable("f0", FieldType.INT32));
+ Schema output = Schema.of(Schema.Field.of("f0", FieldType.INT32));
+
+ List<Cast.CompatibilityError> errors = Cast.Widening.of().apply(input,
output);
+ Cast.CompatibilityError expected =
+ Cast.CompatibilityError.create(
+ Arrays.asList("f0"), "Can't cast nullable field to non-nullable
field");
+
+ assertThat(errors, containsInAnyOrder(expected));
+ }
+
+ @Test
+ public void testNarrowingNullableToNotNullable() {
+ Schema input = Schema.of(Schema.Field.nullable("f0", FieldType.INT32));
+ Schema output = Schema.of(Schema.Field.of("f0", FieldType.INT32));
+
+ List<Cast.CompatibilityError> errors = Cast.Narrowing.of().apply(input,
output);
+
+ assertThat(errors, empty());
+ }
+}
diff --git
a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/SchemaZipFoldTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/SchemaZipFoldTest.java
new file mode 100644
index 00000000000..65986a0daff
--- /dev/null
+++
b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/utils/SchemaZipFoldTest.java
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.schemas.utils;
+
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThat;
+
+import com.google.common.base.Joiner;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.apache.beam.sdk.schemas.Schema;
+import org.junit.Test;
+
+/** Tests for {@link SchemaZipFold} with examples. */
+public class SchemaZipFoldTest {
+
+ private static final Schema LEFT =
+ Schema.of(
+ Schema.Field.of("left", Schema.FieldType.INT32),
+ Schema.Field.of("f0", Schema.FieldType.INT32),
+ Schema.Field.of("f1", Schema.FieldType.INT32),
+ Schema.Field.of("f2", Schema.FieldType.INT32),
+ Schema.Field.of(
+ "f3",
+ Schema.FieldType.row(
+ Schema.of(
+ Schema.Field.of("inner_left", Schema.FieldType.INT32),
+ Schema.Field.of("f0", Schema.FieldType.INT32),
+ Schema.Field.of("f1", Schema.FieldType.INT32)))));
+
+ private static final Schema RIGHT =
+ Schema.of(
+ Schema.Field.of("right", Schema.FieldType.INT32),
+ Schema.Field.of("f0", Schema.FieldType.INT32),
+ Schema.Field.of("f1", Schema.FieldType.INT32),
+ Schema.Field.of("f2", Schema.FieldType.STRING),
+ Schema.Field.of(
+ "f3",
+ Schema.FieldType.row(
+ Schema.of(
+ Schema.Field.of("inner_right", Schema.FieldType.INT32),
+ Schema.Field.of("f0", Schema.FieldType.INT32),
+ Schema.Field.of("f1", Schema.FieldType.STRING)))));
+
+ @Test
+ public void testCountCommonLeafs() {
+ assertEquals(3, new CountCommonLeafs().apply(LEFT, RIGHT).intValue());
+ }
+
+ @Test
+ public void testCountCommonFields() {
+ assertEquals(6, new CountCommonFields().apply(LEFT, RIGHT).intValue());
+ }
+
+ @Test
+ public void testCountMissingFields() {
+ assertEquals(4, new CountMissingFields().apply(LEFT, RIGHT).intValue());
+ }
+
+ @Test
+ public void testListCommonFields() {
+ assertThat(
+ new ListCommonFields().apply(LEFT, RIGHT),
+ containsInAnyOrder("f0", "f1", "f2", "f3", "f3.f0", "f3.f1"));
+ }
+
+ static class CountCommonLeafs extends SchemaZipFold<Integer> {
+
+ @Override
+ public Integer accumulate(Integer left, Integer right) {
+ return left + right;
+ }
+
+ @Override
+ public Integer accept(Context context, Schema.FieldType left,
Schema.FieldType right) {
+
+ if (left.getTypeName() != right.getTypeName()) {
+ return 0;
+ }
+
+ if (left.getTypeName() == Schema.TypeName.ROW) {
+ return 0;
+ }
+
+ if (left.getTypeName() == Schema.TypeName.ARRAY) {
+ return 0;
+ }
+
+ if (left.getTypeName() == Schema.TypeName.MAP) {
+ return 0;
+ }
+
+ return 1;
+ }
+
+ @Override
+ public Integer accept(
+ Context context, Optional<Schema.Field> left, Optional<Schema.Field>
right) {
+ return 0;
+ }
+ }
+
+ static class CountCommonFields extends SchemaZipFold<Integer> {
+
+ @Override
+ public Integer accumulate(Integer left, Integer right) {
+ return left + right;
+ }
+
+ @Override
+ public Integer accept(Context context, Schema.FieldType left,
Schema.FieldType right) {
+
+ return 0;
+ }
+
+ @Override
+ public Integer accept(
+ Context context, Optional<Schema.Field> left, Optional<Schema.Field>
right) {
+ if (left.isPresent() && right.isPresent()) {
+ return 1;
+ } else {
+ return 0;
+ }
+ }
+ }
+
+ static class CountMissingFields extends SchemaZipFold<Integer> {
+
+ @Override
+ public Integer accumulate(Integer left, Integer right) {
+ return left + right;
+ }
+
+ @Override
+ public Integer accept(Context context, Schema.FieldType left,
Schema.FieldType right) {
+
+ return 0;
+ }
+
+ @Override
+ public Integer accept(
+ Context context, Optional<Schema.Field> left, Optional<Schema.Field>
right) {
+ if (!left.isPresent() || !right.isPresent()) {
+ return 1;
+ } else {
+ return 0;
+ }
+ }
+ }
+
+ static class ListCommonFields extends SchemaZipFold<List<String>> {
+
+ @Override
+ public List<String> accumulate(List<String> left, List<String> right) {
+ return Stream.concat(left.stream(),
right.stream()).collect(Collectors.toList());
+ }
+
+ @Override
+ public List<String> accept(Context context, Schema.FieldType left,
Schema.FieldType right) {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public List<String> accept(
+ Context context, Optional<Schema.Field> left, Optional<Schema.Field>
right) {
+ if (left.isPresent() && right.isPresent()) {
+ String pathStr = Joiner.on('.').join(context.path());
+ return Collections.singletonList(pathStr);
+ } else {
+ return Collections.emptyList();
+ }
+ }
+ }
+}
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
Issue Time Tracking
-------------------
Worklog Id: (was: 162736)
Time Spent: 5h (was: 4h 50m)
> Add Cast transform for Rows
> ---------------------------
>
> Key: BEAM-5918
> URL: https://issues.apache.org/jira/browse/BEAM-5918
> Project: Beam
> Issue Type: Improvement
> Components: sdk-java-core
> Reporter: Gleb Kanterov
> Assignee: Gleb Kanterov
> Priority: Major
> Time Spent: 5h
> Remaining Estimate: 0h
>
> There is a need for a generic transform that given two Row schemas will
> convert rows between them. There must be a possibility to opt-out from
> certain kind of conversions, for instance, converting ints to shorts can
> cause overflow. Another example, a schema could have a nullable field, but
> never have NULL value in practice, because it was filtered out.
> What is needed:
> - widening values (e.g., int -> long)
> - narrowwing (e.g., int -> short)
> - runtime check for overflow while narrowing
> - ignoring nullability (nullable=true -> nullable=false)
> - weakening nullability (nullable=false -> nullable=true)
> - projection (Schema(a: Int32, b: Int32) -> Schema(a: Int32))
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)