This is an automated email from the ASF dual-hosted git repository.
eldenmoon pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new 820b300300d branch-3.0: [Fix](ShortCircuit) fix prepared statement
with partial arguments prepared #45371 (#45465)
820b300300d is described below
commit 820b300300d5f63213fd6aa926b107afa3cb68c1
Author: github-actions[bot]
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Tue Dec 17 09:46:24 2024 +0800
branch-3.0: [Fix](ShortCircuit) fix prepared statement with partial
arguments prepared #45371 (#45465)
Cherry-picked from #45371
Co-authored-by: lihangyu <[email protected]>
---
.../org/apache/doris/nereids/StatementContext.java | 6 ++
.../nereids/rules/analysis/ExpressionAnalyzer.java | 21 +++++
.../org/apache/doris/qe/PointQueryExecutor.java | 40 +++++++---
.../data/point_query_p0/test_point_query.out | 30 +++++++
.../suites/point_query_p0/test_point_query.groovy | 92 +++++++++++++++++-----
5 files changed, 157 insertions(+), 32 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java
index 175b623467a..cd11b3228b9 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java
@@ -129,6 +129,8 @@ public class StatementContext implements Closeable {
private final IdGenerator<PlaceholderId> placeHolderIdGenerator =
PlaceholderId.createGenerator();
// relation id to placeholders for prepared statement, ordered by
placeholder id
private final Map<PlaceholderId, Expression> idToPlaceholderRealExpr = new
TreeMap<>();
+ // map placeholder id to comparison slot, which will used to replace
conjuncts directly
+ private final Map<PlaceholderId, SlotReference> idToComparisonSlot = new
TreeMap<>();
// collect all hash join conditions to compute node connectivity in join
graph
private final List<Expression> joinFilters = new ArrayList<>();
@@ -367,6 +369,10 @@ public class StatementContext implements Closeable {
return idToPlaceholderRealExpr;
}
+ public Map<PlaceholderId, SlotReference> getIdToComparisonSlot() {
+ return idToComparisonSlot;
+ }
+
public Map<CTEId, List<Pair<Multimap<Slot, Slot>, Group>>>
getCteIdToConsumerGroup() {
return cteIdToConsumerGroup;
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java
index 49789aa66e1..5ef3d0fbff3 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java
@@ -24,6 +24,7 @@ import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.common.DdlException;
import org.apache.doris.common.Pair;
import org.apache.doris.common.util.Util;
+import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.SqlCacheContext;
import org.apache.doris.nereids.StatementContext;
@@ -75,6 +76,7 @@ import
org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import
org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
+import org.apache.doris.nereids.trees.plans.PlaceholderId;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.types.ArrayType;
@@ -531,10 +533,29 @@ public class ExpressionAnalyzer extends
SubExprAnalyzer<ExpressionRewriteContext
return visit(realExpr, context);
}
+ // Register prepared statement placeholder id to related slot in
comparison predicate.
+ // Used to replace expression in ShortCircuit plan
+ private void registerPlaceholderIdToSlot(ComparisonPredicate cp,
+ ExpressionRewriteContext context, Expression left,
Expression right) {
+ if (ConnectContext.get() != null
+ && ConnectContext.get().getCommand() ==
MysqlCommand.COM_STMT_EXECUTE) {
+ // Used to replace expression in ShortCircuit plan
+ if (cp.right() instanceof Placeholder && left instanceof
SlotReference) {
+ PlaceholderId id = ((Placeholder)
cp.right()).getPlaceholderId();
+
context.cascadesContext.getStatementContext().getIdToComparisonSlot().put(id,
(SlotReference) left);
+ } else if (cp.left() instanceof Placeholder && right instanceof
SlotReference) {
+ PlaceholderId id = ((Placeholder)
cp.left()).getPlaceholderId();
+
context.cascadesContext.getStatementContext().getIdToComparisonSlot().put(id,
(SlotReference) right);
+ }
+ }
+ }
+
@Override
public Expression visitComparisonPredicate(ComparisonPredicate cp,
ExpressionRewriteContext context) {
Expression left = cp.left().accept(this, context);
Expression right = cp.right().accept(this, context);
+ // Used to replace expression in ShortCircuit plan
+ registerPlaceholderIdToSlot(cp, context, left, right);
cp = (ComparisonPredicate) cp.withChildren(left, right);
return TypeCoercionUtils.processComparisonPredicate(cp);
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/qe/PointQueryExecutor.java
b/fe/fe-core/src/main/java/org/apache/doris/qe/PointQueryExecutor.java
index 9e4030b768b..b1bf3e227f0 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/PointQueryExecutor.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/PointQueryExecutor.java
@@ -31,7 +31,9 @@ import org.apache.doris.common.UserException;
import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.plans.PlaceholderId;
import org.apache.doris.planner.OlapScanNode;
import org.apache.doris.proto.InternalService;
import org.apache.doris.proto.InternalService.KeyTuple;
@@ -59,12 +61,12 @@ import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
-import java.util.stream.Collectors;
public class PointQueryExecutor implements CoordInterface {
private static final Logger LOG =
LogManager.getLogger(PointQueryExecutor.class);
@@ -142,33 +144,45 @@ public class PointQueryExecutor implements CoordInterface
{
Preconditions.checkNotNull(preparedStmtCtx.shortCircuitQueryContext);
ShortCircuitQueryContext shortCircuitQueryContext =
preparedStmtCtx.shortCircuitQueryContext.get();
// update conjuncts
- List<Expr> conjunctVals =
statementContext.getIdToPlaceholderRealExpr().values().stream().map(
- expression -> (
- (Literal) expression).toLegacyLiteral())
- .collect(Collectors.toList());
- if (conjunctVals.size() != preparedStmtCtx.command.placeholderCount())
{
+ Map<String, Expr> colNameToConjunct = Maps.newHashMap();
+ for (Entry<PlaceholderId, SlotReference> entry :
statementContext.getIdToComparisonSlot().entrySet()) {
+ String colName = entry.getValue().getColumn().get().getName();
+ Expr conjunctVal = ((Literal)
statementContext.getIdToPlaceholderRealExpr()
+ .get(entry.getKey())).toLegacyLiteral();
+ colNameToConjunct.put(colName, conjunctVal);
+ }
+ if (colNameToConjunct.size() !=
preparedStmtCtx.command.placeholderCount()) {
throw new AnalysisException("Mismatched conjuncts values size with
prepared"
+ "statement parameters size, expected "
+ preparedStmtCtx.command.placeholderCount()
- + ", but meet " + conjunctVals.size());
+ + ", but meet " + colNameToConjunct.size());
}
- updateScanNodeConjuncts(shortCircuitQueryContext.scanNode,
conjunctVals);
+ updateScanNodeConjuncts(shortCircuitQueryContext.scanNode,
colNameToConjunct);
// short circuit plan and execution
executor.executeAndSendResult(false, false,
shortCircuitQueryContext.analzyedQuery, executor.getContext()
.getMysqlChannel(), null, null);
}
- private static void updateScanNodeConjuncts(OlapScanNode scanNode,
List<Expr> conjunctVals) {
- for (int i = 0; i < conjunctVals.size(); ++i) {
- BinaryPredicate binaryPredicate = (BinaryPredicate)
scanNode.getConjuncts().get(i);
+ private static void updateScanNodeConjuncts(OlapScanNode scanNode,
+ Map<String, Expr> colNameToConjunct) {
+ for (Expr conjunct : scanNode.getConjuncts()) {
+ BinaryPredicate binaryPredicate = (BinaryPredicate) conjunct;
+ SlotRef slot = null;
+ int updateChildIdx = 0;
if (binaryPredicate.getChild(0) instanceof LiteralExpr) {
- binaryPredicate.setChild(0, conjunctVals.get(i));
+ slot = (SlotRef) binaryPredicate.getChildWithoutCast(1);
} else if (binaryPredicate.getChild(1) instanceof LiteralExpr) {
- binaryPredicate.setChild(1, conjunctVals.get(i));
+ slot = (SlotRef) binaryPredicate.getChildWithoutCast(0);
+ updateChildIdx = 1;
} else {
Preconditions.checkState(false, "Should contains literal in "
+ binaryPredicate.toSqlImpl());
}
+ // not a placeholder to replace
+ if (!colNameToConjunct.containsKey(slot.getColumnName())) {
+ continue;
+ }
+ binaryPredicate.setChild(updateChildIdx,
colNameToConjunct.get(slot.getColumnName()));
}
}
diff --git a/regression-test/data/point_query_p0/test_point_query.out
b/regression-test/data/point_query_p0/test_point_query.out
index 1cc4142e39f..55c79757820 100644
--- a/regression-test/data/point_query_p0/test_point_query.out
+++ b/regression-test/data/point_query_p0/test_point_query.out
@@ -160,3 +160,33 @@
-- !sql --
-10 20 aabc update val
+-- !point_select --
+user_guid feature sk feature_value 2021-01-01T00:00
+
+-- !point_select --
+user_guid feature sk feature_value 2021-01-01T00:00
+
+-- !point_select --
+user_guid feature sk feature_value 2021-01-01T00:00
+
+-- !point_select --
+user_guid feature sk feature_value 2021-01-01T00:00
+
+-- !point_select --
+user_guid feature sk feature_value 2021-01-01T00:00
+
+-- !point_select --
+user_guid feature sk feature_value 2021-01-01T00:00
+
+-- !point_select --
+user_guid feature sk feature_value 2021-01-01T00:00
+
+-- !point_select --
+user_guid feature sk feature_value 2021-01-01T00:00
+
+-- !point_select --
+user_guid feature sk feature_value 2021-01-01T00:00
+
+-- !point_select --
+user_guid feature sk feature_value 2021-01-01T00:00
+
diff --git a/regression-test/suites/point_query_p0/test_point_query.groovy
b/regression-test/suites/point_query_p0/test_point_query.groovy
index f84012a8fd7..0ea879956e3 100644
--- a/regression-test/suites/point_query_p0/test_point_query.groovy
+++ b/regression-test/suites/point_query_p0/test_point_query.groovy
@@ -27,32 +27,30 @@ suite("test_point_query", "nonConcurrent") {
logger.info("update config: code=" + code + ", out=" + out + ",
err=" + err)
}
}
+ def user = context.config.jdbcUser
+ def password = context.config.jdbcPassword
+ def realDb = "regression_test_serving_p0"
+ // Parse url
+ String jdbcUrl = context.config.jdbcUrl
+ String urlWithoutSchema = jdbcUrl.substring(jdbcUrl.indexOf("://") + 3)
+ def sql_ip = urlWithoutSchema.substring(0, urlWithoutSchema.indexOf(":"))
+ def sql_port
+ if (urlWithoutSchema.indexOf("/") >= 0) {
+ // e.g: jdbc:mysql://locahost:8080/?a=b
+ sql_port = urlWithoutSchema.substring(urlWithoutSchema.indexOf(":") +
1, urlWithoutSchema.indexOf("/"))
+ } else {
+ // e.g: jdbc:mysql://locahost:8080
+ sql_port = urlWithoutSchema.substring(urlWithoutSchema.indexOf(":") +
1)
+ }
+ // set server side prepared statement url
+ def prepare_url = "jdbc:mysql://" + sql_ip + ":" + sql_port + "/" + realDb
+ "?&useServerPrepStmts=true"
try {
set_be_config.call("disable_storage_row_cache", "false")
- // nereids do not support point query now
sql "set global enable_fallback_to_original_planner = false"
sql """set global enable_nereids_planner=true"""
- def user = context.config.jdbcUser
- def password = context.config.jdbcPassword
- def realDb = "regression_test_serving_p0"
def tableName = realDb + ".tbl_point_query"
sql "CREATE DATABASE IF NOT EXISTS ${realDb}"
- // Parse url
- String jdbcUrl = context.config.jdbcUrl
- String urlWithoutSchema = jdbcUrl.substring(jdbcUrl.indexOf("://") + 3)
- def sql_ip = urlWithoutSchema.substring(0,
urlWithoutSchema.indexOf(":"))
- def sql_port
- if (urlWithoutSchema.indexOf("/") >= 0) {
- // e.g: jdbc:mysql://locahost:8080/?a=b
- sql_port =
urlWithoutSchema.substring(urlWithoutSchema.indexOf(":") + 1,
urlWithoutSchema.indexOf("/"))
- } else {
- // e.g: jdbc:mysql://locahost:8080
- sql_port =
urlWithoutSchema.substring(urlWithoutSchema.indexOf(":") + 1)
- }
- // set server side prepared statement url
- def prepare_url = "jdbc:mysql://" + sql_ip + ":" + sql_port + "/" +
realDb + "?&useServerPrepStmts=true"
-
def generateString = {len ->
def str = ""
for (int i = 0; i < len; i++) {
@@ -330,4 +328,60 @@ suite("test_point_query", "nonConcurrent") {
qt_sql "select * from table_3821461 where col1 = 10 and col2 = 20 and loc3
= 'aabc';"
sql "update table_3821461 set value = 'update value' where col1 = -10 or
col1 = 20;"
qt_sql """select * from table_3821461 where col1 = -10 and col2 = 20 and
loc3 = 'aabc'"""
+
+ sql "DROP TABLE IF EXISTS test_partial_prepared_statement"
+ sql """
+ CREATE TABLE `test_partial_prepared_statement` (
+ `user_guid` varchar(64) NOT NULL,
+ `feature` varchar(256) NOT NULL,
+ `sk` varchar(256) NOT NULL,
+ `feature_value` text NULL,
+ `data_time` datetime NOT NULL
+ ) ENGINE=OLAP
+ UNIQUE KEY(`user_guid`, `feature`, `sk`)
+ DISTRIBUTED BY HASH(`user_guid`) BUCKETS 32
+ PROPERTIES (
+ "enable_unique_key_merge_on_write" = "true",
+ "light_schema_change" = "true",
+ "function_column.sequence_col" = "data_time",
+ "store_row_column" = "true",
+ "replication_num" = "1",
+ "row_store_page_size" = "16384"
+ );
+ """
+ sql "insert into test_partial_prepared_statement values ('user_guid',
'feature', 'sk','feature_value', '2021-01-01 00:00:00')"
+ def result2 = connect(user, password, prepare_url) {
+ def partial_prepared_stmt = prepareStatement "select /*+
SET_VAR(enable_nereids_planner=true) */ * from
regression_test_point_query_p0.test_partial_prepared_statement where sk = 'sk'
and user_guid = 'user_guid' and feature = ? "
+ assertEquals(partial_prepared_stmt.class,
com.mysql.cj.jdbc.ServerPreparedStatement);
+ partial_prepared_stmt.setString(1, "feature")
+ qe_point_select partial_prepared_stmt
+ qe_point_select partial_prepared_stmt
+
+ partial_prepared_stmt = prepareStatement "select /*+
SET_VAR(enable_nereids_planner=true) */ * from
regression_test_point_query_p0.test_partial_prepared_statement where user_guid
= ? and feature = 'feature' and sk = ?"
+ assertEquals(partial_prepared_stmt.class,
com.mysql.cj.jdbc.ServerPreparedStatement);
+ partial_prepared_stmt.setString(1, "user_guid")
+ partial_prepared_stmt.setString(2, "sk")
+ qe_point_select partial_prepared_stmt
+ qe_point_select partial_prepared_stmt
+
+ partial_prepared_stmt = prepareStatement "select /*+
SET_VAR(enable_nereids_planner=true) */ * from
regression_test_point_query_p0.test_partial_prepared_statement where ? =
user_guid and sk = 'sk' and feature = 'feature' "
+ assertEquals(partial_prepared_stmt.class,
com.mysql.cj.jdbc.ServerPreparedStatement);
+ partial_prepared_stmt.setString(1, "user_guid")
+ qe_point_select partial_prepared_stmt
+ qe_point_select partial_prepared_stmt
+
+ partial_prepared_stmt = prepareStatement "select /*+
SET_VAR(enable_nereids_planner=true) */ * from
regression_test_point_query_p0.test_partial_prepared_statement where ? =
user_guid and sk = 'sk' and feature = ? "
+ assertEquals(partial_prepared_stmt.class,
com.mysql.cj.jdbc.ServerPreparedStatement);
+ partial_prepared_stmt.setString(1, "user_guid")
+ partial_prepared_stmt.setString(2, "feature")
+ qe_point_select partial_prepared_stmt
+ qe_point_select partial_prepared_stmt
+
+ partial_prepared_stmt = prepareStatement "select /*+
SET_VAR(enable_nereids_planner=true) */ * from
regression_test_point_query_p0.test_partial_prepared_statement where sk = ?
and feature = ? and 'user_guid' = user_guid"
+ assertEquals(partial_prepared_stmt.class,
com.mysql.cj.jdbc.ServerPreparedStatement);
+ partial_prepared_stmt.setString(1, "sk")
+ partial_prepared_stmt.setString(2, "feature")
+ qe_point_select partial_prepared_stmt
+ qe_point_select partial_prepared_stmt
+ }
}
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]