This is an automated email from the ASF dual-hosted git repository. apilloud pushed a commit to branch fad in repository https://gitbox.apache.org/repos/asf/beam.git
commit b96099e5949c7541d3d905f3efd22ec478bff4ad Author: Andrew Pilloud <[email protected]> AuthorDate: Thu Sep 2 11:53:57 2021 -0700 [BEAM-12691] FieldAccessDescriptor for BeamCalcRel --- .../sdk/extensions/sql/impl/rel/BeamCalcRel.java | 93 +++++++++------- .../extensions/sql/impl/rel/BeamCalcRelTest.java | 84 ++++++++++++++ .../extensions/sql/zetasql/BeamZetaSqlCalcRel.java | 43 ++++--- .../sql/zetasql/BeamZetaSqlCalcRelTest.java | 123 +++++++++++++++++++++ 4 files changed, 287 insertions(+), 56 deletions(-) diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java index c435695..543d322 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.extensions.sql.impl.rel; +import static org.apache.beam.sdk.schemas.Schema.Field; import static org.apache.beam.sdk.schemas.Schema.FieldType; import static org.apache.beam.vendor.calcite.v1_26_0.com.google.common.base.Preconditions.checkArgument; @@ -40,6 +41,7 @@ import java.util.List; import java.util.Map; import java.util.Set; import java.util.TimeZone; +import java.util.TreeSet; import java.util.stream.Collectors; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlPipelineOptions; import org.apache.beam.sdk.extensions.sql.impl.JavaUdfLoader; @@ -48,6 +50,7 @@ import org.apache.beam.sdk.extensions.sql.impl.planner.BeamJavaTypeFactory; import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils.CharType; import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils.TimeWithLocalTzType; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes; import org.apache.beam.sdk.transforms.DoFn; @@ -157,15 +160,11 @@ public class BeamCalcRel extends AbstractBeamCalcRel { final RelOptPredicateList predicates = mq.getPulledUpPredicates(getInput()); final RexSimplify simplify = new RexSimplify(rexBuilder, predicates, RexUtil.EXECUTOR); final RexProgram program = getProgram().normalize(rexBuilder, simplify); + final InputGetterImpl inputGetter = new InputGetterImpl(rowParam, upstream.getSchema()); Expression condition = RexToLixTranslator.translateCondition( - program, - typeFactory, - builder, - new InputGetterImpl(rowParam, upstream.getSchema()), - null, - conformance); + program, typeFactory, builder, inputGetter, null, conformance); List<Expression> expressions = RexToLixTranslator.translateProjects( @@ -175,7 +174,7 @@ public class BeamCalcRel extends AbstractBeamCalcRel { builder, physType, DataContext.ROOT, - new InputGetterImpl(rowParam, upstream.getSchema()), + inputGetter, null); builder.add( @@ -192,10 +191,8 @@ public class BeamCalcRel extends AbstractBeamCalcRel { builder.toBlock().toString(), outputSchema, options.getVerifyRowValues(), - getJarPaths(program)); - - // validate generated code - calcFn.compile(); + getJarPaths(program), + inputGetter.getFieldAccess()); return upstream.apply(ParDo.of(calcFn)).setRowSchema(outputSchema); } @@ -207,20 +204,29 @@ public class BeamCalcRel extends AbstractBeamCalcRel { private final Schema outputSchema; private final boolean verifyRowValues; private final List<String> jarPaths; + + @FieldAccess("row") + private final FieldAccessDescriptor fieldAccess; + private transient @Nullable ScriptEvaluator se = null; public CalcFn( String processElementBlock, Schema outputSchema, boolean verifyRowValues, - List<String> jarPaths) { + List<String> jarPaths, + FieldAccessDescriptor fieldAccess) { this.processElementBlock = processElementBlock; this.outputSchema = outputSchema; this.verifyRowValues = verifyRowValues; this.jarPaths = jarPaths; + this.fieldAccess = fieldAccess; + + // validate generated code + compile(processElementBlock, jarPaths); } - ScriptEvaluator compile() { + private static ScriptEvaluator compile(String processElementBlock, List<String> jarPaths) { ScriptEvaluator se = new ScriptEvaluator(); if (!jarPaths.isEmpty()) { try { @@ -246,22 +252,22 @@ public class BeamCalcRel extends AbstractBeamCalcRel { @Setup public void setup() { - this.se = compile(); + this.se = compile(processElementBlock, jarPaths); } @ProcessElement - public void processElement(ProcessContext c) { + public void processElement(@FieldAccess("row") Row row, OutputReceiver<Row> r) { assert se != null; final Object[] v; try { - v = (Object[]) se.evaluate(new Object[] {c.element(), CONTEXT_INSTANCE}); + v = (Object[]) se.evaluate(new Object[] {row, CONTEXT_INSTANCE}); } catch (InvocationTargetException e) { throw new RuntimeException( "CalcFn failed to evaluate: " + processElementBlock, e.getCause()); } if (v != null) { - Row row = toBeamRow(Arrays.asList(v), outputSchema, verifyRowValues); - c.output(row); + final Row output = toBeamRow(Arrays.asList(v), outputSchema, verifyRowValues); + r.output(output); } } } @@ -411,14 +417,21 @@ public class BeamCalcRel extends AbstractBeamCalcRel { private final Expression input; private final Schema inputSchema; + private final Set<Integer> referencedColumns; private InputGetterImpl(Expression input, Schema inputSchema) { this.input = input; this.inputSchema = inputSchema; + this.referencedColumns = new TreeSet<>(); + } + + FieldAccessDescriptor getFieldAccess() { + return FieldAccessDescriptor.withFieldIds(this.referencedColumns); } @Override public Expression field(BlockBuilder list, int index, Type storageType) { + this.referencedColumns.add(index); return getBeamField(list, index, input, inputSchema); } @@ -431,64 +444,66 @@ public class BeamCalcRel extends AbstractBeamCalcRel { final Expression expression = list.append(list.newName("current"), input); - FieldType fieldType = schema.getField(index).getType(); - Expression value; + final Field field = schema.getField(index); + final FieldType fieldType = field.getType(); + final Expression fieldName = Expressions.constant(field.getName()); + final Expression value; switch (fieldType.getTypeName()) { case BYTE: - value = Expressions.call(expression, "getByte", Expressions.constant(index)); + value = Expressions.call(expression, "getByte", fieldName); break; case INT16: - value = Expressions.call(expression, "getInt16", Expressions.constant(index)); + value = Expressions.call(expression, "getInt16", fieldName); break; case INT32: - value = Expressions.call(expression, "getInt32", Expressions.constant(index)); + value = Expressions.call(expression, "getInt32", fieldName); break; case INT64: - value = Expressions.call(expression, "getInt64", Expressions.constant(index)); + value = Expressions.call(expression, "getInt64", fieldName); break; case DECIMAL: - value = Expressions.call(expression, "getDecimal", Expressions.constant(index)); + value = Expressions.call(expression, "getDecimal", fieldName); break; case FLOAT: - value = Expressions.call(expression, "getFloat", Expressions.constant(index)); + value = Expressions.call(expression, "getFloat", fieldName); break; case DOUBLE: - value = Expressions.call(expression, "getDouble", Expressions.constant(index)); + value = Expressions.call(expression, "getDouble", fieldName); break; case STRING: - value = Expressions.call(expression, "getString", Expressions.constant(index)); + value = Expressions.call(expression, "getString", fieldName); break; case DATETIME: - value = Expressions.call(expression, "getDateTime", Expressions.constant(index)); + value = Expressions.call(expression, "getDateTime", fieldName); break; case BOOLEAN: - value = Expressions.call(expression, "getBoolean", Expressions.constant(index)); + value = Expressions.call(expression, "getBoolean", fieldName); break; case BYTES: - value = Expressions.call(expression, "getBytes", Expressions.constant(index)); + value = Expressions.call(expression, "getBytes", fieldName); break; case ARRAY: - value = Expressions.call(expression, "getArray", Expressions.constant(index)); + value = Expressions.call(expression, "getArray", fieldName); break; case MAP: - value = Expressions.call(expression, "getMap", Expressions.constant(index)); + value = Expressions.call(expression, "getMap", fieldName); break; case ROW: - value = Expressions.call(expression, "getRow", Expressions.constant(index)); + value = Expressions.call(expression, "getRow", fieldName); break; case LOGICAL_TYPE: String identifier = fieldType.getLogicalType().getIdentifier(); if (CharType.IDENTIFIER.equals(identifier)) { - value = Expressions.call(expression, "getString", Expressions.constant(index)); + value = Expressions.call(expression, "getString", fieldName); } else if (TimeWithLocalTzType.IDENTIFIER.equals(identifier)) { - value = Expressions.call(expression, "getDateTime", Expressions.constant(index)); + value = Expressions.call(expression, "getDateTime", fieldName); } else if (SqlTypes.DATE.getIdentifier().equals(identifier)) { value = Expressions.convert_( Expressions.call( expression, "getLogicalTypeValue", - Expressions.constant(index), + fieldName, Expressions.constant(LocalDate.class)), LocalDate.class); } else if (SqlTypes.TIME.getIdentifier().equals(identifier)) { @@ -497,7 +512,7 @@ public class BeamCalcRel extends AbstractBeamCalcRel { Expressions.call( expression, "getLogicalTypeValue", - Expressions.constant(index), + fieldName, Expressions.constant(LocalTime.class)), LocalTime.class); } else if (SqlTypes.DATETIME.getIdentifier().equals(identifier)) { @@ -506,7 +521,7 @@ public class BeamCalcRel extends AbstractBeamCalcRel { Expressions.call( expression, "getLogicalTypeValue", - Expressions.constant(index), + fieldName, Expressions.constant(LocalDateTime.class)), LocalDateTime.class); } else { diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRelTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRelTest.java index 27baad3..656a514 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRelTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRelTest.java @@ -18,20 +18,34 @@ package org.apache.beam.sdk.extensions.sql.impl.rel; import java.math.BigDecimal; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.extensions.sql.impl.BeamTableStatistics; import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats; import org.apache.beam.sdk.extensions.sql.meta.provider.test.TestBoundedTable; import org.apache.beam.sdk.extensions.sql.meta.provider.test.TestUnboundedTable; +import org.apache.beam.sdk.runners.TransformHierarchy; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor; import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.RelNode; import org.joda.time.DateTime; import org.joda.time.Duration; import org.junit.Assert; import org.junit.BeforeClass; +import org.junit.Rule; import org.junit.Test; /** Tests related to {@code BeamCalcRel}. */ public class BeamCalcRelTest extends BaseRelTest { + + @Rule public final TestPipeline pipeline = TestPipeline.create(); + private static final DateTime FIRST_DATE = new DateTime(1); private static final DateTime SECOND_DATE = new DateTime(1 + 3600 * 1000); @@ -160,4 +174,74 @@ public class BeamCalcRelTest extends BaseRelTest { Assert.assertTrue(doubleEqualEstimate.getRowCount() < equalEstimate.getRowCount()); Assert.assertTrue(doubleEqualEstimate.getWindow() < equalEstimate.getWindow()); } + + private static class NodeGetter extends Pipeline.PipelineVisitor.Defaults { + + private final PValue target; + private TransformHierarchy.Node producer; + + private NodeGetter(PValue target) { + this.target = target; + } + + @Override + public void visitValue(PValue value, TransformHierarchy.Node producer) { + if (value == target) { + assert this.producer == null; + this.producer = producer; + } + } + } + + @Test + public void testSingleFieldAccess() throws IllegalAccessException { + String sql = "SELECT order_id FROM ORDER_DETAILS_BOUNDED"; + + PCollection<Row> rows = compilePipeline(sql, pipeline); + + final NodeGetter nodeGetter = new NodeGetter(rows); + pipeline.traverseTopologically(nodeGetter); + + ParDo.MultiOutput<Row, Row> pardo = + (ParDo.MultiOutput<Row, Row>) nodeGetter.producer.getTransform(); + DoFnSignature sig = DoFnSignatures.getSignature(pardo.getFn().getClass()); + + Assert.assertEquals(1, sig.fieldAccessDeclarations().size()); + DoFnSignature.FieldAccessDeclaration dec = + sig.fieldAccessDeclarations().values().iterator().next(); + FieldAccessDescriptor fieldAccess = (FieldAccessDescriptor) dec.field().get(pardo.getFn()); + + Assert.assertTrue(fieldAccess.referencesSingleField()); + + fieldAccess = + fieldAccess.resolve(nodeGetter.producer.getInputs().values().iterator().next().getSchema()); + Assert.assertEquals("order_id", fieldAccess.fieldNamesAccessed().iterator().next()); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testNoFieldAccess() throws IllegalAccessException { + String sql = "SELECT 1 FROM ORDER_DETAILS_BOUNDED"; + + PCollection<Row> rows = compilePipeline(sql, pipeline); + + final NodeGetter nodeGetter = new NodeGetter(rows); + pipeline.traverseTopologically(nodeGetter); + + ParDo.MultiOutput<Row, Row> pardo = + (ParDo.MultiOutput<Row, Row>) nodeGetter.producer.getTransform(); + DoFnSignature sig = DoFnSignatures.getSignature(pardo.getFn().getClass()); + + Assert.assertEquals(1, sig.fieldAccessDeclarations().size()); + DoFnSignature.FieldAccessDeclaration dec = + sig.fieldAccessDeclarations().values().iterator().next(); + FieldAccessDescriptor fieldAccess = (FieldAccessDescriptor) dec.field().get(pardo.getFn()); + + Assert.assertFalse(fieldAccess.getAllFields()); + Assert.assertTrue(fieldAccess.getFieldsAccessed().isEmpty()); + Assert.assertTrue(fieldAccess.getNestedFieldsAccessed().isEmpty()); + + pipeline.run().waitUntilFinish(); + } } diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRel.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRel.java index 38c604e..1d93ed3 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRel.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRel.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.extensions.sql.zetasql; +import static org.apache.beam.sdk.schemas.Schema.Field; import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; import com.google.auto.value.AutoValue; @@ -39,6 +40,7 @@ import org.apache.beam.sdk.extensions.sql.impl.rel.AbstractBeamCalcRel; import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; import org.apache.beam.sdk.extensions.sql.meta.provider.bigquery.BeamBigQuerySqlDialect; import org.apache.beam.sdk.extensions.sql.meta.provider.bigquery.BeamSqlUnparseContext; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; @@ -142,9 +144,6 @@ public class BeamZetaSqlCalcRel extends AbstractBeamCalcRel { options.getZetaSqlDefaultTimezone(), options.getVerifyRowValues()); - // validate prepared expressions - calcFn.setup(); - return upstream.apply(ParDo.of(calcFn)).setRowSchema(outputSchema); } } @@ -171,7 +170,11 @@ public class BeamZetaSqlCalcRel extends AbstractBeamCalcRel { private final Schema outputSchema; private final String defaultTimezone; private final boolean verifyRowValues; - private transient List<Integer> referencedColumns = ImmutableList.of(); + private final List<Integer> referencedColumns; + + @FieldAccess("row") + private final FieldAccessDescriptor fieldAccess; + private transient Map<BoundedWindow, Queue<TimestampedFuture>> pending = new HashMap<>(); private transient PreparedExpression exp; private transient PreparedExpression.@Nullable Stream stream; @@ -190,10 +193,21 @@ public class BeamZetaSqlCalcRel extends AbstractBeamCalcRel { this.outputSchema = outputSchema; this.defaultTimezone = defaultTimezone; this.verifyRowValues = verifyRowValues; + + try (PreparedExpression exp = + prepareExpression(sql, nullParams, inputSchema, defaultTimezone)) { + ImmutableList.Builder<Integer> columns = new ImmutableList.Builder<>(); + for (String c : exp.getReferencedColumns()) { + columns.add(Integer.parseInt(c.substring(1))); + } + this.referencedColumns = columns.build(); + this.fieldAccess = FieldAccessDescriptor.withFieldIds(this.referencedColumns); + } } /** exp cannot be reused and is transient so needs to be reinitialized. */ - private void prepareExpression() { + private static PreparedExpression prepareExpression( + String sql, Map<String, Value> nullParams, Schema inputSchema, String defaultTimezone) { AnalyzerOptions options = SqlAnalyzer.getAnalyzerOptions(QueryParameters.ofNamed(nullParams), defaultTimezone); for (int i = 0; i < inputSchema.getFieldCount(); i++) { @@ -202,21 +216,15 @@ public class BeamZetaSqlCalcRel extends AbstractBeamCalcRel { ZetaSqlBeamTranslationUtils.toZetaSqlType(inputSchema.getField(i).getType())); } - exp = new PreparedExpression(sql); + PreparedExpression exp = new PreparedExpression(sql); exp.prepare(options); + return exp; } @Setup public void setup() { - prepareExpression(); - - ImmutableList.Builder<Integer> columns = new ImmutableList.Builder<>(); - for (String c : exp.getReferencedColumns()) { - columns.add(Integer.parseInt(c.substring(1))); - } - referencedColumns = columns.build(); - - stream = exp.stream(); + this.exp = prepareExpression(sql, nullParams, inputSchema, defaultTimezone); + this.stream = exp.stream(); } @StartBundle @@ -231,14 +239,15 @@ public class BeamZetaSqlCalcRel extends AbstractBeamCalcRel { @ProcessElement public void processElement( - @Element Row row, @Timestamp Instant t, BoundedWindow w, OutputReceiver<Row> r) + @FieldAccess("row") Row row, @Timestamp Instant t, BoundedWindow w, OutputReceiver<Row> r) throws InterruptedException { Map<String, Value> columns = new HashMap<>(); for (int i : referencedColumns) { + final Field field = inputSchema.getField(i); columns.put( columnName(i), ZetaSqlBeamTranslationUtils.toZetaSqlValue( - row.getBaseValue(i, Object.class), inputSchema.getField(i).getType())); + row.getBaseValue(field.getName(), Object.class), field.getType())); } @NonNull diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRelTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRelTest.java new file mode 100644 index 0000000..352e83a --- /dev/null +++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRelTest.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.sql.zetasql; + +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.extensions.sql.impl.QueryPlanner.QueryParameters; +import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode; +import org.apache.beam.sdk.extensions.sql.impl.rel.BeamSqlRelUtils; +import org.apache.beam.sdk.runners.TransformHierarchy; +import org.apache.beam.sdk.schemas.FieldAccessDescriptor; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PValue; +import org.apache.beam.sdk.values.Row; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; + +/** Tests related to {@code BeamZetaSqlCalcRel}. */ +public class BeamZetaSqlCalcRelTest extends ZetaSqlTestBase { + + private PCollection<Row> compile(String sql) { + ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config); + BeamRelNode beamRelNode = zetaSQLQueryPlanner.convertToBeamRel(sql, QueryParameters.ofNone()); + return BeamSqlRelUtils.toPCollection(pipeline, beamRelNode); + } + + @Rule public final TestPipeline pipeline = TestPipeline.create(); + + @Before + public void setUp() { + initialize(); + } + + private static class NodeGetter extends Pipeline.PipelineVisitor.Defaults { + + private final PValue target; + private TransformHierarchy.Node producer; + + private NodeGetter(PValue target) { + this.target = target; + } + + @Override + public void visitValue(PValue value, TransformHierarchy.Node producer) { + if (value == target) { + assert this.producer == null; + this.producer = producer; + } + } + } + + @Test + public void testSingleFieldAccess() throws IllegalAccessException { + String sql = "SELECT Key FROM KeyValue"; + + PCollection<Row> rows = compile(sql); + + final NodeGetter nodeGetter = new NodeGetter(rows); + pipeline.traverseTopologically(nodeGetter); + + ParDo.MultiOutput<Row, Row> pardo = + (ParDo.MultiOutput<Row, Row>) nodeGetter.producer.getTransform(); + DoFnSignature sig = DoFnSignatures.getSignature(pardo.getFn().getClass()); + + Assert.assertEquals(1, sig.fieldAccessDeclarations().size()); + DoFnSignature.FieldAccessDeclaration dec = + sig.fieldAccessDeclarations().values().iterator().next(); + FieldAccessDescriptor fieldAccess = (FieldAccessDescriptor) dec.field().get(pardo.getFn()); + + Assert.assertTrue(fieldAccess.referencesSingleField()); + + fieldAccess = + fieldAccess.resolve(nodeGetter.producer.getInputs().values().iterator().next().getSchema()); + Assert.assertEquals("Key", fieldAccess.fieldNamesAccessed().iterator().next()); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testNoFieldAccess() throws IllegalAccessException { + String sql = "SELECT 1 FROM KeyValue"; + + PCollection<Row> rows = compile(sql); + + final NodeGetter nodeGetter = new NodeGetter(rows); + pipeline.traverseTopologically(nodeGetter); + + ParDo.MultiOutput<Row, Row> pardo = + (ParDo.MultiOutput<Row, Row>) nodeGetter.producer.getTransform(); + DoFnSignature sig = DoFnSignatures.getSignature(pardo.getFn().getClass()); + + Assert.assertEquals(1, sig.fieldAccessDeclarations().size()); + DoFnSignature.FieldAccessDeclaration dec = + sig.fieldAccessDeclarations().values().iterator().next(); + FieldAccessDescriptor fieldAccess = (FieldAccessDescriptor) dec.field().get(pardo.getFn()); + + Assert.assertFalse(fieldAccess.getAllFields()); + Assert.assertTrue(fieldAccess.getFieldsAccessed().isEmpty()); + Assert.assertTrue(fieldAccess.getNestedFieldsAccessed().isEmpty()); + + pipeline.run().waitUntilFinish(); + } +}
