This is an automated email from the ASF dual-hosted git repository.
jshao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gravitino.git
The following commit(s) were added to refs/heads/main by this push:
new 19b9e41d6b [#6813] feat(core): Support update URI for model version.
(#7017)
19b9e41d6b is described below
commit 19b9e41d6b11d75eb84c52e2ece2ce77e64891db
Author: Lord of Abyss <[email protected]>
AuthorDate: Mon Apr 21 19:22:15 2025 +0800
[#6813] feat(core): Support update URI for model version. (#7017)
### What changes were proposed in this pull request?
Support update URI for model version.
> This pr add a temporary option for model version update, it will be
refactored in the future when add model_version entity.
### Why are the changes needed?
(Please clarify why the changes are needed. For instance,
Fix: #6813
### Does this PR introduce _any_ user-facing change?
User can update uri of a model version now.
### How was this patch tested?
local test + ut.
original uri

`bin/gcli.sh model update -m demo_metalake --name
model_catalog.schema.model2 --version 0 --newuri 's3:///bucket/key' -I`

---
.../integration/test/ModelCatalogOperationsIT.java | 75 ++++++++++++++++++
.../org/apache/gravitino/cli/GravitinoOptions.java | 4 +
.../apache/gravitino/cli/ModelCommandHandler.java | 12 +++
.../apache/gravitino/cli/TestableCommandLine.java | 14 ++++
.../cli/commands/UpdateModelVersionComment.java | 2 +-
...sionComment.java => UpdateModelVersionUri.java} | 36 ++++-----
.../apache/gravitino/cli/TestModelCommands.java | 88 ++++++++++++++++++++++
.../org/apache/gravitino/client/DTOConverters.java | 4 +
.../gravitino/api/model_version_change.py | 52 +++++++++++++
clients/client-python/gravitino/api/types/types.py | 1 +
.../gravitino/client/generic_model_catalog.py | 5 ++
.../dto/requests/model_version_update_request.py | 31 ++++++++
.../tests/integration/test_model_catalog.py | 39 ++++++++++
.../dto/requests/ModelVersionUpdateRequest.java | 33 +++++++-
14 files changed, 376 insertions(+), 20 deletions(-)
diff --git
a/catalogs/catalog-model/src/test/java/org/apache/gravtitino/catalog/model/integration/test/ModelCatalogOperationsIT.java
b/catalogs/catalog-model/src/test/java/org/apache/gravtitino/catalog/model/integration/test/ModelCatalogOperationsIT.java
index d75c21f371..8590878d6d 100644
---
a/catalogs/catalog-model/src/test/java/org/apache/gravtitino/catalog/model/integration/test/ModelCatalogOperationsIT.java
+++
b/catalogs/catalog-model/src/test/java/org/apache/gravtitino/catalog/model/integration/test/ModelCatalogOperationsIT.java
@@ -557,6 +557,81 @@ public class ModelCatalogOperationsIT extends BaseIT {
Assertions.assertEquals(newProperties, updatedModelVersion.properties());
}
+ @Test
+ void testLinkAndUpdateModelVersionUri() {
+ String modelName = RandomNameUtils.genRandomName("model1");
+ String[] aliases = {"alias1"};
+ Map<String, String> properties = ImmutableMap.of("key1", "val1", "key2",
"val2");
+ NameIdentifier modelIdent = NameIdentifier.of(schemaName, modelName);
+
+ String uri = "s3://bucket/path/to/model.zip";
+ String newUri = "s3://bucket/path/to/new_model.zip";
+ String versionComment = "comment";
+
+ gravitinoCatalog.asModelCatalog().registerModel(modelIdent, null, null);
+
+ gravitinoCatalog
+ .asModelCatalog()
+ .linkModelVersion(modelIdent, uri, aliases, versionComment,
properties);
+
+ ModelVersion modelVersion =
gravitinoCatalog.asModelCatalog().getModelVersion(modelIdent, 0);
+
+ Assertions.assertEquals(0, modelVersion.version());
+ Assertions.assertEquals(uri, modelVersion.uri());
+ Assertions.assertArrayEquals(aliases, modelVersion.aliases());
+ Assertions.assertEquals(versionComment, modelVersion.comment());
+ Assertions.assertEquals(properties, modelVersion.properties());
+
+ ModelVersionChange updateUriChange = ModelVersionChange.updateUri(newUri);
+ ModelVersion updatedModelVersion =
+ gravitinoCatalog.asModelCatalog().alterModelVersion(modelIdent, 0,
updateUriChange);
+
+ Assertions.assertEquals(modelVersion.version(),
updatedModelVersion.version());
+ Assertions.assertEquals(newUri, updatedModelVersion.uri());
+ Assertions.assertArrayEquals(modelVersion.aliases(),
updatedModelVersion.aliases());
+ Assertions.assertEquals(modelVersion.comment(),
updatedModelVersion.comment());
+ Assertions.assertEquals(modelVersion.properties(),
updatedModelVersion.properties());
+ }
+
+ @Test
+ void testLinkAndUpdateModelVersionUriByAlias() {
+ String modelName = RandomNameUtils.genRandomName("model1");
+ String[] aliases = {"alias1"};
+ Map<String, String> properties = ImmutableMap.of("key1", "val1", "key2",
"val2");
+ NameIdentifier modelIdent = NameIdentifier.of(schemaName, modelName);
+
+ String uri = "s3://bucket/path/to/model.zip";
+ String newUri = "s3://bucket/path/to/new_model.zip";
+ String versionComment = "comment";
+
+ gravitinoCatalog.asModelCatalog().registerModel(modelIdent, null, null);
+
+ gravitinoCatalog
+ .asModelCatalog()
+ .linkModelVersion(modelIdent, uri, aliases, versionComment,
properties);
+
+ ModelVersion modelVersion =
+ gravitinoCatalog.asModelCatalog().getModelVersion(modelIdent,
aliases[0]);
+
+ Assertions.assertEquals(0, modelVersion.version());
+ Assertions.assertEquals(uri, modelVersion.uri());
+ Assertions.assertArrayEquals(aliases, modelVersion.aliases());
+ Assertions.assertEquals(versionComment, modelVersion.comment());
+ Assertions.assertEquals(properties, modelVersion.properties());
+
+ ModelVersionChange updateUriChange = ModelVersionChange.updateUri(newUri);
+ ModelVersion updatedModelVersion =
+ gravitinoCatalog
+ .asModelCatalog()
+ .alterModelVersion(modelIdent, aliases[0], updateUriChange);
+
+ Assertions.assertEquals(modelVersion.version(),
updatedModelVersion.version());
+ Assertions.assertEquals(newUri, updatedModelVersion.uri());
+ Assertions.assertArrayEquals(modelVersion.aliases(),
updatedModelVersion.aliases());
+ Assertions.assertEquals(modelVersion.comment(),
updatedModelVersion.comment());
+ Assertions.assertEquals(modelVersion.properties(),
updatedModelVersion.properties());
+ }
+
@Test
void testLinkAndRemoveModelVersionProperties() {
String modelName = RandomNameUtils.genRandomName("model1");
diff --git
a/clients/cli/src/main/java/org/apache/gravitino/cli/GravitinoOptions.java
b/clients/cli/src/main/java/org/apache/gravitino/cli/GravitinoOptions.java
index 52c15ad80b..5fd861f078 100644
--- a/clients/cli/src/main/java/org/apache/gravitino/cli/GravitinoOptions.java
+++ b/clients/cli/src/main/java/org/apache/gravitino/cli/GravitinoOptions.java
@@ -66,6 +66,9 @@ public class GravitinoOptions {
public static final String DISABLE = "disable";
public static final String ALIAS = "alias";
public static final String URI = "uri";
+ // TODO: temporary option for model version update, it will be refactored in
the future, just
+ // prove the E2E flow.
+ public static final String NEW_URI = "newuri";
/**
* Builds and returns the CLI options for Gravitino.
@@ -118,6 +121,7 @@ public class GravitinoOptions {
options.addOption(createArgOption(null, URI, "model version artifact"));
options.addOption(createArgsOption(null, ALIAS, "model aliases"));
options.addOption(createArgOption(null, VERSION, "Gravitino client
version"));
+ options.addOption(createArgOption(null, NEW_URI, "New uri of a model
version"));
// Options that support multiple values
options.addOption(createArgsOption("p", PROPERTIES, "property name/value
pairs"));
diff --git
a/clients/cli/src/main/java/org/apache/gravitino/cli/ModelCommandHandler.java
b/clients/cli/src/main/java/org/apache/gravitino/cli/ModelCommandHandler.java
index ec58f98aa6..ee15d3a88f 100644
---
a/clients/cli/src/main/java/org/apache/gravitino/cli/ModelCommandHandler.java
+++
b/clients/cli/src/main/java/org/apache/gravitino/cli/ModelCommandHandler.java
@@ -194,6 +194,18 @@ public class ModelCommandHandler extends CommandHandler {
.validate()
.handle();
}
+
+ if (line.hasOption(GravitinoOptions.NEW_URI)
+ && (line.hasOption(GravitinoOptions.ALIAS) ||
line.hasOption(GravitinoOptions.VERSION))) {
+ String newUri = line.getOptionValue(GravitinoOptions.NEW_URI);
+ Integer version = getVersionFromLine(line);
+ String alias = getAliasFromLine(line);
+ gravitinoCommandLine
+ .newUpdateModelVersionUri(
+ context, metalake, catalog, schema, model, version, alias,
newUri)
+ .validate()
+ .handle();
+ }
}
/** Handles the "LIST" command. */
diff --git
a/clients/cli/src/main/java/org/apache/gravitino/cli/TestableCommandLine.java
b/clients/cli/src/main/java/org/apache/gravitino/cli/TestableCommandLine.java
index 4c25781865..b32a758ef6 100644
---
a/clients/cli/src/main/java/org/apache/gravitino/cli/TestableCommandLine.java
+++
b/clients/cli/src/main/java/org/apache/gravitino/cli/TestableCommandLine.java
@@ -138,6 +138,7 @@ import
org.apache.gravitino.cli.commands.UpdateMetalakeComment;
import org.apache.gravitino.cli.commands.UpdateMetalakeName;
import org.apache.gravitino.cli.commands.UpdateModelName;
import org.apache.gravitino.cli.commands.UpdateModelVersionComment;
+import org.apache.gravitino.cli.commands.UpdateModelVersionUri;
import org.apache.gravitino.cli.commands.UpdateTableComment;
import org.apache.gravitino.cli.commands.UpdateTableName;
import org.apache.gravitino.cli.commands.UpdateTagComment;
@@ -901,6 +902,19 @@ public class TestableCommandLine {
context, metalake, catalog, schema, model, version, alias, comment);
}
+ protected UpdateModelVersionUri newUpdateModelVersionUri(
+ CommandContext context,
+ String metalake,
+ String catalog,
+ String schema,
+ String model,
+ Integer version,
+ String alias,
+ String uri) {
+ return new UpdateModelVersionUri(
+ context, metalake, catalog, schema, model, version, alias, uri);
+ }
+
protected SetModelProperty newSetModelProperty(
CommandContext context,
String metalake,
diff --git
a/clients/cli/src/main/java/org/apache/gravitino/cli/commands/UpdateModelVersionComment.java
b/clients/cli/src/main/java/org/apache/gravitino/cli/commands/UpdateModelVersionComment.java
index bb280b189e..283848d322 100644
---
a/clients/cli/src/main/java/org/apache/gravitino/cli/commands/UpdateModelVersionComment.java
+++
b/clients/cli/src/main/java/org/apache/gravitino/cli/commands/UpdateModelVersionComment.java
@@ -95,7 +95,7 @@ public class UpdateModelVersionComment extends Command {
}
if (alias != null) {
- printInformation(model + " version " + alias + " comment changed.");
+ printInformation(model + " alias " + alias + " comment changed.");
} else {
printInformation(model + " version " + version + " comment changed.");
}
diff --git
a/clients/cli/src/main/java/org/apache/gravitino/cli/commands/UpdateModelVersionComment.java
b/clients/cli/src/main/java/org/apache/gravitino/cli/commands/UpdateModelVersionUri.java
similarity index 77%
copy from
clients/cli/src/main/java/org/apache/gravitino/cli/commands/UpdateModelVersionComment.java
copy to
clients/cli/src/main/java/org/apache/gravitino/cli/commands/UpdateModelVersionUri.java
index bb280b189e..0b5634bcb1 100644
---
a/clients/cli/src/main/java/org/apache/gravitino/cli/commands/UpdateModelVersionComment.java
+++
b/clients/cli/src/main/java/org/apache/gravitino/cli/commands/UpdateModelVersionUri.java
@@ -29,29 +29,29 @@ import
org.apache.gravitino.exceptions.NoSuchSchemaException;
import org.apache.gravitino.exceptions.NoSuchTableException;
import org.apache.gravitino.model.ModelVersionChange;
-/** Update the comment of a model version. */
-public class UpdateModelVersionComment extends Command {
+/** Update the uri of a model version. */
+public class UpdateModelVersionUri extends Command {
protected final String metalake;
protected final String catalog;
protected final String schema;
protected final String model;
protected final Integer version;
private final String alias;
- private final String comment;
+ private final String uri;
/**
- * Constructs a new {@link UpdateModelVersionComment} instance.
+ * Construct a new {@link UpdateModelVersionUri} instance.
*
* @param context The command context.
- * @param metalake The name of the metalake.
- * @param catalog The name of the catalog.
- * @param schema The name of the schema.
- * @param model The name of the model.
- * @param version The version of the model.
- * @param alias The alias of the model version.
- * @param comment The new comment for the model version.
+ * @param metalake The metalake name.
+ * @param catalog The catalog name.
+ * @param schema The schema name.
+ * @param model The model name.
+ * @param version The version number
+ * @param alias The alias name.
+ * @param uri The new uri.
*/
- public UpdateModelVersionComment(
+ public UpdateModelVersionUri(
CommandContext context,
String metalake,
String catalog,
@@ -59,7 +59,7 @@ public class UpdateModelVersionComment extends Command {
String model,
Integer version,
String alias,
- String comment) {
+ String uri) {
super(context);
this.metalake = metalake;
this.catalog = catalog;
@@ -67,16 +67,16 @@ public class UpdateModelVersionComment extends Command {
this.model = model;
this.version = version;
this.alias = alias;
- this.comment = comment;
+ this.uri = uri;
}
- /** Update the comment of a model version. */
+ /** Update the uri of a model version. */
@Override
public void handle() {
try {
NameIdentifier modelIdent = NameIdentifier.of(schema, model);
GravitinoClient client = buildClient(metalake);
- ModelVersionChange change = ModelVersionChange.updateComment(comment);
+ ModelVersionChange change = ModelVersionChange.updateUri(uri);
if (alias != null) {
client.loadCatalog(catalog).asModelCatalog().alterModelVersion(modelIdent,
alias, change);
} else {
@@ -95,9 +95,9 @@ public class UpdateModelVersionComment extends Command {
}
if (alias != null) {
- printInformation(model + " version " + alias + " comment changed.");
+ printInformation(model + " alias " + alias + " uri changed.");
} else {
- printInformation(model + " version " + version + " comment changed.");
+ printInformation(model + " version " + version + " uri changed.");
}
}
diff --git
a/clients/cli/src/test/java/org/apache/gravitino/cli/TestModelCommands.java
b/clients/cli/src/test/java/org/apache/gravitino/cli/TestModelCommands.java
index e4014e1c1b..99b17bce23 100644
--- a/clients/cli/src/test/java/org/apache/gravitino/cli/TestModelCommands.java
+++ b/clients/cli/src/test/java/org/apache/gravitino/cli/TestModelCommands.java
@@ -52,6 +52,7 @@ import org.apache.gravitino.cli.commands.SetModelProperty;
import org.apache.gravitino.cli.commands.SetModelVersionProperty;
import org.apache.gravitino.cli.commands.UpdateModelName;
import org.apache.gravitino.cli.commands.UpdateModelVersionComment;
+import org.apache.gravitino.cli.commands.UpdateModelVersionUri;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
@@ -935,4 +936,91 @@ public class TestModelCommands {
Assertions.assertThrows(RuntimeException.class,
commandLine::handleCommandLine);
}
+
+ @Test
+ void testUpdateModelVersionUri() {
+ UpdateModelVersionUri mockUpdate = mock(UpdateModelVersionUri.class);
+
when(mockCommandLine.hasOption(GravitinoOptions.METALAKE)).thenReturn(true);
+
when(mockCommandLine.getOptionValue(GravitinoOptions.METALAKE)).thenReturn("metalake_demo");
+ when(mockCommandLine.hasOption(GravitinoOptions.NAME)).thenReturn(true);
+
when(mockCommandLine.getOptionValue(GravitinoOptions.NAME)).thenReturn("catalog.schema.model");
+ when(mockCommandLine.hasOption(GravitinoOptions.VERSION)).thenReturn(true);
+
when(mockCommandLine.getOptionValue(GravitinoOptions.VERSION)).thenReturn("1");
+ when(mockCommandLine.hasOption(GravitinoOptions.NEW_URI)).thenReturn(true);
+
when(mockCommandLine.getOptionValue(GravitinoOptions.NEW_URI)).thenReturn("uri");
+
+ GravitinoCommandLine commandLine =
+ spy(
+ new GravitinoCommandLine(
+ mockCommandLine, mockOptions, CommandEntities.MODEL,
CommandActions.UPDATE));
+
+ doReturn(mockUpdate)
+ .when(commandLine)
+ .newUpdateModelVersionUri(
+ any(CommandContext.class),
+ eq("metalake_demo"),
+ eq("catalog"),
+ eq("schema"),
+ eq("model"),
+ any(),
+ any(),
+ eq("uri"));
+ doReturn(mockUpdate).when(mockUpdate).validate();
+ commandLine.handleCommandLine();
+ verify(mockUpdate).handle();
+ }
+
+ @Test
+ void testUpdateModelVersionUriByAlias() {
+ UpdateModelVersionUri mockUpdate = mock(UpdateModelVersionUri.class);
+
when(mockCommandLine.hasOption(GravitinoOptions.METALAKE)).thenReturn(true);
+
when(mockCommandLine.getOptionValue(GravitinoOptions.METALAKE)).thenReturn("metalake_demo");
+ when(mockCommandLine.hasOption(GravitinoOptions.NAME)).thenReturn(true);
+
when(mockCommandLine.getOptionValue(GravitinoOptions.NAME)).thenReturn("catalog.schema.model");
+ when(mockCommandLine.hasOption(GravitinoOptions.ALIAS)).thenReturn(true);
+ when(mockCommandLine.getOptionValues(GravitinoOptions.ALIAS))
+ .thenReturn(new String[] {"aliasA"});
+ when(mockCommandLine.hasOption(GravitinoOptions.NEW_URI)).thenReturn(true);
+
when(mockCommandLine.getOptionValue(GravitinoOptions.NEW_URI)).thenReturn("uri");
+
+ GravitinoCommandLine commandLine =
+ spy(
+ new GravitinoCommandLine(
+ mockCommandLine, mockOptions, CommandEntities.MODEL,
CommandActions.UPDATE));
+
+ doReturn(mockUpdate)
+ .when(commandLine)
+ .newUpdateModelVersionUri(
+ any(CommandContext.class),
+ eq("metalake_demo"),
+ eq("catalog"),
+ eq("schema"),
+ eq("model"),
+ any(),
+ any(),
+ eq("uri"));
+ doReturn(mockUpdate).when(mockUpdate).validate();
+ commandLine.handleCommandLine();
+ verify(mockUpdate).handle();
+ }
+
+ @Test
+ void testUpdateModelVersionUriByAliasAndVersion() {
+
when(mockCommandLine.hasOption(GravitinoOptions.METALAKE)).thenReturn(true);
+
when(mockCommandLine.getOptionValue(GravitinoOptions.METALAKE)).thenReturn("metalake_demo");
+ when(mockCommandLine.hasOption(GravitinoOptions.ALIAS)).thenReturn(true);
+ when(mockCommandLine.getOptionValues(GravitinoOptions.ALIAS))
+ .thenReturn(new String[] {"aliasA"});
+ when(mockCommandLine.hasOption(GravitinoOptions.VERSION)).thenReturn(true);
+
when(mockCommandLine.getOptionValue(GravitinoOptions.VERSION)).thenReturn("1");
+ when(mockCommandLine.hasOption(GravitinoOptions.NEW_URI)).thenReturn(true);
+
when(mockCommandLine.getOptionValue(GravitinoOptions.NEW_URI)).thenReturn("uri");
+
+ GravitinoCommandLine commandLine =
+ spy(
+ new GravitinoCommandLine(
+ mockCommandLine, mockOptions, CommandEntities.MODEL,
CommandActions.UPDATE));
+
+ Assertions.assertThrows(RuntimeException.class,
commandLine::handleCommandLine);
+ }
}
diff --git
a/clients/client-java/src/main/java/org/apache/gravitino/client/DTOConverters.java
b/clients/client-java/src/main/java/org/apache/gravitino/client/DTOConverters.java
index 48ddd9df6e..f683f1ad94 100644
---
a/clients/client-java/src/main/java/org/apache/gravitino/client/DTOConverters.java
+++
b/clients/client-java/src/main/java/org/apache/gravitino/client/DTOConverters.java
@@ -400,6 +400,10 @@ class DTOConverters {
return new ModelVersionUpdateRequest.RemoveModelVersionPropertyRequest(
((ModelVersionChange.RemoveProperty) change).property());
+ } else if (change instanceof ModelVersionChange.UpdateUri) {
+ return new ModelVersionUpdateRequest.UpdateModelVersionUriRequest(
+ ((ModelVersionChange.UpdateUri) change).newUri());
+
} else {
throw new IllegalArgumentException(
"Unknown model version change type: " +
change.getClass().getSimpleName());
diff --git a/clients/client-python/gravitino/api/model_version_change.py
b/clients/client-python/gravitino/api/model_version_change.py
index 56550df874..4418b51f7f 100644
--- a/clients/client-python/gravitino/api/model_version_change.py
+++ b/clients/client-python/gravitino/api/model_version_change.py
@@ -55,6 +55,16 @@ class ModelVersionChange(ABC):
"""
return ModelVersionChange.RemoveProperty(key)
+ @staticmethod
+ def update_uri(uri: str):
+ """Creates a new model version change to update the uri of the model
version.
+ Args:
+ uri: The new uri of the model version.
+ Returns:
+ The model version change.
+ """
+ return ModelVersionChange.UpdateUri(uri)
+
class UpdateComment:
"""A model version change to update the comment of the model
version."""
@@ -188,3 +198,45 @@ class ModelVersionChange(ABC):
A string summary of this property remove operation.
"""
return f"RemoveProperty {self.key()}"
+
+ class UpdateUri:
+ """A model version change to update the URI of the model version."""
+
+ def __init__(self, new_uri: str):
+ self._new_uri = new_uri
+
+ def new_uri(self) -> str:
+ """Retrieves the new URI of the model version.
+ Returns:
+ The new URI of the model version.
+ """
+ return self._new_uri
+
+ def __eq__(self, other):
+ """Compares this UpdateUri instance with another object for
equality. Two instances are
+ considered equal if they designate the same new URI for the model
version.
+ Args:
+ other: The object to compare with this instance.
+ Returns:
+ true if the given object represents an identical model version
URI update operation;
+ false otherwise.
+ """
+ if not isinstance(other, ModelVersionChange.UpdateUri):
+ return False
+ return self.new_uri() == other.new_uri()
+
+ def __hash__(self):
+ """Generates a hash code for this UpdateUri instance. The hash
code is primarily based on
+ the new URI for the model version.
+ Returns:
+ A hash code value for this URI update operation.
+ """
+ return hash(self.new_uri())
+
+ def __str__(self):
+ """Provides a string representation of the UpdateUri instance.
This string includes the
+ class name followed by the new URI of the model version.
+ Returns:
+ A string summary of this URI update operation.
+ """
+ return f"UpdateUri {self._new_uri}"
diff --git a/clients/client-python/gravitino/api/types/types.py
b/clients/client-python/gravitino/api/types/types.py
index 8136796885..f005c158a6 100644
--- a/clients/client-python/gravitino/api/types/types.py
+++ b/clients/client-python/gravitino/api/types/types.py
@@ -741,6 +741,7 @@ class Types:
class ListType(ComplexType):
"""The list type in Gravitino."""
+
_element_type: Type
_element_nullable: bool
diff --git a/clients/client-python/gravitino/client/generic_model_catalog.py
b/clients/client-python/gravitino/client/generic_model_catalog.py
index b304c5115d..21e63b755a 100644
--- a/clients/client-python/gravitino/client/generic_model_catalog.py
+++ b/clients/client-python/gravitino/client/generic_model_catalog.py
@@ -542,6 +542,11 @@ class GenericModelCatalog(BaseSchemaCatalog):
change.property()
)
+ if isinstance(change, ModelVersionChange.UpdateUri):
+ return ModelVersionUpdateRequest.UpdateModelVersionUriRequest(
+ change.new_uri()
+ )
+
raise ValueError(f"Unknown change type: {type(change).__name__}")
def _check_model_namespace(self, namespace: Namespace):
diff --git
a/clients/client-python/gravitino/dto/requests/model_version_update_request.py
b/clients/client-python/gravitino/dto/requests/model_version_update_request.py
index 948f92359b..73fd43fb0f 100644
---
a/clients/client-python/gravitino/dto/requests/model_version_update_request.py
+++
b/clients/client-python/gravitino/dto/requests/model_version_update_request.py
@@ -107,3 +107,34 @@ class ModelVersionUpdateRequest:
def model_version_change(self):
return ModelVersionChange.remove_property(self._property)
+
+ @dataclass
+ class UpdateModelVersionUriRequest(ModelVersionUpdateRequestBase):
+ """Request to update model version uri"""
+
+ _new_uri: Optional[str] = field(metadata=config(field_name="newUri"))
+ """Represents a request to update the uri on a Metalake."""
+
+ def __init__(self, new_uri: str):
+ super().__init__("updateUri")
+ self._new_uri = new_uri
+
+ def new_uri(self):
+ """Retrieves the new uri of the model version.
+ Returns:
+ The new uri of the model version.
+ """
+ return self._new_uri
+
+ def validate(self):
+ """Validates the fields of the request. Always pass."""
+ if not self._new_uri:
+ raise ValueError('"newUri" field is required')
+
+ def model_version_change(self):
+ """
+ Returns a ModelVersionChange object representing the update uri
operation.
+ Returns:
+ ModelVersionChange: The ModelVersionChange object representing
the update uri operation.
+ """
+ return ModelVersionChange.update_uri(self._new_uri)
diff --git a/clients/client-python/tests/integration/test_model_catalog.py
b/clients/client-python/tests/integration/test_model_catalog.py
index 54392ae9f3..3ef9ea8b08 100644
--- a/clients/client-python/tests/integration/test_model_catalog.py
+++ b/clients/client-python/tests/integration/test_model_catalog.py
@@ -375,6 +375,45 @@ class TestModelCatalog(IntegrationTestEnv):
self.assertEqual(update_property_model.aliases(), aliases)
self.assertEqual(update_property_model.properties(), {"k1": "v11",
"k3": "v3"})
+ def test_link_update_model_version_uri(self):
+ model_name = f"model_it_model{str(randint(0, 1000))}"
+ model_ident = NameIdentifier.of(self._schema_name, model_name)
+ aliases = ["alias1", "alias2"]
+ comment = "comment"
+ properties = {"k1": "v1", "k2": "v2"}
+ self._catalog.as_model_catalog().register_model(
+ model_ident, comment, properties
+ )
+ self._catalog.as_model_catalog().link_model_version(
+ model_ident,
+ uri="uri",
+ aliases=aliases,
+ comment="comment",
+ properties={"k1": "v1", "k2": "v2"},
+ )
+
+ original_model_version =
self._catalog.as_model_catalog().get_model_version(
+ model_ident, 0
+ )
+
+ self.assertEqual(0, original_model_version.version())
+ self.assertEqual("uri", original_model_version.uri())
+ self.assertEqual(["alias1", "alias2"],
original_model_version.aliases())
+ self.assertEqual("comment", original_model_version.comment())
+ self.assertEqual({"k1": "v1", "k2": "v2"},
original_model_version.properties())
+
+ changes = [ModelVersionChange.update_uri("new_uri")]
+ self._catalog.as_model_catalog().alter_model_version(model_ident, 0,
*changes)
+
+ updated_model_version =
self._catalog.as_model_catalog().get_model_version(
+ model_ident, 0
+ )
+ self.assertEqual(0, updated_model_version.version())
+ self.assertEqual("new_uri", updated_model_version.uri())
+ self.assertEqual(["alias1", "alias2"], updated_model_version.aliases())
+ self.assertEqual("comment", updated_model_version.comment())
+ self.assertEqual({"k1": "v1", "k2": "v2"},
updated_model_version.properties())
+
def test_link_get_model_version(self):
model_name = "model_it_model" + str(randint(0, 1000))
model_ident = NameIdentifier.of(self._schema_name, model_name)
diff --git
a/common/src/main/java/org/apache/gravitino/dto/requests/ModelVersionUpdateRequest.java
b/common/src/main/java/org/apache/gravitino/dto/requests/ModelVersionUpdateRequest.java
index 364d0bc9c4..96f7ebc70b 100644
---
a/common/src/main/java/org/apache/gravitino/dto/requests/ModelVersionUpdateRequest.java
+++
b/common/src/main/java/org/apache/gravitino/dto/requests/ModelVersionUpdateRequest.java
@@ -45,7 +45,10 @@ import org.apache.gravitino.rest.RESTRequest;
name = "setProperty"),
@JsonSubTypes.Type(
value =
ModelVersionUpdateRequest.RemoveModelVersionPropertyRequest.class,
- name = "removeProperty")
+ name = "removeProperty"),
+ @JsonSubTypes.Type(
+ value = ModelVersionUpdateRequest.UpdateModelVersionUriRequest.class,
+ name = "updateUri")
})
public interface ModelVersionUpdateRequest extends RESTRequest {
@@ -136,4 +139,32 @@ public interface ModelVersionUpdateRequest extends
RESTRequest {
StringUtils.isNotBlank(property), "\"property\" field is required
and cannot be empty");
}
}
+
+ /** Request to update the URI of a model version. */
+ @EqualsAndHashCode
+ @AllArgsConstructor
+ @NoArgsConstructor(force = true)
+ @ToString
+ @Getter
+ class UpdateModelVersionUriRequest implements ModelVersionUpdateRequest {
+ @JsonProperty("newUri")
+ private final String newUri;
+
+ /** {@inheritDoc} */
+ @Override
+ public ModelVersionChange modelVersionChange() {
+ return ModelVersionChange.updateUri(newUri);
+ }
+
+ /**
+ * Validates the request, i.e., checks if the newUri is not empty.
+ *
+ * @throws IllegalArgumentException If the request is invalid, this
exception is thrown.
+ */
+ @Override
+ public void validate() throws IllegalArgumentException {
+ Preconditions.checkArgument(
+ StringUtils.isNotBlank(newUri), "\"newUri\" field is required and
cannot be empty");
+ }
+ }
}