cloud-fan commented on a change in pull request #34001:
URL: https://github.com/apache/spark/pull/34001#discussion_r709343205
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
##########
@@ -261,4 +267,116 @@ object DataSourceUtils extends PredicateHelper {
dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet))
(ExpressionSet(partitionFilters ++ extraPartitionFilter).toSeq,
dataFilters)
}
+
+ def convertV1FilterToV2(v1Filter: sources.Filter): V2Filter = {
+ v1Filter match {
+ case _: sources.AlwaysFalse =>
+ new V2AlwaysFalse
+ case _: sources.AlwaysTrue =>
+ new V2AlwaysTrue
+ case e: sources.EqualNullSafe =>
+ new V2EqualNullSafe(FieldReference(e.attribute),
getLiteralValue(e.value))
+ case equal: sources.EqualTo =>
+ new V2EqualTo(FieldReference(equal.attribute),
getLiteralValue(equal.value))
+ case g: sources.GreaterThan =>
+ new V2GreaterThan(FieldReference(g.attribute),
getLiteralValue(g.value))
+ case ge: sources.GreaterThanOrEqual =>
+ new V2GreaterThanOrEqual(FieldReference(ge.attribute),
getLiteralValue(ge.value))
+ case in: sources.In =>
+ new V2In(FieldReference(
+ in.attribute), in.values.map(value => getLiteralValue(value)))
+ case notNull: sources.IsNotNull =>
+ new V2IsNotNull(FieldReference(notNull.attribute))
+ case isNull: sources.IsNull =>
+ new V2IsNull(FieldReference(isNull.attribute))
+ case l: sources.LessThan =>
+ new V2LessThan(FieldReference(l.attribute), getLiteralValue(l.value))
+ case le: sources.LessThanOrEqual =>
+ new V2LessThanOrEqual(FieldReference(le.attribute),
getLiteralValue(le.value))
+ case contains: sources.StringContains =>
+ new V2StringContains(
+ FieldReference(contains.attribute),
UTF8String.fromString(contains.value))
+ case ends: sources.StringEndsWith =>
+ new V2StringEndsWith(FieldReference(ends.attribute),
UTF8String.fromString(ends.value))
+ case starts: sources.StringStartsWith =>
+ new V2StringStartsWith(
+ FieldReference(starts.attribute),
UTF8String.fromString(starts.value))
+ case and: sources.And =>
+ new V2And(convertV1FilterToV2(and.left),
convertV1FilterToV2(and.right))
+ case or: sources.Or =>
+ new V2Or(convertV1FilterToV2(or.left), convertV1FilterToV2(or.right))
+ case not: sources.Not =>
+ new V2Not(convertV1FilterToV2(not.child))
+ case _ => throw QueryCompilationErrors.invalidFilter(v1Filter)
+ }
+ }
+
+ def getLiteralValue(value: Any): LiteralValue[_] = value match {
+ case _: JavaBigDecimal =>
+ LiteralValue(Decimal(value.asInstanceOf[JavaBigDecimal]),
DecimalType.SYSTEM_DEFAULT)
+ case _: JavaBigInteger =>
+ LiteralValue(Decimal(value.asInstanceOf[JavaBigInteger]),
DecimalType.SYSTEM_DEFAULT)
+ case _: BigDecimal =>
+ LiteralValue(Decimal(value.asInstanceOf[BigDecimal]),
DecimalType.SYSTEM_DEFAULT)
+ case _: Boolean => LiteralValue(value, BooleanType)
+ case _: Byte => LiteralValue(value, ByteType)
+ case _: Array[Byte] => LiteralValue(value, BinaryType)
+ case _: Date =>
+ val date = DateTimeUtils.fromJavaDate(value.asInstanceOf[Date])
+ LiteralValue(date, DateType)
+ case _: LocalDate =>
+ val date = DateTimeUtils.localDateToDays(value.asInstanceOf[LocalDate])
+ LiteralValue(date, DateType)
+ case _: Double => LiteralValue(value, DoubleType)
+ case _: Float => LiteralValue(value, FloatType)
+ case _: Integer => LiteralValue(value, IntegerType)
+ case _: Long => LiteralValue(value, LongType)
+ case _: Short => LiteralValue(value, ShortType)
+ case _: String => LiteralValue(UTF8String.fromString(value.toString),
StringType)
+ case _: Timestamp =>
+ val ts = DateTimeUtils.fromJavaTimestamp(value.asInstanceOf[Timestamp])
+ LiteralValue(ts, TimestampType)
+ case _: Instant =>
+ val ts = DateTimeUtils.instantToMicros(value.asInstanceOf[Instant])
+ LiteralValue(ts, TimestampType)
+ case _ =>
+ throw QueryCompilationErrors.invalidDataTypeForFilterValue(value)
Review comment:
ditto
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
##########
@@ -261,4 +267,116 @@ object DataSourceUtils extends PredicateHelper {
dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet))
(ExpressionSet(partitionFilters ++ extraPartitionFilter).toSeq,
dataFilters)
}
+
+ def convertV1FilterToV2(v1Filter: sources.Filter): V2Filter = {
+ v1Filter match {
+ case _: sources.AlwaysFalse =>
+ new V2AlwaysFalse
+ case _: sources.AlwaysTrue =>
+ new V2AlwaysTrue
+ case e: sources.EqualNullSafe =>
+ new V2EqualNullSafe(FieldReference(e.attribute),
getLiteralValue(e.value))
+ case equal: sources.EqualTo =>
+ new V2EqualTo(FieldReference(equal.attribute),
getLiteralValue(equal.value))
+ case g: sources.GreaterThan =>
+ new V2GreaterThan(FieldReference(g.attribute),
getLiteralValue(g.value))
+ case ge: sources.GreaterThanOrEqual =>
+ new V2GreaterThanOrEqual(FieldReference(ge.attribute),
getLiteralValue(ge.value))
+ case in: sources.In =>
+ new V2In(FieldReference(
+ in.attribute), in.values.map(value => getLiteralValue(value)))
+ case notNull: sources.IsNotNull =>
+ new V2IsNotNull(FieldReference(notNull.attribute))
+ case isNull: sources.IsNull =>
+ new V2IsNull(FieldReference(isNull.attribute))
+ case l: sources.LessThan =>
+ new V2LessThan(FieldReference(l.attribute), getLiteralValue(l.value))
+ case le: sources.LessThanOrEqual =>
+ new V2LessThanOrEqual(FieldReference(le.attribute),
getLiteralValue(le.value))
+ case contains: sources.StringContains =>
+ new V2StringContains(
+ FieldReference(contains.attribute),
UTF8String.fromString(contains.value))
+ case ends: sources.StringEndsWith =>
+ new V2StringEndsWith(FieldReference(ends.attribute),
UTF8String.fromString(ends.value))
+ case starts: sources.StringStartsWith =>
+ new V2StringStartsWith(
+ FieldReference(starts.attribute),
UTF8String.fromString(starts.value))
+ case and: sources.And =>
+ new V2And(convertV1FilterToV2(and.left),
convertV1FilterToV2(and.right))
+ case or: sources.Or =>
+ new V2Or(convertV1FilterToV2(or.left), convertV1FilterToV2(or.right))
+ case not: sources.Not =>
+ new V2Not(convertV1FilterToV2(not.child))
+ case _ => throw QueryCompilationErrors.invalidFilter(v1Filter)
+ }
+ }
+
+ def getLiteralValue(value: Any): LiteralValue[_] = value match {
+ case _: JavaBigDecimal =>
+ LiteralValue(Decimal(value.asInstanceOf[JavaBigDecimal]),
DecimalType.SYSTEM_DEFAULT)
+ case _: JavaBigInteger =>
+ LiteralValue(Decimal(value.asInstanceOf[JavaBigInteger]),
DecimalType.SYSTEM_DEFAULT)
+ case _: BigDecimal =>
+ LiteralValue(Decimal(value.asInstanceOf[BigDecimal]),
DecimalType.SYSTEM_DEFAULT)
+ case _: Boolean => LiteralValue(value, BooleanType)
+ case _: Byte => LiteralValue(value, ByteType)
+ case _: Array[Byte] => LiteralValue(value, BinaryType)
+ case _: Date =>
+ val date = DateTimeUtils.fromJavaDate(value.asInstanceOf[Date])
+ LiteralValue(date, DateType)
+ case _: LocalDate =>
+ val date = DateTimeUtils.localDateToDays(value.asInstanceOf[LocalDate])
+ LiteralValue(date, DateType)
+ case _: Double => LiteralValue(value, DoubleType)
+ case _: Float => LiteralValue(value, FloatType)
+ case _: Integer => LiteralValue(value, IntegerType)
+ case _: Long => LiteralValue(value, LongType)
+ case _: Short => LiteralValue(value, ShortType)
+ case _: String => LiteralValue(UTF8String.fromString(value.toString),
StringType)
+ case _: Timestamp =>
+ val ts = DateTimeUtils.fromJavaTimestamp(value.asInstanceOf[Timestamp])
+ LiteralValue(ts, TimestampType)
+ case _: Instant =>
+ val ts = DateTimeUtils.instantToMicros(value.asInstanceOf[Instant])
+ LiteralValue(ts, TimestampType)
+ case _ =>
+ throw QueryCompilationErrors.invalidDataTypeForFilterValue(value)
+ }
+
+ def convertV2FilterToV1(v2Filter: V2Filter): sources.Filter = {
+ v2Filter match {
+ case _: V2AlwaysFalse => sources.AlwaysFalse
+ case _: V2AlwaysTrue => sources.AlwaysTrue
+ case e: V2EqualNullSafe => sources.EqualNullSafe(e.column.describe,
+ CatalystTypeConverters.convertToScala(e.value.value, e.value.dataType))
+ case equal: V2EqualTo => sources.EqualTo(equal.column.describe,
+ CatalystTypeConverters.convertToScala(equal.value.value,
equal.value.dataType))
+ case g: V2GreaterThan => sources.GreaterThan(g.column.describe,
+ CatalystTypeConverters.convertToScala(g.value.value, g.value.dataType))
+ case ge: V2GreaterThanOrEqual =>
sources.GreaterThanOrEqual(ge.column.describe,
+ CatalystTypeConverters.convertToScala(ge.value.value,
ge.value.dataType))
+ case in: V2In =>
+ var array: Array[Any] = Array.empty
+ for (value <- in.values) {
+ array = array :+ CatalystTypeConverters.convertToScala(value.value,
value.dataType)
+ }
+ sources.In(in.column.describe, array)
+ case notNull: V2IsNotNull => sources.IsNotNull(notNull.column.describe)
+ case isNull: V2IsNull => sources.IsNull(isNull.column.describe)
+ case l: V2LessThan => sources.LessThan(l.column.describe,
+ CatalystTypeConverters.convertToScala(l.value.value, l.value.dataType))
+ case le: V2LessThanOrEqual => sources.LessThanOrEqual(le.column.describe,
+ CatalystTypeConverters.convertToScala(le.value.value,
le.value.dataType))
+ case contains: V2StringContains =>
+ sources.StringContains(contains.column.describe,
contains.value.toString)
+ case ends: V2StringEndsWith =>
+ sources.StringEndsWith(ends.column.describe, ends.value.toString)
+ case starts: V2StringStartsWith =>
+ sources.StringStartsWith(starts.column.describe, starts.value.toString)
+ case and: V2And => sources.And(convertV2FilterToV1(and.left),
convertV2FilterToV1(and.right))
+ case or: V2Or => sources.Or(convertV2FilterToV1(or.left),
convertV2FilterToV1(or.right))
+ case not: V2Not => sources.Not(convertV2FilterToV1(not.child))
+ case _ => throw QueryCompilationErrors.invalidFilter(v2Filter)
Review comment:
ditto
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]