This is an automated email from the ASF dual-hosted git repository.
curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new e46abba57 feat(csharp/src/Drivers/Apache): convert Double to Float for
Apache Spark on scalar conversion (#2296)
e46abba57 is described below
commit e46abba579d70c6fec81704e7920c1a8787103c3
Author: Bruce Irschick <[email protected]>
AuthorDate: Fri Nov 1 11:00:53 2024 -0700
feat(csharp/src/Drivers/Apache): convert Double to Float for Apache Spark
on scalar conversion (#2296)
If Apache Spark server indicates the host data type is FLOAT, and the
`scalar` data type conversion is selected, convert Double array to Float
array.
* Performs conversion as prescribed
* Updates tests
* Corrects SQL statement for older versions of Spark server.
---
.../src/Drivers/Apache/Hive2/HiveServer2Reader.cs | 41 +++++++++++++++++++---
.../Apache/Hive2/HiveServer2SchemaParser.cs | 4 +--
csharp/src/Drivers/Apache/Spark/README.md | 2 +-
.../test/Drivers/Apache/Spark/NumericValueTests.cs | 2 +-
.../Drivers/Apache/Spark/Resources/SparkData.sql | 2 +-
.../Drivers/Apache/Spark/SparkTestEnvironment.cs | 9 +++--
6 files changed, 47 insertions(+), 13 deletions(-)
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs
index e1b711aba..08b0675d0 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs
@@ -63,6 +63,11 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
{ ArrowTypeId.Decimal128, ConvertToDecimal128 },
{ ArrowTypeId.Timestamp, ConvertToTimestamp },
};
+ private static readonly IReadOnlyDictionary<ArrowTypeId,
Func<DoubleArray, IArrowType, IArrowArray>> s_arrowDoubleConverters =
+ new Dictionary<ArrowTypeId, Func<DoubleArray, IArrowType,
IArrowArray>>()
+ {
+ { ArrowTypeId.Float, ConvertToFloat },
+ };
public HiveServer2Reader(
HiveServer2Statement statement,
@@ -147,14 +152,22 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
Func<StringArray, IArrowType, IArrowArray> converter =
s_arrowStringConverters[expectedArrowType.TypeId];
return converter(stringArray, expectedArrowType);
}
+ else if (expectedArrowType != null && arrowArray is DoubleArray
doubleArray && s_arrowDoubleConverters.ContainsKey(expectedArrowType.TypeId))
+ {
+ // Perform a conversion from double to another (float) type.
+ Func<DoubleArray, IArrowType, IArrowArray> converter =
s_arrowDoubleConverters[expectedArrowType.TypeId];
+ return converter(doubleArray, expectedArrowType);
+ }
return arrowArray;
}
internal static Date32Array ConvertToDate32(StringArray array,
IArrowType _)
{
const DateTimeStyles DateTimeStyles =
DateTimeStyles.AllowWhiteSpaces;
- var resultArray = new Date32Array.Builder();
int length = array.Length;
+ var resultArray = new Date32Array
+ .Builder()
+ .Reserve(length);
for (int i = 0; i < length; i++)
{
// Work with UTF8 string.
@@ -178,6 +191,20 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
return resultArray.Build();
}
+ internal static FloatArray ConvertToFloat(DoubleArray array,
IArrowType _)
+ {
+ int length = array.Length;
+ var resultArray = new FloatArray
+ .Builder()
+ .Reserve(length);
+ for (int i = 0; i < length; i++)
+ {
+ resultArray.Append((float?)array.GetValue(i));
+ }
+
+ return resultArray.Build();
+ }
+
internal static bool TryParse(ReadOnlySpan<byte> date, out DateTime
dateTime)
{
if (date.Length == KnownFormatDateLength
@@ -204,12 +231,14 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
private static Decimal128Array ConvertToDecimal128(StringArray array,
IArrowType schemaType)
{
+ int length = array.Length;
// Using the schema type to get the precision and scale.
Decimal128Type decimalType = (Decimal128Type)schemaType;
- var resultArray = new Decimal128Array.Builder(decimalType);
+ var resultArray = new Decimal128Array
+ .Builder(decimalType)
+ .Reserve(length);
Span<byte> buffer = stackalloc byte[decimalType.ByteWidth];
- int length = array.Length;
for (int i = 0; i < length; i++)
{
// Work with UTF8 string.
@@ -235,9 +264,11 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
internal static TimestampArray ConvertToTimestamp(StringArray array,
IArrowType _)
{
const DateTimeStyles DateTimeStyles =
DateTimeStyles.AssumeUniversal | DateTimeStyles.AllowWhiteSpaces;
- // Match the precision of the server
- var resultArrayBuilder = new
TimestampArray.Builder(TimeUnit.Microsecond);
int length = array.Length;
+ // Match the precision of the server
+ var resultArrayBuilder = new TimestampArray
+ .Builder(TimeUnit.Microsecond)
+ .Reserve(length);
for (int i = 0; i < length; i++)
{
// Work with UTF8 string.
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2SchemaParser.cs
b/csharp/src/Drivers/Apache/Hive2/HiveServer2SchemaParser.cs
index 913f4d114..5ef4c01cb 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2SchemaParser.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2SchemaParser.cs
@@ -31,8 +31,8 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
TTypeId.BIGINT_TYPE => Int64Type.Default,
TTypeId.BINARY_TYPE => BinaryType.Default,
TTypeId.BOOLEAN_TYPE => BooleanType.Default,
- TTypeId.DOUBLE_TYPE
- or TTypeId.FLOAT_TYPE => DoubleType.Default,
+ TTypeId.DOUBLE_TYPE => DoubleType.Default,
+ TTypeId.FLOAT_TYPE => convertScalar ? FloatType.Default :
DoubleType.Default,
TTypeId.INT_TYPE => Int32Type.Default,
TTypeId.SMALLINT_TYPE => Int16Type.Default,
TTypeId.TINYINT_TYPE => Int8Type.Default,
diff --git a/csharp/src/Drivers/Apache/Spark/README.md
b/csharp/src/Drivers/Apache/Spark/README.md
index 0fddb4838..7d1f8b560 100644
--- a/csharp/src/Drivers/Apache/Spark/README.md
+++ b/csharp/src/Drivers/Apache/Spark/README.md
@@ -84,7 +84,7 @@ The following table depicts how the Spark ADBC driver
converts a Spark type to a
| DATE* | *String* | *string* | Date32 | DateTime |
| DECIMAL* | *String* | *string* | Decimal128 | SqlDecimal |
| DOUBLE | Double | double | | |
-| FLOAT | *Double* | *double* | | |
+| FLOAT | *Double* | *double* | Float | float |
| INT | Int32 | int | | |
| INTERVAL_DAY_TIME+ | String | string | | |
| INTERVAL_YEAR_MONTH+ | String | string | | |
diff --git a/csharp/test/Drivers/Apache/Spark/NumericValueTests.cs
b/csharp/test/Drivers/Apache/Spark/NumericValueTests.cs
index db041cc04..326caff13 100644
--- a/csharp/test/Drivers/Apache/Spark/NumericValueTests.cs
+++ b/csharp/test/Drivers/Apache/Spark/NumericValueTests.cs
@@ -263,7 +263,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
await InsertSingleValueAsync(table.TableName, columnName,
valueString);
object doubleValue = (double)value;
// Spark over HTTP returns float as double whereas Spark on
Databricks returns float.
- object floatValue = TestEnvironment.ServerType !=
SparkServerType.Databricks ? doubleValue : value;
+ object floatValue = TestEnvironment.ServerType ==
SparkServerType.Databricks ||
TestEnvironment.DataTypeConversion.HasFlag(DataTypeConversion.Scalar) ? value :
doubleValue;
await base.SelectAndValidateValuesAsync(table.TableName,
columnName, floatValue, 1);
string whereClause = GetWhereClause(columnName, value);
if (SupportsDelete) await DeleteFromTableAsync(table.TableName,
whereClause, 1);
diff --git a/csharp/test/Drivers/Apache/Spark/Resources/SparkData.sql
b/csharp/test/Drivers/Apache/Spark/Resources/SparkData.sql
index 8ee0f7e90..9c0a41e0b 100644
--- a/csharp/test/Drivers/Apache/Spark/Resources/SparkData.sql
+++ b/csharp/test/Drivers/Apache/Spark/Resources/SparkData.sql
@@ -14,7 +14,7 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
-CREATE OR REPLACE TABLE IF NOT EXISTS
{ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} (
+CREATE TABLE IF NOT EXISTS {ADBC_CATALOG}.{ADBC_DATASET}.{ADBC_TABLE} (
id LONG,
byte BYTE,
short SHORT,
diff --git a/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs
b/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs
index 7416772a4..06b6fe246 100644
--- a/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs
+++ b/csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs
@@ -123,6 +123,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
internal SparkServerType ServerType =>
((SparkConnection)Connection).ServerType;
+ internal DataTypeConversion DataTypeConversion =>
((SparkConnection)Connection).DataTypeConversion;
+
public override string VendorVersion =>
((HiveServer2Connection)Connection).VendorVersion;
public override bool SupportsDelete => ServerType ==
SparkServerType.Databricks;
@@ -139,10 +141,11 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
public override SampleDataBuilder GetSampleDataBuilder()
{
SampleDataBuilder sampleDataBuilder = new();
- Type floatNetType = ServerType == SparkServerType.Databricks ?
typeof(float) : typeof(double);
- Type floatArrowType = ServerType == SparkServerType.Databricks ?
typeof(FloatType) : typeof(DoubleType);
+ bool dataTypeIsFloat = ServerType == SparkServerType.Databricks ||
DataTypeConversion.HasFlag(DataTypeConversion.Scalar);
+ Type floatNetType = dataTypeIsFloat ? typeof(float) :
typeof(double);
+ Type floatArrowType = dataTypeIsFloat ? typeof(FloatType) :
typeof(DoubleType);
object floatValue;
- if (ServerType == SparkServerType.Databricks)
+ if (dataTypeIsFloat)
floatValue = 1f;
else
floatValue = 1d;