This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 2e044e85a fix(go/adbc/driver/snowflake): handle empty result sets 
(#1805)
2e044e85a is described below

commit 2e044e85a70f42130b70532fcd995e76a2671933
Author: David Li <[email protected]>
AuthorDate: Sat May 4 20:13:52 2024 +0900

    fix(go/adbc/driver/snowflake): handle empty result sets (#1805)
    
    Fixes #1804.
---
 c/driver/flightsql/dremio_flightsql_test.cc |  1 +
 c/validation/adbc_validation.h              |  2 ++
 c/validation/adbc_validation_statement.cc   | 35 ++++++++++++++++++++
 go/adbc/driver/snowflake/driver_test.go     | 19 +++++++++++
 go/adbc/driver/snowflake/record_reader.go   | 50 +++++++++++++++++------------
 5 files changed, 87 insertions(+), 20 deletions(-)

diff --git a/c/driver/flightsql/dremio_flightsql_test.cc 
b/c/driver/flightsql/dremio_flightsql_test.cc
index 8c59eb4a2..acc068279 100644
--- a/c/driver/flightsql/dremio_flightsql_test.cc
+++ b/c/driver/flightsql/dremio_flightsql_test.cc
@@ -92,6 +92,7 @@ class DremioFlightSqlStatementTest : public ::testing::Test,
   void TestSqlIngestColumnEscaping() {
     GTEST_SKIP() << "Column escaping not implemented";
   }
+  void TestSqlQueryEmpty() { GTEST_SKIP() << "Dremio doesn't support 
'acceptPut'"; }
   void TestSqlQueryRowsAffectedDelete() {
     GTEST_SKIP() << "Cannot query rows affected in delete (not implemented)";
   }
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index 6c59d95e0..abe9a7686 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -407,6 +407,7 @@ class StatementTest {
   void TestSqlPrepareErrorNoQuery();
   void TestSqlPrepareErrorParamCountMismatch();
 
+  void TestSqlQueryEmpty();
   void TestSqlQueryInts();
   void TestSqlQueryFloats();
   void TestSqlQueryStrings();
@@ -504,6 +505,7 @@ class StatementTest {
   TEST_F(FIXTURE, SqlPrepareErrorParamCountMismatch) {                         
         \
     TestSqlPrepareErrorParamCountMismatch();                                   
         \
   }                                                                            
         \
+  TEST_F(FIXTURE, SqlQueryEmpty) { TestSqlQueryEmpty(); }                      
         \
   TEST_F(FIXTURE, SqlQueryInts) { TestSqlQueryInts(); }                        
         \
   TEST_F(FIXTURE, SqlQueryFloats) { TestSqlQueryFloats(); }                    
         \
   TEST_F(FIXTURE, SqlQueryStrings) { TestSqlQueryStrings(); }                  
         \
diff --git a/c/validation/adbc_validation_statement.cc 
b/c/validation/adbc_validation_statement.cc
index 59f3f3f9a..333baf141 100644
--- a/c/validation/adbc_validation_statement.cc
+++ b/c/validation/adbc_validation_statement.cc
@@ -2062,6 +2062,41 @@ void 
StatementTest::TestSqlPrepareErrorParamCountMismatch() {
       ::testing::Not(IsOkStatus(&error)));
 }
 
+void StatementTest::TestSqlQueryEmpty() {
+  ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), 
IsOkStatus(&error));
+
+  ASSERT_THAT(quirks()->DropTable(&connection, "QUERYEMPTY", &error), 
IsOkStatus(&error));
+  ASSERT_THAT(
+      AdbcStatementSetSqlQuery(&statement, "CREATE TABLE QUERYEMPTY (FOO 
INT)", &error),
+      IsOkStatus(&error));
+  ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
+              IsOkStatus(&error));
+
+  ASSERT_THAT(
+      AdbcStatementSetSqlQuery(&statement, "SELECT * FROM QUERYEMPTY WHERE 
1=0", &error),
+      IsOkStatus(&error));
+  {
+    StreamReader reader;
+    ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+                                          &reader.rows_affected, &error),
+                IsOkStatus(&error));
+    ASSERT_THAT(reader.rows_affected,
+                ::testing::AnyOf(::testing::Eq(0), ::testing::Eq(-1)));
+
+    ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+    ASSERT_EQ(1, reader.schema->n_children);
+
+    while (true) {
+      ASSERT_NO_FATAL_FAILURE(reader.Next());
+      if (!reader.array->release) {
+        break;
+      }
+      ASSERT_EQ(0, reader.array->length);
+    }
+  }
+  ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
+}
+
 void StatementTest::TestSqlQueryInts() {
   ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), 
