This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new c32d2266bf GH-38757: [C#] Implement common interfaces for structure 
arrays and record batches (#38759)
c32d2266bf is described below

commit c32d2266bfc2e8730512fa8ad618502ded99b1bc
Author: Curt Hagenlocher <[email protected]>
AuthorDate: Fri Nov 17 14:44:44 2023 -0800

    GH-38757: [C#] Implement common interfaces for structure arrays and record 
batches (#38759)
    
    ### What changes are included in this PR?
    
    New interface IStructType is implemented by both Schema and StructType.
    New interface IArrowStructArray is implemented by both RecordBatch and 
StructArray.
    
    These changes make it easier to write code which handles both types of 
structures, which are virtually identical to each other.
    
    ### Are these changes tested?
    
    Yes.
    
    ### Are there any user-facing changes?
    
    New interfaces have been added to access existing structures.
    
    * Closes: #38757
    
    Authored-by: Curt Hagenlocher <[email protected]>
    Signed-off-by: Curt Hagenlocher <[email protected]>
---
 csharp/src/Apache.Arrow/Arrays/StructArray.cs      |  27 ++++-
 .../IArrowType.cs => Interfaces/IArrowRecord.cs}   |  52 ++-------
 csharp/src/Apache.Arrow/RecordBatch.cs             |  34 +++++-
 csharp/src/Apache.Arrow/Schema.cs                  |  34 +++++-
 csharp/src/Apache.Arrow/Types/IArrowType.cs        |   1 +
 .../Types/{IArrowType.cs => IRecordType.cs}        |  49 +--------
 csharp/src/Apache.Arrow/Types/StructType.cs        |  24 +++-
 csharp/test/Apache.Arrow.Tests/FieldComparer.cs    |  11 ++
 csharp/test/Apache.Arrow.Tests/RecordTests.cs      | 122 +++++++++++++++++++++
 9 files changed, 258 insertions(+), 96 deletions(-)

diff --git a/csharp/src/Apache.Arrow/Arrays/StructArray.cs 
b/csharp/src/Apache.Arrow/Arrays/StructArray.cs
index 31aea9b411..11d40e6d4e 100644
--- a/csharp/src/Apache.Arrow/Arrays/StructArray.cs
+++ b/csharp/src/Apache.Arrow/Arrays/StructArray.cs
@@ -20,7 +20,7 @@ using System.Threading;
 
 namespace Apache.Arrow
 {
-    public class StructArray : Array
+    public class StructArray : Array, IArrowRecord
     {
         private IReadOnlyList<IArrowArray> _fields;
 
@@ -44,7 +44,21 @@ namespace Apache.Arrow
             data.EnsureDataType(ArrowTypeId.Struct);
         }
 
-        public override void Accept(IArrowArrayVisitor visitor) => 
Accept(this, visitor);
+        public override void Accept(IArrowArrayVisitor visitor)
+        {
+            switch (visitor)
+            {
+                case IArrowArrayVisitor<StructArray> structArrayVisitor:
+                    structArrayVisitor.Visit(this);
+                    break;
+                case IArrowArrayVisitor<IArrowRecord> arrowStructVisitor:
+                    arrowStructVisitor.Visit(this);
+                    break;
+                default:
+                    visitor.Visit(this);
+                    break;
+            }
+        }
 
         private IReadOnlyList<IArrowArray> InitializeFields()
         {
@@ -55,5 +69,14 @@ namespace Apache.Arrow
             }
             return result;
         }
+
+        IRecordType IArrowRecord.Schema => (StructType)Data.DataType;
+
+        int IArrowRecord.ColumnCount => _fields.Count;
+
+        IArrowArray IArrowRecord.Column(string columnName, 
IEqualityComparer<string> comparer) =>
+            _fields[((StructType)Data.DataType).GetFieldIndex(columnName, 
comparer)];
+
+        IArrowArray IArrowRecord.Column(int columnIndex) => 
_fields[columnIndex];
     }
 }
diff --git a/csharp/src/Apache.Arrow/Types/IArrowType.cs 
b/csharp/src/Apache.Arrow/Interfaces/IArrowRecord.cs
similarity index 51%
copy from csharp/src/Apache.Arrow/Types/IArrowType.cs
copy to csharp/src/Apache.Arrow/Interfaces/IArrowRecord.cs
index cdf423e56f..126d214df2 100644
--- a/csharp/src/Apache.Arrow/Types/IArrowType.cs
+++ b/csharp/src/Apache.Arrow/Interfaces/IArrowRecord.cs
@@ -13,53 +13,17 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+using System.Collections.Generic;
+using Apache.Arrow.Types;
 
-namespace Apache.Arrow.Types
+namespace Apache.Arrow
 {
-    public enum ArrowTypeId
+    public interface IArrowRecord : IArrowArray
     {
-        Null,
-        Boolean,
-        UInt8,
-        Int8,
-        UInt16,
-        Int16,
-        UInt32,
-        Int32,
-        UInt64,
-        Int64,
-        HalfFloat,
-        Float,
-        Double,
-        String,
-        Binary,
-        FixedSizedBinary,
-        Date32,
-        Date64,
-        Timestamp,
-        Time32,
-        Time64,
-        Interval,
-        Decimal128,
-        Decimal256,
-        List,
-        Struct,
-        Union,
-        Dictionary,
-        Map,
-        FixedSizeList,
-        Duration,
-    }
-
-    public interface IArrowType
-    {
-        ArrowTypeId TypeId { get; }
-
-        string Name { get; }
- 
-        void Accept(IArrowTypeVisitor visitor);
+        IRecordType Schema { get; }
+        int ColumnCount { get; }
 
-        bool IsFixedWidth { get; }
-    
+        IArrowArray Column(string columnName, IEqualityComparer<string> 
comparer = default);
+        IArrowArray Column(int columnIndex);
     }
 }
diff --git a/csharp/src/Apache.Arrow/RecordBatch.cs 
b/csharp/src/Apache.Arrow/RecordBatch.cs
index 566c778302..9cc81b1648 100644
--- a/csharp/src/Apache.Arrow/RecordBatch.cs
+++ b/csharp/src/Apache.Arrow/RecordBatch.cs
@@ -19,10 +19,11 @@ using System.Collections.Generic;
 using System.Diagnostics;
 using System.Linq;
 using Apache.Arrow.Memory;
+using Apache.Arrow.Types;
 
 namespace Apache.Arrow
 {
-    public partial class RecordBatch : IDisposable
+    public partial class RecordBatch : IArrowRecord
     {
         public Schema Schema { get; }
         public int ColumnCount => _arrays.Count;
@@ -41,7 +42,12 @@ namespace Apache.Arrow
 
         public IArrowArray Column(string columnName)
         {
-            int fieldIndex = Schema.GetFieldIndex(columnName);
+            return Column(columnName, null);
+        }
+
+        public IArrowArray Column(string columnName, IEqualityComparer<string> 
comparer)
+        {
+            int fieldIndex = Schema.GetFieldIndex(columnName, comparer);
             return _arrays[fieldIndex];
         }
 
@@ -94,6 +100,30 @@ namespace Apache.Arrow
             return new RecordBatch(Schema, arrays, Length);
         }
 
+        public void Accept(IArrowArrayVisitor visitor)
+        {
+            switch (visitor)
+            {
+                case IArrowArrayVisitor<RecordBatch> recordBatchVisitor:
+                    recordBatchVisitor.Visit(this);
+                    break;
+                case IArrowArrayVisitor<IArrowRecord> arrowStructVisitor:
+                    arrowStructVisitor.Visit(this);
+                    break;
+                default:
+                    visitor.Visit(this);
+                    break;
+            }
+        }
+
         public override string ToString() => $"{nameof(RecordBatch)}: 
{ColumnCount} columns by {Length} rows";
+
+        IRecordType IArrowRecord.Schema => this.Schema;
+        int IArrowArray.NullCount => 0;
+        int IArrowArray.Offset => 0;
+        ArrayData IArrowArray.Data => throw new NotSupportedException("Unable 
to get data for RecordBatch");
+
+        bool IArrowArray.IsNull(int index) => false;
+        bool IArrowArray.IsValid(int index) => true;
     }
 }
diff --git a/csharp/src/Apache.Arrow/Schema.cs 
b/csharp/src/Apache.Arrow/Schema.cs
index 608b967630..4357e8b2dd 100644
--- a/csharp/src/Apache.Arrow/Schema.cs
+++ b/csharp/src/Apache.Arrow/Schema.cs
@@ -13,6 +13,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+using Apache.Arrow.Types;
 using System;
 using System.Collections.Generic;
 using System.Diagnostics;
@@ -20,7 +21,7 @@ using System.Linq;
 
 namespace Apache.Arrow
 {
-    public partial class Schema
+    public partial class Schema : IRecordType
     {
         [Obsolete("Use `FieldsList` or `FieldsLookup` instead")]
         public IReadOnlyDictionary<string, Field> Fields => _fieldsDictionary;
@@ -71,11 +72,17 @@ namespace Apache.Arrow
 
         public Field GetFieldByName(string name) => 
FieldsLookup[name].FirstOrDefault();
 
-        public int GetFieldIndex(string name, StringComparer comparer = 
default)
+        public int GetFieldIndex(string name, StringComparer comparer)
+        {
+            IEqualityComparer<string> equalityComparer = 
(IEqualityComparer<string>)comparer;
+            return GetFieldIndex(name, equalityComparer);
+        }
+
+        public int GetFieldIndex(string name, IEqualityComparer<string> 
comparer = default)
         {
             comparer ??= StringComparer.CurrentCulture;
 
-            return _fieldsList.IndexOf(_fieldsList.First(x => 
comparer.Compare(x.Name, name) == 0));
+            return _fieldsList.IndexOf(_fieldsList.First(x => 
comparer.Equals(x.Name, name)));
         }
 
         public Schema RemoveField(int fieldIndex)
@@ -115,6 +122,27 @@ namespace Apache.Arrow
             return new Schema(fields, Metadata);
         }
 
+        public void Accept(IArrowTypeVisitor visitor)
+        {
+            if (visitor is IArrowTypeVisitor<Schema> schemaVisitor)
+            {
+                schemaVisitor.Visit(this);
+            }
+            else if (visitor is IArrowTypeVisitor<IRecordType> 
interfaceVisitor)
+            {
+                interfaceVisitor.Visit(this);
+            }
+            else
+            {
+                visitor.Visit(this);
+            }
+        }
+
         public override string ToString() => $"{nameof(Schema)}: Num 
fields={_fieldsList.Count}, Num metadata={Metadata?.Count ?? 0}";
+
+        int IRecordType.FieldCount => _fieldsList.Count;
+        string IArrowType.Name => "RecordBatch";
+        ArrowTypeId IArrowType.TypeId => ArrowTypeId.RecordBatch;
+        bool IArrowType.IsFixedWidth => false;
     }
 }
