This is an automated email from the ASF dual-hosted git repository. github-merge-queue[bot] pushed a commit to branch gh-readonly-queue/main/pr-5570-18e4e67c70fcc4d4c739338d9c30ce552db3decc in repository https://gitbox.apache.org/repos/asf/texera.git
commit 439ea72e46b78aec7f71e8889225f2c90942a2c2 Author: Anish Shivamurthy <[email protected]> AuthorDate: Mon Jun 22 15:10:55 2026 -0700 feat(huggingface): add audio and media generation tasks (#5570) ## What changes were proposed in this PR? Adds the audio and media-generation task families — 5 HF pipeline tasks — as new `TaskCodegen`s plugged into the dispatcher established by the text-generation PR: audio tasks: `automatic-speech-recognition`, `audio-classification`, `text-to-speech` media-generation tasks: `text-to-image`, `text-to-video` `codegen/AudioTaskCodegen.scala` supplies the per-task payload + parse Python branches for the 3 audio tasks. `codegen/MediaGenCodegen.scala` supplies the per-task payload + parse Python branches for the 2 media-generation tasks. `CodegenContext` is extended with `audioInput` + `inputAudioColumn` (`EncodableString`). `HuggingFaceInferenceOpDesc.scala` gains 2 new `@JsonProperty` fields and registers `AudioTaskCodegen` + `MediaGenCodegen` in the dispatcher. `PythonCodegenBase.scala` grows to host the shared audio/media infrastructure: - Audio task-family tuple (`audio_only_tasks`) in `process_table`. - Per-row audio-byte resolution from upload or column input. - Raw binary request handling for `automatic-speech-recognition` and `audio-classification`. - JSON payload handling for `text-to-speech`. - Provider-specific routing for media generation and audio generation through `_call_provider`, including OpenAI-compatible image/audio endpoints where supported. - Response parsing for audio/media outputs, including data-URL conversion for generated media URLs. - Media helper support for converting remote URLs into `data:image/...`, `data:audio/...`, or `data:video/...` URLs where needed. - Hardened audio input loading to match the image-input path: uploaded audio is accepted as a data URL, remote audio is fetched through the existing HTTPS-only `_fetch_remote_url` helper, and arbitrary worker-local file paths are no longer read. User-input strings continue to flow through `pyb"..."` + `EncodableString` so they reach Python as `self.decode_python_template('<base64>')` rather than raw literals. `PythonCodeRawInvalidTextSpec` still passes with 117/117 descriptors py_compile cleanly. ## Any related issues, documentation, or discussions? Tracking issue: Add audio and media-generation task families to HuggingFace operator apache#5288 Closes apache#5288 Stacked on: Add image task family (`ImageTaskCodegen`) to HuggingFace operator / `hf/03-image-tasks` Parent issue: Add Hugging Face inference operator apache#5041 Closed sibling issue: Add HuggingFaceModelResource REST endpoints for HF operator UI apache#5134 ## How was this PR tested? `sbt "WorkflowOperator/compile; WorkflowOperator/Test/compile"` clean. `sbt scalafmtCheck` clean. `sbt "WorkflowOperator/testOnly org.apache.texera.amber.operator.huggingFace.HuggingFaceInferenceOpDescSpec org.apache.texera.amber.util.PythonCodeRawInvalidTextSpec"` — 26 focused tests pass, including HuggingFace audio/media task coverage and the raw Python descriptor scan. `sbt "WorkflowOperator/testOnly org.apache.texera.amber.util.PythonCodeRawInvalidTextSpec"` — 117/117 descriptors py_compile cleanly with the new operator code paths, no marker leaks. - Added regression coverage that audio remote input routes through `_fetch_remote_url(audio_input)` and no longer uses raw `requests.get(audio_input)` or local file reads. ## Was this PR authored or co-authored using generative AI tooling? Yes, co-authored with generative AI tooling (Codex). --- .../huggingFace/HuggingFaceInferenceOpDesc.scala | 23 ++++- .../huggingFace/codegen/AudioTaskCodegen.scala | 79 +++++++++++++++ .../huggingFace/codegen/MediaGenCodegen.scala | 78 +++++++++++++++ .../huggingFace/codegen/PythonCodegenBase.scala | 110 +++++++++++++++++++-- .../operator/huggingFace/codegen/TaskCodegen.scala | 4 +- .../HuggingFaceInferenceOpDescSpec.scala | 101 ++++++++++++++++++- 6 files changed, 384 insertions(+), 11 deletions(-) diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDesc.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDesc.scala index 5f203717d1..f7805266cf 100644 --- a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDesc.scala +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDesc.scala @@ -25,8 +25,10 @@ import org.apache.texera.amber.core.tuple.{AttributeType, Schema} import org.apache.texera.amber.core.workflow.{InputPort, OutputPort, PortIdentity} import org.apache.texera.amber.operator.PythonOperatorDescriptor import org.apache.texera.amber.operator.huggingFace.codegen.{ + AudioTaskCodegen, CodegenContext, ImageTaskCodegen, + MediaGenCodegen, PythonCodegenBase, TaskCodegen, TextGenCodegen @@ -95,6 +97,17 @@ class HuggingFaceInferenceOpDesc extends PythonOperatorDescriptor { @AutofillAttributeName var inputImageColumn: EncodableString = "" + @JsonProperty(value = "audioInput", required = false) + @JsonSchemaTitle("Audio Upload") + @JsonPropertyDescription("Upload audio for Hugging Face audio tasks") + var audioInput: EncodableString = "" + + @JsonProperty(value = "inputAudioColumn", required = false) + @JsonSchemaTitle("Input Audio Column") + @JsonPropertyDescription("Column containing audio data from the input table") + @AutofillAttributeName + var inputAudioColumn: EncodableString = "" + @JsonProperty( value = "systemPrompt", required = false, @@ -138,6 +151,8 @@ class HuggingFaceInferenceOpDesc extends PythonOperatorDescriptor { val byTask = scala.collection.mutable.Map.empty[String, TaskCodegen] byTask += (TextGenCodegen.task -> TextGenCodegen) ImageTaskCodegen.tasks.foreach(t => byTask += (t -> ImageTaskCodegen)) + AudioTaskCodegen.tasks.foreach(t => byTask += (t -> AudioTaskCodegen)) + MediaGenCodegen.tasks.foreach(t => byTask += (t -> MediaGenCodegen)) byTask.toMap } @@ -181,6 +196,10 @@ class HuggingFaceInferenceOpDesc extends PythonOperatorDescriptor { if (imageInput == null) "" else imageInput val safeInputImageColumn: EncodableString = if (inputImageColumn == null) "" else inputImageColumn + val safeAudioInput: EncodableString = + if (audioInput == null) "" else audioInput + val safeInputAudioColumn: EncodableString = + if (inputAudioColumn == null) "" else inputAudioColumn val ctx = CodegenContext( hfApiToken = safeToken, @@ -192,7 +211,9 @@ class HuggingFaceInferenceOpDesc extends PythonOperatorDescriptor { safeMaxTokens = safeMaxTokens, safeTemp = safeTemp, imageInput = safeImageInput, - inputImageColumn = safeInputImageColumn + inputImageColumn = safeInputImageColumn, + audioInput = safeAudioInput, + inputAudioColumn = safeInputAudioColumn ) PythonCodegenBase.render(ctx, codegenForTask(safeTask)) diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/AudioTaskCodegen.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/AudioTaskCodegen.scala new file mode 100644 index 0000000000..560244962a --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/AudioTaskCodegen.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace.codegen + +/** + * Codegen for Hugging Face audio task families. + * + * ASR and audio-classification send audio bytes as the raw request body. + * Text-to-speech is prompt-driven and sends a JSON payload; its providers + * return either audio bytes directly or a JSON envelope pointing to audio. + */ +object AudioTaskCodegen extends TaskCodegen { + + override val task: String = "automatic-speech-recognition" + + override val tasks: Set[String] = Set( + "automatic-speech-recognition", + "audio-classification", + "text-to-speech" + ) + + override def payloadPython(ctx: CodegenContext): String = + """ if task in audio_only_tasks: + | payload = current_audio_bytes + | use_raw_binary_body = True + | raw_binary_headers = audio_headers + | elif task == "text-to-speech": + | payload = {"inputs": prompt_value}""".stripMargin + + override def parsePython(ctx: CodegenContext): String = + """ if task == "text-to-speech": + | if isinstance(body, dict): + | if "output" in body: + | out = body["output"] + | url = out[0] if isinstance(out, list) else out + | if isinstance(url, str) and url.startswith("http"): + | return self._url_to_data_url(url) + | if "audio" in body: + | audio = body["audio"] + | if isinstance(audio, dict): + | if "url" in audio: + | return self._url_to_data_url(audio["url"]) + | if "b64_json" in audio: + | return f"data:audio/mpeg;base64,{audio['b64_json']}" + | if "data" in body: + | data = body["data"] + | if data and isinstance(data[0], dict): + | if "url" in data[0]: + | return self._url_to_data_url(data[0]["url"]) + | if "b64_json" in data[0]: + | return f"data:audio/mpeg;base64,{data[0]['b64_json']}" + | return json.dumps(body) + | elif task == "automatic-speech-recognition": + | if isinstance(body, dict): + | if "text" in body: + | return body["text"] + | if "generated_text" in body: + | return body["generated_text"] + | return json.dumps(body) + | elif task == "audio-classification": + | return json.dumps(body)""".stripMargin +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/MediaGenCodegen.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/MediaGenCodegen.scala new file mode 100644 index 0000000000..73047da89c --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/MediaGenCodegen.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace.codegen + +/** + * Codegen for prompt-driven media generation tasks. + * + * Providers return media in several shapes: raw bytes, OpenAI-style + * b64_json, or URLs. URL responses are normalized to data URLs by the + * shared `_url_to_data_url` helper so downstream result rendering receives + * a stable string format. + */ +object MediaGenCodegen extends TaskCodegen { + + override val task: String = "text-to-image" + + override val tasks: Set[String] = Set( + "text-to-image", + "text-to-video" + ) + + override def payloadPython(ctx: CodegenContext): String = + """ payload = {"inputs": prompt_value}""".stripMargin + + override def parsePython(ctx: CodegenContext): String = + """ if task == "text-to-image": + | if isinstance(body, dict): + | if "output" in body: + | out = body["output"] + | url = out[0] if isinstance(out, list) else out + | if isinstance(url, str) and url.startswith("http"): + | return self._url_to_data_url(url) + | if "images" in body: + | images = body["images"] + | if images and isinstance(images[0], dict) and "url" in images[0]: + | return self._url_to_data_url(images[0]["url"]) + | if "data" in body: + | data = body["data"] + | if isinstance(data, dict) and "outputs" in data: + | outputs = data["outputs"] + | if outputs and isinstance(outputs[0], str) and outputs[0].startswith("http"): + | return self._url_to_data_url(outputs[0]) + | if isinstance(data, list) and data and isinstance(data[0], dict): + | if "b64_json" in data[0]: + | return f"data:image/png;base64,{data[0]['b64_json']}" + | if "url" in data[0]: + | return self._url_to_data_url(data[0]["url"]) + | return json.dumps(body) + | elif task == "text-to-video": + | if isinstance(body, dict): + | if "output" in body: + | out = body["output"] + | url = out[0] if isinstance(out, list) else out + | if isinstance(url, str) and url.startswith("http"): + | return self._url_to_data_url(url) + | if "video" in body: + | video = body["video"] + | if isinstance(video, dict) and "url" in video: + | return self._url_to_data_url(video["url"]) + | return json.dumps(body)""".stripMargin +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/PythonCodegenBase.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/PythonCodegenBase.scala index eac4641c62..8671b9a76a 100644 --- a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/PythonCodegenBase.scala +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/PythonCodegenBase.scala @@ -57,6 +57,8 @@ object PythonCodegenBase { val temperature = ctx.safeTemp val imageInput = ctx.imageInput val inputImageColumn = ctx.inputImageColumn + val audioInput = ctx.audioInput + val inputAudioColumn = ctx.inputAudioColumn pyb"""import os |import re |import json @@ -137,6 +139,8 @@ object PythonCodegenBase { | self.TEMPERATURE = $temperature | self.IMAGE_INPUT = $imageInput | self.INPUT_IMAGE_COLUMN = $inputImageColumn + | self.AUDIO_INPUT = $audioInput + | self.INPUT_AUDIO_COLUMN = $inputAudioColumn | | def _resolve_providers(self, token): | '''Query the HF Hub API for inference providers serving this model. @@ -286,7 +290,14 @@ object PythonCodegenBase { | if provider_name == "replicate": | url = f"{base}/v1/models/{provider_id}/predictions" | hdrs = {**json_headers, "Prefer": "wait"} - | if task == "image-to-image" and img_b64: + | if task == "text-to-speech": + | inp = {"text": prompt_value} + | elif task in ("text-to-image", "text-to-video"): + | inp = {"prompt": prompt_value} + | elif task in ("automatic-speech-recognition", "audio-classification") and img_b64: + | audio_content_type = raw_binary_headers.get("Content-Type", "audio/mpeg") + | inp = {"audio": f"data:{audio_content_type};base64,{img_b64}"} + | elif task == "image-to-image" and img_b64: | data_url = f"data:image/png;base64,{img_b64}" | inp = {"image": data_url, "images": [data_url], "input_image": data_url, "prompt": prompt_value} | elif img_b64: @@ -340,6 +351,10 @@ object PythonCodegenBase { | # Fal-ai: per-model endpoint. | if provider_name == "fal-ai": | url = f"{base}/{provider_id}" + | if task == "text-to-speech": + | return requests.post(url, headers=json_headers, json={"text": prompt_value}, timeout=120) + | if task in ("text-to-image", "text-to-video"): + | return requests.post(url, headers=json_headers, json={"prompt": prompt_value}, timeout=120) | if task == "image-to-image" and img_b64: | data_url = f"data:image/png;base64,{img_b64}" | return requests.post(url, headers=json_headers, json={"image_url": data_url, "image_urls": [data_url], "prompt": prompt_value}, timeout=120) @@ -398,6 +413,12 @@ object PythonCodegenBase { | return poll_resp | | if provider_name in self.OPENAI_COMPATIBLE_PROVIDERS: + | if task == "text-to-image": + | url = f"{base}/v1/images/generations" + | return requests.post(url, headers=json_headers, json={"model": provider_id, "prompt": prompt_value}, timeout=120) + | if task == "text-to-speech": + | url = f"{base}/v1/audio/speech" + | return requests.post(url, headers=json_headers, json={"model": provider_id, "input": prompt_value}, timeout=120) | url = f"{base}/{self.CHAT_ROUTES.get(provider_name, 'v1/chat/completions')}" | messages = [{"role": "user", "content": prompt_value}] | if img_b64: @@ -444,6 +465,7 @@ object PythonCodegenBase { | image_only_tasks = ("image-classification", "object-detection", "image-segmentation", "image-to-text") | image_prompt_tasks = ("visual-question-answering", "document-question-answering", "zero-shot-image-classification", "image-text-to-text", "image-to-image") | image_tasks = image_only_tasks + image_prompt_tasks + | audio_only_tasks = ("automatic-speech-recognition", "audio-classification") | | # --- validate MODEL_ID format before any HF URL is built --- | if not _HF_MODEL_ID_PATTERN.match(self.MODEL_ID or ""): @@ -463,8 +485,8 @@ object PythonCodegenBase { | # --- resolve all available inference providers for this model (tried in order) --- | providers = self._resolve_providers(token) | - | # --- validate prompt column exists (required for non-image tasks) --- - | if task not in image_tasks: + | # --- validate prompt column exists (skipped for image tasks and binary-only audio tasks) --- + | if task not in image_tasks and task not in audio_only_tasks: | assert prompt_col in table.columns, ( | f"Prompt column '{prompt_col}' not found in input table. " | f"Available columns: {list(table.columns)}" @@ -484,12 +506,19 @@ object PythonCodegenBase { | "Authorization": f"Bearer {token}", | "Content-Type": "application/octet-stream", | } - | | # --- resolve image source (upload or column) for image tasks --- | has_image_upload = bool(self.IMAGE_INPUT) and bool(str(self.IMAGE_INPUT).strip()) | use_image_column = not has_image_upload and bool(self.INPUT_IMAGE_COLUMN) and self.INPUT_IMAGE_COLUMN in table.columns | image_bytes = None | image_error = None + | has_audio_upload = bool(self.AUDIO_INPUT) and bool(str(self.AUDIO_INPUT).strip()) + | use_audio_column = not has_audio_upload and bool(self.INPUT_AUDIO_COLUMN) and self.INPUT_AUDIO_COLUMN in table.columns + | audio_headers = { + | "Authorization": f"Bearer {token}", + | "Content-Type": "application/octet-stream" if use_audio_column else self._get_audio_content_type(), + | } + | audio_bytes = None + | audio_error = None | if task in image_tasks and not use_image_column: | if not has_image_upload: | image_error = "No image source. Set an Input Image Column or upload an image." @@ -498,15 +527,28 @@ object PythonCodegenBase { | image_bytes = self._read_image_input() | except Exception as e: | image_error = f"Could not read image input ({type(e).__name__}: {e})" + | if task in audio_only_tasks and not use_audio_column: + | if not has_audio_upload: + | audio_error = "No audio source. Set an Input Audio Column or upload audio." + | else: + | try: + | audio_bytes = self._read_audio_input() + | except Exception as e: + | audio_error = f"Could not read audio input ({type(e).__name__}: {e})" | | results = [] | for idx, row in table.iterrows(): | if image_error is not None: | results.append(self._format_error("Image task configuration error", image_error)) | continue + | if audio_error is not None: + | results.append(self._format_error("Audio task configuration error", audio_error)) + | continue | | if task in image_only_tasks: | prompt_value = "" + | elif task in audio_only_tasks: + | prompt_value = "" | elif task in image_prompt_tasks and prompt_col not in table.columns: | prompt_value = "What is shown in this image?" | else: @@ -529,6 +571,18 @@ object PythonCodegenBase { | results.append(self._format_error("Image data error", f"Row {idx}: {type(e).__name__}: {e}")) | continue | + | # --- resolve per-row audio bytes from column --- + | current_audio_bytes = audio_bytes + | if task in audio_only_tasks and use_audio_column: + | try: + | current_audio_bytes = self._read_binary_value(row[self.INPUT_AUDIO_COLUMN]) + | if current_audio_bytes is None: + | results.append(self._format_error("Audio data error", f"Row {idx}: audio column is empty")) + | continue + | except Exception as e: + | results.append(self._format_error("Audio data error", f"Row {idx}: {type(e).__name__}: {e}")) + | continue + | | # --- build task-specific payload (provided by per-task codegen) --- | use_raw_binary_body = False | raw_binary_headers = image_headers @@ -576,6 +630,10 @@ object PythonCodegenBase { | b64 = base64.b64encode(resp.content).decode("utf-8") | results.append(f"data:{content_type};base64,{b64}") | continue + | if content_type.startswith("audio/") or content_type.startswith("video/"): + | b64 = base64.b64encode(resp.content).decode("utf-8") + | results.append(f"data:{content_type};base64,{b64}") + | continue | | try: | body = resp.json() @@ -702,6 +760,22 @@ object PythonCodegenBase { | def _image_input_as_base64(self, image_bytes): | return base64.b64encode(image_bytes).decode("utf-8") | + | def _read_audio_input(self): + | audio_input = str(self.AUDIO_INPUT or "").strip() + | if audio_input.startswith("data:"): + | _, encoded = audio_input.split(",", 1) + | return base64.b64decode(encoded) + | if audio_input.startswith("http://") or audio_input.startswith("https://"): + | _, data = self._fetch_remote_url(audio_input) + | return data + | # Reading arbitrary worker-filesystem paths is intentionally NOT + | # supported: uploaded audio arrives as a data URL and remote audio + | # must be fetched through the hardened https-only helper above. + | raise ValueError( + | "Unsupported audio input. Upload an audio file (sent as a data URL) " + | "or provide a public https audio URL." + | ) + | | def _read_binary_value(self, value): | if value is None: | return None @@ -821,6 +895,30 @@ object PythonCodegenBase { | return text[start_pos:pos], pos | return None, start_pos | + | def _get_audio_content_type(self): + | audio_input = str(self.AUDIO_INPUT or "").strip().lower() + | if audio_input.startswith("data:"): + | header = audio_input.split(",", 1)[0] + | if ";" in header: + | return header[5:header.index(";")] + | return header[5:] + | extension_map = { + | ".mp3": "audio/mpeg", + | ".mpeg": "audio/mpeg", + | ".wav": "audio/wav", + | ".flac": "audio/flac", + | ".ogg": "audio/ogg", + | ".oga": "audio/ogg", + | ".webm": "audio/webm", + | ".opus": "audio/webm;codecs=opus", + | ".amr": "audio/amr", + | ".m4a": "audio/m4a", + | } + | from urllib.parse import urlparse as _urlparse + | path = _urlparse(audio_input).path if audio_input.startswith("http") else audio_input + | _, ext = os.path.splitext(path) + | return extension_map.get(ext, "audio/mpeg") + | | def _url_to_data_url(self, url): | '''Fetch a URL and return a data URL with the correct MIME type. | Fetched via _fetch_remote_url so a malicious/compromised provider @@ -831,12 +929,12 @@ object PythonCodegenBase { | if not content_type or content_type == "application/octet-stream": | from urllib.parse import urlparse as _urlparse | ext = os.path.splitext(_urlparse(url).path.lower())[1] - | mime_map = {".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".gif": "image/gif", ".webp": "image/webp", ".svg": "image/svg+xml", ".mp4": "video/mp4", ".webm": "video/webm"} + | mime_map = {".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".gif": "image/gif", ".webp": "image/webp", ".svg": "image/svg+xml", ".mp3": "audio/mpeg", ".mpeg": "audio/mpeg", ".wav": "audio/wav", ".flac": "audio/flac", ".ogg": "audio/ogg", ".oga": "audio/ogg", ".m4a": "audio/m4a", ".mp4": "video/mp4", ".webm": "video/webm"} | guessed = mime_map.get(ext, "") | if guessed: | content_type = guessed | else: - | task_mime = {"image-to-image": "image/png"} + | task_mime = {"image-to-image": "image/png", "text-to-image": "image/png", "text-to-video": "video/mp4", "text-to-speech": "audio/mpeg"} | content_type = task_mime.get(self.TASK, "application/octet-stream") | b64 = base64.b64encode(data).decode("utf-8") | return f"data:{content_type};base64,{b64}" diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/TaskCodegen.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/TaskCodegen.scala index 299ea5d6e3..80bbcc58fc 100644 --- a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/TaskCodegen.scala +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/TaskCodegen.scala @@ -39,7 +39,9 @@ final case class CodegenContext( safeMaxTokens: Int, safeTemp: Double, imageInput: EncodableString = "", - inputImageColumn: EncodableString = "" + inputImageColumn: EncodableString = "", + audioInput: EncodableString = "", + inputAudioColumn: EncodableString = "" ) /** diff --git a/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDescSpec.scala b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDescSpec.scala index 0d6e09302f..eb728945f3 100644 --- a/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDescSpec.scala +++ b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDescSpec.scala @@ -21,7 +21,12 @@ package org.apache.texera.amber.operator.huggingFace import org.apache.texera.amber.core.tuple.{AttributeType, Schema} import org.apache.texera.amber.core.workflow.PortIdentity -import org.apache.texera.amber.operator.huggingFace.codegen.{CodegenContext, TextGenCodegen} +import org.apache.texera.amber.operator.huggingFace.codegen.{ + AudioTaskCodegen, + CodegenContext, + MediaGenCodegen, + TextGenCodegen +} import org.apache.texera.amber.operator.metadata.OperatorGroupConstants import org.apache.texera.amber.pybuilder.PyStringTypes.EncodableString import org.scalatest.flatspec.AnyFlatSpec @@ -39,7 +44,9 @@ class HuggingFaceInferenceOpDescSpec extends AnyFlatSpec with Matchers { temperature: Double = 0.7, resultColumn: EncodableString = "hf_response", imageInput: EncodableString = "", - inputImageColumn: EncodableString = "" + inputImageColumn: EncodableString = "", + audioInput: EncodableString = "", + inputAudioColumn: EncodableString = "" ): HuggingFaceInferenceOpDesc = { val desc = new HuggingFaceInferenceOpDesc() desc.hfApiToken = token @@ -52,6 +59,8 @@ class HuggingFaceInferenceOpDescSpec extends AnyFlatSpec with Matchers { desc.resultColumn = resultColumn desc.imageInput = imageInput desc.inputImageColumn = inputImageColumn + desc.audioInput = audioInput + desc.inputAudioColumn = inputAudioColumn desc } @@ -152,6 +161,8 @@ class HuggingFaceInferenceOpDescSpec extends AnyFlatSpec with Matchers { desc.temperature = null desc.imageInput = null desc.inputImageColumn = null + desc.audioInput = null + desc.inputAudioColumn = null val code = desc.generatePythonCode() code should include("class ProcessTableOperator(UDFTableOperator):") code should include("def open(self):") @@ -272,10 +283,15 @@ class HuggingFaceInferenceOpDescSpec extends AnyFlatSpec with Matchers { // size cap code should include("MAX_REMOTE_FETCH_BYTES") code should include("Remote file exceeds the") - // all three fetch sites route through the helper (no raw requests.get on these URLs) + // all remote fetch sites route through the helper (no raw requests.get on these URLs) code should include("_, data = self._fetch_remote_url(image_input)") + code should include("_, data = self._fetch_remote_url(audio_input)") code should include("_, data = self._fetch_remote_url(val)") code should include("raw_content_type, data = self._fetch_remote_url(url)") + code should not include "def _audio_url_to_data_url" + code should not include "requests.get(audio_input" + code should not include "os.path.exists(audio_input)" + code should not include "open(audio_input" } it should "treat pandas NA sentinels (NaN, pd.NA, NaT) as missing in _read_binary_value" in { @@ -402,6 +418,85 @@ class HuggingFaceInferenceOpDescSpec extends AnyFlatSpec with Matchers { } } + "audio task family" should + "route ASR and audio-classification through AudioTaskCodegen as raw binary payloads" in { + val code = + makeDesc(task = "automatic-speech-recognition", inputAudioColumn = "audio") + .generatePythonCode() + code should include("self.AUDIO_INPUT = ") + code should include("self.INPUT_AUDIO_COLUMN = ") + code should include( + """audio_only_tasks = ("automatic-speech-recognition", "audio-classification")""" + ) + code should include("payload = current_audio_bytes") + code should include("raw_binary_headers = audio_headers") + code should include("self._read_audio_input()") + code should include( + """"Content-Type": "application/octet-stream" if use_audio_column else self._get_audio_content_type()""" + ) + code should include( + """path = _urlparse(audio_input).path if audio_input.startswith("http") else audio_input""" + ) + code should include( + """audio_content_type = raw_binary_headers.get("Content-Type", "audio/mpeg")""" + ) + code should include( + """elif task in ("automatic-speech-recognition", "audio-classification") and img_b64:""" + ) + code should not include "data:audio/wav;base64" + code should include( + """if content_type.startswith("audio/") or content_type.startswith("video/"):""" + ) + } + + it should "route text-to-speech through AudioTaskCodegen and normalize audio URLs" in { + val code = makeDesc(task = "text-to-speech").generatePythonCode() + code should include("""elif task == "text-to-speech":""") + code should include("""payload = {"inputs": prompt_value}""") + code should include("self._url_to_data_url(") + code should include(""""text-to-speech": "audio/mpeg"""") + code should include("""".m4a": "audio/m4a"""") + code should not include "_audio_url_to_data_url" + code should include("data:audio/mpeg;base64") + } + + it should "register all audio task strings under the dispatcher" in { + AudioTaskCodegen.tasks should contain allOf ( + "automatic-speech-recognition", + "audio-classification", + "text-to-speech" + ) + AudioTaskCodegen.tasks.foreach { t => + val code = makeDesc(task = t, inputAudioColumn = "audio").generatePythonCode() + code should include("if task in audio_only_tasks:") + } + } + + "media generation task family" should + "route text-to-image through MediaGenCodegen and parse URL or b64 responses as data URLs" in { + val code = makeDesc(task = "text-to-image").generatePythonCode() + code should include("if task not in image_tasks and task not in audio_only_tasks:") + code should include("""payload = {"inputs": prompt_value}""") + code should include("""if task == "text-to-image":""") + code should include("self._url_to_data_url(") + code should include("data:image/png;base64") + } + + it should "route text-to-video through MediaGenCodegen and normalize remote video URLs" in { + val code = makeDesc(task = "text-to-video").generatePythonCode() + code should include("""elif task == "text-to-video":""") + code should include("self._url_to_data_url(") + code should include("video/mp4") + } + + it should "register all media generation task strings under the dispatcher" in { + MediaGenCodegen.tasks should contain allOf ("text-to-image", "text-to-video") + MediaGenCodegen.tasks.foreach { t => + val code = makeDesc(task = t).generatePythonCode() + code should include("""payload = {"inputs": prompt_value}""") + } + } + "getOutputSchemas" should "add the result column as a STRING to the inherited schema" in { val desc = makeDesc(resultColumn = "answer") val inputSchema = Schema().add("prompt", AttributeType.STRING)
