This is an automated email from the ASF dual-hosted git repository.
scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new e76b51feb3a Fix RowCoderGenerator to use the encodingPositions when
encoding and decoding the bit set representing null fields. (#32389)
e76b51feb3a is described below
commit e76b51feb3a14abbc2f8c1f3989cbffad9f8f87f
Author: Sam Whittle <[email protected]>
AuthorDate: Wed Sep 11 18:49:27 2024 +0200
Fix RowCoderGenerator to use the encodingPositions when encoding and
decoding the bit set representing null fields. (#32389)
---
.../java/org/apache/beam/sdk/coders/RowCoder.java | 8 +-
.../apache/beam/sdk/coders/RowCoderGenerator.java | 131 +++++++++++++++++---
.../org/apache/beam/sdk/schemas/SchemaCoder.java | 7 +-
.../org/apache/beam/sdk/coders/RowCoderTest.java | 134 ++++++++++++++++++++-
4 files changed, 256 insertions(+), 24 deletions(-)
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java
index 9121b60666a..8fa46dbbd25 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java
@@ -25,6 +25,7 @@ import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.transforms.SerializableFunctions;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TypeDescriptors;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.checkerframework.checker.nullness.qual.Nullable;
/** A sub-class of SchemaCoder that can only encode {@link Row} instances. */
@@ -35,7 +36,12 @@ public class RowCoder extends SchemaCoder<Row> {
/** Override encoding positions for the given schema. */
public static void overrideEncodingPositions(UUID uuid, Map<String, Integer>
encodingPositions) {
- SchemaCoder.overrideEncodingPositions(uuid, encodingPositions);
+ RowCoderGenerator.overrideEncodingPositions(uuid, encodingPositions);
+ }
+
+ @VisibleForTesting
+ static void clearGeneratedRowCoders() {
+ RowCoderGenerator.clearRowCoderCache();
}
private RowCoder(Schema schema) {
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java
index e3bd218945b..7a1b16d7e91 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java
@@ -30,6 +30,7 @@ import java.util.BitSet;
import java.util.Map;
import java.util.UUID;
import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
import net.bytebuddy.ByteBuddy;
import net.bytebuddy.description.modifier.FieldManifestation;
import net.bytebuddy.description.modifier.Ownership;
@@ -53,10 +54,14 @@ import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.Field;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.SchemaCoder;
+import org.apache.beam.sdk.util.StringUtils;
import org.apache.beam.sdk.util.common.ReflectHelpers;
import org.apache.beam.sdk.values.Row;
+import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/**
* A utility for automatically generating a {@link Coder} for {@link Row}
objects corresponding to a
@@ -109,21 +114,99 @@ public abstract class RowCoderGenerator {
private static final String CODERS_FIELD_NAME = "FIELD_CODERS";
private static final String POSITIONS_FIELD_NAME =
"FIELD_ENCODING_POSITIONS";
+ static class WithStackTrace<T> {
+ private final T value;
+ private final String stackTrace;
+
+ public WithStackTrace(T value, String stackTrace) {
+ this.value = value;
+ this.stackTrace = stackTrace;
+ }
+
+ public T getValue() {
+ return value;
+ }
+
+ public String getStackTrace() {
+ return stackTrace;
+ }
+ }
+
// Cache for Coder class that are already generated.
- private static final Map<UUID, Coder<Row>> GENERATED_CODERS =
Maps.newConcurrentMap();
- private static final Map<UUID, Map<String, Integer>>
ENCODING_POSITION_OVERRIDES =
- Maps.newConcurrentMap();
+ @GuardedBy("cacheLock")
+ private static final Map<UUID, WithStackTrace<Coder<Row>>> GENERATED_CODERS
= Maps.newHashMap();
+
+ @GuardedBy("cacheLock")
+ private static final Map<UUID, WithStackTrace<Map<String, Integer>>>
ENCODING_POSITION_OVERRIDES =
+ Maps.newHashMap();
+
+ private static final Object cacheLock = new Object();
+
+ private static final Logger LOG =
LoggerFactory.getLogger(RowCoderGenerator.class);
+
+ private static String getStackTrace() {
+ return StringUtils.arrayToNewlines(Thread.currentThread().getStackTrace(),
10);
+ }
public static void overrideEncodingPositions(UUID uuid, Map<String, Integer>
encodingPositions) {
- ENCODING_POSITION_OVERRIDES.put(uuid, encodingPositions);
+ final String stackTrace = getStackTrace();
+ synchronized (cacheLock) {
+ @Nullable
+ WithStackTrace<Map<String, Integer>> previousEncodingPositions =
+ ENCODING_POSITION_OVERRIDES.put(
+ uuid, new WithStackTrace<>(encodingPositions, stackTrace));
+ @Nullable WithStackTrace<Coder<Row>> existingCoder =
GENERATED_CODERS.get(uuid);
+ if (previousEncodingPositions == null) {
+ if (existingCoder != null) {
+ LOG.error(
+ "Received encoding positions for uuid {} too late after creating
RowCoder. Created: {}\n Override: {}",
+ uuid,
+ existingCoder.getStackTrace(),
+ stackTrace);
+ } else {
+ LOG.info("Received encoding positions {} for uuid {}.",
encodingPositions, uuid);
+ }
+ } else if
(!previousEncodingPositions.getValue().equals(encodingPositions)) {
+ if (existingCoder == null) {
+ LOG.error(
+ "Received differing encoding positions for uuid {} before coder
creation. Was {} at {}\n Now {} at {}",
+ uuid,
+ previousEncodingPositions.getValue(),
+ encodingPositions,
+ previousEncodingPositions.getStackTrace(),
+ stackTrace);
+ } else {
+ LOG.error(
+ "Received differing encoding positions for uuid {} after coder
creation at {}\n. "
+ + "Was {} at {}\n Now {} at {}\n",
+ uuid,
+ existingCoder.getStackTrace(),
+ previousEncodingPositions.getValue(),
+ encodingPositions,
+ previousEncodingPositions.getStackTrace(),
+ stackTrace);
+ }
+ }
+ }
+ }
+
+ @VisibleForTesting
+ static void clearRowCoderCache() {
+ synchronized (cacheLock) {
+ GENERATED_CODERS.clear();
+ }
}
@SuppressWarnings("unchecked")
public static Coder<Row> generate(Schema schema) {
- // Using ConcurrentHashMap::computeIfAbsent here would deadlock in case of
nested
- // coders. Using HashMap::computeIfAbsent generates
ConcurrentModificationExceptions in Java 11.
- Coder<Row> rowCoder = GENERATED_CODERS.get(schema.getUUID());
- if (rowCoder == null) {
+ String stackTrace = getStackTrace();
+ UUID uuid = Preconditions.checkNotNull(schema.getUUID());
+ // Avoid using computeIfAbsent which may cause issues with nested schemas.
+ synchronized (cacheLock) {
+ @Nullable WithStackTrace<Coder<Row>> existingRowCoder =
GENERATED_CODERS.get(uuid);
+ if (existingRowCoder != null) {
+ return existingRowCoder.getValue();
+ }
TypeDescription.Generic coderType =
TypeDescription.Generic.Builder.parameterizedType(Coder.class,
Row.class).build();
DynamicType.Builder<Coder> builder =
@@ -131,8 +214,13 @@ public abstract class RowCoderGenerator {
builder = implementMethods(schema, builder);
int[] encodingPosToRowIndex = new int[schema.getFieldCount()];
+ @Nullable
+ WithStackTrace<Map<String, Integer>> existingEncodingPositions =
+ ENCODING_POSITION_OVERRIDES.get(uuid);
Map<String, Integer> encodingPositions =
- ENCODING_POSITION_OVERRIDES.getOrDefault(schema.getUUID(),
schema.getEncodingPositions());
+ existingEncodingPositions == null
+ ? schema.getEncodingPositions()
+ : existingEncodingPositions.getValue();
for (int recordIndex = 0; recordIndex < schema.getFieldCount();
++recordIndex) {
String name = schema.getField(recordIndex).getName();
int encodingPosition = encodingPositions.get(name);
@@ -163,6 +251,7 @@ public abstract class RowCoderGenerator {
.withParameters(Coder[].class, int[].class)
.intercept(new GeneratedCoderConstructor());
+ Coder<Row> rowCoder;
try {
rowCoder =
builder
@@ -179,9 +268,14 @@ public abstract class RowCoderGenerator {
| InvocationTargetException e) {
throw new RuntimeException("Unable to generate coder for schema " +
schema, e);
}
- GENERATED_CODERS.put(schema.getUUID(), rowCoder);
+ GENERATED_CODERS.put(uuid, new WithStackTrace<>(rowCoder, stackTrace));
+ LOG.debug(
+ "Created row coder for uuid {} with encoding positions {} at {}",
+ uuid,
+ encodingPositions,
+ stackTrace);
+ return rowCoder;
}
- return rowCoder;
}
private static class GeneratedCoderConstructor implements Implementation {
@@ -326,7 +420,7 @@ public abstract class RowCoderGenerator {
}
// Encode a bitmap for the null fields to save having to encode a
bunch of nulls.
- NULL_LIST_CODER.encode(scanNullFields(fieldValues), outputStream);
+ NULL_LIST_CODER.encode(scanNullFields(fieldValues,
encodingPosToIndex), outputStream);
for (int encodingPos = 0; encodingPos < fieldValues.length;
++encodingPos) {
@Nullable Object fieldValue =
fieldValues[encodingPosToIndex[encodingPos]];
if (fieldValue != null) {
@@ -348,14 +442,15 @@ public abstract class RowCoderGenerator {
// Figure out which fields of the Row are null, and returns a BitSet. This
allows us to save
// on encoding each null field separately.
- private static BitSet scanNullFields(Object[] fieldValues) {
+ private static BitSet scanNullFields(Object[] fieldValues, int[]
encodingPosToIndex) {
+ Preconditions.checkState(fieldValues.length ==
encodingPosToIndex.length);
BitSet nullFields = new BitSet(fieldValues.length);
- for (int idx = 0; idx < fieldValues.length; ++idx) {
- if (fieldValues[idx] == null) {
- nullFields.set(idx);
+ for (int encodingPos = 0; encodingPos < encodingPosToIndex.length;
++encodingPos) {
+ int fieldIndex = encodingPosToIndex[encodingPos];
+ if (fieldValues[fieldIndex] == null) {
+ nullFields.set(encodingPos);
}
}
-
return nullFields;
}
}
@@ -425,7 +520,7 @@ public abstract class RowCoderGenerator {
// in which case we drop the extra fields.
if (encodingPos < coders.length) {
int rowIndex = encodingPosToIndex[encodingPos];
- if (nullFields.get(rowIndex)) {
+ if (nullFields.get(encodingPos)) {
fieldValues[rowIndex] = null;
} else {
Object fieldValue = coders[encodingPos].decode(inputStream);
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java
index 323f4e98dc5..b93b64f7dbe 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java
@@ -164,7 +164,10 @@ public class SchemaCoder<T> extends CustomCoder<T> {
}
// Sets the schema id, and then recursively ensures that all schemas have
ids set.
- private static void setSchemaIds(Schema schema) {
+ private static void setSchemaIds(@Nullable Schema schema) {
+ if (schema == null) {
+ return;
+ }
if (schema.getUUID() == null) {
schema.setUUID(UUID.randomUUID());
}
@@ -187,7 +190,7 @@ public class SchemaCoder<T> extends CustomCoder<T> {
return;
case ARRAY:
- case ITERABLE:;
+ case ITERABLE:
setSchemaIds(fieldType.getCollectionElementType());
return;
diff --git
a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java
index f62a2611a1c..885ff8f1491 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java
@@ -22,10 +22,12 @@ import static org.junit.Assert.assertEquals;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.math.BigDecimal;
+import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
+import java.util.UUID;
import org.apache.beam.sdk.coders.Coder.NonDeterministicException;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.FieldType;
@@ -37,6 +39,7 @@ import org.apache.beam.sdk.values.Row;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
+import org.checkerframework.checker.nullness.qual.NonNull;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import org.junit.Assume;
@@ -62,7 +65,7 @@ public class RowCoderTest {
.build();
DateTime dateTime =
- new DateTime().withDate(1979, 03, 14).withTime(1, 2, 3,
4).withZone(DateTimeZone.UTC);
+ new DateTime().withDate(1979, 3, 14).withTime(1, 2, 3,
4).withZone(DateTimeZone.UTC);
Row row =
Row.withSchema(schema)
.addValues(
@@ -219,12 +222,14 @@ public class RowCoderTest {
}
@Override
- public Value toBaseType(String input) {
+ @NonNull
+ public Value toBaseType(@NonNull String input) {
return enumeration.valueOf(input);
}
@Override
- public String toInputType(Value base) {
+ @NonNull
+ public String toInputType(@NonNull Value base) {
return enumeration.toString(base);
}
}
@@ -401,6 +406,129 @@ public class RowCoderTest {
assertEquals(expected, decoded);
}
+ @Test
+ public void testEncodingPositionReorderFieldsWithNulls() throws Exception {
+ Schema schema1 =
+ Schema.builder()
+ .addNullableField("f_int32", FieldType.INT32)
+ .addNullableField("f_string", FieldType.STRING)
+ .build();
+ Schema schema2 =
+ Schema.builder()
+ .addNullableField("f_string", FieldType.STRING)
+ .addNullableField("f_int32", FieldType.INT32)
+ .build();
+ schema2.setEncodingPositions(ImmutableMap.of("f_int32", 0, "f_string", 1));
+ Row schema1row =
+ Row.withSchema(schema1)
+ .withFieldValue("f_int32", null)
+ .withFieldValue("f_string", "hello world!")
+ .build();
+
+ Row schema2row =
+ Row.withSchema(schema2)
+ .withFieldValue("f_int32", null)
+ .withFieldValue("f_string", "hello world!")
+ .build();
+
+ ByteArrayOutputStream os = new ByteArrayOutputStream();
+ RowCoder.of(schema1).encode(schema1row, os);
+ Row schema1to2decoded = RowCoder.of(schema2).decode(new
ByteArrayInputStream(os.toByteArray()));
+ assertEquals(schema2row, schema1to2decoded);
+
+ os.reset();
+ RowCoder.of(schema2).encode(schema2row, os);
+ Row schema2to1decoded = RowCoder.of(schema1).decode(new
ByteArrayInputStream(os.toByteArray()));
+ assertEquals(schema1row, schema2to1decoded);
+ }
+
+ @Test
+ public void testEncodingPositionReorderViaStaticOverride() throws Exception {
+ Schema schema1 =
+ Schema.builder()
+ .addNullableField("failsafeTableRowPayload", FieldType.STRING)
+ .addByteArrayField("payload")
+ .addNullableField("timestamp", FieldType.INT32)
+ .addNullableField("unknownFieldsPayload", FieldType.STRING)
+ .build();
+ UUID uuid = UUID.randomUUID();
+ schema1.setUUID(uuid);
+
+ Row row =
+ Row.withSchema(schema1)
+ .addValues("", "hello world!".getBytes(StandardCharsets.UTF_8), 1,
"")
+ .build();
+ ByteArrayOutputStream os = new ByteArrayOutputStream();
+ RowCoder.of(schema1).encode(row, os);
+ // Pretend that we are restarting and want to recover from persisted state
with a compatible
+ // schema using the
+ // overridden encoding positions.
+ RowCoder.clearGeneratedRowCoders();
+ RowCoder.overrideEncodingPositions(
+ uuid,
+ ImmutableMap.of(
+ "failsafeTableRowPayload", 0, "payload", 1, "timestamp", 2,
"unknownFieldsPayload", 3));
+
+ Schema schema2 =
+ Schema.builder()
+ .addByteArrayField("payload")
+ .addNullableField("timestamp", FieldType.INT32)
+ .addNullableField("unknownFieldsPayload", FieldType.STRING)
+ .addNullableField("failsafeTableRowPayload", FieldType.STRING)
+ .build();
+ schema2.setUUID(uuid);
+
+ Row expected =
+ Row.withSchema(schema2)
+ .addValues("hello world!".getBytes(StandardCharsets.UTF_8), 1, "",
"")
+ .build();
+ Row decoded = RowCoder.of(schema2).decode(new
ByteArrayInputStream(os.toByteArray()));
+ assertEquals(expected, decoded);
+ }
+
+ @Test
+ public void testEncodingPositionReorderViaStaticOverrideWithNulls() throws
Exception {
+ Schema schema1 =
+ Schema.builder()
+ .addNullableField("failsafeTableRowPayload", FieldType.BYTES)
+ .addByteArrayField("payload")
+ .addNullableField("timestamp", FieldType.INT32)
+ .addNullableField("unknownFieldsPayload", FieldType.BYTES)
+ .build();
+ UUID uuid = UUID.randomUUID();
+ schema1.setUUID(uuid);
+
+ Row row =
+ Row.withSchema(schema1)
+ .addValues(null, "hello world!".getBytes(StandardCharsets.UTF_8),
1, null)
+ .build();
+ ByteArrayOutputStream os = new ByteArrayOutputStream();
+ RowCoder.of(schema1).encode(row, os);
+ // Pretend that we are restarting and want to recover from persisted state
with a compatible
+ // schema using the overridden encoding positions.
+ RowCoder.clearGeneratedRowCoders();
+ RowCoder.overrideEncodingPositions(
+ uuid,
+ ImmutableMap.of(
+ "failsafeTableRowPayload", 0, "payload", 1, "timestamp", 2,
"unknownFieldsPayload", 3));
+
+ Schema schema2 =
+ Schema.builder()
+ .addByteArrayField("payload")
+ .addNullableField("timestamp", FieldType.INT32)
+ .addNullableField("unknownFieldsPayload", FieldType.BYTES)
+ .addNullableField("failsafeTableRowPayload", FieldType.BYTES)
+ .build();
+ schema2.setUUID(uuid);
+
+ Row expected =
+ Row.withSchema(schema2)
+ .addValues("hello world!".getBytes(StandardCharsets.UTF_8), 1,
null, null)
+ .build();
+ Row decoded = RowCoder.of(schema2).decode(new
ByteArrayInputStream(os.toByteArray()));
+ assertEquals(expected, decoded);
+ }
+
@Test
public void testEncodingPositionAddNewFields() throws Exception {
Schema schema1 =