diff --git a/csharp/src/Apache.Arrow/Types/IArrowType.cs 
b/csharp/src/Apache.Arrow/Types/IArrowType.cs
index cdf423e56f..5e107813be 100644
--- a/csharp/src/Apache.Arrow/Types/IArrowType.cs
+++ b/csharp/src/Apache.Arrow/Types/IArrowType.cs
@@ -49,6 +49,7 @@ namespace Apache.Arrow.Types
         Map,
         FixedSizeList,
         Duration,
+        RecordBatch,
     }
 
     public interface IArrowType
diff --git a/csharp/src/Apache.Arrow/Types/IArrowType.cs 
b/csharp/src/Apache.Arrow/Types/IRecordType.cs
similarity index 52%
copy from csharp/src/Apache.Arrow/Types/IArrowType.cs
copy to csharp/src/Apache.Arrow/Types/IRecordType.cs
index cdf423e56f..510edad23c 100644
--- a/csharp/src/Apache.Arrow/Types/IArrowType.cs
+++ b/csharp/src/Apache.Arrow/Types/IRecordType.cs
@@ -13,53 +13,16 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+using System.Collections.Generic;
 
 namespace Apache.Arrow.Types
 {
-    public enum ArrowTypeId
+    public interface IRecordType : IArrowType
     {
-        Null,
-        Boolean,
-        UInt8,
-        Int8,
-        UInt16,
-        Int16,
-        UInt32,
-        Int32,
-        UInt64,
-        Int64,
-        HalfFloat,
-        Float,
-        Double,
-        String,
-        Binary,
-        FixedSizedBinary,
-        Date32,
-        Date64,
-        Timestamp,
-        Time32,
-        Time64,
-        Interval,
-        Decimal128,
-        Decimal256,
-        List,
-        Struct,
-        Union,
-        Dictionary,
-        Map,
-        FixedSizeList,
-        Duration,
-    }
-
-    public interface IArrowType
-    {
-        ArrowTypeId TypeId { get; }
-
-        string Name { get; }
- 
-        void Accept(IArrowTypeVisitor visitor);
+        int FieldCount { get; }
 
-        bool IsFixedWidth { get; }
-    
+        Field GetFieldByIndex(int index);
+        Field GetFieldByName(string name);
+        int GetFieldIndex(string name, IEqualityComparer<string> comparer);
     }
 }
