This is an automated email from the ASF dual-hosted git repository.

cdionysio pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new c236a25078 [SYSTEMDS-3835] Improve memory efficiency of text context 
operations
c236a25078 is described below

commit c236a25078947137a4882a5c8868455e51b9894d
Author: Christina Dionysio <[email protected]>
AuthorDate: Wed Jan 28 12:35:43 2026 +0100

    [SYSTEMDS-3835] Improve memory efficiency of text context operations
    
    This patch uses a index based method to split text into multiple chunks and 
stores a list of those start and end indices for each chunk in the data in the 
instance metadata.
---
 .../systemds/scuro/drsearch/operator_registry.py   |   2 +
 .../systemds/scuro/drsearch/unimodal_optimizer.py  | 107 ++++++++++++---------
 .../python/systemds/scuro/modality/transformed.py  |   9 +-
 src/main/python/systemds/scuro/modality/type.py    |  19 ++++
 .../systemds/scuro/modality/unimodal_modality.py   |  24 ++++-
 .../representations/aggregated_representation.py   |   2 +-
 .../python/systemds/scuro/representations/bert.py  |  35 ++-----
 .../python/systemds/scuro/representations/elmo.py  |  26 +----
 .../systemds/scuro/representations/text_context.py |   4 +-
 .../representations/text_context_with_indices.py   |  26 +++--
 .../python/systemds/scuro/utils/torch_dataset.py   |  42 ++++++++
 .../python/tests/scuro/test_operator_registry.py   |  10 +-
 .../tests/scuro/test_text_context_operators.py     |  32 +++---
 .../python/tests/scuro/test_unimodal_optimizer.py  |   5 +-
 14 files changed, 199 insertions(+), 144 deletions(-)

diff --git a/src/main/python/systemds/scuro/drsearch/operator_registry.py 
b/src/main/python/systemds/scuro/drsearch/operator_registry.py
index bf9547ddbf..e9c302ba90 100644
--- a/src/main/python/systemds/scuro/drsearch/operator_registry.py
+++ b/src/main/python/systemds/scuro/drsearch/operator_registry.py
@@ -97,6 +97,8 @@ class Registry:
         return reps
 
     def get_context_operators(self, modality_type):
+        if modality_type not in self._context_operators.keys():
+            return []
         return self._context_operators[modality_type]
 
     def get_dimensionality_reduction_operators(self, modality_type):
diff --git a/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py 
b/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
index c555c2b677..5b03147ec1 100644
--- a/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
+++ b/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
@@ -356,7 +356,8 @@ class UnimodalOptimizer:
             operator.__class__, [leaf_id], operator.get_current_parameters()
         )
         current_node_id = rep_node_id
-        dags.append(builder.build(current_node_id))
+        rep_dag = builder.build(current_node_id)
+        dags.append(rep_dag)
 
         dimensionality_reduction_dags = 
self.add_dimensionality_reduction_operators(
             builder, current_node_id
@@ -387,11 +388,6 @@ class UnimodalOptimizer:
                     [context_node_id],
                     operator.get_current_parameters(),
                 )
