[
https://issues.apache.org/jira/browse/BEAM-5411?focusedWorklogId=145387&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-145387
]
ASF GitHub Bot logged work on BEAM-5411:
----------------------------------------
Author: ASF GitHub Bot
Created on: 18/Sep/18 17:27
Start Date: 18/Sep/18 17:27
Worklog Time Spent: 10m
Work Description: apilloud closed pull request #6415: [BEAM-5411]
Simplify BeamUnnest
URL: https://github.com/apache/beam/pull/6415
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java
index df74a95b52c..d4549a3ca4c 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java
@@ -74,6 +74,7 @@
public abstract List<Object> getValues();
/** Get value by field name, {@link ClassCastException} is thrown if type
doesn't match. */
+ @Nullable
@SuppressWarnings("TypeParameterUnusedInFormals")
public <T> T getValue(String fieldName) {
return getValue(getSchema().indexOf(fieldName));
@@ -83,6 +84,7 @@
* Get a {@link TypeName#BYTE} value by field name, {@link
IllegalStateException} is thrown if
* schema doesn't match.
*/
+ @Nullable
public Byte getByte(String fieldName) {
return getByte(getSchema().indexOf(fieldName));
}
@@ -91,6 +93,7 @@ public Byte getByte(String fieldName) {
* Get a {@link TypeName#BYTES} value by field name, {@link
IllegalStateException} is thrown if
* schema doesn't match.
*/
+ @Nullable
public byte[] getBytes(String fieldName) {
return getBytes(getSchema().indexOf(fieldName));
}
@@ -99,6 +102,7 @@ public Byte getByte(String fieldName) {
* Get a {@link TypeName#INT16} value by field name, {@link
IllegalStateException} is thrown if
* schema doesn't match.
*/
+ @Nullable
public Short getInt16(String fieldName) {
return getInt16(getSchema().indexOf(fieldName));
}
@@ -107,6 +111,7 @@ public Short getInt16(String fieldName) {
* Get a {@link TypeName#INT32} value by field name, {@link
IllegalStateException} is thrown if
* schema doesn't match.
*/
+ @Nullable
public Integer getInt32(String fieldName) {
return getInt32(getSchema().indexOf(fieldName));
}
@@ -115,6 +120,7 @@ public Integer getInt32(String fieldName) {
* Get a {@link TypeName#INT64} value by field name, {@link
IllegalStateException} is thrown if
* schema doesn't match.
*/
+ @Nullable
public Long getInt64(String fieldName) {
return getInt64(getSchema().indexOf(fieldName));
}
@@ -123,6 +129,7 @@ public Long getInt64(String fieldName) {
* Get a {@link TypeName#DECIMAL} value by field name, {@link
IllegalStateException} is thrown if
* schema doesn't match.
*/
+ @Nullable
public BigDecimal getDecimal(String fieldName) {
return getDecimal(getSchema().indexOf(fieldName));
}
@@ -131,6 +138,7 @@ public BigDecimal getDecimal(String fieldName) {
* Get a {@link TypeName#FLOAT} value by field name, {@link
IllegalStateException} is thrown if
* schema doesn't match.
*/
+ @Nullable
public Float getFloat(String fieldName) {
return getFloat(getSchema().indexOf(fieldName));
}
@@ -139,6 +147,7 @@ public Float getFloat(String fieldName) {
* Get a {@link TypeName#DOUBLE} value by field name, {@link
IllegalStateException} is thrown if
* schema doesn't match.
*/
+ @Nullable
public Double getDouble(String fieldName) {
return getDouble(getSchema().indexOf(fieldName));
}
@@ -147,6 +156,7 @@ public Double getDouble(String fieldName) {
* Get a {@link TypeName#STRING} value by field name, {@link
IllegalStateException} is thrown if
* schema doesn't match.
*/
+ @Nullable
public String getString(String fieldName) {
return getString(getSchema().indexOf(fieldName));
}
@@ -155,7 +165,8 @@ public String getString(String fieldName) {
* Get a {@link TypeName#DATETIME} value by field name, {@link
IllegalStateException} is thrown if
* schema doesn't match.
*/
- public @Nullable ReadableDateTime getDateTime(String fieldName) {
+ @Nullable
+ public ReadableDateTime getDateTime(String fieldName) {
return getDateTime(getSchema().indexOf(fieldName));
}
@@ -163,6 +174,7 @@ public String getString(String fieldName) {
* Get a {@link TypeName#BOOLEAN} value by field name, {@link
IllegalStateException} is thrown if
* schema doesn't match.
*/
+ @Nullable
public Boolean getBoolean(String fieldName) {
return getBoolean(getSchema().indexOf(fieldName));
}
@@ -171,6 +183,7 @@ public Boolean getBoolean(String fieldName) {
* Get an array value by field name, {@link IllegalStateException} is thrown
if schema doesn't
* match.
*/
+ @Nullable
public <T> List<T> getArray(String fieldName) {
return getArray(getSchema().indexOf(fieldName));
}
@@ -178,6 +191,7 @@ public Boolean getBoolean(String fieldName) {
/**
* Get a MAP value by field name, {@link IllegalStateException} is thrown if
schema doesn't match.
*/
+ @Nullable
public <T1, T2> Map<T1, T2> getMap(String fieldName) {
return getMap(getSchema().indexOf(fieldName));
}
@@ -186,6 +200,7 @@ public Boolean getBoolean(String fieldName) {
* Get a {@link TypeName#ROW} value by field name, {@link
IllegalStateException} is thrown if
* schema doesn't match.
*/
+ @Nullable
public Row getRow(String fieldName) {
return getRow(getSchema().indexOf(fieldName));
}
@@ -194,6 +209,7 @@ public Row getRow(String fieldName) {
* Get a {@link TypeName#BYTE} value by field index, {@link
ClassCastException} is thrown if
* schema doesn't match.
*/
+ @Nullable
public Byte getByte(int idx) {
return getValue(idx);
}
@@ -202,6 +218,7 @@ public Byte getByte(int idx) {
* Get a {@link TypeName#BYTES} value by field index, {@link
ClassCastException} is thrown if
* schema doesn't match.
*/
+ @Nullable
public byte[] getBytes(int idx) {
return getValue(idx);
}
@@ -210,6 +227,7 @@ public Byte getByte(int idx) {
* Get a {@link TypeName#INT16} value by field index, {@link
ClassCastException} is thrown if
* schema doesn't match.
*/
+ @Nullable
public Short getInt16(int idx) {
return getValue(idx);
}
@@ -218,6 +236,7 @@ public Short getInt16(int idx) {
* Get a {@link TypeName#INT32} value by field index, {@link
ClassCastException} is thrown if
* schema doesn't match.
*/
+ @Nullable
public Integer getInt32(int idx) {
return getValue(idx);
}
@@ -226,6 +245,7 @@ public Integer getInt32(int idx) {
* Get a {@link TypeName#FLOAT} value by field index, {@link
ClassCastException} is thrown if
* schema doesn't match.
*/
+ @Nullable
public Float getFloat(int idx) {
return getValue(idx);
}
@@ -234,6 +254,7 @@ public Float getFloat(int idx) {
* Get a {@link TypeName#DOUBLE} value by field index, {@link
ClassCastException} is thrown if
* schema doesn't match.
*/
+ @Nullable
public Double getDouble(int idx) {
return getValue(idx);
}
@@ -242,6 +263,7 @@ public Double getDouble(int idx) {
* Get a {@link TypeName#INT64} value by field index, {@link
ClassCastException} is thrown if
* schema doesn't match.
*/
+ @Nullable
public Long getInt64(int idx) {
return getValue(idx);
}
@@ -250,6 +272,7 @@ public Long getInt64(int idx) {
* Get a {@link String} value by field index, {@link ClassCastException} is
thrown if schema
* doesn't match.
*/
+ @Nullable
public String getString(int idx) {
return getValue(idx);
}
@@ -258,7 +281,8 @@ public String getString(int idx) {
* Get a {@link TypeName#DATETIME} value by field index, {@link
IllegalStateException} is thrown
* if schema doesn't match.
*/
- public @Nullable ReadableDateTime getDateTime(int idx) {
+ @Nullable
+ public ReadableDateTime getDateTime(int idx) {
ReadableInstant instant = getValue(idx);
return instant == null ? null : new
DateTime(instant).withZone(instant.getZone());
}
@@ -267,6 +291,7 @@ public String getString(int idx) {
* Get a {@link BigDecimal} value by field index, {@link ClassCastException}
is thrown if schema
* doesn't match.
*/
+ @Nullable
public BigDecimal getDecimal(int idx) {
return getValue(idx);
}
@@ -275,6 +300,7 @@ public BigDecimal getDecimal(int idx) {
* Get a {@link Boolean} value by field index, {@link ClassCastException} is
thrown if schema
* doesn't match.
*/
+ @Nullable
public Boolean getBoolean(int idx) {
return getValue(idx);
}
@@ -283,6 +309,7 @@ public Boolean getBoolean(int idx) {
* Get an array value by field index, {@link IllegalStateException} is
thrown if schema doesn't
* match.
*/
+ @Nullable
public <T> List<T> getArray(int idx) {
return getValue(idx);
}
@@ -291,6 +318,7 @@ public Boolean getBoolean(int idx) {
* Get a MAP value by field index, {@link IllegalStateException} is thrown
if schema doesn't
* match.
*/
+ @Nullable
public <T1, T2> Map<T1, T2> getMap(int idx) {
return getValue(idx);
}
@@ -299,6 +327,7 @@ public Boolean getBoolean(int idx) {
* Get a {@link Row} value by field index, {@link IllegalStateException} is
thrown if schema
* doesn't match.
*/
+ @Nullable
public Row getRow(int idx) {
return getValue(idx);
}
diff --git
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java
index 55fcc1aea0c..1d37e63a25a 100644
---
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java
+++
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java
@@ -17,67 +17,66 @@
*/
package org.apache.beam.sdk.extensions.sql.impl.rel;
-import static com.google.common.base.Preconditions.checkArgument;
-import static com.google.common.base.Preconditions.checkState;
-
import com.google.common.collect.ImmutableList;
import java.util.List;
import javax.annotation.Nullable;
-import
org.apache.beam.sdk.extensions.sql.impl.interpreter.BeamSqlExpressionEnvironment;
-import
org.apache.beam.sdk.extensions.sql.impl.interpreter.BeamSqlExpressionEnvironments;
-import
org.apache.beam.sdk.extensions.sql.impl.interpreter.BeamSqlExpressionExecutor;
-import org.apache.beam.sdk.extensions.sql.impl.interpreter.BeamSqlFnExecutor;
-import org.apache.beam.sdk.extensions.sql.impl.schema.BeamTableUtils;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
-import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.Row;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.core.Correlate;
-import org.apache.calcite.rel.core.CorrelationId;
+import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Uncollect;
-import org.apache.calcite.sql.SemiJoinType;
-import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.sql.validate.SqlValidatorUtil;
/**
* {@link BeamRelNode} to implement UNNEST, supporting specifically only
{@link Correlate} with
* {@link Uncollect}.
*/
-public class BeamUnnestRel extends Correlate implements BeamRelNode {
+public class BeamUnnestRel extends Uncollect implements BeamRelNode {
+
+ private final RelDataType unnestType;
+ private final int unnestIndex;
public BeamUnnestRel(
RelOptCluster cluster,
- RelTraitSet traits,
- RelNode left,
- RelNode right,
- CorrelationId correlationId,
- ImmutableBitSet requiredColumns,
- SemiJoinType joinType) {
- super(cluster, traits, left, right, correlationId, requiredColumns,
joinType);
+ RelTraitSet traitSet,
+ RelNode input,
+ RelDataType unnestType,
+ int unnestIndex) {
+ super(cluster, traitSet, input);
+ this.unnestType = unnestType;
+ this.unnestIndex = unnestIndex;
}
@Override
- public Correlate copy(
- RelTraitSet relTraitSet,
- RelNode left,
- RelNode right,
- CorrelationId correlationId,
- ImmutableBitSet requireColumns,
- SemiJoinType joinType) {
- return new BeamUnnestRel(
- getCluster(), relTraitSet, left, right, correlationId,
requiredColumns, joinType);
+ public Uncollect copy(RelTraitSet traitSet, RelNode input) {
+ return new BeamUnnestRel(getCluster(), traitSet, input, unnestType,
unnestIndex);
}
@Override
- public List<RelNode> getPCollectionInputs() {
- return ImmutableList.of(BeamSqlRelUtils.getBeamRelInput(left));
+ protected RelDataType deriveRowType() {
+ return SqlValidatorUtil.deriveJoinRowType(
+ input.getRowType(),
+ unnestType,
+ JoinRelType.INNER,
+ getCluster().getTypeFactory(),
+ null,
+ ImmutableList.of());
+ }
+
+ @Override
+ public RelWriter explainTerms(RelWriter pw) {
+ return super.explainTerms(pw).item("unnestIndex",
Integer.toString(unnestIndex));
}
@Override
@@ -91,82 +90,40 @@ public Correlate copy(
// The set of rows where we run the correlated unnest for each row
PCollection<Row> outer = pinput.get(0);
- // The correlated subquery
- BeamUncollectRel uncollect = (BeamUncollectRel)
BeamSqlRelUtils.getBeamRelInput(right);
- Schema innerSchema = CalciteUtils.toSchema(uncollect.getRowType());
- checkArgument(
- innerSchema.getFieldCount() == 1, "Can only UNNEST a single column",
getClass());
-
- BeamSqlExpressionExecutor expr =
- new BeamSqlFnExecutor(
- ((BeamCalcRel)
BeamSqlRelUtils.getBeamRelInput(uncollect.getInput())).getProgram());
-
Schema joinedSchema = CalciteUtils.toSchema(rowType);
return outer
- .apply(
- ParDo.of(
- new UnnestFn(correlationId.getId(), expr, joinedSchema,
innerSchema.getField(0))))
+ .apply(ParDo.of(new UnnestFn(joinedSchema, unnestIndex)))
.setRowSchema(joinedSchema);
}
}
private static class UnnestFn extends DoFn<Row, Row> {
- /** The expression that should return an iterable to be uncollected. */
- private final BeamSqlExpressionExecutor expr;
-
- private final int correlationId;
private final Schema outputSchema;
- private final Schema.Field innerField;
-
- private UnnestFn(
- int correlationId,
- BeamSqlExpressionExecutor expr,
- Schema outputSchema,
- Schema.Field innerField) {
- this.correlationId = correlationId;
- this.expr = expr;
+ private final int unnestIndex;
+
+ private UnnestFn(Schema outputSchema, int unnestIndex) {
this.outputSchema = outputSchema;
- this.innerField = innerField;
+ this.unnestIndex = unnestIndex;
}
@ProcessElement
- public void process(@Element Row row, BoundedWindow window,
OutputReceiver<Row> out) {
+ public void process(@Element Row row, OutputReceiver<Row> out) {
- checkState(correlationId == 0, "Only one level of correlation nesting is
supported");
- BeamSqlExpressionEnvironment env =
- BeamSqlExpressionEnvironments.forRowAndCorrelVariables(
- row, window, ImmutableList.of(row));
-
- @Nullable List<Object> rawValues = expr.execute(row, window, env);
+ @Nullable List<Object> rawValues = row.getArray(unnestIndex);
if (rawValues == null) {
return;
}
- checkState(
- rawValues.size() == 1,
- "%s expression to unnest %s resulted in more than one column",
- getClass(),
- expr);
-
- checkState(
- rawValues.get(0) instanceof Iterable,
- "%s expression to unnest %s not iterable",
- getClass(),
- expr);
-
- for (Object uncollectedValue : (Iterable) rawValues.get(0)) {
- Object coercedValue = BeamTableUtils.autoCastField(innerField,
uncollectedValue);
+ for (Object uncollectedValue : rawValues) {
out.output(
-
Row.withSchema(outputSchema).addValues(row.getValues()).addValue(coercedValue).build());
+ Row.withSchema(outputSchema)
+ .addValues(row.getValues())
+ .addValue(uncollectedValue)
+ .build());
}
}
-
- @Teardown
- public void close() {
- expr.close();
- }
}
}
diff --git
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/BeamUncollectRule.java
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/BeamUncollectRule.java
index 1a8872eb438..30f66efa8ec 100644
---
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/BeamUncollectRule.java
+++
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/BeamUncollectRule.java
@@ -24,9 +24,8 @@
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.convert.ConverterRule;
import org.apache.calcite.rel.core.Uncollect;
-import org.apache.calcite.rel.core.Union;
-/** A {@code ConverterRule} to replace {@link Union} with {@link
BeamUncollectRule}. */
+/** A {@code ConverterRule} to replace {@link Uncollect} with {@link
BeamUncollectRule}. */
public class BeamUncollectRule extends ConverterRule {
public static final BeamUncollectRule INSTANCE = new BeamUncollectRule();
diff --git
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/BeamUnnestRule.java
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/BeamUnnestRule.java
index 45c2bb423db..6ce4dce5c6a 100644
---
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/BeamUnnestRule.java
+++
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rule/BeamUnnestRule.java
@@ -25,11 +25,18 @@
import org.apache.calcite.plan.volcano.RelSubset;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.SingleRel;
+import org.apache.calcite.rel.core.Correlate;
import org.apache.calcite.rel.core.Uncollect;
-import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.logical.LogicalCorrelate;
+import org.apache.calcite.rel.logical.LogicalProject;
+import org.apache.calcite.rex.RexFieldAccess;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SemiJoinType;
-/** A {@code ConverterRule} to replace {@link Union} with {@link
BeamUnnestRule}. */
+/**
+ * A {@code ConverterRule} to replace {@link Correlate} {@link Uncollect} with
{@link
+ * BeamUnnestRule}.
+ */
public class BeamUnnestRule extends RelOptRule {
public static final BeamUnnestRule INSTANCE = new BeamUnnestRule();
@@ -46,6 +53,19 @@ public void onMatch(RelOptRuleCall call) {
LogicalCorrelate correlate = call.rel(0);
RelNode outer = call.rel(1);
RelNode uncollect = call.rel(2);
+
+ if (correlate.getCorrelationId().getId() != 0) {
+ // Only one level of correlation nesting is supported
+ return;
+ }
+ if (correlate.getRequiredColumns().cardinality() != 1) {
+ // can only unnest a single column
+ return;
+ }
+ if (correlate.getJoinType() != SemiJoinType.INNER) {
+ return;
+ }
+
if (!(uncollect instanceof Uncollect)) {
// Drop projection
uncollect = ((SingleRel) uncollect).getInput();
@@ -57,14 +77,31 @@ public void onMatch(RelOptRuleCall call) {
}
}
+ RelNode project = ((Uncollect) uncollect).getInput();
+ if (project instanceof RelSubset) {
+ project = ((RelSubset) project).getOriginal();
+ }
+ if (!(project instanceof LogicalProject)) {
+ return;
+ }
+
+ if (((LogicalProject) project).getProjects().size() != 1) {
+ // can only unnest a single column
+ return;
+ }
+
+ RexNode exp = ((LogicalProject) project).getProjects().get(0);
+ if (!(exp instanceof RexFieldAccess)) {
+ return;
+ }
+ int fieldIndex = ((RexFieldAccess) exp).getField().getIndex();
+
call.transformTo(
new BeamUnnestRel(
correlate.getCluster(),
correlate.getTraitSet().replace(BeamLogicalConvention.INSTANCE),
outer,
- convert(uncollect,
uncollect.getTraitSet().replace(BeamLogicalConvention.INSTANCE)),
- correlate.getCorrelationId(),
- correlate.getRequiredColumns(),
- correlate.getJoinType()));
+ call.rel(2).getRowType(),
+ fieldIndex));
}
}
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
Issue Time Tracking
-------------------
Worklog Id: (was: 145387)
Time Spent: 1h (was: 50m)
> Separate BeamUnnest and BeamCalc
> --------------------------------
>
> Key: BEAM-5411
> URL: https://issues.apache.org/jira/browse/BEAM-5411
> Project: Beam
> Issue Type: Improvement
> Components: dsl-sql
> Reporter: Andrew Pilloud
> Assignee: Andrew Pilloud
> Priority: Major
> Time Spent: 1h
> Remaining Estimate: 0h
>
> Currently Correlated Uncollect (BeamUnnest) embeds a fork of BeamCalc. This
> isn't actually needed, simplifying this node enables easier replacement of
> Calc.
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)