lidavidm commented on code in PR #196: URL: https://github.com/apache/arrow-adbc/pull/196#discussion_r1033648780
########## c/driver/sqlite/statement_reader.c: ########## @@ -0,0 +1,832 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "statement_reader.h" + +#include <inttypes.h> +#include <math.h> +#include <stdio.h> + +#include <nanoarrow.h> +#include <sqlite3.h> + +#include "utils.h" + +AdbcStatusCode AdbcSqliteBinderSet(struct AdbcSqliteBinder* binder, + struct AdbcError* error) { + int rc = binder->params.get_schema(&binder->params, &binder->schema); + if (rc != 0) { + const char* message = binder->params.get_last_error(&binder->params); + if (!message) message = "(unknown error)"; + SetError(error, "Failed to get parameter schema: (%d) %s: %s", rc, strerror(rc), + message); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + struct ArrowError arrow_error = {0}; + rc = ArrowArrayViewInitFromSchema(&binder->batch, &binder->schema, &arrow_error); + if (rc != 0) { + SetError(error, "Failed to initialize array view: (%d) %s: %s", rc, strerror(rc), + arrow_error.message); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + if (binder->batch.storage_type != NANOARROW_TYPE_STRUCT) { + SetError(error, "Bind parameters do not have root type STRUCT"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + binder->types = + (enum ArrowType*)malloc(binder->schema.n_children * sizeof(enum ArrowType)); + + struct ArrowSchemaView view = {0}; + for (int i = 0; i < binder->schema.n_children; i++) { + rc = ArrowSchemaViewInit(&view, binder->schema.children[i], &arrow_error); + if (rc != 0) { + SetError(error, "Failed to parse schema for column %d: %s (%d): %s", i, + strerror(rc), rc, arrow_error.message); + return ADBC_STATUS_INVALID_ARGUMENT; + } + + if (view.data_type == NANOARROW_TYPE_UNINITIALIZED) { + SetError(error, "Column %d has UNINITIALIZED type", i); + return ADBC_STATUS_INTERNAL; + } + binder->types[i] = view.data_type; + } + + return ADBC_STATUS_OK; +} + +AdbcStatusCode AdbcSqliteBinderSetArray(struct AdbcSqliteBinder* binder, + struct ArrowArray* values, + struct ArrowSchema* schema, + struct AdbcError* error) { + AdbcSqliteBinderRelease(binder); + AdbcStatusCode status = BatchToArrayStream(values, schema, &binder->params, error); + if (status != ADBC_STATUS_OK) return status; + return AdbcSqliteBinderSet(binder, error); +} // NOLINT(whitespace/indent) +AdbcStatusCode AdbcSqliteBinderSetArrayStream(struct AdbcSqliteBinder* binder, + struct ArrowArrayStream* values, + struct AdbcError* error) { + AdbcSqliteBinderRelease(binder); + binder->params = *values; + memset(values, 0, sizeof(*values)); + return AdbcSqliteBinderSet(binder, error); +} +AdbcStatusCode AdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, sqlite3* conn, + sqlite3_stmt* stmt, char* finished, + struct AdbcError* error) { + struct ArrowError arrow_error = {0}; + int rc = 0; + while (!binder->array.release || binder->next_row >= binder->array.length) { + if (binder->array.release) { + ArrowArrayViewReset(&binder->batch); + binder->array.release(&binder->array); + + rc = ArrowArrayViewInitFromSchema(&binder->batch, &binder->schema, &arrow_error); + if (rc != 0) { + SetError(error, "Failed to initialize array view: (%d) %s: %s", rc, strerror(rc), + arrow_error.message); + return ADBC_STATUS_INTERNAL; + } + } + + rc = binder->params.get_next(&binder->params, &binder->array); + if (rc != 0) { + const char* message = binder->params.get_last_error(&binder->params); + if (!message) message = "(unknown error)"; + SetError(error, "Failed to get next parameter batch: (%d) %s: %s", rc, strerror(rc), + message); + return ADBC_STATUS_IO; + } + + if (!binder->array.release) { + *finished = 1; + AdbcSqliteBinderRelease(binder); + return ADBC_STATUS_OK; + } + + rc = ArrowArrayViewSetArray(&binder->batch, &binder->array, &arrow_error); + if (rc != 0) { + SetError(error, "Failed to initialize array view: (%d) %s: %s", rc, strerror(rc), + arrow_error.message); + return ADBC_STATUS_INTERNAL; + } + + binder->next_row = 0; + } + + if (sqlite3_reset(stmt) != SQLITE_OK) { + SetError(error, "Failed to reset statement: %s", sqlite3_errmsg(conn)); + return ADBC_STATUS_INTERNAL; + } + if (sqlite3_clear_bindings(stmt) != SQLITE_OK) { + SetError(error, "Failed to clear statement bindings: %s", sqlite3_errmsg(conn)); + return ADBC_STATUS_INTERNAL; + } + + for (int col = 0; col < binder->schema.n_children; col++) { + if (ArrowArrayViewIsNull(binder->batch.children[col], binder->next_row)) { + rc = sqlite3_bind_null(stmt, col + 1); + } else { + switch (binder->types[col]) { + case NANOARROW_TYPE_BINARY: + case NANOARROW_TYPE_LARGE_BINARY: { + struct ArrowBufferView value = + ArrowArrayViewGetBytesUnsafe(binder->batch.children[col], binder->next_row); + rc = sqlite3_bind_text(stmt, col + 1, value.data.as_char, value.n_bytes, + SQLITE_STATIC); + break; + } + case NANOARROW_TYPE_INT64: { + int64_t value = + ArrowArrayViewGetIntUnsafe(binder->batch.children[col], binder->next_row); + rc = sqlite3_bind_int64(stmt, col + 1, value); + break; + } + case NANOARROW_TYPE_STRING: + case NANOARROW_TYPE_LARGE_STRING: { + struct ArrowBufferView value = + ArrowArrayViewGetBytesUnsafe(binder->batch.children[col], binder->next_row); + rc = sqlite3_bind_text(stmt, col + 1, value.data.as_char, value.n_bytes, + SQLITE_STATIC); + break; + } + default: + SetError(error, "Column %d has unsupported type %d", col, binder->types[col]); + return ADBC_STATUS_NOT_IMPLEMENTED; + } + } + + if (rc != SQLITE_OK) { + SetError(error, "Failed to clear statement bindings: %s", sqlite3_errmsg(conn)); + return ADBC_STATUS_INTERNAL; + } + } + + binder->next_row++; + *finished = 0; + return ADBC_STATUS_OK; +} + +void AdbcSqliteBinderRelease(struct AdbcSqliteBinder* binder) { + if (binder->schema.release) { + binder->schema.release(&binder->schema); + } + if (binder->params.release) { + binder->params.release(&binder->params); + } + if (binder->types) { + free(binder->types); + } + if (binder->array.release) { + binder->array.release(&binder->array); + } + ArrowArrayViewReset(&binder->batch); + memset(binder, 0, sizeof(*binder)); +} + +struct StatementReader { + sqlite3* db; + sqlite3_stmt* stmt; + enum ArrowType* types; + struct ArrowSchema schema; + struct ArrowArray initial_batch; + struct AdbcSqliteBinder* binder; + struct ArrowError error; + char done; +}; + +const char* StatementReaderGetLastError(struct ArrowArrayStream* self) { + if (!self->release || !self->private_data) { + return NULL; + } + + struct StatementReader* reader = (struct StatementReader*)self->private_data; + return reader->error.message; +} + +static int kBatchSize = 1024; + +void StatementReaderSetError(struct StatementReader* reader) { + const char* msg = sqlite3_errmsg(reader->db); + strncpy(reader->error.message, msg, sizeof(reader->error.message)); + reader->error.message[sizeof(reader->error.message) - 1] = '\0'; +} + +int StatementReaderGetOneValue(struct StatementReader* reader, int col, + struct ArrowArray* out) { + int sqlite_type = sqlite3_column_type(reader->stmt, col); + + if (sqlite_type == SQLITE_NULL) { + return ArrowArrayAppendNull(out, 1); + } + + switch (reader->types[col]) { + case NANOARROW_TYPE_INT64: { + switch (sqlite_type) { + case SQLITE_INTEGER: { + int64_t value = sqlite3_column_int64(reader->stmt, col); + return ArrowArrayAppendInt(out, value); + } + case SQLITE_FLOAT: { + // TODO: behavior needs to be configurable + snprintf(reader->error.message, sizeof(reader->error.message), + "[SQLite] Type mismatch in column %d: expected INT64 but got DOUBLE", + col); + return EIO; + } + case SQLITE_TEXT: + case SQLITE_BLOB: { + snprintf( + reader->error.message, sizeof(reader->error.message), + "[SQLite] Type mismatch in column %d: expected INT64 but got STRING/BINARY", + col); + return EIO; + } + default: { + snprintf(reader->error.message, sizeof(reader->error.message), + "[SQLite] Type mismatch in column %d: expected INT64 but got unknown " + "type %d", + col, sqlite_type); + return ENOTSUP; + } + } + break; + } + + case NANOARROW_TYPE_DOUBLE: { + switch (sqlite_type) { + case SQLITE_INTEGER: + case SQLITE_FLOAT: { + // Let SQLite convert + double value = sqlite3_column_double(reader->stmt, col); + return ArrowArrayAppendDouble(out, value); + } + case SQLITE_TEXT: + case SQLITE_BLOB: { + snprintf(reader->error.message, sizeof(reader->error.message), + "[SQLite] Type mismatch in column %d: expected DOUBLE but got " + "STRING/BINARY", + col); + return EIO; + } + default: { + snprintf(reader->error.message, sizeof(reader->error.message), + "[SQLite] Type mismatch in column %d: expected DOUBLE but got unknown " + "type %d", + col, sqlite_type); + return ENOTSUP; + } + } + break; + } + + case NANOARROW_TYPE_STRING: { + switch (sqlite_type) { + case SQLITE_INTEGER: + case SQLITE_FLOAT: + case SQLITE_TEXT: + case SQLITE_BLOB: { + // Let SQLite convert + struct ArrowStringView value; + value.data = (const char*)sqlite3_column_text(reader->stmt, col); + value.n_bytes = sqlite3_column_bytes(reader->stmt, col); + return ArrowArrayAppendString(out, value); + } + default: { + snprintf(reader->error.message, sizeof(reader->error.message), + "[SQLite] Type mismatch in column %d: expected STRING but got unknown " + "type %d", + col, sqlite_type); + return ENOTSUP; + } + } + break; + } + + default: { + snprintf(reader->error.message, sizeof(reader->error.message), + "[SQLite] Internal error: unknown inferred column type %d", + reader->types[col]); + return ENOTSUP; + } + } + + return 0; +} + +int StatementReaderGetNext(struct ArrowArrayStream* self, struct ArrowArray* out) { + if (!self->release || !self->private_data) { + return EINVAL; + } + + struct StatementReader* reader = (struct StatementReader*)self->private_data; + if (reader->initial_batch.release != NULL) { + memcpy(out, &reader->initial_batch, sizeof(*out)); + memset(&reader->initial_batch, 0, sizeof(reader->initial_batch)); + return 0; + } else if (reader->done) { + out->release = NULL; + return 0; + } + + RAISE_NA(ArrowArrayInitFromSchema(out, &reader->schema, &reader->error)); + for (int i = 0; i < reader->schema.n_children; i++) { + RAISE_NA(ArrowArrayStartAppending(out->children[i])); + } + int64_t batch_size = 0; + int result = 0; + + sqlite3_mutex_enter(sqlite3_db_mutex(reader->db)); + while (batch_size < kBatchSize) { + if (reader->binder) { + char finished = 0; + struct AdbcError error = {0}; + AdbcStatusCode status = AdbcSqliteBinderBindNext(reader->binder, reader->db, + reader->stmt, &finished, &error); + if (status != ADBC_STATUS_OK) { + reader->done = 1; + result = EIO; + if (error.release) { + strncpy(reader->error.message, error.message, sizeof(reader->error.message)); + reader->error.message[sizeof(reader->error.message) - 1] = '\0'; + error.release(&error); + } + break; + } else if (finished) { + reader->done = 1; + break; + } + } + + int rc = sqlite3_step(reader->stmt); + if (rc == SQLITE_DONE) { + reader->done = 1; + break; + } else if (rc == SQLITE_ERROR) { + reader->done = 1; + result = EIO; + StatementReaderSetError(reader); + break; + } + + for (int col = 0; col < reader->schema.n_children; col++) { + result = StatementReaderGetOneValue(reader, col, out->children[col]); + if (result != 0) break; + } + + if (result != 0) break; + batch_size++; + } + if (result == 0) { + out->length = batch_size; + for (int i = 0; i < reader->schema.n_children; i++) { + RAISE_NA(ArrowArrayFinishBuilding(out->children[i], &reader->error)); + } + } + + sqlite3_mutex_leave(sqlite3_db_mutex(reader->db)); Review Comment: …yes. I really with I had `defer`. Or could turn on that GNU extension. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
