This is an automated email from the ASF dual-hosted git repository. github-merge-queue[bot] pushed a commit to branch gh-readonly-queue/main/pr-5898-aec31f32bfc75b025190d84c2554e819273244b1 in repository https://gitbox.apache.org/repos/asf/texera.git
commit ad4f56fd32fa383b299098a7a963a09d52125f6c Author: Xinyuan Lin <[email protected]> AuthorDate: Tue Jun 23 04:29:04 2026 -0700 test(workflow-operator): add unit test coverage for Hugging Face operator descriptors (#5898) ### What changes were proposed in this PR? Pin behavior of four previously-untested Hugging Face operator descriptors in `common/workflow-operator`. No production-code changes. | Spec | Source class | Tests | | --- | --- | --- | | `HuggingFaceSentimentAnalysisOpDescSpec` | `HuggingFaceSentimentAnalysisOpDesc` | 7 | | `HuggingFaceSpamSMSDetectionOpDescSpec` | `HuggingFaceSpamSMSDetectionOpDesc` | 6 | | `HuggingFaceTextSummarizationOpDescSpec` | `HuggingFaceTextSummarizationOpDesc` | 7 | | `HuggingFaceIrisLogisticRegressionOpDescSpec` | `HuggingFaceIrisLogisticRegressionOpDesc` | 6 | **Behavior pinned** | Surface | Contract | | --- | --- | | `operatorInfo` | exact name + description; Hugging Face group; 1-in/1-out (Sentiment also `supportReconfiguration`) | | field defaults | all column fields default to `null` | | `getOutputSchemas` | appended result-column types (Sentiment 3×DOUBLE; Spam BOOLEAN+DOUBLE; Summarization STRING; Iris STRING+DOUBLE) keyed by the declared output port, plus the null/blank result-name guards | | `generatePythonCode` | emits the model id + structural fragments and carries the configured columns | | `getPhysicalOp` | wires `OpExecWithCode` tagged `"python"` and carries the operator's port identities | | Round-trip | config fields preserved through the polymorphic `LogicalOp` base | Note: column fields are `EncodableString`, so in the emitted (encoded) code they appear as `self.decode_python_template('<base64>')` — assertions use a `carries` helper that accepts the raw name or its base64 form. ### Any related issues, documentation, discussions? Part of the ongoing `workflow-operator` unit-test coverage effort (follow-up to #5843, #5844). ### How was this PR tested? - `sbt "WorkflowOperator/testOnly *HuggingFaceSentimentAnalysisOpDescSpec *HuggingFaceSpamSMSDetectionOpDescSpec *HuggingFaceTextSummarizationOpDescSpec *HuggingFaceIrisLogisticRegressionOpDescSpec"` — 26 tests, all green - `sbt "WorkflowOperator/Test/scalafmtCheck"` and `sbt "WorkflowOperator/scalafixAll --check"` — clean - CI to confirm ### Was this PR authored or co-authored using generative AI tooling? Generated-by: Claude Code (Opus 4.8 [1M context]) --- ...ggingFaceIrisLogisticRegressionOpDescSpec.scala | 131 ++++++++++++++++++++ .../HuggingFaceSentimentAnalysisOpDescSpec.scala | 133 +++++++++++++++++++++ .../HuggingFaceSpamSMSDetectionOpDescSpec.scala | 121 +++++++++++++++++++ .../HuggingFaceTextSummarizationOpDescSpec.scala | 125 +++++++++++++++++++ 4 files changed, 510 insertions(+) diff --git a/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceIrisLogisticRegressionOpDescSpec.scala b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceIrisLogisticRegressionOpDescSpec.scala new file mode 100644 index 0000000000..1c3701b4d6 --- /dev/null +++ b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceIrisLogisticRegressionOpDescSpec.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace + +import org.apache.texera.amber.core.executor.OpExecWithCode +import org.apache.texera.amber.core.tuple.{AttributeType, Schema} +import org.apache.texera.amber.core.virtualidentity.{ExecutionIdentity, WorkflowIdentity} +import org.apache.texera.amber.operator.LogicalOp +import org.apache.texera.amber.operator.metadata.OperatorGroupConstants +import org.apache.texera.amber.util.JSONUtils.objectMapper +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import java.nio.charset.StandardCharsets +import java.util.Base64 + +class HuggingFaceIrisLogisticRegressionOpDescSpec extends AnyFlatSpec with Matchers { + + private val workflowId = WorkflowIdentity(1L) + private val executionId = ExecutionIdentity(1L) + + private def b64(s: String): String = + Base64.getEncoder.encodeToString(s.getBytes(StandardCharsets.UTF_8)) + + // EncodableString fields are always base64-wrapped in .encode mode + // (self.decode_python_template('<base64>')), so assert on the base64 form only rather than + // the raw column name, which could appear in the generated Python for unrelated reasons. + private def carries(output: String, name: String): Boolean = + output.contains(b64(name)) + + private def configured(): HuggingFaceIrisLogisticRegressionOpDesc = { + val d = new HuggingFaceIrisLogisticRegressionOpDesc + d.petalLengthCmAttribute = "petalLength" + d.petalWidthCmAttribute = "petalWidth" + d.predictionClassName = "species" + d.predictionProbabilityName = "probability" + d + } + + "HuggingFaceIrisLogisticRegressionOpDesc.operatorInfo" should + "advertise the name, Hugging Face group, and a 1-in/1-out shape" in { + val info = (new HuggingFaceIrisLogisticRegressionOpDesc).operatorInfo + info.userFriendlyName shouldBe "Hugging Face Iris Logistic Regression" + info.operatorDescription shouldBe + "Predict whether an iris is an Iris-setosa using a pre-trained logistic regression model" + info.operatorGroupName shouldBe OperatorGroupConstants.HUGGINGFACE_GROUP + info.inputPorts should have length 1 + info.outputPorts should have length 1 + } + + "HuggingFaceIrisLogisticRegressionOpDesc" should "default all column fields to null" in { + val d = new HuggingFaceIrisLogisticRegressionOpDesc + d.petalLengthCmAttribute shouldBe null + d.petalWidthCmAttribute shouldBe null + d.predictionClassName shouldBe null + d.predictionProbabilityName shouldBe null + } + + "HuggingFaceIrisLogisticRegressionOpDesc.getOutputSchemas" should + "reject a blank prediction result name" in { + val d = new HuggingFaceIrisLogisticRegressionOpDesc + val in = Schema().add("sepal", AttributeType.STRING) + val ex = intercept[RuntimeException] { + d.getOutputSchemas(Map(d.operatorInfo.inputPorts.head.id -> in)) + } + ex.getMessage shouldBe "Result attribute name should not be empty" + } + + it should "append a STRING class column and a DOUBLE probability column, keyed by the output port" in { + val d = configured() + val in = Schema().add("sepal", AttributeType.STRING) + val out = d.getOutputSchemas(Map(d.operatorInfo.inputPorts.head.id -> in)) + val schema = out(d.operatorInfo.outputPorts.head.id) + schema.getAttribute("sepal").getType shouldBe AttributeType.STRING + schema.getAttribute("species").getType shouldBe AttributeType.STRING + schema.getAttribute("probability").getType shouldBe AttributeType.DOUBLE + } + + "HuggingFaceIrisLogisticRegressionOpDesc.generatePythonCode" should + "emit the logistic-regression operator carrying the configured columns (encoded)" in { + val d = configured() + val code = d.generatePythonCode() + code should include("class ProcessTupleOperator(UDFOperatorV2)") + code should include("LinearModel.from_pretrained") + code should include("sadhaklal/logistic-regression-iris") + code should include("self.decode_python_template(") + carries(code, "petalLength") shouldBe true + carries(code, "species") shouldBe true + } + + "HuggingFaceIrisLogisticRegressionOpDesc.getPhysicalOp" should + "wire an OpExecWithCode python executor carrying the operator's ports" in { + val d = configured() + val physical = d.getPhysicalOp(workflowId, executionId) + physical.opExecInitInfo match { + case OpExecWithCode(_, language) => language shouldBe "python" + case other => fail(s"expected OpExecWithCode, got $other") + } + physical.inputPorts.keySet shouldBe d.operatorInfo.inputPorts.map(_.id).toSet + physical.outputPorts.keySet shouldBe d.operatorInfo.outputPorts.map(_.id).toSet + } + + "HuggingFaceIrisLogisticRegressionOpDesc" should + "round-trip its config fields through the polymorphic base" in { + val d = configured() + val restored = objectMapper.readValue(objectMapper.writeValueAsString(d), classOf[LogicalOp]) + restored shouldBe a[HuggingFaceIrisLogisticRegressionOpDesc] + val h = restored.asInstanceOf[HuggingFaceIrisLogisticRegressionOpDesc] + h.petalLengthCmAttribute shouldBe "petalLength" + h.petalWidthCmAttribute shouldBe "petalWidth" + h.predictionClassName shouldBe "species" + h.predictionProbabilityName shouldBe "probability" + } +} diff --git a/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceSentimentAnalysisOpDescSpec.scala b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceSentimentAnalysisOpDescSpec.scala new file mode 100644 index 0000000000..7902c572cd --- /dev/null +++ b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceSentimentAnalysisOpDescSpec.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace + +import org.apache.texera.amber.core.executor.OpExecWithCode +import org.apache.texera.amber.core.tuple.{AttributeType, Schema} +import org.apache.texera.amber.core.virtualidentity.{ExecutionIdentity, WorkflowIdentity} +import org.apache.texera.amber.operator.LogicalOp +import org.apache.texera.amber.operator.metadata.OperatorGroupConstants +import org.apache.texera.amber.util.JSONUtils.objectMapper +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import java.nio.charset.StandardCharsets +import java.util.Base64 + +class HuggingFaceSentimentAnalysisOpDescSpec extends AnyFlatSpec with Matchers { + + private val workflowId = WorkflowIdentity(1L) + private val executionId = ExecutionIdentity(1L) + + private def b64(s: String): String = + Base64.getEncoder.encodeToString(s.getBytes(StandardCharsets.UTF_8)) + + // EncodableString fields are always base64-wrapped in .encode mode + // (self.decode_python_template('<base64>')), so assert on the base64 form only: the raw + // column name can appear in the generated Python for unrelated reasons (e.g. the + // "positive"/"neutral"/"negative" label keys), which would mask a missing splice. + private def carries(output: String, name: String): Boolean = + output.contains(b64(name)) + + private def configured(): HuggingFaceSentimentAnalysisOpDesc = { + val d = new HuggingFaceSentimentAnalysisOpDesc + d.attribute = "text" + d.resultAttributePositive = "pos" + d.resultAttributeNeutral = "neu" + d.resultAttributeNegative = "neg" + d + } + + "HuggingFaceSentimentAnalysisOpDesc.operatorInfo" should + "advertise the name, Hugging Face group, and a 1-in/1-out reconfigurable shape" in { + val info = (new HuggingFaceSentimentAnalysisOpDesc).operatorInfo + info.userFriendlyName shouldBe "Hugging Face Sentiment Analysis" + info.operatorDescription shouldBe + "Analyzing Sentiments with a Twitter-Based Model from Hugging Face" + info.operatorGroupName shouldBe OperatorGroupConstants.HUGGINGFACE_GROUP + info.inputPorts should have length 1 + info.outputPorts should have length 1 + info.supportReconfiguration shouldBe true + } + + "HuggingFaceSentimentAnalysisOpDesc" should "default all column fields to null" in { + val d = new HuggingFaceSentimentAnalysisOpDesc + d.attribute shouldBe null + d.resultAttributePositive shouldBe null + d.resultAttributeNeutral shouldBe null + d.resultAttributeNegative shouldBe null + } + + "HuggingFaceSentimentAnalysisOpDesc.getOutputSchemas" should + "return null when any result column is unset" in { + val d = new HuggingFaceSentimentAnalysisOpDesc + val in = Schema().add("text", AttributeType.STRING) + d.getOutputSchemas(Map(d.operatorInfo.inputPorts.head.id -> in)) shouldBe null + } + + it should "append the three sentiment columns as DOUBLE, keyed by the declared output port" in { + val d = configured() + val in = Schema().add("text", AttributeType.STRING) + val out = d.getOutputSchemas(Map(d.operatorInfo.inputPorts.head.id -> in)) + val schema = out(d.operatorInfo.outputPorts.head.id) + schema.getAttribute("text").getType shouldBe AttributeType.STRING + schema.getAttribute("pos").getType shouldBe AttributeType.DOUBLE + schema.getAttribute("neu").getType shouldBe AttributeType.DOUBLE + schema.getAttribute("neg").getType shouldBe AttributeType.DOUBLE + } + + "HuggingFaceSentimentAnalysisOpDesc.generatePythonCode" should + "emit the cardiffnlp sentiment operator carrying the configured columns (encoded)" in { + val d = configured() + val code = d.generatePythonCode() + code should include("class ProcessTupleOperator(UDFOperatorV2)") + code should include("cardiffnlp/twitter-roberta-base-sentiment-latest") + code should include("from scipy.special import softmax") + code should include("self.decode_python_template(") + carries(code, "text") shouldBe true + carries(code, "pos") shouldBe true + // EncodableString columns are base64-encoded, not embedded raw. + code should not include "\"text\"]" + } + + "HuggingFaceSentimentAnalysisOpDesc.getPhysicalOp" should + "wire an OpExecWithCode python executor carrying the operator's ports" in { + val d = configured() + val physical = d.getPhysicalOp(workflowId, executionId) + physical.opExecInitInfo match { + case OpExecWithCode(_, language) => language shouldBe "python" + case other => fail(s"expected OpExecWithCode, got $other") + } + physical.inputPorts.keySet shouldBe d.operatorInfo.inputPorts.map(_.id).toSet + physical.outputPorts.keySet shouldBe d.operatorInfo.outputPorts.map(_.id).toSet + } + + "HuggingFaceSentimentAnalysisOpDesc" should + "round-trip its config fields through the polymorphic base" in { + val d = configured() + val restored = objectMapper.readValue(objectMapper.writeValueAsString(d), classOf[LogicalOp]) + restored shouldBe a[HuggingFaceSentimentAnalysisOpDesc] + val h = restored.asInstanceOf[HuggingFaceSentimentAnalysisOpDesc] + h.attribute shouldBe "text" + h.resultAttributePositive shouldBe "pos" + h.resultAttributeNeutral shouldBe "neu" + h.resultAttributeNegative shouldBe "neg" + } +} diff --git a/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceSpamSMSDetectionOpDescSpec.scala b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceSpamSMSDetectionOpDescSpec.scala new file mode 100644 index 0000000000..9da5d52782 --- /dev/null +++ b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceSpamSMSDetectionOpDescSpec.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace + +import org.apache.texera.amber.core.executor.OpExecWithCode +import org.apache.texera.amber.core.tuple.{AttributeType, Schema} +import org.apache.texera.amber.core.virtualidentity.{ExecutionIdentity, WorkflowIdentity} +import org.apache.texera.amber.operator.LogicalOp +import org.apache.texera.amber.operator.metadata.OperatorGroupConstants +import org.apache.texera.amber.util.JSONUtils.objectMapper +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import java.nio.charset.StandardCharsets +import java.util.Base64 + +class HuggingFaceSpamSMSDetectionOpDescSpec extends AnyFlatSpec with Matchers { + + private val workflowId = WorkflowIdentity(1L) + private val executionId = ExecutionIdentity(1L) + + private def b64(s: String): String = + Base64.getEncoder.encodeToString(s.getBytes(StandardCharsets.UTF_8)) + + // EncodableString fields are always base64-wrapped in .encode mode + // (self.decode_python_template('<base64>')), so assert on the base64 form only: the raw + // column name can appear in the generated Python for unrelated reasons (e.g. "text" in the + // "text-classification" task string, "score" in result["score"]), masking a missing splice. + private def carries(output: String, name: String): Boolean = + output.contains(b64(name)) + + private def configured(): HuggingFaceSpamSMSDetectionOpDesc = { + val d = new HuggingFaceSpamSMSDetectionOpDesc + d.attribute = "text" + d.resultAttributeSpam = "is_spam" + d.resultAttributeProbability = "score" + d + } + + "HuggingFaceSpamSMSDetectionOpDesc.operatorInfo" should + "advertise the name, Hugging Face group, and a 1-in/1-out shape" in { + val info = (new HuggingFaceSpamSMSDetectionOpDesc).operatorInfo + info.userFriendlyName shouldBe "Hugging Face Spam Detection" + info.operatorDescription shouldBe "Spam Detection by SMS Spam Detection Model from Hugging Face" + info.operatorGroupName shouldBe OperatorGroupConstants.HUGGINGFACE_GROUP + info.inputPorts should have length 1 + info.outputPorts should have length 1 + } + + "HuggingFaceSpamSMSDetectionOpDesc" should "default all column fields to null" in { + val d = new HuggingFaceSpamSMSDetectionOpDesc + d.attribute shouldBe null + d.resultAttributeSpam shouldBe null + d.resultAttributeProbability shouldBe null + } + + "HuggingFaceSpamSMSDetectionOpDesc.getOutputSchemas" should + "append a BOOLEAN spam column and a DOUBLE score column, keyed by the declared output port" in { + val d = configured() + val in = Schema().add("msg", AttributeType.STRING) + val out = d.getOutputSchemas(Map(d.operatorInfo.inputPorts.head.id -> in)) + val schema = out(d.operatorInfo.outputPorts.head.id) + schema.getAttribute("msg").getType shouldBe AttributeType.STRING + schema.getAttribute("is_spam").getType shouldBe AttributeType.BOOLEAN + schema.getAttribute("score").getType shouldBe AttributeType.DOUBLE + } + + "HuggingFaceSpamSMSDetectionOpDesc.generatePythonCode" should + "emit the spam-detection pipeline carrying the configured columns (encoded)" in { + val d = configured() + val code = d.generatePythonCode() + code should include("from transformers import pipeline") + code should include("class ProcessTupleOperator(UDFOperatorV2)") + code should include("mrm8488/bert-tiny-finetuned-sms-spam-detection") + code should include("result[\"label\"] == \"LABEL_1\"") + code should include("self.decode_python_template(") + carries(code, "text") shouldBe true + carries(code, "is_spam") shouldBe true + carries(code, "score") shouldBe true + } + + "HuggingFaceSpamSMSDetectionOpDesc.getPhysicalOp" should + "wire an OpExecWithCode python executor carrying the operator's ports" in { + val d = configured() + val physical = d.getPhysicalOp(workflowId, executionId) + physical.opExecInitInfo match { + case OpExecWithCode(_, language) => language shouldBe "python" + case other => fail(s"expected OpExecWithCode, got $other") + } + physical.inputPorts.keySet shouldBe d.operatorInfo.inputPorts.map(_.id).toSet + physical.outputPorts.keySet shouldBe d.operatorInfo.outputPorts.map(_.id).toSet + } + + "HuggingFaceSpamSMSDetectionOpDesc" should + "round-trip its config fields through the polymorphic base" in { + val d = configured() + val restored = objectMapper.readValue(objectMapper.writeValueAsString(d), classOf[LogicalOp]) + restored shouldBe a[HuggingFaceSpamSMSDetectionOpDesc] + val h = restored.asInstanceOf[HuggingFaceSpamSMSDetectionOpDesc] + h.attribute shouldBe "text" + h.resultAttributeSpam shouldBe "is_spam" + h.resultAttributeProbability shouldBe "score" + } +} diff --git a/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceTextSummarizationOpDescSpec.scala b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceTextSummarizationOpDescSpec.scala new file mode 100644 index 0000000000..7e62242372 --- /dev/null +++ b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/huggingFace/HuggingFaceTextSummarizationOpDescSpec.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.huggingFace + +import org.apache.texera.amber.core.executor.OpExecWithCode +import org.apache.texera.amber.core.tuple.{AttributeType, Schema} +import org.apache.texera.amber.core.virtualidentity.{ExecutionIdentity, WorkflowIdentity} +import org.apache.texera.amber.operator.LogicalOp +import org.apache.texera.amber.operator.metadata.OperatorGroupConstants +import org.apache.texera.amber.util.JSONUtils.objectMapper +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import java.nio.charset.StandardCharsets +import java.util.Base64 + +class HuggingFaceTextSummarizationOpDescSpec extends AnyFlatSpec with Matchers { + + private val workflowId = WorkflowIdentity(1L) + private val executionId = ExecutionIdentity(1L) + + private def b64(s: String): String = + Base64.getEncoder.encodeToString(s.getBytes(StandardCharsets.UTF_8)) + + // EncodableString fields are always base64-wrapped in .encode mode + // (self.decode_python_template('<base64>')), so assert on the base64 form only: the raw + // column name can appear in the generated Python for unrelated reasons (e.g. the local + // variables "text"/"summary"), which would mask a missing splice. + private def carries(output: String, name: String): Boolean = + output.contains(b64(name)) + + private def configured(): HuggingFaceTextSummarizationOpDesc = { + val d = new HuggingFaceTextSummarizationOpDesc + d.attribute = "text" + d.resultAttribute = "summary" + d + } + + "HuggingFaceTextSummarizationOpDesc.operatorInfo" should + "advertise the name, Hugging Face group, and a 1-in/1-out shape" in { + val info = (new HuggingFaceTextSummarizationOpDesc).operatorInfo + info.userFriendlyName shouldBe "Hugging Face Text Summarization" + info.operatorDescription shouldBe + "Summarize the given text content with a mini2bert pre-trained model from Hugging Face" + info.operatorGroupName shouldBe OperatorGroupConstants.HUGGINGFACE_GROUP + info.inputPorts should have length 1 + info.outputPorts should have length 1 + } + + "HuggingFaceTextSummarizationOpDesc" should "default attribute and resultAttribute to null" in { + val d = new HuggingFaceTextSummarizationOpDesc + d.attribute shouldBe null + d.resultAttribute shouldBe null + } + + "HuggingFaceTextSummarizationOpDesc.getOutputSchemas" should + "reject a blank result attribute name" in { + val d = new HuggingFaceTextSummarizationOpDesc + val in = Schema().add("text", AttributeType.STRING) + val ex = intercept[RuntimeException] { + d.getOutputSchemas(Map(d.operatorInfo.inputPorts.head.id -> in)) + } + ex.getMessage shouldBe "Result attribute name should be given" + } + + it should "append the summary column as STRING, keyed by the declared output port" in { + val d = configured() + val in = Schema().add("text", AttributeType.STRING) + val out = d.getOutputSchemas(Map(d.operatorInfo.inputPorts.head.id -> in)) + val schema = out(d.operatorInfo.outputPorts.head.id) + schema.getAttribute("text").getType shouldBe AttributeType.STRING + schema.getAttribute("summary").getType shouldBe AttributeType.STRING + } + + "HuggingFaceTextSummarizationOpDesc.generatePythonCode" should + "emit the bert-mini2bert summarizer carrying the configured columns (encoded)" in { + val d = configured() + val code = d.generatePythonCode() + code should include("class ProcessTupleOperator(UDFOperatorV2)") + code should include("from transformers import BertTokenizerFast, EncoderDecoderModel") + code should include("mrm8488/bert-mini2bert-mini-finetuned-cnn_daily_mail-summarization") + code should include("self.decode_python_template(") + carries(code, "text") shouldBe true + carries(code, "summary") shouldBe true + } + + "HuggingFaceTextSummarizationOpDesc.getPhysicalOp" should + "wire an OpExecWithCode python executor carrying the operator's ports" in { + val d = configured() + val physical = d.getPhysicalOp(workflowId, executionId) + physical.opExecInitInfo match { + case OpExecWithCode(_, language) => language shouldBe "python" + case other => fail(s"expected OpExecWithCode, got $other") + } + physical.inputPorts.keySet shouldBe d.operatorInfo.inputPorts.map(_.id).toSet + physical.outputPorts.keySet shouldBe d.operatorInfo.outputPorts.map(_.id).toSet + } + + "HuggingFaceTextSummarizationOpDesc" should + "round-trip its config fields through the polymorphic base" in { + val d = configured() + val restored = objectMapper.readValue(objectMapper.writeValueAsString(d), classOf[LogicalOp]) + restored shouldBe a[HuggingFaceTextSummarizationOpDesc] + val h = restored.asInstanceOf[HuggingFaceTextSummarizationOpDesc] + h.attribute shouldBe "text" + h.resultAttribute shouldBe "summary" + } +}
