[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user mn-mikke commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r218917303 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -665,3 +667,219 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti override def prettyName: String = "element_at" } + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ +Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { + + private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + val allowedTypes = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { +if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess +} else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe { +return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + +s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") +} + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + lazy val javaType: String = CodeGenerator.javaType(dataType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = dataType match { --- End diff -- Thanks! I've created #22471 to call the pattern matching only once. WDYT about [Reverse](https://github.com/apache/spark/pull/21034/files#diff-9853dcf5ce3d2ac1e94d473197ff5768R240)? It looks like a similar problem. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user rxin commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r218677837 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -665,3 +667,219 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti override def prettyName: String = "element_at" } + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ +Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { + + private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + val allowedTypes = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { +if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess +} else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe { +return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + +s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") +} + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + lazy val javaType: String = CodeGenerator.javaType(dataType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = dataType match { --- End diff -- so this pattern match will probably cause significant regression in the interpreted (non-codegen) mode, due to the way scala pattern matching is implemented. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user asfgit closed the pull request at: https://github.com/apache/spark/pull/20858 --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user mn-mikke commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r182357725 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -353,3 +356,218 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override def prettyName: String = "array_max" } + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ +Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { + + private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + val allowedTypes = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { +if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess +} else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe { +return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + +s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") +} + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + lazy val javaType: String = CodeGenerator.javaType(dataType) --- End diff -- Good point! But I think it would be better to reuse `javaType` also in `genCodeForPrimitiveArrays` and `genCodeForNonPrimitiveArrays`. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r182350135 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -353,3 +356,218 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override def prettyName: String = "array_max" } + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ +Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { + + private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + val allowedTypes = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { +if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess +} else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe { +return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + +s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") +} + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + lazy val javaType: String = CodeGenerator.javaType(dataType) + + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = dataType match { +case BinaryType => + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) +case StringType => + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) +case ArrayType(elementType, _) => + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { +null + } else { +val arrayData = inputs.map(_.asInstanceOf[ArrayData]) +val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) +if (numberOfElements > MAX_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + +s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") +} +val finalData = new Array[AnyRef](numberOfElements.toInt) +var position = 0 +for(ad <- arrayData) { + val arr = ad.toObjectArray(elementType) + Array.copy(arr, 0, finalData, position, arr.length) + position += arr.length +} +new GenericArrayData(finalData) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +val evals = children.map(_.genCode(ctx)) +val args = ctx.freshName("args") + +val inputs = evals.zipWithIndex.map { case (eval, index) => + s""" +${eval.code} +if (!${eval.isNull}) { + $args[$index] = ${eval.value}; +} + """ +} + +val (concatenator, initCode) = dataType match { + case BinaryType => +(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") + case StringType => +("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") + case ArrayType(elementType, _) => +val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForPrimitiveArrays(ctx, elementType) +} else { + genCodeForNonPrimitiveArrays(ctx, elementType) +} +(arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") +} +val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = inputs, + funcName = "valueConcat", + extraArguments = (s"${javaType}[]", args) :: Nil) +ev.copy(s""" + $initCode + $codes + ${javaType} ${ev.value} = $concatenator.concat($args); --- End diff -- ni
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r182349064 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -353,3 +356,218 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast override def prettyName: String = "array_max" } + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ +Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { + + private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + val allowedTypes = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { +if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess +} else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe { +return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + +s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") +} + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + lazy val javaType: String = CodeGenerator.javaType(dataType) --- End diff -- We can move this into `doGenCode()` method. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user mn-mikke commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r181740076 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +290,231 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ +Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { + + val allowedTypes = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { +if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess +} else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe { +return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + +s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") +} + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + lazy val javaType: String = CodeGenerator.javaType(dataType) + + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = dataType match { +case BinaryType => + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) +case StringType => + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) +case ArrayType(elementType, _) => + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { +null + } else { +val elements = inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(elementType)) +new GenericArrayData(elements) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +val evals = children.map(_.genCode(ctx)) +val args = ctx.freshName("args") + +val inputs = evals.zipWithIndex.map { case (eval, index) => + s""" +${eval.code} +if (!${eval.isNull}) { + $args[$index] = ${eval.value}; +} + """ +} + +val (concatenator, initCode) = dataType match { + case BinaryType => +(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") + case StringType => +("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") + case ArrayType(elementType, _) => +val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForPrimitiveArrayConcat(ctx, elementType) +} else { + genCodeForComplexArrayConcat(ctx, elementType) +} +(arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") +} +val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = inputs, + funcName = "valueConcat", + extraArguments = (s"${javaType}[]", args) :: Nil) +ev.copy(s""" + $initCode + $codes + ${javaType} ${ev.value} = $concatenator.concat($args); + boolean ${ev.isNull} = ${ev.value} == null; +""") + } + + private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { +val tempVariableName = ctx.freshName("tempNumElements") +val numElementsConstant = ctx.freshName("numElements") +val assignments = (0 until children.length) + .map(idx => s"$tempVariableName[0] += args[$idx].numElements();") + +val assignmentSection = ctx.splitExpressions( + expressions = assignments, + funcName = "complexArrayConcat", + arguments = Seq((s"${javaType}[]", "args"), ("int[]", tempVariableName))) + +(s""" +|int[] $tempVariableName = new int[]{0}; +|$assignmentSection
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r181640349 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +290,231 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ +Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { + + val allowedTypes = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { +if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess +} else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe { +return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + +s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") +} + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + lazy val javaType: String = CodeGenerator.javaType(dataType) + + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = dataType match { +case BinaryType => + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) +case StringType => + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) +case ArrayType(elementType, _) => + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { +null + } else { +val elements = inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(elementType)) +new GenericArrayData(elements) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +val evals = children.map(_.genCode(ctx)) +val args = ctx.freshName("args") + +val inputs = evals.zipWithIndex.map { case (eval, index) => + s""" +${eval.code} +if (!${eval.isNull}) { + $args[$index] = ${eval.value}; +} + """ +} + +val (concatenator, initCode) = dataType match { + case BinaryType => +(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") + case StringType => +("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") + case ArrayType(elementType, _) => +val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForPrimitiveArrayConcat(ctx, elementType) +} else { + genCodeForComplexArrayConcat(ctx, elementType) +} +(arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") +} +val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = inputs, + funcName = "valueConcat", + extraArguments = (s"${javaType}[]", args) :: Nil) +ev.copy(s""" + $initCode + $codes + ${javaType} ${ev.value} = $concatenator.concat($args); + boolean ${ev.isNull} = ${ev.value} == null; +""") + } + + private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { +val tempVariableName = ctx.freshName("tempNumElements") +val numElementsConstant = ctx.freshName("numElements") +val assignments = (0 until children.length) + .map(idx => s"$tempVariableName[0] += args[$idx].numElements();") + +val assignmentSection = ctx.splitExpressions( + expressions = assignments, + funcName = "complexArrayConcat", + arguments = Seq((s"${javaType}[]", "args"), ("int[]", tempVariableName))) + +(s""" +|int[] $tempVariableName = new int[]{0}; +|$assignmentSection +
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r181638570 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +290,231 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ +Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { + + val allowedTypes = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { +if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess +} else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe { +return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + +s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") +} + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + lazy val javaType: String = CodeGenerator.javaType(dataType) + + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = dataType match { +case BinaryType => + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) +case StringType => + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) +case ArrayType(elementType, _) => + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { +null + } else { +val elements = inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(elementType)) +new GenericArrayData(elements) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +val evals = children.map(_.genCode(ctx)) +val args = ctx.freshName("args") + +val inputs = evals.zipWithIndex.map { case (eval, index) => + s""" +${eval.code} +if (!${eval.isNull}) { + $args[$index] = ${eval.value}; +} + """ +} + +val (concatenator, initCode) = dataType match { + case BinaryType => +(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") + case StringType => +("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") + case ArrayType(elementType, _) => +val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForPrimitiveArrayConcat(ctx, elementType) +} else { + genCodeForComplexArrayConcat(ctx, elementType) +} +(arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") +} +val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = inputs, + funcName = "valueConcat", + extraArguments = (s"${javaType}[]", args) :: Nil) +ev.copy(s""" + $initCode + $codes + ${javaType} ${ev.value} = $concatenator.concat($args); + boolean ${ev.isNull} = ${ev.value} == null; +""") + } + + private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { +val tempVariableName = ctx.freshName("tempNumElements") +val numElementsConstant = ctx.freshName("numElements") +val assignments = (0 until children.length) + .map(idx => s"$tempVariableName[0] += args[$idx].numElements();") + +val assignmentSection = ctx.splitExpressions( + expressions = assignments, + funcName = "complexArrayConcat", + arguments = Seq((s"${javaType}[]", "args"), ("int[]", tempVariableName))) + +(s""" +|int[] $tempVariableName = new int[]{0}; +|$assignmentSection +
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r181643397 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +290,231 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ +Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { + + val allowedTypes = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { +if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess +} else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe { +return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + +s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") +} + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + lazy val javaType: String = CodeGenerator.javaType(dataType) + + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = dataType match { +case BinaryType => + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) +case StringType => + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) +case ArrayType(elementType, _) => + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { +null + } else { +val elements = inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(elementType)) +new GenericArrayData(elements) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +val evals = children.map(_.genCode(ctx)) +val args = ctx.freshName("args") + +val inputs = evals.zipWithIndex.map { case (eval, index) => + s""" +${eval.code} +if (!${eval.isNull}) { + $args[$index] = ${eval.value}; +} + """ +} + +val (concatenator, initCode) = dataType match { + case BinaryType => +(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") + case StringType => +("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") + case ArrayType(elementType, _) => +val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForPrimitiveArrayConcat(ctx, elementType) +} else { + genCodeForComplexArrayConcat(ctx, elementType) +} +(arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") +} +val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = inputs, + funcName = "valueConcat", + extraArguments = (s"${javaType}[]", args) :: Nil) +ev.copy(s""" + $initCode + $codes + ${javaType} ${ev.value} = $concatenator.concat($args); + boolean ${ev.isNull} = ${ev.value} == null; +""") + } + + private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { +val tempVariableName = ctx.freshName("tempNumElements") +val numElementsConstant = ctx.freshName("numElements") +val assignments = (0 until children.length) + .map(idx => s"$tempVariableName[0] += args[$idx].numElements();") + +val assignmentSection = ctx.splitExpressions( + expressions = assignments, + funcName = "complexArrayConcat", + arguments = Seq((s"${javaType}[]", "args"), ("int[]", tempVariableName))) + +(s""" +|int[] $tempVariableName = new int[]{0}; +|$assignmentSection +
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r181641673 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +290,231 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ +Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { + + val allowedTypes = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { +if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess +} else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe { +return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + +s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") +} + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + lazy val javaType: String = CodeGenerator.javaType(dataType) + + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = dataType match { +case BinaryType => + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) +case StringType => + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) +case ArrayType(elementType, _) => + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { +null + } else { +val elements = inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(elementType)) +new GenericArrayData(elements) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +val evals = children.map(_.genCode(ctx)) +val args = ctx.freshName("args") + +val inputs = evals.zipWithIndex.map { case (eval, index) => + s""" +${eval.code} +if (!${eval.isNull}) { + $args[$index] = ${eval.value}; +} + """ +} + +val (concatenator, initCode) = dataType match { + case BinaryType => +(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") + case StringType => +("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") + case ArrayType(elementType, _) => +val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForPrimitiveArrayConcat(ctx, elementType) +} else { + genCodeForComplexArrayConcat(ctx, elementType) +} +(arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") +} +val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = inputs, + funcName = "valueConcat", + extraArguments = (s"${javaType}[]", args) :: Nil) +ev.copy(s""" + $initCode + $codes + ${javaType} ${ev.value} = $concatenator.concat($args); + boolean ${ev.isNull} = ${ev.value} == null; +""") + } + + private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { +val tempVariableName = ctx.freshName("tempNumElements") +val numElementsConstant = ctx.freshName("numElements") +val assignments = (0 until children.length) + .map(idx => s"$tempVariableName[0] += args[$idx].numElements();") + +val assignmentSection = ctx.splitExpressions( + expressions = assignments, + funcName = "complexArrayConcat", + arguments = Seq((s"${javaType}[]", "args"), ("int[]", tempVariableName))) + +(s""" +|int[] $tempVariableName = new int[]{0}; +|$assignmentSection +
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r181639154 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +290,231 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ +Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { + + val allowedTypes = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { +if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess +} else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe { +return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + +s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") +} + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + lazy val javaType: String = CodeGenerator.javaType(dataType) + + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = dataType match { +case BinaryType => + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) +case StringType => + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) +case ArrayType(elementType, _) => + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { +null + } else { +val elements = inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(elementType)) +new GenericArrayData(elements) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +val evals = children.map(_.genCode(ctx)) +val args = ctx.freshName("args") + +val inputs = evals.zipWithIndex.map { case (eval, index) => + s""" +${eval.code} +if (!${eval.isNull}) { + $args[$index] = ${eval.value}; +} + """ +} + +val (concatenator, initCode) = dataType match { + case BinaryType => +(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") + case StringType => +("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") + case ArrayType(elementType, _) => +val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForPrimitiveArrayConcat(ctx, elementType) +} else { + genCodeForComplexArrayConcat(ctx, elementType) +} +(arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") +} +val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = inputs, + funcName = "valueConcat", + extraArguments = (s"${javaType}[]", args) :: Nil) +ev.copy(s""" + $initCode + $codes + ${javaType} ${ev.value} = $concatenator.concat($args); + boolean ${ev.isNull} = ${ev.value} == null; +""") + } + + private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { +val tempVariableName = ctx.freshName("tempNumElements") +val numElementsConstant = ctx.freshName("numElements") +val assignments = (0 until children.length) + .map(idx => s"$tempVariableName[0] += args[$idx].numElements();") + +val assignmentSection = ctx.splitExpressions( + expressions = assignments, + funcName = "complexArrayConcat", + arguments = Seq((s"${javaType}[]", "args"), ("int[]", tempVariableName))) + +(s""" +|int[] $tempVariableName = new int[]{0}; +|$assignmentSection +
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r181359247 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +290,231 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ +Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { + + val allowedTypes = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { +if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess +} else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe { +return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + +s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") +} + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + lazy val javaType: String = CodeGenerator.javaType(dataType) + + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = dataType match { +case BinaryType => + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) +case StringType => + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) +case ArrayType(elementType, _) => + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { +null + } else { +val elements = inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(elementType)) --- End diff -- Can we always allocate an array? I think that the total array element size may be overflow in some cases. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r181355000 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +290,231 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ +Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { + + val allowedTypes = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { +if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess +} else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe { +return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + +s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") +} + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + lazy val javaType: String = CodeGenerator.javaType(dataType) + + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = dataType match { +case BinaryType => + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) +case StringType => + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) +case ArrayType(elementType, _) => + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { +null + } else { +val elements = inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(elementType)) +new GenericArrayData(elements) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +val evals = children.map(_.genCode(ctx)) +val args = ctx.freshName("args") + +val inputs = evals.zipWithIndex.map { case (eval, index) => + s""" +${eval.code} +if (!${eval.isNull}) { + $args[$index] = ${eval.value}; +} + """ +} + +val (concatenator, initCode) = dataType match { + case BinaryType => +(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") + case StringType => +("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") + case ArrayType(elementType, _) => +val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForPrimitiveArrayConcat(ctx, elementType) +} else { + genCodeForComplexArrayConcat(ctx, elementType) +} +(arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") +} +val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = inputs, + funcName = "valueConcat", + extraArguments = (s"${javaType}[]", args) :: Nil) +ev.copy(s""" + $initCode + $codes + ${javaType} ${ev.value} = $concatenator.concat($args); + boolean ${ev.isNull} = ${ev.value} == null; +""") + } + + private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { +val tempVariableName = ctx.freshName("tempNumElements") +val numElementsConstant = ctx.freshName("numElements") +val assignments = (0 until children.length) + .map(idx => s"$tempVariableName[0] += args[$idx].numElements();") + +val assignmentSection = ctx.splitExpressions( + expressions = assignments, + funcName = "complexArrayConcat", + arguments = Seq((s"${javaType}[]", "args"), ("int[]", tempVariableName))) + +(s""" +|int[] $tempVariableName = new int[]{0}; +|$assignmentSection +
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r180181355 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +290,191 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary and compatible array columns. + */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.", + examples = """ +Examples: + > SELECT _FUNC_('Spark', 'SQL'); + SparkSQL + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + | [1,2,3,4,5,6] + """) +case class Concat(children: Seq[Expression]) extends Expression { + + val allowedTypes = Seq(StringType, BinaryType, ArrayType) + + override def checkInputDataTypes(): TypeCheckResult = { +if (children.isEmpty) { + TypeCheckResult.TypeCheckSuccess +} else { + val childTypes = children.map(_.dataType) + if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe { +return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have been StringType, BinaryType or ArrayType," + +s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName") +} + } + + override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType) + + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + override def eval(input: InternalRow): Any = dataType match { +case BinaryType => + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) +case StringType => + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs : _*) +case ArrayType(elementType, _) => + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { +null + } else { +val elements = inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(elementType)) +new GenericArrayData(elements) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +val evals = children.map(_.genCode(ctx)) +val args = ctx.freshName("args") + +val inputs = evals.zipWithIndex.map { case (eval, index) => + s""" +${eval.code} +if (!${eval.isNull}) { + $args[$index] = ${eval.value}; +} + """ +} + +val (concatenator, initCode) = dataType match { + case BinaryType => +(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") + case StringType => +("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") + case ArrayType(elementType, _) => +val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForPrimitiveArrayConcat(ctx, elementType) +} else { + genCodeForComplexArrayConcat(ctx) +} +(arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") +} +val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = inputs, + funcName = "valueConcat", + extraArguments = (s"${CodeGenerator.javaType(dataType)}[]", args) :: Nil) +ev.copy(s""" + $initCode + $codes + ${CodeGenerator.javaType(dataType)} ${ev.value} = $concatenator.concat($args); + boolean ${ev.isNull} = ${ev.value} == null; +""") + } + + private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { +val variableName = ctx.freshName("numElements") +val code = (0 until children.length) + .map(idx => s"$variableName += args[$idx].numElements();") + .foldLeft(s"int $variableName = 0;")((acc, s) => acc + "\n" + s) +(code, variableName) + } + + private def nullArgumentProtection() : String = { +children.zipWithIndex + .filter(_._1.nullable) + .map(ci => s"if (args[${ci._2}] == null) return null;") + .mkString("\n") + } + + private def genCodeForPrimitiveArrayConcat(ctx: CodegenContext, elementType: DataType): String = { +
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user mn-mikke commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r178759909 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/functions.scala --- @@ -3046,6 +3036,16 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary columns and arrays of the same time. + * + * @group collection_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat(exprs: Column*): Column = withExpr { UnresolvedConcat(exprs.map(_.expr)) } --- End diff -- @gatorsmile, @HyukjinKwon What is your view on this? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user mn-mikke commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r178759108 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +289,152 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * Concatenates multiple arrays into one. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, ...) - Concatenates multiple arrays into one.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + [1,2,3,4,5,6] + """) +case class ConcatArrays(children: Seq[Expression]) extends Expression with NullSafeEvaluation { + + override def checkInputDataTypes(): TypeCheckResult = { +val arrayCheck = checkInputDataTypesAreArrays +if(arrayCheck.isFailure) arrayCheck +else TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") + } + + private def checkInputDataTypesAreArrays(): TypeCheckResult = + { +val mismatches = children.zipWithIndex.collect { + case (child, idx) if !ArrayType.acceptsType(child.dataType) => +s"argument ${idx + 1} has to be ${ArrayType.simpleString} type, " + + s"however, '${child.sql}' is of ${child.dataType.simpleString} type." +} + +if (mismatches.isEmpty) { + TypeCheckResult.TypeCheckSuccess +} else { + TypeCheckResult.TypeCheckFailure(mismatches.mkString(" ")) +} + } + + override def dataType: ArrayType = +children + .headOption.map(_.dataType.asInstanceOf[ArrayType]) + .getOrElse(ArrayType.defaultConcreteType.asInstanceOf[ArrayType]) + + + override protected def nullSafeEval(inputs: Seq[Any]): Any = { +val elements = inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(dataType.elementType)) +new GenericArrayData(elements) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +nullSafeCodeGen(ctx, ev, arrays => { + val elementType = dataType.elementType + if (CodeGenerator.isPrimitiveType(elementType)) { +genCodeForConcatOfPrimitiveElements(ctx, elementType, arrays, ev.value) + } else { +genCodeForConcatOfComplexElements(ctx, arrays, ev.value) + } +}) + } + + private def genCodeForNumberOfElements( +ctx: CodegenContext, +elements: Seq[String] + ) : (String, String) = { +val variableName = ctx.freshName("numElements") +val code = elements + .map(el => s"$variableName += $el.numElements();") + .foldLeft( s"int $variableName = 0;")((acc, s) => acc + "\n" + s) +(code, variableName) + } + + private def genCodeForConcatOfPrimitiveElements( +ctx: CodegenContext, +elementType: DataType, +elements: Seq[String], +arrayDataName: String + ): String = { +val arrayName = ctx.freshName("array") +val arraySizeName = ctx.freshName("size") +val counter = ctx.freshName("counter") +val tempArrayDataName = ctx.freshName("tempArrayData") + +val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, elements) + +val unsafeArraySizeInBytes = s""" + |int $arraySizeName = UnsafeArrayData.calculateHeaderPortionInBytes($numElemName) + + |${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord( + |${elementType.defaultSize} * $numElemName + |); + """.stripMargin +val baseOffset = Platform.BYTE_ARRAY_OFFSET + +val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) +val assignments = elements.map { el => + s""" +|for(int z = 0; z < $el.numElements(); z++) { +| if($el.isNullAt(z)) { +| $tempArrayDataName.setNullAt($counter); +| } else { +| $tempArrayDataName.set$primitiveValueTypeName( +| $counter, +| $el.get$primitiveValueTypeName(z) +| ); +| } +| $counter++; +|} +""".stripMargin +}.mkString("\n") + +s""" + |$numElemCode + |$unsafeArraySizeInBytes + |byte[] $arrayName = new byte[$arraySizeName]; + |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData(); + |Platform.putLong($arrayName, $baseOffset, $numElemName); + |$tempArrayDataName.pointTo($arrayName, $baseOffset, $arraySizeName); + |int $counter = 0; + |$assignments + |$arrayDataName = $tempArrayDataName; +""".stripMargin
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user mn-mikke commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r178753211 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala --- @@ -699,3 +699,88 @@ abstract class TernaryExpression extends Expression { * and Hive function wrappers. */ trait UserDefinedExpression + +/** + * The trait covers logic for performing null save evaluation and code generation. + */ +trait NullSafeEvaluation extends Expression +{ + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = children.exists(_.nullable) + + /** + * Default behavior of evaluation according to the default nullability of NullSafeEvaluation. + * If a class utilizing NullSaveEvaluation override [[nullable]], probably should also + * override this. + */ + override def eval(input: InternalRow): Any = + { +val values = children.map(_.eval(input)) +if (values.contains(null)) null +else nullSafeEval(values) + } + + /** + * Called by default [[eval]] implementation. If a class utilizing NullSaveEvaluation keep + * the default nullability, they can override this method to save null-check code. If we need + * full control of evaluation process, we should override [[eval]]. + */ + protected def nullSafeEval(inputs: Seq[Any]): Any = +sys.error(s"The class utilizing NullSaveEvaluation must override either eval or nullSafeEval") + + /** + * Short hand for generating of null save evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f accepts a sequence of variable names and returns Java code to compute the output. + */ + protected def defineCodeGen( +ctx: CodegenContext, +ev: ExprCode, +f: Seq[String] => String): ExprCode = { +nullSafeCodeGen(ctx, ev, values => { + s"${ev.value} = ${f(values)};" +}) + } + + /** + * Called by expressions to generate null safe evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f a function that accepts a sequence of non-null evaluation result names of children + * and returns Java code to compute the output. + */ + protected def nullSafeCodeGen( --- End diff -- Ok, will try. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r178711505 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +289,152 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * Concatenates multiple arrays into one. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, ...) - Concatenates multiple arrays into one.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + [1,2,3,4,5,6] + """) +case class ConcatArrays(children: Seq[Expression]) extends Expression with NullSafeEvaluation { + + override def checkInputDataTypes(): TypeCheckResult = { +val arrayCheck = checkInputDataTypesAreArrays +if(arrayCheck.isFailure) arrayCheck +else TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") + } + + private def checkInputDataTypesAreArrays(): TypeCheckResult = + { +val mismatches = children.zipWithIndex.collect { + case (child, idx) if !ArrayType.acceptsType(child.dataType) => +s"argument ${idx + 1} has to be ${ArrayType.simpleString} type, " + + s"however, '${child.sql}' is of ${child.dataType.simpleString} type." +} + +if (mismatches.isEmpty) { + TypeCheckResult.TypeCheckSuccess +} else { + TypeCheckResult.TypeCheckFailure(mismatches.mkString(" ")) +} + } + + override def dataType: ArrayType = +children + .headOption.map(_.dataType.asInstanceOf[ArrayType]) + .getOrElse(ArrayType.defaultConcreteType.asInstanceOf[ArrayType]) + + + override protected def nullSafeEval(inputs: Seq[Any]): Any = { +val elements = inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(dataType.elementType)) +new GenericArrayData(elements) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +nullSafeCodeGen(ctx, ev, arrays => { + val elementType = dataType.elementType + if (CodeGenerator.isPrimitiveType(elementType)) { +genCodeForConcatOfPrimitiveElements(ctx, elementType, arrays, ev.value) + } else { +genCodeForConcatOfComplexElements(ctx, arrays, ev.value) + } +}) + } + + private def genCodeForNumberOfElements( +ctx: CodegenContext, +elements: Seq[String] + ) : (String, String) = { +val variableName = ctx.freshName("numElements") +val code = elements + .map(el => s"$variableName += $el.numElements();") + .foldLeft( s"int $variableName = 0;")((acc, s) => acc + "\n" + s) +(code, variableName) + } + + private def genCodeForConcatOfPrimitiveElements( +ctx: CodegenContext, +elementType: DataType, +elements: Seq[String], +arrayDataName: String + ): String = { +val arrayName = ctx.freshName("array") +val arraySizeName = ctx.freshName("size") +val counter = ctx.freshName("counter") +val tempArrayDataName = ctx.freshName("tempArrayData") + +val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, elements) + +val unsafeArraySizeInBytes = s""" + |int $arraySizeName = UnsafeArrayData.calculateHeaderPortionInBytes($numElemName) + + |${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord( + |${elementType.defaultSize} * $numElemName + |); + """.stripMargin +val baseOffset = Platform.BYTE_ARRAY_OFFSET + +val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) +val assignments = elements.map { el => + s""" +|for(int z = 0; z < $el.numElements(); z++) { +| if($el.isNullAt(z)) { +| $tempArrayDataName.setNullAt($counter); +| } else { +| $tempArrayDataName.set$primitiveValueTypeName( +| $counter, +| $el.get$primitiveValueTypeName(z) +| ); +| } +| $counter++; +|} +""".stripMargin +}.mkString("\n") + +s""" + |$numElemCode + |$unsafeArraySizeInBytes + |byte[] $arrayName = new byte[$arraySizeName]; + |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData(); + |Platform.putLong($arrayName, $baseOffset, $numElemName); + |$tempArrayDataName.pointTo($arrayName, $baseOffset, $arraySizeName); + |int $counter = 0; + |$assignments + |$arrayDataName = $tempArrayDataName; +""".stripMargin +
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r178710167 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala --- @@ -700,3 +700,88 @@ abstract class TernaryExpression extends Expression { * and Hive function wrappers. */ trait UserDefinedExpression + +/** + * The trait covers logic for performing null safe evaluation and code generation. + */ +trait NullSafeEvaluation extends Expression --- End diff -- nit: `trait NullSafeEvaluation extends Expression {` --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r178708207 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala --- @@ -699,3 +699,88 @@ abstract class TernaryExpression extends Expression { * and Hive function wrappers. */ trait UserDefinedExpression + +/** + * The trait covers logic for performing null save evaluation and code generation. + */ +trait NullSafeEvaluation extends Expression +{ + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = children.exists(_.nullable) + + /** + * Default behavior of evaluation according to the default nullability of NullSafeEvaluation. + * If a class utilizing NullSaveEvaluation override [[nullable]], probably should also + * override this. + */ + override def eval(input: InternalRow): Any = + { +val values = children.map(_.eval(input)) +if (values.contains(null)) null +else nullSafeEval(values) + } + + /** + * Called by default [[eval]] implementation. If a class utilizing NullSaveEvaluation keep + * the default nullability, they can override this method to save null-check code. If we need + * full control of evaluation process, we should override [[eval]]. + */ + protected def nullSafeEval(inputs: Seq[Any]): Any = +sys.error(s"The class utilizing NullSaveEvaluation must override either eval or nullSafeEval") + + /** + * Short hand for generating of null save evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f accepts a sequence of variable names and returns Java code to compute the output. + */ + protected def defineCodeGen( +ctx: CodegenContext, +ev: ExprCode, +f: Seq[String] => String): ExprCode = { +nullSafeCodeGen(ctx, ev, values => { + s"${ev.value} = ${f(values)};" +}) + } + + /** + * Called by expressions to generate null safe evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f a function that accepts a sequence of non-null evaluation result names of children + * and returns Java code to compute the output. + */ + protected def nullSafeCodeGen( --- End diff -- This is a refactoring issue, so I think we should discuss in a separate ticket. Can you make this pr minimal as much as possible? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r178701403 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala --- @@ -408,6 +407,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), +expression[UnresolvedConcat]("concat"), --- End diff -- I thinks we should not put unresolved expr here. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r178700701 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +289,166 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * Replaces [[org.apache.spark.sql.catalyst.analysis.UnresolvedConcat UnresolvedConcat]]s + * with concrete concate expressions. + */ +object ResolveConcat +{ --- End diff -- nit: `object ResolveConcat {` --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r178700550 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/functions.scala --- @@ -3046,6 +3036,16 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Concatenates multiple input columns together into a single column. + * The function works with strings, binary columns and arrays of the same time. + * + * @group collection_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def concat(exprs: Column*): Column = withExpr { UnresolvedConcat(exprs.map(_.expr)) } --- End diff -- If you want to use the existing `concat` to merge arrays, I feel it'd be better to implement a new logic to merge arrays in `Concat`. I think this approach could remove `UnresolvedConcat`, too. Thought? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user mn-mikke commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r177419513 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +289,152 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * Concatenates multiple arrays into one. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, ...) - Concatenates multiple arrays into one.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + [1,2,3,4,5,6] + """) +case class ConcatArrays(children: Seq[Expression]) extends Expression with NullSafeEvaluation { + + override def checkInputDataTypes(): TypeCheckResult = { +val arrayCheck = checkInputDataTypesAreArrays +if(arrayCheck.isFailure) arrayCheck +else TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") + } + + private def checkInputDataTypesAreArrays(): TypeCheckResult = + { +val mismatches = children.zipWithIndex.collect { + case (child, idx) if !ArrayType.acceptsType(child.dataType) => +s"argument ${idx + 1} has to be ${ArrayType.simpleString} type, " + + s"however, '${child.sql}' is of ${child.dataType.simpleString} type." +} + +if (mismatches.isEmpty) { + TypeCheckResult.TypeCheckSuccess +} else { + TypeCheckResult.TypeCheckFailure(mismatches.mkString(" ")) +} + } + + override def dataType: ArrayType = +children + .headOption.map(_.dataType.asInstanceOf[ArrayType]) + .getOrElse(ArrayType.defaultConcreteType.asInstanceOf[ArrayType]) --- End diff -- Ok, changing to return type `array` when no children are provided. Also I've created the jira ticket [SPARK-23798](https://issues.apache.org/jira/browse/SPARK-23798) since I don't see any reason why it couldn't return a default concrete type in this case. Hope I don't miss anything. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user mn-mikke commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r177351705 --- Diff: python/pyspark/sql/functions.py --- @@ -1834,6 +1819,25 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(1.5) +@ignore_unicode_prefix +def concat(*cols): +""" +Concatenates multiple input columns together into a single column. +The function works with strings, binary columns and arrays of the same time. + +>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) +>>> df.select(concat(df.s, df.d).alias('s')).collect() +[Row(s=u'abcd123')] + +>>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c']) +>>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect() +[Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)] +""" +sc = SparkContext._active_spark_context +return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column))) --- End diff -- The whole file is divide into sections according to groups of functions. Based on @gatorsmile's suggestion, the concat function should be categorized as a collection function. So I moved the function to comply with the file structure. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r177279540 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +289,152 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +/** + * Concatenates multiple arrays into one. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, ...) - Concatenates multiple arrays into one.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6)); + [1,2,3,4,5,6] + """) +case class ConcatArrays(children: Seq[Expression]) extends Expression with NullSafeEvaluation { + + override def checkInputDataTypes(): TypeCheckResult = { +val arrayCheck = checkInputDataTypesAreArrays +if(arrayCheck.isFailure) arrayCheck +else TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName") + } + + private def checkInputDataTypesAreArrays(): TypeCheckResult = + { +val mismatches = children.zipWithIndex.collect { + case (child, idx) if !ArrayType.acceptsType(child.dataType) => +s"argument ${idx + 1} has to be ${ArrayType.simpleString} type, " + + s"however, '${child.sql}' is of ${child.dataType.simpleString} type." +} + +if (mismatches.isEmpty) { + TypeCheckResult.TypeCheckSuccess +} else { + TypeCheckResult.TypeCheckFailure(mismatches.mkString(" ")) +} + } + + override def dataType: ArrayType = +children + .headOption.map(_.dataType.asInstanceOf[ArrayType]) + .getOrElse(ArrayType.defaultConcreteType.asInstanceOf[ArrayType]) --- End diff -- Hm .. but then this is `array` when the children are empty. Seems `CreateArray`'s type is `array` in this case. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r177278880 --- Diff: python/pyspark/sql/functions.py --- @@ -1834,6 +1819,25 @@ def array_contains(col, value): return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) +@since(1.5) +@ignore_unicode_prefix +def concat(*cols): +""" +Concatenates multiple input columns together into a single column. +The function works with strings, binary columns and arrays of the same time. + +>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd']) +>>> df.select(concat(df.s, df.d).alias('s')).collect() +[Row(s=u'abcd123')] + +>>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c']) +>>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect() +[Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)] +""" +sc = SparkContext._active_spark_context +return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column))) --- End diff -- Why did we move this down .. ? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #20858: [SPARK-23736][SQL] Extending the concat function ...
Github user HyukjinKwon commented on a diff in the pull request: https://github.com/apache/spark/pull/20858#discussion_r177278671 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala --- @@ -699,3 +699,88 @@ abstract class TernaryExpression extends Expression { * and Hive function wrappers. */ trait UserDefinedExpression + +/** + * The trait covers logic for performing null save evaluation and code generation. + */ +trait NullSafeEvaluation extends Expression +{ + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = children.exists(_.nullable) + + /** + * Default behavior of evaluation according to the default nullability of NullSafeEvaluation. + * If a class utilizing NullSaveEvaluation override [[nullable]], probably should also + * override this. + */ + override def eval(input: InternalRow): Any = + { --- End diff -- Seems the style fix is missed here. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org