This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-3.1 by this push:
new 7fe4cf924de branch-3.1: [fix](prepare statement)Support FE execute
COM_STMT_EXECUTE show command. #54446 (#54876)
7fe4cf924de is described below
commit 7fe4cf924de68c42dd1cfd991e596d6493f4e4e5
Author: James <[email protected]>
AuthorDate: Tue Aug 19 10:25:37 2025 +0800
branch-3.1: [fix](prepare statement)Support FE execute COM_STMT_EXECUTE
show command. #54446 (#54876)
backport: #54446
---
.../org/apache/doris/qe/MysqlConnectProcessor.java | 2 +-
.../java/org/apache/doris/qe/StmtExecutor.java | 97 +++++++++++++++++---
.../java/org/apache/doris/qe/StmtExecutorTest.java | 100 +++++++++++++++++++++
.../data/prepared_stmt_p0/prepared_show.out | Bin 0 -> 242 bytes
.../suites/prepared_stmt_p0/prepared_show.groovy | 54 +++++++++++
5 files changed, 242 insertions(+), 11 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/qe/MysqlConnectProcessor.java
b/fe/fe-core/src/main/java/org/apache/doris/qe/MysqlConnectProcessor.java
index 50990a753c3..be5e437da0a 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/MysqlConnectProcessor.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/MysqlConnectProcessor.java
@@ -164,7 +164,7 @@ public class MysqlConnectProcessor extends ConnectProcessor
{
}
StatementBase stmt = new LogicalPlanAdapter(executeStmt,
statementContext);
stmt.setOrigStmt(prepareCommand.getOriginalStmt());
- executor = new StmtExecutor(ctx, stmt);
+ executor = new StmtExecutor(ctx, stmt, true);
ctx.setExecutor(executor);
if (null != queryId) {
executor.execute(queryId);
diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java
b/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java
index 88bdb03c294..f1ebd8a3156 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/StmtExecutor.java
@@ -149,6 +149,7 @@ import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.trees.expressions.Placeholder;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal;
import org.apache.doris.nereids.trees.plans.commands.Command;
import org.apache.doris.nereids.trees.plans.commands.CreatePolicyCommand;
import org.apache.doris.nereids.trees.plans.commands.CreateTableCommand;
@@ -290,6 +291,8 @@ public class StmtExecutor {
// The profile of this execution
private final Profile profile;
private Boolean isForwardedToMaster = null;
+ // Flag for execute prepare statement, need to use binary protocol
resultset
+ private boolean isComStmtExecute = false;
// The result schema if "dry_run_query" is true.
// Only one column to indicate the real return row numbers.
@@ -322,9 +325,14 @@ public class StmtExecutor {
// constructor for receiving parsed stmt from connect processor
public StmtExecutor(ConnectContext ctx, StatementBase parsedStmt) {
+ this(ctx, parsedStmt, false);
+ }
+
+ public StmtExecutor(ConnectContext ctx, StatementBase parsedStmt, boolean
isComStmtExecute) {
this.context = ctx;
this.parsedStmt = parsedStmt;
this.originStmt = parsedStmt.getOrigStmt();
+ this.isComStmtExecute = isComStmtExecute;
if (context.getConnectType() == ConnectType.MYSQL) {
this.serializer = context.getMysqlChannel().getSerializer();
} else {
@@ -2898,18 +2906,14 @@ public class StmtExecutor {
sendMetaData(resultSet.getMetaData(), fieldInfos);
// Send result set.
- for (List<String> row : resultSet.getResultRows()) {
- serializer.reset();
- for (String item : row) {
- if (item == null || item.equals(FeConstants.null_string)) {
- serializer.writeNull();
- } else {
- serializer.writeLenEncodedString(item);
- }
+ if (isComStmtExecute) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Use binary protocol to set result.");
}
-
context.getMysqlChannel().sendOnePacket(serializer.toByteBuffer());
+ sendBinaryResultRow(resultSet);
+ } else {
+ sendTextResultRow(resultSet);
}
-
context.getState().setEof();
} else if
(context.getConnectType().equals(ConnectType.ARROW_FLIGHT_SQL)) {
context.updateReturnRows(resultSet.getResultRows().size());
@@ -2921,6 +2925,79 @@ public class StmtExecutor {
}
}
+ protected void sendTextResultRow(ResultSet resultSet) throws IOException {
+ for (List<String> row : resultSet.getResultRows()) {
+ serializer.reset();
+ for (String item : row) {
+ if (item == null || item.equals(FeConstants.null_string)) {
+ serializer.writeNull();
+ } else {
+ serializer.writeLenEncodedString(item);
+ }
+ }
+ context.getMysqlChannel().sendOnePacket(serializer.toByteBuffer());
+ }
+ }
+
+ protected void sendBinaryResultRow(ResultSet resultSet) throws IOException
{
+ //
https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_binary_resultset.html#sect_protocol_binary_resultset_row_value
+ ResultSetMetaData metaData = resultSet.getMetaData();
+ int nullBitmapLength = (metaData.getColumnCount() + 7 + 2) / 8;
+ for (List<String> row : resultSet.getResultRows()) {
+ serializer.reset();
+ // Reserved one byte.
+ serializer.writeByte((byte) 0x00);
+ byte[] nullBitmap = new byte[nullBitmapLength];
+ // Generate null bitmap
+ for (int i = 0; i < row.size(); i++) {
+ String item = row.get(i);
+ if (item == null || item.equals(FeConstants.null_string)) {
+ // The first 2 bits are reserved.
+ int byteIndex = (i + 2) / 8; // Index of the byte in the
bitmap array
+ int bitInByte = (i + 2) % 8; // Position within the
target byte (0-7)
+ nullBitmap[byteIndex] |= (1 << bitInByte);
+ }
+ }
+ // Null bitmap
+ serializer.writeBytes(nullBitmap);
+ // Non-null columns
+ for (int i = 0; i < row.size(); i++) {
+ String item = row.get(i);
+ if (item != null && !item.equals(FeConstants.null_string)) {
+ Column col = metaData.getColumn(i);
+ switch (col.getType().getPrimitiveType()) {
+ case INT:
+ serializer.writeInt4(Integer.parseInt(item));
+ break;
+ case BIGINT:
+ serializer.writeInt8(Long.parseLong(item));
+ break;
+ case DATETIME:
+ case DATETIMEV2:
+ DateTimeV2Literal datetime = new
DateTimeV2Literal(item);
+ long microSecond = datetime.getMicroSecond();
+ //
https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_query_response_text_resultset.html
+ int length = microSecond == 0 ? 7 : 11;
+ serializer.writeInt1(length);
+ serializer.writeInt2((int) (datetime.getYear()));
+ serializer.writeInt1((int) datetime.getMonth());
+ serializer.writeInt1((int) datetime.getDay());
+ serializer.writeInt1((int) datetime.getHour());
+ serializer.writeInt1((int) datetime.getMinute());
+ serializer.writeInt1((int) datetime.getSecond());
+ if (microSecond > 0) {
+ serializer.writeInt4((int) microSecond);
+ }
+ break;
+ default:
+ serializer.writeLenEncodedString(item);
+ }
+ }
+ }
+ context.getMysqlChannel().sendOnePacket(serializer.toByteBuffer());
+ }
+ }
+
// Process show statement
private void handleShow() throws IOException, AnalysisException,
DdlException {
ShowExecutor executor = new ShowExecutor(context, (ShowStmt)
parsedStmt);
diff --git a/fe/fe-core/src/test/java/org/apache/doris/qe/StmtExecutorTest.java
b/fe/fe-core/src/test/java/org/apache/doris/qe/StmtExecutorTest.java
index efa3fe2eb4b..46e5a10031b 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/qe/StmtExecutorTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/qe/StmtExecutorTest.java
@@ -17,14 +17,27 @@
package org.apache.doris.qe;
+import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.InternalSchemaInitializer;
+import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.common.Config;
import org.apache.doris.common.FeConstants;
+import org.apache.doris.mysql.MysqlChannel;
+import org.apache.doris.mysql.MysqlSerializer;
+import org.apache.doris.qe.CommonResultSet.CommonResultSetMetaData;
+import org.apache.doris.qe.ConnectContext.ConnectType;
import org.apache.doris.qe.QueryState.MysqlStateType;
import org.apache.doris.utframe.TestWithFeService;
+import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
public class StmtExecutorTest extends TestWithFeService {
@@ -176,4 +189,91 @@ public class StmtExecutorTest extends TestWithFeService {
Assertions.assertEquals(QueryState.MysqlStateType.OK,
connectContext.getState().getStateType());
}
+ @Test
+ public void testSendTextResultRow() throws IOException {
+ ConnectContext mockCtx = Mockito.mock(ConnectContext.class);
+ MysqlChannel channel = Mockito.mock(MysqlChannel.class);
+ Mockito.when(mockCtx.getConnectType()).thenReturn(ConnectType.MYSQL);
+ Mockito.when(mockCtx.getMysqlChannel()).thenReturn(channel);
+ MysqlSerializer mysqlSerializer = MysqlSerializer.newInstance();
+ Mockito.when(channel.getSerializer()).thenReturn(mysqlSerializer);
+ SessionVariable sessionVariable = VariableMgr.newSessionVariable();
+ Mockito.when(mockCtx.getSessionVariable()).thenReturn(sessionVariable);
+ OriginStatement stmt = new OriginStatement("", 1);
+
+ List<List<String>> rows = Lists.newArrayList();
+ List<String> row1 = Lists.newArrayList();
+ row1.add(null);
+ row1.add("row1");
+ List<String> row2 = Lists.newArrayList();
+ row2.add("1234");
+ row2.add("row2");
+ rows.add(row1);
+ rows.add(row2);
+ List<Column> columns = Lists.newArrayList();
+ columns.add(new Column());
+ columns.add(new Column());
+ ResultSet resultSet = new CommonResultSet(new
CommonResultSetMetaData(columns), rows);
+ AtomicInteger i = new AtomicInteger();
+ Mockito.doAnswer(invocation -> {
+ byte[] expected0 = new byte[]{-5, 4, 114, 111, 119, 49};
+ byte[] expected1 = new byte[]{4, 49, 50, 51, 52, 4, 114, 111, 119,
50};
+ ByteBuffer buffer = invocation.getArgument(0);
+ if (i.get() == 0) {
+ Assertions.assertArrayEquals(expected0, buffer.array());
+ i.getAndIncrement();
+ } else if (i.get() == 1) {
+ Assertions.assertArrayEquals(expected1, buffer.array());
+ i.getAndIncrement();
+ }
+ return null;
+ }).when(channel).sendOnePacket(Mockito.any(ByteBuffer.class));
+
+ StmtExecutor executor = new StmtExecutor(mockCtx, stmt, false);
+ executor.sendTextResultRow(resultSet);
+ }
+
+ @Test
+ public void testSendBinaryResultRow() throws IOException {
+ ConnectContext mockCtx = Mockito.mock(ConnectContext.class);
+ MysqlChannel channel = Mockito.mock(MysqlChannel.class);
+ Mockito.when(mockCtx.getConnectType()).thenReturn(ConnectType.MYSQL);
+ Mockito.when(mockCtx.getMysqlChannel()).thenReturn(channel);
+ MysqlSerializer mysqlSerializer = MysqlSerializer.newInstance();
+ Mockito.when(channel.getSerializer()).thenReturn(mysqlSerializer);
+ SessionVariable sessionVariable = VariableMgr.newSessionVariable();
+ Mockito.when(mockCtx.getSessionVariable()).thenReturn(sessionVariable);
+ OriginStatement stmt = new OriginStatement("", 1);
+
+ List<List<String>> rows = Lists.newArrayList();
+ List<String> row1 = Lists.newArrayList();
+ row1.add(null);
+ row1.add("2025-01-01 01:02:03");
+ List<String> row2 = Lists.newArrayList();
+ row2.add("1234");
+ row2.add("2025-01-01 01:02:03.123456");
+ rows.add(row1);
+ rows.add(row2);
+ List<Column> columns = Lists.newArrayList();
+ columns.add(new Column("col1", PrimitiveType.BIGINT));
+ columns.add(new Column("col2", PrimitiveType.DATETIMEV2));
+ ResultSet resultSet = new CommonResultSet(new
CommonResultSetMetaData(columns), rows);
+ AtomicInteger i = new AtomicInteger();
+ Mockito.doAnswer(invocation -> {
+ byte[] expected0 = new byte[]{0, 4, 7, -23, 7, 1, 1, 1, 2, 3};
+ byte[] expected1 = new byte[]{0, 0, -46, 4, 0, 0, 0, 0, 0, 0, 11,
-23, 7, 1, 1, 1, 2, 3, 64, -30, 1, 0};
+ ByteBuffer buffer = invocation.getArgument(0);
+ if (i.get() == 0) {
+ Assertions.assertArrayEquals(expected0, buffer.array());
+ i.getAndIncrement();
+ } else if (i.get() == 1) {
+ Assertions.assertArrayEquals(expected1, buffer.array());
+ i.getAndIncrement();
+ }
+ return null;
+ }).when(channel).sendOnePacket(Mockito.any(ByteBuffer.class));
+
+ StmtExecutor executor = new StmtExecutor(mockCtx, stmt, false);
+ executor.sendBinaryResultRow(resultSet);
+ }
}
diff --git a/regression-test/data/prepared_stmt_p0/prepared_show.out
b/regression-test/data/prepared_stmt_p0/prepared_show.out
new file mode 100644
index 00000000000..403246ea021
Binary files /dev/null and
b/regression-test/data/prepared_stmt_p0/prepared_show.out differ
diff --git a/regression-test/suites/prepared_stmt_p0/prepared_show.groovy
b/regression-test/suites/prepared_stmt_p0/prepared_show.groovy
new file mode 100644
index 00000000000..e276e7c401b
--- /dev/null
+++ b/regression-test/suites/prepared_stmt_p0/prepared_show.groovy
@@ -0,0 +1,54 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("prepared_show") {
+ def tableName = "prepared_show"
+ def user = context.config.jdbcUser
+ def password = context.config.jdbcPassword
+ sql """drop database if exists prepared_show"""
+ sql """create database prepared_show"""
+ sql """use prepared_show"""
+ sql """CREATE TABLE IF NOT EXISTS prepared_show_table1 (`k1` tinyint NULL)
+ ENGINE=OLAP
+ DUPLICATE KEY(`k1`)
+ DISTRIBUTED BY HASH(`k1`) BUCKETS 1
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1"
+ )"""
+
+ sql """CREATE TABLE IF NOT EXISTS prepared_show_table2 (`k1` tinyint NULL)
+ ENGINE=OLAP
+ DUPLICATE KEY(`k1`)
+ DISTRIBUTED BY HASH(`k1`) BUCKETS 1
+ PROPERTIES (
+ "replication_allocation" = "tag.location.default: 1"
+ )"""
+ String url = getServerPrepareJdbcUrl(context.config.jdbcUrl,
"prepared_show")
+ def result1 = connect(user, password, url) {
+ def stmt_read = prepareStatement """show databases like
"prepared_show" """
+ qe_stmt_show_db stmt_read
+
+ stmt_read = prepareStatement """show tables from prepared_show"""
+ qe_stmt_show_table stmt_read
+
+ stmt_read = prepareStatement """show table stats
prepared_show_table1"""
+ qe_stmt_show_table_stats1 stmt_read
+
+ stmt_read = prepareStatement """show processlist"""
+ stmt_read.executeQuery()
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]