anishshiva7 commented on code in PR #5570: URL: https://github.com/apache/texera/pull/5570#discussion_r3411208110
########## common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/PythonCodegenBase.scala: ########## @@ -0,0 +1,869 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace.codegen + +import org.apache.texera.amber.pybuilder.PythonTemplateBuilder.PythonTemplateBuilderStringContext + +/** + * Builds the Python script emitted by HuggingFaceInferenceOpDesc. + * + * The script defines a `ProcessTableOperator` class with: + * - Per-instance configuration set in `open(self)` from base64-encoded + * values that the `pyb"..."` macro decodes at runtime (so user-input + * strings never appear as raw Python literals in the source). + * - A provider-fallback system that walks the HF Hub's inference-provider + * list cheapest-first and tries each provider's native chat-completions + * route, with HF Inference Router as the default. + * - A `process_table` loop that validates the prompt column, builds the + * per-row payload via the per-task codegen, posts to the resolved + * provider, and parses the response. + * - A `_parse_response` task switch whose branches are provided by the + * per-task codegen. + * + * Per-task variation lives in `TaskCodegen` implementations. This class + * holds only what is shared across all HF tasks; per-task helpers (image + * loading, audio MIME inference, media-URL fetching, etc.) will be added + * in subsequent PRs as the corresponding task families land. + */ +object PythonCodegenBase { + + def render(ctx: CodegenContext, codegen: TaskCodegen): String = { + val payload = codegen.payloadPython(ctx) + val parse = codegen.parsePython(ctx) + val hfApiToken = ctx.hfApiToken + val modelId = ctx.modelId + val promptColumn = ctx.promptColumn + val resultColumn = ctx.resultColumn + val task = ctx.task + val systemPrompt = ctx.systemPrompt + val maxNewTokens = ctx.safeMaxTokens + val temperature = ctx.safeTemp + val imageInput = ctx.imageInput + val inputImageColumn = ctx.inputImageColumn + val audioInput = ctx.audioInput + val inputAudioColumn = ctx.inputAudioColumn + pyb"""import os + |import re + |import json + |import base64 + |import requests + |import pandas as pd + |from urllib.parse import urlparse + |from pytexera import * + | + |# Defensive format check for MODEL_ID before it is interpolated into + |# HF URL paths. The base host is hardcoded so the worst case isn't + |# SSRF, but rejecting `..` segments / query strings / fragments / + |# control chars keeps the operator's request shape predictable. + |_HF_MODEL_ID_PATTERN = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*(/[A-Za-z0-9._-]+)+$$") + | + |class ProcessTableOperator(UDFTableOperator): + | + | # Providers ranked cheapest-first (lower index = cheaper). + | # Unknown providers are appended at the end. + | PROVIDER_COST_PRIORITY = [ + | "hf-inference", + | "cerebras", + | "sambanova", + | "groq", + | "novita", + | "nebius", + | "fireworks-ai", + | "together", + | "hyperbolic", + | "scaleway", + | "nscale", + | "ovhcloud", + | "deepinfra", + | "featherless-ai", + | "baseten", + | "publicai", + | "nvidia", + | "openai", + | "cohere", + | "clarifai", + | ] + | + | # Per-provider chat-completions route overrides. Providers not listed + | # here use the default `v1/chat/completions` path. Single source of + | # truth for both _post_with_fallback (text-gen) and _call_provider + | # (OpenAI-compatible fallback) so the two stay in sync as providers + | # are added. + | CHAT_ROUTES = { + | "groq": "openai/v1/chat/completions", + | "fireworks-ai": "inference/v1/chat/completions", + | "cohere": "compatibility/v1/chat/completions", + | "clarifai": "v2/ext/openai/v1/chat/completions", + | "deepinfra": "v1/openai/chat/completions", + | } + | + | # Third-party providers that speak the OpenAI chat-completions + | # protocol. Used by _call_provider's OpenAI-compatible branch. + | OPENAI_COMPATIBLE_PROVIDERS = ( + | "cerebras", "sambanova", "groq", "novita", "nebius", + | "fireworks-ai", "together", "hyperbolic", "cohere", "clarifai", + | "deepinfra", "featherless-ai", "nscale", "nvidia", "openai", + | "ovhcloud", "publicai", "scaleway", "baseten", + | ) + | + | def open(self): + | # User-provided strings reach the operator via base64-encoded + | # decode expressions so they cannot break Python syntax or + | # leak raw text into the generated source. + | self.HF_API_TOKEN = $hfApiToken + | self.MODEL_ID = $modelId + | self.PROMPT_COLUMN = $promptColumn + | self.RESULT_COLUMN = $resultColumn + | self.TASK = $task + | self.SYSTEM_PROMPT = $systemPrompt + | self.MAX_NEW_TOKENS = $maxNewTokens + | self.TEMPERATURE = $temperature + | self.IMAGE_INPUT = $imageInput + | self.INPUT_IMAGE_COLUMN = $inputImageColumn + | self.AUDIO_INPUT = $audioInput + | self.INPUT_AUDIO_COLUMN = $inputAudioColumn + | + | def _resolve_providers(self, token): + | '''Query the HF Hub API for inference providers serving this model. + | Returns a list of dicts with 'name' and 'providerId' sorted + | cheapest-first. Falls back to hf-inference if anything goes wrong. + | ''' + | try: + | resp = requests.get( + | f"https://huggingface.co/api/models/{self.MODEL_ID}", + | headers={"Authorization": f"Bearer {token}"}, + | params={"expand[]": "inferenceProviderMapping"}, + | timeout=30, + | ) + | if resp.status_code == 200: + | data = resp.json() + | mapping = ( + | data.get("inferenceProviderMapping") + | or data.get("inference_provider_mapping") + | or {} + | ) + | if mapping: + | live = [ + | { + | "name": p, + | "providerId": v.get("providerId", self.MODEL_ID), + | "task": v.get("task", ""), + | "isModelAuthor": v.get("isModelAuthor", False), + | } + | for p, v in mapping.items() + | if isinstance(v, dict) and v.get("status") == "live" + | ] + | if live: + | priority = {name: idx for idx, name in enumerate(self.PROVIDER_COST_PRIORITY)} + | live.sort(key=lambda prov: priority.get(prov["name"], len(self.PROVIDER_COST_PRIORITY))) + | return live + | except Exception: + | pass + | return [{"name": "hf-inference", "providerId": self.MODEL_ID}] + | + | def _post_with_fallback(self, providers, json_headers, raw_binary_headers, pipeline_payload, use_raw_binary_body, prompt_value): + | '''Try providers in order, using the correct API route for each. + | Returns (response, provider_summary). provider_summary is None on + | success or a string describing what failed. + | ''' + | RETRYABLE = (400, 404, 422, 429, 502, 503) + | last_resp = None + | errors = [] + | for prov in providers: + | provider_name = prov["name"] + | provider_id = prov["providerId"] + | is_model_author = prov.get("isModelAuthor", False) + | prov_task = prov.get("task", "") + | try: + | if self.TASK in ("text-generation", "image-text-to-text"): + | route = self.CHAT_ROUTES.get(provider_name, "v1/chat/completions") + | url = f"https://router.huggingface.co/{provider_name}/{route}" + | resp = requests.post(url, headers=json_headers, json=pipeline_payload, timeout=120) + | elif is_model_author and prov_task in ("image-to-text", "image-text-to-text") and provider_name not in ("zai-org",): + | url = f"https://router.huggingface.co/{provider_name}/v1/chat/completions" + | img_b64 = "" + | if use_raw_binary_body and isinstance(pipeline_payload, bytes): + | img_b64 = base64.b64encode(pipeline_payload).decode("utf-8") + | chat_payload = { + | "model": provider_id, + | "messages": [{ + | "role": "user", + | "content": [ + | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}} if img_b64 else None, + | {"type": "text", "text": prompt_value if prompt_value else "What is in this image?"}, + | ], + | }], + | } + | chat_payload["messages"][0]["content"] = [c for c in chat_payload["messages"][0]["content"] if c is not None] + | resp = requests.post(url, headers=json_headers, json=chat_payload, timeout=120) + | elif provider_name == "hf-inference": + | url = f"https://router.huggingface.co/hf-inference/models/{self.MODEL_ID}" + | if use_raw_binary_body: + | resp = requests.post(url, headers=raw_binary_headers, data=pipeline_payload, timeout=120) + | else: + | resp = requests.post(url, headers=json_headers, json=pipeline_payload, timeout=120) + | else: + | resp = self._call_provider(provider_name, provider_id, json_headers, raw_binary_headers, pipeline_payload, use_raw_binary_body, prompt_value) + | except Exception as e: + | errors.append(f"{provider_name}: {type(e).__name__}") + | continue + | if resp.status_code in (200, 201): + | return resp, None + | if resp.status_code == 401: + | return resp, None + | try: + | detail = resp.json().get("error", resp.text[:200]) + | except Exception: + | detail = resp.text[:200] if resp.text else "no details" + | errors.append(f"{provider_name}: HTTP {resp.status_code} - {detail}") + | last_resp = resp + | if resp.status_code not in RETRYABLE: + | return resp, "; ".join(errors) + | summary = "; ".join(errors) if errors else "no providers available" + | return last_resp, summary + | + | def _call_provider(self, provider_name, provider_id, json_headers, raw_binary_headers, pipeline_payload, use_raw_binary_body, prompt_value): + | '''Route to a third-party provider using its native API format. + | Handles OpenAI-compatible chat providers for text-gen, zai-org's + | custom API, Replicate / Fal-ai / Wavespeed for media-generation + | and image-to-image, and an unknown-provider fallback that tries + | the pipeline format then chat completions. + | ''' + | base = f"https://router.huggingface.co/{provider_name}" + | task = self.TASK + | img_b64 = "" + | if use_raw_binary_body and isinstance(pipeline_payload, bytes): + | img_b64 = base64.b64encode(pipeline_payload).decode("utf-8") + | + | # zai-org: custom /api/paas/v4/ surface. + | if provider_name == "zai-org": + | zai_headers = {**json_headers, "x-source-channel": "hugging_face", "accept-language": "en-US,en"} + | if task in ("image-to-text", "image-text-to-text"): + | url = f"{base}/api/paas/v4/layout_parsing" + | file_data = f"data:image/png;base64,{img_b64}" if img_b64 else "" + | return requests.post(url, headers=zai_headers, json={"model": provider_id, "file": file_data}, timeout=120) + | url = f"{base}/api/paas/v4/chat/completions" + | messages = [{"role": "user", "content": prompt_value}] + | if img_b64: + | messages = [{"role": "user", "content": [ + | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}}, + | {"type": "text", "text": prompt_value if prompt_value else "What is in this image?"}, + | ]}] + | return requests.post(url, headers=zai_headers, json={"model": provider_id, "messages": messages}, timeout=120) + | + | # Replicate: synchronous predictions endpoint with polling fallback. + | if provider_name == "replicate": + | url = f"{base}/v1/models/{provider_id}/predictions" + | hdrs = {**json_headers, "Prefer": "wait"} + | if task == "text-to-speech": + | inp = {"text": prompt_value} + | elif task in ("text-to-image", "text-to-video"): + | inp = {"prompt": prompt_value} + | elif task == "automatic-speech-recognition" and img_b64: + | inp = {"audio": f"data:audio/wav;base64,{img_b64}"} Review Comment: fixed by using the audio Content-Type already available in raw_binary_headers when constructing the Replicate data URL added regression coverage to ensure audio/wav is not hardcoded -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