IsOkStatus(&error));
   ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error),
diff --git a/go/adbc/driver/snowflake/driver_test.go 
b/go/adbc/driver/snowflake/driver_test.go
index de175c3a0..af94e6108 100644
--- a/go/adbc/driver/snowflake/driver_test.go
+++ b/go/adbc/driver/snowflake/driver_test.go
@@ -2031,3 +2031,22 @@ func (suite *SnowflakeTests) TestMetadataOnlyQuery() {
        // all the rows from each record in the stream.
        suite.Equal(n, recv)
 }
+
+func (suite *SnowflakeTests) TestEmptyResultSet() {
+       // regression test for apache/arrow-adbc#1804
+       // this would previously crash
+       suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT 42 WHERE 1=0`))
+       rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx)
+       suite.Require().NoError(err)
+       defer rdr.Release()
+
+       recv := int64(0)
+       for rdr.Next() {
+               recv += rdr.Record().NumRows()
+       }
+
+       // verify that we got the exepected number of rows if we sum up
+       // all the rows from each record in the stream.
+       suite.Equal(n, recv)
+       suite.Equal(recv, int64(0))
+}
diff --git a/go/adbc/driver/snowflake/record_reader.go 
b/go/adbc/driver/snowflake/record_reader.go
index bda3e8f70..e404f116d 100644
--- a/go/adbc/driver/snowflake/record_reader.go
+++ b/go/adbc/driver/snowflake/record_reader.go
@@ -571,6 +571,34 @@ func newRecordReader(ctx context.Context, alloc 
memory.Allocator, ld gosnowflake
        }
 
        ch := make(chan arrow.Record, bufferSize)
+       group, ctx := errgroup.WithContext(compute.WithAllocator(ctx, alloc))
+       ctx, cancelFn := context.WithCancel(ctx)
+       group.SetLimit(prefetchConcurrency)
+
+       defer func() {
+               if err != nil {
+                       close(ch)
+                       cancelFn()
+               }
+       }()
+
+       chs := make([]chan arrow.Record, len(batches))
+       rdr := &reader{
+               refCount: 1,
+               chs:      chs,
+               err:      nil,
+               cancelFn: cancelFn,
+       }
+
+       if len(batches) == 0 {
+               schema, err := rowTypesToArrowSchema(ctx, ld, useHighPrecision)
+               if err != nil {
+                       return nil, err
+               }
+               rdr.schema, _ = getTransformer(schema, ld, useHighPrecision)
+               return rdr, nil
+       }
+
        r, err := batches[0].GetStream(ctx)
        if err != nil {
                return nil, errToAdbcErr(adbc.StatusIO, err)
@@ -584,19 +612,9 @@ func newRecordReader(ctx context.Context, alloc 
memory.Allocator, ld gosnowflake
                }
        }
 
-       group, ctx := errgroup.WithContext(compute.WithAllocator(ctx, alloc))
-       ctx, cancelFn := context.WithCancel(ctx)
-
-       schema, recTransform := getTransformer(rr.Schema(), ld, 
useHighPrecision)
+       var recTransform recordTransformer
+       rdr.schema, recTransform = getTransformer(rr.Schema(), ld, 
useHighPrecision)
 
-       defer func() {
-               if err != nil {
-                       close(ch)
-                       cancelFn()
-               }
-       }()
-
-       group.SetLimit(prefetchConcurrency)
        group.Go(func() error {
                defer rr.Release()
                defer r.Close()
@@ -615,15 +633,7 @@ func newRecordReader(ctx context.Context, alloc 
memory.Allocator, ld gosnowflake
                return rr.Err()
        })
 
-       chs := make([]chan arrow.Record, len(batches))
        chs[0] = ch
-       rdr := &reader{
-               refCount: 1,
-               chs:      chs,
-               err:      nil,
-               cancelFn: cancelFn,
-               schema:   schema,
-       }
 
        lastChannelIndex := len(chs) - 1
        go func() {

Reply via email to