This is an automated email from the ASF dual-hosted git repository. github-merge-queue[bot] pushed a commit to branch gh-readonly-queue/main/pr-5278-3dab771a2fe3ea5bf97c4c69cfbd761f9cd01e54 in repository https://gitbox.apache.org/repos/asf/texera.git
commit 2b9add956c9e63c3c4f6e717221a0c5e33e54875 Author: Prateek Ganigi <[email protected]> AuthorDate: Mon Jun 15 15:27:49 2026 -0700 feat(huggingFace): refactor operator into per-task codegen + text-generation (#5278) > ⚠️ This PR is stacked on #5124. Until that lands, the diff below includes #5124's `HuggingFaceModelResource.scala` and the 1-line registration in `TexeraWebApplication.scala`. The new code in this PR is everything under `common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/` and the new test under `common/workflow-operator/src/test/.../huggingFace/HuggingFaceInferenceOpDescSpec.scala`. Once #5124 merges, this diff will auto-clean to ~839 lines. ### What changes were proposed in this PR? Refactors the monolithic 1,278-line `HuggingFaceInferenceOpDesc` from the team's feature branch into a dispatcher + per-task codegen architecture and ships the first task family (text-generation): - `codegen/TaskCodegen.scala` introduces the trait + `CodegenContext` that model per-task variation. - `codegen/PythonCodegenBase.scala` emits the shared provider-fallback / `process_table` / `_parse_response` infrastructure with two holes for the per-task payload and parse snippets. - `codegen/TextGenCodegen.scala` supplies text-generation's chat-completions payload and the `body["choices"][0 ["message"]["content"]` parse branch. - `HuggingFaceInferenceOpDesc.scala` becomes a thin (~180-line) dispatcher holding the `@JsonProperty` fields and the `registeredCodegens` map. User-input string fields are typed `EncodableString` and emitted via the `pyb"..."` macro so values reach Python as `self.decode_python_template('<base64>')` rather than raw literals. Class constants are assigned in `open(self)` so `self` is in scope for the decode call. The generated `process_table` runs a defensive `_HF_MODEL_ID_PATTERN` check at runtime before any HF URL is composed. The `TaskCodegen` trait also exposes a `tasks: Set[String]` default so a single codegen can register under multiple task strings, this becomes relevant in PR 3 (image family). ### Any related issues, documentation, or discussions? Tracked in #5277 & #5041(umbrella issue for the HuggingFace operator end-to-end implementation). Closes #5277 Stacked on #5124 (PR 1 - REST resource). This is PR 2 of a multi-PR series landing the HuggingFace operator end-to-end. The full plan and umbrella issue live separately; this PR's scope is exactly the dispatcher pattern + text-generation codegen. ### How was this PR tested? - `sbt "WorkflowOperator/compile; WorkflowOperator/Test/compile"` clean. - `sbt scalafmtCheck` clean. - `sbt "WorkflowOperator/testOnly org.apache.texera.amber.operator.huggingFace.HuggingFaceInferenceOpDescSpec"` - 10/10 pass (operator info, validation, codegen wiring, MODEL_ID runtime check, leak-prevention, clamping, schema). - `sbt "WorkflowOperator/testOnly org.apache.texera.amber.util.PythonCodeRawInvalidTextSpec"` - 117/117 descriptors `py_compile` cleanly, no raw-text leaks. The new operator is included in this scan. - Generated Python verified via `python3 -m py_compile` on a sample output. ### Was this PR authored or co-authored using generative AI tooling? Co-authored with Claude Opus 4.7 --------- Co-authored-by: Elliot Lin <[email protected]> Co-authored-by: Claude Opus 4.7 (1M context) <[email protected]> Co-authored-by: Xuan Gu <[email protected]> --- .../apache/texera/amber/operator/LogicalOp.scala | 2 + .../huggingFace/HuggingFaceInferenceOpDesc.scala | 194 +++++++++++ .../huggingFace/codegen/PythonCodegenBase.scala | 376 +++++++++++++++++++++ .../operator/huggingFace/codegen/TaskCodegen.scala | 77 +++++ .../huggingFace/codegen/TextGenCodegen.scala | 54 +++ .../HuggingFaceInferenceOpDescSpec.scala | 202 +++++++++++ 6 files changed, 905 insertions(+) diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/LogicalOp.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/LogicalOp.scala index 4e9d6c6e2c..55e241ecaf 100644 --- a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/LogicalOp.scala +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/LogicalOp.scala @@ -43,6 +43,7 @@ import org.apache.texera.amber.operator.dummy.DummyOpDesc import org.apache.texera.amber.operator.filter.SpecializedFilterOpDesc import org.apache.texera.amber.operator.hashJoin.HashJoinOpDesc import org.apache.texera.amber.operator.huggingFace.{ + HuggingFaceInferenceOpDesc, HuggingFaceIrisLogisticRegressionOpDesc, HuggingFaceSentimentAnalysisOpDesc, HuggingFaceSpamSMSDetectionOpDesc, @@ -396,6 +397,7 @@ trait StateTransferFunc ), new Type(value = classOf[SklearnDummyClassifierOpDesc], name = "SklearnDummyClassifier"), new Type(value = classOf[SklearnPredictionOpDesc], name = "SklearnPrediction"), + new Type(value = classOf[HuggingFaceInferenceOpDesc], name = "HuggingFace"), new Type( value = classOf[HuggingFaceSentimentAnalysisOpDesc], name = "HuggingFaceSentimentAnalysis" diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDesc.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDesc.scala new file mode 100644 index 0000000000..07466c898e --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDesc.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace + +import com.fasterxml.jackson.annotation.{JsonProperty, JsonPropertyDescription} +import com.kjetland.jackson.jsonSchema.annotations.JsonSchemaTitle +import org.apache.texera.amber.core.tuple.{AttributeType, Schema} +import org.apache.texera.amber.core.workflow.{InputPort, OutputPort, PortIdentity} +import org.apache.texera.amber.operator.PythonOperatorDescriptor +import org.apache.texera.amber.operator.huggingFace.codegen.{ + CodegenContext, + PythonCodegenBase, + TaskCodegen, + TextGenCodegen +} +import org.apache.texera.amber.operator.metadata.annotations.AutofillAttributeName +import org.apache.texera.amber.operator.metadata.{OperatorGroupConstants, OperatorInfo} +import org.apache.texera.amber.pybuilder.PyStringTypes.EncodableString + +/** + * Generic Hugging Face inference operator. + * + * This is the first slice of a feature that will eventually cover ~20 HF + * pipeline tasks. PR 2 ships text-generation only; image, audio, + * media-generation, and QA task families land in subsequent PRs as new + * `TaskCodegen` implementations registered in `registeredCodegens`. + * + * The Python script that runs at execution time is assembled by + * `PythonCodegenBase.render(ctx, codegen)`, which composes the shared + * provider-fallback / request-loop infrastructure with the per-task + * payload + parse snippets supplied by the selected `TaskCodegen`. + * + * User-provided string fields are typed as [[EncodableString]] so the + * `pyb"..."` macro inside `PythonCodegenBase` emits them as + * base64-decoded expressions at runtime instead of raw Python literals — + * this is what allows the operator to satisfy + * `PythonCodeRawInvalidTextSpec`'s contract that arbitrary `@JsonProperty` + * values must not leak into generated source. + */ +class HuggingFaceInferenceOpDesc extends PythonOperatorDescriptor { + + @JsonProperty(value = "hfApiToken", required = true) + @JsonSchemaTitle("HF API Token") + @JsonPropertyDescription( + "Your Hugging Face API token (from https://huggingface.co/settings/tokens)" + ) + var hfApiToken: EncodableString = "" + + @JsonProperty(value = "task", required = true, defaultValue = "text-generation") + @JsonSchemaTitle("Task") + @JsonPropertyDescription("The Hugging Face pipeline task type") + var task: EncodableString = "text-generation" + + @JsonProperty( + value = "modelId", + required = true, + defaultValue = "Qwen/Qwen2.5-72B-Instruct" + ) + @JsonSchemaTitle("Model") + @JsonPropertyDescription("Select a Hugging Face model") + var modelId: EncodableString = "Qwen/Qwen2.5-72B-Instruct" + + @JsonProperty(value = "promptColumn", required = true) + @JsonSchemaTitle("Prompt Column") + @JsonPropertyDescription("Column in the input table to use as the user prompt") + @AutofillAttributeName + var promptColumn: EncodableString = "" + + @JsonProperty( + value = "systemPrompt", + required = false, + defaultValue = "You are a helpful assistant." + ) + @JsonSchemaTitle("System Prompt") + @JsonPropertyDescription("Optional system message to set model behavior") + var systemPrompt: EncodableString = "You are a helpful assistant." + + @JsonProperty(value = "maxNewTokens", required = false, defaultValue = "256") + @JsonSchemaTitle("Max New Tokens") + @JsonPropertyDescription("Maximum number of tokens to generate (1-4096)") + var maxNewTokens: java.lang.Integer = 256 + + @JsonProperty(value = "temperature", required = false) + @JsonSchemaTitle("Temperature") + @JsonPropertyDescription("Sampling temperature (0.0 = deterministic, up to 2.0)") + var temperature: java.lang.Double = 0.7 + + @JsonProperty( + value = "resultColumn", + required = false, + defaultValue = "hf_response" + ) + @JsonSchemaTitle("Result Column Name") + @JsonPropertyDescription("Name of the new column added to the output table") + var resultColumn: EncodableString = "hf_response" + + /** + * Per-task code generators. New entries are added as task families land + * in subsequent PRs (e.g. ImageTaskCodegen, AudioTaskCodegen, etc.). + * + * An unrecognized task string falls back to [[TextGenCodegen]]; the + * generated Python's `else` branch then produces a generic `{"inputs": + * prompt_value}` payload and the HF endpoint surfaces the real error at + * runtime. This matches the original monolithic operator's behavior and + * keeps `generatePythonCode` total (it never throws on arbitrary input, + * which is required by `PythonCodeRawInvalidTextSpec`). + */ + private val registeredCodegens: Map[String, TaskCodegen] = + Map(TextGenCodegen.task -> TextGenCodegen) + + private def codegenForTask(t: String): TaskCodegen = + registeredCodegens.getOrElse(t, TextGenCodegen) + + /** + * The output column name to use in generated Python and in the output + * schema. Falls back to the `"hf_response"` sentinel when the user + * leaves the field null or blank. + * + * Shared between [[generatePythonCode]] and [[getOutputSchemas]] so the + * two never drift apart (a divergence would cause the Python operator + * to write to a column the schema didn't declare). Returns + * [[EncodableString]] rather than `String` so the value flows into the + * `pyb` template with the encoding annotation intact. + */ + private def resolvedResultColumn: EncodableString = + if (resultColumn == null || resultColumn.trim.isEmpty) "hf_response" + else resultColumn + + override def generatePythonCode(): String = { + val safeTask: EncodableString = + if (task == null || task.trim.isEmpty) "text-generation" else task + val safeModelId: EncodableString = + if (modelId == null) "" else modelId.trim + val safePromptCol: EncodableString = + if (promptColumn == null) "" else promptColumn + val safeResultCol: EncodableString = resolvedResultColumn + val safeSystemPrompt: EncodableString = + if (systemPrompt == null) "" else systemPrompt + val safeToken: EncodableString = + if (hfApiToken == null) "" else hfApiToken + + val safeMaxTokens = + math.max(1, math.min(if (maxNewTokens != null) maxNewTokens.intValue else 256, 4096)) + val safeTemp = + math.max(0.0, math.min(if (temperature != null) temperature.doubleValue else 0.7, 2.0)) + + val ctx = CodegenContext( + hfApiToken = safeToken, + modelId = safeModelId, + promptColumn = safePromptCol, + resultColumn = safeResultCol, + task = safeTask, + systemPrompt = safeSystemPrompt, + safeMaxTokens = safeMaxTokens, + safeTemp = safeTemp + ) + + PythonCodegenBase.render(ctx, codegenForTask(safeTask)) + } + + override def operatorInfo: OperatorInfo = + OperatorInfo( + "Hugging Face", + "Call a Hugging Face model via the Inference API", + OperatorGroupConstants.HUGGINGFACE_GROUP, + inputPorts = List(InputPort()), + outputPorts = List(OutputPort()) + ) + + override def getOutputSchemas( + inputSchemas: Map[PortIdentity, Schema] + ): Map[PortIdentity, Schema] = + Map( + operatorInfo.outputPorts.head.id -> inputSchemas.values.head + .add(resolvedResultColumn, AttributeType.STRING) + ) +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/PythonCodegenBase.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/PythonCodegenBase.scala new file mode 100644 index 0000000000..16c2cc9bbb --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/PythonCodegenBase.scala @@ -0,0 +1,376 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace.codegen + +import org.apache.texera.amber.pybuilder.PythonTemplateBuilder.PythonTemplateBuilderStringContext + +/** + * Builds the Python script emitted by HuggingFaceInferenceOpDesc. + * + * The script defines a `ProcessTableOperator` class with: + * - Per-instance configuration set in `open(self)` from base64-encoded + * values that the `pyb"..."` macro decodes at runtime (so user-input + * strings never appear as raw Python literals in the source). + * - A provider-fallback system that walks the HF Hub's inference-provider + * list cheapest-first and tries each provider's native chat-completions + * route, with HF Inference Router as the default. + * - A `process_table` loop that validates the prompt column, builds the + * per-row payload via the per-task codegen, posts to the resolved + * provider, and parses the response. + * - A `_parse_response` task switch whose branches are provided by the + * per-task codegen. + * + * Per-task variation lives in `TaskCodegen` implementations. This class + * holds only what is shared across all HF tasks; per-task helpers (image + * loading, audio MIME inference, media-URL fetching, etc.) will be added + * in subsequent PRs as the corresponding task families land. + */ +object PythonCodegenBase { + + def render(ctx: CodegenContext, codegen: TaskCodegen): String = { + val payload = codegen.payloadPython(ctx) + val parse = codegen.parsePython(ctx) + val hfApiToken = ctx.hfApiToken + val modelId = ctx.modelId + val promptColumn = ctx.promptColumn + val resultColumn = ctx.resultColumn + val task = ctx.task + val systemPrompt = ctx.systemPrompt + val maxNewTokens = ctx.safeMaxTokens + val temperature = ctx.safeTemp + pyb"""import os + |import re + |import json + |import requests + |import pandas as pd + |from pytexera import * + | + |# Defensive format check for MODEL_ID before it is interpolated into + |# HF URL paths. The base host is hardcoded so the worst case isn't + |# SSRF, but rejecting `..` segments / query strings / fragments / + |# control chars keeps the operator's request shape predictable. + |_HF_MODEL_ID_PATTERN = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*(/[A-Za-z0-9._-]+)+$$") + | + |class ProcessTableOperator(UDFTableOperator): + | + | # Providers ranked cheapest-first (lower index = cheaper). + | # Unknown providers are appended at the end. + | PROVIDER_COST_PRIORITY = [ + | "hf-inference", + | "cerebras", + | "sambanova", + | "groq", + | "novita", + | "nebius", + | "fireworks-ai", + | "together", + | "hyperbolic", + | "scaleway", + | "nscale", + | "ovhcloud", + | "deepinfra", + | "featherless-ai", + | "baseten", + | "publicai", + | "nvidia", + | "openai", + | "cohere", + | "clarifai", + | ] + | + | # Per-provider chat-completions route overrides. Providers not listed + | # here use the default `v1/chat/completions` path. Single source of + | # truth for both _post_with_fallback (text-gen) and _call_provider + | # (OpenAI-compatible fallback) so the two stay in sync as providers + | # are added. + | CHAT_ROUTES = { + | "groq": "openai/v1/chat/completions", + | "fireworks-ai": "inference/v1/chat/completions", + | "cohere": "compatibility/v1/chat/completions", + | "clarifai": "v2/ext/openai/v1/chat/completions", + | "deepinfra": "v1/openai/chat/completions", + | } + | + | # Third-party providers that speak the OpenAI chat-completions + | # protocol. Used by _call_provider's OpenAI-compatible branch. + | OPENAI_COMPATIBLE_PROVIDERS = ( + | "cerebras", "sambanova", "groq", "novita", "nebius", + | "fireworks-ai", "together", "hyperbolic", "cohere", "clarifai", + | "deepinfra", "featherless-ai", "nscale", "nvidia", "openai", + | "ovhcloud", "publicai", "scaleway", "baseten", + | ) + | + | def open(self): + | # User-provided strings reach the operator via base64-encoded + | # decode expressions so they cannot break Python syntax or + | # leak raw text into the generated source. + | self.HF_API_TOKEN = $hfApiToken + | self.MODEL_ID = $modelId + | self.PROMPT_COLUMN = $promptColumn + | self.RESULT_COLUMN = $resultColumn + | self.TASK = $task + | self.SYSTEM_PROMPT = $systemPrompt + | self.MAX_NEW_TOKENS = $maxNewTokens + | self.TEMPERATURE = $temperature + | + | def _resolve_providers(self, token): + | '''Query the HF Hub API for inference providers serving this model. + | Returns a list of dicts with 'name' and 'providerId' sorted + | cheapest-first. Falls back to hf-inference if anything goes wrong. + | ''' + | try: + | resp = requests.get( + | f"https://huggingface.co/api/models/{self.MODEL_ID}", + | headers={"Authorization": f"Bearer {token}"}, + | params={"expand[]": "inferenceProviderMapping"}, + | timeout=30, + | ) + | if resp.status_code == 200: + | data = resp.json() + | mapping = ( + | data.get("inferenceProviderMapping") + | or data.get("inference_provider_mapping") + | or {} + | ) + | if mapping: + | live = [ + | { + | "name": p, + | "providerId": v.get("providerId", self.MODEL_ID), + | "task": v.get("task", ""), + | "isModelAuthor": v.get("isModelAuthor", False), + | } + | for p, v in mapping.items() + | if isinstance(v, dict) and v.get("status") == "live" + | ] + | if live: + | priority = {name: idx for idx, name in enumerate(self.PROVIDER_COST_PRIORITY)} + | live.sort(key=lambda prov: priority.get(prov["name"], len(self.PROVIDER_COST_PRIORITY))) + | return live + | except Exception: + | pass + | return [{"name": "hf-inference", "providerId": self.MODEL_ID}] + | + | def _post_with_fallback(self, providers, json_headers, pipeline_payload, prompt_value): + | '''Try providers in order, using the correct API route for each. + | Returns (response, provider_summary). provider_summary is None on + | success or a string describing what failed. + | ''' + | RETRYABLE = (400, 404, 422, 429, 502, 503) + | last_resp = None + | errors = [] + | for prov in providers: + | provider_name = prov["name"] + | provider_id = prov["providerId"] + | try: + | if self.TASK == "text-generation": + | route = self.CHAT_ROUTES.get(provider_name, "v1/chat/completions") + | url = f"https://router.huggingface.co/{provider_name}/{route}" + | resp = requests.post(url, headers=json_headers, json=pipeline_payload, timeout=120) + | elif provider_name == "hf-inference": + | url = f"https://router.huggingface.co/hf-inference/models/{self.MODEL_ID}" + | resp = requests.post(url, headers=json_headers, json=pipeline_payload, timeout=120) + | else: + | resp = self._call_provider(provider_name, provider_id, json_headers, pipeline_payload, prompt_value) + | except Exception as e: + | errors.append(f"{provider_name}: {type(e).__name__}") + | continue + | if resp.status_code in (200, 201): + | return resp, None + | if resp.status_code == 401: + | return resp, None + | try: + | detail = resp.json().get("error", resp.text[:200]) + | except Exception: + | detail = resp.text[:200] if resp.text else "no details" + | errors.append(f"{provider_name}: HTTP {resp.status_code} - {detail}") + | last_resp = resp + | if resp.status_code not in RETRYABLE: + | return resp, "; ".join(errors) + | summary = "; ".join(errors) if errors else "no providers available" + | return last_resp, summary + | + | def _call_provider(self, provider_name, provider_id, json_headers, pipeline_payload, prompt_value): + | '''Route to a third-party provider using its native API format. + | For the text-gen-only build this covers the OpenAI-compatible chat + | providers and an unknown-provider fallback that tries the pipeline + | format then chat completions. Image / audio / media routing will + | be added in subsequent PRs alongside the corresponding task + | codegens. + | ''' + | base = f"https://router.huggingface.co/{provider_name}" + | if provider_name in self.OPENAI_COMPATIBLE_PROVIDERS: + | url = f"{base}/{self.CHAT_ROUTES.get(provider_name, 'v1/chat/completions')}" + | messages = [{"role": "user", "content": prompt_value}] + | return requests.post( + | url, + | headers=json_headers, + | json={"model": provider_id, "messages": messages}, + | timeout=120, + | ) + | + | # Unknown provider: try pipeline format, then chat completions. + | url = f"{base}/{provider_id}" + | resp = requests.post(url, headers=json_headers, json=pipeline_payload, timeout=120) + | if resp.status_code in (400, 404, 422): + | url = f"{base}/v1/chat/completions" + | messages = [{"role": "user", "content": prompt_value}] + | resp2 = requests.post( + | url, + | headers=json_headers, + | json={"model": provider_id, "messages": messages}, + | timeout=120, + | ) + | if resp2.status_code == 200: + | return resp2 + | return resp + | + | @overrides + | def process_table(self, table: Table, port: int) -> Iterator[Optional[TableLike]]: + | prompt_col = self.PROMPT_COLUMN + | result_col = self.RESULT_COLUMN + | task = self.TASK + | + | # --- validate MODEL_ID format before any HF URL is built --- + | if not _HF_MODEL_ID_PATTERN.match(self.MODEL_ID or ""): + | raise ValueError( + | f"Invalid Hugging Face model ID '{self.MODEL_ID}'. " + | f"Expected format like 'org/model-name' or 'org/model-name/revision'." + | ) + | + | # --- resolve API token --- + | token = self.HF_API_TOKEN if self.HF_API_TOKEN else os.environ.get("HF_TOKEN", "") + | if not token: + | raise ValueError( + | "Hugging Face API token is not set. " + | "Provide it in the operator config or via HF_TOKEN env var." + | ) + | + | # --- resolve all available inference providers for this model (tried in order) --- + | providers = self._resolve_providers(token) + | + | # --- validate prompt column exists --- + | assert prompt_col in table.columns, ( + | f"Prompt column '{prompt_col}' not found in input table. " + | f"Available columns: {list(table.columns)}" + | ) + | + | # --- handle empty table --- + | if table.empty: + | table[result_col] = pd.Series(dtype="object") + | yield table + | return + | + | json_headers = { + | "Authorization": f"Bearer {token}", + | "Content-Type": "application/json", + | } + | + | results = [] + | for idx, row in table.iterrows(): + | prompt_value = row[prompt_col] + | if pd.isna(prompt_value): + | prompt_value = "" + | else: + | prompt_value = str(prompt_value) + | + | # --- build task-specific payload (provided by per-task codegen) --- + |${payload} + | + | try: + | resp, provider_summary = self._post_with_fallback( + | providers, json_headers, payload, prompt_value + | ) + | + | if resp is None: + | results.append( + | self._format_error( + | "All inference providers failed", + | f"No provider could serve model '{self.MODEL_ID}'. " + | f"Tried: {provider_summary}" + | ) + | ) + | continue + | + | if resp.status_code == 429: + | results.append( + | self._format_http_error( + | "HF API rate limit hit, retry later", resp.status_code, resp.text + | ) + | ) + | continue + | if resp.status_code == 401: + | results.append( + | self._format_http_error("Invalid HF API token", resp.status_code, resp.text) + | ) + | continue + | if resp.status_code not in (200, 201): + | results.append( + | self._format_error( + | "All inference providers failed", + | f"No provider could serve model '{self.MODEL_ID}'. " + | f"Tried: {provider_summary}" + | ) + | ) + | continue + | + | try: + | body = resp.json() + | except ValueError: + | body = resp.text + | content = self._parse_response(body) + | results.append(content) + | + | except Exception as e: + | import warnings + | warnings.warn( + | f"Row {idx}: request failed ({type(e).__name__}: {e}), " + | f"setting result to readable error text." + | ) + | results.append(self._format_error("Request failed", f"{type(e).__name__}: {e}")) + | + | table[result_col] = results + | yield table + | + | def _format_error(self, title, detail): + | return f"{title}: {detail}" + | + | def _format_http_error(self, title, status_code, response_text): + | # Cap at 200 chars to match the truncation in _post_with_fallback's + | # error-detail extraction; a large body / HTML error page would + | # otherwise land verbatim in the result cell. + | detail = response_text.strip()[:200] + | if not detail: + | detail = "<empty response>" + | return f"{title} [status={status_code}] response={detail}" + | + | def _parse_response(self, body): + | task = self.TASK + | try: + | if isinstance(body, str): + | return body + |${parse} + | else: + | return json.dumps(body) + | except (KeyError, IndexError, TypeError): + | return json.dumps(body) + |""".encode + } +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/TaskCodegen.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/TaskCodegen.scala new file mode 100644 index 0000000000..333d1a038c --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/TaskCodegen.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace.codegen + +import org.apache.texera.amber.pybuilder.PyStringTypes.EncodableString + +/** + * Inputs the dispatcher passes through to each TaskCodegen. + * + * User-provided string fields are typed as [[EncodableString]] so the + * `pyb"..."` macro in [[PythonCodegenBase]] emits them as base64-decoded + * runtime expressions rather than raw Python string literals — required to + * pass `PythonCodeRawInvalidTextSpec`'s leakage check. + */ +final case class CodegenContext( + hfApiToken: EncodableString, + modelId: EncodableString, + promptColumn: EncodableString, + resultColumn: EncodableString, + task: EncodableString, + systemPrompt: EncodableString, + safeMaxTokens: Int, + safeTemp: Double +) + +/** + * A bundle of Python snippets that customize generated inference code for + * one Hugging Face pipeline task family. + * + * Concrete implementations are `object`s registered in + * `HuggingFaceInferenceOpDesc.registeredCodegens`. New task families + * (image, audio, QA, etc.) land in subsequent PRs by introducing new + * `*Codegen` objects and adding them to that map. + * + * Snippets returned by these methods are Python source spliced into the + * shared template assembled by [[PythonCodegenBase.render]]. Snippets must + * NOT directly inline user-provided strings — reference the per-instance + * attributes `self.HF_API_TOKEN`, `self.MODEL_ID`, `self.PROMPT_COLUMN`, + * etc. that the base class initializes from `CodegenContext` via the + * `pyb` macro's safe encoding. The snippet author is responsible for the + * correct indentation column (see existing implementations). + */ +trait TaskCodegen { + + /** Canonical Hugging Face pipeline task string, e.g. "text-generation". */ + def task: String + + /** Python text that assigns `payload = …` for one row inside + * `process_table`'s per-row loop. The snippet supplies its own leading + * `if`/`elif task == "...":` opener and any `else` fallback. + */ + def payloadPython(ctx: CodegenContext): String + + /** Python text for the body of `_parse_response`'s task switch. The + * snippet supplies its own leading `if`/`elif task == "...":` opener. + * The base class wraps the result in the try/except matching the + * source layout. + */ + def parsePython(ctx: CodegenContext): String +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/TextGenCodegen.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/TextGenCodegen.scala new file mode 100644 index 0000000000..b836de9e12 --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/huggingFace/codegen/TextGenCodegen.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace.codegen + +/** + * Codegen for the `text-generation` Hugging Face pipeline task. + * + * The payload is the OpenAI chat-completions shape — `messages` with a + * system + user pair plus `max_tokens` / `temperature` knobs — which is + * what the HF router and every OpenAI-compatible third-party provider + * (Cerebras, Groq, Sambanova, Together, …) accepts. + * + * The parse step pulls `body["choices"][0]["message"]["content"]` out of + * the response. + */ +object TextGenCodegen extends TaskCodegen { + + override val task: String = "text-generation" + + override def payloadPython(ctx: CodegenContext): String = + """ if task == "text-generation": + | payload = { + | "model": self.MODEL_ID, + | "messages": [ + | {"role": "system", "content": self.SYSTEM_PROMPT}, + | {"role": "user", "content": prompt_value}, + | ], + | "max_tokens": self.MAX_NEW_TOKENS, + | "temperature": self.TEMPERATURE, + | } + | else: + | payload = {"inputs": prompt_value}""".stripMargin + + override def parsePython(ctx: CodegenContext): String = + """ if task == "text-generation": + | return body["choices"][0]["message"]["content"]""".stripMargin +} diff --git a/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDescSpec.scala b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDescSpec.scala new file mode 100644 index 0000000000..06424df604 --- /dev/null +++ b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceInferenceOpDescSpec.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace + +import org.apache.texera.amber.core.tuple.{AttributeType, Schema} +import org.apache.texera.amber.core.workflow.PortIdentity +import org.apache.texera.amber.operator.huggingFace.codegen.{CodegenContext, TextGenCodegen} +import org.apache.texera.amber.operator.metadata.OperatorGroupConstants +import org.apache.texera.amber.pybuilder.PyStringTypes.EncodableString +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class HuggingFaceInferenceOpDescSpec extends AnyFlatSpec with Matchers { + + private def makeDesc( + token: EncodableString = "token", + modelId: EncodableString = "Qwen/Qwen2.5-72B-Instruct", + promptColumn: EncodableString = "prompt", + task: EncodableString = "text-generation", + systemPrompt: EncodableString = "You are a helpful assistant.", + maxNewTokens: Int = 256, + temperature: Double = 0.7, + resultColumn: EncodableString = "hf_response" + ): HuggingFaceInferenceOpDesc = { + val desc = new HuggingFaceInferenceOpDesc() + desc.hfApiToken = token + desc.modelId = modelId + desc.promptColumn = promptColumn + desc.task = task + desc.systemPrompt = systemPrompt + desc.maxNewTokens = maxNewTokens + desc.temperature = temperature + desc.resultColumn = resultColumn + desc + } + + "HuggingFaceInferenceOpDesc.operatorInfo" should + "advertise the user-friendly name, HuggingFace group, and one input/output port" in { + val info = (new HuggingFaceInferenceOpDesc).operatorInfo + info.userFriendlyName shouldBe "Hugging Face" + info.operatorGroupName shouldBe OperatorGroupConstants.HUGGINGFACE_GROUP + info.inputPorts.size shouldBe 1 + info.outputPorts.size shouldBe 1 + } + + "generatePythonCode" should + "fall back to the text-gen codegen on an unrecognized task (HF reports the real error at runtime)" in { + // generatePythonCode must be total — never throw on arbitrary @JsonProperty + // values — per the PythonCodeRawInvalidTextSpec contract. An unknown task + // routes through TextGenCodegen, whose payload `if/else` hits the generic + // `{"inputs": prompt_value}` branch at runtime. + val code = makeDesc(task = "not-a-real-task").generatePythonCode() + code should include("""payload = {"inputs": prompt_value}""") + } + + it should "emit a ProcessTableOperator that initializes config in open()" in { + val code = makeDesc().generatePythonCode() + code should include("class ProcessTableOperator(UDFTableOperator):") + code should include("def open(self):") + // User-input strings are decoded at runtime, not embedded as literals. + code should include("self.HF_API_TOKEN = self.decode_python_template(") + code should include("self.MODEL_ID = self.decode_python_template(") + code should include("self.PROMPT_COLUMN = self.decode_python_template(") + code should include("self.TASK = self.decode_python_template(") + code should include("self.SYSTEM_PROMPT = self.decode_python_template(") + } + + it should "wire the text-gen payload and response parse correctly" in { + val code = makeDesc().generatePythonCode() + // Payload — chat-completions shape against the configured model + system prompt. + code should include("self.MODEL_ID") + code should include("self.SYSTEM_PROMPT") + code should include("self.MAX_NEW_TOKENS") + code should include("self.TEMPERATURE") + // Parse — text-gen pulls choices[0].message.content out of the response. + code should include("""body["choices"][0]["message"]["content"]""") + } + + it should + "emit a runtime check that rejects malformed MODEL_ID values before any HF URL is built" in { + val code = makeDesc().generatePythonCode() + // Pattern that fences MODEL_ID to org/model-name (allowing org/model-name/revision). + code should include("_HF_MODEL_ID_PATTERN = re.compile(") + // Runtime fail-fast inside process_table — happens before _resolve_providers + // composes the URL, so a malformed value never escapes into a request. + code should include("if not _HF_MODEL_ID_PATTERN.match(") + code should include("raise ValueError(") + code should include("Invalid Hugging Face model ID") + } + + it should "not leak raw user-input strings into the generated Python source" in { + // Sentinel value chosen to be distinctive and non-overlapping with anything + // else in the template. If our encoding regressed back to raw literals + // (e.g. `MODEL_ID = "MARKER_zXyq42"`), this assertion would fail. + val marker = "MARKER_zXyq42" + val code = + makeDesc(modelId = marker, promptColumn = marker, token = marker).generatePythonCode() + code should not include marker + } + + it should "clamp maxNewTokens into the 1-4096 range" in { + makeDesc(maxNewTokens = -5).generatePythonCode() should include( + "self.MAX_NEW_TOKENS = 1" + ) + makeDesc(maxNewTokens = 99999).generatePythonCode() should include( + "self.MAX_NEW_TOKENS = 4096" + ) + } + + it should "clamp temperature into the 0.0-2.0 range" in { + makeDesc(temperature = -1.0).generatePythonCode() should include( + "self.TEMPERATURE = 0.0" + ) + makeDesc(temperature = 5.0).generatePythonCode() should include( + "self.TEMPERATURE = 2.0" + ) + } + + it should "tolerate null @JsonProperty values and fall back to safe defaults" in { + // Every user-input field can land as null when the JSON deserializer is + // handed a workflow that omits the field. generatePythonCode must not + // throw on any combination — and the generated Python must still parse. + val desc = new HuggingFaceInferenceOpDesc() + desc.hfApiToken = null + desc.modelId = null + desc.promptColumn = null + desc.systemPrompt = null + desc.resultColumn = null + desc.task = null + desc.maxNewTokens = null + desc.temperature = null + val code = desc.generatePythonCode() + code should include("class ProcessTableOperator(UDFTableOperator):") + code should include("def open(self):") + // System-prompt default is the empty-string sentinel (no fallback string + // injected) but the operator class still initializes the constant. + code should include("self.SYSTEM_PROMPT = ") + // maxNewTokens null path defaults to 256. + code should include("self.MAX_NEW_TOKENS = 256") + // temperature null path defaults to 0.7. + code should include("self.TEMPERATURE = 0.7") + } + + "TextGenCodegen" should "advertise text-generation as its canonical task" in { + TextGenCodegen.task shouldBe "text-generation" + } + + it should + "emit payload and parse snippets that don't depend on the CodegenContext" in { + // For text-generation, the codegen's only inputs to Python are static + // strings referencing self.* attributes — exercising both methods + // confirms they don't accidentally consume ctx fields (a future + // refactor regression would surface here). + val ctx = CodegenContext( + hfApiToken = "irrelevant", + modelId = "irrelevant", + promptColumn = "irrelevant", + resultColumn = "irrelevant", + task = "irrelevant", + systemPrompt = "irrelevant", + safeMaxTokens = 0, + safeTemp = 0.0 + ) + TextGenCodegen.payloadPython(ctx) should include("self.MODEL_ID") + TextGenCodegen.parsePython(ctx) should include("""body["choices"][0]["message"]["content"]""") + } + + "getOutputSchemas" should "add the result column as a STRING to the inherited schema" in { + val desc = makeDesc(resultColumn = "answer") + val inputSchema = Schema().add("prompt", AttributeType.STRING) + val out = desc.getOutputSchemas(Map(PortIdentity(0) -> inputSchema)) + val outSchema = out(desc.operatorInfo.outputPorts.head.id) + outSchema.getAttributeNames.contains("prompt") shouldBe true + outSchema.getAttributeNames.contains("answer") shouldBe true + outSchema.getAttribute("answer").getType shouldBe AttributeType.STRING + } + + it should "fall back to the default 'hf_response' name when resultColumn is empty" in { + val desc = makeDesc(resultColumn = "") + val inputSchema = Schema().add("prompt", AttributeType.STRING) + val out = desc.getOutputSchemas(Map(PortIdentity(0) -> inputSchema)) + val outSchema = out(desc.operatorInfo.outputPorts.head.id) + outSchema.getAttributeNames.contains("hf_response") shouldBe true + } +}
