This is an automated email from the ASF dual-hosted git repository.

chaokunyang pushed a commit to branch releases-0.11.0
in repository https://gitbox.apache.org/repos/asf/fory.git

commit bce9fa2c260a82d6b36e4c9684370353111db9f3
Author: Steven Schlansker <[email protected]>
AuthorDate: Fri Jun 6 11:02:54 2025 -0700

    feat(java): row encoder supports synthesizing interfaces nested inside of 
records (#2304)
    
    ## What does this PR do?
    
    Row encoder synthesized interfaces now work inside enclosing record
    classes.
    The previous logic was over-complicated to work around the overly
    aggressive expanding of types. With the fix in #2265 this is no longer
    necessary and we can simplify the logic around detecting interfaces to
    synthesize, in a way that better supports nesting too.
---
 .../fory/integration_tests/RecordRowTest.java      | 26 ++++++++++++++++++
 .../apache/fory/type/TypeResolutionContext.java    | 32 ++++++++--------------
 .../main/java/org/apache/fory/type/TypeUtils.java  | 18 ++++++------
 .../fory/format/encoder/ArrayDataForEach.java      |  5 +---
 .../format/encoder/BaseBinaryEncoderBuilder.java   | 26 +++---------------
 .../org/apache/fory/format/encoder/Encoders.java   |  4 ++-
 .../fory/format/encoder/RowEncoderBuilder.java     | 16 +++--------
 .../org/apache/fory/format/type/TypeInference.java | 13 ++-------
 8 files changed, 60 insertions(+), 80 deletions(-)

diff --git 
a/integration_tests/latest_jdk_tests/src/test/java/org/apache/fory/integration_tests/RecordRowTest.java
 
b/integration_tests/latest_jdk_tests/src/test/java/org/apache/fory/integration_tests/RecordRowTest.java
index a07bfaf1..99c61c64 100644
--- 
a/integration_tests/latest_jdk_tests/src/test/java/org/apache/fory/integration_tests/RecordRowTest.java
+++ 
b/integration_tests/latest_jdk_tests/src/test/java/org/apache/fory/integration_tests/RecordRowTest.java
@@ -60,4 +60,30 @@ public class RecordRowTest {
     final OuterTestRecord deserializedBean = encoder.fromRow(row);
     Assert.assertEquals(deserializedBean, bean);
   }
+
+  public record TestRecordNestedInterface(NestedInterface f1) {}
+
+  public interface NestedInterface {
+    int f1();
+
+    class Impl implements NestedInterface {
+      @Override
+      public int f1() {
+        return 42;
+      }
+    }
+  }
+
+  @Test
+  public void testRecordNestedInterface() {
+    final TestRecordNestedInterface bean =
+        new TestRecordNestedInterface(new NestedInterface.Impl());
+    final RowEncoder<TestRecordNestedInterface> encoder =
+        Encoders.bean(TestRecordNestedInterface.class);
+    final BinaryRow row = encoder.toRow(bean);
+    final MemoryBuffer buffer = MemoryUtils.wrap(row.toBytes());
+    row.pointTo(buffer, 0, buffer.size());
+    final TestRecordNestedInterface deserializedBean = encoder.fromRow(row);
+    Assert.assertEquals(deserializedBean.f1().f1(), bean.f1().f1());
+  }
 }
diff --git 
a/java/fory-core/src/main/java/org/apache/fory/type/TypeResolutionContext.java 
b/java/fory-core/src/main/java/org/apache/fory/type/TypeResolutionContext.java
index 2e1b5d55..14ce43d3 100644
--- 
a/java/fory-core/src/main/java/org/apache/fory/type/TypeResolutionContext.java
+++ 
b/java/fory-core/src/main/java/org/apache/fory/type/TypeResolutionContext.java
@@ -20,10 +20,7 @@
 package org.apache.fory.type;
 
 import java.util.Arrays;
-import java.util.Collections;
-import java.util.HashSet;
 import java.util.LinkedHashSet;
-import java.util.Set;
 import org.apache.fory.annotation.Internal;
 import org.apache.fory.reflect.TypeRef;
 
