Repository: spark
Updated Branches:
refs/heads/master b3fde5a41 -> e6b466084
[SPARK-23736][SQL] Extending the concat function to support array columns
## What changes were proposed in this pull request?
The PR adds a logic for easy concatenation of multiple array columns and covers:
- Concat expression has been extended to support array columns
- A Python wrapper
## How was this patch tested?
New tests added into:
- CollectionExpressionsSuite
- DataFrameFunctionsSuite
- typeCoercion/native/concat.sql
## Codegen examples
### Primitive-type elements
```
val df = Seq(
(Seq(1 ,2), Seq(3, 4)),
(Seq(1, 2, 3), null)
).toDF("a", "b")
df.filter('a.isNotNull).select(concat('a, 'b)).debugCodegen()
```
Result:
```
/* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 034 */ ArrayData inputadapter_value = inputadapter_isNull ?
/* 035 */ null : (inputadapter_row.getArray(0));
/* 036 */
/* 037 */ if (!(!inputadapter_isNull)) continue;
/* 038 */
/* 039 */ ((org.apache.spark.sql.execution.metric.SQLMetric)
references[0] /* numOutputRows */).add(1);
/* 040 */
/* 041 */ ArrayData[] project_args = new ArrayData[2];
/* 042 */
/* 043 */ if (!false) {
/* 044 */ project_args[0] = inputadapter_value;
/* 045 */ }
/* 046 */
/* 047 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1);
/* 048 */ ArrayData inputadapter_value1 = inputadapter_isNull1 ?
/* 049 */ null : (inputadapter_row.getArray(1));
/* 050 */ if (!inputadapter_isNull1) {
/* 051 */ project_args[1] = inputadapter_value1;
/* 052 */ }
/* 053 */
/* 054 */ ArrayData project_value = new Object() {
/* 055 */ public ArrayData concat(ArrayData[] args) {
/* 056 */ for (int z = 0; z < 2; z++) {
/* 057 */ if (args[z] == null) return null;
/* 058 */ }
/* 059 */
/* 060 */ long project_numElements = 0L;
/* 061 */ for (int z = 0; z < 2; z++) {
/* 062 */ project_numElements += args[z].numElements();
/* 063 */ }
/* 064 */ if (project_numElements > 2147483632) {
/* 065 */ throw new RuntimeException("Unsuccessful try to concat
arrays with " + project_numElements +
/* 066 */ " elements due to exceeding the array size limit
2147483632.");
/* 067 */ }
/* 068 */
/* 069 */ long project_size =
UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
/* 070 */ project_numElements,
/* 071 */ 4);
/* 072 */ if (project_size > 2147483632) {
/* 073 */ throw new RuntimeException("Unsuccessful try to concat
arrays with " + project_size +
/* 074 */ " bytes of data due to exceeding the limit 2147483632
bytes" +
/* 075 */ " for UnsafeArrayData.");
/* 076 */ }
/* 077 */
/* 078 */ byte[] project_array = new byte[(int)project_size];
/* 079 */ UnsafeArrayData project_arrayData = new UnsafeArrayData();
/* 080 */ Platform.putLong(project_array, 16, project_numElements);
/* 081 */ project_arrayData.pointTo(project_array, 16,
(int)project_size);
/* 082 */ int project_counter = 0;
/* 083 */ for (int y = 0; y < 2; y++) {
/* 084 */ for (int z = 0; z < args[y].numElements(); z++) {
/* 085 */ if (args[y].isNullAt(z)) {
/* 086 */ project_arrayData.setNullAt(project_counter);
/* 087 */ } else {
/* 088 */ project_arrayData.setInt(
/* 089 */ project_counter,
/* 090 */ args[y].getInt(z)
/* 091 */ );
/* 092 */ }
/* 093 */ project_counter++;
/* 094 */ }
/* 095 */ }
/* 096 */ return project_arrayData;
/* 097 */ }
/* 098 */ }.concat(project_args);
/* 099 */ boolean project_isNull = project_value == null;
```
### Non-primitive-type elements
```
val df = Seq(
(Seq("aa" ,"bb"), Seq("ccc", "ddd")),
(Seq("x", "y"), null)
).toDF("a", "b")
df.filter('a.isNotNull).select(concat('a, 'b)).debugCodegen()
```
Result:
```
/* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 034 */ ArrayData inputadapter_value = inputadapter_isNull ?
/* 035 */ null : (inputadapter_row.getArray(0));
/* 036 */
/* 037 */ if (!(!inputadapter_isNull)) continue;
/* 038 */
/* 039 */ ((org.apache.spark.sql.execution.metric.SQLMetric)
references[0] /* numOutputRows */).add(1);
/* 040 */
/* 041 */ ArrayData[] project_args = new ArrayData[2];
/* 042 */
/* 043 */ if (!false) {
/* 044 */ project_args[0] = inputadapter_value;
/* 045 */ }
/* 046 */
/* 047 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1);
/* 048 */ ArrayData inputadapter_value1 = inputadapter_isNull1 ?
/* 049 */ null : (inputadapter_row.getArray(1));
/* 050 */ if (!inputadapter_isNull1) {
/* 051 */ project_args[1] = inputadapter_value1;
/* 052 */ }
/* 053 */
/* 054 */ ArrayData project_value = new Object() {
/* 055 */ public ArrayData concat(ArrayData[] args) {
/* 056 */ for (int z = 0; z < 2; z++) {
/* 057 */ if (args[z] == null) return null;
/* 058 */ }
/* 059 */
/* 060 */ long project_numElements = 0L;
/* 061 */ for (int z = 0; z < 2; z++) {
/* 062 */ project_numElements += args[z].numElements();
/* 063 */ }
/* 064 */ if (project_numElements > 2147483632) {
/* 065 */ throw new RuntimeException("Unsuccessful try to concat
arrays with " + project_numElements +
/* 066 */ " elements due to exceeding the array size limit
2147483632.");
/* 067 */ }
/* 068 */
/* 069 */ Object[] project_arrayObjects = new
Object[(int)project_numElements];
/* 070 */ int project_counter = 0;
/* 071 */ for (int y = 0; y < 2; y++) {
/* 072 */ for (int z = 0; z < args[y].numElements(); z++) {
/* 073 */ project_arrayObjects[project_counter] =
args[y].getUTF8String(z);
/* 074 */ project_counter++;
/* 075 */ }
/* 076 */ }
/* 077 */ return new
org.apache.spark.sql.catalyst.util.GenericArrayData(project_arrayObjects);
/* 078 */ }
/* 079 */ }.concat(project_args);
/* 080 */ boolean project_isNull = project_value == null;
```
Author: mn-mikke <mrkAha12346github>
Closes #20858 from mn-mikke/feature/array-api-concat_arrays-to-master.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e6b46608
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e6b46608
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e6b46608
Branch: refs/heads/master
Commit: e6b466084c26fbb9b9e50dd5cc8b25da7533ac72
Parents: b3fde5a
Author: mn-mikke <mrkAha12346github>
Authored: Fri Apr 20 14:58:11 2018 +0900
Committer: Takuya UESHIN <[email protected]>
Committed: Fri Apr 20 14:58:11 2018 +0900
----------------------------------------------------------------------
.../spark/unsafe/array/ByteArrayMethods.java | 6 +-
python/pyspark/sql/functions.py | 34 +--
.../catalyst/expressions/UnsafeArrayData.java | 10 +
.../catalyst/analysis/FunctionRegistry.scala | 2 +-
.../sql/catalyst/analysis/TypeCoercion.scala | 8 +
.../expressions/collectionOperations.scala | 220 ++++++++++++++++++-
.../expressions/stringExpressions.scala | 81 -------
.../CollectionExpressionsSuite.scala | 41 ++++
.../scala/org/apache/spark/sql/functions.scala | 20 +-
.../inputs/typeCoercion/native/concat.sql | 62 ++++++
.../results/typeCoercion/native/concat.sql.out | 78 +++++++
.../spark/sql/DataFrameFunctionsSuite.scala | 74 +++++++
.../spark/sql/execution/command/DDLSuite.scala | 4 +-
13 files changed, 529 insertions(+), 111 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
----------------------------------------------------------------------
diff --git
a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
index 4bc9955..ef0f78d 100644
---
a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
+++
b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
@@ -33,7 +33,11 @@ public class ByteArrayMethods {
}
public static int roundNumberOfBytesToNearestWord(int numBytes) {
- int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8`
+ return (int)roundNumberOfBytesToNearestWord((long)numBytes);
+ }
+
+ public static long roundNumberOfBytesToNearestWord(long numBytes) {
+ long remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8`
if (remainder == 0) {
return numBytes;
} else {
http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 1be68f2..da32ab2 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1427,21 +1427,6 @@ del _name, _doc
@since(1.5)
@ignore_unicode_prefix
-def concat(*cols):
- """
- Concatenates multiple input columns together into a single column.
- If all inputs are binary, concat returns an output as binary. Otherwise,
it returns as string.
-
- >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
- >>> df.select(concat(df.s, df.d).alias('s')).collect()
- [Row(s=u'abcd123')]
- """
- sc = SparkContext._active_spark_context
- return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))
-
-
-@since(1.5)
-@ignore_unicode_prefix
def concat_ws(sep, *cols):
"""
Concatenates multiple input string columns together into a single string
column,
@@ -1845,6 +1830,25 @@ def array_contains(col, value):
return Column(sc._jvm.functions.array_contains(_to_java_column(col),
value))
+@since(1.5)
+@ignore_unicode_prefix
+def concat(*cols):
+ """
+ Concatenates multiple input columns together into a single column.
+ The function works with strings, binary and compatible array columns.
+
+ >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
+ >>> df.select(concat(df.s, df.d).alias('s')).collect()
+ [Row(s=u'abcd123')]
+
+ >>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None,
[3])], ['a', 'b', 'c'])
+ >>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect()
+ [Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))
+
+
@since(2.4)
def array_position(col, value):
"""
http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index 8546c28..d5d934b 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -56,9 +56,19 @@ import org.apache.spark.unsafe.types.UTF8String;
public final class UnsafeArrayData extends ArrayData {
public static int calculateHeaderPortionInBytes(int numFields) {
+ return (int)calculateHeaderPortionInBytes((long)numFields);
+ }
+
+ public static long calculateHeaderPortionInBytes(long numFields) {
return 8 + ((numFields + 63)/ 64) * 8;
}
+ public static long calculateSizeOfUnderlyingByteArray(long numFields, int
elementSize) {
+ long size = UnsafeArrayData.calculateHeaderPortionInBytes(numFields) +
+ ByteArrayMethods.roundNumberOfBytesToNearestWord(numFields *
elementSize);
+ return size;
+ }
+
private Object baseObject;
private long baseOffset;
http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index a44f2d5..c41f16c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -308,7 +308,6 @@ object FunctionRegistry {
expression[BitLength]("bit_length"),
expression[Length]("char_length"),
expression[Length]("character_length"),
- expression[Concat]("concat"),
expression[ConcatWs]("concat_ws"),
expression[Decode]("decode"),
expression[Elt]("elt"),
@@ -413,6 +412,7 @@ object FunctionRegistry {
expression[ArrayMin]("array_min"),
expression[ArrayMax]("array_max"),
expression[Reverse]("reverse"),
+ expression[Concat]("concat"),
CreateStruct.registryEntry,
// misc functions
http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 281f206..cfcbd8d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -520,6 +520,14 @@ object TypeCoercion {
case None => a
}
+ case c @ Concat(children) if children.forall(c =>
ArrayType.acceptsType(c.dataType)) &&
+ !haveSameType(children) =>
+ val types = children.map(_.dataType)
+ findWiderCommonType(types) match {
+ case Some(finalDataType) => Concat(children.map(Cast(_,
finalDataType)))
+ case None => c
+ }
+
case m @ CreateMap(children) if m.keys.length == m.values.length &&
(!haveSameType(m.keys) || !haveSameType(m.values)) =>
val newKeys = if (haveSameType(m.keys)) {
http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index dba426e..c16793b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -23,7 +23,9 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData,
MapData, TypeUtils}
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.array.ByteArrayMethods
+import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
/**
* Given an array or map, returns its size. Returns -1 if null.
@@ -665,3 +667,219 @@ case class ElementAt(left: Expression, right: Expression)
extends GetMapValueUti
override def prettyName: String = "element_at"
}
+
+/**
+ * Concatenates multiple input columns together into a single column.
+ * The function works with strings, binary and compatible array columns.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1,
col2, ..., colN.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_('Spark', 'SQL');
+ SparkSQL
+ > SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
+ | [1,2,3,4,5,6]
+ """)
+case class Concat(children: Seq[Expression]) extends Expression {
+
+ private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
+
+ val allowedTypes = Seq(StringType, BinaryType, ArrayType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (children.isEmpty) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ val childTypes = children.map(_.dataType)
+ if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) {
+ return TypeCheckResult.TypeCheckFailure(
+ s"input to function $prettyName should have been StringType,
BinaryType or ArrayType," +
+ s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ",
"]"))
+ }
+ TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName")
+ }
+ }
+
+ override def dataType: DataType =
children.map(_.dataType).headOption.getOrElse(StringType)
+
+ lazy val javaType: String = CodeGenerator.javaType(dataType)
+
+ override def nullable: Boolean = children.exists(_.nullable)
+
+ override def foldable: Boolean = children.forall(_.foldable)
+
+ override def eval(input: InternalRow): Any = dataType match {
+ case BinaryType =>
+ val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
+ ByteArray.concat(inputs: _*)
+ case StringType =>
+ val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
+ UTF8String.concat(inputs : _*)
+ case ArrayType(elementType, _) =>
+ val inputs = children.toStream.map(_.eval(input))
+ if (inputs.contains(null)) {
+ null
+ } else {
+ val arrayData = inputs.map(_.asInstanceOf[ArrayData])
+ val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum +
ad.numElements())
+ if (numberOfElements > MAX_ARRAY_LENGTH) {
+ throw new RuntimeException(s"Unsuccessful try to concat arrays with
$numberOfElements" +
+ s" elements due to exceeding the array size limit
$MAX_ARRAY_LENGTH.")
+ }
+ val finalData = new Array[AnyRef](numberOfElements.toInt)
+ var position = 0
+ for(ad <- arrayData) {
+ val arr = ad.toObjectArray(elementType)
+ Array.copy(arr, 0, finalData, position, arr.length)
+ position += arr.length
+ }
+ new GenericArrayData(finalData)
+ }
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode = {
+ val evals = children.map(_.genCode(ctx))
+ val args = ctx.freshName("args")
+
+ val inputs = evals.zipWithIndex.map { case (eval, index) =>
+ s"""
+ ${eval.code}
+ if (!${eval.isNull}) {
+ $args[$index] = ${eval.value};
+ }
+ """
+ }
+
+ val (concatenator, initCode) = dataType match {
+ case BinaryType =>
+ (classOf[ByteArray].getName, s"byte[][] $args = new
byte[${evals.length}][];")
+ case StringType =>
+ ("UTF8String", s"UTF8String[] $args = new
UTF8String[${evals.length}];")
+ case ArrayType(elementType, _) =>
+ val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType))
{
+ genCodeForPrimitiveArrays(ctx, elementType)
+ } else {
+ genCodeForNonPrimitiveArrays(ctx, elementType)
+ }
+ (arrayConcatClass, s"ArrayData[] $args = new
ArrayData[${evals.length}];")
+ }
+ val codes = ctx.splitExpressionsWithCurrentInputs(
+ expressions = inputs,
+ funcName = "valueConcat",
+ extraArguments = (s"$javaType[]", args) :: Nil)
+ ev.copy(s"""
+ $initCode
+ $codes
+ $javaType ${ev.value} = $concatenator.concat($args);
+ boolean ${ev.isNull} = ${ev.value} == null;
+ """)
+ }
+
+ private def genCodeForNumberOfElements(ctx: CodegenContext) : (String,
String) = {
+ val numElements = ctx.freshName("numElements")
+ val code = s"""
+ |long $numElements = 0L;
+ |for (int z = 0; z < ${children.length}; z++) {
+ | $numElements += args[z].numElements();
+ |}
+ |if ($numElements > $MAX_ARRAY_LENGTH) {
+ | throw new RuntimeException("Unsuccessful try to concat arrays with
" + $numElements +
+ | " elements due to exceeding the array size limit
$MAX_ARRAY_LENGTH.");
+ |}
+ """.stripMargin
+
+ (code, numElements)
+ }
+
+ private def nullArgumentProtection() : String = {
+ if (nullable) {
+ s"""
+ |for (int z = 0; z < ${children.length}; z++) {
+ | if (args[z] == null) return null;
+ |}
+ """.stripMargin
+ } else {
+ ""
+ }
+ }
+
+ private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType:
DataType): String = {
+ val arrayName = ctx.freshName("array")
+ val arraySizeName = ctx.freshName("size")
+ val counter = ctx.freshName("counter")
+ val arrayData = ctx.freshName("arrayData")
+
+ val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
+
+ val unsafeArraySizeInBytes = s"""
+ |long $arraySizeName =
UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
+ | $numElemName,
+ | ${elementType.defaultSize});
+ |if ($arraySizeName > $MAX_ARRAY_LENGTH) {
+ | throw new RuntimeException("Unsuccessful try to concat arrays with "
+ $arraySizeName +
+ | " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes"
+
+ | " for UnsafeArrayData.");
+ |}
+ """.stripMargin
+ val baseOffset = Platform.BYTE_ARRAY_OFFSET
+ val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
+
+ s"""
+ |new Object() {
+ | public ArrayData concat($javaType[] args) {
+ | ${nullArgumentProtection()}
+ | $numElemCode
+ | $unsafeArraySizeInBytes
+ | byte[] $arrayName = new byte[(int)$arraySizeName];
+ | UnsafeArrayData $arrayData = new UnsafeArrayData();
+ | Platform.putLong($arrayName, $baseOffset, $numElemName);
+ | $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName);
+ | int $counter = 0;
+ | for (int y = 0; y < ${children.length}; y++) {
+ | for (int z = 0; z < args[y].numElements(); z++) {
+ | if (args[y].isNullAt(z)) {
+ | $arrayData.setNullAt($counter);
+ | } else {
+ | $arrayData.set$primitiveValueTypeName(
+ | $counter,
+ | ${CodeGenerator.getValue(s"args[y]", elementType, "z")}
+ | );
+ | }
+ | $counter++;
+ | }
+ | }
+ | return $arrayData;
+ | }
+ |}""".stripMargin.stripPrefix("\n")
+ }
+
+ private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType:
DataType): String = {
+ val genericArrayClass = classOf[GenericArrayData].getName
+ val arrayData = ctx.freshName("arrayObjects")
+ val counter = ctx.freshName("counter")
+
+ val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
+
+ s"""
+ |new Object() {
+ | public ArrayData concat($javaType[] args) {
+ | ${nullArgumentProtection()}
+ | $numElemCode
+ | Object[] $arrayData = new Object[(int)$numElemName];
+ | int $counter = 0;
+ | for (int y = 0; y < ${children.length}; y++) {
+ | for (int z = 0; z < args[y].numElements(); z++) {
+ | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]",
elementType, "z")};
+ | $counter++;
+ | }
+ | }
+ | return new $genericArrayClass($arrayData);
+ | }
+ |}""".stripMargin.stripPrefix("\n")
+ }
+
+ override def toString: String = s"concat(${children.mkString(", ")})"
+
+ override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 5a02ca0..ea005a2 100755
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -37,87 +37,6 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
/**
- * An expression that concatenates multiple inputs into a single output.
- * If all inputs are binary, concat returns an output as binary. Otherwise, it
returns as string.
- * If any input is null, concat returns null.
- */
-@ExpressionDescription(
- usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of str1,
str2, ..., strN.",
- examples = """
- Examples:
- > SELECT _FUNC_('Spark', 'SQL');
- SparkSQL
- """)
-case class Concat(children: Seq[Expression]) extends Expression {
-
- private lazy val isBinaryMode: Boolean = dataType == BinaryType
-
- override def checkInputDataTypes(): TypeCheckResult = {
- if (children.isEmpty) {
- TypeCheckResult.TypeCheckSuccess
- } else {
- val childTypes = children.map(_.dataType)
- if (childTypes.exists(tpe => !Seq(StringType,
BinaryType).contains(tpe))) {
- return TypeCheckResult.TypeCheckFailure(
- s"input to function $prettyName should have StringType or
BinaryType, but it's " +
- childTypes.map(_.simpleString).mkString("[", ", ", "]"))
- }
- TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName")
- }
- }
-
- override def dataType: DataType =
children.map(_.dataType).headOption.getOrElse(StringType)
-
- override def nullable: Boolean = children.exists(_.nullable)
- override def foldable: Boolean = children.forall(_.foldable)
-
- override def eval(input: InternalRow): Any = {
- if (isBinaryMode) {
- val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
- ByteArray.concat(inputs: _*)
- } else {
- val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
- UTF8String.concat(inputs : _*)
- }
- }
-
- override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode = {
- val evals = children.map(_.genCode(ctx))
- val args = ctx.freshName("args")
-
- val inputs = evals.zipWithIndex.map { case (eval, index) =>
- s"""
- ${eval.code}
- if (!${eval.isNull}) {
- $args[$index] = ${eval.value};
- }
- """
- }
-
- val (concatenator, initCode) = if (isBinaryMode) {
- (classOf[ByteArray].getName, s"byte[][] $args = new
byte[${evals.length}][];")
- } else {
- ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
- }
- val codes = ctx.splitExpressionsWithCurrentInputs(
- expressions = inputs,
- funcName = "valueConcat",
- extraArguments = (s"${CodeGenerator.javaType(dataType)}[]", args) :: Nil)
- ev.copy(s"""
- $initCode
- $codes
- ${CodeGenerator.javaType(dataType)} ${ev.value} =
$concatenator.concat($args);
- boolean ${ev.isNull} = ${ev.value} == null;
- """)
- }
-
- override def toString: String = s"concat(${children.mkString(", ")})"
-
- override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
-}
-
-
-/**
* An expression that concatenates multiple input strings or array of strings
into a single string,
* using a given separator (the first child).
*
http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 7d8fe21..43c5dda 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -239,4 +239,45 @@ class CollectionExpressionsSuite extends SparkFunSuite
with ExpressionEvalHelper
checkEvaluation(ElementAt(m2, Literal("a")), null)
}
+
+ test("Concat") {
+ // Primitive-type elements
+ val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
+ val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType))
+ val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType))
+ val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType))
+ val ai4 = Literal.create(null, ArrayType(IntegerType))
+
+ checkEvaluation(Concat(Seq(ai0)), Seq(1, 2, 3))
+ checkEvaluation(Concat(Seq(ai0, ai1)), Seq(1, 2, 3))
+ checkEvaluation(Concat(Seq(ai1, ai0)), Seq(1, 2, 3))
+ checkEvaluation(Concat(Seq(ai0, ai0)), Seq(1, 2, 3, 1, 2, 3))
+ checkEvaluation(Concat(Seq(ai0, ai2)), Seq(1, 2, 3, 4, null, 5))
+ checkEvaluation(Concat(Seq(ai0, ai3, ai2)), Seq(1, 2, 3, null, null, 4,
null, 5))
+ checkEvaluation(Concat(Seq(ai4)), null)
+ checkEvaluation(Concat(Seq(ai0, ai4)), null)
+ checkEvaluation(Concat(Seq(ai4, ai0)), null)
+
+ // Non-primitive-type elements
+ val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType))
+ val as1 = Literal.create(Seq.empty[String], ArrayType(StringType))
+ val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType))
+ val as3 = Literal.create(Seq(null, null), ArrayType(StringType))
+ val as4 = Literal.create(null, ArrayType(StringType))
+
+ val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")),
ArrayType(ArrayType(StringType)))
+ val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")),
ArrayType(ArrayType(StringType)))
+
+ checkEvaluation(Concat(Seq(as0)), Seq("a", "b", "c"))
+ checkEvaluation(Concat(Seq(as0, as1)), Seq("a", "b", "c"))
+ checkEvaluation(Concat(Seq(as1, as0)), Seq("a", "b", "c"))
+ checkEvaluation(Concat(Seq(as0, as0)), Seq("a", "b", "c", "a", "b", "c"))
+ checkEvaluation(Concat(Seq(as0, as2)), Seq("a", "b", "c", "d", null, "e"))
+ checkEvaluation(Concat(Seq(as0, as3, as2)), Seq("a", "b", "c", null, null,
"d", null, "e"))
+ checkEvaluation(Concat(Seq(as4)), null)
+ checkEvaluation(Concat(Seq(as0, as4)), null)
+ checkEvaluation(Concat(Seq(as4, as0)), null)
+
+ checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"),
Seq("d"), Seq("e", "f")))
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 9c85803..bea8c0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2229,16 +2229,6 @@ object functions {
def base64(e: Column): Column = withExpr { Base64(e.expr) }
/**
- * Concatenates multiple input columns together into a single column.
- * If all inputs are binary, concat returns an output as binary. Otherwise,
it returns as string.
- *
- * @group string_funcs
- * @since 1.5.0
- */
- @scala.annotation.varargs
- def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) }
-
- /**
* Concatenates multiple input string columns together into a single string
column,
* using the given separator.
*
@@ -3039,6 +3029,16 @@ object functions {
}
/**
+ * Concatenates multiple input columns together into a single column.
+ * The function works with strings, binary and compatible array columns.
+ *
+ * @group collection_funcs
+ * @since 1.5.0
+ */
+ @scala.annotation.varargs
+ def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) }
+
+ /**
* Locates the position of the first occurrence of the value in the given
array as long.
* Returns null if either of the arguments are null.
*
http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql
b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql
index 0beebec..db00a18 100644
---
a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql
+++
b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql
@@ -91,3 +91,65 @@ FROM (
encode(string(id + 3), 'utf-8') col4
FROM range(10)
);
+
+CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES (
+ array(true, false), array(true),
+ array(2Y, 1Y), array(3Y, 4Y),
+ array(2S, 1S), array(3S, 4S),
+ array(2, 1), array(3, 4),
+ array(2L, 1L), array(3L, 4L),
+ array(9223372036854775809, 9223372036854775808), array(9223372036854775808,
9223372036854775809),
+ array(2.0D, 1.0D), array(3.0D, 4.0D),
+ array(float(2.0), float(1.0)), array(float(3.0), float(4.0)),
+ array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date
'2016-03-11'),
+ array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12
20:54:00.000'),
+ array(timestamp '2016-11-11 20:54:00.000'),
+ array('a', 'b'), array('c', 'd'),
+ array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')),
+ array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)),
+ array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4))
+) AS various_arrays(
+ boolean_array1, boolean_array2,
+ tinyint_array1, tinyint_array2,
+ smallint_array1, smallint_array2,
+ int_array1, int_array2,
+ bigint_array1, bigint_array2,
+ decimal_array1, decimal_array2,
+ double_array1, double_array2,
+ float_array1, float_array2,
+ date_array1, data_array2,
+ timestamp_array1, timestamp_array2,
+ string_array1, string_array2,
+ array_array1, array_array2,
+ struct_array1, struct_array2,
+ map_array1, map_array2
+);
+
+-- Concatenate arrays of the same type
+SELECT
+ (boolean_array1 || boolean_array2) boolean_array,
+ (tinyint_array1 || tinyint_array2) tinyint_array,
+ (smallint_array1 || smallint_array2) smallint_array,
+ (int_array1 || int_array2) int_array,
+ (bigint_array1 || bigint_array2) bigint_array,
+ (decimal_array1 || decimal_array2) decimal_array,
+ (double_array1 || double_array2) double_array,
+ (float_array1 || float_array2) float_array,
+ (date_array1 || data_array2) data_array,
+ (timestamp_array1 || timestamp_array2) timestamp_array,
+ (string_array1 || string_array2) string_array,
+ (array_array1 || array_array2) array_array,
+ (struct_array1 || struct_array2) struct_array,
+ (map_array1 || map_array2) map_array
+FROM various_arrays;
+
+-- Concatenate arrays of different types
+SELECT
+ (tinyint_array1 || smallint_array2) ts_array,
+ (smallint_array1 || int_array2) si_array,
+ (int_array1 || bigint_array2) ib_array,
+ (double_array1 || float_array2) df_array,
+ (string_array1 || data_array2) std_array,
+ (timestamp_array1 || string_array2) tst_array,
+ (string_array1 || int_array2) sti_array
+FROM various_arrays;
http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out
b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out
index 09729fd..62befc5 100644
---
a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out
+++
b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out
@@ -237,3 +237,81 @@ struct<col:binary>
78910
891011
9101112
+
+
+-- !query 11
+CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES (
+ array(true, false), array(true),
+ array(2Y, 1Y), array(3Y, 4Y),
+ array(2S, 1S), array(3S, 4S),
+ array(2, 1), array(3, 4),
+ array(2L, 1L), array(3L, 4L),
+ array(9223372036854775809, 9223372036854775808), array(9223372036854775808,
9223372036854775809),
+ array(2.0D, 1.0D), array(3.0D, 4.0D),
+ array(float(2.0), float(1.0)), array(float(3.0), float(4.0)),
+ array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date
'2016-03-11'),
+ array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12
20:54:00.000'),
+ array(timestamp '2016-11-11 20:54:00.000'),
+ array('a', 'b'), array('c', 'd'),
+ array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')),
+ array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)),
+ array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4))
+) AS various_arrays(
+ boolean_array1, boolean_array2,
+ tinyint_array1, tinyint_array2,
+ smallint_array1, smallint_array2,
+ int_array1, int_array2,
+ bigint_array1, bigint_array2,
+ decimal_array1, decimal_array2,
+ double_array1, double_array2,
+ float_array1, float_array2,
+ date_array1, data_array2,
+ timestamp_array1, timestamp_array2,
+ string_array1, string_array2,
+ array_array1, array_array2,
+ struct_array1, struct_array2,
+ map_array1, map_array2
+)
+-- !query 11 schema
+struct<>
+-- !query 11 output
+
+
+
+-- !query 12
+SELECT
+ (boolean_array1 || boolean_array2) boolean_array,
+ (tinyint_array1 || tinyint_array2) tinyint_array,
+ (smallint_array1 || smallint_array2) smallint_array,
+ (int_array1 || int_array2) int_array,
+ (bigint_array1 || bigint_array2) bigint_array,
+ (decimal_array1 || decimal_array2) decimal_array,
+ (double_array1 || double_array2) double_array,
+ (float_array1 || float_array2) float_array,
+ (date_array1 || data_array2) data_array,
+ (timestamp_array1 || timestamp_array2) timestamp_array,
+ (string_array1 || string_array2) string_array,
+ (array_array1 || array_array2) array_array,
+ (struct_array1 || struct_array2) struct_array,
+ (map_array1 || map_array2) map_array
+FROM various_arrays
+-- !query 12 schema
+struct<boolean_array:array<boolean>,tinyint_array:array<tinyint>,smallint_array:array<smallint>,int_array:array<int>,bigint_array:array<bigint>,decimal_array:array<decimal(19,0)>,double_array:array<double>,float_array:array<float>,data_array:array<date>,timestamp_array:array<timestamp>,string_array:array<string>,array_array:array<array<string>>,struct_array:array<struct<col1:string,col2:int>>,map_array:array<map<string,int>>>
+-- !query 12 output
+[true,false,true] [2,1,3,4] [2,1,3,4] [2,1,3,4]
[2,1,3,4]
[9223372036854775809,9223372036854775808,9223372036854775808,9223372036854775809]
[2.0,1.0,3.0,4.0] [2.0,1.0,3.0,4.0]
[2016-03-14,2016-03-13,2016-03-12,2016-03-11] [2016-11-15
20:54:00.0,2016-11-12 20:54:00.0,2016-11-11 20:54:00.0] ["a","b","c","d"]
[["a","b"],["c","d"],["e"],["f"]]
[{"col1":"a","col2":1},{"col1":"b","col2":2},{"col1":"c","col2":3},{"col1":"d","col2":4}]
[{"a":1},{"b":2},{"c":3},{"d":4}]
+
+
+-- !query 13
+SELECT
+ (tinyint_array1 || smallint_array2) ts_array,
+ (smallint_array1 || int_array2) si_array,
+ (int_array1 || bigint_array2) ib_array,
+ (double_array1 || float_array2) df_array,
+ (string_array1 || data_array2) std_array,
+ (timestamp_array1 || string_array2) tst_array,
+ (string_array1 || int_array2) sti_array
+FROM various_arrays
+-- !query 13 schema
+struct<ts_array:array<smallint>,si_array:array<int>,ib_array:array<bigint>,df_array:array<double>,std_array:array<string>,tst_array:array<string>,sti_array:array<string>>
+-- !query 13 output
+[2,1,3,4] [2,1,3,4] [2,1,3,4] [2.0,1.0,3.0,4.0]
["a","b","2016-03-12","2016-03-11"] ["2016-11-15 20:54:00","2016-11-12
20:54:00","c","d"] ["a","b","3","4"]
http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 7c976c1..25e5cd6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -617,6 +617,80 @@ class DataFrameFunctionsSuite extends QueryTest with
SharedSQLContext {
)
}
+ test("concat function - arrays") {
+ val nseqi : Seq[Int] = null
+ val nseqs : Seq[String] = null
+ val df = Seq(
+
+ (Seq(1), Seq(2, 3), Seq(5L, 6L), nseqi, Seq("a", "b", "c"), Seq("d",
"e"), Seq("f"), nseqs),
+ (Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String],
Seq(null), nseqs)
+ ).toDF("i1", "i2", "i3", "in", "s1", "s2", "s3", "sn")
+
+ val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codeGen
on
+
+ // Simple test cases
+ checkAnswer(
+ df.selectExpr("array(1, 2, 3L)"),
+ Seq(Row(Seq(1L, 2L, 3L)), Row(Seq(1L, 2L, 3L)))
+ )
+
+ checkAnswer (
+ df.select(concat($"i1", $"s1")),
+ Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a")))
+ )
+ checkAnswer(
+ df.select(concat($"i1", $"i2", $"i3")),
+ Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2)))
+ )
+ checkAnswer(
+ df.filter(dummyFilter($"i1")).select(concat($"i1", $"i2", $"i3")),
+ Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2)))
+ )
+ checkAnswer(
+ df.selectExpr("concat(array(1, null), i2, i3)"),
+ Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2)))
+ )
+ checkAnswer(
+ df.select(concat($"s1", $"s2", $"s3")),
+ Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null)))
+ )
+ checkAnswer(
+ df.selectExpr("concat(s1, s2, s3)"),
+ Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null)))
+ )
+ checkAnswer(
+ df.filter(dummyFilter($"s1"))select(concat($"s1", $"s2", $"s3")),
+ Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null)))
+ )
+
+ // Null test cases
+ checkAnswer(
+ df.select(concat($"i1", $"in")),
+ Seq(Row(null), Row(null))
+ )
+ checkAnswer(
+ df.select(concat($"in", $"i1")),
+ Seq(Row(null), Row(null))
+ )
+ checkAnswer(
+ df.select(concat($"s1", $"sn")),
+ Seq(Row(null), Row(null))
+ )
+ checkAnswer(
+ df.select(concat($"sn", $"s1")),
+ Seq(Row(null), Row(null))
+ )
+
+ // Type error test cases
+ intercept[AnalysisException] {
+ df.selectExpr("concat(i1, i2, null)")
+ }
+
+ intercept[AnalysisException] {
+ df.selectExpr("concat(i1, array(i1, i2))")
+ }
+ }
+
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false),
(false, true))) {
http://git-wip-us.apache.org/repos/asf/spark/blob/e6b46608/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
index cbd7f9d..3998cec 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
@@ -1742,8 +1742,8 @@ abstract class DDLSuite extends QueryTest with
SQLTestUtils {
sql("DESCRIBE FUNCTION 'concat'"),
Row("Class: org.apache.spark.sql.catalyst.expressions.Concat") ::
Row("Function: concat") ::
- Row("Usage: concat(str1, str2, ..., strN) - " +
- "Returns the concatenation of str1, str2, ..., strN.") :: Nil
+ Row("Usage: concat(col1, col2, ..., colN) - " +
+ "Returns the concatenation of col1, col2, ..., colN.") :: Nil
)
// extended mode
checkAnswer(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]