Ma77Ball commented on code in PR #5570: URL: https://github.com/apache/texera/pull/5570#discussion_r3383892997
########## common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/AudioTaskCodegen.scala: ########## @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace.codegen + +/** + * Codegen for Hugging Face audio task families. + * + * ASR and audio-classification send audio bytes as the raw request body. + * Text-to-speech is prompt-driven and sends a JSON payload; its providers + * return either audio bytes directly or a JSON envelope pointing to audio. + */ +object AudioTaskCodegen extends TaskCodegen { + + override val task: String = "automatic-speech-recognition" + + override val tasks: Set[String] = Set( + "automatic-speech-recognition", + "audio-classification", + "text-to-speech" + ) + + override def payloadPython(ctx: CodegenContext): String = + """ if task in audio_only_tasks: + | payload = current_audio_bytes + | use_raw_binary_body = True + | raw_binary_headers = audio_headers + | elif task == "text-to-speech": + | payload = {"inputs": prompt_value} + | else: + | payload = {"inputs": prompt_value}""".stripMargin Review Comment: The `else` is unreachable: `tasks` only contains the three audio tasks, all covered above, and it duplicates the text-to-speech payload (which will drift if edited). Fix: ```suggestion | elif task == "text-to-speech": | payload = {"inputs": prompt_value}""".stripMargin ``` ########## common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/MediaGenCodegen.scala: ########## @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace.codegen + +/** + * Codegen for prompt-driven media generation tasks. + * + * Providers return media in several shapes: raw bytes, OpenAI-style + * b64_json, or URLs. URL responses are normalized to data URLs by the + * shared `_url_to_data_url` helper so downstream result rendering receives + * a stable string format. + */ +object MediaGenCodegen extends TaskCodegen { + + override val task: String = "text-to-image" + + override val tasks: Set[String] = Set( + "text-to-image", + "text-to-video" + ) + + override def payloadPython(ctx: CodegenContext): String = + """ if task in ("text-to-image", "text-to-video"): + | payload = {"inputs": prompt_value} + | else: + | payload = {"inputs": prompt_value}""".stripMargin Review Comment: `tasks` is exactly `{"text-to-image", "text-to-video"}`, so the `if` always matches and this `else` is dead, duplicating the same payload. Fix: ```suggestion """ payload = {"inputs": prompt_value}""".stripMargin ``` ########## common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/PythonCodegenBase.scala: ########## @@ -0,0 +1,869 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace.codegen + +import org.apache.texera.amber.pybuilder.PythonTemplateBuilder.PythonTemplateBuilderStringContext + +/** + * Builds the Python script emitted by HuggingFaceInferenceOpDesc. + * + * The script defines a `ProcessTableOperator` class with: + * - Per-instance configuration set in `open(self)` from base64-encoded + * values that the `pyb"..."` macro decodes at runtime (so user-input + * strings never appear as raw Python literals in the source). + * - A provider-fallback system that walks the HF Hub's inference-provider + * list cheapest-first and tries each provider's native chat-completions + * route, with HF Inference Router as the default. + * - A `process_table` loop that validates the prompt column, builds the + * per-row payload via the per-task codegen, posts to the resolved + * provider, and parses the response. + * - A `_parse_response` task switch whose branches are provided by the + * per-task codegen. + * + * Per-task variation lives in `TaskCodegen` implementations. This class + * holds only what is shared across all HF tasks; per-task helpers (image + * loading, audio MIME inference, media-URL fetching, etc.) will be added + * in subsequent PRs as the corresponding task families land. + */ +object PythonCodegenBase { + + def render(ctx: CodegenContext, codegen: TaskCodegen): String = { + val payload = codegen.payloadPython(ctx) + val parse = codegen.parsePython(ctx) + val hfApiToken = ctx.hfApiToken + val modelId = ctx.modelId + val promptColumn = ctx.promptColumn + val resultColumn = ctx.resultColumn + val task = ctx.task + val systemPrompt = ctx.systemPrompt + val maxNewTokens = ctx.safeMaxTokens + val temperature = ctx.safeTemp + val imageInput = ctx.imageInput + val inputImageColumn = ctx.inputImageColumn + val audioInput = ctx.audioInput + val inputAudioColumn = ctx.inputAudioColumn + pyb"""import os + |import re + |import json + |import base64 + |import requests + |import pandas as pd + |from urllib.parse import urlparse + |from pytexera import * + | + |# Defensive format check for MODEL_ID before it is interpolated into + |# HF URL paths. The base host is hardcoded so the worst case isn't + |# SSRF, but rejecting `..` segments / query strings / fragments / + |# control chars keeps the operator's request shape predictable. + |_HF_MODEL_ID_PATTERN = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*(/[A-Za-z0-9._-]+)+$$") + | + |class ProcessTableOperator(UDFTableOperator): + | + | # Providers ranked cheapest-first (lower index = cheaper). + | # Unknown providers are appended at the end. + | PROVIDER_COST_PRIORITY = [ + | "hf-inference", + | "cerebras", + | "sambanova", + | "groq", + | "novita", + | "nebius", + | "fireworks-ai", + | "together", + | "hyperbolic", + | "scaleway", + | "nscale", + | "ovhcloud", + | "deepinfra", + | "featherless-ai", + | "baseten", + | "publicai", + | "nvidia", + | "openai", + | "cohere", + | "clarifai", + | ] + | + | # Per-provider chat-completions route overrides. Providers not listed + | # here use the default `v1/chat/completions` path. Single source of + | # truth for both _post_with_fallback (text-gen) and _call_provider + | # (OpenAI-compatible fallback) so the two stay in sync as providers + | # are added. + | CHAT_ROUTES = { + | "groq": "openai/v1/chat/completions", + | "fireworks-ai": "inference/v1/chat/completions", + | "cohere": "compatibility/v1/chat/completions", + | "clarifai": "v2/ext/openai/v1/chat/completions", + | "deepinfra": "v1/openai/chat/completions", + | } + | + | # Third-party providers that speak the OpenAI chat-completions + | # protocol. Used by _call_provider's OpenAI-compatible branch. + | OPENAI_COMPATIBLE_PROVIDERS = ( + | "cerebras", "sambanova", "groq", "novita", "nebius", + | "fireworks-ai", "together", "hyperbolic", "cohere", "clarifai", + | "deepinfra", "featherless-ai", "nscale", "nvidia", "openai", + | "ovhcloud", "publicai", "scaleway", "baseten", + | ) + | + | def open(self): + | # User-provided strings reach the operator via base64-encoded + | # decode expressions so they cannot break Python syntax or + | # leak raw text into the generated source. + | self.HF_API_TOKEN = $hfApiToken + | self.MODEL_ID = $modelId + | self.PROMPT_COLUMN = $promptColumn + | self.RESULT_COLUMN = $resultColumn + | self.TASK = $task + | self.SYSTEM_PROMPT = $systemPrompt + | self.MAX_NEW_TOKENS = $maxNewTokens + | self.TEMPERATURE = $temperature + | self.IMAGE_INPUT = $imageInput + | self.INPUT_IMAGE_COLUMN = $inputImageColumn + | self.AUDIO_INPUT = $audioInput + | self.INPUT_AUDIO_COLUMN = $inputAudioColumn + | + | def _resolve_providers(self, token): + | '''Query the HF Hub API for inference providers serving this model. + | Returns a list of dicts with 'name' and 'providerId' sorted + | cheapest-first. Falls back to hf-inference if anything goes wrong. + | ''' + | try: + | resp = requests.get( + | f"https://huggingface.co/api/models/{self.MODEL_ID}", + | headers={"Authorization": f"Bearer {token}"}, + | params={"expand[]": "inferenceProviderMapping"}, + | timeout=30, + | ) + | if resp.status_code == 200: + | data = resp.json() + | mapping = ( + | data.get("inferenceProviderMapping") + | or data.get("inference_provider_mapping") + | or {} + | ) + | if mapping: + | live = [ + | { + | "name": p, + | "providerId": v.get("providerId", self.MODEL_ID), + | "task": v.get("task", ""), + | "isModelAuthor": v.get("isModelAuthor", False), + | } + | for p, v in mapping.items() + | if isinstance(v, dict) and v.get("status") == "live" + | ] + | if live: + | priority = {name: idx for idx, name in enumerate(self.PROVIDER_COST_PRIORITY)} + | live.sort(key=lambda prov: priority.get(prov["name"], len(self.PROVIDER_COST_PRIORITY))) + | return live + | except Exception: + | pass + | return [{"name": "hf-inference", "providerId": self.MODEL_ID}] + | + | def _post_with_fallback(self, providers, json_headers, raw_binary_headers, pipeline_payload, use_raw_binary_body, prompt_value): + | '''Try providers in order, using the correct API route for each. + | Returns (response, provider_summary). provider_summary is None on + | success or a string describing what failed. + | ''' + | RETRYABLE = (400, 404, 422, 429, 502, 503) + | last_resp = None + | errors = [] + | for prov in providers: + | provider_name = prov["name"] + | provider_id = prov["providerId"] + | is_model_author = prov.get("isModelAuthor", False) + | prov_task = prov.get("task", "") + | try: + | if self.TASK in ("text-generation", "image-text-to-text"): + | route = self.CHAT_ROUTES.get(provider_name, "v1/chat/completions") + | url = f"https://router.huggingface.co/{provider_name}/{route}" + | resp = requests.post(url, headers=json_headers, json=pipeline_payload, timeout=120) + | elif is_model_author and prov_task in ("image-to-text", "image-text-to-text") and provider_name not in ("zai-org",): + | url = f"https://router.huggingface.co/{provider_name}/v1/chat/completions" + | img_b64 = "" + | if use_raw_binary_body and isinstance(pipeline_payload, bytes): + | img_b64 = base64.b64encode(pipeline_payload).decode("utf-8") + | chat_payload = { + | "model": provider_id, + | "messages": [{ + | "role": "user", + | "content": [ + | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}} if img_b64 else None, + | {"type": "text", "text": prompt_value if prompt_value else "What is in this image?"}, + | ], + | }], + | } + | chat_payload["messages"][0]["content"] = [c for c in chat_payload["messages"][0]["content"] if c is not None] + | resp = requests.post(url, headers=json_headers, json=chat_payload, timeout=120) + | elif provider_name == "hf-inference": + | url = f"https://router.huggingface.co/hf-inference/models/{self.MODEL_ID}" + | if use_raw_binary_body: + | resp = requests.post(url, headers=raw_binary_headers, data=pipeline_payload, timeout=120) + | else: + | resp = requests.post(url, headers=json_headers, json=pipeline_payload, timeout=120) + | else: + | resp = self._call_provider(provider_name, provider_id, json_headers, raw_binary_headers, pipeline_payload, use_raw_binary_body, prompt_value) + | except Exception as e: + | errors.append(f"{provider_name}: {type(e).__name__}") + | continue + | if resp.status_code in (200, 201): + | return resp, None + | if resp.status_code == 401: + | return resp, None + | try: + | detail = resp.json().get("error", resp.text[:200]) + | except Exception: + | detail = resp.text[:200] if resp.text else "no details" + | errors.append(f"{provider_name}: HTTP {resp.status_code} - {detail}") + | last_resp = resp + | if resp.status_code not in RETRYABLE: + | return resp, "; ".join(errors) + | summary = "; ".join(errors) if errors else "no providers available" + | return last_resp, summary + | + | def _call_provider(self, provider_name, provider_id, json_headers, raw_binary_headers, pipeline_payload, use_raw_binary_body, prompt_value): + | '''Route to a third-party provider using its native API format. + | Handles OpenAI-compatible chat providers for text-gen, zai-org's + | custom API, Replicate / Fal-ai / Wavespeed for media-generation + | and image-to-image, and an unknown-provider fallback that tries + | the pipeline format then chat completions. + | ''' + | base = f"https://router.huggingface.co/{provider_name}" + | task = self.TASK + | img_b64 = "" + | if use_raw_binary_body and isinstance(pipeline_payload, bytes): + | img_b64 = base64.b64encode(pipeline_payload).decode("utf-8") + | + | # zai-org: custom /api/paas/v4/ surface. + | if provider_name == "zai-org": + | zai_headers = {**json_headers, "x-source-channel": "hugging_face", "accept-language": "en-US,en"} + | if task in ("image-to-text", "image-text-to-text"): + | url = f"{base}/api/paas/v4/layout_parsing" + | file_data = f"data:image/png;base64,{img_b64}" if img_b64 else "" + | return requests.post(url, headers=zai_headers, json={"model": provider_id, "file": file_data}, timeout=120) + | url = f"{base}/api/paas/v4/chat/completions" + | messages = [{"role": "user", "content": prompt_value}] + | if img_b64: + | messages = [{"role": "user", "content": [ + | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}}, + | {"type": "text", "text": prompt_value if prompt_value else "What is in this image?"}, + | ]}] + | return requests.post(url, headers=zai_headers, json={"model": provider_id, "messages": messages}, timeout=120) + | + | # Replicate: synchronous predictions endpoint with polling fallback. + | if provider_name == "replicate": + | url = f"{base}/v1/models/{provider_id}/predictions" + | hdrs = {**json_headers, "Prefer": "wait"} + | if task == "text-to-speech": + | inp = {"text": prompt_value} + | elif task in ("text-to-image", "text-to-video"): + | inp = {"prompt": prompt_value} + | elif task == "automatic-speech-recognition" and img_b64: + | inp = {"audio": f"data:audio/wav;base64,{img_b64}"} Review Comment: In the Replicate branch the audio data URL MIME type is hardcoded to `audio/wav`, but `current_audio_bytes` may be mp3/flac/ogg, and the real type is already computed by `_get_audio_content_type()`. A non-wav payload labeled `audio/wav` can be rejected or mis-decoded by the model. The audio type is not currently threaded into `_call_provider`, so this is a small refactor. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
