lidavidm commented on code in PR #1288:
URL: https://github.com/apache/arrow-adbc/pull/1288#discussion_r1419666654


##########
c/driver/postgresql/postgres_copy_reader.h:
##########
@@ -1217,6 +1217,142 @@ class PostgresCopyIntervalFieldWriter : public 
PostgresCopyFieldWriter {
   }
 };
 
+// Inspiration for this taken from get_str_from_var in the pg source
+// src/backend/utils/adt/numeric.c
+template<enum ArrowType T>
+class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
+public:
+  PostgresCopyNumericFieldWriter<T>(int32_t precision, int32_t scale) :
+    precision_{precision}, scale_{scale} {}
+
+  ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) 
override {
+    struct ArrowDecimal decimal;
+    ArrowDecimalInit(&decimal, bitwidth_, precision_, scale_);
+    ArrowArrayViewGetDecimalUnsafe(array_view_, index, &decimal);
+
+    const int16_t sign = ArrowDecimalSign(&decimal) > 0 ? kNumericPos : 
kNumericNeg;
+
+    // Number of decimal digits per Postgres digit
+    constexpr int kDecDigits = 4;
+    std::vector<int16_t> pg_digits;
+    int16_t weight = -(scale_ / kDecDigits);
+    int16_t dscale = scale_;
+    bool seen_decimal = scale_ == 0;
+    bool truncating_trailing_zeros = true;
+
+    const std::string decimal_string = DecimalToString<bitwidth_>(&decimal);

Review Comment:
   another micro-optimization might be to have a stack-allocated char array 
here and have DecimalToString just fill the char array and return the index of 
the start; avoids allocating a string in each iteration



##########
c/driver/postgresql/postgres_copy_reader.h:
##########
@@ -1217,6 +1217,142 @@ class PostgresCopyIntervalFieldWriter : public 
PostgresCopyFieldWriter {
   }
 };
 
+// Inspiration for this taken from get_str_from_var in the pg source
+// src/backend/utils/adt/numeric.c
+template<enum ArrowType T>
+class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
+public:
+  PostgresCopyNumericFieldWriter<T>(int32_t precision, int32_t scale) :
+    precision_{precision}, scale_{scale} {}
+
+  ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) 
override {
+    struct ArrowDecimal decimal;
+    ArrowDecimalInit(&decimal, bitwidth_, precision_, scale_);
+    ArrowArrayViewGetDecimalUnsafe(array_view_, index, &decimal);
+
+    const int16_t sign = ArrowDecimalSign(&decimal) > 0 ? kNumericPos : 
kNumericNeg;
+
+    // Number of decimal digits per Postgres digit
+    constexpr int kDecDigits = 4;
+    std::vector<int16_t> pg_digits;
+    int16_t weight = -(scale_ / kDecDigits);
+    int16_t dscale = scale_;
+    bool seen_decimal = scale_ == 0;
+    bool truncating_trailing_zeros = true;
+
+    const std::string decimal_string = DecimalToString<bitwidth_>(&decimal);
+    int digits_remaining = decimal_string.size();
+    do {
+      const int start_pos = digits_remaining < kDecDigits ?
+        0 : digits_remaining - kDecDigits;
+      const size_t len = digits_remaining < 4 ? digits_remaining : kDecDigits;
+      std::string substr{decimal_string.substr(start_pos, len)};
+      int16_t val = static_cast<int16_t>(std::stoi(substr.data()));
+
+      if (val == 0) {
+        if (!seen_decimal && truncating_trailing_zeros) {
+          dscale -= kDecDigits;
+        }
+      } else {
+        pg_digits.insert(pg_digits.begin(), val);
+        if (!seen_decimal && truncating_trailing_zeros) {
+          if (val % 1000 == 0) {
+            dscale -= 3;
+          } else if (val % 100 == 0) {
+            dscale -= 2;
+          } else if (val % 10 == 0) {
+            dscale -= 1;
+          }
+        }
+        truncating_trailing_zeros = false;
+      }
+      digits_remaining -= kDecDigits;
+      if (digits_remaining <= 0) {
+        break;
+      }
+      weight++;
+
+      if (start_pos <= static_cast<int>(decimal_string.size()) - scale_) {
+        seen_decimal = true;
+      }
+    } while (true);
+
+    int16_t ndigits = pg_digits.size();
+    int32_t field_size_bytes = sizeof(ndigits)
+      + sizeof(weight)
+      + sizeof(sign)
+      + sizeof(dscale)
+      + ndigits * sizeof(int16_t);
+
+    NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, field_size_bytes, 
error));
+    NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, ndigits, error));
+    NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, weight, error));
+    NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, sign, error));
+    NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, dscale, error));
+
+    for (auto pg_digit : pg_digits) {
+      NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, pg_digit, error));

Review Comment:
   presumably you could check once then memcpy the digits over



##########
c/driver/postgresql/postgres_copy_reader.h:
##########
@@ -1217,6 +1217,142 @@ class PostgresCopyIntervalFieldWriter : public 
PostgresCopyFieldWriter {
   }
 };
 
