This is an automated email from the ASF dual-hosted git repository.

morrySnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new e9a146d6367 [fix](connect) Align COM_RESET_CONNECTION behavior with 
MySQL (#63884)
e9a146d6367 is described below

commit e9a146d6367f300df86ca5643f769e8304a0a6de
Author: yujun <[email protected]>
AuthorDate: Wed Jun 3 17:20:25 2026 +0800

    [fix](connect) Align COM_RESET_CONNECTION behavior with MySQL (#63884)
    
    ### What problem does this PR solve?
    
    `COM_RESET_CONNECTION` was accepted by Doris, but its behavior was not
    compatible with MySQL. The previous implementation cleared the current
    catalog/database state and returned OK after only a partial reset. This
    could make pooled clients, such as C# MySqlConnector with
    `ConnectionReset=True`, fail later unqualified SQL with `Current
    database is not set`. Other session-scoped state, including user
    variables and prepared statements, also needed to be reset consistently.
    
    ### What is changed?
    
    - Preserve the current catalog/database state across
    `COM_RESET_CONNECTION` so pooled connections can continue using the
    selected database.
    - Reset session variables, user variables, prepared statements, running
    query state, insert result, command state, and returned row count.
    - Roll back transaction state during reset and return an error if
    rollback fails.
    - Drop temporary tables during reset and return an error if cleanup
    fails.
    - Return OK with the autocommit server status when reset succeeds.
    - Return the MySQL-compatible unknown prepared statement error when
    executing a statement cleared by reset.
    - Extend regression and FE unit coverage for reset behavior, error
    handling, and current database preservation.
---
 .../java/org/apache/doris/qe/ConnectContext.java   |  99 ++++++++++--
 .../java/org/apache/doris/qe/ConnectProcessor.java |  17 ++-
 .../org/apache/doris/qe/MysqlConnectProcessor.java |  10 +-
 .../org/apache/doris/qe/ConnectContextTest.java    | 166 +++++++++++++++++++++
 .../doris/regression/suite/SuiteContext.groovy     |  16 ++
 .../test_reset_connection_session_variable.groovy  |  76 ++++++++++
 6 files changed, 366 insertions(+), 18 deletions(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java 
b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java
index 3cf9b29c16b..996a1ae572a 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java
@@ -56,6 +56,7 @@ import org.apache.doris.mysql.MysqlCommand;
 import org.apache.doris.mysql.MysqlHandshakePacket;
 import org.apache.doris.mysql.MysqlSslContext;
 import org.apache.doris.mysql.ProxyMysqlChannel;
+import org.apache.doris.mysql.privilege.Auth;
 import org.apache.doris.mysql.privilege.PrivPredicate;
 import org.apache.doris.nereids.StatementContext;
 import org.apache.doris.nereids.stats.StatsErrorEstimator;
@@ -113,6 +114,7 @@ public class ConnectContext {
     private static final Logger LOG = 
LogManager.getLogger(ConnectContext.class);
 
     private static final String SSL_PROTOCOL = "TLS";
+    private static final int INITIAL_PREPARED_STMT_ID = Integer.MIN_VALUE;
 
     public enum ConnectType {
         MYSQL,
@@ -128,7 +130,7 @@ public class ConnectContext {
     protected volatile TUniqueId loadId;
     protected volatile long backendId;
     // range [Integer.MIN_VALUE, Integer.MAX_VALUE]
-    protected int preparedStmtId = Integer.MIN_VALUE;
+    protected int preparedStmtId = INITIAL_PREPARED_STMT_ID;
     protected volatile LoadTaskInfo streamLoadInfo;
 
     protected volatile TUniqueId queryId = null;
@@ -369,6 +371,47 @@ public class ConnectContext {
         lastDBOfCatalog.clear();
     }
 
+    public void resetConnection() throws UserException {
+        closeTxnForConnectionReset();
+        if (!dbToTempTableNamesMap.isEmpty()) {
+            cleanupTemporaryTables(true);
+            dbToTempTableNamesMap.clear();
+        }
+        resetSessionVariable();
+        userVars = new HashMap<>();
+        preparedQuerys.clear();
+        preparedStatementContextMap.clear();
+        runningQuery = null;
+        queryId = null;
+        lastQueryId = null;
+        setTraceId(null);
+        insertResult = null;
+        command = MysqlCommand.COM_SLEEP;
+        returnRows = 0;
+    }
+
+    private void resetSessionVariable() {
+        sessionVariable = VariableMgr.newSessionVariable();
+        applyUserSessionVariableDefaults();
+        if (Config.use_fuzzy_session_variable) {
+            sessionVariable.initFuzzyModeVariables();
+        }
+    }
+
+    private void applyUserSessionVariableDefaults() {
+        String qualifiedUser = getQualifiedUser();
+        if (Strings.isNullOrEmpty(qualifiedUser)) {
+            return;
+        }
+        Env currentEnv = env == null ? Env.getCurrentEnv() : env;
+        Auth auth = currentEnv == null ? null : currentEnv.getAuth();
+        if (auth == null) {
+            return;
+        }
+        setUserQueryTimeout(auth.getQueryTimeout(qualifiedUser));
+        setUserInsertTimeout(auth.getInsertTimeout(qualifiedUser));
+    }
+
     public void setNotEvalNondeterministicFunction(boolean 
notEvalNondeterministicFunction) {
         this.notEvalNondeterministicFunction = notEvalNondeterministicFunction;
     }
@@ -385,12 +428,9 @@ public class ConnectContext {
         state = new QueryState();
         returnRows = 0;
         isKilled = false;
-        sessionVariable = VariableMgr.newSessionVariable();
+        resetSessionVariable();
         userVars = new HashMap<>();
         command = MysqlCommand.COM_SLEEP;
-        if (Config.use_fuzzy_session_variable) {
-            sessionVariable.initFuzzyModeVariables();
-        }
 
         sessionId = UUID.randomUUID().toString();
         if (!FeConstants.runningUnitTest) {
@@ -490,6 +530,18 @@ public class ConnectContext {
         }
     }
 
+    private void closeTxnForConnectionReset() throws DdlException {
+        if (isTxnModel()) {
+            try {
+                txnEntry.abortTransaction();
+            } catch (Exception e) {
+                throw new DdlException(String.format("rollback transaction 
failed, db: %s, txnId: %s",
+                        currentDb, txnEntry.getTransactionId()), e);
+            }
+            txnEntry = null;
+        }
+    }
+
     public long getStmtId() {
         return stmtId;
     }
@@ -911,21 +963,41 @@ public class ConnectContext {
     }
 
     protected void deleteTempTable() {
+        try {
+            cleanupTemporaryTables(false);
+        } catch (DdlException e) {
+            LOG.error("drop temporary table error", e);
+        }
+    }
+
+    private void cleanupTemporaryTables(boolean reportFailure) throws 
DdlException {
         // only delete temporary table in its creating session, not proxy 
session in master fe
         if (isProxy) {
             return;
         }
 
+        Map<String, Set<String>> tempTables = new HashMap<>();
+        for (Map.Entry<String, Set<String>> entry : 
dbToTempTableNamesMap.entrySet()) {
+            tempTables.put(entry.getKey(), new HashSet<>(entry.getValue()));
+        }
+
         // if current fe is master, delete temporary table directly
         if (Env.getCurrentEnv().isMaster()) {
-            for (String dbName : dbToTempTableNamesMap.keySet()) {
-                Database db = 
Env.getCurrentEnv().getInternalCatalog().getDb(dbName).get();
-                for (String tableName : dbToTempTableNamesMap.get(dbName)) {
+            for (String dbName : tempTables.keySet()) {
+                for (String tableName : tempTables.get(dbName)) {
                     LOG.info("try to drop temporary table: {}.{}", dbName, 
tableName);
                     try {
+                        Database db = 
Env.getCurrentEnv().getInternalCatalog().getDb(dbName).get();
                         Env.getCurrentEnv().getInternalCatalog()
                             .dropTableWithoutCheck(db, 
db.getTable(tableName).get(), false, true);
-                    } catch (DdlException e) {
+                    } catch (Exception e) {
+                        if (reportFailure) {
+                            if (e instanceof DdlException) {
+                                throw (DdlException) e;
+                            }
+                            throw new DdlException(String.format(
+                                    "drop temporary table error: db: %s, 
table: %s", dbName, tableName), e);
+                        }
                         LOG.error("drop temporary table error: {}.{}", dbName, 
tableName, e);
                     }
                 }
@@ -933,8 +1005,8 @@ public class ConnectContext {
         } else {
             // forward to master fe to drop table
             RedirectStatus redirectStatus = new RedirectStatus(true, false);
-            for (String dbName : dbToTempTableNamesMap.keySet()) {
-                for (String tableName : dbToTempTableNamesMap.get(dbName)) {
+            for (String dbName : tempTables.keySet()) {
+                for (String tableName : tempTables.get(dbName)) {
                     LOG.info("request to delete temporary table: {}.{}", 
dbName, tableName);
                     String dropTableSql = String.format("drop table `%s`", 
tableName);
                     OriginStatement originStmt = new 
OriginStatement(dropTableSql, 0);
@@ -945,6 +1017,11 @@ public class ConnectContext {
                     try {
                         masterOpExecutor.execute();
                     } catch (Exception e) {
+                        if (reportFailure) {
+                            throw new DdlException(String.format(
+                                    "master FE drop temporary table error: db: 
%s, table: %s",
+                                    dbName, tableName), e);
+                        }
                         LOG.error("master FE drop temporary table error: db: 
{}, table: {}", dbName, tableName, e);
                     }
                 }
diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java 
b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java
index 8cd734ceda8..aa77f3d28b1 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectProcessor.java
@@ -50,7 +50,6 @@ import org.apache.doris.common.util.DebugUtil;
 import org.apache.doris.common.util.SqlUtils;
 import org.apache.doris.common.util.Util;
 import org.apache.doris.datasource.CatalogIf;
-import org.apache.doris.datasource.InternalCatalog;
 import org.apache.doris.metric.MetricRepo;
 import org.apache.doris.mysql.MysqlChannel;
 import org.apache.doris.mysql.MysqlCommand;
@@ -155,12 +154,20 @@ public abstract class ConnectProcessor {
     }
 
     protected void handleResetConnection() {
-        ctx.changeDefaultCatalog(InternalCatalog.INTERNAL_CATALOG_NAME);
-        ctx.clearLastDBOfCatalog();
-        ctx.getState().setOk();
+        try {
+            ctx.resetConnection();
+            ctx.getState().setOk();
+        } catch (UserException e) {
+            ctx.getState().setError(e.getMysqlErrorCode(), e.getMessage());
+        }
     }
 
-    protected void handleStmtReset() {
+    protected void handleStmtResetById(int stmtId) {
+        if (ctx.getPreparedStementContext(String.valueOf(stmtId)) == null) {
+            ctx.getState().setError(ErrorCode.ERR_UNKNOWN_STMT_HANDLER,
+                    String.format("Unknown prepared statement handler (%s) 
given to mysqld_stmt_reset", stmtId));
+            return;
+        }
         ctx.getState().setOk();
     }
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/qe/MysqlConnectProcessor.java 
b/fe/fe-core/src/main/java/org/apache/doris/qe/MysqlConnectProcessor.java
index 8e2b805b27b..f70a9d3f622 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/MysqlConnectProcessor.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/MysqlConnectProcessor.java
@@ -82,6 +82,12 @@ public class MysqlConnectProcessor extends ConnectProcessor {
         handleStmtClose(stmtId);
     }
 
+    private void handleStmtReset() {
+        packetBuf = packetBuf.order(ByteOrder.LITTLE_ENDIAN);
+        int stmtId = packetBuf.getInt();
+        handleStmtResetById(stmtId);
+    }
+
     private String getPacket() {
         byte[] bytes = packetBuf.array();
         StringBuilder printB = new StringBuilder();
@@ -214,8 +220,8 @@ public class MysqlConnectProcessor extends ConnectProcessor 
{
         PreparedStatementContext preparedStatementContext = 
ctx.getPreparedStementContext(String.valueOf(stmtId));
         if (preparedStatementContext == null) {
             LOG.warn("No such statement in context, stmtId:{}", stmtId);
-            ctx.getState().setError(ErrorCode.ERR_UNKNOWN_COM_ERROR,
-                    "msg: Not supported such prepared statement");
+            ctx.getState().setError(ErrorCode.ERR_UNKNOWN_STMT_HANDLER,
+                    String.format("Unknown prepared statement handler (%s) 
given to mysqld_stmt_execute", stmtId));
             return;
         }
         handleExecute(preparedStatementContext.command, stmtId, 
preparedStatementContext, packetBuf, null);
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/qe/ConnectContextTest.java 
b/fe/fe-core/src/test/java/org/apache/doris/qe/ConnectContextTest.java
index 090adbb6089..94d922e6987 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/qe/ConnectContextTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/qe/ConnectContextTest.java
@@ -18,8 +18,12 @@
 package org.apache.doris.qe;
 
 import org.apache.doris.analysis.ResourceTypeEnum;
+import org.apache.doris.analysis.SetVar;
+import org.apache.doris.analysis.StringLiteral;
 import org.apache.doris.analysis.UserIdentity;
+import org.apache.doris.catalog.Database;
 import org.apache.doris.catalog.Env;
+import org.apache.doris.catalog.Table;
 import org.apache.doris.cloud.qe.ComputeGroupException;
 import org.apache.doris.cloud.system.CloudSystemInfoService;
 import org.apache.doris.common.Config;
@@ -33,9 +37,11 @@ import org.apache.doris.mysql.MysqlCommand;
 import org.apache.doris.mysql.privilege.AccessControllerManager;
 import org.apache.doris.mysql.privilege.Auth;
 import org.apache.doris.mysql.privilege.PrivPredicate;
+import org.apache.doris.qe.QueryState.MysqlStateType;
 import org.apache.doris.system.Backend;
 import org.apache.doris.system.SystemInfoService;
 import org.apache.doris.thrift.TUniqueId;
+import org.apache.doris.transaction.TransactionStatus;
 
 import com.google.common.collect.Lists;
 import org.junit.Assert;
@@ -50,6 +56,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.UUID;
+import java.util.concurrent.atomic.AtomicInteger;
 
 public class ConnectContextTest {
     private StmtExecutor executor = Mockito.mock(StmtExecutor.class);
@@ -73,6 +80,165 @@ public class ConnectContextTest {
         
Mockito.when(catalogMgr.getCatalog(Mockito.anyString())).thenReturn(internalCatalog);
     }
 
+    @Test
+    public void testResetConnectionClearsSessionState() throws Exception {
+        ConnectContext ctx = new ConnectContext();
+        ctx.setEnv(env);
+        
ctx.setCurrentUserIdentity(UserIdentity.createAnalyzedUserIdentWithIp("testUser",
 "%"));
+        Mockito.when(env.getAuth()).thenReturn(auth);
+        Mockito.when(auth.getQueryTimeout("testUser")).thenReturn(123);
+        Mockito.when(auth.getInsertTimeout("testUser")).thenReturn(456);
+        ctx.setUserQueryTimeout(123);
+        ctx.setUserInsertTimeout(456);
+        VariableMgr.setVar(ctx.getSessionVariable(),
+                new SetVar(SessionVariable.SQL_SELECT_LIMIT, new 
StringLiteral("0")));
+        ctx.getSessionVariable().setQueryTimeoutS(1);
+        ctx.getSessionVariable().setInsertTimeoutS(2);
+        ctx.setUserVar("user_var", new StringLiteral("value"));
+        ctx.changeDefaultCatalog("external_catalog");
+        ctx.currentDb = "test_db";
+        ctx.currentDbId = 10;
+        ctx.addLastDBOfCatalog("external_catalog", "test_db");
+        ctx.addPreparedQuery("1", "select 1");
+        long initialPreparedStmtId = ctx.getPreparedStmtId();
+        ctx.getSessionVariable().enableServeSidePreparedStatement = true;
+        ctx.addPreparedStatementContext(String.valueOf(initialPreparedStmtId),
+                new PreparedStatementContext(null, ctx, null, "select 1"));
+        long nextPreparedStmtId = ctx.getPreparedStmtId();
+        ctx.setRunningQuery("select 1");
+        TUniqueId queryId = new TUniqueId(100, 200);
+        ctx.setQueryId(queryId);
+        ctx.setTraceId("old_trace");
+        ctx.setConnectScheduler(connectScheduler);
+        ctx.setCommand(MysqlCommand.COM_QUERY);
+        ctx.updateReturnRows(10);
+        ctx.setOrUpdateInsertResult(1, "label", "test_db", "test_table", 
TransactionStatus.VISIBLE, 1, 0);
+
+        Assert.assertEquals(0, ctx.getSessionVariable().getSqlSelectLimit());
+        Assert.assertEquals(1, ctx.getSessionVariable().getQueryTimeoutS());
+        Assert.assertEquals(2, ctx.getSessionVariable().getInsertTimeoutS());
+        Assert.assertFalse(ctx.getUserVars().isEmpty());
+        Assert.assertNotNull(ctx.getInsertResult());
+
+        ctx.resetConnection();
+
+        Assert.assertEquals(-1, ctx.getSessionVariable().getSqlSelectLimit());
+        Assert.assertEquals(123, ctx.getSessionVariable().getQueryTimeoutS());
+        Assert.assertEquals(456, ctx.getSessionVariable().getInsertTimeoutS());
+        Assert.assertTrue(ctx.getUserVars().isEmpty());
+        Assert.assertEquals("external_catalog", ctx.getDefaultCatalog());
+        Assert.assertEquals("test_db", ctx.getDatabase());
+        Assert.assertEquals("test_db", 
ctx.getLastDBOfCatalog("external_catalog"));
+        Assert.assertNull(ctx.getPreparedQuery("1"));
+        Assert.assertNull(ctx.getRunningQuery());
+        Assert.assertNull(ctx.queryId());
+        Assert.assertNull(ctx.getLastQueryId());
+        Assert.assertNull(ctx.traceId());
+        Mockito.verify(connectScheduler).removeOldTraceId("old_trace");
+        Assert.assertEquals(nextPreparedStmtId, ctx.getPreparedStmtId());
+        Assert.assertTrue(initialPreparedStmtId != ctx.getPreparedStmtId());
+        Assert.assertNull(ctx.getInsertResult());
+        Assert.assertEquals(MysqlCommand.COM_SLEEP, ctx.getCommand());
+        Assert.assertEquals(0, ctx.getReturnRows());
+    }
+
+    @Test
+    public void testHandleResetConnectionDoesNotSetServerStatus() {
+        ConnectContext ctx = new ConnectContext();
+        ConnectProcessor processor = new ConnectProcessor(ctx) {
+        };
+
+        ctx.getState().reset();
+        processor.handleResetConnection();
+
+        Assert.assertEquals(0, ctx.getState().serverStatus);
+    }
+
+    @Test
+    public void testHandleStmtResetReturnsOkForKnownStatement() throws 
Exception {
+        ConnectContext ctx = new ConnectContext();
+        ctx.getSessionVariable().enableServeSidePreparedStatement = true;
+        ctx.addPreparedStatementContext("1", new 
PreparedStatementContext(null, ctx, null, "select 1"));
+        ConnectProcessor processor = new ConnectProcessor(ctx) {
+        };
+
+        ctx.getState().reset();
+        processor.handleStmtResetById(1);
+
+        Assert.assertEquals(MysqlStateType.OK, ctx.getState().getStateType());
+    }
+
+    @Test
+    public void testHandleStmtResetReturnsErrorForUnknownStatement() {
+        ConnectContext ctx = new ConnectContext();
+        ConnectProcessor processor = new ConnectProcessor(ctx) {
+        };
+
+        ctx.getState().reset();
+        processor.handleStmtResetById(1);
+
+        Assert.assertEquals(MysqlStateType.ERR, ctx.getState().getStateType());
+        Assert.assertEquals(ErrorCode.ERR_UNKNOWN_STMT_HANDLER, 
ctx.getState().getErrorCode());
+        
Assert.assertTrue(ctx.getState().getErrorMessage().contains("mysqld_stmt_reset"));
+    }
+
+    @Test
+    public void testHandleResetConnectionReturnsErrorOnResetFailure() {
+        ConnectContext ctx = new ConnectContext() {
+            @Override
+            public void resetConnection() throws DdlException {
+                throw new DdlException("reset connection failed");
+            }
+        };
+        ConnectProcessor processor = new ConnectProcessor(ctx) {
+        };
+
+        ctx.getState().reset();
+        processor.handleResetConnection();
+
+        Assert.assertEquals(MysqlStateType.ERR, ctx.getState().getStateType());
+        Assert.assertEquals(ErrorCode.ERR_UNKNOWN_ERROR, 
ctx.getState().getErrorCode());
+        Assert.assertTrue(ctx.getState().getErrorMessage().contains("reset 
connection failed"));
+    }
+
+    @Test
+    public void testResetConnectionDropsMultipleTemporaryTables() throws 
Exception {
+        ConnectContext ctx = new ConnectContext();
+        ctx.setEnv(env);
+        ctx.addTempTableToDB("test_db", "test_temp_table1");
+        ctx.addTempTableToDB("test_db", "test_temp_table2");
+
+        Database db = Mockito.mock(Database.class);
+        Table table1 = Mockito.mock(Table.class);
+        Table table2 = Mockito.mock(Table.class);
+        AtomicInteger droppedTableCount = new AtomicInteger();
+
+        Mockito.when(env.isMaster()).thenReturn(true);
+        Mockito.when(env.getAuth()).thenReturn(auth);
+        
Mockito.when(internalCatalog.getDb("test_db")).thenReturn(Optional.of(db));
+        
Mockito.when(db.getTable("test_temp_table1")).thenReturn(Optional.of(table1));
+        
Mockito.when(db.getTable("test_temp_table2")).thenReturn(Optional.of(table2));
+        Mockito.doAnswer(invocation -> {
+            Table table = invocation.getArgument(1);
+            if (table == table1) {
+                ctx.removeTempTableFromDB("test_db", "test_temp_table1");
+            } else {
+                ctx.removeTempTableFromDB("test_db", "test_temp_table2");
+            }
+            droppedTableCount.incrementAndGet();
+            return null;
+        }).when(internalCatalog).dropTableWithoutCheck(Mockito.eq(db), 
Mockito.any(Table.class),
+                Mockito.eq(false), Mockito.eq(true));
+
+        try (MockedStatic<Env> mockedEnv = Mockito.mockStatic(Env.class)) {
+            mockedEnv.when(Env::getCurrentEnv).thenReturn(env);
+            ctx.resetConnection();
+        }
+
+        Assert.assertEquals(2, droppedTableCount.get());
+        Assert.assertTrue(ctx.getDbToTempTableNamesMap().isEmpty());
+    }
+
     @Test
     public void testNormal() {
         try (MockedStatic<Env> mockedEnv = Mockito.mockStatic(Env.class)) {
diff --git 
a/regression-test/framework/src/main/groovy/org/apache/doris/regression/suite/SuiteContext.groovy
 
b/regression-test/framework/src/main/groovy/org/apache/doris/regression/suite/SuiteContext.groovy
index 545b388c5a3..0d599aed817 100644
--- 
a/regression-test/framework/src/main/groovy/org/apache/doris/regression/suite/SuiteContext.groovy
+++ 
b/regression-test/framework/src/main/groovy/org/apache/doris/regression/suite/SuiteContext.groovy
@@ -18,6 +18,10 @@
 package org.apache.doris.regression.suite
 
 import com.google.common.collect.Maps
+import com.mysql.cj.NativeSession
+import com.mysql.cj.jdbc.JdbcConnection
+import com.mysql.cj.protocol.a.NativeConstants
+import com.mysql.cj.protocol.a.NativePacketPayload
 import groovy.transform.CompileStatic
 import org.apache.doris.regression.Config
 import org.apache.doris.regression.util.OutputUtils
@@ -437,6 +441,18 @@ class SuiteContext implements Closeable {
         connectTo(connInfo.conn.getMetaData().getURL(), connInfo.username, 
connInfo.password);
     }
 
+    public void resetConnection() {
+        ConnectionInfo connInfo = threadLocalConn.get()
+        if (connInfo == null) {
+            return
+        }
+        NativeSession session = (NativeSession) 
connInfo.conn.unwrap(JdbcConnection.class).getSession()
+        // COM_RESET_CONNECTION has no payload besides the command byte.
+        NativePacketPayload packet = new NativePacketPayload(1)
+        packet.writeInteger(NativeConstants.IntegerDataType.INT1, 0x1f)
+        session.sendCommand(packet, false, 0)
+    }
+
     public void connectTo(String url, String username, String password) {
         ConnectionInfo oldConn = threadLocalConn.get()
         if (oldConn != null) {
diff --git 
a/regression-test/suites/query_p0/session_variable/test_reset_connection_session_variable.groovy
 
b/regression-test/suites/query_p0/session_variable/test_reset_connection_session_variable.groovy
new file mode 100644
index 00000000000..e9f017e3fd6
--- /dev/null
+++ 
b/regression-test/suites/query_p0/session_variable/test_reset_connection_session_variable.groovy
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+import com.mysql.cj.NativeSession
+import com.mysql.cj.jdbc.JdbcConnection
+import com.mysql.cj.protocol.a.NativeConstants
+import com.mysql.cj.protocol.a.NativePacketPayload
+
+suite("test_reset_connection_session_variable", "p0") {
+    def resetCurrentConnection = {
+        NativeSession session = (NativeSession) 
context.getConnection().unwrap(JdbcConnection.class).getSession()
+        NativePacketPayload packet = new NativePacketPayload(1)
+        packet.writeInteger(NativeConstants.IntegerDataType.INT1, 0x1f)
+        session.sendCommand(packet, false, 0)
+    }
+    def currentDb = (sql "select database()")[0][0]
+
+    sql "set @reset_connection_user_variable = 1"
+    assertEquals(1, (sql "select @reset_connection_user_variable")[0][0])
+
+    sql "set sql_select_limit = 0"
+
+    def limitedResult = sql "select 1 union all select 2"
+    assertEquals(0, limitedResult.size())
+
+    resetCurrentConnection()
+
+    def resetDb = (sql "select database()")[0][0]
+    assertEquals(currentDb, resetDb)
+
+    def resetResult = sql "select 1 union all select 2"
+    assertEquals(2, resetResult.size())
+    assertNull((sql "select @reset_connection_user_variable")[0][0])
+
+    String url = getServerPrepareJdbcUrl(context.config.jdbcUrl, currentDb, 
false)
+    connect(context.config.jdbcUser, context.config.jdbcPassword, url) {
+        def connectionId = (sql "select connection_id()")[0][0].toString()
+        sql "set enable_server_side_prepared_statement = true"
+        def stmt = prepareStatement "select 1"
+        assertEquals(com.mysql.cj.jdbc.ServerPreparedStatement, stmt.class)
+        assertEquals(1, exec(stmt)[0][0])
+
+        resetCurrentConnection()
+
+        connect(context.config.jdbcUser, context.config.jdbcPassword, 
context.config.jdbcUrl) {
+            def processList = sql_return_maparray "show processlist"
+            def process = processList.find { it.Id.toString() == connectionId }
+            assertNotNull(process)
+            assertEquals("", process.QueryId)
+            assertEquals("", process.Info)
+        }
+
+        try {
+            exec(stmt)
+            assertTrue(false)
+        } catch (Exception e) {
+            assertTrue(e.getMessage().contains("Unknown prepared statement 
handler"))
+        } finally {
+            stmt.close()
+        }
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to