This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch main
in repository 
https://gitbox.apache.org/repos/asf/arrow-flight-sql-postgresql.git


The following commit(s) were added to refs/heads/main by this push:
     new 46a2983  Add support for UInt16 (#68)
46a2983 is described below

commit 46a2983a5983fa1df332f18afd540a724aacce1f
Author: Sutou Kouhei <[email protected]>
AuthorDate: Tue Aug 22 15:30:11 2023 +0900

    Add support for UInt16 (#68)
    
    Closes GH-52
---
 src/afs.cc              | 72 +++++++++++++++++++++++++++++++++++++------------
 test/test-flight-sql.rb | 11 ++++----
 2 files changed, 61 insertions(+), 22 deletions(-)

diff --git a/src/afs.cc b/src/afs.cc
index 1859130..a31505a 100644
--- a/src/afs.cc
+++ b/src/afs.cc
@@ -725,6 +725,52 @@ class WorkerProcessor : public Processor {
        dshash_table* sessions_;
 };
 
+class ArrowPGTypeConverter : public arrow::TypeVisitor {
+   public:
+       explicit ArrowPGTypeConverter() : oid_(InvalidOid) {}
+
+       Oid oid() const { return oid_; }
+
+       arrow::Status Visit(const arrow::Int8Type& type)
+       {
+               oid_ = INT2OID;
+               return arrow::Status::OK();
+       }
+
+       arrow::Status Visit(const arrow::UInt8Type& type)
+       {
+               oid_ = INT2OID;
+               return arrow::Status::OK();
+       }
+
+       arrow::Status Visit(const arrow::Int16Type& type)
+       {
+               oid_ = INT2OID;
+               return arrow::Status::OK();
+       }
+
+       arrow::Status Visit(const arrow::UInt16Type& type)
+       {
+               oid_ = INT2OID;
+               return arrow::Status::OK();
+       }
+
+       arrow::Status Visit(const arrow::Int32Type& type)
+       {
+               oid_ = INT4OID;
+               return arrow::Status::OK();
+       }
+
+       arrow::Status Visit(const arrow::Int64Type& type)
+       {
+               oid_ = INT8OID;
+               return arrow::Status::OK();
+       }
+
+   private:
+       Oid oid_;
+};
+
 class ArrowPGValueConverter : public arrow::ArrayVisitor {
    public:
        explicit ArrowPGValueConverter(int64_t i_row, Datum& datum)
@@ -750,6 +796,12 @@ class ArrowPGValueConverter : public arrow::ArrayVisitor {
                return arrow::Status::OK();
        }
 
+       arrow::Status Visit(const arrow::UInt16Array& array)
+       {
+               datum_ = UInt16GetDatum(array.Value(i_row_));
+               return arrow::Status::OK();
+       }
+
        arrow::Status Visit(const arrow::Int32Array& array)
        {
                datum_ = Int32GetDatum(array.Value(i_row_));
@@ -878,25 +930,11 @@ class PreparedStatement {
                const std::shared_ptr<arrow::Schema>& schema)
        {
                std::vector<Oid> pgTypes;
+               ArrowPGTypeConverter converter;
                for (const auto& field : schema->fields())
                {
-                       switch (field->type()->id())
-                       {
-                               case arrow::Type::INT8:
-                               case arrow::Type::UINT8:
-                               case arrow::Type::INT16:
-                                       pgTypes.push_back(INT2OID);
-                                       break;
-                               case arrow::Type::INT32:
-                                       pgTypes.push_back(INT4OID);
-                                       break;
-                               case arrow::Type::INT64:
-                                       pgTypes.push_back(INT8OID);
-                                       break;
-                               default:
-                                       return arrow::Status::NotImplemented(
-                                               "Unsupported Apache Arrow type: 
", field->type()->name());
-                       }
+                       ARROW_RETURN_NOT_OK(field->type()->Accept(&converter));
+                       pgTypes.push_back(converter.oid());
                }
                return std::move(pgTypes);
        }
diff --git a/test/test-flight-sql.rb b/test/test-flight-sql.rb
index 9a58a0b..6431f8e 100644
--- a/test/test-flight-sql.rb
+++ b/test/test-flight-sql.rb
@@ -81,11 +81,12 @@ SELECT * FROM data
     RESULT
   end
 
-  data("int8",  ["smallint", Arrow::Int8Array,  [1, -2, 3]])
-  data("int16", ["smallint", Arrow::Int16Array, [1, -2, 3]])
-  data("int32", ["integer",  Arrow::Int32Array, [1, -2, 3]])
-  data("int64", ["bigint",   Arrow::Int64Array, [1, -2, 3]])
-  data("uint8", ["smallint", Arrow::UInt8Array, [1,  2, 3]])
+  data("int8",   ["smallint", Arrow::Int8Array,   [1, -2, 3]])
+  data("int16",  ["smallint", Arrow::Int16Array,  [1, -2, 3]])
+  data("int32",  ["integer",  Arrow::Int32Array,  [1, -2, 3]])
+  data("int64",  ["bigint",   Arrow::Int64Array,  [1, -2, 3]])
+  data("uint8",  ["smallint", Arrow::UInt8Array,  [1,  2, 3]])
+  data("uint16", ["smallint", Arrow::UInt16Array, [1,  2, 3]])
   def test_insert_type
     unless flight_sql_client.respond_to?(:prepare)
       omit("red-arrow-flight-sql 14.0.0 or later is required")

Reply via email to