+// Inspiration for this taken from get_str_from_var in the pg source
+// src/backend/utils/adt/numeric.c
+template<enum ArrowType T>
+class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
+public:
+  PostgresCopyNumericFieldWriter<T>(int32_t precision, int32_t scale) :
+    precision_{precision}, scale_{scale} {}
+
+  ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) 
override {
+    struct ArrowDecimal decimal;
+    ArrowDecimalInit(&decimal, bitwidth_, precision_, scale_);
+    ArrowArrayViewGetDecimalUnsafe(array_view_, index, &decimal);
+
+    const int16_t sign = ArrowDecimalSign(&decimal) > 0 ? kNumericPos : 
kNumericNeg;
+
+    // Number of decimal digits per Postgres digit
+    constexpr int kDecDigits = 4;
+    std::vector<int16_t> pg_digits;
+    int16_t weight = -(scale_ / kDecDigits);
+    int16_t dscale = scale_;
+    bool seen_decimal = scale_ == 0;
+    bool truncating_trailing_zeros = true;
+
+    const std::string decimal_string = DecimalToString<bitwidth_>(&decimal);
+    int digits_remaining = decimal_string.size();
+    do {
+      const int start_pos = digits_remaining < kDecDigits ?
+        0 : digits_remaining - kDecDigits;
+      const size_t len = digits_remaining < 4 ? digits_remaining : kDecDigits;
+      std::string substr{decimal_string.substr(start_pos, len)};
+      int16_t val = static_cast<int16_t>(std::stoi(substr.data()));
+
+      if (val == 0) {
+        if (!seen_decimal && truncating_trailing_zeros) {
+          dscale -= kDecDigits;
+        }
+      } else {
+        pg_digits.insert(pg_digits.begin(), val);
+        if (!seen_decimal && truncating_trailing_zeros) {
+          if (val % 1000 == 0) {
+            dscale -= 3;
+          } else if (val % 100 == 0) {
+            dscale -= 2;
+          } else if (val % 10 == 0) {
+            dscale -= 1;
+          }
+        }
+        truncating_trailing_zeros = false;
+      }
+      digits_remaining -= kDecDigits;
+      if (digits_remaining <= 0) {
+        break;
+      }
+      weight++;
+
+      if (start_pos <= static_cast<int>(decimal_string.size()) - scale_) {
+        seen_decimal = true;
+      }
+    } while (true);
+
+    int16_t ndigits = pg_digits.size();
+    int32_t field_size_bytes = sizeof(ndigits)
+      + sizeof(weight)
+      + sizeof(sign)
+      + sizeof(dscale)
+      + ndigits * sizeof(int16_t);
+
+    NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, field_size_bytes, 
error));
+    NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, ndigits, error));
+    NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, weight, error));
+    NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, sign, error));
+    NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, dscale, error));
+
+    for (auto pg_digit : pg_digits) {
+      NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, pg_digit, error));

Review Comment:
   (again, possibly the compiler already does this)



##########
c/driver/postgresql/postgres_copy_reader.h:
##########
@@ -1217,6 +1217,142 @@ class PostgresCopyIntervalFieldWriter : public 
PostgresCopyFieldWriter {
   }
 };
 
+// Inspiration for this taken from get_str_from_var in the pg source
+// src/backend/utils/adt/numeric.c
+template<enum ArrowType T>
+class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
+public:
+  PostgresCopyNumericFieldWriter<T>(int32_t precision, int32_t scale) :
+    precision_{precision}, scale_{scale} {}
+
+  ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) 
override {
+    struct ArrowDecimal decimal;
+    ArrowDecimalInit(&decimal, bitwidth_, precision_, scale_);
+    ArrowArrayViewGetDecimalUnsafe(array_view_, index, &decimal);
+
+    const int16_t sign = ArrowDecimalSign(&decimal) > 0 ? kNumericPos : 
kNumericNeg;
+
+    // Number of decimal digits per Postgres digit
+    constexpr int kDecDigits = 4;
+    std::vector<int16_t> pg_digits;
+    int16_t weight = -(scale_ / kDecDigits);
+    int16_t dscale = scale_;
+    bool seen_decimal = scale_ == 0;
+    bool truncating_trailing_zeros = true;
+
+    const std::string decimal_string = DecimalToString<bitwidth_>(&decimal);
+    int digits_remaining = decimal_string.size();
+    do {
+      const int start_pos = digits_remaining < kDecDigits ?
+        0 : digits_remaining - kDecDigits;
+      const size_t len = digits_remaining < 4 ? digits_remaining : kDecDigits;
+      std::string substr{decimal_string.substr(start_pos, len)};

Review Comment:
   c++17 would let us use string_view to avoid the extra allocation; we could 
track indices manually here to avoid it explicitly for now



##########
c/driver/postgresql/postgres_copy_reader.h:
##########
@@ -1217,6 +1217,142 @@ class PostgresCopyIntervalFieldWriter : public 
PostgresCopyFieldWriter {
   }
 };
 
+// Inspiration for this taken from get_str_from_var in the pg source
+// src/backend/utils/adt/numeric.c
+template<enum ArrowType T>
+class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
+public:
+  PostgresCopyNumericFieldWriter<T>(int32_t precision, int32_t scale) :
+    precision_{precision}, scale_{scale} {}
+
+  ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) 
override {
+    struct ArrowDecimal decimal;
+    ArrowDecimalInit(&decimal, bitwidth_, precision_, scale_);
+    ArrowArrayViewGetDecimalUnsafe(array_view_, index, &decimal);
+
+    const int16_t sign = ArrowDecimalSign(&decimal) > 0 ? kNumericPos : 
kNumericNeg;
+
+    // Number of decimal digits per Postgres digit
+    constexpr int kDecDigits = 4;
+    std::vector<int16_t> pg_digits;
+    int16_t weight = -(scale_ / kDecDigits);
+    int16_t dscale = scale_;
+    bool seen_decimal = scale_ == 0;
+    bool truncating_trailing_zeros = true;
+
+    const std::string decimal_string = DecimalToString<bitwidth_>(&decimal);

Review Comment:
   (though possibly, it gets all inlined and optimized away anyways)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to