This is an automated email from the ASF dual-hosted git repository.
jshao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gravitino.git
The following commit(s) were added to refs/heads/main by this push:
new c022efe585 [#8758] feat(client-python): add relational table (#9132)
c022efe585 is described below
commit c022efe585d43cc1019e68ffb322e77ead4ba10e
Author: George T. C. Lai <[email protected]>
AuthorDate: Tue Dec 2 19:59:33 2025 +0800
[#8758] feat(client-python): add relational table (#9132)
### What changes were proposed in this pull request?
The following classes and methods were included in this PR:
- `RelationalTable`
- `DTOConverters.from_dto(table_dto: TableDTO)`
**NOTE** that a refactor to the method
`DTOConverters.from_dto(table_dto: TableDTO)` is included in this PR due
to the issue of *cyclic-import*.
### Why are the changes needed?
We need to enable table operations in Python client that requires
implementation of the `RelationalTable`.
Fix: #8758
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit tests.
---------
Signed-off-by: George T. C. Lai <[email protected]>
---
.../gravitino/client/relational_table.py | 198 +++++++++++--
.../dto/responses/partition_list_response.py | 4 +
.../gravitino/dto/responses/partition_response.py | 9 +
.../gravitino/dto/util/dto_converters.py | 58 +++-
.../tests/unittests/test_relational_table.py | 328 +++++++++++++++++++++
.../tests/unittests/test_responses.py | 13 +
6 files changed, 576 insertions(+), 34 deletions(-)
diff --git a/clients/client-python/gravitino/client/relational_table.py
b/clients/client-python/gravitino/client/relational_table.py
index 2d02d5e6a7..86c62cbdd3 100644
--- a/clients/client-python/gravitino/client/relational_table.py
+++ b/clients/client-python/gravitino/client/relational_table.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Optional
+from typing import Optional, cast
from gravitino.api.audit import Audit
from gravitino.api.rel.column import Column
@@ -23,57 +23,193 @@ from
gravitino.api.rel.expressions.distributions.distribution import Distributio
from gravitino.api.rel.expressions.sorts.sort_order import SortOrder
from gravitino.api.rel.expressions.transforms.transform import Transform
from gravitino.api.rel.indexes.index import Index
+from gravitino.api.rel.partitions.partition import Partition
from gravitino.api.rel.table import Table
+from gravitino.client.generic_column import GenericColumn
+from gravitino.dto.rel.partitions.partition_dto import PartitionDTO
+from gravitino.dto.rel.table_dto import TableDTO
+from gravitino.dto.requests.add_partitions_request import AddPartitionsRequest
+from gravitino.dto.responses.drop_response import DropResponse
+from gravitino.dto.responses.partition_list_response import
PartitionListResponse
+from gravitino.dto.responses.partition_name_list_response import (
+ PartitionNameListResponse,
+)
+from gravitino.dto.responses.partition_response import PartitionResponse
+from gravitino.dto.util.dto_converters import DTOConverters
+from gravitino.exceptions.handlers.partition_error_handler import (
+ PARTITION_ERROR_HANDLER,
+)
+from gravitino.namespace import Namespace
+from gravitino.rest.rest_utils import encode_string
+from gravitino.utils import HTTPClient
-class RelationalTable(Table): # pylint: disable=too-many-instance-attributes
- """A generic table implementation."""
+class RelationalTable(Table):
+ """Represents a relational table."""
def __init__(
- self,
- name: str,
- columns: list[Column],
- partitioning: list[Transform],
- sort_order: list[SortOrder],
- distribution: Distribution,
- index: list[Index],
- comment: Optional[str],
- properties: dict[str, str],
- audit_info: Audit,
+ self, namespace: Namespace, table_dto: TableDTO, rest_client:
HTTPClient
):
- self._name = name
- self._columns = columns
- self._partitioning = partitioning
- self._sort_order = sort_order
- self._distribution = distribution
- self._index = index
- self._comment = comment
- self._properties = properties
- self._audit_info = audit_info
+ self._namespace = namespace
+ self._table = cast(Table, DTOConverters.from_dto(table_dto))
+ self._rest_client = rest_client
def name(self) -> str:
- return self._name
+ return self._table.name()
def columns(self) -> list[Column]:
- return self._columns
+ metalake, catalog, schema = self._namespace.levels()
+ return [
+ cast(
+ Column,
+ GenericColumn(
+ c, self._rest_client, metalake, catalog, schema,
self._table.name()
+ ),
+ )
+ for c in self._table.columns()
+ ]
def partitioning(self) -> list[Transform]:
- return self._partitioning
+ return self._table.partitioning()
def sort_order(self) -> list[SortOrder]:
- return self._sort_order
+ return self._table.sort_order()
def distribution(self) -> Distribution:
- return self._distribution
+ return self._table.distribution()
def index(self) -> list[Index]:
- return self._index
+ return self._table.index()
def comment(self) -> Optional[str]:
- return self._comment
+ return self._table.comment()
def properties(self) -> dict[str, str]:
- return self._properties
+ return self._table.properties()
def audit_info(self) -> Audit:
- return self._audit_info
+ return self._table.audit_info()
+
+ def _get_partition_request_path(self) -> str:
+ """Get the partition request path.
+
+ Returns:
+ str: The partition request path.
+ """
+
+ return (
+ f"api/metalakes/{encode_string(self._namespace.level(0))}"
+ f"/catalogs/{encode_string(self._namespace.level(1))}"
+ f"/schemas/{encode_string(self._namespace.level(2))}"
+ f"/tables/{encode_string(self._table.name())}"
+ "/partitions"
+ )
+
+ def list_partition_names(self) -> list[str]:
+ """Get the partition names of the table.
+
+ Returns:
+ list[str]: The partition names of the table.
+ """
+
+ resp = self._rest_client.get(
+ endpoint=self._get_partition_request_path(),
+ error_handler=PARTITION_ERROR_HANDLER,
+ )
+ partition_name_list_resp = PartitionNameListResponse.from_json(
+ resp.body, infer_missing=True
+ )
+ partition_name_list_resp.validate()
+
+ return partition_name_list_resp.partition_names()
+
+ def list_partitions(self) -> list[Partition]:
+ """Get the partitions of the table.
+
+ Returns:
+ list[Partition]: The partitions of the table.
+ """
+
+ params = {"details": "true"}
+ resp = self._rest_client.get(
+ endpoint=self._get_partition_request_path(),
+ params=params,
+ error_handler=PARTITION_ERROR_HANDLER,
+ )
+ partition_list_resp = PartitionListResponse.from_json(
+ resp.body, infer_missing=True
+ )
+ partition_list_resp.validate()
+
+ return partition_list_resp.get_partitions()
+
+ def get_partition(self, partition_name: str) -> Partition:
+ """Returns the partition with the given name.
+
+ Args:
+ partition_name (str): the name of the partition
+
+ Returns:
+ Partition: the partition with the given name
+
+ Raises:
+ NoSuchPartitionException:
+ if the partition does not exist, throws this exception.
+ """
+
+ resp = self._rest_client.get(
+
endpoint=f"{self._get_partition_request_path()}/{encode_string(partition_name)}",
+ error_handler=PARTITION_ERROR_HANDLER,
+ )
+ partition_resp = PartitionResponse.from_json(resp.body,
infer_missing=True)
+ partition_resp.validate()
+
+ return partition_resp.get_partition()
+
+ def drop_partition(self, partition_name: str) -> bool:
+ """Drops the partition with the given name.
+
+ Args:
+ partition_name (str): The name of the partition.
+
+ Returns:
+ bool: `True` if the partition is dropped, `False` if the partition
does not exist.
+ """
+ resp = self._rest_client.delete(
+
endpoint=f"{self._get_partition_request_path()}/{encode_string(partition_name)}",
+ error_handler=PARTITION_ERROR_HANDLER,
+ )
+ drop_resp = DropResponse.from_json(resp.body, infer_missing=True)
+ drop_resp.validate()
+
+ return drop_resp.dropped()
+
+ def add_partition(self, partition: Partition) -> Partition:
+ """Adds a partition to the table.
+
+ Args:
+ partition (Partition): The partition to add.
+
+ Returns:
+ Partition: The added partition.
+
+ Raises:
+ PartitionAlreadyExistsException:
+ if the partition already exists, throws this exception.
+ """
+
+ req = AddPartitionsRequest(
+ [cast(PartitionDTO, DTOConverters.to_dto(partition))]
+ )
+ req.validate()
+
+ resp = self._rest_client.post(
+ endpoint=self._get_partition_request_path(),
+ json=req,
+ error_handler=PARTITION_ERROR_HANDLER,
+ )
+ partition_list_resp = PartitionListResponse.from_json(
+ resp.body, infer_missing=True
+ )
+ partition_list_resp.validate()
+ return partition_list_resp.get_partitions()[0]
diff --git
a/clients/client-python/gravitino/dto/responses/partition_list_response.py
b/clients/client-python/gravitino/dto/responses/partition_list_response.py
index 084871fe16..3d01ee6d3f 100644
--- a/clients/client-python/gravitino/dto/responses/partition_list_response.py
+++ b/clients/client-python/gravitino/dto/responses/partition_list_response.py
@@ -19,6 +19,7 @@ from dataclasses import dataclass, field
from dataclasses_json import config
+from gravitino.api.rel.partitions.partition import Partition
from gravitino.dto.rel.partitions.json_serdes.partition_dto_serdes import (
PartitionDTOSerdes,
)
@@ -41,3 +42,6 @@ class PartitionListResponse(BaseResponse):
],
)
)
+
+ def get_partitions(self) -> list[Partition]:
+ return self._partitions
diff --git
a/clients/client-python/gravitino/dto/responses/partition_response.py
b/clients/client-python/gravitino/dto/responses/partition_response.py
index b6a0117cf7..f3daece22d 100644
--- a/clients/client-python/gravitino/dto/responses/partition_response.py
+++ b/clients/client-python/gravitino/dto/responses/partition_response.py
@@ -19,6 +19,7 @@ from dataclasses import dataclass, field
from dataclasses_json import config
+from gravitino.api.rel.partitions.partition import Partition
from gravitino.dto.rel.partitions.json_serdes.partition_dto_serdes import (
PartitionDTOSerdes,
)
@@ -37,3 +38,11 @@ class PartitionResponse(BaseResponse):
encoder=PartitionDTOSerdes.serialize,
)
)
+
+ def get_partition(self) -> Partition:
+ """Returns the partition.
+
+ Returns:
+ Partition: The partition.
+ """
+ return self._partition
diff --git a/clients/client-python/gravitino/dto/util/dto_converters.py
b/clients/client-python/gravitino/dto/util/dto_converters.py
index 94d7224379..d3188d045f 100644
--- a/clients/client-python/gravitino/dto/util/dto_converters.py
+++ b/clients/client-python/gravitino/dto/util/dto_converters.py
@@ -16,8 +16,9 @@
# under the License.
from functools import singledispatchmethod
-from typing import cast, overload
+from typing import Optional, cast, overload
+from gravitino.api.audit import Audit
from gravitino.api.rel.column import Column
from gravitino.api.rel.expressions.distributions.distribution import
Distribution
from gravitino.api.rel.expressions.distributions.distributions import
Distributions
@@ -39,7 +40,6 @@ from gravitino.api.rel.partitions.partition import Partition
from gravitino.api.rel.partitions.range_partition import RangePartition
from gravitino.api.rel.table import Table
from gravitino.api.rel.types.types import Types
-from gravitino.client.relational_table import RelationalTable
from gravitino.dto.rel.column_dto import ColumnDTO
from gravitino.dto.rel.distribution_dto import DistributionDTO
from gravitino.dto.rel.expressions.field_reference_dto import FieldReferenceDTO
@@ -265,7 +265,59 @@ class DTOConverters:
Table: The table.
"""
- return RelationalTable(
+ class TableImpl(Table): # pylint: disable=too-many-instance-attributes
+ """A table implementation."""
+
+ def __init__(
+ self,
+ name: str,
+ columns: list[Column],
+ partitioning: list[Transform],
+ sort_order: list[SortOrder],
+ distribution: Distribution,
+ index: list[Index],
+ comment: Optional[str],
+ properties: dict[str, str],
+ audit_info: Audit,
+ ):
+ self._name = name
+ self._columns = columns
+ self._partitioning = partitioning
+ self._sort_order = sort_order
+ self._distribution = distribution
+ self._index = index
+ self._comment = comment
+ self._properties = properties
+ self._audit_info = audit_info
+
+ def name(self) -> str:
+ return self._name
+
+ def columns(self) -> list[Column]:
+ return self._columns
+
+ def partitioning(self) -> list[Transform]:
+ return self._partitioning
+
+ def sort_order(self) -> list[SortOrder]:
+ return self._sort_order
+
+ def distribution(self) -> Distribution:
+ return self._distribution
+
+ def index(self) -> list[Index]:
+ return self._index
+
+ def comment(self) -> Optional[str]:
+ return self._comment
+
+ def properties(self) -> dict[str, str]:
+ return self._properties
+
+ def audit_info(self) -> Audit:
+ return self._audit_info
+
+ return TableImpl(
name=dto.name(),
columns=DTOConverters.from_dtos(dto.columns()),
partitioning=DTOConverters.from_dtos(dto.partitioning()),
diff --git a/clients/client-python/tests/unittests/test_relational_table.py
b/clients/client-python/tests/unittests/test_relational_table.py
new file mode 100644
index 0000000000..aa2bae4bb7
--- /dev/null
+++ b/clients/client-python/tests/unittests/test_relational_table.py
@@ -0,0 +1,328 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import json
+import unittest
+from http.client import HTTPResponse
+from typing import cast
+from unittest.mock import Mock, patch
+
+from gravitino.api.rel.expressions.distributions.strategy import Strategy
+from gravitino.api.rel.expressions.literals.literals import Literals
+from gravitino.api.rel.expressions.sorts.null_ordering import NullOrdering
+from gravitino.api.rel.expressions.sorts.sort_direction import SortDirection
+from gravitino.api.rel.expressions.transforms.transforms import Transforms
+from gravitino.api.rel.indexes.index import Index
+from gravitino.api.rel.partitions.partitions import Partitions
+from gravitino.client.generic_column import GenericColumn
+from gravitino.client.relational_table import RelationalTable
+from gravitino.dto.rel.partitions.json_serdes.partition_dto_serdes import (
+ PartitionDTOSerdes,
+)
+from gravitino.dto.rel.table_dto import TableDTO
+from gravitino.dto.responses.drop_response import DropResponse
+from gravitino.dto.responses.partition_list_response import
PartitionListResponse
+from gravitino.dto.responses.partition_name_list_response import (
+ PartitionNameListResponse,
+)
+from gravitino.dto.responses.partition_response import PartitionResponse
+from gravitino.namespace import Namespace
+from gravitino.utils import HTTPClient, Response
+
+
+class TestRelationalTable(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls) -> None:
+ cls.TABLE_DTO_JSON_STRING = """
+ {
+ "name": "example_table",
+ "comment": "This is an example table",
+ "audit": {
+ "creator": "Apache Gravitino",
+ "createTime":"2025-10-10T00:00:00"
+ },
+ "columns": [
+ {
+ "name": "id",
+ "type": "integer",
+ "comment": "id column comment",
+ "nullable": false,
+ "autoIncrement": true,
+ "defaultValue": {
+ "type": "literal",
+ "dataType": "integer",
+ "value": "-1"
+ }
+ },
+ {
+ "name": "name",
+ "type": "varchar(500)",
+ "comment": "name column comment",
+ "nullable": true,
+ "autoIncrement": false,
+ "defaultValue": {
+ "type": "literal",
+ "dataType": "null",
+ "value": "null"
+ }
+ },
+ {
+ "name": "StartingDate",
+ "type": "timestamp",
+ "comment": "StartingDate column comment",
+ "nullable": false,
+ "autoIncrement": false,
+ "defaultValue": {
+ "type": "function",
+ "funcName": "current_timestamp",
+ "funcArgs": []
+ }
+ },
+ {
+ "name": "info",
+ "type": {
+ "type": "struct",
+ "fields": [
+ {
+ "name": "position",
+ "type": "string",
+ "nullable": true,
+ "comment": "position field comment"
+ },
+ {
+ "name": "contact",
+ "type": {
+ "type": "list",
+ "elementType": "integer",
+ "containsNull": false
+ },
+ "nullable": true,
+ "comment": "contact field comment"
+ },
+ {
+ "name": "rating",
+ "type": {
+ "type": "map",
+ "keyType": "string",
+ "valueType": "integer",
+ "valueContainsNull": false
+ },
+ "nullable": true,
+ "comment": "rating field comment"
+ }
+ ]
+ },
+ "comment": "info column comment",
+ "nullable": true
+ },
+ {
+ "name": "dt",
+ "type": "date",
+ "comment": "dt column comment",
+ "nullable": true
+ }
+ ],
+ "partitioning": [
+ {
+ "strategy": "identity",
+ "fieldName": [ "dt" ]
+ }
+ ],
+ "distribution": {
+ "strategy": "hash",
+ "number": 32,
+ "funcArgs": [
+ {
+ "type": "field",
+ "fieldName": [ "id" ]
+ }
+ ]
+ },
+ "sortOrders": [
+ {
+ "sortTerm": {
+ "type": "field",
+ "fieldName": [ "age" ]
+ },
+ "direction": "asc",
+ "nullOrdering": "nulls_first"
+ }
+ ],
+ "indexes": [
+ {
+ "indexType": "primary_key",
+ "name": "PRIMARY",
+ "fieldNames": [["id"]]
+ }
+ ],
+ "properties": {
+ "format": "ORC"
+ }
+ }
+ """
+
+ cls.PARTITION_JSON_STRING = """
+ {
+ "type": "identity",
+ "name": "test_identity_partition",
+ "fieldNames": [
+ [
+ "column_name"
+ ]
+ ],
+ "values": [
+ {
+ "type": "literal",
+ "dataType": "integer",
+ "value": "0"
+ },
+ {
+ "type": "literal",
+ "dataType": "integer",
+ "value": "100"
+ }
+ ]
+ }
+ """
+
+ cls.table_dto = TableDTO.from_json(cls.TABLE_DTO_JSON_STRING)
+ cls.namespace = Namespace.of("metalake_demo", "test_catalog",
"test_schema")
+ cls.rest_client = HTTPClient("http://localhost:8090")
+ cls.relational_table = RelationalTable(
+ cls.namespace, cls.table_dto, cls.rest_client
+ )
+
+ def _get_mock_http_resp(self, json_str: str):
+ mock_http_resp = Mock(HTTPResponse)
+ mock_http_resp.getcode.return_value = 200
+ mock_http_resp.read.return_value = json_str
+ mock_http_resp.info.return_value = None
+ mock_http_resp.url = None
+ mock_resp = Response(mock_http_resp)
+ return mock_resp
+
+ def test_list_partition_names(self):
+ resp_body = PartitionNameListResponse(0, ["partition_1",
"partition_2"])
+ mock_resp = self._get_mock_http_resp(resp_body.to_json())
+
+ with patch(
+ "gravitino.utils.http_client.HTTPClient.get",
+ return_value=mock_resp,
+ ):
+ names = self.relational_table.list_partition_names()
+ self.assertListEqual(names, resp_body.partition_names())
+
+ def test_columns(self):
+ cols = self.relational_table.columns()
+ self.assertEqual(len(cols), len(self.table_dto.columns()))
+ self.assertTrue(all(isinstance(col, GenericColumn) for col in cols))
+
+ def test_list_partitions(self):
+ expected_serialized =
json.loads(TestRelationalTable.PARTITION_JSON_STRING)
+ partitions = [PartitionDTOSerdes.deserialize(expected_serialized)]
+ resp_body = PartitionListResponse(0, partitions)
+ mock_resp = self._get_mock_http_resp(resp_body.to_json())
+
+ with patch(
+ "gravitino.utils.http_client.HTTPClient.get",
+ return_value=mock_resp,
+ ):
+ partitions = self.relational_table.list_partitions()
+ self.assertListEqual(partitions, resp_body.get_partitions())
+
+ def test_get_partition(self):
+ expected_serialized =
json.loads(TestRelationalTable.PARTITION_JSON_STRING)
+ partition_dto = PartitionDTOSerdes.deserialize(expected_serialized)
+ resp_body = PartitionResponse(0, partition_dto)
+ mock_resp = self._get_mock_http_resp(resp_body.to_json())
+
+ with patch(
+ "gravitino.utils.http_client.HTTPClient.get",
+ return_value=mock_resp,
+ ):
+ partition = self.relational_table.get_partition("partition_name")
+ self.assertEqual(partition, resp_body.get_partition())
+
+ def test_drop_partition(self):
+ resp_body = DropResponse(0, True)
+ mock_resp = self._get_mock_http_resp(resp_body.to_json())
+
+ with patch(
+ "gravitino.utils.http_client.HTTPClient.delete",
+ return_value=mock_resp,
+ ):
+
self.assertTrue(self.relational_table.drop_partition("partition_name"))
+
+ def test_add_partition(self):
+ partition = Partitions.identity(
+ "test_identity_partition",
+ [["column_name"]],
+ [Literals.integer_literal(0), Literals.integer_literal(100)],
+ )
+ expected_serialized =
json.loads(TestRelationalTable.PARTITION_JSON_STRING)
+ partitions = [PartitionDTOSerdes.deserialize(expected_serialized)]
+ resp_body = PartitionListResponse(0, partitions)
+ mock_resp = self._get_mock_http_resp(resp_body.to_json())
+
+ with patch(
+ "gravitino.utils.http_client.HTTPClient.post",
+ return_value=mock_resp,
+ ):
+ added_partition = self.relational_table.add_partition(partition)
+ self.assertEqual(added_partition, resp_body.get_partitions()[0])
+
+ def test_get_name(self):
+ self.assertEqual(self.relational_table.name(), "example_table")
+
+ def test_get_comment(self):
+ self.assertEqual(self.relational_table.comment(), "This is an example
table")
+
+ def test_get_partitioning(self):
+ partitioning_list = self.relational_table.partitioning()
+ partitioning = cast(Transforms.IdentityTransform, partitioning_list[0])
+
+ self.assertEqual(len(partitioning_list), 1)
+ self.assertListEqual(partitioning.field_name(), ["dt"])
+
+ def test_get_sort_order(self):
+ sort_order_list = self.relational_table.sort_order()
+ sort_order = sort_order_list[0]
+
+ self.assertEqual(len(sort_order_list), 1)
+ self.assertEqual(sort_order.direction(), SortDirection.ASCENDING)
+ self.assertEqual(sort_order.null_ordering(), NullOrdering.NULLS_FIRST)
+
+ def test_get_distribution(self):
+ distribution = self.relational_table.distribution()
+ self.assertEqual(distribution.strategy(), Strategy.HASH)
+ self.assertEqual(distribution.number(), 32)
+
+ def test_get_index(self):
+ index_list = self.relational_table.index()
+ index = index_list[0]
+ self.assertEqual(len(index_list), 1)
+ self.assertEqual(index.name(), "PRIMARY")
+ self.assertEqual(index.type(), Index.IndexType.PRIMARY_KEY)
+
+ def test_get_audit_info(self):
+ audit_info = self.relational_table.audit_info()
+ self.assertEqual(audit_info.creator(), "Apache Gravitino")
+ self.assertEqual(audit_info.create_time(), "2025-10-10T00:00:00")
+
+ def test_get_properties(self):
+ properties = self.relational_table.properties()
+ self.assertDictEqual(properties, {"format": "ORC"})
diff --git a/clients/client-python/tests/unittests/test_responses.py
b/clients/client-python/tests/unittests/test_responses.py
index b6783cb959..de6505cfcd 100644
--- a/clients/client-python/tests/unittests/test_responses.py
+++ b/clients/client-python/tests/unittests/test_responses.py
@@ -17,6 +17,9 @@
import json
import unittest
+from gravitino.dto.rel.partitions.json_serdes.partition_dto_serdes import (
+ PartitionDTOSerdes,
+)
from gravitino.dto.responses.credential_response import CredentialResponse
from gravitino.dto.responses.file_location_response import FileLocationResponse
from gravitino.dto.responses.model_response import ModelResponse
@@ -336,8 +339,12 @@ class TestResponses(unittest.TestCase):
"partition": {TestResponses.PARTITION_JSON_STRING}
}}
"""
+ partition = PartitionDTOSerdes.deserialize(
+ json.loads(TestResponses.PARTITION_JSON_STRING)
+ )
resp: PartitionResponse = PartitionResponse.from_json(json_string)
resp.validate()
+ self.assertEqual(resp.get_partition(), partition)
def test_partition_list_response(self):
json_string = f"""
@@ -346,5 +353,11 @@ class TestResponses(unittest.TestCase):
"partitions": [{TestResponses.PARTITION_JSON_STRING}]
}}
"""
+ partitions = [
+ PartitionDTOSerdes.deserialize(
+ json.loads(TestResponses.PARTITION_JSON_STRING)
+ )
+ ]
resp: PartitionListResponse =
PartitionListResponse.from_json(json_string)
resp.validate()
+ self.assertListEqual(resp.get_partitions(), partitions)