This is an automated email from the ASF dual-hosted git repository.
curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 319ba01b6 feat(csharp/src/Drivers/Apache/Spark): extend SQL type name
parsing for all types (#1911)
319ba01b6 is described below
commit 319ba01b60109e824fdf87b30dc2a9f394767adf
Author: Bruce Irschick <[email protected]>
AuthorDate: Thu Jun 13 09:53:40 2024 -0700
feat(csharp/src/Drivers/Apache/Spark): extend SQL type name parsing for all
types (#1911)
Extend SQL type name parsing to all possible types for Spark.
Additional support for:
* ARRAY
* BIGINT
* BINARY
* BOOLEAN
* DATE
* DOUBLE
* FLOAT
* INTEGER
* JAVA_OBJECT
* SMALLINT
* STRUCT
* TIMESTAMP
* TIMESTAMP_WITH_TIMEZONE
* TINYINT
Add extensive tests for SQL type name parsing
---
csharp/src/Drivers/Apache/Spark/SparkConnection.cs | 42 +--
.../src/Drivers/Apache/Spark/SqlTypeNameParser.cs | 387 +++++++++++++++++++--
csharp/test/Drivers/Apache/Spark/DriverTests.cs | 3 +
.../Drivers/Apache/Spark/SqlTypeNameParserTests.cs | 316 +++++++++++++++++
4 files changed, 698 insertions(+), 50 deletions(-)
diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
index 798fdeec2..45111446c 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
@@ -708,7 +708,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
case (short)ColumnTypeId.DECIMAL:
case (short)ColumnTypeId.NUMERIC:
{
- SqlDecimalParserResult result = new
SqlDecimalTypeParser().ParseOrDefault(typeName, new
SqlDecimalParserResult(typeName));
+ SqlDecimalParserResult result =
SqlTypeNameParser<SqlDecimalParserResult>.Parse(typeName, colType);
tableInfo?.Precision.Add(result.Precision);
tableInfo?.Scale.Add((short)result.Scale);
tableInfo?.BaseTypeName.Add(result.BaseTypeName);
@@ -717,30 +717,26 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
case (short)ColumnTypeId.CHAR:
case (short)ColumnTypeId.NCHAR:
- {
- bool success = new
SqlCharTypeParser().TryParse(typeName, out SqlCharVarcharParserResult? result);
- tableInfo?.Precision.Add(success ? result!.ColumnSize
: SqlVarcharTypeParser.VarcharColumnSizeDefault);
- tableInfo?.Scale.Add(null);
- tableInfo?.BaseTypeName.Add(success ?
result!.BaseTypeName : "CHAR");
- break;
- }
case (short)ColumnTypeId.VARCHAR:
case (short)ColumnTypeId.LONGVARCHAR:
case (short)ColumnTypeId.LONGNVARCHAR:
case (short)ColumnTypeId.NVARCHAR:
{
- bool success = new
SqlVarcharTypeParser().TryParse(typeName, out SqlCharVarcharParserResult?
result);
- tableInfo?.Precision.Add(success ? result!.ColumnSize
: SqlVarcharTypeParser.VarcharColumnSizeDefault);
+ SqlCharVarcharParserResult result =
SqlTypeNameParser<SqlCharVarcharParserResult>.Parse(typeName, colType);
+ tableInfo?.Precision.Add(result.ColumnSize);
tableInfo?.Scale.Add(null);
- tableInfo?.BaseTypeName.Add(success ?
result!.BaseTypeName : "STRING");
+ tableInfo?.BaseTypeName.Add(result.BaseTypeName);
break;
}
default:
- tableInfo?.Precision.Add(null);
- tableInfo?.Scale.Add(null);
- tableInfo?.BaseTypeName.Add(typeName);
- break;
+ {
+ SqlTypeNameParserResult result =
SqlTypeNameParser<SqlTypeNameParserResult>.Parse(typeName, colType);
+ tableInfo?.Precision.Add(null);
+ tableInfo?.Scale.Add(null);
+ tableInfo?.BaseTypeName.Add(result.BaseTypeName);
+ break;
+ }
}
}
@@ -783,8 +779,8 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
case (int)ColumnTypeId.NUMERIC:
// Note: parsing the type name for SQL DECIMAL types as
the precision and scale values
// are not returned in the Thrift call to GetColumns
- return new SqlDecimalTypeParser()
- .ParseOrDefault(typeName, new
SqlDecimalParserResult(typeName))
+ return SqlTypeNameParser<SqlDecimalParserResult>
+ .Parse(typeName, columnTypeId)
.Decimal128Type;
case (int)ColumnTypeId.NULL:
return NullType.Default;
@@ -797,7 +793,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
}
}
- private StructArray GetDbSchemas(
+ private static StructArray GetDbSchemas(
GetObjectsDepth depth,
Dictionary<string, Dictionary<string, TableInfo>> schemaMap)
{
@@ -841,7 +837,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
nullBitmapBuffer.Build());
}
- private StructArray GetTableSchemas(
+ private static StructArray GetTableSchemas(
GetObjectsDepth depth,
Dictionary<string, TableInfo> tableMap)
{
@@ -892,7 +888,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
nullBitmapBuffer.Build());
}
- private StructArray GetColumnSchema(TableInfo tableInfo)
+ private static StructArray GetColumnSchema(TableInfo tableInfo)
{
StringArray.Builder columnNameBuilder = new StringArray.Builder();
Int32Array.Builder ordinalPositionBuilder = new
Int32Array.Builder();
@@ -976,7 +972,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
nullBitmapBuffer.Build());
}
- private string PatternToRegEx(string? pattern)
+ private static string PatternToRegEx(string? pattern)
{
if (pattern == null)
return ".*";
@@ -984,13 +980,13 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
StringBuilder builder = new StringBuilder("(?i)^");
string convertedPattern = pattern.Replace("_", ".").Replace("%",
".*");
builder.Append(convertedPattern);
- builder.Append("$");
+ builder.Append('$');
return builder.ToString();
}
- private string GetProductVersion()
+ private static string GetProductVersion()
{
FileVersionInfo fileVersionInfo =
FileVersionInfo.GetVersionInfo(Assembly.GetExecutingAssembly().Location);
return fileVersionInfo.ProductVersion ?? ProductVersionDefault;
diff --git a/csharp/src/Drivers/Apache/Spark/SqlTypeNameParser.cs
b/csharp/src/Drivers/Apache/Spark/SqlTypeNameParser.cs
index 915fb1a19..ab56703eb 100644
--- a/csharp/src/Drivers/Apache/Spark/SqlTypeNameParser.cs
+++ b/csharp/src/Drivers/Apache/Spark/SqlTypeNameParser.cs
@@ -16,29 +16,133 @@
*/
using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Linq;
using System.Text.RegularExpressions;
using Apache.Arrow.Types;
namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
{
+ /// <summary>
+ /// Interface for the SQL type name parser.
+ /// </summary>
+ internal interface ISqlTypeNameParser
+ {
+ /// <summary>
+ /// Tries to parse the input string for a valid SQL type definition.
+ /// </summary>
+ /// <param name="input">The SQL type defintion string to parse.</param>
+ /// <param name="result">If successful, the result; otherwise
<c>null</c>.</param>
+ /// <returns>True if it can successfully parse the type definition
input string; otherwise false.</returns>
+ bool TryParse(string input, out SqlTypeNameParserResult? result);
+ }
+
/// <summary>
/// Abstract and generic SQL data type name parser.
/// </summary>
- /// <typeparam name="T">The <see cref="ParserResult"/> type when returning
a successful parse</typeparam>
- internal abstract class SqlTypeNameParser<T> where T : ParserResult
+ /// <typeparam name="T">The <see cref="SqlTypeNameParserResult"/> type
when returning a successful parse</typeparam>
+ internal abstract class SqlTypeNameParser<T> : ISqlTypeNameParser where T
: SqlTypeNameParserResult
{
+ private static readonly ConcurrentDictionary<string,
SqlTypeNameParserResult> s_cache = new();
+
+ private static readonly IReadOnlyDictionary<int, ISqlTypeNameParser>
s_parserMap = new Dictionary<int, ISqlTypeNameParser>()
+ {
+ { (int)SparkConnection.ColumnTypeId.ARRAY,
SqlArrayTypeParser.Default },
+ { (int)SparkConnection.ColumnTypeId.BIGINT,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.BIGINT.ToString()) },
+ { (int)SparkConnection.ColumnTypeId.BIT,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.BIT.ToString()) },
+ { (int)SparkConnection.ColumnTypeId.BINARY,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.BINARY.ToString()) },
+ { (int)SparkConnection.ColumnTypeId.BLOB,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.BLOB.ToString()) },
+ { (int)SparkConnection.ColumnTypeId.BOOLEAN,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.BOOLEAN.ToString()) },
+ { (int)SparkConnection.ColumnTypeId.CHAR,
SqlCharTypeParser.Default },
+ { (int)SparkConnection.ColumnTypeId.CLOB,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.CLOB.ToString()) },
+ { (int)SparkConnection.ColumnTypeId.DATALINK,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.DATALINK.ToString()) },
+ { (int)SparkConnection.ColumnTypeId.DATE,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.DATE.ToString()) },
+ { (int)SparkConnection.ColumnTypeId.DECIMAL,
SqlDecimalTypeParser.Default },
+ { (int)SparkConnection.ColumnTypeId.DISTINCT,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.DISTINCT.ToString()) },
+ { (int)SparkConnection.ColumnTypeId.DOUBLE,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.DOUBLE.ToString()) },
+ { (int)SparkConnection.ColumnTypeId.FLOAT,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.FLOAT.ToString()) },
+ { (int)SparkConnection.ColumnTypeId.INTEGER,
SqlIntegerTypeParser.Default },
+ { (int)SparkConnection.ColumnTypeId.JAVA_OBJECT,
SqlMapTypeParser.Default },
+ { (int)SparkConnection.ColumnTypeId.LONGNVARCHAR,
SqlVarcharTypeParser.Default },
+ { (int)SparkConnection.ColumnTypeId.LONGVARCHAR,
SqlVarcharTypeParser.Default },
+ { (int)SparkConnection.ColumnTypeId.NCHAR,
SqlCharTypeParser.Default },
+ { (int)SparkConnection.ColumnTypeId.NCLOB,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.NCLOB.ToString()) },
+ { (int)SparkConnection.ColumnTypeId.NULL,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.NULL.ToString()) },
+ { (int)SparkConnection.ColumnTypeId.NUMERIC,
SqlDecimalTypeParser.Default },
+ { (int)SparkConnection.ColumnTypeId.NVARCHAR,
SqlVarcharTypeParser.Default },
+ { (int)SparkConnection.ColumnTypeId.OTHER,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.OTHER.ToString()) },
+ { (int)SparkConnection.ColumnTypeId.REAL,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.REAL.ToString()) },
+ { (int)SparkConnection.ColumnTypeId.REF_CURSOR,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.REF_CURSOR.ToString())
},
+ { (int)SparkConnection.ColumnTypeId.REF,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.REF.ToString()) },
+ { (int)SparkConnection.ColumnTypeId.ROWID,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.ROWID.ToString()) },
+ { (int)SparkConnection.ColumnTypeId.SMALLINT,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.SMALLINT.ToString()) },
+ { (int)SparkConnection.ColumnTypeId.STRUCT,
SqlStructTypeParser.Default },
+ { (int)SparkConnection.ColumnTypeId.TIME,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.TIME.ToString()) },
+ { (int)SparkConnection.ColumnTypeId.TIME_WITH_TIMEZONE,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.TIME_WITH_TIMEZONE.ToString())
},
+ { (int)SparkConnection.ColumnTypeId.TIMESTAMP,
SqlTimestampTypeParser.Default },
+ { (int)SparkConnection.ColumnTypeId.TIMESTAMP_WITH_TIMEZONE,
SqlTimestampTypeParser.Default },
+ { (int)SparkConnection.ColumnTypeId.TINYINT,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.TINYINT.ToString()) },
+ { (int)SparkConnection.ColumnTypeId.VARCHAR,
SqlVarcharTypeParser.Default },
+ { (int)SparkConnection.ColumnTypeId.SQLXML,
SqlSimpleTypeParser.Default(SparkConnection.ColumnTypeId.SQLXML.ToString()) },
+ };
+
+ // Note: the INTERVAL sql type does not have an associated column type
id.
+ private static readonly HashSet<ISqlTypeNameParser> s_parsers =
s_parserMap.Values
+ .Concat([SqlIntervalTypeParser.Default,
SqlSimpleTypeParser.Default("VOID")])
+ .ToHashSet();
+
+ /// <summary>
+ /// Gets the base SQL type name without decoration or sub clauses
+ /// </summary>
+ public abstract string BaseTypeName { get; }
+
+ /// <summary>
+ /// Parses the input type name string and produces a result.
+ /// When a matching parser is found that successfully parses the type
name string, the result of that parse is returned.
+ /// If no parser is able to successfully match the input type name,
+ /// then a <see cref="NotSupportedException"/> is thrown.
+ /// </summary>
+ /// <param name="input">The type name string to parse</param>
+ /// <param name="columnTypeIdHint">If provided, the column type id is
used as a hint to find the most likely matching parser.</param>
+ /// <returns>
+ /// A parser result, from a successful match and parse.
+ /// </returns>
+ public static T Parse(string input, int? columnTypeIdHint = null) =>
+ SqlTypeNameParser<T>.TryParse(input, out SqlTypeNameParserResult?
result, columnTypeIdHint) && result != null
+ ? CastResultOrThrow(input, result)
+ : throw new NotSupportedException($"Unsupported SQL type name:
'{input}'");
+
/// <summary>
/// Gets the <see cref="Regex"/> expression to parse the SQL type name
/// </summary>
- public abstract Regex Expression { get; }
+ protected abstract Regex Expression { get; }
/// <summary>
- /// Generates the successful result of a matching parse
+ /// Generates the successful result for a matching parse
/// </summary>
/// <param name="input">The original SQL type name</param>
/// <param name="match">The successful <see cref="Match"/>
result</param>
/// <returns></returns>
- public abstract T GenerateResult(string input, Match match);
+ protected virtual T GenerateResult(string input, Match match) =>
(T)new SqlTypeNameParserResult(input, BaseTypeName);
+
+ private static T CastResultOrThrow(string input,
SqlTypeNameParserResult result) =>
+ (result is T typedResult)
+ ? typedResult
+ : throw new InvalidCastException($"Cannot cast return type
'{result.GetType().Name}' to type '{(typeof(T)).Name}' for input SQL type name:
'{input}'.");
+
+ /// <summary>
+ /// Tries to parse the input string for a valid SQL type definition.
+ /// </summary>
+ /// <param name="input">The SQL type defintion string to parse.</param>
+ /// <param name="result">If successful, the result; otherwise
<c>null</c>.</param>
+ /// <returns>True if it can successfully parse the type definition
input string; otherwise false.</returns>
+ bool ISqlTypeNameParser.TryParse(string input, out
SqlTypeNameParserResult? result)
+ {
+ bool success = TryParse(input, out T? typedResult);
+ result = success ? typedResult : (SqlTypeNameParserResult?)default;
+ return success;
+ }
/// <summary>
/// Tries to parse the input string for a valid SQL type definition.
@@ -46,7 +150,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
/// <param name="input">The SQL type defintion string to parse.</param>
/// <param name="result">If successful, the result; otherwise
<c>null</c>.</param>
/// <returns>True if it can successfully parse the type definition
input string; otherwise false.</returns>
- public bool TryParse(string input, out T? result)
+ internal bool TryParse(string input, out T? result)
{
Match match = Expression.Match(input);
if (!match.Success)
@@ -60,23 +164,57 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
}
/// <summary>
- /// Parses the input string for a valid SQL type definition and
returns the result or returns the <c>defaultValue</c>, if invalid.
+ /// Tries to parse the input SQL type name. If a matching parser is
found and can parse the type name, it's result is set in <c>parserResult</c>
and <c>true</c> is returned.
+ /// If a matching parser is not found <c>parserResult</c> is set to
null and <c>false</c> is returned.
/// </summary>
- /// <param name="input">The SQL type defintion string to parse.</param>
- /// <param name="defaultValue">If input string is an invalid type
definition, this result is returned instead.</param>
- /// <returns>If input string is a valid SQL type definition, it
returns the result; otherwise <c>defaultValue</c>.</returns>
- public T ParseOrDefault(string input, T defaultValue)
+ /// <param name="input">The SQL type name to parse</param>
+ /// <param name="parserResult">The result of a successful parse,
<c>null</c> otherwise</param>
+ /// <param name="columnTypeIdHint">The column type id as a hint to
find the most appropriate parser</param>
+ /// <returns><c>true</c> if a matching parser is able to parse the SQL
type name, <c>false</c> otherwise</returns>
+ internal static bool TryParse(string input, out
SqlTypeNameParserResult? parserResult, int? columnTypeIdHint = null)
{
- return TryParse(input, out T? result) ? result! : defaultValue;
+ // Note: there may be multiple calls that successfully add/set the
value in the cache
+ // - but the parser will produce the same result in each case.
+ string trimmedInput = input.Trim();
+ if (s_cache.ContainsKey(trimmedInput))
+ {
+ parserResult = s_cache[trimmedInput];
+ return true;
+ }
+
+ ISqlTypeNameParser? sqlTypeNameParser = null;
+ if (columnTypeIdHint != null &&
s_parserMap.ContainsKey(columnTypeIdHint.Value))
+ {
+ sqlTypeNameParser = s_parserMap[columnTypeIdHint.Value];
+ if (sqlTypeNameParser.TryParse(input, out
SqlTypeNameParserResult? result) && result != null)
+ {
+ parserResult = result;
+ s_cache[trimmedInput] = result;
+ return true;
+ }
+ }
+ foreach (ISqlTypeNameParser parser in s_parsers)
+ {
+ if (parser == sqlTypeNameParser) continue;
+ if (parser.TryParse(input, out SqlTypeNameParserResult?
result) && result != null)
+ {
+ parserResult = result;
+ s_cache[trimmedInput] = result;
+ return true;
+ }
+ }
+
+ parserResult = null;
+ return false;
}
}
/// <summary>
- /// An result for parsing a SQL data type.
+ /// A result for parsing a SQL data type.
/// </summary>
/// <param name="typeName">The original SQL type name to parse</param>
/// <param name="baseTypeName">The 'base' type name to use which is
typically more simple without sub-clauses</param>
- internal class ParserResult(string typeName, string baseTypeName)
+ internal class SqlTypeNameParserResult(string typeName, string
baseTypeName)
{
/// <summary>
/// The original SQL type name
@@ -87,6 +225,19 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
/// The 'base' type name to use which is typically more simple without
sub-clauses
/// </summary>
public string BaseTypeName { get; } = baseTypeName;
+
+ public override bool Equals(object? obj)
+ {
+ if (ReferenceEquals(this, obj)) return true;
+ if (obj is not SqlTypeNameParserResult other) return false;
+ return TypeName.Equals(other.TypeName)
+ && BaseTypeName.Equals(other.BaseTypeName);
+ }
+
+ public override int GetHashCode()
+ {
+ return TypeName.GetHashCode() ^ BaseTypeName.GetHashCode();
+ }
}
/// <summary>
@@ -95,12 +246,23 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
/// <param name="typeName">The original SQL type name to parse</param>
/// <param name="baseTypeName">The 'base' type name without the length
clause</param>
/// <param name="columnSize">The length of the column for this type
name</param>
- internal class SqlCharVarcharParserResult(string typeName, string
baseTypeName, int columnSize) : ParserResult(typeName, baseTypeName)
+ internal class SqlCharVarcharParserResult(string typeName, string
baseTypeName, int columnSize = SqlVarcharTypeParser.VarcharColumnSizeDefault) :
SqlTypeNameParserResult(typeName, baseTypeName)
{
/// <summary>
/// The length of the column for this type name
/// </summary>
public int ColumnSize { get; } = columnSize;
+
+ public override bool Equals(object? obj) => obj is
SqlCharVarcharParserResult result
+ && base.Equals(obj)
+ && TypeName == result.TypeName
+ && BaseTypeName == result.BaseTypeName
+ && ColumnSize == result.ColumnSize;
+
+ public override int GetHashCode() => base.GetHashCode()
+ ^ TypeName.GetHashCode()
+ ^ BaseTypeName.GetHashCode()
+ ^ ColumnSize.GetHashCode();
}
/// <summary>
@@ -110,7 +272,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
/// <param name="baseTypeName">The 'base' type name without the precision
or scale clause</param>
/// <param name="precision">The precision of the decimal type</param>
/// <param name="scale">The scale (decimal digits) of the decimal
type</param>
- internal class SqlDecimalParserResult(string typeName, string
baseTypeName, int precision, int scale) : ParserResult(typeName, baseTypeName)
+ internal class SqlDecimalParserResult(string typeName, string
baseTypeName, int precision, int scale) : SqlTypeNameParserResult(typeName,
baseTypeName)
{
/// <summary>
/// Constructs a new default result given the original type name.
@@ -132,6 +294,37 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
/// The <see cref='Types.Decimal128Type'/> representing the parsed
type name
/// </summary>
public Decimal128Type Decimal128Type { get; } = new
Decimal128Type(precision, scale);
+
+ public override bool Equals(object? obj) => obj is
SqlDecimalParserResult result
+ && base.Equals(obj)
+ && TypeName == result.TypeName
+ && BaseTypeName == result.BaseTypeName
+ && Precision == result.Precision
+ && Scale == result.Scale
+ &&
EqualityComparer<Decimal128Type>.Default.Equals(Decimal128Type,
result.Decimal128Type);
+
+ public override int GetHashCode() => base.GetHashCode()
+ ^ TypeName.GetHashCode()
+ ^ BaseTypeName.GetHashCode()
+ ^ Precision.GetHashCode()
+ ^ Scale.GetHashCode()
+ ^ Decimal128Type.GetHashCode();
+ }
+
+ internal class SqlIntervalParserResult(string typeName, string
baseTypeName, string qualifiers) : SqlTypeNameParserResult(typeName,
baseTypeName)
+ {
+ public string Qualifiers { get; } = qualifiers;
+
+ public override bool Equals(object? obj) => obj is
SqlIntervalParserResult result
+ && base.Equals(obj)
+ && TypeName == result.TypeName
+ && BaseTypeName == result.BaseTypeName
+ && Qualifiers == result.Qualifiers;
+
+ public override int GetHashCode() => base.GetHashCode()
+ ^ TypeName.GetHashCode()
+ ^ BaseTypeName.GetHashCode()
+ ^ Qualifiers.GetHashCode();
}
/// <summary>
@@ -139,15 +332,17 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
/// </summary>
internal class SqlCharTypeParser :
SqlTypeNameParser<SqlCharVarcharParserResult>
{
- private const string BaseTypeName = "CHAR";
+ public static SqlCharTypeParser Default { get; } = new();
+
+ public override string BaseTypeName => "CHAR";
private static readonly Regex s_expression = new(
@"^\s*(?<typeName>((CHAR)|(NCHAR)))(\s*\(\s*(?<precision>\d{1,10})\s*\))\s*$",
RegexOptions.IgnoreCase | RegexOptions.Compiled |
RegexOptions.CultureInvariant);
- public override Regex Expression => s_expression;
+ protected override Regex Expression => s_expression;
- public override SqlCharVarcharParserResult GenerateResult(string
input, Match match)
+ protected override SqlCharVarcharParserResult GenerateResult(string
input, Match match)
{
GroupCollection groups = match.Groups;
Group precisionGroup = groups["precision"];
@@ -165,17 +360,19 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
internal class SqlVarcharTypeParser :
SqlTypeNameParser<SqlCharVarcharParserResult>
{
internal const int VarcharColumnSizeDefault = int.MaxValue;
-
- private const string VarcharBaseTypeName = "VARCHAR";
private const string StringBaseTypeName = "STRING";
+ public static SqlVarcharTypeParser Default => new();
+
+ public override string BaseTypeName => "VARCHAR";
+
private static readonly Regex s_expression = new(
@"^\s*(?<typeName>((STRING)|(VARCHAR)|(LONGVARCHAR)|(LONGNVARCHAR)|(NVARCHAR)))(\s*\(\s*(?<precision>\d{1,10})\s*\))?\s*$",
RegexOptions.IgnoreCase | RegexOptions.Compiled |
RegexOptions.CultureInvariant);
- public override Regex Expression => s_expression;
+ protected override Regex Expression => s_expression;
- public override SqlCharVarcharParserResult GenerateResult(string
input, Match match)
+ protected override SqlCharVarcharParserResult GenerateResult(string
input, Match match)
{
GroupCollection groups = match.Groups;
Group precisionGroup = groups["precision"];
@@ -183,7 +380,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
string baseTypeName =
typeNameGroup.Value.Equals(StringBaseTypeName,
StringComparison.InvariantCultureIgnoreCase)
? StringBaseTypeName
- : VarcharBaseTypeName;
+ : BaseTypeName;
int precision = precisionGroup.Success &&
int.TryParse(precisionGroup.Value, out int candidatePrecision)
? candidatePrecision
: VarcharColumnSizeDefault;
@@ -199,7 +396,9 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
internal const int DecimalPrecisionDefault = 10;
internal const int DecimalScaleDefault = 0;
- private const string BaseTypeName = "DECIMAL";
+ public static SqlDecimalTypeParser Default => new();
+
+ public override string BaseTypeName => "DECIMAL";
// Pattern is based on this definition
//
https://docs.databricks.com/en/sql/language-manual/data-types/decimal-type.html#syntax
@@ -210,9 +409,9 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
@"^\s*(?<typeName>((DECIMAL)|(DEC)|(NUMERIC)))(\s*\(\s*((?<precision>\d{1,2})(\s*\,\s*(?<scale>\d{1,2}))?)\s*\))?\s*$",
RegexOptions.IgnoreCase | RegexOptions.Compiled |
RegexOptions.CultureInvariant);
- public override Regex Expression => s_expression;
+ protected override Regex Expression => s_expression;
- public override SqlDecimalParserResult GenerateResult(string input,
Match match)
+ protected override SqlDecimalParserResult GenerateResult(string input,
Match match)
{
GroupCollection groups = match.Groups;
Group precisionGroup = groups["precision"];
@@ -224,4 +423,138 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
return new SqlDecimalParserResult(input, BaseTypeName, precision,
scale);
}
}
+
+ /// <summary>
+ /// Provides a parser for SQL INTEGER type definitions.
+ /// </summary>
+ internal class SqlIntegerTypeParser :
SqlTypeNameParser<SqlTypeNameParserResult>
+ {
+ public static SqlIntegerTypeParser Default => new();
+
+ public override string BaseTypeName => "INTEGER";
+
+ // Pattern is based on this definition
+ //
https://docs.databricks.com/en/sql/language-manual/data-types/int-type.html#syntax
+ // { INT | INTEGER }
+ private static readonly Regex s_expression = new(
+ @"^\s*(?<typeName>((INTEGER)|(INT)))\s*$",
+ RegexOptions.IgnoreCase | RegexOptions.Compiled |
RegexOptions.CultureInvariant);
+
+ protected override Regex Expression => s_expression;
+ }
+
+ /// <summary>
+ /// Provides a parser for SQL TIMESTAMP type definitions.
+ /// </summary>
+ internal class SqlTimestampTypeParser :
SqlTypeNameParser<SqlTypeNameParserResult>
+ {
+ public static SqlTimestampTypeParser Default => new();
+
+ public override string BaseTypeName => "TIMESTAMP";
+
+ // Pattern is based on this definition
+ //
https://docs.databricks.com/en/sql/language-manual/data-types/map-type.html#syntax
+ // MAP <keyType, valueType>
+ // keyType: Any data type other than MAP specifying the keys.
+ // valueType: Any data type specifying the values.
+ private static readonly Regex s_expression = new(
+
@"^\s*(?<typeName>((TIMESTAMP)|(TIMESTAMP_LTZ)|(TIMESTAMP_NTZ)))\s*$",
+ RegexOptions.IgnoreCase | RegexOptions.Compiled |
RegexOptions.CultureInvariant);
+
+ protected override Regex Expression => s_expression;
+ }
+
+ /// <summary>
+ /// Provides a parser for SQL STRUCT type definitions.
+ /// </summary>
+ internal class SqlStructTypeParser :
SqlTypeNameParser<SqlTypeNameParserResult>
+ {
+ public static SqlStructTypeParser Default => new();
+
+ public override string BaseTypeName => "STRUCT";
+
+ // Pattern is based on this definition
+ //
https://docs.databricks.com/en/sql/language-manual/data-types/struct-type.html#syntax
+ // STRUCT < [fieldName [:] fieldType [NOT NULL] [COMMENT str] [, …] ] >
+ // fieldName: An identifier naming the field. The names need not be
unique.
+ // fieldType: Any data type.
+ // NOT NULL: When specified the struct guarantees that the value of
this field is never NULL.
+ // COMMENT str: An optional string literal describing the field.
+ private static readonly Regex s_expression = new(
+ @"^\s*(?<typeName>STRUCT)(?<structClause>\s*\<(.+)\>)\s*$", //
STUCT
+ RegexOptions.IgnoreCase | RegexOptions.Compiled |
RegexOptions.CultureInvariant);
+
+ protected override Regex Expression => s_expression;
+ }
+
+ /// <summary>
+ /// Provides a parser for SQL ARRAY type definitions.
+ /// </summary>
+ internal class SqlArrayTypeParser :
SqlTypeNameParser<SqlTypeNameParserResult>
+ {
+ public static SqlArrayTypeParser Default => new();
+
+ public override string BaseTypeName => "ARRAY";
+
+ // Pattern is based on this definition
+ //
https://docs.databricks.com/en/sql/language-manual/data-types/array-type.html#syntax
+ // ARRAY < elementType >
+ // elementType: Any data type defining the type of the elements of the
array.
+ private static readonly Regex s_expression = new(
+ @"^\s*(?<typeName>ARRAY)(?<arrayClause>\s*\<(.+)\>)\s*$",
+ RegexOptions.IgnoreCase | RegexOptions.Compiled |
RegexOptions.CultureInvariant);
+
+ protected override Regex Expression => s_expression;
+ }
+
+ /// <summary>
+ /// Provides a parser for SQL MAP type definitions.
+ /// </summary>
+ internal class SqlMapTypeParser :
SqlTypeNameParser<SqlTypeNameParserResult>
+ {
+ public static SqlMapTypeParser Default => new();
+
+ public override string BaseTypeName => "MAP";
+
+ // Pattern is based on this definition
+ //
https://docs.databricks.com/en/sql/language-manual/data-types/map-type.html#syntax
+ // MAP <keyType, valueType>
+ // keyType: Any data type other than MAP specifying the keys.
+ // valueType: Any data type specifying the values.
+ private static readonly Regex s_expression = new(
+ @"^\s*(?<typeName>MAP)(?<mapClause>\s*\<(.+)\>)\s*$",
+ RegexOptions.IgnoreCase | RegexOptions.Compiled |
RegexOptions.CultureInvariant);
+
+ protected override Regex Expression => s_expression;
+ }
+
+ internal class SqlIntervalTypeParser :
SqlTypeNameParser<SqlTypeNameParserResult>
+ {
+ public static SqlIntervalTypeParser Default => new();
+
+ public override string BaseTypeName { get; } = "INTERVAL";
+
+ // See:
https://docs.databricks.com/en/sql/language-manual/data-types/interval-type.html#syntax
+ private static readonly Regex s_expression = new(
+ @"^\s*(?<typeName>INTERVAL)\s+.*$",
+ RegexOptions.IgnoreCase | RegexOptions.Compiled |
RegexOptions.CultureInvariant);
+
+ protected override Regex Expression => s_expression;
+ }
+
+ internal class SqlSimpleTypeParser(string baseTypeName) :
SqlTypeNameParser<SqlTypeNameParserResult>
+ {
+ private static readonly ConcurrentDictionary<string,
SqlSimpleTypeParser> s_parserMap = new ConcurrentDictionary<string,
SqlSimpleTypeParser>();
+
+ public static SqlSimpleTypeParser Default(string baseTypeName)
+ {
+ return s_parserMap.GetOrAdd(baseTypeName, (typeName) => new
SqlSimpleTypeParser(typeName));
+ }
+
+ public override string BaseTypeName { get; } = baseTypeName;
+
+ protected override Regex Expression => new(
+ @"^\s*" + Regex.Escape(BaseTypeName) + @"\s*$",
+ RegexOptions.IgnoreCase | RegexOptions.Compiled |
RegexOptions.CultureInvariant);
+ }
}
diff --git a/csharp/test/Drivers/Apache/Spark/DriverTests.cs
b/csharp/test/Drivers/Apache/Spark/DriverTests.cs
index 8fb341f87..42f0e08b8 100644
--- a/csharp/test/Drivers/Apache/Spark/DriverTests.cs
+++ b/csharp/test/Drivers/Apache/Spark/DriverTests.cs
@@ -18,6 +18,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using System.Text.RegularExpressions;
using System.Threading.Tasks;
using Apache.Arrow.Adbc.Tests.Metadata;
using Apache.Arrow.Adbc.Tests.Xunit;
@@ -350,6 +351,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
Assert.Equal(i + 1, column.OrdinalPosition);
Assert.False(string.IsNullOrEmpty(column.Name));
Assert.False(string.IsNullOrEmpty(column.XdbcTypeName));
+ Assert.False(Regex.IsMatch(column.XdbcTypeName,
@"[_,\d\<\>\(\)]", RegexOptions.IgnoreCase | RegexOptions.CultureInvariant),
+ "Unexpected character found in field XdbcTypeName");
var supportedTypes =
Enum.GetValues(typeof(SupportedSparkDataType)).Cast<SupportedSparkDataType>();
Assert.Contains((SupportedSparkDataType)column.XdbcSqlDataType!,
supportedTypes);
diff --git a/csharp/test/Drivers/Apache/Spark/SqlTypeNameParserTests.cs
b/csharp/test/Drivers/Apache/Spark/SqlTypeNameParserTests.cs
new file mode 100644
index 000000000..388a8f82a
--- /dev/null
+++ b/csharp/test/Drivers/Apache/Spark/SqlTypeNameParserTests.cs
@@ -0,0 +1,316 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.InteropServices;
+using System.Text;
+using System.Threading.Tasks;
+using Apache.Arrow.Adbc.Drivers.Apache.Spark;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
+{
+ public class SqlTypeNameParserTests(ITestOutputHelper outputHelper)
+ {
+ private readonly ITestOutputHelper _outputHelper = outputHelper;
+
+ [Theory()]
+ [InlineData("ARRAY<INT>", "ARRAY")]
+ [InlineData("ARRAY < INT >", "ARRAY")]
+ [InlineData(" ARRAY < ARRAY < INT > > ", "ARRAY")]
+ [InlineData("ARRAY<VARCHAR(255)>", "ARRAY")]
+ [InlineData("DATE", "DATE")]
+ [InlineData("dec(15)", "DECIMAL")]
+ [InlineData("numeric", "DECIMAL")]
+ [InlineData("STRUCT<F1:INT>", "STRUCT")]
+ [InlineData("STRUCT< F1 INT >", "STRUCT")]
+ [InlineData("STRUCT < F1: ARRAY < INT > > ", "STRUCT")]
+ [InlineData("STRUCT<F1: VARCHAR(255), F2 ARRAY<STRING>>", "STRUCT")]
+ [InlineData("MAP<INT,STRING>", "MAP")]
+ [InlineData("MAP< INT , VARCHAR(255) >", "MAP")]
+ [InlineData("MAP < ARRAY < INT >, INT > ", "MAP")]
+ [InlineData("TIMESTAMP", "TIMESTAMP")]
+ [InlineData("TIMESTAMP_LTZ", "TIMESTAMP")]
+ [InlineData("TIMESTAMP_NTZ", "TIMESTAMP")]
+ internal void CanParseAnyType(string testTypeName, string
expectedBaseTypeName)
+ {
+ SqlTypeNameParserResult result =
SqlTypeNameParser<SqlTypeNameParserResult>.Parse(testTypeName);
+ Assert.NotNull(result);
+ Assert.Equal(testTypeName, result.TypeName);
+ Assert.Equal(expectedBaseTypeName, result.BaseTypeName);
+ }
+
+ [Theory()]
+ [InlineData("BIGINT", "BIGINT")]
+ [InlineData("BINARY", "BINARY")]
+ [InlineData("BOOLEAN", "BOOLEAN")]
+ [InlineData("DATE", "DATE")]
+ [InlineData("DOUBLE", "DOUBLE")]
+ [InlineData("FLOAT", "FLOAT")]
+ [InlineData("SMALLINT", "SMALLINT")]
+ [InlineData("TINYINT", "TINYINT")]
+ internal void CanParseSimpleTypeName(string testTypeName, string
expectedBaseTypeName)
+ {
+
Assert.True(SqlTypeNameParser<SqlTypeNameParserResult>.TryParse(testTypeName,
out SqlTypeNameParserResult? result));
+ Assert.NotNull(result);
+ Assert.Equal(expectedBaseTypeName, result.BaseTypeName);
+ }
+
+ [Theory()]
+ [InlineData("INTERVAL YEAR", "INTERVAL")]
+ [InlineData("INTERVAL MONTH", "INTERVAL")]
+ [InlineData("INTERVAL DAY", "INTERVAL")]
+ [InlineData("INTERVAL HOUR", "INTERVAL")]
+ [InlineData("INTERVAL MINUTE", "INTERVAL")]
+ [InlineData("INTERVAL SECOND", "INTERVAL")]
+ [InlineData("INTERVAL YEAR TO MONTH", "INTERVAL")]
+ [InlineData("INTERVAL DAY TO HOUR", "INTERVAL")]
+ [InlineData("INTERVAL DAY TO MINUTE", "INTERVAL")]
+ [InlineData("INTERVAL DAY TO SECOND", "INTERVAL")]
+ [InlineData("INTERVAL HOUR TO MINUTE", "INTERVAL")]
+ [InlineData("INTERVAL HOUR TO SECOND", "INTERVAL")]
+ [InlineData("INTERVAL MINUTE TO SECOND", "INTERVAL")]
+ internal void CanParseInterval(string testTypeName, string
expectedBaseTypeName)
+ {
+
Assert.True(SqlTypeNameParser<SqlTypeNameParserResult>.TryParse(testTypeName,
out SqlTypeNameParserResult? result));
+ Assert.NotNull(result);
+ Assert.Equal(expectedBaseTypeName, result.BaseTypeName);
+ }
+
+ [Theory()]
+ [MemberData(nameof(GenerateCharTestData), "CHAR")]
+ [MemberData(nameof(GenerateCharTestData), "NCHAR")]
+ [MemberData(nameof(GenerateCharTestData), "CHaR")]
+ internal void CanParseChar(string testTypeName,
SqlCharVarcharParserResult expected)
+ {
+ _outputHelper.WriteLine(testTypeName);
+ Assert.True(SqlCharTypeParser.Default.TryParse(testTypeName, out
SqlCharVarcharParserResult? result));
+ Assert.NotNull(result);
+ Assert.Equal(expected, result);
+ }
+
+ [Theory()]
+ [MemberData(nameof(GenerateVarcharTestData), "VARCHAR")]
+ [MemberData(nameof(GenerateVarcharTestData), "LONGVARCHAR")]
+ [MemberData(nameof(GenerateVarcharTestData), "NVARCHAR")]
+ [MemberData(nameof(GenerateVarcharTestData), "LONGNVARCHAR")]
+ [MemberData(nameof(GenerateVarcharTestData), "VaRCHaR")]
+ internal void CanParseVarchar(string testTypeName,
SqlCharVarcharParserResult expected)
+ {
+ _outputHelper.WriteLine(testTypeName);
+ Assert.True(SqlVarcharTypeParser.Default.TryParse(testTypeName,
out SqlCharVarcharParserResult? result));
+ Assert.NotNull(result);
+ Assert.Equal(expected, result);
+ }
+
+ [Theory()]
+ [MemberData(nameof(GenerateDecimalTestData), "DECIMAL")]
+ [MemberData(nameof(GenerateDecimalTestData), "DEC")]
+ [MemberData(nameof(GenerateDecimalTestData), "NUMERIC")]
+ [MemberData(nameof(GenerateDecimalTestData), "DeCiMaL")]
+ internal void CanParseDecimal(string testTypeName,
SqlDecimalParserResult expected)
+ {
+ _outputHelper.WriteLine(testTypeName);
+ Assert.True(SqlDecimalTypeParser.Default.TryParse(testTypeName,
out SqlDecimalParserResult? result));
+ Assert.NotNull(result);
+ Assert.Equal(expected.TypeName, result.TypeName);
+ Assert.Equal(expected.BaseTypeName, result.BaseTypeName);
+ // Note: Decimal128Type does not override Equals/GetHashCode
+ Assert.Equal(expected.Decimal128Type.Name,
result.Decimal128Type.Name);
+ Assert.Equal(expected.Decimal128Type.Precision,
result.Decimal128Type.Precision);
+ Assert.Equal(expected.Decimal128Type.Scale,
result.Decimal128Type.Scale);
+ }
+
+ [Theory()]
+ [InlineData("INT")]
+ [InlineData("INTEGER")]
+ [InlineData(" INT ")]
+ [InlineData(" INTEGER ")]
+ [InlineData(" iNTeGeR ")]
+ internal void CanParseInteger(string testTypeName)
+ {
+ string baseTypeName = SqlIntegerTypeParser.Default.BaseTypeName;
+ SqlTypeNameParserResult expected = new(testTypeName, baseTypeName);
+ _outputHelper.WriteLine(testTypeName);
+ Assert.True(SqlIntegerTypeParser.Default.TryParse(testTypeName,
out SqlTypeNameParserResult? result));
+ Assert.NotNull(result);
+ Assert.Equal(expected, result);
+ }
+
+ [Theory()]
+ [InlineData("TIMESTAMP")]
+ [InlineData("TIMESTAMP_LTZ")]
+ [InlineData("TIMESTAMP_NTZ")]
+ [InlineData("TiMeSTaMP")]
+ internal void CanParseTimestamp(string testTypeName)
+ {
+ string baseTypeName = SqlTimestampTypeParser.Default.BaseTypeName;
+ SqlTypeNameParserResult expected = new(testTypeName, baseTypeName);
+ _outputHelper.WriteLine(testTypeName);
+ Assert.True(SqlTimestampTypeParser.Default.TryParse(testTypeName,
out SqlTypeNameParserResult? result));
+ Assert.NotNull(result);
+ Assert.Equal(expected, result);
+ }
+
+ [Theory()]
+ [InlineData("ARRAY<INT>")]
+ [InlineData("ARRAY < INT >")]
+ [InlineData(" ARRAY < ARRAY < INT > > ")]
+ [InlineData("ARRAY<VARCHAR(255)>")]
+ [InlineData("aRRaY<iNT>")]
+ internal void CanParseArray(string testTypeName)
+ {
+ string baseTypeName = SqlArrayTypeParser.Default.BaseTypeName;
+ SqlTypeNameParserResult expected = new(testTypeName, baseTypeName);
+ _outputHelper.WriteLine(testTypeName);
+ Assert.True(SqlArrayTypeParser.Default.TryParse(testTypeName, out
SqlTypeNameParserResult? result));
+ Assert.NotNull(result);
+ Assert.Equal(expected, result);
+ }
+
+ [Theory()]
+ [InlineData("MAP<INT,STRING>")]
+ [InlineData("MAP< INT , VARCHAR(255) >")]
+ [InlineData("MAP < ARRAY < INT >, INT > ")]
+ [InlineData("MaP<iNT,STRiNG>")]
+ internal void CanParseMap(string testTypeName)
+ {
+ string baseTypeName = SqlMapTypeParser.Default.BaseTypeName;
+ SqlTypeNameParserResult expected = new(testTypeName, baseTypeName);
+ _outputHelper.WriteLine(testTypeName);
+ Assert.True(SqlMapTypeParser.Default.TryParse(testTypeName, out
SqlTypeNameParserResult? result));
+ Assert.NotNull(result);
+ Assert.Equal(expected, result);
+ }
+
+ [Theory()]
+ [InlineData("STRUCT<F1:INT>")]
+ [InlineData("STRUCT< F1 INT >")]
+ [InlineData("STRUCT < F1: ARRAY < INT > > ")]
+ [InlineData("STRUCT<F1: VARCHAR(255), F2 ARRAY<STRING>>")]
+ [InlineData("STRuCT<F1:iNT>")]
+ internal void CanParseStruct(string testTypeName)
+ {
+ string baseTypeName = SqlStructTypeParser.Default.BaseTypeName;
+ SqlTypeNameParserResult expected = new(testTypeName, baseTypeName);
+ _outputHelper.WriteLine(testTypeName);
+ Assert.True(SqlStructTypeParser.Default.TryParse(testTypeName, out
SqlTypeNameParserResult? result));
+ Assert.NotNull(result);
+ Assert.Equal(expected, result);
+ }
+
+ [Theory()]
+ [InlineData("ARRAY")]
+ [InlineData("MAP")]
+ [InlineData("STRUCT")]
+ [InlineData("ARRAY<")]
+ [InlineData("MAP<")]
+ [InlineData("STRUCT<")]
+ [InlineData("ARRAY>")]
+ [InlineData("MAP>")]
+ [InlineData("STRUCT>")]
+ [InlineData("INTERVAL")]
+ [InlineData("TIMESTAMP_ZZZ")]
+ internal void CannotParseUnexpectedTypeName(string testTypeName)
+ {
+
Assert.False(SqlTypeNameParser<SqlTypeNameParserResult>.TryParse(testTypeName,
out _), $"Expecting type {testTypeName} to fail to parse.");
+ }
+
+ [Fact()]
+ internal void CanDetectInvalidReturnType()
+ {
+ Func<object?> testCode = () =>
SqlTypeNameParser<SqlDecimalParserResult>.Parse("INTEGER",
(int)SparkConnection.ColumnTypeId.INTEGER);
+
_outputHelper.WriteLine(Assert.Throws<InvalidCastException>(testCode).Message);
+ }
+
+ public static IEnumerable<object[]> GenerateCharTestData(string
typeName)
+ {
+ int?[] lengths = [1, 10, int.MaxValue,];
+ string[] spaces = ["", " ", "\t"];
+ string baseTypeName = SqlCharTypeParser.Default.BaseTypeName;
+ foreach (int? length in lengths)
+ {
+ foreach (string leadingSpace in spaces)
+ {
+ foreach (string trailingSpace in spaces)
+ {
+ string clause = length == null ? "" :
$"{leadingSpace}({leadingSpace}{length}{trailingSpace})";
+ string testTypeName =
$"{leadingSpace}{typeName}{clause}{trailingSpace}";
+ SqlCharVarcharParserResult expectedResult =
new(testTypeName, baseTypeName, length ?? int.MaxValue);
+ yield return new object[] { testTypeName,
expectedResult };
+ }
+ }
+ }
+ }
+
+ public static IEnumerable<object[]> GenerateVarcharTestData(string
typeName)
+ {
+ int?[] lengths = [null, 1, 10, int.MaxValue,];
+ string[] spaces = ["", " ", "\t"];
+ string baseTypeName = SqlVarcharTypeParser.Default.BaseTypeName;
+ foreach (int? length in lengths)
+ {
+ foreach (string leadingSpace in spaces)
+ {
+ foreach (string trailingSpace in spaces)
+ {
+ string clause = length == null ? "" :
$"{leadingSpace}({leadingSpace}{length}{trailingSpace})";
+ string testTypeName =
$"{leadingSpace}{typeName}{clause}{trailingSpace}";
+ SqlCharVarcharParserResult expectedResult =
new(testTypeName, baseTypeName, length ?? int.MaxValue);
+ yield return new object[] { testTypeName,
expectedResult };
+ }
+ }
+ }
+ yield return new object[] { "STRING", new
SqlCharVarcharParserResult("STRING", "STRING") };
+ }
+
+ public static IEnumerable<object[]> GenerateDecimalTestData(string
typeName)
+ {
+ string baseTypeName = SqlDecimalTypeParser.Default.BaseTypeName;
+ var precisionScales = new[]
+ {
+ new { Precision = (int?)null, Scale = (int?)null },
+ new { Precision = (int?)1, Scale = (int?)null },
+ new { Precision = (int?)1, Scale = (int?)1 },
+ new { Precision = (int?)38, Scale = (int?)null },
+ new { Precision = (int?)38, Scale = (int?)38 },
+ new { Precision = (int?)99, Scale = (int?)null },
+ new { Precision = (int?)99, Scale = (int?)99 },
+ };
+ string[] spaces = ["", " ", "\t"];
+ foreach (var precisionScale in precisionScales)
+ {
+ foreach (string leadingSpace in spaces)
+ {
+ foreach (string trailingSpace in spaces)
+ {
+ string clause = precisionScale.Precision == null ? ""
+ : precisionScale.Scale == null
+ ?
$"({leadingSpace}{precisionScale.Precision}{trailingSpace})"
+ :
$"({leadingSpace}{precisionScale.Precision}{trailingSpace},{leadingSpace}{precisionScale.Scale}{trailingSpace})";
+ string testTypeName =
$"{leadingSpace}{typeName}{clause}{trailingSpace}";
+ SqlDecimalParserResult expectedResult =
new(testTypeName, baseTypeName, precisionScale.Precision ?? 10,
precisionScale.Scale ?? 0);
+ yield return new object[] { testTypeName,
expectedResult };
+ }
+ }
+ }
+ }
+ }
+}