This is an automated email from the ASF dual-hosted git repository.

chaokunyang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/fory.git


The following commit(s) were added to refs/heads/main by this push:
     new 81842f93 feat(java): row encoder supports custom rewriting values 
without changing their type (#2305)
81842f93 is described below

commit 81842f939ddfd2e75ac0d28f9b7af499a5a493dc
Author: Steven Schlansker <[email protected]>
AuthorDate: Sat Jun 7 10:30:24 2025 -0700

    feat(java): row encoder supports custom rewriting values without changing 
their type (#2305)
    
    ## What does this PR do?
    
    add support for row encoder CustomCodec to read or write replace values
    
    we use this to e.g. replace Fory generated interface types with our own
    implementation, generated via Immutables ( https://immutables.github.io/
    )
---
 .../org/apache/fory/type/CustomTypeRegistry.java   |  14 ++-
 .../main/java/org/apache/fory/type/TypeUtils.java  |  62 ++++++----
 .../fory/format/encoder/ArrayDataForEach.java      |   2 +-
 .../fory/format/encoder/ArrayEncoderBuilder.java   |   4 +-
 .../format/encoder/BaseBinaryEncoderBuilder.java   | 126 +++++++++++++--------
 .../apache/fory/format/encoder/CustomCodec.java    |  47 ++++++--
 .../fory/format/encoder/RowEncoderBuilder.java     |  95 +++++++++-------
 .../format/type/CustomTypeEncoderRegistry.java     |   4 +-
 .../apache/fory/format/type/CustomTypeHandler.java |  14 ++-
 .../org/apache/fory/format/type/TypeInference.java |  29 +++--
 .../fory/format/encoder/CustomCodecTest.java       |  92 ++++++++++++---
 11 files changed, 333 insertions(+), 156 deletions(-)

diff --git 
a/java/fory-core/src/main/java/org/apache/fory/type/CustomTypeRegistry.java 
b/java/fory-core/src/main/java/org/apache/fory/type/CustomTypeRegistry.java
index 69edb5f7..bfbe4cf2 100644
--- a/java/fory-core/src/main/java/org/apache/fory/type/CustomTypeRegistry.java
+++ b/java/fory-core/src/main/java/org/apache/fory/type/CustomTypeRegistry.java
@@ -20,14 +20,15 @@
 package org.apache.fory.type;
 
 import org.apache.fory.annotation.Internal;
+import org.apache.fory.reflect.TypeRef;
 
 @Internal
 public interface CustomTypeRegistry {
   CustomTypeRegistry EMPTY =
       new CustomTypeRegistry() {
         @Override
-        public boolean hasCodec(final Class<?> beanType, final Class<?> 
fieldType) {
-          return false;
+        public TypeRef<?> replacementTypeFor(final Class<?> beanType, final 
Class<?> fieldType) {
+          return null;
         }
 
         @Override
@@ -35,9 +36,16 @@ public interface CustomTypeRegistry {
             final Class<?> collectionType, final Class<?> elementType) {
           return false;
         }
+
+        @Override
+        public boolean isExtraSupportedType(final TypeRef<?> type) {
+          return false;
+        }
       };
 
-  boolean hasCodec(Class<?> beanType, Class<?> fieldType);
+  TypeRef<?> replacementTypeFor(Class<?> beanType, Class<?> fieldType);
 
   boolean canConstructCollection(Class<?> collectionType, Class<?> 
elementType);
+
+  boolean isExtraSupportedType(TypeRef<?> type);
 }
diff --git a/java/fory-core/src/main/java/org/apache/fory/type/TypeUtils.java 
b/java/fory-core/src/main/java/org/apache/fory/type/TypeUtils.java
index 23785a5c..1c7b3946 100644
--- a/java/fory-core/src/main/java/org/apache/fory/type/TypeUtils.java
+++ b/java/fory-core/src/main/java/org/apache/fory/type/TypeUtils.java
@@ -605,6 +605,10 @@ public class TypeUtils {
     if (Modifier.isAbstract(cls.getModifiers()) || 
Modifier.isInterface(cls.getModifiers())) {
       return false;
     }
+    if (ctx.getWalkedTypePath().contains(typeRef)
+        || ctx.getCustomTypeRegistry().isExtraSupportedType(typeRef)) {
+      return false;
+    }
     // since we need to access class in generated code in our package, the 
class must be public
     // if ReflectionUtils.hasNoArgConstructor(cls) return false, we use Unsafe 
to create object.
     if (Modifier.isPublic(cls.getModifiers())) {
@@ -624,17 +628,21 @@ public class TypeUtils {
               && !ITERABLE_TYPE.isSupertypeOf(typeRef)
               && !MAP_TYPE.isSupertypeOf(typeRef);
       if (maybe) {
-        return Descriptor.getDescriptors(cls).stream()
-            .allMatch(
-                d -> {
-                  TypeRef<?> t = d.getTypeRef();
-                  // do field modifiers and getter/setter validation here, not 
in getDescriptors.
-                  // If Modifier.isFinal(d.getModifiers()), use reflection
-                  // private field that doesn't have getter/setter will be 
handled by reflection.
-                  return ctx.getCustomTypeRegistry().hasCodec(cls, 
t.getRawType())
-                      || isSupported(t, newTypePath)
-                      || isBean(t, newTypePath);
-                });
+        for (Descriptor d : Descriptor.getDescriptors(cls)) {
+          TypeRef<?> t = d.getTypeRef();
+          // do field modifiers and getter/setter validation here, not in 
getDescriptors.
+          // If Modifier.isFinal(d.getModifiers()), use reflection
+          // private field that doesn't have getter/setter will be handled by 
reflection.
+          TypeRef<?> replacementType =
+              ctx.getCustomTypeRegistry().replacementTypeFor(cls, 
t.getRawType());
+          if (replacementType != null) {
+            t = replacementType;
+          }
+          if (!isSupported(t, newTypePath)) {
+            return false;
+          }
+        }
+        return true;
       } else {
         return false;
       }
@@ -658,10 +666,16 @@ public class TypeUtils {
       // box.
       return true;
     }
-    if (SUPPORTED_TYPES.contains(typeRef)) {
+    TypeRef<?> replacementType =
+        ctx.getCustomTypeRegistry()
+            .replacementTypeFor(ctx.getEnclosingType().getRawType(), 
typeRef.getRawType());
+    if (replacementType != null) {
+      return isSupported(replacementType, ctx);
+    } else if (SUPPORTED_TYPES.contains(typeRef)
+        || ctx.getCustomTypeRegistry().isExtraSupportedType(typeRef)) {
       return true;
     } else if (typeRef.isArray()) {
-      return isSupported(Objects.requireNonNull(typeRef.getComponentType()));
+      return isSupported(Objects.requireNonNull(typeRef.getComponentType()), 
ctx);
     } else if (ITERABLE_TYPE.isSupertypeOf(typeRef)) {
       TypeRef<?> elementType = getElementType(typeRef);
       boolean isSuperOfArrayList = cls.isAssignableFrom(ArrayList.class);
@@ -672,7 +686,7 @@ public class TypeUtils {
               .canConstructCollection(typeRef.getRawType(), 
elementType.getRawType())) {
         return false;
       }
-      return isSupported(getElementType(typeRef));
+      return isSupported(elementType, ctx);
     } else if (MAP_TYPE.isSupertypeOf(typeRef)) {
       boolean isSuperOfHashMap = cls.isAssignableFrom(HashMap.class);
       if (!isSuperOfHashMap && (cls.isInterface() || 
Modifier.isAbstract(cls.getModifiers()))) {
@@ -684,7 +698,7 @@ public class TypeUtils {
       return true;
     } else {
       ctx.checkNoCycle(typeRef);
-      return isBean(typeRef, ctx.appendTypePath(typeRef));
+      return isBean(typeRef, ctx);
     }
   }
 
@@ -715,30 +729,32 @@ public class TypeUtils {
     LinkedHashSet<Class<?>> beans = new LinkedHashSet<>();
     Class<?> enclosingType = ctx.getEnclosingType().getRawType();
     Class<?> type = typeRef.getRawType();
-    TypeResolutionContext newCtx = ctx;
-    if (ctx.getCustomTypeRegistry().hasCodec(enclosingType, type)) {
+    TypeRef<?> replacementType =
+        ctx.getCustomTypeRegistry().replacementTypeFor(enclosingType, type);
+    if (replacementType != null && !replacementType.equals(typeRef)) {
+      beans.addAll(listBeansRecursiveInclusive(replacementType, ctx));
       return beans;
     } else if (type == Optional.class) {
       TypeRef<?> elemType = getTypeArguments(typeRef).get(0);
-      beans.addAll(listBeansRecursiveInclusive(elemType, newCtx));
+      beans.addAll(listBeansRecursiveInclusive(elemType, ctx));
     } else if (isCollection(type) || Iterable.class == type) {
       TypeRef<?> elementType = getElementType(typeRef);
-      beans.addAll(listBeansRecursiveInclusive(elementType, newCtx));
+      beans.addAll(listBeansRecursiveInclusive(elementType, ctx));
     } else if (isMap(type)) {
       Tuple2<TypeRef<?>, TypeRef<?>> mapKeyValueType = 
getMapKeyValueType(typeRef);
-      TypeResolutionContext mapCtx = newCtx;
+      TypeResolutionContext mapCtx = ctx;
       beans.addAll(listBeansRecursiveInclusive(mapKeyValueType.f0, mapCtx));
       beans.addAll(listBeansRecursiveInclusive(mapKeyValueType.f1, mapCtx));
     } else if (type.isArray()) {
       Class<?> arrayComponent = getArrayComponent(type);
-      beans.addAll(listBeansRecursiveInclusive(TypeRef.of(arrayComponent), 
newCtx));
-    } else if (isBean(type, newCtx)) {
+      beans.addAll(listBeansRecursiveInclusive(TypeRef.of(arrayComponent), 
ctx));
+    } else if (isBean(type, ctx)) {
       List<Descriptor> descriptors = Descriptor.getDescriptors(type);
       beans.add(type);
       for (Descriptor descriptor : descriptors) {
         ctx.checkNoCycle(typeRef);
         beans.addAll(
-            listBeansRecursiveInclusive(descriptor.getTypeRef(), 
newCtx.appendTypePath(typeRef)));
+            listBeansRecursiveInclusive(descriptor.getTypeRef(), 
ctx.appendTypePath(typeRef)));
       }
     }
     return beans;
diff --git 
a/java/fory-format/src/main/java/org/apache/fory/format/encoder/ArrayDataForEach.java
 
b/java/fory-format/src/main/java/org/apache/fory/format/encoder/ArrayDataForEach.java
index ed98ef3e..7451aeb2 100644
--- 
a/java/fory-format/src/main/java/org/apache/fory/format/encoder/ArrayDataForEach.java
+++ 
b/java/fory-format/src/main/java/org/apache/fory/format/encoder/ArrayDataForEach.java
@@ -85,7 +85,7 @@ public class ArrayDataForEach extends AbstractExpression {
     if (customEncoder == null) {
       accessType = elemType;
     } else {
-      accessType = TypeRef.of(customEncoder.encodedType());
+      accessType = customEncoder.encodedType();
     }
     CustomTypeHandler customTypeHandler = 
CustomTypeEncoderRegistry.customTypeHandler();
     TypeResolutionContext ctx = new TypeResolutionContext(customTypeHandler, 
true);
diff --git 
a/java/fory-format/src/main/java/org/apache/fory/format/encoder/ArrayEncoderBuilder.java
 
b/java/fory-format/src/main/java/org/apache/fory/format/encoder/ArrayEncoderBuilder.java
index 164c8d04..8937a66e 100644
--- 
a/java/fory-format/src/main/java/org/apache/fory/format/encoder/ArrayEncoderBuilder.java
+++ 
b/java/fory-format/src/main/java/org/apache/fory/format/encoder/ArrayEncoderBuilder.java
@@ -23,6 +23,7 @@ import static org.apache.fory.type.TypeUtils.CLASS_TYPE;
 import static org.apache.fory.type.TypeUtils.getRawType;
 
 import java.lang.reflect.Array;
+import java.util.HashSet;
 import org.apache.arrow.vector.types.pojo.Field;
 import org.apache.fory.Fory;
 import org.apache.fory.codegen.CodeGenerator;
@@ -181,7 +182,8 @@ public class ArrayEncoderBuilder extends 
BaseBinaryEncoderBuilder {
             arrayData,
             elemType,
             (i, value) ->
-                new Expression.Invoke(collection, "add", deserializeFor(value, 
elemType, typeCtx)),
+                new Expression.Invoke(
+                    collection, "add", deserializeFor(value, elemType, 
typeCtx, new HashSet<>())),
             i -> new Expression.Invoke(collection, "add", 
ExpressionUtils.nullValue(elemType)));
     return new Expression.ListExpression(collection, addElemsOp, collection);
   }
diff --git 
a/java/fory-format/src/main/java/org/apache/fory/format/encoder/BaseBinaryEncoderBuilder.java
 
b/java/fory-format/src/main/java/org/apache/fory/format/encoder/BaseBinaryEncoderBuilder.java
index 79f369e9..332a2bcc 100644
--- 
a/java/fory-format/src/main/java/org/apache/fory/format/encoder/BaseBinaryEncoderBuilder.java
+++ 
b/java/fory-format/src/main/java/org/apache/fory/format/encoder/BaseBinaryEncoderBuilder.java
@@ -30,6 +30,7 @@ import java.util.HashSet;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Optional;
+import java.util.Set;
 import org.apache.arrow.vector.types.pojo.Field;
 import org.apache.arrow.vector.types.pojo.Schema;
 import org.apache.fory.builder.CodecBuilder;
@@ -150,22 +151,18 @@ public abstract class BaseBinaryEncoderBuilder extends 
CodecBuilder {
       Expression inputObject,
       Expression writer,
       TypeRef<?> typeRef,
-      Expression arrowField) {
+      Expression arrowField,
+      Set<TypeRef<?>> visitedCustomTypes) {
     Class<?> rawType = getRawType(typeRef);
-    CustomCodec<?, ?> customHandler = 
customTypeHandler.findCodec(beanType.getRawType(), rawType);
-    if (customHandler != null) {
-      TypeRef<?> rewrittenType = TypeRef.of(customHandler.encodedType());
-      Expression newInputObject =
-          new Expression.StaticInvoke(
-              customTypeHandler.getClass(),
-              "encode",
-              "rewrittenValue",
-              rewrittenType,
-              true,
-              new Expression.Null(beanType, true),
-              inputObject);
+    TypeRef<?> rewrittenType = 
customTypeHandler.replacementTypeFor(beanType.getRawType(), rawType);
+    if (rewrittenType != null
+        && !visitedCustomTypes.contains(typeRef)
+        && !typeRef.equals(rewrittenType)) {
+      Expression newInputObject = customEncode(inputObject, rewrittenType);
+      visitedCustomTypes.add(typeRef);
       Expression doSerialize =
-          serializeFor(ordinal, newInputObject, writer, rewrittenType, 
arrowField);
+          serializeFor(
+              ordinal, newInputObject, writer, rewrittenType, arrowField, 
visitedCustomTypes);
       return new If(
           ExpressionUtils.eqNull(inputObject),
           new Invoke(writer, "setNullAt", ordinal),
@@ -177,7 +174,12 @@ public abstract class BaseBinaryEncoderBuilder extends 
CodecBuilder {
       Expression unwrapped =
           new If(ExpressionUtils.eqNull(inputObject), new 
Expression.Null(elemType), orNull);
       return serializeFor(
-          ordinal, new Expression.Cast(unwrapped, elemType), writer, elemType, 
arrowField);
+          ordinal,
+          new Expression.Cast(unwrapped, elemType),
+          writer,
+          elemType,
+          arrowField,
+          visitedCustomTypes);
     } else if (TypeUtils.isPrimitive(rawType)) {
       return new ListExpression(
           // notNull is by default, no need to call setNotNullAt
@@ -310,7 +312,8 @@ public abstract class BaseBinaryEncoderBuilder extends 
CodecBuilder {
                         value,
                         arrayWriter,
                         Objects.requireNonNull(typeRef.getComponentType()),
-                        arrayElementField));
+                        arrayElementField,
+                        new HashSet<>()));
         return new ListExpression(reset, forEach, arrayWriter);
       }
     } else if (getRawType(typeRef) == Iterable.class) {
@@ -322,7 +325,12 @@ public abstract class BaseBinaryEncoderBuilder extends 
CodecBuilder {
               listFromIterable,
               (i, value) ->
                   serializeFor(
-                      i, value, arrayWriter, 
TypeUtils.getElementType(typeRef), arrayElementField));
+                      i,
+                      value,
+                      arrayWriter,
+                      TypeUtils.getElementType(typeRef),
+                      arrayElementField,
+                      new HashSet<>()));
       return new ListExpression(reset, forEach, arrayWriter);
     } else { // collection
       Invoke size = new Invoke(inputObject, "size", 
TypeUtils.PRIMITIVE_INT_TYPE);
@@ -332,7 +340,12 @@ public abstract class BaseBinaryEncoderBuilder extends 
CodecBuilder {
               inputObject,
               (i, value) ->
                   serializeFor(
-                      i, value, arrayWriter, 
TypeUtils.getElementType(typeRef), arrayElementField));
+                      i,
+                      value,
+                      arrayWriter,
+                      TypeUtils.getElementType(typeRef),
+                      arrayElementField,
+                      new HashSet<>()));
       return new ListExpression(reset, forEach, arrayWriter);
     }
   }
@@ -517,36 +530,20 @@ public abstract class BaseBinaryEncoderBuilder extends 
CodecBuilder {
    * typeToken</code>.
    */
   protected Expression deserializeFor(
-      Expression value, TypeRef<?> typeRef, TypeResolutionContext ctx) {
+      Expression value,
+      TypeRef<?> typeRef,
+      TypeResolutionContext ctx,
+      Set<TypeRef<?>> visitedCustomTypes) {
     Class<?> rawType = getRawType(typeRef);
-    CustomCodec<?, ?> customHandler = 
customTypeHandler.findCodec(beanType.getRawType(), rawType);
-    if (customHandler != null) {
-      TypeRef<?> rewrittenType = TypeRef.of(customHandler.encodedType());
-      Class<?> rawRewrittenType = rewrittenType.getRawType();
-      Expression inputValue;
-      if (rawRewrittenType == byte[].class) {
-        inputValue = Invoke.inlineInvoke(value, "toByteArray", 
TypeRef.of(byte[].class));
-      } else {
-        inputValue = value;
-      }
-      Expression newValue =
-          new Expression.StaticInvoke(
-              customTypeHandler.getClass(),
-              "decode",
-              "decodedValue",
-              typeRef,
-              true,
-              new Expression.Null(beanType, true),
-              new Expression.Null(typeRef, true),
-              inputValue);
-      if (rawRewrittenType == MemoryBuffer.class) {
-        return newValue;
-      } else if (rawRewrittenType == BinaryArray.class) {
-        return newValue;
-      } else if (rawRewrittenType == byte[].class) {
-        return newValue;
-      }
-      return deserializeFor(newValue, rewrittenType, ctx);
+    TypeRef<?> rewrittenType = 
customTypeHandler.replacementTypeFor(beanType.getRawType(), rawType);
+    ;
+    if (rewrittenType != null
+        && !visitedCustomTypes.contains(typeRef)
+        && !typeRef.equals(rewrittenType)) {
+      visitedCustomTypes.add(typeRef);
+      final Expression deserializedValue =
+          deserializeFor(value, rewrittenType, ctx, visitedCustomTypes);
+      return customDecode(typeRef, deserializedValue);
     } else if (rawType == Optional.class) {
       TypeRef<?> elemType = TypeUtils.getTypeArguments(typeRef).get(0);
       return new Expression.StaticInvoke(
@@ -555,7 +552,7 @@ public abstract class BaseBinaryEncoderBuilder extends 
CodecBuilder {
           "optional",
           typeRef,
           true,
-          deserializeFor(value, elemType, ctx));
+          deserializeFor(value, elemType, ctx, visitedCustomTypes));
     } else if (TypeUtils.isPrimitive(rawType) || TypeUtils.isBoxed(rawType)) {
       return value;
     } else if (rawType == BigDecimal.class) {
@@ -575,6 +572,10 @@ public abstract class BaseBinaryEncoderBuilder extends 
CodecBuilder {
           DateTimeUtils.class, "microsToInstant", TypeUtils.INSTANT_TYPE, 
false, value);
     } else if (rawType == String.class) {
       return value;
+    } else if (rawType == MemoryBuffer.class) {
+      return value;
+    } else if (rawType == BinaryArray.class) {
+      return value;
     } else if (rawType.isEnum()) {
       return ExpressionUtils.valueOf(typeRef, value);
     } else if (rawType.isArray()) {
@@ -644,7 +645,9 @@ public abstract class BaseBinaryEncoderBuilder extends 
CodecBuilder {
           new ArrayDataForEach(
               arrayData,
               elemType,
-              (i, value) -> new Invoke(collection, "add", 
deserializeFor(value, elemType, typeCtx)),
+              (i, value) ->
+                  new Invoke(
+                      collection, "add", deserializeFor(value, elemType, 
typeCtx, new HashSet<>())),
               i -> new Invoke(collection, "add", 
ExpressionUtils.nullValue(elemType)));
       return new ListExpression(collection, addElemsOp, collection);
     } catch (Exception e) {
@@ -800,7 +803,7 @@ public abstract class BaseBinaryEncoderBuilder extends 
CodecBuilder {
                 arrayData,
                 elemType,
                 (i, value) -> {
-                  Expression elemValue = deserializeFor(value, elemType, 
typeCtx);
+                  Expression elemValue = deserializeFor(value, elemType, 
typeCtx, new HashSet<>());
                   return new AssignArrayElem(javaArray, elemValue, i);
                 });
         // add javaArray at last as expression value
@@ -816,4 +819,27 @@ public abstract class BaseBinaryEncoderBuilder extends 
CodecBuilder {
   protected Expression deserializeForObject(Expression value, TypeRef<?> 
typeRef) {
     return new Invoke(foryRef, "deserialize", typeRef, value);
   }
+
+  protected Expression customEncode(Expression inputObject, TypeRef<?> 
rewrittenType) {
+    return new Expression.StaticInvoke(
+        customTypeHandler.getClass(),
+        "encode",
+        "rewrittenValue",
+        rewrittenType,
+        true,
+        new Expression.Null(beanType, true),
+        inputObject);
+  }
+
+  protected Expression customDecode(TypeRef<?> typeRef, final Expression 
deserializedValue) {
+    return new Expression.StaticInvoke(
+        customTypeHandler.getClass(),
+        "decode",
+        "decodedValue",
+        typeRef,
+        true,
+        new Expression.Null(beanType, true),
+        new Expression.Null(typeRef, true),
+        deserializedValue);
+  }
 }
diff --git 
a/java/fory-format/src/main/java/org/apache/fory/format/encoder/CustomCodec.java
 
b/java/fory-format/src/main/java/org/apache/fory/format/encoder/CustomCodec.java
index 75dcb111..09811234 100644
--- 
a/java/fory-format/src/main/java/org/apache/fory/format/encoder/CustomCodec.java
+++ 
b/java/fory-format/src/main/java/org/apache/fory/format/encoder/CustomCodec.java
@@ -24,20 +24,29 @@ import org.apache.arrow.vector.types.pojo.Field;
 import org.apache.fory.format.row.binary.BinaryArray;
 import org.apache.fory.format.type.DataTypes;
 import org.apache.fory.memory.MemoryBuffer;
+import org.apache.fory.reflect.TypeRef;
 
+/**
+ * Extension point to customize Fory row codec behavior. Supports intercepting 
types to be written
+ * ({@code encode}) and read ({@code decode}).
+ *
+ * @param <T> the type the codec decodes to (used in Java)
+ * @param <E> the type the codec encodes to (byte representation)
+ */
 public interface CustomCodec<T, E> {
   Field getField(String fieldName);
 
-  Class<E> encodedType();
+  TypeRef<E> encodedType();
 
   E encode(T value);
 
   T decode(E value);
 
+  /** Specialized codec base for encoding and decoding to/from {@link 
MemoryBuffer}. */
   interface MemoryBufferCodec<T> extends CustomCodec<T, MemoryBuffer> {
     @Override
-    default Class<MemoryBuffer> encodedType() {
-      return MemoryBuffer.class;
+    default TypeRef<MemoryBuffer> encodedType() {
+      return TypeRef.of(MemoryBuffer.class);
     }
 
     @Override
@@ -46,10 +55,11 @@ public interface CustomCodec<T, E> {
     }
   }
 
+  /** Specialized codec base for encoding and decoding to/from {@code byte[]}. 
*/
   interface ByteArrayCodec<T> extends CustomCodec<T, byte[]> {
     @Override
-    default Class<byte[]> encodedType() {
-      return byte[].class;
+    default TypeRef<byte[]> encodedType() {
+      return TypeRef.of(byte[].class);
     }
 
     @Override
@@ -58,10 +68,11 @@ public interface CustomCodec<T, E> {
     }
   }
 
+  /** Specialized codec base for encoding and decoding to/from {@link 
BinaryArray}. */
   interface BinaryArrayCodec<T> extends CustomCodec<T, BinaryArray> {
     @Override
-    default Class<BinaryArray> encodedType() {
-      return BinaryArray.class;
+    default TypeRef<BinaryArray> encodedType() {
+      return TypeRef.of(BinaryArray.class);
     }
 
     @Override
@@ -69,4 +80,26 @@ public interface CustomCodec<T, E> {
       return DataTypes.primitiveArrayField(fieldName, DataTypes.int8());
     }
   }
+
+  /**
+   * Specialized codec base for read and write replace of a value, without 
changing its type.
+   * Example use: converting Fory generated implementation into a standard 
user-provided
+   * implementation.
+   */
+  interface InterceptingCodec<T> extends CustomCodec<T, T> {
+    @Override
+    default Field getField(final String fieldName) {
+      return null;
+    }
+
+    @Override
+    default T decode(final T value) {
+      return value;
+    }
+
+    @Override
+    default T encode(final T value) {
+      return value;
+    }
+  }
 }
diff --git 
a/java/fory-format/src/main/java/org/apache/fory/format/encoder/RowEncoderBuilder.java
 
b/java/fory-format/src/main/java/org/apache/fory/format/encoder/RowEncoderBuilder.java
index a82f368c..a1026052 100644
--- 
a/java/fory-format/src/main/java/org/apache/fory/format/encoder/RowEncoderBuilder.java
+++ 
b/java/fory-format/src/main/java/org/apache/fory/format/encoder/RowEncoderBuilder.java
@@ -24,6 +24,7 @@ import static org.apache.fory.type.TypeUtils.getRawType;
 
 import java.lang.reflect.Modifier;
 import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Optional;
 import java.util.SortedMap;
@@ -175,12 +176,17 @@ public class RowEncoderBuilder extends 
BaseBinaryEncoderBuilder {
   @Override
   public Expression buildEncodeExpression() {
     Reference inputObject = new Reference(ROOT_OBJECT_NAME, 
TypeUtils.OBJECT_TYPE, false);
+    Expression bean = new Expression.Cast(inputObject, beanType, 
ctx.newName(beanClass));
     Reference writer = new Reference(ROOT_ROW_WRITER_NAME, rowWriterTypeToken, 
false);
     Reference schemaExpr = new Reference(SCHEMA_NAME, schemaTypeToken, false);
 
+    CustomCodec<?, ?> customCodec = customTypeHandler.findCodec(beanClass, 
beanClass);
+    if (customCodec != null && customCodec.encodedType().equals(beanType)) {
+      bean = customEncode(bean, beanType);
+    }
+
     int numFields = schema.getFields().size();
     Expression.ListExpression expressions = new Expression.ListExpression();
-    Expression.Cast bean = new Expression.Cast(inputObject, beanType, 
ctx.newName(beanClass));
     // schema field's name must correspond to descriptor's name.
     for (int i = 0; i < numFields; i++) {
       Descriptor d = 
getDescriptorByFieldName(schema.getFields().get(i).getName());
@@ -191,7 +197,8 @@ public class RowEncoderBuilder extends 
BaseBinaryEncoderBuilder {
       Expression.StaticInvoke field =
           new Expression.StaticInvoke(
               DataTypes.class, "fieldOfSchema", ARROW_FIELD_TYPE, false, 
schemaExpr, ordinal);
-      Expression fieldExpr = serializeFor(ordinal, fieldValue, writer, 
fieldType, field);
+      Expression fieldExpr =
+          serializeFor(ordinal, fieldValue, writer, fieldType, field, new 
HashSet<>());
       expressions.add(fieldExpr);
     }
     expressions.add(
@@ -210,51 +217,54 @@ public class RowEncoderBuilder extends 
BaseBinaryEncoderBuilder {
 
     addDecoderMethods();
 
-    if (generatedBeanImpl != null) {
-      return new Expression.Return(
-          new Expression.Reference("new " + generatedBeanImplName + "(row)"));
-    }
-
-    int numFields = schema.getFields().size();
-    List<String> fieldNames = new ArrayList<>(numFields);
-    Expression[] values = new Expression[numFields];
-    Descriptor[] descriptors = new Descriptor[numFields];
     Expression.ListExpression expressions = new Expression.ListExpression();
-    // schema field's name must correspond to descriptor's name.
-    for (int i = 0; i < numFields; i++) {
-      Literal ordinal = Literal.ofInt(i);
-      Descriptor d = 
getDescriptorByFieldName(schema.getFields().get(i).getName());
-      fieldNames.add(d.getName());
-      descriptors[i] = d;
-      TypeRef<?> fieldType = d.getTypeRef();
-      Expression.Variable value = new Expression.Variable(d.getName(), 
nullValue(fieldType));
-      values[i] = value;
-      expressions.add(value);
-      Expression.Invoke isNullAt =
-          new Expression.Invoke(row, "isNullAt", 
TypeUtils.PRIMITIVE_BOOLEAN_TYPE, ordinal);
-      Expression decode =
-          new Expression.If(
-              ExpressionUtils.not(isNullAt),
-              new Expression.Assign(
-                  value, new Expression.Reference(decodeMethodName(i) + 
"(row)", fieldType)));
-      expressions.add(decode);
-    }
     Expression bean;
-    if (RecordUtils.isRecord(beanClass)) {
-      int[] map = RecordUtils.buildRecordComponentMapping(beanClass, 
fieldNames);
-      Expression[] args = new Expression[numFields];
+    if (generatedBeanImpl != null) {
+      bean = new Expression.Reference("new " + generatedBeanImplName + 
"(row)");
+    } else {
+      int numFields = schema.getFields().size();
+      List<String> fieldNames = new ArrayList<>(numFields);
+      Expression[] values = new Expression[numFields];
+      Descriptor[] descriptors = new Descriptor[numFields];
+      // schema field's name must correspond to descriptor's name.
       for (int i = 0; i < numFields; i++) {
-        args[i] = values[map[i]];
+        Literal ordinal = Literal.ofInt(i);
+        Descriptor d = 
getDescriptorByFieldName(schema.getFields().get(i).getName());
+        fieldNames.add(d.getName());
+        descriptors[i] = d;
+        TypeRef<?> fieldType = d.getTypeRef();
+        Expression.Variable value = new Expression.Variable(d.getName(), 
nullValue(fieldType));
+        values[i] = value;
+        expressions.add(value);
+        Expression.Invoke isNullAt =
+            new Expression.Invoke(row, "isNullAt", 
TypeUtils.PRIMITIVE_BOOLEAN_TYPE, ordinal);
+        Expression decode =
+            new Expression.If(
+                ExpressionUtils.not(isNullAt),
+                new Expression.Assign(
+                    value, new Expression.Reference(decodeMethodName(i) + 
"(row)", fieldType)));
+        expressions.add(decode);
       }
-      bean = new Expression.NewInstance(beanType, 
beanType.getRawType().getName(), args);
-    } else {
-      bean = newBean();
-      expressions.add(bean);
-      for (int i = 0; i < values.length; i++) {
-        expressions.add(setFieldValue(bean, descriptors[i], values[i]));
+      if (RecordUtils.isRecord(beanClass)) {
+        int[] map = RecordUtils.buildRecordComponentMapping(beanClass, 
fieldNames);
+        Expression[] args = new Expression[numFields];
+        for (int i = 0; i < numFields; i++) {
+          args[i] = values[map[i]];
+        }
+        bean = new Expression.NewInstance(beanType, 
beanType.getRawType().getName(), args);
+      } else {
+        bean = newBean();
+        expressions.add(bean);
+        for (int i = 0; i < values.length; i++) {
+          expressions.add(setFieldValue(bean, descriptors[i], values[i]));
+        }
       }
     }
 
+    CustomCodec<?, ?> customCodec = customTypeHandler.findCodec(beanClass, 
beanClass);
+    if (customCodec != null && customCodec.encodedType().equals(beanType)) {
+      bean = customDecode(beanType, bean);
+    }
     expressions.add(new Expression.Return(bean));
     return expressions;
   }
@@ -284,7 +294,7 @@ public class RowEncoderBuilder extends 
BaseBinaryEncoderBuilder {
         if (customEncoder == null) {
           columnAccessType = fieldType;
         } else {
-          columnAccessType = TypeRef.of(customEncoder.encodedType());
+          columnAccessType = customEncoder.encodedType();
         }
       }
       String columnAccessMethodName =
@@ -298,7 +308,8 @@ public class RowEncoderBuilder extends 
BaseBinaryEncoderBuilder {
               colType,
               false,
               ordinal);
-      Expression value = new Expression.Return(deserializeFor(columnValue, 
fieldType, typeCtx));
+      Expression value =
+          new Expression.Return(deserializeFor(columnValue, fieldType, 
typeCtx, new HashSet<>()));
       ctx.addMethod(
           decodeMethodName(i),
           value.doGenCode(ctx).code(),
diff --git 
a/java/fory-format/src/main/java/org/apache/fory/format/type/CustomTypeEncoderRegistry.java
 
b/java/fory-format/src/main/java/org/apache/fory/format/type/CustomTypeEncoderRegistry.java
index 09945593..b9520879 100644
--- 
a/java/fory-format/src/main/java/org/apache/fory/format/type/CustomTypeEncoderRegistry.java
+++ 
b/java/fory-format/src/main/java/org/apache/fory/format/type/CustomTypeEncoderRegistry.java
@@ -139,7 +139,7 @@ public class CustomTypeEncoderRegistry {
                     + ")"
                     + codecFieldName
                     + ".encode(fieldValue);",
-                enc.encodedType(),
+                enc.encodedType().getRawType(),
                 reg.getBeanType(),
                 "bean",
                 reg.getFieldType(),
@@ -156,7 +156,7 @@ public class CustomTypeEncoderRegistry {
                 "bean",
                 reg.getFieldType(),
                 "fieldNull",
-                enc.encodedType(),
+                enc.encodedType().getRawType(),
                 "encodedValue");
             ctx.addField(
                 true,
diff --git 
a/java/fory-format/src/main/java/org/apache/fory/format/type/CustomTypeHandler.java
 
b/java/fory-format/src/main/java/org/apache/fory/format/type/CustomTypeHandler.java
index c3cad271..28aa0e14 100644
--- 
a/java/fory-format/src/main/java/org/apache/fory/format/type/CustomTypeHandler.java
+++ 
b/java/fory-format/src/main/java/org/apache/fory/format/type/CustomTypeHandler.java
@@ -22,6 +22,9 @@ package org.apache.fory.format.type;
 import org.apache.fory.annotation.Internal;
 import org.apache.fory.format.encoder.CustomCodec;
 import org.apache.fory.format.encoder.CustomCollectionFactory;
+import org.apache.fory.format.row.binary.BinaryArray;
+import org.apache.fory.memory.MemoryBuffer;
+import org.apache.fory.reflect.TypeRef;
 import org.apache.fory.type.CustomTypeRegistry;
 
 @Internal
@@ -46,8 +49,9 @@ public interface CustomTypeHandler extends CustomTypeRegistry 
{
       Class<?> collectionType, Class<?> elementType);
 
   @Override
-  default boolean hasCodec(final Class<?> beanType, final Class<?> fieldType) {
-    return findCodec(beanType, fieldType) != null;
+  default TypeRef<?> replacementTypeFor(final Class<?> beanType, final 
Class<?> fieldType) {
+    final CustomCodec<?, ?> codec = findCodec(beanType, fieldType);
+    return codec == null ? null : codec.encodedType();
   }
 
   @Override
@@ -55,4 +59,10 @@ public interface CustomTypeHandler extends 
CustomTypeRegistry {
       final Class<?> collectionType, final Class<?> elementType) {
     return findCollectionFactory(collectionType, elementType) != null;
   }
+
+  @Override
+  default boolean isExtraSupportedType(final TypeRef<?> type) {
+    final Class<?> cls = type.getRawType();
+    return cls == BinaryArray.class || cls == MemoryBuffer.class;
+  }
 }
diff --git 
a/java/fory-format/src/main/java/org/apache/fory/format/type/TypeInference.java 
b/java/fory-format/src/main/java/org/apache/fory/format/type/TypeInference.java
index a3196017..e7a67ce4 100644
--- 
a/java/fory-format/src/main/java/org/apache/fory/format/type/TypeInference.java
+++ 
b/java/fory-format/src/main/java/org/apache/fory/format/type/TypeInference.java
@@ -22,12 +22,12 @@ package org.apache.fory.format.type;
 import static org.apache.fory.format.type.DataTypes.field;
 import static org.apache.fory.type.TypeUtils.getRawType;
 
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
-import java.util.stream.Collectors;
 import org.apache.arrow.vector.complex.MapVector;
 import org.apache.arrow.vector.types.DateUnit;
 import org.apache.arrow.vector.types.FloatingPointPrecision;
@@ -39,6 +39,7 @@ import org.apache.arrow.vector.types.pojo.Schema;
 import org.apache.fory.collection.Tuple2;
 import org.apache.fory.format.encoder.CustomCodec;
 import org.apache.fory.format.encoder.CustomCollectionFactory;
+import org.apache.fory.format.row.binary.BinaryArray;
 import org.apache.fory.reflect.TypeRef;
 import org.apache.fory.type.Descriptor;
 import org.apache.fory.type.TypeResolutionContext;
@@ -155,8 +156,12 @@ public class TypeInference {
               true, fieldType.getType(), fieldType.getDictionary(), 
fieldType.getMetadata()),
           result.getChildren());
     } else if (customEncoder != null) {
-      return customEncoder.getField(name);
-    } else if (rawType == boolean.class) {
+      Field replacementField = customEncoder.getField(name);
+      if (replacementField != null) {
+        return replacementField;
+      }
+    }
+    if (rawType == boolean.class) {
       return field(name, DataTypes.notNullFieldType(ArrowType.Bool.INSTANCE));
     } else if (rawType == byte.class) {
       return field(name, DataTypes.notNullFieldType(new ArrowType.Int(8, 
true)));
@@ -209,6 +214,8 @@ public class TypeInference {
       return field(name, FieldType.nullable(ArrowType.Utf8.INSTANCE));
     } else if (rawType.isEnum()) {
       return field(name, FieldType.nullable(ArrowType.Utf8.INSTANCE));
+    } else if (rawType == BinaryArray.class) {
+      return field(name, FieldType.nullable(ArrowType.Binary.INSTANCE));
     } else if (rawType.isArray()) { // array
       Field f =
           inferField(
@@ -230,15 +237,13 @@ public class TypeInference {
       return DataTypes.mapField(name, keyField, valueField);
     } else if (TypeUtils.isBean(rawType, ctx)) { // bean field
       ctx.checkNoCycle(rawType);
-      List<Field> fields =
-          Descriptor.getDescriptors(rawType).stream()
-              .map(
-                  descriptor -> {
-                    String n = 
StringUtils.lowerCamelToLowerUnderscore(descriptor.getName());
-                    TypeRef<?> fieldType = descriptor.getTypeRef();
-                    return inferField(n, fieldType, 
ctx.appendTypePath(rawType));
-                  })
-              .collect(Collectors.toList());
+      List<Descriptor> descriptors = Descriptor.getDescriptors(rawType);
+      List<Field> fields = new ArrayList<>(descriptors.size());
+      for (Descriptor descriptor : descriptors) {
+        String n = 
StringUtils.lowerCamelToLowerUnderscore(descriptor.getName());
+        TypeRef<?> fieldType = descriptor.getTypeRef();
+        fields.add(inferField(n, fieldType, ctx.appendTypePath(rawType)));
+      }
       return DataTypes.structField(name, true, fields);
     } else {
       throw new UnsupportedOperationException(
diff --git 
a/java/fory-format/src/test/java/org/apache/fory/format/encoder/CustomCodecTest.java
 
b/java/fory-format/src/test/java/org/apache/fory/format/encoder/CustomCodecTest.java
index bfda9e42..33b0faa4 100644
--- 
a/java/fory-format/src/test/java/org/apache/fory/format/encoder/CustomCodecTest.java
+++ 
b/java/fory-format/src/test/java/org/apache/fory/format/encoder/CustomCodecTest.java
@@ -32,9 +32,9 @@ import org.apache.arrow.vector.types.pojo.ArrowType;
 import org.apache.arrow.vector.types.pojo.Field;
 import org.apache.fory.format.row.binary.BinaryArray;
 import org.apache.fory.format.row.binary.BinaryRow;
-import org.apache.fory.format.type.DataTypes;
 import org.apache.fory.memory.MemoryBuffer;
 import org.apache.fory.memory.MemoryUtils;
+import org.apache.fory.reflect.TypeRef;
 import org.testng.Assert;
 import org.testng.annotations.Test;
 
@@ -46,6 +46,7 @@ public class CustomCodecTest {
     Encoders.registerCustomCodec(CustomByteBuf2.class, new 
CustomByteBuf2Encoder());
     Encoders.registerCustomCodec(CustomByteBuf3.class, new 
CustomByteBuf3Encoder());
     Encoders.registerCustomCodec(UUID.class, new UuidEncoder());
+    Encoders.registerCustomCodec(InterceptedType.class, new 
InterceptedTypeEncoder());
     Encoders.registerCustomCollectionFactory(
         SortedSet.class, UUID.class, new SortedSetOfUuidDecoder());
   }
@@ -154,8 +155,8 @@ public class CustomCodecTest {
     }
 
     @Override
-    public Class<String> encodedType() {
-      return String.class;
+    public TypeRef<String> encodedType() {
+      return TypeRef.of(String.class);
     }
   }
 
@@ -184,16 +185,6 @@ public class CustomCodecTest {
   }
 
   static class CustomByteBuf3Encoder implements 
CustomCodec.BinaryArrayCodec<CustomByteBuf3> {
-    @Override
-    public Field getField(final String fieldName) {
-      return DataTypes.primitiveArrayField(fieldName, DataTypes.int8());
-    }
-
-    @Override
-    public Class<BinaryArray> encodedType() {
-      return BinaryArray.class;
-    }
-
     @Override
     public BinaryArray encode(final CustomByteBuf3 value) {
       return BinaryArray.fromPrimitiveArray(value.buf);
@@ -240,4 +231,79 @@ public class CustomCodecTest {
       return Long.compareUnsigned(o1.getLeastSignificantBits(), 
o2.getLeastSignificantBits());
     }
   }
+
+  public interface InterceptedType {
+    int f1();
+  }
+
+  public static class InterceptedTypeImpl implements InterceptedType {
+    private final int f1;
+
+    public InterceptedTypeImpl(final int f1) {
+      this.f1 = f1;
+    }
+
+    @Override
+    public int f1() {
+      return f1;
+    }
+  }
+
+  static class InterceptedTypeEncoder implements 
CustomCodec.InterceptingCodec<InterceptedType> {
+    @Override
+    public TypeRef<InterceptedType> encodedType() {
+      return TypeRef.of(InterceptedType.class);
+    }
+
+    @Override
+    public InterceptedType encode(final InterceptedType value) {
+      return new InterceptedTypeImpl(value.f1() + 2);
+    }
+
+    @Override
+    public InterceptedType decode(final InterceptedType value) {
+      return new InterceptedTypeImpl(value.f1() + 3);
+    }
+  }
+
+  @Test
+  public void testCodecTypeInterception() {
+    final InterceptedType bean = new InterceptedTypeImpl(42);
+    final RowEncoder<InterceptedType> encoder = 
Encoders.bean(InterceptedType.class);
+    final BinaryRow row = encoder.toRow(bean);
+    final MemoryBuffer buffer = MemoryUtils.wrap(row.toBytes());
+    row.pointTo(buffer, 0, buffer.size());
+    final InterceptedType deserializedBean = encoder.fromRow(row);
+    Assert.assertEquals(deserializedBean.f1(), bean.f1() + 5);
+    Assert.assertEquals(deserializedBean.getClass(), 
InterceptedTypeImpl.class);
+  }
+
+  public interface WrapInterceptedType {
+    InterceptedType f1();
+  }
+
+  public static class WrapInterceptedTypeImpl implements WrapInterceptedType {
+    private final InterceptedType f1;
+
+    public WrapInterceptedTypeImpl(final InterceptedType f1) {
+      this.f1 = f1;
+    }
+
+    @Override
+    public InterceptedType f1() {
+      return f1;
+    }
+  }
+
+  @Test
+  public void testNestedCodecTypeInterception() {
+    final WrapInterceptedType bean = new WrapInterceptedTypeImpl(new 
InterceptedTypeImpl(42));
+    final RowEncoder<WrapInterceptedType> encoder = 
Encoders.bean(WrapInterceptedType.class);
+    final BinaryRow row = encoder.toRow(bean);
+    final MemoryBuffer buffer = MemoryUtils.wrap(row.toBytes());
+    row.pointTo(buffer, 0, buffer.size());
+    final WrapInterceptedType deserializedBean = encoder.fromRow(row);
+    Assert.assertEquals(deserializedBean.f1().f1(), bean.f1().f1() + 5);
+    Assert.assertEquals(deserializedBean.f1().getClass(), 
InterceptedTypeImpl.class);
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to