This is an automated email from the ASF dual-hosted git repository.
jshao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gravitino.git
The following commit(s) were added to refs/heads/main by this push:
new 83de7c61f9 [#7437] feat(client): Java/Python clients and CLI support
model version with multiple URIs (#8267)
83de7c61f9 is described below
commit 83de7c61f9e7442ed2825776fb707c423e1b1382
Author: XiaoZ <[email protected]>
AuthorDate: Wed Aug 27 15:48:21 2025 +0800
[#7437] feat(client): Java/Python clients and CLI support model version
with multiple URIs (#8267)
### What changes were proposed in this pull request?
Java/Python clients and CLI support model version with multiple URIs.
### Why are the changes needed?
Fix: #7437
### Does this PR introduce _any_ user-facing change?
Yes.
### How was this patch tested?
UT & IT.
---------
Co-authored-by: zhanghan <[email protected]>
---
.../org/apache/gravitino/model/ModelCatalog.java | 18 +-
.../org/apache/gravitino/model/ModelVersion.java | 4 +-
.../integration/test/ModelCatalogOperationsIT.java | 262 +++++++++++++++++++++
.../org/apache/gravitino/cli/ErrorMessages.java | 2 +-
.../org/apache/gravitino/cli/GravitinoOptions.java | 4 +-
.../apache/gravitino/cli/ModelCommandHandler.java | 35 ++-
.../apache/gravitino/cli/TestableCommandLine.java | 4 +-
.../apache/gravitino/cli/commands/LinkModel.java | 16 +-
.../apache/gravitino/cli/TestModelCommands.java | 38 +--
.../org/apache/gravitino/client/ErrorHandlers.java | 5 +
.../gravitino/client/GenericModelCatalog.java | 50 +++-
.../gravitino/client/GenericModelVersion.java | 4 +-
.../gravitino/client/TestGenericModelCatalog.java | 248 ++++++++++++++++++-
.../gravitino/api/model_version_change.py | 143 ++++++++++-
.../gravitino/client/generic_model_catalog.py | 91 +++++--
.../gravitino/client/generic_model_version.py | 3 -
.../gravitino/dto/model_version_dto.py | 7 +-
.../dto/requests/model_version_link_request.py | 16 +-
.../dto/requests/model_version_update_request.py | 87 ++++++-
.../dto/responses/model_version_list_response.py | 2 +-
.../dto/responses/model_version_response.py | 2 +-
...n_response.py => model_version_uri_response.py} | 29 +--
.../tests/integration/test_model_catalog.py | 193 +++++++++++++++
.../tests/unittests/test_model_catalog_api.py | 79 ++++++-
.../tests/unittests/test_responses.py | 34 ++-
25 files changed, 1245 insertions(+), 131 deletions(-)
diff --git a/api/src/main/java/org/apache/gravitino/model/ModelCatalog.java
b/api/src/main/java/org/apache/gravitino/model/ModelCatalog.java
index 9f0f8caf1d..e10df19ce2 100644
--- a/api/src/main/java/org/apache/gravitino/model/ModelCatalog.java
+++ b/api/src/main/java/org/apache/gravitino/model/ModelCatalog.java
@@ -250,15 +250,13 @@ public interface ModelCatalog {
* @throws NoSuchModelException If the model does not exist.
* @throws ModelVersionAliasesAlreadyExistException If the aliases already
exist in the model.
*/
- default void linkModelVersion(
+ void linkModelVersion(
NameIdentifier ident,
Map<String, String> uris,
String[] aliases,
String comment,
Map<String, String> properties)
- throws NoSuchModelException, ModelVersionAliasesAlreadyExistException {
- throw new UnsupportedOperationException("Not supported yet");
- }
+ throws NoSuchModelException, ModelVersionAliasesAlreadyExistException;
/**
* Get the URI of the model artifact with a specified version number and URI
name.
@@ -269,10 +267,8 @@ public interface ModelCatalog {
* @throws NoSuchModelVersionException If the model version does not exist.
* @return The URI of the model version.
*/
- default String getModelVersionUri(NameIdentifier ident, int version, String
uriName)
- throws NoSuchModelVersionException, NoSuchModelVersionURINameException {
- throw new UnsupportedOperationException("Not supported yet");
- }
+ String getModelVersionUri(NameIdentifier ident, int version, String uriName)
+ throws NoSuchModelVersionException, NoSuchModelVersionURINameException;
/**
* Get the URI of the model artifact with a specified version alias and URI
name.
@@ -283,10 +279,8 @@ public interface ModelCatalog {
* @throws NoSuchModelVersionException If the model version does not exist.
* @return The URI of the model version.
*/
- default String getModelVersionUri(NameIdentifier ident, String alias, String
uriName)
- throws NoSuchModelVersionException, NoSuchModelVersionURINameException {
- throw new UnsupportedOperationException("Not supported yet");
- }
+ String getModelVersionUri(NameIdentifier ident, String alias, String uriName)
+ throws NoSuchModelVersionException, NoSuchModelVersionURINameException;
/**
* Delete the model version by the {@link NameIdentifier} and version
number. If the model version
diff --git a/api/src/main/java/org/apache/gravitino/model/ModelVersion.java
b/api/src/main/java/org/apache/gravitino/model/ModelVersion.java
index 5f6fa5bf42..a9d3be4abb 100644
--- a/api/src/main/java/org/apache/gravitino/model/ModelVersion.java
+++ b/api/src/main/java/org/apache/gravitino/model/ModelVersion.java
@@ -84,9 +84,7 @@ public interface ModelVersion extends Auditable {
* @return The URIs of the model version, the key is the name of the URI and
the value is the URI
* of the model artifact.
*/
- default Map<String, String> uris() {
- throw new UnsupportedOperationException("Not implemented yet");
- }
+ Map<String, String> uris();
/**
* The properties of the model version. The properties are key-value pairs
that can be used to
diff --git
a/catalogs/catalog-model/src/test/java/org/apache/gravtitino/catalog/model/integration/test/ModelCatalogOperationsIT.java
b/catalogs/catalog-model/src/test/java/org/apache/gravtitino/catalog/model/integration/test/ModelCatalogOperationsIT.java
index f61071dfbb..f16cec8e9c 100644
---
a/catalogs/catalog-model/src/test/java/org/apache/gravtitino/catalog/model/integration/test/ModelCatalogOperationsIT.java
+++
b/catalogs/catalog-model/src/test/java/org/apache/gravtitino/catalog/model/integration/test/ModelCatalogOperationsIT.java
@@ -35,6 +35,7 @@ import
org.apache.gravitino.exceptions.ModelAlreadyExistsException;
import
org.apache.gravitino.exceptions.ModelVersionAliasesAlreadyExistException;
import org.apache.gravitino.exceptions.NoSuchModelException;
import org.apache.gravitino.exceptions.NoSuchModelVersionException;
+import org.apache.gravitino.exceptions.NoSuchModelVersionURINameException;
import org.apache.gravitino.exceptions.NoSuchSchemaException;
import org.apache.gravitino.integration.test.util.BaseIT;
import org.apache.gravitino.model.Model;
@@ -369,6 +370,51 @@ public class ModelCatalogOperationsIT extends BaseIT {
Assertions.assertEquals(0, modelVersionsAfterDelete.length);
}
+ @Test
+ public void testLinkModelVersionWithMultipleUris() {
+ String modelName = RandomNameUtils.genRandomName("model1");
+ NameIdentifier modelIdent = NameIdentifier.of(schemaName, modelName);
+ gravitinoCatalog.asModelCatalog().registerModel(modelIdent, null, null);
+ Map<String, String> uris = ImmutableMap.of("n1", "u1", "n2", "u2");
+ gravitinoCatalog
+ .asModelCatalog()
+ .linkModelVersion(modelIdent, uris, new String[] {"alias1"},
"comment1", null);
+
+
Assertions.assertTrue(gravitinoCatalog.asModelCatalog().modelVersionExists(modelIdent,
0));
+ Assertions.assertTrue(
+ gravitinoCatalog.asModelCatalog().modelVersionExists(modelIdent,
"alias1"));
+
+ // Test get model version
+ ModelVersion modelVersion =
gravitinoCatalog.asModelCatalog().getModelVersion(modelIdent, 0);
+ Assertions.assertEquals(0, modelVersion.version());
+ Assertions.assertEquals(uris, modelVersion.uris());
+ Assertions.assertArrayEquals(new String[] {"alias1"},
modelVersion.aliases());
+ Assertions.assertEquals("comment1", modelVersion.comment());
+ Assertions.assertEquals(Collections.emptyMap(), modelVersion.properties());
+
+ // Test list model versions
+ int[] modelVersions =
gravitinoCatalog.asModelCatalog().listModelVersions(modelIdent);
+ Set<Integer> resultSet =
Arrays.stream(modelVersions).boxed().collect(Collectors.toSet());
+ Assertions.assertEquals(1, resultSet.size());
+ Assertions.assertTrue(resultSet.contains(0));
+
+ // Test list model version infos
+ ModelVersion[] modelVersionInfos =
+ gravitinoCatalog.asModelCatalog().listModelVersionInfos(modelIdent);
+ Assertions.assertEquals(1, modelVersionInfos.length);
+ Assertions.assertEquals(0, modelVersionInfos[0].version());
+ Assertions.assertEquals(uris, modelVersionInfos[0].uris());
+ Assertions.assertArrayEquals(new String[] {"alias1"},
modelVersionInfos[0].aliases());
+ Assertions.assertEquals("comment1", modelVersionInfos[0].comment());
+ Assertions.assertEquals(Collections.emptyMap(),
modelVersionInfos[0].properties());
+
+ // Test delete and list model versions info
+
Assertions.assertTrue(gravitinoCatalog.asModelCatalog().deleteModelVersion(modelIdent,
0));
+ int[] modelVersionsAfterDelete =
+ gravitinoCatalog.asModelCatalog().listModelVersions(modelIdent);
+ Assertions.assertEquals(0, modelVersionsAfterDelete.length);
+ }
+
@Test
void testLinkAndUpdateModelVersionComment() {
String modelName = RandomNameUtils.genRandomName("model1");
@@ -714,6 +760,84 @@ public class ModelCatalogOperationsIT extends BaseIT {
Assertions.assertEquals(modelVersion.properties(),
reloadedModelVersion.properties());
}
+ @Test
+ void testLinkAndUpdateModelVersionUriWithMultipleUris() {
+ String modelName = RandomNameUtils.genRandomName("model1");
+ String[] aliases = {"alias1"};
+ Map<String, String> properties = ImmutableMap.of("key1", "val1", "key2",
"val2");
+ NameIdentifier modelIdent = NameIdentifier.of(schemaName, modelName);
+
+ Map<String, String> uris = ImmutableMap.of("n1", "u1");
+ String versionComment = "comment";
+
+ // create and link model version
+ gravitinoCatalog.asModelCatalog().registerModel(modelIdent, null, null);
+ gravitinoCatalog
+ .asModelCatalog()
+ .linkModelVersion(modelIdent, uris, aliases, versionComment,
properties);
+ ModelVersion modelVersion =
gravitinoCatalog.asModelCatalog().getModelVersion(modelIdent, 0);
+ Assertions.assertEquals(0, modelVersion.version());
+ Assertions.assertEquals(uris, modelVersion.uris());
+ Assertions.assertArrayEquals(aliases, modelVersion.aliases());
+ Assertions.assertEquals(versionComment, modelVersion.comment());
+ Assertions.assertEquals(properties, modelVersion.properties());
+
+ // Test update uri
+ Map<String, String> updatedUris = ImmutableMap.of("n1", "u1-1");
+ ModelVersionChange updateUriChange = ModelVersionChange.updateUri("n1",
"u1-1");
+ ModelVersion updatedModelVersion =
+ gravitinoCatalog.asModelCatalog().alterModelVersion(modelIdent, 0,
updateUriChange);
+ Assertions.assertEquals(modelVersion.version(),
updatedModelVersion.version());
+ Assertions.assertEquals(updatedUris, updatedModelVersion.uris());
+ Assertions.assertArrayEquals(modelVersion.aliases(),
updatedModelVersion.aliases());
+ Assertions.assertEquals(modelVersion.comment(),
updatedModelVersion.comment());
+ Assertions.assertEquals(modelVersion.properties(),
updatedModelVersion.properties());
+
+ ModelVersion reloadedModelVersion =
+ gravitinoCatalog.asModelCatalog().getModelVersion(modelIdent, 0);
+ Assertions.assertEquals(modelVersion.version(),
reloadedModelVersion.version());
+ Assertions.assertEquals(updatedUris, reloadedModelVersion.uris());
+ Assertions.assertArrayEquals(modelVersion.aliases(),
reloadedModelVersion.aliases());
+ Assertions.assertEquals(modelVersion.comment(),
reloadedModelVersion.comment());
+ Assertions.assertEquals(modelVersion.properties(),
reloadedModelVersion.properties());
+
+ // Test add uri
+ updatedUris = ImmutableMap.of("n1", "u1-1", "n2", "u2");
+ ModelVersionChange addUriChange = ModelVersionChange.addUri("n2", "u2");
+ updatedModelVersion =
+ gravitinoCatalog.asModelCatalog().alterModelVersion(modelIdent, 0,
addUriChange);
+ Assertions.assertEquals(modelVersion.version(),
updatedModelVersion.version());
+ Assertions.assertEquals(updatedUris, updatedModelVersion.uris());
+ Assertions.assertArrayEquals(modelVersion.aliases(),
updatedModelVersion.aliases());
+ Assertions.assertEquals(modelVersion.comment(),
updatedModelVersion.comment());
+ Assertions.assertEquals(modelVersion.properties(),
updatedModelVersion.properties());
+
+ reloadedModelVersion =
gravitinoCatalog.asModelCatalog().getModelVersion(modelIdent, 0);
+ Assertions.assertEquals(modelVersion.version(),
reloadedModelVersion.version());
+ Assertions.assertEquals(updatedUris, reloadedModelVersion.uris());
+ Assertions.assertArrayEquals(modelVersion.aliases(),
reloadedModelVersion.aliases());
+ Assertions.assertEquals(modelVersion.comment(),
reloadedModelVersion.comment());
+ Assertions.assertEquals(modelVersion.properties(),
reloadedModelVersion.properties());
+
+ // Test remove uri
+ updatedUris = ImmutableMap.of("n2", "u2");
+ ModelVersionChange removeUriChange = ModelVersionChange.removeUri("n1");
+ updatedModelVersion =
+ gravitinoCatalog.asModelCatalog().alterModelVersion(modelIdent, 0,
removeUriChange);
+ Assertions.assertEquals(modelVersion.version(),
updatedModelVersion.version());
+ Assertions.assertEquals(updatedUris, updatedModelVersion.uris());
+ Assertions.assertArrayEquals(modelVersion.aliases(),
updatedModelVersion.aliases());
+ Assertions.assertEquals(modelVersion.comment(),
updatedModelVersion.comment());
+ Assertions.assertEquals(modelVersion.properties(),
updatedModelVersion.properties());
+
+ reloadedModelVersion =
gravitinoCatalog.asModelCatalog().getModelVersion(modelIdent, 0);
+ Assertions.assertEquals(modelVersion.version(),
reloadedModelVersion.version());
+ Assertions.assertEquals(updatedUris, reloadedModelVersion.uris());
+ Assertions.assertArrayEquals(modelVersion.aliases(),
reloadedModelVersion.aliases());
+ Assertions.assertEquals(modelVersion.comment(),
reloadedModelVersion.comment());
+ Assertions.assertEquals(modelVersion.properties(),
reloadedModelVersion.properties());
+ }
+
@Test
void testLinkAndUpdateModelVersionUriByAlias() {
String modelName = RandomNameUtils.genRandomName("model1");
@@ -948,6 +1072,144 @@ public class ModelCatalogOperationsIT extends BaseIT {
Assertions.assertEquals(modelVersion.properties(),
reloadedModelVersion.properties());
}
+ @Test
+ public void testGetModelVersionUri() {
+ // Test get model version without default uri name
+ String modelName = RandomNameUtils.genRandomName("model1");
+ NameIdentifier modelIdent = NameIdentifier.of(schemaName, modelName);
+ gravitinoCatalog.asModelCatalog().registerModel(modelIdent, null, null);
+ Map<String, String> uris = ImmutableMap.of("n1", "u1", "n2", "u2");
+ gravitinoCatalog
+ .asModelCatalog()
+ .linkModelVersion(modelIdent, uris, new String[] {"alias1"},
"comment1", null);
+
Assertions.assertTrue(gravitinoCatalog.asModelCatalog().modelVersionExists(modelIdent,
0));
+ Assertions.assertTrue(
+ gravitinoCatalog.asModelCatalog().modelVersionExists(modelIdent,
"alias1"));
+
+ Assertions.assertEquals(
+ "u1", gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent,
0, "n1"));
+ Assertions.assertEquals(
+ "u2", gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent,
0, "n2"));
+ Assertions.assertThrows(
+ NoSuchModelVersionURINameException.class,
+ () -> gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent,
0, "n3"));
+ Assertions.assertThrows(
+ IllegalArgumentException.class,
+ () -> gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent,
0, null));
+
+ // Test get model version with default uri name
+ String modelName1 = RandomNameUtils.genRandomName("model1");
+ NameIdentifier modelIdent1 = NameIdentifier.of(schemaName, modelName1);
+ Map<String, String> modelProperties =
+ ImmutableMap.of(ModelVersion.PROPERTY_DEFAULT_URI_NAME, "n1");
+ gravitinoCatalog.asModelCatalog().registerModel(modelIdent1, null,
modelProperties);
+ gravitinoCatalog
+ .asModelCatalog()
+ .linkModelVersion(modelIdent1, uris, new String[] {"alias2"},
"comment1", null);
+
Assertions.assertTrue(gravitinoCatalog.asModelCatalog().modelVersionExists(modelIdent1,
0));
+ Assertions.assertTrue(
+ gravitinoCatalog.asModelCatalog().modelVersionExists(modelIdent1,
"alias2"));
+
+ Assertions.assertEquals(
+ "u1",
gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent1, 0, "n1"));
+ Assertions.assertEquals(
+ "u2",
gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent1, 0, "n2"));
+ Assertions.assertThrows(
+ NoSuchModelVersionURINameException.class,
+ () ->
gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent1, 0, "n3"));
+ Assertions.assertEquals(
+ "u1",
gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent1, 0, null));
+
+ Map<String, String> modelVersionProperties =
+ ImmutableMap.of(ModelVersion.PROPERTY_DEFAULT_URI_NAME, "n2");
+ gravitinoCatalog
+ .asModelCatalog()
+ .linkModelVersion(
+ modelIdent1, uris, new String[] {"alias3"}, "comment1",
modelVersionProperties);
+
Assertions.assertTrue(gravitinoCatalog.asModelCatalog().modelVersionExists(modelIdent1,
1));
+ Assertions.assertTrue(
+ gravitinoCatalog.asModelCatalog().modelVersionExists(modelIdent1,
"alias3"));
+
+ Assertions.assertEquals(
+ "u1",
gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent1, 1, "n1"));
+ Assertions.assertEquals(
+ "u2",
gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent1, 1, "n2"));
+ Assertions.assertThrows(
+ NoSuchModelVersionURINameException.class,
+ () ->
gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent1, 1, "n3"));
+ Assertions.assertEquals(
+ "u2",
gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent1, 1, null));
+ }
+
+ @Test
+ public void testGetModelVersionUriByAlias() {
+ // Test get model version without default uri name
+ String modelName = RandomNameUtils.genRandomName("model1");
+ NameIdentifier modelIdent = NameIdentifier.of(schemaName, modelName);
+ gravitinoCatalog.asModelCatalog().registerModel(modelIdent, null, null);
+ Map<String, String> uris = ImmutableMap.of("n1", "u1", "n2", "u2");
+ gravitinoCatalog
+ .asModelCatalog()
+ .linkModelVersion(modelIdent, uris, new String[] {"alias1"},
"comment1", null);
+
Assertions.assertTrue(gravitinoCatalog.asModelCatalog().modelVersionExists(modelIdent,
0));
+ Assertions.assertTrue(
+ gravitinoCatalog.asModelCatalog().modelVersionExists(modelIdent,
"alias1"));
+
+ Assertions.assertEquals(
+ "u1", gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent,
"alias1", "n1"));
+ Assertions.assertEquals(
+ "u2", gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent,
"alias1", "n2"));
+ Assertions.assertThrows(
+ NoSuchModelVersionURINameException.class,
+ () -> gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent,
"alias1", "n3"));
+ Assertions.assertThrows(
+ IllegalArgumentException.class,
+ () -> gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent,
"alias1", null));
+
+ // Test get model version with default uri name
+ String modelName1 = RandomNameUtils.genRandomName("model1");
+ NameIdentifier modelIdent1 = NameIdentifier.of(schemaName, modelName1);
+ Map<String, String> modelProperties =
+ ImmutableMap.of(ModelVersion.PROPERTY_DEFAULT_URI_NAME, "n1");
+ gravitinoCatalog.asModelCatalog().registerModel(modelIdent1, null,
modelProperties);
+ gravitinoCatalog
+ .asModelCatalog()
+ .linkModelVersion(modelIdent1, uris, new String[] {"alias2"},
"comment1", null);
+
Assertions.assertTrue(gravitinoCatalog.asModelCatalog().modelVersionExists(modelIdent1,
0));
+ Assertions.assertTrue(
+ gravitinoCatalog.asModelCatalog().modelVersionExists(modelIdent1,
"alias2"));
+
+ Assertions.assertEquals(
+ "u1",
gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent1, "alias2",
"n1"));
+ Assertions.assertEquals(
+ "u2",
gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent1, "alias2",
"n2"));
+ Assertions.assertThrows(
+ NoSuchModelVersionURINameException.class,
+ () ->
gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent1, "alias2",
"n3"));
+ Assertions.assertEquals(
+ "u1",
gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent1, "alias2",
null));
+
+ Map<String, String> modelVersionProperties =
+ ImmutableMap.of(ModelVersion.PROPERTY_DEFAULT_URI_NAME, "n2");
+ gravitinoCatalog
+ .asModelCatalog()
+ .linkModelVersion(
+ modelIdent1, uris, new String[] {"alias3"}, "comment1",
modelVersionProperties);
+
Assertions.assertTrue(gravitinoCatalog.asModelCatalog().modelVersionExists(modelIdent1,
1));
+ Assertions.assertTrue(
+ gravitinoCatalog.asModelCatalog().modelVersionExists(modelIdent1,
"alias3"));
+
+ Assertions.assertEquals(
+ "u1",
gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent1, "alias3",
"n1"));
+ Assertions.assertEquals(
+ "u2",
gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent1, "alias3",
"n2"));
+ Assertions.assertThrows(
+ NoSuchModelVersionURINameException.class,
+ () ->
gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent1, "alias3",
"n3"));
+ Assertions.assertEquals(
+ "u2",
gravitinoCatalog.asModelCatalog().getModelVersionUri(modelIdent1, "alias3",
null));
+ }
+
private void createMetalake() {
GravitinoMetalake[] gravitinoMetalakes = client.listMetalakes();
Assertions.assertEquals(0, gravitinoMetalakes.length);
diff --git
a/clients/cli/src/main/java/org/apache/gravitino/cli/ErrorMessages.java
b/clients/cli/src/main/java/org/apache/gravitino/cli/ErrorMessages.java
index 2f937dd4be..e5dc6b0fd6 100644
--- a/clients/cli/src/main/java/org/apache/gravitino/cli/ErrorMessages.java
+++ b/clients/cli/src/main/java/org/apache/gravitino/cli/ErrorMessages.java
@@ -62,7 +62,7 @@ public class ErrorMessages {
public static final String MISSING_PROPERTY_AND_VALUE = "Missing --property
and --value options.";
public static final String MISSING_ROLE = "Missing --role option.";
public static final String MISSING_TAG = "Missing --tag option.";
- public static final String MISSING_URI = "Missing --uri option.";
+ public static final String MISSING_URIS = "Missing --uris option.";
public static final String MISSING_USER = "Missing --user option.";
public static final String MISSING_VALUE = "Missing --value option.";
diff --git
a/clients/cli/src/main/java/org/apache/gravitino/cli/GravitinoOptions.java
b/clients/cli/src/main/java/org/apache/gravitino/cli/GravitinoOptions.java
index 81404c8b8e..a81a322c7e 100644
--- a/clients/cli/src/main/java/org/apache/gravitino/cli/GravitinoOptions.java
+++ b/clients/cli/src/main/java/org/apache/gravitino/cli/GravitinoOptions.java
@@ -65,7 +65,7 @@ public class GravitinoOptions {
public static final String ENABLE = "enable";
public static final String DISABLE = "disable";
public static final String ALIAS = "alias";
- public static final String URI = "uri";
+ public static final String URIS = "uris";
// TODO: temporary option for model version update, it will be refactored in
the future, just
// prove the E2E flow.
public static final String NEW_URI = "newuri";
@@ -123,7 +123,7 @@ public class GravitinoOptions {
options.addOption(createSimpleOption(null, ALL, "on all entities"));
// model options
- options.addOption(createArgOption(null, URI, "model version artifact"));
+ options.addOption(createArgOption(null, URIS, "model version URIs"));
options.addOption(createArgsOption(null, ALIAS, "model aliases"));
options.addOption(createArgOption(null, VERSION, "Gravitino client
version"));
options.addOption(createArgOption(null, NEW_URI, "New uri of a model
version"));
diff --git
a/clients/cli/src/main/java/org/apache/gravitino/cli/ModelCommandHandler.java
b/clients/cli/src/main/java/org/apache/gravitino/cli/ModelCommandHandler.java
index 222e1d56e6..31c61551e5 100644
---
a/clients/cli/src/main/java/org/apache/gravitino/cli/ModelCommandHandler.java
+++
b/clients/cli/src/main/java/org/apache/gravitino/cli/ModelCommandHandler.java
@@ -19,6 +19,7 @@
package org.apache.gravitino.cli;
+import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import java.util.List;
import java.util.Map;
@@ -27,6 +28,9 @@ import org.apache.gravitino.cli.commands.Command;
/** Handles the command execution for Models based on command type and the
command line options. */
public class ModelCommandHandler extends CommandHandler {
+ private static final String DELIMITER = ",";
+ private static final String KEY_VALUE_SEPARATOR = "=";
+
private final GravitinoCommandLine gravitinoCommandLine;
private final CommandLine line;
private final String command;
@@ -160,15 +164,23 @@ public class ModelCommandHandler extends CommandHandler {
/** Handles the "UPDATE" command. */
private void handleUpdateCommand() {
- if (line.hasOption(GravitinoOptions.URI)) {
+ if (line.hasOption(GravitinoOptions.URIS)) {
String[] alias = line.getOptionValues(GravitinoOptions.ALIAS);
- String uri = line.getOptionValue(GravitinoOptions.URI);
+ Map<String, String> uris = getUrisFromLime(line);
String linkComment = line.getOptionValue(GravitinoOptions.COMMENT);
String[] linkProperties =
line.getOptionValues(CommandActions.PROPERTIES);
Map<String, String> linkPropertityMap = new
Properties().parse(linkProperties);
gravitinoCommandLine
.newLinkModel(
- context, metalake, catalog, schema, model, uri, alias,
linkComment, linkPropertityMap)
+ context,
+ metalake,
+ catalog,
+ schema,
+ model,
+ uris,
+ alias,
+ linkComment,
+ linkPropertityMap)
.validate()
.handle();
}
@@ -190,7 +202,7 @@ public class ModelCommandHandler extends CommandHandler {
.handle();
}
- if (!line.hasOption(GravitinoOptions.URI)
+ if (!line.hasOption(GravitinoOptions.URIS)
&& line.hasOption(GravitinoOptions.COMMENT)
&& (line.hasOption(GravitinoOptions.ALIAS) ||
line.hasOption(GravitinoOptions.VERSION))) {
String comment = line.getOptionValue(GravitinoOptions.COMMENT);
@@ -317,4 +329,19 @@ public class ModelCommandHandler extends CommandHandler {
? getOneAlias(line.getOptionValues(GravitinoOptions.ALIAS))
: null;
}
+
+ private Map<String, String> getUrisFromLime(CommandLine line) {
+ String input = line.getOptionValue(GravitinoOptions.URIS);
+ ImmutableMap.Builder<String, String> uris = ImmutableMap.builder();
+ if (input != null) {
+ String[] pairs = input.split(DELIMITER);
+ for (String pair : pairs) {
+ String[] keyValue = pair.split(KEY_VALUE_SEPARATOR, 2);
+ if (keyValue.length == 2) {
+ uris.put(keyValue[0].trim(), keyValue[1].trim());
+ }
+ }
+ }
+ return uris.build();
+ }
}
diff --git
a/clients/cli/src/main/java/org/apache/gravitino/cli/TestableCommandLine.java
b/clients/cli/src/main/java/org/apache/gravitino/cli/TestableCommandLine.java
index 7b4c809d40..becdf61767 100644
---
a/clients/cli/src/main/java/org/apache/gravitino/cli/TestableCommandLine.java
+++
b/clients/cli/src/main/java/org/apache/gravitino/cli/TestableCommandLine.java
@@ -1000,11 +1000,11 @@ public class TestableCommandLine {
String catalog,
String schema,
String model,
- String uri,
+ Map<String, String> uris,
String[] alias,
String comment,
Map<String, String> properties) {
return new LinkModel(
- context, metalake, catalog, schema, model, uri, alias, comment,
properties);
+ context, metalake, catalog, schema, model, uris, alias, comment,
properties);
}
}
diff --git
a/clients/cli/src/main/java/org/apache/gravitino/cli/commands/LinkModel.java
b/clients/cli/src/main/java/org/apache/gravitino/cli/commands/LinkModel.java
index 94a5f99a5b..8664dacd2a 100644
--- a/clients/cli/src/main/java/org/apache/gravitino/cli/commands/LinkModel.java
+++ b/clients/cli/src/main/java/org/apache/gravitino/cli/commands/LinkModel.java
@@ -37,7 +37,7 @@ public class LinkModel extends Command {
protected final String catalog;
protected final String schema;
protected final String model;
- protected final String uri;
+ protected final Map<String, String> uris;
protected final String[] alias;
protected final String comment;
protected final Map<String, String> properties;
@@ -50,7 +50,7 @@ public class LinkModel extends Command {
* @param catalog The name of the catalog.
* @param schema The name of schema.
* @param model The name of model.
- * @param uri The URI of the model version artifact.
+ * @param uris The URIs of the model version artifact.
* @param alias The aliases of the model version.
* @param comment The comment of the model version.
* @param properties The properties of the model version.
@@ -61,7 +61,7 @@ public class LinkModel extends Command {
String catalog,
String schema,
String model,
- String uri,
+ Map<String, String> uris,
String[] alias,
String comment,
Map<String, String> properties) {
@@ -70,7 +70,7 @@ public class LinkModel extends Command {
this.catalog = catalog;
this.schema = schema;
this.model = model;
- this.uri = uri;
+ this.uris = uris;
this.alias = alias;
this.comment = comment;
this.properties = properties;
@@ -84,7 +84,7 @@ public class LinkModel extends Command {
try {
GravitinoClient client = buildClient(metalake);
ModelCatalog modelCatalog = client.loadCatalog(catalog).asModelCatalog();
- modelCatalog.linkModelVersion(name, uri, alias, comment, properties);
+ modelCatalog.linkModelVersion(name, uris, alias, comment, properties);
} catch (NoSuchMetalakeException err) {
exitWithError(ErrorMessages.UNKNOWN_METALAKE);
} catch (NoSuchCatalogException err) {
@@ -100,12 +100,14 @@ public class LinkModel extends Command {
}
printResults(
- "Linked model " + model + " to " + uri + " with aliases " +
Arrays.toString(alias));
+ "Linked model " + model + " to " + uris + " with aliases " +
Arrays.toString(alias));
}
@Override
public Command validate() {
- if (uri == null) exitWithError(ErrorMessages.MISSING_URI);
+ if (uris == null || uris.isEmpty()) {
+ exitWithError(ErrorMessages.MISSING_URIS);
+ }
return super.validate();
}
}
diff --git
a/clients/cli/src/test/java/org/apache/gravitino/cli/TestModelCommands.java
b/clients/cli/src/test/java/org/apache/gravitino/cli/TestModelCommands.java
index 155778381c..a3102c1261 100644
--- a/clients/cli/src/test/java/org/apache/gravitino/cli/TestModelCommands.java
+++ b/clients/cli/src/test/java/org/apache/gravitino/cli/TestModelCommands.java
@@ -19,8 +19,8 @@
package org.apache.gravitino.cli;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.ArgumentMatchers.eq;
@@ -32,6 +32,7 @@ import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
+import com.google.common.collect.ImmutableMap;
import java.io.ByteArrayOutputStream;
import java.io.PrintStream;
import java.nio.charset.StandardCharsets;
@@ -56,7 +57,6 @@ import
org.apache.gravitino.cli.commands.UpdateModelVersionAliases;
import org.apache.gravitino.cli.commands.UpdateModelVersionComment;
import org.apache.gravitino.cli.commands.UpdateModelVersionUri;
import org.junit.jupiter.api.AfterEach;
-import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.testcontainers.shaded.com.google.common.base.Joiner;
@@ -443,8 +443,8 @@ public class TestModelCommands {
when(mockCommandLine.getOptionValue(GravitinoOptions.METALAKE)).thenReturn("metalake_demo");
when(mockCommandLine.hasOption(GravitinoOptions.NAME)).thenReturn(true);
when(mockCommandLine.getOptionValue(GravitinoOptions.NAME)).thenReturn("catalog.schema.model");
- when(mockCommandLine.hasOption(GravitinoOptions.URI)).thenReturn(true);
-
when(mockCommandLine.getOptionValue(GravitinoOptions.URI)).thenReturn("file:///tmp/file");
+ when(mockCommandLine.hasOption(GravitinoOptions.URIS)).thenReturn(true);
+
when(mockCommandLine.getOptionValue(GravitinoOptions.URIS)).thenReturn("n1=u1,n2=u2");
when(mockCommandLine.hasOption(GravitinoOptions.ALIAS)).thenReturn(false);
GravitinoCommandLine commandLine =
spy(
@@ -459,7 +459,7 @@ public class TestModelCommands {
eq("catalog"),
eq("schema"),
eq("model"),
- eq("file:///tmp/file"),
+ eq(ImmutableMap.of("n1", "u1", "n2", "u2")),
isNull(),
isNull(),
argThat(Map::isEmpty));
@@ -475,8 +475,8 @@ public class TestModelCommands {
when(mockCommandLine.getOptionValue(GravitinoOptions.METALAKE)).thenReturn("metalake_demo");
when(mockCommandLine.hasOption(GravitinoOptions.NAME)).thenReturn(true);
when(mockCommandLine.getOptionValue(GravitinoOptions.NAME)).thenReturn("catalog.schema.model");
- when(mockCommandLine.hasOption(GravitinoOptions.URI)).thenReturn(true);
-
when(mockCommandLine.getOptionValue(GravitinoOptions.URI)).thenReturn("file:///tmp/file");
+ when(mockCommandLine.hasOption(GravitinoOptions.URIS)).thenReturn(true);
+
when(mockCommandLine.getOptionValue(GravitinoOptions.URIS)).thenReturn("n1=u1,n2=u2");
when(mockCommandLine.hasOption(GravitinoOptions.ALIAS)).thenReturn(true);
when(mockCommandLine.getOptionValues(GravitinoOptions.ALIAS))
.thenReturn(new String[] {"aliasA", "aliasB"});
@@ -493,7 +493,7 @@ public class TestModelCommands {
eq("catalog"),
eq("schema"),
eq("model"),
- eq("file:///tmp/file"),
+ eq(ImmutableMap.of("n1", "u1", "n2", "u2")),
argThat(
argument ->
argument.length == 2
@@ -527,7 +527,7 @@ public class TestModelCommands {
assertThrows(RuntimeException.class, spyLinkModel::validate);
verify(spyLinkModel, never()).handle();
String output = new String(errContent.toByteArray(),
StandardCharsets.UTF_8).trim();
- assertEquals(ErrorMessages.MISSING_URI, output);
+ assertEquals(ErrorMessages.MISSING_URIS, output);
}
@Test
@@ -537,8 +537,8 @@ public class TestModelCommands {
when(mockCommandLine.getOptionValue(GravitinoOptions.METALAKE)).thenReturn("metalake_demo");
when(mockCommandLine.hasOption(GravitinoOptions.NAME)).thenReturn(true);
when(mockCommandLine.getOptionValue(GravitinoOptions.NAME)).thenReturn("catalog.schema.model");
- when(mockCommandLine.hasOption(GravitinoOptions.URI)).thenReturn(true);
-
when(mockCommandLine.getOptionValue(GravitinoOptions.URI)).thenReturn("file:///tmp/file");
+ when(mockCommandLine.hasOption(GravitinoOptions.URIS)).thenReturn(true);
+
when(mockCommandLine.getOptionValue(GravitinoOptions.URIS)).thenReturn("n1=u1,n2=u2");
when(mockCommandLine.hasOption(GravitinoOptions.ALIAS)).thenReturn(true);
when(mockCommandLine.getOptionValues(GravitinoOptions.ALIAS))
.thenReturn(new String[] {"aliasA", "aliasB"});
@@ -560,7 +560,7 @@ public class TestModelCommands {
eq("catalog"),
eq("schema"),
eq("model"),
- eq("file:///tmp/file"),
+ eq(ImmutableMap.of("n1", "u1", "n2", "u2")),
argThat(
argument ->
argument.length == 2
@@ -783,7 +783,7 @@ public class TestModelCommands {
new GravitinoCommandLine(
mockCommandLine, mockOptions, CommandEntities.MODEL,
CommandActions.UPDATE));
- Assertions.assertThrows(RuntimeException.class,
commandLine::handleCommandLine);
+ assertThrows(RuntimeException.class, commandLine::handleCommandLine);
}
@Test
@@ -878,7 +878,7 @@ public class TestModelCommands {
new GravitinoCommandLine(
mockCommandLine, mockOptions, CommandEntities.MODEL,
CommandActions.SET));
- Assertions.assertThrows(RuntimeException.class,
commandLine::handleCommandLine);
+ assertThrows(RuntimeException.class, commandLine::handleCommandLine);
}
@Test
@@ -965,7 +965,7 @@ public class TestModelCommands {
new GravitinoCommandLine(
mockCommandLine, mockOptions, CommandEntities.MODEL,
CommandActions.REMOVE));
- Assertions.assertThrows(RuntimeException.class,
commandLine::handleCommandLine);
+ assertThrows(RuntimeException.class, commandLine::handleCommandLine);
}
@Test
@@ -1052,7 +1052,7 @@ public class TestModelCommands {
new GravitinoCommandLine(
mockCommandLine, mockOptions, CommandEntities.MODEL,
CommandActions.UPDATE));
- Assertions.assertThrows(RuntimeException.class,
commandLine::handleCommandLine);
+ assertThrows(RuntimeException.class, commandLine::handleCommandLine);
}
@Test
@@ -1144,7 +1144,7 @@ public class TestModelCommands {
new GravitinoCommandLine(
mockCommandLine, mockOptions, CommandEntities.MODEL,
CommandActions.UPDATE));
- Assertions.assertThrows(RuntimeException.class,
commandLine::handleCommandLine);
+ assertThrows(RuntimeException.class, commandLine::handleCommandLine);
}
@Test
@@ -1165,6 +1165,6 @@ public class TestModelCommands {
new GravitinoCommandLine(
mockCommandLine, mockOptions, CommandEntities.MODEL,
CommandActions.UPDATE));
- Assertions.assertThrows(RuntimeException.class,
commandLine::handleCommandLine);
+ assertThrows(RuntimeException.class, commandLine::handleCommandLine);
}
}
diff --git
a/clients/client-java/src/main/java/org/apache/gravitino/client/ErrorHandlers.java
b/clients/client-java/src/main/java/org/apache/gravitino/client/ErrorHandlers.java
index 11e41aacba..705731808f 100644
---
a/clients/client-java/src/main/java/org/apache/gravitino/client/ErrorHandlers.java
+++
b/clients/client-java/src/main/java/org/apache/gravitino/client/ErrorHandlers.java
@@ -55,6 +55,7 @@ import
org.apache.gravitino.exceptions.NoSuchMetadataObjectException;
import org.apache.gravitino.exceptions.NoSuchMetalakeException;
import org.apache.gravitino.exceptions.NoSuchModelException;
import org.apache.gravitino.exceptions.NoSuchModelVersionException;
+import org.apache.gravitino.exceptions.NoSuchModelVersionURINameException;
import org.apache.gravitino.exceptions.NoSuchPartitionException;
import org.apache.gravitino.exceptions.NoSuchPolicyException;
import org.apache.gravitino.exceptions.NoSuchRoleException;
@@ -1108,6 +1109,10 @@ public class ErrorHandlers {
.getType()
.equals(NoSuchModelVersionException.class.getSimpleName())) {
throw new NoSuchModelVersionException(errorMsg);
+ } else if (errorResponse
+ .getType()
+
.equals(NoSuchModelVersionURINameException.class.getSimpleName())) {
+ throw new NoSuchModelVersionURINameException(errorMsg);
} else {
throw new NotFoundException(errorMsg);
}
diff --git
a/clients/client-java/src/main/java/org/apache/gravitino/client/GenericModelCatalog.java
b/clients/client-java/src/main/java/org/apache/gravitino/client/GenericModelCatalog.java
index 18fe4fe94e..da3977b09e 100644
---
a/clients/client-java/src/main/java/org/apache/gravitino/client/GenericModelCatalog.java
+++
b/clients/client-java/src/main/java/org/apache/gravitino/client/GenericModelCatalog.java
@@ -45,10 +45,12 @@ import org.apache.gravitino.dto.responses.ModelResponse;
import org.apache.gravitino.dto.responses.ModelVersionInfoListResponse;
import org.apache.gravitino.dto.responses.ModelVersionListResponse;
import org.apache.gravitino.dto.responses.ModelVersionResponse;
+import org.apache.gravitino.dto.responses.ModelVersionUriResponse;
import org.apache.gravitino.exceptions.ModelAlreadyExistsException;
import
org.apache.gravitino.exceptions.ModelVersionAliasesAlreadyExistException;
import org.apache.gravitino.exceptions.NoSuchModelException;
import org.apache.gravitino.exceptions.NoSuchModelVersionException;
+import org.apache.gravitino.exceptions.NoSuchModelVersionURINameException;
import org.apache.gravitino.exceptions.NoSuchSchemaException;
import org.apache.gravitino.model.Model;
import org.apache.gravitino.model.ModelCatalog;
@@ -220,14 +222,14 @@ class GenericModelCatalog extends BaseSchemaCatalog
implements ModelCatalog {
@Override
public void linkModelVersion(
NameIdentifier ident,
- String uri,
+ Map<String, String> uris,
String[] aliases,
String comment,
Map<String, String> properties)
throws NoSuchModelException, ModelVersionAliasesAlreadyExistException {
checkModelNameIdentifier(ident);
- ModelVersionLinkRequest req = new ModelVersionLinkRequest(uri, aliases,
comment, properties);
+ ModelVersionLinkRequest req = new ModelVersionLinkRequest(uris, aliases,
comment, properties);
NameIdentifier modelFullIdent = modelFullNameIdentifier(ident);
BaseResponse resp =
restClient.post(
@@ -240,6 +242,50 @@ class GenericModelCatalog extends BaseSchemaCatalog
implements ModelCatalog {
resp.validate();
}
+ @Override
+ public String getModelVersionUri(NameIdentifier ident, int version, String
uriName)
+ throws NoSuchModelVersionException, NoSuchModelVersionURINameException {
+ checkModelNameIdentifier(ident);
+ Preconditions.checkArgument(version >= 0, "Model version must be
non-negative");
+
+ NameIdentifier modelFullIdent = modelFullNameIdentifier(ident);
+ Map<String, String> queryParam =
+ uriName == null ? Collections.emptyMap() : ImmutableMap.of("uriName",
uriName);
+ ModelVersionUriResponse resp =
+ restClient.get(
+ formatModelVersionRequestPath(modelFullIdent) + "/versions/" +
version + "/uri",
+ queryParam,
+ ModelVersionUriResponse.class,
+ Collections.emptyMap(),
+ ErrorHandlers.modelErrorHandler());
+ resp.validate();
+ return resp.getUri();
+ }
+
+ @Override
+ public String getModelVersionUri(NameIdentifier ident, String alias, String
uriName)
+ throws NoSuchModelVersionException, NoSuchModelVersionURINameException {
+ checkModelNameIdentifier(ident);
+ Preconditions.checkArgument(StringUtils.isNotBlank(alias), "Model alias
must be non-empty");
+
+ NameIdentifier modelFullIdent = modelFullNameIdentifier(ident);
+ Map<String, String> queryParam =
+ uriName == null ? Collections.emptyMap() : ImmutableMap.of("uriName",
uriName);
+ ModelVersionUriResponse resp =
+ restClient.get(
+ formatModelVersionRequestPath(modelFullIdent)
+ + "/aliases/"
+ + RESTUtils.encodeString(alias)
+ + "/uri",
+ queryParam,
+ ModelVersionUriResponse.class,
+ Collections.emptyMap(),
+ ErrorHandlers.modelErrorHandler());
+
+ resp.validate();
+ return resp.getUri();
+ }
+
@Override
public boolean deleteModelVersion(NameIdentifier ident, int version) {
checkModelNameIdentifier(ident);
diff --git
a/clients/client-java/src/main/java/org/apache/gravitino/client/GenericModelVersion.java
b/clients/client-java/src/main/java/org/apache/gravitino/client/GenericModelVersion.java
index 28b2d1e93e..406237ee10 100644
---
a/clients/client-java/src/main/java/org/apache/gravitino/client/GenericModelVersion.java
+++
b/clients/client-java/src/main/java/org/apache/gravitino/client/GenericModelVersion.java
@@ -37,8 +37,8 @@ class GenericModelVersion implements ModelVersion {
}
@Override
- public String uri() {
- return modelVersionDTO.uri();
+ public Map<String, String> uris() {
+ return modelVersionDTO.uris();
}
@Override
diff --git
a/clients/client-java/src/test/java/org/apache/gravitino/client/TestGenericModelCatalog.java
b/clients/client-java/src/test/java/org/apache/gravitino/client/TestGenericModelCatalog.java
index 47d00a6599..af5314a3f5 100644
---
a/clients/client-java/src/test/java/org/apache/gravitino/client/TestGenericModelCatalog.java
+++
b/clients/client-java/src/test/java/org/apache/gravitino/client/TestGenericModelCatalog.java
@@ -48,6 +48,7 @@ import org.apache.gravitino.dto.responses.ModelResponse;
import org.apache.gravitino.dto.responses.ModelVersionInfoListResponse;
import org.apache.gravitino.dto.responses.ModelVersionListResponse;
import org.apache.gravitino.dto.responses.ModelVersionResponse;
+import org.apache.gravitino.dto.responses.ModelVersionUriResponse;
import org.apache.gravitino.exceptions.ModelAlreadyExistsException;
import
org.apache.gravitino.exceptions.ModelVersionAliasesAlreadyExistException;
import org.apache.gravitino.exceptions.NoSuchModelException;
@@ -493,7 +494,10 @@ public class TestGenericModelCatalog extends TestBase {
ModelVersionLinkRequest request =
new ModelVersionLinkRequest(
- "uri", new String[] {"alias1", "alias2"}, "comment",
Collections.emptyMap());
+ ImmutableMap.of(ModelVersion.URI_NAME_UNKNOWN, "uri"),
+ new String[] {"alias1", "alias2"},
+ "comment",
+ Collections.emptyMap());
BaseResponse resp = new BaseResponse(0);
buildMockResource(Method.POST, modelVersionPath, request, resp,
HttpStatus.SC_OK);
@@ -565,6 +569,94 @@ public class TestGenericModelCatalog extends TestBase {
"internal error");
}
+ @ParameterizedTest
+ @ValueSource(strings = {"schema1/model1", "스키마1/모델1"})
+ public void testLinkModelVersionWithMultipleUris(String input) throws
JsonProcessingException {
+ String[] split = input.split("/");
+ String schemaName = split[0];
+ String modelName = split[1];
+ NameIdentifier modelId = NameIdentifier.of(schemaName, modelName);
+ String modelVersionPath =
+ withSlash(
+ GenericModelCatalog.formatModelVersionRequestPath(
+ NameIdentifier.of(METALAKE_NAME, CATALOG_NAME, schemaName,
modelName))
+ + "/versions");
+
+ Map<String, String> uris = ImmutableMap.of("n1", "u1", "n2", "u2");
+ ModelVersionLinkRequest request =
+ new ModelVersionLinkRequest(
+ uris, new String[] {"alias1", "alias2"}, "comment",
Collections.emptyMap());
+ BaseResponse resp = new BaseResponse(0);
+ buildMockResource(Method.POST, modelVersionPath, request, resp,
HttpStatus.SC_OK);
+
+ Assertions.assertDoesNotThrow(
+ () ->
+ catalog
+ .asModelCatalog()
+ .linkModelVersion(
+ modelId,
+ uris,
+ new String[] {"alias1", "alias2"},
+ "comment",
+ Collections.emptyMap()));
+
+ // Throw model not found exception
+ ErrorResponse errResp =
+ ErrorResponse.notFound(NoSuchModelException.class.getSimpleName(),
"model not found");
+ buildMockResource(Method.POST, modelVersionPath, request, errResp,
HttpStatus.SC_NOT_FOUND);
+
+ Assertions.assertThrows(
+ NoSuchModelException.class,
+ () ->
+ catalog
+ .asModelCatalog()
+ .linkModelVersion(
+ modelId,
+ uris,
+ new String[] {"alias1", "alias2"},
+ "comment",
+ Collections.emptyMap()),
+ "model not found");
+
+ // Throw ModelVersionAliasesAlreadyExistException
+ ErrorResponse errResp2 =
+ ErrorResponse.alreadyExists(
+ ModelVersionAliasesAlreadyExistException.class.getSimpleName(),
+ "model version already exists");
+ buildMockResource(Method.POST, modelVersionPath, request, errResp2,
HttpStatus.SC_CONFLICT);
+
+ Assertions.assertThrows(
+ ModelVersionAliasesAlreadyExistException.class,
+ () ->
+ catalog
+ .asModelCatalog()
+ .linkModelVersion(
+ modelId,
+ uris,
+ new String[] {"alias1", "alias2"},
+ "comment",
+ Collections.emptyMap()),
+ "model version already exists");
+
+ // Throw RuntimeException
+ ErrorResponse errResp3 = ErrorResponse.internalError("internal error");
+ buildMockResource(
+ Method.POST, modelVersionPath, request, errResp3,
HttpStatus.SC_INTERNAL_SERVER_ERROR);
+
+ Assertions.assertThrows(
+ RuntimeException.class,
+ () ->
+ catalog
+ .asModelCatalog()
+ .linkModelVersion(
+ modelId,
+ uris,
+ new String[] {"alias1", "alias2"},
+ "comment",
+ Collections.emptyMap()),
+ "internal error");
+ }
+
@ParameterizedTest
@ValueSource(strings = {"schema1/model1", "스키마1/모델1"})
public void testDeleteModelVersion(String input) throws
JsonProcessingException {
@@ -846,6 +938,147 @@ public class TestGenericModelCatalog extends TestBase {
"internal error");
}
+ @ParameterizedTest
+ @ValueSource(strings = {"schema1/model1/alias1/alias2", "스키마1/모델1/별칭1/별칭2"})
+ void testUpdateModelVersionWithMultipleUris(String input) throws
JsonProcessingException {
+ String[] split = input.split("/");
+ String schemaName = split[0];
+ String modelName = split[1];
+ String[] aliases = new String[] {split[2], split[3]};
+ String comment = "comment";
+ int version = 0;
+
+ NameIdentifier modelId = NameIdentifier.of(schemaName, modelName);
+ String modelVersionPath =
+ withSlash(
+ GenericModelCatalog.formatModelVersionRequestPath(
+ NameIdentifier.of(METALAKE_NAME, CATALOG_NAME, schemaName,
modelName))
+ + "/aliases/"
+ + RESTUtils.encodeString(aliases[0]));
+
+ // Test update uri
+ Map<String, String> uris = ImmutableMap.of("n1", "u1", "n2", "u2");
+ ModelVersionDTO mockModelVersion =
+ mockModelVersion(version, uris, aliases, comment,
Collections.emptyMap());
+ ModelVersionResponse resp = new ModelVersionResponse(mockModelVersion);
+ ModelVersionUpdateRequest.UpdateModelVersionUriRequest updateUri =
+ new ModelVersionUpdateRequest.UpdateModelVersionUriRequest("n2", "u2");
+ buildMockResource(
+ Method.PUT,
+ modelVersionPath,
+ new ModelVersionUpdatesRequest(ImmutableList.of(updateUri)),
+ resp,
+ HttpStatus.SC_OK);
+ ModelVersion updatedModelVersion =
+ catalog
+ .asModelCatalog()
+ .alterModelVersion(modelId, aliases[0],
updateUri.modelVersionChange());
+ compareModelVersion(mockModelVersion, updatedModelVersion);
+ Assertions.assertEquals(uris, updatedModelVersion.uris());
+ Assertions.assertEquals(comment, updatedModelVersion.comment());
+ Assertions.assertEquals(Collections.emptyMap(),
updatedModelVersion.properties());
+ Assertions.assertEquals(version, updatedModelVersion.version());
+ Assertions.assertArrayEquals(aliases, updatedModelVersion.aliases());
+
+ // Test add uri
+ uris = ImmutableMap.of("n1", "u1", "n2", "u2");
+ mockModelVersion = mockModelVersion(version, uris, aliases, comment,
Collections.emptyMap());
+ resp = new ModelVersionResponse(mockModelVersion);
+ ModelVersionUpdateRequest.AddModelVersionUriRequest addUri =
+ new ModelVersionUpdateRequest.AddModelVersionUriRequest("n2", "u2");
+ buildMockResource(
+ Method.PUT,
+ modelVersionPath,
+ new ModelVersionUpdatesRequest(ImmutableList.of(addUri)),
+ resp,
+ HttpStatus.SC_OK);
+ updatedModelVersion =
+ catalog
+ .asModelCatalog()
+ .alterModelVersion(modelId, aliases[0],
addUri.modelVersionChange());
+ compareModelVersion(mockModelVersion, updatedModelVersion);
+ Assertions.assertEquals(uris, updatedModelVersion.uris());
+ Assertions.assertEquals(comment, updatedModelVersion.comment());
+ Assertions.assertEquals(Collections.emptyMap(),
updatedModelVersion.properties());
+ Assertions.assertEquals(version, updatedModelVersion.version());
+ Assertions.assertArrayEquals(aliases, updatedModelVersion.aliases());
+
+ // Test remove uri
+ uris = ImmutableMap.of("n1", "u1");
+ mockModelVersion = mockModelVersion(version, uris, aliases, comment,
Collections.emptyMap());
+ resp = new ModelVersionResponse(mockModelVersion);
+ ModelVersionUpdateRequest.RemoveModelVersionUriRequest removeUri =
+ new ModelVersionUpdateRequest.RemoveModelVersionUriRequest("n2");
+ buildMockResource(
+ Method.PUT,
+ modelVersionPath,
+ new ModelVersionUpdatesRequest(ImmutableList.of(removeUri)),
+ resp,
+ HttpStatus.SC_OK);
+ updatedModelVersion =
+ catalog
+ .asModelCatalog()
+ .alterModelVersion(modelId, aliases[0],
removeUri.modelVersionChange());
+ compareModelVersion(mockModelVersion, updatedModelVersion);
+ Assertions.assertEquals(uris, updatedModelVersion.uris());
+ Assertions.assertEquals(comment, updatedModelVersion.comment());
+ Assertions.assertEquals(Collections.emptyMap(),
updatedModelVersion.properties());
+ Assertions.assertEquals(version, updatedModelVersion.version());
+ Assertions.assertArrayEquals(aliases, updatedModelVersion.aliases());
+ }
+
+ @ParameterizedTest
+ @ValueSource(strings = {"schema1/model1", "스키마1/모델1"})
+ public void testGetModelVersionUri(String input) throws
JsonProcessingException {
+ String[] split = input.split("/");
+ String schemaName = split[0];
+ String modelName = split[1];
+ NameIdentifier modelId = NameIdentifier.of(schemaName, modelName);
+
+ int version = 0;
+ String modelVersionUriPath =
+ withSlash(
+ GenericModelCatalog.formatModelVersionRequestPath(
+ NameIdentifier.of(METALAKE_NAME, CATALOG_NAME, schemaName,
modelName))
+ + "/versions/"
+ + version
+ + "/uri");
+ String uriName = "name-s3";
+ String uri = "s3://path/to/model";
+ Map<String, String> params = ImmutableMap.of("uriName", uriName);
+ ModelVersionUriResponse resp = new ModelVersionUriResponse(uri);
+ buildMockResource(Method.GET, modelVersionUriPath, params, null, resp,
HttpStatus.SC_OK);
+
+ Assertions.assertEquals(
+ uri, catalog.asModelCatalog().getModelVersionUri(modelId, version,
uriName));
+ }
+
+ @ParameterizedTest
+ @ValueSource(strings = {"schema1/model1", "스키마1/모델1"})
+ public void testGetModelVersionUriByAlias(String input) throws
JsonProcessingException {
+ String[] split = input.split("/");
+ String schemaName = split[0];
+ String modelName = split[1];
+ NameIdentifier modelId = NameIdentifier.of(schemaName, modelName);
+
+ String alias = "alias1";
+ String modelVersionUriPath =
+ withSlash(
+ GenericModelCatalog.formatModelVersionRequestPath(
+ NameIdentifier.of(METALAKE_NAME, CATALOG_NAME, schemaName,
modelName))
+ + "/aliases/"
+ + alias
+ + "/uri");
+ String uriName = "name-s3";
+ String uri = "s3://path/to/model";
+ Map<String, String> params = ImmutableMap.of("uriName", uriName);
+ ModelVersionUriResponse resp = new ModelVersionUriResponse(uri);
+ buildMockResource(Method.GET, modelVersionUriPath, params, null, resp,
HttpStatus.SC_OK);
+
+ Assertions.assertEquals(
+ uri, catalog.asModelCatalog().getModelVersionUri(modelId, alias,
uriName));
+ }
+
private ModelDTO mockModelDTO(
String modelName, int latestVersion, String comment, Map<String, String>
properties) {
return ModelDTO.builder()
@@ -859,9 +1092,19 @@ public class TestGenericModelCatalog extends TestBase {
private ModelVersionDTO mockModelVersion(
int version, String uri, String[] aliases, String comment, Map<String,
String> properties) {
+ return mockModelVersion(
+ version, ImmutableMap.of(ModelVersion.URI_NAME_UNKNOWN, uri), aliases,
comment, properties);
+ }
+
+ private ModelVersionDTO mockModelVersion(
+ int version,
+ Map<String, String> uris,
+ String[] aliases,
+ String comment,
+ Map<String, String> properties) {
return ModelVersionDTO.builder()
.withVersion(version)
- .withUris(ImmutableMap.of(ModelVersion.URI_NAME_UNKNOWN, uri))
+ .withUris(uris)
.withAliases(aliases)
.withComment(comment)
.withProperties(properties)
@@ -879,6 +1122,7 @@ public class TestGenericModelCatalog extends TestBase {
private void compareModelVersion(ModelVersion expect, ModelVersion result) {
Assertions.assertEquals(expect.version(), result.version());
Assertions.assertEquals(expect.uri(), result.uri());
+ Assertions.assertEquals(expect.uris(), result.uris());
Assertions.assertArrayEquals(expect.aliases(), result.aliases());
Assertions.assertEquals(expect.comment(), result.comment());
Assertions.assertEquals(expect.properties(), result.properties());
diff --git a/clients/client-python/gravitino/api/model_version_change.py
b/clients/client-python/gravitino/api/model_version_change.py
index 6ca2aff18b..cdbe2cef2f 100644
--- a/clients/client-python/gravitino/api/model_version_change.py
+++ b/clients/client-python/gravitino/api/model_version_change.py
@@ -56,14 +56,36 @@ class ModelVersionChange(ABC):
return ModelVersionChange.RemoveProperty(key)
@staticmethod
- def update_uri(uri: str):
+ def update_uri(uri: str, uri_name: str = None):
"""Creates a new model version change to update the uri of the model
version.
Args:
uri: The new uri of the model version.
+ uri_name: The uri name of the model version to be updated.
Returns:
The model version change.
"""
- return ModelVersionChange.UpdateUri(uri)
+ return ModelVersionChange.UpdateUri(uri, uri_name)
+
+ @staticmethod
+ def add_uri(uri_name: str, uri: str):
+ """Creates a new model version change to add the uri of the model
version.
+ Args:
+ uri_name: The uri name of the model version to be added.
+ uri: The new uri of the model version to be added.
+ Returns:
+ The model version change.
+ """
+ return ModelVersionChange.AddUri(uri_name, uri)
+
+ @staticmethod
+ def remove_uri(uri_name: str):
+ """Creates a new model version change to remove the uri of the model
version.
+ Args:
+ uri_name: The uri name of the model version to be removed.
+ Returns:
+ The model version change.
+ """
+ return ModelVersionChange.RemoveUri(uri_name)
@staticmethod
def update_aliases(aliases_to_add, aliases_to_remove):
@@ -213,8 +235,9 @@ class ModelVersionChange(ABC):
class UpdateUri:
"""A model version change to update the URI of the model version."""
- def __init__(self, new_uri: str):
+ def __init__(self, new_uri: str, uri_name: str = None):
self._new_uri = new_uri
+ self._uri_name = uri_name
def new_uri(self) -> str:
"""Retrieves the new URI of the model version.
@@ -223,9 +246,16 @@ class ModelVersionChange(ABC):
"""
return self._new_uri
+ def uri_name(self) -> str:
+ """Retrieves the URI name of the model version to be updated.
+ Returns:
+ The URI name of the model version to be updated.
+ """
+ return self._uri_name
+
def __eq__(self, other):
"""Compares this UpdateUri instance with another object for
equality. Two instances are
- considered equal if they designate the same new URI for the model
version.
+ considered equal if they designate the same new URI and URI name
for the model version.
Args:
other: The object to compare with this instance.
Returns:
@@ -234,23 +264,118 @@ class ModelVersionChange(ABC):
"""
if not isinstance(other, ModelVersionChange.UpdateUri):
return False
- return self.new_uri() == other.new_uri()
+ return (
+ self.new_uri() == other.new_uri()
+ and self.uri_name() == other.uri_name()
+ )
def __hash__(self):
"""Generates a hash code for this UpdateUri instance. The hash
code is primarily based on
- the new URI for the model version.
+ the new URI and its name for the model version.
Returns:
A hash code value for this URI update operation.
"""
- return hash(self.new_uri())
+ return hash((self.new_uri(), self.uri_name()))
def __str__(self):
"""Provides a string representation of the UpdateUri instance.
This string includes the
- class name followed by the new URI of the model version.
+ class name followed by the new URI and its name of the model
version.
Returns:
A string summary of this URI update operation.
"""
- return f"UpdateUri {self._new_uri}"
+ return f"UpdateUri uriName: ({self._uri_name}) newUri:
({self._new_uri})"
+
+ class AddUri:
+ """A ModelVersionChange to add a uri of the model version."""
+
+ def __init__(self, uri_name: str, uri: str):
+ self._uri_name = uri_name
+ self._uri = uri
+
+ def uri_name(self) -> str:
+ """Retrieves the URI name of the model version to be added.
+ Returns:
+ The URI name of the model version to be added.
+ """
+ return self._uri_name
+
+ def uri(self) -> str:
+ """Retrieves the URI of the model version to be added.
+ Returns:
+ The new URI of the model version to be added.
+ """
+ return self._uri
+
+ def __eq__(self, other):
+ """Compares this AddUri instance with another object for equality.
Two instances are
+ considered equal if they designate the same URI and URI name for
the model version.
+ Args:
+ other: The object to compare with this instance.
+ Returns:
+ true if the given object represents an identical model version
URI add operation;
+ false otherwise.
+ """
+ if not isinstance(other, ModelVersionChange.AddUri):
+ return False
+ return self.uri_name() == other.uri_name() and self.uri() ==
other.uri()
+
+ def __hash__(self):
+ """Generates a hash code for this AddUri instance. The hash code
is primarily based on
+ the URI and its name for the model version.
+ Returns:
+ A hash code value for this URI add operation.
+ """
+ return hash((self.uri_name(), self.uri()))
+
+ def __str__(self):
+ """Provides a string representation of the AddUri instance. This
string includes the
+ class name followed by the URI and its name of the model version.
+ Returns:
+ A string summary of this URI add operation.
+ """
+ return f"AddUri uriName: ({self._uri_name}) uri: ({self.uri})"
+
+ class RemoveUri:
+ """A ModelVersionChange to remove a uri of the model version."""
+
+ def __init__(self, uri_name: str):
+ self._uri_name = uri_name
+
+ def uri_name(self) -> str:
+ """Retrieves the URI name of the model version to be removed.
+ Returns:
+ The URI name of the model version to be removed.
+ """
+ return self._uri_name
+
+ def __eq__(self, other):
+ """Compares this RemoveUri instance with another object for
equality. Two instances are
+ considered equal if they designate the same URI name for the model
version.
+ Args:
+ other: The object to compare with this instance.
+ Returns:
+ true if the given object represents an identical model version
URI remove operation;
+ false otherwise.
+ """
+ if not isinstance(other, ModelVersionChange.RemoveUri):
+ return False
+ return self.uri_name() == other.uri_name()
+
+ def __hash__(self):
+ """Generates a hash code for this RemoveUri instance. The hash
code is primarily based on
+ the URI name for the model version.
+ Returns:
+ A hash code value for this URI remove operation.
+ """
+ return hash(self.uri_name())
+
+ def __str__(self):
+ """Provides a string representation of the RemoveUri instance.
This string includes the
+ class name followed by the URI name of the model version.
+ Returns:
+ A string summary of this URI remove operation.
+ """
+ return f"RemoveUri uriName: ({self._uri_name})"
class UpdateAliases:
"""A model version change to update the aliases of the model
version."""
diff --git a/clients/client-python/gravitino/client/generic_model_catalog.py
b/clients/client-python/gravitino/client/generic_model_catalog.py
index 860eace840..e71287a1fc 100644
--- a/clients/client-python/gravitino/client/generic_model_catalog.py
+++ b/clients/client-python/gravitino/client/generic_model_catalog.py
@@ -45,6 +45,7 @@ from gravitino.dto.responses.model_version_list_response
import (
ModelVersionListResponse,
)
from gravitino.dto.responses.model_version_response import ModelVersionResponse
+from gravitino.dto.responses.model_version_uri_response import
ModelVersionUriResponse
from gravitino.exceptions.handlers.model_error_handler import
MODEL_ERROR_HANDLER
from gravitino.name_identifier import NameIdentifier
from gravitino.namespace import Namespace
@@ -450,20 +451,11 @@ class GenericModelCatalog(BaseSchemaCatalog):
NoSuchModelException: If the model does not exist.
ModelVersionAliasesAlreadyExistException: If the aliases of the
model version already exist.
"""
- self._check_model_ident(model_ident)
-
- model_full_ident = self._model_full_identifier(model_ident)
+ uris = {ModelVersion.URI_NAME_UNKNOWN: uri} if uri else {}
- request = ModelVersionLinkRequest(uri, comment, aliases, properties)
- request.validate()
-
- resp = self.rest_client.post(
-
f"{self._format_model_version_request_path(model_full_ident)}/versions",
- request,
- error_handler=MODEL_ERROR_HANDLER,
+ return self.link_model_version_with_multiple_uris(
+ model_ident, uris, aliases, comment, properties
)
- base_resp = BaseResponse.from_json(resp.body, infer_missing=True)
- base_resp.validate()
def link_model_version_with_multiple_uris(
self,
@@ -491,7 +483,20 @@ class GenericModelCatalog(BaseSchemaCatalog):
NoSuchModelException: If the model does not exist.
ModelVersionAliasesAlreadyExistException: If the aliases of the
model version already exist.
"""
- raise NotImplementedError("Not supported yet")
+ self._check_model_ident(model_ident)
+
+ model_full_ident = self._model_full_identifier(model_ident)
+
+ request = ModelVersionLinkRequest(uris, comment, aliases, properties)
+ request.validate()
+
+ resp = self.rest_client.post(
+
f"{self._format_model_version_request_path(model_full_ident)}/versions",
+ request,
+ error_handler=MODEL_ERROR_HANDLER,
+ )
+ base_resp = BaseResponse.from_json(resp.body, infer_missing=True)
+ base_resp.validate()
def get_model_version_uri(
self, model_ident: NameIdentifier, version: int, uri_name: str = None
@@ -510,7 +515,25 @@ class GenericModelCatalog(BaseSchemaCatalog):
Returns:
The URI of the model version.
"""
- raise NotImplementedError("Not supported yet")
+ self._check_model_ident(model_ident)
+
+ model_full_ident = self._model_full_identifier(model_ident)
+ params = {}
+ if uri_name is not None:
+ params["uriName"] = encode_string(uri_name)
+
+ resp = self.rest_client.get(
+
f"{self._format_model_version_request_path(model_full_ident)}/versions/{version}/uri",
+ params=params,
+ error_handler=MODEL_ERROR_HANDLER,
+ )
+
+ model_version_uri_resp = ModelVersionUriResponse.from_json(
+ resp.body, infer_missing=True
+ )
+ model_version_uri_resp.validate()
+
+ return model_version_uri_resp.uri()
def get_model_version_uri_by_alias(
self, model_ident: NameIdentifier, alias: str, uri_name: str = None
@@ -529,7 +552,25 @@ class GenericModelCatalog(BaseSchemaCatalog):
Returns:
The URI of the model version.
"""
- raise NotImplementedError("Not supported yet")
+ self._check_model_ident(model_ident)
+
+ model_full_ident = self._model_full_identifier(model_ident)
+ params = {}
+ if uri_name is not None:
+ params["uriName"] = encode_string(uri_name)
+
+ resp = self.rest_client.get(
+
f"{self._format_model_version_request_path(model_full_ident)}/aliases/{encode_string(alias)}/uri",
+ params=params,
+ error_handler=MODEL_ERROR_HANDLER,
+ )
+
+ model_version_uri_resp = ModelVersionUriResponse.from_json(
+ resp.body, infer_missing=True
+ )
+ model_version_uri_resp.validate()
+
+ return model_version_uri_resp.uri()
def delete_model_version(self, model_ident: NameIdentifier, version: int)
-> bool:
"""Delete the model version from the catalog. If the model version
does not exist, return false.
@@ -608,9 +649,10 @@ class GenericModelCatalog(BaseSchemaCatalog):
Returns:
The registered model object.
"""
- model = self.register_model(ident, comment, properties)
- self.link_model_version(ident, uri, aliases, comment, properties)
- return model
+ uris = {ModelVersion.URI_NAME_UNKNOWN: uri} if uri else {}
+ return self.register_model_version_with_multiple_uris(
+ ident, uris, aliases, comment, properties
+ )
def register_model_version_with_multiple_uris(
self,
@@ -664,6 +706,7 @@ class GenericModelCatalog(BaseSchemaCatalog):
raise ValueError(f"Unknown change type: {type(change).__name__}")
+ # pylint: disable=too-many-return-statements
@staticmethod
def to_model_version_update_request(change: ModelVersionChange):
if isinstance(change, ModelVersionChange.UpdateComment):
@@ -683,7 +726,17 @@ class GenericModelCatalog(BaseSchemaCatalog):
if isinstance(change, ModelVersionChange.UpdateUri):
return ModelVersionUpdateRequest.UpdateModelVersionUriRequest(
- change.new_uri()
+ change.new_uri(), change.uri_name()
+ )
+
+ if isinstance(change, ModelVersionChange.AddUri):
+ return ModelVersionUpdateRequest.AddModelVersionUriRequest(
+ change.uri_name(), change.uri()
+ )
+
+ if isinstance(change, ModelVersionChange.RemoveUri):
+ return ModelVersionUpdateRequest.RemoveModelVersionUriRequest(
+ change.uri_name()
)
if isinstance(change, ModelVersionChange.UpdateAliases):
diff --git a/clients/client-python/gravitino/client/generic_model_version.py
b/clients/client-python/gravitino/client/generic_model_version.py
index 0c13f53668..9c0cc977a1 100644
--- a/clients/client-python/gravitino/client/generic_model_version.py
+++ b/clients/client-python/gravitino/client/generic_model_version.py
@@ -37,9 +37,6 @@ class GenericModelVersion(ModelVersion):
def aliases(self) -> List[str]:
return self._model_version_dto.aliases()
- def uri(self) -> str:
- return self._model_version_dto.uri()
-
def uris(self) -> Dict[str, str]:
return self._model_version_dto.uris()
diff --git a/clients/client-python/gravitino/dto/model_version_dto.py
b/clients/client-python/gravitino/dto/model_version_dto.py
index 939eb2e347..fca1a62f4a 100644
--- a/clients/client-python/gravitino/dto/model_version_dto.py
+++ b/clients/client-python/gravitino/dto/model_version_dto.py
@@ -31,7 +31,7 @@ class ModelVersionDTO(ModelVersion, DataClassJsonMixin):
_version: int = field(metadata=config(field_name="version"))
_comment: Optional[str] = field(metadata=config(field_name="comment"))
_aliases: Optional[List[str]] =
field(metadata=config(field_name="aliases"))
- _uri: str = field(metadata=config(field_name="uri"))
+ _uris: Dict[str, str] = field(metadata=config(field_name="uris"))
_properties: Optional[Dict[str, str]] = field(
metadata=config(field_name="properties")
)
@@ -46,11 +46,8 @@ class ModelVersionDTO(ModelVersion, DataClassJsonMixin):
def aliases(self) -> Optional[List[str]]:
return self._aliases
- def uri(self) -> str:
- return self._uri
-
def uris(self) -> Dict[str, str]:
- raise NotImplementedError("Not supported yet")
+ return self._uris
def properties(self) -> Optional[Dict[str, str]]:
return self._properties
diff --git
a/clients/client-python/gravitino/dto/requests/model_version_link_request.py
b/clients/client-python/gravitino/dto/requests/model_version_link_request.py
index 98b1c45514..a5b02a405d 100644
--- a/clients/client-python/gravitino/dto/requests/model_version_link_request.py
+++ b/clients/client-python/gravitino/dto/requests/model_version_link_request.py
@@ -27,7 +27,7 @@ from gravitino.rest.rest_message import RESTRequest
class ModelVersionLinkRequest(RESTRequest):
"""Represents a request to link a model version to a model."""
- _uri: str = field(metadata=config(field_name="uri"))
+ _uris: Dict[str, str] = field(metadata=config(field_name="uris"))
_comment: Optional[str] = field(metadata=config(field_name="comment"))
_aliases: Optional[List[str]] =
field(metadata=config(field_name="aliases"))
_properties: Optional[Dict[str, str]] = field(
@@ -36,12 +36,12 @@ class ModelVersionLinkRequest(RESTRequest):
def __init__(
self,
- uri: str,
+ uris: Dict[str, str],
comment: Optional[str] = None,
aliases: Optional[List[str]] = None,
properties: Optional[Dict[str, str]] = None,
):
- self._uri = uri
+ self._uris = uris
self._comment = comment
self._aliases = aliases
self._properties = properties
@@ -52,11 +52,17 @@ class ModelVersionLinkRequest(RESTRequest):
Raises:
IllegalArgumentException if the request is invalid
"""
- if not self._is_not_blank(self._uri):
+ if not self._uris:
raise IllegalArgumentException(
- '"uri" field is required and cannot be empty'
+ '"uris" field is required and cannot be empty'
)
+ for key, value in self._uris.items():
+ if not self._is_not_blank(key):
+ raise IllegalArgumentException("uri name must not be null or
empty")
+ if not self._is_not_blank(value):
+ raise IllegalArgumentException("uri must not be null or empty")
+
for alias in self._aliases or []:
if not self._is_not_blank(alias):
raise IllegalArgumentException("Alias must not be null or
empty")
diff --git
a/clients/client-python/gravitino/dto/requests/model_version_update_request.py
b/clients/client-python/gravitino/dto/requests/model_version_update_request.py
index 962d0634b3..644cd966c6 100644
---
a/clients/client-python/gravitino/dto/requests/model_version_update_request.py
+++
b/clients/client-python/gravitino/dto/requests/model_version_update_request.py
@@ -113,11 +113,12 @@ class ModelVersionUpdateRequest:
"""Request to update model version uri"""
_new_uri: Optional[str] = field(metadata=config(field_name="newUri"))
- """Represents a request to update the uri on a Metalake."""
+ _uri_name: Optional[str] = field(metadata=config(field_name="uriName"))
- def __init__(self, new_uri: str):
+ def __init__(self, new_uri: str, uri_name: str):
super().__init__("updateUri")
self._new_uri = new_uri
+ self._uri_name = uri_name
def new_uri(self):
"""Retrieves the new uri of the model version.
@@ -126,8 +127,15 @@ class ModelVersionUpdateRequest:
"""
return self._new_uri
+ def uri_name(self):
+ """Retrieves the uri name of the model version.
+ Returns:
+ The uri name of the model version.
+ """
+ return self._uri_name
+
def validate(self):
- """Validates the fields of the request. Always pass."""
+ """Validates the fields of the request."""
if not self._new_uri:
raise ValueError('"newUri" field is required')
@@ -137,7 +145,78 @@ class ModelVersionUpdateRequest:
Returns:
ModelVersionChange: The ModelVersionChange object representing
the update uri operation.
"""
- return ModelVersionChange.update_uri(self._new_uri)
+ return ModelVersionChange.update_uri(self._new_uri, self._uri_name)
+
+ @dataclass
+ class AddModelVersionUriRequest(ModelVersionUpdateRequestBase):
+ """Request to add model version uri"""
+
+ _uri_name: Optional[str] = field(metadata=config(field_name="uriName"))
+ _uri: Optional[str] = field(metadata=config(field_name="uri"))
+
+ def __init__(self, uri_name: str, uri: str):
+ super().__init__("addUri")
+ self._uri_name = uri_name
+ self._uri = uri
+
+ def uri_name(self):
+ """Retrieves the uri name of the model version.
+ Returns:
+ The uri name of the model version.
+ """
+ return self._uri_name
+
+ def uri(self):
+ """Retrieves the uri of the model version.
+ Returns:
+ The uri of the model version.
+ """
+ return self._uri
+
+ def validate(self):
+ """Validates the fields of the request."""
+ if not self._uri_name:
+ raise ValueError('"uriName" field is required')
+ if not self._uri:
+ raise ValueError('"uri" field is required')
+
+ def model_version_change(self):
+ """
+ Returns a ModelVersionChange object representing the add uri
operation.
+ Returns:
+ ModelVersionChange: The ModelVersionChange object representing
the add uri operation.
+ """
+ return ModelVersionChange.add_uri(self._uri_name, self._uri)
+
+ @dataclass
+ class RemoveModelVersionUriRequest(ModelVersionUpdateRequestBase):
+ """Request to remove model version uri"""
+
+ _uri_name: Optional[str] = field(metadata=config(field_name="uriName"))
+
+ def __init__(self, uri_name: str):
+ super().__init__("removeUri")
+ self._uri_name = uri_name
+
+ def uri_name(self):
+ """Retrieves the uri name of the model version.
+ Returns:
+ The uri name of the model version.
+ """
+ return self._uri_name
+
+ def validate(self):
+ """Validates the fields of the request."""
+ if not self._uri_name:
+ raise ValueError('"uriName" field is required')
+
+ def model_version_change(self):
+ """
+ Returns a ModelVersionChange object representing the remove uri
operation.
+ Returns:
+ ModelVersionChange: The ModelVersionChange object representing
the remove uri operation.
+ """
+ return ModelVersionChange.remove_uri(self._uri_name)
@dataclass
class ModelVersionAliasesRequest(ModelVersionUpdateRequestBase):
diff --git
a/clients/client-python/gravitino/dto/responses/model_version_list_response.py
b/clients/client-python/gravitino/dto/responses/model_version_list_response.py
index 84986c00ec..90b158f236 100644
---
a/clients/client-python/gravitino/dto/responses/model_version_list_response.py
+++
b/clients/client-python/gravitino/dto/responses/model_version_list_response.py
@@ -73,7 +73,7 @@ class ModelVersionInfoListResponse(BaseResponse):
raise IllegalArgumentException(
"Model version 'version' must not be null"
)
- if version.uri() is None:
+ if version.uris() is None:
raise IllegalArgumentException("Model version 'uri' must not
be null")
if version.audit_info() is None:
raise IllegalArgumentException(
diff --git
a/clients/client-python/gravitino/dto/responses/model_version_response.py
b/clients/client-python/gravitino/dto/responses/model_version_response.py
index 0c0101d6f9..ee16067b6d 100644
--- a/clients/client-python/gravitino/dto/responses/model_version_response.py
+++ b/clients/client-python/gravitino/dto/responses/model_version_response.py
@@ -45,7 +45,7 @@ class ModelVersionResponse(BaseResponse):
raise IllegalArgumentException("Model version must not be null")
if self._model_version.version() is None:
raise IllegalArgumentException("Model version 'version' must not
be null")
- if self._model_version.uri() is None:
+ if self._model_version.uris() is None:
raise IllegalArgumentException("Model version 'uri' must not be
null")
if self._model_version.audit_info() is None:
raise IllegalArgumentException("Model version 'auditInfo' must not
be null")
diff --git
a/clients/client-python/gravitino/dto/responses/model_version_response.py
b/clients/client-python/gravitino/dto/responses/model_version_uri_response.py
similarity index 52%
copy from
clients/client-python/gravitino/dto/responses/model_version_response.py
copy to
clients/client-python/gravitino/dto/responses/model_version_uri_response.py
index 0c0101d6f9..515faabe36 100644
--- a/clients/client-python/gravitino/dto/responses/model_version_response.py
+++
b/clients/client-python/gravitino/dto/responses/model_version_uri_response.py
@@ -14,38 +14,29 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from dataclasses import field, dataclass
-from dataclasses_json import config
+from dataclasses import dataclass, field
-from gravitino.dto.model_version_dto import ModelVersionDTO
+from dataclasses_json import config
from gravitino.dto.responses.base_response import BaseResponse
from gravitino.exceptions.base import IllegalArgumentException
@dataclass
-class ModelVersionResponse(BaseResponse):
- """Represents a response for a model version."""
+class ModelVersionUriResponse(BaseResponse):
+ """Response for the model version uri."""
- _model_version: ModelVersionDTO =
field(metadata=config(field_name="modelVersion"))
+ _uri: str = field(metadata=config(field_name="uri"))
- def model_version(self) -> ModelVersionDTO:
- """Returns the model version."""
- return self._model_version
+ def uri(self) -> str:
+ return self._uri
def validate(self):
"""Validates the response data.
Raises:
- IllegalArgumentException if the model version is not set.
+ IllegalArgumentException if model version uri is not set.
"""
super().validate()
-
- if self._model_version is None:
- raise IllegalArgumentException("Model version must not be null")
- if self._model_version.version() is None:
- raise IllegalArgumentException("Model version 'version' must not
be null")
- if self._model_version.uri() is None:
- raise IllegalArgumentException("Model version 'uri' must not be
null")
- if self._model_version.audit_info() is None:
- raise IllegalArgumentException("Model version 'auditInfo' must not
be null")
+ if self._uri is None or len(self.uri()) == 0:
+ raise IllegalArgumentException("Model version uri must not be
null")
diff --git a/clients/client-python/tests/integration/test_model_catalog.py
b/clients/client-python/tests/integration/test_model_catalog.py
index 234d604709..a73c6e7efc 100644
--- a/clients/client-python/tests/integration/test_model_catalog.py
+++ b/clients/client-python/tests/integration/test_model_catalog.py
@@ -25,6 +25,7 @@ from gravitino.exceptions.base import (
NoSuchModelException,
NoSuchModelVersionException,
NoSuchSchemaException,
+ NoSuchModelVersionURINameException,
)
from gravitino.namespace import Namespace
from tests.integration.integration_test_env import IntegrationTestEnv
@@ -440,6 +441,84 @@ class TestModelCatalog(IntegrationTestEnv):
self.assertEqual("comment", updated_model_version.comment())
self.assertEqual({"k1": "v1", "k2": "v2"},
updated_model_version.properties())
+ def test_link_add_model_version_uri(self):
+ model_name = f"model_it_model{str(randint(0, 1000))}"
+ model_ident = NameIdentifier.of(self._schema_name, model_name)
+ aliases = ["alias1", "alias2"]
+ comment = "comment"
+ properties = {"k1": "v1", "k2": "v2"}
+ self._catalog.as_model_catalog().register_model(
+ model_ident, comment, properties
+ )
+ self._catalog.as_model_catalog().link_model_version_with_multiple_uris(
+ model_ident,
+ uris={"n1": "u1"},
+ aliases=aliases,
+ comment="comment",
+ properties={"k1": "v1", "k2": "v2"},
+ )
+
+ original_model_version =
self._catalog.as_model_catalog().get_model_version(
+ model_ident, 0
+ )
+
+ self.assertEqual(0, original_model_version.version())
+ self.assertEqual({"n1": "u1"}, original_model_version.uris())
+ self.assertEqual(["alias1", "alias2"],
original_model_version.aliases())
+ self.assertEqual("comment", original_model_version.comment())
+ self.assertEqual({"k1": "v1", "k2": "v2"},
original_model_version.properties())
+
+ changes = [ModelVersionChange.add_uri("n2", "u2")]
+ self._catalog.as_model_catalog().alter_model_version(model_ident, 0,
*changes)
+
+ updated_model_version =
self._catalog.as_model_catalog().get_model_version(
+ model_ident, 0
+ )
+ self.assertEqual(0, updated_model_version.version())
+ self.assertEqual({"n1": "u1", "n2": "u2"},
updated_model_version.uris())
+ self.assertEqual(["alias1", "alias2"], updated_model_version.aliases())
+ self.assertEqual("comment", updated_model_version.comment())
+ self.assertEqual({"k1": "v1", "k2": "v2"},
updated_model_version.properties())
+
+ def test_link_remove_model_version_uri(self):
+ model_name = f"model_it_model{str(randint(0, 1000))}"
+ model_ident = NameIdentifier.of(self._schema_name, model_name)
+ aliases = ["alias1", "alias2"]
+ comment = "comment"
+ properties = {"k1": "v1", "k2": "v2"}
+ self._catalog.as_model_catalog().register_model(
+ model_ident, comment, properties
+ )
+ self._catalog.as_model_catalog().link_model_version_with_multiple_uris(
+ model_ident,
+ uris={"n1": "u1", "n2": "u2"},
+ aliases=aliases,
+ comment="comment",
+ properties={"k1": "v1", "k2": "v2"},
+ )
+
+ original_model_version =
self._catalog.as_model_catalog().get_model_version(
+ model_ident, 0
+ )
+
+ self.assertEqual(0, original_model_version.version())
+ self.assertEqual({"n1": "u1", "n2": "u2"},
original_model_version.uris())
+ self.assertEqual(["alias1", "alias2"],
original_model_version.aliases())
+ self.assertEqual("comment", original_model_version.comment())
+ self.assertEqual({"k1": "v1", "k2": "v2"},
original_model_version.properties())
+
+ changes = [ModelVersionChange.remove_uri("n1")]
+ self._catalog.as_model_catalog().alter_model_version(model_ident, 0,
*changes)
+
+ updated_model_version =
self._catalog.as_model_catalog().get_model_version(
+ model_ident, 0
+ )
+ self.assertEqual(0, updated_model_version.version())
+ self.assertEqual({"n2": "u2"}, updated_model_version.uris())
+ self.assertEqual(["alias1", "alias2"], updated_model_version.aliases())
+ self.assertEqual("comment", updated_model_version.comment())
+ self.assertEqual({"k1": "v1", "k2": "v2"},
updated_model_version.properties())
+
def test_link_update_model_version_aliases(self):
model_name = f"model_it_model{str(randint(0, 1000))}"
model_ident = NameIdentifier.of(self._schema_name, model_name)
@@ -662,6 +741,120 @@ class TestModelCatalog(IntegrationTestEnv):
NameIdentifier.of(self._schema_name, "non_existent_model")
)
+ def test_link_model_version_with_multiple_uris(self):
+ model_name = "model_it_model" + str(randint(0, 1000))
+ model_ident = NameIdentifier.of(self._schema_name, model_name)
+ self._catalog.as_model_catalog().register_model(model_ident,
"comment", {})
+
+ # Test link model version
+ self._catalog.as_model_catalog().link_model_version_with_multiple_uris(
+ model_ident,
+ uris={"n1": "u1", "n2": "u2"},
+ aliases=["alias1", "alias2"],
+ comment="comment",
+ properties={"k1": "v1", "k2": "v2"},
+ )
+
+ # Test get model version
+ model_version = self._catalog.as_model_catalog().get_model_version(
+ model_ident, 0
+ )
+ self.assertEqual(0, model_version.version())
+ self.assertEqual({"n1": "u1", "n2": "u2"}, model_version.uris())
+ self.assertEqual(["alias1", "alias2"], model_version.aliases())
+ self.assertEqual("comment", model_version.comment())
+ self.assertEqual({"k1": "v1", "k2": "v2"}, model_version.properties())
+
+ # Test get model version by alias
+ model_version =
self._catalog.as_model_catalog().get_model_version_by_alias(
+ model_ident, "alias1"
+ )
+ self.assertEqual(0, model_version.version())
+ self.assertEqual({"n1": "u1", "n2": "u2"}, model_version.uris())
+
+ model_version =
self._catalog.as_model_catalog().get_model_version_by_alias(
+ model_ident, "alias2"
+ )
+ self.assertEqual(0, model_version.version())
+ self.assertEqual({"n1": "u1", "n2": "u2"}, model_version.uris())
+
+ # Test list model versions
+ model_versions = self._catalog.as_model_catalog().list_model_versions(
+ model_ident
+ )
+ self.assertEqual(1, len(model_versions))
+ self.assertTrue(0 in model_versions)
+
+ # Test list model version infos
+ model_versions = self._catalog.as_model_catalog().list_model_versions(
+ model_ident
+ )
+ self.assertEqual(1, len(model_versions))
+ self.assertTrue(0 in model_versions)
+ model_versions =
self._catalog.as_model_catalog().list_model_version_infos(
+ model_ident
+ )
+ self.assertEqual(1, len(model_versions))
+ self.assertEqual(0, model_versions[0].version())
+ self.assertEqual({"n1": "u1", "n2": "u2"}, model_versions[0].uris())
+ self.assertEqual("comment", model_versions[0].comment())
+ self.assertEqual(["alias1", "alias2"], model_versions[0].aliases())
+ self.assertEqual({"k1": "v1", "k2": "v2"},
model_versions[0].properties())
+
+ def test_get_model_version_uri(self):
+ model_name = "model_it_model" + str(randint(0, 1000))
+ model_ident = NameIdentifier.of(self._schema_name, model_name)
+ self._catalog.as_model_catalog().register_model(model_ident,
"comment", {})
+
+ # link model version
+ self._catalog.as_model_catalog().link_model_version_with_multiple_uris(
+ model_ident,
+ uris={"n1": "u1", "n2": "u2"},
+ aliases=["alias1", "alias2"],
+ comment="comment",
+ properties={"k1": "v1", "k2": "v2"},
+ )
+
+ # Test get model version
+ model_version = self._catalog.as_model_catalog().get_model_version(
+ model_ident, 0
+ )
+ self.assertEqual(0, model_version.version())
+ self.assertEqual({"n1": "u1", "n2": "u2"}, model_version.uris())
+ self.assertEqual(["alias1", "alias2"], model_version.aliases())
+ self.assertEqual("comment", model_version.comment())
+ self.assertEqual({"k1": "v1", "k2": "v2"}, model_version.properties())
+
+ # Test get model version uri
+ model_version_uri =
self._catalog.as_model_catalog().get_model_version_uri(
+ model_ident, 0, "n1"
+ )
+ self.assertEqual("u1", model_version_uri)
+ model_version_uri =
self._catalog.as_model_catalog().get_model_version_uri(
+ model_ident, 0, "n2"
+ )
+ self.assertEqual("u2", model_version_uri)
+ with self.assertRaises(NoSuchModelVersionURINameException):
+
self._catalog.as_model_catalog().get_model_version_uri(model_ident, 0, "n3")
+
+ # Test get model version uri by alias
+ model_version_uri = (
+ self._catalog.as_model_catalog().get_model_version_uri_by_alias(
+ model_ident, "alias1", "n1"
+ )
+ )
+ self.assertEqual("u1", model_version_uri)
+ model_version_uri = (
+ self._catalog.as_model_catalog().get_model_version_uri_by_alias(
+ model_ident, "alias1", "n2"
+ )
+ )
+ self.assertEqual("u2", model_version_uri)
+ with self.assertRaises(NoSuchModelVersionURINameException):
+ self._catalog.as_model_catalog().get_model_version_uri_by_alias(
+ model_ident, "alias1", "n3"
+ )
+
def test_link_delete_model_version(self):
model_name = "model_it_model" + str(randint(0, 1000))
model_ident = NameIdentifier.of(self._schema_name, model_name)
diff --git a/clients/client-python/tests/unittests/test_model_catalog_api.py
b/clients/client-python/tests/unittests/test_model_catalog_api.py
index 540562367f..663eb439b0 100644
--- a/clients/client-python/tests/unittests/test_model_catalog_api.py
+++ b/clients/client-python/tests/unittests/test_model_catalog_api.py
@@ -35,6 +35,7 @@ from gravitino.dto.responses.model_version_list_response
import (
ModelVersionListResponse,
)
from gravitino.dto.responses.model_version_response import ModelVersionResponse
+from gravitino.dto.responses.model_version_uri_response import
ModelVersionUriResponse
from gravitino.namespace import Namespace
from gravitino.utils import Response
from tests.unittests import mock_base
@@ -274,7 +275,7 @@ class TestModelCatalogApi(unittest.TestCase):
model_versions_dto = [
ModelVersionDTO(
_version=0,
- _uri="http://localhost:8090",
+ _uris={"unknown": "http://localhost:8090"},
_aliases=["alias1", "alias2"],
_comment="this is test",
_properties={"k": "v"},
@@ -325,7 +326,7 @@ class TestModelCatalogApi(unittest.TestCase):
model_version_dto = ModelVersionDTO(
_version=1,
- _uri="http://localhost:8090",
+ _uris={"unknown": "http://localhost:8090"},
_aliases=["alias1", "alias2"],
_comment="new comment",
_properties={"k": "v"},
@@ -363,7 +364,7 @@ class TestModelCatalogApi(unittest.TestCase):
## test with response
model_version_dto = ModelVersionDTO(
_version=1,
- _uri="http://localhost:8090",
+ _uris={"unknown": "http://localhost:8090"},
_aliases=["alias1", "alias2"],
_comment="this is test",
_properties={"k": "v"},
@@ -390,7 +391,7 @@ class TestModelCatalogApi(unittest.TestCase):
## test with empty response
model_version_dto = ModelVersionDTO(
_version=1,
- _uri="http://localhost:8090",
+ _uris={"unknown": "http://localhost:8090"},
_aliases=None,
_comment=None,
_properties=None,
@@ -425,7 +426,7 @@ class TestModelCatalogApi(unittest.TestCase):
## test with response
model_version_dto = ModelVersionDTO(
_version=1,
- _uri="http://localhost:8090",
+ _uris={"unknown": "http://localhost:8090"},
_aliases=["alias1", "alias2"],
_comment="this is test",
_properties={"k": "v"},
@@ -449,6 +450,74 @@ class TestModelCatalogApi(unittest.TestCase):
)
)
+ def test_link_model_version_with_multiple_uris(self, *mock_method):
+ gravitino_client = GravitinoClient(
+ uri="http://localhost:8090", metalake_name=self._metalake_name
+ )
+ catalog = gravitino_client.load_catalog(self._catalog_name)
+
+ model_ident = NameIdentifier.of("schema", "model1")
+
+ ## test with response
+ model_version_dto = ModelVersionDTO(
+ _version=1,
+ _uris={"default-uri-name": "http://localhost:8090"},
+ _aliases=["alias1", "alias2"],
+ _comment="this is test",
+ _properties={"k": "v"},
+ _audit=AuditDTO(_creator="test",
_create_time="2022-01-01T00:00:00Z"),
+ )
+ model_resp = ModelVersionResponse(_model_version=model_version_dto,
_code=0)
+ json_str = model_resp.to_json()
+ mock_resp = self._mock_http_response(json_str)
+
+ with patch(
+ "gravitino.utils.http_client.HTTPClient.post",
+ return_value=mock_resp,
+ ):
+ self.assertIsNone(
+
catalog.as_model_catalog().link_model_version_with_multiple_uris(
+ model_ident,
+ {"default-uri-name": "http://localhost:8090"},
+ ["alias1", "alias2"],
+ "this is test",
+ {"k": "v"},
+ )
+ )
+
+ def test_get_model_version_uri(self, *mock_method):
+ gravitino_client = GravitinoClient(
+ uri="http://localhost:8090", metalake_name=self._metalake_name
+ )
+ catalog = gravitino_client.load_catalog(self._catalog_name)
+
+ model_ident = NameIdentifier.of("schema", "model1")
+ version = 1
+ alias = "alias1"
+
+ ## test with response
+ model_version_uri_resp = ModelVersionUriResponse(
+ _code=0, _uri="s3://path/to/model"
+ )
+ json_str = model_version_uri_resp.to_json()
+ mock_resp = self._mock_http_response(json_str)
+
+ with patch(
+ "gravitino.utils.http_client.HTTPClient.get",
+ return_value=mock_resp,
+ ):
+ model_version_uri =
catalog.as_model_catalog().get_model_version_uri(
+ model_ident, version, "uri_name"
+ )
+ self.assertEqual("s3://path/to/model", model_version_uri)
+
+ model_version_uri = (
+ catalog.as_model_catalog().get_model_version_uri_by_alias(
+ model_ident, alias, "uri_name"
+ )
+ )
+ self.assertEqual("s3://path/to/model", model_version_uri)
+
def test_delete_model_version(self, *mock_method):
gravitino_client = GravitinoClient(
uri="http://localhost:8090", metalake_name=self._metalake_name
diff --git a/clients/client-python/tests/unittests/test_responses.py
b/clients/client-python/tests/unittests/test_responses.py
index 4dc04a402b..3cc451ffc4 100644
--- a/clients/client-python/tests/unittests/test_responses.py
+++ b/clients/client-python/tests/unittests/test_responses.py
@@ -22,6 +22,7 @@ from gravitino.dto.responses.file_location_response import
FileLocationResponse
from gravitino.dto.responses.model_response import ModelResponse
from gravitino.dto.responses.model_version_list_response import
ModelVersionListResponse
from gravitino.dto.responses.model_version_response import ModelVersionResponse
+from gravitino.dto.responses.model_version_uri_response import
ModelVersionUriResponse
from gravitino.exceptions.base import IllegalArgumentException
@@ -161,7 +162,7 @@ class TestResponses(unittest.TestCase):
"modelVersion": {
"version": 0,
"aliases": ["alias1", "alias2"],
- "uri": "http://localhost:8080",
+ "uris": {"unknown": "http://localhost:8080"},
"comment": "test comment",
"properties": {"key1": "value1"},
"audit": {
@@ -188,7 +189,7 @@ class TestResponses(unittest.TestCase):
"code": 0,
"modelVersion": {
"version": 0,
- "uri": "http://localhost:8080",
+ "uris": {"unknown": "http://localhost:8080"},
"audit": {
"creator": "anonymous",
"createTime": "2024-04-05T10:10:35.218Z",
@@ -208,7 +209,7 @@ class TestResponses(unittest.TestCase):
json_data = {
"code": 0,
"modelVersion": {
- "uri": "http://localhost:8080",
+ "uris": {"unknown": "http://localhost:8080"},
"audit": {
"creator": "anonymous",
"createTime": "2024-04-05T10:10:35.218Z",
@@ -241,7 +242,7 @@ class TestResponses(unittest.TestCase):
"code": 0,
"modelVersion": {
"version": 0,
- "uri": "http://localhost:8080",
+ "uris": {"unknown": "http://localhost:8080"},
},
}
json_str = json.dumps(json_data)
@@ -249,3 +250,28 @@ class TestResponses(unittest.TestCase):
json_str, infer_missing=True
)
self.assertRaises(IllegalArgumentException, resp.validate)
+
+ def test_model_version_uri_response(self):
+ json_data = {"code": 0, "uri": "s3://path/to/model"}
+ json_str = json.dumps(json_data)
+ resp: ModelVersionUriResponse = ModelVersionUriResponse.from_json(
+ json_str, infer_missing=True
+ )
+ resp.validate()
+ self.assertEqual("s3://path/to/model", resp.uri())
+
+ json_data_missing = {"code": 0, "uri": ""}
+ json_str_missing = json.dumps(json_data_missing)
+ resp_missing: ModelVersionUriResponse =
ModelVersionUriResponse.from_json(
+ json_str_missing, infer_missing=True
+ )
+ self.assertRaises(IllegalArgumentException, resp_missing.validate)
+
+ json_data_missing_1 = {
+ "code": 0,
+ }
+ json_str_missing_1 = json.dumps(json_data_missing_1)
+ resp_missing_1: ModelVersionUriResponse =
ModelVersionUriResponse.from_json(
+ json_str_missing_1, infer_missing=True
+ )
+ self.assertRaises(IllegalArgumentException, resp_missing_1.validate)