This is an automated email from the ASF dual-hosted git repository. lzljs3620320 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push: new b9f188e36c [python] add test for data type (#5900) b9f188e36c is described below commit b9f188e36cdc3c2b3586bceab8b7b4241f012996 Author: jerry <lining....@alibaba-inc.com> AuthorDate: Wed Jul 16 14:03:25 2025 +0800 [python] add test for data type (#5900) --- pypaimon/api/__init__.py | 22 ++- pypaimon/api/api_response.py | 289 ++++-------------------------------- pypaimon/api/api_resquest.py | 16 +- pypaimon/api/auth.py | 1 + pypaimon/api/client.py | 205 +++++++++++++++++++++++--- pypaimon/api/data_types.py | 341 +++++++++++++++++++++++++++++++++++++++++++ pypaimon/api/rest_json.py | 71 ++++++--- pypaimon/api/typedef.py | 4 +- pypaimon/tests/api_test.py | 104 ++++++++++++- 9 files changed, 739 insertions(+), 314 deletions(-) diff --git a/pypaimon/api/__init__.py b/pypaimon/api/__init__.py index 573fd889b9..a82d9ebb10 100644 --- a/pypaimon/api/__init__.py +++ b/pypaimon/api/__init__.py @@ -18,11 +18,10 @@ import logging from typing import Dict, List, Optional, Callable from urllib.parse import unquote -import api from api.auth import RESTAuthFunction from api.api_response import PagedList, GetTableResponse, ListDatabasesResponse, ListTablesResponse, \ GetDatabaseResponse, ConfigResponse, PagedResponse -from api.api_resquest import CreateDatabaseRequest +from api.api_resquest import CreateDatabaseRequest, AlterDatabaseRequest from api.typedef import Identifier from api.client import HttpClient from api.auth import DLFAuthProvider, DLFToken @@ -36,6 +35,7 @@ class RESTCatalogOptions: DLF_REGION = "dlf.region" DLF_ACCESS_KEY_ID = "dlf.access-key-id" DLF_ACCESS_KEY_SECRET = "dlf.access-key-secret" + DLF_ACCESS_SECURITY_TOKEN = "dlf.security-token" PREFIX = 'prefix' @@ -217,8 +217,8 @@ class RESTApi: databases = response.data() or [] return PagedList(databases, response.get_next_page_token()) - def create_database(self, name: str, properties: Dict[str, str]) -> None: - request = CreateDatabaseRequest(name, properties) + def create_database(self, name: str, options: Dict[str, str]) -> None: + request = CreateDatabaseRequest(name, options) self.client.post(self.resource_paths.databases(), request, self.rest_auth_function) def get_database(self, name: str) -> GetDatabaseResponse: @@ -231,6 +231,20 @@ class RESTApi: def drop_database(self, name: str) -> None: self.client.delete(self.resource_paths.database(name), self.rest_auth_function) + def alter_database(self, name: str, removals: Optional[List[str]] = None, + updates: Optional[Dict[str, str]] = None): + if not name or not name.strip(): + raise ValueError("Database name cannot be empty") + removals = removals or [] + updates = updates or {} + request = AlterDatabaseRequest(removals, updates) + + return self.client.post( + self.resource_paths.database(name), + request, + self.rest_auth_function + ) + def list_tables(self, database_name: str) -> List[str]: return self.__list_data_from_page_api( lambda query_params: self.client.get_with_params( diff --git a/pypaimon/api/api_response.py b/pypaimon/api/api_response.py index 4f2361c6b8..6005e11810 100644 --- a/pypaimon/api/api_response.py +++ b/pypaimon/api/api_response.py @@ -19,7 +19,10 @@ limitations under the License. from abc import ABC, abstractmethod from typing import Dict, Optional, Any, Generic, List from dataclasses import dataclass, field + +from api.rest_json import json_field from api.typedef import T +from api.data_types import DataField @dataclass @@ -34,15 +37,11 @@ class RESTResponse(ABC): @dataclass class ErrorResponse(RESTResponse): - FIELD_RESOURCE_TYPE: "resourceType" - FIELD_RESOURCE_NAME: "resourceName" - FIELD_MESSAGE: "message" - FIELD_CODE: "code" - resource_type: Optional[str] = None - resource_name: Optional[str] = None - message: Optional[str] = None - code: Optional[int] = None + resource_type: Optional[str] = json_field("resourceType", default=None) + resource_name: Optional[str] = json_field("resourceName", default=None) + message: Optional[str] = json_field("message", default=None) + code: Optional[int] = json_field("code", default=None) def __init__(self, resource_type: Optional[str] = None, @@ -54,23 +53,6 @@ class ErrorResponse(RESTResponse): self.message = message self.code = code - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'ErrorResponse': - return cls( - resource_type=data.get(cls.FIELD_RESOURCE_TYPE), - resource_name=data.get(cls.FIELD_RESOURCE_NAME), - message=data.get(cls.FIELD_MESSAGE), - code=data.get(cls.FIELD_CODE), - ) - - def to_dict(self) -> Dict[str, Any]: - return { - self.FIELD_RESOURCE_TYPE: self.resource_type, - self.FIELD_RESOURCE_NAME: self.resource_name, - self.FIELD_MESSAGE: self.message, - self.FIELD_CODE: self.code - } - @dataclass class AuditRESTResponse(RESTResponse): @@ -80,11 +62,11 @@ class AuditRESTResponse(RESTResponse): FIELD_UPDATED_AT = "updatedAt" FIELD_UPDATED_BY = "updatedBy" - owner: Optional[str] = None - created_at: Optional[int] = None - created_by: Optional[str] = None - updated_at: Optional[int] = None - updated_by: Optional[str] = None + owner: Optional[str] = json_field(FIELD_OWNER, default=None) + created_at: Optional[int] = json_field(FIELD_CREATED_AT, default=None) + created_by: Optional[str] = json_field(FIELD_CREATED_BY, default=None) + updated_at: Optional[int] = json_field(FIELD_UPDATED_AT, default=None) + updated_by: Optional[str] = json_field(FIELD_UPDATED_BY, default=None) def get_owner(self) -> Optional[str]: return self.owner @@ -118,21 +100,8 @@ class PagedResponse(RESTResponse, Generic[T]): class ListDatabasesResponse(PagedResponse[str]): FIELD_DATABASES = "databases" - databases: List[str] - next_page_token: str - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'ListDatabasesResponse': - return cls( - databases=data.get(cls.FIELD_DATABASES), - next_page_token=data.get(cls.FIELD_NEXT_PAGE_TOKEN) - ) - - def to_dict(self) -> Dict[str, Any]: - return { - self.FIELD_DATABASES: self.databases, - self.FIELD_NEXT_PAGE_TOKEN: self.next_page_token - } + databases: List[str] = json_field(FIELD_DATABASES) + next_page_token: str = json_field(PagedResponse.FIELD_NEXT_PAGE_TOKEN) def data(self) -> List[str]: return self.databases @@ -145,21 +114,8 @@ class ListDatabasesResponse(PagedResponse[str]): class ListTablesResponse(PagedResponse[str]): FIELD_TABLES = "tables" - tables: Optional[List[str]] - next_page_token: Optional[str] - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'ListTablesResponse': - return cls( - tables=data.get(cls.FIELD_TABLES), - next_page_token=data.get(cls.FIELD_NEXT_PAGE_TOKEN) - ) - - def to_dict(self) -> Dict[str, Any]: - return { - self.FIELD_TABLES: self.tables, - self.FIELD_NEXT_PAGE_TOKEN: self.next_page_token - } + tables: Optional[List[str]] = json_field(FIELD_TABLES) + next_page_token: Optional[str] = json_field(PagedResponse.FIELD_NEXT_PAGE_TOKEN) def data(self) -> Optional[List[str]]: return self.tables @@ -168,88 +124,6 @@ class ListTablesResponse(PagedResponse[str]): return self.next_page_token -@dataclass -class PaimonDataType: - FIELD_TYPE = "type" - FIELD_ELEMENT = "element" - FIELD_FIELDS = "fields" - FIELD_KEY = "key" - FIELD_VALUE = "value" - - type: str - element: Optional['PaimonDataType'] = None - fields: List['DataField'] = field(default_factory=list) - key: Optional['PaimonDataType'] = None - value: Optional['PaimonDataType'] = None - - @classmethod - def from_dict(cls, data: Any) -> 'PaimonDataType': - if isinstance(data, dict): - element = data.get(cls.FIELD_ELEMENT, None) - fields = data.get(cls.FIELD_FIELDS, None) - key = data.get(cls.FIELD_KEY, None) - value = data.get(cls.FIELD_VALUE, None) - if element is not None: - element = PaimonDataType.from_dict(data.get(cls.FIELD_ELEMENT)), - if fields is not None: - fields = list(map(lambda f: DataField.from_dict(f), fields)), - if key is not None: - key = PaimonDataType.from_dict(key) - if value is not None: - value = PaimonDataType.from_dict(value) - return cls( - type=data.get(cls.FIELD_TYPE), - element=element, - fields=fields, - key=key, - value=value, - ) - else: - return cls(type=data) - - def to_dict(self) -> Any: - if self.element is None and self.fields is None and self.key: - return self.type - if self.element is not None: - return {self.FIELD_TYPE: self.type, self.FIELD_ELEMENT: self.element} - elif self.fields is not None: - return {self.FIELD_TYPE: self.type, self.FIELD_FIELDS: self.fields} - elif self.value is not None: - return {self.FIELD_TYPE: self.type, self.FIELD_KEY: self.key, self.FIELD_VALUE: self.value} - elif self.key is not None and self.value is None: - return {self.FIELD_TYPE: self.type, self.FIELD_KEY: self.key} - - -@dataclass -class DataField: - FIELD_ID = "id" - FIELD_NAME = "name" - FIELD_TYPE = "type" - FIELD_DESCRIPTION = "description" - - description: str - id: int - name: str - type: PaimonDataType - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'DataField': - return cls( - id=data.get(cls.FIELD_ID), - name=data.get(cls.FIELD_NAME), - type=PaimonDataType.from_dict(data.get(cls.FIELD_TYPE)), - description=data.get(cls.FIELD_DESCRIPTION), - ) - - def to_dict(self) -> Dict[str, Any]: - return { - self.FIELD_ID: self.id, - self.FIELD_NAME: self.name, - self.FIELD_TYPE: PaimonDataType.to_dict(self.type), - self.FIELD_DESCRIPTION: self.description - } - - @dataclass class Schema: FIELD_FIELDS = "fields" @@ -258,30 +132,11 @@ class Schema: FIELD_OPTIONS = "options" FIELD_COMMENT = "comment" - fields: List[DataField] = field(default_factory=list) - partition_keys: List[str] = field(default_factory=list) - primary_keys: List[str] = field(default_factory=list) - options: Dict[str, str] = field(default_factory=dict) - comment: Optional[str] = None - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'Schema': - return cls( - fields=list(map(lambda f: DataField.from_dict(f), data.get(cls.FIELD_FIELDS))), - partition_keys=data.get(cls.FIELD_PARTITION_KEYS), - primary_keys=data.get(cls.FIELD_PRIMARY_KEYS), - options=data.get(cls.FIELD_OPTIONS), - comment=data.get(cls.FIELD_COMMENT) - ) - - def to_dict(self) -> Dict[str, Any]: - return { - self.FIELD_FIELDS: list(map(lambda f: DataField.to_dict(f), self.fields)), - self.FIELD_PARTITION_KEYS: self.partition_keys, - self.FIELD_PRIMARY_KEYS: self.primary_keys, - self.FIELD_OPTIONS: self.options, - self.FIELD_COMMENT: self.comment - } + fields: List[DataField] = json_field(FIELD_FIELDS, default_factory=list) + partition_keys: List[str] = json_field(FIELD_PARTITION_KEYS, default_factory=list) + primary_keys: List[str] = json_field(FIELD_PRIMARY_KEYS, default_factory=list) + options: Dict[str, str] = json_field(FIELD_OPTIONS, default_factory=dict) + comment: Optional[str] = json_field(FIELD_COMMENT, default=None) @dataclass @@ -332,12 +187,12 @@ class GetTableResponse(AuditRESTResponse): FIELD_SCHEMA_ID = "schemaId" FIELD_SCHEMA = "schema" - id: Optional[str] = None - name: Optional[str] = None - path: Optional[str] = None - is_external: Optional[bool] = None - schema_id: Optional[int] = None - schema: Optional[Schema] = None + id: Optional[str] = json_field(FIELD_ID, default=None) + name: Optional[str] = json_field(FIELD_NAME, default=None) + path: Optional[str] = json_field(FIELD_PATH, default=None) + is_external: Optional[bool] = json_field(FIELD_IS_EXTERNAL, default=None) + schema_id: Optional[int] = json_field(FIELD_SCHEMA_ID, default=None) + schema: Optional[Schema] = json_field(FIELD_SCHEMA, default=None) def __init__(self, id: str, @@ -359,44 +214,6 @@ class GetTableResponse(AuditRESTResponse): self.schema_id = schema_id self.schema = schema - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'GetTableResponse': - return cls( - id=data.get(cls.FIELD_ID), - name=data.get(cls.FIELD_NAME), - path=data.get(cls.FIELD_PATH), - is_external=data.get(cls.FIELD_IS_EXTERNAL), - schema_id=data.get(cls.FIELD_SCHEMA_ID), - schema=Schema.from_dict(data.get(cls.FIELD_SCHEMA)), - owner=data.get(cls.FIELD_OWNER), - created_at=data.get(cls.FIELD_CREATED_AT), - created_by=data.get(cls.FIELD_CREATED_BY), - updated_at=data.get(cls.FIELD_UPDATED_AT), - updated_by=data.get(cls.FIELD_UPDATED_BY) - ) - - def to_dict(self) -> Dict[str, Any]: - result = { - self.FIELD_ID: self.id, - self.FIELD_NAME: self.name, - self.FIELD_PATH: self.path, - self.FIELD_IS_EXTERNAL: self.is_external, - self.FIELD_SCHEMA_ID: self.schema_id, - self.FIELD_SCHEMA: Schema.to_dict(self.schema) - } - if self.owner is not None: - result[self.FIELD_OWNER] = self.owner - if self.created_at is not None: - result[self.FIELD_CREATED_AT] = self.created_at - if self.created_by is not None: - result[self.FIELD_CREATED_BY] = self.created_by - if self.updated_at is not None: - result[self.FIELD_UPDATED_AT] = self.updated_at - if self.updated_by is not None: - result[self.FIELD_UPDATED_BY] = self.updated_by - - return result - @dataclass class GetDatabaseResponse(AuditRESTResponse): @@ -405,10 +222,10 @@ class GetDatabaseResponse(AuditRESTResponse): FIELD_LOCATION = "location" FIELD_OPTIONS = "options" - id: Optional[str] = None - name: Optional[str] = None - location: Optional[str] = None - options: Optional[Dict[str, str]] = field(default_factory=dict) + id: Optional[str] = json_field(FIELD_ID, default=None) + name: Optional[str] = json_field(FIELD_NAME, default=None) + location: Optional[str] = json_field(FIELD_LOCATION, default=None) + options: Optional[Dict[str, str]] = json_field(FIELD_OPTIONS, default_factory=dict) def __init__(self, id: Optional[str] = None, @@ -438,54 +255,12 @@ class GetDatabaseResponse(AuditRESTResponse): def get_options(self) -> Dict[str, str]: return self.options or {} - def to_dict(self) -> Dict[str, Any]: - result = { - self.FIELD_ID: self.id, - self.FIELD_NAME: self.name, - self.FIELD_LOCATION: self.location, - self.FIELD_OPTIONS: self.options - } - - if self.owner is not None: - result[self.FIELD_OWNER] = self.owner - if self.created_at is not None: - result[self.FIELD_CREATED_AT] = self.created_at - if self.created_by is not None: - result[self.FIELD_CREATED_BY] = self.created_by - if self.updated_at is not None: - result[self.FIELD_UPDATED_AT] = self.updated_at - if self.updated_by is not None: - result[self.FIELD_UPDATED_BY] = self.updated_by - - return result - - @classmethod - def from_dict(cls, data: Dict[str, Any]) -> 'GetDatabaseResponse': - return cls( - id=data.get(cls.FIELD_ID), - name=data.get(cls.FIELD_NAME), - location=data.get(cls.FIELD_LOCATION), - options=data.get(cls.FIELD_OPTIONS, {}), - owner=data.get(cls.FIELD_OWNER), - created_at=data.get(cls.FIELD_CREATED_AT), - created_by=data.get(cls.FIELD_CREATED_BY), - updated_at=data.get(cls.FIELD_UPDATED_AT), - updated_by=data.get(cls.FIELD_UPDATED_BY) - ) - @dataclass class ConfigResponse(RESTResponse): FILED_DEFAULTS = "defaults" - defaults: Dict[str, str] - - @classmethod - def from_dict(cls, data: Dict[str, str]) -> 'ConfigResponse': - return cls(defaults=data.get(cls.FILED_DEFAULTS)) - - def to_dict(self) -> Dict[str, Any]: - return {self.FILED_DEFAULTS: self.defaults} + defaults: Dict[str, str] = json_field(FILED_DEFAULTS) def merge(self, options: Dict[str, str]) -> Dict[str, str]: merged = options.copy() diff --git a/pypaimon/api/api_resquest.py b/pypaimon/api/api_resquest.py index 2d0d2583a8..1434176849 100644 --- a/pypaimon/api/api_resquest.py +++ b/pypaimon/api/api_resquest.py @@ -20,6 +20,8 @@ from abc import ABC from dataclasses import dataclass from typing import Dict, List +from api.rest_json import json_field + class RESTRequest(ABC): pass @@ -27,11 +29,17 @@ class RESTRequest(ABC): @dataclass class CreateDatabaseRequest(RESTRequest): - name: str - properties: Dict[str, str] + FIELD_NAME = "name" + FIELD_OPTIONS = "options" + + name: str = json_field(FIELD_NAME) + options: Dict[str, str] = json_field(FIELD_OPTIONS) @dataclass class AlterDatabaseRequest(RESTRequest): - removals: List[str] - updates: Dict[str, str] + FIELD_REMOVALS = "removals" + FIELD_UPDATES = "updates" + + removals: List[str] = json_field(FIELD_REMOVALS) + updates: Dict[str, str] = json_field(FIELD_UPDATES) diff --git a/pypaimon/api/auth.py b/pypaimon/api/auth.py index 9e88651694..aabefe2d49 100644 --- a/pypaimon/api/auth.py +++ b/pypaimon/api/auth.py @@ -45,6 +45,7 @@ class DLFToken: from api import RESTCatalogOptions self.access_key_id = options.get(RESTCatalogOptions.DLF_ACCESS_KEY_ID) self.access_key_secret = options.get(RESTCatalogOptions.DLF_ACCESS_KEY_SECRET) + self.security_token = options.get(RESTCatalogOptions.DLF_ACCESS_SECURITY_TOKEN) class AuthProvider(ABC): diff --git a/pypaimon/api/client.py b/pypaimon/api/client.py index fc41e2672d..8c4b5b85e0 100644 --- a/pypaimon/api/client.py +++ b/pypaimon/api/client.py @@ -18,9 +18,10 @@ limitations under the License. import json import logging +import traceback import urllib.parse from abc import ABC, abstractmethod -from typing import Dict, Optional, Type, TypeVar, Callable +from typing import Dict, Optional, Type, TypeVar, Callable, Any import requests from requests.adapters import HTTPAdapter @@ -38,10 +39,100 @@ class RESTRequest(ABC): class RESTException(Exception): + def __init__(self, message: str = None, *args: Any, cause: Optional[Exception] = None): + if message and args: + try: + formatted_message = message % args + except (TypeError, ValueError): + formatted_message = f"{message} {' '.join(str(arg) for arg in args)}" + else: + formatted_message = message or "REST API error occurred" + + super().__init__(formatted_message) + self.__cause__ = cause + + def get_cause(self) -> Optional[Exception]: + return self.__cause__ + + def get_message(self) -> str: + return str(self) + + def print_stack_trace(self) -> None: + traceback.print_exception(type(self), self, self.__traceback__) + + def get_stack_trace(self) -> str: + return ''.join(traceback.format_exception(type(self), self, self.__traceback__)) + + def __repr__(self) -> str: + if self.__cause__: + return f"{self.__class__.__name__}('{self}', caused by {type(self.__cause__).__name__}: {self.__cause__})" + return f"{self.__class__.__name__}('{self}')" + + +class BadRequestException(RESTException): + + def __init__(self, message: str = None, *args: Any): + super().__init__(message, *args) + + +class BadRequestException(RESTException): + """Exception for 400 Bad Request""" + pass + + +class NotAuthorizedException(RESTException): + """Exception for not authorized (401)""" + + def __init__(self, message: str, *args: Any): + super().__init__(message, *args) + + +class ForbiddenException(RESTException): + """Exception for forbidden access (403)""" + + def __init__(self, message: str, *args: Any): + super().__init__(message, *args) + + +class NoSuchResourceException(RESTException): + """Exception for resource not found (404)""" + + def __init__(self, resource_type: Optional[str], resource_name: Optional[str], + message: str, *args: Any): + self.resource_type = resource_type + self.resource_name = resource_name + super().__init__(message, *args) + + +class AlreadyExistsException(RESTException): + """Exception for resource already exists (409)""" + + def __init__(self, resource_type: Optional[str], resource_name: Optional[str], + message: str, *args: Any): + self.resource_type = resource_type + self.resource_name = resource_name + super().__init__(message, *args) + + +class ServiceFailureException(RESTException): + """Exception for service failure (500)""" + + def __init__(self, message: str, *args: Any): + super().__init__(message, *args) + + +class NotImplementedException(RESTException): + """Exception for not implemented (501)""" - def __init__(self, message: str, cause: Optional[Exception] = None): - super().__init__(message) - self.cause = cause + def __init__(self, message: str, *args: Any): + super().__init__(message, *args) + + +class ServiceUnavailableException(RESTException): + """Exception for service unavailable (503)""" + + def __init__(self, message: str, *args: Any): + super().__init__(message, *args) class ErrorHandler(ABC): @@ -51,24 +142,95 @@ class ErrorHandler(ABC): pass +# DefaultErrorHandler implementation class DefaultErrorHandler(ErrorHandler): + """ + Default error handler that converts error responses to appropriate exceptions. + + This class implements the singleton pattern and handles various HTTP error codes + by throwing corresponding exception types. + """ - _instance = None + _instance: Optional['DefaultErrorHandler'] = None + + def __new__(cls) -> 'DefaultErrorHandler': + """Implement singleton pattern""" + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance @classmethod def get_instance(cls) -> 'DefaultErrorHandler': + """Get the singleton instance of DefaultErrorHandler""" if cls._instance is None: cls._instance = cls() return cls._instance def accept(self, error: ErrorResponse, request_id: str) -> None: - message = f"REST API error (request_id: {request_id}): {error.message}" - if error.resource_name: - message += f" (resource: {error.resource_name})" - if error.resource_type: - message += f" (resource_type: {error.resource_type})" + """ + Handle an error response by throwing appropriate exception. + + Args: + error: The error response to handle + request_id: The request ID associated with the error - raise RESTException(message) + Raises: + Appropriate exception based on error code + """ + code = error.code + + # Format message with request ID if not default + if LoggingInterceptor.DEFAULT_REQUEST_ID == request_id: + message = error.message + else: + # If we have a requestId, append it to the message + message = f"{error.message} requestId:{request_id}" + + # Handle different error codes + if code == 400: + raise BadRequestException("%s", message) + + elif code == 401: + raise NotAuthorizedException("Not authorized: %s", message) + + elif code == 403: + raise ForbiddenException("Forbidden: %s", message) + + elif code == 404: + raise NoSuchResourceException( + error.resource_type, + error.resource_name, + "%s", + message + ) + + elif code in [405, 406]: + # These codes are handled but don't throw exceptions + pass + + elif code == 409: + raise AlreadyExistsException( + error.resource_type, + error.resource_name, + "%s", + message + ) + + elif code == 500: + raise ServiceFailureException("Server error: %s", message) + + elif code == 501: + raise NotImplementedException(message) + + elif code == 503: + raise ServiceUnavailableException("Service unavailable: %s", message) + + else: + # Default case for unhandled codes + pass + + # If no specific exception was thrown, throw generic RESTException + raise RESTException("Unable to process: %s", message) class ExponentialRetryInterceptor: @@ -97,7 +259,6 @@ class ExponentialRetryInterceptor: class LoggingInterceptor: - REQUEST_ID_KEY = "x-request-id" DEFAULT_REQUEST_ID = "unknown" @@ -250,17 +411,17 @@ class HttpClient(RESTClient): rest_auth_function: Callable[[RESTAuthParameter], Dict[str, str]]) -> T: try: body_str = JSON.to_json(body) - auth_headers = _get_headers(path, "POST", body_str, rest_auth_function) + auth_headers = _get_headers(path, "POST", None, body_str, rest_auth_function) url = self._get_request_url(path, None) - - return self._execute_request("POST", url, data=body_str, headers=auth_headers, - response_type=response_type) - except json.JSONEncodeError as e: - raise RESTException("build request failed.", e) + return self._execute_request("POST", url, data=body_str, headers=auth_headers, response_type=response_type) + except RESTException as e: + raise e + except Exception as e: + raise RESTException("build request failed.", cause=e) def delete(self, path: str, rest_auth_function: Callable[[RESTAuthParameter], Dict[str, str]]) -> T: - auth_headers = _get_headers(path, "DELETE", "", rest_auth_function) + auth_headers = _get_headers(path, "DELETE", None, "", rest_auth_function) url = self._get_request_url(path, None) return self._execute_request("DELETE", url, headers=auth_headers, response_type=None) @@ -327,7 +488,7 @@ class HttpClient(RESTClient): else: raise RESTException("response body is null.") - except RESTException: - raise + except RESTException as e: + raise e except Exception as e: - raise RESTException("rest exception", e) + raise RESTException("rest exception", cause=e) diff --git a/pypaimon/api/data_types.py b/pypaimon/api/data_types.py new file mode 100644 index 0000000000..9a8080fdab --- /dev/null +++ b/pypaimon/api/data_types.py @@ -0,0 +1,341 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +import threading +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Dict, Any, Optional, List, Union + + +class AtomicInteger: + + def __init__(self, initial_value: int = 0): + self._value = initial_value + self._lock = threading.RLock() + + def get(self) -> int: + with self._lock: + return self._value + + def increment_and_get(self) -> int: + with self._lock: + self._value += 1 + return self._value + + def get_and_increment(self) -> int: + with self._lock: + old_value = self._value + self._value += 1 + return old_value + + def set(self, value: int): + with self._lock: + self._value = value + + +class DataType(ABC): + def __init__(self, nullable: bool = True): + self.nullable = nullable + + @abstractmethod + def to_dict(self) -> Dict[str, Any]: + pass + + @abstractmethod + def __str__(self) -> str: + pass + + +@dataclass +class AtomicType(DataType): + type: str + + def __init__(self, type: str, nullable: bool = True): + super().__init__(nullable) + self.type = type + + def to_dict(self) -> Dict[str, Any]: + return { + 'type': self.type, + 'nullable': self.nullable + } + + def __str__(self) -> str: + null_suffix = '' if self.nullable else ' NOT NULL' + return f"{self.type}{null_suffix}" + + +@dataclass +class ArrayType(DataType): + element: DataType + + def __init__(self, nullable: bool, element_type: DataType): + super().__init__(nullable) + self.element = element_type + + def to_dict(self) -> Dict[str, Any]: + return { + 'type': f"ARRAY{'<' + str(self.element) + '>' if self.element else ''}", + 'element': self.element.to_dict() if self.element else None, + 'nullable': self.nullable + } + + def __str__(self) -> str: + null_suffix = '' if self.nullable else ' NOT NULL' + return f"ARRAY<{self.element}>{null_suffix}" + + +@dataclass +class MultisetType(DataType): + element: DataType + + def __init__(self, nullable: bool, element_type: DataType): + super().__init__(nullable) + self.element = element_type + + def to_dict(self) -> Dict[str, Any]: + return { + 'type': f"MULTISET{'<' + str(self.element) + '>' if self.element else ''}", + 'element': self.element.to_dict() if self.element else None, + 'nullable': self.nullable + } + + def __str__(self) -> str: + null_suffix = '' if self.nullable else ' NOT NULL' + return f"MULTISET<{self.element}>{null_suffix}" + + +@dataclass +class MapType(DataType): + key: DataType + value: DataType + + def __init__(self, nullable: bool, key_type: DataType, value_type: DataType): + super().__init__(nullable) + self.key = key_type + self.value = value_type + + def to_dict(self) -> Dict[str, Any]: + return { + 'type': f"MAP<{self.key}, {self.value}>", + 'key': self.key.to_dict() if self.key else None, + 'value': self.value.to_dict() if self.value else None, + 'nullable': self.nullable + } + + def __str__(self) -> str: + null_suffix = '' if self.nullable else ' NOT NULL' + return f"MAP<{self.key}, {self.value}>{null_suffix}" + + +@dataclass +class DataField: + FIELD_ID = "id" + FIELD_NAME = "name" + FIELD_TYPE = "type" + FIELD_DESCRIPTION = "description" + FIELD_DEFAULT_VALUE = "defaultValue" + + id: int + name: str + type: DataType + description: Optional[str] = None + default_value: Optional[str] = None + + def __init__(self, id: int, name: str, type: DataType, description: Optional[str] = None, + default_value: Optional[str] = None): + self.id = id + self.name = name + self.type = type + self.description = description + self.default_value = default_value + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'DataField': + return DataTypeParser.parse_data_field(data) + + def to_dict(self) -> Dict[str, Any]: + result = { + self.FIELD_ID: self.id, + self.FIELD_NAME: self.name, + self.FIELD_TYPE: self.type.to_dict() if self.type else None + } + + if self.description is not None: + result[self.FIELD_DESCRIPTION] = self.description + + if self.default_value is not None: + result[self.FIELD_DEFAULT_VALUE] = self.default_value + + return result + + +@dataclass +class RowType(DataType): + fields: List[DataField] + + def __init__(self, nullable: bool, fields: List[DataField]): + super().__init__(nullable) + self.fields = fields or [] + + def to_dict(self) -> Dict[str, Any]: + return { + 'type': 'ROW' + ('' if self.nullable else ' NOT NULL'), + 'fields': [field.to_dict() for field in self.fields], + 'nullable': self.nullable + } + + def __str__(self) -> str: + field_strs = [f"{field.name}: {field.type}" for field in self.fields] + null_suffix = '' if self.nullable else ' NOT NULL' + return f"ROW<{', '.join(field_strs)}>{null_suffix}" + + +class Keyword(Enum): + CHAR = "CHAR" + VARCHAR = "VARCHAR" + STRING = "STRING" + BOOLEAN = "BOOLEAN" + BINARY = "BINARY" + VARBINARY = "VARBINARY" + BYTES = "BYTES" + DECIMAL = "DECIMAL" + NUMERIC = "NUMERIC" + DEC = "DEC" + TINYINT = "TINYINT" + SMALLINT = "SMALLINT" + INT = "INT" + INTEGER = "INTEGER" + BIGINT = "BIGINT" + FLOAT = "FLOAT" + DOUBLE = "DOUBLE" + DATE = "DATE" + TIME = "TIME" + TIMESTAMP = "TIMESTAMP" + TIMESTAMP_LTZ = "TIMESTAMP_LTZ" + VARIANT = "VARIANT" + + +class DataTypeParser: + + @staticmethod + def parse_nullability(type_string: str) -> bool: + if ('NOT NULL' in type_string): + return False + elif ('NULL' in type_string): + return True + return True + + @staticmethod + def parse_atomic_type_sql_string(type_string: str) -> DataType: + type_upper = type_string.upper().strip() + + if '(' in type_upper: + base_type = type_upper.split('(')[0] + else: + base_type = type_upper + + try: + Keyword(base_type) + return AtomicType(type_string, DataTypeParser.parse_nullability(type_string)) + except ValueError: + raise Exception(f"Unknown type: {base_type}") + + @staticmethod + def parse_data_type(json_data: Union[Dict[str, Any], str], field_id: Optional[AtomicInteger] = None) -> DataType: + + if isinstance(json_data, str): + return DataTypeParser.parse_atomic_type_sql_string(json_data) + + if isinstance(json_data, dict): + if 'type' not in json_data: + raise ValueError(f"Missing 'type' field in JSON: {json_data}") + + type_string = json_data['type'] + + if type_string.startswith("ARRAY"): + element = DataTypeParser.parse_data_type(json_data.get('element'), field_id) + nullable = 'NOT NULL' not in type_string + return ArrayType(nullable, element) + + elif type_string.startswith("MULTISET"): + element = DataTypeParser.parse_data_type(json_data.get('element'), field_id) + nullable = 'NOT NULL' not in type_string + return MultisetType(nullable, element) + + elif type_string.startswith("MAP"): + key = DataTypeParser.parse_data_type(json_data.get('key'), field_id) + value = DataTypeParser.parse_data_type(json_data.get('value'), field_id) + nullable = 'NOT NULL' not in type_string + return MapType(nullable, key, value) + + elif type_string.startswith("ROW"): + field_array = json_data.get('fields', []) + fields = [] + for field_json in field_array: + fields.append(DataTypeParser.parse_data_field(field_json, field_id)) + nullable = 'NOT NULL' not in type_string + return RowType(nullable, fields) + + else: + return DataTypeParser.parse_atomic_type_sql_string(type_string) + + raise ValueError(f"Cannot parse data type: {json_data}") + + @staticmethod + def parse_data_field(json_data: Dict[str, Any], field_id: Optional[AtomicInteger] = None) -> DataField: + + if DataField.FIELD_ID in json_data and json_data[DataField.FIELD_ID] is not None: + if field_id is not None and field_id.get() != -1: + raise ValueError("Partial field id is not allowed.") + field_id_value = int(json_data['id']) + else: + if field_id is None: + raise ValueError("Field ID is required when not provided in JSON") + field_id_value = field_id.increment_and_get() + + if DataField.FIELD_NAME not in json_data: + raise ValueError("Missing 'name' field in JSON") + name = json_data[DataField.FIELD_NAME] + + if DataField.FIELD_TYPE not in json_data: + raise ValueError("Missing 'type' field in JSON") + data_type = DataTypeParser.parse_data_type(json_data[DataField.FIELD_TYPE], field_id) + + description = json_data.get(DataField.FIELD_DESCRIPTION) + + default_value = json_data.get(DataField.FIELD_DEFAULT_VALUE) + + return DataField( + id=field_id_value, + name=name, + type=data_type, + description=description, + default_value=default_value + ) + + +def parse_data_type_from_json(json_str: str, field_id: Optional[AtomicInteger] = None) -> DataType: + json_data = json.loads(json_str) + return DataTypeParser.parse_data_type(json_data, field_id) + + +def parse_data_field_from_json(json_str: str, field_id: Optional[AtomicInteger] = None) -> DataField: + json_data = json.loads(json_str) + return DataTypeParser.parse_data_field(json_data, field_id) diff --git a/pypaimon/api/rest_json.py b/pypaimon/api/rest_json.py index 57bdefdd06..c6a61449d2 100644 --- a/pypaimon/api/rest_json.py +++ b/pypaimon/api/rest_json.py @@ -16,36 +16,69 @@ # under the License. import json -from dataclasses import asdict -from typing import Any, Type +from dataclasses import field, fields, is_dataclass +from typing import Any, Type, Dict + from api.typedef import T +def json_field(json_name: str, **kwargs): + """Create a field with custom JSON name""" + return field(metadata={'json_name': json_name}, **kwargs) + + class JSON: - """Universal JSON serializer""" + + @staticmethod + def to_json(obj: Any, **kwargs) -> str: + """Convert to JSON string""" + return json.dumps(JSON.__to_dict(obj), ensure_ascii=False, **kwargs) @staticmethod def from_json(json_str: str, target_class: Type[T]) -> T: + """Create instance from JSON string""" data = json.loads(json_str) - if hasattr(target_class, 'from_dict'): - return target_class.from_dict(data) - return data + return JSON.__from_dict(data, target_class) @staticmethod - def to_json(obj: Any) -> str: - """Serialize any object to JSON""" - return json.dumps(obj, default=JSON._default_serializer) + def __to_dict(obj: Any) -> Dict[str, Any]: + """Convert to dictionary with custom field names""" + result = {} + for field_info in fields(obj): + field_value = getattr(obj, field_info.name) - @staticmethod - def _default_serializer(obj): - """Default serialization handler""" + # Get custom JSON name from metadata + json_name = field_info.metadata.get('json_name', field_info.name) + + # Handle nested objects + if is_dataclass(field_value): + result[json_name] = JSON.__to_dict(field_value) + elif hasattr(field_value, 'to_dict'): + result[json_name] = field_value.to_dict() + elif isinstance(field_value, list): + result[json_name] = [ + item.to_dict() if hasattr(item, 'to_dict') else item + for item in field_value + ] + else: + result[json_name] = field_value - # Handle objects with to_dict method - if hasattr(obj, 'to_dict') and callable(obj.to_dict): - return obj.to_dict() + return result + + @staticmethod + def __from_dict(data: Dict[str, Any], target_class: Type[T]) -> T: + """Create instance from dictionary""" + # Create field name mapping (json_name -> field_name) + field_mapping = {} + for field_info in fields(target_class): + json_name = field_info.metadata.get('json_name', field_info.name) + field_mapping[json_name] = field_info.name - # Handle dataclass objects - if hasattr(obj, '__dataclass_fields__'): - return asdict(obj) + # Map JSON data to field names + kwargs = {} + for json_name, value in data.items(): + if json_name in field_mapping: + field_name = field_mapping[json_name] + kwargs[field_name] = value - raise TypeError(f"Object of type {type(obj).__name__} is not JSON") \ No newline at end of file + return target_class(**kwargs) diff --git a/pypaimon/api/typedef.py b/pypaimon/api/typedef.py index 157fa9c1d6..bb6ce4cef4 100644 --- a/pypaimon/api/typedef.py +++ b/pypaimon/api/typedef.py @@ -14,11 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from dataclasses import dataclass from typing import Optional, TypeVar T = TypeVar('T') + @dataclass class Identifier: """Table/View/Function identifier""" @@ -58,4 +60,4 @@ class Identifier: return self.branch_name def is_system_table(self) -> bool: - return self.object_name.startswith('$') \ No newline at end of file + return self.object_name.startswith('$') diff --git a/pypaimon/tests/api_test.py b/pypaimon/tests/api_test.py index 577b9ca94c..74a5c44436 100644 --- a/pypaimon/tests/api_test.py +++ b/pypaimon/tests/api_test.py @@ -27,11 +27,11 @@ import unittest import api from api.api_response import (ConfigResponse, ListDatabasesResponse, GetDatabaseResponse, TableMetadata, Schema, - GetTableResponse, ListTablesResponse, TableSchema, RESTResponse, PagedList, DataField, - PaimonDataType) + GetTableResponse, ListTablesResponse, TableSchema, RESTResponse, PagedList, DataField) from api import RESTApi from api.rest_json import JSON from api.typedef import Identifier +from api.data_types import AtomicInteger, DataTypeParser, AtomicType, ArrayType, MapType, RowType @dataclass @@ -743,7 +743,94 @@ class RESTCatalogServer: class ApiTestCase(unittest.TestCase): - def test(self): + def test_parse_data(self): + simple_type_test_cases = [ + "DECIMAL", + "DECIMAL(5)", + "DECIMAL(10, 2)", + "DECIMAL(38, 18)", + "VARBINARY", + "VARBINARY(100)", + "VARBINARY(1024)", + "BYTES", + "VARCHAR(255)", + "CHAR(10)", + "INT", + "BOOLEAN" + ] + for type_str in simple_type_test_cases: + data_type = DataTypeParser.parse_data_type(type_str) + self.assertEqual(data_type.nullable, True) + self.assertEqual(data_type.type, type_str) + field_id = AtomicInteger(0) + simple_type = DataTypeParser.parse_data_type("VARCHAR(32)") + self.assertEqual(simple_type.nullable, True) + self.assertEqual(simple_type.type, 'VARCHAR(32)') + + array_json = { + "type": "ARRAY", + "element": "INT" + } + array_type = DataTypeParser.parse_data_type(array_json, field_id) + self.assertEqual(array_type.element.type, 'INT') + + map_json = { + "type": "MAP", + "key": "STRING", + "value": "INT" + } + map_type = DataTypeParser.parse_data_type(map_json, field_id) + self.assertEqual(map_type.key.type, 'STRING') + self.assertEqual(map_type.value.type, 'INT') + row_json = { + "type": "ROW", + "fields": [ + { + "name": "id", + "type": "BIGINT", + "description": "Primary key" + }, + { + "name": "name", + "type": "VARCHAR(100)", + "description": "User name" + }, + { + "name": "scores", + "type": { + "type": "ARRAY", + "element": "DOUBLE" + } + } + ] + } + + row_type: RowType = DataTypeParser.parse_data_type(row_json, AtomicInteger(0)) + self.assertEqual(row_type.fields[0].type.type, 'BIGINT') + self.assertEqual(row_type.fields[1].type.type, 'VARCHAR(100)') + + complex_json = { + "type": "ARRAY", + "element": { + "type": "MAP", + "key": "STRING", + "value": { + "type": "ROW", + "fields": [ + {"name": "count", "type": "BIGINT"}, + {"name": "percentage", "type": "DOUBLE"} + ] + } + } + } + + complex_type: ArrayType = DataTypeParser.parse_data_type(complex_json, field_id) + element_type: MapType = complex_type.element + value_type: RowType = element_type.value + self.assertEqual(value_type.fields[0].type.type, 'BIGINT') + self.assertEqual(value_type.fields[1].type.type, 'DOUBLE') + + def test_api(self): """Example usage of RESTCatalogServer""" # Setup logging logging.basicConfig(level=logging.INFO) @@ -773,11 +860,14 @@ class ApiTestCase(unittest.TestCase): "test_db2": server.mock_database("test_db2", {"env": "test"}), "prod_db": server.mock_database("prod_db", {"env": "prod"}) } + data_fields = [ + DataField( 0, "name", AtomicType('INT'), 'desc name'), + DataField( 1, "arr11", ArrayType(True, AtomicType('INT')), 'desc arr11'), + DataField( 2, "map11", MapType(False, AtomicType('INT'), MapType(False, AtomicType('INT'), AtomicType('INT'))), 'desc arr11'), + ] + schema = TableSchema(len(data_fields), data_fields, len(data_fields), [], [], {}, "") test_tables = { - "default.user": TableMetadata(uuid=str(uuid.uuid4()), is_external=True, - schema=TableSchema(1, - [DataField("name", 0, "name", PaimonDataType('int'))], - 1, [], [], {}, "")), + "default.user": TableMetadata(uuid=str(uuid.uuid4()), is_external=True,schema=schema), } server.table_metadata_store.update(test_tables) server.database_store.update(test_databases)