diff --git a/csharp/src/Apache.Arrow/Types/StructType.cs 
b/csharp/src/Apache.Arrow/Types/StructType.cs
index 79e83db165..da2411d34e 100644
--- a/csharp/src/Apache.Arrow/Types/StructType.cs
+++ b/csharp/src/Apache.Arrow/Types/StructType.cs
@@ -19,7 +19,7 @@ using System.Linq;
 
 namespace Apache.Arrow.Types
 {
-    public sealed class StructType : NestedType
+    public sealed class StructType : NestedType, IRecordType
     {
         public override ArrowTypeId TypeId => ArrowTypeId.Struct;
         public override string Name => "struct";
@@ -27,6 +27,8 @@ namespace Apache.Arrow.Types
         public StructType(IReadOnlyList<Field> fields) : base(fields)
         { }
 
+        public Field GetFieldByIndex(int index) => Fields[index];
+
         public Field GetFieldByName(string name,
             IEqualityComparer<string> comparer = default)
         {
@@ -56,6 +58,24 @@ namespace Apache.Arrow.Types
             return -1;
         }
 
-        public override void Accept(IArrowTypeVisitor visitor) => Accept(this, 
visitor);
+        public override void Accept(IArrowTypeVisitor visitor)
+        {
+            if (visitor is IArrowTypeVisitor<StructType> structTypeVisitor)
+            {
+                structTypeVisitor.Visit(this);
+            }
+            else if (visitor is IArrowTypeVisitor<IRecordType> 
interfaceVisitor)
+            {
+                interfaceVisitor.Visit(this);
+            }
+            else
+            {
+                visitor.Visit(this);
+            }
+        }
+
+        int IRecordType.FieldCount => Fields.Count;
+
+        Field IRecordType.GetFieldByName(string name) => GetFieldByName(name);
     }
 }
