This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git
The following commit(s) were added to refs/heads/main by this push:
new f17963e GH-120: Add initial Decimal32/Decimal64 implementation (#121)
f17963e is described below
commit f17963ea611201e5aa9af755f99a3c1c9aeaaead
Author: Matt Topol <[email protected]>
AuthorDate: Fri Sep 13 10:36:10 2024 -0400
GH-120: Add initial Decimal32/Decimal64 implementation (#121)
Fix GH-120
### Rationale for this change
Widening the Decimal128/256 type to allow for bitwidths of 32 and 64
allows for more interoperability with other libraries and utilities
which already support these types. This provides even more opportunities
for zero-copy interactions between things such as libcudf and various
databases.
### What changes are included in this PR?
This PR contains the basic Go implementations for Decimal32/Decimal64
types, arrays, builders and scalars. It also includes the minimum
necessary to get everything compiling and tests passing without also
extending the acero kernels and parquet handling (both of which will be
handled in follow-up PRs).
### Are these changes tested?
Yes, tests were extended where applicable to add decimal32/decimal64
cases.
---
.gitignore | 1 +
arrow/array/array.go | 2 +
arrow/array/array_test.go | 2 +
arrow/array/builder.go | 8 +
arrow/array/compare.go | 20 +-
arrow/array/decimal.go | 432 +++++++++++++++++++
arrow/array/decimal128.go | 368 ----------------
arrow/array/decimal256.go | 368 ----------------
arrow/array/dictionary.go | 66 ++-
arrow/array/numeric.gen.go | 17 +
arrow/cdata/cdata.go | 13 +-
arrow/cdata/cdata_exports.go | 4 +
arrow/cdata/cdata_test.go | 2 +-
arrow/datatype.go | 24 +-
arrow/datatype_fixedwidth.go | 89 +++-
arrow/datatype_fixedwidth_test.go | 88 ++++
arrow/decimal/decimal.go | 473 +++++++++++++++++++++
arrow/decimal/decimal_test.go | 470 ++++++++++++++++++++
arrow/decimal/traits.go | 78 ++++
arrow/decimal128/decimal128.go | 10 +
arrow/decimal256/decimal256.go | 10 +
arrow/internal/arrjson/arrjson.go | 84 +++-
arrow/internal/flatbuf/Decimal.go | 24 +-
arrow/ipc/file_reader.go | 2 +-
arrow/ipc/metadata.go | 20 +-
arrow/type_string.go | 6 +-
arrow/type_traits.go | 5 +-
arrow/type_traits_decimal128.go | 14 +-
arrow/type_traits_decimal256.go | 14 +-
...aits_decimal128.go => type_traits_decimal32.go} | 27 +-
...aits_decimal128.go => type_traits_decimal64.go} | 27 +-
arrow/type_traits_test.go | 90 ++++
32 files changed, 2028 insertions(+), 830 deletions(-)
diff --git a/.gitignore b/.gitignore
index 06f4070..5523236 100644
--- a/.gitignore
+++ b/.gitignore
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
+.vscode
/apache-arrow-go.tar.gz
/dev/release/apache-rat-*.jar
/dev/release/filtered_rat.txt
diff --git a/arrow/array/array.go b/arrow/array/array.go
index 586d876..6e281a4 100644
--- a/arrow/array/array.go
+++ b/arrow/array/array.go
@@ -160,6 +160,8 @@ func init() {
arrow.TIME64: func(data arrow.ArrayData)
arrow.Array { return NewTime64Data(data) },
arrow.INTERVAL_MONTHS: func(data arrow.ArrayData)
arrow.Array { return NewMonthIntervalData(data) },
arrow.INTERVAL_DAY_TIME: func(data arrow.ArrayData)
arrow.Array { return NewDayTimeIntervalData(data) },
+ arrow.DECIMAL32: func(data arrow.ArrayData)
arrow.Array { return NewDecimal32Data(data) },
+ arrow.DECIMAL64: func(data arrow.ArrayData)
arrow.Array { return NewDecimal64Data(data) },
arrow.DECIMAL128: func(data arrow.ArrayData)
arrow.Array { return NewDecimal128Data(data) },
arrow.DECIMAL256: func(data arrow.ArrayData)
arrow.Array { return NewDecimal256Data(data) },
arrow.LIST: func(data arrow.ArrayData)
arrow.Array { return NewListData(data) },
diff --git a/arrow/array/array_test.go b/arrow/array/array_test.go
index 203c62e..9509e31 100644
--- a/arrow/array/array_test.go
+++ b/arrow/array/array_test.go
@@ -75,6 +75,8 @@ func TestMakeFromData(t *testing.T) {
{name: "time64", d: &testDataType{arrow.TIME64}},
{name: "month_interval", d:
arrow.FixedWidthTypes.MonthInterval},
{name: "day_time_interval", d:
arrow.FixedWidthTypes.DayTimeInterval},
+ {name: "decimal32", d: &testDataType{arrow.DECIMAL32}},
+ {name: "decimal64", d: &testDataType{arrow.DECIMAL64}},
{name: "decimal128", d: &testDataType{arrow.DECIMAL128}},
{name: "decimal256", d: &testDataType{arrow.DECIMAL256}},
{name: "month_day_nano_interval", d:
arrow.FixedWidthTypes.MonthDayNanoInterval},
diff --git a/arrow/array/builder.go b/arrow/array/builder.go
index 108b615..a2a40d4 100644
--- a/arrow/array/builder.go
+++ b/arrow/array/builder.go
@@ -313,6 +313,14 @@ func NewBuilder(mem memory.Allocator, dtype
arrow.DataType) Builder {
return NewDayTimeIntervalBuilder(mem)
case arrow.INTERVAL_MONTH_DAY_NANO:
return NewMonthDayNanoIntervalBuilder(mem)
+ case arrow.DECIMAL32:
+ if typ, ok := dtype.(*arrow.Decimal32Type); ok {
+ return NewDecimal32Builder(mem, typ)
+ }
+ case arrow.DECIMAL64:
+ if typ, ok := dtype.(*arrow.Decimal64Type); ok {
+ return NewDecimal64Builder(mem, typ)
+ }
case arrow.DECIMAL128:
if typ, ok := dtype.(*arrow.Decimal128Type); ok {
return NewDecimal128Builder(mem, typ)
diff --git a/arrow/array/compare.go b/arrow/array/compare.go
index 4117880..ad3a50b 100644
--- a/arrow/array/compare.go
+++ b/arrow/array/compare.go
@@ -271,12 +271,18 @@ func Equal(left, right arrow.Array) bool {
case *Float64:
r := right.(*Float64)
return arrayEqualFloat64(l, r)
+ case *Decimal32:
+ r := right.(*Decimal32)
+ return arrayEqualDecimal(l, r)
+ case *Decimal64:
+ r := right.(*Decimal64)
+ return arrayEqualDecimal(l, r)
case *Decimal128:
r := right.(*Decimal128)
- return arrayEqualDecimal128(l, r)
+ return arrayEqualDecimal(l, r)
case *Decimal256:
r := right.(*Decimal256)
- return arrayEqualDecimal256(l, r)
+ return arrayEqualDecimal(l, r)
case *Date32:
r := right.(*Date32)
return arrayEqualDate32(l, r)
@@ -527,12 +533,18 @@ func arrayApproxEqual(left, right arrow.Array, opt
equalOption) bool {
case *Float64:
r := right.(*Float64)
return arrayApproxEqualFloat64(l, r, opt)
+ case *Decimal32:
+ r := right.(*Decimal32)
+ return arrayEqualDecimal(l, r)
+ case *Decimal64:
+ r := right.(*Decimal64)
+ return arrayEqualDecimal(l, r)
case *Decimal128:
r := right.(*Decimal128)
- return arrayEqualDecimal128(l, r)
+ return arrayEqualDecimal(l, r)
case *Decimal256:
r := right.(*Decimal256)
- return arrayEqualDecimal256(l, r)
+ return arrayEqualDecimal(l, r)
case *Date32:
r := right.(*Date32)
return arrayEqualDate32(l, r)
diff --git a/arrow/array/decimal.go b/arrow/array/decimal.go
new file mode 100644
index 0000000..1a9d61c
--- /dev/null
+++ b/arrow/array/decimal.go
@@ -0,0 +1,432 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package array
+
+import (
+ "bytes"
+ "fmt"
+ "reflect"
+ "strings"
+ "sync/atomic"
+
+ "github.com/apache/arrow-go/v18/arrow"
+ "github.com/apache/arrow-go/v18/arrow/bitutil"
+ "github.com/apache/arrow-go/v18/arrow/decimal"
+ "github.com/apache/arrow-go/v18/arrow/internal/debug"
+ "github.com/apache/arrow-go/v18/arrow/memory"
+ "github.com/apache/arrow-go/v18/internal/json"
+)
+
+type baseDecimal[T interface {
+ decimal.DecimalTypes
+ decimal.Num[T]
+}] struct {
+ array
+
+ values []T
+}
+
+func newDecimalData[T interface {
+ decimal.DecimalTypes
+ decimal.Num[T]
+}](data arrow.ArrayData) *baseDecimal[T] {
+ a := &baseDecimal[T]{}
+ a.refCount = 1
+ a.setData(data.(*Data))
+ return a
+}
+
+func (a *baseDecimal[T]) Value(i int) T { return a.values[i] }
+
+func (a *baseDecimal[T]) ValueStr(i int) string {
+ if a.IsNull(i) {
+ return NullValueStr
+ }
+ return a.GetOneForMarshal(i).(string)
+}
+
+func (a *baseDecimal[T]) Values() []T { return a.values }
+
+func (a *baseDecimal[T]) String() string {
+ o := new(strings.Builder)
+ o.WriteString("[")
+ for i := 0; i < a.Len(); i++ {
+ if i > 0 {
+ fmt.Fprintf(o, " ")
+ }
+ switch {
+ case a.IsNull(i):
+ o.WriteString(NullValueStr)
+ default:
+ fmt.Fprintf(o, "%v", a.Value(i))
+ }
+ }
+ o.WriteString("]")
+ return o.String()
+}
+
+func (a *baseDecimal[T]) setData(data *Data) {
+ a.array.setData(data)
+ vals := data.buffers[1]
+ if vals != nil {
+ a.values = arrow.GetData[T](vals.Bytes())
+ beg := a.array.data.offset
+ end := beg + a.array.data.length
+ a.values = a.values[beg:end]
+ }
+}
+
+func (a *baseDecimal[T]) GetOneForMarshal(i int) any {
+ if a.IsNull(i) {
+ return nil
+ }
+
+ typ := a.DataType().(arrow.DecimalType)
+ n, scale := a.Value(i), typ.GetScale()
+ return n.ToBigFloat(scale).Text('g', int(typ.GetPrecision()))
+}
+
+func (a *baseDecimal[T]) MarshalJSON() ([]byte, error) {
+ vals := make([]any, a.Len())
+ for i := 0; i < a.Len(); i++ {
+ vals[i] = a.GetOneForMarshal(i)
+ }
+ return json.Marshal(vals)
+}
+
+func arrayEqualDecimal[T interface {
+ decimal.DecimalTypes
+ decimal.Num[T]
+}](left, right *baseDecimal[T]) bool {
+ for i := 0; i < left.Len(); i++ {
+ if left.IsNull(i) {
+ continue
+ }
+
+ if left.Value(i) != right.Value(i) {
+ return false
+ }
+ }
+ return true
+}
+
+type Decimal32 = baseDecimal[decimal.Decimal32]
+
+func NewDecimal32Data(data arrow.ArrayData) *Decimal32 {
+ return newDecimalData[decimal.Decimal32](data)
+}
+
+type Decimal64 = baseDecimal[decimal.Decimal64]
+
+func NewDecimal64Data(data arrow.ArrayData) *Decimal64 {
+ return newDecimalData[decimal.Decimal64](data)
+}
+
+type Decimal128 = baseDecimal[decimal.Decimal128]
+
+func NewDecimal128Data(data arrow.ArrayData) *Decimal128 {
+ return newDecimalData[decimal.Decimal128](data)
+}
+
+type Decimal256 = baseDecimal[decimal.Decimal256]
+
+func NewDecimal256Data(data arrow.ArrayData) *Decimal256 {
+ return newDecimalData[decimal.Decimal256](data)
+}
+
+type Decimal32Builder = baseDecimalBuilder[decimal.Decimal32]
+type Decimal64Builder = baseDecimalBuilder[decimal.Decimal64]
+type Decimal128Builder struct {
+ *baseDecimalBuilder[decimal.Decimal128]
+}
+
+func (b *Decimal128Builder) NewDecimal128Array() *Decimal128 {
+ return b.NewDecimalArray()
+}
+
+type Decimal256Builder struct {
+ *baseDecimalBuilder[decimal.Decimal256]
+}
+
+func (b *Decimal256Builder) NewDecimal256Array() *Decimal256 {
+ return b.NewDecimalArray()
+}
+
+type baseDecimalBuilder[T interface {
+ decimal.DecimalTypes
+ decimal.Num[T]
+}] struct {
+ builder
+ traits decimal.Traits[T]
+
+ dtype arrow.DecimalType
+ data *memory.Buffer
+ rawData []T
+}
+
+func newDecimalBuilder[T interface {
+ decimal.DecimalTypes
+ decimal.Num[T]
+}, DT arrow.DecimalType](mem memory.Allocator, dtype DT)
*baseDecimalBuilder[T] {
+ return &baseDecimalBuilder[T]{
+ builder: builder{refCount: 1, mem: mem},
+ dtype: dtype,
+ }
+}
+
+func (b *baseDecimalBuilder[T]) Type() arrow.DataType { return b.dtype }
+
+func (b *baseDecimalBuilder[T]) Release() {
+ debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases")
+
+ if atomic.AddInt64(&b.refCount, -1) == 0 {
+ if b.nullBitmap != nil {
+ b.nullBitmap.Release()
+ b.nullBitmap = nil
+ }
+ if b.data != nil {
+ b.data.Release()
+ b.data, b.rawData = nil, nil
+ }
+ }
+}
+
+func (b *baseDecimalBuilder[T]) Append(v T) {
+ b.Reserve(1)
+ b.UnsafeAppend(v)
+}
+
+func (b *baseDecimalBuilder[T]) UnsafeAppend(v T) {
+ bitutil.SetBit(b.nullBitmap.Bytes(), b.length)
+ b.rawData[b.length] = v
+ b.length++
+}
+
+func (b *baseDecimalBuilder[T]) AppendNull() {
+ b.Reserve(1)
+ b.UnsafeAppendBoolToBitmap(false)
+}
+
+func (b *baseDecimalBuilder[T]) AppendNulls(n int) {
+ for i := 0; i < n; i++ {
+ b.AppendNull()
+ }
+}
+
+func (b *baseDecimalBuilder[T]) AppendEmptyValue() {
+ var empty T
+ b.Append(empty)
+}
+
+func (b *baseDecimalBuilder[T]) AppendEmptyValues(n int) {
+ for i := 0; i < n; i++ {
+ b.AppendEmptyValue()
+ }
+}
+
+func (b *baseDecimalBuilder[T]) UnsafeAppendBoolToBitmap(isValid bool) {
+ if isValid {
+ bitutil.SetBit(b.nullBitmap.Bytes(), b.length)
+ } else {
+ b.nulls++
+ }
+ b.length++
+}
+
+func (b *baseDecimalBuilder[T]) AppendValues(v []T, valid []bool) {
+ if len(v) != len(valid) && len(valid) != 0 {
+ panic("len(v) != len(valid) && len(valid) != 0")
+ }
+
+ if len(v) == 0 {
+ return
+ }
+
+ b.Reserve(len(v))
+ if len(v) > 0 {
+ copy(b.rawData[b.length:], v)
+ }
+ b.builder.unsafeAppendBoolsToBitmap(valid, len(v))
+}
+
+func (b *baseDecimalBuilder[T]) init(capacity int) {
+ b.builder.init(capacity)
+
+ b.data = memory.NewResizableBuffer(b.mem)
+ bytesN := int(reflect.TypeFor[T]().Size()) * capacity
+ b.data.Resize(bytesN)
+ b.rawData = arrow.GetData[T](b.data.Bytes())
+}
+
+func (b *baseDecimalBuilder[T]) Reserve(n int) {
+ b.builder.reserve(n, b.Resize)
+}
+
+func (b *baseDecimalBuilder[T]) Resize(n int) {
+ nBuilder := n
+ if n < minBuilderCapacity {
+ n = minBuilderCapacity
+ }
+
+ if b.capacity == 0 {
+ b.init(n)
+ } else {
+ b.builder.resize(nBuilder, b.init)
+ b.data.Resize(b.traits.BytesRequired(n))
+ b.rawData = arrow.GetData[T](b.data.Bytes())
+ }
+}
+
+func (b *baseDecimalBuilder[T]) NewDecimalArray() (a *baseDecimal[T]) {
+ data := b.newData()
+ a = newDecimalData[T](data)
+ data.Release()
+ return
+}
+
+func (b *baseDecimalBuilder[T]) NewArray() arrow.Array {
+ return b.NewDecimalArray()
+}
+
+func (b *baseDecimalBuilder[T]) newData() (data *Data) {
+ bytesRequired := b.traits.BytesRequired(b.length)
+ if bytesRequired > 0 && bytesRequired < b.data.Len() {
+ // trim buffers
+ b.data.Resize(bytesRequired)
+ }
+ data = NewData(b.dtype, b.length, []*memory.Buffer{b.nullBitmap,
b.data}, nil, b.nulls, 0)
+ b.reset()
+
+ if b.data != nil {
+ b.data.Release()
+ b.data, b.rawData = nil, nil
+ }
+
+ return
+}
+
+func (b *baseDecimalBuilder[T]) AppendValueFromString(s string) error {
+ if s == NullValueStr {
+ b.AppendNull()
+ return nil
+ }
+
+ val, err := b.traits.FromString(s, b.dtype.GetPrecision(),
b.dtype.GetScale())
+ if err != nil {
+ b.AppendNull()
+ return err
+ }
+ b.Append(val)
+ return nil
+}
+
+func (b *baseDecimalBuilder[T]) UnmarshalOne(dec *json.Decoder) error {
+ t, err := dec.Token()
+ if err != nil {
+ return err
+ }
+
+ var token T
+ switch v := t.(type) {
+ case float64:
+ token, err = b.traits.FromFloat64(v, b.dtype.GetPrecision(),
b.dtype.GetScale())
+ if err != nil {
+ return err
+ }
+ b.Append(token)
+ case string:
+ token, err = b.traits.FromString(v, b.dtype.GetPrecision(),
b.dtype.GetScale())
+ if err != nil {
+ return err
+ }
+ b.Append(token)
+ case json.Number:
+ token, err = b.traits.FromString(v.String(),
b.dtype.GetPrecision(), b.dtype.GetScale())
+ if err != nil {
+ return err
+ }
+ b.Append(token)
+ case nil:
+ b.AppendNull()
+ default:
+ return &json.UnmarshalTypeError{
+ Value: fmt.Sprint(t),
+ Type: reflect.TypeFor[T](),
+ Offset: dec.InputOffset(),
+ }
+ }
+
+ return nil
+}
+
+func (b *baseDecimalBuilder[T]) Unmarshal(dec *json.Decoder) error {
+ for dec.More() {
+ if err := b.UnmarshalOne(dec); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func (b *baseDecimalBuilder[T]) UnmarshalJSON(data []byte) error {
+ dec := json.NewDecoder(bytes.NewReader(data))
+ t, err := dec.Token()
+ if err != nil {
+ return err
+ }
+
+ if delim, ok := t.(json.Delim); !ok || delim != '[' {
+ return fmt.Errorf("decimal builder must unpack from json array,
found %s", delim)
+ }
+
+ return b.Unmarshal(dec)
+}
+
+func NewDecimal32Builder(mem memory.Allocator, dtype *arrow.Decimal32Type)
*Decimal32Builder {
+ b := newDecimalBuilder[decimal.Decimal32](mem, dtype)
+ b.traits = decimal.Dec32Traits
+ return b
+}
+
+func NewDecimal64Builder(mem memory.Allocator, dtype *arrow.Decimal64Type)
*Decimal64Builder {
+ b := newDecimalBuilder[decimal.Decimal64](mem, dtype)
+ b.traits = decimal.Dec64Traits
+ return b
+}
+
+func NewDecimal128Builder(mem memory.Allocator, dtype *arrow.Decimal128Type)
*Decimal128Builder {
+ b := newDecimalBuilder[decimal.Decimal128](mem, dtype)
+ b.traits = decimal.Dec128Traits
+ return &Decimal128Builder{b}
+}
+
+func NewDecimal256Builder(mem memory.Allocator, dtype *arrow.Decimal256Type)
*Decimal256Builder {
+ b := newDecimalBuilder[decimal.Decimal256](mem, dtype)
+ b.traits = decimal.Dec256Traits
+ return &Decimal256Builder{b}
+}
+
+var (
+ _ arrow.Array = (*Decimal32)(nil)
+ _ arrow.Array = (*Decimal64)(nil)
+ _ arrow.Array = (*Decimal128)(nil)
+ _ arrow.Array = (*Decimal256)(nil)
+ _ Builder = (*Decimal32Builder)(nil)
+ _ Builder = (*Decimal64Builder)(nil)
+ _ Builder = (*Decimal128Builder)(nil)
+ _ Builder = (*Decimal256Builder)(nil)
+)
diff --git a/arrow/array/decimal128.go b/arrow/array/decimal128.go
deleted file mode 100644
index c5861dc..0000000
--- a/arrow/array/decimal128.go
+++ /dev/null
@@ -1,368 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package array
-
-import (
- "bytes"
- "fmt"
- "math/big"
- "reflect"
- "strings"
- "sync/atomic"
-
- "github.com/apache/arrow-go/v18/arrow"
- "github.com/apache/arrow-go/v18/arrow/bitutil"
- "github.com/apache/arrow-go/v18/arrow/decimal128"
- "github.com/apache/arrow-go/v18/arrow/internal/debug"
- "github.com/apache/arrow-go/v18/arrow/memory"
- "github.com/apache/arrow-go/v18/internal/json"
-)
-
-// A type which represents an immutable sequence of 128-bit decimal values.
-type Decimal128 struct {
- array
-
- values []decimal128.Num
-}
-
-func NewDecimal128Data(data arrow.ArrayData) *Decimal128 {
- a := &Decimal128{}
- a.refCount = 1
- a.setData(data.(*Data))
- return a
-}
-
-func (a *Decimal128) Value(i int) decimal128.Num { return a.values[i] }
-
-func (a *Decimal128) ValueStr(i int) string {
- if a.IsNull(i) {
- return NullValueStr
- }
- return a.GetOneForMarshal(i).(string)
-}
-
-func (a *Decimal128) Values() []decimal128.Num { return a.values }
-
-func (a *Decimal128) String() string {
- o := new(strings.Builder)
- o.WriteString("[")
- for i := 0; i < a.Len(); i++ {
- if i > 0 {
- fmt.Fprintf(o, " ")
- }
- switch {
- case a.IsNull(i):
- o.WriteString(NullValueStr)
- default:
- fmt.Fprintf(o, "%v", a.Value(i))
- }
- }
- o.WriteString("]")
- return o.String()
-}
-
-func (a *Decimal128) setData(data *Data) {
- a.array.setData(data)
- vals := data.buffers[1]
- if vals != nil {
- a.values = arrow.Decimal128Traits.CastFromBytes(vals.Bytes())
- beg := a.array.data.offset
- end := beg + a.array.data.length
- a.values = a.values[beg:end]
- }
-}
-func (a *Decimal128) GetOneForMarshal(i int) interface{} {
- if a.IsNull(i) {
- return nil
- }
- typ := a.DataType().(*arrow.Decimal128Type)
- n := a.Value(i)
- scale := typ.Scale
- f := (&big.Float{}).SetInt(n.BigInt())
- if scale < 0 {
- f.SetPrec(128).Mul(f,
(&big.Float{}).SetInt(decimal128.GetScaleMultiplier(int(-scale)).BigInt()))
- } else {
- f.SetPrec(128).Quo(f,
(&big.Float{}).SetInt(decimal128.GetScaleMultiplier(int(scale)).BigInt()))
- }
- return f.Text('g', int(typ.Precision))
-}
-
-// ["1.23", ]
-func (a *Decimal128) MarshalJSON() ([]byte, error) {
- vals := make([]interface{}, a.Len())
- for i := 0; i < a.Len(); i++ {
- vals[i] = a.GetOneForMarshal(i)
- }
- return json.Marshal(vals)
-}
-
-func arrayEqualDecimal128(left, right *Decimal128) bool {
- for i := 0; i < left.Len(); i++ {
- if left.IsNull(i) {
- continue
- }
- if left.Value(i) != right.Value(i) {
- return false
- }
- }
- return true
-}
-
-type Decimal128Builder struct {
- builder
-
- dtype *arrow.Decimal128Type
- data *memory.Buffer
- rawData []decimal128.Num
-}
-
-func NewDecimal128Builder(mem memory.Allocator, dtype *arrow.Decimal128Type)
*Decimal128Builder {
- return &Decimal128Builder{
- builder: builder{refCount: 1, mem: mem},
- dtype: dtype,
- }
-}
-
-func (b *Decimal128Builder) Type() arrow.DataType { return b.dtype }
-
-// Release decreases the reference count by 1.
-// When the reference count goes to zero, the memory is freed.
-func (b *Decimal128Builder) Release() {
- debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases")
-
- if atomic.AddInt64(&b.refCount, -1) == 0 {
- if b.nullBitmap != nil {
- b.nullBitmap.Release()
- b.nullBitmap = nil
- }
- if b.data != nil {
- b.data.Release()
- b.data = nil
- b.rawData = nil
- }
- }
-}
-
-func (b *Decimal128Builder) Append(v decimal128.Num) {
- b.Reserve(1)
- b.UnsafeAppend(v)
-}
-
-func (b *Decimal128Builder) UnsafeAppend(v decimal128.Num) {
- bitutil.SetBit(b.nullBitmap.Bytes(), b.length)
- b.rawData[b.length] = v
- b.length++
-}
-
-func (b *Decimal128Builder) AppendNull() {
- b.Reserve(1)
- b.UnsafeAppendBoolToBitmap(false)
-}
-
-func (b *Decimal128Builder) AppendNulls(n int) {
- for i := 0; i < n; i++ {
- b.AppendNull()
- }
-}
-
-func (b *Decimal128Builder) AppendEmptyValue() {
- b.Append(decimal128.Num{})
-}
-
-func (b *Decimal128Builder) AppendEmptyValues(n int) {
- for i := 0; i < n; i++ {
- b.AppendEmptyValue()
- }
-}
-
-func (b *Decimal128Builder) UnsafeAppendBoolToBitmap(isValid bool) {
- if isValid {
- bitutil.SetBit(b.nullBitmap.Bytes(), b.length)
- } else {
- b.nulls++
- }
- b.length++
-}
-
-// AppendValues will append the values in the v slice. The valid slice
determines which values
-// in v are valid (not null). The valid slice must either be empty or be equal
in length to v. If empty,
-// all values in v are appended and considered valid.
-func (b *Decimal128Builder) AppendValues(v []decimal128.Num, valid []bool) {
- if len(v) != len(valid) && len(valid) != 0 {
- panic("len(v) != len(valid) && len(valid) != 0")
- }
-
- if len(v) == 0 {
- return
- }
-
- b.Reserve(len(v))
- if len(v) > 0 {
- arrow.Decimal128Traits.Copy(b.rawData[b.length:], v)
- }
- b.builder.unsafeAppendBoolsToBitmap(valid, len(v))
-}
-
-func (b *Decimal128Builder) init(capacity int) {
- b.builder.init(capacity)
-
- b.data = memory.NewResizableBuffer(b.mem)
- bytesN := arrow.Decimal128Traits.BytesRequired(capacity)
- b.data.Resize(bytesN)
- b.rawData = arrow.Decimal128Traits.CastFromBytes(b.data.Bytes())
-}
-
-// Reserve ensures there is enough space for appending n elements
-// by checking the capacity and calling Resize if necessary.
-func (b *Decimal128Builder) Reserve(n int) {
- b.builder.reserve(n, b.Resize)
-}
-
-// Resize adjusts the space allocated by b to n elements. If n is greater than
b.Cap(),
-// additional memory will be allocated. If n is smaller, the allocated memory
may reduced.
-func (b *Decimal128Builder) Resize(n int) {
- nBuilder := n
- if n < minBuilderCapacity {
- n = minBuilderCapacity
- }
-
- if b.capacity == 0 {
- b.init(n)
- } else {
- b.builder.resize(nBuilder, b.init)
- b.data.Resize(arrow.Decimal128Traits.BytesRequired(n))
- b.rawData = arrow.Decimal128Traits.CastFromBytes(b.data.Bytes())
- }
-}
-
-// NewArray creates a Decimal128 array from the memory buffers used by the
builder and resets the Decimal128Builder
-// so it can be used to build a new array.
-func (b *Decimal128Builder) NewArray() arrow.Array {
- return b.NewDecimal128Array()
-}
-
-// NewDecimal128Array creates a Decimal128 array from the memory buffers used
by the builder and resets the Decimal128Builder
-// so it can be used to build a new array.
-func (b *Decimal128Builder) NewDecimal128Array() (a *Decimal128) {
- data := b.newData()
- a = NewDecimal128Data(data)
- data.Release()
- return
-}
-
-func (b *Decimal128Builder) newData() (data *Data) {
- bytesRequired := arrow.Decimal128Traits.BytesRequired(b.length)
- if bytesRequired > 0 && bytesRequired < b.data.Len() {
- // trim buffers
- b.data.Resize(bytesRequired)
- }
- data = NewData(b.dtype, b.length, []*memory.Buffer{b.nullBitmap,
b.data}, nil, b.nulls, 0)
- b.reset()
-
- if b.data != nil {
- b.data.Release()
- b.data = nil
- b.rawData = nil
- }
-
- return
-}
-
-func (b *Decimal128Builder) AppendValueFromString(s string) error {
- if s == NullValueStr {
- b.AppendNull()
- return nil
- }
- val, err := decimal128.FromString(s, b.dtype.Precision, b.dtype.Scale)
- if err != nil {
- b.AppendNull()
- return err
- }
- b.Append(val)
- return nil
-}
-
-func (b *Decimal128Builder) UnmarshalOne(dec *json.Decoder) error {
- t, err := dec.Token()
- if err != nil {
- return err
- }
-
- switch v := t.(type) {
- case float64:
- val, err := decimal128.FromFloat64(v, b.dtype.Precision,
b.dtype.Scale)
- if err != nil {
- return err
- }
- b.Append(val)
- case string:
- val, err := decimal128.FromString(v, b.dtype.Precision,
b.dtype.Scale)
- if err != nil {
- return err
- }
- b.Append(val)
- case json.Number:
- val, err := decimal128.FromString(v.String(),
b.dtype.Precision, b.dtype.Scale)
- if err != nil {
- return err
- }
- b.Append(val)
- case nil:
- b.AppendNull()
- return nil
- default:
- return &json.UnmarshalTypeError{
- Value: fmt.Sprint(t),
- Type: reflect.TypeOf(decimal128.Num{}),
- Offset: dec.InputOffset(),
- }
- }
-
- return nil
-}
-
-func (b *Decimal128Builder) Unmarshal(dec *json.Decoder) error {
- for dec.More() {
- if err := b.UnmarshalOne(dec); err != nil {
- return err
- }
- }
- return nil
-}
-
-// UnmarshalJSON will add the unmarshalled values to this builder.
-//
-// If the values are strings, they will get parsed with big.ParseFloat using
-// a rounding mode of big.ToNearestAway currently.
-func (b *Decimal128Builder) UnmarshalJSON(data []byte) error {
- dec := json.NewDecoder(bytes.NewReader(data))
- t, err := dec.Token()
- if err != nil {
- return err
- }
-
- if delim, ok := t.(json.Delim); !ok || delim != '[' {
- return fmt.Errorf("decimal128 builder must unpack from json
array, found %s", delim)
- }
-
- return b.Unmarshal(dec)
-}
-
-var (
- _ arrow.Array = (*Decimal128)(nil)
- _ Builder = (*Decimal128Builder)(nil)
-)
diff --git a/arrow/array/decimal256.go b/arrow/array/decimal256.go
deleted file mode 100644
index 7f30f89..0000000
--- a/arrow/array/decimal256.go
+++ /dev/null
@@ -1,368 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package array
-
-import (
- "bytes"
- "fmt"
- "math/big"
- "reflect"
- "strings"
- "sync/atomic"
-
- "github.com/apache/arrow-go/v18/arrow"
- "github.com/apache/arrow-go/v18/arrow/bitutil"
- "github.com/apache/arrow-go/v18/arrow/decimal256"
- "github.com/apache/arrow-go/v18/arrow/internal/debug"
- "github.com/apache/arrow-go/v18/arrow/memory"
- "github.com/apache/arrow-go/v18/internal/json"
-)
-
-// Decimal256 is a type that represents an immutable sequence of 256-bit
decimal values.
-type Decimal256 struct {
- array
-
- values []decimal256.Num
-}
-
-func NewDecimal256Data(data arrow.ArrayData) *Decimal256 {
- a := &Decimal256{}
- a.refCount = 1
- a.setData(data.(*Data))
- return a
-}
-
-func (a *Decimal256) Value(i int) decimal256.Num { return a.values[i] }
-
-func (a *Decimal256) ValueStr(i int) string {
- if a.IsNull(i) {
- return NullValueStr
- }
- return a.GetOneForMarshal(i).(string)
-}
-
-func (a *Decimal256) Values() []decimal256.Num { return a.values }
-
-func (a *Decimal256) String() string {
- o := new(strings.Builder)
- o.WriteString("[")
- for i := 0; i < a.Len(); i++ {
- if i > 0 {
- fmt.Fprintf(o, " ")
- }
- switch {
- case a.IsNull(i):
- o.WriteString(NullValueStr)
- default:
- fmt.Fprintf(o, "%v", a.Value(i))
- }
- }
- o.WriteString("]")
- return o.String()
-}
-
-func (a *Decimal256) setData(data *Data) {
- a.array.setData(data)
- vals := data.buffers[1]
- if vals != nil {
- a.values = arrow.Decimal256Traits.CastFromBytes(vals.Bytes())
- beg := a.array.data.offset
- end := beg + a.array.data.length
- a.values = a.values[beg:end]
- }
-}
-
-func (a *Decimal256) GetOneForMarshal(i int) interface{} {
- if a.IsNull(i) {
- return nil
- }
- typ := a.DataType().(*arrow.Decimal256Type)
- n := a.Value(i)
- scale := typ.Scale
- f := (&big.Float{}).SetInt(n.BigInt())
- if scale < 0 {
- f.SetPrec(256).Mul(f,
(&big.Float{}).SetInt(decimal256.GetScaleMultiplier(int(-scale)).BigInt()))
- } else {
- f.SetPrec(256).Quo(f,
(&big.Float{}).SetInt(decimal256.GetScaleMultiplier(int(scale)).BigInt()))
- }
- return f.Text('g', int(typ.Precision))
-}
-
-func (a *Decimal256) MarshalJSON() ([]byte, error) {
- vals := make([]interface{}, a.Len())
- for i := 0; i < a.Len(); i++ {
- vals[i] = a.GetOneForMarshal(i)
- }
- return json.Marshal(vals)
-}
-
-func arrayEqualDecimal256(left, right *Decimal256) bool {
- for i := 0; i < left.Len(); i++ {
- if left.IsNull(i) {
- continue
- }
- if left.Value(i) != right.Value(i) {
- return false
- }
- }
- return true
-}
-
-type Decimal256Builder struct {
- builder
-
- dtype *arrow.Decimal256Type
- data *memory.Buffer
- rawData []decimal256.Num
-}
-
-func NewDecimal256Builder(mem memory.Allocator, dtype *arrow.Decimal256Type)
*Decimal256Builder {
- return &Decimal256Builder{
- builder: builder{refCount: 1, mem: mem},
- dtype: dtype,
- }
-}
-
-// Release decreases the reference count by 1.
-// When the reference count goes to zero, the memory is freed.
-func (b *Decimal256Builder) Release() {
- debug.Assert(atomic.LoadInt64(&b.refCount) > 0, "too many releases")
-
- if atomic.AddInt64(&b.refCount, -1) == 0 {
- if b.nullBitmap != nil {
- b.nullBitmap.Release()
- b.nullBitmap = nil
- }
- if b.data != nil {
- b.data.Release()
- b.data = nil
- b.rawData = nil
- }
- }
-}
-
-func (b *Decimal256Builder) Append(v decimal256.Num) {
- b.Reserve(1)
- b.UnsafeAppend(v)
-}
-
-func (b *Decimal256Builder) UnsafeAppend(v decimal256.Num) {
- bitutil.SetBit(b.nullBitmap.Bytes(), b.length)
- b.rawData[b.length] = v
- b.length++
-}
-
-func (b *Decimal256Builder) AppendNull() {
- b.Reserve(1)
- b.UnsafeAppendBoolToBitmap(false)
-}
-
-func (b *Decimal256Builder) AppendNulls(n int) {
- for i := 0; i < n; i++ {
- b.AppendNull()
- }
-}
-
-func (b *Decimal256Builder) AppendEmptyValue() {
- b.Append(decimal256.Num{})
-}
-
-func (b *Decimal256Builder) AppendEmptyValues(n int) {
- for i := 0; i < n; i++ {
- b.AppendEmptyValue()
- }
-}
-
-func (b *Decimal256Builder) Type() arrow.DataType { return b.dtype }
-
-func (b *Decimal256Builder) UnsafeAppendBoolToBitmap(isValid bool) {
- if isValid {
- bitutil.SetBit(b.nullBitmap.Bytes(), b.length)
- } else {
- b.nulls++
- }
- b.length++
-}
-
-// AppendValues will append the values in the v slice. The valid slice
determines which values
-// in v are valid (not null). The valid slice must either be empty or be equal
in length to v. If empty,
-// all values in v are appended and considered valid.
-func (b *Decimal256Builder) AppendValues(v []decimal256.Num, valid []bool) {
- if len(v) != len(valid) && len(valid) != 0 {
- panic("arrow/array: len(v) != len(valid) && len(valid) != 0")
- }
-
- if len(v) == 0 {
- return
- }
-
- b.Reserve(len(v))
- if len(v) > 0 {
- arrow.Decimal256Traits.Copy(b.rawData[b.length:], v)
- }
- b.builder.unsafeAppendBoolsToBitmap(valid, len(v))
-}
-
-func (b *Decimal256Builder) init(capacity int) {
- b.builder.init(capacity)
-
- b.data = memory.NewResizableBuffer(b.mem)
- bytesN := arrow.Decimal256Traits.BytesRequired(capacity)
- b.data.Resize(bytesN)
- b.rawData = arrow.Decimal256Traits.CastFromBytes(b.data.Bytes())
-}
-
-// Reserve ensures there is enough space for appending n elements
-// by checking the capacity and calling Resize if necessary.
-func (b *Decimal256Builder) Reserve(n int) {
- b.builder.reserve(n, b.Resize)
-}
-
-// Resize adjusts the space allocated by b to n elements. If n is greater than
b.Cap(),
-// additional memory will be allocated. If n is smaller, the allocated memory
may reduced.
-func (b *Decimal256Builder) Resize(n int) {
- nBuilder := n
- if n < minBuilderCapacity {
- n = minBuilderCapacity
- }
-
- if b.capacity == 0 {
- b.init(n)
- } else {
- b.builder.resize(nBuilder, b.init)
- b.data.Resize(arrow.Decimal256Traits.BytesRequired(n))
- b.rawData = arrow.Decimal256Traits.CastFromBytes(b.data.Bytes())
- }
-}
-
-// NewArray creates a Decimal256 array from the memory buffers used by the
builder and resets the Decimal256Builder
-// so it can be used to build a new array.
-func (b *Decimal256Builder) NewArray() arrow.Array {
- return b.NewDecimal256Array()
-}
-
-// NewDecimal256Array creates a Decimal256 array from the memory buffers used
by the builder and resets the Decimal256Builder
-// so it can be used to build a new array.
-func (b *Decimal256Builder) NewDecimal256Array() (a *Decimal256) {
- data := b.newData()
- a = NewDecimal256Data(data)
- data.Release()
- return
-}
-
-func (b *Decimal256Builder) newData() (data *Data) {
- bytesRequired := arrow.Decimal256Traits.BytesRequired(b.length)
- if bytesRequired > 0 && bytesRequired < b.data.Len() {
- // trim buffers
- b.data.Resize(bytesRequired)
- }
- data = NewData(b.dtype, b.length, []*memory.Buffer{b.nullBitmap,
b.data}, nil, b.nulls, 0)
- b.reset()
-
- if b.data != nil {
- b.data.Release()
- b.data = nil
- b.rawData = nil
- }
-
- return
-}
-
-func (b *Decimal256Builder) AppendValueFromString(s string) error {
- if s == NullValueStr {
- b.AppendNull()
- return nil
- }
- val, err := decimal256.FromString(s, b.dtype.Precision, b.dtype.Scale)
- if err != nil {
- b.AppendNull()
- return err
- }
- b.Append(val)
- return nil
-}
-
-func (b *Decimal256Builder) UnmarshalOne(dec *json.Decoder) error {
- t, err := dec.Token()
- if err != nil {
- return err
- }
-
- switch v := t.(type) {
- case float64:
- val, err := decimal256.FromFloat64(v, b.dtype.Precision,
b.dtype.Scale)
- if err != nil {
- return err
- }
- b.Append(val)
- case string:
- out, err := decimal256.FromString(v, b.dtype.Precision,
b.dtype.Scale)
- if err != nil {
- return err
- }
- b.Append(out)
- case json.Number:
- out, err := decimal256.FromString(v.String(),
b.dtype.Precision, b.dtype.Scale)
- if err != nil {
- return err
- }
- b.Append(out)
- case nil:
- b.AppendNull()
- return nil
- default:
- return &json.UnmarshalTypeError{
- Value: fmt.Sprint(t),
- Type: reflect.TypeOf(decimal256.Num{}),
- Offset: dec.InputOffset(),
- }
- }
-
- return nil
-}
-
-func (b *Decimal256Builder) Unmarshal(dec *json.Decoder) error {
- for dec.More() {
- if err := b.UnmarshalOne(dec); err != nil {
- return err
- }
- }
- return nil
-}
-
-// UnmarshalJSON will add the unmarshalled values to this builder.
-//
-// If the values are strings, they will get parsed with big.ParseFloat using
-// a rounding mode of big.ToNearestAway currently.
-func (b *Decimal256Builder) UnmarshalJSON(data []byte) error {
- dec := json.NewDecoder(bytes.NewReader(data))
- t, err := dec.Token()
- if err != nil {
- return err
- }
-
- if delim, ok := t.(json.Delim); !ok || delim != '[' {
- return fmt.Errorf("arrow/array: decimal256 builder must unpack
from json array, found %s", delim)
- }
-
- return b.Unmarshal(dec)
-}
-
-var (
- _ arrow.Array = (*Decimal256)(nil)
- _ Builder = (*Decimal256Builder)(nil)
-)
diff --git a/arrow/array/dictionary.go b/arrow/array/dictionary.go
index 34f0f2b..0c23934 100644
--- a/arrow/array/dictionary.go
+++ b/arrow/array/dictionary.go
@@ -27,6 +27,7 @@ import (
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/bitutil"
+ "github.com/apache/arrow-go/v18/arrow/decimal"
"github.com/apache/arrow-go/v18/arrow/decimal128"
"github.com/apache/arrow-go/v18/arrow/decimal256"
"github.com/apache/arrow-go/v18/arrow/float16"
@@ -392,7 +393,8 @@ func createMemoTable(mem memory.Allocator, dt
arrow.DataType) (ret hashing.MemoT
ret = hashing.NewFloat32MemoTable(0)
case arrow.FLOAT64:
ret = hashing.NewFloat64MemoTable(0)
- case arrow.BINARY, arrow.FIXED_SIZE_BINARY, arrow.DECIMAL128,
arrow.DECIMAL256, arrow.INTERVAL_DAY_TIME, arrow.INTERVAL_MONTH_DAY_NANO:
+ case arrow.BINARY, arrow.FIXED_SIZE_BINARY, arrow.DECIMAL32,
arrow.DECIMAL64,
+ arrow.DECIMAL128, arrow.DECIMAL256, arrow.INTERVAL_DAY_TIME,
arrow.INTERVAL_MONTH_DAY_NANO:
ret = hashing.NewBinaryMemoTable(0, 0, NewBinaryBuilder(mem,
arrow.BinaryTypes.Binary))
case arrow.STRING:
ret = hashing.NewBinaryMemoTable(0, 0, NewBinaryBuilder(mem,
arrow.BinaryTypes.String))
@@ -623,6 +625,22 @@ func NewDictionaryBuilderWithDict(mem memory.Allocator, dt
*arrow.DictionaryType
}
}
return ret
+ case arrow.DECIMAL32:
+ ret := &Decimal32DictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Decimal32)); err
!= nil {
+ panic(err)
+ }
+ }
+ return ret
+ case arrow.DECIMAL64:
+ ret := &Decimal64DictionaryBuilder{bldr}
+ if init != nil {
+ if err = ret.InsertDictValues(init.(*Decimal64)); err
!= nil {
+ panic(err)
+ }
+ }
+ return ret
case arrow.DECIMAL128:
ret := &Decimal128DictionaryBuilder{bldr}
if init != nil {
@@ -906,6 +924,16 @@ func getvalFn(arr arrow.Array) func(i int) interface{} {
return func(i int) interface{} { return typedarr.Value(i) }
case *String:
return func(i int) interface{} { return typedarr.Value(i) }
+ case *Decimal32:
+ return func(i int) interface{} {
+ val := typedarr.Value(i)
+ return
(*(*[arrow.Decimal32SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+ }
+ case *Decimal64:
+ return func(i int) interface{} {
+ val := typedarr.Value(i)
+ return
(*(*[arrow.Decimal64SizeBytes]byte)(unsafe.Pointer(&val)))[:]
+ }
case *Decimal128:
return func(i int) interface{} {
val := typedarr.Value(i)
@@ -1394,6 +1422,42 @@ func (b *FixedSizeBinaryDictionaryBuilder)
InsertDictValues(arr *FixedSizeBinary
return
}
+type Decimal32DictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *Decimal32DictionaryBuilder) Append(v decimal.Decimal32) error {
+ return
b.appendValue((*(*[arrow.Decimal32SizeBytes]byte)(unsafe.Pointer(&v)))[:])
+}
+func (b *Decimal32DictionaryBuilder) InsertDictValues(arr *Decimal32) (err
error) {
+ data := arrow.Decimal32Traits.CastToBytes(arr.values)
+ for len(data) > 0 {
+ if err = b.insertDictValue(data[:arrow.Decimal32SizeBytes]);
err != nil {
+ break
+ }
+ data = data[arrow.Decimal32SizeBytes:]
+ }
+ return
+}
+
+type Decimal64DictionaryBuilder struct {
+ dictionaryBuilder
+}
+
+func (b *Decimal64DictionaryBuilder) Append(v decimal.Decimal64) error {
+ return
b.appendValue((*(*[arrow.Decimal64SizeBytes]byte)(unsafe.Pointer(&v)))[:])
+}
+func (b *Decimal64DictionaryBuilder) InsertDictValues(arr *Decimal64) (err
error) {
+ data := arrow.Decimal64Traits.CastToBytes(arr.values)
+ for len(data) > 0 {
+ if err = b.insertDictValue(data[:arrow.Decimal64SizeBytes]);
err != nil {
+ break
+ }
+ data = data[arrow.Decimal64SizeBytes:]
+ }
+ return
+}
+
type Decimal128DictionaryBuilder struct {
dictionaryBuilder
}
diff --git a/arrow/array/numeric.gen.go b/arrow/array/numeric.gen.go
index c6c7b0b..7e94fe5 100644
--- a/arrow/array/numeric.gen.go
+++ b/arrow/array/numeric.gen.go
@@ -101,11 +101,13 @@ func (a *Int64) GetOneForMarshal(i int) interface{} {
func (a *Int64) MarshalJSON() ([]byte, error) {
vals := make([]interface{}, a.Len())
for i := 0; i < a.Len(); i++ {
+
if a.IsValid(i) {
vals[i] = a.values[i]
} else {
vals[i] = nil
}
+
}
return json.Marshal(vals)
@@ -196,11 +198,13 @@ func (a *Uint64) GetOneForMarshal(i int) interface{} {
func (a *Uint64) MarshalJSON() ([]byte, error) {
vals := make([]interface{}, a.Len())
for i := 0; i < a.Len(); i++ {
+
if a.IsValid(i) {
vals[i] = a.values[i]
} else {
vals[i] = nil
}
+
}
return json.Marshal(vals)
@@ -398,11 +402,13 @@ func (a *Int32) GetOneForMarshal(i int) interface{} {
func (a *Int32) MarshalJSON() ([]byte, error) {
vals := make([]interface{}, a.Len())
for i := 0; i < a.Len(); i++ {
+
if a.IsValid(i) {
vals[i] = a.values[i]
} else {
vals[i] = nil
}
+
}
return json.Marshal(vals)
@@ -493,11 +499,13 @@ func (a *Uint32) GetOneForMarshal(i int) interface{} {
func (a *Uint32) MarshalJSON() ([]byte, error) {
vals := make([]interface{}, a.Len())
for i := 0; i < a.Len(); i++ {
+
if a.IsValid(i) {
vals[i] = a.values[i]
} else {
vals[i] = nil
}
+
}
return json.Marshal(vals)
@@ -602,6 +610,7 @@ func (a *Float32) MarshalJSON() ([]byte, error) {
default:
vals[i] = f
}
+
}
return json.Marshal(vals)
@@ -692,11 +701,13 @@ func (a *Int16) GetOneForMarshal(i int) interface{} {
func (a *Int16) MarshalJSON() ([]byte, error) {
vals := make([]interface{}, a.Len())
for i := 0; i < a.Len(); i++ {
+
if a.IsValid(i) {
vals[i] = a.values[i]
} else {
vals[i] = nil
}
+
}
return json.Marshal(vals)
@@ -787,11 +798,13 @@ func (a *Uint16) GetOneForMarshal(i int) interface{} {
func (a *Uint16) MarshalJSON() ([]byte, error) {
vals := make([]interface{}, a.Len())
for i := 0; i < a.Len(); i++ {
+
if a.IsValid(i) {
vals[i] = a.values[i]
} else {
vals[i] = nil
}
+
}
return json.Marshal(vals)
@@ -882,11 +895,13 @@ func (a *Int8) GetOneForMarshal(i int) interface{} {
func (a *Int8) MarshalJSON() ([]byte, error) {
vals := make([]interface{}, a.Len())
for i := 0; i < a.Len(); i++ {
+
if a.IsValid(i) {
vals[i] = float64(a.values[i]) // prevent uint8 from
being seen as binary data
} else {
vals[i] = nil
}
+
}
return json.Marshal(vals)
@@ -977,11 +992,13 @@ func (a *Uint8) GetOneForMarshal(i int) interface{} {
func (a *Uint8) MarshalJSON() ([]byte, error) {
vals := make([]interface{}, a.Len())
for i := 0; i < a.Len(); i++ {
+
if a.IsValid(i) {
vals[i] = float64(a.values[i]) // prevent uint8 from
being seen as binary data
} else {
vals[i] = nil
}
+
}
return json.Marshal(vals)
diff --git a/arrow/cdata/cdata.go b/arrow/cdata/cdata.go
index 4688a1e..d5748a3 100644
--- a/arrow/cdata/cdata.go
+++ b/arrow/cdata/cdata.go
@@ -254,12 +254,17 @@ func importSchema(schema *CArrowSchema) (ret arrow.Field,
err error) {
return ret, xerrors.Errorf("could not parse decimal
scale in '%s': %s", f, err.Error())
}
- if bitwidth == 128 {
+ switch bitwidth {
+ case 32:
+ dt = &arrow.Decimal32Type{Precision: int32(precision),
Scale: int32(scale)}
+ case 64:
+ dt = &arrow.Decimal64Type{Precision: int32(precision),
Scale: int32(scale)}
+ case 128:
dt = &arrow.Decimal128Type{Precision: int32(precision),
Scale: int32(scale)}
- } else if bitwidth == 256 {
+ case 256:
dt = &arrow.Decimal256Type{Precision: int32(precision),
Scale: int32(scale)}
- } else {
- return ret, xerrors.Errorf("only decimal128 and
decimal256 are supported, got '%s'", f)
+ default:
+ return ret, xerrors.Errorf("unsupported decimal
bitwidth, got '%s'", f)
}
}
diff --git a/arrow/cdata/cdata_exports.go b/arrow/cdata/cdata_exports.go
index 4ed9d0e..d367348 100644
--- a/arrow/cdata/cdata_exports.go
+++ b/arrow/cdata/cdata_exports.go
@@ -154,6 +154,10 @@ func (exp *schemaExporter) exportFormat(dt arrow.DataType)
string {
return "g"
case *arrow.FixedSizeBinaryType:
return fmt.Sprintf("w:%d", dt.ByteWidth)
+ case *arrow.Decimal32Type:
+ return fmt.Sprintf("d:%d,%d,32", dt.Precision, dt.Scale)
+ case *arrow.Decimal64Type:
+ return fmt.Sprintf("d:%d,%d,64", dt.Precision, dt.Scale)
case *arrow.Decimal128Type:
return fmt.Sprintf("d:%d,%d", dt.Precision, dt.Scale)
case *arrow.Decimal256Type:
diff --git a/arrow/cdata/cdata_test.go b/arrow/cdata/cdata_test.go
index 697a73b..2a86ea6 100644
--- a/arrow/cdata/cdata_test.go
+++ b/arrow/cdata/cdata_test.go
@@ -153,7 +153,7 @@ func TestDecimalSchemaErrors(t *testing.T) {
{"d:a,2,3", "could not parse decimal precision in 'd:a,2,3':"},
{"d:1,a,3", "could not parse decimal scale in 'd:1,a,3':"},
{"d:1,2,a", "could not parse decimal bitwidth in 'd:1,2,a':"},
- {"d:1,2,384", "only decimal128 and decimal256 are supported,
got 'd:1,2,384'"},
+ {"d:1,2,384", "unsupported decimal bitwidth, got 'd:1,2,384'"},
}
for _, tt := range tests {
diff --git a/arrow/datatype.go b/arrow/datatype.go
index 2fba655..9556585 100644
--- a/arrow/datatype.go
+++ b/arrow/datatype.go
@@ -107,7 +107,7 @@ const (
// parameters.
DECIMAL128
- // DECIMAL256 is a precision and scale based decimal type, with 256 bit
max. not yet implemented
+ // DECIMAL256 is a precision and scale based decimal type, with 256 bit
max.
DECIMAL256
// LIST is a list of some logical data type
@@ -116,10 +116,10 @@ const (
// STRUCT of logical types
STRUCT
- // SPARSE_UNION of logical types. not yet implemented
+ // SPARSE_UNION of logical types
SPARSE_UNION
- // DENSE_UNION of logical types. not yet implemented
+ // DENSE_UNION of logical types
DENSE_UNION
// DICTIONARY aka Category type
@@ -138,13 +138,13 @@ const (
// or nanoseconds.
DURATION
- // like STRING, but 64-bit offsets. not yet implemented
+ // like STRING, but 64-bit offsets
LARGE_STRING
- // like BINARY but with 64-bit offsets, not yet implemented
+ // like BINARY but with 64-bit offsets
LARGE_BINARY
- // like LIST but with 64-bit offsets. not yet implemented
+ // like LIST but with 64-bit offsets
LARGE_LIST
// calendar interval with three fields
@@ -165,6 +165,12 @@ const (
// like LIST but with 64-bit offsets
LARGE_LIST_VIEW
+ // Decimal value with 32-bit representation
+ DECIMAL32
+
+ // Decimal value with 64-bit representation
+ DECIMAL64
+
// Alias to ensure we do not break any consumers
DECIMAL = DECIMAL128
)
@@ -365,10 +371,10 @@ func IsLargeBinaryLike(t Type) bool {
return false
}
-// IsFixedSizeBinary returns true for Decimal128/256 and FixedSizeBinary
+// IsFixedSizeBinary returns true for Decimal32/64/128/256 and FixedSizeBinary
func IsFixedSizeBinary(t Type) bool {
switch t {
- case DECIMAL128, DECIMAL256, FIXED_SIZE_BINARY:
+ case DECIMAL32, DECIMAL64, DECIMAL128, DECIMAL256, FIXED_SIZE_BINARY:
return true
}
return false
@@ -377,7 +383,7 @@ func IsFixedSizeBinary(t Type) bool {
// IsDecimal returns true for Decimal128 and Decimal256
func IsDecimal(t Type) bool {
switch t {
- case DECIMAL128, DECIMAL256:
+ case DECIMAL32, DECIMAL64, DECIMAL128, DECIMAL256:
return true
}
return false
diff --git a/arrow/datatype_fixedwidth.go b/arrow/datatype_fixedwidth.go
index 41c7b6f..5928be3 100644
--- a/arrow/datatype_fixedwidth.go
+++ b/arrow/datatype_fixedwidth.go
@@ -22,8 +22,9 @@ import (
"sync"
"time"
+ "github.com/apache/arrow-go/v18/arrow/decimal"
+ "github.com/apache/arrow-go/v18/arrow/internal/debug"
"github.com/apache/arrow-go/v18/internal/json"
-
"golang.org/x/xerrors"
)
@@ -532,19 +533,103 @@ type DecimalType interface {
DataType
GetPrecision() int32
GetScale() int32
+ BitWidth() int
+}
+
+// NarrowestDecimalType constructs the smallest decimal type that can represent
+// the requested precision. An error is returned if the requested precision
+// cannot be represented (prec <= 0 || prec > 76).
+//
+// For reference:
+//
+// prec in [ 1, 9] => Decimal32Type
+// prec in [10, 18] => Decimal64Type
+// prec in [19, 38] => Decimal128Type
+// prec in [39, 76] => Decimal256Type
+func NarrowestDecimalType(prec, scale int32) (DecimalType, error) {
+ switch {
+ case prec <= 0:
+ return nil, fmt.Errorf("%w: precision must be > 0 for decimal
types, got %d",
+ ErrInvalid, prec)
+ case prec <= int32(decimal.MaxPrecision[decimal.Decimal32]()):
+ return &Decimal32Type{Precision: prec, Scale: scale}, nil
+ case prec <= int32(decimal.MaxPrecision[decimal.Decimal64]()):
+ return &Decimal64Type{Precision: prec, Scale: scale}, nil
+ case prec <= int32(decimal.MaxPrecision[decimal.Decimal128]()):
+ return &Decimal128Type{Precision: prec, Scale: scale}, nil
+ case prec <= int32(decimal.MaxPrecision[decimal.Decimal256]()):
+ return &Decimal256Type{Precision: prec, Scale: scale}, nil
+ default:
+ return nil, fmt.Errorf("%w: invalid precision for decimal
types, %d",
+ ErrInvalid, prec)
+ }
}
func NewDecimalType(id Type, prec, scale int32) (DecimalType, error) {
switch id {
+ case DECIMAL32:
+ debug.Assert(prec <=
int32(decimal.MaxPrecision[decimal.Decimal32]()), "invalid precision for
decimal32")
+ return &Decimal32Type{Precision: prec, Scale: scale}, nil
+ case DECIMAL64:
+ debug.Assert(prec <=
int32(decimal.MaxPrecision[decimal.Decimal64]()), "invalid precision for
decimal64")
+ return &Decimal64Type{Precision: prec, Scale: scale}, nil
case DECIMAL128:
+ debug.Assert(prec <=
int32(decimal.MaxPrecision[decimal.Decimal128]()), "invalid precision for
decimal128")
return &Decimal128Type{Precision: prec, Scale: scale}, nil
case DECIMAL256:
+ debug.Assert(prec <=
int32(decimal.MaxPrecision[decimal.Decimal256]()), "invalid precision for
decimal256")
return &Decimal256Type{Precision: prec, Scale: scale}, nil
default:
- return nil, fmt.Errorf("%w: must use DECIMAL128 or DECIMAL256
to create a DecimalType", ErrInvalid)
+ return nil, fmt.Errorf("%w: must use one of the DECIMAL IDs to
create a DecimalType", ErrInvalid)
}
}
+// Decimal32Type represents a fixed-size 32-bit decimal type.
+type Decimal32Type struct {
+ Precision int32
+ Scale int32
+}
+
+func (*Decimal32Type) ID() Type { return DECIMAL32 }
+func (*Decimal32Type) Name() string { return "decimal32" }
+func (*Decimal32Type) BitWidth() int { return 32 }
+func (*Decimal32Type) Bytes() int { return Decimal32SizeBytes }
+func (t *Decimal32Type) String() string {
+ return fmt.Sprintf("%s(%d, %d)", t.Name(), t.Precision, t.Scale)
+}
+func (t *Decimal32Type) Fingerprint() string {
+ return fmt.Sprintf("%s[%d,%d,%d]", typeFingerprint(t), t.BitWidth(),
t.Precision, t.Scale)
+}
+func (t *Decimal32Type) GetPrecision() int32 { return t.Precision }
+func (t *Decimal32Type) GetScale() int32 { return t.Scale }
+
+func (Decimal32Type) Layout() DataTypeLayout {
+ return DataTypeLayout{Buffers: []BufferSpec{SpecBitmap(),
SpecFixedWidth(Decimal32SizeBytes)}}
+}
+
+// Decimal64Type represents a fixed-size 32-bit decimal type.
+type Decimal64Type struct {
+ Precision int32
+ Scale int32
+}
+
+func (*Decimal64Type) ID() Type { return DECIMAL64 }
+func (*Decimal64Type) Name() string { return "decimal64" }
+func (*Decimal64Type) BitWidth() int { return 64 }
+func (*Decimal64Type) Bytes() int { return Decimal64SizeBytes }
+func (t *Decimal64Type) String() string {
+ return fmt.Sprintf("%s(%d, %d)", t.Name(), t.Precision, t.Scale)
+}
+func (t *Decimal64Type) Fingerprint() string {
+ return fmt.Sprintf("%s[%d,%d,%d]", typeFingerprint(t), t.BitWidth(),
t.Precision, t.Scale)
+}
+func (t *Decimal64Type) GetPrecision() int32 { return t.Precision }
+func (t *Decimal64Type) GetScale() int32 { return t.Scale }
+
+func (Decimal64Type) Layout() DataTypeLayout {
+ return DataTypeLayout{Buffers: []BufferSpec{SpecBitmap(),
SpecFixedWidth(Decimal64SizeBytes)}}
+}
+
// Decimal128Type represents a fixed-size 128-bit decimal type.
type Decimal128Type struct {
Precision int32
diff --git a/arrow/datatype_fixedwidth_test.go
b/arrow/datatype_fixedwidth_test.go
index d60c6b1..bc899f3 100644
--- a/arrow/datatype_fixedwidth_test.go
+++ b/arrow/datatype_fixedwidth_test.go
@@ -23,6 +23,7 @@ import (
"github.com/apache/arrow-go/v18/arrow"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
)
// TestTimeUnit_String verifies each time unit matches its string
representation.
@@ -43,6 +44,60 @@ func TestTimeUnit_String(t *testing.T) {
}
}
+func TestDecimal32Type(t *testing.T) {
+ for _, tc := range []struct {
+ precision int32
+ scale int32
+ want string
+ }{
+ {1, 9, "decimal32(1, 9)"},
+ {9, 9, "decimal32(9, 9)"},
+ {9, 1, "decimal32(9, 1)"},
+ } {
+ t.Run(tc.want, func(t *testing.T) {
+ dt := arrow.Decimal32Type{Precision: tc.precision,
Scale: tc.scale}
+ if got, want := dt.BitWidth(), 32; got != want {
+ t.Fatalf("invalid bitwidth: got=%d, want=%d",
got, want)
+ }
+
+ if got, want := dt.ID(), arrow.DECIMAL32; got != want {
+ t.Fatalf("invalid type ID: got=%v, want=%v",
got, want)
+ }
+
+ if got, want := dt.String(), tc.want; got != want {
+ t.Fatalf("invalid stringer: got=%q, want=%q",
got, want)
+ }
+ })
+ }
+}
+
+func TestDecimal64Type(t *testing.T) {
+ for _, tc := range []struct {
+ precision int32
+ scale int32
+ want string
+ }{
+ {1, 10, "decimal64(1, 10)"},
+ {10, 10, "decimal64(10, 10)"},
+ {10, 1, "decimal64(10, 1)"},
+ } {
+ t.Run(tc.want, func(t *testing.T) {
+ dt := arrow.Decimal64Type{Precision: tc.precision,
Scale: tc.scale}
+ if got, want := dt.BitWidth(), 64; got != want {
+ t.Fatalf("invalid bitwidth: got=%d, want=%d",
got, want)
+ }
+
+ if got, want := dt.ID(), arrow.DECIMAL64; got != want {
+ t.Fatalf("invalid type ID: got=%v, want=%v",
got, want)
+ }
+
+ if got, want := dt.String(), tc.want; got != want {
+ t.Fatalf("invalid stringer: got=%q, want=%q",
got, want)
+ }
+ })
+ }
+}
+
func TestDecimal128Type(t *testing.T) {
for _, tc := range []struct {
precision int32
@@ -438,3 +493,36 @@ func TestDateFromTime(t *testing.T) {
assert.EqualValues(t, wantD64, arrow.Date64FromTime(tm))
assert.EqualValues(t, wantD32, arrow.Date32FromTime(tm))
}
+
+func TestNarrowestDecimalType(t *testing.T) {
+ tests := []struct {
+ min, max int32
+ expected arrow.Type
+ }{
+ {1, 9, arrow.DECIMAL32},
+ {10, 18, arrow.DECIMAL64},
+ {19, 38, arrow.DECIMAL128},
+ {39, 76, arrow.DECIMAL256},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.expected.String(), func(t *testing.T) {
+ for i := tt.min; i <= tt.max; i++ {
+ typ, err := arrow.NarrowestDecimalType(i, 5)
+ require.NoError(t, err)
+
+ assert.Equal(t, i, typ.GetPrecision())
+ assert.Equal(t, int32(5), typ.GetScale())
+ assert.Equal(t, tt.expected, typ.ID())
+ }
+ })
+ }
+
+ _, err := arrow.NarrowestDecimalType(-1, 5)
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, arrow.ErrInvalid)
+
+ _, err = arrow.NarrowestDecimalType(78, 5)
+ assert.Error(t, err)
+ assert.ErrorIs(t, err, arrow.ErrInvalid)
+}
diff --git a/arrow/decimal/decimal.go b/arrow/decimal/decimal.go
new file mode 100644
index 0000000..098a4e0
--- /dev/null
+++ b/arrow/decimal/decimal.go
@@ -0,0 +1,473 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package decimal
+
+import (
+ "errors"
+ "fmt"
+ "math"
+ "math/big"
+ "math/bits"
+ "unsafe"
+
+ "github.com/apache/arrow-go/v18/arrow/decimal128"
+ "github.com/apache/arrow-go/v18/arrow/decimal256"
+ "github.com/apache/arrow-go/v18/arrow/internal/debug"
+)
+
+// DecimalTypes is a generic constraint representing the implemented decimal
types
+// in this package, and a single point of update for future additions.
Everything
+// else is constrained by this.
+type DecimalTypes interface {
+ Decimal32 | Decimal64 | Decimal128 | Decimal256
+}
+
+// Num is an interface that is useful for building generic types for all
decimal
+// type implementations. It presents all the methods and interfaces necessary
to
+// operate on the decimal objects without having to care about the bit width.
+type Num[T DecimalTypes] interface {
+ Negate() T
+ Add(T) T
+ Sub(T) T
+ Mul(T) T
+ Div(T) (res, rem T)
+ Pow(T) T
+
+ FitsInPrecision(int32) bool
+ Abs() T
+ Sign() int
+ Rescale(int32, int32) (T, error)
+ Cmp(T) int
+
+ IncreaseScaleBy(int32) T
+ ReduceScaleBy(int32, bool) T
+
+ ToFloat32(int32) float32
+ ToFloat64(int32) float64
+ ToBigFloat(int32) *big.Float
+
+ ToString(int32) string
+}
+
+type (
+ Decimal32 int32
+ Decimal64 int64
+ Decimal128 = decimal128.Num
+ Decimal256 = decimal256.Num
+)
+
+func MaxPrecision[T DecimalTypes]() int {
+ // max precision is computed by Floor(log10(2^(nbytes * 8 - 1) - 1))
+ var z T
+ return int(math.Floor(math.Log10(math.Pow(2,
float64(unsafe.Sizeof(z))*8-1) - 1)))
+}
+
+func (d Decimal32) Negate() Decimal32 { return -d }
+func (d Decimal64) Negate() Decimal64 { return -d }
+
+func (d Decimal32) Add(rhs Decimal32) Decimal32 { return d + rhs }
+func (d Decimal64) Add(rhs Decimal64) Decimal64 { return d + rhs }
+
+func (d Decimal32) Sub(rhs Decimal32) Decimal32 { return d - rhs }
+func (d Decimal64) Sub(rhs Decimal64) Decimal64 { return d - rhs }
+
+func (d Decimal32) Mul(rhs Decimal32) Decimal32 { return d * rhs }
+func (d Decimal64) Mul(rhs Decimal64) Decimal64 { return d * rhs }
+
+func (d Decimal32) Div(rhs Decimal32) (res, rem Decimal32) {
+ return d / rhs, d % rhs
+}
+
+func (d Decimal64) Div(rhs Decimal64) (res, rem Decimal64) {
+ return d / rhs, d % rhs
+}
+
+// about 4-5x faster than using math.Pow which requires converting to float64
+// and back to integers
+func intPow[T int32 | int64](base, exp T) T {
+ result := T(1)
+ for {
+ if exp&1 == 1 {
+ result *= base
+ }
+ exp >>= 1
+ if exp == 0 {
+ break
+ }
+ base *= base
+ }
+
+ return result
+}
+
+func (d Decimal32) Pow(rhs Decimal32) Decimal32 {
+ return Decimal32(intPow(int32(d), int32(rhs)))
+}
+
+func (d Decimal64) Pow(rhs Decimal64) Decimal64 {
+ return Decimal64(intPow(int64(d), int64(rhs)))
+}
+
+func (d Decimal32) Sign() int {
+ if d == 0 {
+ return 0
+ }
+ return int(1 | (d >> 31))
+}
+
+func (d Decimal64) Sign() int {
+ if d == 0 {
+ return 0
+ }
+ return int(1 | (d >> 63))
+}
+
+func (n Decimal32) Abs() Decimal32 {
+ if n < 0 {
+ return -n
+ }
+ return n
+}
+
+func (n Decimal64) Abs() Decimal64 {
+ if n < 0 {
+ return -n
+ }
+ return n
+}
+
+func (n Decimal32) FitsInPrecision(prec int32) bool {
+ debug.Assert(prec > 0, "precision must be > 0")
+ debug.Assert(prec <= 9, "precision must be <= 9")
+ return n.Abs() < Decimal32(math.Pow10(int(prec)))
+}
+
+func (n Decimal64) FitsInPrecision(prec int32) bool {
+ debug.Assert(prec > 0, "precision must be > 0")
+ debug.Assert(prec <= 18, "precision must be <= 18")
+ return n.Abs() < Decimal64(math.Pow10(int(prec)))
+}
+
+func (n Decimal32) ToString(scale int32) string {
+ return n.ToBigFloat(scale).Text('f', int(scale))
+}
+
+func (n Decimal64) ToString(scale int32) string {
+ return n.ToBigFloat(scale).Text('f', int(scale))
+}
+
+var pt5 = big.NewFloat(0.5)
+
+func decimalFromString[T interface {
+ Decimal32 | Decimal64
+ FitsInPrecision(int32) bool
+}](v string, prec, scale int32) (n T, err error) {
+ var nbits = uint(unsafe.Sizeof(T(0))) * 8
+
+ var out *big.Float
+ out, _, err = big.ParseFloat(v, 10, nbits, big.ToNearestEven)
+
+ if scale < 0 {
+ var tmp big.Int
+ val, _ := out.Int(&tmp)
+ if val.BitLen() > int(nbits) {
+ return n, fmt.Errorf("bitlen too large for decimal%d",
nbits)
+ }
+
+ n = T(val.Int64() / int64(math.Pow10(int(-scale))))
+ } else {
+ var precInBits =
uint(math.Round(float64(prec+scale+1)/math.Log10(2))) + 1
+
+ p :=
(&big.Float{}).SetInt(big.NewInt(int64(math.Pow10(int(scale)))))
+ out.SetPrec(precInBits).Mul(out, p)
+ if out.Signbit() {
+ out.Sub(out, pt5)
+ } else {
+ out.Add(out, pt5)
+ }
+
+ var tmp big.Int
+ val, _ := out.Int(&tmp)
+ if val.BitLen() > int(nbits) {
+ return n, fmt.Errorf("bitlen too large for decimal%d",
nbits)
+ }
+ n = T(val.Int64())
+ }
+
+ if !n.FitsInPrecision(prec) {
+ err = fmt.Errorf("val %v doesn't fit in precision %d", n, prec)
+ }
+ return
+}
+
+func Decimal32FromString(v string, prec, scale int32) (n Decimal32, err error)
{
+ return decimalFromString[Decimal32](v, prec, scale)
+}
+
+func Decimal64FromString(v string, prec, scale int32) (n Decimal64, err error)
{
+ return decimalFromString[Decimal64](v, prec, scale)
+}
+
+func Decimal128FromString(v string, prec, scale int32) (n Decimal128, err
error) {
+ return decimal128.FromString(v, prec, scale)
+}
+
+func Decimal256FromString(v string, prec, scale int32) (n Decimal256, err
error) {
+ return decimal256.FromString(v, prec, scale)
+}
+
+func scalePositiveFloat64(v float64, prec, scale int32) (float64, error) {
+ v *= math.Pow10(int(scale))
+ v = math.RoundToEven(v)
+
+ maxabs := math.Pow10(int(prec))
+ if v >= maxabs {
+ return 0, fmt.Errorf("cannot convert %f to
decimal(precision=%d, scale=%d)", v, prec, scale)
+ }
+ return v, nil
+}
+
+func fromPositiveFloat[T Decimal32 | Decimal64, F float32 | float64](v F,
prec, scale int32) (T, error) {
+ if prec > int32(MaxPrecision[T]()) {
+ return T(0), fmt.Errorf("invalid precision %d for converting
float to Decimal", prec)
+ }
+
+ val, err := scalePositiveFloat64(float64(v), prec, scale)
+ if err != nil {
+ return T(0), err
+ }
+
+ return T(F(val)), nil
+}
+
+func Decimal32FromFloat[F float32 | float64](v F, prec, scale int32)
(Decimal32, error) {
+ if v < 0 {
+ dec, err := fromPositiveFloat[Decimal32](-v, prec, scale)
+ if err != nil {
+ return dec, err
+ }
+
+ return -dec, nil
+ }
+
+ return fromPositiveFloat[Decimal32](v, prec, scale)
+}
+
+func Decimal64FromFloat[F float32 | float64](v F, prec, scale int32)
(Decimal64, error) {
+ if v < 0 {
+ dec, err := fromPositiveFloat[Decimal64](-v, prec, scale)
+ if err != nil {
+ return dec, err
+ }
+
+ return -dec, nil
+ }
+
+ return fromPositiveFloat[Decimal64](v, prec, scale)
+}
+
+func Decimal128FromFloat(v float64, prec, scale int32) (Decimal128, error) {
+ return decimal128.FromFloat64(v, prec, scale)
+}
+
+func Decimal256FromFloat(v float64, prec, scale int32) (Decimal256, error) {
+ return decimal256.FromFloat64(v, prec, scale)
+}
+
+func (n Decimal32) ToFloat32(scale int32) float32 {
+ return float32(n.ToFloat64(scale))
+}
+
+func (n Decimal64) ToFloat32(scale int32) float32 {
+ return float32(n.ToFloat64(scale))
+}
+
+func (n Decimal32) ToFloat64(scale int32) float64 {
+ return float64(n) * math.Pow10(-int(scale))
+}
+
+func (n Decimal64) ToFloat64(scale int32) float64 {
+ return float64(n) * math.Pow10(-int(scale))
+}
+
+func (n Decimal32) ToBigFloat(scale int32) *big.Float {
+ f := (&big.Float{}).SetInt64(int64(n))
+ if scale < 0 {
+ f.SetPrec(32).Mul(f, (&big.Float{}).SetInt64(intPow(10,
-int64(scale))))
+ } else {
+ f.SetPrec(32).Quo(f, (&big.Float{}).SetInt64(intPow(10,
int64(scale))))
+ }
+ return f
+}
+
+func (n Decimal64) ToBigFloat(scale int32) *big.Float {
+ f := (&big.Float{}).SetInt64(int64(n))
+ if scale < 0 {
+ f.SetPrec(64).Mul(f, (&big.Float{}).SetInt64(intPow(10,
-int64(scale))))
+ } else {
+ f.SetPrec(64).Quo(f, (&big.Float{}).SetInt64(intPow(10,
int64(scale))))
+ }
+ return f
+}
+
+func cmpDec[T Decimal32 | Decimal64](lhs, rhs T) int {
+ switch {
+ case lhs > rhs:
+ return 1
+ case lhs < rhs:
+ return -1
+ }
+ return 0
+}
+
+func (n Decimal32) Cmp(other Decimal32) int {
+ return cmpDec(n, other)
+}
+
+func (n Decimal64) Cmp(other Decimal64) int {
+ return cmpDec(n, other)
+}
+
+func (n Decimal32) IncreaseScaleBy(increase int32) Decimal32 {
+ debug.Assert(increase >= 0, "invalid increase scale for decimal32")
+ debug.Assert(increase <= 9, "invalid increase scale for decimal32")
+
+ return n * Decimal32(intPow(10, increase))
+}
+
+func (n Decimal64) IncreaseScaleBy(increase int32) Decimal64 {
+ debug.Assert(increase >= 0, "invalid increase scale for decimal64")
+ debug.Assert(increase <= 18, "invalid increase scale for decimal64")
+
+ return n * Decimal64(intPow(10, int64(increase)))
+}
+
+func reduceScale[T interface {
+ Decimal32 | Decimal64
+ Abs() T
+}](n T, reduce int32, round bool) T {
+ if reduce == 0 {
+ return n
+ }
+
+ divisor := T(intPow(10, reduce))
+ if !round {
+ return n / divisor
+ }
+
+ quo, remainder := n/divisor, n%divisor
+ divisorHalf := divisor / 2
+ if remainder.Abs() >= divisorHalf {
+ if n > 0 {
+ quo++
+ } else {
+ quo--
+ }
+ }
+
+ return quo
+}
+
+func (n Decimal32) ReduceScaleBy(reduce int32, round bool) Decimal32 {
+ debug.Assert(reduce >= 0, "invalid reduce scale for decimal32")
+ debug.Assert(reduce <= 9, "invalid reduce scale for decimal32")
+
+ return reduceScale(n, reduce, round)
+}
+
+func (n Decimal64) ReduceScaleBy(reduce int32, round bool) Decimal64 {
+ debug.Assert(reduce >= 0, "invalid reduce scale for decimal32")
+ debug.Assert(reduce <= 18, "invalid reduce scale for decimal32")
+
+ return reduceScale(n, reduce, round)
+}
+
+//lint:ignore U1000 function is being used, staticcheck seems to not follow
generics
+func (n Decimal32) rescaleWouldCauseDataLoss(deltaScale int32, multiplier
Decimal32) (out Decimal32, loss bool) {
+ if deltaScale < 0 {
+ debug.Assert(multiplier != 0, "multiplier must not be zero")
+ quo, remainder := bits.Div32(0, uint32(n), uint32(multiplier))
+ return Decimal32(quo), remainder != 0
+ }
+
+ overflow, result := bits.Mul32(uint32(n), uint32(multiplier))
+ if overflow != 0 {
+ return Decimal32(result), true
+ }
+
+ out = Decimal32(result)
+ return out, out < n
+}
+
+//lint:ignore U1000 function is being used, staticcheck seems to not follow
generics
+func (n Decimal64) rescaleWouldCauseDataLoss(deltaScale int32, multiplier
Decimal64) (out Decimal64, loss bool) {
+ if deltaScale < 0 {
+ debug.Assert(multiplier != 0, "multiplier must not be zero")
+ quo, remainder := bits.Div32(0, uint32(n), uint32(multiplier))
+ return Decimal64(quo), remainder != 0
+ }
+
+ overflow, result := bits.Mul32(uint32(n), uint32(multiplier))
+ if overflow != 0 {
+ return Decimal64(result), true
+ }
+
+ out = Decimal64(result)
+ return out, out < n
+}
+
+func rescale[T interface {
+ Decimal32 | Decimal64
+ rescaleWouldCauseDataLoss(int32, T) (T, bool)
+ Sign() int
+}](n T, originalScale, newScale int32) (out T, err error) {
+ if originalScale == newScale {
+ return n, nil
+ }
+
+ deltaScale := newScale - originalScale
+ absDeltaScale := int32(math.Abs(float64(deltaScale)))
+
+ sign := n.Sign()
+ if n < 0 {
+ n = -n
+ }
+
+ multiplier := T(intPow(10, absDeltaScale))
+ var wouldHaveLoss bool
+ out, wouldHaveLoss = n.rescaleWouldCauseDataLoss(deltaScale, multiplier)
+ if wouldHaveLoss {
+ err = errors.New("rescale data loss")
+ }
+ out *= T(sign)
+ return
+}
+
+func (n Decimal32) Rescale(originalScale, newScale int32) (out Decimal32, err
error) {
+ return rescale(n, originalScale, newScale)
+}
+
+func (n Decimal64) Rescale(originalScale, newScale int32) (out Decimal64, err
error) {
+ return rescale(n, originalScale, newScale)
+}
+
+var (
+ _ Num[Decimal32] = Decimal32(0)
+ _ Num[Decimal64] = Decimal64(0)
+ _ Num[Decimal128] = Decimal128{}
+ _ Num[Decimal256] = Decimal256{}
+)
diff --git a/arrow/decimal/decimal_test.go b/arrow/decimal/decimal_test.go
new file mode 100644
index 0000000..9b4f04f
--- /dev/null
+++ b/arrow/decimal/decimal_test.go
@@ -0,0 +1,470 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package decimal_test
+
+import (
+ "fmt"
+ "math"
+ "math/big"
+ "strconv"
+ "testing"
+
+ "github.com/apache/arrow-go/v18/arrow/decimal"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func ulps64(actual, expected float64) int64 {
+ ulp := math.Nextafter(actual, math.Inf(1)) - actual
+ return int64(math.Abs((expected - actual) / ulp))
+}
+
+func ulps32(actual, expected float32) int64 {
+ ulp := math.Nextafter32(actual, float32(math.Inf(1))) - actual
+ return int64(math.Abs(float64((expected - actual) / ulp)))
+}
+
+func assertFloat32Approx(t *testing.T, x, y float32) bool {
+ const maxulps int64 = 4
+ ulps := ulps32(x, y)
+ return assert.LessOrEqualf(t, ulps, maxulps, "%f not equal to %f (%d
ulps)", x, y, ulps)
+}
+
+func assertFloat64Approx(t *testing.T, x, y float64) bool {
+ const maxulps int64 = 4
+ ulps := ulps64(x, y)
+ return assert.LessOrEqualf(t, ulps, maxulps, "%f not equal to %f (%d
ulps)", x, y, ulps)
+}
+
+func TestDecimalToReal(t *testing.T) {
+ tests := []struct {
+ decimalVal string
+ scale int32
+ exp float64
+ }{
+ {"0", 0, 0},
+ {"0", 10, 0.0},
+ {"0", -10, 0.0},
+ {"1", 0, 1.0},
+ {"12345", 0, 12345.0},
+ {"12345", 1, 1234.5},
+ {"536870912", 0, math.Pow(2, 29)},
+ }
+
+ t.Run("float32", func(t *testing.T) {
+ checkDecimalToFloat := func(t *testing.T, str string, v
float32, scale int32) {
+ n, err := decimal.Decimal32FromString(str, 9, 0)
+ require.NoError(t, err)
+ assert.Equalf(t, v, n.ToFloat32(scale), "Decimal Val:
%s, Scale: %d", str, scale)
+
+ n64, err := decimal.Decimal64FromString(str, 18, 0)
+ require.NoError(t, err)
+ assert.Equalf(t, v, n64.ToFloat32(scale), "Decimal Val:
%s, Scale: %d", str, scale)
+ }
+ for _, tt := range tests {
+ t.Run(tt.decimalVal, func(t *testing.T) {
+ checkDecimalToFloat(t, tt.decimalVal,
float32(tt.exp), tt.scale)
+ if tt.decimalVal != "0" {
+ checkDecimalToFloat(t,
"-"+tt.decimalVal, float32(-tt.exp), tt.scale)
+ }
+ })
+ }
+
+ t.Run("large values", func(t *testing.T) {
+ checkApproxDecimaltoFloat := func(str string, v
float32, scale int32) {
+ n, err := decimal.Decimal32FromString(str, 9, 0)
+ require.NoError(t, err)
+ assertFloat32Approx(t, v, n.ToFloat32(scale))
+ }
+
+ checkApproxDecimal64toFloat := func(str string, v
float32, scale int32) {
+ n, err := decimal.Decimal64FromString(str, 9, 0)
+ require.NoError(t, err)
+ assertFloat32Approx(t, v, n.ToFloat32(scale))
+ }
+
+ // exact comparisons would succeed on most platforms,
but not all power-of-ten
+ // factors are exactly representable in binary floating
point, so we'll use
+ // approx and ensure that the values are within 4 ULP
(unit of least precision)
+ for scale := int32(-9); scale <= 9; scale++ {
+ checkApproxDecimaltoFloat("1",
float32(math.Pow10(-int(scale))), scale)
+ checkApproxDecimaltoFloat("123",
float32(123)*float32(math.Pow10(-int(scale))), scale)
+ }
+
+ for scale := int32(-18); scale <= 18; scale++ {
+ checkApproxDecimal64toFloat("1",
float32(math.Pow10(-int(scale))), scale)
+ checkApproxDecimal64toFloat("123",
float32(123)*float32(math.Pow10(-int(scale))), scale)
+ }
+ })
+ })
+
+ t.Run("float64", func(t *testing.T) {
+ checkDecimalToFloat := func(t *testing.T, str string, v
float64, scale int32) {
+ n, err := decimal.Decimal32FromString(str, 9, 0)
+ require.NoError(t, err)
+ assert.Equalf(t, v, n.ToFloat64(scale), "Decimal Val:
%s, Scale: %d", str, scale)
+
+ assert.Equalf(t, big.NewFloat(v).SetPrec(32),
n.ToBigFloat(scale),
+ "Decimal Val: %s, Scale: %d", str, scale)
+
+ n64, err := decimal.Decimal64FromString(str, 18, 0)
+ require.NoError(t, err)
+ assert.Equalf(t, v, n64.ToFloat64(scale), "Decimal Val:
%s, Scale: %d", str, scale)
+ assert.Equalf(t, big.NewFloat(v).SetPrec(64),
n64.ToBigFloat(scale),
+ "Decimal Val: %s, Scale: %d", str, scale)
+ }
+ for _, tt := range tests {
+ t.Run(tt.decimalVal, func(t *testing.T) {
+ checkDecimalToFloat(t, tt.decimalVal, tt.exp,
tt.scale)
+ if tt.decimalVal != "0" {
+ checkDecimalToFloat(t,
"-"+tt.decimalVal, -tt.exp, tt.scale)
+ }
+ })
+ }
+
+ t.Run("large values", func(t *testing.T) {
+ checkApproxDecimaltoFloat := func(str string, v
float64, scale int32) {
+ n, err := decimal.Decimal32FromString(str, 9, 0)
+ require.NoError(t, err)
+ assertFloat64Approx(t, v, n.ToFloat64(scale))
+
+ assert.Equalf(t, big.NewFloat(v).SetPrec(32),
n.ToBigFloat(scale),
+ "Decimal Val: %s, Scale: %d", str,
scale)
+ }
+
+ checkApproxDecimal64toFloat := func(str string, v
float64, scale int32) {
+ n, err := decimal.Decimal64FromString(str, 9, 0)
+ require.NoError(t, err)
+ assertFloat64Approx(t, v, n.ToFloat64(scale))
+
+ bf, _ := n.ToBigFloat(scale).Float64()
+ assertFloat64Approx(t, v, bf)
+ }
+
+ // exact comparisons would succeed on most platforms,
but not all power-of-ten
+ // factors are exactly representable in binary floating
point, so we'll use
+ // approx and ensure that the values are within 4 ULP
(unit of least precision)
+ for scale := int32(-9); scale <= 9; scale++ {
+ checkApproxDecimaltoFloat("1",
math.Pow10(-int(scale)), scale)
+ checkApproxDecimaltoFloat("123",
float64(123)*math.Pow10(-int(scale)), scale)
+ }
+
+ for scale := int32(-18); scale <= 18; scale++ {
+ checkApproxDecimal64toFloat("1",
math.Pow10(-int(scale)), scale)
+ checkApproxDecimal64toFloat("123",
float64(123)*math.Pow10(-int(scale)), scale)
+ }
+ })
+ })
+}
+
+func TestDecimalFromFloat(t *testing.T) {
+ tests := []struct {
+ val float64
+ precision, scale int32
+ expected string
+ }{
+ {0, 1, 0, "0"},
+ {-0, 1, 0, "0"},
+ {0, 9, 4, "0.0000"},
+ {math.Copysign(0.0, -1), 9, 4, "0.0000"},
+ {123, 7, 4, "123.0000"},
+ {-123, 7, 4, "-123.0000"},
+ {456.78, 7, 4, "456.7800"},
+ {-456.78, 7, 4, "-456.7800"},
+ {456.784, 5, 2, "456.78"},
+ {-456.784, 5, 2, "-456.78"},
+ {456.786, 5, 2, "456.79"},
+ {-456.786, 5, 2, "-456.79"},
+ {999.99, 5, 2, "999.99"},
+ {-999.99, 5, 2, "-999.99"},
+ {123, 9, 0, "123"},
+ {-123, 9, 0, "-123"},
+ {123.4, 9, 0, "123"},
+ {-123.4, 9, 0, "-123"},
+ {123.6, 9, 0, "124"},
+ {-123.6, 9, 0, "-124"},
+ }
+
+ t.Run("float64", func(t *testing.T) {
+ for _, tt := range tests {
+ t.Run(tt.expected, func(t *testing.T) {
+ n, err := decimal.Decimal32FromFloat(tt.val,
tt.precision, tt.scale)
+ require.NoError(t, err)
+ assert.Equal(t, tt.expected,
fmt.Sprintf("%."+strconv.Itoa(int(tt.scale))+"f", n.ToFloat64(tt.scale)))
+ })
+ }
+
+ t.Run("large values", func(t *testing.T) {
+ // test entire float64 range
+ for scale := int32(-308); scale <= 308; scale++ {
+ val := math.Pow10(int(scale))
+ n, err := decimal.Decimal64FromFloat(val, 1,
-scale)
+ require.NoError(t, err)
+ assert.EqualValues(t, 1, n)
+ }
+
+ for scale := int32(-307); scale <= 306; scale++ {
+ val := 123 * math.Pow10(int(scale))
+ n, err := decimal.Decimal64FromFloat(val, 2,
-scale-1)
+ require.NoError(t, err)
+ assert.EqualValues(t, 12, n)
+ n, err = decimal.Decimal64FromFloat(val, 3,
-scale)
+ require.NoError(t, err)
+ assert.EqualValues(t, 123, n)
+ n, err = decimal.Decimal64FromFloat(val, 4,
-scale+1)
+ require.NoError(t, err)
+ assert.EqualValues(t, 1230, n)
+ }
+ })
+ })
+
+ t.Run("float32", func(t *testing.T) {
+ for _, tt := range tests {
+ t.Run(tt.expected, func(t *testing.T) {
+ n, err :=
decimal.Decimal32FromFloat(float32(tt.val), tt.precision, tt.scale)
+ require.NoError(t, err)
+ assert.Equal(t, tt.expected,
fmt.Sprintf("%."+strconv.Itoa(int(tt.scale))+"f", n.ToFloat32(tt.scale)))
+ })
+ }
+
+ t.Run("large values", func(t *testing.T) {
+ // test entire float32 range
+ for scale := int32(-38); scale <= 38; scale++ {
+ val := float32(math.Pow10(int(scale)))
+ n, err := decimal.Decimal64FromFloat(val, 1,
-scale)
+ require.NoError(t, err)
+ assert.EqualValues(t, 1, n)
+ }
+
+ for scale := int32(-37); scale <= 36; scale++ {
+ val := 123 * float32(math.Pow10(int(scale)))
+ n, err := decimal.Decimal64FromFloat(val, 2,
-scale-1)
+ require.NoError(t, err)
+ assert.EqualValues(t, 12, n)
+ n, err = decimal.Decimal64FromFloat(val, 3,
-scale)
+ require.NoError(t, err)
+ assert.EqualValues(t, 123, n)
+ n, err = decimal.Decimal64FromFloat(val, 4,
-scale+1)
+ require.NoError(t, err)
+ assert.EqualValues(t, 1230, n)
+ }
+ })
+ })
+}
+
+func TestFromString(t *testing.T) {
+ tests := []struct {
+ s string
+ expected int64
+ expectedScale int32
+ }{
+ {"12.3", 123, 1},
+ {"0.00123", 123, 5},
+ {"1.23e-8", 123, 10},
+ {"-1.23E-8", -123, 10},
+ {"1.23e+3", 1230, 0},
+ {"-1.23E+3", -1230, 0},
+ {"1.23e+5", 123000, 0},
+ {"1.2345E+7", 12345000, 0},
+ {"1.23e-8", 123, 10},
+ {"-1.23E-8", -123, 10},
+ {"0000000", 0, 0},
+ {"000.0000", 0, 4},
+ {".0000", 0, 5},
+ {"1e1", 10, 0},
+ {"+234.567", 234567, 3},
+ {"1e-8", 1, 8},
+ {"2112.33", 211233, 2},
+ {"-2112.33", -211233, 2},
+ {"12E2", 12, -2},
+ }
+
+ for _, tt := range tests {
+ t.Run(fmt.Sprintf("%s_%d", tt.s, tt.expectedScale), func(t
*testing.T) {
+ n, err := decimal.Decimal32FromString(tt.s, 8,
tt.expectedScale)
+ require.NoError(t, err)
+
+ ex := decimal.Decimal32(tt.expected)
+ assert.Equal(t, ex, n)
+
+ n64, err := decimal.Decimal64FromString(tt.s, 8,
tt.expectedScale)
+ require.NoError(t, err)
+
+ ex64 := decimal.Decimal64(tt.expected)
+ assert.Equal(t, ex64, n64)
+ })
+ }
+}
+
+func TestCmp(t *testing.T) {
+ for _, tc := range []struct {
+ n decimal.Decimal32
+ rhs decimal.Decimal32
+ want int
+ }{
+ {decimal.Decimal32(2), decimal.Decimal32(1), 1},
+ {decimal.Decimal32(-1), decimal.Decimal32(-2), 1},
+ {decimal.Decimal32(2), decimal.Decimal32(3), -1},
+ {decimal.Decimal32(-3), decimal.Decimal32(-2), -1},
+ {decimal.Decimal32(2), decimal.Decimal32(2), 0},
+ {decimal.Decimal32(-2), decimal.Decimal32(-2), 0},
+ } {
+ t.Run("cmp", func(t *testing.T) {
+ n := tc.n.Cmp(tc.rhs)
+ if got, want := n, tc.want; got != want {
+ t.Fatalf("invalid value. got=%v, want=%v", got,
want)
+ }
+ })
+ }
+
+ for _, tc := range []struct {
+ n decimal.Decimal64
+ rhs decimal.Decimal64
+ want int
+ }{
+ {decimal.Decimal64(2), decimal.Decimal64(1), 1},
+ {decimal.Decimal64(-1), decimal.Decimal64(-2), 1},
+ {decimal.Decimal64(2), decimal.Decimal64(3), -1},
+ {decimal.Decimal64(-3), decimal.Decimal64(-2), -1},
+ {decimal.Decimal64(2), decimal.Decimal64(2), 0},
+ {decimal.Decimal64(-2), decimal.Decimal64(-2), 0},
+ } {
+ t.Run("cmp", func(t *testing.T) {
+ n := tc.n.Cmp(tc.rhs)
+ if got, want := n, tc.want; got != want {
+ t.Fatalf("invalid value. got=%v, want=%v", got,
want)
+ }
+ })
+ }
+}
+
+func TestDecimalRescale(t *testing.T) {
+ tests := []struct {
+ orig, exp int32
+ oldScale, newScale int32
+ }{
+ {111, 11100, 0, 2},
+ {11100, 111, 2, 0},
+ {500000, 5, 6, 1},
+ {5, 500000, 1, 6},
+ {-111, -11100, 0, 2},
+ {-11100, -111, 2, 0},
+ {555, 555, 2, 2},
+ }
+
+ for _, tt := range tests {
+ t.Run("decimal32", func(t *testing.T) {
+ out, err :=
decimal.Decimal32(tt.orig).Rescale(tt.oldScale, tt.newScale)
+ require.NoError(t, err)
+ assert.Equal(t, decimal.Decimal32(tt.exp), out)
+ })
+ t.Run("decimal64", func(t *testing.T) {
+ out, err :=
decimal.Decimal64(tt.orig).Rescale(tt.oldScale, tt.newScale)
+ require.NoError(t, err)
+ assert.Equal(t, decimal.Decimal64(tt.exp), out)
+ })
+ }
+
+ _, err := decimal.Decimal32(555555).Rescale(6, 1)
+ assert.Error(t, err)
+ _, err = decimal.Decimal64(555555).Rescale(6, 1)
+ assert.Error(t, err)
+
+ _, err = decimal.Decimal32(555555).Rescale(0, 5)
+ assert.ErrorContains(t, err, "rescale data loss")
+ _, err = decimal.Decimal64(555555).Rescale(0, 5)
+ assert.ErrorContains(t, err, "rescale data loss")
+}
+
+func TestDecimalIncreaseScale(t *testing.T) {
+ assert.Equal(t, decimal.Decimal32(1234),
decimal.Decimal32(1234).IncreaseScaleBy(0))
+ assert.Equal(t, decimal.Decimal32(1234000),
decimal.Decimal32(1234).IncreaseScaleBy(3))
+ assert.Equal(t, decimal.Decimal32(-1234000),
decimal.Decimal32(-1234).IncreaseScaleBy(3))
+
+ assert.Equal(t, decimal.Decimal64(1234),
decimal.Decimal64(1234).IncreaseScaleBy(0))
+ assert.Equal(t, decimal.Decimal64(1234000),
decimal.Decimal64(1234).IncreaseScaleBy(3))
+ assert.Equal(t, decimal.Decimal64(-1234000),
decimal.Decimal64(-1234).IncreaseScaleBy(3))
+}
+
+func TestDecimalReduceScale(t *testing.T) {
+ tests := []struct {
+ value int32
+ scale int32
+ round bool
+ expected int32
+ }{
+ {123456, 0, false, 123456},
+ {123456, 1, false, 12345},
+ {123456, 1, true, 12346},
+ {123451, 1, true, 12345},
+ {123789, 2, true, 1238},
+ {123749, 2, true, 1237},
+ {123750, 2, true, 1238},
+ {5, 1, true, 1},
+ {0, 1, true, 0},
+ }
+
+ for _, tt := range tests {
+ assert.Equal(t, decimal.Decimal32(tt.expected),
+ decimal.Decimal32(tt.value).ReduceScaleBy(tt.scale,
tt.round), "decimal32")
+ assert.Equal(t, decimal.Decimal32(tt.expected).Negate(),
+
decimal.Decimal32(tt.value).Negate().ReduceScaleBy(tt.scale, tt.round),
"decimal32")
+ assert.Equal(t, decimal.Decimal64(tt.expected),
+ decimal.Decimal64(tt.value).ReduceScaleBy(tt.scale,
tt.round), "decimal64")
+ assert.Equal(t, decimal.Decimal64(tt.expected).Negate(),
+
decimal.Decimal64(tt.value).Negate().ReduceScaleBy(tt.scale, tt.round),
"decimal64")
+ }
+}
+
+func TestDecimalBasics(t *testing.T) {
+ tests := []struct {
+ lhs, rhs int32
+ }{
+ {100, 3},
+ {200, 3},
+ {20100, 301},
+ {-20100, 301},
+ {20100, -301},
+ {-20100, -301},
+ }
+
+ for _, tt := range tests {
+ assert.EqualValues(t, tt.lhs+tt.rhs,
+
decimal.Decimal32(tt.lhs).Add(decimal.Decimal32(tt.rhs)))
+ assert.EqualValues(t, tt.lhs+tt.rhs,
+
decimal.Decimal64(tt.lhs).Add(decimal.Decimal64(tt.rhs)))
+
+ assert.EqualValues(t, tt.lhs-tt.rhs,
+
decimal.Decimal32(tt.lhs).Sub(decimal.Decimal32(tt.rhs)))
+ assert.EqualValues(t, tt.lhs-tt.rhs,
+
decimal.Decimal64(tt.lhs).Sub(decimal.Decimal64(tt.rhs)))
+
+ assert.EqualValues(t, tt.lhs*tt.rhs,
+
decimal.Decimal32(tt.lhs).Mul(decimal.Decimal32(tt.rhs)))
+ assert.EqualValues(t, tt.lhs*tt.rhs,
+
decimal.Decimal64(tt.lhs).Mul(decimal.Decimal64(tt.rhs)))
+
+ expdiv, expmod := tt.lhs/tt.rhs, tt.lhs%tt.rhs
+ div, mod :=
decimal.Decimal32(tt.lhs).Div(decimal.Decimal32(tt.rhs))
+ assert.EqualValues(t, expdiv, div)
+ assert.EqualValues(t, expmod, mod)
+
+ div64, mod64 :=
decimal.Decimal64(tt.lhs).Div(decimal.Decimal64(tt.rhs))
+ assert.EqualValues(t, expdiv, div64)
+ assert.EqualValues(t, expmod, mod64)
+ }
+}
diff --git a/arrow/decimal/traits.go b/arrow/decimal/traits.go
new file mode 100644
index 0000000..0ec0c31
--- /dev/null
+++ b/arrow/decimal/traits.go
@@ -0,0 +1,78 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package decimal
+
+// Traits is a convenience for building generic objects for operating on
+// Decimal values to get around the limitations of Go generics. By providing
this
+// interface a generic object can handle producing the proper types to generate
+// new decimal values.
+type Traits[T DecimalTypes] interface {
+ BytesRequired(int) int
+ FromString(string, int32, int32) (T, error)
+ FromFloat64(float64, int32, int32) (T, error)
+}
+
+var (
+ Dec32Traits dec32Traits
+ Dec64Traits dec64Traits
+ Dec128Traits dec128Traits
+ Dec256Traits dec256Traits
+)
+
+type (
+ dec32Traits struct{}
+ dec64Traits struct{}
+ dec128Traits struct{}
+ dec256Traits struct{}
+)
+
+func (dec32Traits) BytesRequired(n int) int { return 4 * n }
+func (dec64Traits) BytesRequired(n int) int { return 8 * n }
+func (dec128Traits) BytesRequired(n int) int { return 16 * n }
+func (dec256Traits) BytesRequired(n int) int { return 32 * n }
+
+func (dec32Traits) FromString(v string, prec, scale int32) (Decimal32, error) {
+ return Decimal32FromString(v, prec, scale)
+}
+
+func (dec64Traits) FromString(v string, prec, scale int32) (Decimal64, error) {
+ return Decimal64FromString(v, prec, scale)
+}
+
+func (dec128Traits) FromString(v string, prec, scale int32) (Decimal128,
error) {
+ return Decimal128FromString(v, prec, scale)
+}
+
+func (dec256Traits) FromString(v string, prec, scale int32) (Decimal256,
error) {
+ return Decimal256FromString(v, prec, scale)
+}
+
+func (dec32Traits) FromFloat64(v float64, prec, scale int32) (Decimal32,
error) {
+ return Decimal32FromFloat(v, prec, scale)
+}
+
+func (dec64Traits) FromFloat64(v float64, prec, scale int32) (Decimal64,
error) {
+ return Decimal64FromFloat(v, prec, scale)
+}
+
+func (dec128Traits) FromFloat64(v float64, prec, scale int32) (Decimal128,
error) {
+ return Decimal128FromFloat(v, prec, scale)
+}
+
+func (dec256Traits) FromFloat64(v float64, prec, scale int32) (Decimal256,
error) {
+ return Decimal256FromFloat(v, prec, scale)
+}
diff --git a/arrow/decimal128/decimal128.go b/arrow/decimal128/decimal128.go
index 2e451c1..660c413 100644
--- a/arrow/decimal128/decimal128.go
+++ b/arrow/decimal128/decimal128.go
@@ -327,6 +327,16 @@ func (n Num) ToFloat64(scale int32) float64 {
return n.tofloat64Positive(scale)
}
+func (n Num) ToBigFloat(scale int32) *big.Float {
+ f := (&big.Float{}).SetInt(n.BigInt())
+ if scale < 0 {
+ f.SetPrec(128).Mul(f,
(&big.Float{}).SetInt(scaleMultipliers[-scale].BigInt()))
+ } else {
+ f.SetPrec(128).Quo(f,
(&big.Float{}).SetInt(scaleMultipliers[scale].BigInt()))
+ }
+ return f
+}
+
// LowBits returns the low bits of the two's complement representation of the
number.
func (n Num) LowBits() uint64 { return n.lo }
diff --git a/arrow/decimal256/decimal256.go b/arrow/decimal256/decimal256.go
index 76b6185..82c52a6 100644
--- a/arrow/decimal256/decimal256.go
+++ b/arrow/decimal256/decimal256.go
@@ -339,6 +339,16 @@ func (n Num) ToFloat64(scale int32) float64 {
return n.tofloat64Positive(scale)
}
+func (n Num) ToBigFloat(scale int32) *big.Float {
+ f := (&big.Float{}).SetInt(n.BigInt())
+ if scale < 0 {
+ f.SetPrec(256).Mul(f,
(&big.Float{}).SetInt(scaleMultipliers[-scale].BigInt()))
+ } else {
+ f.SetPrec(256).Quo(f,
(&big.Float{}).SetInt(scaleMultipliers[scale].BigInt()))
+ }
+ return f
+}
+
func (n Num) Sign() int {
if n == (Num{}) {
return 0
diff --git a/arrow/internal/arrjson/arrjson.go
b/arrow/internal/arrjson/arrjson.go
index 452809e..2181ebd 100644
--- a/arrow/internal/arrjson/arrjson.go
+++ b/arrow/internal/arrjson/arrjson.go
@@ -29,6 +29,7 @@ import (
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/array"
"github.com/apache/arrow-go/v18/arrow/bitutil"
+ "github.com/apache/arrow-go/v18/arrow/decimal"
"github.com/apache/arrow-go/v18/arrow/decimal128"
"github.com/apache/arrow-go/v18/arrow/decimal256"
"github.com/apache/arrow-go/v18/arrow/float16"
@@ -224,6 +225,10 @@ func typeToJSON(arrowType arrow.DataType)
(json.RawMessage, error) {
typ = listSizeJSON{"fixedsizelist", dt.Len()}
case *arrow.FixedSizeBinaryType:
typ = byteWidthJSON{"fixedsizebinary", dt.ByteWidth}
+ case *arrow.Decimal32Type:
+ typ = decimalJSON{"decimal", int(dt.Scale), int(dt.Precision),
32}
+ case *arrow.Decimal64Type:
+ typ = decimalJSON{"decimal", int(dt.Scale), int(dt.Precision),
64}
case *arrow.Decimal128Type:
typ = decimalJSON{"decimal", int(dt.Scale), int(dt.Precision),
128}
case *arrow.Decimal256Type:
@@ -491,6 +496,10 @@ func typeFromJSON(typ json.RawMessage, children
[]FieldWrapper) (arrowType arrow
arrowType = &arrow.Decimal256Type{Precision:
int32(t.Precision), Scale: int32(t.Scale)}
case 128, 0: // default to 128 bits when missing
arrowType = &arrow.Decimal128Type{Precision:
int32(t.Precision), Scale: int32(t.Scale)}
+ case 64:
+ arrowType = &arrow.Decimal64Type{Precision:
int32(t.Precision), Scale: int32(t.Scale)}
+ case 32:
+ arrowType = &arrow.Decimal32Type{Precision:
int32(t.Precision), Scale: int32(t.Scale)}
}
case "union":
t := unionJSON{}
@@ -1295,6 +1304,22 @@ func arrayFromJSON(mem memory.Allocator, dt
arrow.DataType, arr Array) arrow.Arr
bldr.AppendValues(data, valids)
return returnNewArrayData(bldr)
+ case *arrow.Decimal32Type:
+ bldr := array.NewDecimal32Builder(mem, dt)
+ defer bldr.Release()
+ data := decimal32FromJSON(arr.Data)
+ valids := validsFromJSON(arr.Valids)
+ bldr.AppendValues(data, valids)
+ return returnNewArrayData(bldr)
+
+ case *arrow.Decimal64Type:
+ bldr := array.NewDecimal64Builder(mem, dt)
+ defer bldr.Release()
+ data := decimal64FromJSON(arr.Data)
+ valids := validsFromJSON(arr.Valids)
+ bldr.AppendValues(data, valids)
+ return returnNewArrayData(bldr)
+
case *arrow.Decimal128Type:
bldr := array.NewDecimal128Builder(mem, dt)
defer bldr.Release()
@@ -1713,6 +1738,22 @@ func arrayToJSON(field arrow.Field, arr arrow.Array)
Array {
Valids: validsToJSON(arr),
}
+ case *array.Decimal32:
+ return Array{
+ Name: field.Name,
+ Count: arr.Len(),
+ Data: decimal32ToJSON(arr),
+ Valids: validsToJSON(arr),
+ }
+
+ case *array.Decimal64:
+ return Array{
+ Name: field.Name,
+ Count: arr.Len(),
+ Data: decimal64ToJSON(arr),
+ Valids: validsToJSON(arr),
+ }
+
case *array.Decimal128:
return Array{
Name: field.Name,
@@ -2038,6 +2079,47 @@ func f64ToJSON(arr *array.Float64) []interface{} {
return o
}
+func decimal32ToJSON(arr *array.Decimal32) []interface{} {
+ o := make([]interface{}, arr.Len())
+ for i := range o {
+ o[i] = arr.ValueStr(i)
+ }
+ return o
+}
+
+func decimal32FromJSON(vs []interface{}) []decimal.Decimal32 {
+ var tmp big.Int
+ o := make([]decimal.Decimal32, len(vs))
+ for i, v := range vs {
+ if err := tmp.UnmarshalJSON([]byte(v.(string))); err != nil {
+ panic(fmt.Errorf("could not convert %v (%T) to
decimal32: %w", v, v, err))
+ }
+
+ o[i] = decimal.Decimal32(tmp.Int64())
+ }
+ return o
+}
+
+func decimal64ToJSON(arr *array.Decimal64) []interface{} {
+ o := make([]interface{}, arr.Len())
+ for i := range o {
+ o[i] = arr.ValueStr(i)
+ }
+ return o
+}
+
+func decimal64FromJSON(vs []interface{}) []decimal.Decimal64 {
+ var tmp big.Int
+ o := make([]decimal.Decimal64, len(vs))
+ for i, v := range vs {
+ if err := tmp.UnmarshalJSON([]byte(v.(string))); err != nil {
+ panic(fmt.Errorf("could not convert %v (%T) to
decimal64: %w", v, v, err))
+ }
+
+ o[i] = decimal.Decimal64(tmp.Int64())
+ }
+ return o
+}
func decimal128ToJSON(arr *array.Decimal128) []interface{} {
o := make([]interface{}, arr.Len())
for i := range o {
@@ -2072,7 +2154,7 @@ func decimal256FromJSON(vs []interface{})
[]decimal256.Num {
o := make([]decimal256.Num, len(vs))
for i, v := range vs {
if err := tmp.UnmarshalJSON([]byte(v.(string))); err != nil {
- panic(fmt.Errorf("could not convert %v (%T) to
decimal128: %w", v, v, err))
+ panic(fmt.Errorf("could not convert %v (%T) to
decimal256: %w", v, v, err))
}
o[i] = decimal256.FromBigInt(&tmp)
diff --git a/arrow/internal/flatbuf/Decimal.go
b/arrow/internal/flatbuf/Decimal.go
index 2fc9d5a..234c396 100644
--- a/arrow/internal/flatbuf/Decimal.go
+++ b/arrow/internal/flatbuf/Decimal.go
@@ -22,10 +22,10 @@ import (
flatbuffers "github.com/google/flatbuffers/go"
)
-// / Exact decimal value represented as an integer value in two's
-// / complement. Currently only 128-bit (16-byte) and 256-bit (32-byte)
integers
-// / are used. The representation uses the endianness indicated
-// / in the Schema.
+/// Exact decimal value represented as an integer value in two's
+/// complement. Currently 32-bit (4-byte), 64-bit (8-byte),
+/// 128-bit (16-byte) and 256-bit (32-byte) integers are used.
+/// The representation uses the endianness indicated in the Schema.
type Decimal struct {
_tab flatbuffers.Table
}
@@ -46,7 +46,7 @@ func (rcv *Decimal) Table() flatbuffers.Table {
return rcv._tab
}
-// / Total number of decimal digits
+/// Total number of decimal digits
func (rcv *Decimal) Precision() int32 {
o := flatbuffers.UOffsetT(rcv._tab.Offset(4))
if o != 0 {
@@ -55,12 +55,12 @@ func (rcv *Decimal) Precision() int32 {
return 0
}
-// / Total number of decimal digits
+/// Total number of decimal digits
func (rcv *Decimal) MutatePrecision(n int32) bool {
return rcv._tab.MutateInt32Slot(4, n)
}
-// / Number of digits after the decimal point "."
+/// Number of digits after the decimal point "."
func (rcv *Decimal) Scale() int32 {
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
if o != 0 {
@@ -69,13 +69,13 @@ func (rcv *Decimal) Scale() int32 {
return 0
}
-// / Number of digits after the decimal point "."
+/// Number of digits after the decimal point "."
func (rcv *Decimal) MutateScale(n int32) bool {
return rcv._tab.MutateInt32Slot(6, n)
}
-// / Number of bits per value. The only accepted widths are 128 and 256.
-// / We use bitWidth for consistency with Int::bitWidth.
+/// Number of bits per value. The accepted widths are 32, 64, 128 and 256.
+/// We use bitWidth for consistency with Int::bitWidth.
func (rcv *Decimal) BitWidth() int32 {
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
if o != 0 {
@@ -84,8 +84,8 @@ func (rcv *Decimal) BitWidth() int32 {
return 128
}
-// / Number of bits per value. The only accepted widths are 128 and 256.
-// / We use bitWidth for consistency with Int::bitWidth.
+/// Number of bits per value. The accepted widths are 32, 64, 128 and 256.
+/// We use bitWidth for consistency with Int::bitWidth.
func (rcv *Decimal) MutateBitWidth(n int32) bool {
return rcv._tab.MutateInt32Slot(8, n)
}
diff --git a/arrow/ipc/file_reader.go b/arrow/ipc/file_reader.go
index d027db5..2715831 100644
--- a/arrow/ipc/file_reader.go
+++ b/arrow/ipc/file_reader.go
@@ -476,7 +476,7 @@ func (ctx *arrayLoaderContext) loadArray(dt arrow.DataType)
arrow.ArrayData {
*arrow.Int8Type, *arrow.Int16Type, *arrow.Int32Type,
*arrow.Int64Type,
*arrow.Uint8Type, *arrow.Uint16Type, *arrow.Uint32Type,
*arrow.Uint64Type,
*arrow.Float16Type, *arrow.Float32Type, *arrow.Float64Type,
- *arrow.Decimal128Type, *arrow.Decimal256Type,
+ arrow.DecimalType,
*arrow.Time32Type, *arrow.Time64Type,
*arrow.TimestampType,
*arrow.Date32Type, *arrow.Date64Type,
diff --git a/arrow/ipc/metadata.go b/arrow/ipc/metadata.go
index 228f271..a5bf187 100644
--- a/arrow/ipc/metadata.go
+++ b/arrow/ipc/metadata.go
@@ -281,20 +281,12 @@ func (fv *fieldVisitor) visit(field arrow.Field) {
fv.dtype = flatbuf.TypeFloatingPoint
fv.offset = floatToFB(fv.b, int32(dt.BitWidth()))
- case *arrow.Decimal128Type:
+ case arrow.DecimalType:
fv.dtype = flatbuf.TypeDecimal
flatbuf.DecimalStart(fv.b)
- flatbuf.DecimalAddPrecision(fv.b, dt.Precision)
- flatbuf.DecimalAddScale(fv.b, dt.Scale)
- flatbuf.DecimalAddBitWidth(fv.b, 128)
- fv.offset = flatbuf.DecimalEnd(fv.b)
-
- case *arrow.Decimal256Type:
- fv.dtype = flatbuf.TypeDecimal
- flatbuf.DecimalStart(fv.b)
- flatbuf.DecimalAddPrecision(fv.b, dt.Precision)
- flatbuf.DecimalAddScale(fv.b, dt.Scale)
- flatbuf.DecimalAddBitWidth(fv.b, 256)
+ flatbuf.DecimalAddPrecision(fv.b, dt.GetPrecision())
+ flatbuf.DecimalAddScale(fv.b, dt.GetScale())
+ flatbuf.DecimalAddBitWidth(fv.b, int32(dt.BitWidth()))
fv.offset = flatbuf.DecimalEnd(fv.b)
case *arrow.FixedSizeBinaryType:
@@ -947,6 +939,10 @@ func floatToFB(b *flatbuffers.Builder, bw int32)
flatbuffers.UOffsetT {
func decimalFromFB(data flatbuf.Decimal) (arrow.DataType, error) {
switch data.BitWidth() {
+ case 32:
+ return &arrow.Decimal32Type{Precision: data.Precision(), Scale:
data.Scale()}, nil
+ case 64:
+ return &arrow.Decimal64Type{Precision: data.Precision(), Scale:
data.Scale()}, nil
case 128:
return &arrow.Decimal128Type{Precision: data.Precision(),
Scale: data.Scale()}, nil
case 256:
diff --git a/arrow/type_string.go b/arrow/type_string.go
index ee3ccb7..6e5a943 100644
--- a/arrow/type_string.go
+++ b/arrow/type_string.go
@@ -51,11 +51,13 @@ func _() {
_ = x[BINARY_VIEW-40]
_ = x[LIST_VIEW-41]
_ = x[LARGE_LIST_VIEW-42]
+ _ = x[DECIMAL32-43]
+ _ = x[DECIMAL64-44]
}
-const _Type_name =
"NULLBOOLUINT8INT8UINT16INT16UINT32INT32UINT64INT64FLOAT16FLOAT32FLOAT64STRINGBINARYFIXED_SIZE_BINARYDATE32DATE64TIMESTAMPTIME32TIME64INTERVAL_MONTHSINTERVAL_DAY_TIMEDECIMAL128DECIMAL256LISTSTRUCTSPARSE_UNIONDENSE_UNIONDICTIONARYMAPEXTENSIONFIXED_SIZE_LISTDURATIONLARGE_STRINGLARGE_BINARYLARGE_LISTINTERVAL_MONTH_DAY_NANORUN_END_ENCODEDSTRING_VIEWBINARY_VIEWLIST_VIEWLARGE_LIST_VIEW"
+const _Type_name =
"NULLBOOLUINT8INT8UINT16INT16UINT32INT32UINT64INT64FLOAT16FLOAT32FLOAT64STRINGBINARYFIXED_SIZE_BINARYDATE32DATE64TIMESTAMPTIME32TIME64INTERVAL_MONTHSINTERVAL_DAY_TIMEDECIMAL128DECIMAL256LISTSTRUCTSPARSE_UNIONDENSE_UNIONDICTIONARYMAPEXTENSIONFIXED_SIZE_LISTDURATIONLARGE_STRINGLARGE_BINARYLARGE_LISTINTERVAL_MONTH_DAY_NANORUN_END_ENCODEDSTRING_VIEWBINARY_VIEWLIST_VIEWLARGE_LIST_VIEWDECIMAL32DECIMAL64"
-var _Type_index = [...]uint16{0, 4, 8, 13, 17, 23, 28, 34, 39, 45, 50, 57, 64,
71, 77, 83, 100, 106, 112, 121, 127, 133, 148, 165, 175, 185, 189, 195, 207,
218, 228, 231, 240, 255, 263, 275, 287, 297, 320, 335, 346, 357, 366, 381}
+var _Type_index = [...]uint16{0, 4, 8, 13, 17, 23, 28, 34, 39, 45, 50, 57, 64,
71, 77, 83, 100, 106, 112, 121, 127, 133, 148, 165, 175, 185, 189, 195, 207,
218, 228, 231, 240, 255, 263, 275, 287, 297, 320, 335, 346, 357, 366, 381, 390,
399}
func (i Type) String() string {
if i < 0 || i >= Type(len(_Type_index)-1) {
diff --git a/arrow/type_traits.go b/arrow/type_traits.go
index 87e2f06..7185ef2 100644
--- a/arrow/type_traits.go
+++ b/arrow/type_traits.go
@@ -20,8 +20,7 @@ import (
"reflect"
"unsafe"
- "github.com/apache/arrow-go/v18/arrow/decimal128"
- "github.com/apache/arrow-go/v18/arrow/decimal256"
+ "github.com/apache/arrow-go/v18/arrow/decimal"
"github.com/apache/arrow-go/v18/arrow/float16"
"golang.org/x/exp/constraints"
)
@@ -68,7 +67,7 @@ type NumericType interface {
// as a bitmap and thus the buffer can't be just reinterpreted as a []bool
type FixedWidthType interface {
IntType | UintType |
- FloatType | decimal128.Num | decimal256.Num |
+ FloatType | decimal.DecimalTypes |
DayTimeInterval | MonthDayNanoInterval
}
diff --git a/arrow/type_traits_decimal128.go b/arrow/type_traits_decimal128.go
index 860a7f1..6e416cd 100644
--- a/arrow/type_traits_decimal128.go
+++ b/arrow/type_traits_decimal128.go
@@ -19,7 +19,7 @@ package arrow
import (
"unsafe"
- "github.com/apache/arrow-go/v18/arrow/decimal128"
+ "github.com/apache/arrow-go/v18/arrow/decimal"
"github.com/apache/arrow-go/v18/arrow/endian"
)
@@ -28,7 +28,7 @@ var Decimal128Traits decimal128Traits
const (
// Decimal128SizeBytes specifies the number of bytes required to store
a single decimal128 in memory
- Decimal128SizeBytes = int(unsafe.Sizeof(decimal128.Num{}))
+ Decimal128SizeBytes = int(unsafe.Sizeof(decimal.Decimal128{}))
)
type decimal128Traits struct{}
@@ -37,7 +37,7 @@ type decimal128Traits struct{}
func (decimal128Traits) BytesRequired(n int) int { return Decimal128SizeBytes
* n }
// PutValue
-func (decimal128Traits) PutValue(b []byte, v decimal128.Num) {
+func (decimal128Traits) PutValue(b []byte, v decimal.Decimal128) {
endian.Native.PutUint64(b[:8], uint64(v.LowBits()))
endian.Native.PutUint64(b[8:], uint64(v.HighBits()))
}
@@ -45,14 +45,14 @@ func (decimal128Traits) PutValue(b []byte, v
decimal128.Num) {
// CastFromBytes reinterprets the slice b to a slice of type uint16.
//
// NOTE: len(b) must be a multiple of Uint16SizeBytes.
-func (decimal128Traits) CastFromBytes(b []byte) []decimal128.Num {
- return GetData[decimal128.Num](b)
+func (decimal128Traits) CastFromBytes(b []byte) []decimal.Decimal128 {
+ return GetData[decimal.Decimal128](b)
}
// CastToBytes reinterprets the slice b to a slice of bytes.
-func (decimal128Traits) CastToBytes(b []decimal128.Num) []byte {
+func (decimal128Traits) CastToBytes(b []decimal.Decimal128) []byte {
return GetBytes(b)
}
// Copy copies src to dst.
-func (decimal128Traits) Copy(dst, src []decimal128.Num) { copy(dst, src) }
+func (decimal128Traits) Copy(dst, src []decimal.Decimal128) { copy(dst, src) }
diff --git a/arrow/type_traits_decimal256.go b/arrow/type_traits_decimal256.go
index f86bd2a..b196c2e 100644
--- a/arrow/type_traits_decimal256.go
+++ b/arrow/type_traits_decimal256.go
@@ -19,7 +19,7 @@ package arrow
import (
"unsafe"
- "github.com/apache/arrow-go/v18/arrow/decimal256"
+ "github.com/apache/arrow-go/v18/arrow/decimal"
"github.com/apache/arrow-go/v18/arrow/endian"
)
@@ -27,14 +27,14 @@ import (
var Decimal256Traits decimal256Traits
const (
- Decimal256SizeBytes = int(unsafe.Sizeof(decimal256.Num{}))
+ Decimal256SizeBytes = int(unsafe.Sizeof(decimal.Decimal256{}))
)
type decimal256Traits struct{}
func (decimal256Traits) BytesRequired(n int) int { return Decimal256SizeBytes
* n }
-func (decimal256Traits) PutValue(b []byte, v decimal256.Num) {
+func (decimal256Traits) PutValue(b []byte, v decimal.Decimal256) {
for i, a := range v.Array() {
start := i * 8
endian.Native.PutUint64(b[start:], a)
@@ -42,12 +42,12 @@ func (decimal256Traits) PutValue(b []byte, v
decimal256.Num) {
}
// CastFromBytes reinterprets the slice b to a slice of decimal256
-func (decimal256Traits) CastFromBytes(b []byte) []decimal256.Num {
- return GetData[decimal256.Num](b)
+func (decimal256Traits) CastFromBytes(b []byte) []decimal.Decimal256 {
+ return GetData[decimal.Decimal256](b)
}
-func (decimal256Traits) CastToBytes(b []decimal256.Num) []byte {
+func (decimal256Traits) CastToBytes(b []decimal.Decimal256) []byte {
return GetBytes(b)
}
-func (decimal256Traits) Copy(dst, src []decimal256.Num) { copy(dst, src) }
+func (decimal256Traits) Copy(dst, src []decimal.Decimal256) { copy(dst, src) }
diff --git a/arrow/type_traits_decimal128.go b/arrow/type_traits_decimal32.go
similarity index 60%
copy from arrow/type_traits_decimal128.go
copy to arrow/type_traits_decimal32.go
index 860a7f1..ebca65f 100644
--- a/arrow/type_traits_decimal128.go
+++ b/arrow/type_traits_decimal32.go
@@ -19,40 +19,39 @@ package arrow
import (
"unsafe"
- "github.com/apache/arrow-go/v18/arrow/decimal128"
+ "github.com/apache/arrow-go/v18/arrow/decimal"
"github.com/apache/arrow-go/v18/arrow/endian"
)
-// Decimal128 traits
-var Decimal128Traits decimal128Traits
+// Decimal32 traits
+var Decimal32Traits decimal32Traits
const (
- // Decimal128SizeBytes specifies the number of bytes required to store
a single decimal128 in memory
- Decimal128SizeBytes = int(unsafe.Sizeof(decimal128.Num{}))
+ // Decimal32SizeBytes specifies the number of bytes required to store a
single decimal32 in memory
+ Decimal32SizeBytes = int(unsafe.Sizeof(decimal.Decimal32(0)))
)
-type decimal128Traits struct{}
+type decimal32Traits struct{}
// BytesRequired returns the number of bytes required to store n elements in
memory.
-func (decimal128Traits) BytesRequired(n int) int { return Decimal128SizeBytes
* n }
+func (decimal32Traits) BytesRequired(n int) int { return Decimal32SizeBytes *
n }
// PutValue
-func (decimal128Traits) PutValue(b []byte, v decimal128.Num) {
- endian.Native.PutUint64(b[:8], uint64(v.LowBits()))
- endian.Native.PutUint64(b[8:], uint64(v.HighBits()))
+func (decimal32Traits) PutValue(b []byte, v decimal.Decimal32) {
+ endian.Native.PutUint32(b[:4], uint32(v))
}
// CastFromBytes reinterprets the slice b to a slice of type uint16.
//
// NOTE: len(b) must be a multiple of Uint16SizeBytes.
-func (decimal128Traits) CastFromBytes(b []byte) []decimal128.Num {
- return GetData[decimal128.Num](b)
+func (decimal32Traits) CastFromBytes(b []byte) []decimal.Decimal32 {
+ return GetData[decimal.Decimal32](b)
}
// CastToBytes reinterprets the slice b to a slice of bytes.
-func (decimal128Traits) CastToBytes(b []decimal128.Num) []byte {
+func (decimal32Traits) CastToBytes(b []decimal.Decimal32) []byte {
return GetBytes(b)
}
// Copy copies src to dst.
-func (decimal128Traits) Copy(dst, src []decimal128.Num) { copy(dst, src) }
+func (decimal32Traits) Copy(dst, src []decimal.Decimal32) { copy(dst, src) }
diff --git a/arrow/type_traits_decimal128.go b/arrow/type_traits_decimal64.go
similarity index 60%
copy from arrow/type_traits_decimal128.go
copy to arrow/type_traits_decimal64.go
index 860a7f1..bd07883 100644
--- a/arrow/type_traits_decimal128.go
+++ b/arrow/type_traits_decimal64.go
@@ -19,40 +19,39 @@ package arrow
import (
"unsafe"
- "github.com/apache/arrow-go/v18/arrow/decimal128"
+ "github.com/apache/arrow-go/v18/arrow/decimal"
"github.com/apache/arrow-go/v18/arrow/endian"
)
-// Decimal128 traits
-var Decimal128Traits decimal128Traits
+// Decimal64 traits
+var Decimal64Traits decimal64Traits
const (
- // Decimal128SizeBytes specifies the number of bytes required to store
a single decimal128 in memory
- Decimal128SizeBytes = int(unsafe.Sizeof(decimal128.Num{}))
+ // Decimal64SizeBytes specifies the number of bytes required to store a
single decimal64 in memory
+ Decimal64SizeBytes = int(unsafe.Sizeof(decimal.Decimal64(0)))
)
-type decimal128Traits struct{}
+type decimal64Traits struct{}
// BytesRequired returns the number of bytes required to store n elements in
memory.
-func (decimal128Traits) BytesRequired(n int) int { return Decimal128SizeBytes
* n }
+func (decimal64Traits) BytesRequired(n int) int { return Decimal64SizeBytes *
n }
// PutValue
-func (decimal128Traits) PutValue(b []byte, v decimal128.Num) {
- endian.Native.PutUint64(b[:8], uint64(v.LowBits()))
- endian.Native.PutUint64(b[8:], uint64(v.HighBits()))
+func (decimal64Traits) PutValue(b []byte, v decimal.Decimal64) {
+ endian.Native.PutUint64(b[:8], uint64(v))
}
// CastFromBytes reinterprets the slice b to a slice of type uint16.
//
// NOTE: len(b) must be a multiple of Uint16SizeBytes.
-func (decimal128Traits) CastFromBytes(b []byte) []decimal128.Num {
- return GetData[decimal128.Num](b)
+func (decimal64Traits) CastFromBytes(b []byte) []decimal.Decimal64 {
+ return GetData[decimal.Decimal64](b)
}
// CastToBytes reinterprets the slice b to a slice of bytes.
-func (decimal128Traits) CastToBytes(b []decimal128.Num) []byte {
+func (decimal64Traits) CastToBytes(b []decimal.Decimal64) []byte {
return GetBytes(b)
}
// Copy copies src to dst.
-func (decimal128Traits) Copy(dst, src []decimal128.Num) { copy(dst, src) }
+func (decimal64Traits) Copy(dst, src []decimal.Decimal64) { copy(dst, src) }
diff --git a/arrow/type_traits_test.go b/arrow/type_traits_test.go
index d86a67b..93d98b9 100644
--- a/arrow/type_traits_test.go
+++ b/arrow/type_traits_test.go
@@ -23,8 +23,10 @@ import (
"testing"
"github.com/apache/arrow-go/v18/arrow"
+ "github.com/apache/arrow-go/v18/arrow/decimal"
"github.com/apache/arrow-go/v18/arrow/decimal128"
"github.com/apache/arrow-go/v18/arrow/decimal256"
+
"github.com/apache/arrow-go/v18/arrow/float16"
)
@@ -90,6 +92,94 @@ func TestFloat16Traits(t *testing.T) {
}
}
+func TestDecimal32Traits(t *testing.T) {
+ const N = 10
+ nbytes := arrow.Decimal32Traits.BytesRequired(N)
+ b1 := arrow.Decimal32Traits.CastToBytes([]decimal.Decimal32{
+ decimal.Decimal32(0),
+ decimal.Decimal32(1),
+ decimal.Decimal32(2),
+ decimal.Decimal32(3),
+ decimal.Decimal32(4),
+ decimal.Decimal32(5),
+ decimal.Decimal32(6),
+ decimal.Decimal32(7),
+ decimal.Decimal32(8),
+ decimal.Decimal32(9),
+ })
+
+ b2 := make([]byte, nbytes)
+ for i := 0; i < N; i++ {
+ beg := i * arrow.Decimal32SizeBytes
+ end := (i + 1) * arrow.Decimal32SizeBytes
+ arrow.Decimal32Traits.PutValue(b2[beg:end],
decimal.Decimal32(i))
+ }
+
+ if !reflect.DeepEqual(b1, b2) {
+ v1 := arrow.Decimal32Traits.CastFromBytes(b1)
+ v2 := arrow.Decimal32Traits.CastFromBytes(b2)
+ t.Fatalf("invalid values:\nb1=%v\nb2=%v\nv1=%v\nv2=%v\n", b1,
b2, v1, v2)
+ }
+
+ v1 := arrow.Decimal32Traits.CastFromBytes(b1)
+ for i, v := range v1 {
+ if got, want := v, decimal.Decimal32(i); got != want {
+ t.Fatalf("invalid value[%d]. got=%v, want=%v", i, got,
want)
+ }
+ }
+
+ v2 := make([]decimal.Decimal32, N)
+ arrow.Decimal32Traits.Copy(v2, v1)
+
+ if !reflect.DeepEqual(v1, v2) {
+ t.Fatalf("invalid values:\nv1=%v\nv2=%v\n", v1, v2)
+ }
+}
+
+func TestDecimal64Traits(t *testing.T) {
+ const N = 10
+ nbytes := arrow.Decimal64Traits.BytesRequired(N)
+ b1 := arrow.Decimal64Traits.CastToBytes([]decimal.Decimal64{
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+ })
+
+ b2 := make([]byte, nbytes)
+ for i := 0; i < N; i++ {
+ beg := i * arrow.Decimal64SizeBytes
+ end := (i + 1) * arrow.Decimal64SizeBytes
+ arrow.Decimal64Traits.PutValue(b2[beg:end],
decimal.Decimal64(i))
+ }
+
+ if !reflect.DeepEqual(b1, b2) {
+ v1 := arrow.Decimal64Traits.CastFromBytes(b1)
+ v2 := arrow.Decimal64Traits.CastFromBytes(b2)
+ t.Fatalf("invalid values:\nb1=%v\nb2=%v\nv1=%v\nv2=%v\n", b1,
b2, v1, v2)
+ }
+
+ v1 := arrow.Decimal64Traits.CastFromBytes(b1)
+ for i, v := range v1 {
+ if got, want := v, decimal.Decimal64(i); got != want {
+ t.Fatalf("invalid value[%d]. got=%v, want=%v", i, got,
want)
+ }
+ }
+
+ v2 := make([]decimal.Decimal64, N)
+ arrow.Decimal64Traits.Copy(v2, v1)
+
+ if !reflect.DeepEqual(v1, v2) {
+ t.Fatalf("invalid values:\nv1=%v\nv2=%v\n", v1, v2)
+ }
+}
+
func TestDecimal128Traits(t *testing.T) {
const N = 10
nbytes := arrow.Decimal128Traits.BytesRequired(N)