This is an automated email from the ASF dual-hosted git repository.

aokolnychyi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new 3c42a10588 Spark 3.3: Add ProcedureInput to simplify procedures (#7038)
3c42a10588 is described below

commit 3c42a105883ff6edcc3cee71155b59ba89d5146d
Author: Anton Okolnychyi <[email protected]>
AuthorDate: Thu Mar 9 19:28:33 2023 -0800

    Spark 3.3: Add ProcedureInput to simplify procedures (#7038)
---
 .../procedures/CreateChangelogViewProcedure.java   | 103 +++++-----
 .../iceberg/spark/procedures/ProcedureInput.java   | 210 +++++++++++++++++++++
 2 files changed, 252 insertions(+), 61 deletions(-)

diff --git 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/CreateChangelogViewProcedure.java
 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/CreateChangelogViewProcedure.java
index ab844e08e9..b47cc0de0b 100644
--- 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/CreateChangelogViewProcedure.java
+++ 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/CreateChangelogViewProcedure.java
@@ -24,8 +24,8 @@ import java.util.Map;
 import org.apache.iceberg.MetadataColumns;
 import org.apache.iceberg.Table;
 import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
 import org.apache.iceberg.relocated.com.google.common.collect.Lists;
-import org.apache.iceberg.relocated.com.google.common.collect.Maps;
 import org.apache.iceberg.spark.ChangelogIterator;
 import org.apache.iceberg.spark.source.SparkChangelogTable;
 import org.apache.spark.api.java.function.MapPartitionsFunction;
@@ -42,7 +42,6 @@ import org.apache.spark.sql.types.Metadata;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 import org.apache.spark.unsafe.types.UTF8String;
-import scala.runtime.BoxedUnit;
 
 /**
  * A procedure that creates a view for changed rows.
@@ -81,23 +80,29 @@ import scala.runtime.BoxedUnit;
  */
 public class CreateChangelogViewProcedure extends BaseProcedure {
 
+  private static final ProcedureParameter TABLE_PARAM =
+      ProcedureParameter.required("table", DataTypes.StringType);
+  private static final ProcedureParameter CHANGELOG_VIEW_PARAM =
+      ProcedureParameter.optional("changelog_view", DataTypes.StringType);
+  private static final ProcedureParameter OPTIONS_PARAM =
+      ProcedureParameter.optional("options", STRING_MAP);
+  private static final ProcedureParameter COMPUTE_UPDATES_PARAM =
+      ProcedureParameter.optional("compute_updates", DataTypes.BooleanType);
+  private static final ProcedureParameter REMOVE_CARRYOVERS_PARAM =
+      ProcedureParameter.optional("remove_carryovers", DataTypes.BooleanType);
+  private static final ProcedureParameter IDENTIFIER_COLUMNS_PARAM =
+      ProcedureParameter.optional("identifier_columns", STRING_ARRAY);
+
   private static final ProcedureParameter[] PARAMETERS =
       new ProcedureParameter[] {
-        ProcedureParameter.required("table", DataTypes.StringType),
-        ProcedureParameter.optional("changelog_view", DataTypes.StringType),
-        ProcedureParameter.optional("options", STRING_MAP),
-        ProcedureParameter.optional("compute_updates", DataTypes.BooleanType),
-        ProcedureParameter.optional("remove_carryovers", 
DataTypes.BooleanType),
-        ProcedureParameter.optional("identifier_columns", STRING_ARRAY),
+        TABLE_PARAM,
+        CHANGELOG_VIEW_PARAM,
+        OPTIONS_PARAM,
+        COMPUTE_UPDATES_PARAM,
+        REMOVE_CARRYOVERS_PARAM,
+        IDENTIFIER_COLUMNS_PARAM,
       };
 
-  private static final int TABLE_NAME_ORDINAL = 0;
-  private static final int CHANGELOG_VIEW_NAME_ORDINAL = 1;
-  private static final int OPTIONS_ORDINAL = 2;
-  private static final int COMPUTE_UPDATES_ORDINAL = 3;
-  private static final int REMOVE_CARRYOVERS_ORDINAL = 4;
-  private static final int IDENTIFIER_COLUMNS_ORDINAL = 5;
-
   private static final StructType OUTPUT_TYPE =
       new StructType(
           new StructField[] {
@@ -129,20 +134,21 @@ public class CreateChangelogViewProcedure extends 
BaseProcedure {
 
   @Override
   public InternalRow[] call(InternalRow args) {
-    Identifier tableIdent =
-        toIdentifier(args.getString(TABLE_NAME_ORDINAL), 
PARAMETERS[TABLE_NAME_ORDINAL].name());
+    ProcedureInput input = new ProcedureInput(spark(), tableCatalog(), 
PARAMETERS, args);
+
+    Identifier tableIdent = input.ident(TABLE_PARAM);
 
     // load insert and deletes from the changelog table
     Identifier changelogTableIdent = changelogTableIdent(tableIdent);
-    Dataset<Row> df = loadDataSetFromTable(changelogTableIdent, options(args));
+    Dataset<Row> df = loadDataSetFromTable(changelogTableIdent, 
options(input));
 
-    if (shouldComputeUpdateImages(args)) {
-      df = computeUpdateImages(identifierColumns(args, tableIdent), df);
-    } else if (shouldRemoveCarryoverRows(args)) {
+    if (shouldComputeUpdateImages(input)) {
+      df = computeUpdateImages(identifierColumns(input, tableIdent), df);
+    } else if (shouldRemoveCarryoverRows(input)) {
       df = removeCarryoverRows(df);
     }
 
-    String viewName = viewName(args, tableIdent.name());
+    String viewName = viewName(input, tableIdent.name());
 
     df.createOrReplaceTempView(viewName);
 
@@ -164,21 +170,14 @@ public class CreateChangelogViewProcedure extends 
BaseProcedure {
     return applyChangelogIterator(df, repartitionColumns);
   }
 
-  private boolean shouldComputeUpdateImages(InternalRow args) {
-    if (!args.isNullAt(COMPUTE_UPDATES_ORDINAL)) {
-      return args.getBoolean(COMPUTE_UPDATES_ORDINAL);
-    } else {
-      // If the identifier columns are set, we compute pre/post update images 
by default.
-      return !args.isNullAt(IDENTIFIER_COLUMNS_ORDINAL);
-    }
+  private boolean shouldComputeUpdateImages(ProcedureInput input) {
+    // If the identifier columns are set, we compute pre/post update images by 
default.
+    boolean defaultValue = input.isProvided(IDENTIFIER_COLUMNS_PARAM);
+    return input.bool(COMPUTE_UPDATES_PARAM, defaultValue);
   }
 
-  private boolean shouldRemoveCarryoverRows(InternalRow args) {
-    if (args.isNullAt(REMOVE_CARRYOVERS_ORDINAL)) {
-      return true;
-    } else {
-      return args.getBoolean(REMOVE_CARRYOVERS_ORDINAL);
-    }
+  private boolean shouldRemoveCarryoverRows(ProcedureInput input) {
+    return input.bool(REMOVE_CARRYOVERS_PARAM, true);
   }
 
   private Dataset<Row> removeCarryoverRows(Dataset<Row> df) {
@@ -190,11 +189,9 @@ public class CreateChangelogViewProcedure extends 
BaseProcedure {
     return applyChangelogIterator(df, repartitionColumns);
   }
 
-  private String[] identifierColumns(InternalRow args, Identifier tableIdent) {
-    if (!args.isNullAt(IDENTIFIER_COLUMNS_ORDINAL)) {
-      return Arrays.stream(args.getArray(IDENTIFIER_COLUMNS_ORDINAL).array())
-          .map(column -> column.toString())
-          .toArray(String[]::new);
+  private String[] identifierColumns(ProcedureInput input, Identifier 
tableIdent) {
+    if (input.isProvided(IDENTIFIER_COLUMNS_PARAM)) {
+      return input.stringArray(IDENTIFIER_COLUMNS_PARAM);
     } else {
       Table table = loadSparkTable(tableIdent).table();
       return table.schema().identifierFieldNames().toArray(new String[0]);
@@ -208,29 +205,13 @@ public class CreateChangelogViewProcedure extends 
BaseProcedure {
     return Identifier.of(namespace.toArray(new String[0]), 
SparkChangelogTable.TABLE_NAME);
   }
 
-  private Map<String, String> options(InternalRow args) {
-    Map<String, String> options = Maps.newHashMap();
-
-    if (!args.isNullAt(OPTIONS_ORDINAL)) {
-      args.getMap(OPTIONS_ORDINAL)
-          .foreach(
-              DataTypes.StringType,
-              DataTypes.StringType,
-              (k, v) -> {
-                options.put(k.toString(), v.toString());
-                return BoxedUnit.UNIT;
-              });
-    }
-
-    return options;
+  private Map<String, String> options(ProcedureInput input) {
+    return input.stringMap(OPTIONS_PARAM, ImmutableMap.of());
   }
 
-  private static String viewName(InternalRow args, String tableName) {
-    if (args.isNullAt(CHANGELOG_VIEW_NAME_ORDINAL)) {
-      return String.format("`%s_changes`", tableName);
-    } else {
-      return args.getString(CHANGELOG_VIEW_NAME_ORDINAL);
-    }
+  private String viewName(ProcedureInput input, String tableName) {
+    String defaultValue = String.format("`%s_changes`", tableName);
+    return input.string(CHANGELOG_VIEW_PARAM, defaultValue);
   }
 
   private Dataset<Row> applyChangelogIterator(Dataset<Row> df, Column[] 
repartitionColumns) {
diff --git 
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java
 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java
new file mode 100644
index 0000000000..1b994c5c36
--- /dev/null
+++ 
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.procedures;
+
+import java.lang.reflect.Array;
+import java.util.Map;
+import java.util.function.BiFunction;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+import org.apache.iceberg.spark.Spark3Util;
+import org.apache.iceberg.spark.Spark3Util.CatalogAndIdentifier;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.util.ArrayData;
+import org.apache.spark.sql.catalyst.util.MapData;
+import org.apache.spark.sql.connector.catalog.CatalogPlugin;
+import org.apache.spark.sql.connector.catalog.Identifier;
+import org.apache.spark.sql.connector.catalog.TableCatalog;
+import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+
+/** A class that abstracts common logic for working with input to a procedure. 
*/
+class ProcedureInput {
+
+  private static final DataType STRING_ARRAY = 
DataTypes.createArrayType(DataTypes.StringType);
+  private static final DataType STRING_MAP =
+      DataTypes.createMapType(DataTypes.StringType, DataTypes.StringType);
+
+  private final SparkSession spark;
+  private final TableCatalog catalog;
+  private final Map<String, Integer> paramOrdinals;
+  private final InternalRow args;
+
+  ProcedureInput(
+      SparkSession spark, TableCatalog catalog, ProcedureParameter[] params, 
InternalRow args) {
+    this.spark = spark;
+    this.catalog = catalog;
+    this.paramOrdinals = computeParamOrdinals(params);
+    this.args = args;
+  }
+
+  public boolean isProvided(ProcedureParameter param) {
+    int ordinal = ordinal(param);
+    return !args.isNullAt(ordinal);
+  }
+
+  public boolean bool(ProcedureParameter param, boolean defaultValue) {
+    validateParamType(param, DataTypes.BooleanType);
+    int ordinal = ordinal(param);
+    return args.isNullAt(ordinal) ? defaultValue : args.getBoolean(ordinal);
+  }
+
+  public String string(ProcedureParameter param) {
+    String value = string(param, null);
+    Preconditions.checkArgument(value != null, "Parameter '%s' is not set", 
param.name());
+    return value;
+  }
+
+  public String string(ProcedureParameter param, String defaultValue) {
+    validateParamType(param, DataTypes.StringType);
+    int ordinal = ordinal(param);
+    return args.isNullAt(ordinal) ? defaultValue : args.getString(ordinal);
+  }
+
+  public String[] stringArray(ProcedureParameter param) {
+    String[] value = stringArray(param, null);
+    Preconditions.checkArgument(value != null, "Parameter '%s' is not set", 
param.name());
+    return value;
+  }
+
+  public String[] stringArray(ProcedureParameter param, String[] defaultValue) 
{
+    validateParamType(param, STRING_ARRAY);
+    return array(
+        param,
+        (array, ordinal) -> array.getUTF8String(ordinal).toString(),
+        String.class,
+        defaultValue);
+  }
+
+  @SuppressWarnings("unchecked")
+  private <T> T[] array(
+      ProcedureParameter param,
+      BiFunction<ArrayData, Integer, T> convertElement,
+      Class<T> elementClass,
+      T[] defaultValue) {
+
+    int ordinal = ordinal(param);
+
+    if (args.isNullAt(ordinal)) {
+      return defaultValue;
+    }
+
+    ArrayData arrayData = args.getArray(ordinal);
+
+    T[] convertedArray = (T[]) Array.newInstance(elementClass, 
arrayData.numElements());
+
+    for (int index = 0; index < arrayData.numElements(); index++) {
+      convertedArray[index] = convertElement.apply(arrayData, index);
+    }
+
+    return convertedArray;
+  }
+
+  public Map<String, String> stringMap(ProcedureParameter param, Map<String, 
String> defaultValue) {
+    validateParamType(param, STRING_MAP);
+    return map(
+        param,
+        (keys, ordinal) -> keys.getUTF8String(ordinal).toString(),
+        (values, ordinal) -> values.getUTF8String(ordinal).toString(),
+        defaultValue);
+  }
+
+  private <K, V> Map<K, V> map(
+      ProcedureParameter param,
+      BiFunction<ArrayData, Integer, K> convertKey,
+      BiFunction<ArrayData, Integer, V> convertValue,
+      Map<K, V> defaultValue) {
+
+    int ordinal = ordinal(param);
+
+    if (args.isNullAt(ordinal)) {
+      return defaultValue;
+    }
+
+    MapData mapData = args.getMap(ordinal);
+
+    Map<K, V> convertedMap = Maps.newHashMap();
+
+    for (int index = 0; index < mapData.numElements(); index++) {
+      K convertedKey = convertKey.apply(mapData.keyArray(), index);
+      V convertedValue = convertValue.apply(mapData.valueArray(), index);
+      convertedMap.put(convertedKey, convertedValue);
+    }
+
+    return convertedMap;
+  }
+
+  public Identifier ident(ProcedureParameter param) {
+    String identAsString = string(param);
+    CatalogAndIdentifier catalogAndIdent = toCatalogAndIdent(identAsString, 
param.name(), catalog);
+
+    Preconditions.checkArgument(
+        catalogAndIdent.catalog().equals(catalog),
+        "Cannot run procedure in catalog '%s': '%s' is a table in catalog 
'%s'",
+        catalog.name(),
+        identAsString,
+        catalogAndIdent.catalog().name());
+
+    return catalogAndIdent.identifier();
+  }
+
+  private CatalogAndIdentifier toCatalogAndIdent(
+      String identAsString, String paramName, CatalogPlugin defaultCatalog) {
+
+    Preconditions.checkArgument(
+        StringUtils.isNotBlank(identAsString),
+        "Cannot handle an empty identifier for parameter '%s'",
+        paramName);
+
+    String desc = String.format("identifier for parameter '%s'", paramName);
+    return Spark3Util.catalogAndIdentifier(desc, spark, identAsString, 
defaultCatalog);
+  }
+
+  private int ordinal(ProcedureParameter param) {
+    return paramOrdinals.get(param.name());
+  }
+
+  private Map<String, Integer> computeParamOrdinals(ProcedureParameter[] 
params) {
+    Map<String, Integer> ordinals = Maps.newHashMap();
+
+    for (int index = 0; index < params.length; index++) {
+      String paramName = params[index].name();
+
+      Preconditions.checkArgument(
+          !ordinals.containsKey(paramName),
+          "Detected multiple parameters named as '%s'",
+          paramName);
+
+      ordinals.put(paramName, index);
+    }
+
+    return ordinals;
+  }
+
+  private void validateParamType(ProcedureParameter param, DataType 
expectedDataType) {
+    Preconditions.checkArgument(
+        expectedDataType.sameType(param.dataType()),
+        "Parameter '%s' must be of type %s",
+        param.name(),
+        expectedDataType.catalogString());
+  }
+}

Reply via email to