This is an automated email from the ASF dual-hosted git repository.

apilloud pushed a commit to branch fad
in repository https://gitbox.apache.org/repos/asf/beam.git

commit b96099e5949c7541d3d905f3efd22ec478bff4ad
Author: Andrew Pilloud <[email protected]>
AuthorDate: Thu Sep 2 11:53:57 2021 -0700

    [BEAM-12691] FieldAccessDescriptor for BeamCalcRel
---
 .../sdk/extensions/sql/impl/rel/BeamCalcRel.java   |  93 +++++++++-------
 .../extensions/sql/impl/rel/BeamCalcRelTest.java   |  84 ++++++++++++++
 .../extensions/sql/zetasql/BeamZetaSqlCalcRel.java |  43 ++++---
 .../sql/zetasql/BeamZetaSqlCalcRelTest.java        | 123 +++++++++++++++++++++
 4 files changed, 287 insertions(+), 56 deletions(-)

diff --git 
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java
 
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java
index c435695..543d322 100644
--- 
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java
+++ 
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRel.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.sdk.extensions.sql.impl.rel;
 
+import static org.apache.beam.sdk.schemas.Schema.Field;
 import static org.apache.beam.sdk.schemas.Schema.FieldType;
 import static 
org.apache.beam.vendor.calcite.v1_26_0.com.google.common.base.Preconditions.checkArgument;
 
@@ -40,6 +41,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.TimeZone;
+import java.util.TreeSet;
 import java.util.stream.Collectors;
 import org.apache.beam.sdk.extensions.sql.impl.BeamSqlPipelineOptions;
 import org.apache.beam.sdk.extensions.sql.impl.JavaUdfLoader;
@@ -48,6 +50,7 @@ import 
org.apache.beam.sdk.extensions.sql.impl.planner.BeamJavaTypeFactory;
 import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
 import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils.CharType;
 import 
org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils.TimeWithLocalTzType;
+import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
 import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes;
 import org.apache.beam.sdk.transforms.DoFn;
@@ -157,15 +160,11 @@ public class BeamCalcRel extends AbstractBeamCalcRel {
       final RelOptPredicateList predicates = 
mq.getPulledUpPredicates(getInput());
       final RexSimplify simplify = new RexSimplify(rexBuilder, predicates, 
RexUtil.EXECUTOR);
       final RexProgram program = getProgram().normalize(rexBuilder, simplify);
+      final InputGetterImpl inputGetter = new InputGetterImpl(rowParam, 
upstream.getSchema());
 
       Expression condition =
           RexToLixTranslator.translateCondition(
-              program,
-              typeFactory,
-              builder,
-              new InputGetterImpl(rowParam, upstream.getSchema()),
-              null,
-              conformance);
+              program, typeFactory, builder, inputGetter, null, conformance);
 
       List<Expression> expressions =
           RexToLixTranslator.translateProjects(
@@ -175,7 +174,7 @@ public class BeamCalcRel extends AbstractBeamCalcRel {
               builder,
               physType,
               DataContext.ROOT,
-              new InputGetterImpl(rowParam, upstream.getSchema()),
+              inputGetter,
               null);
 
       builder.add(
@@ -192,10 +191,8 @@ public class BeamCalcRel extends AbstractBeamCalcRel {
               builder.toBlock().toString(),
               outputSchema,
               options.getVerifyRowValues(),
-              getJarPaths(program));
-
-      // validate generated code
-      calcFn.compile();
+              getJarPaths(program),
+              inputGetter.getFieldAccess());
 
       return upstream.apply(ParDo.of(calcFn)).setRowSchema(outputSchema);
     }
