This is an automated email from the ASF dual-hosted git repository. kou pushed a commit to branch maint-10.0.x in repository https://gitbox.apache.org/repos/asf/arrow.git
commit 89f1b91a86583c3bcd7122a058fcb81815157599 Author: David Li <[email protected]> AuthorDate: Wed Nov 9 16:07:19 2022 -0500 ARROW-18294: [Java] Fix Flight SQL JDBC PreparedStatement#executeUpdate (#14616) We need to implement a separate code path for executing a prepared statement that returns an update count. Authored-by: David Li <[email protected]> Signed-off-by: David Li <[email protected]> --- .../arrow/driver/jdbc/ArrowFlightConnection.java | 2 +- .../arrow/driver/jdbc/ArrowFlightMetaImpl.java | 78 +++++++++++++++++++--- .../jdbc/ArrowFlightPreparedStatementTest.java | 15 ++++- 3 files changed, 83 insertions(+), 12 deletions(-) diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java index d2b6e89e3f..79bc04d27f 100644 --- a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java @@ -139,7 +139,7 @@ public final class ArrowFlightConnection extends AvaticaConnection { * * @return the handler. */ - ArrowFlightSqlClientHandler getClientHandler() throws SQLException { + ArrowFlightSqlClientHandler getClientHandler() { return clientHandler; } diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java index cc7addc3a7..f825e7d13c 100644 --- a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java @@ -42,7 +42,7 @@ import org.apache.calcite.avatica.remote.TypedValue; * Metadata handler for Arrow Flight. */ public class ArrowFlightMetaImpl extends MetaImpl { - private final Map<StatementHandle, PreparedStatement> statementHandlePreparedStatementMap; + private final Map<StatementHandleKey, PreparedStatement> statementHandlePreparedStatementMap; /** * Constructs a {@link MetaImpl} object specific for Arrow Flight. @@ -67,7 +67,8 @@ public class ArrowFlightMetaImpl extends MetaImpl { @Override public void closeStatement(final StatementHandle statementHandle) { - PreparedStatement preparedStatement = statementHandlePreparedStatementMap.remove(statementHandle); + PreparedStatement preparedStatement = + statementHandlePreparedStatementMap.remove(new StatementHandleKey(statementHandle)); // Testing if the prepared statement was created because the statement can be not created until this moment if (preparedStatement != null) { preparedStatement.close(); @@ -82,12 +83,25 @@ public class ArrowFlightMetaImpl extends MetaImpl { @Override public ExecuteResult execute(final StatementHandle statementHandle, final List<TypedValue> typedValues, final long maxRowCount) { - // TODO Why is maxRowCount ignored? - Preconditions.checkNotNull(statementHandle.signature, "Signature not found."); - return new ExecuteResult( - Collections.singletonList(MetaResultSet.create( - statementHandle.connectionId, statementHandle.id, - true, statementHandle.signature, null))); + Preconditions.checkArgument(connection.id.equals(statementHandle.connectionId), + "Connection IDs are not consistent"); + if (statementHandle.signature == null) { + // Update query + final StatementHandleKey key = new StatementHandleKey(statementHandle); + PreparedStatement preparedStatement = statementHandlePreparedStatementMap.get(key); + if (preparedStatement == null) { + throw new IllegalStateException("Prepared statement not found: " + statementHandle); + } + long updatedCount = preparedStatement.executeUpdate(); + return new ExecuteResult(Collections.singletonList(MetaResultSet.count(statementHandle.connectionId, + statementHandle.id, updatedCount))); + } else { + // TODO Why is maxRowCount ignored? + return new ExecuteResult( + Collections.singletonList(MetaResultSet.create( + statementHandle.connectionId, statementHandle.id, + true, statementHandle.signature, null))); + } } @Override @@ -121,6 +135,9 @@ public class ArrowFlightMetaImpl extends MetaImpl { final String query, final long maxRowCount) { final StatementHandle handle = super.createStatement(connectionHandle); handle.signature = newSignature(query); + final PreparedStatement preparedStatement = + ((ArrowFlightConnection) connection).getClientHandler().prepare(query); + statementHandlePreparedStatementMap.put(new StatementHandleKey(handle), preparedStatement); return handle; } @@ -143,7 +160,7 @@ public class ArrowFlightMetaImpl extends MetaImpl { final PreparedStatement preparedStatement = ((ArrowFlightConnection) connection).getClientHandler().prepare(query); final StatementType statementType = preparedStatement.getType(); - statementHandlePreparedStatementMap.put(handle, preparedStatement); + statementHandlePreparedStatementMap.put(new StatementHandleKey(handle), preparedStatement); final Signature signature = newSignature(query); final long updateCount = statementType.equals(StatementType.UPDATE) ? preparedStatement.executeUpdate() : -1; @@ -195,6 +212,47 @@ public class ArrowFlightMetaImpl extends MetaImpl { } PreparedStatement getPreparedStatement(StatementHandle statementHandle) { - return statementHandlePreparedStatementMap.get(statementHandle); + return statementHandlePreparedStatementMap.get(new StatementHandleKey(statementHandle)); + } + + // Helper used to look up prepared statement instances later. Avatica doesn't give us the signature in + // an UPDATE code path so we can't directly use StatementHandle as a map key. + private static final class StatementHandleKey { + public final String connectionId; + public final int id; + + StatementHandleKey(String connectionId, int id) { + this.connectionId = connectionId; + this.id = id; + } + + StatementHandleKey(StatementHandle statementHandle) { + this.connectionId = statementHandle.connectionId; + this.id = statementHandle.id; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + StatementHandleKey that = (StatementHandleKey) o; + + if (id != that.id) { + return false; + } + return connectionId.equals(that.connectionId); + } + + @Override + public int hashCode() { + int result = connectionId.hashCode(); + result = 31 * result + id; + return result; + } } } diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java index 51c491be28..8af529296f 100644 --- a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java @@ -18,6 +18,7 @@ package org.apache.arrow.driver.jdbc; import static org.hamcrest.CoreMatchers.equalTo; +import static org.junit.jupiter.api.Assertions.assertEquals; import java.sql.Connection; import java.sql.PreparedStatement; @@ -25,6 +26,7 @@ import java.sql.ResultSet; import java.sql.SQLException; import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.ClassRule; @@ -34,9 +36,10 @@ import org.junit.rules.ErrorCollector; public class ArrowFlightPreparedStatementTest { + public static final MockFlightSqlProducer PRODUCER = CoreMockedSqlProducers.getLegacyProducer(); @ClassRule public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE = FlightServerTestRule - .createStandardTestRule(CoreMockedSqlProducers.getLegacyProducer()); + .createStandardTestRule(PRODUCER); private static Connection connection; @@ -75,4 +78,14 @@ public class ArrowFlightPreparedStatementTest { collector.checkThat(6, equalTo(psmt.getMetaData().getColumnCount())); } } + + @Test + public void testUpdateQuery() throws SQLException { + String query = "Fake update"; + PRODUCER.addUpdateQuery(query, /*updatedRows*/42); + try (final PreparedStatement stmt = connection.prepareStatement(query)) { + int updated = stmt.executeUpdate(); + assertEquals(42, updated); + } + } }
