This is an automated email from the ASF dual-hosted git repository.

mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new a33270390b [CALCITE-6244] Improve `Expressions#constant` to allow Java 
records
a33270390b is described below

commit a33270390b85170f53616d59608f1c7eb35a1889
Author: Wegdan Ghazi <[email protected]>
AuthorDate: Fri Jan 5 16:52:10 2024 +0100

    [CALCITE-6244] Improve `Expressions#constant` to allow Java records
---
 .../calcite/linq4j/tree/ConstantExpression.java    |  72 +++++++++++---
 .../org/apache/calcite/linq4j/util/Compatible.java |  79 +++++++++++++++
 .../apache/calcite/linq4j/util/package-info.java   |  21 ++++
 .../apache/calcite/linq4j/test/ExpressionTest.java |  29 ++++++
 .../calcite/linq4j/test/util/RecordHelper.java     | 107 +++++++++++++++++++++
 5 files changed, 297 insertions(+), 11 deletions(-)

diff --git 
a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConstantExpression.java 
b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConstantExpression.java
index 8e96c40b8e..73002c0e14 100644
--- 
a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConstantExpression.java
+++ 
b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConstantExpression.java
@@ -16,10 +16,15 @@
  */
 package org.apache.calcite.linq4j.tree;
 
+import org.apache.calcite.linq4j.util.Compatible;
+
 import org.checkerframework.checker.nullness.qual.Nullable;
 
 import java.lang.reflect.Constructor;
 import java.lang.reflect.Field;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.lang.reflect.Modifier;
 import java.lang.reflect.Type;
 import java.math.BigDecimal;
 import java.math.BigInteger;
@@ -216,19 +221,14 @@ public class ConstantExpression extends Expression {
       return writer;
     }
 
