alamb commented on code in PR #15776:
URL: https://github.com/apache/datafusion/pull/15776#discussion_r2192719901


##########
datafusion/sqllogictest/test_files/corr_type_coercion.slt:
##########
@@ -0,0 +1,248 @@
+# Licensed to the Apache Software Foundation (ASF) under one

Review Comment:
   I am not sure how much coverage this file provides compared to what was 
added in aggregate.slt
   
   However, I think it is ok to keep



##########
datafusion/sqllogictest/test_files/corr_type_coercion.slt:
##########
@@ -0,0 +1,248 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+statement ok
+CREATE TABLE numeric_types (
+    int8_col TINYINT,
+    int16_col SMALLINT,
+    int32_col INT,
+    int64_col BIGINT,
+    uint8_col TINYINT UNSIGNED,
+    uint16_col SMALLINT UNSIGNED,
+    uint32_col INT UNSIGNED,
+    uint64_col BIGINT UNSIGNED,
+    float32_col FLOAT,
+    float64_col DOUBLE
+);
+
+statement ok
+INSERT INTO numeric_types VALUES
+    (1, 2, 3, 4, 5, 6, 7, 8, 9.1, 10.1),
+    (2, 3, 4, 5, 6, 7, 8, 9, 10.2, 11.2),
+    (3, 4, 5, 6, 7, 8, 9, 10, 11.3, 12.3),
+    (4, 5, 6, 7, 8, 9, 10, 11, 12.4, 13.4),
+    (5, 6, 7, 8, 9, 10, 11, 12, 13.5, 14.5);
+
+# Mixed int8 and int16
+query R
+SELECT corr(int8_col, int16_col) FROM numeric_types;
+----
+1
+
+# Mixed int8 and int32
+query R
+SELECT corr(int8_col, int32_col) FROM numeric_types;
+----
+1
+
+# Mixed int8 and int64
+query R
+SELECT corr(int8_col, int64_col) FROM numeric_types;
+----
+1
+
+# Mixed int8 and float32
+query R
+SELECT corr(int8_col, float32_col) FROM numeric_types;
+----
+1
+
+# Mixed int8 and float64
+query R
+SELECT corr(int8_col, float64_col) FROM numeric_types;
+----
+1
+
+# Mixed int16 and int32
+query R
+SELECT corr(int16_col, int32_col) FROM numeric_types;
+----
+1
+
+# Mixed int16 and float64
+query R
+SELECT corr(int16_col, float64_col) FROM numeric_types;
+----
+1
+
+# Mixed int32 and float32
+query R
+SELECT corr(int32_col, float32_col) FROM numeric_types;
+----
+1
+
+# Mixed uint8 and int64
+query R
+SELECT corr(uint8_col, int64_col) FROM numeric_types;
+----
+1
+
+# Mixed uint16 and float32
+query R
+SELECT corr(uint16_col, float32_col) FROM numeric_types;
+----
+1
+
+# Mixed uint32 and float64
+query R
+SELECT corr(uint32_col, float64_col) FROM numeric_types;
+----
+1
+
+# Mixed uint64 and int8
+query R
+SELECT corr(uint64_col, int8_col) FROM numeric_types;
+----
+1
+
+# Float32 and Float64
+query R
+SELECT corr(float32_col, float64_col) FROM numeric_types;
+----
+1
+
+# Literal of different numeric types
+query R
+SELECT corr(int8_col, 10.5) FROM numeric_types;
+----
+0
+
+# Literal integer and column
+query R
+SELECT corr(123, float32_col) FROM numeric_types;
+----
+0
+
+# Test with NULL values
+statement ok
+INSERT INTO numeric_types VALUES
+    (NULL, 7, 8, 9, 10, 11, 12, 13, 14.6, 15.6),
+    (6, NULL, 9, 10, 11, 12, 13, 14, 15.7, 16.7),
+    (7, 8, NULL, 11, 12, 13, 14, 15, 16.8, 17.8);
+
+# NULL handling
+query R
+SELECT corr(int8_col, int16_col) FROM numeric_types;
+----
+1
+
+# Complex query with GROUP BY
+statement ok
+CREATE TABLE grouped_data (
+    category VARCHAR,
+    value1 INT,
+    value2 FLOAT
+);
+
+statement ok
+INSERT INTO grouped_data VALUES
+    ('A', 1, 10.1),
+    ('A', 2, 11.2),
+    ('A', 3, 12.3),
+    ('B', 4, 20.4),
+    ('B', 5, 21.5),
+    ('B', 6, 22.6),
+    ('C', 7, 30.7),
+    ('C', 8, 31.8),
+    ('C', 9, 32.9);
+
+query TR rowsort
+SELECT category, corr(value1, value2) FROM grouped_data GROUP BY category;
+----
+A 1
+B 1
+C 0.999999999999
+
+# Verify that the physical plan is using our type coercion
+# by showing the optimized plan for a correlation query
+
+query TT
+EXPLAIN SELECT corr(int8_col, float32_col) FROM numeric_types;
+----
+logical_plan
+01)Aggregate: groupBy=[[]], aggr=[[corr(CAST(numeric_types.int8_col AS 
Float64), CAST(numeric_types.float32_col AS Float64))]]
+02)--TableScan: numeric_types projection=[int8_col, float32_col]
+physical_plan
+01)AggregateExec: mode=Single, gby=[], 
aggr=[corr(numeric_types.int8_col,numeric_types.float32_col)]
+02)--DataSourceExec: partitions=1, partition_sizes=[2]
+
+# Verify the return type is Float64 with a cast that would fail otherwise
+query R
+SELECT CAST(corr(int8_col, float32_col) AS DOUBLE) + 0.0 FROM numeric_types;
+----
+0.99318328739
+
+# Benchmark query pattern - corr with integer columns, power function, and 
GROUP BY
+statement ok
+CREATE TABLE benchmark_data (
+    id1 VARCHAR,
+    id2 VARCHAR,
+    id3 VARCHAR,
+    id4 INT,
+    id5 INT,
+    id6 INT,
+    v1 INT,
+    v2 INT,
+    v3 DOUBLE
+);
+
+statement ok
+INSERT INTO benchmark_data VALUES
+    ('a1', 'group1', 'x1', 1, 10, 20, 100, 200, 1.1),
+    ('a2', 'group1', 'x2', 1, 11, 21, 110, 220, 1.2),
+    ('a3', 'group1', 'x3', 1, 12, 22, 120, 240, 1.3),
+    ('a4', 'group1', 'x4', 1, 13, 23, 130, 260, 1.4),
+    ('a5', 'group1', 'x5', 1, 14, 24, 140, 280, 1.5),
+    ('b1', 'group2', 'y1', 2, 15, 25, 200, 150, 2.1),
+    ('b2', 'group2', 'y2', 2, 16, 26, 220, 165, 2.2),
+    ('b3', 'group2', 'y3', 2, 17, 27, 240, 180, 2.3),
+    ('b4', 'group2', 'y4', 2, 18, 28, 260, 195, 2.4),
+    ('b5', 'group2', 'y5', 2, 19, 29, 280, 210, 2.5);
+
+# Run the benchmark query pattern
+query TIR rowsort
+SELECT id2, id4, power(corr(v1, v2), 2) as r2 FROM benchmark_data GROUP BY 
id2, id4;
+----
+group1 1 1
+group2 2 1
+
+# Verify the benchmark query's physical plan to confirm type coercion is 
happening

Review Comment:
   I didn't really see any verification in this plan -- I would expect to see 
casting happening in the datasource of something



##########
datafusion/functions-aggregate/src/correlation.rs:
##########
@@ -83,10 +82,13 @@ impl Default for Correlation {
 }
 
 impl Correlation {
-    /// Create a new COVAR_POP aggregate function
+    /// Create a new CORR aggregate function
     pub fn new() -> Self {
         Self {
-            signature: Signature::uniform(2, NUMERICS.to_vec(), 
Volatility::Immutable),
+            signature: Signature::exact(
+                vec![DataType::Float64, DataType::Float64],

Review Comment:
   👍 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to