This is an automated email from the ASF dual-hosted git repository.
zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git
The following commit(s) were added to refs/heads/main by this push:
new bb9d77ac fix(arrow/cdata): Handle errors to prevent panic (#614)
bb9d77ac is described below
commit bb9d77ac8888f95d418342706a04d014f3ff8a25
Author: cai.zhang <[email protected]>
AuthorDate: Thu Jan 15 00:18:37 2026 +0800
fix(arrow/cdata): Handle errors to prevent panic (#614)
### Rationale for this change
Fixes: #613
### What changes are included in this PR?
Return error correctly.
### Are these changes tested?
Yes
### Are there any user-facing changes?
No
Signed-off-by: Cai Zhang <[email protected]>
---
arrow/cdata/cdata.go | 39 ++++++++-------
arrow/cdata/cdata_test.go | 106 +++++++++++++++++++++++++++++++++++++++-
arrow/cdata/import_allocator.go | 7 +++
3 files changed, 131 insertions(+), 21 deletions(-)
diff --git a/arrow/cdata/cdata.go b/arrow/cdata/cdata.go
index 4085ed3d..63419469 100644
--- a/arrow/cdata/cdata.go
+++ b/arrow/cdata/cdata.go
@@ -407,7 +407,9 @@ func (imp *cimporter) doImportChildren() error {
st := imp.dt.(*arrow.StructType)
for i, c := range children {
imp.children[i].dt = st.Field(i).Type
- imp.children[i].importChild(imp, c)
+ if err := imp.children[i].importChild(imp, c); err !=
nil {
+ return err
+ }
}
case arrow.RUN_END_ENCODED: // import run-ends and values
st := imp.dt.(*arrow.RunEndEncodedType)
@@ -428,13 +430,17 @@ func (imp *cimporter) doImportChildren() error {
dt := imp.dt.(*arrow.DenseUnionType)
for i, c := range children {
imp.children[i].dt = dt.Fields()[i].Type
- imp.children[i].importChild(imp, c)
+ if err := imp.children[i].importChild(imp, c); err !=
nil {
+ return err
+ }
}
case arrow.SPARSE_UNION:
dt := imp.dt.(*arrow.SparseUnionType)
for i, c := range children {
imp.children[i].dt = dt.Fields()[i].Type
- imp.children[i].importChild(imp, c)
+ if err := imp.children[i].importChild(imp, c); err !=
nil {
+ return err
+ }
}
}
@@ -461,33 +467,28 @@ func (imp *cimporter) doImportArr(src *CArrowArray) error
{
// and only null columns, then we can release the CArrowArray
// struct immediately after import, since we have no imported
// memory that we have to track the lifetime of.
+ // On error, we always release regardless of buffer count to avoid
leaks.
+ var importErr error
defer func() {
- if imp.alloc.bufCount.Load() == 0 {
- C.ArrowArrayRelease(imp.arr)
- C.free(unsafe.Pointer(imp.arr))
+ if importErr != nil || imp.alloc.bufCount.Load() == 0 {
+ imp.alloc.forceRelease()
}
}()
- return imp.doImport()
+ importErr = imp.doImport()
+ return importErr
}
// import is called recursively as needed for importing an array and its
children
// in order to generate array.Data objects
func (imp *cimporter) doImport() error {
- // move the array from the src object passed in to the one referenced by
- // this importer. That way we can set up a finalizer on the created
- // arrow.ArrayData object so we clean up our Array's memory when
garbage collected.
- defer func(arr *CArrowArray) {
- // this should only occur in the case of an error happening
- // during import, at which point we need to clean up the
- // ArrowArray struct we allocated.
- if imp.data == nil {
- C.free(unsafe.Pointer(arr))
- }
- }(imp.arr)
-
// import any children
if err := imp.doImportChildren(); err != nil {
+ for _, c := range imp.children {
+ if c.data != nil {
+ c.data.Release()
+ }
+ }
return err
}
diff --git a/arrow/cdata/cdata_test.go b/arrow/cdata/cdata_test.go
index 170a5151..8fa690f2 100644
--- a/arrow/cdata/cdata_test.go
+++ b/arrow/cdata/cdata_test.go
@@ -669,8 +669,8 @@ func createTestDenseUnion() arrow.Array {
func createTestUnionArr(mode arrow.UnionMode) arrow.Array {
fields := []arrow.Field{
- arrow.Field{Name: "u0", Type: arrow.PrimitiveTypes.Int32,
Nullable: true},
- arrow.Field{Name: "u1", Type: arrow.PrimitiveTypes.Uint8,
Nullable: true},
+ {Name: "u0", Type: arrow.PrimitiveTypes.Int32, Nullable: true},
+ {Name: "u1", Type: arrow.PrimitiveTypes.Uint8, Nullable: true},
}
typeCodes := []arrow.UnionTypeCode{5, 10}
bld := array.NewBuilder(memory.DefaultAllocator, arrow.UnionOf(mode,
fields, typeCodes)).(array.UnionBuilder)
@@ -785,6 +785,104 @@ func TestRecordBatch(t *testing.T) {
assert.True(t, array.RecordEqual(rb, rec))
}
+func TestImportStructWithInvalidSchema(t *testing.T) {
+ mem := mallocator.NewMallocator()
+ defer mem.AssertSize(t, 0)
+
+ arr := createTestStructArr()
+ defer arr.Release()
+
+ carr := createCArr(arr, mem)
+ defer freeTestMallocatorArr(carr, mem)
+
+ sc := testStruct([]string{"+s", "c", "l"}, []string{"", "a", "b"},
[]int64{0, flagIsNullable, flagIsNullable})
+ defer freeMallocedSchemas(sc)
+
+ top := (*[1]*CArrowSchema)(unsafe.Pointer(sc))[0]
+ _, err := ImportCRecordBatch(carr, top)
+ assert.Error(t, err)
+}
+
+func TestImportDenseUnionWithInvalidSchema(t *testing.T) {
+ mem := mallocator.NewMallocator()
+ defer mem.AssertSize(t, 0)
+
+ unionArr := createTestDenseUnion()
+ defer unionArr.Release()
+
+ structBld := array.NewStructBuilder(memory.DefaultAllocator,
arrow.StructOf(
+ arrow.Field{Name: "union_field", Type: unionArr.DataType(),
Nullable: false},
+ ))
+ defer structBld.Release()
+
+ unionBld := structBld.FieldBuilder(0).(*array.DenseUnionBuilder)
+ structBld.Append(true)
+ du := unionArr.(*array.DenseUnion)
+ for i := 0; i < du.Len(); i++ {
+ unionBld.Append(du.TypeCode(i))
+ if du.TypeCode(i) == 5 {
+
unionBld.Child(0).(*array.Int32Builder).Append(du.Field(0).(*array.Int32).Value(int(du.ValueOffset(i))))
+ } else {
+
unionBld.Child(1).(*array.Uint8Builder).Append(du.Field(1).(*array.Uint8).Value(int(du.ValueOffset(i))))
+ }
+ }
+
+ structArr := structBld.NewArray()
+ defer structArr.Release()
+
+ carr := createCArr(structArr, mem)
+ defer freeTestMallocatorArr(carr, mem)
+
+ // Create an invalid schema: wrong type for union field (using "i"
instead of proper union schema)
+ sc := testStruct([]string{"+s", "i"}, []string{"", "union_field"},
[]int64{0, flagIsNullable})
+ defer freeMallocedSchemas(sc)
+
+ top := (*[1]*CArrowSchema)(unsafe.Pointer(sc))[0]
+ _, err := ImportCRecordBatch(carr, top)
+ assert.Error(t, err)
+}
+
+func TestImportSPARSEUnionWithInvalidSchema(t *testing.T) {
+ mem := mallocator.NewMallocator()
+ defer mem.AssertSize(t, 0)
+
+ unionArr := createTestSparseUnion()
+ defer unionArr.Release()
+
+ structBld := array.NewStructBuilder(memory.DefaultAllocator,
arrow.StructOf(
+ arrow.Field{Name: "union_field", Type: unionArr.DataType(),
Nullable: false},
+ ))
+ defer structBld.Release()
+
+ unionBld := structBld.FieldBuilder(0).(*array.SparseUnionBuilder)
+ structBld.Append(true)
+ su := unionArr.(*array.SparseUnion)
+ for i := 0; i < su.Len(); i++ {
+ unionBld.Append(su.TypeCode(i))
+ if su.TypeCode(i) == 5 {
+
unionBld.Child(0).(*array.Int32Builder).Append(su.Field(0).(*array.Int32).Value(i))
+ unionBld.Child(1).(*array.Uint8Builder).AppendNull()
+ } else {
+ unionBld.Child(0).(*array.Int32Builder).AppendNull()
+
unionBld.Child(1).(*array.Uint8Builder).Append(su.Field(1).(*array.Uint8).Value(i))
+ }
+ }
+
+ structArr := structBld.NewArray()
+ defer structArr.Release()
+
+ carr := createCArr(structArr, mem)
+ defer freeTestMallocatorArr(carr, mem)
+
+ // Create an invalid schema: wrong type for union field (using "u"
instead of proper union schema)
+ sc := testStruct([]string{"+s", "u"}, []string{"", "union_field"},
[]int64{0, flagIsNullable})
+ defer freeMallocedSchemas(sc)
+
+ top := (*[1]*CArrowSchema)(unsafe.Pointer(sc))[0]
+ _, err := ImportCRecordBatch(carr, top)
+ assert.Error(t, err)
+}
+
func TestRecordReaderStream(t *testing.T) {
stream := arrayStreamTest()
defer releaseStreamTest(stream)
@@ -1006,17 +1104,21 @@ func (r *failingReader) Schema() *arrow.Schema {
}
return arrdata.Records["primitives"][0].Schema()
}
+
func (r *failingReader) Next() bool {
r.opCount -= 1
return r.opCount > 0
}
+
func (r *failingReader) RecordBatch() arrow.RecordBatch {
arrdata.Records["primitives"][0].Retain()
return arrdata.Records["primitives"][0]
}
+
func (r *failingReader) Record() arrow.Record {
return r.RecordBatch()
}
+
func (r *failingReader) Err() error {
if r.opCount == 0 {
return fmt.Errorf("Expected error message")
diff --git a/arrow/cdata/import_allocator.go b/arrow/cdata/import_allocator.go
index d2cc44b7..2dea1336 100644
--- a/arrow/cdata/import_allocator.go
+++ b/arrow/cdata/import_allocator.go
@@ -29,6 +29,7 @@ import "C"
type importAllocator struct {
bufCount atomic.Int64
+ released atomic.Bool
arr *CArrowArray
}
@@ -49,6 +50,12 @@ func (i *importAllocator) Free([]byte) {
debug.Assert(i.bufCount.Load() > 0, "too many releases")
if i.bufCount.Add(-1) == 0 {
+ i.forceRelease()
+ }
+}
+
+func (i *importAllocator) forceRelease() {
+ if i.released.CompareAndSwap(false, true) {
defer C.free(unsafe.Pointer(i.arr))
C.ArrowArrayRelease(i.arr)
if C.ArrowArrayIsReleased(i.arr) != 1 {