kaxil commented on code in PR #64077: URL: https://github.com/apache/airflow/pull/64077#discussion_r2991684800
########## providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_file_analysis.py: ########## @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example DAGs demonstrating LLMFileAnalysisOperator usage.""" + +from __future__ import annotations + +from pydantic import BaseModel + +from airflow.providers.common.ai.operators.llm_file_analysis import LLMFileAnalysisOperator +from airflow.providers.common.compat.sdk import dag, task + + +class FileAnalysisSummary(BaseModel): Review Comment: Module-level `BaseModel` in example DAGs breaks Sphinx autoapi -- the inherited internal docstrings produce invalid RST. Move the class inside the function that uses it: ```python @dag def example_llm_file_analysis_structured(): class FileAnalysisSummary(BaseModel): findings: list[str] highest_severity: str truncated_inputs: bool LLMFileAnalysisOperator( ... output_type=FileAnalysisSummary, ) ``` This matches the pattern used in `example_llm_analysis_pipeline.py`. ########## providers/common/ai/src/airflow/providers/common/ai/utils/file_analysis.py: ########## @@ -0,0 +1,673 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Helpers for building file-analysis prompts for LLM operators.""" + +from __future__ import annotations + +import csv +import gzip +import io +import json +import logging +from bisect import insort +from dataclasses import dataclass +from pathlib import PurePosixPath +from typing import TYPE_CHECKING, Any + +from pydantic_ai.messages import BinaryContent + +from airflow.providers.common.ai.exceptions import ( + LLMFileAnalysisLimitExceededError, + LLMFileAnalysisMultimodalRequiredError, + LLMFileAnalysisUnsupportedFormatError, +) +from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException, ObjectStoragePath + +if TYPE_CHECKING: + from collections.abc import Sequence + + from pydantic_ai.messages import UserContent + +SUPPORTED_FILE_FORMATS: tuple[str, ...] = ( + "avro", + "csv", + "jpeg", + "jpg", + "json", + "log", + "parquet", + "pdf", + "png", +) + +_TEXT_LIKE_FORMATS = frozenset({"csv", "json", "log", "avro", "parquet"}) +_MULTI_MODAL_FORMATS = frozenset({"jpeg", "jpg", "pdf", "png"}) +_COMPRESSION_SUFFIXES = { + "bz2": "bzip2", + "gz": "gzip", + "snappy": "snappy", + "xz": "xz", + "zst": "zstd", +} +_GZIP_SUPPORTED_FORMATS = frozenset({"csv", "json", "log"}) +_TEXT_SAMPLE_HEAD_CHARS = 8_000 +_TEXT_SAMPLE_TAIL_CHARS = 2_000 +_MEDIA_TYPES = { + "jpeg": "image/jpeg", + "jpg": "image/jpeg", + "pdf": "application/pdf", + "png": "image/png", +} +log = logging.getLogger(__name__) + + +@dataclass +class FileAnalysisRequest: + """Prepared prompt content and discovery metadata for the file-analysis operator.""" + + user_content: str | Sequence[UserContent] + resolved_paths: list[str] + total_size_bytes: int + omitted_files: int = 0 + text_truncated: bool = False + attachment_count: int = 0 + text_file_count: int = 0 + + +@dataclass +class _PreparedFile: + path: ObjectStoragePath + file_format: str + size_bytes: int + compression: str | None + partitions: tuple[str, ...] + estimated_rows: int | None = None + text_content: str | None = None + attachment: BinaryContent | None = None + content_size_bytes: int = 0 + content_truncated: bool = False + content_omitted: bool = False + + +@dataclass +class _DiscoveredFile: + path: ObjectStoragePath + file_format: str + size_bytes: int + compression: str | None + + +@dataclass +class _RenderResult: + text: str + estimated_rows: int | None + content_size_bytes: int + + +def build_file_analysis_request( + *, + file_path: str, + file_conn_id: str | None, + prompt: str, + multi_modal: bool, + max_files: int, + max_file_size_bytes: int, + max_total_size_bytes: int, + max_text_chars: int, + sample_rows: int, +) -> FileAnalysisRequest: + """Resolve files, normalize supported formats, and build prompt content for an LLM run.""" + if sample_rows <= 0: + raise ValueError("sample_rows must be greater than zero.") + log.info( + "Preparing file analysis request for path=%s, file_conn_id=%s, multi_modal=%s, " + "max_files=%s, max_file_size_bytes=%s, max_total_size_bytes=%s, max_text_chars=%s, sample_rows=%s", + file_path, + file_conn_id, + multi_modal, + max_files, + max_file_size_bytes, + max_total_size_bytes, + max_text_chars, + sample_rows, + ) + root = ObjectStoragePath(file_path, conn_id=file_conn_id) + resolved_paths, omitted_files = _resolve_paths(root=root, max_files=max_files) + log.info( + "Resolved %s file(s) from %s%s", + len(resolved_paths), + file_path, + f"; omitted {omitted_files} additional file(s) due to max_files limit" if omitted_files else "", + ) + if log.isEnabledFor(logging.DEBUG): + log.debug("Resolved file paths: %s", [str(path) for path in resolved_paths]) + + discovered_files: list[_DiscoveredFile] = [] + total_size_bytes = 0 + for path in resolved_paths: + discovered = _discover_file( + path=path, + max_file_size_bytes=max_file_size_bytes, + ) + total_size_bytes += discovered.size_bytes + if total_size_bytes > max_total_size_bytes: + log.info( + "Rejecting file set before content reads because cumulative size reached %s bytes (limit=%s bytes).", + total_size_bytes, + max_total_size_bytes, + ) + raise LLMFileAnalysisLimitExceededError( + "Total input size exceeds the configured limit: " + f"{total_size_bytes} bytes > {max_total_size_bytes} bytes." + ) + discovered_files.append(discovered) + + log.info( + "Validated byte limits for %s file(s) before reading file contents; total_size_bytes=%s.", + len(discovered_files), + total_size_bytes, + ) + + prepared_files: list[_PreparedFile] = [] + processed_size_bytes = 0 + for discovered in discovered_files: + remaining_content_bytes = max_total_size_bytes - processed_size_bytes + if remaining_content_bytes <= 0: + raise LLMFileAnalysisLimitExceededError( + "Total processed input size exceeds the configured limit after decompression." + ) + prepared = _prepare_file( + discovered_file=discovered, + multi_modal=multi_modal, + sample_rows=sample_rows, + max_content_bytes=min(max_file_size_bytes, remaining_content_bytes), + ) + processed_size_bytes += prepared.content_size_bytes + prepared_files.append(prepared) + + text_truncated = _apply_text_budget(prepared_files=prepared_files, max_text_chars=max_text_chars) + if text_truncated: + log.info("Normalized text content exceeded max_text_chars=%s and was truncated.", max_text_chars) + text_preamble = _build_text_preamble( + prompt=prompt, + prepared_files=prepared_files, + omitted_files=omitted_files, + text_truncated=text_truncated, + ) + attachments = [prepared.attachment for prepared in prepared_files if prepared.attachment is not None] + text_file_count = sum(1 for prepared in prepared_files if prepared.text_content is not None) + user_content: str | list[UserContent] + if attachments: + user_content = [text_preamble, *attachments] + else: + user_content = text_preamble + log.info( + "Prepared file analysis request with %s text file(s), %s attachment(s), total_size_bytes=%s.", + text_file_count, + len(attachments), + total_size_bytes, + ) + if log.isEnabledFor(logging.DEBUG): + log.debug("Prepared text preamble length=%s", len(text_preamble)) + return FileAnalysisRequest( + user_content=user_content, + resolved_paths=[str(path) for path in resolved_paths], + total_size_bytes=total_size_bytes, + omitted_files=omitted_files, + text_truncated=text_truncated, + attachment_count=len(attachments), + text_file_count=text_file_count, + ) + + +def _resolve_paths(*, root: ObjectStoragePath, max_files: int) -> tuple[list[ObjectStoragePath], int]: + try: + if root.is_file(): + return [root], 0 + except FileNotFoundError: + pass + + try: + selected: list[tuple[str, ObjectStoragePath]] = [] + omitted_files = 0 + for path in root.rglob("*"): + if not path.is_file(): + continue + path_key = str(path) + if len(selected) < max_files: + insort(selected, (path_key, path)) + continue + if path_key < selected[-1][0]: + insort(selected, (path_key, path)) + selected.pop() + omitted_files += 1 + except (FileNotFoundError, NotADirectoryError): + selected = [] + omitted_files = 0 + + if not selected: + raise FileNotFoundError(f"No files found for {root}.") + + return [path for _, path in selected], omitted_files + + +def _discover_file(*, path: ObjectStoragePath, max_file_size_bytes: int) -> _DiscoveredFile: + file_format, compression = detect_file_format(path) + size_bytes = path.stat().st_size + log.debug( + "Discovered file %s (format=%s, size_bytes=%s%s).", + path, + file_format, + size_bytes, + f", compression={compression}" if compression else "", + ) + if size_bytes > max_file_size_bytes: + log.info( + "Rejecting file %s because size_bytes=%s exceeds the per-file limit=%s.", + path, + size_bytes, + max_file_size_bytes, + ) + raise LLMFileAnalysisLimitExceededError( + f"File {path} exceeds the configured per-file limit: {size_bytes} bytes > {max_file_size_bytes} bytes." + ) + return _DiscoveredFile( + path=path, + file_format=file_format, + size_bytes=size_bytes, + compression=compression, + ) + + +def _prepare_file( + *, + discovered_file: _DiscoveredFile, + multi_modal: bool, + sample_rows: int, + max_content_bytes: int, +) -> _PreparedFile: + path = discovered_file.path + file_format = discovered_file.file_format + size_bytes = discovered_file.size_bytes + compression = discovered_file.compression + log.debug( + "Preparing file content for %s (format=%s, size_bytes=%s%s).", + path, + file_format, + size_bytes, + f", compression={compression}" if compression else "", + ) + prepared = _PreparedFile( + path=path, + file_format=file_format, + size_bytes=size_bytes, + compression=compression, + partitions=_infer_partitions(path), + ) + + if file_format in _MULTI_MODAL_FORMATS: + if not multi_modal: + log.info("Rejecting file %s because format=%s requires multi_modal=True.", path, file_format) + raise LLMFileAnalysisMultimodalRequiredError( + f"File {path} has format {file_format!r}; set multi_modal=True to analyze images or PDFs." + ) + prepared.attachment = BinaryContent( + data=_read_raw_bytes(path, compression=compression, max_bytes=max_content_bytes), + media_type=_MEDIA_TYPES[file_format], + identifier=str(path), + ) + prepared.content_size_bytes = len(prepared.attachment.data) + log.debug( + "Attached %s as multimodal binary content with media_type=%s.", path, _MEDIA_TYPES[file_format] + ) + return prepared + + render_result = _render_text_content( + path=path, + file_format=file_format, + compression=compression, + sample_rows=sample_rows, + max_content_bytes=max_content_bytes, + ) + prepared.text_content = render_result.text + prepared.estimated_rows = render_result.estimated_rows + prepared.content_size_bytes = render_result.content_size_bytes + log.debug( + "Normalized %s into text content of %s characters%s.", + path, + len(render_result.text), + f"; estimated_rows={render_result.estimated_rows}" + if render_result.estimated_rows is not None + else "", + ) + return prepared + + +def detect_file_format(path: ObjectStoragePath) -> tuple[str, str | None]: + """Detect the logical file format and compression codec from a path suffix.""" + suffixes = [suffix.removeprefix(".").lower() for suffix in path.suffixes] + compression: str | None = None + if suffixes and suffixes[-1] in _COMPRESSION_SUFFIXES: + compression = _COMPRESSION_SUFFIXES[suffixes[-1]] + suffixes = suffixes[:-1] + detected = suffixes[-1] if suffixes else "log" + if detected not in SUPPORTED_FILE_FORMATS: + raise LLMFileAnalysisUnsupportedFormatError( + f"Unsupported file format {detected!r} for {path}. Supported formats: {', '.join(SUPPORTED_FILE_FORMATS)}." + ) + if compression and compression != "gzip": + log.info("Rejecting file %s because compression=%s is not supported.", path, compression) + raise LLMFileAnalysisUnsupportedFormatError( + f"Compression {compression!r} is not supported for file analysis." + ) + if compression == "gzip" and detected not in _GZIP_SUPPORTED_FORMATS: + raise LLMFileAnalysisUnsupportedFormatError( + f"Compression {compression!r} is not supported for {detected!r} file analysis." + ) + return detected, compression + + +def _render_text_content( + *, + path: ObjectStoragePath, + file_format: str, + compression: str | None, + sample_rows: int, + max_content_bytes: int, +) -> _RenderResult: + if file_format == "json": + return _render_json(path, compression=compression, max_content_bytes=max_content_bytes) + if file_format == "csv": + return _render_csv( + path, compression=compression, sample_rows=sample_rows, max_content_bytes=max_content_bytes + ) + if file_format == "parquet": + return _render_parquet(path, sample_rows=sample_rows, max_content_bytes=max_content_bytes) + if file_format == "avro": + return _render_avro(path, sample_rows=sample_rows, max_content_bytes=max_content_bytes) + return _render_text_like(path, compression=compression, max_content_bytes=max_content_bytes) + + +def _render_text_like( + path: ObjectStoragePath, *, compression: str | None, max_content_bytes: int +) -> _RenderResult: + raw_bytes = _read_raw_bytes(path, compression=compression, max_bytes=max_content_bytes) + text = _decode_text(raw_bytes) + return _RenderResult(text=_truncate_text(text), estimated_rows=None, content_size_bytes=len(raw_bytes)) + + +def _render_json( + path: ObjectStoragePath, *, compression: str | None, max_content_bytes: int +) -> _RenderResult: + raw_bytes = _read_raw_bytes(path, compression=compression, max_bytes=max_content_bytes) + decoded = _decode_text(raw_bytes) + document = json.loads(decoded) + if isinstance(document, list): + estimated_rows = len(document) + else: + estimated_rows = None + pretty = json.dumps(document, indent=2, sort_keys=True, default=str) + return _RenderResult( + text=_truncate_text(pretty), + estimated_rows=estimated_rows, + content_size_bytes=len(raw_bytes), + ) + + +def _render_csv( + path: ObjectStoragePath, *, compression: str | None, sample_rows: int, max_content_bytes: int +) -> _RenderResult: + raw_bytes = _read_raw_bytes(path, compression=compression, max_bytes=max_content_bytes) + decoded = _decode_text(raw_bytes) + reader = list(csv.reader(io.StringIO(decoded))) + if not reader: + return _RenderResult(text="", estimated_rows=0, content_size_bytes=len(raw_bytes)) + header, rows = reader[0], reader[1:] + sampled_rows = rows[:sample_rows] + payload = ["Header: " + ", ".join(header)] + if sampled_rows: + payload.append("Sample rows:") + payload += [", ".join(str(value) for value in row) for row in sampled_rows] + return _RenderResult( + text=_truncate_text("\n".join(payload)), + estimated_rows=len(rows), + content_size_bytes=len(raw_bytes), + ) + + +def _render_parquet(path: ObjectStoragePath, *, sample_rows: int, max_content_bytes: int) -> _RenderResult: + try: + import pyarrow.parquet as pq + except ImportError as exc: + raise AirflowOptionalProviderFeatureException( + "Parquet analysis requires the `parquet` extra for apache-airflow-providers-common-ai." + ) from exc + + with path.open("rb") as handle: + parquet_file = pq.ParquetFile(handle) + metadata = parquet_file.metadata + num_rows = metadata.num_rows if metadata is not None else 0 + + handle.seek(0, io.SEEK_END) + content_size_bytes = handle.tell() + handle.seek(0) + if content_size_bytes > max_content_bytes: + raise LLMFileAnalysisLimitExceededError( + f"File {path} exceeds the configured processed-content limit: {content_size_bytes} bytes > {max_content_bytes} bytes." + ) + + schema = ", ".join(f"{field.name}: {field.type}" for field in parquet_file.schema_arrow) + sampled_rows: list[dict[str, Any]] = [] + if sample_rows > 0 and num_rows > 0 and parquet_file.num_row_groups > 0: + remaining_rows = sample_rows + for row_group_index in range(parquet_file.num_row_groups): + if remaining_rows <= 0: + break + row_group = parquet_file.read_row_group(row_group_index) + if row_group.num_rows == 0: + continue + group_rows = row_group.slice(0, remaining_rows).to_pylist() + sampled_rows.extend(group_rows) + remaining_rows -= len(group_rows) + payload = [f"Schema: {schema}", "Sample rows:", json.dumps(sampled_rows, indent=2, default=str)] + return _RenderResult( + text=_truncate_text("\n".join(payload)), + estimated_rows=num_rows, + content_size_bytes=content_size_bytes, + ) + + +def _render_avro(path: ObjectStoragePath, *, sample_rows: int, max_content_bytes: int) -> _RenderResult: + try: + import fastavro + except ImportError as exc: + raise AirflowOptionalProviderFeatureException( + "Avro analysis requires the `avro` extra for apache-airflow-providers-common-ai." + ) from exc + + sampled_rows: list[dict[str, Any]] = [] + total_rows = 0 + with path.open("rb") as handle: + handle.seek(0, io.SEEK_END) + content_size_bytes = handle.tell() + handle.seek(0) + if content_size_bytes > max_content_bytes: + raise LLMFileAnalysisLimitExceededError( + f"File {path} exceeds the configured processed-content limit: {content_size_bytes} bytes > {max_content_bytes} bytes." + ) + reader = fastavro.reader(handle) + writer_schema = getattr(reader, "writer_schema", None) + fully_read = False + if sample_rows > 0: + for record in reader: + total_rows += 1 + if isinstance(record, dict): + sampled_rows.append({str(key): value for key, value in record.items()}) + if len(sampled_rows) >= sample_rows: Review Comment: If records aren't dicts (e.g., Avro union top-level type), `len(sampled_rows) >= sample_rows` never triggers and the loop scans the entire file. Should break on `total_rows` instead: ```python if total_rows >= sample_rows: break ``` ########## providers/common/ai/src/airflow/providers/common/ai/utils/file_analysis.py: ########## @@ -0,0 +1,673 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Helpers for building file-analysis prompts for LLM operators.""" + +from __future__ import annotations + +import csv +import gzip +import io +import json +import logging +from bisect import insort +from dataclasses import dataclass +from pathlib import PurePosixPath +from typing import TYPE_CHECKING, Any + +from pydantic_ai.messages import BinaryContent + +from airflow.providers.common.ai.exceptions import ( + LLMFileAnalysisLimitExceededError, + LLMFileAnalysisMultimodalRequiredError, + LLMFileAnalysisUnsupportedFormatError, +) +from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException, ObjectStoragePath + +if TYPE_CHECKING: + from collections.abc import Sequence + + from pydantic_ai.messages import UserContent + +SUPPORTED_FILE_FORMATS: tuple[str, ...] = ( + "avro", + "csv", + "jpeg", + "jpg", + "json", + "log", + "parquet", + "pdf", + "png", +) + +_TEXT_LIKE_FORMATS = frozenset({"csv", "json", "log", "avro", "parquet"}) +_MULTI_MODAL_FORMATS = frozenset({"jpeg", "jpg", "pdf", "png"}) +_COMPRESSION_SUFFIXES = { + "bz2": "bzip2", + "gz": "gzip", + "snappy": "snappy", + "xz": "xz", + "zst": "zstd", +} +_GZIP_SUPPORTED_FORMATS = frozenset({"csv", "json", "log"}) +_TEXT_SAMPLE_HEAD_CHARS = 8_000 +_TEXT_SAMPLE_TAIL_CHARS = 2_000 +_MEDIA_TYPES = { + "jpeg": "image/jpeg", + "jpg": "image/jpeg", + "pdf": "application/pdf", + "png": "image/png", +} +log = logging.getLogger(__name__) + + +@dataclass +class FileAnalysisRequest: + """Prepared prompt content and discovery metadata for the file-analysis operator.""" + + user_content: str | Sequence[UserContent] + resolved_paths: list[str] + total_size_bytes: int + omitted_files: int = 0 + text_truncated: bool = False + attachment_count: int = 0 + text_file_count: int = 0 + + +@dataclass +class _PreparedFile: + path: ObjectStoragePath + file_format: str + size_bytes: int + compression: str | None + partitions: tuple[str, ...] + estimated_rows: int | None = None + text_content: str | None = None + attachment: BinaryContent | None = None + content_size_bytes: int = 0 + content_truncated: bool = False + content_omitted: bool = False + + +@dataclass +class _DiscoveredFile: + path: ObjectStoragePath + file_format: str + size_bytes: int + compression: str | None + + +@dataclass +class _RenderResult: + text: str + estimated_rows: int | None + content_size_bytes: int + + +def build_file_analysis_request( + *, + file_path: str, + file_conn_id: str | None, + prompt: str, + multi_modal: bool, + max_files: int, + max_file_size_bytes: int, + max_total_size_bytes: int, + max_text_chars: int, + sample_rows: int, +) -> FileAnalysisRequest: + """Resolve files, normalize supported formats, and build prompt content for an LLM run.""" + if sample_rows <= 0: + raise ValueError("sample_rows must be greater than zero.") + log.info( + "Preparing file analysis request for path=%s, file_conn_id=%s, multi_modal=%s, " + "max_files=%s, max_file_size_bytes=%s, max_total_size_bytes=%s, max_text_chars=%s, sample_rows=%s", + file_path, + file_conn_id, + multi_modal, + max_files, + max_file_size_bytes, + max_total_size_bytes, + max_text_chars, + sample_rows, + ) + root = ObjectStoragePath(file_path, conn_id=file_conn_id) + resolved_paths, omitted_files = _resolve_paths(root=root, max_files=max_files) + log.info( + "Resolved %s file(s) from %s%s", + len(resolved_paths), + file_path, + f"; omitted {omitted_files} additional file(s) due to max_files limit" if omitted_files else "", + ) + if log.isEnabledFor(logging.DEBUG): + log.debug("Resolved file paths: %s", [str(path) for path in resolved_paths]) + + discovered_files: list[_DiscoveredFile] = [] + total_size_bytes = 0 + for path in resolved_paths: + discovered = _discover_file( + path=path, + max_file_size_bytes=max_file_size_bytes, + ) + total_size_bytes += discovered.size_bytes + if total_size_bytes > max_total_size_bytes: + log.info( + "Rejecting file set before content reads because cumulative size reached %s bytes (limit=%s bytes).", + total_size_bytes, + max_total_size_bytes, + ) + raise LLMFileAnalysisLimitExceededError( + "Total input size exceeds the configured limit: " + f"{total_size_bytes} bytes > {max_total_size_bytes} bytes." + ) + discovered_files.append(discovered) + + log.info( + "Validated byte limits for %s file(s) before reading file contents; total_size_bytes=%s.", + len(discovered_files), + total_size_bytes, + ) + + prepared_files: list[_PreparedFile] = [] + processed_size_bytes = 0 + for discovered in discovered_files: + remaining_content_bytes = max_total_size_bytes - processed_size_bytes + if remaining_content_bytes <= 0: + raise LLMFileAnalysisLimitExceededError( + "Total processed input size exceeds the configured limit after decompression." + ) + prepared = _prepare_file( + discovered_file=discovered, + multi_modal=multi_modal, + sample_rows=sample_rows, + max_content_bytes=min(max_file_size_bytes, remaining_content_bytes), + ) + processed_size_bytes += prepared.content_size_bytes + prepared_files.append(prepared) + + text_truncated = _apply_text_budget(prepared_files=prepared_files, max_text_chars=max_text_chars) + if text_truncated: + log.info("Normalized text content exceeded max_text_chars=%s and was truncated.", max_text_chars) + text_preamble = _build_text_preamble( + prompt=prompt, + prepared_files=prepared_files, + omitted_files=omitted_files, + text_truncated=text_truncated, + ) + attachments = [prepared.attachment for prepared in prepared_files if prepared.attachment is not None] + text_file_count = sum(1 for prepared in prepared_files if prepared.text_content is not None) + user_content: str | list[UserContent] + if attachments: + user_content = [text_preamble, *attachments] + else: + user_content = text_preamble + log.info( + "Prepared file analysis request with %s text file(s), %s attachment(s), total_size_bytes=%s.", + text_file_count, + len(attachments), + total_size_bytes, + ) + if log.isEnabledFor(logging.DEBUG): + log.debug("Prepared text preamble length=%s", len(text_preamble)) + return FileAnalysisRequest( + user_content=user_content, + resolved_paths=[str(path) for path in resolved_paths], + total_size_bytes=total_size_bytes, + omitted_files=omitted_files, + text_truncated=text_truncated, + attachment_count=len(attachments), + text_file_count=text_file_count, + ) + + +def _resolve_paths(*, root: ObjectStoragePath, max_files: int) -> tuple[list[ObjectStoragePath], int]: + try: + if root.is_file(): + return [root], 0 + except FileNotFoundError: + pass + + try: + selected: list[tuple[str, ObjectStoragePath]] = [] + omitted_files = 0 + for path in root.rglob("*"): + if not path.is_file(): + continue + path_key = str(path) + if len(selected) < max_files: + insort(selected, (path_key, path)) + continue + if path_key < selected[-1][0]: + insort(selected, (path_key, path)) + selected.pop() + omitted_files += 1 + except (FileNotFoundError, NotADirectoryError): + selected = [] + omitted_files = 0 + + if not selected: + raise FileNotFoundError(f"No files found for {root}.") + + return [path for _, path in selected], omitted_files + + +def _discover_file(*, path: ObjectStoragePath, max_file_size_bytes: int) -> _DiscoveredFile: + file_format, compression = detect_file_format(path) + size_bytes = path.stat().st_size + log.debug( + "Discovered file %s (format=%s, size_bytes=%s%s).", + path, + file_format, + size_bytes, + f", compression={compression}" if compression else "", + ) + if size_bytes > max_file_size_bytes: + log.info( + "Rejecting file %s because size_bytes=%s exceeds the per-file limit=%s.", + path, + size_bytes, + max_file_size_bytes, + ) + raise LLMFileAnalysisLimitExceededError( + f"File {path} exceeds the configured per-file limit: {size_bytes} bytes > {max_file_size_bytes} bytes." + ) + return _DiscoveredFile( + path=path, + file_format=file_format, + size_bytes=size_bytes, + compression=compression, + ) + + +def _prepare_file( + *, + discovered_file: _DiscoveredFile, + multi_modal: bool, + sample_rows: int, + max_content_bytes: int, +) -> _PreparedFile: + path = discovered_file.path + file_format = discovered_file.file_format + size_bytes = discovered_file.size_bytes + compression = discovered_file.compression + log.debug( + "Preparing file content for %s (format=%s, size_bytes=%s%s).", + path, + file_format, + size_bytes, + f", compression={compression}" if compression else "", + ) + prepared = _PreparedFile( + path=path, + file_format=file_format, + size_bytes=size_bytes, + compression=compression, + partitions=_infer_partitions(path), + ) + + if file_format in _MULTI_MODAL_FORMATS: + if not multi_modal: + log.info("Rejecting file %s because format=%s requires multi_modal=True.", path, file_format) + raise LLMFileAnalysisMultimodalRequiredError( + f"File {path} has format {file_format!r}; set multi_modal=True to analyze images or PDFs." + ) + prepared.attachment = BinaryContent( + data=_read_raw_bytes(path, compression=compression, max_bytes=max_content_bytes), + media_type=_MEDIA_TYPES[file_format], + identifier=str(path), + ) + prepared.content_size_bytes = len(prepared.attachment.data) + log.debug( + "Attached %s as multimodal binary content with media_type=%s.", path, _MEDIA_TYPES[file_format] + ) + return prepared + + render_result = _render_text_content( + path=path, + file_format=file_format, + compression=compression, + sample_rows=sample_rows, + max_content_bytes=max_content_bytes, + ) + prepared.text_content = render_result.text + prepared.estimated_rows = render_result.estimated_rows + prepared.content_size_bytes = render_result.content_size_bytes + log.debug( + "Normalized %s into text content of %s characters%s.", + path, + len(render_result.text), + f"; estimated_rows={render_result.estimated_rows}" + if render_result.estimated_rows is not None + else "", + ) + return prepared + + +def detect_file_format(path: ObjectStoragePath) -> tuple[str, str | None]: + """Detect the logical file format and compression codec from a path suffix.""" + suffixes = [suffix.removeprefix(".").lower() for suffix in path.suffixes] + compression: str | None = None + if suffixes and suffixes[-1] in _COMPRESSION_SUFFIXES: + compression = _COMPRESSION_SUFFIXES[suffixes[-1]] + suffixes = suffixes[:-1] + detected = suffixes[-1] if suffixes else "log" + if detected not in SUPPORTED_FILE_FORMATS: + raise LLMFileAnalysisUnsupportedFormatError( + f"Unsupported file format {detected!r} for {path}. Supported formats: {', '.join(SUPPORTED_FILE_FORMATS)}." + ) + if compression and compression != "gzip": + log.info("Rejecting file %s because compression=%s is not supported.", path, compression) + raise LLMFileAnalysisUnsupportedFormatError( + f"Compression {compression!r} is not supported for file analysis." + ) + if compression == "gzip" and detected not in _GZIP_SUPPORTED_FORMATS: + raise LLMFileAnalysisUnsupportedFormatError( + f"Compression {compression!r} is not supported for {detected!r} file analysis." + ) + return detected, compression + + +def _render_text_content( + *, + path: ObjectStoragePath, + file_format: str, + compression: str | None, + sample_rows: int, + max_content_bytes: int, +) -> _RenderResult: + if file_format == "json": + return _render_json(path, compression=compression, max_content_bytes=max_content_bytes) + if file_format == "csv": + return _render_csv( + path, compression=compression, sample_rows=sample_rows, max_content_bytes=max_content_bytes + ) + if file_format == "parquet": + return _render_parquet(path, sample_rows=sample_rows, max_content_bytes=max_content_bytes) + if file_format == "avro": + return _render_avro(path, sample_rows=sample_rows, max_content_bytes=max_content_bytes) + return _render_text_like(path, compression=compression, max_content_bytes=max_content_bytes) + + +def _render_text_like( + path: ObjectStoragePath, *, compression: str | None, max_content_bytes: int +) -> _RenderResult: + raw_bytes = _read_raw_bytes(path, compression=compression, max_bytes=max_content_bytes) + text = _decode_text(raw_bytes) + return _RenderResult(text=_truncate_text(text), estimated_rows=None, content_size_bytes=len(raw_bytes)) + + +def _render_json( + path: ObjectStoragePath, *, compression: str | None, max_content_bytes: int +) -> _RenderResult: + raw_bytes = _read_raw_bytes(path, compression=compression, max_bytes=max_content_bytes) + decoded = _decode_text(raw_bytes) + document = json.loads(decoded) + if isinstance(document, list): + estimated_rows = len(document) + else: + estimated_rows = None + pretty = json.dumps(document, indent=2, sort_keys=True, default=str) + return _RenderResult( + text=_truncate_text(pretty), + estimated_rows=estimated_rows, + content_size_bytes=len(raw_bytes), + ) + + +def _render_csv( + path: ObjectStoragePath, *, compression: str | None, sample_rows: int, max_content_bytes: int +) -> _RenderResult: + raw_bytes = _read_raw_bytes(path, compression=compression, max_bytes=max_content_bytes) + decoded = _decode_text(raw_bytes) + reader = list(csv.reader(io.StringIO(decoded))) + if not reader: + return _RenderResult(text="", estimated_rows=0, content_size_bytes=len(raw_bytes)) + header, rows = reader[0], reader[1:] + sampled_rows = rows[:sample_rows] + payload = ["Header: " + ", ".join(header)] + if sampled_rows: + payload.append("Sample rows:") + payload += [", ".join(str(value) for value in row) for row in sampled_rows] + return _RenderResult( + text=_truncate_text("\n".join(payload)), + estimated_rows=len(rows), + content_size_bytes=len(raw_bytes), + ) + + +def _render_parquet(path: ObjectStoragePath, *, sample_rows: int, max_content_bytes: int) -> _RenderResult: + try: + import pyarrow.parquet as pq + except ImportError as exc: + raise AirflowOptionalProviderFeatureException( + "Parquet analysis requires the `parquet` extra for apache-airflow-providers-common-ai." + ) from exc + + with path.open("rb") as handle: + parquet_file = pq.ParquetFile(handle) + metadata = parquet_file.metadata + num_rows = metadata.num_rows if metadata is not None else 0 + + handle.seek(0, io.SEEK_END) + content_size_bytes = handle.tell() + handle.seek(0) + if content_size_bytes > max_content_bytes: + raise LLMFileAnalysisLimitExceededError( + f"File {path} exceeds the configured processed-content limit: {content_size_bytes} bytes > {max_content_bytes} bytes." + ) + + schema = ", ".join(f"{field.name}: {field.type}" for field in parquet_file.schema_arrow) + sampled_rows: list[dict[str, Any]] = [] + if sample_rows > 0 and num_rows > 0 and parquet_file.num_row_groups > 0: + remaining_rows = sample_rows + for row_group_index in range(parquet_file.num_row_groups): + if remaining_rows <= 0: + break + row_group = parquet_file.read_row_group(row_group_index) + if row_group.num_rows == 0: + continue + group_rows = row_group.slice(0, remaining_rows).to_pylist() + sampled_rows.extend(group_rows) + remaining_rows -= len(group_rows) + payload = [f"Schema: {schema}", "Sample rows:", json.dumps(sampled_rows, indent=2, default=str)] + return _RenderResult( + text=_truncate_text("\n".join(payload)), + estimated_rows=num_rows, + content_size_bytes=content_size_bytes, + ) + + +def _render_avro(path: ObjectStoragePath, *, sample_rows: int, max_content_bytes: int) -> _RenderResult: + try: + import fastavro + except ImportError as exc: + raise AirflowOptionalProviderFeatureException( + "Avro analysis requires the `avro` extra for apache-airflow-providers-common-ai." + ) from exc + + sampled_rows: list[dict[str, Any]] = [] + total_rows = 0 + with path.open("rb") as handle: + handle.seek(0, io.SEEK_END) + content_size_bytes = handle.tell() + handle.seek(0) + if content_size_bytes > max_content_bytes: + raise LLMFileAnalysisLimitExceededError( + f"File {path} exceeds the configured processed-content limit: {content_size_bytes} bytes > {max_content_bytes} bytes." + ) + reader = fastavro.reader(handle) + writer_schema = getattr(reader, "writer_schema", None) + fully_read = False + if sample_rows > 0: + for record in reader: + total_rows += 1 + if isinstance(record, dict): + sampled_rows.append({str(key): value for key, value in record.items()}) + if len(sampled_rows) >= sample_rows: + break + else: + fully_read = True + payload = [ + f"Schema: {json.dumps(writer_schema, indent=2, default=str)}", + "Sample rows:", + json.dumps(sampled_rows, indent=2, default=str), + ] + return _RenderResult( + text=_truncate_text("\n".join(payload)), + estimated_rows=total_rows if fully_read else None, + content_size_bytes=content_size_bytes, + ) + + +def _read_raw_bytes(path: ObjectStoragePath, *, compression: str | None, max_bytes: int) -> bytes: + with path.open("rb") as handle: + if compression == "gzip": + with gzip.GzipFile(fileobj=handle) as gzip_handle: + return _read_limited_bytes(gzip_handle, path=path, max_bytes=max_bytes) + return _read_limited_bytes(handle, path=path, max_bytes=max_bytes) + + +def _read_limited_bytes(handle: io.BufferedIOBase, *, path: ObjectStoragePath, max_bytes: int) -> bytes: + chunks: list[bytes] = [] + total_bytes = 0 + while True: + chunk = handle.read(min(64 * 1024, max_bytes - total_bytes + 1)) + if not chunk: + break + total_bytes += len(chunk) + if total_bytes > max_bytes: + raise LLMFileAnalysisLimitExceededError( + f"File {path} exceeds the configured processed-content limit: > {max_bytes} bytes." + ) + chunks.append(chunk) + return b"".join(chunks) + + +def _decode_text(data: bytes) -> str: + return data.decode("utf-8", errors="replace") + + +def _apply_text_budget(*, prepared_files: list[_PreparedFile], max_text_chars: int) -> bool: + remaining = max_text_chars + truncated_any = False + for prepared in prepared_files: + if prepared.text_content is None: + continue + if remaining <= 0: + prepared.text_content = None + prepared.content_omitted = True + truncated_any = True + log.debug( + "Omitted normalized text for %s because the prompt text budget was exhausted.", prepared.path + ) + continue + original = prepared.text_content + if len(original) > remaining: + prepared.text_content = _truncate_text(original, max_chars=remaining) + prepared.content_truncated = True + truncated_any = True + log.debug( + "Truncated normalized text for %s from %s to %s characters to fit the remaining budget.", + prepared.path, + len(original), + len(prepared.text_content), + ) + remaining -= len(prepared.text_content) + return truncated_any + + +def _build_text_preamble( + *, + prompt: str, + prepared_files: list[_PreparedFile], + omitted_files: int, + text_truncated: bool, +) -> str: + lines = [ + "You are analyzing read-only files supplied.", Review Comment: `_build_system_prompt()` in the operator already sends "You are a read-only file analysis assistant..." as system instructions. This preamble repeats nearly identical text as user content. The LLM gets both, which wastes tokens and can cause confusion when two authority sources say slightly different things. Consider removing the instructional lines from the preamble and keeping it focused on presenting the data (file metadata, content). ########## providers/common/ai/src/airflow/providers/common/ai/utils/file_analysis.py: ########## @@ -0,0 +1,673 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Helpers for building file-analysis prompts for LLM operators.""" + +from __future__ import annotations + +import csv +import gzip +import io +import json +import logging +from bisect import insort +from dataclasses import dataclass +from pathlib import PurePosixPath +from typing import TYPE_CHECKING, Any + +from pydantic_ai.messages import BinaryContent + +from airflow.providers.common.ai.exceptions import ( + LLMFileAnalysisLimitExceededError, + LLMFileAnalysisMultimodalRequiredError, + LLMFileAnalysisUnsupportedFormatError, +) +from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException, ObjectStoragePath + +if TYPE_CHECKING: + from collections.abc import Sequence + + from pydantic_ai.messages import UserContent + +SUPPORTED_FILE_FORMATS: tuple[str, ...] = ( + "avro", + "csv", + "jpeg", + "jpg", + "json", + "log", + "parquet", + "pdf", + "png", +) + +_TEXT_LIKE_FORMATS = frozenset({"csv", "json", "log", "avro", "parquet"}) +_MULTI_MODAL_FORMATS = frozenset({"jpeg", "jpg", "pdf", "png"}) +_COMPRESSION_SUFFIXES = { + "bz2": "bzip2", + "gz": "gzip", + "snappy": "snappy", + "xz": "xz", + "zst": "zstd", +} +_GZIP_SUPPORTED_FORMATS = frozenset({"csv", "json", "log"}) +_TEXT_SAMPLE_HEAD_CHARS = 8_000 +_TEXT_SAMPLE_TAIL_CHARS = 2_000 +_MEDIA_TYPES = { + "jpeg": "image/jpeg", + "jpg": "image/jpeg", + "pdf": "application/pdf", + "png": "image/png", +} +log = logging.getLogger(__name__) + + +@dataclass +class FileAnalysisRequest: + """Prepared prompt content and discovery metadata for the file-analysis operator.""" + + user_content: str | Sequence[UserContent] + resolved_paths: list[str] + total_size_bytes: int + omitted_files: int = 0 + text_truncated: bool = False + attachment_count: int = 0 + text_file_count: int = 0 + + +@dataclass +class _PreparedFile: + path: ObjectStoragePath + file_format: str + size_bytes: int + compression: str | None + partitions: tuple[str, ...] + estimated_rows: int | None = None + text_content: str | None = None + attachment: BinaryContent | None = None + content_size_bytes: int = 0 + content_truncated: bool = False + content_omitted: bool = False + + +@dataclass +class _DiscoveredFile: + path: ObjectStoragePath + file_format: str + size_bytes: int + compression: str | None + + +@dataclass +class _RenderResult: + text: str + estimated_rows: int | None + content_size_bytes: int + + +def build_file_analysis_request( + *, + file_path: str, + file_conn_id: str | None, + prompt: str, + multi_modal: bool, + max_files: int, + max_file_size_bytes: int, + max_total_size_bytes: int, + max_text_chars: int, + sample_rows: int, +) -> FileAnalysisRequest: + """Resolve files, normalize supported formats, and build prompt content for an LLM run.""" + if sample_rows <= 0: + raise ValueError("sample_rows must be greater than zero.") + log.info( + "Preparing file analysis request for path=%s, file_conn_id=%s, multi_modal=%s, " + "max_files=%s, max_file_size_bytes=%s, max_total_size_bytes=%s, max_text_chars=%s, sample_rows=%s", + file_path, + file_conn_id, + multi_modal, + max_files, + max_file_size_bytes, + max_total_size_bytes, + max_text_chars, + sample_rows, + ) + root = ObjectStoragePath(file_path, conn_id=file_conn_id) + resolved_paths, omitted_files = _resolve_paths(root=root, max_files=max_files) + log.info( + "Resolved %s file(s) from %s%s", + len(resolved_paths), + file_path, + f"; omitted {omitted_files} additional file(s) due to max_files limit" if omitted_files else "", + ) + if log.isEnabledFor(logging.DEBUG): + log.debug("Resolved file paths: %s", [str(path) for path in resolved_paths]) + + discovered_files: list[_DiscoveredFile] = [] + total_size_bytes = 0 + for path in resolved_paths: + discovered = _discover_file( + path=path, + max_file_size_bytes=max_file_size_bytes, + ) + total_size_bytes += discovered.size_bytes + if total_size_bytes > max_total_size_bytes: + log.info( + "Rejecting file set before content reads because cumulative size reached %s bytes (limit=%s bytes).", + total_size_bytes, + max_total_size_bytes, + ) + raise LLMFileAnalysisLimitExceededError( + "Total input size exceeds the configured limit: " + f"{total_size_bytes} bytes > {max_total_size_bytes} bytes." + ) + discovered_files.append(discovered) + + log.info( + "Validated byte limits for %s file(s) before reading file contents; total_size_bytes=%s.", + len(discovered_files), + total_size_bytes, + ) + + prepared_files: list[_PreparedFile] = [] + processed_size_bytes = 0 + for discovered in discovered_files: + remaining_content_bytes = max_total_size_bytes - processed_size_bytes + if remaining_content_bytes <= 0: + raise LLMFileAnalysisLimitExceededError( + "Total processed input size exceeds the configured limit after decompression." + ) + prepared = _prepare_file( + discovered_file=discovered, + multi_modal=multi_modal, + sample_rows=sample_rows, + max_content_bytes=min(max_file_size_bytes, remaining_content_bytes), + ) + processed_size_bytes += prepared.content_size_bytes + prepared_files.append(prepared) + + text_truncated = _apply_text_budget(prepared_files=prepared_files, max_text_chars=max_text_chars) + if text_truncated: + log.info("Normalized text content exceeded max_text_chars=%s and was truncated.", max_text_chars) + text_preamble = _build_text_preamble( + prompt=prompt, + prepared_files=prepared_files, + omitted_files=omitted_files, + text_truncated=text_truncated, + ) + attachments = [prepared.attachment for prepared in prepared_files if prepared.attachment is not None] + text_file_count = sum(1 for prepared in prepared_files if prepared.text_content is not None) + user_content: str | list[UserContent] + if attachments: + user_content = [text_preamble, *attachments] + else: + user_content = text_preamble + log.info( + "Prepared file analysis request with %s text file(s), %s attachment(s), total_size_bytes=%s.", + text_file_count, + len(attachments), + total_size_bytes, + ) + if log.isEnabledFor(logging.DEBUG): + log.debug("Prepared text preamble length=%s", len(text_preamble)) + return FileAnalysisRequest( + user_content=user_content, + resolved_paths=[str(path) for path in resolved_paths], + total_size_bytes=total_size_bytes, + omitted_files=omitted_files, + text_truncated=text_truncated, + attachment_count=len(attachments), + text_file_count=text_file_count, + ) + + +def _resolve_paths(*, root: ObjectStoragePath, max_files: int) -> tuple[list[ObjectStoragePath], int]: + try: + if root.is_file(): + return [root], 0 + except FileNotFoundError: + pass + + try: + selected: list[tuple[str, ObjectStoragePath]] = [] + omitted_files = 0 + for path in root.rglob("*"): Review Comment: `rglob("*")` enumerates all objects under a prefix, and then `is_file()` makes a separate stat call per path. For an S3 prefix with millions of objects, this produces millions of LIST + HEAD API calls before `max_files` takes effect. Consider either: 1. Adding a `max_scan` bound (e.g., stop after scanning 10x `max_files` candidates), or 2. Documenting this as a known limitation: "for prefixes with many files, prefer specific paths." -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
