This is an automated email from the ASF dual-hosted git repository.

curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new e46abba57 feat(csharp/src/Drivers/Apache): convert Double to Float for 
Apache Spark on scalar conversion (#2296)
e46abba57 is described below

commit e46abba579d70c6fec81704e7920c1a8787103c3
Author: Bruce Irschick <[email protected]>
AuthorDate: Fri Nov 1 11:00:53 2024 -0700

    feat(csharp/src/Drivers/Apache): convert Double to Float for Apache Spark 
on scalar conversion (#2296)
    
    If Apache Spark server indicates the host data type is FLOAT, and the
    `scalar` data type conversion is selected, convert Double array to Float
    array.
    
    * Performs conversion as prescribed
    * Updates tests
    * Corrects SQL statement for older versions of Spark server.
---
 .../src/Drivers/Apache/Hive2/HiveServer2Reader.cs  | 41 +++++++++++++++++++---
 .../Apache/Hive2/HiveServer2SchemaParser.cs        |  4 +--
 csharp/src/Drivers/Apache/Spark/README.md          |  2 +-
 .../test/Drivers/Apache/Spark/NumericValueTests.cs |  2 +-
 .../Drivers/Apache/Spark/Resources/SparkData.sql   |  2 +-
 .../Drivers/Apache/Spark/SparkTestEnvironment.cs   |  9 +++--
 6 files changed, 47 insertions(+), 13 deletions(-)

diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs 
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs
index e1b711aba..08b0675d0 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs
@@ -63,6 +63,11 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
                 { ArrowTypeId.Decimal128, ConvertToDecimal128 },
                 { ArrowTypeId.Timestamp, ConvertToTimestamp },
             };
+        private static readonly IReadOnlyDictionary<ArrowTypeId, 
Func<DoubleArray, IArrowType, IArrowArray>> s_arrowDoubleConverters =
+            new Dictionary<ArrowTypeId, Func<DoubleArray, IArrowType, 
IArrowArray>>()
+            {
+                { ArrowTypeId.Float, ConvertToFloat },
+            };
 
         public HiveServer2Reader(
             HiveServer2Statement statement,
@@ -147,14 +152,22 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
                 Func<StringArray, IArrowType, IArrowArray> converter = 
s_arrowStringConverters[expectedArrowType.TypeId];
                 return converter(stringArray, expectedArrowType);
             }
+            else if (expectedArrowType != null && arrowArray is DoubleArray 
doubleArray && s_arrowDoubleConverters.ContainsKey(expectedArrowType.TypeId))
+            {
+                // Perform a conversion from double to another (float) type.
+                Func<DoubleArray, IArrowType, IArrowArray> converter = 
s_arrowDoubleConverters[expectedArrowType.TypeId];
+                return converter(doubleArray, expectedArrowType);
+            }
             return arrowArray;
         }
 
         internal static Date32Array ConvertToDate32(StringArray array, 
IArrowType _)
         {
             const DateTimeStyles DateTimeStyles = 
DateTimeStyles.AllowWhiteSpaces;
-            var resultArray = new Date32Array.Builder();
             int length = array.Length;
+            var resultArray = new Date32Array
+                .Builder()
+                .Reserve(length);
             for (int i = 0; i < length; i++)
             {
                 // Work with UTF8 string.
@@ -178,6 +191,20 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
             return resultArray.Build();
         }
 
+        internal static FloatArray ConvertToFloat(DoubleArray array, 
IArrowType _)
+        {
+            int length = array.Length;
+            var resultArray = new FloatArray
+                .Builder()
+                .Reserve(length);
+            for (int i = 0; i < length; i++)
+            {
+                resultArray.Append((float?)array.GetValue(i));
+            }
+
+            return resultArray.Build();
+        }
+
         internal static bool TryParse(ReadOnlySpan<byte> date, out DateTime 
dateTime)
         {
             if (date.Length == KnownFormatDateLength
@@ -204,12 +231,14 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
 
         private static Decimal128Array ConvertToDecimal128(StringArray array, 
IArrowType schemaType)
         {
+            int length = array.Length;
             // Using the schema type to get the precision and scale.
             Decimal128Type decimalType = (Decimal128Type)schemaType;
-            var resultArray = new Decimal128Array.Builder(decimalType);
+            var resultArray = new Decimal128Array
+                .Builder(decimalType)
+                .Reserve(length);
             Span<byte> buffer = stackalloc byte[decimalType.ByteWidth];
 
-            int length = array.Length;
             for (int i = 0; i < length; i++)
             {
                 // Work with UTF8 string.
@@ -235,9 +264,11 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
         internal static TimestampArray ConvertToTimestamp(StringArray array, 
IArrowType _)
         {
             const DateTimeStyles DateTimeStyles = 
DateTimeStyles.AssumeUniversal | DateTimeStyles.AllowWhiteSpaces;
-            // Match the precision of the server
-            var resultArrayBuilder = new 
TimestampArray.Builder(TimeUnit.Microsecond);
             int length = array.Length;
+            // Match the precision of the server
+            var resultArrayBuilder = new TimestampArray
+                .Builder(TimeUnit.Microsecond)
+                .Reserve(length);
             for (int i = 0; i < length; i++)
             {
                 // Work with UTF8 string.
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2SchemaParser.cs 
b/csharp/src/Drivers/Apache/Hive2/HiveServer2SchemaParser.cs
index 913f4d114..5ef4c01cb 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2SchemaParser.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2SchemaParser.cs
@@ -31,8 +31,8 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
                 TTypeId.BIGINT_TYPE => Int64Type.Default,
                 TTypeId.BINARY_TYPE => BinaryType.Default,
                 TTypeId.BOOLEAN_TYPE => BooleanType.Default,
-                TTypeId.DOUBLE_TYPE
-                or TTypeId.FLOAT_TYPE => DoubleType.Default,
+                TTypeId.DOUBLE_TYPE => DoubleType.Default,
+                TTypeId.FLOAT_TYPE => convertScalar ? FloatType.Default : 
DoubleType.Default,
                 TTypeId.INT_TYPE => Int32Type.Default,
                 TTypeId.SMALLINT_TYPE => Int16Type.Default,
                 TTypeId.TINYINT_TYPE => Int8Type.Default,
diff --git a/csharp/src/Drivers/Apache/Spark/README.md 
b/csharp/src/Drivers/Apache/Spark/README.md
index 0fddb4838..7d1f8b560 100644
--- a/csharp/src/Drivers/Apache/Spark/README.md
+++ b/csharp/src/Drivers/Apache/Spark/README.md
@@ -84,7 +84,7 @@ The following table depicts how the Spark ADBC driver 
converts a Spark type to a
 | DATE*                | *String*   | *string* | Date32 | DateTime |
 | DECIMAL*             | *String*   | *string* | Decimal128 | SqlDecimal |
 | DOUBLE               | Double     | double | | |
-| FLOAT                | *Double*   | *double* | | |
+| FLOAT                | *Double*   | *double* | Float | float |
 | INT                  | Int32      | int | | |
 | INTERVAL_DAY_TIME+   | String     | string | | |
 | INTERVAL_YEAR_MONTH+ | String     | string | | |
diff --git a/csharp/test/Drivers/Apache/Spark/NumericValueTests.cs 
b/csharp/test/Drivers/Apache/Spark/NumericValueTests.cs
index db041cc04..326caff13 100644
--- a/csharp/test/Drivers/Apache/Spark/NumericValueTests.cs
+++ b/csharp/test/Drivers/Apache/Spark/NumericValueTests.cs
@@ -263,7 +263,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
             await InsertSingleValueAsync(table.TableName, columnName, 
valueString);
             object doubleValue = (double)value;
             // Spark over HTTP returns float as double whereas Spark on 
Databricks returns float.
-            object floatValue = TestEnvironment.ServerType != 
SparkServerType.Databricks ? doubleValue : value;
+            object floatValue = TestEnvironment.ServerType == 
SparkServerType.Databricks || 
TestEnvironment.DataTypeConversion.HasFlag(DataTypeConversion.Scalar) ? value : 
doubleValue;
             await base.SelectAndValidateValuesAsync(table.TableName, 
columnName, floatValue, 1);
             string whereClause = GetWhereClause(columnName, value);
             if (SupportsDelete) await DeleteFromTableAsync(table.TableName, 
whereClause, 1);
diff --git a/csharp/test/Drivers/Apache/Spark/Resources/SparkData.sql 
b/csharp/test/Drivers/Apache/Spark/Resources/SparkData.sql
index 8ee0f7e90..9c0a41e0b 100644
--- a/csharp/test/Drivers/Apache/Spark/Resources/SparkData.sql
+++ b/csharp/test/Drivers/Apache/Spark/Resources/SparkData.sql
@@ -14,7 +14,7 @@
  -- See the License for the specific language governing permissions and
  -- limitations under the License.
 
-CREATE OR REPLACE TABLE IF NOT EXISTS 
{ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} (
+CREATE TABLE IF NOT EXISTS {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} (
   id LONG,
   byte BYTE,
   short SHORT,
diff --git a/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs 
b/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs
index 7416772a4..06b6fe246 100644
--- a/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs
+++ b/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs
@@ -123,6 +123,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
 
         internal SparkServerType ServerType => 
((SparkConnection)Connection).ServerType;
 
+        internal DataTypeConversion DataTypeConversion => 
((SparkConnection)Connection).DataTypeConversion;
+
         public override string VendorVersion => 
((HiveServer2Connection)Connection).VendorVersion;
 
         public override bool SupportsDelete => ServerType == 
SparkServerType.Databricks;
@@ -139,10 +141,11 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
         public override SampleDataBuilder GetSampleDataBuilder()
         {
             SampleDataBuilder sampleDataBuilder = new();
-            Type floatNetType = ServerType == SparkServerType.Databricks ? 
typeof(float) : typeof(double);
-            Type floatArrowType = ServerType == SparkServerType.Databricks ? 
typeof(FloatType) : typeof(DoubleType);
+            bool dataTypeIsFloat = ServerType == SparkServerType.Databricks || 
DataTypeConversion.HasFlag(DataTypeConversion.Scalar);
+            Type floatNetType = dataTypeIsFloat ? typeof(float) : 
typeof(double);
+            Type floatArrowType = dataTypeIsFloat ? typeof(FloatType) : 
typeof(DoubleType);
             object floatValue;
-            if (ServerType == SparkServerType.Databricks)
+            if (dataTypeIsFloat)
                 floatValue = 1f;
             else
                 floatValue = 1d;

Reply via email to