@@ -207,20 +204,29 @@ public class BeamCalcRel extends AbstractBeamCalcRel {
     private final Schema outputSchema;
     private final boolean verifyRowValues;
     private final List<String> jarPaths;
+
+    @FieldAccess("row")
+    private final FieldAccessDescriptor fieldAccess;
+
     private transient @Nullable ScriptEvaluator se = null;
 
     public CalcFn(
         String processElementBlock,
         Schema outputSchema,
         boolean verifyRowValues,
-        List<String> jarPaths) {
+        List<String> jarPaths,
+        FieldAccessDescriptor fieldAccess) {
       this.processElementBlock = processElementBlock;
       this.outputSchema = outputSchema;
       this.verifyRowValues = verifyRowValues;
       this.jarPaths = jarPaths;
+      this.fieldAccess = fieldAccess;
+
+      // validate generated code
+      compile(processElementBlock, jarPaths);
     }
 
-    ScriptEvaluator compile() {
+    private static ScriptEvaluator compile(String processElementBlock, 
List<String> jarPaths) {
       ScriptEvaluator se = new ScriptEvaluator();
       if (!jarPaths.isEmpty()) {
         try {
@@ -246,22 +252,22 @@ public class BeamCalcRel extends AbstractBeamCalcRel {
 
     @Setup
     public void setup() {
-      this.se = compile();
+      this.se = compile(processElementBlock, jarPaths);
     }
 
     @ProcessElement
-    public void processElement(ProcessContext c) {
+    public void processElement(@FieldAccess("row") Row row, 
OutputReceiver<Row> r) {
       assert se != null;
       final Object[] v;
       try {
-        v = (Object[]) se.evaluate(new Object[] {c.element(), 
CONTEXT_INSTANCE});
+        v = (Object[]) se.evaluate(new Object[] {row, CONTEXT_INSTANCE});
       } catch (InvocationTargetException e) {
         throw new RuntimeException(
             "CalcFn failed to evaluate: " + processElementBlock, e.getCause());
       }
       if (v != null) {
-        Row row = toBeamRow(Arrays.asList(v), outputSchema, verifyRowValues);
-        c.output(row);
+        final Row output = toBeamRow(Arrays.asList(v), outputSchema, 
verifyRowValues);
+        r.output(output);
       }
     }
   }
@@ -411,14 +417,21 @@ public class BeamCalcRel extends AbstractBeamCalcRel {
 
     private final Expression input;
     private final Schema inputSchema;
+    private final Set<Integer> referencedColumns;
 
     private InputGetterImpl(Expression input, Schema inputSchema) {
       this.input = input;
       this.inputSchema = inputSchema;
+      this.referencedColumns = new TreeSet<>();
+    }
+
+    FieldAccessDescriptor getFieldAccess() {
+      return FieldAccessDescriptor.withFieldIds(this.referencedColumns);
     }
 
     @Override
     public Expression field(BlockBuilder list, int index, Type storageType) {
+      this.referencedColumns.add(index);
       return getBeamField(list, index, input, inputSchema);
     }
 
@@ -431,64 +444,66 @@ public class BeamCalcRel extends AbstractBeamCalcRel {
 
       final Expression expression = list.append(list.newName("current"), 
input);
 
-      FieldType fieldType = schema.getField(index).getType();
-      Expression value;
+      final Field field = schema.getField(index);
+      final FieldType fieldType = field.getType();
+      final Expression fieldName = Expressions.constant(field.getName());
+      final Expression value;
       switch (fieldType.getTypeName()) {
         case BYTE:
-          value = Expressions.call(expression, "getByte", 
Expressions.constant(index));
+          value = Expressions.call(expression, "getByte", fieldName);
           break;
         case INT16:
-          value = Expressions.call(expression, "getInt16", 
Expressions.constant(index));
+          value = Expressions.call(expression, "getInt16", fieldName);
           break;
         case INT32:
-          value = Expressions.call(expression, "getInt32", 
Expressions.constant(index));
+          value = Expressions.call(expression, "getInt32", fieldName);
           break;
         case INT64:
-          value = Expressions.call(expression, "getInt64", 
Expressions.constant(index));
+          value = Expressions.call(expression, "getInt64", fieldName);
           break;
         case DECIMAL:
-          value = Expressions.call(expression, "getDecimal", 
Expressions.constant(index));
+          value = Expressions.call(expression, "getDecimal", fieldName);
           break;
         case FLOAT:
-          value = Expressions.call(expression, "getFloat", 
Expressions.constant(index));
+          value = Expressions.call(expression, "getFloat", fieldName);
           break;
         case DOUBLE:
-          value = Expressions.call(expression, "getDouble", 
Expressions.constant(index));
+          value = Expressions.call(expression, "getDouble", fieldName);
           break;
         case STRING:
-          value = Expressions.call(expression, "getString", 
Expressions.constant(index));
+          value = Expressions.call(expression, "getString", fieldName);
           break;
         case DATETIME:
-          value = Expressions.call(expression, "getDateTime", 
Expressions.constant(index));
+          value = Expressions.call(expression, "getDateTime", fieldName);
           break;
         case BOOLEAN:
-          value = Expressions.call(expression, "getBoolean", 
Expressions.constant(index));
+          value = Expressions.call(expression, "getBoolean", fieldName);
           break;
         case BYTES:
-          value = Expressions.call(expression, "getBytes", 
Expressions.constant(index));
+          value = Expressions.call(expression, "getBytes", fieldName);
           break;
         case ARRAY:
-          value = Expressions.call(expression, "getArray", 
Expressions.constant(index));
+          value = Expressions.call(expression, "getArray", fieldName);
           break;
         case MAP:
-          value = Expressions.call(expression, "getMap", 
Expressions.constant(index));
+          value = Expressions.call(expression, "getMap", fieldName);
           break;
         case ROW:
-          value = Expressions.call(expression, "getRow", 
Expressions.constant(index));
+          value = Expressions.call(expression, "getRow", fieldName);
           break;
         case LOGICAL_TYPE:
           String identifier = fieldType.getLogicalType().getIdentifier();
           if (CharType.IDENTIFIER.equals(identifier)) {
-            value = Expressions.call(expression, "getString", 
Expressions.constant(index));
+            value = Expressions.call(expression, "getString", fieldName);
           } else if (TimeWithLocalTzType.IDENTIFIER.equals(identifier)) {
-            value = Expressions.call(expression, "getDateTime", 
Expressions.constant(index));
+            value = Expressions.call(expression, "getDateTime", fieldName);
           } else if (SqlTypes.DATE.getIdentifier().equals(identifier)) {
             value =
                 Expressions.convert_(
                     Expressions.call(
                         expression,
                         "getLogicalTypeValue",
-                        Expressions.constant(index),
+                        fieldName,
                         Expressions.constant(LocalDate.class)),
                     LocalDate.class);
           } else if (SqlTypes.TIME.getIdentifier().equals(identifier)) {
@@ -497,7 +512,7 @@ public class BeamCalcRel extends AbstractBeamCalcRel {
                     Expressions.call(
                         expression,
                         "getLogicalTypeValue",
-                        Expressions.constant(index),
+                        fieldName,
                         Expressions.constant(LocalTime.class)),
                     LocalTime.class);
           } else if (SqlTypes.DATETIME.getIdentifier().equals(identifier)) {
@@ -506,7 +521,7 @@ public class BeamCalcRel extends AbstractBeamCalcRel {
                     Expressions.call(
                         expression,
                         "getLogicalTypeValue",
-                        Expressions.constant(index),
+                        fieldName,
                         Expressions.constant(LocalDateTime.class)),
                     LocalDateTime.class);
           } else {
diff --git 
a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRelTest.java
 
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRelTest.java
index 27baad3..656a514 100644
--- 
a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRelTest.java
+++ 
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamCalcRelTest.java
@@ -18,20 +18,34 @@
 package org.apache.beam.sdk.extensions.sql.impl.rel;
 
 import java.math.BigDecimal;
+import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.extensions.sql.impl.BeamTableStatistics;
 import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats;
 import org.apache.beam.sdk.extensions.sql.meta.provider.test.TestBoundedTable;
 import 
org.apache.beam.sdk.extensions.sql.meta.provider.test.TestUnboundedTable;
+import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
 import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.Row;
 import org.apache.beam.vendor.calcite.v1_26_0.org.apache.calcite.rel.RelNode;
 import org.joda.time.DateTime;
 import org.joda.time.Duration;
 import org.junit.Assert;
 import org.junit.BeforeClass;
+import org.junit.Rule;
 import org.junit.Test;
 
 /** Tests related to {@code BeamCalcRel}. */
 public class BeamCalcRelTest extends BaseRelTest {
+
+  @Rule public final TestPipeline pipeline = TestPipeline.create();
+
   private static final DateTime FIRST_DATE = new DateTime(1);
   private static final DateTime SECOND_DATE = new DateTime(1 + 3600 * 1000);
 
@@ -160,4 +174,74 @@ public class BeamCalcRelTest extends BaseRelTest {
     Assert.assertTrue(doubleEqualEstimate.getRowCount() < 
equalEstimate.getRowCount());
     Assert.assertTrue(doubleEqualEstimate.getWindow() < 
equalEstimate.getWindow());
   }
+
+  private static class NodeGetter extends Pipeline.PipelineVisitor.Defaults {
+
+    private final PValue target;
+    private TransformHierarchy.Node producer;
+
+    private NodeGetter(PValue target) {
+      this.target = target;
+    }
+
+    @Override
+    public void visitValue(PValue value, TransformHierarchy.Node producer) {
+      if (value == target) {
+        assert this.producer == null;
+        this.producer = producer;
+      }
+    }
+  }
+
+  @Test
+  public void testSingleFieldAccess() throws IllegalAccessException {
+    String sql = "SELECT order_id FROM ORDER_DETAILS_BOUNDED";
+
+    PCollection<Row> rows = compilePipeline(sql, pipeline);
+
+    final NodeGetter nodeGetter = new NodeGetter(rows);
+    pipeline.traverseTopologically(nodeGetter);
+
+    ParDo.MultiOutput<Row, Row> pardo =
+        (ParDo.MultiOutput<Row, Row>) nodeGetter.producer.getTransform();
+    DoFnSignature sig = DoFnSignatures.getSignature(pardo.getFn().getClass());
+
+    Assert.assertEquals(1, sig.fieldAccessDeclarations().size());
+    DoFnSignature.FieldAccessDeclaration dec =
+        sig.fieldAccessDeclarations().values().iterator().next();
+    FieldAccessDescriptor fieldAccess = (FieldAccessDescriptor) 
dec.field().get(pardo.getFn());
+
+    Assert.assertTrue(fieldAccess.referencesSingleField());
+
+    fieldAccess =
+        
fieldAccess.resolve(nodeGetter.producer.getInputs().values().iterator().next().getSchema());
+    Assert.assertEquals("order_id", 
fieldAccess.fieldNamesAccessed().iterator().next());
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  @Test
+  public void testNoFieldAccess() throws IllegalAccessException {
+    String sql = "SELECT 1 FROM ORDER_DETAILS_BOUNDED";
+
+    PCollection<Row> rows = compilePipeline(sql, pipeline);
+
+    final NodeGetter nodeGetter = new NodeGetter(rows);
+    pipeline.traverseTopologically(nodeGetter);
+
+    ParDo.MultiOutput<Row, Row> pardo =
+        (ParDo.MultiOutput<Row, Row>) nodeGetter.producer.getTransform();
+    DoFnSignature sig = DoFnSignatures.getSignature(pardo.getFn().getClass());
+
+    Assert.assertEquals(1, sig.fieldAccessDeclarations().size());
+    DoFnSignature.FieldAccessDeclaration dec =
+        sig.fieldAccessDeclarations().values().iterator().next();
+    FieldAccessDescriptor fieldAccess = (FieldAccessDescriptor) 
dec.field().get(pardo.getFn());
+
+    Assert.assertFalse(fieldAccess.getAllFields());
+    Assert.assertTrue(fieldAccess.getFieldsAccessed().isEmpty());
+    Assert.assertTrue(fieldAccess.getNestedFieldsAccessed().isEmpty());
+
+    pipeline.run().waitUntilFinish();
+  }
 }
diff --git 
a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRel.java
 
b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRel.java
index 38c604e..1d93ed3 100644
--- 
a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRel.java
+++ 
b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRel.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.sdk.extensions.sql.zetasql;
 
+import static org.apache.beam.sdk.schemas.Schema.Field;
 import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;
 
 import com.google.auto.value.AutoValue;
@@ -39,6 +40,7 @@ import 
org.apache.beam.sdk.extensions.sql.impl.rel.AbstractBeamCalcRel;
 import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
 import 
org.apache.beam.sdk.extensions.sql.meta.provider.bigquery.BeamBigQuerySqlDialect;
 import 
org.apache.beam.sdk.extensions.sql.meta.provider.bigquery.BeamSqlUnparseContext;
+import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
 import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -142,9 +144,6 @@ public class BeamZetaSqlCalcRel extends AbstractBeamCalcRel 
{
               options.getZetaSqlDefaultTimezone(),
               options.getVerifyRowValues());
 
-      // validate prepared expressions
-      calcFn.setup();
-
       return upstream.apply(ParDo.of(calcFn)).setRowSchema(outputSchema);
     }
   }
@@ -171,7 +170,11 @@ public class BeamZetaSqlCalcRel extends 
AbstractBeamCalcRel {
     private final Schema outputSchema;
     private final String defaultTimezone;
     private final boolean verifyRowValues;
-    private transient List<Integer> referencedColumns = ImmutableList.of();
+    private final List<Integer> referencedColumns;
+
+    @FieldAccess("row")
+    private final FieldAccessDescriptor fieldAccess;
+
     private transient Map<BoundedWindow, Queue<TimestampedFuture>> pending = 
new HashMap<>();
     private transient PreparedExpression exp;
     private transient PreparedExpression.@Nullable Stream stream;
@@ -190,10 +193,21 @@ public class BeamZetaSqlCalcRel extends 
AbstractBeamCalcRel {
       this.outputSchema = outputSchema;
       this.defaultTimezone = defaultTimezone;
       this.verifyRowValues = verifyRowValues;
+
+      try (PreparedExpression exp =
+          prepareExpression(sql, nullParams, inputSchema, defaultTimezone)) {
+        ImmutableList.Builder<Integer> columns = new ImmutableList.Builder<>();
+        for (String c : exp.getReferencedColumns()) {
+          columns.add(Integer.parseInt(c.substring(1)));
+        }
+        this.referencedColumns = columns.build();
+        this.fieldAccess = 
FieldAccessDescriptor.withFieldIds(this.referencedColumns);
+      }
     }
 
     /** exp cannot be reused and is transient so needs to be reinitialized. */
-    private void prepareExpression() {
+    private static PreparedExpression prepareExpression(
+        String sql, Map<String, Value> nullParams, Schema inputSchema, String 
defaultTimezone) {
       AnalyzerOptions options =
           SqlAnalyzer.getAnalyzerOptions(QueryParameters.ofNamed(nullParams), 
defaultTimezone);
       for (int i = 0; i < inputSchema.getFieldCount(); i++) {
@@ -202,21 +216,15 @@ public class BeamZetaSqlCalcRel extends 
AbstractBeamCalcRel {
             
ZetaSqlBeamTranslationUtils.toZetaSqlType(inputSchema.getField(i).getType()));
       }
 
-      exp = new PreparedExpression(sql);
+      PreparedExpression exp = new PreparedExpression(sql);
       exp.prepare(options);
+      return exp;
     }
 
     @Setup
     public void setup() {
-      prepareExpression();
-
-      ImmutableList.Builder<Integer> columns = new ImmutableList.Builder<>();
-      for (String c : exp.getReferencedColumns()) {
-        columns.add(Integer.parseInt(c.substring(1)));
-      }
-      referencedColumns = columns.build();
-
-      stream = exp.stream();
+      this.exp = prepareExpression(sql, nullParams, inputSchema, 
defaultTimezone);
+      this.stream = exp.stream();
     }
 
     @StartBundle
@@ -231,14 +239,15 @@ public class BeamZetaSqlCalcRel extends 
AbstractBeamCalcRel {
 
     @ProcessElement
     public void processElement(
-        @Element Row row, @Timestamp Instant t, BoundedWindow w, 
OutputReceiver<Row> r)
+        @FieldAccess("row") Row row, @Timestamp Instant t, BoundedWindow w, 
OutputReceiver<Row> r)
         throws InterruptedException {
       Map<String, Value> columns = new HashMap<>();
       for (int i : referencedColumns) {
+        final Field field = inputSchema.getField(i);
         columns.put(
             columnName(i),
             ZetaSqlBeamTranslationUtils.toZetaSqlValue(
-                row.getBaseValue(i, Object.class), 
inputSchema.getField(i).getType()));
+                row.getBaseValue(field.getName(), Object.class), 
field.getType()));
       }
 
       @NonNull
diff --git 
a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRelTest.java
 
b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRelTest.java
new file mode 100644
index 0000000..352e83a
--- /dev/null
+++ 
b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRelTest.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.zetasql;
+
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.extensions.sql.impl.QueryPlanner.QueryParameters;
+import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode;
+import org.apache.beam.sdk.extensions.sql.impl.rel.BeamSqlRelUtils;
+import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
+import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.Row;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+
+/** Tests related to {@code BeamZetaSqlCalcRel}. */
+public class BeamZetaSqlCalcRelTest extends ZetaSqlTestBase {
+
+  private PCollection<Row> compile(String sql) {
+    ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config);
+    BeamRelNode beamRelNode = zetaSQLQueryPlanner.convertToBeamRel(sql, 
QueryParameters.ofNone());
+    return BeamSqlRelUtils.toPCollection(pipeline, beamRelNode);
+  }
+
+  @Rule public final TestPipeline pipeline = TestPipeline.create();
+
+  @Before
+  public void setUp() {
+    initialize();
+  }
+
+  private static class NodeGetter extends Pipeline.PipelineVisitor.Defaults {
+
+    private final PValue target;
+    private TransformHierarchy.Node producer;
+
+    private NodeGetter(PValue target) {
+      this.target = target;
+    }
+
+    @Override
+    public void visitValue(PValue value, TransformHierarchy.Node producer) {
+      if (value == target) {
+        assert this.producer == null;
+        this.producer = producer;
+      }
+    }
+  }
+
+  @Test
+  public void testSingleFieldAccess() throws IllegalAccessException {
+    String sql = "SELECT Key FROM KeyValue";
+
+    PCollection<Row> rows = compile(sql);
+
+    final NodeGetter nodeGetter = new NodeGetter(rows);
+    pipeline.traverseTopologically(nodeGetter);
+
+    ParDo.MultiOutput<Row, Row> pardo =
+        (ParDo.MultiOutput<Row, Row>) nodeGetter.producer.getTransform();
+    DoFnSignature sig = DoFnSignatures.getSignature(pardo.getFn().getClass());
+
+    Assert.assertEquals(1, sig.fieldAccessDeclarations().size());
+    DoFnSignature.FieldAccessDeclaration dec =
+        sig.fieldAccessDeclarations().values().iterator().next();
+    FieldAccessDescriptor fieldAccess = (FieldAccessDescriptor) 
dec.field().get(pardo.getFn());
+
+    Assert.assertTrue(fieldAccess.referencesSingleField());
+
+    fieldAccess =
+        
fieldAccess.resolve(nodeGetter.producer.getInputs().values().iterator().next().getSchema());
+    Assert.assertEquals("Key", 
fieldAccess.fieldNamesAccessed().iterator().next());
+
+    pipeline.run().waitUntilFinish();
+  }
+
+  @Test
+  public void testNoFieldAccess() throws IllegalAccessException {
+    String sql = "SELECT 1 FROM KeyValue";
+
+    PCollection<Row> rows = compile(sql);
+
+    final NodeGetter nodeGetter = new NodeGetter(rows);
+    pipeline.traverseTopologically(nodeGetter);
+
+    ParDo.MultiOutput<Row, Row> pardo =
+        (ParDo.MultiOutput<Row, Row>) nodeGetter.producer.getTransform();
+    DoFnSignature sig = DoFnSignatures.getSignature(pardo.getFn().getClass());
+
+    Assert.assertEquals(1, sig.fieldAccessDeclarations().size());
+    DoFnSignature.FieldAccessDeclaration dec =
+        sig.fieldAccessDeclarations().values().iterator().next();
+    FieldAccessDescriptor fieldAccess = (FieldAccessDescriptor) 
dec.field().get(pardo.getFn());
+
+    Assert.assertFalse(fieldAccess.getAllFields());
+    Assert.assertTrue(fieldAccess.getFieldsAccessed().isEmpty());
+    Assert.assertTrue(fieldAccess.getNestedFieldsAccessed().isEmpty());
+
+    pipeline.run().waitUntilFinish();
+  }
+}

Reply via email to