This is an automated email from the ASF dual-hosted git repository. github-merge-queue[bot] pushed a commit to branch gh-readonly-queue/main/pr-5320-45d28ce962926e1e21a27fa10bf3652ae2956227 in repository https://gitbox.apache.org/repos/asf/texera.git
commit 731d671ccb4ce8bd5bf3f1b55fa4f0ee949567c5 Author: Prateek Ganigi <[email protected]> AuthorDate: Wed Jun 17 23:24:04 2026 -0700 feat(huggingFace): add image task family via ImageTaskCodegen (#5320) ### What changes were proposed in this PR? Adds the image task family — 9 HF pipeline tasks — as the second `TaskCodegen` plugged into the dispatcher established by #5278: image-only: image-classification, object-detection, image-segmentation, image-to-text image + prompt: visual-question-answering, document-question-answering, zero-shot-image-classification, image-text-to-text, image-to-image - `codegen/ImageTaskCodegen.scala` supplies the per-task payload + parse Python branches for all 9 tasks. - `TaskCodegen` trait gains a `tasks: Set[String]` default method (defaults to `Set(task)`) so a single codegen can register under multiple task strings; `ImageTaskCodegen` is the first multi-task codegen to use it. - `CodegenContext` extended with `imageInput` + `inputImageColumn` (`EncodableString`). - `HuggingFaceInferenceOpDesc.scala` gains 2 new `@JsonProperty` fields and registers `ImageTaskCodegen` via the new `tasks` flat-map. `PythonCodegenBase.scala` grows to host the shared image infrastructure: - Task-family tuples (`image_only_tasks`, `image_prompt_tasks`, `image_tasks`) + `image_headers` in `process_table`. - Per-row image-bytes resolution from upload or column with `_read_image_input` / `_read_binary_value` / `_compress_image_bytes`. - `_post_with_fallback` extended with `raw_binary_headers` + `use_raw_binary_body`; adds image-text-to-text chat-completions and model-author vision branches. - `_call_provider` gains zai-org, Replicate predictions + polling, Fal-ai, Wavespeed submit+poll branches, and image embedding for OpenAI-compatible / unknown-provider fallbacks. - Image content-type response handling returns `data:image/...;base64,...` URLs. - Image helpers added: `_read_image_input`, `_compress_image_bytes`, `_image_input_as_base64`, `_read_binary_value`, `_looks_like_html`, `_html_to_image_bytes`, `_extract_json_arg`, `_url_to_data_url`. Frontend integration (HF lines only — no agent / dataset noise): `HuggingFaceImageUploadComponent` declared in `app.module.ts`, `huggingface-image-upload` formly type registered, image upload component .ts/.html/.scss + `HuggingFace.png` + `sample-image.png` assets. User-input strings continue to flow through `pyb"..."` + `EncodableString` so they reach Python as `self.decode_python_template('<base64>')` rather than raw literals. `PythonCodeRawInvalidTextSpec` still passes (117/117 descriptors `py_compile` cleanly). ### Any related issues, documentation, or discussions? - Tracking issue: #5319 - Closes: #5319 - Stacked on: #5278 (operator + text-generation — issue #5277) - Parent issue: #5041 - Closed sibling issue: #5134 (REST resource — landed via #5124) ### How was this PR tested? - `sbt "WorkflowOperator/compile; WorkflowOperator/Test/compile"` clean. - `sbt scalafmtCheck` clean. - `sbt "WorkflowOperator/testOnly org.apache.texera.amber.operator.huggingFace.HuggingFaceInferenceOpDescSpec"` — 18/18 pass (PR 2's 13 spec tests + 5 new image-task tests: image-only routing, VQA / document-QA payload, image-text-to-text chat-completions, image-to-image data-URL parse, all-9-tasks dispatcher coverage). - `sbt "WorkflowOperator/testOnly org.apache.texera.amber.util.PythonCodeRawInvalidTextSpec"` — 117/117 descriptors `py_compile` cleanly with the new operator code paths, no marker leaks. - Generated Python verified via `python3 -m py_compile` on sample image-task outputs. ### Was this PR authored or co-authored using generative AI tooling? Yes, co-authored with Claude Opus 4.7. --------- Signed-off-by: Prateek Ganigi <[email protected]> Co-authored-by: Claude Opus 4.7 (1M context) <[email protected]> Co-authored-by: Copilot Autofix powered by AI <[email protected]> --- .../huggingFace/HuggingFaceInferenceOpDesc.scala | 29 +- .../huggingFace/codegen/ImageTaskCodegen.scala | 166 +++++++ .../huggingFace/codegen/PythonCodegenBase.scala | 520 ++++++++++++++++++++- .../operator/huggingFace/codegen/TaskCodegen.scala | 16 +- .../HuggingFaceInferenceOpDescSpec.scala | 222 ++++++++- frontend/src/app/app.module.ts | 2 + frontend/src/app/common/formly/formly-config.ts | 2 + .../hugging-face-image-upload.component.html | 51 ++ .../hugging-face-image-upload.component.scss | 60 +++ .../hugging-face-image-upload.component.spec.ts | 146 ++++++ .../hugging-face-image-upload.component.ts | 162 +++++++ .../src/assets/operator_images/HuggingFace.png | Bin 0 -> 13831 bytes frontend/src/assets/sample-image.png | Bin 0 -> 101737 bytes 13 files changed, 1350 insertions(+), 26 deletions(-) diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDesc.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDesc.scala index 07466c898e..5f203717d1 100644 --- a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDesc.scala +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDesc.scala @@ -26,6 +26,7 @@ import org.apache.texera.amber.core.workflow.{InputPort, OutputPort, PortIdentit import org.apache.texera.amber.operator.PythonOperatorDescriptor import org.apache.texera.amber.operator.huggingFace.codegen.{ CodegenContext, + ImageTaskCodegen, PythonCodegenBase, TaskCodegen, TextGenCodegen @@ -83,6 +84,17 @@ class HuggingFaceInferenceOpDesc extends PythonOperatorDescriptor { @AutofillAttributeName var promptColumn: EncodableString = "" + @JsonProperty(value = "imageInput", required = false) + @JsonSchemaTitle("Image Upload") + @JsonPropertyDescription("Upload an image for Hugging Face image tasks") + var imageInput: EncodableString = "" + + @JsonProperty(value = "inputImageColumn", required = false) + @JsonSchemaTitle("Input Image Column") + @JsonPropertyDescription("Column containing image data from the input table") + @AutofillAttributeName + var inputImageColumn: EncodableString = "" + @JsonProperty( value = "systemPrompt", required = false, @@ -122,8 +134,12 @@ class HuggingFaceInferenceOpDesc extends PythonOperatorDescriptor { * keeps `generatePythonCode` total (it never throws on arbitrary input, * which is required by `PythonCodeRawInvalidTextSpec`). */ - private val registeredCodegens: Map[String, TaskCodegen] = - Map(TextGenCodegen.task -> TextGenCodegen) + private val registeredCodegens: Map[String, TaskCodegen] = { + val byTask = scala.collection.mutable.Map.empty[String, TaskCodegen] + byTask += (TextGenCodegen.task -> TextGenCodegen) + ImageTaskCodegen.tasks.foreach(t => byTask += (t -> ImageTaskCodegen)) + byTask.toMap + } private def codegenForTask(t: String): TaskCodegen = registeredCodegens.getOrElse(t, TextGenCodegen) @@ -161,6 +177,11 @@ class HuggingFaceInferenceOpDesc extends PythonOperatorDescriptor { val safeTemp = math.max(0.0, math.min(if (temperature != null) temperature.doubleValue else 0.7, 2.0)) + val safeImageInput: EncodableString = + if (imageInput == null) "" else imageInput + val safeInputImageColumn: EncodableString = + if (inputImageColumn == null) "" else inputImageColumn + val ctx = CodegenContext( hfApiToken = safeToken, modelId = safeModelId, @@ -169,7 +190,9 @@ class HuggingFaceInferenceOpDesc extends PythonOperatorDescriptor { task = safeTask, systemPrompt = safeSystemPrompt, safeMaxTokens = safeMaxTokens, - safeTemp = safeTemp + safeTemp = safeTemp, + imageInput = safeImageInput, + inputImageColumn = safeInputImageColumn ) PythonCodegenBase.render(ctx, codegenForTask(safeTask)) diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/ImageTaskCodegen.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/ImageTaskCodegen.scala new file mode 100644 index 0000000000..c5c4a2669c --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/ImageTaskCodegen.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace.codegen + +/** + * Codegen for the Hugging Face image-pipeline task family. + * + * Splits into two sub-families: + * - "image-only" tasks send raw image bytes as the request body and don't + * consume the prompt column: image-classification, object-detection, + * image-segmentation, image-to-text. + * - "image + prompt" tasks bundle a base64 image and a text prompt in a + * JSON payload: visual-question-answering, document-question-answering, + * zero-shot-image-classification, image-text-to-text, image-to-image. + * + * Per-row `current_image_bytes` is resolved upstream in + * [[PythonCodegenBase]]'s `process_table` (either from the operator's + * uploaded image or from `INPUT_IMAGE_COLUMN`). The image helpers + * (`_read_image_input`, `_compress_image_bytes`, `_image_input_as_base64`, + * `_read_binary_value`, `_looks_like_html`, `_html_to_image_bytes`, + * `_extract_json_arg`) live in PythonCodegenBase alongside the per-task + * tuples (`image_only_tasks`, `image_prompt_tasks`, `image_tasks`). + */ +object ImageTaskCodegen extends TaskCodegen { + + /** Primary key for registration; the dispatcher maps every task in + * [[tasks]] to this codegen. + */ + override val task: String = "image-classification" + + /** All HF tasks routed through this codegen. */ + override val tasks: Set[String] = Set( + // image-only + "image-classification", + "object-detection", + "image-segmentation", + "image-to-text", + // image + prompt + "visual-question-answering", + "document-question-answering", + "zero-shot-image-classification", + "image-text-to-text", + "image-to-image" + ) + + override def payloadPython(ctx: CodegenContext): String = + """ if task in image_only_tasks: + | payload = current_image_bytes + | use_raw_binary_body = True + | raw_binary_headers = image_headers + | elif task in ("visual-question-answering", "document-question-answering"): + | payload = { + | "inputs": { + | "image": self._image_input_as_base64(current_image_bytes), + | "question": prompt_value, + | } + | } + | elif task == "image-text-to-text": + | img_b64 = self._image_input_as_base64(current_image_bytes) + | payload = { + | "model": self.MODEL_ID, + | "messages": [{ + | "role": "user", + | "content": [ + | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}}, + | {"type": "text", "text": prompt_value if prompt_value else "Describe this image."}, + | ], + | }], + | "max_tokens": self.MAX_NEW_TOKENS, + | } + | elif task == "image-to-image": + | payload = current_image_bytes + | use_raw_binary_body = True + | raw_binary_headers = image_headers + | elif task == "zero-shot-image-classification": + | # Zero-shot requires the caller to supply candidate labels. + | # We reuse the prompt column as a comma-separated label list so + | # the task is shippable without a dedicated operator field. + | # TODO: replace with a first-class `candidateLabels` field once + | # the property panel supports task-specific inputs. + | # + | # Fail fast if usable labels can't be derived. Both modes lead to + | # a meaningless inference call: + | # 1. Empty prompt column -> labels = [] + | # The HF API rejects candidate_labels: [] with an opaque 400. + | # 2. Missing prompt column -> upstream sets prompt_value + | # to the fallback "What is shown in this image?", which has + | # no comma, so labels collapses to a single nonsense entry. + | # Zero-shot classification needs >= 2 candidate labels to be + | # meaningful — surface a configuration error in both cases. + | labels = [s.strip() for s in prompt_value.split(",") if s.strip()] + | if len(labels) < 2: + | raise ValueError( + | "zero-shot-image-classification requires at least 2 candidate " + | "labels: provide a comma-separated list in the prompt column." + | ) + | payload = { + | "inputs": self._image_input_as_base64(current_image_bytes), + | "parameters": {"candidate_labels": labels}, + | } + | else: + | payload = {"inputs": prompt_value}""".stripMargin + + override def parsePython(ctx: CodegenContext): String = + """ if task == "image-to-text": + | if isinstance(body, dict): + | if "md_results" in body: + | return body["md_results"] + | if "choices" in body: + | return body["choices"][0]["message"]["content"] + | if isinstance(body, list) and body and isinstance(body[0], dict): + | return body[0].get("generated_text", json.dumps(body)) + | return json.dumps(body) + | elif task in ("visual-question-answering", "document-question-answering"): + | if isinstance(body, dict): + | return body.get("answer", json.dumps(body)) + | return json.dumps(body) + | elif task == "image-text-to-text": + | if isinstance(body, dict) and "choices" in body: + | return body["choices"][0]["message"]["content"] + | if isinstance(body, list) and body and isinstance(body[0], dict): + | return body[0].get("generated_text", json.dumps(body)) + | return json.dumps(body) + | elif task == "image-to-image": + | if isinstance(body, dict): + | if "output" in body: + | out = body["output"] + | url = out[0] if isinstance(out, list) else out + | if isinstance(url, str) and url.startswith("http"): + | return self._url_to_data_url(url) + | if "images" in body: + | images = body["images"] + | if images and isinstance(images[0], dict) and "url" in images[0]: + | return self._url_to_data_url(images[0]["url"]) + | if "data" in body: + | data = body["data"] + | if isinstance(data, dict) and "outputs" in data: + | outputs = data["outputs"] + | if outputs and isinstance(outputs[0], str) and outputs[0].startswith("http"): + | return self._url_to_data_url(outputs[0]) + | if isinstance(data, list) and data and isinstance(data[0], dict): + | if "b64_json" in data[0]: + | return f"data:image/png;base64,{data[0]['b64_json']}" + | if "url" in data[0]: + | return self._url_to_data_url(data[0]["url"]) + | return json.dumps(body) + | elif task in ("image-classification", "object-detection", "image-segmentation", "zero-shot-image-classification"): + | return json.dumps(body)""".stripMargin +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/PythonCodegenBase.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/PythonCodegenBase.scala index 16c2cc9bbb..eac4641c62 100644 --- a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/PythonCodegenBase.scala +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/PythonCodegenBase.scala @@ -55,9 +55,12 @@ object PythonCodegenBase { val systemPrompt = ctx.systemPrompt val maxNewTokens = ctx.safeMaxTokens val temperature = ctx.safeTemp + val imageInput = ctx.imageInput + val inputImageColumn = ctx.inputImageColumn pyb"""import os |import re |import json + |import base64 |import requests |import pandas as pd |from pytexera import * @@ -117,6 +120,9 @@ object PythonCodegenBase { | "ovhcloud", "publicai", "scaleway", "baseten", | ) | + | # Hard cap on bytes pulled from an external (user/response-provided) URL. + | MAX_REMOTE_FETCH_BYTES = 25 * 1024 * 1024 + | | def open(self): | # User-provided strings reach the operator via base64-encoded | # decode expressions so they cannot break Python syntax or @@ -129,6 +135,8 @@ object PythonCodegenBase { | self.SYSTEM_PROMPT = $systemPrompt | self.MAX_NEW_TOKENS = $maxNewTokens | self.TEMPERATURE = $temperature + | self.IMAGE_INPUT = $imageInput + | self.INPUT_IMAGE_COLUMN = $inputImageColumn | | def _resolve_providers(self, token): | '''Query the HF Hub API for inference providers serving this model. @@ -168,7 +176,7 @@ object PythonCodegenBase { | pass | return [{"name": "hf-inference", "providerId": self.MODEL_ID}] | - | def _post_with_fallback(self, providers, json_headers, pipeline_payload, prompt_value): + | def _post_with_fallback(self, providers, json_headers, raw_binary_headers, pipeline_payload, use_raw_binary_body, prompt_value): | '''Try providers in order, using the correct API route for each. | Returns (response, provider_summary). provider_summary is None on | success or a string describing what failed. @@ -179,16 +187,38 @@ object PythonCodegenBase { | for prov in providers: | provider_name = prov["name"] | provider_id = prov["providerId"] + | is_model_author = prov.get("isModelAuthor", False) + | prov_task = prov.get("task", "") | try: - | if self.TASK == "text-generation": + | if self.TASK in ("text-generation", "image-text-to-text"): | route = self.CHAT_ROUTES.get(provider_name, "v1/chat/completions") | url = f"https://router.huggingface.co/{provider_name}/{route}" | resp = requests.post(url, headers=json_headers, json=pipeline_payload, timeout=120) + | elif is_model_author and prov_task in ("image-to-text", "image-text-to-text") and provider_name not in ("zai-org",): + | url = f"https://router.huggingface.co/{provider_name}/v1/chat/completions" + | img_b64 = "" + | if use_raw_binary_body and isinstance(pipeline_payload, bytes): + | img_b64 = base64.b64encode(pipeline_payload).decode("utf-8") + | chat_payload = { + | "model": provider_id, + | "messages": [{ + | "role": "user", + | "content": [ + | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}} if img_b64 else None, + | {"type": "text", "text": prompt_value if prompt_value else "What is in this image?"}, + | ], + | }], + | } + | chat_payload["messages"][0]["content"] = [c for c in chat_payload["messages"][0]["content"] if c is not None] + | resp = requests.post(url, headers=json_headers, json=chat_payload, timeout=120) | elif provider_name == "hf-inference": | url = f"https://router.huggingface.co/hf-inference/models/{self.MODEL_ID}" - | resp = requests.post(url, headers=json_headers, json=pipeline_payload, timeout=120) + | if use_raw_binary_body: + | resp = requests.post(url, headers=raw_binary_headers, data=pipeline_payload, timeout=120) + | else: + | resp = requests.post(url, headers=json_headers, json=pipeline_payload, timeout=120) | else: - | resp = self._call_provider(provider_name, provider_id, json_headers, pipeline_payload, prompt_value) + | resp = self._call_provider(provider_name, provider_id, json_headers, raw_binary_headers, pipeline_payload, use_raw_binary_body, prompt_value) | except Exception as e: | errors.append(f"{provider_name}: {type(e).__name__}") | continue @@ -207,18 +237,174 @@ object PythonCodegenBase { | summary = "; ".join(errors) if errors else "no providers available" | return last_resp, summary | - | def _call_provider(self, provider_name, provider_id, json_headers, pipeline_payload, prompt_value): + | def _call_provider(self, provider_name, provider_id, json_headers, raw_binary_headers, pipeline_payload, use_raw_binary_body, prompt_value): | '''Route to a third-party provider using its native API format. - | For the text-gen-only build this covers the OpenAI-compatible chat - | providers and an unknown-provider fallback that tries the pipeline - | format then chat completions. Image / audio / media routing will - | be added in subsequent PRs alongside the corresponding task - | codegens. + | Handles OpenAI-compatible chat providers for text-gen, zai-org's + | custom API, Replicate / Fal-ai / Wavespeed for media-generation + | and image-to-image, and an unknown-provider fallback that tries + | the pipeline format then chat completions. | ''' | base = f"https://router.huggingface.co/{provider_name}" + | task = self.TASK + | img_b64 = "" + | if use_raw_binary_body and isinstance(pipeline_payload, bytes): + | img_b64 = base64.b64encode(pipeline_payload).decode("utf-8") + | elif isinstance(pipeline_payload, dict): + | # Image+prompt tasks (visual-question-answering, document-question- + | # answering, zero-shot-image-classification) build dict payloads + | # with use_raw_binary_body=False, so the raw-bytes extraction above + | # doesn't fire. Without this branch, when one of those tasks routes + | # to a third-party provider (replicate / fal-ai / wavespeed / + | # OpenAI-compatible / unknown-fallback) the image is silently + | # dropped and only prompt_value is sent — they happen to work only + | # on hf-inference, where the dict goes through as JSON. Surfacing + | # img_b64 here keeps the provider-specific branches below image- + | # aware without each branch needing to know the dict shape. + | inputs = pipeline_payload.get("inputs") + | if isinstance(inputs, dict) and isinstance(inputs.get("image"), str): + | img_b64 = inputs["image"] + | elif task == "zero-shot-image-classification" and isinstance(inputs, str): + | img_b64 = inputs + | + | # zai-org: custom /api/paas/v4/ surface. + | if provider_name == "zai-org": + | zai_headers = {**json_headers, "x-source-channel": "hugging_face", "accept-language": "en-US,en"} + | if task in ("image-to-text", "image-text-to-text"): + | url = f"{base}/api/paas/v4/layout_parsing" + | file_data = f"data:image/png;base64,{img_b64}" if img_b64 else "" + | return requests.post(url, headers=zai_headers, json={"model": provider_id, "file": file_data}, timeout=120) + | url = f"{base}/api/paas/v4/chat/completions" + | messages = [{"role": "user", "content": prompt_value}] + | if img_b64: + | messages = [{"role": "user", "content": [ + | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}}, + | {"type": "text", "text": prompt_value if prompt_value else "What is in this image?"}, + | ]}] + | return requests.post(url, headers=zai_headers, json={"model": provider_id, "messages": messages}, timeout=120) + | + | # Replicate: synchronous predictions endpoint with polling fallback. + | if provider_name == "replicate": + | url = f"{base}/v1/models/{provider_id}/predictions" + | hdrs = {**json_headers, "Prefer": "wait"} + | if task == "image-to-image" and img_b64: + | data_url = f"data:image/png;base64,{img_b64}" + | inp = {"image": data_url, "images": [data_url], "input_image": data_url, "prompt": prompt_value} + | elif img_b64: + | inp = {"image": f"data:image/png;base64,{img_b64}", "prompt": prompt_value} + | else: + | inp = {"prompt": prompt_value} + | resp = requests.post(url, headers=hdrs, json={"input": inp}, timeout=120) + | if resp.status_code == 202: + | import time as _time + | pred = resp.json() + | poll_url = pred.get("urls", {}).get("get", "") + | if not poll_url: + | return resp + | from urllib.parse import urlparse as _urlparse + | poll_path = _urlparse(poll_url).path + | poll_url = f"{base}{poll_path}" + | # Worst case: 300 polls × 2s = ~10 minutes per row before we give + | # up. Sized for text-to-video which legitimately takes minutes on + | # Replicate. process_table is synchronous, so emit a progress + | # line every 30 polls (~1 min) to distinguish slow work from a + | # hang in the worker log. + | for poll_idx in range(300): + | _time.sleep(2) + | poll_resp = requests.get(poll_url, headers=json_headers, timeout=30) + | if poll_resp.status_code != 200: + | continue + | status = poll_resp.json().get("status", "") + | if status == "succeeded": + | return poll_resp + | if status in ("failed", "canceled"): + | # The polling HTTP request itself returned 200, but the + | # Replicate prediction terminally failed. Without this + | # branch, process_table would treat the 200 as success + | # and emit json.dumps(body) (raw error JSON) into the + | # output cell. Convert to a synthetic 502 so + | # _post_with_fallback's non-200 handler surfaces the + | # actual failure detail via _format_error. + | body_json = poll_resp.json() if poll_resp.text else {} + | detail = (body_json.get("error") or body_json.get("logs") or status) \ + | if isinstance(body_json, dict) else status + | poll_resp.status_code = 502 + | poll_resp._content = json.dumps({ + | "error": f"Replicate prediction {status}: {detail}" + | }).encode("utf-8") + | return poll_resp + | if (poll_idx + 1) % 30 == 0: + | print(f"[hf] Replicate still running for model '{self.MODEL_ID}' after {(poll_idx + 1) * 2}s; will wait up to 600s.") + | return poll_resp + | return resp + | + | # Fal-ai: per-model endpoint. + | if provider_name == "fal-ai": + | url = f"{base}/{provider_id}" + | if task == "image-to-image" and img_b64: + | data_url = f"data:image/png;base64,{img_b64}" + | return requests.post(url, headers=json_headers, json={"image_url": data_url, "image_urls": [data_url], "prompt": prompt_value}, timeout=120) + | if img_b64: + | return requests.post(url, headers=json_headers, json={"image_url": f"data:image/png;base64,{img_b64}", "prompt": prompt_value}, timeout=120) + | return requests.post(url, headers=json_headers, json={"prompt": prompt_value}, timeout=120) + | + | # Wavespeed: async submit + poll. + | if provider_name == "wavespeed": + | url = f"{base}/api/v3/{provider_id}" + | payload = {"prompt": prompt_value} + | if img_b64: + | payload["image"] = img_b64 + | payload["images"] = [img_b64] + | submit_resp = requests.post(url, headers=json_headers, json=payload, timeout=120) + | if submit_resp.status_code not in (200, 201): + | return submit_resp + | get_path = submit_resp.json().get("data", {}).get("urls", {}).get("get", "") + | if not get_path: + | return submit_resp + | from urllib.parse import urlparse as _urlparse + | result_url = f"{base}{_urlparse(get_path).path}" + | import time as _time + | poll_resp = submit_resp + | # Worst case: 120 polls × 1s = ~2 minutes per row. Emit a progress + | # line every 30 polls (~30 s) so the worker log distinguishes slow + | # work from a hang. + | for poll_idx in range(120): + | _time.sleep(1) + | poll_resp = requests.get(result_url, headers=json_headers, timeout=30) + | if poll_resp.status_code != 200: + | continue + | body_json = poll_resp.json() if poll_resp.text else {} + | data_obj = body_json.get("data", {}) if isinstance(body_json, dict) else {} + | status = data_obj.get("status", "") if isinstance(data_obj, dict) else "" + | if status == "completed": + | return poll_resp + | if status == "failed": + | # Same shape as Replicate: HTTP 200 + body says "failed". + | # Synthesize a 502 so _post_with_fallback's non-200 handler + | # reports the actual reason instead of process_table + | # parsing the success-shaped body and writing raw error + | # JSON into the result cell. + | detail = ( + | (data_obj.get("error") if isinstance(data_obj, dict) else None) + | or (body_json.get("error") if isinstance(body_json, dict) else None) + | or "failed" + | ) + | poll_resp.status_code = 502 + | poll_resp._content = json.dumps({ + | "error": f"Wavespeed job failed: {detail}" + | }).encode("utf-8") + | return poll_resp + | if (poll_idx + 1) % 30 == 0: + | print(f"[hf] Wavespeed still running for model '{self.MODEL_ID}' after {poll_idx + 1}s; will wait up to 120s.") + | return poll_resp + | | if provider_name in self.OPENAI_COMPATIBLE_PROVIDERS: | url = f"{base}/{self.CHAT_ROUTES.get(provider_name, 'v1/chat/completions')}" | messages = [{"role": "user", "content": prompt_value}] + | if img_b64: + | messages = [{"role": "user", "content": [ + | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}}, + | {"type": "text", "text": prompt_value if prompt_value else "What is in this image?"}, + | ]}] | return requests.post( | url, | headers=json_headers, @@ -228,10 +414,18 @@ object PythonCodegenBase { | | # Unknown provider: try pipeline format, then chat completions. | url = f"{base}/{provider_id}" - | resp = requests.post(url, headers=json_headers, json=pipeline_payload, timeout=120) + | if use_raw_binary_body: + | resp = requests.post(url, headers=raw_binary_headers, data=pipeline_payload, timeout=120) + | else: + | resp = requests.post(url, headers=json_headers, json=pipeline_payload, timeout=120) | if resp.status_code in (400, 404, 422): | url = f"{base}/v1/chat/completions" | messages = [{"role": "user", "content": prompt_value}] + | if img_b64: + | messages = [{"role": "user", "content": [ + | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}}, + | {"type": "text", "text": prompt_value if prompt_value else "Describe this image."}, + | ]}] | resp2 = requests.post( | url, | headers=json_headers, @@ -247,6 +441,9 @@ object PythonCodegenBase { | prompt_col = self.PROMPT_COLUMN | result_col = self.RESULT_COLUMN | task = self.TASK + | image_only_tasks = ("image-classification", "object-detection", "image-segmentation", "image-to-text") + | image_prompt_tasks = ("visual-question-answering", "document-question-answering", "zero-shot-image-classification", "image-text-to-text", "image-to-image") + | image_tasks = image_only_tasks + image_prompt_tasks | | # --- validate MODEL_ID format before any HF URL is built --- | if not _HF_MODEL_ID_PATTERN.match(self.MODEL_ID or ""): @@ -266,11 +463,12 @@ object PythonCodegenBase { | # --- resolve all available inference providers for this model (tried in order) --- | providers = self._resolve_providers(token) | - | # --- validate prompt column exists --- - | assert prompt_col in table.columns, ( - | f"Prompt column '{prompt_col}' not found in input table. " - | f"Available columns: {list(table.columns)}" - | ) + | # --- validate prompt column exists (required for non-image tasks) --- + | if task not in image_tasks: + | assert prompt_col in table.columns, ( + | f"Prompt column '{prompt_col}' not found in input table. " + | f"Available columns: {list(table.columns)}" + | ) | | # --- handle empty table --- | if table.empty: @@ -282,21 +480,63 @@ object PythonCodegenBase { | "Authorization": f"Bearer {token}", | "Content-Type": "application/json", | } + | image_headers = { + | "Authorization": f"Bearer {token}", + | "Content-Type": "application/octet-stream", + | } + | + | # --- resolve image source (upload or column) for image tasks --- + | has_image_upload = bool(self.IMAGE_INPUT) and bool(str(self.IMAGE_INPUT).strip()) + | use_image_column = not has_image_upload and bool(self.INPUT_IMAGE_COLUMN) and self.INPUT_IMAGE_COLUMN in table.columns + | image_bytes = None + | image_error = None + | if task in image_tasks and not use_image_column: + | if not has_image_upload: + | image_error = "No image source. Set an Input Image Column or upload an image." + | else: + | try: + | image_bytes = self._read_image_input() + | except Exception as e: + | image_error = f"Could not read image input ({type(e).__name__}: {e})" | | results = [] | for idx, row in table.iterrows(): - | prompt_value = row[prompt_col] - | if pd.isna(prompt_value): + | if image_error is not None: + | results.append(self._format_error("Image task configuration error", image_error)) + | continue + | + | if task in image_only_tasks: | prompt_value = "" + | elif task in image_prompt_tasks and prompt_col not in table.columns: + | prompt_value = "What is shown in this image?" | else: - | prompt_value = str(prompt_value) + | prompt_value = row[prompt_col] + | if pd.isna(prompt_value): + | prompt_value = "" + | else: + | prompt_value = str(prompt_value) + | + | # --- resolve per-row image bytes from column --- + | current_image_bytes = image_bytes + | if task in image_tasks and use_image_column: + | try: + | raw = self._read_binary_value(row[self.INPUT_IMAGE_COLUMN]) + | if raw is None: + | results.append(self._format_error("Image data error", f"Row {idx}: image column is empty")) + | continue + | current_image_bytes = self._compress_image_bytes(raw) + | except Exception as e: + | results.append(self._format_error("Image data error", f"Row {idx}: {type(e).__name__}: {e}")) + | continue | | # --- build task-specific payload (provided by per-task codegen) --- + | use_raw_binary_body = False + | raw_binary_headers = image_headers |${payload} | | try: | resp, provider_summary = self._post_with_fallback( - | providers, json_headers, payload, prompt_value + | providers, json_headers, raw_binary_headers, payload, use_raw_binary_body, prompt_value | ) | | if resp is None: @@ -331,6 +571,12 @@ object PythonCodegenBase { | ) | continue | + | content_type = resp.headers.get("Content-Type", "") + | if content_type.startswith("image/"): + | b64 = base64.b64encode(resp.content).decode("utf-8") + | results.append(f"data:{content_type};base64,{b64}") + | continue + | | try: | body = resp.json() | except ValueError: @@ -361,6 +607,240 @@ object PythonCodegenBase { | detail = "<empty response>" | return f"{title} [status={status_code}] response={detail}" | + | # ────────────────────────────────────────────────────────────────── + | # Image-task helpers (used by ImageTaskCodegen and image-related + | # branches of _call_provider). + | # ────────────────────────────────────────────────────────────────── + | + | def _fetch_remote_url(self, url): + | '''Fetch an external URL with SSRF hardening. Returns (content_type, data). + | Enforces https-only, rejects private/loopback/link-local/reserved + | addresses (covers the 169.254.169.254 cloud-metadata endpoint), and + | caps the response at MAX_REMOTE_FETCH_BYTES. The address check runs + | before the request, so it mitigates but does not fully prevent DNS + | rebinding (requests re-resolves on connect). + | ''' + | import ipaddress + | import socket + | from urllib.parse import urlparse as _urlparse + | parsed = _urlparse(url) + | if parsed.scheme != "https": + | raise ValueError(f"Only https URLs are allowed (got scheme '{parsed.scheme}').") + | host = parsed.hostname + | if not host: + | raise ValueError("Remote URL has no host.") + | try: + | addrinfos = socket.getaddrinfo(host, parsed.port or 443, proto=socket.IPPROTO_TCP) + | except socket.gaierror as e: + | raise ValueError(f"Could not resolve host '{host}': {e}") + | for info in addrinfos: + | ip = ipaddress.ip_address(info[4][0]) + | if (ip.is_private or ip.is_loopback or ip.is_link_local + | or ip.is_reserved or ip.is_multicast or ip.is_unspecified): + | raise ValueError(f"Refusing to fetch from non-public address {ip}.") + | resp = requests.get(url, timeout=120, stream=True) + | resp.raise_for_status() + | content_type = resp.headers.get("Content-Type", "") + | total = 0 + | chunks = [] + | for chunk in resp.iter_content(65536): + | total += len(chunk) + | if total > self.MAX_REMOTE_FETCH_BYTES: + | resp.close() + | raise ValueError( + | f"Remote file exceeds the {self.MAX_REMOTE_FETCH_BYTES} byte limit." + | ) + | chunks.append(chunk) + | return content_type, b"".join(chunks) + | + | def _read_image_input(self): + | image_input = str(self.IMAGE_INPUT or "").strip() + | if image_input.startswith("data:"): + | _, encoded = image_input.split(",", 1) + | return base64.b64decode(encoded) + | if image_input.startswith("http://") or image_input.startswith("https://"): + | _, data = self._fetch_remote_url(image_input) + | return data + | # Reading arbitrary worker-filesystem paths is intentionally NOT + | # supported: a workflow could otherwise point this at any file on the + | # worker (e.g. /etc/passwd) and exfiltrate it via the inference call. + | # Uploaded images arrive as data URLs; remote images as https URLs. + | raise ValueError( + | "Unsupported image input. Upload an image (sent as a data URL) " + | "or provide a public https image URL." + | ) + | + | def _compress_image_bytes(self, image_bytes, max_bytes=33000): + | from io import BytesIO + | from PIL import Image as PILImage + | if len(image_bytes) <= max_bytes: + | return image_bytes + | try: + | img = PILImage.open(BytesIO(image_bytes)) + | img = img.convert("RGB") + | max_dim = 512 + | quality = 75 + | while max_dim >= 160: + | scale = min(1, max_dim / max(img.width, img.height)) + | w = max(1, round(img.width * scale)) + | h = max(1, round(img.height * scale)) + | resized = img.resize((w, h), PILImage.LANCZOS) + | q = quality + | while q >= 35: + | buf = BytesIO() + | resized.save(buf, format="JPEG", quality=q) + | if buf.tell() <= max_bytes: + | return buf.getvalue() + | q -= 10 + | max_dim = int(max_dim * 0.75) + | buf = BytesIO() + | resized.save(buf, format="JPEG", quality=35) + | return buf.getvalue() + | except Exception: + | return image_bytes + | + | def _image_input_as_base64(self, image_bytes): + | return base64.b64encode(image_bytes).decode("utf-8") + | + | def _read_binary_value(self, value): + | if value is None: + | return None + | if isinstance(value, bytes): + | return value + | # Treat scalar pandas/numpy missing sentinels (NaN, pd.NA, NaT) as empty. + | # isinstance(value, float) only catches float('nan'); pd.NA / NaT are not + | # floats and would otherwise be str()-ified into "<NA>"/"NaT" bytes. Guard + | # pd.isna against non-scalar inputs, where it returns an array and `if` + | # raises on an ambiguous truth value. + | try: + | if pd.isna(value): + | return None + | except (TypeError, ValueError): + | pass + | val = str(value).strip() + | if not val: + | return None + | if self._looks_like_html(val): + | return self._html_to_image_bytes(val) + | if val.startswith("data:"): + | _, encoded = val.split(",", 1) + | return base64.b64decode(encoded) + | if val.startswith("http://") or val.startswith("https://"): + | _, data = self._fetch_remote_url(val) + | return data + | # No worker-filesystem path reads here either (see _read_image_input): + | # a column value must be a data URL, http(s) URL, rendered HTML, or + | # base64-encoded bytes. Anything else is treated as raw bytes, never + | # as a path to open. + | try: + | return base64.b64decode(val) + | except Exception: + | return val.encode("utf-8") + | + | def _looks_like_html(self, val): + | s = val.lstrip()[:200].lower() + | if s.startswith("<!doctype html") or s.startswith("<html"): + | return True + | if "plotly.newplot" in val[:5000].lower() or "plotly.react" in val[:5000].lower(): + | return True + | if "<img" in s and "base64," in s: + | return True + | return False + | + | def _html_to_image_bytes(self, html_string): + | match = re.search(r"data:image/[^;]+;base64,([A-Za-z0-9+/\n\r =]+)", html_string) + | if match: + | b64 = match.group(1).replace("\n", "").replace("\r", "").replace(" ", "") + | return base64.b64decode(b64) + | if "Plotly." in html_string: + | try: + | import plotly.graph_objects as go + | import plotly.io as pio + | plotly_match = re.search(r"Plotly\.(?:newPlot|react)\s*\(\s*", html_string) + | if plotly_match: + | pos = plotly_match.end() + | if pos < len(html_string) and html_string[pos] in ('"', "'"): + | q = html_string[pos] + | pos += 1 + | while pos < len(html_string) and html_string[pos] != q: + | if html_string[pos] == "\\": + | pos += 1 + | pos += 1 + | pos += 1 + | while pos < len(html_string) and html_string[pos] in " ,\n\r\t": + | pos += 1 + | data_json, pos = self._extract_json_arg(html_string, pos) + | while pos < len(html_string) and html_string[pos] in " ,\n\r\t": + | pos += 1 + | layout_json, _ = self._extract_json_arg(html_string, pos) + | if data_json: + | data = json.loads(data_json) + | layout = json.loads(layout_json) if layout_json else {} + | fig = go.Figure(data=data, layout=layout) + | return pio.to_image(fig, format="png", width=800, height=600) + | except ImportError as ie: + | raise ValueError( + | f"Plotly chart detected but cannot render to image: {ie}. " + | f"Install kaleido: pip install kaleido" + | ) + | except json.JSONDecodeError: + | pass + | raise ValueError( + | "Cannot convert HTML to image. The HTML does not contain " + | "an extractable base64 image or a parseable Plotly chart." + | ) + | + | def _extract_json_arg(self, text, start_pos): + | if start_pos >= len(text): + | return None, start_pos + | ch = text[start_pos] + | openers = {"[": "]", "{": "}"} + | if ch not in openers: + | return None, start_pos + | closer = openers[ch] + | depth = 1 + | pos = start_pos + 1 + | in_string = False + | while pos < len(text) and depth > 0: + | c = text[pos] + | if in_string: + | if c == "\\": + | pos += 2 + | continue + | if c == '"': + | in_string = False + | else: + | if c == '"': + | in_string = True + | elif c == ch: + | depth += 1 + | elif c == closer: + | depth -= 1 + | pos += 1 + | if depth == 0: + | return text[start_pos:pos], pos + | return None, start_pos + | + | def _url_to_data_url(self, url): + | '''Fetch a URL and return a data URL with the correct MIME type. + | Fetched via _fetch_remote_url so a malicious/compromised provider + | cannot redirect this to an internal address or oversized payload. + | ''' + | raw_content_type, data = self._fetch_remote_url(url) + | content_type = raw_content_type.split(";")[0].strip() + | if not content_type or content_type == "application/octet-stream": + | from urllib.parse import urlparse as _urlparse + | ext = os.path.splitext(_urlparse(url).path.lower())[1] + | mime_map = {".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".gif": "image/gif", ".webp": "image/webp", ".svg": "image/svg+xml", ".mp4": "video/mp4", ".webm": "video/webm"} + | guessed = mime_map.get(ext, "") + | if guessed: + | content_type = guessed + | else: + | task_mime = {"image-to-image": "image/png"} + | content_type = task_mime.get(self.TASK, "application/octet-stream") + | b64 = base64.b64encode(data).decode("utf-8") + | return f"data:{content_type};base64,{b64}" + | | def _parse_response(self, body): | task = self.TASK | try: diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/TaskCodegen.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/TaskCodegen.scala index 333d1a038c..299ea5d6e3 100644 --- a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/TaskCodegen.scala +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/TaskCodegen.scala @@ -37,7 +37,9 @@ final case class CodegenContext( task: EncodableString, systemPrompt: EncodableString, safeMaxTokens: Int, - safeTemp: Double + safeTemp: Double, + imageInput: EncodableString = "", + inputImageColumn: EncodableString = "" ) /** @@ -59,9 +61,19 @@ final case class CodegenContext( */ trait TaskCodegen { - /** Canonical Hugging Face pipeline task string, e.g. "text-generation". */ + /** Canonical Hugging Face pipeline task string used as the primary key for + * registration, e.g. "text-generation". Codegens that handle multiple + * task strings (image, audio, …) override [[tasks]] to enumerate all of + * them — the operator's dispatcher registers an entry per task. + */ def task: String + /** All Hugging Face pipeline task strings handled by this codegen. + * Defaults to the singleton `Set(task)` for codegens that handle one + * task; multi-task codegens override this. + */ + def tasks: Set[String] = Set(task) + /** Python text that assigns `payload = …` for one row inside * `process_table`'s per-row loop. The snippet supplies its own leading * `if`/`elif task == "...":` opener and any `else` fallback. diff --git a/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDescSpec.scala b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDescSpec.scala index 06424df604..0d6e09302f 100644 --- a/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDescSpec.scala +++ b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDescSpec.scala @@ -37,7 +37,9 @@ class HuggingFaceInferenceOpDescSpec extends AnyFlatSpec with Matchers { systemPrompt: EncodableString = "You are a helpful assistant.", maxNewTokens: Int = 256, temperature: Double = 0.7, - resultColumn: EncodableString = "hf_response" + resultColumn: EncodableString = "hf_response", + imageInput: EncodableString = "", + inputImageColumn: EncodableString = "" ): HuggingFaceInferenceOpDesc = { val desc = new HuggingFaceInferenceOpDesc() desc.hfApiToken = token @@ -48,6 +50,8 @@ class HuggingFaceInferenceOpDescSpec extends AnyFlatSpec with Matchers { desc.maxNewTokens = maxNewTokens desc.temperature = temperature desc.resultColumn = resultColumn + desc.imageInput = imageInput + desc.inputImageColumn = inputImageColumn desc } @@ -146,6 +150,8 @@ class HuggingFaceInferenceOpDescSpec extends AnyFlatSpec with Matchers { desc.task = null desc.maxNewTokens = null desc.temperature = null + desc.imageInput = null + desc.inputImageColumn = null val code = desc.generatePythonCode() code should include("class ProcessTableOperator(UDFTableOperator):") code should include("def open(self):") @@ -182,6 +188,220 @@ class HuggingFaceInferenceOpDescSpec extends AnyFlatSpec with Matchers { TextGenCodegen.parsePython(ctx) should include("""body["choices"][0]["message"]["content"]""") } + "image task family" should + "route image-only tasks through ImageTaskCodegen (raw binary payload + image headers)" in { + val code = + makeDesc(task = "image-classification", inputImageColumn = "img").generatePythonCode() + code should include("self.IMAGE_INPUT = ") + code should include("self.INPUT_IMAGE_COLUMN = ") + code should include("if task in image_only_tasks:") + code should include("payload = current_image_bytes") + code should include("use_raw_binary_body = True") + code should include("raw_binary_headers = image_headers") + // image bytes resolution + image content-type response handling exist + code should include("self._read_image_input()") + code should include("self._read_binary_value") + code should include("self._compress_image_bytes") + code should include("""if content_type.startswith("image/"):""") + } + + it should + "not read arbitrary worker-filesystem paths for image inputs (SSRF/LFI hardening)" in { + // Opening an arbitrary path from the worker filesystem would let a workflow + // exfiltrate any file (e.g. /etc/passwd) via the inference call. Image inputs + // must be data URLs, http(s) URLs, rendered HTML, or raw/base64 bytes only — + // never a path passed to open(). + val code = makeDesc(task = "image-classification", inputImageColumn = "img") + .generatePythonCode() + // The removed filesystem-read branches must not reappear. + code should not include "open(image_input" + code should not include "os.path.isfile(image_input)" + code should not include "os.path.exists(image_input)" + code should not include "if os.path.exists(val) and os.path.isfile(val):" + // Unsupported image inputs are rejected with a clear error instead. + code should include("Unsupported image input") + } + + it should "route VQA / document-QA through ImageTaskCodegen (base64 image + question payload)" in { + val code = makeDesc(task = "visual-question-answering").generatePythonCode() + code should include( + """elif task in ("visual-question-answering", "document-question-answering"):""" + ) + code should include("self._image_input_as_base64(current_image_bytes)") + code should include(""""question": prompt_value""") + } + + it should + "emit single-backslash regex/whitespace escapes in the HTML->image helpers" in { + // The HTML->image helpers came from the original monolith where, inside a + // raw triple-quoted Scala string, "\\n"/"\\." emit DOUBLE backslashes to + // Python. That makes the base64 char class match a literal backslash+n + // instead of a newline, and makes the Plotly detection regex require a + // literal backslash before "newPlot" (so it never matches). The generated + // Python must contain single-backslash forms. + val code = makeDesc(task = "image-to-text", inputImageColumn = "img").generatePythonCode() + + // base64 char class allows real newlines/CR; strip uses real newline chars. + code should include("""[A-Za-z0-9+/\n\r =]""") + code should include(""".replace("\n", "").replace("\r", "")""") + // Plotly detection regex uses real regex escapes. + code should include("""r"Plotly\.(?:newPlot|react)\s*\(\s*"""") + // whitespace-skip set contains real whitespace chars. + code should include("""in " ,\n\r\t"""") + + // The broken double-backslash forms must NOT reappear. + code should not include """[A-Za-z0-9+/\\n\\r =]""" + code should not include """Plotly\\.(?:newPlot|react)""" + code should not include """in " ,\\n\\r\\t"""" + } + + it should "harden remote URL fetches against SSRF (https-only, private-IP block, size cap)" in { + // Remote image/result URLs (user-provided or returned by a third-party + // provider) are fetched through _fetch_remote_url, which enforces https, + // rejects private/loopback/link-local/reserved/metadata addresses, and + // caps the response size. + val code = makeDesc(task = "image-to-image", inputImageColumn = "img").generatePythonCode() + code should include("def _fetch_remote_url(self, url):") + // https-only + code should include("""if parsed.scheme != "https":""") + // private / metadata IP blocking (169.254.169.254 is link-local) + code should include("ip.is_private") + code should include("ip.is_loopback") + code should include("ip.is_link_local") + code should include("Refusing to fetch from non-public address") + // size cap + code should include("MAX_REMOTE_FETCH_BYTES") + code should include("Remote file exceeds the") + // all three fetch sites route through the helper (no raw requests.get on these URLs) + code should include("_, data = self._fetch_remote_url(image_input)") + code should include("_, data = self._fetch_remote_url(val)") + code should include("raw_content_type, data = self._fetch_remote_url(url)") + } + + it should "treat pandas NA sentinels (NaN, pd.NA, NaT) as missing in _read_binary_value" in { + // isinstance(value, float) only catches float('nan'); pd.NA / NaT are not + // floats and previously fell through to be str()-ified into bytes. The + // guarded pd.isna check now catches all scalar NA sentinels. + val code = makeDesc(task = "image-classification", inputImageColumn = "img") + .generatePythonCode() + code should include("if pd.isna(value):") + code should include("except (TypeError, ValueError):") + // The old float-only guard must be gone. + code should not include "isinstance(value, float) and pd.isna(value)" + } + + it should "not import the unused top-level urlparse in the generated script" in { + val code = makeDesc().generatePythonCode() + code should not include "from urllib.parse import urlparse\n" + // The local aliased import is still used where needed. + code should include("from urllib.parse import urlparse as _urlparse") + } + + it should + "convert Replicate terminal failed/canceled status into a synthetic 502 with surfaced error detail" in { + // Replicate's polling endpoint returns HTTP 200 even when the prediction + // itself terminally failed. Without this fix, + // _post_with_fallback sees status 200 and process_table parses the + // success-shape, silently emitting json.dumps(body) (raw error JSON) + // into the result column instead of a readable error. We synthesize a + // 502 with a top-level `error` field so the upstream non-200 path + // surfaces the actual reason via _format_error. + val code = makeDesc(task = "image-to-image").generatePythonCode() + code should include("""if status == "succeeded":""") + code should include("""if status in ("failed", "canceled"):""") + code should include("Replicate prediction") + code should include("poll_resp.status_code = 502") + code should include("""body_json.get("error")""") + } + + it should + "convert Wavespeed terminal failed status into a synthetic 502 with surfaced error detail" in { + // Same fix as Replicate, applied to Wavespeed's poll loop where the + // pattern was `status in ("completed", "failed")` collapsing both + // terminal states into a single `return poll_resp`. We now route + // "failed" through the synthetic-502 path so the error reaches the + // user instead of being parsed as a successful body. + val code = makeDesc(task = "image-to-image").generatePythonCode() + code should include("""if status == "completed":""") + code should include("""if status == "failed":""") + code should include("Wavespeed job failed") + } + + it should + "fail fast at runtime when zero-shot-image-classification has fewer than 2 candidate labels" in { + // Without a dedicated candidateLabels field (lands in PR 5), zero-shot + // reuses prompt_value as a comma- + // separated list. Two failure modes the bare list comprehension hides + // are both caught by the >= 2 check: + // 1. Empty prompt column → labels = [] → HF API rejects + // candidate_labels: [] with an opaque 400. + // 2. Missing prompt column → upstream falls back to "What is shown in + // this image?" (no comma) → labels = ["What is shown in this image?"], + // a single nonsense label that returns a useless 1.0 score. + // Zero-shot classification needs >= 2 candidate labels to be meaningful, + // so the fix raises ValueError before the request goes out and the user + // sees a clear configuration error instead of a generic HTTP failure or + // misleading 100%-confidence garbage. + val code = makeDesc(task = "zero-shot-image-classification").generatePythonCode() + code should include("if len(labels) < 2:") + code should include("raise ValueError(") + code should include("at least 2 candidate") + } + + it should + "extract base64 image from image+prompt dict payloads in _call_provider so third-party providers receive it" in { + // Regression test: visual-question-answering, + // document-question-answering, and zero-shot-image-classification build + // dict payloads with use_raw_binary_body=False. Before the fix, when + // those tasks routed off hf-inference to a third-party provider, the + // top-of-_call_provider img_b64 stayed "" and the image was silently + // dropped. The fix reads the base64 out of payload["inputs"]["image"] + // (for VQA / doc-QA) or payload["inputs"] (for zero-shot-image- + // classification) so every provider branch below sees a populated img_b64. + val code = makeDesc(task = "visual-question-answering").generatePythonCode() + // VQA / doc-QA: image at payload["inputs"]["image"]. + code should include("""isinstance(inputs, dict) and isinstance(inputs.get("image"), str)""") + code should include("""img_b64 = inputs["image"]""") + // Zero-shot-image-classification: image at payload["inputs"] directly. + code should include( + """elif task == "zero-shot-image-classification" and isinstance(inputs, str):""" + ) + code should include("img_b64 = inputs") + } + + it should "route image-text-to-text through chat completions with embedded base64 image" in { + val code = makeDesc(task = "image-text-to-text").generatePythonCode() + code should include("""elif task == "image-text-to-text":""") + code should include("""data:image/png;base64,{img_b64}""") + code should include("self.MODEL_ID") + } + + it should "route image-to-image as raw binary and parse via _url_to_data_url on JSON response" in { + val code = makeDesc(task = "image-to-image").generatePythonCode() + code should include("""elif task == "image-to-image":""") + code should include("self._url_to_data_url(") + } + + it should + "register all 9 image task strings under the dispatcher (image-only + image+prompt)" in { + // Each image task should pull in ImageTaskCodegen's branch chain. + val imageTasks = Seq( + "image-classification", + "object-detection", + "image-segmentation", + "image-to-text", + "visual-question-answering", + "document-question-answering", + "zero-shot-image-classification", + "image-text-to-text", + "image-to-image" + ) + imageTasks.foreach { t => + val code = makeDesc(task = t).generatePythonCode() + code should include("if task in image_only_tasks:") + } + } + "getOutputSchemas" should "add the result column as a STRING to the inherited schema" in { val desc = makeDesc(resultColumn = "answer") val inputSchema = Schema().add("prompt", AttributeType.STRING) diff --git a/frontend/src/app/app.module.ts b/frontend/src/app/app.module.ts index 524146cf75..35e82f81b7 100644 --- a/frontend/src/app/app.module.ts +++ b/frontend/src/app/app.module.ts @@ -105,6 +105,7 @@ import { CoeditorUserIconComponent } from "./workspace/component/menu/coeditor-u import { AgentPanelComponent } from "./workspace/component/agent/agent-panel/agent-panel.component"; import { AgentChatComponent } from "./workspace/component/agent/agent-panel/agent-chat/agent-chat.component"; import { AgentRegistrationComponent } from "./workspace/component/agent/agent-panel/agent-registration/agent-registration.component"; +import { HuggingFaceImageUploadComponent } from "./workspace/component/hugging-face-image-upload/hugging-face-image-upload.component"; import { DatasetFileSelectorComponent } from "./workspace/component/dataset-file-selector/dataset-file-selector.component"; import { DatasetVersionSelectorComponent } from "./workspace/component/dataset-version-selector/dataset-version-selector.component"; import { DatasetSelectionModalComponent } from "./workspace/component/dataset-selection-modal/dataset-selection-modal.component"; @@ -331,6 +332,7 @@ registerLocaleData(en); AgentChatComponent, AgentRegistrationComponent, AgentInteractionComponent, + HuggingFaceImageUploadComponent, DatasetFileSelectorComponent, DatasetVersionSelectorComponent, DatasetSelectionModalComponent, diff --git a/frontend/src/app/common/formly/formly-config.ts b/frontend/src/app/common/formly/formly-config.ts index 707ddfa797..ba80dc51f9 100644 --- a/frontend/src/app/common/formly/formly-config.ts +++ b/frontend/src/app/common/formly/formly-config.ts @@ -29,6 +29,7 @@ import { CollabWrapperComponent } from "./collab-wrapper/collab-wrapper/collab-w import { FormlyRepeatDndComponent } from "./repeat-dnd/repeat-dnd.component"; import { UiUdfParametersComponent } from "../../workspace/component/ui-udf-parameters/ui-udf-parameters.component"; import { DatasetVersionSelectorComponent } from "../../workspace/component/dataset-version-selector/dataset-version-selector.component"; +import { HuggingFaceImageUploadComponent } from "../../workspace/component/hugging-face-image-upload/hugging-face-image-upload.component"; /** * Configuration for using Json Schema with Formly. @@ -80,6 +81,7 @@ export const TEXERA_FORMLY_CONFIG = { { name: "codearea", component: CodeareaCustomTemplateComponent }, { name: "inputautocomplete", component: DatasetFileSelectorComponent, wrappers: ["form-field"] }, { name: "datasetversionselector", component: DatasetVersionSelectorComponent, wrappers: ["form-field"] }, + { name: "huggingface-image-upload", component: HuggingFaceImageUploadComponent, wrappers: ["form-field"] }, { name: "repeat-section-dnd", component: FormlyRepeatDndComponent }, { name: "ui-udf-parameters", component: UiUdfParametersComponent, wrappers: ["form-field"] }, ], diff --git a/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.html b/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.html new file mode 100644 index 0000000000..441c71f9c9 --- /dev/null +++ b/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.html @@ -0,0 +1,51 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + +<div class="hf-image-upload"> + <input + #fileInput + type="file" + accept="image/*" + class="hf-image-upload-input" + (change)="onFileSelected($event)" /> + + <div + *ngIf="previewSrc" + class="hf-image-preview"> + <img + [src]="previewSrc" + alt="Uploaded Hugging Face task input" /> + <div class="hf-image-meta"> + <span>{{ displayFileName || "Selected image" }}</span> + <button + nz-button + nzSize="small" + type="button" + (click)="clearImage(fileInput)"> + Clear + </button> + </div> + </div> + + <div + *ngIf="errorMessage" + class="hf-image-error"> + {{ errorMessage }} + </div> +</div> diff --git a/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.scss b/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.scss new file mode 100644 index 0000000000..b292d8131e --- /dev/null +++ b/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.scss @@ -0,0 +1,60 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +.hf-image-upload { + display: flex; + flex-direction: column; + gap: 8px; +} + +.hf-image-upload-input { + width: 100%; +} + +.hf-image-preview { + border: 1px solid #d9d9d9; + border-radius: 4px; + padding: 8px; +} + +.hf-image-preview img { + display: block; + width: 100%; + max-height: 220px; + object-fit: contain; + background: #f5f5f5; +} + +.hf-image-meta { + display: flex; + align-items: center; + justify-content: space-between; + gap: 8px; + margin-top: 8px; +} + +.hf-image-meta span { + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.hf-image-error { + color: #cf1322; +} diff --git a/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.spec.ts b/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.spec.ts new file mode 100644 index 0000000000..6bd947ef0e --- /dev/null +++ b/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.spec.ts @@ -0,0 +1,146 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { ComponentFixture, TestBed } from "@angular/core/testing"; +import { FormControl } from "@angular/forms"; +import { HuggingFaceImageUploadComponent } from "./hugging-face-image-upload.component"; +import { commonTestProviders } from "../../../common/testing/test-utils"; + +describe("HuggingFaceImageUploadComponent", () => { + let component: HuggingFaceImageUploadComponent; + let fixture: ComponentFixture<HuggingFaceImageUploadComponent>; + + beforeEach(async () => { + await TestBed.configureTestingModule({ + imports: [HuggingFaceImageUploadComponent], + providers: [...commonTestProviders], + }).compileComponents(); + + fixture = TestBed.createComponent(HuggingFaceImageUploadComponent); + component = fixture.componentInstance; + component.field = { + props: {}, + formControl: new FormControl(""), + key: "image", + model: {}, + } as any; + fixture.detectChanges(); + }); + + it("should create", () => { + expect(component).toBeTruthy(); + }); + + describe("derived view state", () => { + it("reports no image when formControl is empty", () => { + expect(component.hasImage).toBe(false); + expect(component.previewSrc).toBe(""); + expect(component.displayFileName).toBe(""); + }); + + it("reports an image when formControl holds a data URL", () => { + component.formControl.setValue("data:image/jpeg;base64,AAA"); + expect(component.hasImage).toBe(true); + expect(component.previewSrc).toBe("data:image/jpeg;base64,AAA"); + expect(component.displayFileName).toBe("Uploaded image"); + }); + + it("prefers the explicit filename over the fallback label", () => { + component.formControl.setValue("data:image/jpeg;base64,AAA"); + component.fileName = "cat.jpg"; + expect(component.displayFileName).toBe("cat.jpg"); + }); + }); + + describe("onFileSelected", () => { + function makeFileInput(file?: File): HTMLInputElement { + const input = document.createElement("input"); + input.type = "file"; + if (file) { + Object.defineProperty(input, "files", { + value: [file] as unknown as FileList, + configurable: true, + }); + } + return input; + } + + it("clears prior error and returns early when no file is provided", async () => { + component.errorMessage = "previous error"; + const input = makeFileInput(); + await component.onFileSelected({ target: input } as unknown as Event); + expect(component.errorMessage).toBe(""); + expect(component.formControl.value).toBe(""); + }); + + it("rejects non-image files and resets the input", async () => { + const txtFile = new File(["hi"], "note.txt", { type: "text/plain" }); + const input = makeFileInput(txtFile); + await component.onFileSelected({ target: input } as unknown as Event); + expect(component.errorMessage).toBe("Choose an image file."); + expect(component.hasImage).toBe(false); + }); + + it("reports an error when image compression fails", async () => { + // jsdom's Image never fires onload/onerror, so compressImage would hang + // forever. Stub FileReader so it synchronously fires onerror, which + // makes compressImage reject and exercises the catch branch. + const realFileReader = globalThis.FileReader; + class FailingFileReader { + onload: ((e: Event) => void) | null = null; + onerror: ((e: Event) => void) | null = null; + readAsDataURL() { + queueMicrotask(() => this.onerror?.(new Event("error"))); + } + } + (globalThis as any).FileReader = FailingFileReader; + try { + const imgFile = new File(["fake"], "broken.png", { type: "image/png" }); + const input = makeFileInput(imgFile); + await component.onFileSelected({ target: input } as unknown as Event); + expect(component.errorMessage).toBe("Could not prepare this image. Try a smaller image file."); + expect(component.hasImage).toBe(false); + } finally { + (globalThis as any).FileReader = realFileReader; + } + }); + }); + + describe("clearImage", () => { + it("resets file state, the form control, and any model value", () => { + (component.field as any).model = { image: "data:image/jpeg;base64,AAA" }; + component.formControl.setValue("data:image/jpeg;base64,AAA"); + component.fileName = "cat.jpg"; + component.errorMessage = "some error"; + + const input = document.createElement("input"); + input.type = "file"; + + component.clearImage(input); + + expect(component.fileName).toBe(""); + expect(component.errorMessage).toBe(""); + expect(input.value).toBe(""); + expect(component.formControl.value).toBe(""); + expect(component.formControl.dirty).toBe(true); + expect(component.formControl.touched).toBe(true); + expect((component.model as any).image).toBe(""); + }); + }); +}); diff --git a/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.ts b/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.ts new file mode 100644 index 0000000000..4b72e14aa5 --- /dev/null +++ b/frontend/src/app/workspace/component/hugging-face-image-upload/hugging-face-image-upload.component.ts @@ -0,0 +1,162 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { Component } from "@angular/core"; +import { CommonModule } from "@angular/common"; +import { FieldType, FieldTypeConfig } from "@ngx-formly/core"; +import { NzButtonModule } from "ng-zorro-antd/button"; + +// Keep in sync with PythonCodegenBase._compress_image_bytes(max_bytes) on the backend: +// the uploaded data URL must stay within the size the inference helpers expect. +const MAX_DATA_URL_LENGTH = 45000; +const INITIAL_MAX_DIMENSION = 512; +const MIN_MAX_DIMENSION = 160; +const INITIAL_JPEG_QUALITY = 0.75; +const MIN_JPEG_QUALITY = 0.35; + +@Component({ + selector: "texera-hugging-face-image-upload", + templateUrl: "./hugging-face-image-upload.component.html", + styleUrls: ["./hugging-face-image-upload.component.scss"], + imports: [CommonModule, NzButtonModule], +}) +export class HuggingFaceImageUploadComponent extends FieldType<FieldTypeConfig> { + fileName = ""; + errorMessage = ""; + + get hasImage(): boolean { + const value = this.formControl.value; + return typeof value === "string" && value.startsWith("data:image/"); + } + + get previewSrc(): string { + return this.hasImage ? this.formControl.value : ""; + } + + get displayFileName(): string { + if (this.fileName) return this.fileName; + if (this.hasImage) return "Uploaded image"; + return ""; + } + + async onFileSelected(event: Event): Promise<void> { + this.errorMessage = ""; + const input = event.target as HTMLInputElement; + const file = input.files?.[0]; + + if (!file) { + return; + } + if (!file.type.startsWith("image/")) { + this.errorMessage = "Choose an image file."; + input.value = ""; + return; + } + + try { + const dataUrl = await this.compressImage(file); + this.fileName = file.name; + this.formControl.setValue(dataUrl); + if (typeof this.key === "string" && this.model) { + this.model[this.key] = dataUrl; + } + this.formControl.markAsDirty(); + this.formControl.markAsTouched(); + this.formControl.updateValueAndValidity(); + } catch { + this.errorMessage = "Could not prepare this image. Try a smaller image file."; + input.value = ""; + } + } + + private compressImage(file: File): Promise<string> { + const reader = new FileReader(); + const image = new Image(); + + return new Promise((resolve, reject) => { + reader.onload = () => { + if (typeof reader.result !== "string") { + reject(); + return; + } + image.onload = () => { + const compressed = this.renderCompressedDataUrl(image); + if (!compressed.startsWith("data:image/") || compressed.length > MAX_DATA_URL_LENGTH) { + reject(); + return; + } + resolve(compressed); + }; + image.onerror = () => reject(); + image.src = reader.result; + }; + reader.onerror = () => reject(); + reader.readAsDataURL(file); + }); + } + + private renderCompressedDataUrl(image: HTMLImageElement): string { + let maxDimension = INITIAL_MAX_DIMENSION; + let quality = INITIAL_JPEG_QUALITY; + let bestDataUrl = ""; + + while (maxDimension >= MIN_MAX_DIMENSION) { + const scale = Math.min(1, maxDimension / Math.max(image.width, image.height)); + const width = Math.max(1, Math.round(image.width * scale)); + const height = Math.max(1, Math.round(image.height * scale)); + const canvas = document.createElement("canvas"); + canvas.width = width; + canvas.height = height; + const ctx = canvas.getContext("2d"); + + if (!ctx) { + return bestDataUrl; + } + + ctx.drawImage(image, 0, 0, width, height); + quality = INITIAL_JPEG_QUALITY; + + while (quality >= MIN_JPEG_QUALITY) { + const dataUrl = canvas.toDataURL("image/jpeg", quality); + bestDataUrl = dataUrl; + if (dataUrl.length <= MAX_DATA_URL_LENGTH) { + return dataUrl; + } + quality -= 0.1; + } + + maxDimension = Math.floor(maxDimension * 0.75); + } + + return bestDataUrl; + } + + clearImage(input: HTMLInputElement): void { + this.fileName = ""; + this.errorMessage = ""; + input.value = ""; + this.formControl.setValue(""); + if (typeof this.key === "string" && this.model) { + this.model[this.key] = ""; + } + this.formControl.markAsDirty(); + this.formControl.markAsTouched(); + this.formControl.updateValueAndValidity(); + } +} diff --git a/frontend/src/assets/operator_images/HuggingFace.png b/frontend/src/assets/operator_images/HuggingFace.png new file mode 100644 index 0000000000..673b8ea907 Binary files /dev/null and b/frontend/src/assets/operator_images/HuggingFace.png differ diff --git a/frontend/src/assets/sample-image.png b/frontend/src/assets/sample-image.png new file mode 100644 index 0000000000..c28d120ab7 Binary files /dev/null and b/frontend/src/assets/sample-image.png differ
