This is an automated email from the ASF dual-hosted git repository.
ptupitsyn pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/ignite-3.git
The following commit(s) were added to refs/heads/main by this push:
new 6e50d46e96 IGNITE-18257 .NET: LINQ: Implement type casts (#1377)
6e50d46e96 is described below
commit 6e50d46e9665459559991251b27a2a2f5fdec98d
Author: Pavel Tupitsyn <[email protected]>
AuthorDate: Fri Nov 25 08:23:32 2022 +0300
IGNITE-18257 .NET: LINQ: Implement type casts (#1377)
* Unlike Ignite 2.x, we can't rely on SQL engine to perform type conversion
in many cases. Emit explicit SQL `cast` when required.
* Remove redundant whitespace in generated SQL.
* Improve `ToSqlTypeName`.
---
.../Linq/LinqSqlGenerationTests.cs | 10 +--
.../Linq/LinqTests.Aggregate.cs | 2 +-
.../Apache.Ignite.Tests/Linq/LinqTests.Cast.cs | 88 ++++++++++++++++++++++
.../Apache.Ignite.Tests/Linq/LinqTests.GroupBy.cs | 12 ++-
.../Apache.Ignite.Tests/Linq/LinqTests.Join.cs | 8 +-
.../Sql/SqlColumnTypeExtensionsTests.cs | 13 +++-
.../Internal/Linq/IgniteQueryExpressionVisitor.cs | 25 +++---
.../Internal/Linq/IgniteQueryModelVisitor.cs | 16 ++--
.../Internal/Sql/SqlColumnTypeExtensions.cs | 19 ++++-
9 files changed, 156 insertions(+), 37 deletions(-)
diff --git
a/modules/platforms/dotnet/Apache.Ignite.Tests/Linq/LinqSqlGenerationTests.cs
b/modules/platforms/dotnet/Apache.Ignite.Tests/Linq/LinqSqlGenerationTests.cs
index df1e2a511d..0a0ba656fd 100644
---
a/modules/platforms/dotnet/Apache.Ignite.Tests/Linq/LinqSqlGenerationTests.cs
+++
b/modules/platforms/dotnet/Apache.Ignite.Tests/Linq/LinqSqlGenerationTests.cs
@@ -60,23 +60,23 @@ public partial class LinqSqlGenerationTests
[Test]
public void TestSum() =>
- AssertSql("select sum (_T0.KEY) from PUBLIC.tbl1 as _T0", q => q.Sum(x
=> x.Key));
+ AssertSql("select sum(_T0.KEY) from PUBLIC.tbl1 as _T0", q => q.Sum(x
=> x.Key));
[Test]
public void TestAvg() =>
- AssertSql("select avg (_T0.KEY) from PUBLIC.tbl1 as _T0", q =>
q.Average(x => x.Key));
+ AssertSql("select avg(_T0.KEY) from PUBLIC.tbl1 as _T0", q =>
q.Average(x => x.Key));
[Test]
public void TestMin() =>
- AssertSql("select min (_T0.KEY) from PUBLIC.tbl1 as _T0", q => q.Min(x
=> x.Key));
+ AssertSql("select min(_T0.KEY) from PUBLIC.tbl1 as _T0", q => q.Min(x
=> x.Key));
[Test]
public void TestMax() =>
- AssertSql("select max (_T0.KEY) from PUBLIC.tbl1 as _T0", q => q.Max(x
=> x.Key));
+ AssertSql("select max(_T0.KEY) from PUBLIC.tbl1 as _T0", q => q.Max(x
=> x.Key));
[Test]
public void TestCount() =>
- AssertSql("select count (*) from PUBLIC.tbl1 as _T0", q => q.Count());
+ AssertSql("select count(*) from PUBLIC.tbl1 as _T0", q => q.Count());
[Test]
public void TestDistinct() =>
diff --git
a/modules/platforms/dotnet/Apache.Ignite.Tests/Linq/LinqTests.Aggregate.cs
b/modules/platforms/dotnet/Apache.Ignite.Tests/Linq/LinqTests.Aggregate.cs
index abf6bd8bfb..e186ceac80 100644
--- a/modules/platforms/dotnet/Apache.Ignite.Tests/Linq/LinqTests.Aggregate.cs
+++ b/modules/platforms/dotnet/Apache.Ignite.Tests/Linq/LinqTests.Aggregate.cs
@@ -110,7 +110,7 @@ public partial class LinqTests
Assert.AreEqual(2, res[2].Max);
StringAssert.Contains(
- "select _T0.KEY, count (*) , sum (_T0.KEY) , avg (_T0.KEY) , min
(_T0.KEY) , max (_T0.KEY) " +
+ "select _T0.KEY, count(*), sum(_T0.KEY), avg(_T0.KEY),
min(_T0.KEY), max(_T0.KEY) " +
"from PUBLIC.TBL_INT32 as _T0 " +
"group by (_T0.KEY) " +
"order by (_T0.KEY) asc",
diff --git
a/modules/platforms/dotnet/Apache.Ignite.Tests/Linq/LinqTests.Cast.cs
b/modules/platforms/dotnet/Apache.Ignite.Tests/Linq/LinqTests.Cast.cs
new file mode 100644
index 0000000000..f539304012
--- /dev/null
+++ b/modules/platforms/dotnet/Apache.Ignite.Tests/Linq/LinqTests.Cast.cs
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+namespace Apache.Ignite.Tests.Linq;
+
+using System.Linq;
+using NUnit.Framework;
+
+/// <summary>
+/// Linq type cast tests.
+/// </summary>
+public partial class LinqTests
+{
+ [Test]
+ public void TestProjectionWithCastIntoAnonymousType()
+ {
+ // TODO IGNITE-18258 decimal, BigInteger.
+ var query = PocoIntView.AsQueryable()
+ .Select(x => new
+ {
+ Byte = (sbyte)x.Val,
+ Short = (short)x.Val,
+ Long = (long)x.Val,
+ Float = (float)x.Val / 1000,
+ Double = (double)x.Val / 2000,
+ })
+ .OrderByDescending(x => x.Long);
+
+ var res = query.ToList();
+
+ Assert.AreEqual(-124, res[0].Byte);
+ Assert.AreEqual(900, res[0].Short);
+ Assert.AreEqual(900, res[0].Long);
+ Assert.AreEqual(900f / 1000, res[0].Float);
+ Assert.AreEqual(900d / 2000, res[0].Double);
+
+ StringAssert.Contains(
+ "select cast(_T0.VAL as tinyint), cast(_T0.VAL as smallint),
cast(_T0.VAL as bigint), " +
+ "(cast(_T0.VAL as real) / ?), (cast(_T0.VAL as double) / ?) " +
+ "from PUBLIC.TBL_INT32 as _T0 " +
+ "order by (cast(_T0.VAL as bigint)) desc",
+ query.ToString());
+ }
+
+ [Test]
+ public void TestJoinOnDifferentTypes()
+ {
+ var query = PocoFloatView.AsQueryable()
+ .Join(
+ PocoByteView.AsQueryable(),
+ x => x.Key,
+ y => y.Key,
+ (x, y) => new
+ {
+ x.Key,
+ Val1 = x.Val,
+ Val2 = y.Val
+ })
+ .OrderByDescending(x => x.Key);
+
+ var res = query.ToList();
+
+ Assert.AreEqual(9, res[0].Key);
+ Assert.AreEqual(9f, res[0].Val1);
+ Assert.AreEqual(3, res[0].Val2);
+
+ StringAssert.Contains(
+ "select _T0.KEY, _T0.VAL, _T1.VAL " +
+ "from PUBLIC.TBL_FLOAT as _T0 " +
+ "inner join PUBLIC.TBL_INT8 as _T1 on (cast(_T1.KEY as real) =
_T0.KEY) " +
+ "order by (_T0.KEY) desc",
+ query.ToString());
+ }
+}
diff --git
a/modules/platforms/dotnet/Apache.Ignite.Tests/Linq/LinqTests.GroupBy.cs
b/modules/platforms/dotnet/Apache.Ignite.Tests/Linq/LinqTests.GroupBy.cs
index a980dfad90..c858d20cb0 100644
--- a/modules/platforms/dotnet/Apache.Ignite.Tests/Linq/LinqTests.GroupBy.cs
+++ b/modules/platforms/dotnet/Apache.Ignite.Tests/Linq/LinqTests.GroupBy.cs
@@ -70,10 +70,9 @@ public partial class LinqTests
[Test]
public void TestGroupByWithAggregates()
{
- // TODO IGNITE-18196 Remove cast to long for Sum and Count
var query = PocoByteView.AsQueryable()
.GroupBy(x => x.Val)
- .Select(x => new { x.Key, Count = (long)x.Count(), Sum =
(long)x.Sum(e => e.Key), Avg = x.Average(e => e.Key) })
+ .Select(x => new { x.Key, Count = x.Count(), Sum = x.Sum(e =>
e.Key), Avg = x.Average(e => e.Key) })
.OrderBy(x => x.Key);
var res = query.ToList();
@@ -83,7 +82,7 @@ public partial class LinqTests
Assert.AreEqual(4.0d, res[1].Avg);
StringAssert.Contains(
- "select _T0.VAL, count (*) , sum (_T0.KEY) , avg (_T0.KEY) " +
+ "select _T0.VAL, count(*), sum(cast(_T0.KEY as int)),
avg(cast(_T0.KEY as int)) " +
"from PUBLIC.TBL_INT8 as _T0 " +
"group by (_T0.VAL) " +
"order by (_T0.VAL) asc",
@@ -116,7 +115,6 @@ public partial class LinqTests
[Test]
public void TestGroupByWithJoinAndProjection()
{
- // TODO IGNITE-18196 Remove cast to long for Sum and Count
var query1 = PocoView.AsQueryable();
var query2 = PocoIntView.AsQueryable();
@@ -131,7 +129,7 @@ public partial class LinqTests
Price = a.Val
})
.GroupBy(x => x.Category)
- .Select(g => new {Cat = g.Key, Count = (long)g.Count()})
+ .Select(g => new {Cat = g.Key, Count = g.Count()})
.OrderBy(x => x.Cat);
var res = query.ToList();
@@ -141,9 +139,9 @@ public partial class LinqTests
Assert.AreEqual(10, res.Count);
StringAssert.Contains(
- "select _T0.VAL, count (*) " +
+ "select _T0.VAL, count(*) " +
"from PUBLIC.TBL1 as _T1 " +
- "inner join PUBLIC.TBL_INT32 as _T0 on (_T0.KEY = _T1.KEY) " +
+ "inner join PUBLIC.TBL_INT32 as _T0 on (cast(_T0.KEY as bigint) =
_T1.KEY) " +
"group by (_T0.VAL) " +
"order by (_T0.VAL) asc",
query.ToString());
diff --git
a/modules/platforms/dotnet/Apache.Ignite.Tests/Linq/LinqTests.Join.cs
b/modules/platforms/dotnet/Apache.Ignite.Tests/Linq/LinqTests.Join.cs
index b8373d5fff..94e123a5f7 100644
--- a/modules/platforms/dotnet/Apache.Ignite.Tests/Linq/LinqTests.Join.cs
+++ b/modules/platforms/dotnet/Apache.Ignite.Tests/Linq/LinqTests.Join.cs
@@ -83,7 +83,7 @@ public partial class LinqTests
StringAssert.Contains(
"select _T0.KEY, _T1.VAL " +
"from PUBLIC.TBL1 as _T0 " +
- "inner join PUBLIC.TBL_INT32 as _T1 on (_T1.KEY = _T0.KEY) " +
+ "inner join PUBLIC.TBL_INT32 as _T1 on (cast(_T1.KEY as bigint) =
_T0.KEY) " +
"where (_T0.KEY > ?) " +
"order by (_T0.KEY) asc " +
"limit ?",
@@ -128,7 +128,7 @@ public partial class LinqTests
StringAssert.Contains(
"select _T0.KEY, _T1.VAL, _T2.VAL " +
"from PUBLIC.TBL1 as _T0 " +
- "inner join PUBLIC.TBL_INT32 as _T1 on (_T1.KEY = _T0.KEY) " +
+ "inner join PUBLIC.TBL_INT32 as _T1 on (cast(_T1.KEY as bigint) =
_T0.KEY) " +
"inner join PUBLIC.TBL_INT64 as _T2 on (_T2.KEY = _T0.KEY)",
joinQuery.ToString());
}
@@ -157,7 +157,7 @@ public partial class LinqTests
StringAssert.Contains(
"select _T0.KEY, _T1.VAL " +
"from PUBLIC.TBL1 as _T0 " +
- "inner join PUBLIC.TBL_INT32 as _T1 on (_T1.KEY = _T0.KEY) " +
+ "inner join PUBLIC.TBL_INT32 as _T1 on (cast(_T1.KEY as bigint) =
_T0.KEY) " +
"where (_T1.KEY > ?) " +
"order by (_T1.KEY) asc",
joinQuery.ToString());
@@ -253,7 +253,7 @@ public partial class LinqTests
"select _T0.KEY, _T1.VAL " +
"from PUBLIC.TBL_INT32 as _T0 " +
"left outer join (select * from PUBLIC.TBL_INT16 as _T2 ) as _T1 "
+
- "on (_T1.KEY = _T0.KEY)",
+ "on (cast(_T1.KEY as int) = _T0.KEY)",
joinQuery.ToString());
}
diff --git
a/modules/platforms/dotnet/Apache.Ignite.Tests/Sql/SqlColumnTypeExtensionsTests.cs
b/modules/platforms/dotnet/Apache.Ignite.Tests/Sql/SqlColumnTypeExtensionsTests.cs
index fe16d89563..ee24a7d481 100644
---
a/modules/platforms/dotnet/Apache.Ignite.Tests/Sql/SqlColumnTypeExtensionsTests.cs
+++
b/modules/platforms/dotnet/Apache.Ignite.Tests/Sql/SqlColumnTypeExtensionsTests.cs
@@ -38,7 +38,7 @@ public class SqlColumnTypeExtensionsTests
Assert.AreEqual(sqlColumnType,
sqlColumnType.ToClrType().ToSqlColumnType());
[TestCaseSource(nameof(SqlColumnTypes))]
- public void TestToSqlTypeName(SqlColumnType sqlColumnType)
+ public void TestSqlColumnTypeToSqlTypeName(SqlColumnType sqlColumnType)
{
if (sqlColumnType is SqlColumnType.Duration or SqlColumnType.Period)
{
@@ -47,4 +47,15 @@ public class SqlColumnTypeExtensionsTests
Assert.IsNotNull(sqlColumnType.ToSqlTypeName(),
sqlColumnType.ToString());
}
+
+ [TestCaseSource(nameof(SqlColumnTypes))]
+ public void TestClrTypeToSqlTypeName(SqlColumnType sqlColumnType)
+ {
+ if (sqlColumnType is SqlColumnType.Duration or SqlColumnType.Period)
+ {
+ return;
+ }
+
+ Assert.IsNotNull(sqlColumnType.ToClrType().ToSqlTypeName(),
sqlColumnType.ToString());
+ }
}
diff --git
a/modules/platforms/dotnet/Apache.Ignite/Internal/Linq/IgniteQueryExpressionVisitor.cs
b/modules/platforms/dotnet/Apache.Ignite/Internal/Linq/IgniteQueryExpressionVisitor.cs
index e3a83daf11..1e033cedda 100644
---
a/modules/platforms/dotnet/Apache.Ignite/Internal/Linq/IgniteQueryExpressionVisitor.cs
+++
b/modules/platforms/dotnet/Apache.Ignite/Internal/Linq/IgniteQueryExpressionVisitor.cs
@@ -26,6 +26,7 @@ using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text;
+using Common;
using Remotion.Linq;
using Remotion.Linq.Clauses;
using Remotion.Linq.Clauses.Expressions;
@@ -204,7 +205,7 @@ internal sealed class IgniteQueryExpressionVisitor :
ThrowingExpressionVisitor
}
Visit(expression.Right);
- ResultBuilder.Append(')');
+ ResultBuilder.TrimEnd().Append(')');
return expression;
}
@@ -325,8 +326,7 @@ internal sealed class IgniteQueryExpressionVisitor :
ThrowingExpressionVisitor
Visit(expression.IfTrue);
ResultBuilder.Append(" as ");
- var sqlColumnType = expression.Type.ToSqlColumnType() ?? throw new
NotSupportedException("Unsupported type: " + expression.Type);
- ResultBuilder.Append(sqlColumnType.ToSqlTypeName());
+ ResultBuilder.Append(expression.Type.ToSqlTypeName());
ResultBuilder.Append(')');
Visit(expression.IfFalse);
@@ -351,7 +351,7 @@ internal sealed class IgniteQueryExpressionVisitor :
ThrowingExpressionVisitor
{
ResultBuilder.Append('(');
_modelVisitor.VisitQueryModel(subQueryModel, false, true);
- ResultBuilder.Append(')');
+ ResultBuilder.TrimEnd().Append(')');
}
else
{
@@ -384,7 +384,7 @@ internal sealed class IgniteQueryExpressionVisitor :
ThrowingExpressionVisitor
break;
case ExpressionType.Convert:
- // Ignore, let the db do the conversion
+ ResultBuilder.Append("cast(");
break;
default:
@@ -395,7 +395,14 @@ internal sealed class IgniteQueryExpressionVisitor :
ThrowingExpressionVisitor
if (closeBracket)
{
- ResultBuilder.Append(')');
+ ResultBuilder.TrimEnd().Append(')');
+ }
+ else if (expression.NodeType is ExpressionType.Convert)
+ {
+ ResultBuilder
+ .Append(" as ")
+ .Append(expression.Type.ToSqlTypeName())
+ .Append(')');
}
return expression;
@@ -481,7 +488,7 @@ internal sealed class IgniteQueryExpressionVisitor :
ThrowingExpressionVisitor
if (!first)
{
- ResultBuilder.Append(", ");
+ ResultBuilder.TrimEnd().Append(", ");
}
first = false;
@@ -598,7 +605,7 @@ internal sealed class IgniteQueryExpressionVisitor :
ThrowingExpressionVisitor
throw new NotSupportedException("Aggregate functions do
not support multiple fields");
}
- ResultBuilder.Append(", ");
+ ResultBuilder.TrimEnd().Append(", ");
}
first = false;
@@ -664,7 +671,7 @@ internal sealed class IgniteQueryExpressionVisitor :
ThrowingExpressionVisitor
Visit(expression.Left);
ResultBuilder.Append(", ");
Visit(expression.Right);
- ResultBuilder.Append(')');
+ ResultBuilder.TrimEnd().Append(')');
return true;
}
diff --git
a/modules/platforms/dotnet/Apache.Ignite/Internal/Linq/IgniteQueryModelVisitor.cs
b/modules/platforms/dotnet/Apache.Ignite/Internal/Linq/IgniteQueryModelVisitor.cs
index b5f0a4f7d0..8a5e3e3e32 100644
---
a/modules/platforms/dotnet/Apache.Ignite/Internal/Linq/IgniteQueryModelVisitor.cs
+++
b/modules/platforms/dotnet/Apache.Ignite/Internal/Linq/IgniteQueryModelVisitor.cs
@@ -109,7 +109,7 @@ internal sealed class IgniteQueryModelVisitor :
QueryModelVisitorBase
BuildSqlExpression(ordering.Expression);
- _builder.Append(')');
+ _builder.TrimEnd().Append(')');
_builder.Append(ordering.OrderingDirection ==
OrderingDirection.Asc ? " asc" : " desc");
}
@@ -278,7 +278,7 @@ internal sealed class IgniteQueryModelVisitor :
QueryModelVisitorBase
{
// FIELD1, FIELD2
BuildSqlExpression(queryModel.SelectClause.Selector, parenCount >
0, includeAllFields);
- _builder.Append(')', parenCount).Append(' ');
+ _builder.TrimEnd().Append(')', parenCount).Append(' ');
}
}
@@ -339,27 +339,27 @@ internal sealed class IgniteQueryModelVisitor :
QueryModelVisitorBase
{
if (op is CountResultOperator or LongCountResultOperator)
{
- _builder.Append("count (");
+ _builder.Append("count(");
parenCount++;
}
else if (op is SumResultOperator)
{
- _builder.Append("sum (");
+ _builder.Append("sum(");
parenCount++;
}
else if (op is MinResultOperator)
{
- _builder.Append("min (");
+ _builder.Append("min(");
parenCount++;
}
else if (op is MaxResultOperator)
{
- _builder.Append("max (");
+ _builder.Append("max(");
parenCount++;
}
else if (op is AverageResultOperator)
{
- _builder.Append("avg (");
+ _builder.Append("avg(");
parenCount++;
}
else if (op is DistinctResultOperator)
@@ -452,7 +452,7 @@ internal sealed class IgniteQueryModelVisitor :
QueryModelVisitorBase
VisitQueryModel(queryable.GetQueryModel());
}
- _builder.Append(')');
+ _builder.TrimEnd().Append(')');
}
}
}
diff --git
a/modules/platforms/dotnet/Apache.Ignite/Internal/Sql/SqlColumnTypeExtensions.cs
b/modules/platforms/dotnet/Apache.Ignite/Internal/Sql/SqlColumnTypeExtensions.cs
index e8202716d0..dd1c5ac237 100644
---
a/modules/platforms/dotnet/Apache.Ignite/Internal/Sql/SqlColumnTypeExtensions.cs
+++
b/modules/platforms/dotnet/Apache.Ignite/Internal/Sql/SqlColumnTypeExtensions.cs
@@ -33,6 +33,11 @@ internal static class SqlColumnTypeExtensions
private static readonly IReadOnlyDictionary<Type, SqlColumnType> ClrToSql =
Enum.GetValues<SqlColumnType>().ToDictionary(x => x.ToClrType(), x =>
x);
+ private static readonly IReadOnlyDictionary<Type, string> ClrToSqlName =
+ Enum.GetValues<SqlColumnType>()
+ .Where(x => x != SqlColumnType.Period && x !=
SqlColumnType.Duration)
+ .ToDictionary(x => x.ToClrType(), x => x.ToSqlTypeName());
+
/// <summary>
/// Gets corresponding .NET type.
/// </summary>
@@ -66,7 +71,7 @@ internal static class SqlColumnTypeExtensions
/// Gets corresponding SQL type name.
/// </summary>
/// <param name="sqlColumnType">SQL column type.</param>
- /// <returns>CLR type.</returns>
+ /// <returns>SQL type name.</returns>
public static string ToSqlTypeName(this SqlColumnType sqlColumnType) =>
sqlColumnType switch
{
SqlColumnType.Boolean => "boolean",
@@ -85,10 +90,20 @@ internal static class SqlColumnTypeExtensions
SqlColumnType.Bitmask => "bitmap",
SqlColumnType.String => "varchar",
SqlColumnType.ByteArray => "varbinary",
- SqlColumnType.Number => "number",
+ SqlColumnType.Number => "numeric",
_ => throw new InvalidOperationException($"Unsupported
{nameof(SqlColumnType)}: {sqlColumnType}")
};
+ /// <summary>
+ /// Gets corresponding SQL type name.
+ /// </summary>
+ /// <param name="type">CLR type.</param>
+ /// <returns>SQL type name.</returns>
+ public static string ToSqlTypeName(this Type type) =>
+ ClrToSqlName.TryGetValue(type, out var sqlTypeName)
+ ? sqlTypeName
+ : throw new InvalidOperationException($"Type is not supported in
SQL: {type}");
+
/// <summary>
/// Gets corresponding <see cref="SqlColumnType"/>.
/// </summary>