-    Constructor constructor = matchingConstructor(value);
+    final Field[] classFields = getClassFields(value.getClass());
+    Constructor constructor = matchingConstructor(value, classFields);
     if (constructor != null) {
       writer.append("new ").append(value.getClass());
       list(writer,
-          Arrays.stream(value.getClass().getFields())
+          Arrays.stream(classFields)
               // <@Nullable Object> is needed for CheckerFramework
-              .<@Nullable Object>map(field -> {
-                try {
-                  return field.get(value);
-                } catch (IllegalAccessException e) {
-                  throw new RuntimeException(e);
-                }
-              })
+              .<@Nullable Object>map(field -> getFieldValue(value, field))
               .collect(Collectors.toList()),
           "(\n", ",\n", ")");
       return writer;
@@ -303,8 +303,7 @@ public class ConstantExpression extends Expression {
     return writer.append(end);
   }
 
-  private static @Nullable Constructor matchingConstructor(Object value) {
-    final Field[] fields = value.getClass().getFields();
+  private static @Nullable Constructor matchingConstructor(Object value, 
Field[] fields) {
     for (Constructor<?> constructor : value.getClass().getConstructors()) {
       if (argsMatchFields(fields, constructor.getParameterTypes())) {
         return constructor;
@@ -356,6 +355,57 @@ public class ConstantExpression extends Expression {
     buf.append('"');
   }
 
+  private static @Nullable Object getFieldValue(Object source, Field field) {
+    if (isRecord(source.getClass())) {
+      return getValueFromGetterMethod(source, field);
+    }
+    return getValueFromField(source, field);
+  }
+
+  private static @Nullable Object getValueFromGetterMethod(Object source, 
Field field) {
+    try {
+      return findPublicGetter(field, 
source.getClass().getMethods()).invoke(source);
+    } catch (IllegalAccessException | InvocationTargetException e) {
+      throw new IllegalArgumentException("Could not invoke getter method for 
field: "
+          + field.getName(), e);
+    }
+  }
+
+  private static @Nullable Object getValueFromField(Object source, Field 
field) {
+    try {
+      return field.get(source);
+    } catch (IllegalAccessException e) {
+      throw new IllegalArgumentException("Could not get field value for field: 
"
+          + field.getName(), e);
+    }
+  }
+
+  private static Field[] getClassFields(Class<?> clazz) {
+    return isRecord(clazz) ? clazz.getDeclaredFields() : clazz.getFields();
+  }
+
+  private static boolean isRecord(Class<?> clazz) {
+    return Compatible.INSTANCE.isRecord(clazz);
+  }
+
+  private static Method findPublicGetter(Field field, Method[] methods) {
+    return Arrays.stream(methods)
+        .filter(method -> isFieldGetter(field, method))
+        .findFirst()
+        .orElseThrow(() -> new IllegalArgumentException("Could not get field 
value"));
+  }
+
+  private static boolean isFieldGetter(Field field, Method method) {
+    return method.getReturnType().equals(field.getType())
+        && Modifier.isPublic(method.getModifiers())
+        && method.getParameterCount() == 0
+        && nameMatchesGetter(field, method);
+  }
+
+  private static boolean nameMatchesGetter(Field field, Method method) {
+    return method.getName().equals(field.getName());
+  }
+
   @Override public boolean equals(@Nullable Object o) {
     // REVIEW: Should constants with the same value and different type
     // (e.g. 3L and 3) be considered equal.
diff --git 
a/linq4j/src/main/java/org/apache/calcite/linq4j/util/Compatible.java 
b/linq4j/src/main/java/org/apache/calcite/linq4j/util/Compatible.java
new file mode 100644
index 0000000000..63781e0641
--- /dev/null
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/util/Compatible.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.linq4j.util;
+
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+import java.lang.invoke.MethodHandle;
+import java.lang.invoke.MethodHandles;
+import java.lang.invoke.MethodType;
+import java.lang.reflect.Proxy;
+import java.util.Locale;
+
+import static java.util.Objects.requireNonNull;
+
+/**
+ * Compatibility layer.
+ *
+ * <p>Allows to use advanced functionality if the latest JDK version is 
present.
+ */
+public interface Compatible {
+  Compatible INSTANCE = new Compatible.Factory().create();
+
+  /** Tells whether the given class is a JDK 16+ record. */
+  <T> boolean isRecord(Class<T> clazz);
+
+  /** Creates an implementation of {@link Compatible} suitable for the current 
environment. */
+  class Factory {
+    private static final @Nullable MethodHandle IS_RECORD =
+        tryGetIsRecordMethod(MethodHandles.lookup());
+
+    Compatible create() {
+      return (Compatible) Proxy.newProxyInstance(
+          Compatible.class.getClassLoader(),
+          new Class<?>[]{Compatible.class},
+          (proxy, method, args) -> {
+            if ("isRecord".equals(method.getName())) {
+              return isRecord(requireNonNull(args[0], "args[0]"));
+            }
+            return null;
+          });
+    }
+
+    private static boolean isRecord(Object clazz) {
+      if (IS_RECORD == null) {
+        return false;
+      }
+
+      try {
+        return (boolean) IS_RECORD.invoke(clazz);
+      } catch (Throwable e) {
+        throw new RuntimeException(
+            String.format(Locale.ROOT, "Failed to invoke %s on %s", IS_RECORD, 
clazz), e);
+      }
+    }
+
+    private static @Nullable MethodHandle 
tryGetIsRecordMethod(MethodHandles.Lookup lookup) {
+      try {
+        MethodType methodType = MethodType.methodType(boolean.class);
+        return lookup.findVirtual(Class.class, "isRecord", methodType);
+      } catch (NoSuchMethodException | IllegalAccessException e) {
+        return null;
+      }
+    }
+  }
+}
diff --git 
a/linq4j/src/main/java/org/apache/calcite/linq4j/util/package-info.java 
b/linq4j/src/main/java/org/apache/calcite/linq4j/util/package-info.java
new file mode 100644
index 0000000000..d1f427e9ad
--- /dev/null
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/util/package-info.java
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Provides utility classes.
+ */
+package org.apache.calcite.linq4j.util;
diff --git 
a/linq4j/src/test/java/org/apache/calcite/linq4j/test/ExpressionTest.java 
b/linq4j/src/test/java/org/apache/calcite/linq4j/test/ExpressionTest.java
index 8aa7e6185a..895d509d00 100644
--- a/linq4j/src/test/java/org/apache/calcite/linq4j/test/ExpressionTest.java
+++ b/linq4j/src/test/java/org/apache/calcite/linq4j/test/ExpressionTest.java
@@ -39,11 +39,13 @@ import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Sets;
 
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
 
 import java.lang.reflect.Modifier;
 import java.lang.reflect.Type;
 import java.math.BigDecimal;
 import java.math.BigInteger;
+import java.nio.file.Path;
 import java.util.AbstractList;
 import java.util.Arrays;
 import java.util.Collections;
@@ -59,6 +61,8 @@ import java.util.TreeSet;
 
 import static org.apache.calcite.linq4j.test.BlockBuilderBase.ONE;
 import static org.apache.calcite.linq4j.test.BlockBuilderBase.TWO;
+import static org.apache.calcite.linq4j.test.util.RecordHelper.createInstance;
+import static 
org.apache.calcite.linq4j.test.util.RecordHelper.createRecordClass;
 
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
@@ -883,6 +887,31 @@ public class ExpressionTest {
             Expressions.constant(Linq4jTest.emps)));
   }
 
+  /** Test case for
+   * <a 
href="https://issues.apache.org/jira/browse/CALCITE-6244";>[CALCITE-6244]
+   * Allow passing record as constant expression</a>. */
+  @Test void testWriteRecordConstant(@TempDir Path tempDir) {
+    Class<?> recordClass = createRecordClass(tempDir, "RecordModel");
+
+    // Call constructor for record
+    assertEquals(
+        "com.google.common.collect.ImmutableSet.of(new RecordModel(\n"
+            +  "  \"test1\",\n"
+            +  "  1),new RecordModel(\n"
+            +  "  \"test2\",\n"
+            +  "  2),new RecordModel(\n"
+            +  "  \"test3\",\n"
+            +  "  3),new RecordModel(\n"
+            +  "  \"test4\",\n"
+            +  "  4))",
+        Expressions.toString(
+            Expressions.constant(
+                ImmutableSet.of(createInstance(recordClass, "test1", 1),
+                    createInstance(recordClass, "test2", 2),
+                    createInstance(recordClass, "test3", 3),
+                    createInstance(recordClass, "test4", 4)))));
+  }
+
   @Test void testWriteArray() {
     assertEquals(
         "1 + integers[2 + index]",
diff --git 
a/linq4j/src/test/java/org/apache/calcite/linq4j/test/util/RecordHelper.java 
b/linq4j/src/test/java/org/apache/calcite/linq4j/test/util/RecordHelper.java
new file mode 100644
index 0000000000..e95501ebe6
--- /dev/null
+++ b/linq4j/src/test/java/org/apache/calcite/linq4j/test/util/RecordHelper.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.linq4j.test.util;
+
+import org.opentest4j.TestAbortedException;
+
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
+import java.net.MalformedURLException;
+import java.net.URL;
+import java.net.URLClassLoader;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.Locale;
+import javax.tools.JavaCompiler;
+import javax.tools.ToolProvider;
+
+/**
+ * Helper that compiles a record instance dynamically, if compatible, and 
initializes instances of
+ * Java records.
+ */
+public class RecordHelper {
+
+  private RecordHelper(){}
+
+  private static final String RECORD_TEMPLATE = "public record %s(String name, 
int count) {}";
+  private static final String JAVA_FILE_NAME_TEMPLATE = "%s.java";
+
+  /** Creates a Java record, aborts if the JDK is non-compatible. */
+  public static Class<?> createRecordClass(Path tempDir, String className) {
+    if (canSupportRecords()) {
+      return compileAndLoadClass(tempDir, className);
+    } else {
+      throw new TestAbortedException("Records not supported");
+    }
+  }
+
+  static boolean canSupportRecords() {
+    try {
+      Class.class.getMethod("isRecord");
+      return true;
+    } catch (NoSuchMethodException e) {
+      return false;
+    }
+  }
+
+  /** Compiles and loads a Java record dynamically. */
+  private static Class<?> compileAndLoadClass(Path tempDir, String className) {
+    createAndCompileTempClass(tempDir, className);
+
+    try {
+      return Class.forName(className,
+          true,
+          URLClassLoader.newInstance(new URL[] {tempDir.toUri().toURL() }));
+    } catch (ClassNotFoundException | MalformedURLException e) {
+      throw new IllegalArgumentException("Could not load class.");
+    }
+  }
+
+  /** Creates a temporary Java record and compiles it. */
+  public static void createAndCompileTempClass(Path tempDir, String className) 
{
+    String classSourceCode =
+        String.format(Locale.ROOT, RECORD_TEMPLATE, className);
+    JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
+
+
+    Path tempJavaClassFile =
+        tempDir.resolve(String.format(Locale.ROOT, JAVA_FILE_NAME_TEMPLATE, 
className));
+    try {
+      Files.write(tempJavaClassFile, 
classSourceCode.getBytes(StandardCharsets.UTF_8));
+    } catch (IOException e) {
+      throw new IllegalArgumentException("Could not write file.");
+    }
+    compiler.run(null, null, null, 
tempJavaClassFile.toAbsolutePath().toString());
+  }
+
+  /** Creates new instances of the loaded class. */
+  public static Object createInstance(Class<?> clazz, String nameFieldValue, 
int countFieldValue) {
+    Constructor<?> constructor = null;
+    try {
+      constructor = clazz.getDeclaredConstructor(String.class, int.class);
+    } catch (NoSuchMethodException e) {
+      throw new IllegalArgumentException("Could not find constructor");
+    }
+    try {
+      return constructor.newInstance(nameFieldValue, countFieldValue);
+    } catch (InstantiationException | IllegalAccessException | 
InvocationTargetException e) {
+      throw new IllegalArgumentException("Could not create instance");
+    }
+  }
+}

Reply via email to