This is an automated email from the ASF dual-hosted git repository.
linxinyuan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/texera.git
The following commit(s) were added to refs/heads/main by this push:
new 3e92803d84 feat: introduce sklearn testing operator (#4211)
3e92803d84 is described below
commit 3e92803d8499eec427ceb3085e9cf91fcd87e6a4
Author: Xinyuan Lin <[email protected]>
AuthorDate: Sat Feb 14 00:27:54 2026 -0800
feat: introduce sklearn testing operator (#4211)
### What changes were proposed in this PR?
Introduce the Sklearn Testing Operator, which accepts any number of
machine learning models (from 1 to n) and computes `accuracy`, `F1`
score, `precision`, and `recall` for each model, appending these metrics
to the output columns.
There is a port dependency between the model and the data: the data port
must be finished first. The data table is then used as an internal state
for testing. This operator can accept any number of models; each model
will be tested against the same data table.
**Input single model:**
<img width="797" alt="Screenshot 2025-07-04 at 22 54 33"
src="https://github.com/user-attachments/assets/d14326f2-4f5f-4476-9eda-eb464ea8049c"
/>
**Input multiple models:**
<img width="822" alt="Screenshot 2025-07-04 at 22 47 09"
src="https://github.com/user-attachments/assets/4333ca08-3717-407a-b978-05995197f8c8"
/>
### Was this PR authored or co-authored using generative AI tooling?
No
---
.../input_port_materialization_reader_runnable.py | 3 +-
.../apache/texera/amber/operator/LogicalOp.scala | 4 +-
.../sklearn/testing/SklearnTestingOpDesc.scala | 115 +++++++++++++++++++++
.../src/assets/operator_images/SklearnTesting.png | Bin 0 -> 843070 bytes
4 files changed, 120 insertions(+), 2 deletions(-)
diff --git
a/amber/src/main/python/core/storage/runnables/input_port_materialization_reader_runnable.py
b/amber/src/main/python/core/storage/runnables/input_port_materialization_reader_runnable.py
index c82926a60a..e49c0316cc 100644
---
a/amber/src/main/python/core/storage/runnables/input_port_materialization_reader_runnable.py
+++
b/amber/src/main/python/core/storage/runnables/input_port_materialization_reader_runnable.py
@@ -17,7 +17,7 @@
import typing
from loguru import logger
-from pyarrow.lib import Table
+from pyarrow import Table
from typing import Union
from core.architecture.sendsemantics.broad_cast_partitioner import (
@@ -146,6 +146,7 @@ class InputPortMaterializationReaderRunnable(Runnable,
Stoppable):
break
# Each tuple is sent to the partitioner and converted to
# a batch-based iterator.
+ tup.cast_to_schema(self.tuple_schema)
for data_frame in self.tuple_to_batch_with_filter(tup):
self.emit_payload(data_frame)
self.emit_ecm("EndChannel",
EmbeddedControlMessageType.PORT_ALIGNMENT)
diff --git
a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/LogicalOp.scala
b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/LogicalOp.scala
index eb319a82d1..a575d5b018 100644
---
a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/LogicalOp.scala
+++
b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/LogicalOp.scala
@@ -137,6 +137,7 @@ import
org.apache.texera.amber.operator.visualization.volcanoPlot.VolcanoPlotOpD
import
org.apache.texera.amber.operator.visualization.waterfallChart.WaterfallChartOpDesc
import org.apache.texera.amber.operator.visualization.wordCloud.WordCloudOpDesc
import org.apache.commons.lang3.builder.{EqualsBuilder, HashCodeBuilder,
ToStringBuilder}
+import org.apache.texera.amber.operator.sklearn.testing.SklearnTestingOpDesc
import
org.apache.texera.amber.operator.visualization.stripChart.StripChartOpDesc
import java.util.UUID
@@ -407,7 +408,8 @@ trait StateTransferFunc
new Type(
value = classOf[SklearnAdvancedSVRTrainerOpDesc],
name = "SVRTrainer"
- )
+ ),
+ new Type(value = classOf[SklearnTestingOpDesc], name = "SklearnTesting")
)
)
abstract class LogicalOp extends PortDescriptor with Serializable {
diff --git
a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/sklearn/testing/SklearnTestingOpDesc.scala
b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/sklearn/testing/SklearnTestingOpDesc.scala
new file mode 100644
index 0000000000..4c7af2db98
--- /dev/null
+++
b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/sklearn/testing/SklearnTestingOpDesc.scala
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.texera.amber.operator.sklearn.testing
+
+import com.fasterxml.jackson.annotation.{JsonProperty, JsonPropertyDescription}
+import com.kjetland.jackson.jsonSchema.annotations.JsonSchemaTitle
+import org.apache.texera.amber.core.tuple.{AttributeType, Schema}
+import org.apache.texera.amber.core.workflow.{InputPort, OutputPort,
PortIdentity}
+import org.apache.texera.amber.operator.PythonOperatorDescriptor
+import org.apache.texera.amber.operator.metadata.annotations.{
+ AutofillAttributeName,
+ AutofillAttributeNameOnPort1
+}
+import org.apache.texera.amber.operator.metadata.{OperatorGroupConstants,
OperatorInfo}
+import org.apache.texera.amber.pybuilder.PyStringTypes.EncodableString
+import
org.apache.texera.amber.pybuilder.PythonTemplateBuilder.PythonTemplateBuilderStringContext
+
+class SklearnTestingOpDesc extends PythonOperatorDescriptor {
+ @JsonProperty(required = true, defaultValue = "false")
+ @JsonSchemaTitle("Regression")
+ @JsonPropertyDescription(
+ "Choose to solve a regression task"
+ )
+ var isRegression: Boolean = false
+
+ @JsonSchemaTitle("Model Attribute")
+ @JsonProperty(required = true, defaultValue = "model")
+ @JsonPropertyDescription("Attribute corresponding to ML model")
+ @AutofillAttributeName
+ var model: EncodableString = _
+
+ @JsonSchemaTitle("Target Attribute")
+ @JsonPropertyDescription("Attribute in your dataset corresponding to
target.")
+ @JsonProperty(required = true)
+ @AutofillAttributeNameOnPort1
+ var target: EncodableString = _
+
+ override def generatePythonCode(): String = {
+ val isRegressionStr = if (isRegression) "True" else "False"
+ pyb"""from pytexera import *
+ |from sklearn.metrics import accuracy_score, f1_score,
precision_score, recall_score, root_mean_squared_error, mean_absolute_error,
r2_score
+ |class ProcessTupleOperator(UDFOperatorV2):
+ | @overrides
+ | def open(self) -> None:
+ | self.data = []
+ | @overrides
+ | def process_tuple(self, tuple_: Tuple, port: int) ->
Iterator[Optional[TupleLike]]:
+ | if port == 1:
+ | self.data.append(tuple_)
+ | else:
+ | model = tuple_[$model]
+ | table = Table(self.data)
+ | Y = table[$target]
+ | X = table.drop($target, axis=1)
+ | predictions = model.predict(X)
+ | if $isRegressionStr:
+ | tuple_["R2"] = r2_score(Y, predictions)
+ | tuple_["RMSE"] = root_mean_squared_error(Y,
predictions)
+ | tuple_["MAE"] = mean_absolute_error(Y, predictions)
+ | else:
+ | tuple_["accuracy"] = round(accuracy_score(Y,
predictions), 4)
+ | tuple_["f1"] = f1_score(Y, predictions,
average="weighted")
+ | tuple_["precision"] = precision_score(Y,
predictions, average="weighted")
+ | tuple_["recall"] = recall_score(Y, predictions,
average="weighted")
+ | yield tuple_""".encode
+ }
+
+ override def operatorInfo: OperatorInfo =
+ OperatorInfo(
+ "Sklearn Testing",
+ "It will generate scorers for Sklearn model",
+ OperatorGroupConstants.SKLEARN_GROUP,
+ inputPorts = List(
+ InputPort(
+ PortIdentity(),
+ "model",
+ dependencies = List(PortIdentity(1)),
+ allowMultiLinks = true
+ ),
+ InputPort(PortIdentity(1), "data")
+ ),
+ outputPorts = List(OutputPort())
+ )
+
+ override def getOutputSchemas(
+ inputSchemas: Map[PortIdentity, Schema]
+ ): Map[PortIdentity, Schema] =
+ Map(
+ operatorInfo.outputPorts.head.id ->
+ (if (!isRegression)
+ Seq("accuracy", "f1", "precision", "recall")
+ else
+ Seq("R2", "RMSE", "MAE"))
+ .foldLeft(inputSchemas(operatorInfo.inputPorts.head.id))(
+ _.add(_, AttributeType.DOUBLE)
+ )
+ )
+}
diff --git a/frontend/src/assets/operator_images/SklearnTesting.png
b/frontend/src/assets/operator_images/SklearnTesting.png
new file mode 100644
index 0000000000..b90f8853fb
Binary files /dev/null and
b/frontend/src/assets/operator_images/SklearnTesting.png differ