This is an automated email from the ASF dual-hosted git repository.

libenchao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new 289d082055 [CALCITE-5674] CAST expr to target type should respect 
nullable when it is complex type
289d082055 is described below

commit 289d08205584efefa7b399318e7bb81c4f7f11aa
Author: yongen.ly <[email protected]>
AuthorDate: Fri May 5 20:31:50 2023 +0800

    [CALCITE-5674] CAST expr to target type should respect nullable when it is 
complex type
    
    Close apache/calcite#3189
---
 .../apache/calcite/sql/fun/SqlCastFunction.java    | 66 +++++++++++++++++++++-
 .../org/apache/calcite/test/SqlValidatorTest.java  | 58 ++++++++++++++++++-
 2 files changed, 120 insertions(+), 4 deletions(-)

diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlCastFunction.java 
b/core/src/main/java/org/apache/calcite/sql/fun/SqlCastFunction.java
index 74b1fc0236..04a05e7358 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlCastFunction.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlCastFunction.java
@@ -43,12 +43,20 @@ import com.google.common.collect.ImmutableSetMultimap;
 import com.google.common.collect.SetMultimap;
 
 import java.text.Collator;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Objects;
 
 import static com.google.common.base.Preconditions.checkArgument;
 
+import static org.apache.calcite.sql.type.SqlTypeUtil.isArray;
+import static org.apache.calcite.sql.type.SqlTypeUtil.isCollection;
+import static org.apache.calcite.sql.type.SqlTypeUtil.isMap;
+import static org.apache.calcite.sql.type.SqlTypeUtil.isRow;
 import static org.apache.calcite.util.Static.RESOURCE;
 