-                dimensionality_reduction_dags = 
self.add_dimensionality_reduction_operators(
-                    builder, context_rep_node_id
-                )  # TODO: check if this is correctly using the 3d approach of 
the dimensionality reduction operator
-                if dimensionality_reduction_dags is not None:
-                    dags.extend(dimensionality_reduction_dags)
 
                 agg_operator = AggregatedRepresentation()
                 context_agg_node_id = builder.create_operation_node(
@@ -409,64 +405,88 @@ class UnimodalOptimizer:
             not_self_contained_reps = [
                 rep for rep in not_self_contained_reps if rep != 
operator.__class__
             ]
+            rep_id = current_node_id
 
-            for combination in self._combination_operators:
-                current_node_id = rep_node_id
-                for other_rep in not_self_contained_reps:
-                    other_rep_id = builder.create_operation_node(
-                        other_rep, [leaf_id], other_rep().parameters
-                    )
-
+            for rep in not_self_contained_reps:
+                other_rep_id = builder.create_operation_node(
+                    rep, [leaf_id], rep().parameters
+                )
+                for combination in self._combination_operators:
                     combine_id = builder.create_operation_node(
                         combination.__class__,
-                        [current_node_id, other_rep_id],
+                        [rep_id, other_rep_id],
                         combination.get_current_parameters(),
                     )
-                    dags.append(builder.build(combine_id))
-                    current_node_id = combine_id
-                if modality.modality_type in [
-                    ModalityType.EMBEDDING,
-                    ModalityType.IMAGE,
-                    ModalityType.AUDIO,
-                ]:
-                    dags.extend(
-                        self.default_context_operators(
-                            modality, builder, leaf_id, current_node_id
+                    rep_dag = builder.build(combine_id)
+                    dags.append(rep_dag)
+                    if modality.modality_type in [
+                        ModalityType.EMBEDDING,
+                        ModalityType.IMAGE,
+                        ModalityType.AUDIO,
+                    ]:
+                        dags.extend(
+                            self.default_context_operators(
+                                modality, builder, leaf_id, rep_dag, False
+                            )
                         )
-                    )
-                elif modality.modality_type == ModalityType.TIMESERIES:
-                    dags.extend(
-                        self.temporal_context_operators(
-                            modality, builder, leaf_id, current_node_id
+                    elif modality.modality_type == ModalityType.TIMESERIES:
+                        dags.extend(
+                            self.temporal_context_operators(
+                                modality,
+                                builder,
+                                leaf_id,
+                            )
                         )
-                    )
+                rep_id = combine_id
+
+        if rep_dag.nodes[-1].operation().output_modality_type in [
+            ModalityType.EMBEDDING
+        ]:
+            dags.extend(
+                self.default_context_operators(
+                    modality, builder, leaf_id, rep_dag, True
+                )
+            )
+
+        if (
+            modality.modality_type == ModalityType.TIMESERIES
+            or modality.modality_type == ModalityType.AUDIO
+        ):
+            dags.extend(self.temporal_context_operators(modality, builder, 
leaf_id))
         return dags
 
-    def default_context_operators(self, modality, builder, leaf_id, 
current_node_id):
+    def default_context_operators(
+        self, modality, builder, leaf_id, rep_dag, apply_context_to_leaf=False
+    ):
         dags = []
-        context_operators = self._get_context_operators(modality.modality_type)
-        for context_op in context_operators:
+        if apply_context_to_leaf:
             if (
                 modality.modality_type != ModalityType.TEXT
                 and modality.modality_type != ModalityType.VIDEO
             ):
-                context_node_id = builder.create_operation_node(
-                    context_op,
-                    [leaf_id],
-                    context_op().get_current_parameters(),
-                )
-                dags.append(builder.build(context_node_id))
+                context_operators = 
self._get_context_operators(modality.modality_type)
+                for context_op in context_operators:
+                    context_node_id = builder.create_operation_node(
+                        context_op,
+                        [leaf_id],
+                        context_op().get_current_parameters(),
+                    )
+                    dags.append(builder.build(context_node_id))
 
+        context_operators = self._get_context_operators(
+            rep_dag.nodes[-1].operation().output_modality_type
+        )
+        for context_op in context_operators:
             context_node_id = builder.create_operation_node(
                 context_op,
-                [current_node_id],
+                [rep_dag.nodes[-1].node_id],
                 context_op().get_current_parameters(),
             )
             dags.append(builder.build(context_node_id))
 
         return dags
 
-    def temporal_context_operators(self, modality, builder, leaf_id, 
current_node_id):
+    def temporal_context_operators(self, modality, builder, leaf_id):
         aggregators = 
self.operator_registry.get_representations(modality.modality_type)
         context_operators = self._get_context_operators(modality.modality_type)
 
@@ -561,12 +581,11 @@ class UnimodalResults:
 
         results = results[: self.k]
         sorted_indices = sorted_indices[: self.k]
-
         task_cache = self.cache.get(modality.modality_id, 
{}).get(task.model.name, None)
         if not task_cache:
             cache = [
-                list(task_results[i].dag.execute([modality]).values())[-1]
-                for i in sorted_indices
+                list(results[i].dag.execute([modality]).values())[-1]
+                for i in range(len(results))
             ]
         elif isinstance(task_cache, list):
             cache = task_cache
diff --git a/src/main/python/systemds/scuro/modality/transformed.py 
b/src/main/python/systemds/scuro/modality/transformed.py
index 078b65f0bc..a443f5a313 100644
--- a/src/main/python/systemds/scuro/modality/transformed.py
+++ b/src/main/python/systemds/scuro/modality/transformed.py
@@ -31,7 +31,12 @@ import copy
 class TransformedModality(Modality):
 
     def __init__(
-        self, modality, transformation, new_modality_type=None, 
self_contained=True
+        self,
+        modality,
+        transformation,
+        new_modality_type=None,
+        self_contained=True,
+        set_data=False,
     ):
         """
         Parent class of the different Modalities (unimodal & multimodal)
@@ -49,6 +54,8 @@ class TransformedModality(Modality):
             modality.data_type,
             modality.transform_time,
         )
+        if set_data:
+            self.data = modality.data
         self.transformation = None
         self.self_contained = (
             self_contained and transformation.self_contained
diff --git a/src/main/python/systemds/scuro/modality/type.py 
b/src/main/python/systemds/scuro/modality/type.py
index 23d97e869b..85f4d04e9b 100644
--- a/src/main/python/systemds/scuro/modality/type.py
+++ b/src/main/python/systemds/scuro/modality/type.py
@@ -210,6 +210,25 @@ class ModalityType(Flag):
     def get_schema(self):
         return ModalitySchemas.get(self.name)
 
+    def has_field(self, md, field):
+        for value in md.values():
+            if field in value:
+                return True
+            else:
+                return False
+        return False
+
+    def get_field_for_instances(self, md, field):
+        data = []
+        for items in md.values():
+            data.append(self.get_field(items, field))
+        return data
+
+    def get_field(self, md, field):
+        if field in md:
+            return md[field]
+        return None
+
     def update_metadata(self, md, data):
         return ModalitySchemas.update_metadata(self.name, md, data)
 
diff --git a/src/main/python/systemds/scuro/modality/unimodal_modality.py 
b/src/main/python/systemds/scuro/modality/unimodal_modality.py
index 4efaa7d733..89d95810e0 100644
--- a/src/main/python/systemds/scuro/modality/unimodal_modality.py
+++ b/src/main/python/systemds/scuro/modality/unimodal_modality.py
@@ -91,9 +91,14 @@ class UnimodalModality(Modality):
         if not self.has_data():
             self.extract_raw_data()
 
-        transformed_modality = TransformedModality(self, context_operator)
-
-        transformed_modality.data = context_operator.execute(self)
+        transformed_modality = TransformedModality(
+            self, context_operator, set_data=True
+        )
+        d = context_operator.execute(transformed_modality)
+        if d is not None:
+            transformed_modality.data = d
+        else:
+            transformed_modality.data = self.data
         transformed_modality.transform_time += time.time() - start
         return transformed_modality
 
@@ -212,14 +217,23 @@ class UnimodalModality(Modality):
                                 mode="constant",
                                 constant_values=0,
                             )
-                        else:
+                        elif len(embeddings.shape) == 2:
                             padded = np.pad(
                                 embeddings,
                                 ((0, padding_needed), (0, 0)),
                                 mode="constant",
                                 constant_values=0,
                             )
-                        padded_embeddings.append(padded)
+                        elif len(embeddings.shape) == 3:
+                            padded = np.pad(
+                                embeddings,
+                                ((0, padding_needed), (0, 0), (0, 0)),
+                                mode="constant",
+                                constant_values=0,
+                            )
+                            padded_embeddings.append(padded)
+                        else:
+                            raise ValueError(f"Unsupported shape: 
{embeddings.shape}")
                 else:
                     padded_embeddings.append(embeddings)
 
diff --git 
a/src/main/python/systemds/scuro/representations/aggregated_representation.py 
b/src/main/python/systemds/scuro/representations/aggregated_representation.py
index bcc36f4621..cad1a4a448 100644
--- 
a/src/main/python/systemds/scuro/representations/aggregated_representation.py
+++ 
b/src/main/python/systemds/scuro/representations/aggregated_representation.py
@@ -38,7 +38,7 @@ class AggregatedRepresentation(Representation):
         aggregated_modality = TransformedModality(
             modality, self, self_contained=modality.self_contained
         )
+        aggregated_modality.data = self.aggregation.execute(modality)
         end = time.perf_counter()
         aggregated_modality.transform_time += end - start
-        aggregated_modality.data = self.aggregation.execute(modality)
         return aggregated_modality
diff --git a/src/main/python/systemds/scuro/representations/bert.py 
b/src/main/python/systemds/scuro/representations/bert.py
index be579c0dd6..6f4d3705a1 100644
--- a/src/main/python/systemds/scuro/representations/bert.py
+++ b/src/main/python/systemds/scuro/representations/bert.py
@@ -28,35 +28,12 @@ from systemds.scuro.modality.type import ModalityType
 from systemds.scuro.drsearch.operator_registry import register_representation
 from systemds.scuro.utils.static_variables import get_device
 import os
-from torch.utils.data import Dataset, DataLoader
+from torch.utils.data import DataLoader
+from systemds.scuro.utils.torch_dataset import TextDataset, TextSpanDataset
 
 os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
 
-class TextDataset(Dataset):
-    def __init__(self, texts):
-
-        self.texts = []
-        if isinstance(texts, list):
-            self.texts = texts
-        else:
-            for text in texts:
-                if text is None:
-                    self.texts.append("")
-                elif isinstance(text, np.ndarray):
-                    self.texts.append(str(text.item()) if text.size == 1 else 
str(text))
-                elif not isinstance(text, str):
-                    self.texts.append(str(text))
-                else:
-                    self.texts.append(text)
-
-    def __len__(self):
-        return len(self.texts)
-
-    def __getitem__(self, idx):
-        return self.texts[idx]
-
-
 class BertFamily(UnimodalRepresentation):
     def __init__(
         self,
@@ -96,10 +73,12 @@ class BertFamily(UnimodalRepresentation):
                     layer.register_forward_hook(get_activation(name))
                     break
 
-        if isinstance(modality.data[0], list):
+        if ModalityType.TEXT.has_field(modality.metadata, "text_spans"):
+            dataset = TextSpanDataset(modality.data, modality.metadata)
             embeddings = []
-            for d in modality.data:
-                embeddings.append(self.create_embeddings(d, self.model, 
tokenizer))
+            for text in dataset:
+                embedding = self.create_embeddings(text, self.model, tokenizer)
+                embeddings.append(embedding)
         else:
             embeddings = self.create_embeddings(modality.data, self.model, 
tokenizer)
 
diff --git a/src/main/python/systemds/scuro/representations/elmo.py 
b/src/main/python/systemds/scuro/representations/elmo.py
index ba2a99f8e1..33e4f74141 100644
--- a/src/main/python/systemds/scuro/representations/elmo.py
+++ b/src/main/python/systemds/scuro/representations/elmo.py
@@ -29,34 +29,10 @@ from systemds.scuro.modality.type import ModalityType
 from systemds.scuro.utils.static_variables import get_device
 from flair.embeddings import ELMoEmbeddings
 from flair.data import Sentence
-from torch.utils.data import Dataset
+from systemds.scuro.utils.torch_dataset import TextDataset
 from torch.utils.data import DataLoader
 
 
-class TextDataset(Dataset):
-    def __init__(self, texts):
-
-        self.texts = []
-        if isinstance(texts, list):
-            self.texts = texts
-        else:
-            for text in texts:
-                if text is None:
-                    self.texts.append("")
-                elif isinstance(text, np.ndarray):
-                    self.texts.append(str(text.item()) if text.size == 1 else 
str(text))
-                elif not isinstance(text, str):
-                    self.texts.append(str(text))
-                else:
-                    self.texts.append(text)
-
-    def __len__(self):
-        return len(self.texts)
-
-    def __getitem__(self, idx):
-        return self.texts[idx]
-
-
 # @register_representation([ModalityType.TEXT])
 class ELMoRepresentation(UnimodalRepresentation):
     def __init__(
diff --git a/src/main/python/systemds/scuro/representations/text_context.py 
b/src/main/python/systemds/scuro/representations/text_context.py
index b98b90e187..b4f82bda19 100644
--- a/src/main/python/systemds/scuro/representations/text_context.py
+++ b/src/main/python/systemds/scuro/representations/text_context.py
@@ -72,7 +72,7 @@ def _extract_text(instance: Any) -> str:
     return text
 
 
-@register_context_operator(ModalityType.TEXT)
+# @register_context_operator(ModalityType.TEXT)
 class SentenceBoundarySplit(Context):
     """
     Splits text at sentence boundaries while respecting maximum word count.
@@ -154,7 +154,7 @@ class SentenceBoundarySplit(Context):
         return chunked_data
 
 
-@register_context_operator(ModalityType.TEXT)
+# @register_context_operator(ModalityType.TEXT)
 class OverlappingSplit(Context):
     """
     Splits text with overlapping chunks using a sliding window approach.
diff --git 
a/src/main/python/systemds/scuro/representations/text_context_with_indices.py 
b/src/main/python/systemds/scuro/representations/text_context_with_indices.py
index 7daf93855f..5a3c3b34e0 100644
--- 
a/src/main/python/systemds/scuro/representations/text_context_with_indices.py
+++ 
b/src/main/python/systemds/scuro/representations/text_context_with_indices.py
@@ -134,7 +134,7 @@ class WordCountSplitIndices(Context):
         return chunked_data
 
 
-# @register_context_operator(ModalityType.TEXT)
+@register_context_operator(ModalityType.TEXT)
 class SentenceBoundarySplitIndices(Context):
     """
     Splits text at sentence boundaries while respecting maximum word count.
@@ -162,18 +162,17 @@ class SentenceBoundarySplitIndices(Context):
         Returns:
             List of lists, where each inner list contains text chunks (strings)
         """
-        chunked_data = []
 
-        for instance in modality.data:
+        for instance, metadata in zip(modality.data, 
modality.metadata.values()):
             text = _extract_text(instance)
             if not text:
-                chunked_data.append((0, 0))
+                ModalityType.TEXT.add_field(metadata, "text_spans", [(0, 0)])
                 continue
 
             sentences = _split_into_sentences(text)
 
             if not sentences:
-                chunked_data.append((0, len(text)))
+                ModalityType.TEXT.add_field(metadata, "text_spans", [(0, 
len(text))])
                 continue
 
             chunks = []
@@ -225,12 +224,12 @@ class SentenceBoundarySplitIndices(Context):
             if not chunks:
                 chunks = [(0, len(text))]
 
-            chunked_data.append(chunks)
+            ModalityType.TEXT.add_field(metadata, "text_spans", chunks)
 
-        return chunked_data
+        return None
 
 
-# @register_context_operator(ModalityType.TEXT)
+@register_context_operator(ModalityType.TEXT)
 class OverlappingSplitIndices(Context):
     """
     Splits text with overlapping chunks using a sliding window approach.
@@ -263,18 +262,17 @@ class OverlappingSplitIndices(Context):
         Returns:
             List of tuples, where each tuple contains start and end index to 
the text chunks
         """
-        chunked_data = []
 
-        for instance in modality.data:
+        for instance, metadata in zip(modality.data, 
modality.metadata.values()):
             text = _extract_text(instance)
             if not text:
-                chunked_data.append((0, 0))
+                ModalityType.TEXT.add_field(metadata, "text_spans", [(0, 0)])
                 continue
 
             words = _split_into_words(text)
 
             if len(words) <= self.max_words:
-                chunked_data.append((0, len(text)))
+                ModalityType.TEXT.add_field(metadata, "text_spans", [(0, 
len(text))])
                 continue
 
             chunks = []
@@ -295,6 +293,6 @@ class OverlappingSplitIndices(Context):
             if not chunks:
                 chunks = [(0, len(text))]
 
-            chunked_data.append(chunks)
+            ModalityType.TEXT.add_field(metadata, "text_spans", chunks)
 
-        return chunked_data
+        return None
diff --git a/src/main/python/systemds/scuro/utils/torch_dataset.py 
b/src/main/python/systemds/scuro/utils/torch_dataset.py
index 9c462e3675..ba3e24a317 100644
--- a/src/main/python/systemds/scuro/utils/torch_dataset.py
+++ b/src/main/python/systemds/scuro/utils/torch_dataset.py
@@ -24,6 +24,8 @@ import numpy as np
 import torch
 import torchvision.transforms as transforms
 
+from systemds.scuro.modality.type import ModalityType
+
 
 class CustomDataset(torch.utils.data.Dataset):
     def __init__(self, data, data_type, device, size=None, tf=None):
@@ -78,3 +80,43 @@ class CustomDataset(torch.utils.data.Dataset):
 
     def __len__(self) -> int:
         return len(self.data)
+
+
+class TextDataset(torch.utils.data.Dataset):
+    def __init__(self, texts):
+
+        self.texts = []
+        if isinstance(texts, list):
+            self.texts = texts
+        else:
+            for text in texts:
+                if text is None:
+                    self.texts.append("")
+                elif isinstance(text, np.ndarray):
+                    self.texts.append(str(text.item()) if text.size == 1 else 
str(text))
+                elif not isinstance(text, str):
+                    self.texts.append(str(text))
+                else:
+                    self.texts.append(text)
+
+    def __len__(self):
+        return len(self.texts)
+
+    def __getitem__(self, idx):
+        return self.texts[idx]
+
+
+class TextSpanDataset(torch.utils.data.Dataset):
+    def __init__(self, full_texts, metadata):
+        self.full_texts = full_texts
+        self.spans_per_text = ModalityType.TEXT.get_field_for_instances(
+            metadata, "text_spans"
+        )
+
+    def __len__(self):
+        return len(self.full_texts)
+
+    def __getitem__(self, idx):
+        text = self.full_texts[idx]
+        spans = self.spans_per_text[idx]
+        return [text[s:e] for (s, e) in spans]
diff --git a/src/main/python/tests/scuro/test_operator_registry.py 
b/src/main/python/tests/scuro/test_operator_registry.py
index 189e3e44d7..443cc039d6 100644
--- a/src/main/python/tests/scuro/test_operator_registry.py
+++ b/src/main/python/tests/scuro/test_operator_registry.py
@@ -21,9 +21,9 @@
 
 import unittest
 
-from systemds.scuro.representations.text_context import (
-    SentenceBoundarySplit,
-    OverlappingSplit,
+from systemds.scuro.representations.text_context_with_indices import (
+    SentenceBoundarySplitIndices,
+    OverlappingSplitIndices,
 )
 
 from systemds.scuro.representations.covarep_audio_features import (
@@ -134,8 +134,8 @@ class TestOperatorRegistry(unittest.TestCase):
             DynamicWindow,
         ]
         assert registry.get_context_operators(ModalityType.TEXT) == [
-            SentenceBoundarySplit,
-            OverlappingSplit,
+            SentenceBoundarySplitIndices,
+            OverlappingSplitIndices,
         ]
 
     # def test_fusion_operator_in_registry(self):
diff --git a/src/main/python/tests/scuro/test_text_context_operators.py 
b/src/main/python/tests/scuro/test_text_context_operators.py
index 1f04165407..ffa702b7c8 100644
--- a/src/main/python/tests/scuro/test_text_context_operators.py
+++ b/src/main/python/tests/scuro/test_text_context_operators.py
@@ -36,6 +36,7 @@ from tests.scuro.data_generator import (
 )
 from systemds.scuro.modality.unimodal_modality import UnimodalModality
 from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.representations.bert import Bert
 
 
 class TestTextContextOperator(unittest.TestCase):
@@ -80,33 +81,30 @@ class TestTextContextOperator(unittest.TestCase):
 
     def test_sentence_boundary_split_indices(self):
         sentence_boundary_split = SentenceBoundarySplitIndices(10, min_words=4)
-        chunks = sentence_boundary_split.execute(self.text_modality)
-        for i in range(0, len(chunks)):
-            for chunk in chunks[i]:
-                text = self.text_modality.data[i][chunk[0] : chunk[1]].split(" 
")
+        sentence_boundary_split.execute(self.text_modality)
+        for instance, md in zip(
+            self.text_modality.data, self.text_modality.metadata.values()
+        ):
+            for chunk in md["text_spans"]:
+                text = instance[chunk[0] : chunk[1]].split(" ")
                 assert len(text) <= 10 and (
                     text[-1][-1] == "." or text[-1][-1] == "!" or text[-1][-1] 
== "?"
                 )
 
     def test_overlapping_split_indices(self):
         overlapping_split = OverlappingSplitIndices(40, 0.1)
-        chunks = overlapping_split.execute(self.text_modality)
-        for i in range(len(chunks)):
+        overlapping_split.execute(self.text_modality)
+        for instance, md in zip(
+            self.text_modality.data, self.text_modality.metadata.values()
+        ):
             prev_chunk = (0, 0)
-            for j, chunk in enumerate(chunks[i]):
+            for j, chunk in enumerate(md["text_spans"]):
                 if j > 0:
-                    prev_words = self.text_modality.data[i][
-                        prev_chunk[0] : prev_chunk[1]
-                    ].split(" ")
-                    curr_words = self.text_modality.data[i][chunk[0] : 
chunk[1]].split(
-                        " "
-                    )
+                    prev_words = instance[prev_chunk[0] : 
prev_chunk[1]].split(" ")
+                    curr_words = instance[chunk[0] : chunk[1]].split(" ")
                     assert prev_words[-4:] == curr_words[:4]
                 prev_chunk = chunk
-                assert (
-                    len(self.text_modality.data[i][chunk[0] : 
chunk[1]].split(" "))
-                    <= 40
-                )
+                assert len(instance[chunk[0] : chunk[1]].split(" ")) <= 40
 
 
 if __name__ == "__main__":
diff --git a/src/main/python/tests/scuro/test_unimodal_optimizer.py 
b/src/main/python/tests/scuro/test_unimodal_optimizer.py
index 0d8ae90177..7fa606d835 100644
--- a/src/main/python/tests/scuro/test_unimodal_optimizer.py
+++ b/src/main/python/tests/scuro/test_unimodal_optimizer.py
@@ -36,6 +36,7 @@ from systemds.scuro.representations.covarep_audio_features 
import (
 )
 from systemds.scuro.representations.word2vec import W2V
 from systemds.scuro.representations.bow import BoW
+from systemds.scuro.representations.bert import Bert
 from systemds.scuro.modality.unimodal_modality import UnimodalModality
 from systemds.scuro.representations.resnet import ResNet
 from tests.scuro.data_generator import (
@@ -124,7 +125,7 @@ class 
TestUnimodalRepresentationOptimizer(unittest.TestCase):
         ):
             registry = Registry()
 
-            unimodal_optimizer = UnimodalOptimizer([modality], self.tasks, 
False)
+            unimodal_optimizer = UnimodalOptimizer([modality], self.tasks, 
False, k=1)
             unimodal_optimizer.optimize()
 
             assert (
@@ -133,7 +134,7 @@ class 
TestUnimodalRepresentationOptimizer(unittest.TestCase):
             )
             assert len(unimodal_optimizer.operator_performance.task_names) == 2
             result, cached = 
unimodal_optimizer.operator_performance.get_k_best_results(
-                modality, 1, self.tasks[0], "accuracy"
+                modality, self.tasks[0], "accuracy"
             )
             assert len(result) == 1
             assert len(cached) == 1

Reply via email to