Copilot commented on code in PR #62232:
URL: https://github.com/apache/airflow/pull/62232#discussion_r2849613089


##########
providers/common/ai/src/airflow/providers/common/ai/datafusion/format_handlers.py:
##########
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.common.ai.config import FormatType
+from airflow.providers.common.ai.datafusion.base import FormatHandler
+from airflow.providers.common.ai.exceptions import 
FileFormatRegistrationException
+
+if TYPE_CHECKING:
+    from datafusion import SessionContext
+
+
+class ParquetFormatHandler(FormatHandler):
+    """
+    Parquet format handler.
+
+    :param options: Additional options for the Parquet format.
+        
https://datafusion.apache.org/python/autoapi/datafusion/context/index.html#datafusion.context.SessionContext.register_parquet
+    """
+
+    def __init__(self, options: dict[str, Any] | None = None):
+        self.options = options or {}
+
+    @property
+    def get_format(self) -> str:
+        """Return the format type."""
+        return FormatType.PARQUET.value
+
+    def register_data_source_format(self, ctx: SessionContext, table_name: 
str, path: str):
+        """Register a data source format."""
+        try:
+            ctx.register_parquet(table_name, path, **self.options)
+        except Exception as e:
+            raise FileFormatRegistrationException("Failed to register Parquet 
data source: %s", e)
+

Review Comment:
   `raise FileFormatRegistrationException("Failed to register Parquet data 
source: %s", e)` passes multiple args to the exception (resulting in an odd 
tuple-like message) and drops exception chaining. Prefer interpolating the 
message (or using `from e`) so the resulting exception message is correct and 
the original stack trace is preserved.



##########
providers/common/ai/src/airflow/providers/common/ai/datafusion/object_storage_provider.py:
##########
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datafusion.object_store import AmazonS3, LocalFileSystem
+
+from airflow.providers.common.ai.config import ConnectionConfig, StorageType
+from airflow.providers.common.ai.datafusion.base import ObjectStorageProvider
+from airflow.providers.common.ai.exceptions import ObjectStoreCreationException
+
+
+class S3ObjectStorageProvider(ObjectStorageProvider):
+    """S3 Object Storage Provider using DataFusion's AmazonS3."""
+
+    def get_storage_type(self) -> str:
+        """Return the storage type."""
+        return StorageType.S3.value
+
+    def create_object_store(self, path: str, connection_config: 
ConnectionConfig | None = None):
+        """Create an S3 object store using DataFusion's AmazonS3."""
+        if connection_config is None:
+            raise ValueError("connection_config must be provided")
+
+        try:
+            credentials = connection_config.credentials
+            bucket = self.get_bucket(path)
+
+            s3_store = AmazonS3(**credentials, 
**connection_config.extra_config, bucket_name=bucket)
+            self.log.info("Created S3 object store for bucket %s", bucket)
+
+            return s3_store
+
+        except Exception as e:
+            raise ObjectStoreCreationException("Failed to create S3 object 
store", e)

Review Comment:
   `raise ObjectStoreCreationException("Failed to create S3 object store", e)` 
passes the original exception as a second positional arg, producing an odd 
message and not chaining the exception. Prefer `raise 
ObjectStoreCreationException(f"Failed to create S3 object store: {e}") from e` 
(or similar) so the message and traceback are correct.
   ```suggestion
               raise ObjectStoreCreationException(f"Failed to create S3 object 
store: {e}") from e
   ```



##########
providers/common/ai/src/airflow/providers/common/ai/datafusion/object_storage_provider.py:
##########
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datafusion.object_store import AmazonS3, LocalFileSystem
+
+from airflow.providers.common.ai.config import ConnectionConfig, StorageType
+from airflow.providers.common.ai.datafusion.base import ObjectStorageProvider
+from airflow.providers.common.ai.exceptions import ObjectStoreCreationException
+
+
+class S3ObjectStorageProvider(ObjectStorageProvider):
+    """S3 Object Storage Provider using DataFusion's AmazonS3."""
+
+    def get_storage_type(self) -> str:
+        """Return the storage type."""
+        return StorageType.S3.value
+
+    def create_object_store(self, path: str, connection_config: 
ConnectionConfig | None = None):
+        """Create an S3 object store using DataFusion's AmazonS3."""
+        if connection_config is None:
+            raise ValueError("connection_config must be provided")
+
+        try:
+            credentials = connection_config.credentials
+            bucket = self.get_bucket(path)
+
+            s3_store = AmazonS3(**credentials, 
**connection_config.extra_config, bucket_name=bucket)
+            self.log.info("Created S3 object store for bucket %s", bucket)
+
+            return s3_store
+
+        except Exception as e:
+            raise ObjectStoreCreationException("Failed to create S3 object 
store", e)
+
+    def get_scheme(self) -> str:
+        """Return the scheme for S3."""
+        return "s3://"
+
+
+class LocalObjectStorageProvider(ObjectStorageProvider):
+    """Local Object Storage Provider using DataFusion's LocalFileSystem."""
+
+    def get_storage_type(self) -> str:
+        """Return the storage type."""
+        return StorageType.LOCAL.value
+
+    def create_object_store(self, path: str, connection_config: 
ConnectionConfig | None = None):
+        """Create a Local object store."""
+        return LocalFileSystem()
+
+    def get_scheme(self) -> str:
+        """Return the scheme to a Local file system."""
+        return "file://"
+
+
+class ObjectStorageProviderFactory:
+    """Factory to create object storage providers based on storage type."""
+
+    # TODO: Add support for GCS, Azure, HTTP: 
https://datafusion.apache.org/python/autoapi/datafusion/object_store/index.html
+    _providers: dict[str, type] = {
+        StorageType.S3: S3ObjectStorageProvider,
+        StorageType.LOCAL: LocalObjectStorageProvider,
+    }
+
+    @classmethod
+    def create_provider(cls, storage_type: StorageType) -> 
ObjectStorageProvider:
+        """Create a storage provider instance."""
+        if storage_type not in cls._providers:
+            raise ValueError(
+                f"Unsupported storage type: {storage_type}. Supported types: 
{list(cls._providers.keys())}"

Review Comment:
   `ObjectStorageProviderFactory._providers` is typed as `dict[str, type]` but 
keyed with `StorageType` enum members, and `create_provider()` only accepts 
`StorageType`. Since callers/tests may reasonably pass strings (e.g. from 
config/YAML), consider accepting `str | StorageType` and normalizing (or change 
the mapping keys to `StorageType` consistently and update tests/type hints).
   ```suggestion
       _providers: dict[StorageType, type[ObjectStorageProvider]] = {
           StorageType.S3: S3ObjectStorageProvider,
           StorageType.LOCAL: LocalObjectStorageProvider,
       }
   
       @classmethod
       def create_provider(cls, storage_type: StorageType | str) -> 
ObjectStorageProvider:
           """Create a storage provider instance."""
           from airflow.providers.common.ai.config import StorageType as 
_StorageType
   
           if isinstance(storage_type, str):
               try:
                   storage_type = _StorageType(storage_type)
               except ValueError as e:
                   supported = [st.value for st in _StorageType]
                   raise ValueError(
                       f"Unsupported storage type: {storage_type}. Supported 
types: {supported}"
                   ) from e
   
           if storage_type not in cls._providers:
               supported = [st.value for st in cls._providers.keys()]
               raise ValueError(
                   f"Unsupported storage type: {storage_type}. Supported types: 
{supported}"
   ```



##########
providers/common/ai/src/airflow/providers/common/ai/utils/mixins.py:
##########
@@ -0,0 +1,92 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.common.ai.config import ConnectionConfig
+from airflow.providers.common.compat.sdk import BaseHook, Connection
+
+if TYPE_CHECKING:
+    from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+log = logging.getLogger(__name__)
+
+
+class CommonAIHookMixin:
+    """Mixin for Common AI."""
+
+    def get_db_api_hook(self, conn_id: str) -> DbApiHook:
+        """Get the given connection's database hook."""
+        connection = BaseHook.get_connection(conn_id)
+        return connection.get_hook()
+
+    def get_conn_config_from_airflow_connection(self, conn_id: str) -> 
ConnectionConfig:
+        """Get connection configuration from Airflow connection."""
+        try:
+            airflow_conn = BaseHook.get_connection(conn_id)
+
+            config = self._convert_airflow_connection(airflow_conn)
+
+            log.info("Loaded connection config for: %s", conn_id)
+            return config
+
+        except Exception as e:
+            log.error("Failed to get connection config for %s: %s", conn_id, e)
+            raise
+
+    def _convert_airflow_connection(self, conn: Connection) -> 
ConnectionConfig:
+        """Convert Airflow connection to ConnectionConfig."""
+        credentials = self._get_credentials(conn)
+
+        extra_config = conn.extra_dejson if conn.extra else {}
+
+        return ConnectionConfig(
+            conn_id=conn.conn_id,
+            credentials=credentials,
+            extra_config=extra_config,
+        )
+
+    @classmethod
+    def remove_none_values(cls, params: dict[str, Any]) -> dict[str, Any]:
+        """Filter out None values from the dictionary."""
+        return {k: v for k, v in params.items() if v is not None}
+
+    def _get_credentials(self, conn: Connection) -> dict[str, Any]:
+        from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
+
+        credentials = {}
+
+        match conn.conn_type:
+            case "aws":
+                s3_conn: AwsGenericHook = 
AwsGenericHook(aws_conn_id=conn.conn_id, client_type="s3")
+                creds = s3_conn.get_credentials()

Review Comment:
   `AwsGenericHook` is imported unconditionally inside `_get_credentials()`. If 
users install this provider without the `amazon` extra, using an `aws` 
connection will raise a raw `ImportError`. Consider catching `ImportError` here 
and raising `AirflowOptionalProviderFeatureException` with a clear message 
about installing `apache-airflow-providers-common-ai[amazon]` (or 
`apache-airflow-providers-amazon`).



##########
providers/common/ai/src/airflow/providers/common/ai/operators/analytics.py:
##########
@@ -0,0 +1,163 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+from collections.abc import Sequence
+from functools import cached_property
+from typing import TYPE_CHECKING, Any, Literal
+
+from airflow.providers.common.ai.datafusion.engine import DataFusionEngine
+from airflow.providers.common.ai.utils.mixins import CommonAIHookMixin
+from airflow.sdk import BaseOperator, Context
+
+if TYPE_CHECKING:
+    from airflow.providers.common.ai.config import DataSourceConfig
+
+
+class AnalyticsOperator(BaseOperator, CommonAIHookMixin):
+    """
+    Operator to run queries on various datasource's stored in object stores 
like S3, GCS, Azure, etc.
+
+    :param datasource_configs: List of datasource configurations to register.
+    :param queries: List of SQL queries to execute.
+    :param max_rows_check: Maximum number of rows allowed in query results. 
Queries exceeding this will be skipped.
+    :param engine: Optional DataFusion engine instance.
+    :param result_output_format: List of output formats for results. 
Supported: 'tabulate', 'json'. Default is 'tabulate'.
+    """
+
+    template_fields: Sequence[str] = (
+        "datasource_configs",
+        "queries",
+        "max_rows_check",
+        "result_output_format",
+    )
+
+    def __init__(
+        self,
+        datasource_configs: list[DataSourceConfig],
+        queries: list[str],
+        max_rows_check: int = 100,
+        engine: DataFusionEngine | None = None,
+        result_output_format: Literal["tabulate", "json"] = "tabulate",
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.datasource_configs = datasource_configs
+        self.queries = queries
+        self.engine = engine
+        self.max_rows_check = max_rows_check
+        self.result_output_format = result_output_format
+
+    @cached_property
+    def _df_engine(self):
+        if self.engine is None:
+            return DataFusionEngine()
+        return self.engine
+
+    def execute(self, context: Context) -> str:
+
+        results = []
+        for datasource_config in self.datasource_configs:
+            connection_config = 
self.get_conn_config_from_airflow_connection(datasource_config.conn_id)

Review Comment:
   `execute()` unconditionally calls 
`get_conn_config_from_airflow_connection(datasource_config.conn_id)`. For local 
filesystem examples/tests you pass `conn_id=""`, which will call 
`BaseHook.get_connection("")` and fail at runtime. Consider treating a falsy 
`conn_id` as “no connection needed” and using an empty `ConnectionConfig` (or 
only resolving connections for non-local/object-store types that require it).
   ```suggestion
               # Treat a falsy conn_id (e.g. "", None) as "no connection needed"
               if getattr(datasource_config, "conn_id", None):
                   connection_config = 
self.get_conn_config_from_airflow_connection(datasource_config.conn_id)
               else:
                   connection_config = None
   ```



##########
providers/common/ai/src/airflow/providers/common/ai/operators/analytics.py:
##########
@@ -0,0 +1,163 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+from collections.abc import Sequence
+from functools import cached_property
+from typing import TYPE_CHECKING, Any, Literal
+
+from airflow.providers.common.ai.datafusion.engine import DataFusionEngine
+from airflow.providers.common.ai.utils.mixins import CommonAIHookMixin
+from airflow.sdk import BaseOperator, Context
+
+if TYPE_CHECKING:
+    from airflow.providers.common.ai.config import DataSourceConfig
+
+
+class AnalyticsOperator(BaseOperator, CommonAIHookMixin):
+    """
+    Operator to run queries on various datasource's stored in object stores 
like S3, GCS, Azure, etc.
+
+    :param datasource_configs: List of datasource configurations to register.
+    :param queries: List of SQL queries to execute.
+    :param max_rows_check: Maximum number of rows allowed in query results. 
Queries exceeding this will be skipped.
+    :param engine: Optional DataFusion engine instance.
+    :param result_output_format: List of output formats for results. 
Supported: 'tabulate', 'json'. Default is 'tabulate'.
+    """
+
+    template_fields: Sequence[str] = (
+        "datasource_configs",
+        "queries",
+        "max_rows_check",
+        "result_output_format",
+    )
+
+    def __init__(
+        self,
+        datasource_configs: list[DataSourceConfig],
+        queries: list[str],
+        max_rows_check: int = 100,
+        engine: DataFusionEngine | None = None,
+        result_output_format: Literal["tabulate", "json"] = "tabulate",
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.datasource_configs = datasource_configs
+        self.queries = queries
+        self.engine = engine
+        self.max_rows_check = max_rows_check
+        self.result_output_format = result_output_format
+
+    @cached_property
+    def _df_engine(self):
+        if self.engine is None:
+            return DataFusionEngine()
+        return self.engine
+
+    def execute(self, context: Context) -> str:
+
+        results = []
+        for datasource_config in self.datasource_configs:
+            connection_config = 
self.get_conn_config_from_airflow_connection(datasource_config.conn_id)
+            self._df_engine.register_datasource(datasource_config, 
connection_config)
+
+        # TODO make it parallel as there is no dependency between queries
+        for query in self.queries:
+            result_dict = self._df_engine.execute_query(query)
+            results.append({"query": query, "data": result_dict})
+
+        match self.result_output_format:
+            case "tabulate":
+                return self._build_tabulate_output(results)
+            case "json":
+                return self._build_json_output(results)
+            case _:
+                raise ValueError(f"Unsupported output format: 
{self.result_output_format}")
+
+    def _is_result_too_large(self, result_dict: dict[str, Any]) -> tuple[bool, 
int]:
+        """Check if a result exceeds the max_rows_check limit."""
+        if not result_dict:
+            return False, 0
+        num_rows = len(next(iter(result_dict.values())))
+        max_rows_exceeded = num_rows >= self.max_rows_check
+        if max_rows_exceeded:
+            self.log.warning(
+                "Query returned %s rows, exceeding max_rows_check (%s). 
Skipping result output as large datasets are unsuitable for return.",
+                num_rows,
+                self.max_rows_check,
+            )
+        return max_rows_exceeded, num_rows

Review Comment:
   `_is_result_too_large()` treats a result with exactly `max_rows_check` rows 
as too large (`>=`). The docs say “more than this number of rows” and the name 
`max_rows_check` suggests the limit should be inclusive. Consider changing the 
condition to `num_rows > self.max_rows_check` (and keep the warning/message 
consistent).



##########
providers/common/ai/src/airflow/providers/common/ai/datafusion/format_handlers.py:
##########
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.common.ai.config import FormatType
+from airflow.providers.common.ai.datafusion.base import FormatHandler
+from airflow.providers.common.ai.exceptions import 
FileFormatRegistrationException
+
+if TYPE_CHECKING:
+    from datafusion import SessionContext
+
+
+class ParquetFormatHandler(FormatHandler):
+    """
+    Parquet format handler.
+
+    :param options: Additional options for the Parquet format.
+        
https://datafusion.apache.org/python/autoapi/datafusion/context/index.html#datafusion.context.SessionContext.register_parquet
+    """
+
+    def __init__(self, options: dict[str, Any] | None = None):
+        self.options = options or {}
+
+    @property
+    def get_format(self) -> str:
+        """Return the format type."""
+        return FormatType.PARQUET.value
+
+    def register_data_source_format(self, ctx: SessionContext, table_name: 
str, path: str):
+        """Register a data source format."""
+        try:
+            ctx.register_parquet(table_name, path, **self.options)
+        except Exception as e:
+            raise FileFormatRegistrationException("Failed to register Parquet 
data source: %s", e)
+
+
+class CsvFormatHandler(FormatHandler):
+    """
+    CSV format handler.
+
+    :param options: Additional options for the CSV format.
+        
https://datafusion.apache.org/python/autoapi/datafusion/context/index.html#datafusion.context.SessionContext.register_csv
+    """
+
+    def __init__(self, options: dict[str, Any] | None = None):
+        self.options = options or {}
+
+    @property
+    def get_format(self) -> str:
+        """Return the format type."""
+        return FormatType.CSV.value
+
+    def register_data_source_format(self, ctx: SessionContext, table_name: 
str, path: str):
+        """Register a data source format."""
+        try:
+            ctx.register_csv(table_name, path, **self.options)
+        except Exception as e:
+            raise FileFormatRegistrationException("Failed to register csv data 
source: %s", e)
+
+
+class AvroFormatHandler(FormatHandler):
+    """
+    Avro format handler.
+
+    :param options: Additional options for the Avro format.
+        
https://datafusion.apache.org/python/autoapi/datafusion/context/index.html#datafusion.context.SessionContext.register_avro
+    """
+
+    def __init__(self, options: dict[str, Any] | None = None):
+        self.options = options or {}
+
+    @property
+    def get_format(self) -> str:
+        """Return the format type."""
+        return FormatType.AVRO.value
+
+    def register_data_source_format(self, ctx: SessionContext, table_name: 
str, path: str) -> None:
+        """Register a data source format."""
+        try:
+            ctx.register_avro(table_name, path, **self.options)
+        except Exception as e:
+            raise FileFormatRegistrationException("Failed to register Avro 
data source: %s", e)
+

Review Comment:
   Same issue as above: `FileFormatRegistrationException("Failed to register 
Avro data source: %s", e)` passes multiple args and loses exception chaining. 
Prefer a formatted message and `raise ... from e`.



##########
providers/common/ai/tests/unit/common/ai/datafusion/test_object_storage_provider.py:
##########
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import patch
+
+import pytest
+
+from airflow.providers.common.ai.config import ConnectionConfig, StorageType
+from airflow.providers.common.ai.datafusion.object_storage_provider import (
+    LocalObjectStorageProvider,
+    ObjectStorageProviderFactory,
+    S3ObjectStorageProvider,
+)
+from airflow.providers.common.ai.exceptions import ObjectStoreCreationException
+
+
+class TestObjectStorageProvider:
+    
@patch("airflow.providers.common.ai.datafusion.object_storage_provider.AmazonS3")
+    def test_s3_provider_success(self, mock_s3):
+        provider = S3ObjectStorageProvider()
+        connection_config = ConnectionConfig(
+            conn_id="aws_default",
+            credentials={"access_key_id": "fake_key", "secret_access_key": 
"fake_secret"},
+        )
+
+        store = provider.create_object_store("s3://demo-data/path", 
connection_config)
+
+        mock_s3.assert_called_once_with(
+            access_key_id="fake_key", secret_access_key="fake_secret", 
bucket_name="demo-data"
+        )
+        assert store == mock_s3.return_value
+        assert provider.get_storage_type() == StorageType.S3.value
+        assert provider.get_scheme() == "s3://"
+
+    def test_s3_provider_failure(self):
+        provider = S3ObjectStorageProvider()
+        connection_config = ConnectionConfig(conn_id="aws_default")
+
+        with patch(
+            
"airflow.providers.common.ai.datafusion.object_storage_provider.AmazonS3",
+            side_effect=Exception("Error"),
+        ):
+            with pytest.raises(ObjectStoreCreationException, match="Failed to 
create S3 object store"):
+                provider.create_object_store("s3://demo-data/path", 
connection_config)
+
+    
@patch("airflow.providers.common.ai.datafusion.object_storage_provider.LocalFileSystem")
+    def test_local_provider(self, mock_local):
+        provider = LocalObjectStorageProvider()
+        assert provider.get_storage_type() == StorageType.LOCAL.value
+        assert provider.get_scheme() == "file://"
+        local_store = provider.create_object_store("file://path")
+        assert local_store == mock_local.return_value
+
+    def test_factory_create_provider(self):
+        assert isinstance(
+            
ObjectStorageProviderFactory.create_provider(StorageType.S3.value), 
S3ObjectStorageProvider
+        )
+        assert isinstance(
+            
ObjectStorageProviderFactory.create_provider(StorageType.LOCAL.value), 
LocalObjectStorageProvider

Review Comment:
   `ObjectStorageProviderFactory.create_provider()` is called here with 
`StorageType.S3.value` / `StorageType.LOCAL.value` (strings), but the factory 
implementation expects a `StorageType` enum and its `_providers` keys are 
`StorageType` members. This test will fail. Pass the enum values directly (e.g. 
`StorageType.S3`) or update the factory to accept strings and normalize to 
`StorageType`.
   ```suggestion
               ObjectStorageProviderFactory.create_provider(StorageType.S3), 
S3ObjectStorageProvider
           )
           assert isinstance(
               ObjectStorageProviderFactory.create_provider(StorageType.LOCAL), 
LocalObjectStorageProvider
   ```



##########
airflow-core/tests/unit/always/test_project_structure.py:
##########
@@ -98,6 +98,8 @@ def test_providers_modules_should_have_tests(self):
             
"providers/cncf/kubernetes/tests/unit/cncf/kubernetes/utils/test_delete_from.py",
             
"providers/cncf/kubernetes/tests/unit/cncf/kubernetes/utils/test_k8s_hashlib_wrapper.py",
             
"providers/cncf/kubernetes/tests/unit/cncf/kubernetes/utils/test_xcom_sidecar.py",
+            "providers/common/ai/tests/unit/common/ai/datafusion/test_base.py",
+            "providers/common/ai/tests/unit/common/ai/test_exceptions.py",

Review Comment:
   The project structure test now expects 
`providers/common/ai/tests/unit/common/ai/datafusion/test_base.py` and 
`providers/common/ai/tests/unit/common/ai/test_exceptions.py`, but those test 
files are not present in this PR. This will make 
`test_providers_modules_should_have_tests` fail. Either add the missing tests 
or remove these entries (and list the actual test files that exist).
   ```suggestion
   
   ```



##########
providers/common/ai/pyproject.toml:
##########
@@ -59,13 +59,30 @@ requires-python = ">=3.10"
 # After you modify the dependencies, and rebuild your Breeze CI image with 
``breeze ci-image build``
 dependencies = [
     "apache-airflow>=3.0.0",
+    "datafusion>=50.0.0"
+]
+
+# The optional dependencies should be modified in place in the generated file
+# Any change in the dependencies is preserved when the file is regenerated
+[project.optional-dependencies]
+"amazon" = [
+    "apache-airflow-providers-amazon"
+]
+"common.compat" = [
+    "apache-airflow-providers-common-compat"
+]
+"common.sql" = [
+    "apache-airflow-providers-common-sql"
 ]

Review Comment:
   `airflow.providers.common.compat.sdk` is imported by this provider at 
runtime (e.g. in `utils/mixins.py` and `exceptions.py`), so 
`apache-airflow-providers-common-compat` must be a required dependency (as in 
other providers), not only an optional extra. Without it, importing the 
provider will fail for users who install `apache-airflow-providers-common-ai` 
without extras.



##########
providers/common/ai/src/airflow/providers/common/ai/datafusion/format_handlers.py:
##########
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.common.ai.config import FormatType
+from airflow.providers.common.ai.datafusion.base import FormatHandler
+from airflow.providers.common.ai.exceptions import 
FileFormatRegistrationException
+
+if TYPE_CHECKING:
+    from datafusion import SessionContext
+
+
+class ParquetFormatHandler(FormatHandler):
+    """
+    Parquet format handler.
+
+    :param options: Additional options for the Parquet format.
+        
https://datafusion.apache.org/python/autoapi/datafusion/context/index.html#datafusion.context.SessionContext.register_parquet
+    """
+
+    def __init__(self, options: dict[str, Any] | None = None):
+        self.options = options or {}
+
+    @property
+    def get_format(self) -> str:
+        """Return the format type."""
+        return FormatType.PARQUET.value
+
+    def register_data_source_format(self, ctx: SessionContext, table_name: 
str, path: str):
+        """Register a data source format."""
+        try:
+            ctx.register_parquet(table_name, path, **self.options)
+        except Exception as e:
+            raise FileFormatRegistrationException("Failed to register Parquet 
data source: %s", e)
+
+
+class CsvFormatHandler(FormatHandler):
+    """
+    CSV format handler.
+
+    :param options: Additional options for the CSV format.
+        
https://datafusion.apache.org/python/autoapi/datafusion/context/index.html#datafusion.context.SessionContext.register_csv
+    """
+
+    def __init__(self, options: dict[str, Any] | None = None):
+        self.options = options or {}
+
+    @property
+    def get_format(self) -> str:
+        """Return the format type."""
+        return FormatType.CSV.value
+
+    def register_data_source_format(self, ctx: SessionContext, table_name: 
str, path: str):
+        """Register a data source format."""
+        try:
+            ctx.register_csv(table_name, path, **self.options)
+        except Exception as e:
+            raise FileFormatRegistrationException("Failed to register csv data 
source: %s", e)
+

Review Comment:
   Same issue as above: `FileFormatRegistrationException("Failed to register 
csv data source: %s", e)` passes multiple args and loses exception chaining. 
Prefer a formatted message and `raise ... from e`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to