sunchao commented on a change in pull request #32583:
URL: https://github.com/apache/spark/pull/32583#discussion_r683631913
##########
File path:
sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala
##########
@@ -900,13 +904,80 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
s"${SQLConf.HIVE_METASTORE_PARTITION_PRUNING_FALLBACK_ON_EXCEPTION.key} " +
" to false and let the query fail instead.", ex)
// HiveShim clients are expected to handle a superset of the
requested partitions
- getAllPartitionsMethod.invoke(hive,
table).asInstanceOf[JSet[Partition]]
+ prunePartitionsEvalClientSide(hive, table, catalogTable,
predicates)
case ex: InvocationTargetException if
ex.getCause.isInstanceOf[MetaException] =>
throw QueryExecutionErrors.getPartitionMetadataByFilterError(ex)
}
}
- partitions.asScala.toSeq
+ partitions.toSeq
+ }
+
+ def prunePartitionsEvalClientSide(hive: Hive,
+ table: Table,
+ catalogTable: CatalogTable,
+ predicates: Seq[Expression]): Seq[Partition] = {
+
+ val timeZoneId = SQLConf.get.sessionLocalTimeZone
+
+ // Because there is no way to know whether the partition properties has
timeZone,
+ // client-side filtering cannot be used with TimeZoneAwareExpression.
+ def hasTimeZoneAwareExpression(e: Expression): Boolean = {
+ e.collectFirst {
+ case t: TimeZoneAwareExpression => t
+ }.isDefined
+ }
+
+ if (!SQLConf.get.metastorePartitionPruningEvalClientSide ||
+ predicates.isEmpty ||
+ predicates.exists(hasTimeZoneAwareExpression)) {
+ getAllPartitionsMethod.invoke(hive,
table).asInstanceOf[JSet[Partition]].asScala.toSeq
+ } else {
+ try {
+ val partitionSchema =
CharVarcharUtils.replaceCharVarcharWithStringInSchema(
+ catalogTable.partitionSchema)
+ val partitionColumnNames = catalogTable.partitionColumnNames.toSet
+
+ val nonPartitionPruningPredicates = predicates.filterNot {
+ _.references.map(_.name).toSet.subsetOf(partitionColumnNames)
+ }
+ if (nonPartitionPruningPredicates.nonEmpty) {
+ throw
QueryCompilationErrors.nonPartitionPruningPredicatesNotExpectedError(
+ nonPartitionPruningPredicates)
+ }
+
+ val boundPredicate =
+ Predicate.createInterpreted(predicates.reduce(And).transform {
+ case att: AttributeReference =>
+ val index = partitionSchema.indexWhere(_.name == att.name)
+ BoundReference(index, partitionSchema(index).dataType, nullable
= true)
+ })
+
+ def toRow(spec: TablePartitionSpec): InternalRow = {
+ InternalRow.fromSeq(partitionSchema.map { field =>
+ val partValue = if (spec(field.name) ==
ExternalCatalogUtils.DEFAULT_PARTITION_NAME) {
+ null
+ } else {
+ spec(field.name)
+ }
+ Cast(Literal(partValue), field.dataType, Option(timeZoneId)).eval()
+ })
+ }
+
+ val allPartitionNames = hive.getPartitionNames(
Review comment:
+1. This looks like a nice optimization.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]