@@ -31,21 +28,26 @@ import org.apache.fory.reflect.TypeRef;
 public class TypeResolutionContext {
   private final CustomTypeRegistry customTypeRegistry;
   private final LinkedHashSet<TypeRef<?>> walkedTypePath;
-  private final Set<Class<?>> synthesizedBeanTypes;
+  private final boolean synthesizeInterfaces;
 
   public TypeResolutionContext(CustomTypeRegistry customTypeRegistry) {
+    this(customTypeRegistry, false);
+  }
+
+  public TypeResolutionContext(
+      CustomTypeRegistry customTypeRegistry, boolean synthesizeInterfaces) {
     this.customTypeRegistry = customTypeRegistry;
+    this.synthesizeInterfaces = synthesizeInterfaces;
     walkedTypePath = new LinkedHashSet<>();
-    synthesizedBeanTypes = Collections.emptySet();
   }
 
   public TypeResolutionContext(
       CustomTypeRegistry customTypeRegistry,
       LinkedHashSet<TypeRef<?>> walkedTypePath,
-      Set<Class<?>> synthesizedBeanTypes) {
+      boolean synthesizeInterfaces) {
     this.customTypeRegistry = customTypeRegistry;
     this.walkedTypePath = walkedTypePath;
-    this.synthesizedBeanTypes = synthesizedBeanTypes;
+    this.synthesizeInterfaces = synthesizeInterfaces;
   }
 
   public CustomTypeRegistry getCustomTypeRegistry() {
@@ -56,8 +58,8 @@ public class TypeResolutionContext {
     return walkedTypePath;
   }
 
-  public Set<Class<?>> getSynthesizedBeanTypes() {
-    return synthesizedBeanTypes;
+  public boolean isSynthesizeInterfaces() {
+    return synthesizeInterfaces;
   }
 
   public TypeRef<?> getEnclosingType() {
@@ -71,23 +73,13 @@ public class TypeResolutionContext {
   public TypeResolutionContext appendTypePath(TypeRef<?>... typeRef) {
     LinkedHashSet<TypeRef<?>> newWalkedTypePath = new 
LinkedHashSet<>(walkedTypePath);
     newWalkedTypePath.addAll(Arrays.asList(typeRef));
-    return new TypeResolutionContext(customTypeRegistry, newWalkedTypePath, 
synthesizedBeanTypes);
+    return new TypeResolutionContext(customTypeRegistry, newWalkedTypePath, 
synthesizeInterfaces);
   }
 
   public TypeResolutionContext appendTypePath(Class<?> clz) {
     return appendTypePath(TypeRef.of(clz));
   }
 
-  public TypeResolutionContext withSynthesizedBeanType(Class<?> clz) {
-    Set<Class<?>> newSynthesizedBeanTypes = new 
HashSet<>(synthesizedBeanTypes);
-    newSynthesizedBeanTypes.add(clz);
-    return new TypeResolutionContext(customTypeRegistry, walkedTypePath, 
newSynthesizedBeanTypes);
-  }
-
-  public boolean isSynthesizedBeanType(Class<?> cls) {
-    return synthesizedBeanTypes.contains(cls);
-  }
-
   public void checkNoCycle(Class<?> clz) {
     checkNoCycle(TypeRef.of(clz));
   }
diff --git a/java/fory-core/src/main/java/org/apache/fory/type/TypeUtils.java 
b/java/fory-core/src/main/java/org/apache/fory/type/TypeUtils.java
index 15b8d9ca..23785a5c 100644
--- a/java/fory-core/src/main/java/org/apache/fory/type/TypeUtils.java
+++ b/java/fory-core/src/main/java/org/apache/fory/type/TypeUtils.java
@@ -599,7 +599,7 @@ public class TypeUtils {
 
   public static boolean isBean(TypeRef<?> typeRef, TypeResolutionContext ctx) {
     Class<?> cls = getRawType(typeRef);
-    if (ctx.isSynthesizedBeanType(cls) || RecordUtils.isRecord(cls)) {
+    if (ctx.isSynthesizeInterfaces() && (RecordUtils.isRecord(cls) || 
(cls.isInterface()))) {
       return true;
     }
     if (Modifier.isAbstract(cls.getModifiers()) || 
Modifier.isInterface(cls.getModifiers())) {
@@ -702,12 +702,14 @@ public class TypeUtils {
   public static LinkedHashSet<Class<?>> listBeansRecursiveInclusive(
       Class<?> beanClass, CustomTypeRegistry customTypes) {
     TypeResolutionContext ctx = new TypeResolutionContext(customTypes);
-    if (beanClass.isInterface()) {
-      ctx = ctx.withSynthesizedBeanType(beanClass);
-    }
     return listBeansRecursiveInclusive(TypeRef.of(beanClass), ctx);
   }
 
+  public static LinkedHashSet<Class<?>> listBeansRecursiveInclusive(
+      Class<?> beanClass, TypeResolutionContext typeCtx) {
+    return listBeansRecursiveInclusive(TypeRef.of(beanClass), typeCtx);
+  }
+
   private static LinkedHashSet<Class<?>> listBeansRecursiveInclusive(
       TypeRef<?> typeRef, TypeResolutionContext ctx) {
     LinkedHashSet<Class<?>> beans = new LinkedHashSet<>();
@@ -735,12 +737,8 @@ public class TypeUtils {
       beans.add(type);
       for (Descriptor descriptor : descriptors) {
         ctx.checkNoCycle(typeRef);
-        TypeRef<?> propertyTypeRef = descriptor.getTypeRef();
-        Class<?> propertyType = propertyTypeRef.getRawType();
-        if (propertyType.isInterface()) {
-          newCtx = newCtx.withSynthesizedBeanType(propertyType);
-        }
-        beans.addAll(listBeansRecursiveInclusive(propertyTypeRef, 
newCtx.appendTypePath(typeRef)));
+        beans.addAll(
+            listBeansRecursiveInclusive(descriptor.getTypeRef(), 
newCtx.appendTypePath(typeRef)));
       }
     }
     return beans;
diff --git 
a/java/fory-format/src/main/java/org/apache/fory/format/encoder/ArrayDataForEach.java
 
b/java/fory-format/src/main/java/org/apache/fory/format/encoder/ArrayDataForEach.java
index 3f8a1bd2..ed98ef3e 100644
--- 
a/java/fory-format/src/main/java/org/apache/fory/format/encoder/ArrayDataForEach.java
+++ 
b/java/fory-format/src/main/java/org/apache/fory/format/encoder/ArrayDataForEach.java
@@ -88,10 +88,7 @@ public class ArrayDataForEach extends AbstractExpression {
       accessType = TypeRef.of(customEncoder.encodedType());
     }
     CustomTypeHandler customTypeHandler = 
CustomTypeEncoderRegistry.customTypeHandler();
-    TypeResolutionContext ctx = new TypeResolutionContext(customTypeHandler);
-    if (inputArrayData.type().getRawType().isInterface() && 
elemType.getRawType().isInterface()) {
-      ctx = ctx.withSynthesizedBeanType(elemType.getRawType());
-    }
+    TypeResolutionContext ctx = new TypeResolutionContext(customTypeHandler, 
true);
     this.accessMethod = BinaryUtils.getElemAccessMethodName(accessType, ctx);
     this.elemType = BinaryUtils.getElemReturnType(accessType, ctx);
     this.notNullAction = notNullAction;
diff --git 
a/java/fory-format/src/main/java/org/apache/fory/format/encoder/BaseBinaryEncoderBuilder.java
 
b/java/fory-format/src/main/java/org/apache/fory/format/encoder/BaseBinaryEncoderBuilder.java
index 41a0448e..79f369e9 100644
--- 
a/java/fory-format/src/main/java/org/apache/fory/format/encoder/BaseBinaryEncoderBuilder.java
+++ 
b/java/fory-format/src/main/java/org/apache/fory/format/encoder/BaseBinaryEncoderBuilder.java
@@ -110,11 +110,8 @@ public abstract class BaseBinaryEncoderBuilder extends 
CodecBuilder {
     ctx.addImport(BinaryRow.class.getPackage().getName() + ".*");
     ctx.addImport(BinaryWriter.class.getPackage().getName() + ".*");
     ctx.addImport(Schema.class.getPackage().getName() + ".*");
-    TypeResolutionContext typeCtx = new 
TypeResolutionContext(customTypeHandler);
+    TypeResolutionContext typeCtx = new 
TypeResolutionContext(customTypeHandler, true);
     typeCtx.appendTypePath(beanClass);
-    if (beanClass.isInterface()) {
-      typeCtx = typeCtx.withSynthesizedBeanType(beanClass);
-    }
     this.typeCtx = typeCtx;
   }
 
@@ -254,7 +251,7 @@ public abstract class BaseBinaryEncoderBuilder extends 
CodecBuilder {
           expression);
     } else if (TypeUtils.MAP_TYPE.isSupertypeOf(typeRef)) {
       return serializeForMap(ordinal, writer, inputObject, typeRef, 
arrowField);
-    } else if (TypeUtils.isBean(rawType, createElementTypeContext(typeRef))) {
+    } else if (TypeUtils.isBean(rawType, typeCtx)) {
       return serializeForBean(ordinal, writer, inputObject, typeRef, 
arrowField);
     } else if (rawType == BinaryArray.class) {
       Invoke writeExp =
@@ -647,11 +644,7 @@ public abstract class BaseBinaryEncoderBuilder extends 
CodecBuilder {
           new ArrayDataForEach(
               arrayData,
               elemType,
-              (i, value) ->
-                  new Invoke(
-                      collection,
-                      "add",
-                      deserializeFor(value, elemType, 
createElementTypeContext(elemType))),
+              (i, value) -> new Invoke(collection, "add", 
deserializeFor(value, elemType, typeCtx)),
               i -> new Invoke(collection, "add", 
ExpressionUtils.nullValue(elemType)));
       return new ListExpression(collection, addElemsOp, collection);
     } catch (Exception e) {
@@ -807,8 +800,7 @@ public abstract class BaseBinaryEncoderBuilder extends 
CodecBuilder {
                 arrayData,
                 elemType,
                 (i, value) -> {
-                  Expression elemValue =
-                      deserializeFor(value, elemType, 
createElementTypeContext(elemType));
+                  Expression elemValue = deserializeFor(value, elemType, 
typeCtx);
                   return new AssignArrayElem(javaArray, elemValue, i);
                 });
         // add javaArray at last as expression value
@@ -824,14 +816,4 @@ public abstract class BaseBinaryEncoderBuilder extends 
CodecBuilder {
   protected Expression deserializeForObject(Expression value, TypeRef<?> 
typeRef) {
     return new Invoke(foryRef, "deserialize", typeRef, value);
   }
-
-  protected TypeResolutionContext createElementTypeContext(TypeRef<?> 
elemType) {
-    TypeResolutionContext newTypeCtx;
-    if (elemType.isInterface() && beanClass.isInterface()) {
-      newTypeCtx = typeCtx.withSynthesizedBeanType(elemType.getRawType());
-    } else {
-      newTypeCtx = typeCtx;
-    }
-    return newTypeCtx;
-  }
 }
diff --git 
a/java/fory-format/src/main/java/org/apache/fory/format/encoder/Encoders.java 
b/java/fory-format/src/main/java/org/apache/fory/format/encoder/Encoders.java
index 8a7bd6b3..020167f6 100644
--- 
a/java/fory-format/src/main/java/org/apache/fory/format/encoder/Encoders.java
+++ 
b/java/fory-format/src/main/java/org/apache/fory/format/encoder/Encoders.java
@@ -49,6 +49,7 @@ import org.apache.fory.logging.LoggerFactory;
 import org.apache.fory.memory.MemoryBuffer;
 import org.apache.fory.memory.MemoryUtils;
 import org.apache.fory.reflect.TypeRef;
+import org.apache.fory.type.TypeResolutionContext;
 import org.apache.fory.type.TypeUtils;
 
 /**
@@ -679,7 +680,8 @@ public class Encoders {
   public static Class<?> loadOrGenRowCodecClass(Class<?> beanClass) {
     Set<Class<?>> classes =
         TypeUtils.listBeansRecursiveInclusive(
-            beanClass, CustomTypeEncoderRegistry.customTypeHandler());
+            beanClass,
+            new 
TypeResolutionContext(CustomTypeEncoderRegistry.customTypeHandler(), true));
     LOG.info("Create RowCodec for classes {}", classes);
     CompileUnit[] compileUnits =
         classes.stream()
diff --git 
a/java/fory-format/src/main/java/org/apache/fory/format/encoder/RowEncoderBuilder.java
 
b/java/fory-format/src/main/java/org/apache/fory/format/encoder/RowEncoderBuilder.java
index 36d13806..cf351365 100644
--- 
a/java/fory-format/src/main/java/org/apache/fory/format/encoder/RowEncoderBuilder.java
+++ 
b/java/fory-format/src/main/java/org/apache/fory/format/encoder/RowEncoderBuilder.java
@@ -51,7 +51,6 @@ import org.apache.fory.logging.Logger;
 import org.apache.fory.logging.LoggerFactory;
 import org.apache.fory.reflect.TypeRef;
 import org.apache.fory.type.Descriptor;
-import org.apache.fory.type.TypeResolutionContext;
 import org.apache.fory.type.TypeUtils;
 import org.apache.fory.util.GraalvmSupport;
 import org.apache.fory.util.Preconditions;
@@ -80,8 +79,7 @@ public class RowEncoderBuilder extends 
BaseBinaryEncoderBuilder {
 
   public RowEncoderBuilder(TypeRef<?> beanType) {
     super(new CodegenContext(), beanType);
-    Preconditions.checkArgument(
-        beanClass.isInterface() || TypeUtils.isBean(beanType.getType(), 
customTypeHandler));
+    Preconditions.checkArgument(beanClass.isInterface() || 
TypeUtils.isBean(beanType, typeCtx));
     className = codecClassName(beanClass);
     this.schema = TypeInference.inferSchema(getRawType(beanType));
     this.descriptorsMap = Descriptor.getDescriptorsMap(beanClass);
@@ -278,12 +276,6 @@ public class RowEncoderBuilder extends 
BaseBinaryEncoderBuilder {
       Descriptor d = 
getDescriptorByFieldName(schema.getFields().get(i).getName());
       TypeRef<?> fieldType = d.getTypeRef();
       Class<?> rawFieldType = fieldType.getRawType();
-      TypeResolutionContext fieldCtx;
-      if (beanClass.isInterface() && rawFieldType.isInterface()) {
-        fieldCtx = typeCtx.withSynthesizedBeanType(rawFieldType);
-      } else {
-        fieldCtx = typeCtx;
-      }
       TypeRef<?> columnAccessType;
       if (rawFieldType == Optional.class) {
         columnAccessType = TypeUtils.getTypeArguments(fieldType).get(0);
@@ -296,8 +288,8 @@ public class RowEncoderBuilder extends 
BaseBinaryEncoderBuilder {
         }
       }
       String columnAccessMethodName =
-          BinaryUtils.getElemAccessMethodName(columnAccessType, fieldCtx);
-      TypeRef<?> colType = BinaryUtils.getElemReturnType(columnAccessType, 
fieldCtx);
+          BinaryUtils.getElemAccessMethodName(columnAccessType, typeCtx);
+      TypeRef<?> colType = BinaryUtils.getElemReturnType(columnAccessType, 
typeCtx);
       Expression.Invoke columnValue =
           new Expression.Invoke(
               row,
@@ -306,7 +298,7 @@ public class RowEncoderBuilder extends 
BaseBinaryEncoderBuilder {
               colType,
               false,
               ordinal);
-      Expression value = new Expression.Return(deserializeFor(columnValue, 
fieldType, fieldCtx));
+      Expression value = new Expression.Return(deserializeFor(columnValue, 
fieldType, typeCtx));
       ctx.addMethod(
           decodeMethodName(i),
           value.doGenCode(ctx).code(),
diff --git 
a/java/fory-format/src/main/java/org/apache/fory/format/type/TypeInference.java 
b/java/fory-format/src/main/java/org/apache/fory/format/type/TypeInference.java
index 872ba72f..a3196017 100644
--- 
a/java/fory-format/src/main/java/org/apache/fory/format/type/TypeInference.java
+++ 
b/java/fory-format/src/main/java/org/apache/fory/format/type/TypeInference.java
@@ -121,11 +121,7 @@ public class TypeInference {
 
   private static Field inferField(TypeRef<?> arrayTypeRef, TypeRef<?> typeRef) 
{
     TypeResolutionContext ctx =
-        new 
TypeResolutionContext(CustomTypeEncoderRegistry.customTypeHandler());
-    Class<?> clz = getRawType(typeRef);
-    if (clz.isInterface()) {
-      ctx = ctx.withSynthesizedBeanType(clz);
-    }
+        new 
TypeResolutionContext(CustomTypeEncoderRegistry.customTypeHandler(), true);
     String name = "";
     if (arrayTypeRef != null) {
       Field f = inferField(DataTypes.ARRAY_ITEM_NAME, typeRef, ctx);
@@ -239,13 +235,8 @@ public class TypeInference {
               .map(
                   descriptor -> {
                     String n = 
StringUtils.lowerCamelToLowerUnderscore(descriptor.getName());
-                    TypeResolutionContext newCtx = ctx.appendTypePath(rawType);
                     TypeRef<?> fieldType = descriptor.getTypeRef();
-                    Class<?> rawFieldType = getRawType(fieldType);
-                    if (rawFieldType.isInterface()) {
-                      newCtx = newCtx.withSynthesizedBeanType(rawFieldType);
-                    }
-                    return inferField(n, fieldType, newCtx);
+                    return inferField(n, fieldType, 
ctx.appendTypePath(rawType));
                   })
               .collect(Collectors.toList());
       return DataTypes.structField(name, true, fields);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to