This is an automated email from the ASF dual-hosted git repository. jark pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit e13ae26e806eff403787d5b75e90008bcc13d8dc Author: luoyuxia <[email protected]> AuthorDate: Fri Jul 8 11:42:32 2022 +0800 [FLINK-28451][table][hive] Use UserCodeClassloader instead of the current thread's classloader to load function This closes #20211 --- .../factories/HiveFunctionDefinitionFactory.java | 66 ++++++++++------ .../flink/table/functions/hive/HiveFunction.java | 2 +- .../table/functions/hive/HiveFunctionWrapper.java | 79 +++++++++++-------- .../table/functions/hive/HiveGenericUDAF.java | 7 +- .../flink/table/functions/hive/HiveGenericUDF.java | 6 +- .../table/functions/hive/HiveGenericUDTF.java | 2 +- .../table/functions/hive/HiveScalarFunction.java | 16 ++-- .../flink/table/functions/hive/HiveSimpleUDF.java | 6 +- .../apache/flink/table/module/hive/HiveModule.java | 16 +++- .../flink/table/module/hive/HiveModuleFactory.java | 2 +- .../table/planner/delegation/hive/HiveParser.java | 3 +- .../hive/parse/HiveParserDDLSemanticAnalyzer.java | 11 ++- .../functions/hive/HiveFunctionWrapperTest.java | 91 ++++++++++++++++++++++ .../table/functions/hive/HiveGenericUDAFTest.java | 3 +- .../table/functions/hive/HiveGenericUDFTest.java | 5 +- .../table/functions/hive/HiveGenericUDTFTest.java | 4 +- .../table/functions/hive/HiveSimpleUDFTest.java | 5 +- .../client/gateway/context/ExecutionContext.java | 9 --- .../table/client/gateway/local/LocalExecutor.java | 11 +-- .../flink/table/catalog/FunctionCatalog.java | 7 +- .../table/factories/FunctionDefinitionFactory.java | 38 ++++++++- .../factories/TestFunctionDefinitionFactory.java} | 27 +++---- .../table/planner/factories/TestValuesCatalog.java | 8 +- .../table/planner/catalog/CatalogTableITCase.scala | 40 +++++++++- 24 files changed, 334 insertions(+), 130 deletions(-) diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/factories/HiveFunctionDefinitionFactory.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/factories/HiveFunctionDefinitionFactory.java index 934f77902bd..62458186f60 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/factories/HiveFunctionDefinitionFactory.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/catalog/hive/factories/HiveFunctionDefinitionFactory.java @@ -21,9 +21,11 @@ package org.apache.flink.table.catalog.hive.factories; import org.apache.flink.connectors.hive.HiveTableFactory; import org.apache.flink.table.api.TableException; import org.apache.flink.table.catalog.CatalogFunction; +import org.apache.flink.table.catalog.FunctionLanguage; import org.apache.flink.table.catalog.hive.client.HiveShim; import org.apache.flink.table.factories.FunctionDefinitionFactory; import org.apache.flink.table.functions.FunctionDefinition; +import org.apache.flink.table.functions.UserDefinedFunction; import org.apache.flink.table.functions.UserDefinedFunctionHelper; import org.apache.flink.table.functions.hive.HiveFunctionWrapper; import org.apache.flink.table.functions.hive.HiveGenericUDAF; @@ -54,17 +56,39 @@ public class HiveFunctionDefinitionFactory implements FunctionDefinitionFactory @Override public FunctionDefinition createFunctionDefinition( - String name, CatalogFunction catalogFunction) { - if (catalogFunction.isGeneric()) { - return createFunctionDefinitionFromFlinkFunction(name, catalogFunction); + String name, CatalogFunction catalogFunction, Context context) { + if (isFlinkFunction(catalogFunction, context.getClassLoader())) { + return createFunctionDefinitionFromFlinkFunction(name, catalogFunction, context); } - return createFunctionDefinitionFromHiveFunction(name, catalogFunction.getClassName()); + return createFunctionDefinitionFromHiveFunction( + name, catalogFunction.getClassName(), context); } public FunctionDefinition createFunctionDefinitionFromFlinkFunction( - String name, CatalogFunction catalogFunction) { + String name, CatalogFunction catalogFunction, Context context) { return UserDefinedFunctionHelper.instantiateFunction( - Thread.currentThread().getContextClassLoader(), null, name, catalogFunction); + context.getClassLoader(), null, name, catalogFunction); + } + + /** + * Distinguish if the function is a generic function. + * + * @return whether the function is a generic function + */ + private boolean isFlinkFunction(CatalogFunction catalogFunction, ClassLoader classLoader) { + if (catalogFunction.getFunctionLanguage() == FunctionLanguage.PYTHON) { + return true; + } + try { + Class<?> c = Class.forName(catalogFunction.getClassName(), true, classLoader); + if (UserDefinedFunction.class.isAssignableFrom(c)) { + return true; + } + } catch (ClassNotFoundException e) { + throw new RuntimeException( + String.format("Can't resolve udf class %s", catalogFunction.getClassName()), e); + } + return false; } /** @@ -72,10 +96,10 @@ public class HiveFunctionDefinitionFactory implements FunctionDefinitionFactory * org.apache.flink.table.module.hive.HiveModule}. */ public FunctionDefinition createFunctionDefinitionFromHiveFunction( - String name, String functionClassName) { - Class<?> clazz; + String name, String functionClassName, Context context) { + Class<?> functionClz; try { - clazz = Thread.currentThread().getContextClassLoader().loadClass(functionClassName); + functionClz = context.getClassLoader().loadClass(functionClassName); LOG.info("Successfully loaded Hive udf '{}' with class '{}'", name, functionClassName); } catch (ClassNotFoundException e) { @@ -84,33 +108,31 @@ public class HiveFunctionDefinitionFactory implements FunctionDefinitionFactory e); } - if (UDF.class.isAssignableFrom(clazz)) { + if (UDF.class.isAssignableFrom(functionClz)) { LOG.info("Transforming Hive function '{}' into a HiveSimpleUDF", name); - return new HiveSimpleUDF(new HiveFunctionWrapper<>(functionClassName), hiveShim); - } else if (GenericUDF.class.isAssignableFrom(clazz)) { + return new HiveSimpleUDF(new HiveFunctionWrapper<>(functionClz), hiveShim); + } else if (GenericUDF.class.isAssignableFrom(functionClz)) { LOG.info("Transforming Hive function '{}' into a HiveGenericUDF", name); - return new HiveGenericUDF(new HiveFunctionWrapper<>(functionClassName), hiveShim); - } else if (GenericUDTF.class.isAssignableFrom(clazz)) { + return new HiveGenericUDF(new HiveFunctionWrapper<>(functionClz), hiveShim); + } else if (GenericUDTF.class.isAssignableFrom(functionClz)) { LOG.info("Transforming Hive function '{}' into a HiveGenericUDTF", name); - return new HiveGenericUDTF(new HiveFunctionWrapper<>(functionClassName), hiveShim); - } else if (GenericUDAFResolver2.class.isAssignableFrom(clazz) - || UDAF.class.isAssignableFrom(clazz)) { + return new HiveGenericUDTF(new HiveFunctionWrapper<>(functionClz), hiveShim); + } else if (GenericUDAFResolver2.class.isAssignableFrom(functionClz) + || UDAF.class.isAssignableFrom(functionClz)) { - if (GenericUDAFResolver2.class.isAssignableFrom(clazz)) { + if (GenericUDAFResolver2.class.isAssignableFrom(functionClz)) { LOG.info( "Transforming Hive function '{}' into a HiveGenericUDAF without UDAF bridging", name); - return new HiveGenericUDAF( - new HiveFunctionWrapper<>(functionClassName), false, hiveShim); + return new HiveGenericUDAF(new HiveFunctionWrapper<>(functionClz), false, hiveShim); } else { LOG.info( "Transforming Hive function '{}' into a HiveGenericUDAF with UDAF bridging", name); - return new HiveGenericUDAF( - new HiveFunctionWrapper<>(functionClassName), true, hiveShim); + return new HiveGenericUDAF(new HiveFunctionWrapper<>(functionClz), true, hiveShim); } } else { throw new IllegalArgumentException( diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveFunction.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveFunction.java index c51a3c78370..c37f9e6ac2a 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveFunction.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveFunction.java @@ -86,7 +86,7 @@ public interface HiveFunction<UDFType> { if (throwOnFailure) { throw callContext.newValidationError( "Cannot find a suitable Hive function from %s for the input arguments", - hiveFunction.getFunctionWrapper().getClassName()); + hiveFunction.getFunctionWrapper().getUDFClassName()); } else { return Optional.empty(); } diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveFunctionWrapper.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveFunctionWrapper.java index a53d99c2069..eebd1e2153f 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveFunctionWrapper.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveFunctionWrapper.java @@ -21,9 +21,13 @@ package org.apache.flink.table.functions.hive; import org.apache.flink.annotation.Internal; import org.apache.flink.util.Preconditions; -import org.apache.hadoop.hive.ql.exec.SerializationUtilities; import org.apache.hadoop.hive.ql.exec.UDF; +import org.apache.hive.com.esotericsoftware.kryo.Kryo; +import org.apache.hive.com.esotericsoftware.kryo.io.Input; +import org.apache.hive.com.esotericsoftware.kryo.io.Output; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.Serializable; /** @@ -37,30 +41,30 @@ public class HiveFunctionWrapper<UDFType> implements Serializable { public static final long serialVersionUID = 393313529306818205L; - private final String className; - // a field to hold the string serialized for the UDF. + private final Class<UDFType> functionClz; + // a field to hold the bytes serialized for the UDF. // we sometimes need to hold it in case of some serializable UDF will contain // additional information such as Hive's GenericUDFMacro and if we construct the UDF directly by // getUDFClass#newInstance, the information will be missed. - private String udfSerializedString; + private byte[] udfSerializedBytes; private transient UDFType instance = null; - public HiveFunctionWrapper(String className) { - this.className = className; + public HiveFunctionWrapper(Class<?> functionClz) { + this.functionClz = (Class<UDFType>) functionClz; } /** * Create a HiveFunctionWrapper with a UDF instance. In this constructor, the instance will be * serialized to string and held on in the HiveFunctionWrapper. */ - public HiveFunctionWrapper(String className, UDFType serializableInstance) { - this(className); + public HiveFunctionWrapper(Class<?> functionClz, UDFType serializableInstance) { + this(functionClz); Preconditions.checkArgument( - serializableInstance.getClass().getName().equals(className), + serializableInstance.getClass().getName().equals(getUDFClassName()), String.format( "Expect the UDF is instance of %s, but is instance of %s.", - className, serializableInstance.getClass().getName())); + getUDFClassName(), serializableInstance.getClass().getName())); Preconditions.checkArgument( serializableInstance instanceof Serializable, String.format( @@ -68,8 +72,7 @@ public class HiveFunctionWrapper<UDFType> implements Serializable { serializableInstance.getClass().getName())); // we need to use the SerializationUtilities#serializeObject to serialize UDF for the UDF // may not be serialized by Java serializer - this.udfSerializedString = - SerializationUtilities.serializeObject((Serializable) serializableInstance); + this.udfSerializedBytes = serializeObjectToKryo((Serializable) serializableInstance); } /** @@ -78,7 +81,7 @@ public class HiveFunctionWrapper<UDFType> implements Serializable { * @return a Hive function instance */ public UDFType createFunction() { - if (udfSerializedString != null) { + if (udfSerializedBytes != null) { // deserialize the string to udf instance return deserializeUDF(); } else if (instance != null) { @@ -86,10 +89,11 @@ public class HiveFunctionWrapper<UDFType> implements Serializable { } else { UDFType func; try { - func = getUDFClass().newInstance(); - } catch (InstantiationException | IllegalAccessException | ClassNotFoundException e) { + func = functionClz.newInstance(); + } catch (InstantiationException | IllegalAccessException e) { throw new FlinkHiveUDFException( - String.format("Failed to create function from %s", className), e); + String.format("Failed to create function from %s", functionClz.getName()), + e); } if (!(func instanceof UDF)) { @@ -107,18 +111,12 @@ public class HiveFunctionWrapper<UDFType> implements Serializable { * * @return class name of the Hive function */ - public String getClassName() { - return className; + public String getUDFClassName() { + return functionClz.getName(); } - /** - * Get class of the Hive function. - * - * @return class of the Hive function - * @throws ClassNotFoundException thrown when the class is not found in classpath - */ - public Class<UDFType> getUDFClass() throws ClassNotFoundException { - return (Class<UDFType>) Thread.currentThread().getContextClassLoader().loadClass(className); + public Class<UDFType> getUDFClass() { + return functionClz; } /** @@ -127,13 +125,26 @@ public class HiveFunctionWrapper<UDFType> implements Serializable { * @return the UDF deserialized */ private UDFType deserializeUDF() { - try { - return (UDFType) - SerializationUtilities.deserializeObject( - udfSerializedString, (Class<Serializable>) getUDFClass()); - } catch (ClassNotFoundException e) { - throw new FlinkHiveUDFException( - String.format("Failed to deserialize function %s.", className), e); - } + return (UDFType) + deserializeObjectFromKryo(udfSerializedBytes, (Class<Serializable>) getUDFClass()); + } + + private static byte[] serializeObjectToKryo(Serializable object) { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + Output output = new Output(baos); + Kryo kryo = new Kryo(); + kryo.writeObject(output, object); + output.close(); + return baos.toByteArray(); + } + + private static <T extends Serializable> T deserializeObjectFromKryo( + byte[] bytes, Class<T> clazz) { + Input inp = new Input(new ByteArrayInputStream(bytes)); + Kryo kryo = new Kryo(); + kryo.setClassLoader(clazz.getClassLoader()); + T func = kryo.readObject(inp, clazz); + inp.close(); + return func; } } diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveGenericUDAF.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveGenericUDAF.java index df753df152a..c15b5d85255 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveGenericUDAF.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveGenericUDAF.java @@ -172,7 +172,7 @@ public class HiveGenericUDAF throw new FlinkHiveUDFException( String.format( "Failed to create accumulator for %s", - hiveFunctionWrapper.getClassName()), + hiveFunctionWrapper.getUDFClassName()), e); } } @@ -206,7 +206,8 @@ public class HiveGenericUDAF } catch (HiveException e) { throw new FlinkHiveUDFException( String.format( - "Failed to get final result on %s", hiveFunctionWrapper.getClassName()), + "Failed to get final result on %s", + hiveFunctionWrapper.getUDFClassName()), e); } } @@ -247,7 +248,7 @@ public class HiveGenericUDAF throw new FlinkHiveUDFException( String.format( "Failed to get Hive result type from %s", - hiveFunctionWrapper.getClassName()), + hiveFunctionWrapper.getUDFClassName()), e); } } diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveGenericUDF.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveGenericUDF.java index cca3a9ce4b8..10483d18024 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveGenericUDF.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveGenericUDF.java @@ -47,13 +47,13 @@ public class HiveGenericUDF extends HiveScalarFunction<GenericUDF> { public HiveGenericUDF(HiveFunctionWrapper<GenericUDF> hiveFunctionWrapper, HiveShim hiveShim) { super(hiveFunctionWrapper); this.hiveShim = hiveShim; - LOG.info("Creating HiveGenericUDF from '{}'", hiveFunctionWrapper.getClassName()); + LOG.info("Creating HiveGenericUDF from '{}'", hiveFunctionWrapper.getUDFClassName()); } @Override public void openInternal() { - LOG.info("Open HiveGenericUDF as {}", hiveFunctionWrapper.getClassName()); + LOG.info("Open HiveGenericUDF as {}", hiveFunctionWrapper.getUDFClassName()); function = createFunction(); @@ -96,7 +96,7 @@ public class HiveGenericUDF extends HiveScalarFunction<GenericUDF> { public DataType inferReturnType() throws UDFArgumentException { LOG.info( "Getting result type of HiveGenericUDF from {}", - hiveFunctionWrapper.getClassName()); + hiveFunctionWrapper.getUDFClassName()); ObjectInspector[] argumentInspectors = HiveInspectors.getArgInspectors(hiveShim, arguments); ObjectInspector resultObjectInspector = diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveGenericUDTF.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveGenericUDTF.java index 616ff096dec..35e35333bab 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveGenericUDTF.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveGenericUDTF.java @@ -145,7 +145,7 @@ public class HiveGenericUDTF extends TableFunction<Row> implements HiveFunction< public DataType inferReturnType() throws UDFArgumentException { LOG.info( "Getting result type of HiveGenericUDTF with {}", - hiveFunctionWrapper.getClassName()); + hiveFunctionWrapper.getUDFClassName()); ObjectInspector[] argumentInspectors = HiveInspectors.getArgInspectors(hiveShim, arguments); return HiveTypeUtil.toFlinkType( hiveFunctionWrapper.createFunction().initialize(argumentInspectors)); diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveScalarFunction.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveScalarFunction.java index 0620910f9da..459575303c2 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveScalarFunction.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveScalarFunction.java @@ -52,16 +52,12 @@ public abstract class HiveScalarFunction<UDFType> extends ScalarFunction @Override public boolean isDeterministic() { - try { - org.apache.hadoop.hive.ql.udf.UDFType udfType = - hiveFunctionWrapper - .getUDFClass() - .getAnnotation(org.apache.hadoop.hive.ql.udf.UDFType.class); - - return udfType != null && udfType.deterministic() && !udfType.stateful(); - } catch (ClassNotFoundException e) { - throw new FlinkHiveUDFException(e); - } + org.apache.hadoop.hive.ql.udf.UDFType udfType = + hiveFunctionWrapper + .getUDFClass() + .getAnnotation(org.apache.hadoop.hive.ql.udf.UDFType.class); + + return udfType != null && udfType.deterministic() && !udfType.stateful(); } @Override diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveSimpleUDF.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveSimpleUDF.java index 23adb64f043..daf2684482f 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveSimpleUDF.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveSimpleUDF.java @@ -61,12 +61,12 @@ public class HiveSimpleUDF extends HiveScalarFunction<UDF> { public HiveSimpleUDF(HiveFunctionWrapper<UDF> hiveFunctionWrapper, HiveShim hiveShim) { super(hiveFunctionWrapper); this.hiveShim = hiveShim; - LOG.info("Creating HiveSimpleUDF from '{}'", this.hiveFunctionWrapper.getClassName()); + LOG.info("Creating HiveSimpleUDF from '{}'", this.hiveFunctionWrapper.getUDFClassName()); } @Override public void openInternal() { - LOG.info("Opening HiveSimpleUDF as '{}'", hiveFunctionWrapper.getClassName()); + LOG.info("Opening HiveSimpleUDF as '{}'", hiveFunctionWrapper.getUDFClassName()); function = hiveFunctionWrapper.createFunction(); @@ -105,7 +105,7 @@ public class HiveSimpleUDF extends HiveScalarFunction<UDF> { throw new FlinkHiveUDFException( String.format( "Failed to open HiveSimpleUDF from %s", - hiveFunctionWrapper.getClassName()), + hiveFunctionWrapper.getUDFClassName()), e); } } diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java index b5297f72f25..562e053f1c8 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java @@ -22,6 +22,7 @@ import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.table.catalog.hive.client.HiveShim; import org.apache.flink.table.catalog.hive.client.HiveShimLoader; import org.apache.flink.table.catalog.hive.factories.HiveFunctionDefinitionFactory; +import org.apache.flink.table.factories.FunctionDefinitionFactory; import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.module.Module; import org.apache.flink.table.module.hive.udf.generic.GenericUDFLegacyGroupingID; @@ -82,12 +83,17 @@ public class HiveModule implements Module { private final String hiveVersion; private final HiveShim hiveShim; private Set<String> functionNames; + private final ClassLoader classLoader; public HiveModule() { - this(HiveShimLoader.getHiveVersion()); + this(HiveShimLoader.getHiveVersion(), Thread.currentThread().getContextClassLoader()); } public HiveModule(String hiveVersion) { + this(hiveVersion, Thread.currentThread().getContextClassLoader()); + } + + public HiveModule(String hiveVersion, ClassLoader classLoader) { checkArgument( !StringUtils.isNullOrWhitespaceOnly(hiveVersion), "hiveVersion cannot be null"); @@ -95,6 +101,7 @@ public class HiveModule implements Module { this.hiveShim = HiveShimLoader.loadHiveShim(hiveVersion); this.factory = new HiveFunctionDefinitionFactory(hiveShim); this.functionNames = new HashSet<>(); + this.classLoader = classLoader; } @Override @@ -114,18 +121,19 @@ public class HiveModule implements Module { if (BUILT_IN_FUNC_BLACKLIST.contains(name)) { return Optional.empty(); } + FunctionDefinitionFactory.Context context = () -> classLoader; // We override Hive's grouping function. Refer to the implementation for more details. if (name.equalsIgnoreCase("grouping")) { return Optional.of( factory.createFunctionDefinitionFromHiveFunction( - name, HiveGenericUDFGrouping.class.getName())); + name, HiveGenericUDFGrouping.class.getName(), context)); } // this function is used to generate legacy GROUPING__ID value for old hive versions if (name.equalsIgnoreCase(GenericUDFLegacyGroupingID.NAME)) { return Optional.of( factory.createFunctionDefinitionFromHiveFunction( - name, GenericUDFLegacyGroupingID.class.getName())); + name, GenericUDFLegacyGroupingID.class.getName(), context)); } // We override Hive's internal_interval. Refer to the implementation for more details @@ -140,7 +148,7 @@ public class HiveModule implements Module { return info.map( functionInfo -> factory.createFunctionDefinitionFromHiveFunction( - name, functionInfo.getFunctionClass().getName())); + name, functionInfo.getFunctionClass().getName(), context)); } public String getHiveVersion() { diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModuleFactory.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModuleFactory.java index 451ab52aaae..7d81e25fe62 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModuleFactory.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModuleFactory.java @@ -63,6 +63,6 @@ public class HiveModuleFactory implements ModuleFactory { .getOptional(HIVE_VERSION) .orElseGet(HiveShimLoader::getHiveVersion); - return new HiveModule(hiveVersion); + return new HiveModule(hiveVersion, context.getClassLoader()); } } diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParser.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParser.java index af8c14637f1..1a534c61d36 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParser.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/HiveParser.java @@ -226,7 +226,8 @@ public class HiveParser extends ParserImpl { context, dmlHelper, frameworkConfig, - plannerContext.getCluster()); + plannerContext.getCluster(), + plannerContext.getFlinkContext().getClassLoader()); operation = ddlAnalyzer.convertToOperation(node); return Collections.singletonList(operation); } else { diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/parse/HiveParserDDLSemanticAnalyzer.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/parse/HiveParserDDLSemanticAnalyzer.java index 53679d97de8..835fc6b7783 100644 --- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/parse/HiveParserDDLSemanticAnalyzer.java +++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/planner/delegation/hive/parse/HiveParserDDLSemanticAnalyzer.java @@ -200,6 +200,7 @@ public class HiveParserDDLSemanticAnalyzer { private final HiveParserDMLHelper dmlHelper; private final FrameworkConfig frameworkConfig; private final RelOptCluster cluster; + private final ClassLoader classLoader; static { TokenToTypeName.put(HiveASTParser.TOK_BOOLEAN, serdeConstants.BOOLEAN_TYPE_NAME); @@ -262,7 +263,8 @@ public class HiveParserDDLSemanticAnalyzer { HiveParserContext context, HiveParserDMLHelper dmlHelper, FrameworkConfig frameworkConfig, - RelOptCluster cluster) + RelOptCluster cluster, + ClassLoader classLoader) throws SemanticException { this.queryState = queryState; this.conf = queryState.getConf(); @@ -276,6 +278,7 @@ public class HiveParserDDLSemanticAnalyzer { this.dmlHelper = dmlHelper; this.frameworkConfig = frameworkConfig; this.cluster = cluster; + this.classLoader = classLoader; reservedPartitionValues = new HashSet<>(); // Partition can't have this name reservedPartitionValues.add(HiveConf.getVar(conf, HiveConf.ConfVars.DEFAULTPARTITIONNAME)); @@ -524,7 +527,8 @@ public class HiveParserDDLSemanticAnalyzer { FunctionDefinition funcDefinition = funcDefFactory.createFunctionDefinition( functionName, - new CatalogFunctionImpl(className, FunctionLanguage.JAVA)); + new CatalogFunctionImpl(className, FunctionLanguage.JAVA), + () -> classLoader); return new CreateTempSystemFunctionOperation(functionName, false, funcDefinition); } else { ObjectIdentifier identifier = parseObjectIdentifier(functionName); @@ -558,8 +562,7 @@ public class HiveParserDDLSemanticAnalyzer { FunctionDefinition macroDefinition = new HiveGenericUDF( - new HiveFunctionWrapper<>(GenericUDFMacro.class.getName(), macro), - hiveShim); + new HiveFunctionWrapper<>(GenericUDFMacro.class, macro), hiveShim); // hive's marco is more like flink's temp system function return new CreateTempSystemFunctionOperation(macroName, false, macroDefinition); } diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveFunctionWrapperTest.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveFunctionWrapperTest.java new file mode 100644 index 00000000000..8c993381987 --- /dev/null +++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveFunctionWrapperTest.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.hive; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.table.functions.ScalarFunction; +import org.apache.flink.util.FlinkUserCodeClassLoaders; +import org.apache.flink.util.UserClassLoaderJarTestUtils; + +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFMacro; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.File; +import java.net.URL; +import java.util.Random; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Test for {@link HiveFunctionWrapper}. */ +public class HiveFunctionWrapperTest { + + @TempDir private static File tempFolder; + + private static final Random random = new Random(); + private static String udfClassName; + private static File udfJar; + + @BeforeAll + static void before() throws Exception { + udfClassName = "MyToLower" + random.nextInt(50); + String udfCode = + "public class " + + "%s" + + " extends org.apache.flink.table.functions.ScalarFunction {\n" + + " public String eval(String str) {\n" + + " return str.toLowerCase();\n" + + " }\n" + + "}\n"; + udfJar = + UserClassLoaderJarTestUtils.createJarFile( + tempFolder, + "test-classloader-udf.jar", + udfClassName, + String.format(udfCode, udfClassName)); + } + + @SuppressWarnings("unchecked") + @Test + public void testDeserializeUDF() throws Exception { + // test deserialize udf + GenericUDFMacro udfMacro = new GenericUDFMacro(); + HiveFunctionWrapper<GenericUDFMacro> functionWrapper = + new HiveFunctionWrapper<>(GenericUDFMacro.class, udfMacro); + GenericUDFMacro deserializeUdfMacro = functionWrapper.createFunction(); + assertThat(deserializeUdfMacro.getClass().getName()) + .isEqualTo(GenericUDFMacro.class.getName()); + + // test deserialize udf loaded by user code class loader instead of current thread class + // loader + ClassLoader userClassLoader = + FlinkUserCodeClassLoaders.create( + new URL[] {udfJar.toURI().toURL()}, + getClass().getClassLoader(), + new Configuration()); + Class<ScalarFunction> udfClass = + (Class<ScalarFunction>) userClassLoader.loadClass(udfClassName); + ScalarFunction udf = udfClass.newInstance(); + HiveFunctionWrapper<ScalarFunction> functionWrapper1 = + new HiveFunctionWrapper<>(udfClass, udf); + ScalarFunction deserializedUdf = functionWrapper1.createFunction(); + assertThat(deserializedUdf.getClass().getName()).isEqualTo(udfClassName); + } +} diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveGenericUDAFTest.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveGenericUDAFTest.java index dfd73889033..771d85cef9e 100644 --- a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveGenericUDAFTest.java +++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveGenericUDAFTest.java @@ -116,8 +116,7 @@ public class HiveGenericUDAFTest { private static HiveGenericUDAF init( Class<?> hiveUdfClass, Object[] constantArgs, DataType[] argTypes) throws Exception { - HiveFunctionWrapper<GenericUDAFResolver> wrapper = - new HiveFunctionWrapper<>(hiveUdfClass.getName()); + HiveFunctionWrapper<GenericUDAFResolver> wrapper = new HiveFunctionWrapper<>(hiveUdfClass); CallContextMock callContext = new CallContextMock(); callContext.argumentDataTypes = Arrays.asList(argTypes); diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveGenericUDFTest.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveGenericUDFTest.java index 7ded892f942..c5162df6087 100644 --- a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveGenericUDFTest.java +++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveGenericUDFTest.java @@ -334,9 +334,8 @@ public class HiveGenericUDFTest { } private static HiveGenericUDF init( - Class hiveUdfClass, Object[] constantArgs, DataType[] argTypes) { - HiveGenericUDF udf = - new HiveGenericUDF(new HiveFunctionWrapper(hiveUdfClass.getName()), hiveShim); + Class<?> hiveUdfClass, Object[] constantArgs, DataType[] argTypes) { + HiveGenericUDF udf = new HiveGenericUDF(new HiveFunctionWrapper<>(hiveUdfClass), hiveShim); CallContextMock callContext = new CallContextMock(); callContext.argumentDataTypes = Arrays.asList(argTypes); diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveGenericUDTFTest.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveGenericUDTFTest.java index b82fd4b5f9e..3e7235a4e7e 100644 --- a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveGenericUDTFTest.java +++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveGenericUDTFTest.java @@ -148,8 +148,8 @@ public class HiveGenericUDTFTest { } private static HiveGenericUDTF init( - Class hiveUdfClass, Object[] constantArgs, DataType[] argTypes) throws Exception { - HiveFunctionWrapper<GenericUDTF> wrapper = new HiveFunctionWrapper(hiveUdfClass.getName()); + Class<?> hiveUdfClass, Object[] constantArgs, DataType[] argTypes) throws Exception { + HiveFunctionWrapper<GenericUDTF> wrapper = new HiveFunctionWrapper<>(hiveUdfClass); CallContextMock callContext = new CallContextMock(); callContext.argumentDataTypes = Arrays.asList(argTypes); diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveSimpleUDFTest.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveSimpleUDFTest.java index 5a4e864b5b9..b143d2442d9 100644 --- a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveSimpleUDFTest.java +++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/functions/hive/HiveSimpleUDFTest.java @@ -214,9 +214,8 @@ public class HiveSimpleUDFTest { assertThat(udf.eval(5, testInputs, testInputs)).isEqualTo(11); } - protected static HiveSimpleUDF init(Class hiveUdfClass, DataType[] argTypes) { - HiveSimpleUDF udf = - new HiveSimpleUDF(new HiveFunctionWrapper(hiveUdfClass.getName()), hiveShim); + protected static HiveSimpleUDF init(Class<?> hiveUdfClass, DataType[] argTypes) { + HiveSimpleUDF udf = new HiveSimpleUDF(new HiveFunctionWrapper<>(hiveUdfClass), hiveShim); // Hive UDF won't have literal args CallContextMock callContext = new CallContextMock(); diff --git a/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/gateway/context/ExecutionContext.java b/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/gateway/context/ExecutionContext.java index bc662f56e91..67b796de25b 100644 --- a/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/gateway/context/ExecutionContext.java +++ b/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/gateway/context/ExecutionContext.java @@ -82,15 +82,6 @@ public class ExecutionContext { this.tableEnv = createTableEnvironment(); } - /** - * Executes the given supplier using the execution context's classloader as thread classloader. - */ - public <R> R wrapClassLoader(Supplier<R> supplier) { - try (TemporaryClassLoaderContext ignored = TemporaryClassLoaderContext.of(classLoader)) { - return supplier.get(); - } - } - public StreamTableEnvironment getTableEnvironment() { return tableEnv; } diff --git a/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/gateway/local/LocalExecutor.java b/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/gateway/local/LocalExecutor.java index 1540dec8f65..82446e92dc2 100644 --- a/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/gateway/local/LocalExecutor.java +++ b/flink-table/flink-sql-client/src/main/java/org/apache/flink/table/client/gateway/local/LocalExecutor.java @@ -169,7 +169,7 @@ public class LocalExecutor implements Executor { List<Operation> operations; try { - operations = context.wrapClassLoader(() -> parser.parse(statement)); + operations = parser.parse(statement); } catch (Throwable t) { throw new SqlExecutionException("Failed to parse statement: " + statement, t); } @@ -186,10 +186,7 @@ public class LocalExecutor implements Executor { (TableEnvironmentInternal) context.getTableEnvironment(); try { - return context.wrapClassLoader( - () -> - Arrays.asList( - tableEnv.getParser().getCompletionHints(statement, position))); + return Arrays.asList(tableEnv.getParser().getCompletionHints(statement, position)); } catch (Throwable t) { // catch everything such that the query does not crash the executor if (LOG.isDebugEnabled()) { @@ -206,7 +203,7 @@ public class LocalExecutor implements Executor { final TableEnvironmentInternal tEnv = (TableEnvironmentInternal) context.getTableEnvironment(); try { - return context.wrapClassLoader(() -> tEnv.executeInternal(operation)); + return tEnv.executeInternal(operation); } catch (Throwable t) { throw new SqlExecutionException(MESSAGE_SQL_EXECUTION_ERROR, t); } @@ -219,7 +216,7 @@ public class LocalExecutor implements Executor { final TableEnvironmentInternal tEnv = (TableEnvironmentInternal) context.getTableEnvironment(); try { - return context.wrapClassLoader(() -> tEnv.executeInternal(operations)); + return tEnv.executeInternal(operations); } catch (Throwable t) { throw new SqlExecutionException(MESSAGE_SQL_EXECUTION_ERROR, t); } diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/FunctionCatalog.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/FunctionCatalog.java index 6a6c0c940ff..75a5be4d0ff 100644 --- a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/FunctionCatalog.java +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/FunctionCatalog.java @@ -578,10 +578,15 @@ public final class FunctionCatalog { FunctionDefinition fd; if (catalog.getFunctionDefinitionFactory().isPresent() && catalogFunction.getFunctionLanguage() != FunctionLanguage.PYTHON) { + registerFunctionJarResources( + oi.asSummaryString(), catalogFunction.getFunctionResources()); fd = catalog.getFunctionDefinitionFactory() .get() - .createFunctionDefinition(oi.getObjectName(), catalogFunction); + .createFunctionDefinition( + oi.getObjectName(), + catalogFunction, + resourceManager::getUserClassLoader); } else { fd = getFunctionDefinition(oi.asSummaryString(), catalogFunction); } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/factories/FunctionDefinitionFactory.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/factories/FunctionDefinitionFactory.java index 1f91db8f156..7ff642863cf 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/factories/FunctionDefinitionFactory.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/factories/FunctionDefinitionFactory.java @@ -21,6 +21,7 @@ package org.apache.flink.table.factories; import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.table.catalog.CatalogFunction; import org.apache.flink.table.functions.FunctionDefinition; +import org.apache.flink.util.TemporaryClassLoaderContext; /** A factory to create {@link FunctionDefinition}. */ @PublicEvolving @@ -32,6 +33,41 @@ public interface FunctionDefinitionFactory { * @param name name of the {@link CatalogFunction} * @param catalogFunction the catalog function * @return a {@link FunctionDefinition} + * @deprecated Please implement {@link #createFunctionDefinition(String, CatalogFunction, + * Context)} instead. */ - FunctionDefinition createFunctionDefinition(String name, CatalogFunction catalogFunction); + @Deprecated + default FunctionDefinition createFunctionDefinition( + String name, CatalogFunction catalogFunction) { + throw new RuntimeException( + "Please implement FunctionDefinitionFactory#createFunctionDefinition(String, CatalogFunction, Context) instead."); + } + + /** + * Creates a {@link FunctionDefinition} from given {@link CatalogFunction} with the given {@link + * Context} containing the class loader of the current session, which is useful when it's needed + * to load class from class name. + * + * <p>The default implementation will call {@link #createFunctionDefinition(String, + * CatalogFunction)} directly. + * + * @param name name of the {@link CatalogFunction} + * @param catalogFunction the catalog function + * @param context the {@link Context} for creating function definition + * @return a {@link FunctionDefinition} + */ + default FunctionDefinition createFunctionDefinition( + String name, CatalogFunction catalogFunction, Context context) { + try (TemporaryClassLoaderContext ignored = + TemporaryClassLoaderContext.of(context.getClassLoader())) { + return createFunctionDefinition(name, catalogFunction); + } + } + + /** Context provided when a function definition is created. */ + @PublicEvolving + interface Context { + /** Returns the class loader of the current session. */ + ClassLoader getClassLoader(); + } } diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/factories/FunctionDefinitionFactory.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestFunctionDefinitionFactory.java similarity index 54% copy from flink-table/flink-table-common/src/main/java/org/apache/flink/table/factories/FunctionDefinitionFactory.java copy to flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestFunctionDefinitionFactory.java index 1f91db8f156..78b25978cd8 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/factories/FunctionDefinitionFactory.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestFunctionDefinitionFactory.java @@ -16,22 +16,23 @@ * limitations under the License. */ -package org.apache.flink.table.factories; +package org.apache.flink.table.planner.factories; -import org.apache.flink.annotation.PublicEvolving; import org.apache.flink.table.catalog.CatalogFunction; +import org.apache.flink.table.factories.FunctionDefinitionFactory; import org.apache.flink.table.functions.FunctionDefinition; +import org.apache.flink.table.functions.UserDefinedFunctionHelper; -/** A factory to create {@link FunctionDefinition}. */ -@PublicEvolving -public interface FunctionDefinitionFactory { +/** + * Use TestFunctionDefinitionFactory to test loading function to ensure the function can be loaded + * correctly if only implement legacy interface {@link + * FunctionDefinitionFactory#createFunctionDefinition(String, CatalogFunction)}. + */ +public class TestFunctionDefinitionFactory implements FunctionDefinitionFactory { - /** - * Creates a {@link FunctionDefinition} from given {@link CatalogFunction}. - * - * @param name name of the {@link CatalogFunction} - * @param catalogFunction the catalog function - * @return a {@link FunctionDefinition} - */ - FunctionDefinition createFunctionDefinition(String name, CatalogFunction catalogFunction); + public FunctionDefinition createFunctionDefinition( + String name, CatalogFunction catalogFunction) { + return UserDefinedFunctionHelper.instantiateFunction( + Thread.currentThread().getContextClassLoader(), null, name, catalogFunction); + } } diff --git a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesCatalog.java b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesCatalog.java index 86d9d9b6b69..69ee1955488 100644 --- a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesCatalog.java +++ b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/factories/TestValuesCatalog.java @@ -29,6 +29,7 @@ import org.apache.flink.table.catalog.exceptions.TableNotExistException; import org.apache.flink.table.catalog.exceptions.TableNotPartitionedException; import org.apache.flink.table.expressions.Expression; import org.apache.flink.table.expressions.ResolvedExpression; +import org.apache.flink.table.factories.FunctionDefinitionFactory; import org.apache.flink.table.planner.utils.FilterUtils; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.logical.BooleanType; @@ -44,7 +45,7 @@ import java.util.Optional; import java.util.function.Function; import java.util.stream.Collectors; -/** Use TestValuesCatalog to test partition push down. */ +/** Use TestValuesCatalog to test partition push down and create function definition. */ public class TestValuesCatalog extends GenericInMemoryCatalog { private final boolean supportListPartitionByFilter; @@ -95,6 +96,11 @@ public class TestValuesCatalog extends GenericInMemoryCatalog { .collect(Collectors.toList()); } + @Override + public Optional<FunctionDefinitionFactory> getFunctionDefinitionFactory() { + return Optional.of(new TestFunctionDefinitionFactory()); + } + private Function<String, Comparable<?>> getValueGetter( Map<String, String> spec, TableSchema schema) { return field -> { diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/catalog/CatalogTableITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/catalog/CatalogTableITCase.scala index d4019fecaa0..2704d564305 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/catalog/CatalogTableITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/catalog/CatalogTableITCase.scala @@ -22,12 +22,14 @@ import org.apache.flink.table.api.config.{ExecutionConfigOptions, TableConfigOpt import org.apache.flink.table.api.internal.TableEnvironmentImpl import org.apache.flink.table.catalog._ import org.apache.flink.table.planner.expressions.utils.Func0 +import org.apache.flink.table.planner.factories.TestValuesCatalog import org.apache.flink.table.planner.factories.utils.TestCollectionTableFactory import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedScalarFunctions.JavaFunc0 import org.apache.flink.table.planner.utils.DateTimeTestUtil.localDateTime +import org.apache.flink.table.utils.UserDefinedFunctions.{GENERATED_LOWER_UDF_CLASS, GENERATED_LOWER_UDF_CODE} import org.apache.flink.test.util.AbstractTestBase import org.apache.flink.types.Row -import org.apache.flink.util.FileUtils +import org.apache.flink.util.{FileUtils, UserClassLoaderJarTestUtils} import org.junit.{Before, Rule, Test} import org.junit.Assert.{assertEquals, assertNotEquals, fail} @@ -39,8 +41,10 @@ import java.io.File import java.math.{BigDecimal => JBigDecimal} import java.net.URI import java.util +import java.util.UUID import scala.collection.JavaConversions._ +import scala.util.Random /** Test cases for catalog table. */ @RunWith(classOf[Parameterized]) @@ -1235,6 +1239,40 @@ class CatalogTableITCase(isStreamingMode: Boolean) extends AbstractTestBase { expectedProperty.put("k2", "b") assertEquals(expectedProperty, database.getProperties) } + + @Test + def testLoadFunction(): Unit = { + tableEnv.registerCatalog("cat2", new TestValuesCatalog("cat2", "default", true)) + tableEnv.executeSql("use catalog cat2") + // test load customer function packaged in a jar + val random = new Random(); + val udfClassName = GENERATED_LOWER_UDF_CLASS + random.nextInt(50) + val jarPath = UserClassLoaderJarTestUtils + .createJarFile( + AbstractTestBase.TEMPORARY_FOLDER.newFolder(String.format("test-jar-%s", UUID.randomUUID)), + "test-classloader-udf.jar", + udfClassName, + String.format(GENERATED_LOWER_UDF_CODE, udfClassName) + ) + .toURI + .toString + tableEnv.executeSql(s"""create function lowerUdf as '$udfClassName' using jar '$jarPath'""") + + TestCollectionTableFactory.reset() + TestCollectionTableFactory.initData(List(Row.of("BoB"))) + val ddl1 = + """ + |create table t1( + | a varchar + |) with ( + | 'connector' = 'COLLECTION' + |) + """.stripMargin + tableEnv.executeSql(ddl1) + assertEquals( + "+I[bob]", + tableEnv.executeSql("select lowerUdf(a) from t1").collect().next().toString) + } } object CatalogTableITCase {
