This is an automated email from the ASF dual-hosted git repository. srowen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d71b180295e [SPARK-40398][CORE][SQL] Use Loop instead of Arrays.stream api d71b180295e is described below commit d71b180295ea89b39047cff8397c5b3c2fe0bd20 Author: yangjie01 <yangji...@baidu.com> AuthorDate: Fri Sep 16 08:29:31 2022 -0500 [SPARK-40398][CORE][SQL] Use Loop instead of Arrays.stream api ### What changes were proposed in this pull request? This PR replaces `Arrays.stream` api with loop where performance improvement can be obtained. ### Why are the changes needed? Minor performance improvement. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass Github actions Closes #37843 from LuciferYang/ExpressionArrayToStrings. Authored-by: yangjie01 <yangji...@baidu.com> Signed-off-by: Sean Owen <sro...@gmail.com> --- .../network/shuffle/OneForOneBlockFetcher.java | 24 ++++++++- .../sql/connector/expressions/Expression.java | 20 +++++-- .../sql/connector/metric/CustomAvgMetric.java | 7 ++- .../sql/connector/metric/CustomSumMetric.java | 8 +-- .../sql/connector/util/V2ExpressionSQLBuilder.java | 62 +++++++++++++--------- .../datasources/v2/V2PredicateSuite.scala | 4 +- 6 files changed, 87 insertions(+), 38 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index a788b508e7b..b93db3f570b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -113,10 +113,30 @@ public class OneForOneBlockFetcher { * @return whether the array contains only shuffle block IDs */ private boolean areShuffleBlocksOrChunks(String[] blockIds) { - if (Arrays.stream(blockIds).anyMatch(blockId -> !blockId.startsWith(SHUFFLE_BLOCK_PREFIX))) { + if (isAnyBlockNotStartWithShuffleBlockPrefix(blockIds)) { // It comes here because there is a blockId which doesn't have "shuffle_" prefix so we // check if all the block ids are shuffle chunk Ids. - return Arrays.stream(blockIds).allMatch(blockId -> blockId.startsWith(SHUFFLE_CHUNK_PREFIX)); + return isAllBlocksStartWithShuffleChunkPrefix(blockIds); + } + return true; + } + + // SPARK-40398: Replace `Arrays.stream().anyMatch()` with this method due to perf gain. + private static boolean isAnyBlockNotStartWithShuffleBlockPrefix(String[] blockIds) { + for (String blockId : blockIds) { + if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX)) { + return true; + } + } + return false; + } + + // SPARK-40398: Replace `Arrays.stream().allMatch()` with this method due to perf gain. + private static boolean isAllBlocksStartWithShuffleChunkPrefix(String[] blockIds) { + for (String blockId : blockIds) { + if (!blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) { + return false; + } } return true; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java index 76dfe73f666..25953ec32e4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java @@ -17,7 +17,9 @@ package org.apache.spark.sql.connector.expressions; -import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; import org.apache.spark.annotation.Evolving; @@ -30,6 +32,13 @@ import org.apache.spark.annotation.Evolving; public interface Expression { Expression[] EMPTY_EXPRESSION = new Expression[0]; + /** + * `EMPTY_EXPRESSION` is only used as an input when the + * default `references` method builds the result array to avoid + * repeatedly allocating an empty array. + */ + NamedReference[] EMPTY_NAMED_REFERENCE = new NamedReference[0]; + /** * Format the expression as a human readable SQL-like string. */ @@ -44,7 +53,12 @@ public interface Expression { * List of fields or columns that are referenced by this expression. */ default NamedReference[] references() { - return Arrays.stream(children()).map(e -> e.references()) - .flatMap(Arrays::stream).distinct().toArray(NamedReference[]::new); + // SPARK-40398: Replace `Arrays.stream()...distinct()` + // to this for perf gain, the result order is not important. + Set<NamedReference> set = new HashSet<>(); + for (Expression e : children()) { + Collections.addAll(set, e.references()); + } + return set.toArray(EMPTY_NAMED_REFERENCE); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomAvgMetric.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomAvgMetric.java index 71e83002dda..99ac3ac8d20 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomAvgMetric.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomAvgMetric.java @@ -19,7 +19,6 @@ package org.apache.spark.sql.connector.metric; import org.apache.spark.annotation.Evolving; -import java.util.Arrays; import java.text.DecimalFormat; /** @@ -33,7 +32,11 @@ public abstract class CustomAvgMetric implements CustomMetric { @Override public String aggregateTaskMetrics(long[] taskMetrics) { if (taskMetrics.length > 0) { - double average = ((double)Arrays.stream(taskMetrics).sum()) / taskMetrics.length; + long sum = 0L; + for (long taskMetric : taskMetrics) { + sum += taskMetric; + } + double average = ((double) sum) / taskMetrics.length; return new DecimalFormat("#0.000").format(average); } else { return "0"; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomSumMetric.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomSumMetric.java index ba28e9b9187..47d0ae9b805 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomSumMetric.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomSumMetric.java @@ -19,8 +19,6 @@ package org.apache.spark.sql.connector.metric; import org.apache.spark.annotation.Evolving; -import java.util.Arrays; - /** * Built-in `CustomMetric` that sums up metric values. Note that please extend this class * and override `name` and `description` to create your custom metric for real usage. @@ -31,6 +29,10 @@ import java.util.Arrays; public abstract class CustomSumMetric implements CustomMetric { @Override public String aggregateTaskMetrics(long[] taskMetrics) { - return String.valueOf(Arrays.stream(taskMetrics).sum()); + long sum = 0L; + for (long taskMetric : taskMetrics) { + sum += taskMetric; + } + return String.valueOf(sum); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java index 315b3309054..b32958d13da 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -17,10 +17,9 @@ package org.apache.spark.sql.connector.util; -import java.util.Arrays; +import java.util.ArrayList; import java.util.List; import java.util.StringJoiner; -import java.util.stream.Collectors; import org.apache.spark.sql.connector.expressions.Cast; import org.apache.spark.sql.connector.expressions.Expression; @@ -62,9 +61,9 @@ public class V2ExpressionSQLBuilder { String name = e.name(); switch (name) { case "IN": { - List<String> children = - Arrays.stream(e.children()).map(c -> build(c)).collect(Collectors.toList()); - return visitIn(children.get(0), children.subList(1, children.size())); + Expression[] expressions = e.children(); + List<String> children = expressionsToStringList(expressions, 1, expressions.length - 1); + return visitIn(build(expressions[0]), children); } case "IS_NULL": return visitIsNull(build(e.children()[0])); @@ -159,25 +158,18 @@ public class V2ExpressionSQLBuilder { case "BIT_LENGTH": case "CHAR_LENGTH": case "CONCAT": - return visitSQLFunction(name, - Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new)); + return visitSQLFunction(name, expressionsToStringArray(e.children())); case "CASE_WHEN": { - List<String> children = - Arrays.stream(e.children()).map(c -> build(c)).collect(Collectors.toList()); - return visitCaseWhen(children.toArray(new String[e.children().length])); + return visitCaseWhen(expressionsToStringArray(e.children())); } case "TRIM": - return visitTrim("BOTH", - Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new)); + return visitTrim("BOTH", expressionsToStringArray(e.children())); case "LTRIM": - return visitTrim("LEADING", - Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new)); + return visitTrim("LEADING", expressionsToStringArray(e.children())); case "RTRIM": - return visitTrim("TRAILING", - Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new)); + return visitTrim("TRAILING", expressionsToStringArray(e.children())); case "OVERLAY": - return visitOverlay( - Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new)); + return visitOverlay(expressionsToStringArray(e.children())); // TODO supports other expressions default: return visitUnexpectedExpr(expr); @@ -185,37 +177,37 @@ public class V2ExpressionSQLBuilder { } else if (expr instanceof Min) { Min min = (Min) expr; return visitAggregateFunction("MIN", false, - Arrays.stream(min.children()).map(c -> build(c)).toArray(String[]::new)); + expressionsToStringArray(min.children())); } else if (expr instanceof Max) { Max max = (Max) expr; return visitAggregateFunction("MAX", false, - Arrays.stream(max.children()).map(c -> build(c)).toArray(String[]::new)); + expressionsToStringArray(max.children())); } else if (expr instanceof Count) { Count count = (Count) expr; return visitAggregateFunction("COUNT", count.isDistinct(), - Arrays.stream(count.children()).map(c -> build(c)).toArray(String[]::new)); + expressionsToStringArray(count.children())); } else if (expr instanceof Sum) { Sum sum = (Sum) expr; return visitAggregateFunction("SUM", sum.isDistinct(), - Arrays.stream(sum.children()).map(c -> build(c)).toArray(String[]::new)); + expressionsToStringArray(sum.children())); } else if (expr instanceof CountStar) { return visitAggregateFunction("COUNT", false, new String[]{"*"}); } else if (expr instanceof Avg) { Avg avg = (Avg) expr; return visitAggregateFunction("AVG", avg.isDistinct(), - Arrays.stream(avg.children()).map(c -> build(c)).toArray(String[]::new)); + expressionsToStringArray(avg.children())); } else if (expr instanceof GeneralAggregateFunc) { GeneralAggregateFunc f = (GeneralAggregateFunc) expr; return visitAggregateFunction(f.name(), f.isDistinct(), - Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new)); + expressionsToStringArray(f.children())); } else if (expr instanceof UserDefinedScalarFunc) { UserDefinedScalarFunc f = (UserDefinedScalarFunc) expr; return visitUserDefinedScalarFunction(f.name(), f.canonicalName(), - Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new)); + expressionsToStringArray(f.children())); } else if (expr instanceof UserDefinedAggregateFunc) { UserDefinedAggregateFunc f = (UserDefinedAggregateFunc) expr; return visitUserDefinedAggregateFunction(f.name(), f.canonicalName(), f.isDistinct(), - Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new)); + expressionsToStringArray(f.children())); } else { return visitUnexpectedExpr(expr); } @@ -393,4 +385,22 @@ public class V2ExpressionSQLBuilder { } return joiner.toString(); } + + private String[] expressionsToStringArray(Expression[] expressions) { + String[] result = new String[expressions.length]; + for (int i = 0; i < expressions.length; i++) { + result[i] = build(expressions[i]); + } + return result; + } + + private List<String> expressionsToStringList(Expression[] expressions, int offset, int length) { + List<String> list = new ArrayList<>(length); + final int till = Math.min(offset + length, expressions.length); + while (offset < till) { + list.add(build(expressions[offset])); + offset++; + } + return list; + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala index de556c50f5d..a5fee51dc91 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala @@ -263,7 +263,7 @@ class V2PredicateSuite extends SparkFunSuite { new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))), new Predicate("=", Array[Expression](ref("b"), LiteralValue(1, IntegerType)))) assert(predicate1.equals(predicate2)) - assert(predicate1.references.map(_.describe()).toSeq == Seq("a", "b")) + assert(predicate1.references.map(_.describe()).toSeq.sorted == Seq("a", "b")) assert(predicate1.describe.equals("(a = 1) AND (b = 1)")) val v1Filter = V1And(EqualTo("a", 1), EqualTo("b", 1)) @@ -287,7 +287,7 @@ class V2PredicateSuite extends SparkFunSuite { new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))), new Predicate("=", Array[Expression](ref("b"), LiteralValue(1, IntegerType)))) assert(predicate1.equals(predicate2)) - assert(predicate1.references.map(_.describe()).toSeq == Seq("a", "b")) + assert(predicate1.references.map(_.describe()).toSeq.sorted == Seq("a", "b")) assert(predicate1.describe.equals("(a = 1) OR (b = 1)")) val v1Filter = V1Or(EqualTo("a", 1), EqualTo("b", 1)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org