gopidesupavan commented on code in PR #64077: URL: https://github.com/apache/airflow/pull/64077#discussion_r2995546160
########## providers/common/ai/src/airflow/providers/common/ai/utils/file_analysis.py: ########## @@ -0,0 +1,673 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Helpers for building file-analysis prompts for LLM operators.""" + +from __future__ import annotations + +import csv +import gzip +import io +import json +import logging +from bisect import insort +from dataclasses import dataclass +from pathlib import PurePosixPath +from typing import TYPE_CHECKING, Any + +from pydantic_ai.messages import BinaryContent + +from airflow.providers.common.ai.exceptions import ( + LLMFileAnalysisLimitExceededError, + LLMFileAnalysisMultimodalRequiredError, + LLMFileAnalysisUnsupportedFormatError, +) +from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException, ObjectStoragePath + +if TYPE_CHECKING: + from collections.abc import Sequence + + from pydantic_ai.messages import UserContent + +SUPPORTED_FILE_FORMATS: tuple[str, ...] = ( + "avro", + "csv", + "jpeg", + "jpg", + "json", + "log", + "parquet", + "pdf", + "png", +) + +_TEXT_LIKE_FORMATS = frozenset({"csv", "json", "log", "avro", "parquet"}) +_MULTI_MODAL_FORMATS = frozenset({"jpeg", "jpg", "pdf", "png"}) +_COMPRESSION_SUFFIXES = { + "bz2": "bzip2", + "gz": "gzip", + "snappy": "snappy", + "xz": "xz", + "zst": "zstd", +} +_GZIP_SUPPORTED_FORMATS = frozenset({"csv", "json", "log"}) +_TEXT_SAMPLE_HEAD_CHARS = 8_000 +_TEXT_SAMPLE_TAIL_CHARS = 2_000 +_MEDIA_TYPES = { + "jpeg": "image/jpeg", + "jpg": "image/jpeg", + "pdf": "application/pdf", + "png": "image/png", +} +log = logging.getLogger(__name__) + + +@dataclass +class FileAnalysisRequest: + """Prepared prompt content and discovery metadata for the file-analysis operator.""" + + user_content: str | Sequence[UserContent] + resolved_paths: list[str] + total_size_bytes: int + omitted_files: int = 0 + text_truncated: bool = False + attachment_count: int = 0 + text_file_count: int = 0 + + +@dataclass +class _PreparedFile: + path: ObjectStoragePath + file_format: str + size_bytes: int + compression: str | None + partitions: tuple[str, ...] + estimated_rows: int | None = None + text_content: str | None = None + attachment: BinaryContent | None = None + content_size_bytes: int = 0 + content_truncated: bool = False + content_omitted: bool = False + + +@dataclass +class _DiscoveredFile: + path: ObjectStoragePath + file_format: str + size_bytes: int + compression: str | None + + +@dataclass +class _RenderResult: + text: str + estimated_rows: int | None + content_size_bytes: int + + +def build_file_analysis_request( + *, + file_path: str, + file_conn_id: str | None, + prompt: str, + multi_modal: bool, + max_files: int, + max_file_size_bytes: int, + max_total_size_bytes: int, + max_text_chars: int, + sample_rows: int, +) -> FileAnalysisRequest: + """Resolve files, normalize supported formats, and build prompt content for an LLM run.""" + if sample_rows <= 0: + raise ValueError("sample_rows must be greater than zero.") + log.info( + "Preparing file analysis request for path=%s, file_conn_id=%s, multi_modal=%s, " + "max_files=%s, max_file_size_bytes=%s, max_total_size_bytes=%s, max_text_chars=%s, sample_rows=%s", + file_path, + file_conn_id, + multi_modal, + max_files, + max_file_size_bytes, + max_total_size_bytes, + max_text_chars, + sample_rows, + ) + root = ObjectStoragePath(file_path, conn_id=file_conn_id) + resolved_paths, omitted_files = _resolve_paths(root=root, max_files=max_files) + log.info( + "Resolved %s file(s) from %s%s", + len(resolved_paths), + file_path, + f"; omitted {omitted_files} additional file(s) due to max_files limit" if omitted_files else "", + ) + if log.isEnabledFor(logging.DEBUG): + log.debug("Resolved file paths: %s", [str(path) for path in resolved_paths]) + + discovered_files: list[_DiscoveredFile] = [] + total_size_bytes = 0 + for path in resolved_paths: + discovered = _discover_file( + path=path, + max_file_size_bytes=max_file_size_bytes, + ) + total_size_bytes += discovered.size_bytes + if total_size_bytes > max_total_size_bytes: + log.info( + "Rejecting file set before content reads because cumulative size reached %s bytes (limit=%s bytes).", + total_size_bytes, + max_total_size_bytes, + ) + raise LLMFileAnalysisLimitExceededError( + "Total input size exceeds the configured limit: " + f"{total_size_bytes} bytes > {max_total_size_bytes} bytes." + ) + discovered_files.append(discovered) + + log.info( + "Validated byte limits for %s file(s) before reading file contents; total_size_bytes=%s.", + len(discovered_files), + total_size_bytes, + ) + + prepared_files: list[_PreparedFile] = [] + processed_size_bytes = 0 + for discovered in discovered_files: + remaining_content_bytes = max_total_size_bytes - processed_size_bytes + if remaining_content_bytes <= 0: + raise LLMFileAnalysisLimitExceededError( + "Total processed input size exceeds the configured limit after decompression." + ) + prepared = _prepare_file( + discovered_file=discovered, + multi_modal=multi_modal, + sample_rows=sample_rows, + max_content_bytes=min(max_file_size_bytes, remaining_content_bytes), + ) + processed_size_bytes += prepared.content_size_bytes + prepared_files.append(prepared) + + text_truncated = _apply_text_budget(prepared_files=prepared_files, max_text_chars=max_text_chars) + if text_truncated: + log.info("Normalized text content exceeded max_text_chars=%s and was truncated.", max_text_chars) + text_preamble = _build_text_preamble( + prompt=prompt, + prepared_files=prepared_files, + omitted_files=omitted_files, + text_truncated=text_truncated, + ) + attachments = [prepared.attachment for prepared in prepared_files if prepared.attachment is not None] + text_file_count = sum(1 for prepared in prepared_files if prepared.text_content is not None) + user_content: str | list[UserContent] + if attachments: + user_content = [text_preamble, *attachments] + else: + user_content = text_preamble + log.info( + "Prepared file analysis request with %s text file(s), %s attachment(s), total_size_bytes=%s.", + text_file_count, + len(attachments), + total_size_bytes, + ) + if log.isEnabledFor(logging.DEBUG): + log.debug("Prepared text preamble length=%s", len(text_preamble)) + return FileAnalysisRequest( + user_content=user_content, + resolved_paths=[str(path) for path in resolved_paths], + total_size_bytes=total_size_bytes, + omitted_files=omitted_files, + text_truncated=text_truncated, + attachment_count=len(attachments), + text_file_count=text_file_count, + ) + + +def _resolve_paths(*, root: ObjectStoragePath, max_files: int) -> tuple[list[ObjectStoragePath], int]: + try: + if root.is_file(): + return [root], 0 + except FileNotFoundError: + pass + + try: + selected: list[tuple[str, ObjectStoragePath]] = [] + omitted_files = 0 + for path in root.rglob("*"): Review Comment: updated docs https://github.com/apache/airflow/pull/64077/changes/c0093c98328fa7062632b9ad5357bc525b0fb6fa -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
