This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 3eb64c9 Clarify bulk ingestion semantics (#34)
3eb64c9 is described below
commit 3eb64c9d929e25d090561917935a36b88dc7b223
Author: David Li <[email protected]>
AuthorDate: Mon Jul 11 12:10:22 2022 -0400
Clarify bulk ingestion semantics (#34)
---
adbc.h | 9 +-
drivers/sqlite/sqlite.cc | 98 ++++++++++------------
drivers/sqlite/sqlite_test.cc | 57 ++++++++++++-
.../org/apache/arrow/adbc/core/AdbcConnection.java | 9 +-
.../arrow/adbc/driver/jdbc/JdbcDriverUtil.java | 16 +++-
.../arrow/adbc/driver/jdbc/JdbcStatement.java | 58 ++++++++++---
.../driver/testsuite/AbstractStatementTest.java | 43 +++++++++-
.../adbc/driver/testsuite/ArrowAssertions.java | 24 +++++-
8 files changed, 239 insertions(+), 75 deletions(-)
diff --git a/adbc.h b/adbc.h
index 3c2445b..b201a40 100644
--- a/adbc.h
+++ b/adbc.h
@@ -410,7 +410,9 @@ AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection*
connection,
///
/// 1. The column's ordinal position in the table (starting from 1).
/// 2. Database-specific description of the column.
-/// 3. Optional, JDBC/ODBC-compatible value.
+/// 3. Optional value. Should be null if not supported by the driver.
+/// xdbc_ values are meant to provide JDBC/ODBC-compatible metadata
+/// in an agnostic manner.
///
/// CONSTRAINT_SCHEMA is a Struct with fields:
///
@@ -719,6 +721,11 @@ AdbcStatusCode AdbcStatementSetOption(struct
AdbcStatement* statement, const cha
/// @{
/// \brief The name of the target table for a bulk insert.
+///
+/// The driver should attempt to create the table if it does not
+/// exist. If the table exists but has a different schema,
+/// ADBC_STATUS_ALREADY_EXISTS should be raised. Else, data should be
+/// appended to the target table.
#define ADBC_INGEST_OPTION_TARGET_TABLE "adbc.ingest.target_table"
/// }@
diff --git a/drivers/sqlite/sqlite.cc b/drivers/sqlite/sqlite.cc
index 8703e66..cb2cedb 100644
--- a/drivers/sqlite/sqlite.cc
+++ b/drivers/sqlite/sqlite.cc
@@ -87,11 +87,20 @@ AdbcStatusCode CheckRc(sqlite3* db, sqlite3_stmt* stmt, int
rc, const char* cont
return ADBC_STATUS_OK;
}
+template <typename CallbackFn>
+AdbcStatusCode DoQuery(sqlite3* db, sqlite3_stmt* stmt, struct AdbcError*
error,
+ CallbackFn&& callback) {
+ auto status = std::move(callback)();
+ std::ignore = CheckRc(db, stmt, sqlite3_finalize(stmt), "sqlite3_finalize",
error);
+ return status;
+}
+
template <typename CallbackFn>
AdbcStatusCode DoQuery(sqlite3* db, const char* query, struct AdbcError* error,
CallbackFn&& callback) {
sqlite3_stmt* stmt;
int rc = sqlite3_prepare_v2(db, query, std::strlen(query), &stmt,
/*pzTail=*/nullptr);
+ if (rc != SQLITE_OK) return CheckRc(db, stmt, rc, "sqlite3_prepare_v2",
error);
auto status = std::move(callback)(stmt);
std::ignore = CheckRc(db, stmt, sqlite3_finalize(stmt), "sqlite3_finalize",
error);
return status;
@@ -1045,28 +1054,11 @@ class SqliteStatementImpl {
}
sqlite3* db = connection_->db();
- sqlite3_stmt* stmt = nullptr;
- int rc = SQLITE_OK;
-
- auto check_status = [&](const arrow::Status& st) mutable {
- if (!st.ok()) {
- SetError(error, st);
- if (stmt) {
- rc = sqlite3_finalize(stmt);
- if (rc != SQLITE_OK) {
- SetError(db, "sqlite3_finalize", error);
- }
- }
- return ADBC_STATUS_IO;
- }
- return ADBC_STATUS_OK;
- };
// Create the table
- // TODO: parameter to choose append/overwrite/error
{
// XXX: not injection-safe
- std::string query = "CREATE TABLE ";
+ std::string query = "CREATE TABLE IF NOT EXISTS ";
query += bulk_table_;
query += " (";
const auto& fields = bind_parameters_->schema()->fields();
@@ -1076,19 +1068,15 @@ class SqliteStatementImpl {
}
query += ')';
- rc = sqlite3_prepare_v2(db, query.c_str(),
static_cast<int>(query.size()), &stmt,
- /*pzTail=*/nullptr);
- ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, "sqlite3_prepare_v2", error));
-
- rc = sqlite3_step(stmt);
- if (rc != SQLITE_DONE) return CheckRc(db, stmt, rc, "sqlite3_step",
error);
-
- rc = sqlite3_finalize(stmt);
- ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, "sqlite3_finalize", error));
+ ADBC_RETURN_NOT_OK(
+ DoQuery(db, query.c_str(), error, [&](sqlite3_stmt* stmt) ->
AdbcStatusCode {
+ const int rc = sqlite3_step(stmt);
+ if (rc == SQLITE_DONE) return ADBC_STATUS_OK;
+ return CheckRc(db, stmt, rc, "sqlite3_step", error);
+ }));
}
// Insert the rows
-
{
std::string query = "INSERT INTO ";
query += bulk_table_;
@@ -1099,37 +1087,43 @@ class SqliteStatementImpl {
query += '?';
}
query += ')';
- rc = sqlite3_prepare_v2(db, query.c_str(),
static_cast<int>(query.size()), &stmt,
- /*pzTail=*/nullptr);
- ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, query.c_str(), error));
- }
- while (true) {
- std::shared_ptr<arrow::RecordBatch> batch;
- auto status = bind_parameters_->Next().Value(&batch);
- ADBC_RETURN_NOT_OK(check_status(status));
- if (!batch) break;
+ sqlite3_stmt* stmt;
+ int rc = sqlite3_prepare_v2(db, query.c_str(),
static_cast<int>(query.size()),
+ &stmt, /*pzTail=*/nullptr);
+ if (rc != SQLITE_OK) {
+ std::ignore = CheckRc(db, stmt, rc, "sqlite3_prepare_v2", error);
+ return ADBC_STATUS_ALREADY_EXISTS;
+ }
+ ADBC_RETURN_NOT_OK(DoQuery(db, stmt, error, [&]() -> AdbcStatusCode {
+ int rc = SQLITE_OK;
+ while (true) {
+ std::shared_ptr<arrow::RecordBatch> batch;
+ ADBC_RETURN_NOT_OK(
+ FromArrowStatus(bind_parameters_->Next().Value(&batch), error));
+ if (!batch) break;
- for (int64_t row = 0; row < batch->num_rows(); row++) {
- // TODO: if this fails we won't release the statement
- ADBC_RETURN_NOT_OK(BindParameters(stmt, *batch, row, &rc, error));
- ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, "sqlite3_bind", error));
+ for (int64_t row = 0; row < batch->num_rows(); row++) {
+ ADBC_RETURN_NOT_OK(BindParameters(stmt, *batch, row, &rc, error));
+ ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, "sqlite3_bind", error));
- rc = sqlite3_step(stmt);
- if (rc != SQLITE_DONE) {
- return CheckRc(db, stmt, rc, "sqlite3_step", error);
- }
+ rc = sqlite3_step(stmt);
+ if (rc != SQLITE_DONE) {
+ return CheckRc(db, stmt, rc, "sqlite3_step", error);
+ }
- rc = sqlite3_reset(stmt);
- ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, "sqlite3_reset", error));
+ rc = sqlite3_reset(stmt);
+ ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, "sqlite3_reset", error));
- rc = sqlite3_clear_bindings(stmt);
- ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, "sqlite3_clear_bindings",
error));
- }
+ rc = sqlite3_clear_bindings(stmt);
+ ADBC_RETURN_NOT_OK(CheckRc(db, stmt, rc, "sqlite3_clear_bindings",
error));
+ }
+ }
+ return ADBC_STATUS_OK;
+ }));
}
- rc = sqlite3_finalize(stmt);
- return CheckRc(db, nullptr, rc, "sqlite3_finalize", error);
+ return ADBC_STATUS_OK;
}
AdbcStatusCode ExecutePrepared(struct AdbcError* error) {
diff --git a/drivers/sqlite/sqlite_test.cc b/drivers/sqlite/sqlite_test.cc
index be232f5..6fa9d2e 100644
--- a/drivers/sqlite/sqlite_test.cc
+++ b/drivers/sqlite/sqlite_test.cc
@@ -247,10 +247,10 @@ TEST_F(Sqlite, BulkIngestTable) {
auto bulk_schema = arrow::schema(
{arrow::field("ints", arrow::int64()), arrow::field("strs",
arrow::utf8())});
auto bulk_table = adbc::RecordBatchFromJSON(bulk_schema, R"([[1, "foo"], [2,
"bar"]])");
- ASSERT_OK(ExportRecordBatch(*bulk_table, &export_table));
- ASSERT_OK(ExportSchema(*bulk_schema, &export_schema));
{
+ ASSERT_OK(ExportRecordBatch(*bulk_table, &export_table));
+ ASSERT_OK(ExportSchema(*bulk_schema, &export_schema));
AdbcStatement statement;
std::memset(&statement, 0, sizeof(statement));
ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement,
&error));
@@ -277,6 +277,59 @@ TEST_F(Sqlite, BulkIngestTable) {
ASSERT_SCHEMA_EQ(*schema, *bulk_schema);
EXPECT_THAT(batches, ::testing::UnorderedPointwise(PointeesEqual(),
{bulk_table}));
}
+
+ // Append
+ {
+ ASSERT_OK(ExportRecordBatch(*bulk_table, &export_table));
+ ASSERT_OK(ExportSchema(*bulk_schema, &export_schema));
+
+ AdbcStatement statement;
+ std::memset(&statement, 0, sizeof(statement));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement,
&error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcStatementSetOption(&statement,
ADBC_INGEST_OPTION_TARGET_TABLE,
+ "bulk_insert", &error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcStatementBind(&statement, &export_table, &export_schema,
&error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementRelease(&statement, &error));
+
+ std::memset(&statement, 0, sizeof(statement));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement,
&error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcStatementSetSqlQuery(&statement, "SELECT * FROM
bulk_insert", &error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementExecute(&statement, &error));
+
+ std::shared_ptr<arrow::Schema> schema;
+ arrow::RecordBatchVector batches;
+ ASSERT_NO_FATAL_FAILURE(ReadStatement(&statement, &schema, &batches));
+ ASSERT_SCHEMA_EQ(*schema, *bulk_schema);
+ EXPECT_THAT(
+ batches,
+ ::testing::UnorderedPointwise(
+ PointeesEqual(),
+ {adbc::RecordBatchFromJSON(
+ bulk_schema, R"([[1, "foo"], [2, "bar"], [1, "foo"], [2,
"bar"]])")}));
+ }
+
+ // Conflict
+ {
+ auto bulk_schema = arrow::schema({arrow::field("ints", arrow::int64())});
+ auto bulk_table = adbc::RecordBatchFromJSON(bulk_schema, R"([[1], [2]])");
+ ASSERT_OK(ExportRecordBatch(*bulk_table, &export_table));
+ ASSERT_OK(ExportSchema(*bulk_schema, &export_schema));
+
+ AdbcStatement statement;
+ std::memset(&statement, 0, sizeof(statement));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementNew(&connection, &statement,
&error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcStatementSetOption(&statement,
ADBC_INGEST_OPTION_TARGET_TABLE,
+ "bulk_insert", &error));
+ ADBC_ASSERT_OK_WITH_ERROR(
+ error, AdbcStatementBind(&statement, &export_table, &export_schema,
&error));
+ ASSERT_EQ(ADBC_STATUS_ALREADY_EXISTS, AdbcStatementExecute(&statement,
&error));
+ ADBC_ASSERT_OK_WITH_ERROR(error, AdbcStatementRelease(&statement, &error));
+ }
}
TEST_F(Sqlite, BulkIngestStream) {
diff --git
a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcConnection.java
b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcConnection.java
index 0917f65..b92a086 100644
--- a/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcConnection.java
+++ b/java/core/src/main/java/org/apache/arrow/adbc/core/AdbcConnection.java
@@ -34,7 +34,14 @@ public interface AdbcConnection extends AutoCloseable {
/** Create a new statement that can be executed. */
AdbcStatement createStatement() throws AdbcException;
- /** Create a new statement to bulk insert a {@link VectorSchemaRoot} into a
table. */
+ /**
+ * Create a new statement to bulk insert a {@link VectorSchemaRoot} into a
table.
+ *
+ * <p>Bind data to the statement, then call {@link AdbcStatement#execute()}.
The table will be
+ * created if it does not exist. Otherwise data will be appended.
<tt>execute()</tt> will throw
+ * AdbcException with status {@link AdbcStatusCode#ALREADY_EXISTS} if the
schema of the bound data
+ * does not match the table schema.
+ */
default AdbcStatement bulkIngest(String targetTableName) throws
AdbcException {
throw new UnsupportedOperationException("Connection does not support bulk
ingestion");
}
diff --git
a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcDriverUtil.java
b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcDriverUtil.java
index 80fcec6..1a9a6a9 100644
---
a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcDriverUtil.java
+++
b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcDriverUtil.java
@@ -29,7 +29,7 @@ final class JdbcDriverUtil {
return "[JDBC] " + s;
}
- static AdbcException fromSqlException(final SQLException e) {
+ static AdbcException fromSqlException(SQLException e) {
return new AdbcException(
prefixExceptionMessage(e.getMessage()),
e.getCause(),
@@ -37,4 +37,18 @@ final class JdbcDriverUtil {
e.getSQLState(),
e.getErrorCode());
}
+
+ static AdbcException fromSqlException(String format, SQLException e,
Object... values) {
+ return fromSqlException(AdbcStatusCode.UNKNOWN, format, e, values);
+ }
+
+ static AdbcException fromSqlException(
+ AdbcStatusCode status, String format, SQLException e, Object... values) {
+ return new AdbcException(
+ String.format(format, values) + prefixExceptionMessage(e.getMessage()),
+ e.getCause(),
+ status,
+ e.getSQLState(),
+ e.getErrorCode());
+ }
}
diff --git
a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java
b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java
index 56bb626..b95f9b8 100644
---
a/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java
+++
b/java/driver/jdbc/src/main/java/org/apache/arrow/adbc/driver/jdbc/JdbcStatement.java
@@ -18,6 +18,7 @@
package org.apache.arrow.adbc.driver.jdbc;
import java.sql.Connection;
+import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
@@ -25,6 +26,7 @@ import java.sql.Statement;
import java.util.Objects;
import org.apache.arrow.adbc.core.AdbcException;
import org.apache.arrow.adbc.core.AdbcStatement;
+import org.apache.arrow.adbc.core.AdbcStatusCode;
import org.apache.arrow.adbc.driver.jdbc.util.JdbcParameterBinder;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
@@ -80,12 +82,7 @@ public class JdbcStatement implements AdbcStatement {
}
}
- private void executeBulk() throws AdbcException {
- if (bindRoot == null) {
- throw new IllegalStateException("Must bind() before bulk insert");
- }
-
- // TODO: also create the table
+ private void createBulkTable() throws AdbcException {
final StringBuilder create = new StringBuilder("CREATE TABLE ");
create.append(bulkTargetTable);
create.append(" (");
@@ -135,6 +132,27 @@ public class JdbcStatement implements AdbcStatement {
} catch (SQLException e) {
throw JdbcDriverUtil.fromSqlException(e);
}
+ }
+
+ private void executeBulk() throws AdbcException {
+ if (bindRoot == null) {
+ throw new IllegalStateException("Must bind() before bulk insert");
+ }
+
+ // Check if table exists, create it if necessary.
+ // XXX: TOC/TOU fallacy.
+ try {
+ final DatabaseMetaData dbmd = connection.getMetaData();
+ try (final ResultSet rs =
+ dbmd.getTables(/*catalog*/ null, /*schema*/ null, bulkTargetTable,
/*types*/ null)) {
+ if (!rs.next()) {
+ createBulkTable();
+ }
+ }
+ } catch (SQLException e) {
+ throw JdbcDriverUtil.fromSqlException(
+ "Could not determine if table %s exists: ", e, bulkTargetTable);
+ }
// XXX: potential injection
// TODO: consider (optionally?) depending on jOOQ to generate SQL and
support different dialects
@@ -149,14 +167,28 @@ public class JdbcStatement implements AdbcStatement {
}
insert.append(")");
- try (final PreparedStatement statement =
connection.prepareStatement(insert.toString())) {
- final JdbcParameterBinder binder =
- JdbcParameterBinder.builder(statement, bindRoot).bindAll().build();
- statement.clearBatch();
- while (binder.next()) {
- statement.addBatch();
+ final PreparedStatement statement;
+ try {
+ statement = connection.prepareStatement(insert.toString());
+ } catch (SQLException e) {
+ throw JdbcDriverUtil.fromSqlException(
+ AdbcStatusCode.ALREADY_EXISTS,
+ "Could not bulk insert into table %s: ",
+ e,
+ bulkTargetTable);
+ }
+ try {
+ try {
+ final JdbcParameterBinder binder =
+ JdbcParameterBinder.builder(statement, bindRoot).bindAll().build();
+ statement.clearBatch();
+ while (binder.next()) {
+ statement.addBatch();
+ }
+ statement.executeBatch();
+ } finally {
+ statement.close();
}
- statement.executeBatch();
} catch (SQLException e) {
throw JdbcDriverUtil.fromSqlException(e);
}
diff --git
a/java/driver/testsuite/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java
b/java/driver/testsuite/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java
index fc10c56..f1870af 100644
---
a/java/driver/testsuite/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java
+++
b/java/driver/testsuite/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java
@@ -19,13 +19,16 @@ package org.apache.arrow.adbc.driver.testsuite;
import static
org.apache.arrow.adbc.driver.testsuite.ArrowAssertions.assertRoot;
import static org.assertj.core.api.Assertions.assertThat;
+import static org.junit.jupiter.api.Assertions.assertThrows;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
+import java.util.Collections;
import org.apache.arrow.adbc.core.AdbcConnection;
import org.apache.arrow.adbc.core.AdbcDatabase;
import org.apache.arrow.adbc.core.AdbcException;
import org.apache.arrow.adbc.core.AdbcStatement;
+import org.apache.arrow.adbc.core.AdbcStatusCode;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
@@ -82,7 +85,9 @@ public abstract class AbstractStatementTest {
strs.setSafe(3, "asdf".getBytes(StandardCharsets.UTF_8));
root.setRowCount(4);
- try (final AdbcStatement stmt = connection.bulkIngest("foo")) {
+ // TODO: XXX: need a "quirks" system to handle idiosyncracies. For
example: Derby forces table
+ // names to uppercase, but does not do case folding in all places.
+ try (final AdbcStatement stmt = connection.bulkIngest("FOO")) {
stmt.bind(root);
stmt.execute();
}
@@ -94,6 +99,42 @@ public abstract class AbstractStatementTest {
assertRoot(arrowReader.getVectorSchemaRoot()).isEqualTo(root);
}
}
+
+ // Append
+ try (final AdbcStatement stmt = connection.bulkIngest("FOO")) {
+ stmt.bind(root);
+ stmt.execute();
+ }
+ try (final AdbcStatement stmt = connection.createStatement()) {
+ stmt.setSqlQuery("SELECT * FROM FOO");
+ stmt.execute();
+ try (ArrowReader arrowReader = stmt.getArrowReader()) {
+ assertThat(arrowReader.loadNextBatch()).isTrue();
+ root.setRowCount(8);
+ ints.setSafe(4, 0);
+ ints.setSafe(5, 1);
+ ints.setSafe(6, 2);
+ ints.setNull(7);
+ strs.setNull(4);
+ strs.setSafe(5, "foo".getBytes(StandardCharsets.UTF_8));
+ strs.setSafe(6, "".getBytes(StandardCharsets.UTF_8));
+ strs.setSafe(7, "asdf".getBytes(StandardCharsets.UTF_8));
+ assertRoot(arrowReader.getVectorSchemaRoot()).isEqualTo(root);
+ }
+ }
+ }
+
+ // Conflict
+ final Schema schema2 =
+ new Schema(
+ Collections.singletonList(
+ Field.nullable("INTS", new ArrowType.Int(32, /*signed=*/
true))));
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema2,
allocator)) {
+ try (final AdbcStatement stmt = connection.bulkIngest("FOO")) {
+ stmt.bind(root);
+ final AdbcException e = assertThrows(AdbcException.class,
stmt::execute);
+ assertThat(e.getStatus()).isEqualTo(AdbcStatusCode.ALREADY_EXISTS);
+ }
}
}
}
diff --git
a/java/driver/testsuite/src/main/java/org/apache/arrow/adbc/driver/testsuite/ArrowAssertions.java
b/java/driver/testsuite/src/main/java/org/apache/arrow/adbc/driver/testsuite/ArrowAssertions.java
index 194fa00..2b3f280 100644
---
a/java/driver/testsuite/src/main/java/org/apache/arrow/adbc/driver/testsuite/ArrowAssertions.java
+++
b/java/driver/testsuite/src/main/java/org/apache/arrow/adbc/driver/testsuite/ArrowAssertions.java
@@ -17,7 +17,9 @@
package org.apache.arrow.adbc.driver.testsuite;
+import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.compare.VectorEqualsVisitor;
import org.assertj.core.api.AbstractAssert;
/** AssertJ assertions for Arrow. */
@@ -41,13 +43,27 @@ public final class ArrowAssertions {
expected.getClass().getName());
}
final VectorSchemaRoot expectedRoot = (VectorSchemaRoot) expected;
- if (!actual.equals(expectedRoot)) {
+ if (!actual.getSchema().equals(expectedRoot.getSchema())) {
throw failureWithActualExpected(
actual,
expected,
- "Expected Root:\n%sActual Root:\n%s",
- expectedRoot.contentToTSVString(),
- actual.contentToTSVString());
+ "Expected Schema:\n%sActual Schema:\n%s",
+ expectedRoot.getSchema(),
+ actual.getSchema());
+ }
+
+ for (int i = 0; i < expectedRoot.getSchema().getFields().size(); i++) {
+ final FieldVector expectedVector = expectedRoot.getVector(i);
+ final FieldVector actualVector = actual.getVector(i);
+ if (!VectorEqualsVisitor.vectorEquals(expectedVector, actualVector)) {
+ throw failureWithActualExpected(
+ actual,
+ expected,
+ "Vector %s does not match.\nExpected vector: %s\nActual vector
: %s",
+ expectedVector.getField(),
+ expectedVector,
+ actualVector);
+ }
}
return this;
}