This is an automated email from the ASF dual-hosted git repository.
shengkai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 95869e0bc99 [FLINK-30579][hive] Introducing configurable option to use
hive native function
95869e0bc99 is described below
commit 95869e0bc99f9703807a114a4240dbc443d970cb
Author: fengli <[email protected]>
AuthorDate: Mon Jan 9 21:07:48 2023 +0800
[FLINK-30579][hive] Introducing configurable option to use hive native
function
This closes #21629
---
.../26d337fc-45c4-4d03-a84a-6692c37fafbc | 6 +
.../apache/flink/connectors/hive/HiveOptions.java | 7 +
.../table/endpoint/hive/HiveServer2Endpoint.java | 18 +-
.../table/functions/hive/HiveSumAggFunction.java | 10 +-
.../apache/flink/table/module/hive/HiveModule.java | 22 ++-
.../flink/table/module/hive/HiveModuleFactory.java | 2 +-
.../connectors/hive/HiveDialectAggITCase.java | 207 +++++++++++++++++++++
.../connectors/hive/HiveDialectQueryITCase.java | 85 ---------
.../endpoint/hive/HiveServer2EndpointITCase.java | 7 +-
.../src/test/resources/endpoint/hive_module.q | 65 +++++++
.../explain/testSumAggFunctionFallbackPlan.out | 21 +++
.../gateway/api/session/SessionEnvironment.java | 89 ++++++---
.../gateway/service/context/SessionContext.java | 31 ++-
13 files changed, 443 insertions(+), 127 deletions(-)
diff --git
a/flink-connectors/flink-connector-hive/archunit-violations/26d337fc-45c4-4d03-a84a-6692c37fafbc
b/flink-connectors/flink-connector-hive/archunit-violations/26d337fc-45c4-4d03-a84a-6692c37fafbc
index 0113d84f20c..f5cbf4e0c18 100644
---
a/flink-connectors/flink-connector-hive/archunit-violations/26d337fc-45c4-4d03-a84a-6692c37fafbc
+++
b/flink-connectors/flink-connector-hive/archunit-violations/26d337fc-45c4-4d03-a84a-6692c37fafbc
@@ -8,6 +8,12 @@ org.apache.flink.connectors.hive.HiveDialectQueryITCase does
not satisfy: only o
* reside in a package 'org.apache.flink.runtime.*' and contain any fields that
are static, final, and of type InternalMiniClusterExtension and annotated with
@RegisterExtension\
* reside outside of package 'org.apache.flink.runtime.*' and contain any
fields that are static, final, and of type MiniClusterExtension and annotated
with @RegisterExtension\
* reside in a package 'org.apache.flink.runtime.*' and is annotated with
@ExtendWith with class InternalMiniClusterExtension\
+* reside outside of package 'org.apache.flink.runtime.*' and is annotated with
@ExtendWith with class MiniClusterExtension\
+ or contain any fields that are public, static, and of type
MiniClusterWithClientResource and final and annotated with @ClassRule or
contain any fields that is of type MiniClusterWithClientResource and public and
final and not static and annotated with @Rule
+org.apache.flink.connectors.hive.HiveDialectAggITCase does not satisfy: only
one of the following predicates match:\
+* reside in a package 'org.apache.flink.runtime.*' and contain any fields that
are static, final, and of type InternalMiniClusterExtension and annotated with
@RegisterExtension\
+* reside outside of package 'org.apache.flink.runtime.*' and contain any
fields that are static, final, and of type MiniClusterExtension and annotated
with @RegisterExtension\
+* reside in a package 'org.apache.flink.runtime.*' and is annotated with
@ExtendWith with class InternalMiniClusterExtension\
* reside outside of package 'org.apache.flink.runtime.*' and is annotated with
@ExtendWith with class MiniClusterExtension\
or contain any fields that are public, static, and of type
MiniClusterWithClientResource and final and annotated with @ClassRule or
contain any fields that is of type MiniClusterWithClientResource and public and
final and not static and annotated with @Rule
org.apache.flink.connectors.hive.HiveLookupJoinITCase does not satisfy: only
one of the following predicates match:\
diff --git
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/connectors/hive/HiveOptions.java
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/connectors/hive/HiveOptions.java
index cb4abc85eed..fc74424cd50 100644
---
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/connectors/hive/HiveOptions.java
+++
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/connectors/hive/HiveOptions.java
@@ -230,6 +230,13 @@ public class HiveOptions {
.withDescription(
"The cache TTL (e.g. 10min) for the build table in
lookup join.");
+ public static final ConfigOption<Boolean>
TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED =
+ key("table.exec.hive.native-agg-function.enabled")
+ .booleanType()
+ .defaultValue(false)
+ .withDescription(
+ "Enabling native aggregate function for hive
dialect to use hash-agg strategy that can improve the aggregation
performance.");
+
//
--------------------------------------------------------------------------------------------
// Enums
//
--------------------------------------------------------------------------------------------
diff --git
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/endpoint/hive/HiveServer2Endpoint.java
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/endpoint/hive/HiveServer2Endpoint.java
index ecbdc75b731..a7fb69334f5 100644
---
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/endpoint/hive/HiveServer2Endpoint.java
+++
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/endpoint/hive/HiveServer2Endpoint.java
@@ -28,6 +28,7 @@ import
org.apache.flink.table.catalog.CatalogBaseTable.TableKind;
import org.apache.flink.table.catalog.ResolvedSchema;
import org.apache.flink.table.catalog.hive.HiveCatalog;
import org.apache.flink.table.catalog.hive.client.HiveShimLoader;
+import org.apache.flink.table.factories.FactoryUtil;
import org.apache.flink.table.gateway.api.SqlGatewayService;
import org.apache.flink.table.gateway.api.endpoint.EndpointVersion;
import org.apache.flink.table.gateway.api.endpoint.SqlGatewayEndpoint;
@@ -40,8 +41,6 @@ import
org.apache.flink.table.gateway.api.session.SessionEnvironment;
import org.apache.flink.table.gateway.api.session.SessionHandle;
import org.apache.flink.table.gateway.api.utils.SqlGatewayException;
import org.apache.flink.table.gateway.api.utils.ThreadUtils;
-import org.apache.flink.table.module.Module;
-import org.apache.flink.table.module.hive.HiveModule;
import org.apache.flink.util.ExceptionUtils;
import org.apache.hadoop.hive.conf.HiveConf;
@@ -310,7 +309,14 @@ public class HiveServer2Endpoint implements
TCLIService.Iface, SqlGatewayEndpoin
// all the alive PersistenceManager in the ObjectStore, which may
get error like
// "Persistence Manager has been closed" in the later connection.
hiveCatalog.open();
- Module hiveModule = new HiveModule();
+ // create hive module lazily
+ SessionEnvironment.ModuleCreator hiveModuleCreator =
+ (readableConfig, classLoader) ->
+ FactoryUtil.createModule(
+ moduleName,
+ Collections.emptyMap(),
+ readableConfig,
+ classLoader);
// set variables to HiveConf and Session's conf
Map<String, String> sessionConfig = new HashMap<>();
sessionConfig.put(TABLE_SQL_DIALECT.key(), SqlDialect.HIVE.name());
@@ -321,8 +327,10 @@ public class HiveServer2Endpoint implements
TCLIService.Iface, SqlGatewayEndpoin
service.openSession(
SessionEnvironment.newBuilder()
.setSessionEndpointVersion(sessionVersion)
- .registerCatalog(catalogName, hiveCatalog)
- .registerModuleAtHead(moduleName,
hiveModule)
+ .registerCatalogCreator(
+ catalogName,
+ (readableConfig, classLoader) ->
hiveCatalog)
+ .registerModuleCreatorAtHead(moduleName,
hiveModuleCreator)
.setDefaultCatalog(catalogName)
.addSessionConfig(sessionConfig)
.build());
diff --git
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveSumAggFunction.java
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveSumAggFunction.java
index 0feb1645c42..713cd72e13a 100644
---
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveSumAggFunction.java
+++
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveSumAggFunction.java
@@ -26,6 +26,7 @@ import
org.apache.flink.table.expressions.UnresolvedReferenceExpression;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.CallContext;
+import static
org.apache.flink.connectors.hive.HiveOptions.TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED;
import static
org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
import static
org.apache.flink.table.planner.expressions.ExpressionBuilder.hiveAggDecimalPlus;
import static
org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse;
@@ -125,11 +126,16 @@ public class HiveSumAggFunction extends
HiveDeclarativeAggregateFunction {
int precision =
Math.min(MAX_PRECISION,
getPrecision(argsType.getLogicalType()) + 10);
return DataTypes.DECIMAL(precision,
getScale(argsType.getLogicalType()));
+ case TIMESTAMP_WITHOUT_TIME_ZONE:
+ throw new TableException(
+ String.format(
+ "Native hive sum aggregate function does not
support type: %s. Please set option '%s' to false.",
+ argsType,
TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED.key()));
default:
throw new TableException(
String.format(
- "Sum aggregate function does not support type:
'%s'. Please re-check the data type.",
- argsType.getLogicalType().getTypeRoot()));
+ "Only numeric or string type arguments are
accepted but %s is passed.",
+ argsType));
}
}
diff --git
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java
index 8cb9cb721bb..1982ddb79ba 100644
---
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java
+++
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModule.java
@@ -19,6 +19,8 @@
package org.apache.flink.table.module.hive;
import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.table.catalog.hive.client.HiveShim;
import org.apache.flink.table.catalog.hive.client.HiveShimLoader;
import
org.apache.flink.table.catalog.hive.factories.HiveFunctionDefinitionFactory;
@@ -41,6 +43,7 @@ import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
+import static
org.apache.flink.connectors.hive.HiveOptions.TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED;
import static org.apache.flink.util.Preconditions.checkArgument;
/** Module to provide Hive built-in metadata. */
@@ -85,17 +88,27 @@ public class HiveModule implements Module {
private final String hiveVersion;
private final HiveShim hiveShim;
private Set<String> functionNames;
+ private final ReadableConfig config;
private final ClassLoader classLoader;
+ @VisibleForTesting
public HiveModule() {
- this(HiveShimLoader.getHiveVersion(),
Thread.currentThread().getContextClassLoader());
+ this(
+ HiveShimLoader.getHiveVersion(),
+ new Configuration(),
+ Thread.currentThread().getContextClassLoader());
}
+ @VisibleForTesting
public HiveModule(String hiveVersion) {
this(hiveVersion, Thread.currentThread().getContextClassLoader());
}
public HiveModule(String hiveVersion, ClassLoader classLoader) {
+ this(hiveVersion, new Configuration(), classLoader);
+ }
+
+ public HiveModule(String hiveVersion, ReadableConfig config, ClassLoader
classLoader) {
checkArgument(
!StringUtils.isNullOrWhitespaceOnly(hiveVersion), "hiveVersion
cannot be null");
@@ -103,6 +116,7 @@ public class HiveModule implements Module {
this.hiveShim = HiveShimLoader.loadHiveShim(hiveVersion);
this.factory = new HiveFunctionDefinitionFactory(hiveShim);
this.functionNames = new HashSet<>();
+ this.config = config;
this.classLoader = classLoader;
}
@@ -128,7 +142,7 @@ public class HiveModule implements Module {
FunctionDefinitionFactory.Context context = () -> classLoader;
// We override Hive's sum function by native implementation to
supports hash-agg
- if (name.equalsIgnoreCase("sum")) {
+ if (isNativeAggFunctionEnabled() && name.equalsIgnoreCase("sum")) {
return Optional.of(new HiveSumAggFunction());
}
@@ -178,4 +192,8 @@ public class HiveModule implements Module {
public String getHiveVersion() {
return hiveVersion;
}
+
+ private boolean isNativeAggFunctionEnabled() {
+ return config.get(TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED);
+ }
}
diff --git
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModuleFactory.java
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModuleFactory.java
index 7d81e25fe62..8978529d107 100644
---
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModuleFactory.java
+++
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/module/hive/HiveModuleFactory.java
@@ -63,6 +63,6 @@ public class HiveModuleFactory implements ModuleFactory {
.getOptional(HIVE_VERSION)
.orElseGet(HiveShimLoader::getHiveVersion);
- return new HiveModule(hiveVersion, context.getClassLoader());
+ return new HiveModule(hiveVersion, context.getConfiguration(),
context.getClassLoader());
}
}
diff --git
a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java
b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java
new file mode 100644
index 00000000000..c92bf6546f1
--- /dev/null
+++
b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connectors.hive;
+
+import org.apache.flink.table.api.SqlDialect;
+import org.apache.flink.table.api.TableEnvironment;
+import org.apache.flink.table.catalog.hive.HiveCatalog;
+import org.apache.flink.table.catalog.hive.HiveTestUtils;
+import org.apache.flink.table.module.CoreModule;
+import org.apache.flink.table.module.hive.HiveModule;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CollectionUtil;
+
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.ClassRule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.List;
+
+import static
org.apache.flink.connectors.hive.HiveOptions.TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED;
+import static org.apache.flink.core.testutils.FlinkAssertions.anyCauseMatches;
+import static
org.apache.flink.table.planner.utils.TableTestUtil.readFromResource;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Test for native hive agg function compatibility. */
+public class HiveDialectAggITCase {
+
+ @ClassRule public static TemporaryFolder tempFolder = new
TemporaryFolder();
+
+ private static HiveCatalog hiveCatalog;
+ private static TableEnvironment tableEnv;
+
+ @BeforeClass
+ public static void setup() throws Exception {
+ hiveCatalog = HiveTestUtils.createHiveCatalog();
+ // required by query like "src.`[k].*` from src"
+
hiveCatalog.getHiveConf().setVar(HiveConf.ConfVars.HIVE_QUOTEDID_SUPPORT,
"none");
+ hiveCatalog.open();
+ tableEnv = getTableEnvWithHiveCatalog();
+
+ // create tables
+ tableEnv.executeSql("create table foo (x int, y int)");
+
+ HiveTestUtils.createTextTableInserter(hiveCatalog, "default", "foo")
+ .addRow(new Object[] {1, 1})
+ .addRow(new Object[] {2, 2})
+ .addRow(new Object[] {3, 3})
+ .addRow(new Object[] {4, 4})
+ .addRow(new Object[] {5, 5})
+ .commit();
+ }
+
+ @Before
+ public void before() {
+ // enable native hive agg function
+ tableEnv.getConfig().set(TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED,
true);
+ }
+
+ @Test
+ public void testSumAggFunctionPlan() {
+ // test explain
+ String actualPlan = explainSql("select x, sum(y) from foo group by x");
+
assertThat(actualPlan).isEqualTo(readFromResource("/explain/testSumAggFunctionPlan.out"));
+
+ // test fallback to hive sum udaf
+ tableEnv.getConfig().set(TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED,
false);
+ String actualSortAggPlan = explainSql("select x, sum(y) from foo group
by x");
+ assertThat(actualSortAggPlan)
+
.isEqualTo(readFromResource("/explain/testSumAggFunctionFallbackPlan.out"));
+ }
+
+ @Test
+ public void testSimpleSumAggFunction() throws Exception {
+ tableEnv.executeSql(
+ "create table test_sum(x string, y string, z int, d
decimal(10,5), e float, f double, ts timestamp)");
+ tableEnv.executeSql(
+ "insert into test_sum values (NULL, '2', 1, 1.11, 1.2,
1.3, '2021-08-04 16:26:33.4'), "
+ + "(NULL, 'b', 2, 2.22, 2.3, 2.4, '2021-08-07
16:26:33.4'), "
+ + "(NULL, '4', 3, 3.33, 3.5, 3.6, '2021-08-08
16:26:33.4'), "
+ + "(NULL, NULL, 4, 4.45, 4.7, 4.8, '2021-08-09
16:26:33.4')")
+ .await();
+
+ // test sum with all elements are null
+ List<Row> result =
+ CollectionUtil.iteratorToList(
+ tableEnv.executeSql("select sum(x) from
test_sum").collect());
+ assertThat(result.toString()).isEqualTo("[+I[null]]");
+
+ // test sum string type with partial element can't convert to double,
result type is double
+ List<Row> result2 =
+ CollectionUtil.iteratorToList(
+ tableEnv.executeSql("select sum(y) from
test_sum").collect());
+ assertThat(result2.toString()).isEqualTo("[+I[6.0]]");
+
+ // test decimal type
+ List<Row> result3 =
+ CollectionUtil.iteratorToList(
+ tableEnv.executeSql("select sum(d) from
test_sum").collect());
+ assertThat(result3.toString()).isEqualTo("[+I[11.11000]]");
+
+ // test sum int, result type is bigint
+ List<Row> result4 =
+ CollectionUtil.iteratorToList(
+ tableEnv.executeSql("select sum(z) from
test_sum").collect());
+ assertThat(result4.toString()).isEqualTo("[+I[10]]");
+
+ // test float type
+ List<Row> result5 =
+ CollectionUtil.iteratorToList(
+ tableEnv.executeSql("select sum(e) from
test_sum").collect());
+ float actualFloatValue = ((Double)
result5.get(0).getField(0)).floatValue();
+ assertThat(actualFloatValue).isEqualTo(11.7f);
+
+ // test double type
+ List<Row> result6 =
+ CollectionUtil.iteratorToList(
+ tableEnv.executeSql("select sum(f) from
test_sum").collect());
+ actualFloatValue = ((Double) result6.get(0).getField(0)).floatValue();
+ assertThat(actualFloatValue).isEqualTo(12.1f);
+
+ // test sum string&int type simultaneously
+ List<Row> result7 =
+ CollectionUtil.iteratorToList(
+ tableEnv.executeSql("select sum(y), sum(z) from
test_sum").collect());
+ assertThat(result7.toString()).isEqualTo("[+I[6.0, 10]]");
+
+ // test unsupported timestamp type
+ assertThatThrownBy(
+ () ->
+ CollectionUtil.iteratorToList(
+ tableEnv.executeSql("select
sum(ts)from test_sum")
+ .collect()))
+ .rootCause()
+ .satisfiesAnyOf(
+ anyCauseMatches(
+ "Native hive sum aggregate function does not
support type: TIMESTAMP(9). "
+ + "Please set option
'table.exec.hive.native-agg-function.enabled' to false."));
+
+ tableEnv.executeSql("drop table test_sum");
+ }
+
+ @Test
+ public void testSumAggWithGroupKey() throws Exception {
+ tableEnv.executeSql(
+ "create table test_sum_group(name string, num bigint, price
decimal(10,5))");
+ tableEnv.executeSql(
+ "insert into test_sum_group values ('tom', 2, 7.2),
('tony', 2, 23.7), ('tom', 10, 3.33), ('tony', 4, 4.45), ('nadal', 4, 10.455)")
+ .await();
+
+ List<Row> result =
+ CollectionUtil.iteratorToList(
+ tableEnv.executeSql(
+ "select name, sum(num), sum(price),
sum(num * price) from test_sum_group group by name")
+ .collect());
+ assertThat(result.toString())
+ .isEqualTo(
+ "[+I[tom, 12, 10.53000, 47.70000], +I[tony, 6,
28.15000, 65.20000], +I[nadal, 4, 10.45500, 41.82000]]");
+
+ tableEnv.executeSql("drop table test_sum_group");
+ }
+
+ private String explainSql(String sql) {
+ return (String)
+ CollectionUtil.iteratorToList(tableEnv.executeSql("explain " +
sql).collect())
+ .get(0)
+ .getField(0);
+ }
+
+ private static TableEnvironment getTableEnvWithHiveCatalog() {
+ TableEnvironment tableEnv =
HiveTestUtils.createTableEnvInBatchMode(SqlDialect.HIVE);
+ tableEnv.registerCatalog(hiveCatalog.getName(), hiveCatalog);
+ tableEnv.useCatalog(hiveCatalog.getName());
+ // automatically load hive module in hive-compatible mode
+ HiveModule hiveModule =
+ new HiveModule(
+ hiveCatalog.getHiveVersion(),
+ tableEnv.getConfig(),
+ Thread.currentThread().getContextClassLoader());
+ CoreModule coreModule = CoreModule.INSTANCE;
+ for (String loaded : tableEnv.listModules()) {
+ tableEnv.unloadModule(loaded);
+ }
+ tableEnv.loadModule("hive", hiveModule);
+ tableEnv.loadModule("core", coreModule);
+ return tableEnv;
+ }
+}
diff --git
a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryITCase.java
b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryITCase.java
index 2ffa4a7d44e..54951d47be7 100644
---
a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryITCase.java
+++
b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectQueryITCase.java
@@ -986,91 +986,6 @@ public class HiveDialectQueryITCase {
}
}
- @Test
- public void testSumAggFunctionPlan() {
- // test explain
- String actualPlan = explainSql("select x, sum(y) from foo group by x");
-
assertThat(actualPlan).isEqualTo(readFromResource("/explain/testSumAggFunctionPlan.out"));
- }
-
- @Test
- public void testSimpleSumAggFunction() throws Exception {
- tableEnv.executeSql(
- "create table test_sum(x string, y string, z int, d
decimal(10,5), e float, f double)");
- tableEnv.executeSql(
- "insert into test_sum values (NULL, '2', 1, 1.11, 1.2,
1.3), "
- + "(NULL, 'b', 2, 2.22, 2.3, 2.4), "
- + "(NULL, '4', 3, 3.33, 3.5, 3.6), "
- + "(NULL, NULL, 4, 4.45, 4.7, 4.8)")
- .await();
-
- // test sum with all elements are null
- List<Row> result =
- CollectionUtil.iteratorToList(
- tableEnv.executeSql("select sum(x) from
test_sum").collect());
- assertThat(result.toString()).isEqualTo("[+I[null]]");
-
- // test sum string type with partial element can't convert to double,
result type is double
- List<Row> result2 =
- CollectionUtil.iteratorToList(
- tableEnv.executeSql("select sum(y) from
test_sum").collect());
- assertThat(result2.toString()).isEqualTo("[+I[6.0]]");
-
- // test decimal type
- List<Row> result3 =
- CollectionUtil.iteratorToList(
- tableEnv.executeSql("select sum(d) from
test_sum").collect());
- assertThat(result3.toString()).isEqualTo("[+I[11.11000]]");
-
- // test sum int, result type is bigint
- List<Row> result4 =
- CollectionUtil.iteratorToList(
- tableEnv.executeSql("select sum(z) from
test_sum").collect());
- assertThat(result4.toString()).isEqualTo("[+I[10]]");
-
- // test float type
- List<Row> result5 =
- CollectionUtil.iteratorToList(
- tableEnv.executeSql("select sum(e) from
test_sum").collect());
- float actualFloatValue = ((Double)
result5.get(0).getField(0)).floatValue();
- assertThat(actualFloatValue).isEqualTo(11.7f);
-
- // test double type
- List<Row> result6 =
- CollectionUtil.iteratorToList(
- tableEnv.executeSql("select sum(f) from
test_sum").collect());
- actualFloatValue = ((Double) result6.get(0).getField(0)).floatValue();
- assertThat(actualFloatValue).isEqualTo(12.1f);
-
- // test sum string&int type simultaneously
- List<Row> result7 =
- CollectionUtil.iteratorToList(
- tableEnv.executeSql("select sum(y), sum(z) from
test_sum").collect());
- assertThat(result7.toString()).isEqualTo("[+I[6.0, 10]]");
-
- tableEnv.executeSql("drop table test_sum");
- }
-
- @Test
- public void testSumAggWithGroupKey() throws Exception {
- tableEnv.executeSql(
- "create table test_sum_group(name string, num bigint, price
decimal(10,5))");
- tableEnv.executeSql(
- "insert into test_sum_group values ('tom', 2, 7.2),
('tony', 2, 23.7), ('tom', 10, 3.33), ('tony', 4, 4.45), ('nadal', 4, 10.455)")
- .await();
-
- List<Row> result =
- CollectionUtil.iteratorToList(
- tableEnv.executeSql(
- "select name, sum(num), sum(price),
sum(num * price) from test_sum_group group by name")
- .collect());
- assertThat(result.toString())
- .isEqualTo(
- "[+I[tom, 12, 10.53000, 47.70000], +I[tony, 6,
28.15000, 65.20000], +I[nadal, 4, 10.45500, 41.82000]]");
-
- tableEnv.executeSql("drop table test_sum_group");
- }
-
private void runQFile(File qfile) throws Exception {
QTest qTest = extractQTest(qfile);
for (int i = 0; i < qTest.statements.size(); i++) {
diff --git
a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/endpoint/hive/HiveServer2EndpointITCase.java
b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/endpoint/hive/HiveServer2EndpointITCase.java
index f2a8e5f72fa..98e8612a54a 100644
---
a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/endpoint/hive/HiveServer2EndpointITCase.java
+++
b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/table/endpoint/hive/HiveServer2EndpointITCase.java
@@ -107,6 +107,7 @@ import java.util.stream.Collectors;
import static org.apache.flink.api.common.RuntimeExecutionMode.BATCH;
import static org.apache.flink.configuration.ExecutionOptions.RUNTIME_MODE;
import static
org.apache.flink.configuration.PipelineOptionsInternal.PIPELINE_FIXED_JOB_ID;
+import static
org.apache.flink.connectors.hive.HiveOptions.TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED;
import static org.apache.flink.core.testutils.FlinkAssertions.anyCauseMatches;
import static
org.apache.flink.table.api.config.TableConfigOptions.MAX_LENGTH_GENERATED_CODE;
import static
org.apache.flink.table.api.config.TableConfigOptions.TABLE_DML_SYNC;
@@ -162,6 +163,8 @@ public class HiveServer2EndpointITCase extends TestLogger {
configs.put("set:system:ks", "vs");
configs.put("set:key1", "value1");
configs.put("set:hivevar:key2", "${hiveconf:common-key}");
+ // enable native hive agg function
+ configs.put(TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED.key(), "true");
openSessionReq.setConfiguration(configs);
TOpenSessionResp openSessionResp = client.OpenSession(openSessionReq);
SessionHandle sessionHandle =
@@ -177,7 +180,9 @@ public class HiveServer2EndpointITCase extends TestLogger {
new AbstractMap.SimpleEntry<>(RUNTIME_MODE.key(),
BATCH.name()),
new
AbstractMap.SimpleEntry<>(MAX_LENGTH_GENERATED_CODE.key(), "-1"),
new AbstractMap.SimpleEntry<>("key1", "value1"),
- new AbstractMap.SimpleEntry<>("key2", "common-val"));
+ new AbstractMap.SimpleEntry<>("key2", "common-val"),
+ new AbstractMap.SimpleEntry<>(
+
TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED.key(), "true"));
}
@Test
diff --git
a/flink-connectors/flink-connector-hive/src/test/resources/endpoint/hive_module.q
b/flink-connectors/flink-connector-hive/src/test/resources/endpoint/hive_module.q
index 4e9f0274c94..00dc445f68e 100644
---
a/flink-connectors/flink-connector-hive/src/test/resources/endpoint/hive_module.q
+++
b/flink-connectors/flink-connector-hive/src/test/resources/endpoint/hive_module.q
@@ -59,6 +59,71 @@ SELECT SUBSTRING_INDEX('www.apache.org', '.', 2) FROM
(VALUES (1, 'Hello World')
1 row in set
!ok
+# ==========================================================================
+# test use built-in native agg function of hive module
+# ==========================================================================
+
+CREATE TABLE source (
+ a INT
+);
+!output
++--------+
+| result |
++--------+
+| OK |
++--------+
+1 row in set
+!ok
+
+EXPLAIN SELECT SUM(a) FROM source;
+!output
+== Abstract Syntax Tree ==
+LogicalProject(_o__c0=[$0])
++- LogicalAggregate(group=[{}], agg#0=[sum($0)])
+ +- LogicalProject($f0=[$0])
+ +- LogicalTableScan(table=[[hive, default, source]])
+
+== Optimized Physical Plan ==
+SortAggregate(isMerge=[false], select=[sum(a) AS $f0])
++- Exchange(distribution=[single])
+ +- TableSourceScan(table=[[hive, default, source]], fields=[a])
+
+== Optimized Execution Plan ==
+SortAggregate(isMerge=[false], select=[sum(a) AS $f0])
++- Exchange(distribution=[single])
+ +- TableSourceScan(table=[[hive, default, source]], fields=[a])
+!ok
+
+# enable hive native agg function that use hash-agg strategy
+SET table.exec.hive.native-agg-function.enabled = true;
+!output
++--------+
+| result |
++--------+
+| OK |
++--------+
+1 row in set
+!ok
+
+EXPLAIN SELECT SUM(a) FROM source;
+!output
+== Abstract Syntax Tree ==
+LogicalProject(_o__c0=[$0])
++- LogicalAggregate(group=[{}], agg#0=[sum($0)])
+ +- LogicalProject($f0=[$0])
+ +- LogicalTableScan(table=[[hive, default, source]])
+
+== Optimized Physical Plan ==
+HashAggregate(isMerge=[false], select=[sum(a) AS $f0])
++- Exchange(distribution=[single])
+ +- TableSourceScan(table=[[hive, default, source]], fields=[a])
+
+== Optimized Execution Plan ==
+HashAggregate(isMerge=[false], select=[sum(a) AS $f0])
++- Exchange(distribution=[single])
+ +- TableSourceScan(table=[[hive, default, source]], fields=[a])
+!ok
+
# load hive module with module name as string literal
LOAD MODULE 'hive';
!output
diff --git
a/flink-connectors/flink-connector-hive/src/test/resources/explain/testSumAggFunctionFallbackPlan.out
b/flink-connectors/flink-connector-hive/src/test/resources/explain/testSumAggFunctionFallbackPlan.out
new file mode 100644
index 00000000000..e297c060e46
--- /dev/null
+++
b/flink-connectors/flink-connector-hive/src/test/resources/explain/testSumAggFunctionFallbackPlan.out
@@ -0,0 +1,21 @@
+== Abstract Syntax Tree ==
+LogicalProject(x=[$0], _o__c1=[$1])
++- LogicalAggregate(group=[{0}], agg#0=[sum($1)])
+ +- LogicalProject($f0=[$0], $f1=[$1])
+ +- LogicalTableScan(table=[[test-catalog, default, foo]])
+
+== Optimized Physical Plan ==
+SortAggregate(isMerge=[true], groupBy=[x], select=[x, Final_sum($f1) AS $f1])
++- Sort(orderBy=[x ASC])
+ +- Exchange(distribution=[hash[x]])
+ +- LocalSortAggregate(groupBy=[x], select=[x, Partial_sum(y) AS $f1])
+ +- Sort(orderBy=[x ASC])
+ +- TableSourceScan(table=[[test-catalog, default, foo]],
fields=[x, y])
+
+== Optimized Execution Plan ==
+SortAggregate(isMerge=[true], groupBy=[x], select=[x, Final_sum($f1) AS $f1])
++- Sort(orderBy=[x ASC])
+ +- Exchange(distribution=[hash[x]])
+ +- LocalSortAggregate(groupBy=[x], select=[x, Partial_sum(y) AS $f1])
+ +- Sort(orderBy=[x ASC])
+ +- TableSourceScan(table=[[test-catalog, default, foo]],
fields=[x, y])
diff --git
a/flink-table/flink-sql-gateway-api/src/main/java/org/apache/flink/table/gateway/api/session/SessionEnvironment.java
b/flink-table/flink-sql-gateway-api/src/main/java/org/apache/flink/table/gateway/api/session/SessionEnvironment.java
index 8be8d265442..3e372f5681e 100644
---
a/flink-table/flink-sql-gateway-api/src/main/java/org/apache/flink/table/gateway/api/session/SessionEnvironment.java
+++
b/flink-table/flink-sql-gateway-api/src/main/java/org/apache/flink/table/gateway/api/session/SessionEnvironment.java
@@ -20,6 +20,7 @@ package org.apache.flink.table.gateway.api.session;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.catalog.Catalog;
import org.apache.flink.table.gateway.api.endpoint.EndpointVersion;
@@ -40,8 +41,8 @@ import static
org.apache.flink.util.Preconditions.checkNotNull;
public class SessionEnvironment {
private final @Nullable String sessionName;
private final EndpointVersion version;
- private final Map<String, Catalog> registeredCatalogs;
- private final Map<String, Module> registeredModules;
+ private final Map<String, CatalogCreator> registeredCatalogCreators;
+ private final Map<String, ModuleCreator> registeredModuleCreators;
private final @Nullable String defaultCatalog;
private final Map<String, String> sessionConfig;
@@ -49,14 +50,14 @@ public class SessionEnvironment {
SessionEnvironment(
@Nullable String sessionName,
EndpointVersion version,
- Map<String, Catalog> registeredCatalogs,
- Map<String, Module> registeredModules,
+ Map<String, CatalogCreator> registeredCatalogCreators,
+ Map<String, ModuleCreator> registeredModuleCreators,
@Nullable String defaultCatalog,
Map<String, String> sessionConfig) {
this.sessionName = sessionName;
this.version = version;
- this.registeredCatalogs = registeredCatalogs;
- this.registeredModules = registeredModules;
+ this.registeredCatalogCreators = registeredCatalogCreators;
+ this.registeredModuleCreators = registeredModuleCreators;
this.defaultCatalog = defaultCatalog;
this.sessionConfig = sessionConfig;
}
@@ -77,12 +78,12 @@ public class SessionEnvironment {
return Collections.unmodifiableMap(sessionConfig);
}
- public Map<String, Catalog> getRegisteredCatalogs() {
- return Collections.unmodifiableMap(registeredCatalogs);
+ public Map<String, CatalogCreator> getRegisteredCatalogCreators() {
+ return Collections.unmodifiableMap(registeredCatalogCreators);
}
- public Map<String, Module> getRegisteredModules() {
- return Collections.unmodifiableMap(registeredModules);
+ public Map<String, ModuleCreator> getRegisteredModuleCreators() {
+ return Collections.unmodifiableMap(registeredModuleCreators);
}
public Optional<String> getDefaultCatalog() {
@@ -102,8 +103,8 @@ public class SessionEnvironment {
SessionEnvironment that = (SessionEnvironment) o;
return Objects.equals(sessionName, that.sessionName)
&& Objects.equals(version, that.version)
- && Objects.equals(registeredCatalogs, that.registeredCatalogs)
- && Objects.equals(registeredModules, that.registeredModules)
+ && Objects.equals(registeredCatalogCreators,
that.registeredCatalogCreators)
+ && Objects.equals(registeredModuleCreators,
that.registeredModuleCreators)
&& Objects.equals(defaultCatalog, that.defaultCatalog)
&& Objects.equals(sessionConfig, that.sessionConfig);
}
@@ -113,8 +114,8 @@ public class SessionEnvironment {
return Objects.hash(
sessionName,
version,
- registeredCatalogs,
- registeredModules,
+ registeredCatalogCreators,
+ registeredModuleCreators,
defaultCatalog,
sessionConfig);
}
@@ -133,8 +134,8 @@ public class SessionEnvironment {
private @Nullable String sessionName;
private EndpointVersion version;
private final Map<String, String> sessionConfig = new HashMap<>();
- private final Map<String, Catalog> registeredCatalogs = new
HashMap<>();
- private final Map<String, Module> registeredModules = new HashMap<>();
+ private final Map<String, CatalogCreator> registeredCatalogCreators =
new HashMap<>();
+ private final Map<String, ModuleCreator> registeredModuleCreators =
new HashMap<>();
private @Nullable String defaultCatalog;
public Builder setSessionName(String sessionName) {
@@ -158,21 +159,41 @@ public class SessionEnvironment {
}
public Builder registerCatalog(String catalogName, Catalog catalog) {
- if (registeredCatalogs.containsKey(catalogName)) {
+ if (registeredCatalogCreators.containsKey(catalogName)) {
throw new ValidationException(
String.format("A catalog with name '%s' already
exists.", catalogName));
}
- this.registeredCatalogs.put(catalogName, catalog);
+ this.registeredCatalogCreators.put(
+ catalogName, (configuration, classLoader) -> catalog);
+ return this;
+ }
+
+ public Builder registerCatalogCreator(String catalogName,
CatalogCreator catalogCreator) {
+ if (registeredCatalogCreators.containsKey(catalogName)) {
+ throw new ValidationException(
+ String.format("A catalog with name '%s' already
exists.", catalogName));
+ }
+ this.registeredCatalogCreators.put(catalogName, catalogCreator);
return this;
}
public Builder registerModuleAtHead(String moduleName, Module module) {
- if (registeredModules.containsKey(moduleName)) {
+ if (registeredModuleCreators.containsKey(moduleName)) {
throw new ValidationException(
String.format("A module with name '%s' already
exists", moduleName));
}
- this.registeredModules.put(moduleName, module);
+ this.registeredModuleCreators.put(moduleName, (configuration,
classLoader) -> module);
+ return this;
+ }
+
+ public Builder registerModuleCreatorAtHead(String moduleName,
ModuleCreator moduleCreator) {
+ if (registeredModuleCreators.containsKey(moduleName)) {
+ throw new ValidationException(
+ String.format("A module with name '%s' already
exists", moduleName));
+ }
+
+ this.registeredModuleCreators.put(moduleName, moduleCreator);
return this;
}
@@ -180,10 +201,34 @@ public class SessionEnvironment {
return new SessionEnvironment(
sessionName,
checkNotNull(version),
- registeredCatalogs,
- registeredModules,
+ registeredCatalogCreators,
+ registeredModuleCreators,
defaultCatalog,
sessionConfig);
}
}
+
+ /** An interface used to create {@link Module}. */
+ @PublicEvolving
+ public interface ModuleCreator {
+
+ /**
+ * @param configuration The read-only configuration with which the
module is created.
+ * @param classLoader The class loader with which the module is
created.
+ * @return The created module object.
+ */
+ Module create(ReadableConfig configuration, ClassLoader classLoader);
+ }
+
+ /** An interface used to create {@link Catalog}. */
+ @PublicEvolving
+ public interface CatalogCreator {
+
+ /**
+ * @param configuration The read-only configuration with which the
catalog is created.
+ * @param classLoader The class loader with which the catalog is
created.
+ * @return The created catalog object.
+ */
+ Catalog create(ReadableConfig configuration, ClassLoader classLoader);
+ }
}
diff --git
a/flink-table/flink-sql-gateway/src/main/java/org/apache/flink/table/gateway/service/context/SessionContext.java
b/flink-table/flink-sql-gateway/src/main/java/org/apache/flink/table/gateway/service/context/SessionContext.java
index d4da49115d2..121de1617a5 100644
---
a/flink-table/flink-sql-gateway/src/main/java/org/apache/flink/table/gateway/service/context/SessionContext.java
+++
b/flink-table/flink-sql-gateway/src/main/java/org/apache/flink/table/gateway/service/context/SessionContext.java
@@ -21,6 +21,7 @@ package org.apache.flink.table.gateway.service.context;
import org.apache.flink.configuration.ConfigOption;
import org.apache.flink.configuration.ConfigOptions;
import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.EnvironmentSettings;
import org.apache.flink.table.api.SqlDialect;
@@ -46,6 +47,7 @@ import
org.apache.flink.table.gateway.api.utils.SqlGatewayException;
import org.apache.flink.table.gateway.service.operation.OperationExecutor;
import org.apache.flink.table.gateway.service.operation.OperationManager;
import org.apache.flink.table.gateway.service.utils.SqlExecutionException;
+import org.apache.flink.table.module.Module;
import org.apache.flink.table.module.ModuleManager;
import org.apache.flink.table.operations.ModifyOperation;
import org.apache.flink.table.resource.ResourceManager;
@@ -265,7 +267,8 @@ public class SessionContext {
final ResourceManager resourceManager = new
ResourceManager(configuration, userClassLoader);
- final ModuleManager moduleManager = buildModuleManager(environment);
+ final ModuleManager moduleManager =
+ buildModuleManager(environment, configuration,
userClassLoader);
final CatalogManager catalogManager =
buildCatalogManager(configuration, userClassLoader,
environment);
@@ -394,17 +397,21 @@ public class SessionContext {
TableConfigOptions.RESOURCES_DOWNLOAD_DIR,
path.toAbsolutePath().toString());
}
- private static ModuleManager buildModuleManager(SessionEnvironment
environment) {
+ private static ModuleManager buildModuleManager(
+ SessionEnvironment environment,
+ ReadableConfig readableConfig,
+ ClassLoader classLoader) {
final ModuleManager moduleManager = new ModuleManager();
environment
- .getRegisteredModules()
+ .getRegisteredModuleCreators()
.forEach(
- (moduleName, module) -> {
+ (moduleName, moduleCreator) -> {
Deque<String> moduleNames =
new
ArrayDeque<>(moduleManager.listModules());
moduleNames.addFirst(moduleName);
+ Module module =
moduleCreator.create(readableConfig, classLoader);
moduleManager.loadModule(moduleName, module);
moduleManager.useModules(moduleNames.toArray(new
String[0]));
});
@@ -427,13 +434,17 @@ public class SessionContext {
Catalog defaultCatalog;
if (environment.getDefaultCatalog().isPresent()) {
defaultCatalogName = environment.getDefaultCatalog().get();
- defaultCatalog =
environment.getRegisteredCatalogs().get(defaultCatalogName);
+ defaultCatalog =
+ environment
+ .getRegisteredCatalogCreators()
+ .get(defaultCatalogName)
+ .create(configuration, userClassLoader);
} else {
EnvironmentSettings settings =
EnvironmentSettings.newInstance().withConfiguration(configuration).build();
defaultCatalogName = settings.getBuiltInCatalogName();
- if
(environment.getRegisteredCatalogs().containsKey(defaultCatalogName)) {
+ if
(environment.getRegisteredCatalogCreators().containsKey(defaultCatalogName)) {
throw new SqlGatewayException(
String.format(
"The name of the registered catalog is
conflicts with the built-in default catalog name: %s.",
@@ -451,11 +462,13 @@ public class SessionContext {
// filter the default catalog out to avoid repeated registration
environment
- .getRegisteredCatalogs()
+ .getRegisteredCatalogCreators()
.forEach(
- (catalogName, catalog) -> {
+ (catalogName, catalogCreator) -> {
if (!catalogName.equals(defaultCatalogName)) {
- catalogManager.registerCatalog(catalogName,
catalog);
+ catalogManager.registerCatalog(
+ catalogName,
+ catalogCreator.create(configuration,
userClassLoader));
}
});