This is an automated email from the ASF dual-hosted git repository.

nielsbasjes pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/avro.git


The following commit(s) were added to refs/heads/master by this push:
     new fa0bb7098 AVRO-3717: [Java] Fix NPE when basic type with Nullable 
annotation.
fa0bb7098 is described below

commit fa0bb7098083aba41c9aa9d9cf6383eb3e2c2696
Author: Yan Zhao <[email protected]>
AuthorDate: Sat Feb 18 23:41:56 2023 +0800

    AVRO-3717: [Java] Fix NPE when basic type with Nullable annotation.
---
 .../java/org/apache/avro/reflect/FieldAccess.java  |  16 ++
 .../apache/avro/reflect/FieldAccessReflect.java    |  24 ++-
 .../org/apache/avro/reflect/FieldAccessUnsafe.java |  16 +-
 .../avro/reflect/TestReflectDatumReader.java       | 209 +++++++++++++++++++++
 4 files changed, 256 insertions(+), 9 deletions(-)

diff --git 
a/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccess.java 
b/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccess.java
index 961884951..dce1aed98 100644
--- a/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccess.java
+++ b/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccess.java
@@ -21,6 +21,22 @@ import java.lang.reflect.Field;
 
 abstract class FieldAccess {
 
+  protected static final int INT_DEFAULT_VALUE = 0;
+
+  protected static final float FLOAT_DEFAULT_VALUE = 0.0f;
+
+  protected static final short SHORT_DEFAULT_VALUE = (short) 0;
+
+  protected static final byte BYTE_DEFAULT_VALUE = (byte) 0;
+
+  protected static final boolean BOOLEAN_DEFAULT_VALUE = false;
+
+  protected static final char CHAR_DEFAULT_VALUE = '\u0000';
+
+  protected static final long LONG_DEFAULT_VALUE = 0L;
+
+  protected static final double DOUBLE_DEFAULT_VALUE = 0.0d;
+
   protected abstract FieldAccessor getAccessor(Field field);
 
 }
