This is an automated email from the ASF dual-hosted git repository.
mbudiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git
The following commit(s) were added to refs/heads/main by this push:
new a33270390b [CALCITE-6244] Improve `Expressions#constant` to allow Java
records
a33270390b is described below
commit a33270390b85170f53616d59608f1c7eb35a1889
Author: Wegdan Ghazi <[email protected]>
AuthorDate: Fri Jan 5 16:52:10 2024 +0100
[CALCITE-6244] Improve `Expressions#constant` to allow Java records
---
.../calcite/linq4j/tree/ConstantExpression.java | 72 +++++++++++---
.../org/apache/calcite/linq4j/util/Compatible.java | 79 +++++++++++++++
.../apache/calcite/linq4j/util/package-info.java | 21 ++++
.../apache/calcite/linq4j/test/ExpressionTest.java | 29 ++++++
.../calcite/linq4j/test/util/RecordHelper.java | 107 +++++++++++++++++++++
5 files changed, 297 insertions(+), 11 deletions(-)
diff --git
a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConstantExpression.java
b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConstantExpression.java
index 8e96c40b8e..73002c0e14 100644
---
a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConstantExpression.java
+++
b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConstantExpression.java
@@ -16,10 +16,15 @@
*/
package org.apache.calcite.linq4j.tree;
+import org.apache.calcite.linq4j.util.Compatible;
+
import org.checkerframework.checker.nullness.qual.Nullable;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.lang.reflect.Modifier;
import java.lang.reflect.Type;
import java.math.BigDecimal;
import java.math.BigInteger;
@@ -216,19 +221,14 @@ public class ConstantExpression extends Expression {
return writer;
}
- Constructor constructor = matchingConstructor(value);
+ final Field[] classFields = getClassFields(value.getClass());
+ Constructor constructor = matchingConstructor(value, classFields);
if (constructor != null) {
writer.append("new ").append(value.getClass());
list(writer,
- Arrays.stream(value.getClass().getFields())
+ Arrays.stream(classFields)
// <@Nullable Object> is needed for CheckerFramework
- .<@Nullable Object>map(field -> {
- try {
- return field.get(value);
- } catch (IllegalAccessException e) {
- throw new RuntimeException(e);
- }
- })
+ .<@Nullable Object>map(field -> getFieldValue(value, field))
.collect(Collectors.toList()),
"(\n", ",\n", ")");
return writer;
@@ -303,8 +303,7 @@ public class ConstantExpression extends Expression {
return writer.append(end);
}
- private static @Nullable Constructor matchingConstructor(Object value) {
- final Field[] fields = value.getClass().getFields();
+ private static @Nullable Constructor matchingConstructor(Object value,
Field[] fields) {
for (Constructor<?> constructor : value.getClass().getConstructors()) {
if (argsMatchFields(fields, constructor.getParameterTypes())) {
return constructor;
@@ -356,6 +355,57 @@ public class ConstantExpression extends Expression {
buf.append('"');
}
+ private static @Nullable Object getFieldValue(Object source, Field field) {
+ if (isRecord(source.getClass())) {
+ return getValueFromGetterMethod(source, field);
+ }
+ return getValueFromField(source, field);
+ }
+
+ private static @Nullable Object getValueFromGetterMethod(Object source,
Field field) {
+ try {
+ return findPublicGetter(field,
source.getClass().getMethods()).invoke(source);
+ } catch (IllegalAccessException | InvocationTargetException e) {
+ throw new IllegalArgumentException("Could not invoke getter method for
field: "
+ + field.getName(), e);
+ }
+ }
+
+ private static @Nullable Object getValueFromField(Object source, Field
field) {
+ try {
+ return field.get(source);
+ } catch (IllegalAccessException e) {
+ throw new IllegalArgumentException("Could not get field value for field:
"
+ + field.getName(), e);
+ }
+ }
+
+ private static Field[] getClassFields(Class<?> clazz) {
+ return isRecord(clazz) ? clazz.getDeclaredFields() : clazz.getFields();
+ }
+
+ private static boolean isRecord(Class<?> clazz) {
+ return Compatible.INSTANCE.isRecord(clazz);
+ }
+
+ private static Method findPublicGetter(Field field, Method[] methods) {
+ return Arrays.stream(methods)
+ .filter(method -> isFieldGetter(field, method))
+ .findFirst()
+ .orElseThrow(() -> new IllegalArgumentException("Could not get field
value"));
+ }
+
+ private static boolean isFieldGetter(Field field, Method method) {
+ return method.getReturnType().equals(field.getType())
+ && Modifier.isPublic(method.getModifiers())
+ && method.getParameterCount() == 0
+ && nameMatchesGetter(field, method);
+ }
+
+ private static boolean nameMatchesGetter(Field field, Method method) {
+ return method.getName().equals(field.getName());
+ }
+
@Override public boolean equals(@Nullable Object o) {
// REVIEW: Should constants with the same value and different type
// (e.g. 3L and 3) be considered equal.
diff --git
a/linq4j/src/main/java/org/apache/calcite/linq4j/util/Compatible.java
b/linq4j/src/main/java/org/apache/calcite/linq4j/util/Compatible.java
new file mode 100644
index 0000000000..63781e0641
--- /dev/null
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/util/Compatible.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.linq4j.util;
+
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+import java.lang.invoke.MethodHandle;
+import java.lang.invoke.MethodHandles;
+import java.lang.invoke.MethodType;
+import java.lang.reflect.Proxy;
+import java.util.Locale;
+
+import static java.util.Objects.requireNonNull;
+
+/**
+ * Compatibility layer.
+ *
+ * <p>Allows to use advanced functionality if the latest JDK version is
present.
+ */
+public interface Compatible {
+ Compatible INSTANCE = new Compatible.Factory().create();
+
+ /** Tells whether the given class is a JDK 16+ record. */
+ <T> boolean isRecord(Class<T> clazz);
+
+ /** Creates an implementation of {@link Compatible} suitable for the current
environment. */
+ class Factory {
+ private static final @Nullable MethodHandle IS_RECORD =
+ tryGetIsRecordMethod(MethodHandles.lookup());
+
+ Compatible create() {
+ return (Compatible) Proxy.newProxyInstance(
+ Compatible.class.getClassLoader(),
+ new Class<?>[]{Compatible.class},
+ (proxy, method, args) -> {
+ if ("isRecord".equals(method.getName())) {
+ return isRecord(requireNonNull(args[0], "args[0]"));
+ }
+ return null;
+ });
+ }
+
+ private static boolean isRecord(Object clazz) {
+ if (IS_RECORD == null) {
+ return false;
+ }
+
+ try {
+ return (boolean) IS_RECORD.invoke(clazz);
+ } catch (Throwable e) {
+ throw new RuntimeException(
+ String.format(Locale.ROOT, "Failed to invoke %s on %s", IS_RECORD,
clazz), e);
+ }
+ }
+
+ private static @Nullable MethodHandle
tryGetIsRecordMethod(MethodHandles.Lookup lookup) {
+ try {
+ MethodType methodType = MethodType.methodType(boolean.class);
+ return lookup.findVirtual(Class.class, "isRecord", methodType);
+ } catch (NoSuchMethodException | IllegalAccessException e) {
+ return null;
+ }
+ }
+ }
+}
diff --git
a/linq4j/src/main/java/org/apache/calcite/linq4j/util/package-info.java
b/linq4j/src/main/java/org/apache/calcite/linq4j/util/package-info.java
new file mode 100644
index 0000000000..d1f427e9ad
--- /dev/null
+++ b/linq4j/src/main/java/org/apache/calcite/linq4j/util/package-info.java
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Provides utility classes.
+ */
+package org.apache.calcite.linq4j.util;
diff --git
a/linq4j/src/test/java/org/apache/calcite/linq4j/test/ExpressionTest.java
b/linq4j/src/test/java/org/apache/calcite/linq4j/test/ExpressionTest.java
index 8aa7e6185a..895d509d00 100644
--- a/linq4j/src/test/java/org/apache/calcite/linq4j/test/ExpressionTest.java
+++ b/linq4j/src/test/java/org/apache/calcite/linq4j/test/ExpressionTest.java
@@ -39,11 +39,13 @@ import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
import java.lang.reflect.Modifier;
import java.lang.reflect.Type;
import java.math.BigDecimal;
import java.math.BigInteger;
+import java.nio.file.Path;
import java.util.AbstractList;
import java.util.Arrays;
import java.util.Collections;
@@ -59,6 +61,8 @@ import java.util.TreeSet;
import static org.apache.calcite.linq4j.test.BlockBuilderBase.ONE;
import static org.apache.calcite.linq4j.test.BlockBuilderBase.TWO;
+import static org.apache.calcite.linq4j.test.util.RecordHelper.createInstance;
+import static
org.apache.calcite.linq4j.test.util.RecordHelper.createRecordClass;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
@@ -883,6 +887,31 @@ public class ExpressionTest {
Expressions.constant(Linq4jTest.emps)));
}
+ /** Test case for
+ * <a
href="https://issues.apache.org/jira/browse/CALCITE-6244">[CALCITE-6244]
+ * Allow passing record as constant expression</a>. */
+ @Test void testWriteRecordConstant(@TempDir Path tempDir) {
+ Class<?> recordClass = createRecordClass(tempDir, "RecordModel");
+
+ // Call constructor for record
+ assertEquals(
+ "com.google.common.collect.ImmutableSet.of(new RecordModel(\n"
+ + " \"test1\",\n"
+ + " 1),new RecordModel(\n"
+ + " \"test2\",\n"
+ + " 2),new RecordModel(\n"
+ + " \"test3\",\n"
+ + " 3),new RecordModel(\n"
+ + " \"test4\",\n"
+ + " 4))",
+ Expressions.toString(
+ Expressions.constant(
+ ImmutableSet.of(createInstance(recordClass, "test1", 1),
+ createInstance(recordClass, "test2", 2),
+ createInstance(recordClass, "test3", 3),
+ createInstance(recordClass, "test4", 4)))));
+ }
+
@Test void testWriteArray() {
assertEquals(
"1 + integers[2 + index]",
diff --git
a/linq4j/src/test/java/org/apache/calcite/linq4j/test/util/RecordHelper.java
b/linq4j/src/test/java/org/apache/calcite/linq4j/test/util/RecordHelper.java
new file mode 100644
index 0000000000..e95501ebe6
--- /dev/null
+++ b/linq4j/src/test/java/org/apache/calcite/linq4j/test/util/RecordHelper.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.linq4j.test.util;
+
+import org.opentest4j.TestAbortedException;
+
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
+import java.net.MalformedURLException;
+import java.net.URL;
+import java.net.URLClassLoader;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.Locale;
+import javax.tools.JavaCompiler;
+import javax.tools.ToolProvider;
+
+/**
+ * Helper that compiles a record instance dynamically, if compatible, and
initializes instances of
+ * Java records.
+ */
+public class RecordHelper {
+
+ private RecordHelper(){}
+
+ private static final String RECORD_TEMPLATE = "public record %s(String name,
int count) {}";
+ private static final String JAVA_FILE_NAME_TEMPLATE = "%s.java";
+
+ /** Creates a Java record, aborts if the JDK is non-compatible. */
+ public static Class<?> createRecordClass(Path tempDir, String className) {
+ if (canSupportRecords()) {
+ return compileAndLoadClass(tempDir, className);
+ } else {
+ throw new TestAbortedException("Records not supported");
+ }
+ }
+
+ static boolean canSupportRecords() {
+ try {
+ Class.class.getMethod("isRecord");
+ return true;
+ } catch (NoSuchMethodException e) {
+ return false;
+ }
+ }
+
+ /** Compiles and loads a Java record dynamically. */
+ private static Class<?> compileAndLoadClass(Path tempDir, String className) {
+ createAndCompileTempClass(tempDir, className);
+
+ try {
+ return Class.forName(className,
+ true,
+ URLClassLoader.newInstance(new URL[] {tempDir.toUri().toURL() }));
+ } catch (ClassNotFoundException | MalformedURLException e) {
+ throw new IllegalArgumentException("Could not load class.");
+ }
+ }
+
+ /** Creates a temporary Java record and compiles it. */
+ public static void createAndCompileTempClass(Path tempDir, String className)
{
+ String classSourceCode =
+ String.format(Locale.ROOT, RECORD_TEMPLATE, className);
+ JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
+
+
+ Path tempJavaClassFile =
+ tempDir.resolve(String.format(Locale.ROOT, JAVA_FILE_NAME_TEMPLATE,
className));
+ try {
+ Files.write(tempJavaClassFile,
classSourceCode.getBytes(StandardCharsets.UTF_8));
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Could not write file.");
+ }
+ compiler.run(null, null, null,
tempJavaClassFile.toAbsolutePath().toString());
+ }
+
+ /** Creates new instances of the loaded class. */
+ public static Object createInstance(Class<?> clazz, String nameFieldValue,
int countFieldValue) {
+ Constructor<?> constructor = null;
+ try {
+ constructor = clazz.getDeclaredConstructor(String.class, int.class);
+ } catch (NoSuchMethodException e) {
+ throw new IllegalArgumentException("Could not find constructor");
+ }
+ try {
+ return constructor.newInstance(nameFieldValue, countFieldValue);
+ } catch (InstantiationException | IllegalAccessException |
InvocationTargetException e) {
+ throw new IllegalArgumentException("Could not create instance");
+ }
+ }
+}