This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new f98a13250d GH-38732: [Java][FlightRPC] Add support for Array parameter
binding in JDBC (#38733)
f98a13250d is described below
commit f98a13250d10dba248a2bb85989d6b80265e82d8
Author: Diego Fernández Giraldo <[email protected]>
AuthorDate: Mon Nov 20 08:13:50 2023 -0700
GH-38732: [Java][FlightRPC] Add support for Array parameter binding in JDBC
(#38733)
This PR adds support for binding Arrays to Prepared Statements in the Arrow
Flight SQL JDBC driver. This has only been tested locally by tweaking
`ArrowFlightPreparedStatementTest` but more thorough testing of all types will
be done in a follow up.
To ensure consistency, I used the `BinderVisitor` to bind the child
vectors. This ensures any conversion logic that gets added to the Converters
will also be reflected here.
* Closes: #38732
Authored-by: Diego Fernandez <[email protected]>
Signed-off-by: David Li <[email protected]>
---
.../FixedSizeListAvaticaParameterConverter.java | 39 ++++++++++++++++++++++
.../impl/LargeListAvaticaParameterConverter.java | 31 +++++++++++++++++
.../impl/ListAvaticaParameterConverter.java | 30 +++++++++++++++++
.../driver/jdbc/utils/AvaticaParameterBinder.java | 14 ++++++--
.../jdbc/ArrowFlightPreparedStatementTest.java | 39 +++++++++++++++++-----
.../driver/jdbc/utils/MockFlightSqlProducer.java | 10 +++++-
6 files changed, 152 insertions(+), 11 deletions(-)
diff --git
a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/FixedSizeListAvaticaParameterConverter.java
b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/FixedSizeListAvaticaParameterConverter.java
index 60231a2460..1525bcaaf5 100644
---
a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/FixedSizeListAvaticaParameterConverter.java
+++
b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/FixedSizeListAvaticaParameterConverter.java
@@ -17,7 +17,11 @@
package org.apache.arrow.driver.jdbc.converter.impl;
+import java.util.List;
+
+import org.apache.arrow.driver.jdbc.utils.AvaticaParameterBinder;
import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.complex.FixedSizeListVector;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.calcite.avatica.AvaticaParameter;
@@ -33,6 +37,41 @@ public class FixedSizeListAvaticaParameterConverter extends
BaseAvaticaParameter
@Override
public boolean bindParameter(FieldVector vector, TypedValue typedValue, int
index) {
+ final List<?> values = (List<?>) typedValue.value;
+ final int arraySize = values.size();
+
+ if (vector instanceof FixedSizeListVector) {
+ FixedSizeListVector listVector = ((FixedSizeListVector) vector);
+ FieldVector childVector = listVector.getDataVector();
+ int maxArraySize = listVector.getListSize();
+
+ if (arraySize != maxArraySize) {
+ if (!childVector.getField().isNullable()) {
+ throw new UnsupportedOperationException("Each array must contain " +
maxArraySize + " elements");
+ } else if (arraySize > maxArraySize) {
+ throw new UnsupportedOperationException("Each array must contain at
most " + maxArraySize + " elements");
+ }
+ }
+
+ int startPos = listVector.startNewValue(index);
+ for (int i = 0; i < arraySize; i++) {
+ Object val = values.get(i);
+ int childIndex = startPos + i;
+ if (val == null) {
+ if (childVector.getField().isNullable()) {
+ childVector.setNull(childIndex);
+ } else {
+ throw new UnsupportedOperationException("Can't set null on
non-nullable child list");
+ }
+ } else {
+ childVector.getField().getType().accept(
+ new AvaticaParameterBinder.BinderVisitor(
+ childVector,
TypedValue.ofSerial(typedValue.componentType, val), childIndex));
+ }
+ }
+ listVector.setValueCount(index + 1);
+ return true;
+ }
return false;
}
diff --git
a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/LargeListAvaticaParameterConverter.java
b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/LargeListAvaticaParameterConverter.java
index 6ef6920474..a20747693e 100644
---
a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/LargeListAvaticaParameterConverter.java
+++
b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/LargeListAvaticaParameterConverter.java
@@ -17,7 +17,12 @@
package org.apache.arrow.driver.jdbc.converter.impl;
+import java.util.List;
+
+import org.apache.arrow.driver.jdbc.utils.AvaticaParameterBinder;
+import org.apache.arrow.memory.util.LargeMemoryUtil;
import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.complex.LargeListVector;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.calcite.avatica.AvaticaParameter;
@@ -33,6 +38,32 @@ public class LargeListAvaticaParameterConverter extends
BaseAvaticaParameterConv
@Override
public boolean bindParameter(FieldVector vector, TypedValue typedValue, int
index) {
+ final List<?> values = (List<?>) typedValue.value;
+
+ if (vector instanceof LargeListVector) {
+ LargeListVector listVector = ((LargeListVector) vector);
+ FieldVector childVector = listVector.getDataVector();
+
+ long startPos = listVector.startNewValue(index);
+ for (int i = 0; i < values.size(); i++) {
+ Object val = values.get(i);
+ int childIndex = LargeMemoryUtil.checkedCastToInt(startPos) + i;
+ if (val == null) {
+ if (childVector.getField().isNullable()) {
+ childVector.setNull(childIndex);
+ } else {
+ throw new UnsupportedOperationException("Can't set null on
non-nullable child list");
+ }
+ } else {
+ childVector.getField().getType().accept(
+ new AvaticaParameterBinder.BinderVisitor(
+ childVector,
TypedValue.ofSerial(typedValue.componentType, val), childIndex));
+ }
+ }
+ listVector.endValue(index, values.size());
+ listVector.setValueCount(index + 1);
+ return true;
+ }
return false;
}
diff --git
a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/ListAvaticaParameterConverter.java
b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/ListAvaticaParameterConverter.java
index aec59cb4d4..f6cb9f3be2 100644
---
a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/ListAvaticaParameterConverter.java
+++
b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/ListAvaticaParameterConverter.java
@@ -17,7 +17,11 @@
package org.apache.arrow.driver.jdbc.converter.impl;
+import java.util.List;
+
+import org.apache.arrow.driver.jdbc.utils.AvaticaParameterBinder;
import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.calcite.avatica.AvaticaParameter;
@@ -33,6 +37,32 @@ public class ListAvaticaParameterConverter extends
BaseAvaticaParameterConverter
@Override
public boolean bindParameter(FieldVector vector, TypedValue typedValue, int
index) {
+ final List<?> values = (List<?>) typedValue.value;
+
+ if (vector instanceof ListVector) {
+ ListVector listVector = ((ListVector) vector);
+ FieldVector childVector = listVector.getDataVector();
+
+ int startPos = listVector.startNewValue(index);
+ for (int i = 0; i < values.size(); i++) {
+ Object val = values.get(i);
+ int childIndex = startPos + i;
+ if (val == null) {
+ if (childVector.getField().isNullable()) {
+ childVector.setNull(childIndex);
+ } else {
+ throw new UnsupportedOperationException("Can't set null on
non-nullable child list");
+ }
+ } else {
+ childVector.getField().getType().accept(
+ new AvaticaParameterBinder.BinderVisitor(
+ childVector,
TypedValue.ofSerial(typedValue.componentType, val), childIndex));
+ }
+ }
+ listVector.endValue(index, values.size());
+ listVector.setValueCount(index + 1);
+ return true;
+ }
return false;
}
diff --git
a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/AvaticaParameterBinder.java
b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/AvaticaParameterBinder.java
index 9e805fc79b..5fa3ba38f2 100644
---
a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/AvaticaParameterBinder.java
+++
b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/AvaticaParameterBinder.java
@@ -119,12 +119,22 @@ public class AvaticaParameterBinder {
}
}
- private static class BinderVisitor implements
ArrowType.ArrowTypeVisitor<Boolean> {
+ /**
+ * ArrowTypeVisitor that binds Avatica TypedValues to the given FieldVector
at the specified index.
+ */
+ public static class BinderVisitor implements
ArrowType.ArrowTypeVisitor<Boolean> {
private final FieldVector vector;
private final TypedValue typedValue;
private final int index;
- private BinderVisitor(FieldVector vector, TypedValue value, int index) {
+ /**
+ * Instantiate a new BinderVisitor.
+ *
+ * @param vector FieldVector to bind values to.
+ * @param value TypedValue to bind.
+ * @param index Vector index (0-based) to bind the value to.
+ */
+ public BinderVisitor(FieldVector vector, TypedValue value, int index) {
this.vector = vector;
this.typedValue = value;
this.index = index;
diff --git
a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java
b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java
index b19f049544..0b521a704b 100644
---
a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java
+++
b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java
@@ -27,6 +27,7 @@ import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.Collections;
+import java.util.List;
import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers;
import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer;
@@ -38,6 +39,7 @@ import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.Text;
import org.junit.AfterClass;
@@ -89,6 +91,14 @@ public class ArrowFlightPreparedStatementTest {
public void testQueryWithParameterBinding() throws SQLException {
final String query = "Fake query with parameters";
final Schema schema = new
Schema(Collections.singletonList(Field.nullable("",
Types.MinorType.INT.getType())));
+ final Schema parameterSchema = new Schema(Arrays.asList(
+ Field.nullable("", ArrowType.Utf8.INSTANCE),
+ new Field("", FieldType.nullable(ArrowType.List.INSTANCE),
+ Collections.singletonList(Field.nullable("",
Types.MinorType.INT.getType())))));
+ final List<List<Object>> expected =
Collections.singletonList(Arrays.asList(
+ new Text("foo"),
+ new Integer[]{1, 2, null}));
+
PRODUCER.addSelectQuery(query, schema,
Collections.singletonList(listener -> {
try (final BufferAllocator allocator = new
RootAllocator(Long.MAX_VALUE);
@@ -105,11 +115,12 @@ public class ArrowFlightPreparedStatementTest {
}
}));
- PRODUCER.addExpectedParameters(query,
- new Schema(Collections.singletonList(Field.nullable("",
ArrowType.Utf8.INSTANCE))),
- Collections.singletonList(Collections.singletonList(new
Text("foo".getBytes(StandardCharsets.UTF_8)))));
+ PRODUCER.addExpectedParameters(query, parameterSchema, expected);
+
try (final PreparedStatement preparedStatement =
connection.prepareStatement(query)) {
preparedStatement.setString(1, "foo");
+ preparedStatement.setArray(2, connection.createArrayOf("INTEGER", new
Integer[]{1, 2, null}));
+
try (final ResultSet resultSet = preparedStatement.executeQuery()) {
resultSet.next();
assert true;
@@ -171,17 +182,29 @@ public class ArrowFlightPreparedStatementTest {
@Test
public void testUpdateQueryWithBatchedParameters() throws SQLException {
String query = "Fake update with batched parameters";
- PRODUCER.addUpdateQuery(query, /*updatedRows*/42);
- PRODUCER.addExpectedParameters(query,
- new Schema(Collections.singletonList(Field.nullable("",
ArrowType.Utf8.INSTANCE))),
+ Schema parameterSchema = new Schema(Arrays.asList(
+ Field.nullable("", ArrowType.Utf8.INSTANCE),
+ new Field("", FieldType.nullable(ArrowType.List.INSTANCE),
+ Collections.singletonList(Field.nullable("",
Types.MinorType.INT.getType())))));
+ List<List<Object>> expected = Arrays.asList(
+ Arrays.asList(
+ new Text("foo"),
+ new Integer[]{1, 2, null}),
Arrays.asList(
- Collections.singletonList(new
Text("foo".getBytes(StandardCharsets.UTF_8))),
- Collections.singletonList(new
Text("bar".getBytes(StandardCharsets.UTF_8)))));
+ new Text("bar"),
+ new Integer[]{0, -1, 100000})
+ );
+
+ PRODUCER.addUpdateQuery(query, /*updatedRows*/42);
+ PRODUCER.addExpectedParameters(query, parameterSchema, expected);
+
try (final PreparedStatement stmt = connection.prepareStatement(query)) {
// TODO: make sure this is validated on the server too
stmt.setString(1, "foo");
+ stmt.setArray(2, connection.createArrayOf("INTEGER", new Integer[]{1, 2,
null}));
stmt.addBatch();
stmt.setString(1, "bar");
+ stmt.setArray(2, connection.createArrayOf("INTEGER", new Integer[]{0,
-1, 100000}));
stmt.addBatch();
int[] updated = stmt.executeBatch();
assertEquals(42, updated[0]);
diff --git
a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java
b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java
index 2b65f8f5a0..eaba008fbf 100644
---
a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java
+++
b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java
@@ -29,6 +29,7 @@ import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.charset.StandardCharsets;
import java.util.AbstractMap.SimpleImmutableEntry;
+import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
@@ -80,6 +81,7 @@ import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.WriteChannel;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.arrow.vector.util.JsonStringArrayList;
import org.apache.calcite.avatica.Meta.StatementType;
import com.google.protobuf.Any;
@@ -373,7 +375,13 @@ public final class MockFlightSqlProducer implements
FlightSqlProducer {
for (int paramIndex = 0; paramIndex < expectedRow.size();
paramIndex++) {
Object expected = expectedRow.get(paramIndex);
Object actual = root.getVector(paramIndex).getObject(i);
- if (!Objects.equals(expected, actual)) {
+ boolean matches;
+ if (expected.getClass().isArray()) {
+ matches = Arrays.equals((Object[]) expected,
((JsonStringArrayList) actual).toArray());
+ } else {
+ matches = Objects.equals(expected, actual);
+ }
+ if (!matches) {
streamListener.onError(CallStatus.INVALID_ARGUMENT
.withDescription(String.format("Parameter mismatch.
Expected: %s Actual: %s", expected, actual))
.toRuntimeException());