+import static java.util.Objects.requireNonNull;
+
 /**
  * SqlCastFunction. Note that the std functions are really singleton objects,
  * because they always get fetched via the StdOperatorTable. So you can't store
@@ -122,8 +130,62 @@ public class SqlCastFunction extends SqlFunction {
   /** Derives the type of "CAST(expression AS targetType)". */
   public static RelDataType deriveType(RelDataTypeFactory typeFactory,
       RelDataType expressionType, RelDataType targetType, boolean safe) {
-    return typeFactory.createTypeWithNullability(targetType,
-        expressionType.isNullable() || safe);
+    return createTypeWithNullabilityFromExpr(typeFactory, expressionType, 
targetType, safe);
+  }
+
+  private static RelDataType 
createTypeWithNullabilityFromExpr(RelDataTypeFactory typeFactory,
+      RelDataType expressionType, RelDataType targetType, boolean safe) {
+    boolean isNullable = expressionType.isNullable() || safe;
+
+    if (isCollection(expressionType)) {
+      RelDataType expressionElementType = expressionType.getComponentType();
+      RelDataType targetElementType = targetType.getComponentType();
+      requireNonNull(expressionElementType, () -> "componentType of " + 
expressionType);
+      requireNonNull(targetElementType, () -> "componentType of " + 
targetType);
+      RelDataType newElementType =
+          createTypeWithNullabilityFromExpr(
+              typeFactory, expressionElementType, targetElementType, safe);
+      return isArray(expressionType)
+          ? SqlTypeUtil.createArrayType(typeFactory, newElementType, 
isNullable)
+          : SqlTypeUtil.createMultisetType(typeFactory, newElementType, 
isNullable);
+    }
+
+    if (isRow(expressionType)) {
+      final int fieldCount = expressionType.getFieldCount();
+      final List<RelDataType> typeList = new ArrayList<>(fieldCount);
+      for (int i = 0; i < fieldCount; ++i) {
+        RelDataType expressionElementType = 
expressionType.getFieldList().get(i).getType();
+        RelDataType targetElementType = 
targetType.getFieldList().get(i).getType();
+        typeList.add(
+            createTypeWithNullabilityFromExpr(typeFactory, 
expressionElementType,
+                targetElementType, safe));
+      }
+      return typeFactory.createTypeWithNullability(
+          typeFactory.createStructType(
+              typeList,
+              targetType.getFieldNames()), isNullable);
+    }
+
+    if (isMap(expressionType)) {
+      RelDataType expressionKeyType =
+          requireNonNull(expressionType.getKeyType(), () -> "keyType of " + 
expressionType);
+      RelDataType expressionValueType =
+          requireNonNull(expressionType.getValueType(), () -> "valueType of " 
+ expressionType);
+      RelDataType targetKeyType =
+          requireNonNull(targetType.getKeyType(), () -> "keyType of " + 
targetType);
+      RelDataType targetValueType =
+          requireNonNull(targetType.getValueType(), () -> "valueType of " + 
targetType);
+
+      RelDataType keyType =
+          createTypeWithNullabilityFromExpr(
+              typeFactory, expressionKeyType, targetKeyType, safe);
+      RelDataType valueType =
+          createTypeWithNullabilityFromExpr(
+              typeFactory, expressionValueType, targetValueType, safe);
+      SqlTypeUtil.createMapType(typeFactory, keyType, valueType, isNullable);
+    }
+
+    return typeFactory.createTypeWithNullability(targetType, isNullable);
   }
 
   @Override public String getSignatureTemplate(final int operandsCount) {
diff --git a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java 
b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
index acd8e3d877..6e6119c900 100644
--- a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
+++ b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java
@@ -6984,9 +6984,12 @@ public class SqlValidatorTest extends 
SqlValidatorTestCase {
 
   @Test void testCastAsCollectionType() {
     sql("select cast(array[1,null,2] as int array) from (values (1))")
-        .columnType("INTEGER NOT NULL ARRAY NOT NULL");
+        .columnType("INTEGER ARRAY NOT NULL");
     sql("select cast(array['1',null,'2'] as varchar(5) array) from (values 
(1))")
-        .columnType("VARCHAR(5) NOT NULL ARRAY NOT NULL");
+        .columnType("VARCHAR(5) ARRAY NOT NULL");
+    sql("select cast(multiset[1,null,2] as int multiset) from (values (1))")
+        .columnType("INTEGER MULTISET NOT NULL");
+
     // test array type.
     sql("select cast(\"intArrayType\" as int array) from COMPLEXTYPES.CTC_T1")
         .withExtendedCatalog()
@@ -7060,6 +7063,57 @@ public class SqlValidatorTest extends 
SqlValidatorTestCase {
             + "MULTISET NOT NULL");
   }
 
+  @Test void testSafeCastAsCollectionType() {
+    final SqlOperatorTable opTable = operatorTableFor(SqlLibrary.BIG_QUERY);
+
+    sql("select safe_cast(array[1,null,2] as int array) from (values (1))")
+        .withOperatorTable(opTable)
+        .columnType("INTEGER ARRAY");
+    sql("select safe_cast(multiset[1,null,2] as int multiset) from (values 
(1))")
+        .withOperatorTable(opTable)
+        .columnType("INTEGER MULTISET");
+
+    // test array type.
+    sql("select safe_cast(\"varchar5ArrayArrayType\" as varchar(5) array 
array) "
+        + "from COMPLEXTYPES.CTC_T1")
+        .withOperatorTable(opTable)
+        .withExtendedCatalog()
+        .columnType("VARCHAR(5) ARRAY ARRAY");
+    // test multiset type.
+    sql("select safe_cast(\"varchar5MultisetArrayType\" as varchar(5) multiset 
array) "
+        + "from COMPLEXTYPES.CTC_T1")
+        .withOperatorTable(opTable)
+        .withExtendedCatalog()
+        .columnType("VARCHAR(5) MULTISET ARRAY");
+  }
+
+  @Test void testTryCastAsRowType() {
+    final SqlOperatorTable opTable = operatorTableFor(SqlLibrary.MSSQL);
+
+    sql("select try_cast(a as row(f0 int, f1 varchar)) from 
COMPLEXTYPES.CTC_T1")
+        .withOperatorTable(opTable)
+        .withExtendedCatalog()
+        .columnType("RecordType(INTEGER F0, VARCHAR F1)");
+    // test nested row type.
+    sql("select "
+        + "try_cast(c as row("
+        + "f0 row(ff0 int, ff1 varchar), "
+        + "f1 timestamp))"
+        + " from COMPLEXTYPES.CTC_T1")
+        .withOperatorTable(opTable)
+        .withExtendedCatalog()
+        .columnType("RecordType("
+            + "RecordType(INTEGER FF0, VARCHAR FF1) F0, "
+            + "TIMESTAMP(0) F1)");
+    // test row type in collection data types.
+    sql("select try_cast(d as row(f0 bigint, f1 decimal) array)\n"
+        + "from COMPLEXTYPES.CTC_T1")
+        .withOperatorTable(opTable)
+        .withExtendedCatalog()
+        .columnType("RecordType(BIGINT F0, DECIMAL(19, 0) F1) "
+            + "ARRAY");
+  }
+
   @Test void testMultisetConstructor() {
     sql("select multiset[1,null,2] as a from (values (1))")
         .columnType("INTEGER MULTISET NOT NULL");

Reply via email to