This is an automated email from the ASF dual-hosted git repository.
panjuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new 0d7bc500aad Refactor PostgreSQLAggregatedCommandPacket (#25961)
0d7bc500aad is described below
commit 0d7bc500aadf295e16845a07267710b009da23f2
Author: Liang Zhang <[email protected]>
AuthorDate: Wed May 31 19:39:11 2023 +0800
Refactor PostgreSQLAggregatedCommandPacket (#25961)
---
.../db/protocol/codec/PacketCodecTest.java | 1 -
.../PostgreSQLAggregatedCommandPacket.java | 57 +++++++++-------------
.../command/OpenGaussCommandExecutorFactory.java | 12 ++---
.../OpenGaussCommandExecutorFactoryTest.java | 4 +-
.../command/PostgreSQLCommandExecutorFactory.java | 12 ++---
.../PostgreSQLCommandExecutorFactoryTest.java | 4 +-
6 files changed, 40 insertions(+), 50 deletions(-)
diff --git
a/db-protocol/core/src/test/java/org/apache/shardingsphere/db/protocol/codec/PacketCodecTest.java
b/db-protocol/core/src/test/java/org/apache/shardingsphere/db/protocol/codec/PacketCodecTest.java
index d2e0808f3d0..3a9ce8e99a9 100644
---
a/db-protocol/core/src/test/java/org/apache/shardingsphere/db/protocol/codec/PacketCodecTest.java
+++
b/db-protocol/core/src/test/java/org/apache/shardingsphere/db/protocol/codec/PacketCodecTest.java
@@ -68,7 +68,6 @@ class PacketCodecTest {
verify(databasePacketCodecEngine, times(0)).decode(context, byteBuf,
Collections.emptyList());
}
- @SuppressWarnings("unchecked")
@Test
void assertEncode() {
DatabasePacket databasePacket = mock(DatabasePacket.class);
diff --git
a/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/command/query/extended/PostgreSQLAggregatedCommandPacket.java
b/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/command/query/extended/PostgreSQLAggregatedCommandPacket.java
index 21a8218aec5..0cf186eb857 100644
---
a/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/command/query/extended/PostgreSQLAggregatedCommandPacket.java
+++
b/db-protocol/postgresql/src/main/java/org/apache/shardingsphere/db/protocol/postgresql/packet/command/query/extended/PostgreSQLAggregatedCommandPacket.java
@@ -17,7 +17,6 @@
package
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended;
-import com.google.common.base.Preconditions;
import lombok.Getter;
import
org.apache.shardingsphere.db.protocol.postgresql.packet.command.PostgreSQLCommandPacket;
import
org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.extended.bind.PostgreSQLComBindPacket;
@@ -27,7 +26,6 @@ import
org.apache.shardingsphere.db.protocol.postgresql.packet.identifier.Postgr
import
org.apache.shardingsphere.db.protocol.postgresql.payload.PostgreSQLPacketPayload;
import java.util.List;
-import java.util.RandomAccess;
@Getter
public final class PostgreSQLAggregatedCommandPacket extends
PostgreSQLCommandPacket {
@@ -36,38 +34,38 @@ public final class PostgreSQLAggregatedCommandPacket
extends PostgreSQLCommandPa
private final boolean containsBatchedStatements;
- private final int firstBindIndex;
+ private final int batchPacketBeginIndex;
- private final int lastExecuteIndex;
+ private final int batchPacketEndIndex;
public PostgreSQLAggregatedCommandPacket(final
List<PostgreSQLCommandPacket> packets) {
this.packets = packets;
- int parseTimes = 0;
- int firstStatementBindTimes = 0;
- int firstStatementExecuteTimes = 0;
- String firstStatement = null;
+ String firstStatementId = null;
String firstPortal = null;
+ int parsePacketCount = 0;
+ int bindPacketCountForFirstStatement = 0;
+ int executePacketCountForFirstStatement = 0;
+ int batchPacketBeginIndex = -1;
+ int batchPacketEndIndex = -1;
int index = 0;
- int firstBindIndex = -1;
- int lastExecuteIndex = -1;
for (PostgreSQLCommandPacket each : packets) {
if (each instanceof PostgreSQLComParsePacket) {
- if (++parseTimes > 1) {
+ if (++parsePacketCount > 1) {
break;
}
- if (null == firstStatement) {
- firstStatement = ((PostgreSQLComParsePacket)
each).getStatementId();
- } else if (!firstStatement.equals(((PostgreSQLComParsePacket)
each).getStatementId())) {
+ if (null == firstStatementId) {
+ firstStatementId = ((PostgreSQLComParsePacket)
each).getStatementId();
+ } else if
(!firstStatementId.equals(((PostgreSQLComParsePacket) each).getStatementId())) {
break;
}
}
if (each instanceof PostgreSQLComBindPacket) {
- if (-1 == firstBindIndex) {
- firstBindIndex = index;
+ if (-1 == batchPacketBeginIndex) {
+ batchPacketBeginIndex = index;
}
- if (null == firstStatement) {
- firstStatement = ((PostgreSQLComBindPacket)
each).getStatementId();
- } else if (!firstStatement.equals(((PostgreSQLComBindPacket)
each).getStatementId())) {
+ if (null == firstStatementId) {
+ firstStatementId = ((PostgreSQLComBindPacket)
each).getStatementId();
+ } else if (!firstStatementId.equals(((PostgreSQLComBindPacket)
each).getStatementId())) {
break;
}
if (null == firstPortal) {
@@ -75,31 +73,24 @@ public final class PostgreSQLAggregatedCommandPacket
extends PostgreSQLCommandPa
} else if (!firstPortal.equals(((PostgreSQLComBindPacket)
each).getPortal())) {
break;
}
- firstStatementBindTimes++;
+ bindPacketCountForFirstStatement++;
}
if (each instanceof PostgreSQLComExecutePacket) {
- if (index > lastExecuteIndex) {
- lastExecuteIndex = index;
+ if (index > batchPacketEndIndex) {
+ batchPacketEndIndex = index;
}
if (null == firstPortal) {
firstPortal = ((PostgreSQLComExecutePacket)
each).getPortal();
} else if (!firstPortal.equals(((PostgreSQLComExecutePacket)
each).getPortal())) {
break;
}
- firstStatementExecuteTimes++;
+ executePacketCountForFirstStatement++;
}
index++;
}
- this.firstBindIndex = firstBindIndex;
- this.lastExecuteIndex = lastExecuteIndex;
- containsBatchedStatements = firstStatementBindTimes ==
firstStatementExecuteTimes && firstStatementBindTimes >= 3;
- if (containsBatchedStatements) {
- ensureRandomAccessible(packets);
- }
- }
-
- private void ensureRandomAccessible(final List<PostgreSQLCommandPacket>
packets) {
- Preconditions.checkArgument(packets instanceof RandomAccess, "Packets
must be RandomAccess.");
+ this.batchPacketBeginIndex = batchPacketBeginIndex;
+ this.batchPacketEndIndex = batchPacketEndIndex;
+ containsBatchedStatements = bindPacketCountForFirstStatement ==
executePacketCountForFirstStatement && bindPacketCountForFirstStatement >= 3;
}
@Override
diff --git
a/proxy/frontend/type/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactory.java
b/proxy/frontend/type/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactory.java
index 2a333e4f580..553e4011df6 100644
---
a/proxy/frontend/type/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactory.java
+++
b/proxy/frontend/type/opengauss/src/main/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactory.java
@@ -95,15 +95,15 @@ public final class OpenGaussCommandExecutorFactory {
private static List<CommandExecutor>
getExecutorsOfAggregatedBatchedStatements(final
PostgreSQLAggregatedCommandPacket aggregatedCommandPacket,
final ConnectionSession connectionSession, final PortalContext
portalContext) throws SQLException {
List<PostgreSQLCommandPacket> packets =
aggregatedCommandPacket.getPackets();
- int firstBindIndex = aggregatedCommandPacket.getFirstBindIndex();
- int lastExecuteIndex = aggregatedCommandPacket.getLastExecuteIndex();
- List<CommandExecutor> result = new ArrayList<>(firstBindIndex +
packets.size() - lastExecuteIndex);
- for (int i = 0; i < firstBindIndex; i++) {
+ int batchPacketBeginIndex =
aggregatedCommandPacket.getBatchPacketBeginIndex();
+ int batchPacketEndIndex =
aggregatedCommandPacket.getBatchPacketEndIndex();
+ List<CommandExecutor> result = new ArrayList<>(batchPacketBeginIndex +
packets.size() - batchPacketEndIndex);
+ for (int i = 0; i < batchPacketBeginIndex; i++) {
PostgreSQLCommandPacket each = packets.get(i);
result.add(getCommandExecutor((CommandPacketType)
each.getIdentifier(), each, connectionSession, portalContext));
}
- result.add(new
PostgreSQLAggregatedBatchedStatementsCommandExecutor(connectionSession,
packets.subList(firstBindIndex, lastExecuteIndex + 1)));
- for (int i = lastExecuteIndex + 1; i < packets.size(); i++) {
+ result.add(new
PostgreSQLAggregatedBatchedStatementsCommandExecutor(connectionSession,
packets.subList(batchPacketBeginIndex, batchPacketEndIndex + 1)));
+ for (int i = batchPacketEndIndex + 1; i < packets.size(); i++) {
PostgreSQLCommandPacket each = packets.get(i);
result.add(getCommandExecutor((CommandPacketType)
each.getIdentifier(), each, connectionSession, portalContext));
}
diff --git
a/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactoryTest.java
b/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactoryTest.java
index ccb77e4eaf3..f0d01c2519c 100644
---
a/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactoryTest.java
+++
b/proxy/frontend/type/opengauss/src/test/java/org/apache/shardingsphere/proxy/frontend/opengauss/command/OpenGaussCommandExecutorFactoryTest.java
@@ -135,8 +135,8 @@ class OpenGaussCommandExecutorFactoryTest {
when(packet.isContainsBatchedStatements()).thenReturn(true);
when(packet.getPackets()).thenReturn(
Arrays.asList(parsePacket, bindPacket, describePacket,
executePacket, bindPacket, describePacket, executePacket, closePacket,
syncPacket, terminationPacket));
- when(packet.getFirstBindIndex()).thenReturn(1);
- when(packet.getLastExecuteIndex()).thenReturn(6);
+ when(packet.getBatchPacketBeginIndex()).thenReturn(1);
+ when(packet.getBatchPacketEndIndex()).thenReturn(6);
CommandExecutor actual =
OpenGaussCommandExecutorFactory.newInstance(null, packet, connectionSession,
portalContext);
assertThat(actual,
instanceOf(PostgreSQLAggregatedCommandExecutor.class));
Iterator<CommandExecutor> actualPacketsIterator =
getExecutorsFromAggregatedCommandExecutor((PostgreSQLAggregatedCommandExecutor)
actual).iterator();
diff --git
a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactory.java
b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactory.java
index 9da80efb4b8..20ac3c66998 100644
---
a/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactory.java
+++
b/proxy/frontend/type/postgresql/src/main/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactory.java
@@ -90,15 +90,15 @@ public final class PostgreSQLCommandExecutorFactory {
private static List<CommandExecutor>
getExecutorsOfAggregatedBatchedStatements(final
PostgreSQLAggregatedCommandPacket aggregatedCommandPacket,
final ConnectionSession connectionSession, final PortalContext
portalContext) throws SQLException {
List<PostgreSQLCommandPacket> packets =
aggregatedCommandPacket.getPackets();
- int firstBindIndex = aggregatedCommandPacket.getFirstBindIndex();
- int lastExecuteIndex = aggregatedCommandPacket.getLastExecuteIndex();
- List<CommandExecutor> result = new ArrayList<>(firstBindIndex +
packets.size() - lastExecuteIndex);
- for (int i = 0; i < firstBindIndex; i++) {
+ int batchPacketBeginIndex =
aggregatedCommandPacket.getBatchPacketBeginIndex();
+ int batchPacketEndIndex =
aggregatedCommandPacket.getBatchPacketEndIndex();
+ List<CommandExecutor> result = new ArrayList<>(batchPacketBeginIndex +
packets.size() - batchPacketEndIndex);
+ for (int i = 0; i < batchPacketBeginIndex; i++) {
PostgreSQLCommandPacket each = packets.get(i);
result.add(getCommandExecutor((PostgreSQLCommandPacketType)
each.getIdentifier(), each, connectionSession, portalContext));
}
- result.add(new
PostgreSQLAggregatedBatchedStatementsCommandExecutor(connectionSession,
packets.subList(firstBindIndex, lastExecuteIndex + 1)));
- for (int i = lastExecuteIndex + 1; i < packets.size(); i++) {
+ result.add(new
PostgreSQLAggregatedBatchedStatementsCommandExecutor(connectionSession,
packets.subList(batchPacketBeginIndex, batchPacketEndIndex + 1)));
+ for (int i = batchPacketEndIndex + 1; i < packets.size(); i++) {
PostgreSQLCommandPacket each = packets.get(i);
result.add(getCommandExecutor((PostgreSQLCommandPacketType)
each.getIdentifier(), each, connectionSession, portalContext));
}
diff --git
a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactoryTest.java
b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactoryTest.java
index 712336bd7fa..b54f3e83214 100644
---
a/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactoryTest.java
+++
b/proxy/frontend/type/postgresql/src/test/java/org/apache/shardingsphere/proxy/frontend/postgresql/command/PostgreSQLCommandExecutorFactoryTest.java
@@ -141,8 +141,8 @@ class PostgreSQLCommandExecutorFactoryTest {
PostgreSQLAggregatedCommandPacket packet =
mock(PostgreSQLAggregatedCommandPacket.class);
when(packet.isContainsBatchedStatements()).thenReturn(true);
when(packet.getPackets()).thenReturn(Arrays.asList(parsePacket,
bindPacket, describePacket, executePacket, bindPacket, describePacket,
executePacket, syncPacket));
- when(packet.getFirstBindIndex()).thenReturn(1);
- when(packet.getLastExecuteIndex()).thenReturn(6);
+ when(packet.getBatchPacketBeginIndex()).thenReturn(1);
+ when(packet.getBatchPacketEndIndex()).thenReturn(6);
CommandExecutor actual =
PostgreSQLCommandExecutorFactory.newInstance(null, packet, connectionSession,
portalContext);
assertThat(actual,
instanceOf(PostgreSQLAggregatedCommandExecutor.class));
Iterator<CommandExecutor> actualPacketsIterator =
getExecutorsFromAggregatedCommandExecutor((PostgreSQLAggregatedCommandExecutor)
actual).iterator();