This is an automated email from the ASF dual-hosted git repository.
jshao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gravitino.git
The following commit(s) were added to refs/heads/main by this push:
new 615b8eb378 [#6814] feat(core): Support update aliases for model
version. (#7037)
615b8eb378 is described below
commit 615b8eb3783bda54287715c94ee68867f5d576e9
Author: Lord of Abyss <[email protected]>
AuthorDate: Thu May 8 10:58:41 2025 +0800
[#6814] feat(core): Support update aliases for model version. (#7037)
### What changes were proposed in this pull request?
Support update aliases for model version.
- [X] PR1: Add ModelVersionChange API interface, Implement the update
alias logic in model catalog and JDBC backend logic, update related
event.
- [ ] PR2: Add REST endpoint to support model version change, add Java
client and Python client for model version alias update.
### Why are the changes needed?
Fix: #6814
### Does this PR introduce _any_ user-facing change?
no.
### How was this patch tested?
local test. When an alias is updated, all old aliases are removed and
all new aliases are inserted, keeping the version number unchanged

---------
Signed-off-by: dependabot[bot] <[email protected]>
Signed-off-by: George T. C. Lai <[email protected]>
Co-authored-by: mchades <[email protected]>
Co-authored-by: roryqi <[email protected]>
Co-authored-by: Qiming Teng <[email protected]>
Co-authored-by: gavin.wang <[email protected]>
Co-authored-by: FANNG <[email protected]>
Co-authored-by: dependabot[bot]
<49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jerry Shao <[email protected]>
Co-authored-by: yangyang zhong <[email protected]>
Co-authored-by: Eric Chang <[email protected]>
Co-authored-by: Mini Yu <[email protected]>
Co-authored-by: Kang <[email protected]>
Co-authored-by: Justin Mclean <[email protected]>
Co-authored-by: Danhua Wang <[email protected]>
Co-authored-by: yunchi <[email protected]>
Co-authored-by: RickyMa <[email protected]>
Co-authored-by: Yuhui <[email protected]>
Co-authored-by: Tian Lu <[email protected]>
Co-authored-by: tian bao <[email protected]>
Co-authored-by: Jimmy Lee <[email protected]>
Co-authored-by: Brijesh Thummar <[email protected]>
Co-authored-by: Qian Xia <[email protected]>
Co-authored-by: Xiaojian Sun <[email protected]>
Co-authored-by: Cyber Star <[email protected]>
Co-authored-by: George T. C. Lai <[email protected]>
Co-authored-by: AndreVale69 <[email protected]>
Co-authored-by: Zhengke Zhou <[email protected]>
---
.../apache/gravitino/model/ModelVersionChange.java | 113 +++++++
.../gravitino/model/TestModelVersionChange.java | 110 +++++++
.../catalog/model/ModelCatalogOperations.java | 28 ++
.../catalog/model/TestModelCatalogOperations.java | 359 +++++++++++++++++++++
.../mapper/ModelVersionAliasRelMapper.java | 6 +
.../ModelVersionAliasSQLProviderFactory.java | 5 +
.../base/ModelVersionAliasRelBaseSQLProvider.java | 16 +
.../service/ModelVersionMetaService.java | 61 +++-
.../storage/relational/utils/POConverters.java | 41 +++
.../catalog/TestModelOperationDispatcher.java | 290 +++++++++++++++++
.../gravitino/connector/TestCatalogOperations.java | 40 +++
.../service/TestModelVersionMetaService.java | 83 +++++
12 files changed, 1141 insertions(+), 11 deletions(-)
diff --git
a/api/src/main/java/org/apache/gravitino/model/ModelVersionChange.java
b/api/src/main/java/org/apache/gravitino/model/ModelVersionChange.java
index ced922cffc..4e28071e05 100644
--- a/api/src/main/java/org/apache/gravitino/model/ModelVersionChange.java
+++ b/api/src/main/java/org/apache/gravitino/model/ModelVersionChange.java
@@ -19,7 +19,14 @@
package org.apache.gravitino.model;
+import com.google.common.base.Joiner;
+import com.google.common.collect.ImmutableSortedSet;
+import com.google.common.collect.Lists;
+import java.util.Arrays;
+import java.util.List;
import java.util.Objects;
+import java.util.Set;
+import java.util.stream.Collectors;
import org.apache.gravitino.annotation.Evolving;
/**
@@ -29,6 +36,8 @@ import org.apache.gravitino.annotation.Evolving;
*/
@Evolving
public interface ModelVersionChange {
+ /** A Joiner for comma-separated values. */
+ Joiner COMMA_JOINER = Joiner.on(",").skipNulls();
/**
* Create a ModelVersionChange for updating the comment of a model version.
@@ -71,6 +80,22 @@ public interface ModelVersionChange {
return new ModelVersionChange.UpdateUri(newUri);
}
+ /**
+ * Create a ModelVersionChange for updating the aliases of a model version.
+ *
+ * @param aliasesToAdd The new aliases to be added for the model version.
+ * @param aliasesToDelete The aliases to be removed from the model version.
+ * @return A new ModelVersionChange instance for updating the aliases of a
model version.
+ */
+ static ModelVersionChange updateAliases(String[] aliasesToAdd, String[]
aliasesToDelete) {
+ String[] toAdd = aliasesToAdd == null ? new String[0] : aliasesToAdd;
+ String[] toDelete = aliasesToDelete == null ? new String[0] :
aliasesToDelete;
+
+ return new UpdateAliases(
+ Arrays.stream(toAdd).collect(Collectors.toList()),
+ Arrays.stream(toDelete).collect(Collectors.toList()));
+ }
+
/** A ModelVersionChange to update the model version comment. */
final class UpdateComment implements ModelVersionChange {
@@ -327,4 +352,92 @@ public interface ModelVersionChange {
return "UpdateUri " + newUri;
}
}
+
+ /**
+ * Represents an update to a model version’s aliases, specifying which
aliases to add and which to
+ * remove.
+ *
+ * <p>Both alias sets are stored as immutable.
+ */
+ final class UpdateAliases implements ModelVersionChange {
+ private final ImmutableSortedSet<String> aliasesToAdd;
+ private final ImmutableSortedSet<String> aliasesToDelete;
+
+ /**
+ * Constructs a new aliases-update operation, specifying the aliases to
add and remove.
+ *
+ * @param aliasesToAdd the aliases to add, or null for none
+ * @param aliasesToDelete the aliases to remove, or null for none
+ */
+ public UpdateAliases(List<String> aliasesToAdd, List<String>
aliasesToDelete) {
+ this.aliasesToAdd =
+ ImmutableSortedSet.copyOf(aliasesToAdd != null ? aliasesToAdd :
Lists.newArrayList());
+ this.aliasesToDelete =
+ ImmutableSortedSet.copyOf(
+ aliasesToDelete != null ? aliasesToDelete :
Lists.newArrayList());
+ }
+
+ /**
+ * Returns the set of aliases to add.
+ *
+ * @return an immutable, sorted set of aliases to add
+ */
+ public Set<String> aliasesToAdd() {
+ return aliasesToAdd;
+ }
+
+ /**
+ * Returns the set of aliases to remove.
+ *
+ * @return an immutable, sorted set of aliases to remove
+ */
+ public Set<String> aliasesToDelete() {
+ return aliasesToDelete;
+ }
+
+ /**
+ * Compares this UpdateAlias instance with another object for equality.
The comparison is based
+ * on the both new and removed aliases of the model version.
+ *
+ * @param o The object to compare with this instance.
+ * @return {@code true} if the given object represents the same model
update operation; {@code
+ * false} otherwise.
+ */
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (!(o instanceof UpdateAliases)) return false;
+ UpdateAliases that = (UpdateAliases) o;
+ return aliasesToAdd.equals(that.aliasesToAdd) &&
aliasesToDelete.equals(that.aliasesToDelete);
+ }
+
+ /**
+ * Generates a hash code for this UpdateAlias instance. The hash code is
based on the both new
+ * and removed aliases of the model.
+ *
+ * @return A hash code value for this model renaming operation.
+ */
+ @Override
+ public int hashCode() {
+ return Objects.hash(aliasesToAdd, aliasesToDelete);
+ }
+
+ /**
+ * Provides a string representation of the UpdateAlias instance. This
string format includes the
+ * class name followed by the new and removed aliases to be set and
removed.
+ *
+ * @return A string summary of the UpdateAlias instance.
+ */
+ @Override
+ public String toString() {
+ return "UpdateAlias "
+ + "AliasToAdd: ("
+ + COMMA_JOINER.join(aliasesToAdd)
+ + ")"
+ + " "
+ + "AliasToDelete: ("
+ + COMMA_JOINER.join(aliasesToDelete)
+ + ")";
+ }
+ }
}
diff --git
a/api/src/test/java/org/apache/gravitino/model/TestModelVersionChange.java
b/api/src/test/java/org/apache/gravitino/model/TestModelVersionChange.java
index c4e24110e3..d2b6526841 100644
--- a/api/src/test/java/org/apache/gravitino/model/TestModelVersionChange.java
+++ b/api/src/test/java/org/apache/gravitino/model/TestModelVersionChange.java
@@ -19,6 +19,9 @@
package org.apache.gravitino.model;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Lists;
+import java.util.List;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
@@ -199,4 +202,111 @@ public class TestModelVersionChange {
Assertions.assertNotEquals(modelVersionChange1.hashCode(),
modelVersionChange3.hashCode());
Assertions.assertNotEquals(modelVersionChange2.hashCode(),
modelVersionChange3.hashCode());
}
+
+ @Test
+ void testCreateUpdateVersionAliasUseStaticMethod() {
+ String[] aliasesToAdd = {"alias add 1", "alias add 2"};
+ String[] aliasesToDelete = {"alias delete 1", "alias delete 2"};
+
+ ModelVersionChange modelVersionChange =
+ ModelVersionChange.updateAliases(aliasesToAdd, aliasesToDelete);
+
+ Assertions.assertEquals(ModelVersionChange.UpdateAliases.class,
modelVersionChange.getClass());
+
+ ModelVersionChange.UpdateAliases updateAliasesChange =
+ (ModelVersionChange.UpdateAliases) modelVersionChange;
+ Assertions.assertEquals(
+ ImmutableSet.of("alias add 1", "alias add 2"),
updateAliasesChange.aliasesToAdd());
+ Assertions.assertEquals(
+ ImmutableSet.of("alias delete 1", "alias delete 2"),
updateAliasesChange.aliasesToDelete());
+ Assertions.assertEquals(
+ "UpdateAlias "
+ + "AliasToAdd: (alias add 1,alias add 2)"
+ + " "
+ + "AliasToDelete: (alias "
+ + "delete 1,alias delete 2)",
+ updateAliasesChange.toString());
+ }
+
+ @Test
+ void testCreateUpdateVersionAliasUseStaticMethodWithNull() {
+ ModelVersionChange modelVersionChange =
ModelVersionChange.updateAliases(null, null);
+ Assertions.assertEquals(ModelVersionChange.UpdateAliases.class,
modelVersionChange.getClass());
+
+ ModelVersionChange.UpdateAliases updateAliasesChange =
+ (ModelVersionChange.UpdateAliases) modelVersionChange;
+ Assertions.assertEquals(ImmutableSet.of(),
updateAliasesChange.aliasesToAdd());
+ Assertions.assertEquals(ImmutableSet.of(),
updateAliasesChange.aliasesToDelete());
+ Assertions.assertEquals(
+ "UpdateAlias AliasToAdd: () AliasToDelete: ()",
updateAliasesChange.toString());
+ }
+
+ @Test
+ void testCreateUpdateVersionAliasUseConstructor() {
+ List<String> aliasesToAdd = Lists.newArrayList("alias add 1", "alias add
2");
+ List<String> aliasesToDelete = Lists.newArrayList("alias delete 1", "alias
delete 2");
+
+ ModelVersionChange modelVersionChange =
+ new ModelVersionChange.UpdateAliases(aliasesToAdd, aliasesToDelete);
+
+ Assertions.assertEquals(ModelVersionChange.UpdateAliases.class,
modelVersionChange.getClass());
+
+ ModelVersionChange.UpdateAliases updateAliasesChange =
+ (ModelVersionChange.UpdateAliases) modelVersionChange;
+ Assertions.assertEquals(
+ ImmutableSet.of("alias add 1", "alias add 2"),
updateAliasesChange.aliasesToAdd());
+ Assertions.assertEquals(
+ ImmutableSet.of("alias delete 1", "alias delete 2"),
updateAliasesChange.aliasesToDelete());
+ Assertions.assertEquals(
+ "UpdateAlias "
+ + "AliasToAdd: (alias add 1,alias add 2)"
+ + " "
+ + "AliasToDelete: (alias "
+ + "delete 1,alias delete 2)",
+ updateAliasesChange.toString());
+ }
+
+ @Test
+ void testCreateUpdateVersionAliasUseConstructorWithNull() {
+ ModelVersionChange modelVersionChange = new
ModelVersionChange.UpdateAliases(null, null);
+ Assertions.assertEquals(ModelVersionChange.UpdateAliases.class,
modelVersionChange.getClass());
+
+ ModelVersionChange.UpdateAliases updateAliasesChange =
+ (ModelVersionChange.UpdateAliases) modelVersionChange;
+ Assertions.assertEquals(ImmutableSet.of(),
updateAliasesChange.aliasesToAdd());
+ Assertions.assertEquals(ImmutableSet.of(),
updateAliasesChange.aliasesToDelete());
+ Assertions.assertEquals(
+ "UpdateAlias AliasToAdd: () AliasToDelete: ()",
updateAliasesChange.toString());
+ }
+
+ @Test
+ void testUpdateVersionAliasChangeEquals() {
+ List<String> aliasesToAdd = Lists.newArrayList("alias add 1", "alias add
2");
+ List<String> aliasesToDelete = Lists.newArrayList("alias delete 1", "alias
delete 2");
+
+ List<String> differentAliasesToAdd = Lists.newArrayList("alias add 1",
"alias add 3");
+ List<String> differentAliasesToDelete = Lists.newArrayList("alias delete
1", "alias delete 3");
+
+ ModelVersionChange modelVersionChange1 =
+ new ModelVersionChange.UpdateAliases(aliasesToAdd, aliasesToDelete);
+ ModelVersionChange modelVersionChange2 =
+ new ModelVersionChange.UpdateAliases(aliasesToAdd, aliasesToDelete);
+ ModelVersionChange modelVersionChange3 =
+ new ModelVersionChange.UpdateAliases(differentAliasesToAdd,
aliasesToDelete);
+ ModelVersionChange modelVersionChange4 =
+ new ModelVersionChange.UpdateAliases(aliasesToAdd,
differentAliasesToDelete);
+
+ Assertions.assertEquals(modelVersionChange1, modelVersionChange2);
+ Assertions.assertNotEquals(modelVersionChange1, modelVersionChange3);
+ Assertions.assertNotEquals(modelVersionChange1, modelVersionChange4);
+ Assertions.assertNotEquals(modelVersionChange2, modelVersionChange3);
+ Assertions.assertNotEquals(modelVersionChange2, modelVersionChange4);
+ Assertions.assertNotEquals(modelVersionChange3, modelVersionChange4);
+
+ Assertions.assertEquals(modelVersionChange1.hashCode(),
modelVersionChange2.hashCode());
+ Assertions.assertNotEquals(modelVersionChange1.hashCode(),
modelVersionChange3.hashCode());
+ Assertions.assertNotEquals(modelVersionChange1.hashCode(),
modelVersionChange4.hashCode());
+ Assertions.assertNotEquals(modelVersionChange2.hashCode(),
modelVersionChange3.hashCode());
+ Assertions.assertNotEquals(modelVersionChange2.hashCode(),
modelVersionChange4.hashCode());
+ }
}
diff --git
a/catalogs/catalog-model/src/main/java/org/apache/gravitino/catalog/model/ModelCatalogOperations.java
b/catalogs/catalog-model/src/main/java/org/apache/gravitino/catalog/model/ModelCatalogOperations.java
index 47b709d957..11c81ed9b0 100644
---
a/catalogs/catalog-model/src/main/java/org/apache/gravitino/catalog/model/ModelCatalogOperations.java
+++
b/catalogs/catalog-model/src/main/java/org/apache/gravitino/catalog/model/ModelCatalogOperations.java
@@ -21,10 +21,13 @@ package org.apache.gravitino.catalog.model;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
import java.io.IOException;
import java.time.Instant;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
+import java.util.Set;
import org.apache.gravitino.Catalog;
import org.apache.gravitino.Entity;
import org.apache.gravitino.EntityAlreadyExistsException;
@@ -435,6 +438,17 @@ public class ModelCatalogOperations extends
ManagedSchemaOperations
ModelVersionChange.UpdateUri updateUriChange =
(ModelVersionChange.UpdateUri) change;
entityUri = updateUriChange.newUri();
+ } else if (change instanceof ModelVersionChange.UpdateAliases) {
+ ModelVersionChange.UpdateAliases updateAliasesChange =
+ (ModelVersionChange.UpdateAliases) change;
+ Set<String> addTmpSet = updateAliasesChange.aliasesToAdd();
+ Set<String> deleteTmpSet = updateAliasesChange.aliasesToDelete();
+ Set<String> aliasToAdd = Sets.difference(addTmpSet,
deleteTmpSet).immutableCopy();
+ Set<String> aliasToDelete = Sets.difference(deleteTmpSet,
addTmpSet).immutableCopy();
+
+ doDeleteAlias(entityAliases, aliasToDelete);
+ doSetAlias(entityAliases, aliasToAdd);
+
} else {
throw new IllegalArgumentException(
"Unsupported model version change: " +
change.getClass().getSimpleName());
@@ -518,4 +532,18 @@ public class ModelCatalogOperations extends
ManagedSchemaOperations
Map<String, String> entityProperties, ModelVersionChange.RemoveProperty
change) {
entityProperties.remove(change.property());
}
+
+ private void doDeleteAlias(List<String> entityAliases, Set<String>
deleteSet) {
+ entityAliases.removeAll(deleteSet);
+ }
+
+ private void doSetAlias(List<String> entityAliases, Set<String> addSet) {
+ // for fast lookup
+ Set<String> aliasSet = new HashSet<>(entityAliases);
+ for (String alias : addSet) {
+ if (aliasSet.add(alias)) {
+ entityAliases.add(alias);
+ }
+ }
+ }
}
diff --git
a/catalogs/catalog-model/src/test/java/org/apache/gravtitino/catalog/model/TestModelCatalogOperations.java
b/catalogs/catalog-model/src/test/java/org/apache/gravtitino/catalog/model/TestModelCatalogOperations.java
index 882b1ec798..e297c7195e 100644
---
a/catalogs/catalog-model/src/test/java/org/apache/gravtitino/catalog/model/TestModelCatalogOperations.java
+++
b/catalogs/catalog-model/src/test/java/org/apache/gravtitino/catalog/model/TestModelCatalogOperations.java
@@ -35,6 +35,7 @@ import static
org.apache.gravitino.Configs.VERSION_RETENTION_COUNT;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import java.io.File;
import java.io.IOException;
@@ -1245,6 +1246,364 @@ public class TestModelCatalogOperations {
Assertions.assertEquals(versionProperties,
updatedModelVersion.properties());
}
+ @Test
+ void testUpdateModelAlias() {
+ String schemaName = randomSchemaName();
+ createSchema(schemaName);
+
+ String modelName = "model1";
+ String modelComment = "model1 comment";
+
+ String versionComment = "version1 comment";
+ String versionUri = "model_version_path";
+ String[] versionAliases = new String[] {"alias1", "alias2"};
+ String[] newVersionAliases = new String[] {"new_alias1", "new_alias2"};
+
+ NameIdentifier modelIdent =
+ NameIdentifierUtil.ofModel(METALAKE_NAME, CATALOG_NAME, schemaName,
modelName);
+ StringIdentifier stringId = StringIdentifier.fromId(idGenerator.nextId());
+ Map<String, String> properties =
StringIdentifier.newPropertiesWithId(stringId, null);
+
+ ops.registerModel(modelIdent, modelComment, properties);
+ StringIdentifier versionId = StringIdentifier.fromId(idGenerator.nextId());
+ Map<String, String> versionProperties =
+ StringIdentifier.newPropertiesWithId(
+ versionId, ImmutableMap.of("key1", "value1", "key2", "value2"));
+
+ ops.linkModelVersion(modelIdent, versionUri, versionAliases,
versionComment, versionProperties);
+
+ // validate loaded model
+ Model loadedModel = ops.getModel(modelIdent);
+ Assertions.assertEquals(1, loadedModel.latestVersion());
+
+ // validate loaded version
+ ModelVersion loadedVersion = ops.getModelVersion(modelIdent, 0);
+ Assertions.assertEquals(0, loadedVersion.version());
+ Assertions.assertArrayEquals(versionAliases, loadedVersion.aliases());
+ Assertions.assertEquals(versionComment, loadedVersion.comment());
+ Assertions.assertEquals(versionUri, loadedVersion.uri());
+ Assertions.assertEquals(versionProperties, loadedVersion.properties());
+
+ // validate update version aliases
+ ModelVersionChange change =
ModelVersionChange.updateAliases(newVersionAliases, versionAliases);
+ ModelVersion updatedModelVersion = ops.alterModelVersion(modelIdent, 0,
change);
+
+ Assertions.assertEquals(0, updatedModelVersion.version());
+ Assertions.assertEquals(versionUri, updatedModelVersion.uri());
+ Assertions.assertEquals(versionComment, updatedModelVersion.comment());
+ Assertions.assertArrayEquals(newVersionAliases,
updatedModelVersion.aliases());
+ Assertions.assertEquals(versionProperties,
updatedModelVersion.properties());
+
+ // Reload the version
+ ModelVersion reloadVersion = ops.getModelVersion(modelIdent, 0);
+ Assertions.assertEquals(0, reloadVersion.version());
+ Assertions.assertArrayEquals(newVersionAliases, reloadVersion.aliases());
+ Assertions.assertEquals(versionUri, reloadVersion.uri());
+ Assertions.assertEquals(versionComment, reloadVersion.comment());
+ Assertions.assertEquals(versionProperties, reloadVersion.properties());
+ }
+
+ @Test
+ void testUpdateModelAliasByAlias() {
+ String schemaName = randomSchemaName();
+ createSchema(schemaName);
+
+ String modelName = "model1";
+ String modelComment = "model1 comment";
+
+ String versionComment = "version1 comment";
+ String versionUri = "model_version_path";
+ String[] versionAliases = new String[] {"alias1", "alias2"};
+ String[] newVersionAliases = new String[] {"new_alias1", "new_alias2"};
+
+ NameIdentifier modelIdent =
+ NameIdentifierUtil.ofModel(METALAKE_NAME, CATALOG_NAME, schemaName,
modelName);
+ StringIdentifier stringId = StringIdentifier.fromId(idGenerator.nextId());
+ Map<String, String> properties =
StringIdentifier.newPropertiesWithId(stringId, null);
+
+ ops.registerModel(modelIdent, modelComment, properties);
+ StringIdentifier versionId = StringIdentifier.fromId(idGenerator.nextId());
+ Map<String, String> versionProperties =
+ StringIdentifier.newPropertiesWithId(
+ versionId, ImmutableMap.of("key1", "value1", "key2", "value2"));
+
+ ops.linkModelVersion(modelIdent, versionUri, versionAliases,
versionComment, versionProperties);
+
+ // validate loaded model
+ Model loadedModel = ops.getModel(modelIdent);
+ Assertions.assertEquals(1, loadedModel.latestVersion());
+
+ // validate loaded version
+ ModelVersion loadedVersion = ops.getModelVersion(modelIdent, "alias1");
+ Assertions.assertEquals(0, loadedVersion.version());
+ Assertions.assertArrayEquals(versionAliases, loadedVersion.aliases());
+ Assertions.assertEquals(versionComment, loadedVersion.comment());
+ Assertions.assertEquals(versionUri, loadedVersion.uri());
+ Assertions.assertEquals(versionProperties, loadedVersion.properties());
+
+ // validate update version aliases
+ ModelVersionChange change =
ModelVersionChange.updateAliases(newVersionAliases, versionAliases);
+ ModelVersion updatedModelVersion = ops.alterModelVersion(modelIdent,
"alias1", change);
+
+ Assertions.assertEquals(0, updatedModelVersion.version());
+ Assertions.assertEquals(versionUri, updatedModelVersion.uri());
+ Assertions.assertEquals(versionComment, updatedModelVersion.comment());
+ Assertions.assertArrayEquals(newVersionAliases,
updatedModelVersion.aliases());
+ Assertions.assertEquals(versionProperties,
updatedModelVersion.properties());
+
+ // Reload the version
+ ModelVersion reloadVersion = ops.getModelVersion(modelIdent, "new_alias2");
+ Assertions.assertEquals(0, reloadVersion.version());
+ Assertions.assertArrayEquals(newVersionAliases, reloadVersion.aliases());
+ Assertions.assertEquals(versionUri, reloadVersion.uri());
+ Assertions.assertEquals(versionComment, reloadVersion.comment());
+ Assertions.assertEquals(versionProperties, reloadVersion.properties());
+ }
+
+ @Test
+ void testUpdateModelVersionWithPartialAliasChanges() {
+ String schemaName = randomSchemaName();
+ createSchema(schemaName);
+
+ String modelName = "model2";
+ String modelComment = "model2 comment";
+
+ String versionComment = "version1 comment";
+ String versionUri = "model_version_path";
+ String[] versionAliases = new String[] {"alias1", "alias2"};
+ String[] newVersionAliases = new String[] {"new_alias1", "new_alias2"};
+
+ NameIdentifier modelIdent =
+ NameIdentifierUtil.ofModel(METALAKE_NAME, CATALOG_NAME, schemaName,
modelName);
+ StringIdentifier stringId = StringIdentifier.fromId(idGenerator.nextId());
+ Map<String, String> properties =
StringIdentifier.newPropertiesWithId(stringId, null);
+
+ ops.registerModel(modelIdent, modelComment, properties);
+ StringIdentifier versionId = StringIdentifier.fromId(idGenerator.nextId());
+ Map<String, String> versionProperties =
+ StringIdentifier.newPropertiesWithId(
+ versionId, ImmutableMap.of("key1", "value1", "key2", "value2"));
+
+ ops.linkModelVersion(modelIdent, versionUri, versionAliases,
versionComment, versionProperties);
+
+ // validate loaded model
+ Model loadedModel = ops.getModel(modelIdent);
+ Assertions.assertEquals(1, loadedModel.latestVersion());
+
+ // validate loaded version
+ ModelVersion loadedVersion = ops.getModelVersion(modelIdent, 0);
+ Assertions.assertEquals(0, loadedVersion.version());
+ Assertions.assertArrayEquals(versionAliases, loadedVersion.aliases());
+ Assertions.assertEquals(versionComment, loadedVersion.comment());
+ Assertions.assertEquals(versionUri, loadedVersion.uri());
+ Assertions.assertEquals(versionProperties, loadedVersion.properties());
+
+ // validate update version aliases
+ ModelVersionChange change =
+ ModelVersionChange.updateAliases(newVersionAliases, new String[]
{"alias1"});
+ ModelVersion updatedModelVersion = ops.alterModelVersion(modelIdent, 0,
change);
+
+ Assertions.assertEquals(0, updatedModelVersion.version());
+ Assertions.assertEquals(versionUri, updatedModelVersion.uri());
+ Assertions.assertEquals(versionComment, updatedModelVersion.comment());
+ Assertions.assertArrayEquals(
+ new String[] {"alias2", "new_alias1", "new_alias2"},
updatedModelVersion.aliases());
+ Assertions.assertEquals(versionProperties,
updatedModelVersion.properties());
+
+ // Reload the version
+ ModelVersion reloadVersion = ops.getModelVersion(modelIdent, 0);
+ Assertions.assertEquals(0, reloadVersion.version());
+ Assertions.assertArrayEquals(
+ new String[] {"alias2", "new_alias1", "new_alias2"},
reloadVersion.aliases());
+ Assertions.assertEquals(versionUri, reloadVersion.uri());
+ Assertions.assertEquals(versionComment, reloadVersion.comment());
+ Assertions.assertEquals(versionProperties, reloadVersion.properties());
+ }
+
+ @Test
+ void testUpdateModelVersionByAliasWithPartialAliasChanges() {
+ String schemaName = randomSchemaName();
+ createSchema(schemaName);
+
+ String modelName = "model3";
+ String modelComment = "model3 comment";
+
+ String versionComment = "version1 comment";
+ String versionUri = "model_version_path";
+ String[] versionAliases = new String[] {"alias1", "alias2"};
+ String[] newVersionAliases = new String[] {"new_alias1", "new_alias2"};
+
+ NameIdentifier modelIdent =
+ NameIdentifierUtil.ofModel(METALAKE_NAME, CATALOG_NAME, schemaName,
modelName);
+ StringIdentifier stringId = StringIdentifier.fromId(idGenerator.nextId());
+ Map<String, String> properties =
StringIdentifier.newPropertiesWithId(stringId, null);
+
+ ops.registerModel(modelIdent, modelComment, properties);
+ StringIdentifier versionId = StringIdentifier.fromId(idGenerator.nextId());
+ Map<String, String> versionProperties =
+ StringIdentifier.newPropertiesWithId(
+ versionId, ImmutableMap.of("key1", "value1", "key2", "value2"));
+
+ ops.linkModelVersion(modelIdent, versionUri, versionAliases,
versionComment, versionProperties);
+
+ // validate loaded model
+ Model loadedModel = ops.getModel(modelIdent);
+ Assertions.assertEquals(1, loadedModel.latestVersion());
+
+ // validate loaded version
+ ModelVersion loadedVersion = ops.getModelVersion(modelIdent, "alias1");
+ Assertions.assertEquals(0, loadedVersion.version());
+ Assertions.assertArrayEquals(versionAliases, loadedVersion.aliases());
+ Assertions.assertEquals(versionComment, loadedVersion.comment());
+ Assertions.assertEquals(versionUri, loadedVersion.uri());
+ Assertions.assertEquals(versionProperties, loadedVersion.properties());
+
+ // validate update version aliases
+ ModelVersionChange change =
+ ModelVersionChange.updateAliases(newVersionAliases, new String[]
{"alias1"});
+ ModelVersion updatedModelVersion = ops.alterModelVersion(modelIdent,
"alias1", change);
+
+ Assertions.assertEquals(0, updatedModelVersion.version());
+ Assertions.assertEquals(versionUri, updatedModelVersion.uri());
+ Assertions.assertEquals(versionComment, updatedModelVersion.comment());
+ Assertions.assertArrayEquals(
+ new String[] {"alias2", "new_alias1", "new_alias2"},
updatedModelVersion.aliases());
+ Assertions.assertEquals(versionProperties,
updatedModelVersion.properties());
+
+ // Reload the version
+ ModelVersion reloadVersion = ops.getModelVersion(modelIdent, "new_alias2");
+ Assertions.assertEquals(0, reloadVersion.version());
+ Assertions.assertArrayEquals(
+ new String[] {"alias2", "new_alias1", "new_alias2"},
reloadVersion.aliases());
+ Assertions.assertEquals(versionUri, reloadVersion.uri());
+ Assertions.assertEquals(versionComment, reloadVersion.comment());
+ Assertions.assertEquals(versionProperties, reloadVersion.properties());
+ }
+
+ @Test
+ void testUpdateModelVersionAliasesOverlapAddAndRemove() {
+ String schemaName = randomSchemaName();
+ createSchema(schemaName);
+
+ String modelName = "model1";
+ String modelComment = "model1 comment";
+
+ String versionComment = "version1 comment";
+ String versionUri = "model_version_path";
+ String[] versionAliases = new String[] {"alias2", "alias3"};
+
+ NameIdentifier modelIdent =
+ NameIdentifierUtil.ofModel(METALAKE_NAME, CATALOG_NAME, schemaName,
modelName);
+ StringIdentifier stringId = StringIdentifier.fromId(idGenerator.nextId());
+ Map<String, String> properties =
StringIdentifier.newPropertiesWithId(stringId, null);
+
+ ops.registerModel(modelIdent, modelComment, properties);
+ StringIdentifier versionId = StringIdentifier.fromId(idGenerator.nextId());
+ Map<String, String> versionProperties =
+ StringIdentifier.newPropertiesWithId(
+ versionId, ImmutableMap.of("key1", "value1", "key2", "value2"));
+
+ ops.linkModelVersion(modelIdent, versionUri, versionAliases,
versionComment, versionProperties);
+
+ // validate loaded model
+ Model loadedModel = ops.getModel(modelIdent);
+ Assertions.assertEquals(1, loadedModel.latestVersion());
+
+ // validate loaded version
+ ModelVersion loadedVersion = ops.getModelVersion(modelIdent, 0);
+ Assertions.assertEquals(0, loadedVersion.version());
+ Assertions.assertArrayEquals(versionAliases, loadedVersion.aliases());
+ Assertions.assertEquals(versionComment, loadedVersion.comment());
+ Assertions.assertEquals(versionUri, loadedVersion.uri());
+ Assertions.assertEquals(versionProperties, loadedVersion.properties());
+
+ // validate update version aliases
+ ModelVersionChange change =
+ ModelVersionChange.updateAliases(
+ new String[] {"alias1", "alias2"}, new String[] {"alias2",
"alias3"});
+ ModelVersion updatedModelVersion = ops.alterModelVersion(modelIdent, 0,
change);
+
+ Assertions.assertEquals(0, updatedModelVersion.version());
+ Assertions.assertEquals(versionUri, updatedModelVersion.uri());
+ Assertions.assertEquals(versionComment, updatedModelVersion.comment());
+ Assertions.assertEquals(
+ ImmutableSet.of("alias1", "alias2"),
+
Arrays.stream(updatedModelVersion.aliases()).collect(Collectors.toSet()));
+ Assertions.assertEquals(versionProperties,
updatedModelVersion.properties());
+
+ // Reload the version
+ ModelVersion reloadVersion = ops.getModelVersion(modelIdent, 0);
+ Assertions.assertEquals(0, reloadVersion.version());
+ Assertions.assertEquals(
+ ImmutableSet.of("alias1", "alias2"),
+ Arrays.stream(reloadVersion.aliases()).collect(Collectors.toSet()));
+ Assertions.assertEquals(versionUri, reloadVersion.uri());
+ Assertions.assertEquals(versionComment, reloadVersion.comment());
+ Assertions.assertEquals(versionProperties, reloadVersion.properties());
+ }
+
+ @Test
+ void testUpdateModelVersionAliasesByAliasOverlapAddAndRemove() {
+ String schemaName = randomSchemaName();
+ createSchema(schemaName);
+
+ String modelName = "model1";
+ String modelComment = "model1 comment";
+
+ String versionComment = "version1 comment";
+ String versionUri = "model_version_path";
+ String[] versionAliases = new String[] {"alias2", "alias3"};
+
+ NameIdentifier modelIdent =
+ NameIdentifierUtil.ofModel(METALAKE_NAME, CATALOG_NAME, schemaName,
modelName);
+ StringIdentifier stringId = StringIdentifier.fromId(idGenerator.nextId());
+ Map<String, String> properties =
StringIdentifier.newPropertiesWithId(stringId, null);
+
+ ops.registerModel(modelIdent, modelComment, properties);
+ StringIdentifier versionId = StringIdentifier.fromId(idGenerator.nextId());
+ Map<String, String> versionProperties =
+ StringIdentifier.newPropertiesWithId(
+ versionId, ImmutableMap.of("key1", "value1", "key2", "value2"));
+
+ ops.linkModelVersion(modelIdent, versionUri, versionAliases,
versionComment, versionProperties);
+
+ // validate loaded model
+ Model loadedModel = ops.getModel(modelIdent);
+ Assertions.assertEquals(1, loadedModel.latestVersion());
+
+ // validate loaded version
+ ModelVersion loadedVersion = ops.getModelVersion(modelIdent, "alias2");
+ Assertions.assertEquals(0, loadedVersion.version());
+ Assertions.assertArrayEquals(versionAliases, loadedVersion.aliases());
+ Assertions.assertEquals(versionComment, loadedVersion.comment());
+ Assertions.assertEquals(versionUri, loadedVersion.uri());
+ Assertions.assertEquals(versionProperties, loadedVersion.properties());
+
+ // validate update version aliases
+ ModelVersionChange change =
+ ModelVersionChange.updateAliases(
+ new String[] {"alias1", "alias2"}, new String[] {"alias2",
"alias3"});
+ ModelVersion updatedModelVersion = ops.alterModelVersion(modelIdent,
"alias3", change);
+
+ Assertions.assertEquals(0, updatedModelVersion.version());
+ Assertions.assertEquals(versionUri, updatedModelVersion.uri());
+ Assertions.assertEquals(versionComment, updatedModelVersion.comment());
+ Assertions.assertEquals(
+ ImmutableSet.of("alias1", "alias2"),
+
Arrays.stream(updatedModelVersion.aliases()).collect(Collectors.toSet()));
+ Assertions.assertEquals(versionProperties,
updatedModelVersion.properties());
+
+ // Reload the version
+ ModelVersion reloadVersion = ops.getModelVersion(modelIdent, "alias2");
+ Assertions.assertEquals(0, reloadVersion.version());
+ Assertions.assertEquals(
+ ImmutableSet.of("alias1", "alias2"),
+ Arrays.stream(reloadVersion.aliases()).collect(Collectors.toSet()));
+ Assertions.assertEquals(versionUri, reloadVersion.uri());
+ Assertions.assertEquals(versionComment, reloadVersion.comment());
+ Assertions.assertEquals(versionProperties, reloadVersion.properties());
+ }
+
private String randomSchemaName() {
return "schema_" + UUID.randomUUID().toString().replace("-", "");
}
diff --git
a/core/src/main/java/org/apache/gravitino/storage/relational/mapper/ModelVersionAliasRelMapper.java
b/core/src/main/java/org/apache/gravitino/storage/relational/mapper/ModelVersionAliasRelMapper.java
index 6960649759..a84396c997 100644
---
a/core/src/main/java/org/apache/gravitino/storage/relational/mapper/ModelVersionAliasRelMapper.java
+++
b/core/src/main/java/org/apache/gravitino/storage/relational/mapper/ModelVersionAliasRelMapper.java
@@ -36,6 +36,12 @@ public interface ModelVersionAliasRelMapper {
void insertModelVersionAliasRels(
@Param("modelVersionAliasRel") List<ModelVersionAliasRelPO>
modelVersionAliasRelPOs);
+ @UpdateProvider(
+ type = ModelVersionAliasSQLProviderFactory.class,
+ method = "updateModelVersionAliasRel")
+ void updateModelVersionAliasRel(
+ @Param("modelVersionAliasRel") List<ModelVersionAliasRelPO>
modelVersionAliasRelPOs);
+
@SelectProvider(
type = ModelVersionAliasSQLProviderFactory.class,
method = "selectModelVersionAliasRelsByModelId")
diff --git
a/core/src/main/java/org/apache/gravitino/storage/relational/mapper/ModelVersionAliasSQLProviderFactory.java
b/core/src/main/java/org/apache/gravitino/storage/relational/mapper/ModelVersionAliasSQLProviderFactory.java
index c83e9deaa2..292ad5a4c8 100644
---
a/core/src/main/java/org/apache/gravitino/storage/relational/mapper/ModelVersionAliasSQLProviderFactory.java
+++
b/core/src/main/java/org/apache/gravitino/storage/relational/mapper/ModelVersionAliasSQLProviderFactory.java
@@ -103,4 +103,9 @@ public class ModelVersionAliasSQLProviderFactory {
@Param("legacyTimeline") Long legacyTimeline, @Param("limit") int limit)
{
return
getProvider().deleteModelVersionAliasRelsByLegacyTimeline(legacyTimeline,
limit);
}
+
+ public static String updateModelVersionAliasRel(
+ @Param("modelVersionAliasRel") List<ModelVersionAliasRelPO>
modelVersionAliasRelPOs) {
+ return getProvider().updateModelVersionAliasRel(modelVersionAliasRelPOs);
+ }
}
diff --git
a/core/src/main/java/org/apache/gravitino/storage/relational/mapper/provider/base/ModelVersionAliasRelBaseSQLProvider.java
b/core/src/main/java/org/apache/gravitino/storage/relational/mapper/provider/base/ModelVersionAliasRelBaseSQLProvider.java
index abaaa5a8ae..dfbf81e9f8 100644
---
a/core/src/main/java/org/apache/gravitino/storage/relational/mapper/provider/base/ModelVersionAliasRelBaseSQLProvider.java
+++
b/core/src/main/java/org/apache/gravitino/storage/relational/mapper/provider/base/ModelVersionAliasRelBaseSQLProvider.java
@@ -149,4 +149,20 @@ public class ModelVersionAliasRelBaseSQLProvider {
+ ModelVersionAliasRelMapper.TABLE_NAME
+ " WHERE deleted_at > 0 AND deleted_at < #{legacyTimeline} LIMIT
#{limit}";
}
+
+ public String updateModelVersionAliasRel(
+ @Param("modelVersionAliasRel") List<ModelVersionAliasRelPO>
modelVersionAliasRelPOs) {
+ return "<script>"
+ + "INSERT INTO "
+ + ModelVersionAliasRelMapper.TABLE_NAME
+ + " (model_id, model_version, model_version_alias, deleted_at)"
+ + " VALUES "
+ + " <foreach collection='modelVersionAliasRel' item='item'
separator=','>"
+ + " (#{item.modelId},"
+ + " #{item.modelVersion},"
+ + " #{item.modelVersionAlias},"
+ + " #{item.deletedAt})"
+ + " </foreach>"
+ + "</script>";
+ }
}
diff --git
a/core/src/main/java/org/apache/gravitino/storage/relational/service/ModelVersionMetaService.java
b/core/src/main/java/org/apache/gravitino/storage/relational/service/ModelVersionMetaService.java
index ab440e8c4d..4cd8a3a7f9 100644
---
a/core/src/main/java/org/apache/gravitino/storage/relational/service/ModelVersionMetaService.java
+++
b/core/src/main/java/org/apache/gravitino/storage/relational/service/ModelVersionMetaService.java
@@ -288,7 +288,7 @@ public class ModelVersionMetaService {
ident.toString());
}
- List<ModelVersionAliasRelPO> aliasRelPOs =
+ List<ModelVersionAliasRelPO> oldAliasRelPOs =
SessionUtils.getWithoutCommit(
ModelVersionAliasRelMapper.class,
mapper -> {
@@ -302,7 +302,7 @@ public class ModelVersionMetaService {
});
ModelVersionEntity oldModelVersionEntity =
- POConverters.fromModelVersionPO(modelIdent, oldModelVersionPO,
aliasRelPOs);
+ POConverters.fromModelVersionPO(modelIdent, oldModelVersionPO,
oldAliasRelPOs);
ModelVersionEntity newModelVersionEntity =
(ModelVersionEntity) updater.apply((E) oldModelVersionEntity);
@@ -312,25 +312,64 @@ public class ModelVersionMetaService {
newModelVersionEntity.version(),
oldModelVersionEntity.version());
- Integer updateResult;
+ boolean isAliasChanged =
+ isModelVersionAliasUpdated(oldModelVersionEntity,
newModelVersionEntity);
+ List<ModelVersionAliasRelPO> newAliasRelPOs =
+ POConverters.updateModelVersionAliasRelPO(oldAliasRelPOs,
newModelVersionEntity);
+
+ final AtomicInteger updateResult = new AtomicInteger(0);
try {
- updateResult =
- SessionUtils.doWithCommitAndFetchResult(
- ModelVersionMetaMapper.class,
- mapper ->
- mapper.updateModelVersionMeta(
- POConverters.updateModelVersionPO(oldModelVersionPO,
newModelVersionEntity),
- oldModelVersionPO));
+ SessionUtils.doMultipleWithCommit(
+ () ->
+ updateResult.set(
+ SessionUtils.doWithoutCommitAndFetchResult(
+ ModelVersionMetaMapper.class,
+ mapper ->
+ mapper.updateModelVersionMeta(
+ POConverters.updateModelVersionPO(
+ oldModelVersionPO, newModelVersionEntity),
+ oldModelVersionPO))),
+ () -> {
+ if (isAliasChanged) {
+ SessionUtils.doWithoutCommit(
+ ModelVersionAliasRelMapper.class,
+ mapper -> {
+ oldModelVersionEntity
+ .aliases()
+ .forEach(
+ alias ->
+
mapper.softDeleteModelVersionAliasRelsByModelIdAndAlias(
+ modelEntity.id(), alias));
+ });
+
+ SessionUtils.doWithoutCommit(
+ ModelVersionAliasRelMapper.class,
+ mapper -> mapper.updateModelVersionAliasRel(newAliasRelPOs));
+ }
+ });
+
} catch (RuntimeException re) {
ExceptionUtils.checkSQLException(
re, Entity.EntityType.CATALOG,
newModelVersionEntity.nameIdentifier().toString());
throw re;
}
- if (updateResult > 0) {
+ if (updateResult.get() > 0) {
return newModelVersionEntity;
} else {
throw new IOException("Failed to update the entity: " + ident);
}
}
+
+ private boolean isModelVersionAliasUpdated(
+ ModelVersionEntity oldModelVersionEntity, ModelVersionEntity
newModelVersionEntity) {
+ List<String> oldAliases = oldModelVersionEntity.aliases();
+ List<String> newAliases = newModelVersionEntity.aliases();
+
+ if (oldAliases.size() != newAliases.size()) {
+ return true;
+ }
+
+ return !oldAliases.equals(newAliases);
+ }
}
diff --git
a/core/src/main/java/org/apache/gravitino/storage/relational/utils/POConverters.java
b/core/src/main/java/org/apache/gravitino/storage/relational/utils/POConverters.java
index 36728458de..e8ae898a7c 100644
---
a/core/src/main/java/org/apache/gravitino/storage/relational/utils/POConverters.java
+++
b/core/src/main/java/org/apache/gravitino/storage/relational/utils/POConverters.java
@@ -1433,6 +1433,28 @@ public class POConverters {
}
}
+ /**
+ * Construct a new ModelVersionAliasRelPO object with the given alias.
+ *
+ * @param oldModelVersionAliasRelPOs The old ModelVersionAliasRelPOs object
+ * @param newModelVersion The new {@link ModelVersionEntity} object
+ * @return The new ModelVersionAliasRelPO object
+ */
+ public static List<ModelVersionAliasRelPO> updateModelVersionAliasRelPO(
+ List<ModelVersionAliasRelPO> oldModelVersionAliasRelPOs,
ModelVersionEntity newModelVersion) {
+
+ if (!oldModelVersionAliasRelPOs.isEmpty()) {
+ ModelVersionAliasRelPO oldModelVersionAliasRelPO =
oldModelVersionAliasRelPOs.get(0);
+ return newModelVersion.aliases().stream()
+ .map(alias -> createAliasRelPO(oldModelVersionAliasRelPO, alias))
+ .collect(Collectors.toList());
+ } else {
+ return newModelVersion.aliases().stream()
+ .map(alias -> createAliasRelPO(newModelVersion, alias))
+ .collect(Collectors.toList());
+ }
+ }
+
public static ModelVersionPO initializeModelVersionPO(
ModelVersionEntity modelVersionEntity, ModelVersionPO.Builder builder) {
try {
@@ -1469,4 +1491,23 @@ public class POConverters {
.build())
.collect(Collectors.toList());
}
+
+ private static ModelVersionAliasRelPO createAliasRelPO(
+ ModelVersionAliasRelPO oldModelVersionAliasRelPO, String alias) {
+ return ModelVersionAliasRelPO.builder()
+ .withModelVersion(oldModelVersionAliasRelPO.getModelVersion())
+ .withModelVersionAlias(alias)
+ .withModelId(oldModelVersionAliasRelPO.getModelId())
+ .withDeletedAt(DEFAULT_DELETED_AT)
+ .build();
+ }
+
+ private static ModelVersionAliasRelPO createAliasRelPO(ModelVersionEntity
entity, String alias) {
+ return ModelVersionAliasRelPO.builder()
+ .withModelVersion(entity.version())
+ .withModelVersionAlias(alias)
+ .withModelId(entity.id())
+ .withDeletedAt(DEFAULT_DELETED_AT)
+ .build();
+ }
}
diff --git
a/core/src/test/java/org/apache/gravitino/catalog/TestModelOperationDispatcher.java
b/core/src/test/java/org/apache/gravitino/catalog/TestModelOperationDispatcher.java
index eb086a4de3..63996c9ba7 100644
---
a/core/src/test/java/org/apache/gravitino/catalog/TestModelOperationDispatcher.java
+++
b/core/src/test/java/org/apache/gravitino/catalog/TestModelOperationDispatcher.java
@@ -24,6 +24,7 @@ import static
org.apache.gravitino.Configs.TREE_LOCK_MIN_NODE_IN_MEMORY;
import static org.apache.gravitino.StringIdentifier.ID_KEY;
import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import java.io.IOException;
import java.util.Arrays;
@@ -658,6 +659,295 @@ public class TestModelOperationDispatcher extends
TestOperationDispatcher {
Assertions.assertEquals(modelVersion.properties(),
alteredModelVersion.properties());
}
+ @Test
+ void testUpdateModelVersionAliases() {
+ String schemaName = randomSchemaName();
+ String schemaComment = "schema which tests update";
+
+ String modelName = randomModelName();
+ String modelComment = "model which tests update";
+ Map<String, String> props = ImmutableMap.of("k1", "v1", "k2", "v2");
+
+ String versionUri = "s3://test-bucket/test-path/model.json";
+ String[] versionAliases = {"alias1", "alias2"};
+ String versionComment = "version which tests update";
+
+ NameIdentifier schemaIdent = NameIdentifier.of(metalake, catalog,
schemaName);
+ schemaOperationDispatcher.createSchema(schemaIdent, schemaComment, props);
+
+ NameIdentifier modelIdent =
+ NameIdentifierUtil.ofModel(metalake, catalog, schemaName, modelName);
+ modelOperationDispatcher.registerModel(modelIdent, modelComment, props);
+
+ modelOperationDispatcher.linkModelVersion(
+ modelIdent, versionUri, versionAliases, versionComment, props);
+
+ ModelVersionChange change =
+ ModelVersionChange.updateAliases(new String[] {"new_alias1",
"new_alias2"}, versionAliases);
+ ModelVersion modelVersion =
modelOperationDispatcher.getModelVersion(modelIdent, 0);
+ ModelVersion alteredModelVersion =
+ modelOperationDispatcher.alterModelVersion(modelIdent, 0, change);
+
+ Assertions.assertEquals(modelVersion.uri(), alteredModelVersion.uri());
+ Assertions.assertEquals(modelVersion.version(),
alteredModelVersion.version());
+ Assertions.assertArrayEquals(
+ new String[] {"new_alias1", "new_alias2"},
alteredModelVersion.aliases());
+ Assertions.assertEquals(modelVersion.comment(),
alteredModelVersion.comment());
+ Assertions.assertEquals(modelVersion.properties(),
alteredModelVersion.properties());
+
+ // Reload model version
+ ModelVersion reloadedModelVersion =
modelOperationDispatcher.getModelVersion(modelIdent, 0);
+ Assertions.assertEquals(modelVersion.uri(), reloadedModelVersion.uri());
+ Assertions.assertEquals(modelVersion.version(),
reloadedModelVersion.version());
+ Assertions.assertArrayEquals(
+ new String[] {"new_alias1", "new_alias2"},
reloadedModelVersion.aliases());
+ Assertions.assertEquals(modelVersion.comment(),
reloadedModelVersion.comment());
+ Assertions.assertEquals(modelVersion.properties(),
reloadedModelVersion.properties());
+ }
+
+ @Test
+ void testUpdateModelVersionAliasesByAlias() {
+ String schemaName = randomSchemaName();
+ String schemaComment = "schema which tests update";
+
+ String modelName = randomModelName();
+ String modelComment = "model which tests update";
+ Map<String, String> props = ImmutableMap.of("k1", "v1", "k2", "v2");
+
+ String versionUri = "s3://test-bucket/test-path/model.json";
+ String[] versionAliases = {"alias1", "alias2"};
+ String versionComment = "version which tests update";
+
+ NameIdentifier schemaIdent = NameIdentifier.of(metalake, catalog,
schemaName);
+ schemaOperationDispatcher.createSchema(schemaIdent, schemaComment, props);
+
+ NameIdentifier modelIdent =
+ NameIdentifierUtil.ofModel(metalake, catalog, schemaName, modelName);
+ modelOperationDispatcher.registerModel(modelIdent, modelComment, props);
+
+ modelOperationDispatcher.linkModelVersion(
+ modelIdent, versionUri, versionAliases, versionComment, props);
+
+ ModelVersionChange change =
+ ModelVersionChange.updateAliases(new String[] {"new_alias1",
"new_alias2"}, versionAliases);
+ ModelVersion modelVersion =
+ modelOperationDispatcher.getModelVersion(modelIdent,
versionAliases[0]);
+ ModelVersion alteredModelVersion =
+ modelOperationDispatcher.alterModelVersion(modelIdent,
versionAliases[0], change);
+
+ Assertions.assertEquals(modelVersion.uri(), alteredModelVersion.uri());
+ Assertions.assertEquals(modelVersion.version(),
alteredModelVersion.version());
+ Assertions.assertArrayEquals(
+ new String[] {"new_alias1", "new_alias2"},
alteredModelVersion.aliases());
+ Assertions.assertEquals(modelVersion.comment(),
alteredModelVersion.comment());
+ Assertions.assertEquals(modelVersion.properties(),
alteredModelVersion.properties());
+
+ // Reload model version
+ ModelVersion reloadedModelVersion =
+ modelOperationDispatcher.getModelVersion(modelIdent, "new_alias1");
+ Assertions.assertEquals(modelVersion.uri(), reloadedModelVersion.uri());
+ Assertions.assertEquals(modelVersion.version(),
reloadedModelVersion.version());
+ Assertions.assertArrayEquals(
+ new String[] {"new_alias1", "new_alias2"},
reloadedModelVersion.aliases());
+ Assertions.assertEquals(modelVersion.comment(),
reloadedModelVersion.comment());
+ Assertions.assertEquals(modelVersion.properties(),
reloadedModelVersion.properties());
+ }
+
+ @Test
+ void testUpdatePartialModelVersionAliases() {
+ String schemaName = randomSchemaName();
+ String schemaComment = "schema which tests update";
+
+ String modelName = randomModelName();
+ String modelComment = "model which tests update";
+ Map<String, String> props = ImmutableMap.of("k1", "v1", "k2", "v2");
+
+ String versionUri = "s3://test-bucket/test-path/model.json";
+ String[] versionAliases = {"alias1", "alias2"};
+ String versionComment = "version which tests update";
+
+ NameIdentifier schemaIdent = NameIdentifier.of(metalake, catalog,
schemaName);
+ schemaOperationDispatcher.createSchema(schemaIdent, schemaComment, props);
+
+ NameIdentifier modelIdent =
+ NameIdentifierUtil.ofModel(metalake, catalog, schemaName, modelName);
+ modelOperationDispatcher.registerModel(modelIdent, modelComment, props);
+
+ modelOperationDispatcher.linkModelVersion(
+ modelIdent, versionUri, versionAliases, versionComment, props);
+
+ ModelVersionChange change =
+ ModelVersionChange.updateAliases(
+ new String[] {"new_alias1", "new_alias2"}, new String[]
{"alias1"});
+ ModelVersion modelVersion =
modelOperationDispatcher.getModelVersion(modelIdent, 0);
+ ModelVersion alteredModelVersion =
+ modelOperationDispatcher.alterModelVersion(modelIdent, 0, change);
+
+ Assertions.assertEquals(modelVersion.uri(), alteredModelVersion.uri());
+ Assertions.assertEquals(modelVersion.version(),
alteredModelVersion.version());
+ Assertions.assertArrayEquals(
+ new String[] {"alias2", "new_alias1", "new_alias2"},
alteredModelVersion.aliases());
+ Assertions.assertEquals(modelVersion.comment(),
alteredModelVersion.comment());
+ Assertions.assertEquals(modelVersion.properties(),
alteredModelVersion.properties());
+
+ // Reload model version
+ ModelVersion reloadedModelVersion =
modelOperationDispatcher.getModelVersion(modelIdent, 0);
+ Assertions.assertEquals(modelVersion.uri(), reloadedModelVersion.uri());
+ Assertions.assertEquals(modelVersion.version(),
reloadedModelVersion.version());
+ Assertions.assertArrayEquals(
+ new String[] {"alias2", "new_alias1", "new_alias2"},
reloadedModelVersion.aliases());
+ Assertions.assertEquals(modelVersion.comment(),
reloadedModelVersion.comment());
+ Assertions.assertEquals(modelVersion.properties(),
reloadedModelVersion.properties());
+ }
+
+ @Test
+ void testUpdatePartialModelVersionAliasesByAlias() {
+ String schemaName = randomSchemaName();
+ String schemaComment = "schema which tests update";
+
+ String modelName = randomModelName();
+ String modelComment = "model which tests update";
+ Map<String, String> props = ImmutableMap.of("k1", "v1", "k2", "v2");
+
+ String versionUri = "s3://test-bucket/test-path/model.json";
+ String[] versionAliases = {"alias1", "alias2"};
+ String versionComment = "version which tests update";
+
+ NameIdentifier schemaIdent = NameIdentifier.of(metalake, catalog,
schemaName);
+ schemaOperationDispatcher.createSchema(schemaIdent, schemaComment, props);
+
+ NameIdentifier modelIdent =
+ NameIdentifierUtil.ofModel(metalake, catalog, schemaName, modelName);
+ modelOperationDispatcher.registerModel(modelIdent, modelComment, props);
+
+ modelOperationDispatcher.linkModelVersion(
+ modelIdent, versionUri, versionAliases, versionComment, props);
+
+ ModelVersionChange change =
+ ModelVersionChange.updateAliases(
+ new String[] {"new_alias1", "new_alias2"}, new String[]
{"alias1"});
+ ModelVersion modelVersion =
+ modelOperationDispatcher.getModelVersion(modelIdent,
versionAliases[0]);
+ ModelVersion alteredModelVersion =
+ modelOperationDispatcher.alterModelVersion(modelIdent,
versionAliases[0], change);
+
+ Assertions.assertEquals(modelVersion.uri(), alteredModelVersion.uri());
+ Assertions.assertEquals(modelVersion.version(),
alteredModelVersion.version());
+ Assertions.assertArrayEquals(
+ new String[] {"alias2", "new_alias1", "new_alias2"},
alteredModelVersion.aliases());
+ Assertions.assertEquals(modelVersion.comment(),
alteredModelVersion.comment());
+ Assertions.assertEquals(modelVersion.properties(),
alteredModelVersion.properties());
+
+ // Reload model version
+ ModelVersion reloadedModelVersion =
+ modelOperationDispatcher.getModelVersion(modelIdent, "new_alias1");
+ Assertions.assertEquals(modelVersion.uri(), reloadedModelVersion.uri());
+ Assertions.assertEquals(modelVersion.version(),
reloadedModelVersion.version());
+ Assertions.assertArrayEquals(
+ new String[] {"alias2", "new_alias1", "new_alias2"},
reloadedModelVersion.aliases());
+ Assertions.assertEquals(modelVersion.comment(),
reloadedModelVersion.comment());
+ Assertions.assertEquals(modelVersion.properties(),
reloadedModelVersion.properties());
+ }
+
+ @Test
+ void testUpdateModelVersionAliasesOverlapAddAndRemove() {
+ String schemaName = randomSchemaName();
+ String schemaComment = "schema which tests update";
+
+ String modelName = randomModelName();
+ String modelComment = "model which tests update";
+ Map<String, String> props = ImmutableMap.of("k1", "v1", "k2", "v2");
+
+ String versionUri = "s3://test-bucket/test-path/model.json";
+ String[] versionAliases = {"alias2", "alias3"};
+ String versionComment = "version which tests update";
+
+ NameIdentifier schemaIdent = NameIdentifier.of(metalake, catalog,
schemaName);
+ schemaOperationDispatcher.createSchema(schemaIdent, schemaComment, props);
+
+ NameIdentifier modelIdent =
+ NameIdentifierUtil.ofModel(metalake, catalog, schemaName, modelName);
+ modelOperationDispatcher.registerModel(modelIdent, modelComment, props);
+
+ modelOperationDispatcher.linkModelVersion(
+ modelIdent, versionUri, versionAliases, versionComment, props);
+
+ ModelVersionChange change =
+ ModelVersionChange.updateAliases(
+ new String[] {"alias1", "alias2"}, new String[] {"alias2",
"alias3"});
+ ModelVersion modelVersion =
modelOperationDispatcher.getModelVersion(modelIdent, 0);
+ ModelVersion alteredModelVersion =
+ modelOperationDispatcher.alterModelVersion(modelIdent, 0, change);
+
+ Assertions.assertEquals(modelVersion.uri(), alteredModelVersion.uri());
+ Assertions.assertEquals(modelVersion.version(),
alteredModelVersion.version());
+ Assertions.assertEquals(
+ ImmutableSet.of("alias1", "alias2"),
+
Arrays.stream(alteredModelVersion.aliases()).collect(Collectors.toSet()));
+ Assertions.assertEquals(modelVersion.comment(),
alteredModelVersion.comment());
+ Assertions.assertEquals(modelVersion.properties(),
alteredModelVersion.properties());
+
+ // Reload model version
+ ModelVersion reloadedModelVersion =
modelOperationDispatcher.getModelVersion(modelIdent, 0);
+ Assertions.assertEquals(modelVersion.uri(), reloadedModelVersion.uri());
+ Assertions.assertEquals(modelVersion.version(),
reloadedModelVersion.version());
+ Assertions.assertEquals(
+ ImmutableSet.of("alias1", "alias2"),
+
Arrays.stream(reloadedModelVersion.aliases()).collect(Collectors.toSet()));
+ Assertions.assertEquals(modelVersion.comment(),
reloadedModelVersion.comment());
+ Assertions.assertEquals(modelVersion.properties(),
reloadedModelVersion.properties());
+ }
+
+ @Test
+ void testUpdateModelVersionAliasesByAliasOverlapAddAndRemove() {
+ String schemaName = randomSchemaName();
+ String schemaComment = "schema which tests update";
+
+ String modelName = randomModelName();
+ String modelComment = "model which tests update";
+ Map<String, String> props = ImmutableMap.of("k1", "v1", "k2", "v2");
+
+ String versionUri = "s3://test-bucket/test-path/model.json";
+ String[] versionAliases = {"alias2", "alias3"};
+ String versionComment = "version which tests update";
+
+ NameIdentifier schemaIdent = NameIdentifier.of(metalake, catalog,
schemaName);
+ schemaOperationDispatcher.createSchema(schemaIdent, schemaComment, props);
+
+ NameIdentifier modelIdent =
+ NameIdentifierUtil.ofModel(metalake, catalog, schemaName, modelName);
+ modelOperationDispatcher.registerModel(modelIdent, modelComment, props);
+
+ modelOperationDispatcher.linkModelVersion(
+ modelIdent, versionUri, versionAliases, versionComment, props);
+
+ ModelVersionChange change =
+ ModelVersionChange.updateAliases(
+ new String[] {"alias1", "alias2"}, new String[] {"alias2",
"alias3"});
+ ModelVersion modelVersion =
modelOperationDispatcher.getModelVersion(modelIdent, "alias2");
+ ModelVersion alteredModelVersion =
+ modelOperationDispatcher.alterModelVersion(modelIdent, "alias2",
change);
+
+ Assertions.assertEquals(modelVersion.uri(), alteredModelVersion.uri());
+ Assertions.assertEquals(modelVersion.version(),
alteredModelVersion.version());
+ Assertions.assertEquals(
+ ImmutableSet.of("alias1", "alias2"),
+
Arrays.stream(alteredModelVersion.aliases()).collect(Collectors.toSet()));
+ Assertions.assertEquals(modelVersion.comment(),
alteredModelVersion.comment());
+ Assertions.assertEquals(modelVersion.properties(),
alteredModelVersion.properties());
+
+ // Reload model version
+ ModelVersion reloadedModelVersion =
+ modelOperationDispatcher.getModelVersion(modelIdent, "alias1");
+ Assertions.assertEquals(modelVersion.uri(), reloadedModelVersion.uri());
+ Assertions.assertEquals(modelVersion.version(),
reloadedModelVersion.version());
+ Assertions.assertEquals(
+ ImmutableSet.of("alias1", "alias2"),
+
Arrays.stream(reloadedModelVersion.aliases()).collect(Collectors.toSet()));
+ Assertions.assertEquals(modelVersion.comment(),
reloadedModelVersion.comment());
+ Assertions.assertEquals(modelVersion.properties(),
reloadedModelVersion.properties());
+ }
+
private String randomSchemaName() {
return "schema_" + UUID.randomUUID().toString().replace("-", "");
}
diff --git
a/core/src/test/java/org/apache/gravitino/connector/TestCatalogOperations.java
b/core/src/test/java/org/apache/gravitino/connector/TestCatalogOperations.java
index 8e91f21575..8370f6a1ae 100644
---
a/core/src/test/java/org/apache/gravitino/connector/TestCatalogOperations.java
+++
b/core/src/test/java/org/apache/gravitino/connector/TestCatalogOperations.java
@@ -23,15 +23,19 @@ import static
org.apache.gravitino.file.Fileset.PROPERTY_DEFAULT_LOCATION_NAME;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
import java.io.File;
import java.io.IOException;
import java.time.Instant;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
@@ -1064,6 +1068,18 @@ public class TestCatalogOperations
ModelVersionChange.SetProperty setProperty =
(ModelVersionChange.SetProperty) change;
newProps.put(setProperty.property(), setProperty.value());
+ } else if (change instanceof ModelVersionChange.UpdateAliases) {
+ ModelVersionChange.UpdateAliases updateAliasesChange =
+ (ModelVersionChange.UpdateAliases) change;
+
+ Set<String> addTmpSet = updateAliasesChange.aliasesToAdd();
+ Set<String> deleteTmpSet = updateAliasesChange.aliasesToDelete();
+ Set<String> aliasToAdd = Sets.difference(addTmpSet,
deleteTmpSet).immutableCopy();
+ Set<String> aliasToDelete = Sets.difference(deleteTmpSet,
addTmpSet).immutableCopy();
+
+ newAliases = doDeleteAlias(newAliases, aliasToDelete);
+ newAliases = doSetAlias(newAliases, aliasToAdd);
+
} else if (change instanceof ModelVersionChange.UpdateUri) {
ModelVersionChange.UpdateUri updateUriChange =
(ModelVersionChange.UpdateUri) change;
newUri = updateUriChange.newUri();
@@ -1084,6 +1100,10 @@ public class TestCatalogOperations
.build();
modelVersions.put(versionPair, updatedModelVersion);
+
+ Arrays.stream(newAliases)
+ .map(alias -> Pair.of(ident, alias))
+ .forEach(pair -> modelAliasToVersion.put(pair, newVersion));
return updatedModelVersion;
}
@@ -1284,4 +1304,24 @@ public class TestCatalogOperations
.sorted(Comparator.comparingInt(TestColumn::position))
.toArray(TestColumn[]::new);
}
+
+ private String[] doDeleteAlias(String[] entityAliases, Set<String>
aliasToDelete) {
+ List<String> aliasList = new ArrayList<>(Arrays.asList(entityAliases));
+ aliasList.removeAll(aliasToDelete);
+
+ return aliasList.toArray(new String[0]);
+ }
+
+ private String[] doSetAlias(String[] entityAliases, Set<String> aliasToAdd) {
+ List<String> aliasList = new ArrayList<>(Arrays.asList(entityAliases));
+ Set<String> aliasSet = new HashSet<>(aliasList);
+
+ for (String alias : aliasToAdd) {
+ if (aliasSet.add(alias)) {
+ aliasList.add(alias);
+ }
+ }
+
+ return aliasList.toArray(new String[0]);
+ }
}
diff --git
a/core/src/test/java/org/apache/gravitino/storage/relational/service/TestModelVersionMetaService.java
b/core/src/test/java/org/apache/gravitino/storage/relational/service/TestModelVersionMetaService.java
index 38bfc1e94e..06102f6447 100644
---
a/core/src/test/java/org/apache/gravitino/storage/relational/service/TestModelVersionMetaService.java
+++
b/core/src/test/java/org/apache/gravitino/storage/relational/service/TestModelVersionMetaService.java
@@ -715,6 +715,89 @@ public class TestModelVersionMetaService extends
TestJDBCBackend {
updatePropertiesUpdater));
}
+ @Test
+ void testUpdateModelVersionAliases() throws IOException {
+ createParentEntities(METALAKE_NAME, CATALOG_NAME, SCHEMA_NAME, auditInfo);
+
+ Map<String, String> properties = ImmutableMap.of("k1", "v1", "k2", "v2");
+ String modelName = randomModelName();
+ String modelComment = "model1 comment";
+ String modelVersionUri = "S3://test/path/to/model/version";
+ List<String> modelVersionAliases = ImmutableList.of("alias1", "alias2");
+ List<String> updatedVersionAliases = ImmutableList.of("alias2", "alias3");
+ String modelVersionComment = "test comment";
+ int version = 0;
+
+ ModelEntity modelEntity =
+ createModelEntity(
+ RandomIdGenerator.INSTANCE.nextId(),
+ MODEL_NS,
+ modelName,
+ modelComment,
+ 0,
+ properties,
+ auditInfo);
+
+ ModelVersionEntity modelVersionEntity =
+ createModelVersionEntity(
+ modelEntity.nameIdentifier(),
+ version,
+ modelVersionUri,
+ modelVersionAliases,
+ modelVersionComment,
+ properties,
+ auditInfo);
+
+ ModelVersionEntity updatedModelVersionEntity =
+ createModelVersionEntity(
+ modelVersionEntity.modelIdentifier(),
+ modelVersionEntity.version(),
+ modelVersionEntity.uri(),
+ updatedVersionAliases,
+ modelVersionEntity.comment(),
+ modelVersionEntity.properties(),
+ modelVersionEntity.auditInfo());
+
+ Assertions.assertDoesNotThrow(
+ () -> ModelMetaService.getInstance().insertModel(modelEntity, false));
+
+ Assertions.assertDoesNotThrow(
+ () ->
ModelVersionMetaService.getInstance().insertModelVersion(modelVersionEntity));
+
+ Function<ModelVersionEntity, ModelVersionEntity> updatePropertiesUpdater =
+ oldModelVersionEntity -> updatedModelVersionEntity;
+
+ ModelVersionEntity alteredModelVersionEntity =
+ ModelVersionMetaService.getInstance()
+ .updateModelVersion(modelVersionEntity.nameIdentifier(),
updatePropertiesUpdater);
+
+ Assertions.assertEquals(updatedModelVersionEntity,
alteredModelVersionEntity);
+
+ // Test update a non-exist model
+ Assertions.assertThrows(
+ NoSuchEntityException.class,
+ () ->
+ ModelVersionMetaService.getInstance()
+ .updateModelVersion(
+ NameIdentifierUtil.ofModelVersion(
+ METALAKE_NAME,
+ CATALOG_NAME,
+ SCHEMA_NAME,
+ "non_exist_model",
+ "non_exist_version"),
+ updatePropertiesUpdater));
+
+ // Test update a non-exist model version
+ Assertions.assertThrows(
+ NoSuchEntityException.class,
+ () ->
+ ModelVersionMetaService.getInstance()
+ .updateModelVersion(
+ NameIdentifierUtil.ofModelVersion(
+ METALAKE_NAME, CATALOG_NAME, SCHEMA_NAME, modelName,
"non_exist_version"),
+ updatePropertiesUpdater));
+ }
+
private NameIdentifier getModelVersionIdent(NameIdentifier modelIdent, int
version) {
List<String> parts = Lists.newArrayList(modelIdent.namespace().levels());
parts.add(modelIdent.name());