This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new d571e93ad2 ARROW-17730: [Go] Implement Take kernels for FSB and 
VarBinary (#14127)
d571e93ad2 is described below

commit d571e93ad24d5111800540b42a3b8d56459edd9b
Author: Matt Topol <[email protected]>
AuthorDate: Fri Sep 16 11:05:46 2022 -0400

    ARROW-17730: [Go] Implement Take kernels for FSB and VarBinary (#14127)
    
    Authored-by: Matt Topol <[email protected]>
    Signed-off-by: Matt Topol <[email protected]>
---
 .../compute/internal/kernels/vector_selection.go   | 177 +++++++++++++++++++--
 go/arrow/compute/vector_selection_test.go          |  51 ++++++
 2 files changed, 218 insertions(+), 10 deletions(-)

diff --git a/go/arrow/compute/internal/kernels/vector_selection.go 
b/go/arrow/compute/internal/kernels/vector_selection.go
index fa1c33be59..c4bfcca8bc 100644
--- a/go/arrow/compute/internal/kernels/vector_selection.go
+++ b/go/arrow/compute/internal/kernels/vector_selection.go
@@ -991,19 +991,171 @@ func binaryFilterImpl[OffsetT int32 | int64](ctx 
*exec.KernelCtx, values, filter
        return nil
 }
 
-func FilterFSB(ctx *exec.KernelCtx, batch *exec.ExecSpan, out 
*exec.ExecResult) error {
+func takeExecImpl[T exec.UintTypes](ctx *exec.KernelCtx, outputLen int64, 
values, indices *exec.ArraySpan, out *exec.ExecResult, visitValid func(int64) 
error, visitNull func() error) error {
        var (
-               values       = &batch.Values[0].Array
-               selection    = &batch.Values[1].Array
-               outputLength = getFilterOutputSize(selection, 
ctx.State.(FilterState).NullSelection)
-               valueSize    = 
int64(values.Type.(arrow.FixedWidthDataType).Bytes())
-               valueData    = values.Buffers[1].Buf[values.Offset*valueSize:]
+               validityBuilder = validityBuilder{mem: 
exec.GetAllocator(ctx.Ctx)}
+               indicesValues   = exec.GetSpanValues[T](indices, 1)
+               isValid         = indices.Buffers[0].Buf
+               valuesHaveNulls = values.MayHaveNulls()
+
+               indicesIsValid = bitutil.OptionalBitIndexer{Bitmap: isValid, 
Offset: int(indices.Offset)}
+               valuesIsValid  = bitutil.OptionalBitIndexer{Bitmap: 
values.Buffers[0].Buf, Offset: int(values.Offset)}
+               bitCounter     = bitutils.NewOptionalBitBlockCounter(isValid, 
indices.Offset, indices.Len)
+               pos            int64
+       )
+
+       validityBuilder.Reserve(outputLen)
+       for pos < indices.Len {
+               block := bitCounter.NextBlock()
+               indicesHaveNulls := block.Popcnt < block.Len
+               if !indicesHaveNulls && !valuesHaveNulls {
+                       // fastest path, neither indices nor values have nulls
+                       validityBuilder.UnsafeAppendN(int64(block.Len), true)
+                       for i := 0; i < int(block.Len); i++ {
+                               if err := 
visitValid(int64(indicesValues[pos])); err != nil {
+                                       return err
+                               }
+                               pos++
+                       }
+               } else if block.Popcnt > 0 {
+                       // since we have to branch on whether indices are null 
or not,
+                       // we combine the "non-null indices block but some 
values null"
+                       // and "some null indices block but values non-null" 
into single loop
+                       for i := 0; i < int(block.Len); i++ {
+                               if (!indicesHaveNulls || 
indicesIsValid.GetBit(int(pos))) && 
valuesIsValid.GetBit(int(indicesValues[pos])) {
+                                       validityBuilder.UnsafeAppend(true)
+                                       if err := 
visitValid(int64(indicesValues[pos])); err != nil {
+                                               return err
+                                       }
+                               } else {
+                                       validityBuilder.UnsafeAppend(false)
+                                       if err := visitNull(); err != nil {
+                                               return err
+                                       }
+                               }
+                               pos++
+                       }
+               } else {
+                       // the whole block is null
+                       validityBuilder.UnsafeAppendN(int64(block.Len), false)
+                       for i := 0; i < int(block.Len); i++ {
+                               if err := visitNull(); err != nil {
+                                       return err
+                               }
+                       }
+                       pos += int64(block.Len)
+               }
+       }
+
+       out.Len = int64(validityBuilder.bitLength)
+       out.Nulls = int64(validityBuilder.falseCount)
+       out.Buffers[0].WrapBuffer(validityBuilder.Finish())
+       return nil
+}
+
+func takeExec(ctx *exec.KernelCtx, outputLen int64, values, indices 
*exec.ArraySpan, out *exec.ExecResult, visitValid func(int64) error, visitNull 
func() error) error {
+       indexWidth := indices.Type.(arrow.FixedWidthDataType).Bytes()
+
+       switch indexWidth {
+       case 1:
+               return takeExecImpl[uint8](ctx, outputLen, values, indices, 
out, visitValid, visitNull)
+       case 2:
+               return takeExecImpl[uint16](ctx, outputLen, values, indices, 
out, visitValid, visitNull)
+       case 4:
+               return takeExecImpl[uint32](ctx, outputLen, values, indices, 
out, visitValid, visitNull)
+       case 8:
+               return takeExecImpl[uint64](ctx, outputLen, values, indices, 
out, visitValid, visitNull)
+       default:
+               return fmt.Errorf("%w: invalid index width", arrow.ErrInvalid)
+       }
+}
+
+type outputFn func(*exec.KernelCtx, int64, *exec.ArraySpan, *exec.ArraySpan, 
*exec.ExecResult, func(int64) error, func() error) error
+type implFn func(*exec.KernelCtx, *exec.ExecSpan, int64, *exec.ExecResult, 
outputFn) error
+
+func FilterExec(impl implFn, fn outputFn) exec.ArrayKernelExec {
+       return func(ctx *exec.KernelCtx, batch *exec.ExecSpan, out 
*exec.ExecResult) error {
+               var (
+                       selection    = &batch.Values[1].Array
+                       outputLength = getFilterOutputSize(selection, 
ctx.State.(FilterState).NullSelection)
+               )
+               return impl(ctx, batch, outputLength, out, fn)
+       }
+}
+
+func TakeExec(impl implFn, fn outputFn) exec.ArrayKernelExec {
+       return func(ctx *exec.KernelCtx, batch *exec.ExecSpan, out 
*exec.ExecResult) error {
+               if ctx.State.(TakeState).BoundsCheck {
+                       if err := checkIndexBounds(&batch.Values[1].Array, 
uint64(batch.Values[0].Array.Len)); err != nil {
+                               return err
+                       }
+               }
+
+               return impl(ctx, batch, batch.Values[1].Array.Len, out, fn)
+       }
+}
+
+func VarBinaryImpl[OffsetT int32 | int64](ctx *exec.KernelCtx, batch 
*exec.ExecSpan, outputLength int64, out *exec.ExecResult, fn outputFn) error {
+       var (
+               values        = &batch.Values[0].Array
+               selection     = &batch.Values[1].Array
+               rawOffsets    = exec.GetSpanOffsets[OffsetT](values, 1)
+               rawData       = values.Buffers[2].Buf
+               offsetBuilder = 
newBufferBuilder[OffsetT](exec.GetAllocator(ctx.Ctx))
+               dataBuilder   = 
newBufferBuilder[uint8](exec.GetAllocator(ctx.Ctx))
+       )
+
+       // presize the data builder with a rough estimate of the required data 
size
+       if values.Len > 0 {
+               dataLength := rawOffsets[values.Len] - rawOffsets[0]
+               meanValueLen := float64(dataLength) / float64(values.Len)
+               dataBuilder.reserve(int(meanValueLen))
+       }
+
+       offsetBuilder.reserve(int(outputLength) + 1)
+       spaceAvail := dataBuilder.cap()
+       var offset OffsetT
+       err := fn(ctx, outputLength, values, selection, out,
+               func(idx int64) error {
+                       offsetBuilder.unsafeAppend(offset)
+                       valOffset := rawOffsets[idx]
+                       valSize := rawOffsets[idx+1] - valOffset
+
+                       offset += valSize
+                       if valSize > OffsetT(spaceAvail) {
+                               dataBuilder.reserve(int(valSize))
+                               spaceAvail = dataBuilder.cap() - 
dataBuilder.len()
+                       }
+                       dataBuilder.unsafeAppendSlice(rawData[valOffset : 
valOffset+valSize])
+                       spaceAvail -= int(valSize)
+                       return nil
+               }, func() error {
+                       offsetBuilder.unsafeAppend(offset)
+                       return nil
+               })
+
+       if err != nil {
+               return err
+       }
+
+       offsetBuilder.unsafeAppend(offset)
+       out.Buffers[1].WrapBuffer(offsetBuilder.finish())
+       out.Buffers[2].WrapBuffer(dataBuilder.finish())
+       return nil
+}
+
+func FSBImpl(ctx *exec.KernelCtx, batch *exec.ExecSpan, outputLength int64, 
out *exec.ExecResult, fn outputFn) error {
+       var (
+               values    = &batch.Values[0].Array
+               selection = &batch.Values[1].Array
+               valueSize = 
int64(values.Type.(arrow.FixedWidthDataType).Bytes())
+               valueData = values.Buffers[1].Buf[values.Offset*valueSize:]
        )
 
        out.Buffers[1].WrapBuffer(ctx.Allocate(int(valueSize * outputLength)))
        buf := out.Buffers[1].Buf
 
-       err := filterExec(ctx, outputLength, values, selection, out,
+       err := fn(ctx, outputLength, values, selection, out,
                func(idx int64) error {
                        start := idx * int64(valueSize)
                        copy(buf, valueData[start:start+valueSize])
@@ -1076,9 +1228,9 @@ func GetVectorSelectionKernels() (filterkernels, 
takeKernels []SelectionKernelDa
        filterkernels = []SelectionKernelData{
                {In: exec.NewMatchedInput(exec.Primitive()), Exec: 
PrimitiveFilter},
                {In: exec.NewExactInput(arrow.Null), Exec: NullFilter},
-               {In: exec.NewIDInput(arrow.DECIMAL128), Exec: FilterFSB},
-               {In: exec.NewIDInput(arrow.DECIMAL256), Exec: FilterFSB},
-               {In: exec.NewIDInput(arrow.FIXED_SIZE_BINARY), Exec: FilterFSB},
+               {In: exec.NewIDInput(arrow.DECIMAL128), Exec: 
FilterExec(FSBImpl, filterExec)},
+               {In: exec.NewIDInput(arrow.DECIMAL256), Exec: 
FilterExec(FSBImpl, filterExec)},
+               {In: exec.NewIDInput(arrow.FIXED_SIZE_BINARY), Exec: 
FilterExec(FSBImpl, filterExec)},
                {In: exec.NewMatchedInput(exec.BinaryLike()), Exec: 
FilterBinary},
                {In: exec.NewMatchedInput(exec.LargeBinaryLike()), Exec: 
FilterBinary},
        }
@@ -1086,6 +1238,11 @@ func GetVectorSelectionKernels() (filterkernels, 
takeKernels []SelectionKernelDa
        takeKernels = []SelectionKernelData{
                {In: exec.NewExactInput(arrow.Null), Exec: NullTake},
                {In: exec.NewMatchedInput(exec.Primitive()), Exec: 
PrimitiveTake},
+               {In: exec.NewIDInput(arrow.DECIMAL128), Exec: TakeExec(FSBImpl, 
takeExec)},
+               {In: exec.NewIDInput(arrow.DECIMAL256), Exec: TakeExec(FSBImpl, 
takeExec)},
+               {In: exec.NewIDInput(arrow.FIXED_SIZE_BINARY), Exec: 
TakeExec(FSBImpl, takeExec)},
+               {In: exec.NewMatchedInput(exec.BinaryLike()), Exec: 
TakeExec(VarBinaryImpl[int32], takeExec)},
+               {In: exec.NewMatchedInput(exec.LargeBinaryLike()), Exec: 
TakeExec(VarBinaryImpl[int64], takeExec)},
        }
        return
 }
diff --git a/go/arrow/compute/vector_selection_test.go 
b/go/arrow/compute/vector_selection_test.go
index 59b0e1d07b..e5fdfbcb77 100644
--- a/go/arrow/compute/vector_selection_test.go
+++ b/go/arrow/compute/vector_selection_test.go
@@ -663,11 +663,62 @@ func (tk *TakeKernelTestNumeric) TestTakeNumeric() {
        })
 }
 
+type TakeKernelTestFSB struct {
+       TakeKernelTestTyped
+}
+
+func (tk *TakeKernelTestFSB) SetupSuite() {
+       tk.dt = &arrow.FixedSizeBinaryType{ByteWidth: 3}
+}
+
+func (tk *TakeKernelTestFSB) TestFixedSizeBinary() {
+       // YWFh == base64("aaa")
+       // YmJi == base64("bbb")
+       // Y2Nj == base64("ccc")
+       tk.assertTake(`["YWFh", "YmJi", "Y2Nj"]`, `[0, 1, 0]`, `["YWFh", 
"YmJi", "YWFh"]`)
+       tk.assertTake(`[null, "YmJi", "Y2Nj"]`, `[0, 1, 0]`, `[null, "YmJi", 
null]`)
+       tk.assertTake(`["YWFh", "YmJi", "Y2Nj"]`, `[null, 1, 0]`, `[null, 
"YmJi", "YWFh"]`)
+
+       tk.assertNoValidityBitmapUnknownNullCountJSON(tk.dt, `["YWFh", "YmJi", 
"Y2Nj"]`, `[0, 1, 0]`)
+
+       _, err := tk.takeJSON(tk.dt, `["YWFh", "YmJi", "Y2Nj"]`, 
arrow.PrimitiveTypes.Int8, `[0, 9, 0]`)
+       tk.ErrorIs(err, arrow.ErrIndex)
+       _, err = tk.takeJSON(tk.dt, `["YWFh", "YmJi", "Y2Nj"]`, 
arrow.PrimitiveTypes.Int64, `[2, 5]`)
+       tk.ErrorIs(err, arrow.ErrIndex)
+}
+
+type TakeKernelTestString struct {
+       TakeKernelTestTyped
+}
+
+func (tk *TakeKernelTestString) TestTakeString() {
+       tk.Run(tk.dt.String(), func() {
+               // base64 encoded so the binary non-utf8 arrays work
+               // YQ== -> "a"
+               // Yg== -> "b"
+               // Yw== -> "c"
+               tk.assertTake(`["YQ==", "Yg==", "Yw=="]`, `[0, 1, 0]`, 
`["YQ==", "Yg==", "YQ=="]`)
+               tk.assertTake(`[null, "Yg==", "Yw=="]`, `[0, 1, 0]`, `[null, 
"Yg==", null]`)
+               tk.assertTake(`["YQ==", "Yg==", "Yw=="]`, `[null, 1, 0]`, 
`[null, "Yg==", "YQ=="]`)
+
+               tk.assertNoValidityBitmapUnknownNullCountJSON(tk.dt, `["YQ==", 
"Yg==", "Yw=="]`, `[0, 1, 0]`)
+
+               _, err := tk.takeJSON(tk.dt, `["YQ==", "Yg==", "Yw=="]`, 
arrow.PrimitiveTypes.Int8, `[0, 9, 0]`)
+               tk.ErrorIs(err, arrow.ErrIndex)
+               _, err = tk.takeJSON(tk.dt, `["YQ==", "Yg==", "Yw=="]`, 
arrow.PrimitiveTypes.Int64, `[2, 5]`)
+               tk.ErrorIs(err, arrow.ErrIndex)
+       })
+}
+
 func TestTakeKernels(t *testing.T) {
        suite.Run(t, new(TakeKernelTest))
        for _, dt := range numericTypes {
                suite.Run(t, &TakeKernelTestNumeric{TakeKernelTestTyped: 
TakeKernelTestTyped{dt: dt}})
        }
+       suite.Run(t, new(TakeKernelTestFSB))
+       for _, dt := range baseBinaryTypes {
+               suite.Run(t, &TakeKernelTestString{TakeKernelTestTyped: 
TakeKernelTestTyped{dt: dt}})
+       }
 }
 
 func TestFilterKernels(t *testing.T) {

Reply via email to