This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch main
in repository 
https://gitbox.apache.org/repos/asf/arrow-flight-sql-postgresql.git


The following commit(s) were added to refs/heads/main by this push:
     new 3835280  Add support for Int16 (#65)
3835280 is described below

commit 3835280abb963879fea3936bbaf43eb969b0e3c6
Author: Sutou Kouhei <[email protected]>
AuthorDate: Tue Aug 22 14:50:39 2023 +0900

    Add support for Int16 (#65)
    
    Closes GH-48
---
 src/afs.cc              | 70 ++++++++++++++++++++++++++++++++++++++-----------
 test/test-flight-sql.rb | 18 ++++++++-----
 2 files changed, 66 insertions(+), 22 deletions(-)

diff --git a/src/afs.cc b/src/afs.cc
index 9082336..935df8a 100644
--- a/src/afs.cc
+++ b/src/afs.cc
@@ -738,6 +738,12 @@ class ArrowPGValueConverter : public arrow::ArrayVisitor {
                return arrow::Status::OK();
        }
 
+       arrow::Status Visit(const arrow::Int16Array& array)
+       {
+               datum_ = Int16GetDatum(array.Value(i_row_));
+               return arrow::Status::OK();
+       }
+
        arrow::Status Visit(const arrow::Int32Array& array)
        {
                datum_ = Int32GetDatum(array.Value(i_row_));
@@ -749,6 +755,44 @@ class ArrowPGValueConverter : public arrow::ArrayVisitor {
        Datum& datum_;
 };
 
+class PGArrowValueConverter : public arrow::ArrayVisitor {
+   public:
+       explicit PGArrowValueConverter(Form_pg_attribute attribute) : 
attribute_(attribute) {}
+
+       arrow::Result<std::shared_ptr<arrow::DataType>> convert_type() const
+       {
+               switch (attribute_->atttypid)
+               {
+                       case INT2OID:
+                               return arrow::int16();
+                       case INT4OID:
+                               return arrow::int32();
+                       default:
+                               return 
arrow::Status::NotImplemented("Unsupported PostgreSQL type: ",
+                                                                    
attribute_->atttypid);
+               }
+       }
+
+       arrow::Status convert_value(arrow::ArrayBuilder* builder, Datum datum) 
const
+       {
+               switch (attribute_->atttypid)
+               {
+                       case INT2OID:
+                               return 
static_cast<arrow::Int16Builder*>(builder)->Append(
+                                       DatumGetInt16(datum));
+                       case INT4OID:
+                               return 
static_cast<arrow::Int32Builder*>(builder)->Append(
+                                       DatumGetInt32(datum));
+                       default:
+                               return 
arrow::Status::NotImplemented("Unsupported PostgreSQL type: ",
+                                                                    
attribute_->atttypid);
+               }
+       }
+
+   private:
+       Form_pg_attribute attribute_;
+};
+
 class PreparedStatement {
    public:
        explicit PreparedStatement(std::string query) : 
query_(std::move(query)) {}
@@ -822,6 +866,7 @@ class PreparedStatement {
                        switch (field->type()->id())
                        {
                                case arrow::Type::INT8:
+                               case arrow::Type::INT16:
                                        pgTypes.push_back(INT2OID);
                                        break;
                                case arrow::Type::INT32:
@@ -1233,22 +1278,16 @@ class Executor : public WorkerProcessor {
        arrow::Status write()
        {
                SharedRingBufferOutputStream output(this, session_);
+               std::vector<PGArrowValueConverter> converters;
                std::vector<std::shared_ptr<arrow::Field>> fields;
                for (int i = 0; i < SPI_tuptable->tupdesc->natts; ++i)
                {
                        auto attribute = TupleDescAttr(SPI_tuptable->tupdesc, 
i);
-                       std::shared_ptr<arrow::DataType> type;
-                       switch (attribute->atttypid)
-                       {
-                               case INT4OID:
-                                       type = arrow::int32();
-                                       break;
-                               default:
-                                       return 
arrow::Status::NotImplemented("Unsupported PostgreSQL type: ",
-                                                                            
attribute->atttypid);
-                       }
-                       fields.push_back(
-                               arrow::field(NameStr(attribute->attname), type, 
!attribute->attnotnull));
+                       converters.emplace_back(attribute);
+                       const auto& converter = converters[converters.size() - 
1];
+                       ARROW_ASSIGN_OR_RAISE(auto type, 
converter.convert_type());
+                       fields.push_back(arrow::field(
+                               NameStr(attribute->attname), std::move(type), 
!attribute->attnotnull));
                }
                auto schema = arrow::schema(fields);
                ARROW_ASSIGN_OR_RAISE(
@@ -1293,16 +1332,15 @@ class Executor : public WorkerProcessor {
                                                           
SPI_tuptable->tupdesc,
                                                           iAttribute + 1,
                                                           &isNull);
+                               auto arrayBuilder = 
builder->GetField(iAttribute);
                                if (isNull)
                                {
-                                       auto arrayBuilder = 
builder->GetField(iAttribute);
                                        
ARROW_RETURN_NOT_OK(arrayBuilder->AppendNull());
                                }
                                else
                                {
-                                       auto arrayBuilder =
-                                               
builder->GetFieldAs<arrow::Int32Builder>(iAttribute);
-                                       
ARROW_RETURN_NOT_OK(arrayBuilder->Append(DatumGetInt32(datum)));
+                                       ARROW_RETURN_NOT_OK(
+                                               
converters[iAttribute].convert_value(arrayBuilder, datum));
                                }
                        }
 
diff --git a/test/test-flight-sql.rb b/test/test-flight-sql.rb
index c4ee2e5..a55243f 100644
--- a/test/test-flight-sql.rb
+++ b/test/test-flight-sql.rb
@@ -29,13 +29,18 @@ class FlightSQLTest < Test::Unit::TestCase
     flight_client.authenticate_basic(user, password, @options)
   end
 
-  def test_select_int32
-    info = flight_sql_client.execute("SELECT 1 AS value", @options)
-    assert_equal(Arrow::Schema.new(value: :int32),
+  data("int16", ["smallint", Arrow::Int16Array, -2])
+  data("int32", ["integer",  Arrow::Int32Array, -2])
+  def test_select_type
+    pg_type, array_class, value = data
+    values = array_class.new([value])
+    info = flight_sql_client.execute("SELECT #{value}::#{pg_type} AS value",
+                                     @options)
+    assert_equal(Arrow::Schema.new(value: values.value_data_type),
                  info.get_schema)
     endpoint = info.endpoints.first
     reader = flight_sql_client.do_get(endpoint.ticket, @options)
-    assert_equal(Arrow::Table.new(value: Arrow::Int32Array.new([1])),
+    assert_equal(Arrow::Table.new(value: values),
                  reader.read_all)
   end
 
@@ -75,8 +80,9 @@ SELECT * FROM data
     RESULT
   end
 
-  data("int8", ["smallint", Arrow::Int8Array, [1, -2, 3]])
-  data("int32", ["integer", Arrow::Int32Array, [1, -2, 3]])
+  data("int8",  ["smallint", Arrow::Int8Array,  [1, -2, 3]])
+  data("int16", ["smallint", Arrow::Int16Array, [1, -2, 3]])
+  data("int32", ["integer",  Arrow::Int32Array, [1, -2, 3]])
   def test_insert_type
     unless flight_sql_client.respond_to?(:prepare)
       omit("red-arrow-flight-sql 14.0.0 or later is required")

Reply via email to