Github user rxin commented on a diff in the pull request:
https://github.com/apache/spark/pull/6775#discussion_r32701067
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
---
@@ -313,3 +312,78 @@ case class StringLength(child: Expression) extends
UnaryExpression with ExpectsI
defineCodeGen(ctx, ev, c => s"($c).length()")
}
}
+
+/**
+ * Like Concat below, but with custom separator SEP.
+ */
+case class ConcatWS(children: Expression*) extends Expression {
+ // return type is always String
+ override def dataType: DataType = StringType
+ override def nullable: Boolean = sep.nullable
+ override def foldable: Boolean = children.forall(_.foldable)
+ override def toString: String = s"""CONCAT_WS($children)"""
+ private def sep = children.head
+ private def exprs = children.tail
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ def supportedType(dt: DataType): Boolean = dt match {
+ case ArrayType(StringType, _) => true
+ case ArrayType(NullType, _) => true
+ case StringType => true
+ case NullType => true
+ case _ => false
+ }
+ if (sep.dataType != StringType && sep.dataType != NullType) {
+ TypeCheckResult.TypeCheckFailure(
+ s"type of separator expression in ConcatWS should be string, not
${sep.dataType}")
+ } else if (children.size < 2) {
+ TypeCheckResult.TypeCheckFailure(
+ s"ConcatWS takes at least two arguments")
+ } else if (exprs.exists(expr => !supportedType(expr.dataType))) {
+ TypeCheckResult.TypeCheckFailure(
+ "type of exprs expressions in ConcatWS should be array(string) or
string, not" +
+ s" ${exprs.map(_.dataType)}")
+ } else {
+ TypeCheckResult.TypeCheckSuccess
+ }
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val sepEval = sep.eval(input)
+ if (sepEval != null) {
+ val childrenArr = exprs.map(expr => (expr.eval(input),
expr.dataType))
+ val separator = sepEval.asInstanceOf[UTF8String].toString
+ val validSeq = childrenArr.filter(_._1 != null).map(child =>
child._2 match {
+ case StringType => child._1.asInstanceOf[UTF8String].toString
+ case ArrayType(StringType, _) =>
child._1.asInstanceOf[Seq[UTF8String]].mkString(separator)
+ case ArrayType(NullType, _) =>
child._1.asInstanceOf[Seq[UTF8String]].mkString(separator)
+ })
+ UTF8String.fromString(validSeq.mkString(separator))
+ } else {
+ null
+ }
+ }
+}
+
+/**
+ * A function that returns the string or bytes resulting from
concatenating the strings or bytes
+ * passed in as parameters in order. For example, concat('foo', 'bar')
results in 'foobar'. Note
+ * that this function can take any number of input strings.
+ */
+case class Concat(children: Expression*)
+ extends Expression with ExpectsInputTypes {
+ override def dataType: DataType = StringType
+ override def nullable: Boolean = children.exists(_.nullable)
+ override def foldable: Boolean = children.forall(_.foldable)
+ override def expectedChildTypes: Seq[DataType] =
Seq.fill(children.size)(StringType)
+ override def toString: String = s"""CONCAT($children)"""
+
+ override def eval(input: InternalRow): Any = {
+ val validSeq = children.map(_.eval(input))
--- End diff --
Doesn't the internal string return UTF8String? If it does, this is calling
toString on every UTF8String -- very inefficient.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]