cloud-fan commented on a change in pull request #33352: URL: https://github.com/apache/spark/pull/33352#discussion_r672345811
########## File path: sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/AggregateFunc.java ########## @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import org.apache.spark.annotation.Evolving; + +/** + * Base class of the Aggregate Functions. + * + * @since 3.2.0 + */ +@Evolving +public interface AggregateFunc { Review comment: should this be `Serializable`? ########## File path: sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Aggregation.java ########## @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import org.apache.spark.annotation.Evolving; + +import java.io.Serializable; + +/** + * Aggregation in SQL statement. + * + * @since 3.2.0 + */ +@Evolving +public class Aggregation implements Serializable { Review comment: why does it need to be `Serializable`? ########## File path: sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Aggregation.java ########## @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import org.apache.spark.annotation.Evolving; + +import java.io.Serializable; + +/** + * Aggregation in SQL statement. + * + * @since 3.2.0 + */ +@Evolving +public class Aggregation implements Serializable { Review comment: and can be `final class`? ########## File path: sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java ########## @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Aggregation; + +/** + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to + * push down aggregates. Spark assumes that the data source can't fully complete the + * grouping work, and will group the data source output again. For queries like + * "SELECT min(value) AS m FROM t GROUP BY key", after pushing down the aggregate + * to the data source, the data source can still output data with duplicated keys, which is OK + * as Spark will do GROUP BY key again. The final query plan can be something like this: + * {{{ + * Aggregate [key#1], [min(min(value)#2) AS AS m#3] Review comment: nit: `AS AS` -> `AS` ########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/aggregates.scala ########## @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions + +import org.apache.spark.sql.types.DataType + +case class Min(column: FieldReference) extends AggregateFunc Review comment: are these public APIs? If yes we should write them in Java as well. another way is to follow `Transform`. We only expose the function name and arguments, and the concrete implementations are internal. ########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/aggregates.scala ########## @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions + +import org.apache.spark.sql.types.DataType + +case class Min(column: FieldReference) extends AggregateFunc + +case class Max(column: FieldReference) extends AggregateFunc Review comment: We can also document the data type of `Sum`: 1. return long for integral types 2. return double for float/double 3. return `decimal(p + 10, s)` for `decimal(p, s)` ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala ########## @@ -129,12 +131,32 @@ case class RowDataSourceScanExec( override def inputRDD: RDD[InternalRow] = rdd override val metadata: Map[String, String] = { - val markedFilters = for (filter <- filters) yield { - if (handledFilters.contains(filter)) s"*$filter" else s"$filter" + + def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]") + + val (aggString, groupByString) = if (aggregation.nonEmpty) { + (seqToString(aggregation.get.getAggregateExpressions), + seqToString(aggregation.get.getGroupByColumns)) + } else { + ("[]", "[]") + } + + if (filters.nonEmpty) { Review comment: nit: ``` val markedFilters = ... Map(...) ``` ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala ########## @@ -221,6 +269,19 @@ private[jdbc] class JDBCRDD( } } + /** + * A GROUP BY clause representing pushed-down grouping columns. + */ + private def getGroupByClause: String = { + if (aggregation.nonEmpty && aggregation.get.getGroupByColumns.length > 0) { + val quotedColumns = aggregation.get.getGroupByColumns + .map(c => JdbcDialects.get(url).quoteIdentifier(c.fieldNames.head)) Review comment: we should create `JdbcDialects` only once. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala ########## @@ -133,6 +144,32 @@ object JDBCRDD extends Logging { }) } + def compileAggregates( + aggregates: Seq[AggregateFunc], + dialect: JdbcDialect): (Array[String]) = { Review comment: `(Array[String])` -> `Array[String]` ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala ########## @@ -133,6 +144,32 @@ object JDBCRDD extends Logging { }) } + def compileAggregates( + aggregates: Seq[AggregateFunc], + dialect: JdbcDialect): (Array[String]) = { + def quote(colName: String): String = dialect.quoteIdentifier(colName) + + val aggBuilder = ArrayBuilder.make[String] + aggregates.map { Review comment: Sorry I don't get it. Why not just `aggregates.map`? Why do we need an ArrayBuilder? ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala ########## @@ -49,6 +51,60 @@ case class JDBCScanBuilder( override def pushedFilters(): Array[Filter] = pushedFilter + private var pushedAggregations = Option.empty[Aggregation] + + private var pushedAggregateColumn: Array[String] = Array() + + private def getStructFieldForCol(col: FieldReference): StructField = + schema.fields(schema.fieldNames.toList.indexOf(col.fieldNames.head)) + + override def pushAggregation(aggregation: Aggregation): Boolean = { + if (!jdbcOptions.pushDownAggregate) return false + + val dialect = JdbcDialects.get(jdbcOptions.url) + val compiledAgg = JDBCRDD.compileAggregates(aggregation.getAggregateExpressions, dialect) + // if any of the aggregates is not supported by the data source, not push down + if (compiledAgg.length != aggregation.getAggregateExpressions.size) return false Review comment: this can never happen ########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/aggregates.scala ########## @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions + +import org.apache.spark.sql.types.DataType + +case class Min(column: FieldReference) extends AggregateFunc + +case class Max(column: FieldReference) extends AggregateFunc Review comment: We can also document the data type of `Sum`: 1. return long for integral types 2. return double for float/double 3. return `decimal(p + 10, s)` for `decimal(p, s)` ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala ########## @@ -181,7 +228,8 @@ private[jdbc] class JDBCRDD( filters: Array[Filter], partitions: Array[Partition], url: String, - options: JDBCOptions) + options: JDBCOptions, + aggregation: Option[Aggregation]) Review comment: Seems we only need a `groupByColumns: Array[FieldReference]` here. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala ########## @@ -152,19 +191,27 @@ object JDBCRDD extends Logging { requiredColumns: Array[String], filters: Array[Filter], parts: Array[Partition], - options: JDBCOptions): RDD[InternalRow] = { + options: JDBCOptions, + requiredSchema: Option[StructType] = None, Review comment: `outputSchema`? ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala ########## @@ -152,19 +191,27 @@ object JDBCRDD extends Logging { requiredColumns: Array[String], filters: Array[Filter], parts: Array[Partition], - options: JDBCOptions): RDD[InternalRow] = { + options: JDBCOptions, + requiredSchema: Option[StructType] = None, + aggregation: Option[Aggregation] = None): RDD[InternalRow] = { val url = options.url val dialect = JdbcDialects.get(url) - val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) + val quotedColumns = if (aggregation.isEmpty) { + requiredColumns.map(colName => dialect.quoteIdentifier(colName)) + } else { + // these are already quoted in JDBCScanBuilder + requiredColumns + } new JDBCRDD( sc, JdbcUtils.createConnectionFactory(options), - pruneSchema(schema, requiredColumns), + pruneSchema(schema, requiredColumns, requiredSchema), Review comment: `requiredSchema.getOrElse(pruneSchema...)`, then we don't need to change `pruneSchema` at all. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
