This is an automated email from the ASF dual-hosted git repository.
aokolnychyi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new 3c42a10588 Spark 3.3: Add ProcedureInput to simplify procedures (#7038)
3c42a10588 is described below
commit 3c42a105883ff6edcc3cee71155b59ba89d5146d
Author: Anton Okolnychyi <[email protected]>
AuthorDate: Thu Mar 9 19:28:33 2023 -0800
Spark 3.3: Add ProcedureInput to simplify procedures (#7038)
---
.../procedures/CreateChangelogViewProcedure.java | 103 +++++-----
.../iceberg/spark/procedures/ProcedureInput.java | 210 +++++++++++++++++++++
2 files changed, 252 insertions(+), 61 deletions(-)
diff --git
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/CreateChangelogViewProcedure.java
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/CreateChangelogViewProcedure.java
index ab844e08e9..b47cc0de0b 100644
---
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/CreateChangelogViewProcedure.java
+++
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/CreateChangelogViewProcedure.java
@@ -24,8 +24,8 @@ import java.util.Map;
import org.apache.iceberg.MetadataColumns;
import org.apache.iceberg.Table;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
-import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.spark.ChangelogIterator;
import org.apache.iceberg.spark.source.SparkChangelogTable;
import org.apache.spark.api.java.function.MapPartitionsFunction;
@@ -42,7 +42,6 @@ import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.types.UTF8String;
-import scala.runtime.BoxedUnit;
/**
* A procedure that creates a view for changed rows.
@@ -81,23 +80,29 @@ import scala.runtime.BoxedUnit;
*/
public class CreateChangelogViewProcedure extends BaseProcedure {
+ private static final ProcedureParameter TABLE_PARAM =
+ ProcedureParameter.required("table", DataTypes.StringType);
+ private static final ProcedureParameter CHANGELOG_VIEW_PARAM =
+ ProcedureParameter.optional("changelog_view", DataTypes.StringType);
+ private static final ProcedureParameter OPTIONS_PARAM =
+ ProcedureParameter.optional("options", STRING_MAP);
+ private static final ProcedureParameter COMPUTE_UPDATES_PARAM =
+ ProcedureParameter.optional("compute_updates", DataTypes.BooleanType);
+ private static final ProcedureParameter REMOVE_CARRYOVERS_PARAM =
+ ProcedureParameter.optional("remove_carryovers", DataTypes.BooleanType);
+ private static final ProcedureParameter IDENTIFIER_COLUMNS_PARAM =
+ ProcedureParameter.optional("identifier_columns", STRING_ARRAY);
+
private static final ProcedureParameter[] PARAMETERS =
new ProcedureParameter[] {
- ProcedureParameter.required("table", DataTypes.StringType),
- ProcedureParameter.optional("changelog_view", DataTypes.StringType),
- ProcedureParameter.optional("options", STRING_MAP),
- ProcedureParameter.optional("compute_updates", DataTypes.BooleanType),
- ProcedureParameter.optional("remove_carryovers",
DataTypes.BooleanType),
- ProcedureParameter.optional("identifier_columns", STRING_ARRAY),
+ TABLE_PARAM,
+ CHANGELOG_VIEW_PARAM,
+ OPTIONS_PARAM,
+ COMPUTE_UPDATES_PARAM,
+ REMOVE_CARRYOVERS_PARAM,
+ IDENTIFIER_COLUMNS_PARAM,
};
- private static final int TABLE_NAME_ORDINAL = 0;
- private static final int CHANGELOG_VIEW_NAME_ORDINAL = 1;
- private static final int OPTIONS_ORDINAL = 2;
- private static final int COMPUTE_UPDATES_ORDINAL = 3;
- private static final int REMOVE_CARRYOVERS_ORDINAL = 4;
- private static final int IDENTIFIER_COLUMNS_ORDINAL = 5;
-
private static final StructType OUTPUT_TYPE =
new StructType(
new StructField[] {
@@ -129,20 +134,21 @@ public class CreateChangelogViewProcedure extends
BaseProcedure {
@Override
public InternalRow[] call(InternalRow args) {
- Identifier tableIdent =
- toIdentifier(args.getString(TABLE_NAME_ORDINAL),
PARAMETERS[TABLE_NAME_ORDINAL].name());
+ ProcedureInput input = new ProcedureInput(spark(), tableCatalog(),
PARAMETERS, args);
+
+ Identifier tableIdent = input.ident(TABLE_PARAM);
// load insert and deletes from the changelog table
Identifier changelogTableIdent = changelogTableIdent(tableIdent);
- Dataset<Row> df = loadDataSetFromTable(changelogTableIdent, options(args));
+ Dataset<Row> df = loadDataSetFromTable(changelogTableIdent,
options(input));
- if (shouldComputeUpdateImages(args)) {
- df = computeUpdateImages(identifierColumns(args, tableIdent), df);
- } else if (shouldRemoveCarryoverRows(args)) {
+ if (shouldComputeUpdateImages(input)) {
+ df = computeUpdateImages(identifierColumns(input, tableIdent), df);
+ } else if (shouldRemoveCarryoverRows(input)) {
df = removeCarryoverRows(df);
}
- String viewName = viewName(args, tableIdent.name());
+ String viewName = viewName(input, tableIdent.name());
df.createOrReplaceTempView(viewName);
@@ -164,21 +170,14 @@ public class CreateChangelogViewProcedure extends
BaseProcedure {
return applyChangelogIterator(df, repartitionColumns);
}
- private boolean shouldComputeUpdateImages(InternalRow args) {
- if (!args.isNullAt(COMPUTE_UPDATES_ORDINAL)) {
- return args.getBoolean(COMPUTE_UPDATES_ORDINAL);
- } else {
- // If the identifier columns are set, we compute pre/post update images
by default.
- return !args.isNullAt(IDENTIFIER_COLUMNS_ORDINAL);
- }
+ private boolean shouldComputeUpdateImages(ProcedureInput input) {
+ // If the identifier columns are set, we compute pre/post update images by
default.
+ boolean defaultValue = input.isProvided(IDENTIFIER_COLUMNS_PARAM);
+ return input.bool(COMPUTE_UPDATES_PARAM, defaultValue);
}
- private boolean shouldRemoveCarryoverRows(InternalRow args) {
- if (args.isNullAt(REMOVE_CARRYOVERS_ORDINAL)) {
- return true;
- } else {
- return args.getBoolean(REMOVE_CARRYOVERS_ORDINAL);
- }
+ private boolean shouldRemoveCarryoverRows(ProcedureInput input) {
+ return input.bool(REMOVE_CARRYOVERS_PARAM, true);
}
private Dataset<Row> removeCarryoverRows(Dataset<Row> df) {
@@ -190,11 +189,9 @@ public class CreateChangelogViewProcedure extends
BaseProcedure {
return applyChangelogIterator(df, repartitionColumns);
}
- private String[] identifierColumns(InternalRow args, Identifier tableIdent) {
- if (!args.isNullAt(IDENTIFIER_COLUMNS_ORDINAL)) {
- return Arrays.stream(args.getArray(IDENTIFIER_COLUMNS_ORDINAL).array())
- .map(column -> column.toString())
- .toArray(String[]::new);
+ private String[] identifierColumns(ProcedureInput input, Identifier
tableIdent) {
+ if (input.isProvided(IDENTIFIER_COLUMNS_PARAM)) {
+ return input.stringArray(IDENTIFIER_COLUMNS_PARAM);
} else {
Table table = loadSparkTable(tableIdent).table();
return table.schema().identifierFieldNames().toArray(new String[0]);
@@ -208,29 +205,13 @@ public class CreateChangelogViewProcedure extends
BaseProcedure {
return Identifier.of(namespace.toArray(new String[0]),
SparkChangelogTable.TABLE_NAME);
}
- private Map<String, String> options(InternalRow args) {
- Map<String, String> options = Maps.newHashMap();
-
- if (!args.isNullAt(OPTIONS_ORDINAL)) {
- args.getMap(OPTIONS_ORDINAL)
- .foreach(
- DataTypes.StringType,
- DataTypes.StringType,
- (k, v) -> {
- options.put(k.toString(), v.toString());
- return BoxedUnit.UNIT;
- });
- }
-
- return options;
+ private Map<String, String> options(ProcedureInput input) {
+ return input.stringMap(OPTIONS_PARAM, ImmutableMap.of());
}
- private static String viewName(InternalRow args, String tableName) {
- if (args.isNullAt(CHANGELOG_VIEW_NAME_ORDINAL)) {
- return String.format("`%s_changes`", tableName);
- } else {
- return args.getString(CHANGELOG_VIEW_NAME_ORDINAL);
- }
+ private String viewName(ProcedureInput input, String tableName) {
+ String defaultValue = String.format("`%s_changes`", tableName);
+ return input.string(CHANGELOG_VIEW_PARAM, defaultValue);
}
private Dataset<Row> applyChangelogIterator(Dataset<Row> df, Column[]
repartitionColumns) {
diff --git
a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java
new file mode 100644
index 0000000000..1b994c5c36
--- /dev/null
+++
b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/procedures/ProcedureInput.java
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark.procedures;
+
+import java.lang.reflect.Array;
+import java.util.Map;
+import java.util.function.BiFunction;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+import org.apache.iceberg.spark.Spark3Util;
+import org.apache.iceberg.spark.Spark3Util.CatalogAndIdentifier;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.util.ArrayData;
+import org.apache.spark.sql.catalyst.util.MapData;
+import org.apache.spark.sql.connector.catalog.CatalogPlugin;
+import org.apache.spark.sql.connector.catalog.Identifier;
+import org.apache.spark.sql.connector.catalog.TableCatalog;
+import org.apache.spark.sql.connector.iceberg.catalog.ProcedureParameter;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+
+/** A class that abstracts common logic for working with input to a procedure.
*/
+class ProcedureInput {
+
+ private static final DataType STRING_ARRAY =
DataTypes.createArrayType(DataTypes.StringType);
+ private static final DataType STRING_MAP =
+ DataTypes.createMapType(DataTypes.StringType, DataTypes.StringType);
+
+ private final SparkSession spark;
+ private final TableCatalog catalog;
+ private final Map<String, Integer> paramOrdinals;
+ private final InternalRow args;
+
+ ProcedureInput(
+ SparkSession spark, TableCatalog catalog, ProcedureParameter[] params,
InternalRow args) {
+ this.spark = spark;
+ this.catalog = catalog;
+ this.paramOrdinals = computeParamOrdinals(params);
+ this.args = args;
+ }
+
+ public boolean isProvided(ProcedureParameter param) {
+ int ordinal = ordinal(param);
+ return !args.isNullAt(ordinal);
+ }
+
+ public boolean bool(ProcedureParameter param, boolean defaultValue) {
+ validateParamType(param, DataTypes.BooleanType);
+ int ordinal = ordinal(param);
+ return args.isNullAt(ordinal) ? defaultValue : args.getBoolean(ordinal);
+ }
+
+ public String string(ProcedureParameter param) {
+ String value = string(param, null);
+ Preconditions.checkArgument(value != null, "Parameter '%s' is not set",
param.name());
+ return value;
+ }
+
+ public String string(ProcedureParameter param, String defaultValue) {
+ validateParamType(param, DataTypes.StringType);
+ int ordinal = ordinal(param);
+ return args.isNullAt(ordinal) ? defaultValue : args.getString(ordinal);
+ }
+
+ public String[] stringArray(ProcedureParameter param) {
+ String[] value = stringArray(param, null);
+ Preconditions.checkArgument(value != null, "Parameter '%s' is not set",
param.name());
+ return value;
+ }
+
+ public String[] stringArray(ProcedureParameter param, String[] defaultValue)
{
+ validateParamType(param, STRING_ARRAY);
+ return array(
+ param,
+ (array, ordinal) -> array.getUTF8String(ordinal).toString(),
+ String.class,
+ defaultValue);
+ }
+
+ @SuppressWarnings("unchecked")
+ private <T> T[] array(
+ ProcedureParameter param,
+ BiFunction<ArrayData, Integer, T> convertElement,
+ Class<T> elementClass,
+ T[] defaultValue) {
+
+ int ordinal = ordinal(param);
+
+ if (args.isNullAt(ordinal)) {
+ return defaultValue;
+ }
+
+ ArrayData arrayData = args.getArray(ordinal);
+
+ T[] convertedArray = (T[]) Array.newInstance(elementClass,
arrayData.numElements());
+
+ for (int index = 0; index < arrayData.numElements(); index++) {
+ convertedArray[index] = convertElement.apply(arrayData, index);
+ }
+
+ return convertedArray;
+ }
+
+ public Map<String, String> stringMap(ProcedureParameter param, Map<String,
String> defaultValue) {
+ validateParamType(param, STRING_MAP);
+ return map(
+ param,
+ (keys, ordinal) -> keys.getUTF8String(ordinal).toString(),
+ (values, ordinal) -> values.getUTF8String(ordinal).toString(),
+ defaultValue);
+ }
+
+ private <K, V> Map<K, V> map(
+ ProcedureParameter param,
+ BiFunction<ArrayData, Integer, K> convertKey,
+ BiFunction<ArrayData, Integer, V> convertValue,
+ Map<K, V> defaultValue) {
+
+ int ordinal = ordinal(param);
+
+ if (args.isNullAt(ordinal)) {
+ return defaultValue;
+ }
+
+ MapData mapData = args.getMap(ordinal);
+
+ Map<K, V> convertedMap = Maps.newHashMap();
+
+ for (int index = 0; index < mapData.numElements(); index++) {
+ K convertedKey = convertKey.apply(mapData.keyArray(), index);
+ V convertedValue = convertValue.apply(mapData.valueArray(), index);
+ convertedMap.put(convertedKey, convertedValue);
+ }
+
+ return convertedMap;
+ }
+
+ public Identifier ident(ProcedureParameter param) {
+ String identAsString = string(param);
+ CatalogAndIdentifier catalogAndIdent = toCatalogAndIdent(identAsString,
param.name(), catalog);
+
+ Preconditions.checkArgument(
+ catalogAndIdent.catalog().equals(catalog),
+ "Cannot run procedure in catalog '%s': '%s' is a table in catalog
'%s'",
+ catalog.name(),
+ identAsString,
+ catalogAndIdent.catalog().name());
+
+ return catalogAndIdent.identifier();
+ }
+
+ private CatalogAndIdentifier toCatalogAndIdent(
+ String identAsString, String paramName, CatalogPlugin defaultCatalog) {
+
+ Preconditions.checkArgument(
+ StringUtils.isNotBlank(identAsString),
+ "Cannot handle an empty identifier for parameter '%s'",
+ paramName);
+
+ String desc = String.format("identifier for parameter '%s'", paramName);
+ return Spark3Util.catalogAndIdentifier(desc, spark, identAsString,
defaultCatalog);
+ }
+
+ private int ordinal(ProcedureParameter param) {
+ return paramOrdinals.get(param.name());
+ }
+
+ private Map<String, Integer> computeParamOrdinals(ProcedureParameter[]
params) {
+ Map<String, Integer> ordinals = Maps.newHashMap();
+
+ for (int index = 0; index < params.length; index++) {
+ String paramName = params[index].name();
+
+ Preconditions.checkArgument(
+ !ordinals.containsKey(paramName),
+ "Detected multiple parameters named as '%s'",
+ paramName);
+
+ ordinals.put(paramName, index);
+ }
+
+ return ordinals;
+ }
+
+ private void validateParamType(ProcedureParameter param, DataType
expectedDataType) {
+ Preconditions.checkArgument(
+ expectedDataType.sameType(param.dataType()),
+ "Parameter '%s' must be of type %s",
+ param.name(),
+ expectedDataType.catalogString());
+ }
+}