This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 1080888  feat(java/driver/jdbc): expose constraints in GetObjects 
(#474)
1080888 is described below

commit 10808885d5139e603233c7773847b4cf9ee6873c
Author: Tornike Gurgenidze <[email protected]>
AuthorDate: Mon Feb 27 17:22:05 2023 +0400

    feat(java/driver/jdbc): expose constraints in GetObjects (#474)
    
    - ObjectMetadataBuilder has been partly rewritten to use Writers instead
    of directly writing to Vectors, allowing to get rid of much of row index
    accounting.
    - ObjectMetadataBuilder handles UNIQUE constraints.
    - Added tests for constraints. Right now SQL statements are issued to
    create PK and FK keys directly without any changes in Quirks. Seems to
    be sufficient for both Derby and Postgres. Probably will have to be
    changed for more complicated scenarios.
    
    Fixes #471.
---
 .../flightsql/FlightSqlConnectionMetadataTest.java |   4 +
 .../adbc/driver/jdbc/ObjectMetadataBuilder.java    | 178 ++++++++++++---------
 .../testsuite/AbstractConnectionMetadataTest.java  |  94 +++++++++++
 .../arrow/adbc/driver/testsuite/SqlTestUtil.java   | 125 +++++++++++++++
 .../adbc/driver/testsuite/SqlValidationQuirks.java |  46 ++++++
 5 files changed, 372 insertions(+), 75 deletions(-)

diff --git 
a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionMetadataTest.java
 
b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionMetadataTest.java
index da56280..605da7f 100644
--- 
a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionMetadataTest.java
+++ 
b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionMetadataTest.java
@@ -27,6 +27,10 @@ public class FlightSqlConnectionMetadataTest extends 
AbstractConnectionMetadataT
     quirks = new FlightSqlQuirks();
   }
 
+  @Override
+  @Disabled("Not yet implemented")
+  public void getObjectsConstraints() throws Exception {}
+
   @Override
   @Disabled("Not yet implemented")
   public void getObjectsColumns() throws Exception {}
diff --git 
a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/ObjectMetadataBuilder.java
 
b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/ObjectMetadataBuilder.java
index abc7cf9..93c1133 100644
--- 
a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/ObjectMetadataBuilder.java
+++ 
b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/ObjectMetadataBuilder.java
@@ -23,9 +23,12 @@ import java.sql.ResultSet;
 import java.sql.SQLException;
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import org.apache.arrow.adbc.core.AdbcConnection;
 import org.apache.arrow.adbc.core.StandardSchemas;
+import org.apache.arrow.memory.ArrowBuf;
 import org.apache.arrow.memory.BufferAllocator;
 import org.apache.arrow.util.AutoCloseables;
 import org.apache.arrow.vector.IntVector;
@@ -34,6 +37,10 @@ import org.apache.arrow.vector.VarCharVector;
 import org.apache.arrow.vector.VectorSchemaRoot;
 import org.apache.arrow.vector.complex.ListVector;
 import org.apache.arrow.vector.complex.StructVector;
+import org.apache.arrow.vector.complex.impl.UnionListWriter;
+import org.apache.arrow.vector.complex.writer.BaseWriter.ListWriter;
+import org.apache.arrow.vector.complex.writer.BaseWriter.StructWriter;
+import org.apache.arrow.vector.complex.writer.VarCharWriter;
 
 /** Helper class to track state needed to build up the object metadata 
structure. */
 final class ObjectMetadataBuilder implements AutoCloseable {
@@ -61,17 +68,18 @@ final class ObjectMetadataBuilder implements AutoCloseable {
   final VarCharVector columnRemarks;
   final SmallIntVector columnXdbcDataTypes;
   final ListVector tableConstraints;
-  final StructVector constraints;
-  final VarCharVector constraintNames;
-  final VarCharVector constraintTypes;
-  final ListVector constraintColumnNames;
-  final VarCharVector constraintColumnNameItems;
-  final ListVector constraintColumnUsage;
-  final StructVector columnUsages;
-  final VarCharVector columnUsageFkCatalogs;
-  final VarCharVector columnUsageFkDbSchemas;
-  final VarCharVector columnUsageFkTables;
-  final VarCharVector columnUsageFkColumns;
+  final UnionListWriter tableConstraintsWriter;
+  final StructWriter tableConstraintsStructWriter;
+  final VarCharWriter constraintNamesWriter;
+  final VarCharWriter constraintTypesWriter;
+  final ListWriter constraintColumnNamesWriter;
+  final ListWriter constraintColumnUsageWriter;
+  final StructWriter constraintColumnUsageStructWriter;
+  final VarCharWriter constraintColumnUsageFkCatalogsWriter;
+  final VarCharWriter constraintColumnUsageFkDbSchemasWriter;
+  final VarCharWriter constraintColumnUsageFkTablesWriter;
+  final VarCharWriter constraintColumnUsageFkColumnsWriter;
+  final BufferAllocator allocator;
 
   ObjectMetadataBuilder(
       BufferAllocator allocator,
@@ -83,6 +91,7 @@ final class ObjectMetadataBuilder implements AutoCloseable {
       final String[] tableTypesFilter,
       final String columnNamePattern)
       throws SQLException {
+    this.allocator = allocator;
     this.depth = depth;
     this.catalogPattern = catalogPattern;
     this.dbSchemaPattern = dbSchemaPattern;
@@ -106,17 +115,21 @@ final class ObjectMetadataBuilder implements 
AutoCloseable {
     this.columnRemarks = (VarCharVector) columns.getVectorById(2);
     this.columnXdbcDataTypes = (SmallIntVector) columns.getVectorById(3);
     this.tableConstraints = (ListVector) tables.getVectorById(3);
-    this.constraints = (StructVector) tableConstraints.getDataVector();
-    this.constraintNames = (VarCharVector) constraints.getVectorById(0);
-    this.constraintTypes = (VarCharVector) constraints.getVectorById(1);
-    this.constraintColumnNames = (ListVector) constraints.getVectorById(2);
-    this.constraintColumnNameItems = (VarCharVector) 
constraintColumnNames.getDataVector();
-    this.constraintColumnUsage = (ListVector) constraints.getVectorById(3);
-    this.columnUsages = (StructVector) constraintColumnUsage.getDataVector();
-    this.columnUsageFkCatalogs = (VarCharVector) columnUsages.getVectorById(0);
-    this.columnUsageFkDbSchemas = (VarCharVector) 
columnUsages.getVectorById(1);
-    this.columnUsageFkTables = (VarCharVector) columnUsages.getVectorById(2);
-    this.columnUsageFkColumns = (VarCharVector) columnUsages.getVectorById(3);
+    this.tableConstraintsWriter = this.tableConstraints.getWriter();
+    this.tableConstraintsStructWriter = this.tableConstraintsWriter.struct();
+    this.constraintNamesWriter = 
this.tableConstraintsWriter.varChar("constraint_name");
+    this.constraintTypesWriter = 
this.tableConstraintsWriter.varChar("constraint_type");
+    this.constraintColumnNamesWriter = 
this.tableConstraintsWriter.list("constraint_column_names");
+    this.constraintColumnUsageWriter = 
this.tableConstraintsWriter.list("constraint_column_usage");
+    this.constraintColumnUsageStructWriter = 
this.constraintColumnUsageWriter.struct();
+    this.constraintColumnUsageFkCatalogsWriter =
+        this.constraintColumnUsageStructWriter.varChar("fk_catalog");
+    this.constraintColumnUsageFkDbSchemasWriter =
+        this.constraintColumnUsageStructWriter.varChar("fk_db_schema");
+    this.constraintColumnUsageFkTablesWriter =
+        this.constraintColumnUsageStructWriter.varChar("fk_table");
+    this.constraintColumnUsageFkColumnsWriter =
+        this.constraintColumnUsageStructWriter.varChar("fk_column_name");
   }
 
   VectorSchemaRoot build() throws SQLException {
@@ -180,36 +193,34 @@ final class ObjectMetadataBuilder implements 
AutoCloseable {
     int tableCount = 0;
     try (final ResultSet rs =
         dbmd.getTables(catalogName, dbSchemaName, tableNamePattern, 
tableTypesFilter)) {
+
       while (rs.next()) {
         final String tableName = rs.getString(3);
         final String tableType = rs.getString(4);
         tables.setIndexDefined(rowIndex + tableCount);
         tableNames.setSafe(rowIndex + tableCount, 
tableName.getBytes(StandardCharsets.UTF_8));
         tableTypes.setSafe(rowIndex + tableCount, 
tableType.getBytes(StandardCharsets.UTF_8));
-        final int constraintOffset = tableConstraints.startNewValue(rowIndex + 
tableCount);
-        int constraintCount = 0;
+        tableConstraintsWriter.setPosition(rowIndex + tableCount);
+        tableConstraintsWriter.startList();
+
         // JDBC doesn't directly expose constraints. Merge various info 
methods:
         // 1. Primary keys
         try (final ResultSet pk = dbmd.getPrimaryKeys(catalogName, 
dbSchemaName, tableName)) {
           String constraintName = null;
           List<String> constraintColumns = new ArrayList<>();
-          if (pk.next()) {
-            while (pk.next()) {
-              constraintName = pk.getString(6);
-              String columnName = pk.getString(4);
-              int columnIndex = pk.getInt(5);
-              while (constraintColumns.size() < columnIndex) 
constraintColumns.add(null);
-              constraintColumns.set(columnIndex - 1, columnName);
-            }
+          while (pk.next()) {
+            constraintName = pk.getString(6);
+            String columnName = pk.getString(4);
+            int columnIndex = pk.getInt(5);
+            while (constraintColumns.size() < columnIndex) 
constraintColumns.add(null);
+            constraintColumns.set(columnIndex - 1, columnName);
+          }
+          if (!constraintColumns.isEmpty()) {
             addConstraint(
-                constraintOffset + constraintCount,
-                constraintName,
-                "PRIMARY KEY",
-                constraintColumns,
-                Collections.emptyList());
-            constraintCount++;
+                constraintName, "PRIMARY KEY", constraintColumns, 
Collections.emptyList());
           }
         }
+
         // 2. Foreign keys ("imported" keys)
         try (final ResultSet fk = dbmd.getImportedKeys(catalogName, 
dbSchemaName, tableName)) {
           List<String> names = new ArrayList<>();
@@ -234,20 +245,36 @@ final class ObjectMetadataBuilder implements 
AutoCloseable {
           }
 
           for (int i = 0; i < names.size(); i++) {
-            addConstraint(
-                constraintOffset + constraintCount,
-                names.get(i),
-                "FOREIGN KEY",
-                columns.get(i),
-                references.get(i));
-            constraintCount++;
+            addConstraint(names.get(i), "FOREIGN KEY", columns.get(i), 
references.get(i));
           }
         }
 
-        // TODO: UNIQUE constraints are exposed under indices
+        // 3. UNIQUE constraints
+        try (final ResultSet uq =
+            dbmd.getIndexInfo(catalogName, dbSchemaName, tableName, true, 
false)) {
+          Map<String, ArrayList<String>> uniqueConstraints = new HashMap<>();
+          while (uq.next()) {
+            String constraintName = uq.getString(6);
+            String columnName = uq.getString(9);
+            int columnIndex = uq.getInt(8);
+
+            if (!uniqueConstraints.containsKey(constraintName)) {
+              uniqueConstraints.put(constraintName, new ArrayList<>());
+            }
+            ArrayList<String> uniqueColumns = 
uniqueConstraints.get(constraintName);
+            while (uniqueColumns.size() < columnIndex) uniqueColumns.add(null);
+            uniqueColumns.set(columnIndex - 1, columnName);
+          }
+
+          uniqueConstraints.forEach(
+              (name, columns) -> {
+                addConstraint(name, "UNIQUE", columns, 
Collections.emptyList());
+              });
+        }
+
         // TODO: how to get CHECK constraints?
+        tableConstraintsWriter.endList();
 
-        tableConstraints.endValue(rowIndex + tableCount, constraintCount);
         if (depth == AdbcConnection.GetObjectsDepth.TABLES) {
           tableColumns.setNull(rowIndex + tableCount);
         } else {
@@ -288,43 +315,44 @@ final class ObjectMetadataBuilder implements 
AutoCloseable {
     return columnCount;
   }
 
+  private void writeVarChar(VarCharWriter writer, String value) {
+    byte[] bytes = value.getBytes(StandardCharsets.UTF_8);
+    try (ArrowBuf tempBuf = allocator.buffer(bytes.length)) {
+      tempBuf.setBytes(0, bytes, 0, bytes.length);
+      writer.writeVarChar(0, bytes.length, tempBuf);
+    }
+  }
+
   private void addConstraint(
-      int index,
       String constraintName,
       String constraintType,
       List<String> constraintColumns,
       List<ReferencedColumn> referencedColumns) {
-    if (constraintName == null) {
-      constraintNames.setNull(index);
-    } else {
-      constraintNames.setSafe(index, 
constraintName.getBytes(StandardCharsets.UTF_8));
-    }
-    constraintTypes.setSafe(index, 
constraintType.getBytes(StandardCharsets.UTF_8));
+    tableConstraintsStructWriter.start();
+
+    writeVarChar(this.constraintNamesWriter, constraintName);
+    writeVarChar(this.constraintTypesWriter, constraintType);
 
-    int namesOffset = constraintColumnNames.startNewValue(index);
-    for (final String column : constraintColumns) {
-      constraintColumnNameItems.setSafe(namesOffset++, 
column.getBytes(StandardCharsets.UTF_8));
+    constraintColumnNamesWriter.startList();
+    for (final String constraintColumn : constraintColumns) {
+      writeVarChar(constraintColumnNamesWriter.varChar(), constraintColumn);
     }
-    constraintColumnNames.endValue(index, constraintColumns.size());
-    int usageOffset = constraintColumnUsage.startNewValue(index);
-    for (final ReferencedColumn column : referencedColumns) {
-      columnUsages.setIndexDefined(usageOffset);
-      if (column.catalog == null) {
-        columnUsageFkCatalogs.setNull(usageOffset);
-      } else {
-        columnUsageFkCatalogs.setSafe(usageOffset, 
column.catalog.getBytes(StandardCharsets.UTF_8));
-      }
-      if (column.dbSchema == null) {
-        columnUsageFkDbSchemas.setNull(usageOffset);
-      } else {
-        columnUsageFkDbSchemas.setSafe(
-            usageOffset, column.dbSchema.getBytes(StandardCharsets.UTF_8));
+    constraintColumnNamesWriter.endList();
+
+    constraintColumnUsageWriter.startList();
+    for (ReferencedColumn referencedColumn : referencedColumns) {
+      constraintColumnUsageStructWriter.start();
+      if (referencedColumn.catalog != null) {
+        writeVarChar(constraintColumnUsageFkCatalogsWriter, 
referencedColumn.catalog);
       }
-      columnUsageFkTables.setSafe(usageOffset, 
column.table.getBytes(StandardCharsets.UTF_8));
-      columnUsageFkColumns.setSafe(usageOffset, 
column.column.getBytes(StandardCharsets.UTF_8));
-      usageOffset++;
+      writeVarChar(constraintColumnUsageFkDbSchemasWriter, 
referencedColumn.dbSchema);
+      writeVarChar(constraintColumnUsageFkTablesWriter, 
referencedColumn.table);
+      writeVarChar(constraintColumnUsageFkColumnsWriter, 
referencedColumn.column);
+      constraintColumnUsageStructWriter.end();
     }
-    constraintColumnUsage.endValue(index, referencedColumns.size());
+    constraintColumnUsageWriter.endList();
+
+    tableConstraintsStructWriter.end();
   }
 
   @Override
diff --git 
a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionMetadataTest.java
 
b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionMetadataTest.java
index 763f3fd..5514097 100644
--- 
a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionMetadataTest.java
+++ 
b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionMetadataTest.java
@@ -21,6 +21,7 @@ import static org.assertj.core.api.Assertions.assertThat;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
@@ -61,6 +62,8 @@ public abstract class AbstractConnectionMetadataTest {
   protected BufferAllocator allocator;
   protected SqlTestUtil util;
   protected String tableName;
+  protected String mainTable;
+  protected String dependentTable;
 
   @BeforeEach
   public void beforeEach() throws Exception {
@@ -70,11 +73,15 @@ public abstract class AbstractConnectionMetadataTest {
     allocator = new RootAllocator();
     util = new SqlTestUtil(quirks);
     tableName = quirks.caseFoldTableName("foo");
+    mainTable = quirks.caseFoldTableName("product");
+    dependentTable = quirks.caseFoldTableName("sale");
   }
 
   @AfterEach
   public void afterEach() throws Exception {
     quirks.cleanupTable(tableName);
+    quirks.cleanupTable(mainTable);
+    quirks.cleanupTable(dependentTable);
     AutoCloseables.close(connection, database, allocator);
   }
 
@@ -107,6 +114,93 @@ public abstract class AbstractConnectionMetadataTest {
     }
   }
 
+  @Test
+  public void getObjectsConstraints() throws Exception {
+    final Schema schema = util.ingestTableWithConstraints(allocator, 
connection, tableName);
+    util.ingestTablesWithReferentialConstraint(allocator, connection, 
mainTable, dependentTable);
+
+    boolean tableFound = false;
+    try (final ArrowReader reader =
+        connection.getObjects(AdbcConnection.GetObjectsDepth.ALL, null, null, 
null, null, null)) {
+      assertThat(reader.getVectorSchemaRoot().getSchema())
+          .isEqualTo(StandardSchemas.GET_OBJECTS_SCHEMA);
+      assertThat(reader.loadNextBatch()).isTrue();
+
+      final ListVector dbSchemas = (ListVector) 
reader.getVectorSchemaRoot().getVector(1);
+      final ListVector dbSchemaTables =
+          (ListVector) ((StructVector) 
dbSchemas.getDataVector()).getVectorById(1);
+      final StructVector tables = (StructVector) 
dbSchemaTables.getDataVector();
+      final VarCharVector tableNames = (VarCharVector) tables.getVectorById(0);
+      final ListVector tableConstraints = (ListVector) tables.getVectorById(3);
+
+      for (int i = 0; i < tables.getValueCount(); i++) {
+        if (tables.isNull(i)) {
+          continue;
+        }
+        final Text tableName = tableNames.getObject(i);
+        if (tableName != null && 
tableName.toString().equalsIgnoreCase(this.tableName)) {
+          tableFound = true;
+
+          @SuppressWarnings("unchecked")
+          final List<Map<String, ?>> constraints =
+              (List<Map<String, ?>>) tableConstraints.getObject(i);
+
+          assertThat(constraints)
+              .filteredOn(c -> c.get("constraint_type").equals(new 
Text("PRIMARY KEY")))
+              .extracting("constraint_name")
+              .containsExactlyInAnyOrderElementsOf(
+                  Collections.singletonList(new 
Text(quirks.caseFoldColumnName("table_pk"))));
+
+          assertThat(constraints)
+              .filteredOn(c -> c.get("constraint_type").equals(new 
Text("PRIMARY KEY")))
+              .flatExtracting("constraint_column_names")
+              .containsExactlyInAnyOrderElementsOf(
+                  schema.getFields().stream()
+                      .map(field -> new Text(field.getName()))
+                      .collect(Collectors.toList()));
+
+          assertThat(constraints)
+              .filteredOn(c -> c.get("constraint_type").equals(new 
Text("UNIQUE")))
+              .extracting("constraint_name")
+              .hasSize(1);
+
+          assertThat(constraints)
+              .filteredOn(c -> c.get("constraint_type").equals(new 
Text("UNIQUE")))
+              .flatExtracting("constraint_column_names")
+              .containsExactlyInAnyOrderElementsOf(
+                  schema.getFields().stream()
+                      .map(field -> new Text(field.getName()))
+                      .collect(Collectors.toList()));
+        }
+
+        if (tableName != null && 
tableName.toString().equalsIgnoreCase(dependentTable)) {
+          @SuppressWarnings("unchecked")
+          final List<Map<String, ?>> constraints =
+              (List<Map<String, ?>>) tableConstraints.getObject(i);
+
+          assertThat(constraints)
+              .extracting("constraint_name")
+              .containsExactlyInAnyOrderElementsOf(
+                  Collections.singletonList(
+                      new Text(quirks.caseFoldColumnName("SALE_PRODUCT_FK"))));
+
+          assertThat(constraints)
+              .flatExtracting("constraint_column_names")
+              .containsExactlyInAnyOrderElementsOf(
+                  Collections.singletonList(new 
Text(quirks.caseFoldColumnName("product_id"))));
+
+          assertThat(constraints)
+              .flatExtracting("constraint_column_usage")
+              .asList()
+              .first()
+              .extracting("fk_table")
+              .isEqualTo(new Text(quirks.caseFoldColumnName("product")));
+        }
+      }
+      assertThat(tableFound).describedAs("Table FOO exists in 
metadata").isTrue();
+    }
+  }
+
   @Test
   public void getObjectsColumns() throws Exception {
     final Schema schema = util.ingestTableIntsStrs(allocator, connection, 
tableName);
diff --git 
a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlTestUtil.java
 
b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlTestUtil.java
index fe6a3ba..623bfab 100644
--- 
a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlTestUtil.java
+++ 
b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlTestUtil.java
@@ -19,6 +19,7 @@ package org.apache.arrow.adbc.driver.testsuite;
 
 import java.nio.charset.StandardCharsets;
 import java.util.Arrays;
+import java.util.Collections;
 import org.apache.arrow.adbc.core.AdbcConnection;
 import org.apache.arrow.adbc.core.AdbcStatement;
 import org.apache.arrow.adbc.core.BulkIngestMode;
@@ -69,4 +70,128 @@ public final class SqlTestUtil {
     }
     return schema;
   }
+
+  /** Load a table with composite primary key */
+  public Schema ingestTableWithConstraints(
+      BufferAllocator allocator, AdbcConnection connection, String tableName) 
throws Exception {
+    tableName = quirks.caseFoldTableName(tableName);
+    final Schema schema =
+        new Schema(
+            Arrays.asList(
+                Field.notNullable(
+                    quirks.caseFoldColumnName("INTS"), new ArrowType.Int(32, 
/*signed=*/ true)),
+                Field.nullable(quirks.caseFoldColumnName("INTS2"), new 
ArrowType.Int(32, true))));
+    try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, 
allocator)) {
+      final IntVector ints = (IntVector) root.getVector(0);
+      final IntVector strs = (IntVector) root.getVector(1);
+
+      ints.allocateNew(4);
+      ints.setSafe(0, 0);
+      ints.setSafe(1, 1);
+      ints.setSafe(2, 2);
+      ints.setSafe(3, 3);
+      strs.allocateNew(4);
+      strs.setSafe(0, 10);
+      strs.setSafe(1, 11);
+      strs.setSafe(2, 12);
+      strs.setSafe(3, 13);
+      root.setRowCount(4);
+      try (final AdbcStatement stmt = connection.bulkIngest(tableName, 
BulkIngestMode.CREATE)) {
+        stmt.bind(root);
+        stmt.executeUpdate();
+      }
+
+      try (final AdbcStatement stmt = connection.createStatement()) {
+        stmt.setSqlQuery(quirks.generateSetNotNullQuery(tableName, "INTS"));
+        stmt.executeUpdate();
+      }
+
+      try (final AdbcStatement stmt = connection.createStatement()) {
+        stmt.setSqlQuery(quirks.generateSetNotNullQuery(tableName, "INTS2"));
+        stmt.executeUpdate();
+      }
+
+      try (final AdbcStatement stmt = connection.createStatement()) {
+        stmt.setSqlQuery(
+            quirks.generateAddPrimaryKeyQuery(
+                "TABLE_PK", tableName, Arrays.asList("INTS", "INTS2")));
+        stmt.executeUpdate();
+      }
+    }
+    return schema;
+  }
+
+  /** Load two tables with foreign key relationship between them */
+  public void ingestTablesWithReferentialConstraint(
+      BufferAllocator allocator, AdbcConnection connection, String mainTable, 
String dependentTable)
+      throws Exception {
+    mainTable = quirks.caseFoldTableName(mainTable);
+    dependentTable = quirks.caseFoldTableName(dependentTable);
+
+    final Schema mainSchema =
+        new Schema(
+            Collections.singletonList(
+                Field.notNullable(
+                    quirks.caseFoldColumnName("PRODUCT_ID"),
+                    new ArrowType.Int(32, /*signed=*/ true))));
+
+    final Schema dependentSchema =
+        new Schema(
+            Arrays.asList(
+                Field.notNullable(
+                    quirks.caseFoldColumnName("SALE_ID"), new 
ArrowType.Int(32, true)),
+                Field.notNullable(
+                    quirks.caseFoldColumnName("PRODUCT_ID"), new 
ArrowType.Int(32, true))));
+
+    try (final VectorSchemaRoot root = VectorSchemaRoot.create(mainSchema, 
allocator)) {
+      final IntVector product = (IntVector) root.getVector(0);
+      product.allocateNew(4);
+      product.setSafe(0, 1);
+      product.setSafe(1, 2);
+      product.setSafe(2, 3);
+      product.setSafe(3, 4);
+      root.setRowCount(4);
+      try (final AdbcStatement stmt = connection.bulkIngest(mainTable, 
BulkIngestMode.CREATE)) {
+        stmt.bind(root);
+        stmt.executeUpdate();
+      }
+    }
+
+    try (final VectorSchemaRoot root = 
VectorSchemaRoot.create(dependentSchema, allocator)) {
+      final IntVector sale = (IntVector) root.getVector(0);
+      final IntVector product = (IntVector) root.getVector(1);
+
+      sale.allocateNew(2);
+      sale.setSafe(0, 1);
+      sale.setSafe(1, 2);
+      product.allocateNew(2);
+      product.setSafe(0, 2);
+      product.setSafe(1, 4);
+      root.setRowCount(2);
+      try (final AdbcStatement stmt =
+          connection.bulkIngest(dependentTable, BulkIngestMode.CREATE)) {
+        stmt.bind(root);
+        stmt.executeUpdate();
+      }
+    }
+
+    try (final AdbcStatement stmt = connection.createStatement()) {
+      stmt.setSqlQuery(quirks.generateSetNotNullQuery(mainTable, 
"PRODUCT_ID"));
+      stmt.executeUpdate();
+    }
+
+    try (final AdbcStatement stmt = connection.createStatement()) {
+      stmt.setSqlQuery(
+          quirks.generateAddPrimaryKeyQuery(
+              "PRODUCT_PK", mainTable, 
Collections.singletonList("PRODUCT_ID")));
+      stmt.executeUpdate();
+    }
+
+    try (final AdbcStatement stmt = connection.createStatement()) {
+      stmt.setSqlQuery(
+          quirks.generateAddForeignKeyQuery(
+              "SALE_PRODUCT_FK", dependentTable, "PRODUCT_ID", mainTable, 
"PRODUCT_ID"));
+      stmt.executeUpdate();
+    }
+  }
 }
diff --git 
a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlValidationQuirks.java
 
b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlValidationQuirks.java
index 9ad6883..ee60e8b 100644
--- 
a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlValidationQuirks.java
+++ 
b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlValidationQuirks.java
@@ -17,6 +17,8 @@
 
 package org.apache.arrow.adbc.driver.testsuite;
 
+import java.util.List;
+import java.util.stream.Collectors;
 import org.apache.arrow.adbc.core.AdbcDatabase;
 import org.apache.arrow.adbc.core.AdbcException;
 
@@ -35,4 +37,48 @@ public abstract class SqlValidationQuirks {
   public String caseFoldColumnName(String name) {
     return name;
   }
+
+  /** Generates a query to set a column to NOT NULL in a table. */
+  public String generateSetNotNullQuery(String table, String column) {
+    return "ALTER TABLE "
+        + caseFoldTableName(table)
+        + " ALTER COLUMN "
+        + caseFoldColumnName(column)
+        + " SET NOT NULL";
+  }
+
+  public String generateAddPrimaryKeyQuery(
+      String constraintName, String table, List<String> columns) {
+    return "ALTER TABLE "
+        + caseFoldTableName(table)
+        + " \n"
+        + "  ADD CONSTRAINT "
+        + constraintName
+        + " \n"
+        + "  PRIMARY KEY ("
+        + 
columns.stream().map(this::caseFoldColumnName).collect(Collectors.joining(","))
+        + ")";
+  }
+
+  public String generateAddForeignKeyQuery(
+      String constraintName,
+      String table,
+      String column,
+      String referenceTable,
+      String referenceColumn) {
+    return "ALTER TABLE "
+        + caseFoldTableName(table)
+        + " \n"
+        + "  ADD CONSTRAINT "
+        + constraintName
+        + " \n"
+        + "  FOREIGN KEY ("
+        + caseFoldColumnName(column)
+        + ") \n"
+        + "  REFERENCES "
+        + caseFoldTableName(referenceTable)
+        + " ("
+        + caseFoldColumnName(referenceColumn)
+        + ") ";
+  }
 }

Reply via email to