diff --git 
a/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccessReflect.java 
b/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccessReflect.java
index c790dbfb8..5d51be054 100644
--- 
a/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccessReflect.java
+++ 
b/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccessReflect.java
@@ -62,7 +62,29 @@ class FieldAccessReflect extends FieldAccess {
 
     @Override
     public void set(Object object, Object value) throws 
IllegalAccessException, IOException {
-      field.set(object, value);
+      if (value == null && field.getType().isPrimitive()) {
+        Object defaultValue = null;
+        if (int.class.equals(field.getType())) {
+          defaultValue = INT_DEFAULT_VALUE;
+        } else if (float.class.equals(field.getType())) {
+          defaultValue = FLOAT_DEFAULT_VALUE;
+        } else if (short.class.equals(field.getType())) {
+          defaultValue = SHORT_DEFAULT_VALUE;
+        } else if (byte.class.equals(field.getType())) {
+          defaultValue = BYTE_DEFAULT_VALUE;
+        } else if (boolean.class.equals(field.getType())) {
+          defaultValue = BOOLEAN_DEFAULT_VALUE;
+        } else if (char.class.equals(field.getType())) {
+          defaultValue = CHAR_DEFAULT_VALUE;
+        } else if (long.class.equals(field.getType())) {
+          defaultValue = LONG_DEFAULT_VALUE;
+        } else if (double.class.equals(field.getType())) {
+          defaultValue = DOUBLE_DEFAULT_VALUE;
+        }
+        field.set(object, defaultValue);
+      } else {
+        field.set(object, value);
+      }
     }
 
     @Override
diff --git 
a/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccessUnsafe.java 
b/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccessUnsafe.java
index f555df49a..a2c5c4e1b 100644
--- 
a/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccessUnsafe.java
+++ 
b/lang/java/avro/src/main/java/org/apache/avro/reflect/FieldAccessUnsafe.java
@@ -106,7 +106,7 @@ class FieldAccessUnsafe extends FieldAccess {
 
     @Override
     protected void set(Object object, Object value) {
-      UNSAFE.putInt(object, offset, (Integer) value);
+      UNSAFE.putInt(object, offset, value == null ? INT_DEFAULT_VALUE : 
(Integer) value);
     }
 
     @Override
@@ -132,7 +132,7 @@ class FieldAccessUnsafe extends FieldAccess {
 
     @Override
     protected void set(Object object, Object value) {
-      UNSAFE.putFloat(object, offset, (Float) value);
+      UNSAFE.putFloat(object, offset, value == null ? FLOAT_DEFAULT_VALUE : 
(Float) value);
     }
 
     @Override
@@ -158,7 +158,7 @@ class FieldAccessUnsafe extends FieldAccess {
 
     @Override
     protected void set(Object object, Object value) {
-      UNSAFE.putShort(object, offset, (Short) value);
+      UNSAFE.putShort(object, offset, value == null ? SHORT_DEFAULT_VALUE : 
(Short) value);
     }
 
     @Override
@@ -184,7 +184,7 @@ class FieldAccessUnsafe extends FieldAccess {
 
     @Override
     protected void set(Object object, Object value) {
-      UNSAFE.putByte(object, offset, (Byte) value);
+      UNSAFE.putByte(object, offset, value == null ? BYTE_DEFAULT_VALUE : 
(Byte) value);
     }
 
     @Override
@@ -210,7 +210,7 @@ class FieldAccessUnsafe extends FieldAccess {
 
     @Override
     protected void set(Object object, Object value) {
-      UNSAFE.putBoolean(object, offset, (Boolean) value);
+      UNSAFE.putBoolean(object, offset, value == null ? BOOLEAN_DEFAULT_VALUE 
: (Boolean) value);
     }
 
     @Override
@@ -236,7 +236,7 @@ class FieldAccessUnsafe extends FieldAccess {
 
     @Override
     protected void set(Object object, Object value) {
-      UNSAFE.putChar(object, offset, (Character) value);
+      UNSAFE.putChar(object, offset, value == null ? CHAR_DEFAULT_VALUE : 
(Character) value);
     }
 
     @Override
@@ -262,7 +262,7 @@ class FieldAccessUnsafe extends FieldAccess {
 
     @Override
     protected void set(Object object, Object value) {
-      UNSAFE.putLong(object, offset, (Long) value);
+      UNSAFE.putLong(object, offset, value == null ? LONG_DEFAULT_VALUE : 
(Long) value);
     }
 
     @Override
@@ -288,7 +288,7 @@ class FieldAccessUnsafe extends FieldAccess {
 
     @Override
     protected void set(Object object, Object value) {
-      UNSAFE.putDouble(object, offset, (Double) value);
+      UNSAFE.putDouble(object, offset, value == null ? DOUBLE_DEFAULT_VALUE : 
(Double) value);
     }
 
     @Override
diff --git 
a/lang/java/avro/src/test/java/org/apache/avro/reflect/TestReflectDatumReader.java
 
b/lang/java/avro/src/test/java/org/apache/avro/reflect/TestReflectDatumReader.java
index 65d01307e..52b40b87b 100644
--- 
a/lang/java/avro/src/test/java/org/apache/avro/reflect/TestReflectDatumReader.java
+++ 
b/lang/java/avro/src/test/java/org/apache/avro/reflect/TestReflectDatumReader.java
@@ -30,6 +30,7 @@ import java.util.Set;
 import java.util.Map;
 import java.util.Optional;
 
+import org.apache.avro.Schema;
 import org.apache.avro.io.Decoder;
 import org.apache.avro.io.DecoderFactory;
 import org.apache.avro.io.Encoder;
@@ -160,6 +161,36 @@ public class TestReflectDatumReader {
     assertEquals(pojoWithOptional, deserialized);
   }
 
+  @Test
+  public void testRead_PojoWithNullableAnnotation() throws IOException {
+    PojoWithBasicTypeNullableAnnotationV1 v1Pojo = new 
PojoWithBasicTypeNullableAnnotationV1();
+    int idValue = 1;
+    v1Pojo.setId(idValue);
+    byte[] serializedBytes = serializeWithReflectDatumWriter(v1Pojo, 
PojoWithBasicTypeNullableAnnotationV1.class);
+    Decoder decoder = DecoderFactory.get().binaryDecoder(serializedBytes, 
null);
+
+    ReflectData reflectData = ReflectData.get();
+    Schema schemaV1 = 
reflectData.getSchema(PojoWithBasicTypeNullableAnnotationV1.class);
+    Schema schemaV2 = 
reflectData.getSchema(PojoWithBasicTypeNullableAnnotationV2.class);
+
+    ReflectDatumReader<PojoWithBasicTypeNullableAnnotationV2> 
reflectDatumReader = new ReflectDatumReader<>(schemaV1,
+        schemaV2);
+
+    PojoWithBasicTypeNullableAnnotationV2 v2Pojo = new 
PojoWithBasicTypeNullableAnnotationV2();
+    reflectDatumReader.read(v2Pojo, decoder);
+
+    assertEquals(v1Pojo.id, v2Pojo.id);
+    assertEquals(v2Pojo.id, idValue);
+    assertEquals(v2Pojo.intId, FieldAccess.INT_DEFAULT_VALUE);
+    assertEquals(v2Pojo.floatId, FieldAccess.FLOAT_DEFAULT_VALUE);
+    assertEquals(v2Pojo.shortId, FieldAccess.SHORT_DEFAULT_VALUE);
+    assertEquals(v2Pojo.byteId, FieldAccess.BYTE_DEFAULT_VALUE);
+    assertEquals(v2Pojo.booleanId, FieldAccess.BOOLEAN_DEFAULT_VALUE);
+    assertEquals(v2Pojo.charId, FieldAccess.CHAR_DEFAULT_VALUE);
+    assertEquals(v2Pojo.longId, FieldAccess.LONG_DEFAULT_VALUE);
+    assertEquals(v2Pojo.doubleId, FieldAccess.DOUBLE_DEFAULT_VALUE);
+  }
+
   public static class PojoWithList {
     private int id;
     private List<Integer> relatedIds;
@@ -392,4 +423,182 @@ public class TestReflectDatumReader {
         return relatedId.equals(other.relatedId);
     }
   }
+
+  public static class PojoWithBasicTypeNullableAnnotationV1 {
+
+    private int id;
+
+    public int getId() {
+      return id;
+    }
+
+    public void setId(int id) {
+      this.id = id;
+    }
+
+    @Override
+    public int hashCode() {
+      final int prime = 31;
+      int result = 1;
+      result = prime * result + id;
+      return result;
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+      if (this == obj)
+        return true;
+      if (obj == null)
+        return false;
+      if (getClass() != obj.getClass())
+        return false;
+      PojoWithBasicTypeNullableAnnotationV1 other = 
(PojoWithBasicTypeNullableAnnotationV1) obj;
+      return id == other.id;
+    }
+  }
+
+  public static class PojoWithBasicTypeNullableAnnotationV2 {
+
+    private int id;
+
+    @Nullable
+    private int intId;
+
+    @Nullable
+    private float floatId;
+
+    @Nullable
+    private short shortId;
+
+    @Nullable
+    private byte byteId;
+
+    @Nullable
+    private boolean booleanId;
+
+    @Nullable
+    private char charId;
+
+    @Nullable
+    private long longId;
+
+    @Nullable
+    private double doubleId;
+
+    public int getId() {
+      return id;
+    }
+
+    public void setId(int id) {
+      this.id = id;
+    }
+
+    public int getIntId() {
+      return intId;
+    }
+
+    public void setIntId(int intId) {
+      this.intId = intId;
+    }
+
+    public float getFloatId() {
+      return floatId;
+    }
+
+    public void setFloatId(float floatId) {
+      this.floatId = floatId;
+    }
+
+    public short getShortId() {
+      return shortId;
+    }
+
+    public void setShortId(short shortId) {
+      this.shortId = shortId;
+    }
+
+    public byte getByteId() {
+      return byteId;
+    }
+
+    public void setByteId(byte byteId) {
+      this.byteId = byteId;
+    }
+
+    public boolean isBooleanId() {
+      return booleanId;
+    }
+
+    public void setBooleanId(boolean booleanId) {
+      this.booleanId = booleanId;
+    }
+
+    public char getCharId() {
+      return charId;
+    }
+
+    public void setCharId(char charId) {
+      this.charId = charId;
+    }
+
+    public long getLongId() {
+      return longId;
+    }
+
+    public void setLongId(long longId) {
+      this.longId = longId;
+    }
+
+    public double getDoubleId() {
+      return doubleId;
+    }
+
+    public void setDoubleId(double doubleId) {
+      this.doubleId = doubleId;
+    }
+
+    @Override
+    public int hashCode() {
+      final int prime = 31;
+      long temp;
+      int result = 1;
+      result = prime * result + id;
+      result = prime * result + intId;
+      result = prime * result + (floatId != 0.0f ? 
Float.floatToIntBits(floatId) : 0);
+      result = prime * result + (int) shortId;
+      result = prime * result + (int) byteId;
+      result = prime * result + (booleanId ? 1 : 0);
+      result = prime * result + (int) charId;
+      result = prime * result + (int) (longId ^ (longId >>> 32));
+      temp = Double.doubleToLongBits(doubleId);
+      result = 31 * result + (int) (temp ^ (temp >>> 32));
+      return result;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o)
+        return true;
+      if (o == null || getClass() != o.getClass())
+        return false;
+      PojoWithBasicTypeNullableAnnotationV2 that = 
(PojoWithBasicTypeNullableAnnotationV2) o;
+      if (id != that.id)
+        return false;
+      if (intId != that.intId)
+        return false;
+      if (Float.compare(that.floatId, floatId) != 0)
+        return false;
+      if (shortId != that.shortId)
+        return false;
+      if (byteId != that.byteId)
+        return false;
+      if (booleanId != that.booleanId)
+        return false;
+      if (charId != that.charId)
+        return false;
+      if (longId != that.longId)
+        return false;
+      return Double.compare(that.doubleId, doubleId) == 0;
+    }
+  }
 }

Reply via email to