This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 07325a844 [spark] Implement of rollback and tag actions using Spark
Procedure (#1897)
07325a844 is described below
commit 07325a844f1902f68891995d6085b1172d0ccd73
Author: Kunni <[email protected]>
AuthorDate: Tue Aug 29 16:16:00 2023 +0800
[spark] Implement of rollback and tag actions using Spark Procedure (#1897)
---
docs/content/maintenance/manage-snapshots.md | 10 ++
docs/content/maintenance/manage-tags.md | 28 ++++
.../java/org/apache/paimon/spark/SparkCatalog.java | 2 +-
.../org/apache/paimon/spark/SparkProcedures.java | 6 +
.../java/org/apache/paimon/spark/SparkUtils.java | 123 ++++++++++++++++++
.../paimon/spark/procedure/BaseProcedure.java | 142 +++++++++++++++++++++
.../paimon/spark/procedure/CreateTagProcedure.java | 90 +++++++++++++
.../paimon/spark/procedure/DeleteTagProcedure.java | 87 +++++++++++++
.../paimon/spark/procedure/RollbackProcedure.java | 92 +++++++++++++
.../extensions/PaimonSparkSessionExtensions.scala | 4 +
.../sql/execution/datasources/v2/CallExec.scala | 36 ++++++
.../v2/ExtendedDataSourceV2Strategy.scala | 43 +++++++
.../sql/CreateAndDeleteTagProcedureTest.scala | 86 +++++++++++++
.../paimon/spark/sql/RollbackProcedureTest.scala | 94 ++++++++++++++
14 files changed, 842 insertions(+), 1 deletion(-)
diff --git a/docs/content/maintenance/manage-snapshots.md
b/docs/content/maintenance/manage-snapshots.md
index ca4e99be4..0bcca42a0 100644
--- a/docs/content/maintenance/manage-snapshots.md
+++ b/docs/content/maintenance/manage-snapshots.md
@@ -257,4 +257,14 @@ public class RollbackTo {
{{< /tab >}}
+{{< tab "Spark" >}}
+
+Run the following sql:
+
+```sql
+CALL rollback(table => 'test.T', version => '2');
+```
+
+{{< /tab >}}
+
{{< /tabs >}}
\ No newline at end of file
diff --git a/docs/content/maintenance/manage-tags.md
b/docs/content/maintenance/manage-tags.md
index b37fe3c15..92ec18a55 100644
--- a/docs/content/maintenance/manage-tags.md
+++ b/docs/content/maintenance/manage-tags.md
@@ -127,6 +127,14 @@ public class CreateTag {
{{< /tab >}}
+{{< tab "Spark" >}}
+Run the following sql:
+```sql
+CALL create_tag(table => 'test.T', tag => 'test_tag', snapshot => 2);
+```
+
+{{< /tab >}}
+
{{< /tabs >}}
## Delete Tags
@@ -168,6 +176,16 @@ public class DeleteTag {
{{< /tab >}}
+{{< /tab >}}
+
+{{< tab "Spark" >}}
+Run the following sql:
+```sql
+CALL delete_tag(table => 'test.T', tag => 'test_tag');
+```
+
+{{< /tab >}}
+
{{< /tabs >}}
## Rollback to Tag
@@ -219,6 +237,16 @@ public class RollbackTo {
{{< /tab >}}
+{{< tab "Spark" >}}
+
+Run the following sql:
+
+```sql
+CALL rollback(table => 'test.T', version => '2');
+```
+
+{{< /tab >}}
+
{{< /tabs >}}
## Work with Flink Savepoint
diff --git
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkCatalog.java
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkCatalog.java
index 81dc8c27e..d81d0517b 100644
---
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkCatalog.java
+++
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkCatalog.java
@@ -322,7 +322,7 @@ public class SparkCatalog implements TableCatalog,
ProcedureCatalog, SupportsNam
@Override
public Procedure loadProcedure(Identifier identifier) throws
NoSuchProcedureException {
if (isValidateNamespace(identifier.namespace())) {
- ProcedureBuilder builder = SparkProcedures.newBuilder(name);
+ ProcedureBuilder builder =
SparkProcedures.newBuilder(identifier.name());
if (builder != null) {
return builder.withTableCatalog(this).build();
}
diff --git
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkProcedures.java
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkProcedures.java
index 8eb08b659..84d37cedc 100644
---
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkProcedures.java
+++
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkProcedures.java
@@ -18,8 +18,11 @@
package org.apache.paimon.spark;
+import org.apache.paimon.spark.procedure.CreateTagProcedure;
+import org.apache.paimon.spark.procedure.DeleteTagProcedure;
import org.apache.paimon.spark.procedure.Procedure;
import org.apache.paimon.spark.procedure.ProcedureBuilder;
+import org.apache.paimon.spark.procedure.RollbackProcedure;
import org.apache.hadoop.shaded.com.google.common.collect.ImmutableMap;
@@ -42,6 +45,9 @@ public class SparkProcedures {
private static Map<String, Supplier<ProcedureBuilder>>
initProcedureBuilders() {
ImmutableMap.Builder<String, Supplier<ProcedureBuilder>>
procedureBuilders =
ImmutableMap.builder();
+ procedureBuilders.put("rollback", RollbackProcedure::builder);
+ procedureBuilders.put("create_tag", CreateTagProcedure::builder);
+ procedureBuilders.put("delete_tag", DeleteTagProcedure::builder);
return procedureBuilders.build();
}
}
diff --git
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkUtils.java
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkUtils.java
index 2a293261f..f315c7140 100644
---
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkUtils.java
+++
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkUtils.java
@@ -19,8 +19,23 @@ package org.apache.paimon.spark;
import org.apache.paimon.disk.IOManager;
import org.apache.paimon.disk.IOManagerImpl;
+import org.apache.paimon.utils.Pair;
+import org.apache.paimon.utils.Preconditions;
import org.apache.spark.SparkEnv;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.parser.ParseException;
+import org.apache.spark.sql.catalyst.parser.ParserInterface;
+import org.apache.spark.sql.connector.catalog.CatalogManager;
+import org.apache.spark.sql.connector.catalog.CatalogPlugin;
+import org.apache.spark.sql.connector.catalog.Identifier;
+
+import java.util.List;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+
+import scala.collection.JavaConverters;
+import scala.collection.immutable.Seq;
/** Utils for Spark. */
public class SparkUtils {
@@ -29,4 +44,112 @@ public class SparkUtils {
String[] localDirs =
SparkEnv.get().blockManager().diskBlockManager().localDirsString();
return new IOManagerImpl(localDirs);
}
+
+ /** This mimics a class inside of Spark which is private inside of
LookupCatalog. */
+ public static class CatalogAndIdentifier {
+ private final CatalogPlugin catalog;
+ private final Identifier identifier;
+
+ public CatalogAndIdentifier(Pair<CatalogPlugin, Identifier>
identifier) {
+ this.catalog = identifier.getLeft();
+ this.identifier = identifier.getRight();
+ }
+
+ public CatalogPlugin catalog() {
+ return catalog;
+ }
+
+ public Identifier identifier() {
+ return identifier;
+ }
+ }
+
+ /**
+ * A modified version of Spark's
LookupCatalog.CatalogAndIdentifier.unapply Attempts to find the
+ * catalog and identifier a multipart identifier represents.
+ *
+ * @param nameParts Multipart identifier representing a table
+ * @return The CatalogPlugin and Identifier for the table
+ */
+ public static <C, T> Pair<C, T> catalogAndIdentifier(
+ List<String> nameParts,
+ Function<String, C> catalogProvider,
+ BiFunction<String[], String, T> identifierProvider,
+ C currentCatalog,
+ String[] currentNamespace) {
+ Preconditions.checkArgument(
+ !nameParts.isEmpty(), "Cannot determine catalog and identifier
from empty name");
+
+ int lastElementIndex = nameParts.size() - 1;
+ String name = nameParts.get(lastElementIndex);
+
+ if (nameParts.size() == 1) {
+ // Only a single element, use current catalog and namespace
+ return Pair.of(currentCatalog,
identifierProvider.apply(currentNamespace, name));
+ } else {
+ C catalog = catalogProvider.apply(nameParts.get(0));
+ if (catalog == null) {
+ // The first element was not a valid catalog, treat it like
part of the namespace
+ String[] namespace = nameParts.subList(0,
lastElementIndex).toArray(new String[0]);
+ return Pair.of(currentCatalog,
identifierProvider.apply(namespace, name));
+ } else {
+ // Assume the first element is a valid catalog
+ String[] namespace = nameParts.subList(1,
lastElementIndex).toArray(new String[0]);
+ return Pair.of(catalog, identifierProvider.apply(namespace,
name));
+ }
+ }
+ }
+
+ /**
+ * A modified version of Spark's
LookupCatalog.CatalogAndIdentifier.unapply Attempts to find the
+ * catalog and identifier a multipart identifier represents.
+ *
+ * @param spark Spark session to use for resolution
+ * @param nameParts Multipart identifier representing a table
+ * @param defaultCatalog Catalog to use if none is specified
+ * @return The CatalogPlugin and Identifier for the table
+ */
+ public static CatalogAndIdentifier catalogAndIdentifier(
+ SparkSession spark, List<String> nameParts, CatalogPlugin
defaultCatalog) {
+ CatalogManager catalogManager = spark.sessionState().catalogManager();
+
+ String[] currentNamespace;
+ if (defaultCatalog.equals(catalogManager.currentCatalog())) {
+ currentNamespace = catalogManager.currentNamespace();
+ } else {
+ currentNamespace = defaultCatalog.defaultNamespace();
+ }
+
+ Pair<CatalogPlugin, Identifier> catalogIdentifier =
+ SparkUtils.catalogAndIdentifier(
+ nameParts,
+ catalogName -> {
+ try {
+ return catalogManager.catalog(catalogName);
+ } catch (Exception e) {
+ return null;
+ }
+ },
+ Identifier::of,
+ defaultCatalog,
+ currentNamespace);
+ return new CatalogAndIdentifier(catalogIdentifier);
+ }
+
+ public static CatalogAndIdentifier catalogAndIdentifier(
+ SparkSession spark, String name, CatalogPlugin defaultCatalog)
throws ParseException {
+ ParserInterface parser = spark.sessionState().sqlParser();
+ Seq<String> multiPartIdentifier =
parser.parseMultipartIdentifier(name).toIndexedSeq();
+ List<String> javaMultiPartIdentifier =
JavaConverters.seqAsJavaList(multiPartIdentifier);
+ return catalogAndIdentifier(spark, javaMultiPartIdentifier,
defaultCatalog);
+ }
+
+ public static CatalogAndIdentifier catalogAndIdentifier(
+ String description, SparkSession spark, String name, CatalogPlugin
defaultCatalog) {
+ try {
+ return catalogAndIdentifier(spark, name, defaultCatalog);
+ } catch (ParseException e) {
+ throw new IllegalArgumentException("Cannot parse " + description +
": " + name, e);
+ }
+ }
}
diff --git
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/BaseProcedure.java
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/BaseProcedure.java
new file mode 100644
index 000000000..797a41833
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/BaseProcedure.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.procedure;
+
+import org.apache.paimon.spark.SparkTable;
+import org.apache.paimon.spark.SparkUtils;
+import org.apache.paimon.utils.Preconditions;
+
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
+import org.apache.spark.sql.connector.catalog.CatalogPlugin;
+import org.apache.spark.sql.connector.catalog.Identifier;
+import org.apache.spark.sql.connector.catalog.Table;
+import org.apache.spark.sql.connector.catalog.TableCatalog;
+import org.apache.spark.sql.execution.CacheManager;
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation;
+
+import java.util.function.Function;
+
+import scala.Option;
+
+/** A base class for procedure. */
+abstract class BaseProcedure implements Procedure {
+
+ private final SparkSession spark;
+ private final TableCatalog tableCatalog;
+
+ protected BaseProcedure(TableCatalog tableCatalog) {
+ this.spark = SparkSession.active();
+ this.tableCatalog = tableCatalog;
+ }
+
+ protected Identifier toIdentifier(String identifierAsString, String
argName) {
+ SparkUtils.CatalogAndIdentifier catalogAndIdentifier =
+ toCatalogAndIdentifier(identifierAsString, argName,
tableCatalog);
+
+ Preconditions.checkArgument(
+ catalogAndIdentifier.catalog().equals(tableCatalog),
+ "Cannot run procedure in catalog '%s': '%s' is a table in
catalog '%s'",
+ tableCatalog.name(),
+ identifierAsString,
+ catalogAndIdentifier.catalog().name());
+
+ return catalogAndIdentifier.identifier();
+ }
+
+ protected SparkUtils.CatalogAndIdentifier toCatalogAndIdentifier(
+ String identifierAsString, String argName, CatalogPlugin catalog) {
+ Preconditions.checkArgument(
+ identifierAsString != null && !identifierAsString.isEmpty(),
+ "Cannot handle an empty identifier for argument %s",
+ argName);
+
+ return SparkUtils.catalogAndIdentifier(
+ "identifier for arg " + argName, spark, identifierAsString,
catalog);
+ }
+
+ protected <T> T modifyPaimonTable(
+ Identifier ident, Function<org.apache.paimon.table.Table, T> func)
{
+ return execute(ident, true, func);
+ }
+
+ private <T> T execute(
+ Identifier ident,
+ boolean refreshSparkCache,
+ Function<org.apache.paimon.table.Table, T> func) {
+ SparkTable sparkTable = loadSparkTable(ident);
+ org.apache.paimon.table.Table table = sparkTable.getTable();
+
+ T result = func.apply(table);
+
+ if (refreshSparkCache) {
+ refreshSparkCache(ident, sparkTable);
+ }
+
+ return result;
+ }
+
+ protected SparkTable loadSparkTable(Identifier ident) {
+ try {
+ Table table = tableCatalog.loadTable(ident);
+ Preconditions.checkArgument(
+ table instanceof SparkTable, "%s is not %s", ident,
SparkTable.class.getName());
+ return (SparkTable) table;
+ } catch (NoSuchTableException e) {
+ String errMsg =
+ String.format(
+ "Couldn't load table '%s' in catalog '%s'", ident,
tableCatalog.name());
+ throw new RuntimeException(errMsg, e);
+ }
+ }
+
+ protected void refreshSparkCache(Identifier ident, Table table) {
+ CacheManager cacheManager = spark.sharedState().cacheManager();
+ DataSourceV2Relation relation =
+ DataSourceV2Relation.create(table, Option.apply(tableCatalog),
Option.apply(ident));
+ cacheManager.recacheByPlan(spark, relation);
+ }
+
+ protected InternalRow newInternalRow(Object... values) {
+ return new GenericInternalRow(values);
+ }
+
+ protected abstract static class Builder<T extends BaseProcedure>
implements ProcedureBuilder {
+ private TableCatalog tableCatalog;
+
+ @Override
+ public Builder<T> withTableCatalog(TableCatalog newTableCatalog) {
+ this.tableCatalog = newTableCatalog;
+ return this;
+ }
+
+ @Override
+ public T build() {
+ return doBuild();
+ }
+
+ protected abstract T doBuild();
+
+ TableCatalog tableCatalog() {
+ return tableCatalog;
+ }
+ }
+}
diff --git
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/CreateTagProcedure.java
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/CreateTagProcedure.java
new file mode 100644
index 000000000..c0c63b9b0
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/CreateTagProcedure.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.procedure;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.connector.catalog.Identifier;
+import org.apache.spark.sql.connector.catalog.TableCatalog;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+import static org.apache.spark.sql.types.DataTypes.LongType;
+import static org.apache.spark.sql.types.DataTypes.StringType;
+
+/** A procedure to create a tag. */
+public class CreateTagProcedure extends BaseProcedure {
+
+ private static final ProcedureParameter[] PARAMETERS =
+ new ProcedureParameter[] {
+ ProcedureParameter.required("table", StringType),
+ ProcedureParameter.required("tag", StringType),
+ ProcedureParameter.required("snapshot", LongType)
+ };
+
+ private static final StructType OUTPUT_TYPE =
+ new StructType(
+ new StructField[] {
+ new StructField("result", DataTypes.BooleanType, true,
Metadata.empty())
+ });
+
+ protected CreateTagProcedure(TableCatalog tableCatalog) {
+ super(tableCatalog);
+ }
+
+ @Override
+ public ProcedureParameter[] parameters() {
+ return PARAMETERS;
+ }
+
+ @Override
+ public StructType outputType() {
+ return OUTPUT_TYPE;
+ }
+
+ @Override
+ public InternalRow[] call(InternalRow args) {
+ Identifier tableIdent = toIdentifier(args.getString(0),
PARAMETERS[0].name());
+ String tag = args.getString(1);
+ long snapshot = args.getLong(2);
+
+ return modifyPaimonTable(
+ tableIdent,
+ table -> {
+ table.createTag(tag, snapshot);
+ InternalRow outputRow = newInternalRow(true);
+ return new InternalRow[] {outputRow};
+ });
+ }
+
+ public static ProcedureBuilder builder() {
+ return new BaseProcedure.Builder<CreateTagProcedure>() {
+ @Override
+ public CreateTagProcedure doBuild() {
+ return new CreateTagProcedure(tableCatalog());
+ }
+ };
+ }
+
+ @Override
+ public String description() {
+ return "CreateTagProcedure";
+ }
+}
diff --git
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/DeleteTagProcedure.java
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/DeleteTagProcedure.java
new file mode 100644
index 000000000..3a499c75a
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/DeleteTagProcedure.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.procedure;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.connector.catalog.Identifier;
+import org.apache.spark.sql.connector.catalog.TableCatalog;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+import static org.apache.spark.sql.types.DataTypes.StringType;
+
+/** A procedure to delete a tag. */
+public class DeleteTagProcedure extends BaseProcedure {
+
+ private static final ProcedureParameter[] PARAMETERS =
+ new ProcedureParameter[] {
+ ProcedureParameter.required("table", StringType),
+ ProcedureParameter.required("tag", StringType)
+ };
+
+ private static final StructType OUTPUT_TYPE =
+ new StructType(
+ new StructField[] {
+ new StructField("result", DataTypes.BooleanType, true,
Metadata.empty())
+ });
+
+ protected DeleteTagProcedure(TableCatalog tableCatalog) {
+ super(tableCatalog);
+ }
+
+ @Override
+ public ProcedureParameter[] parameters() {
+ return PARAMETERS;
+ }
+
+ @Override
+ public StructType outputType() {
+ return OUTPUT_TYPE;
+ }
+
+ @Override
+ public InternalRow[] call(InternalRow args) {
+ Identifier tableIdent = toIdentifier(args.getString(0),
PARAMETERS[0].name());
+ String tag = args.getString(1);
+
+ return modifyPaimonTable(
+ tableIdent,
+ table -> {
+ table.deleteTag(tag);
+ InternalRow outputRow = newInternalRow(true);
+ return new InternalRow[] {outputRow};
+ });
+ }
+
+ public static ProcedureBuilder builder() {
+ return new BaseProcedure.Builder<DeleteTagProcedure>() {
+ @Override
+ public DeleteTagProcedure doBuild() {
+ return new DeleteTagProcedure(tableCatalog());
+ }
+ };
+ }
+
+ @Override
+ public String description() {
+ return "DeleteTagProcedure";
+ }
+}
diff --git
a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/RollbackProcedure.java
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/RollbackProcedure.java
new file mode 100644
index 000000000..d5362836b
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/procedure/RollbackProcedure.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.spark.procedure;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.connector.catalog.Identifier;
+import org.apache.spark.sql.connector.catalog.TableCatalog;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+import static org.apache.spark.sql.types.DataTypes.StringType;
+
+/** A procedure to rollback to a snapshot or a tag. */
+public class RollbackProcedure extends BaseProcedure {
+
+ private static final ProcedureParameter[] PARAMETERS =
+ new ProcedureParameter[] {
+ ProcedureParameter.required("table", StringType),
+ // snapshot id or tag name
+ ProcedureParameter.required("version", StringType)
+ };
+
+ private static final StructType OUTPUT_TYPE =
+ new StructType(
+ new StructField[] {
+ new StructField("result", DataTypes.BooleanType, true,
Metadata.empty())
+ });
+
+ private RollbackProcedure(TableCatalog tableCatalog) {
+ super(tableCatalog);
+ }
+
+ @Override
+ public ProcedureParameter[] parameters() {
+ return PARAMETERS;
+ }
+
+ @Override
+ public StructType outputType() {
+ return OUTPUT_TYPE;
+ }
+
+ @Override
+ public InternalRow[] call(InternalRow args) {
+ Identifier tableIdent = toIdentifier(args.getString(0),
PARAMETERS[0].name());
+ String version = args.getString(1);
+
+ return modifyPaimonTable(
+ tableIdent,
+ table -> {
+ if (version.chars().allMatch(Character::isDigit)) {
+ table.rollbackTo(Long.parseLong(version));
+ } else {
+ table.rollbackTo(version);
+ }
+ InternalRow outputRow = newInternalRow(true);
+ return new InternalRow[] {outputRow};
+ });
+ }
+
+ public static ProcedureBuilder builder() {
+ return new BaseProcedure.Builder<RollbackProcedure>() {
+ @Override
+ public RollbackProcedure doBuild() {
+ return new RollbackProcedure(tableCatalog());
+ }
+ };
+ }
+
+ @Override
+ public String description() {
+ return "RollbackProcedure";
+ }
+}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala
index 4e1117b26..dae604073 100644
---
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/paimon/spark/extensions/PaimonSparkSessionExtensions.scala
@@ -21,6 +21,7 @@ import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.catalyst.analysis.{CoerceArguments,
PaimonAnalysis, ResolveProcedures}
import
org.apache.spark.sql.catalyst.parser.extensions.PaimonSparkSqlExtensionsParser
import org.apache.spark.sql.catalyst.plans.logical.PaimonTableValuedFunctions
+import
org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Strategy
/** Spark session extension to extends the syntax and adds the rules. */
class PaimonSparkSessionExtensions extends (SparkSessionExtensions => Unit) {
@@ -40,5 +41,8 @@ class PaimonSparkSessionExtensions extends
(SparkSessionExtensions => Unit) {
extensions.injectTableFunction(
PaimonTableValuedFunctions.getTableValueFunctionInjection(fnName))
}
+
+ // planner extensions
+ extensions.injectPlannerStrategy(spark =>
ExtendedDataSourceV2Strategy(spark))
}
}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CallExec.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CallExec.scala
new file mode 100644
index 000000000..492fe6027
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CallExec.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.paimon.spark.procedure.Procedure
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.util.truncatedString
+
+case class CallExec(output: Seq[Attribute], procedure: Procedure, input:
InternalRow)
+ extends LeafV2CommandExec {
+
+ override protected def run(): Seq[InternalRow] = {
+ procedure.call(input)
+ }
+
+ override def simpleString(maxFields: Int): String = {
+ s"CallExec${truncatedString(output, "[", ", ", "]", maxFields)}
${procedure.description}"
+ }
+}
diff --git
a/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala
new file mode 100644
index 000000000..cde5816d8
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.{SparkSession, Strategy}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Expression,
GenericInternalRow, PredicateHelper}
+import org.apache.spark.sql.catalyst.plans.logical.{CallCommand, LogicalPlan}
+import org.apache.spark.sql.execution.SparkPlan
+
+case class ExtendedDataSourceV2Strategy(spark: SparkSession) extends Strategy
with PredicateHelper {
+
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case c @ CallCommand(procedure, args) =>
+ val input = buildInternalRow(args)
+ CallExec(c.output, procedure, input) :: Nil
+ case _ => Nil
+ }
+
+ private def buildInternalRow(exprs: Seq[Expression]): InternalRow = {
+ val values = new Array[Any](exprs.size)
+ for (index <- exprs.indices) {
+ values(index) = exprs(index).eval()
+ }
+ new GenericInternalRow(values)
+ }
+
+}
diff --git
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/CreateAndDeleteTagProcedureTest.scala
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/CreateAndDeleteTagProcedureTest.scala
new file mode 100644
index 000000000..eae7a832d
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/CreateAndDeleteTagProcedureTest.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.paimon.spark.sql
+
+import org.apache.paimon.spark.PaimonSparkTestBase
+
+import org.apache.spark.sql.{Dataset, Row}
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.streaming.StreamTest
+
+class CreateAndDeleteTagProcedureTest extends PaimonSparkTestBase with
StreamTest {
+
+ import testImplicits._
+
+ test("Paimon Procedure: create and delete tag") {
+ failAfter(streamingTimeout) {
+ withTempDir {
+ checkpointDir =>
+ // define a change-log table and test `forEachBatch` api
+ spark.sql(s"""
+ |CREATE TABLE T (a INT, b STRING)
+ |TBLPROPERTIES ('primary-key'='a',
'write-mode'='change-log', 'bucket'='3')
+ |""".stripMargin)
+ val location = loadTable("T").location().getPath
+
+ val inputData = MemoryStream[(Int, String)]
+ val stream = inputData
+ .toDS()
+ .toDF("a", "b")
+ .writeStream
+ .option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .foreachBatch {
+ (batch: Dataset[Row], _: Long) =>
+ batch.write.format("paimon").mode("append").save(location)
+ }
+ .start()
+
+ val query = () => spark.sql("SELECT * FROM T ORDER BY a")
+
+ try {
+ // snapshot-1
+ inputData.addData((1, "a"))
+ stream.processAllAvailable()
+ checkAnswer(query(), Row(1, "a") :: Nil)
+
+ // snapshot-2
+ inputData.addData((2, "b"))
+ stream.processAllAvailable()
+ checkAnswer(query(), Row(1, "a") :: Row(2, "b") :: Nil)
+
+ // snapshot-3
+ inputData.addData((2, "b2"))
+ stream.processAllAvailable()
+ checkAnswer(query(), Row(1, "a") :: Row(2, "b2") :: Nil)
+ checkAnswer(
+ spark.sql("CALL create_tag(table => 'test.T', tag => 'test_tag',
snapshot => 2)"),
+ Row(true) :: Nil)
+ checkAnswer(
+ spark.sql("SELECT tag_name FROM paimon.test.`T$tags`"),
+ Row("test_tag") :: Nil)
+ checkAnswer(
+ spark.sql("CALL delete_tag(table => 'test.T', tag =>
'test_tag')"),
+ Row(true) :: Nil)
+ checkAnswer(spark.sql("SELECT tag_name FROM
paimon.test.`T$tags`"), Nil)
+ } finally {
+ stream.stop()
+ }
+ }
+ }
+ }
+}
diff --git
a/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/RollbackProcedureTest.scala
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/RollbackProcedureTest.scala
new file mode 100644
index 000000000..ace04f894
--- /dev/null
+++
b/paimon-spark/paimon-spark-common/src/test/scala/org/apache/paimon/spark/sql/RollbackProcedureTest.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.paimon.spark.sql
+
+import org.apache.paimon.spark.PaimonSparkTestBase
+
+import org.apache.spark.sql.{Dataset, Row}
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.streaming.StreamTest
+
+class RollbackProcedureTest extends PaimonSparkTestBase with StreamTest {
+
+ import testImplicits._
+
+ test("Paimon Procedure: rollback to snapshot and tag") {
+ failAfter(streamingTimeout) {
+ withTempDir {
+ checkpointDir =>
+ // define a change-log table and test `forEachBatch` api
+ spark.sql(s"""
+ |CREATE TABLE T (a INT, b STRING)
+ |TBLPROPERTIES ('primary-key'='a',
'write-mode'='change-log', 'bucket'='3')
+ |""".stripMargin)
+ val location = loadTable("T").location().getPath
+
+ val inputData = MemoryStream[(Int, String)]
+ val stream = inputData
+ .toDS()
+ .toDF("a", "b")
+ .writeStream
+ .option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .foreachBatch {
+ (batch: Dataset[Row], _: Long) =>
+ batch.write.format("paimon").mode("append").save(location)
+ }
+ .start()
+
+ val query = () => spark.sql("SELECT * FROM T ORDER BY a")
+
+ try {
+ // snapshot-1
+ inputData.addData((1, "a"))
+ stream.processAllAvailable()
+ checkAnswer(query(), Row(1, "a") :: Nil)
+
+ checkAnswer(
+ spark.sql("CALL create_tag(table => 'test.T', tag => 'test_tag',
snapshot => 1)"),
+ Row(true) :: Nil)
+
+ // snapshot-2
+ inputData.addData((2, "b"))
+ stream.processAllAvailable()
+ checkAnswer(query(), Row(1, "a") :: Row(2, "b") :: Nil)
+
+ // snapshot-3
+ inputData.addData((2, "b2"))
+ stream.processAllAvailable()
+ checkAnswer(query(), Row(1, "a") :: Row(2, "b2") :: Nil)
+ assertThrows[RuntimeException] {
+ spark.sql("CALL rollback(table => 'test.T_exception', version =>
'2')")
+ }
+ // rollback to snapshot
+ checkAnswer(
+ spark.sql("CALL rollback(table => 'test.T', version => '2')"),
+ Row(true) :: Nil)
+ checkAnswer(query(), Row(1, "a") :: Row(2, "b") :: Nil)
+
+ // rollback to tag
+ checkAnswer(
+ spark.sql("CALL rollback(table => 'test.T', version =>
'test_tag')"),
+ Row(true) :: Nil)
+ checkAnswer(query(), Row(1, "a") :: Nil)
+ } finally {
+ stream.stop()
+ }
+ }
+ }
+ }
+}