diff --git a/csharp/test/Apache.Arrow.Tests/FieldComparer.cs 
b/csharp/test/Apache.Arrow.Tests/FieldComparer.cs
index d7dcc398f2..06fd2abdd1 100644
--- a/csharp/test/Apache.Arrow.Tests/FieldComparer.cs
+++ b/csharp/test/Apache.Arrow.Tests/FieldComparer.cs
@@ -14,6 +14,7 @@
 // limitations under the License.
 
 using System.Linq;
+using Apache.Arrow.Types;
 using Xunit;
 
 namespace Apache.Arrow.Tests
@@ -40,5 +41,15 @@ namespace Apache.Arrow.Tests
 
             actual.DataType.Accept(new ArrayTypeComparer(expected.DataType));
         }
+
+        public static void Compare(IRecordType expected, IRecordType actual)
+        {
+            Assert.Equal(expected.FieldCount, actual.FieldCount);
+
+            for (int i = 0; i < expected.FieldCount; i++)
+            {
+                Compare(expected.GetFieldByIndex(i), 
actual.GetFieldByIndex(i));
+            }
+        }
     }
 }
diff --git a/csharp/test/Apache.Arrow.Tests/RecordTests.cs 
b/csharp/test/Apache.Arrow.Tests/RecordTests.cs
new file mode 100644
index 0000000000..09b0d2c665
--- /dev/null
+++ b/csharp/test/Apache.Arrow.Tests/RecordTests.cs
@@ -0,0 +1,122 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+using System.Text;
+using Apache.Arrow.Types;
+using Xunit;
+
+namespace Apache.Arrow.Tests
+{
+    public class RecordTests
+    {
+        [Fact]
+        public void StructArraysAndRecordBatchesAreSimilar()
+        {
+            Field stringField = new Field("column1", StringType.Default, true);
+            StringArray.Builder stringBuilder = new StringArray.Builder();
+            StringArray stringArray = 
stringBuilder.Append("joe").AppendNull().AppendNull().Append("mark").Build();
+
+            Field intField = new Field("column2", Int32Type.Default, true);
+            Int32Array.Builder intBuilder = new Int32Array.Builder();
+            Int32Array intArray = 
intBuilder.Append(1).Append(2).AppendNull().Append(4).Build();
+
+            Schema schema = new Schema(new[] { stringField, intField }, null);
+            RecordBatch batch = new RecordBatch(schema, new IArrowArray[] { 
stringArray, intArray }, intArray.Length);
+            IArrowRecord structArray1 = batch;
+
+            StructType structType = new StructType(new[] { stringField, 
intField });
+            StructArray structArray = new StructArray(structType, 
intArray.Length, new IArrowArray[] { stringArray, intArray }, 
ArrowBuffer.Empty);
+            IArrowRecord structArray2 = structArray;
+
+            FieldComparer.Compare(structArray1.Schema, structArray2.Schema);
+            Assert.Equal(structArray1.Length, structArray2.Length);
+            Assert.Equal(structArray1.ColumnCount, structArray2.ColumnCount);
+            Assert.Equal(structArray1.NullCount, structArray2.NullCount);
+
+            for (int i = 0; i < structArray1.ColumnCount; i++)
+            {
+                ArrowReaderVerifier.CompareArrays(structArray1.Column(i), 
structArray2.Column(i));
+            }
+        }
+
+        [Fact]
+        public void VisitStructAndBatch()
+        {
+            Field stringField = new Field("column1", StringType.Default, true);
+            StructType level1 = new StructType(new[] { stringField });
+            Field level1Field = new Field("column2", level1, false);
+            StructType level2 = new StructType(new[] { level1Field });
+            Field level2Field = new Field("column3", level2, true);
+            Schema schema = new Schema(new[] { level2Field }, null);
+
+            var visitor1 = new TestTypeVisitor1();
+            visitor1.Visit(schema);
+            Assert.Equal("111utf8", visitor1.stringBuilder.ToString());
+            var visitor2 = new TestTypeVisitor2();
+            visitor2.Visit(schema);
+            Assert.Equal("322utf8", visitor2.stringBuilder.ToString());
+
+            StringArray stringArray = new 
StringArray.Builder().Append("one").AppendNull().AppendNull().Append("four").Build();
+            StructArray level1Array = new StructArray(level1, 
stringArray.Length, new[] { stringArray }, ArrowBuffer.Empty);
+            ArrowBuffer nulls = new 
ArrowBuffer.BitmapBuilder(stringArray.Length).Append(false).Append(false).Append(true).Append(false).Build();
+            StructArray level2Array = new StructArray(level2, 
stringArray.Length, new[] { level1Array }, nulls);
+            RecordBatch batch = new RecordBatch(schema, new IArrowArray[] { 
level2Array }, stringArray.Length);
+
+
+        }
+
+        private class TestTypeVisitor1 : IArrowTypeVisitor, 
IArrowTypeVisitor<IRecordType>
+        {
+            public StringBuilder stringBuilder = new StringBuilder();
+
+            public void Visit(IArrowType type) { 
stringBuilder.Append(type.Name); }
+            public void Visit(IRecordType type) { stringBuilder.Append('1'); 
VisitFields(type); }
+
+            protected void VisitFields(IRecordType type)
+            {
+                for (int i = 0; i < type.FieldCount; i++) { 
type.GetFieldByIndex(i).DataType.Accept(this); }
+            }
+        }
+
+        private class TestTypeVisitor2 : TestTypeVisitor1,
+            IArrowTypeVisitor<StructType>,
+            IArrowTypeVisitor<Schema>
+        {
+            public void Visit(StructType type) { stringBuilder.Append('2'); 
VisitFields(type); }
+            public void Visit(Schema type) { stringBuilder.Append('3'); 
VisitFields(type); }
+        }
+
+        private class TestArrayVisitor1 : IArrowArrayVisitor, 
IArrowArrayVisitor<IArrowRecord>
+        {
+            public StringBuilder stringBuilder = new StringBuilder();
+
+            public void Visit(IArrowArray array) { 
stringBuilder.Append(array.Data.DataType.Name); }
+            public void Visit(IArrowRecord array) { stringBuilder.Append('1'); 
VisitFields(array); }
+
+            protected void VisitFields(IArrowRecord array)
+            {
+                for (int i = 0; i < array.ColumnCount; i++) { 
array.Column(i).Accept(this); }
+            }
+        }
+
+        private class TestArrayVisitor2 : TestArrayVisitor1,
+            IArrowArrayVisitor<StructArray>,
+            IArrowArrayVisitor<RecordBatch>
+        {
+            public void Visit(StructArray array) { stringBuilder.Append('2'); 
VisitFields(array); }
+            public void Visit(RecordBatch batch) { stringBuilder.Append('3'); 
VisitFields(batch); }
+        }
+    }
+}

Reply via email to