This is an automated email from the ASF dual-hosted git repository.

scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new e76b51feb3a Fix RowCoderGenerator to use the encodingPositions when 
encoding and decoding the bit set representing null fields. (#32389)
e76b51feb3a is described below

commit e76b51feb3a14abbc2f8c1f3989cbffad9f8f87f
Author: Sam Whittle <[email protected]>
AuthorDate: Wed Sep 11 18:49:27 2024 +0200

    Fix RowCoderGenerator to use the encodingPositions when encoding and 
decoding the bit set representing null fields. (#32389)
---
 .../java/org/apache/beam/sdk/coders/RowCoder.java  |   8 +-
 .../apache/beam/sdk/coders/RowCoderGenerator.java  | 131 +++++++++++++++++---
 .../org/apache/beam/sdk/schemas/SchemaCoder.java   |   7 +-
 .../org/apache/beam/sdk/coders/RowCoderTest.java   | 134 ++++++++++++++++++++-
 4 files changed, 256 insertions(+), 24 deletions(-)

diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java
index 9121b60666a..8fa46dbbd25 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoder.java
@@ -25,6 +25,7 @@ import org.apache.beam.sdk.schemas.SchemaCoder;
 import org.apache.beam.sdk.transforms.SerializableFunctions;
 import org.apache.beam.sdk.values.Row;
 import org.apache.beam.sdk.values.TypeDescriptors;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import org.checkerframework.checker.nullness.qual.Nullable;
 
 /** A sub-class of SchemaCoder that can only encode {@link Row} instances. */
@@ -35,7 +36,12 @@ public class RowCoder extends SchemaCoder<Row> {
 
   /** Override encoding positions for the given schema. */
   public static void overrideEncodingPositions(UUID uuid, Map<String, Integer> 
encodingPositions) {
-    SchemaCoder.overrideEncodingPositions(uuid, encodingPositions);
+    RowCoderGenerator.overrideEncodingPositions(uuid, encodingPositions);
+  }
+
+  @VisibleForTesting
+  static void clearGeneratedRowCoders() {
+    RowCoderGenerator.clearRowCoderCache();
   }
 
   private RowCoder(Schema schema) {
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java
index e3bd218945b..7a1b16d7e91 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/RowCoderGenerator.java
@@ -30,6 +30,7 @@ import java.util.BitSet;
 import java.util.Map;
 import java.util.UUID;
 import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
 import net.bytebuddy.ByteBuddy;
 import net.bytebuddy.description.modifier.FieldManifestation;
 import net.bytebuddy.description.modifier.Ownership;
@@ -53,10 +54,14 @@ import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.schemas.Schema.Field;
 import org.apache.beam.sdk.schemas.Schema.FieldType;
 import org.apache.beam.sdk.schemas.SchemaCoder;
+import org.apache.beam.sdk.util.StringUtils;
 import org.apache.beam.sdk.util.common.ReflectHelpers;
 import org.apache.beam.sdk.values.Row;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /**
  * A utility for automatically generating a {@link Coder} for {@link Row} 
objects corresponding to a
@@ -109,21 +114,99 @@ public abstract class RowCoderGenerator {
   private static final String CODERS_FIELD_NAME = "FIELD_CODERS";
   private static final String POSITIONS_FIELD_NAME = 
"FIELD_ENCODING_POSITIONS";
 
+  static class WithStackTrace<T> {
+    private final T value;
+    private final String stackTrace;
+
+    public WithStackTrace(T value, String stackTrace) {
+      this.value = value;
+      this.stackTrace = stackTrace;
+    }
+
+    public T getValue() {
+      return value;
+    }
+
+    public String getStackTrace() {
+      return stackTrace;
+    }
+  }
+
   // Cache for Coder class that are already generated.
-  private static final Map<UUID, Coder<Row>> GENERATED_CODERS = 
Maps.newConcurrentMap();
-  private static final Map<UUID, Map<String, Integer>> 
ENCODING_POSITION_OVERRIDES =
-      Maps.newConcurrentMap();
+  @GuardedBy("cacheLock")
+  private static final Map<UUID, WithStackTrace<Coder<Row>>> GENERATED_CODERS 
= Maps.newHashMap();
+
+  @GuardedBy("cacheLock")
+  private static final Map<UUID, WithStackTrace<Map<String, Integer>>> 
ENCODING_POSITION_OVERRIDES =
+      Maps.newHashMap();
+
+  private static final Object cacheLock = new Object();
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(RowCoderGenerator.class);
+
+  private static String getStackTrace() {
+    return StringUtils.arrayToNewlines(Thread.currentThread().getStackTrace(), 
10);
+  }
 
   public static void overrideEncodingPositions(UUID uuid, Map<String, Integer> 
encodingPositions) {
-    ENCODING_POSITION_OVERRIDES.put(uuid, encodingPositions);
+    final String stackTrace = getStackTrace();
+    synchronized (cacheLock) {
+      @Nullable
+      WithStackTrace<Map<String, Integer>> previousEncodingPositions =
+          ENCODING_POSITION_OVERRIDES.put(
+              uuid, new WithStackTrace<>(encodingPositions, stackTrace));
+      @Nullable WithStackTrace<Coder<Row>> existingCoder = 
GENERATED_CODERS.get(uuid);
+      if (previousEncodingPositions == null) {
+        if (existingCoder != null) {
+          LOG.error(
+              "Received encoding positions for uuid {} too late after creating 
RowCoder. Created: {}\n Override: {}",
+              uuid,
+              existingCoder.getStackTrace(),
+              stackTrace);
+        } else {
+          LOG.info("Received encoding positions {} for uuid {}.", 
encodingPositions, uuid);
+        }
+      } else if 
(!previousEncodingPositions.getValue().equals(encodingPositions)) {
+        if (existingCoder == null) {
+          LOG.error(
+              "Received differing encoding positions for uuid {} before coder 
creation. Was {} at {}\n Now {} at {}",
+              uuid,
+              previousEncodingPositions.getValue(),
+              encodingPositions,
+              previousEncodingPositions.getStackTrace(),
+              stackTrace);
+        } else {
+          LOG.error(
+              "Received differing encoding positions for uuid {} after coder 
creation at {}\n. "
+                  + "Was {} at {}\n Now {} at {}\n",
+              uuid,
+              existingCoder.getStackTrace(),
+              previousEncodingPositions.getValue(),
+              encodingPositions,
+              previousEncodingPositions.getStackTrace(),
+              stackTrace);
+        }
+      }
+    }
+  }
+
+  @VisibleForTesting
+  static void clearRowCoderCache() {
+    synchronized (cacheLock) {
+      GENERATED_CODERS.clear();
+    }
   }
 
   @SuppressWarnings("unchecked")
   public static Coder<Row> generate(Schema schema) {
-    // Using ConcurrentHashMap::computeIfAbsent here would deadlock in case of 
nested
-    // coders. Using HashMap::computeIfAbsent generates 
ConcurrentModificationExceptions in Java 11.
-    Coder<Row> rowCoder = GENERATED_CODERS.get(schema.getUUID());
-    if (rowCoder == null) {
+    String stackTrace = getStackTrace();
+    UUID uuid = Preconditions.checkNotNull(schema.getUUID());
+    // Avoid using computeIfAbsent which may cause issues with nested schemas.
+    synchronized (cacheLock) {
+      @Nullable WithStackTrace<Coder<Row>> existingRowCoder = 
GENERATED_CODERS.get(uuid);
+      if (existingRowCoder != null) {
+        return existingRowCoder.getValue();
+      }
       TypeDescription.Generic coderType =
           TypeDescription.Generic.Builder.parameterizedType(Coder.class, 
Row.class).build();
       DynamicType.Builder<Coder> builder =
@@ -131,8 +214,13 @@ public abstract class RowCoderGenerator {
       builder = implementMethods(schema, builder);
 
       int[] encodingPosToRowIndex = new int[schema.getFieldCount()];
+      @Nullable
+      WithStackTrace<Map<String, Integer>> existingEncodingPositions =
+          ENCODING_POSITION_OVERRIDES.get(uuid);
       Map<String, Integer> encodingPositions =
-          ENCODING_POSITION_OVERRIDES.getOrDefault(schema.getUUID(), 
schema.getEncodingPositions());
+          existingEncodingPositions == null
+              ? schema.getEncodingPositions()
+              : existingEncodingPositions.getValue();
       for (int recordIndex = 0; recordIndex < schema.getFieldCount(); 
++recordIndex) {
         String name = schema.getField(recordIndex).getName();
         int encodingPosition = encodingPositions.get(name);
@@ -163,6 +251,7 @@ public abstract class RowCoderGenerator {
               .withParameters(Coder[].class, int[].class)
               .intercept(new GeneratedCoderConstructor());
 
+      Coder<Row> rowCoder;
       try {
         rowCoder =
             builder
@@ -179,9 +268,14 @@ public abstract class RowCoderGenerator {
           | InvocationTargetException e) {
         throw new RuntimeException("Unable to generate coder for schema " + 
schema, e);
       }
-      GENERATED_CODERS.put(schema.getUUID(), rowCoder);
+      GENERATED_CODERS.put(uuid, new WithStackTrace<>(rowCoder, stackTrace));
+      LOG.debug(
+          "Created row coder for uuid {} with encoding positions {} at {}",
+          uuid,
+          encodingPositions,
+          stackTrace);
+      return rowCoder;
     }
-    return rowCoder;
   }
 
   private static class GeneratedCoderConstructor implements Implementation {
@@ -326,7 +420,7 @@ public abstract class RowCoderGenerator {
         }
 
         // Encode a bitmap for the null fields to save having to encode a 
bunch of nulls.
-        NULL_LIST_CODER.encode(scanNullFields(fieldValues), outputStream);
+        NULL_LIST_CODER.encode(scanNullFields(fieldValues, 
encodingPosToIndex), outputStream);
         for (int encodingPos = 0; encodingPos < fieldValues.length; 
++encodingPos) {
           @Nullable Object fieldValue = 
fieldValues[encodingPosToIndex[encodingPos]];
           if (fieldValue != null) {
@@ -348,14 +442,15 @@ public abstract class RowCoderGenerator {
 
     // Figure out which fields of the Row are null, and returns a BitSet. This 
allows us to save
     // on encoding each null field separately.
-    private static BitSet scanNullFields(Object[] fieldValues) {
+    private static BitSet scanNullFields(Object[] fieldValues, int[] 
encodingPosToIndex) {
+      Preconditions.checkState(fieldValues.length == 
encodingPosToIndex.length);
       BitSet nullFields = new BitSet(fieldValues.length);
-      for (int idx = 0; idx < fieldValues.length; ++idx) {
-        if (fieldValues[idx] == null) {
-          nullFields.set(idx);
+      for (int encodingPos = 0; encodingPos < encodingPosToIndex.length; 
++encodingPos) {
+        int fieldIndex = encodingPosToIndex[encodingPos];
+        if (fieldValues[fieldIndex] == null) {
+          nullFields.set(encodingPos);
         }
       }
-
       return nullFields;
     }
   }
@@ -425,7 +520,7 @@ public abstract class RowCoderGenerator {
         // in which case we drop the extra fields.
         if (encodingPos < coders.length) {
           int rowIndex = encodingPosToIndex[encodingPos];
-          if (nullFields.get(rowIndex)) {
+          if (nullFields.get(encodingPos)) {
             fieldValues[rowIndex] = null;
           } else {
             Object fieldValue = coders[encodingPos].decode(inputStream);
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java
index 323f4e98dc5..b93b64f7dbe 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/SchemaCoder.java
@@ -164,7 +164,10 @@ public class SchemaCoder<T> extends CustomCoder<T> {
   }
 
   // Sets the schema id, and then recursively ensures that all schemas have 
ids set.
-  private static void setSchemaIds(Schema schema) {
+  private static void setSchemaIds(@Nullable Schema schema) {
+    if (schema == null) {
+      return;
+    }
     if (schema.getUUID() == null) {
       schema.setUUID(UUID.randomUUID());
     }
@@ -187,7 +190,7 @@ public class SchemaCoder<T> extends CustomCoder<T> {
         return;
 
       case ARRAY:
-      case ITERABLE:;
+      case ITERABLE:
         setSchemaIds(fieldType.getCollectionElementType());
         return;
 
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java
index f62a2611a1c..885ff8f1491 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/coders/RowCoderTest.java
@@ -22,10 +22,12 @@ import static org.junit.Assert.assertEquals;
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
 import java.math.BigDecimal;
+import java.nio.charset.StandardCharsets;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.UUID;
 import org.apache.beam.sdk.coders.Coder.NonDeterministicException;
 import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.schemas.Schema.FieldType;
@@ -37,6 +39,7 @@ import org.apache.beam.sdk.values.Row;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
+import org.checkerframework.checker.nullness.qual.NonNull;
 import org.joda.time.DateTime;
 import org.joda.time.DateTimeZone;
 import org.junit.Assume;
@@ -62,7 +65,7 @@ public class RowCoderTest {
             .build();
 
     DateTime dateTime =
-        new DateTime().withDate(1979, 03, 14).withTime(1, 2, 3, 
4).withZone(DateTimeZone.UTC);
+        new DateTime().withDate(1979, 3, 14).withTime(1, 2, 3, 
4).withZone(DateTimeZone.UTC);
     Row row =
         Row.withSchema(schema)
             .addValues(
@@ -219,12 +222,14 @@ public class RowCoderTest {
     }
 
     @Override
-    public Value toBaseType(String input) {
+    @NonNull
+    public Value toBaseType(@NonNull String input) {
       return enumeration.valueOf(input);
     }
 
     @Override
-    public String toInputType(Value base) {
+    @NonNull
+    public String toInputType(@NonNull Value base) {
       return enumeration.toString(base);
     }
   }
@@ -401,6 +406,129 @@ public class RowCoderTest {
     assertEquals(expected, decoded);
   }
 
+  @Test
+  public void testEncodingPositionReorderFieldsWithNulls() throws Exception {
+    Schema schema1 =
+        Schema.builder()
+            .addNullableField("f_int32", FieldType.INT32)
+            .addNullableField("f_string", FieldType.STRING)
+            .build();
+    Schema schema2 =
+        Schema.builder()
+            .addNullableField("f_string", FieldType.STRING)
+            .addNullableField("f_int32", FieldType.INT32)
+            .build();
+    schema2.setEncodingPositions(ImmutableMap.of("f_int32", 0, "f_string", 1));
+    Row schema1row =
+        Row.withSchema(schema1)
+            .withFieldValue("f_int32", null)
+            .withFieldValue("f_string", "hello world!")
+            .build();
+
+    Row schema2row =
+        Row.withSchema(schema2)
+            .withFieldValue("f_int32", null)
+            .withFieldValue("f_string", "hello world!")
+            .build();
+
+    ByteArrayOutputStream os = new ByteArrayOutputStream();
+    RowCoder.of(schema1).encode(schema1row, os);
+    Row schema1to2decoded = RowCoder.of(schema2).decode(new 
ByteArrayInputStream(os.toByteArray()));
+    assertEquals(schema2row, schema1to2decoded);
+
+    os.reset();
+    RowCoder.of(schema2).encode(schema2row, os);
+    Row schema2to1decoded = RowCoder.of(schema1).decode(new 
ByteArrayInputStream(os.toByteArray()));
+    assertEquals(schema1row, schema2to1decoded);
+  }
+
+  @Test
+  public void testEncodingPositionReorderViaStaticOverride() throws Exception {
+    Schema schema1 =
+        Schema.builder()
+            .addNullableField("failsafeTableRowPayload", FieldType.STRING)
+            .addByteArrayField("payload")
+            .addNullableField("timestamp", FieldType.INT32)
+            .addNullableField("unknownFieldsPayload", FieldType.STRING)
+            .build();
+    UUID uuid = UUID.randomUUID();
+    schema1.setUUID(uuid);
+
+    Row row =
+        Row.withSchema(schema1)
+            .addValues("", "hello world!".getBytes(StandardCharsets.UTF_8), 1, 
"")
+            .build();
+    ByteArrayOutputStream os = new ByteArrayOutputStream();
+    RowCoder.of(schema1).encode(row, os);
+    // Pretend that we are restarting and want to recover from persisted state 
with a compatible
+    // schema using the
+    // overridden encoding positions.
+    RowCoder.clearGeneratedRowCoders();
+    RowCoder.overrideEncodingPositions(
+        uuid,
+        ImmutableMap.of(
+            "failsafeTableRowPayload", 0, "payload", 1, "timestamp", 2, 
"unknownFieldsPayload", 3));
+
+    Schema schema2 =
+        Schema.builder()
+            .addByteArrayField("payload")
+            .addNullableField("timestamp", FieldType.INT32)
+            .addNullableField("unknownFieldsPayload", FieldType.STRING)
+            .addNullableField("failsafeTableRowPayload", FieldType.STRING)
+            .build();
+    schema2.setUUID(uuid);
+
+    Row expected =
+        Row.withSchema(schema2)
+            .addValues("hello world!".getBytes(StandardCharsets.UTF_8), 1, "", 
"")
+            .build();
+    Row decoded = RowCoder.of(schema2).decode(new 
ByteArrayInputStream(os.toByteArray()));
+    assertEquals(expected, decoded);
+  }
+
+  @Test
+  public void testEncodingPositionReorderViaStaticOverrideWithNulls() throws 
Exception {
+    Schema schema1 =
+        Schema.builder()
+            .addNullableField("failsafeTableRowPayload", FieldType.BYTES)
+            .addByteArrayField("payload")
+            .addNullableField("timestamp", FieldType.INT32)
+            .addNullableField("unknownFieldsPayload", FieldType.BYTES)
+            .build();
+    UUID uuid = UUID.randomUUID();
+    schema1.setUUID(uuid);
+
+    Row row =
+        Row.withSchema(schema1)
+            .addValues(null, "hello world!".getBytes(StandardCharsets.UTF_8), 
1, null)
+            .build();
+    ByteArrayOutputStream os = new ByteArrayOutputStream();
+    RowCoder.of(schema1).encode(row, os);
+    // Pretend that we are restarting and want to recover from persisted state 
with a compatible
+    // schema using the overridden encoding positions.
+    RowCoder.clearGeneratedRowCoders();
+    RowCoder.overrideEncodingPositions(
+        uuid,
+        ImmutableMap.of(
+            "failsafeTableRowPayload", 0, "payload", 1, "timestamp", 2, 
"unknownFieldsPayload", 3));
+
+    Schema schema2 =
+        Schema.builder()
+            .addByteArrayField("payload")
+            .addNullableField("timestamp", FieldType.INT32)
+            .addNullableField("unknownFieldsPayload", FieldType.BYTES)
+            .addNullableField("failsafeTableRowPayload", FieldType.BYTES)
+            .build();
+    schema2.setUUID(uuid);
+
+    Row expected =
+        Row.withSchema(schema2)
+            .addValues("hello world!".getBytes(StandardCharsets.UTF_8), 1, 
null, null)
+            .build();
+    Row decoded = RowCoder.of(schema2).decode(new 
ByteArrayInputStream(os.toByteArray()));
+    assertEquals(expected, decoded);
+  }
+
   @Test
   public void testEncodingPositionAddNewFields() throws Exception {
     Schema schema1 =

Reply via email to