This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new f98a13250d GH-38732: [Java][FlightRPC] Add support for Array parameter 
binding in JDBC (#38733)
f98a13250d is described below

commit f98a13250d10dba248a2bb85989d6b80265e82d8
Author: Diego Fernández Giraldo <[email protected]>
AuthorDate: Mon Nov 20 08:13:50 2023 -0700

    GH-38732: [Java][FlightRPC] Add support for Array parameter binding in JDBC 
(#38733)
    
    This PR adds support for binding Arrays to Prepared Statements in the Arrow 
Flight SQL JDBC driver. This has only been tested locally by tweaking 
`ArrowFlightPreparedStatementTest` but more thorough testing of all types will 
be done in a follow up.
    
    To ensure consistency, I used the `BinderVisitor` to bind the child 
vectors. This ensures any conversion logic that gets added to the Converters 
will also be reflected here.
    * Closes: #38732
    
    Authored-by: Diego Fernandez <[email protected]>
    Signed-off-by: David Li <[email protected]>
---
 .../FixedSizeListAvaticaParameterConverter.java    | 39 ++++++++++++++++++++++
 .../impl/LargeListAvaticaParameterConverter.java   | 31 +++++++++++++++++
 .../impl/ListAvaticaParameterConverter.java        | 30 +++++++++++++++++
 .../driver/jdbc/utils/AvaticaParameterBinder.java  | 14 ++++++--
 .../jdbc/ArrowFlightPreparedStatementTest.java     | 39 +++++++++++++++++-----
 .../driver/jdbc/utils/MockFlightSqlProducer.java   | 10 +++++-
 6 files changed, 152 insertions(+), 11 deletions(-)

diff --git 
a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/FixedSizeListAvaticaParameterConverter.java
 
b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/FixedSizeListAvaticaParameterConverter.java
index 60231a2460..1525bcaaf5 100644
--- 
a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/FixedSizeListAvaticaParameterConverter.java
+++ 
b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/FixedSizeListAvaticaParameterConverter.java
@@ -17,7 +17,11 @@
 
 package org.apache.arrow.driver.jdbc.converter.impl;
 
+import java.util.List;
+
+import org.apache.arrow.driver.jdbc.utils.AvaticaParameterBinder;
 import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.complex.FixedSizeListVector;
 import org.apache.arrow.vector.types.pojo.ArrowType;
 import org.apache.arrow.vector.types.pojo.Field;
 import org.apache.calcite.avatica.AvaticaParameter;
@@ -33,6 +37,41 @@ public class FixedSizeListAvaticaParameterConverter extends 
BaseAvaticaParameter
 
   @Override
   public boolean bindParameter(FieldVector vector, TypedValue typedValue, int 
index) {
+    final List<?> values = (List<?>) typedValue.value;
+    final int arraySize = values.size();
+
+    if (vector instanceof FixedSizeListVector) {
+      FixedSizeListVector listVector = ((FixedSizeListVector) vector);
+      FieldVector childVector = listVector.getDataVector();
+      int maxArraySize = listVector.getListSize();
+
+      if (arraySize != maxArraySize) {
+        if (!childVector.getField().isNullable()) {
+          throw new UnsupportedOperationException("Each array must contain " + 
maxArraySize + " elements");
+        } else if (arraySize > maxArraySize) {
+          throw new UnsupportedOperationException("Each array must contain at 
most " + maxArraySize + " elements");
+        }
+      }
+
+      int startPos = listVector.startNewValue(index);
+      for (int i = 0; i < arraySize; i++) {
+        Object val = values.get(i);
+        int childIndex = startPos + i;
+        if (val == null) {
+          if (childVector.getField().isNullable()) {
+            childVector.setNull(childIndex);
+          } else {
+            throw new UnsupportedOperationException("Can't set null on 
non-nullable child list");
+          }
+        } else {
+          childVector.getField().getType().accept(
+                  new AvaticaParameterBinder.BinderVisitor(
+                          childVector, 
TypedValue.ofSerial(typedValue.componentType, val), childIndex));
+        }
+      }
+      listVector.setValueCount(index + 1);
+      return true;
+    }
     return false;
   }
 
diff --git 
a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/LargeListAvaticaParameterConverter.java
 
b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/LargeListAvaticaParameterConverter.java
index 6ef6920474..a20747693e 100644
--- 
a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/LargeListAvaticaParameterConverter.java
+++ 
b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/LargeListAvaticaParameterConverter.java
@@ -17,7 +17,12 @@
 
 package org.apache.arrow.driver.jdbc.converter.impl;
 
+import java.util.List;
+
+import org.apache.arrow.driver.jdbc.utils.AvaticaParameterBinder;
+import org.apache.arrow.memory.util.LargeMemoryUtil;
 import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.complex.LargeListVector;
 import org.apache.arrow.vector.types.pojo.ArrowType;
 import org.apache.arrow.vector.types.pojo.Field;
 import org.apache.calcite.avatica.AvaticaParameter;
@@ -33,6 +38,32 @@ public class LargeListAvaticaParameterConverter extends 
BaseAvaticaParameterConv
 
   @Override
   public boolean bindParameter(FieldVector vector, TypedValue typedValue, int 
index) {
+    final List<?> values = (List<?>) typedValue.value;
+
+    if (vector instanceof LargeListVector) {
+      LargeListVector listVector = ((LargeListVector) vector);
+      FieldVector childVector = listVector.getDataVector();
+
+      long startPos = listVector.startNewValue(index);
+      for (int i = 0; i < values.size(); i++) {
+        Object val = values.get(i);
+        int childIndex = LargeMemoryUtil.checkedCastToInt(startPos) + i;
+        if (val == null) {
+          if (childVector.getField().isNullable()) {
+            childVector.setNull(childIndex);
+          } else {
+            throw new UnsupportedOperationException("Can't set null on 
non-nullable child list");
+          }
+        } else {
+          childVector.getField().getType().accept(
+                  new AvaticaParameterBinder.BinderVisitor(
+                          childVector, 
TypedValue.ofSerial(typedValue.componentType, val), childIndex));
+        }
+      }
+      listVector.endValue(index, values.size());
+      listVector.setValueCount(index + 1);
+      return true;
+    }
     return false;
   }
 
diff --git 
a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/ListAvaticaParameterConverter.java
 
b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/ListAvaticaParameterConverter.java
index aec59cb4d4..f6cb9f3be2 100644
--- 
a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/ListAvaticaParameterConverter.java
+++ 
b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/converter/impl/ListAvaticaParameterConverter.java
@@ -17,7 +17,11 @@
 
 package org.apache.arrow.driver.jdbc.converter.impl;
 
+import java.util.List;
+
+import org.apache.arrow.driver.jdbc.utils.AvaticaParameterBinder;
 import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.complex.ListVector;
 import org.apache.arrow.vector.types.pojo.ArrowType;
 import org.apache.arrow.vector.types.pojo.Field;
 import org.apache.calcite.avatica.AvaticaParameter;
@@ -33,6 +37,32 @@ public class ListAvaticaParameterConverter extends 
BaseAvaticaParameterConverter
 
   @Override
   public boolean bindParameter(FieldVector vector, TypedValue typedValue, int 
index) {
+    final List<?> values = (List<?>) typedValue.value;
+
+    if (vector instanceof ListVector) {
+      ListVector listVector = ((ListVector) vector);
+      FieldVector childVector = listVector.getDataVector();
+
+      int startPos = listVector.startNewValue(index);
+      for (int i = 0; i < values.size(); i++) {
+        Object val = values.get(i);
+        int childIndex = startPos + i;
+        if (val == null) {
+          if (childVector.getField().isNullable()) {
+            childVector.setNull(childIndex);
+          } else {
+            throw new UnsupportedOperationException("Can't set null on 
non-nullable child list");
+          }
+        } else {
+          childVector.getField().getType().accept(
+                  new AvaticaParameterBinder.BinderVisitor(
+                          childVector, 
TypedValue.ofSerial(typedValue.componentType, val), childIndex));
+        }
+      }
+      listVector.endValue(index, values.size());
+      listVector.setValueCount(index + 1);
+      return true;
+    }
     return false;
   }
 
diff --git 
a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/AvaticaParameterBinder.java
 
b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/AvaticaParameterBinder.java
index 9e805fc79b..5fa3ba38f2 100644
--- 
a/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/AvaticaParameterBinder.java
+++ 
b/java/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/AvaticaParameterBinder.java
@@ -119,12 +119,22 @@ public class AvaticaParameterBinder {
     }
   }
 
-  private static class BinderVisitor implements 
ArrowType.ArrowTypeVisitor<Boolean> {
+  /**
+   * ArrowTypeVisitor that binds Avatica TypedValues to the given FieldVector 
at the specified index.
+   */
+  public static class BinderVisitor implements 
ArrowType.ArrowTypeVisitor<Boolean> {
     private final FieldVector vector;
     private final TypedValue typedValue;
     private final int index;
 
-    private BinderVisitor(FieldVector vector, TypedValue value, int index) {
+    /**
+     * Instantiate a new BinderVisitor.
+     *
+     * @param vector FieldVector to bind values to.
+     * @param value TypedValue to bind.
+     * @param index Vector index (0-based) to bind the value to.
+     */
+    public BinderVisitor(FieldVector vector, TypedValue value, int index) {
       this.vector = vector;
       this.typedValue = value;
       this.index = index;
diff --git 
a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java
 
b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java
index b19f049544..0b521a704b 100644
--- 
a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java
+++ 
b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java
@@ -27,6 +27,7 @@ import java.sql.ResultSet;
 import java.sql.SQLException;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.List;
 
 import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers;
 import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer;
@@ -38,6 +39,7 @@ import org.apache.arrow.vector.VectorSchemaRoot;
 import org.apache.arrow.vector.types.Types;
 import org.apache.arrow.vector.types.pojo.ArrowType;
 import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.FieldType;
 import org.apache.arrow.vector.types.pojo.Schema;
 import org.apache.arrow.vector.util.Text;
 import org.junit.AfterClass;
@@ -89,6 +91,14 @@ public class ArrowFlightPreparedStatementTest {
   public void testQueryWithParameterBinding() throws SQLException {
     final String query = "Fake query with parameters";
     final Schema schema = new 
Schema(Collections.singletonList(Field.nullable("", 
Types.MinorType.INT.getType())));
+    final Schema parameterSchema = new Schema(Arrays.asList(
+            Field.nullable("", ArrowType.Utf8.INSTANCE),
+            new Field("", FieldType.nullable(ArrowType.List.INSTANCE),
+                    Collections.singletonList(Field.nullable("", 
Types.MinorType.INT.getType())))));
+    final List<List<Object>> expected = 
Collections.singletonList(Arrays.asList(
+                    new Text("foo"),
+                    new Integer[]{1, 2, null}));
+
     PRODUCER.addSelectQuery(query, schema,
             Collections.singletonList(listener -> {
               try (final BufferAllocator allocator = new 
RootAllocator(Long.MAX_VALUE);
@@ -105,11 +115,12 @@ public class ArrowFlightPreparedStatementTest {
               }
             }));
 
-    PRODUCER.addExpectedParameters(query,
-            new Schema(Collections.singletonList(Field.nullable("", 
ArrowType.Utf8.INSTANCE))),
-            Collections.singletonList(Collections.singletonList(new 
Text("foo".getBytes(StandardCharsets.UTF_8)))));
+    PRODUCER.addExpectedParameters(query, parameterSchema, expected);
+
     try (final PreparedStatement preparedStatement = 
connection.prepareStatement(query)) {
       preparedStatement.setString(1, "foo");
+      preparedStatement.setArray(2, connection.createArrayOf("INTEGER", new 
Integer[]{1, 2, null}));
+
       try (final ResultSet resultSet = preparedStatement.executeQuery()) {
         resultSet.next();
         assert true;
@@ -171,17 +182,29 @@ public class ArrowFlightPreparedStatementTest {
   @Test
   public void testUpdateQueryWithBatchedParameters() throws SQLException {
     String query = "Fake update with batched parameters";
-    PRODUCER.addUpdateQuery(query, /*updatedRows*/42);
-    PRODUCER.addExpectedParameters(query,
-            new Schema(Collections.singletonList(Field.nullable("", 
ArrowType.Utf8.INSTANCE))),
+    Schema parameterSchema = new Schema(Arrays.asList(
+            Field.nullable("", ArrowType.Utf8.INSTANCE),
+            new Field("", FieldType.nullable(ArrowType.List.INSTANCE),
+                    Collections.singletonList(Field.nullable("", 
Types.MinorType.INT.getType())))));
+    List<List<Object>> expected = Arrays.asList(
+            Arrays.asList(
+                    new Text("foo"),
+                    new Integer[]{1, 2, null}),
             Arrays.asList(
-                    Collections.singletonList(new 
Text("foo".getBytes(StandardCharsets.UTF_8))),
-                    Collections.singletonList(new 
Text("bar".getBytes(StandardCharsets.UTF_8)))));
+                    new Text("bar"),
+                    new Integer[]{0, -1, 100000})
+            );
+
+    PRODUCER.addUpdateQuery(query, /*updatedRows*/42);
+    PRODUCER.addExpectedParameters(query, parameterSchema, expected);
+
     try (final PreparedStatement stmt = connection.prepareStatement(query)) {
       // TODO: make sure this is validated on the server too
       stmt.setString(1, "foo");
+      stmt.setArray(2, connection.createArrayOf("INTEGER", new Integer[]{1, 2, 
null}));
       stmt.addBatch();
       stmt.setString(1, "bar");
+      stmt.setArray(2, connection.createArrayOf("INTEGER", new Integer[]{0, 
-1, 100000}));
       stmt.addBatch();
       int[] updated = stmt.executeBatch();
       assertEquals(42, updated[0]);
diff --git 
a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java
 
b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java
index 2b65f8f5a0..eaba008fbf 100644
--- 
a/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java
+++ 
b/java/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java
@@ -29,6 +29,7 @@ import java.nio.ByteBuffer;
 import java.nio.channels.Channels;
 import java.nio.charset.StandardCharsets;
 import java.util.AbstractMap.SimpleImmutableEntry;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -80,6 +81,7 @@ import org.apache.arrow.vector.VectorSchemaRoot;
 import org.apache.arrow.vector.ipc.WriteChannel;
 import org.apache.arrow.vector.ipc.message.MessageSerializer;
 import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.arrow.vector.util.JsonStringArrayList;
 import org.apache.calcite.avatica.Meta.StatementType;
 
 import com.google.protobuf.Any;
@@ -373,7 +375,13 @@ public final class MockFlightSqlProducer implements 
FlightSqlProducer {
           for (int paramIndex = 0; paramIndex < expectedRow.size(); 
paramIndex++) {
             Object expected = expectedRow.get(paramIndex);
             Object actual = root.getVector(paramIndex).getObject(i);
-            if (!Objects.equals(expected, actual)) {
+            boolean matches;
+            if (expected.getClass().isArray()) {
+              matches = Arrays.equals((Object[]) expected, 
((JsonStringArrayList) actual).toArray());
+            } else {
+              matches = Objects.equals(expected, actual);
+            }
+            if (!matches) {
               streamListener.onError(CallStatus.INVALID_ARGUMENT
                   .withDescription(String.format("Parameter mismatch. 
Expected: %s Actual: %s", expected, actual))
                   .toRuntimeException());

Reply via email to