Github user rxin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/7534#discussion_r35060597
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
 ---
    @@ -593,58 +593,79 @@ case class Substring(str: Expression, pos: 
Expression, len: Expression)
       override def foldable: Boolean = str.foldable && pos.foldable && 
len.foldable
       override def nullable: Boolean = str.nullable || pos.nullable || 
len.nullable
     
    -  override def dataType: DataType = {
    -    if (!resolved) {
    -      throw new UnresolvedException(this, s"Cannot resolve since $children 
are not resolved")
    -    }
    -    if (str.dataType == BinaryType) str.dataType else StringType
    -  }
    +  override def dataType: DataType = StringType
     
       override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, 
IntegerType)
     
       override def children: Seq[Expression] = str :: pos :: len :: Nil
     
    -  @inline
    -  def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, 
Int) = {
    +  override def eval(input: InternalRow): Any = {
    +    // Information regarding the pos calculation:
         // Hive and SQL use one-based indexing for SUBSTR arguments but also 
accept zero and
         // negative indices for start positions. If a start index i is greater 
than 0, it
         // refers to element i-1 in the sequence. If a start index i is less 
than 0, it refers
         // to the -ith element before the end of the sequence. If a start 
index i is 0, it
         // refers to the first element.
    -
    -    val start = startPos match {
    -      case pos if pos > 0 => pos - 1
    -      case neg if neg < 0 => length() + neg
    -      case _ => 0
    -    }
    -
    -    val end = sliceLen match {
    -      case max if max == Integer.MAX_VALUE => max
    -      case x => start + x
    +    val string = str.eval(input)
    +    if (string != null) {
    +      val po = pos.eval(input)
    +      if (po != null) {
    +        val ln = len.eval(input)
    +        if (ln != null) {
    +          val length = ln.asInstanceOf[Int]
    +          val s = string.asInstanceOf[UTF8String]
    +          val pos = po.asInstanceOf[Int]
    +          val start = {
    +            if (pos > 0) {
    +              pos - 1
    +            } else {
    +              if (pos < 0) s.numChars() + pos else 0
    +            }
    +          }
    +          val end = if (length == Integer.MAX_VALUE) Integer.MAX_VALUE 
else start + length
    +          s.substring(start, end)
    +        } else {
    +          null
    +        }
    +      } else {
    +        null
    +      }
    +    } else {
    +      null
         }
    -
    -    (start, end)
       }
     
    -  override def eval(input: InternalRow): Any = {
    -    val string = str.eval(input)
    -    val po = pos.eval(input)
    -    val ln = len.eval(input)
    -
    -    if ((string == null) || (po == null) || (ln == null)) {
    -      null
    -    } else {
    -      val start = po.asInstanceOf[Int]
    -      val length = ln.asInstanceOf[Int]
    -      string match {
    -        case ba: Array[Byte] =>
    -          val (st, end) = slicePos(start, length, () => ba.length)
    -          ba.slice(st, end)
    -        case s: UTF8String =>
    -          val (st, end) = slicePos(start, length, () => s.numChars())
    -          s.substring(st, end)
    +  override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
    +    val strGen = str.gen(ctx)
    +    val posGen = pos.gen(ctx)
    +    val lenGen = len.gen(ctx)
    +
    +    val start = ctx.freshName("start")
    +    val end = ctx.freshName("end")
    +
    +    s"""
    +      ${strGen.code}
    +      boolean ${ev.isNull} = ${strGen.isNull};
    +      ${ctx.javaType(dataType)} ${ev.primitive} = 
${ctx.defaultValue(dataType)};
    +      if (!${ev.isNull}) {
    +        ${posGen.code}
    +        if (!${posGen.isNull}) {
    +          ${lenGen.code}
    +          if (!${lenGen.isNull}) {
    +            int $start = (${posGen.primitive} > 0) ? ${posGen.primitive} - 
1 :
    --- End diff --
    
    maybe we can create a substring variant in UTF8String that encapsulates 
these logic?  Could be something like substringSQL.
    
    The goal is to minimize the duplicated code between interpreted version and 
code gen version.
    



---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to