This is an automated email from the ASF dual-hosted git repository.

cdionysio pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new c9a54fe2a3 [SYSTEMDS-3887] Refactor representation optimizers (#2308)
c9a54fe2a3 is described below

commit c9a54fe2a3b4281eec24491f3cd7291f426f2ccc
Author: Christina Dionysio <diony...@tu-berlin.de>
AuthorDate: Mon Aug 18 13:55:55 2025 +0200

    [SYSTEMDS-3887] Refactor representation optimizers (#2308)
    
    This patch adds an updated version of the unimodal and multimodal 
representation optimizers. It includes improved handling of optimization 
results, and more readable debug output for better operator tracing.I added 
additional tests for the adapted optimizers and the fusion representations.
---
 src/main/python/systemds/scuro/__init__.py         |   4 +
 .../systemds/scuro/dataloader/audio_loader.py      |  24 +-
 .../systemds/scuro/dataloader/video_loader.py      |  22 +-
 .../python/systemds/scuro/drsearch/dr_search.py    |   2 +-
 .../scuro/drsearch/multimodal_optimizer.py         | 398 +++++++++++++++++++++
 .../systemds/scuro/drsearch/operator_registry.py   |  12 +
 .../systemds/scuro/drsearch/unimodal_optimizer.py  | 304 ++++++++++++++++
 .../systemds/scuro/modality/joined_transformed.py  |   3 +-
 .../python/systemds/scuro/modality/modality.py     |  58 ++-
 .../python/systemds/scuro/modality/transformed.py  |  30 +-
 src/main/python/systemds/scuro/modality/type.py    |  37 +-
 .../systemds/scuro/modality/unimodal_modality.py   |  50 ++-
 .../systemds/scuro/representations/aggregate.py    |   7 +-
 .../systemds/scuro/representations/average.py      |  19 +-
 .../python/systemds/scuro/representations/bert.py  |  48 +--
 .../python/systemds/scuro/representations/bow.py   |   3 +-
 .../scuro/representations/concatenation.py         |  22 +-
 .../systemds/scuro/representations/fusion.py       |  17 +
 .../python/systemds/scuro/representations/glove.py |   1 +
 .../systemds/scuro/representations/hadamard.py     |   3 +-
 .../python/systemds/scuro/representations/lstm.py  |   2 +-
 .../python/systemds/scuro/representations/max.py   |   4 +-
 .../scuro/representations/mel_spectrogram.py       |   9 +-
 .../python/systemds/scuro/representations/mfcc.py  |   9 +-
 .../scuro/representations/representation.py        |   2 +-
 .../systemds/scuro/representations/spectrogram.py  |  11 +-
 .../python/systemds/scuro/representations/sum.py   |  11 +-
 .../python/systemds/scuro/representations/tfidf.py |   4 +-
 .../systemds/scuro/representations/wav2vec.py      |   4 +-
 .../scuro/representations/window_aggregation.py    |  49 ++-
 .../systemds/scuro/representations/word2vec.py     |   5 +-
 .../python/systemds/scuro/utils/schema_helpers.py  |   4 +
 .../static_variables.py}                           |  27 +-
 src/main/python/tests/scuro/data_generator.py      |  49 ++-
 src/main/python/tests/scuro/test_dr_search.py      |   4 +-
 src/main/python/tests/scuro/test_fusion_orders.py  |   2 +-
 .../python/tests/scuro/test_multimodal_fusion.py   |  68 ++--
 .../python/tests/scuro/test_multimodal_join.py     |   4 -
 .../python/tests/scuro/test_unimodal_optimizer.py  |  31 +-
 39 files changed, 1136 insertions(+), 227 deletions(-)

diff --git a/src/main/python/systemds/scuro/__init__.py 
b/src/main/python/systemds/scuro/__init__.py
index 1c3cfe9223..ae9aed44c0 100644
--- a/src/main/python/systemds/scuro/__init__.py
+++ b/src/main/python/systemds/scuro/__init__.py
@@ -73,6 +73,8 @@ from systemds.scuro.drsearch.representation_cache import 
RepresentationCache
 from systemds.scuro.drsearch.unimodal_representation_optimizer import (
     UnimodalRepresentationOptimizer,
 )
+from systemds.scuro.drsearch.multimodal_optimizer import MultimodalOptimizer
+from systemds.scuro.drsearch.unimodal_optimizer import UnimodalOptimizer
 
 
 __all__ = [
@@ -127,4 +129,6 @@ __all__ = [
     "OptimizationData",
     "RepresentationCache",
     "UnimodalRepresentationOptimizer",
+    "UnimodalOptimizer",
+    "MultimodalOptimizer",
 ]
diff --git a/src/main/python/systemds/scuro/dataloader/audio_loader.py 
b/src/main/python/systemds/scuro/dataloader/audio_loader.py
index a1dad304e5..1197617673 100644
--- a/src/main/python/systemds/scuro/dataloader/audio_loader.py
+++ b/src/main/python/systemds/scuro/dataloader/audio_loader.py
@@ -45,18 +45,18 @@ class AudioLoader(BaseLoader):
 
     def extract(self, file: str, index: Optional[Union[str, List[str]]] = 
None):
         self.file_sanity_check(file)
-        # if not self.load_data_from_file:
-        #     import numpy as np
-        #
-        #     self.metadata[file] = self.modality_type.create_audio_metadata(
-        #         1000, np.array([0])
-        #     )
-        # else:
-        audio, sr = librosa.load(file, dtype=self._data_type)
+        if not self.load_data_from_file:
+            import numpy as np
 
-        if self.normalize:
-            audio = librosa.util.normalize(audio)
+            self.metadata[file] = self.modality_type.create_audio_metadata(
+                1000, np.array([0])
+            )
+        else:
+            audio, sr = librosa.load(file, dtype=self._data_type)
 
-        self.metadata[file] = self.modality_type.create_audio_metadata(sr, 
audio)
+            if self.normalize:
+                audio = librosa.util.normalize(audio)
 
-        self.data.append(audio)
+            self.metadata[file] = self.modality_type.create_audio_metadata(sr, 
audio)
+
+            self.data.append(audio)
diff --git a/src/main/python/systemds/scuro/dataloader/video_loader.py 
b/src/main/python/systemds/scuro/dataloader/video_loader.py
index 96ea5f11f6..0e77d5dc57 100644
--- a/src/main/python/systemds/scuro/dataloader/video_loader.py
+++ b/src/main/python/systemds/scuro/dataloader/video_loader.py
@@ -35,11 +35,13 @@ class VideoLoader(BaseLoader):
         data_type: Union[np.dtype, str] = np.float16,
         chunk_size: Optional[int] = None,
         load=True,
+        fps=None,
     ):
         super().__init__(
             source_path, indices, data_type, chunk_size, ModalityType.VIDEO
         )
         self.load_data_from_file = load
+        self.fps = fps
 
     def extract(self, file: str, index: Optional[Union[str, List[str]]] = 
None):
         self.file_sanity_check(file)
@@ -53,25 +55,33 @@ class VideoLoader(BaseLoader):
         if not cap.isOpened():
             raise f"Could not read video at path: {file}"
 
-        fps = cap.get(cv2.CAP_PROP_FPS)
+        orig_fps = cap.get(cv2.CAP_PROP_FPS)
+        frame_interval = 1
+        if self.fps is not None and self.fps < orig_fps:
+            frame_interval = int(round(orig_fps / self.fps))
+        else:
+            self.fps = orig_fps
+
         length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
         width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
         height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
         num_channels = 3
 
         self.metadata[file] = self.modality_type.create_video_metadata(
-            fps, length, width, height, num_channels
+            self.fps, length, width, height, num_channels
         )
 
         frames = []
+        idx = 0
         while cap.isOpened():
             ret, frame = cap.read()
 
             if not ret:
                 break
-            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
-            frame = frame.astype(self._data_type) / 255.0
-
-            frames.append(frame)
+            if idx % frame_interval == 0:
+                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+                frame = frame.astype(self._data_type) / 255.0
+                frames.append(frame)
+            idx += 1
 
         self.data.append(np.stack(frames))
diff --git a/src/main/python/systemds/scuro/drsearch/dr_search.py 
b/src/main/python/systemds/scuro/drsearch/dr_search.py
index 2000608a1d..601001c742 100644
--- a/src/main/python/systemds/scuro/drsearch/dr_search.py
+++ b/src/main/python/systemds/scuro/drsearch/dr_search.py
@@ -76,7 +76,7 @@ class DRSearch:
         """
 
         # check if modality name is already in dictionary
-        if "_".join(modality_names) not in self.scores.keys():
+        if "_".join(modality_names) not in list(self.scores.keys()):
             # if not add it to dictionary
             self.scores["_".join(modality_names)] = {}
 
diff --git a/src/main/python/systemds/scuro/drsearch/multimodal_optimizer.py 
b/src/main/python/systemds/scuro/drsearch/multimodal_optimizer.py
new file mode 100644
index 0000000000..2da8e7ae19
--- /dev/null
+++ b/src/main/python/systemds/scuro/drsearch/multimodal_optimizer.py
@@ -0,0 +1,398 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import itertools
+
+from systemds.scuro.representations.aggregated_representation import (
+    AggregatedRepresentation,
+)
+
+from systemds.scuro.representations.aggregate import Aggregation
+
+from systemds.scuro.drsearch.operator_registry import Registry
+
+from systemds.scuro.utils.schema_helpers import get_shape
+import dataclasses
+
+
+class MultimodalOptimizer:
+    def __init__(
+        self, modalities, unimodal_optimization_results, tasks, k=2, debug=True
+    ):
+        self.k_best_cache = None
+        self.k_best_modalities = None
+        self.modalities = modalities
+        self.unimodal_optimization_results = unimodal_optimization_results
+        self.tasks = tasks
+        self.k = k
+        self.extract_k_best_modalities_per_task()
+        self.debug = debug
+
+        self.operator_registry = Registry()
+        self.optimization_results = MultimodalResults(
+            modalities, tasks, debug, self.k_best_modalities
+        )
+        self.cache = {}
+
+    def optimize(self):
+        for task in self.tasks:
+            self.optimize_intermodal_representations(task)
+
+    def optimize_intramodal_representations(self, task):
+        for modality in self.modalities:
+            representations = self.k_best_modalities[task.model.name][
+                modality.modality_id
+            ]
+            applied_representations = self.extract_representations(
+                representations, modality, task.model.name
+            )
+
+            for i in range(1, len(applied_representations)):
+                for fusion_method in 
self.operator_registry.get_fusion_operators():
+                    if fusion_method().needs_alignment and not 
applied_representations[
+                        i - 1
+                    ].is_aligned(applied_representations[i]):
+                        continue
+                    combined = applied_representations[i - 1].combine(
+                        applied_representations[i], fusion_method()
+                    )
+                    self.evaluate(
+                        task,
+                        combined,
+                        [i - 1, i],
+                        fusion_method,
+                        [
+                            applied_representations[i - 1].modality_id,
+                            applied_representations[i].modality_id,
+                        ],
+                    )
+                    if not fusion_method().commutative:
+                        combined_comm = applied_representations[i].combine(
+                            applied_representations[i - 1], fusion_method()
+                        )
+                        self.evaluate(
+                            task,
+                            combined_comm,
+                            [i, i - 1],
+                            fusion_method,
+                            [
+                                applied_representations[i - 1].modality_id,
+                                applied_representations[i].modality_id,
+                            ],
+                        )
+
+    def optimize_intermodal_representations(self, task):
+        modality_combos = []
+        n = len(self.k_best_cache[task.model.name])
+        reuse_cache = {}
+
+        def generate_extensions(current_combo, remaining_indices):
+            # Add current combination if it has at least 2 elements
+            if len(current_combo) >= 2:
+                combo_tuple = tuple(i for i in current_combo)
+                modality_combos.append(combo_tuple)
+
+            for i in remaining_indices:
+                new_combo = current_combo + [i]
+                new_remaining = [j for j in remaining_indices if j > i]
+                generate_extensions(new_combo, new_remaining)
+
+        for start_idx in range(n):
+            remaining = list(range(start_idx + 1, n))
+            generate_extensions([start_idx], remaining)
+        fusion_methods = self.operator_registry.get_fusion_operators()
+        fused_representations = []
+        reuse_fused_representations = False
+        for i, modality_combo in enumerate(modality_combos):
+            # clear reuse cache
+            if i % 5 == 0:
+                reuse_cache = self.prune_cache(modality_combos[i:], 
reuse_cache)
+
+            if i != 0:
+                reuse_fused_representations = self.is_prefix_match(
+                    modality_combos[i - 1], modality_combo
+                )
+            if reuse_fused_representations:
+                mods = [
+                    self.k_best_cache[task.model.name][mod_idx]
+                    for mod_idx in modality_combo[len(modality_combos[i - 1]) 
:]
+                ]
+                fused_representations = reuse_cache[modality_combos[i - 1]]
+            else:
+                prefix_idx = self.compute_equal_prefix_index(
+                    modality_combos[i - 1], modality_combo
+                )
+                if prefix_idx > 1:
+                    fused_representations = reuse_cache[
+                        modality_combos[i - 1][:prefix_idx]
+                    ]
+                    reuse_fused_representations = True
+                    mods = [
+                        self.k_best_cache[task.model.name][mod_idx]
+                        for mod_idx in modality_combo[prefix_idx:]
+                    ]
+            if self.debug:
+                print(
+                    f"New modality combo: {modality_combo} - Reuse: 
{reuse_fused_representations} - # fused reps: {len(fused_representations)}"
+                )
+
+            all_mods = [
+                self.k_best_cache[task.model.name][mod_idx]
+                for mod_idx in modality_combo
+            ]
+            temp_fused_reps = []
+            for j, fusion_method in enumerate(fusion_methods):
+                # Evaluate all mods
+                fused_rep = all_mods[0].combine(all_mods[1:], fusion_method())
+                temp_fused_reps.append(fused_rep)
+                self.evaluate(
+                    task,
+                    fused_rep,
+                    [
+                        
self.k_best_modalities[task.model.name][k].representations
+                        for k in modality_combo
+                    ],
+                    fusion_method,
+                    modality_combo,
+                )
+                if reuse_fused_representations:
+                    for fused_representation in fused_representations:
+                        fused_rep = fused_representation.combine(mods, 
fusion_method())
+                        temp_fused_reps.append(fused_rep)
+                        self.evaluate(
+                            task,
+                            fused_rep,
+                            [
+                                self.k_best_modalities[task.model.name][
+                                    k
+                                ].representations
+                                for k in modality_combo
+                            ],
+                            fusion_method,
+                            modality_combo,
+                        )
+
+            if (
+                len(modality_combo) < len(self.k_best_cache[task.model.name])
+                and i + 1 < len(modality_combos)
+                and self.is_prefix_match(modality_combos[i], modality_combos[i 
+ 1])
+            ):
+                reuse_cache[modality_combo] = temp_fused_reps
+            reuse_fused_representations = False
+
+    def prune_cache(self, sequences, cache):
+        seqs_as_tuples = [tuple(seq) for seq in sequences]
+
+        def still_used(key):
+            return any(self.is_prefix_match(key, seq) for seq in 
seqs_as_tuples)
+
+        cache = {key: value for key, value in cache.items() if still_used(key)}
+        return cache
+
+    def is_prefix_match(self, seq1, seq2):
+        if len(seq1) > len(seq2):
+            return False
+
+        # Check if seq1 matches the beginning of seq2
+        return seq2[: len(seq1)] == seq1
+
+    def compute_equal_prefix_index(self, seq1, seq2):
+        max_len = min(len(seq1), len(seq2))
+        i = 0
+        while i < max_len and seq1[i] == seq2[i]:
+            i += 1
+
+        return i
+
+    def extract_representations(self, representations, modality, task_name):
+        applied_representations = []
+        for i in range(0, len(representations)):
+            cache_key = (
+                tuple(representations[i].representations),
+                representations[i].task_time,
+                representations[i].representation_time,
+            )
+            if (
+                cache_key
+                in 
self.unimodal_optimization_results.cache[modality.modality_id][
+                    task_name
+                ]
+            ):
+                applied_representations.append(
+                    
self.unimodal_optimization_results.cache[modality.modality_id][
+                        task_name
+                    ][cache_key]
+                )
+            else:
+                applied_representation = modality
+                for j, rep in enumerate(representations[i].representations):
+                    representation, is_context = (
+                        self.operator_registry.get_representation_by_name(
+                            rep, modality.modality_type
+                        )
+                    )
+                    if representation is None:
+                        if rep == AggregatedRepresentation.__name__:
+                            representation = 
AggregatedRepresentation(Aggregation())
+                    else:
+                        representation = representation()
+                    representation.set_parameters(representations[i].params[j])
+                    if is_context:
+                        applied_representation = 
applied_representation.context(
+                            representation
+                        )
+                    else:
+                        applied_representation = (
+                            
applied_representation.apply_representation(representation)
+                        )
+                self.k_best_cache[task_name].append(applied_representation)
+                applied_representations.append(applied_representation)
+        return applied_representations
+
+    def evaluate(self, task, modality, representations, fusion, 
modality_combo):
+        if task.expected_dim == 1 and get_shape(modality.metadata) > 1:
+            for aggregation in Aggregation().get_aggregation_functions():
+                agg_operator = 
AggregatedRepresentation(Aggregation(aggregation, False))
+                agg_modality = agg_operator.transform(modality)
+
+                scores = task.run(agg_modality.data)
+                reps = representations.copy()
+                reps.append(agg_operator)
+
+                self.optimization_results.add_result(
+                    scores,
+                    reps,
+                    modality.transformation,
+                    modality_combo,
+                    task.model.name,
+                )
+        else:
+            scores = task.run(modality.data)
+            self.optimization_results.add_result(
+                scores,
+                representations,
+                modality.transformation,
+                modality_combo,
+                task.model.name,
+            )
+
+    def add_to_cache(self, result_idx, combined_modality):
+        self.cache[result_idx] = combined_modality
+
+    def extract_k_best_modalities_per_task(self):
+        self.k_best_modalities = {}
+        self.k_best_cache = {}
+        for task in self.tasks:
+            self.k_best_modalities[task.model.name] = []
+            self.k_best_cache[task.model.name] = []
+            for modality in self.modalities:
+                k_best_results, cached_data = (
+                    self.unimodal_optimization_results.get_k_best_results(
+                        modality, self.k, task
+                    )
+                )
+
+                self.k_best_modalities[task.model.name].extend(k_best_results)
+                self.k_best_cache[task.model.name].extend(cached_data)
+
+
+class MultimodalResults:
+    def __init__(self, modalities, tasks, debug, k_best_modalities):
+        self.modality_ids = [modality.modality_id for modality in modalities]
+        self.task_names = [task.model.name for task in tasks]
+        self.results = {}
+        self.debug = debug
+        self.k_best_modalities = k_best_modalities
+
+        for task in tasks:
+            self.results[task.model.name] = {}
+
+    def add_result(
+        self, scores, best_representation_idx, fusion_methods, modality_combo, 
task_name
+    ):
+
+        entry = MultimodalResultEntry(
+            representations=best_representation_idx,
+            train_score=scores[0],
+            val_score=scores[1],
+            fusion_methods=[
+                fusion_method.__class__.__name__ for fusion_method in 
fusion_methods
+            ],
+            modality_combo=modality_combo,
+            task=task_name,
+        )
+
+        modality_id_strings = "_".join(list(map(str, modality_combo)))
+        if not modality_id_strings in self.results[task_name]:
+            self.results[task_name][modality_id_strings] = []
+
+        self.results[task_name][modality_id_strings].append(entry)
+
+        if self.debug:
+            print(f"{modality_id_strings}_{task_name}: {entry}")
+
+    def print_results(self):
+        for task_name in self.task_names:
+            for modality in self.results[task_name].keys():
+                for entry in self.results[task_name][modality]:
+                    reps = []
+                    for i, mod_idx in enumerate(entry.modality_combo):
+                        reps.append(self.k_best_modalities[task_name][mod_idx])
+
+                    print(
+                        f"{modality}_{task_name}: "
+                        f"Validation score: {entry.val_score} - Training 
score: {entry.train_score}"
+                    )
+                    for i, rep in enumerate(reps):
+                        print(
+                            f"    Representation: {entry.modality_combo[i]} - 
{rep.representations}"
+                        )
+
+                    print(f"    Fusion: {entry.fusion_methods[0]} ")
+
+    def store_results(self, file_name=None):
+        for task_name in self.task_names:
+            for modality in self.results[task_name].keys():
+                for entry in self.results[task_name][modality]:
+                    reps = []
+                    for i, mod_idx in enumerate(entry.modality_combo):
+                        reps.append(self.k_best_modalities[task_name][mod_idx])
+                    entry.representations = reps
+
+        import pickle
+
+        if file_name is None:
+            import time
+
+            timestr = time.strftime("%Y%m%d-%H%M%S")
+            file_name = "multimodal_optimizer" + timestr + ".pkl"
+
+        with open(file_name, "wb") as f:
+            pickle.dump(self.results, f)
+
+
+@dataclasses.dataclass
+class MultimodalResultEntry:
+    val_score: float
+    modality_combo: list
+    representations: list
+    fusion_methods: list
+    train_score: float
+    task: str
diff --git a/src/main/python/systemds/scuro/drsearch/operator_registry.py 
b/src/main/python/systemds/scuro/drsearch/operator_registry.py
index cfd313eb56..3909b51ff9 100644
--- a/src/main/python/systemds/scuro/drsearch/operator_registry.py
+++ b/src/main/python/systemds/scuro/drsearch/operator_registry.py
@@ -64,6 +64,18 @@ class Registry:
     def get_fusion_operators(self):
         return self._fusion_operators
 
+    def get_representation_by_name(self, representation_name, modality_type):
+        for representation in self._context_operators:
+            if representation.__name__ == representation_name:
+                return representation, True
+
+        if modality_type is not None:
+            for representation in self._representations[modality_type]:
+                if representation.__name__ == representation_name:
+                    return representation, False
+
+        return None, False
+
 
 def register_representation(modalities: Union[ModalityType, 
List[ModalityType]]):
     """
diff --git a/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py 
b/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
new file mode 100644
index 0000000000..030f04aa43
--- /dev/null
+++ b/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
@@ -0,0 +1,304 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import pickle
+import time
+from concurrent.futures import ProcessPoolExecutor, as_completed
+from dataclasses import dataclass, field, asdict
+
+import multiprocessing as mp
+from typing import Union
+
+import numpy as np
+from systemds.scuro.representations.window_aggregation import WindowAggregation
+
+from systemds.scuro.representations.aggregated_representation import (
+    AggregatedRepresentation,
+)
+from systemds.scuro import ModalityType, Aggregation
+from systemds.scuro.drsearch.operator_registry import Registry
+from systemds.scuro.utils.schema_helpers import get_shape
+
+
+class UnimodalOptimizer:
+    def __init__(self, modalities, tasks, debug=True):
+        self.modalities = modalities
+        self.tasks = tasks
+
+        self.operator_registry = Registry()
+        self.operator_performance = UnimodalResults(modalities, tasks, debug)
+
+        self._tasks_require_same_dims = True
+        self.expected_dimensions = tasks[0].expected_dim
+
+        for i in range(1, len(tasks)):
+            self.expected_dimensions = tasks[i].expected_dim
+            if tasks[i - 1].expected_dim != tasks[i].expected_dim:
+                self._tasks_require_same_dims = False
+
+    def store_results(self, file_name=None):
+        if file_name is None:
+            import time
+
+            timestr = time.strftime("%Y%m%d-%H%M%S")
+            file_name = "unimodal_optimizer" + timestr + ".pkl"
+
+        with open(file_name, "wb") as f:
+            pickle.dump(self.operator_performance.results, f)
+
+    def optimize_parallel(self, n_workers=None):
+        if n_workers is None:
+            n_workers = min(len(self.modalities), mp.cpu_count())
+
+        with ProcessPoolExecutor(max_workers=n_workers) as executor:
+            future_to_modality = {
+                executor.submit(self._process_modality, modality, True): 
modality
+                for modality in self.modalities
+            }
+
+            for future in as_completed(future_to_modality):
+                modality = future_to_modality[future]
+                # try:
+                results = future.result()
+                self._merge_results(results)
+                # except Exception as exc:
+                #     print(f'Modality {modality.modality_id} generated an 
exception: {exc}')
+
+    def optimize(self):
+        for modality in self.modalities:
+            local_result = self._process_modality(modality, False)
+            # self._merge_results(local_result)
+
+    def _process_modality(self, modality, parallel):
+        if parallel:
+            local_results = UnimodalResults(
+                modalities=[modality], tasks=self.tasks, debug=False
+            )
+        else:
+            local_results = self.operator_performance
+
+        context_operators = self.operator_registry.get_context_operators()
+
+        for context_operator in context_operators:
+            context_representation = None
+            if (
+                modality.modality_type != ModalityType.TEXT
+                and modality.modality_type != ModalityType.VIDEO
+            ):
+                con_op = context_operator()
+                context_representation = modality.context(con_op)
+                self._evaluate_local(context_representation, [con_op], 
local_results)
+
+            modality_specific_operators = 
self.operator_registry.get_representations(
+                modality.modality_type
+            )
+            for modality_specific_operator in modality_specific_operators:
+                mod_context = None
+                mod_op = modality_specific_operator()
+                if context_representation is not None:
+                    mod_context = 
context_representation.apply_representation(mod_op)
+                    self._evaluate_local(mod_context, [con_op, mod_op], 
local_results)
+
+                mod = modality.apply_representation(mod_op)
+                self._evaluate_local(mod, [mod_op], local_results)
+
+                for context_operator_after in context_operators:
+                    con_op_after = context_operator_after()
+                    if mod_context is not None:
+                        mod_context = mod_context.context(con_op_after)
+                        self._evaluate_local(
+                            mod_context, [con_op, mod_op, con_op_after], 
local_results
+                        )
+
+                    mod = mod.context(con_op_after)
+                    self._evaluate_local(mod, [mod_op, con_op_after], 
local_results)
+
+            return local_results
+
+    def _merge_results(self, local_results):
+        """Merge local results into the main results"""
+        for modality_id in local_results.results:
+            for task_name in local_results.results[modality_id]:
+                
self.operator_performance.results[modality_id][task_name].extend(
+                    local_results.results[modality_id][task_name]
+                )
+
+        for modality in self.modalities:
+            for task_name in local_results.cache[modality]:
+                for key, value in 
local_results.cache[modality][task_name].items():
+                    self.operator_performance.cache[modality][task_name][key] 
= value
+
+    def _evaluate_local(self, modality, representations, local_results):
+        if self._tasks_require_same_dims:
+            if self.expected_dimensions == 1 and get_shape(modality.metadata) 
> 1:
+                # for aggregation in Aggregation().get_aggregation_functions():
+                agg_operator = AggregatedRepresentation(Aggregation())
+                agg_modality = agg_operator.transform(modality)
+                reps = representations.copy()
+                reps.append(agg_operator)
+                # agg_modality.pad()
+                for task in self.tasks:
+                    start = time.time()
+                    scores = task.run(agg_modality.data)
+                    end = time.time()
+
+                    local_results.add_result(
+                        scores,
+                        reps,
+                        modality,
+                        task.model.name,
+                        end - start,
+                    )
+            else:
+                modality.pad()
+                for task in self.tasks:
+                    start = time.time()
+                    scores = task.run(modality.data)
+                    end = time.time()
+                    local_results.add_result(
+                        scores,
+                        representations,
+                        modality,
+                        task.model.name,
+                        end - start,
+                    )
+        else:
+            for task in self.tasks:
+                if task.expected_dim == 1 and get_shape(modality.metadata) > 1:
+                    # for aggregation in 
Aggregation().get_aggregation_functions():
+                    agg_operator = AggregatedRepresentation(Aggregation())
+                    agg_modality = agg_operator.transform(modality)
+
+                    reps = representations.copy()
+                    reps.append(agg_operator)
+                    # modality.pad()
+                    start = time.time()
+                    scores = task.run(agg_modality.data)
+                    end = time.time()
+                    local_results.add_result(
+                        scores,
+                        reps,
+                        modality,
+                        task.model.name,
+                        end - start,
+                    )
+                else:
+                    # modality.pad()
+                    start = time.time()
+                    scores = task.run(modality.data)
+                    end = time.time()
+                    local_results.add_result(
+                        scores,
+                        representations,
+                        modality,
+                        task.model.name,
+                        end - start,
+                    )
+
+
+class UnimodalResults:
+    def __init__(self, modalities, tasks, debug=False):
+        self.modality_ids = [modality.modality_id for modality in modalities]
+        self.task_names = [task.model.name for task in tasks]
+        self.results = {}
+        self.debug = debug
+        self.cache = {}
+
+        for modality in self.modality_ids:
+            self.results[modality] = {}
+            self.cache[modality] = {}
+            for task_name in self.task_names:
+                self.cache[modality][task_name] = {}
+                self.results[modality][task_name] = []
+
+    def add_result(self, scores, representations, modality, task_name, 
task_time):
+        parameters = []
+        representation_names = []
+
+        for rep in representations:
+            representation_names.append(type(rep).__name__)
+            if isinstance(rep, AggregatedRepresentation):
+                parameters.append(rep.parameters)
+                continue
+
+            params = {}
+            for param in list(rep.parameters.keys()):
+                params[param] = getattr(rep, param)
+
+            if isinstance(rep, WindowAggregation):
+                params["aggregation_function"] = (
+                    rep.aggregation_function.aggregation_function_name
+                )
+
+            parameters.append(params)
+
+        entry = ResultEntry(
+            representations=representation_names,
+            params=parameters,
+            train_score=scores[0],
+            val_score=scores[1],
+            representation_time=modality.transform_time,
+            task_time=task_time,
+        )
+        self.results[modality.modality_id][task_name].append(entry)
+        self.cache[modality.modality_id][task_name][
+            (tuple(representation_names), scores[1], modality.transform_time)
+        ] = modality
+
+        if self.debug:
+            print(f"{modality.modality_id}_{task_name}: {entry}")
+
+    def print_results(self):
+        for modality in self.modality_ids:
+            for task_name in self.task_names:
+                for entry in self.results[modality][task_name]:
+                    print(f"{modality}_{task_name}: {entry}")
+
+    def get_k_best_results(self, modality, k, task):
+        """
+        Get the k best results for the given modality
+        :param modality: modality to get the best results for
+        :param k: number of best results
+        """
+        items = self.results[modality.modality_id][task.model.name]
+        sorted_indices = sorted(
+            range(len(items)), key=lambda x: items[x].val_score, reverse=True
+        )[:k]
+
+        results = sorted(
+            self.results[modality.modality_id][task.model.name],
+            key=lambda x: x.val_score,
+            reverse=True,
+        )[:k]
+
+        items = list(self.cache[modality.modality_id][task.model.name].items())
+        reordered_cache = [items[i][1] for i in sorted_indices]
+
+        return results, list(reordered_cache)
+
+
+@dataclass(frozen=True)
+class ResultEntry:
+    val_score: float
+    representations: list
+    params: list
+    train_score: float
+    representation_time: float
+    task_time: float
diff --git a/src/main/python/systemds/scuro/modality/joined_transformed.py 
b/src/main/python/systemds/scuro/modality/joined_transformed.py
index 6c6190e03c..3e0d8fb9df 100644
--- a/src/main/python/systemds/scuro/modality/joined_transformed.py
+++ b/src/main/python/systemds/scuro/modality/joined_transformed.py
@@ -36,7 +36,8 @@ class JoinedTransformedModality(Modality):
         :param transformation: Representation to be applied on the modality
         """
         super().__init__(
-            reduce(or_, [left_modality.modality_type], 
right_modality.modality_type)
+            reduce(or_, [left_modality.modality_type], 
right_modality.modality_type),
+            data_type=left_modality.data_type,
         )
         self.transformation = transformation
         self.left_modality = left_modality
diff --git a/src/main/python/systemds/scuro/modality/modality.py 
b/src/main/python/systemds/scuro/modality/modality.py
index 87d5b5ee4e..94e745b2cc 100644
--- a/src/main/python/systemds/scuro/modality/modality.py
+++ b/src/main/python/systemds/scuro/modality/modality.py
@@ -22,6 +22,7 @@ from copy import deepcopy
 from typing import List
 
 import numpy as np
+from numpy.f2py.auxfuncs import throw_error
 
 from systemds.scuro.modality.type import ModalityType
 from systemds.scuro.representations import utils
@@ -44,6 +45,7 @@ class Modality:
         self.cost = None
         self.shape = None
         self.modality_id = modality_id
+        self.transform_time = None
 
     @property
     def data(self):
@@ -61,9 +63,11 @@ class Modality:
         """
         Extracts the individual unimodal modalities for a given transformed 
modality.
         """
-        return [
+        modality_names = [
             modality.name for modality in ModalityType if modality in 
self.modality_type
         ]
+        modality_names.append(str(self.modality_id))
+        return modality_names
 
     def copy_from_instance(self):
         """
@@ -90,36 +94,60 @@ class Modality:
             updated_md = self.modality_type.update_metadata(md_v, self.data[i])
             self.metadata[md_k] = updated_md
 
-    def flatten(self, padding=True):
+    def flatten(self, padding=False):
         """
         Flattens modality data by row-wise concatenation
         Prerequisite for some ML-models
         """
         max_len = 0
+        data = []
         for num_instance, instance in enumerate(self.data):
             if type(instance) is np.ndarray:
-                self.data[num_instance] = instance.flatten()
+                d = instance.flatten()
+                max_len = max(max_len, len(d))
+                data.append(d)
             elif isinstance(instance, List):
-                self.data[num_instance] = np.array(
+                d = np.array(
                     [item for sublist in instance for item in sublist]
                 ).flatten()
-            max_len = max(max_len, len(self.data[num_instance]))
+                max_len = max(max_len, len(d))
+                data.append(d)
 
         if padding:
-            for i, instance in enumerate(self.data):
+            for i, instance in enumerate(data):
                 if isinstance(instance, np.ndarray):
                     if len(instance) < max_len:
                         padded_data = np.zeros(max_len, dtype=instance.dtype)
                         padded_data[: len(instance)] = instance
-                        self.data[i] = padded_data
+                        data[i] = padded_data
                 else:
                     padded_data = []
                     for entry in instance:
                         padded_data.append(utils.pad_sequences(entry, max_len))
-                    self.data[i] = padded_data
-        self.data = np.array(self.data)
+                    data[i] = padded_data
+        self.data = np.array(data)
         return self
 
+    def pad(self, value=0, max_len=None):
+        try:
+            if max_len is None:
+                result = np.array(self.data)
+            else:
+                raise "Needs padding to max_len"
+        except:
+            maxlen = (
+                max([len(seq) for seq in self.data]) if max_len is None else 
max_len
+            )
+
+            result = np.full((len(self.data), maxlen), value, 
dtype=self.data_type)
+
+            for i, seq in enumerate(self.data):
+                data = seq[:maxlen]
+                result[i, : len(data)] = data
+                # TODO: add padding to metadata as attention_masks
+
+        self.data = result
+
     def get_data_layout(self):
         if self.has_metadata():
             return 
list(self.metadata.values())[0]["data_layout"]["representation"]
@@ -131,3 +159,15 @@ class Modality:
 
     def has_metadata(self):
         return self.metadata is not None and self.metadata != {}
+
+    def is_aligned(self, other_modality):
+        aligned = True
+        for i in range(len(self.data)):
+            if (
+                list(self.metadata.values())[i]["data_layout"]["shape"]
+                != 
list(other_modality.metadata.values())[i]["data_layout"]["shape"]
+            ):
+                aligned = False
+                continue
+
+        return aligned
diff --git a/src/main/python/systemds/scuro/modality/transformed.py 
b/src/main/python/systemds/scuro/modality/transformed.py
index 362764d21e..6523e9502f 100644
--- a/src/main/python/systemds/scuro/modality/transformed.py
+++ b/src/main/python/systemds/scuro/modality/transformed.py
@@ -20,11 +20,14 @@
 # -------------------------------------------------------------
 from functools import reduce
 from operator import or_
+from typing import Union, List
 
 from systemds.scuro.modality.type import ModalityType
 from systemds.scuro.modality.joined import JoinedModality
 from systemds.scuro.modality.modality import Modality
 from systemds.scuro.representations.window_aggregation import WindowAggregation
+import time
+import copy
 
 
 class TransformedModality(Modality):
@@ -42,7 +45,22 @@ class TransformedModality(Modality):
         super().__init__(
             new_modality_type, modality.modality_id, metadata, 
modality.data_type
         )
-        self.transformation = transformation
+        self.transformation = None
+        self.add_transformation(transformation, modality)
+
+    def add_transformation(self, transformation, modality):
+        if (
+            transformation.__class__.__bases__[0].__name__ == "Fusion"
+            and modality.transformation[0].__class__.__bases__[0].__name__ != 
"Fusion"
+        ):
+            self.transformation = []
+        else:
+            self.transformation = (
+                []
+                if type(modality).__name__ != "TransformedModality"
+                else copy.deepcopy(modality.transformation)
+            )
+        self.transformation.append(transformation)
 
     def copy_from_instance(self):
         return type(self)(self, self.transformation)
@@ -72,22 +90,26 @@ class TransformedModality(Modality):
     def window_aggregation(self, windowSize, aggregation):
         w = WindowAggregation(windowSize, aggregation)
         transformed_modality = TransformedModality(self, w)
+        start = time.time()
         transformed_modality.data = w.execute(self)
-
+        transformed_modality.transform_time = time.time() - start
         return transformed_modality
 
     def context(self, context_operator):
         transformed_modality = TransformedModality(self, context_operator)
-
+        start = time.time()
         transformed_modality.data = context_operator.execute(self)
+        transformed_modality.transform_time = time.time() - start
         return transformed_modality
 
     def apply_representation(self, representation):
+        start = time.time()
         new_modality = representation.transform(self)
         new_modality.update_metadata()
+        new_modality.transform_time = time.time() - start
         return new_modality
 
-    def combine(self, other, fusion_method):
+    def combine(self, other: Union[Modality, List[Modality]], fusion_method):
         """
         Combines two or more modalities with each other using a dedicated 
fusion method
         :param other: The modality to be combined
diff --git a/src/main/python/systemds/scuro/modality/type.py 
b/src/main/python/systemds/scuro/modality/type.py
index a479e07085..b2331d0fae 100644
--- a/src/main/python/systemds/scuro/modality/type.py
+++ b/src/main/python/systemds/scuro/modality/type.py
@@ -26,6 +26,7 @@ from systemds.scuro.utils.schema_helpers import (
     calculate_new_frequency,
     create_timestamps,
 )
+import torch
 
 
 # TODO: needs a way to define if data comes from a dataset with multiple 
instances or is like a streaming scenario where we only have one instance
@@ -99,11 +100,19 @@ class ModalitySchemas:
         dtype = np.nan
         shape = None
         if data_layout is DataLayout.SINGLE_LEVEL:
-            dtype = data.dtype
-            shape = data.shape
+            if isinstance(data, list):
+                dtype = data[0].dtype
+                shape = data[0].shape
+            elif isinstance(data, np.ndarray):
+                dtype = data.dtype
+                shape = data.shape
         elif data_layout is DataLayout.NESTED_LEVEL:
-            shape = data[0].shape
-            dtype = data[0].dtype
+            if data_is_single_instance:
+                dtype = data.dtype
+                shape = data.shape
+            else:
+                shape = data[0].shape
+                dtype = data[0].dtype
 
         md["data_layout"].update(
             {"representation": data_layout, "type": dtype, "shape": shape}
@@ -199,9 +208,15 @@ class ModalityType(Flag):
         md[field] = data
         return md
 
-    def create_audio_metadata(self, sampling_rate, data):
+    def add_field_for_instances(self, md, field, data):
+        for key, value in zip(md.keys(), data):
+            md[key].update({field: value})
+
+        return md
+
+    def create_audio_metadata(self, sampling_rate, data, 
is_single_instance=True):
         md = deepcopy(self.get_schema())
-        md = ModalitySchemas.update_base_metadata(md, data, True)
+        md = ModalitySchemas.update_base_metadata(md, data, is_single_instance)
         md["frequency"] = sampling_rate
         md["length"] = data.shape[0]
         md["timestamp"] = create_timestamps(sampling_rate, md["length"])
@@ -240,10 +255,14 @@ class DataLayout(Enum):
             return None
 
         if data_is_single_instance:
-            if isinstance(data, list):
-                return DataLayout.NESTED_LEVEL
-            elif isinstance(data, np.ndarray):
+            if (
+                isinstance(data, list)
+                or isinstance(data, np.ndarray)
+                and data.ndim == 1
+            ):
                 return DataLayout.SINGLE_LEVEL
+            elif isinstance(data, np.ndarray) or isinstance(data, 
torch.Tensor):
+                return DataLayout.NESTED_LEVEL
 
         if isinstance(data[0], list):
             return DataLayout.NESTED_LEVEL
diff --git a/src/main/python/systemds/scuro/modality/unimodal_modality.py 
b/src/main/python/systemds/scuro/modality/unimodal_modality.py
index c0ee70557c..94d1fa057d 100644
--- a/src/main/python/systemds/scuro/modality/unimodal_modality.py
+++ b/src/main/python/systemds/scuro/modality/unimodal_modality.py
@@ -20,8 +20,9 @@
 # -------------------------------------------------------------
 from functools import reduce
 from operator import or_
-
-
+import time
+import numpy as np
+from systemds.scuro import ModalityType
 from systemds.scuro.dataloader.base_loader import BaseLoader
 from systemds.scuro.modality.modality import Modality
 from systemds.scuro.modality.joined import JoinedModality
@@ -86,12 +87,14 @@ class UnimodalModality(Modality):
         return joined_modality
 
     def context(self, context_operator):
+        start = time.time()
         if not self.has_data():
             self.extract_raw_data()
 
         transformed_modality = TransformedModality(self, context_operator)
 
         transformed_modality.data = context_operator.execute(self)
+        transformed_modality.transform_time = time.time() - start
         return transformed_modality
 
     def aggregate(self, aggregation_function):
@@ -108,18 +111,57 @@ class UnimodalModality(Modality):
             representation,
         )
         new_modality.data = []
-
+        start = time.time()
+        original_lengths = []
         if self.data_loader.chunk_size:
             self.data_loader.reset()
             while self.data_loader.next_chunk < self.data_loader.num_chunks:
                 self.extract_raw_data()
                 transformed_chunk = representation.transform(self)
                 new_modality.data.extend(transformed_chunk.data)
+                for d in transformed_chunk.data:
+                    original_lengths.append(d.shape[0])
                 new_modality.metadata.update(transformed_chunk.metadata)
         else:
-            if not self.data:
+            if not self.has_data():
                 self.extract_raw_data()
             new_modality = representation.transform(self)
 
+            if not all(
+                "attention_masks" in entry for entry in 
new_modality.metadata.values()
+            ):
+                for d in new_modality.data:
+                    original_lengths.append(d.shape[0])
+
+        if len(original_lengths) > 0 and min(original_lengths) < 
max(original_lengths):
+            target_length = max(original_lengths)
+            padded_embeddings = []
+            for embeddings in new_modality.data:
+                current_length = embeddings.shape[0]
+                if current_length < target_length:
+                    padding_needed = target_length - current_length
+
+                    padded = np.pad(
+                        embeddings,
+                        pad_width=(
+                            (0, padding_needed),
+                            (0, 0),
+                        ),
+                        mode="constant",
+                        constant_values=0,
+                    )
+                    padded_embeddings.append(padded)
+                else:
+                    padded_embeddings.append(embeddings)
+
+            attention_masks = np.zeros((len(new_modality.data), target_length))
+            for i, length in enumerate(original_lengths):
+                attention_masks[i, :length] = 1
+
+            ModalityType(self.modality_type).add_field_for_instances(
+                new_modality.metadata, "attention_masks", attention_masks
+            )
+            new_modality.data = padded_embeddings
         new_modality.update_metadata()
+        new_modality.transform_time = time.time() - start
         return new_modality
diff --git a/src/main/python/systemds/scuro/representations/aggregate.py 
b/src/main/python/systemds/scuro/representations/aggregate.py
index 756e6271ea..2c046dc401 100644
--- a/src/main/python/systemds/scuro/representations/aggregate.py
+++ b/src/main/python/systemds/scuro/representations/aggregate.py
@@ -52,12 +52,13 @@ class Aggregation:
             aggregation_function = params["aggregation_function"]
             pad_modality = params["pad_modality"]
 
-        if aggregation_function not in self._aggregation_function.keys():
+        if aggregation_function not in list(self._aggregation_function.keys()):
             raise ValueError("Invalid aggregation function")
 
         self._aggregation_func = 
self._aggregation_function[aggregation_function]
         self.name = "Aggregation"
         self.pad_modality = pad_modality
+        self.aggregation_function_name = aggregation_function
 
         self.parameters = {
             "aggregation_function": aggregation_function,
@@ -91,7 +92,7 @@ class Aggregation:
                         padded_data.append(utils.pad_sequences(entry, max_len))
                     data[i] = padded_data
 
-        return data
+        return np.array(data)
 
     def transform(self, modality):
         return self.execute(modality)
@@ -100,4 +101,4 @@ class Aggregation:
         return self._aggregation_func(instance)
 
     def get_aggregation_functions(self):
-        return self._aggregation_function.keys()
+        return list(self._aggregation_function.keys())
diff --git a/src/main/python/systemds/scuro/representations/average.py 
b/src/main/python/systemds/scuro/representations/average.py
index 8a7e6b9ec8..ac51f5d1e8 100644
--- a/src/main/python/systemds/scuro/representations/average.py
+++ b/src/main/python/systemds/scuro/representations/average.py
@@ -18,7 +18,7 @@
 # under the License.
 #
 # -------------------------------------------------------------
-
+import copy
 from typing import List
 
 import numpy as np
@@ -37,23 +37,14 @@ class Average(Fusion):
         Combines modalities using averaging
         """
         super().__init__("Average")
+        self.needs_alignment = True
         self.associative = True
         self.commutative = True
 
-    def transform(self, modalities: List[Modality]):
-        for modality in modalities:
-            modality.flatten()
-
-        max_emb_size = self.get_max_embedding_size(modalities)
-
-        padded_modalities = []
-        for modality in modalities:
-            d = pad_sequences(modality.data, maxlen=max_emb_size, 
dtype="float32")
-            padded_modalities.append(d)
-
-        data = padded_modalities[0]
+    def execute(self, modalities: List[Modality]):
+        data = copy.deepcopy(modalities[0].data)
         for i in range(1, len(modalities)):
-            data += padded_modalities[i]
+            data += modalities[i].data
 
         data /= len(modalities)
 
diff --git a/src/main/python/systemds/scuro/representations/bert.py 
b/src/main/python/systemds/scuro/representations/bert.py
index 8d8d40f4fd..3478b84e67 100644
--- a/src/main/python/systemds/scuro/representations/bert.py
+++ b/src/main/python/systemds/scuro/representations/bert.py
@@ -18,7 +18,7 @@
 # under the License.
 #
 # -------------------------------------------------------------
-
+import numpy as np
 from systemds.scuro.modality.transformed import TransformedModality
 from systemds.scuro.representations.unimodal import UnimodalRepresentation
 import torch
@@ -34,12 +34,13 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
 @register_representation(ModalityType.TEXT)
 class Bert(UnimodalRepresentation):
-    def __init__(self, model_name="bert", output_file=None):
+    def __init__(self, model_name="bert", output_file=None, 
max_seq_length=512):
         parameters = {"model_name": "bert"}
         self.model_name = model_name
         super().__init__("Bert", ModalityType.EMBEDDING, parameters)
 
         self.output_file = output_file
+        self.max_seq_length = max_seq_length
 
     def transform(self, modality):
         transformed_modality = TransformedModality(modality, self)
@@ -55,32 +56,33 @@ class Bert(UnimodalRepresentation):
         if self.output_file is not None:
             save_embeddings(embeddings, self.output_file)
 
+        transformed_modality.data_type = np.float32
         transformed_modality.data = embeddings
         return transformed_modality
 
     def create_embeddings(self, modality, model, tokenizer):
-        embeddings = []
-        for i, d in enumerate(modality.data):
-            inputs = tokenizer(
-                d,
-                return_offsets_mapping=True,
-                return_tensors="pt",
-                padding=True,
-                truncation=True,
-            )
-
-            ModalityType.TEXT.add_field(
-                list(modality.metadata.values())[i],
-                "token_to_character_mapping",
-                inputs.data["offset_mapping"][0].tolist(),
-            )
+        inputs = tokenizer(
+            modality.data,
+            return_offsets_mapping=True,
+            return_tensors="pt",
+            padding="longest",
+            return_attention_mask=True,
+            truncation=True,
+        )
+        ModalityType.TEXT.add_field_for_instances(
+            modality.metadata,
+            "token_to_character_mapping",
+            inputs.data["offset_mapping"].tolist(),
+        )
 
-            del inputs.data["offset_mapping"]
+        ModalityType.TEXT.add_field_for_instances(
+            modality.metadata, "attention_masks", 
inputs.data["attention_mask"].tolist()
+        )
+        del inputs.data["offset_mapping"]
 
-            with torch.no_grad():
-                outputs = model(**inputs)
+        with torch.no_grad():
+            outputs = model(**inputs)
 
-                cls_embedding = outputs.last_hidden_state[0].numpy()
-                embeddings.append(cls_embedding)
+            cls_embedding = outputs.last_hidden_state.detach().numpy()
 
-        return embeddings
+        return cls_embedding
diff --git a/src/main/python/systemds/scuro/representations/bow.py 
b/src/main/python/systemds/scuro/representations/bow.py
index 6778811c49..7cfddbb506 100644
--- a/src/main/python/systemds/scuro/representations/bow.py
+++ b/src/main/python/systemds/scuro/representations/bow.py
@@ -18,7 +18,7 @@
 # under the License.
 #
 # -------------------------------------------------------------
-
+import numpy as np
 from sklearn.feature_extraction.text import CountVectorizer
 
 from systemds.scuro.modality.transformed import TransformedModality
@@ -50,5 +50,6 @@ class BoW(UnimodalRepresentation):
         if self.output_file is not None:
             save_embeddings(X, self.output_file)
 
+        transformed_modality.data_type = np.float32
         transformed_modality.data = X
         return transformed_modality
diff --git a/src/main/python/systemds/scuro/representations/concatenation.py 
b/src/main/python/systemds/scuro/representations/concatenation.py
index c7ce33ab5c..a4d4d53c43 100644
--- a/src/main/python/systemds/scuro/representations/concatenation.py
+++ b/src/main/python/systemds/scuro/representations/concatenation.py
@@ -20,7 +20,7 @@
 # -------------------------------------------------------------
 
 from typing import List
-
+import copy
 import numpy as np
 
 from systemds.scuro.modality.modality import Modality
@@ -33,14 +33,13 @@ from systemds.scuro.drsearch.operator_registry import 
register_fusion_operator
 
 @register_fusion_operator()
 class Concatenation(Fusion):
-    def __init__(self, padding=True):
+    def __init__(self):
         """
         Combines modalities using concatenation
         """
         super().__init__("Concatenation")
-        self.padding = padding
 
-    def transform(self, modalities: List[Modality]):
+    def execute(self, modalities: List[Modality]):
         if len(modalities) == 1:
             return np.array(modalities[0].data)
 
@@ -53,19 +52,6 @@ class Concatenation(Fusion):
             data = np.zeros((size, 0))
 
         for modality in modalities:
-            if self.padding:
-                data = np.concatenate(
-                    [
-                        data,
-                        pad_sequences(
-                            modality.data,
-                            maxlen=max_emb_size,
-                            dtype=modality.data.dtype,
-                        ),
-                    ],
-                    axis=-1,
-                )
-            else:
-                data = np.concatenate([data, modality.data], axis=-1)
+            data = np.concatenate([data, copy.deepcopy(modality.data)], 
axis=-1)
 
         return np.array(data)
diff --git a/src/main/python/systemds/scuro/representations/fusion.py 
b/src/main/python/systemds/scuro/representations/fusion.py
index cbbb5606e6..4b746eee21 100644
--- a/src/main/python/systemds/scuro/representations/fusion.py
+++ b/src/main/python/systemds/scuro/representations/fusion.py
@@ -21,9 +21,11 @@
 from typing import List
 
 import numpy as np
+from systemds.scuro import AggregatedRepresentation, Aggregation
 
 from systemds.scuro.modality.modality import Modality
 from systemds.scuro.representations.representation import Representation
+from systemds.scuro.utils.schema_helpers import get_shape
 
 
 class Fusion(Representation):
@@ -44,6 +46,21 @@ class Fusion(Representation):
         :param modalities: List of modalities used in the fusion
         :return: fused data
         """
+        mods = []
+        for modality in modalities:
+            agg_modality = None
+            if get_shape(modality.metadata) > 1:
+                agg_operator = AggregatedRepresentation(Aggregation())
+                agg_modality = agg_operator.transform(modality)
+            mods.append(agg_modality if agg_modality else modality)
+
+        if self.needs_alignment:
+            max_len = self.get_max_embedding_size(mods)
+            for modality in mods:
+                modality.pad(max_len=max_len)
+        return self.execute(mods)
+
+    def execute(self, modalities: List[Modality]):
         raise f"Not implemented for Fusion: {self.name}"
 
     def get_max_embedding_size(self, modalities: List[Modality]):
diff --git a/src/main/python/systemds/scuro/representations/glove.py 
b/src/main/python/systemds/scuro/representations/glove.py
index d948567f3f..9076efecfc 100644
--- a/src/main/python/systemds/scuro/representations/glove.py
+++ b/src/main/python/systemds/scuro/representations/glove.py
@@ -67,5 +67,6 @@ class GloVe(UnimodalRepresentation):
         if self.output_file is not None:
             save_embeddings(np.array(embeddings), self.output_file)
 
+        transformed_modality.data_type = np.float32
         transformed_modality.data = np.array(embeddings)
         return transformed_modality
diff --git a/src/main/python/systemds/scuro/representations/hadamard.py 
b/src/main/python/systemds/scuro/representations/hadamard.py
index 138003b874..a777768ff6 100644
--- a/src/main/python/systemds/scuro/representations/hadamard.py
+++ b/src/main/python/systemds/scuro/representations/hadamard.py
@@ -41,8 +41,7 @@ class Hadamard(Fusion):
         self.commutative = True
         self.associative = True
 
-    def transform(self, modalities: List[Modality], train_indices=None):
-        # TODO: check for alignment in the metadata
+    def execute(self, modalities: List[Modality], train_indices=None):
         fused_data = np.prod([m.data for m in modalities], axis=0)
 
         return fused_data
diff --git a/src/main/python/systemds/scuro/representations/lstm.py 
b/src/main/python/systemds/scuro/representations/lstm.py
index cbab0f6897..0cfafddefa 100644
--- a/src/main/python/systemds/scuro/representations/lstm.py
+++ b/src/main/python/systemds/scuro/representations/lstm.py
@@ -64,7 +64,7 @@ class LSTM(Fusion):
         result = np.zeros((size, 0))
 
         for modality in modalities:
-            if modality.modality_type in self.unimodal_embeddings.keys():
+            if modality.modality_type in list(self.unimodal_embeddings.keys()):
                 out = self.unimodal_embeddings.get(modality.modality_type)
             else:
                 out = self.run_lstm(modality.data)
diff --git a/src/main/python/systemds/scuro/representations/max.py 
b/src/main/python/systemds/scuro/representations/max.py
index 6ecf5fd52f..39f5069c2b 100644
--- a/src/main/python/systemds/scuro/representations/max.py
+++ b/src/main/python/systemds/scuro/representations/max.py
@@ -40,11 +40,9 @@ class RowMax(Fusion):
         self.associative = True
         self.commutative = True
 
-    def transform(
+    def execute(
         self,
         modalities: List[Modality],
     ):
-        # TODO: need to check if data is aligned - same number of dimension
         fused_data = np.maximum.reduce([m.data for m in modalities])
-
         return fused_data
diff --git a/src/main/python/systemds/scuro/representations/mel_spectrogram.py 
b/src/main/python/systemds/scuro/representations/mel_spectrogram.py
index 8c14c03ac6..8e897542b0 100644
--- a/src/main/python/systemds/scuro/representations/mel_spectrogram.py
+++ b/src/main/python/systemds/scuro/representations/mel_spectrogram.py
@@ -46,19 +46,18 @@ class MelSpectrogram(UnimodalRepresentation):
             modality, self, self.output_modality_type
         )
         result = []
-        max_length = 0
+
         for i, sample in enumerate(modality.data):
             sr = list(modality.metadata.values())[i]["frequency"]
             S = librosa.feature.melspectrogram(
-                y=sample,
+                y=np.array(sample),
                 sr=sr,
                 n_mels=self.n_mels,
                 hop_length=self.hop_length,
                 n_fft=self.n_fft,
-            )
+            ).astype(modality.data_type)
             S_dB = librosa.power_to_db(S, ref=np.max)
-            if S_dB.shape[-1] > max_length:
-                max_length = S_dB.shape[-1]
+
             result.append(S_dB.T)
 
         transformed_modality.data = result
diff --git a/src/main/python/systemds/scuro/representations/mfcc.py 
b/src/main/python/systemds/scuro/representations/mfcc.py
index 234e93246f..00f735a756 100644
--- a/src/main/python/systemds/scuro/representations/mfcc.py
+++ b/src/main/python/systemds/scuro/representations/mfcc.py
@@ -48,20 +48,19 @@ class MFCC(UnimodalRepresentation):
             modality, self, self.output_modality_type
         )
         result = []
-        max_length = 0
+
         for i, sample in enumerate(modality.data):
             sr = list(modality.metadata.values())[i]["frequency"]
             mfcc = librosa.feature.mfcc(
-                y=sample,
+                y=np.array(sample),
                 sr=sr,
                 n_mfcc=self.n_mfcc,
                 dct_type=self.dct_type,
                 hop_length=self.hop_length,
                 n_mels=self.n_mels,
-            )
+            ).astype(modality.data_type)
             mfcc = (mfcc - np.mean(mfcc)) / np.std(mfcc)
-            if mfcc.shape[-1] > max_length:  # TODO: check if this needs to be 
done
-                max_length = mfcc.shape[-1]
+
             result.append(mfcc.T)
 
         transformed_modality.data = result
diff --git a/src/main/python/systemds/scuro/representations/representation.py 
b/src/main/python/systemds/scuro/representations/representation.py
index a9f283b6fe..6137baf46d 100644
--- a/src/main/python/systemds/scuro/representations/representation.py
+++ b/src/main/python/systemds/scuro/representations/representation.py
@@ -32,7 +32,7 @@ class Representation:
 
     def get_current_parameters(self):
         current_params = {}
-        for parameter in self.parameters.keys():
+        for parameter in list(self.parameters.keys()):
             current_params[parameter] = getattr(self, parameter)
         return current_params
 
diff --git a/src/main/python/systemds/scuro/representations/spectrogram.py 
b/src/main/python/systemds/scuro/representations/spectrogram.py
index 6a713a3d21..8daa9abb01 100644
--- a/src/main/python/systemds/scuro/representations/spectrogram.py
+++ b/src/main/python/systemds/scuro/representations/spectrogram.py
@@ -41,15 +41,14 @@ class Spectrogram(UnimodalRepresentation):
             modality, self, self.output_modality_type
         )
         result = []
-        max_length = 0
+
         for i, sample in enumerate(modality.data):
             spectrogram = librosa.stft(
-                y=sample, hop_length=self.hop_length, n_fft=self.n_fft
-            )
+                y=np.array(sample), hop_length=self.hop_length, 
n_fft=self.n_fft
+            ).astype(modality.data_type)
             S_dB = librosa.amplitude_to_db(np.abs(spectrogram))
-            if S_dB.shape[-1] > max_length:
-                max_length = S_dB.shape[-1]
-            result.append(S_dB.T)
+
+            result.append(S_dB.T.reshape(-1))
 
         transformed_modality.data = result
         return transformed_modality
diff --git a/src/main/python/systemds/scuro/representations/sum.py 
b/src/main/python/systemds/scuro/representations/sum.py
index 46d93f2eda..5b3710b6e1 100644
--- a/src/main/python/systemds/scuro/representations/sum.py
+++ b/src/main/python/systemds/scuro/representations/sum.py
@@ -37,15 +37,12 @@ class Sum(Fusion):
         Combines modalities using colum-wise sum
         """
         super().__init__("Sum")
+        self.needs_alignment = True
 
-    def transform(self, modalities: List[Modality]):
-        max_emb_size = self.get_max_embedding_size(modalities)
-
-        data = pad_sequences(modalities[0].data, maxlen=max_emb_size, 
dtype="float32")
+    def execute(self, modalities: List[Modality]):
+        data = modalities[0].data
 
         for m in range(1, len(modalities)):
-            data += pad_sequences(
-                modalities[m].data, maxlen=max_emb_size, dtype="float32"
-            )
+            data += modalities[m].data
 
         return data
diff --git a/src/main/python/systemds/scuro/representations/tfidf.py 
b/src/main/python/systemds/scuro/representations/tfidf.py
index 1df5a1fde0..3b8f069df8 100644
--- a/src/main/python/systemds/scuro/representations/tfidf.py
+++ b/src/main/python/systemds/scuro/representations/tfidf.py
@@ -43,10 +43,10 @@ class TfIdf(UnimodalRepresentation):
         vectorizer = TfidfVectorizer(min_df=self.min_df)
 
         X = vectorizer.fit_transform(modality.data)
-        X = [np.array(x).reshape(1, -1) for x in X.toarray()]
-
+        X = [np.array(x).astype(np.float32).reshape(1, -1) for x in 
X.toarray()]
         if self.output_file is not None:
             save_embeddings(X, self.output_file)
 
+        transformed_modality.data_type = np.float32
         transformed_modality.data = X
         return transformed_modality
diff --git a/src/main/python/systemds/scuro/representations/wav2vec.py 
b/src/main/python/systemds/scuro/representations/wav2vec.py
index 29f5bcbea0..86145e3769 100644
--- a/src/main/python/systemds/scuro/representations/wav2vec.py
+++ b/src/main/python/systemds/scuro/representations/wav2vec.py
@@ -52,7 +52,9 @@ class Wav2Vec(UnimodalRepresentation):
         result = []
         for i, sample in enumerate(modality.data):
             sr = list(modality.metadata.values())[i]["frequency"]
-            audio_resampled = librosa.resample(sample, orig_sr=sr, 
target_sr=16000)
+            audio_resampled = librosa.resample(
+                np.array(sample), orig_sr=sr, target_sr=16000
+            )
             input = self.processor(
                 audio_resampled, sampling_rate=16000, return_tensors="pt", 
padding=True
             )
diff --git 
a/src/main/python/systemds/scuro/representations/window_aggregation.py 
b/src/main/python/systemds/scuro/representations/window_aggregation.py
index bff63729c7..167f4adafe 100644
--- a/src/main/python/systemds/scuro/representations/window_aggregation.py
+++ b/src/main/python/systemds/scuro/representations/window_aggregation.py
@@ -21,7 +21,7 @@
 import numpy as np
 import math
 
-from systemds.scuro.modality.type import DataLayout
+from systemds.scuro.modality.type import DataLayout, ModalityType
 
 from systemds.scuro.drsearch.operator_registry import register_context_operator
 from systemds.scuro.representations.aggregate import Aggregation
@@ -30,7 +30,7 @@ from systemds.scuro.representations.context import Context
 
 @register_context_operator()
 class WindowAggregation(Context):
-    def __init__(self, window_size=10, aggregation_function="mean"):
+    def __init__(self, window_size=10, aggregation_function="mean", pad=True):
         parameters = {
             "window_size": [window_size],
             "aggregation_function": 
list(Aggregation().get_aggregation_functions()),
@@ -38,6 +38,7 @@ class WindowAggregation(Context):
         super().__init__("WindowAggregation", parameters)
         self.window_size = window_size
         self.aggregation_function = aggregation_function
+        self.pad = pad
 
     @property
     def aggregation_function(self):
@@ -49,6 +50,7 @@ class WindowAggregation(Context):
 
     def execute(self, modality):
         windowed_data = []
+        original_lengths = []
         for instance in modality.data:
             new_length = math.ceil(len(instance) / self.window_size)
             if modality.get_data_layout() == DataLayout.SINGLE_LEVEL:
@@ -59,14 +61,53 @@ class WindowAggregation(Context):
                 windowed_instance = self.window_aggregate_nested_level(
                     instance, new_length
                 )
-
+            original_lengths.append(new_length)
             windowed_data.append(windowed_instance)
 
+        if self.pad and not isinstance(windowed_data, np.ndarray):
+            target_length = max(original_lengths)
+            sample_shape = windowed_data[0].shape
+            is_1d = len(sample_shape) == 1
+
+            padded_features = []
+            for i, features in enumerate(windowed_data):
+                current_len = original_lengths[i]
+
+                if current_len < target_length:
+                    padding_needed = target_length - current_len
+
+                    if is_1d:
+                        padding = np.zeros(padding_needed)
+                        padded = np.concatenate([features, padding])
+                    else:
+                        feature_dim = features.shape[-1]
+                        padding = np.zeros((padding_needed, feature_dim))
+                        padded = np.concatenate([features, padding], axis=0)
+
+                    padded_features.append(padded)
+                else:
+                    padded_features.append(features)
+
+            attention_masks = np.zeros((len(windowed_data), target_length))
+            for i, length in enumerate(original_lengths):
+                actual_length = min(length, target_length)
+                attention_masks[i, :actual_length] = 1
+
+            ModalityType(modality.modality_type).add_field_for_instances(
+                modality.metadata, "attention_masks", attention_masks
+            )
+
+            windowed_data = np.array(padded_features)
+            data_type = 
list(modality.metadata.values())[0]["data_layout"]["type"]
+            if data_type != "str":
+                windowed_data = windowed_data.astype(data_type)
+
         return windowed_data
 
     def window_aggregate_single_level(self, instance, new_length):
         if isinstance(instance, str):
             return instance
+        instance = np.array(instance)
         num_cols = instance.shape[1] if instance.ndim > 1 else 1
         result = np.empty((new_length, num_cols))
         for i in range(0, new_length):
@@ -86,4 +127,4 @@ class WindowAggregation(Context):
                 data[i * self.window_size : i * self.window_size + 
self.window_size]
             )
 
-        return result
+        return np.array(result)
diff --git a/src/main/python/systemds/scuro/representations/word2vec.py 
b/src/main/python/systemds/scuro/representations/word2vec.py
index 0210207a01..88d60ac828 100644
--- a/src/main/python/systemds/scuro/representations/word2vec.py
+++ b/src/main/python/systemds/scuro/representations/word2vec.py
@@ -65,9 +65,10 @@ class W2V(UnimodalRepresentation):
         embeddings = []
         for sentences in modality.data:
             tokens = list(tokenize(sentences.lower()))
-            embeddings.append(np.array(get_embedding(tokens, 
model)).reshape(1, -1))
+            embeddings.append(np.array(get_embedding(tokens, 
model)).astype(np.float32))
 
         if self.output_file is not None:
             save_embeddings(np.array(embeddings), self.output_file)
-        transformed_modality.data = embeddings
+        transformed_modality.data_type = np.float32
+        transformed_modality.data = np.array(embeddings)
         return transformed_modality
diff --git a/src/main/python/systemds/scuro/utils/schema_helpers.py 
b/src/main/python/systemds/scuro/utils/schema_helpers.py
index 28af476cca..3d1fbf4d71 100644
--- a/src/main/python/systemds/scuro/utils/schema_helpers.py
+++ b/src/main/python/systemds/scuro/utils/schema_helpers.py
@@ -40,3 +40,7 @@ def calculate_new_frequency(new_length, old_length, 
old_frequency):
     duration = old_length / old_frequency
     new_frequency = new_length / duration
     return new_frequency
+
+
+def get_shape(metadata):
+    return len(list(metadata.values())[0]["data_layout"]["shape"])
diff --git a/src/main/python/systemds/scuro/representations/representation.py 
b/src/main/python/systemds/scuro/utils/static_variables.py
similarity index 61%
copy from src/main/python/systemds/scuro/representations/representation.py
copy to src/main/python/systemds/scuro/utils/static_variables.py
index a9f283b6fe..8237cdf1b3 100644
--- a/src/main/python/systemds/scuro/representations/representation.py
+++ b/src/main/python/systemds/scuro/utils/static_variables.py
@@ -18,24 +18,19 @@
 # under the License.
 #
 # -------------------------------------------------------------
-from abc import abstractmethod
+import numpy as np
+import torch
 
+global_rng = np.random.default_rng(42)
 
-class Representation:
-    def __init__(self, name, parameters):
-        self.name = name
-        self._parameters = parameters
 
-    @property
-    def parameters(self):
-        return self._parameters
+def get_seed():
+    return global_rng.integers(0, 1024)
 
-    def get_current_parameters(self):
-        current_params = {}
-        for parameter in self.parameters.keys():
-            current_params[parameter] = getattr(self, parameter)
-        return current_params
 
-    def set_parameters(self, parameters):
-        for parameter in parameters:
-            setattr(self, parameter, parameters[parameter])
+def get_device():
+    return torch.device(
+        "cuda:0"
+        if torch.cuda.is_available()
+        else "mps" if torch.mps.is_available() else "cpu"
+    )
diff --git a/src/main/python/tests/scuro/data_generator.py 
b/src/main/python/tests/scuro/data_generator.py
index fbb50ac180..e57716fa99 100644
--- a/src/main/python/tests/scuro/data_generator.py
+++ b/src/main/python/tests/scuro/data_generator.py
@@ -93,10 +93,14 @@ class ModalityRandomDataGenerator:
         self.modality_id += 1
         return tf_modality
 
-    def create_audio_data(self, num_instances, num_features):
-        data = np.random.rand(num_instances, num_features).astype(np.float32)
+    def create_audio_data(self, num_instances, max_audio_length):
+        data = [
+            [random.random() for _ in range(random.randint(1, 
max_audio_length))]
+            for _ in range(num_instances)
+        ]
+
         metadata = {
-            i: ModalityType.AUDIO.create_audio_metadata(16000, data[i])
+            i: ModalityType.AUDIO.create_audio_metadata(16000, 
np.array(data[i]))
             for i in range(num_instances)
         }
 
@@ -165,26 +169,41 @@ class ModalityRandomDataGenerator:
 
         return sentences, metadata
 
-    def create_visual_modality(self, num_instances, num_frames=1, height=28, 
width=28):
-        if num_frames == 1:
+    def create_visual_modality(
+        self, num_instances, max_num_frames=1, height=28, width=28
+    ):
+        data = [
+            np.random.randint(
+                0,
+                256,
+                (np.random.randint(5, max_num_frames + 1), height, width, 3),
+                dtype=np.uint8,
+            )
+            for _ in range(num_instances)
+        ]
+        if max_num_frames == 1:
             print(f"TODO: create image metadata")
         else:
             metadata = {
                 i: ModalityType.VIDEO.create_video_metadata(
-                    30, num_frames, width, height, 1
+                    30, data[i].shape[0], width, height, 3
                 )
                 for i in range(num_instances)
             }
 
-        return (
-            np.random.randint(
-                0,
-                256,
-                (num_instances, num_frames, height, width),
-                # ).astype(np.float16).tolist(),
-            ).astype(np.float16),
-            metadata,
-        )
+        return (data, metadata)
+
+    def create_balanced_labels(self, num_instances, num_classes=2):
+        if num_instances % num_classes != 0:
+            raise ValueError("Size must be even to have equal numbers of 
classes.")
+
+        class_size = int(num_instances / num_classes)
+        vector = np.array([0] * class_size)
+        for i in range(num_classes - 1):
+            vector = np.concatenate((vector, np.array([1] * class_size)))
+
+        np.random.shuffle(vector)
+        return vector
 
 
 def setup_data(modalities, num_instances, path):
diff --git a/src/main/python/tests/scuro/test_dr_search.py 
b/src/main/python/tests/scuro/test_dr_search.py
index 50f57eebb2..3e0e702e6f 100644
--- a/src/main/python/tests/scuro/test_dr_search.py
+++ b/src/main/python/tests/scuro/test_dr_search.py
@@ -94,7 +94,9 @@ class TestDataLoaders(unittest.TestCase):
         cls.num_instances = 20
         cls.data_generator = ModalityRandomDataGenerator()
 
-        cls.labels = np.random.choice([0, 1], size=cls.num_instances)
+        cls.labels = ModalityRandomDataGenerator().create_balanced_labels(
+            num_instances=cls.num_instances
+        )
         # TODO: adapt the representation so they return non aggregated values. 
Apply windowing operation instead
 
         cls.video = cls.data_generator.create1DModality(
diff --git a/src/main/python/tests/scuro/test_fusion_orders.py 
b/src/main/python/tests/scuro/test_fusion_orders.py
index eb01d18ffe..22d64bcc0b 100644
--- a/src/main/python/tests/scuro/test_fusion_orders.py
+++ b/src/main/python/tests/scuro/test_fusion_orders.py
@@ -65,7 +65,7 @@ class TestFusionOrders(unittest.TestCase):
 
         self.assertFalse(np.array_equal(r_1_r_2.data, r_2_r_1.data))
         self.assertFalse(np.array_equal(r_1_r_2_r_3.data, r_2_r_1_r_3.data))
-        self.assertFalse(np.array_equal(r_1_r_2_r_3.data, r1_r2_r3.data))
+        self.assertFalse(np.array_equal(r_2_r_1.data, r1_r2_r3.data))
         self.assertFalse(np.array_equal(r_1_r_2.data, r1_r2_r3.data))
 
     def test_fusion_order_max(self):
diff --git a/src/main/python/tests/scuro/test_multimodal_fusion.py 
b/src/main/python/tests/scuro/test_multimodal_fusion.py
index 77f03054eb..ae3ddedffb 100644
--- a/src/main/python/tests/scuro/test_multimodal_fusion.py
+++ b/src/main/python/tests/scuro/test_multimodal_fusion.py
@@ -22,12 +22,15 @@
 
 import shutil
 import unittest
+from multiprocessing import freeze_support
 
 import numpy as np
 from sklearn import svm
 from sklearn.metrics import classification_report
 from sklearn.model_selection import train_test_split
 
+from systemds.scuro.drsearch.multimodal_optimizer import MultimodalOptimizer
+from systemds.scuro.drsearch.unimodal_optimizer import UnimodalOptimizer
 from systemds.scuro.representations.concatenation import Concatenation
 from systemds.scuro.representations.average import Average
 from systemds.scuro.drsearch.fusion_optimizer import FusionOptimizer
@@ -115,7 +118,9 @@ class 
TestMultimodalRepresentationOptimizer(unittest.TestCase):
     def setUpClass(cls):
         cls.num_instances = 10
         cls.mods = [ModalityType.VIDEO, ModalityType.AUDIO, ModalityType.TEXT]
-        cls.labels = np.random.choice([0, 1], size=cls.num_instances)
+        cls.labels = ModalityRandomDataGenerator().create_balanced_labels(
+            num_instances=cls.num_instances
+        )
         cls.indices = np.array(range(cls.num_instances))
 
         split = train_test_split(
@@ -123,31 +128,15 @@ class 
TestMultimodalRepresentationOptimizer(unittest.TestCase):
             cls.labels,
             test_size=0.2,
             random_state=42,
+            stratify=cls.labels,
         )
         cls.train_indizes, cls.val_indizes = [int(i) for i in split[0]], [
             int(i) for i in split[1]
         ]
 
-        cls.tasks = [
-            Task(
-                "UnimodalRepresentationTask1",
-                TestSVM(),
-                cls.labels,
-                cls.train_indizes,
-                cls.val_indizes,
-            ),
-            Task(
-                "UnimodalRepresentationTask2",
-                TestCNN(),
-                cls.labels,
-                cls.train_indizes,
-                cls.val_indizes,
-            ),
-        ]
-
     def test_multimodal_fusion(self):
         task = Task(
-            "UnimodalRepresentationTask1",
+            "MM_Fusion_Task1",
             TestSVM(),
             self.labels,
             self.train_indizes,
@@ -192,22 +181,47 @@ class 
TestMultimodalRepresentationOptimizer(unittest.TestCase):
         ):
             registry = Registry()
             registry._fusion_operators = [Average, Concatenation]
-            unimodal_optimizer = UnimodalRepresentationOptimizer(
-                [text, audio, video], [task], max_chain_depth=2
+            unimodal_optimizer = UnimodalOptimizer(
+                [audio, text, video], [task], debug=False
             )
             unimodal_optimizer.optimize()
+            unimodal_optimizer.operator_performance.get_k_best_results(audio, 
2, task)
 
-            multimodal_optimizer = FusionOptimizer(
+            multimodal_optimizer = MultimodalOptimizer(
                 [audio, text, video],
-                task,
-                unimodal_optimizer.optimization_results,
-                unimodal_optimizer.cache,
-                2,
-                2,
+                unimodal_optimizer.operator_performance,
+                [task],
                 debug=False,
             )
+
             multimodal_optimizer.optimize()
 
+            assert (
+                
len(multimodal_optimizer.optimization_results.results["TestSVM"].keys())
+                == 57
+            )
+            assert (
+                len(
+                    
multimodal_optimizer.optimization_results.results["TestSVM"][
+                        "0_1_2_3_4_5"
+                    ]
+                )
+                == 62
+            )
+            assert (
+                len(
+                    
multimodal_optimizer.optimization_results.results["TestSVM"][
+                        "3_4_5"
+                    ]
+                )
+                == 6
+            )
+            assert (
+                
len(multimodal_optimizer.optimization_results.results["TestSVM"]["0_1"])
+                == 2
+            )
+
 
 if __name__ == "__main__":
+    freeze_support()
     unittest.main()
diff --git a/src/main/python/tests/scuro/test_multimodal_join.py 
b/src/main/python/tests/scuro/test_multimodal_join.py
index 9e3a16ffca..5fd22dc8d9 100644
--- a/src/main/python/tests/scuro/test_multimodal_join.py
+++ b/src/main/python/tests/scuro/test_multimodal_join.py
@@ -20,7 +20,6 @@
 
 # TODO: Test edge cases: unequal number of audio-video timestamps (should 
still work and add the average over all audio/video samples)
 
-import shutil
 import unittest
 
 import numpy as np
@@ -30,9 +29,6 @@ from systemds.scuro.modality.unimodal_modality import 
UnimodalModality
 from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
 from systemds.scuro.representations.resnet import ResNet
 from tests.scuro.data_generator import TestDataLoader, 
ModalityRandomDataGenerator
-
-from systemds.scuro.dataloader.audio_loader import AudioLoader
-from systemds.scuro.dataloader.video_loader import VideoLoader
 from systemds.scuro.modality.type import ModalityType
 
 
diff --git a/src/main/python/tests/scuro/test_unimodal_optimizer.py 
b/src/main/python/tests/scuro/test_unimodal_optimizer.py
index 9ed034e5fe..a73d7b5fcc 100644
--- a/src/main/python/tests/scuro/test_unimodal_optimizer.py
+++ b/src/main/python/tests/scuro/test_unimodal_optimizer.py
@@ -20,7 +20,6 @@
 # -------------------------------------------------------------
 
 
-import shutil
 import unittest
 
 import numpy as np
@@ -31,8 +30,8 @@ from sklearn.model_selection import train_test_split
 from systemds.scuro.drsearch.operator_registry import Registry
 from systemds.scuro.models.model import Model
 from systemds.scuro.drsearch.task import Task
-from systemds.scuro.drsearch.unimodal_representation_optimizer import (
-    UnimodalRepresentationOptimizer,
+from systemds.scuro.drsearch.unimodal_optimizer import (
+    UnimodalOptimizer,
 )
 
 from systemds.scuro.representations.spectrogram import Spectrogram
@@ -41,9 +40,6 @@ from systemds.scuro.modality.unimodal_modality import 
UnimodalModality
 from systemds.scuro.representations.resnet import ResNet
 from tests.scuro.data_generator import ModalityRandomDataGenerator, 
TestDataLoader
 
-from systemds.scuro.dataloader.audio_loader import AudioLoader
-from systemds.scuro.dataloader.video_loader import VideoLoader
-from systemds.scuro.dataloader.text_loader import TextLoader
 from systemds.scuro.modality.type import ModalityType
 
 
@@ -108,7 +104,9 @@ class 
TestUnimodalRepresentationOptimizer(unittest.TestCase):
     def setUpClass(cls):
         cls.num_instances = 10
         cls.mods = [ModalityType.VIDEO, ModalityType.AUDIO, ModalityType.TEXT]
-        cls.labels = np.random.choice([0, 1], size=cls.num_instances)
+        cls.labels = ModalityRandomDataGenerator().create_balanced_labels(
+            num_instances=cls.num_instances
+        )
         cls.indices = np.array(range(cls.num_instances))
 
         split = train_test_split(
@@ -186,24 +184,19 @@ class 
TestUnimodalRepresentationOptimizer(unittest.TestCase):
         ):
             registry = Registry()
 
-            unimodal_optimizer = UnimodalRepresentationOptimizer(
-                [modality], self.tasks, max_chain_depth=2
-            )
+            unimodal_optimizer = UnimodalOptimizer([modality], self.tasks, 
False)
             unimodal_optimizer.optimize()
 
             assert (
-                list(unimodal_optimizer.optimization_results.keys())[0]
+                unimodal_optimizer.operator_performance.modality_ids[0]
                 == modality.modality_id
             )
-            assert 
len(list(unimodal_optimizer.optimization_results.values())[0]) == 2
-            assert (
-                len(
-                    unimodal_optimizer.get_k_best_results(modality, 1, 
self.tasks[0])[
-                        0
-                    ].operator_chain
-                )
-                >= 1
+            assert len(unimodal_optimizer.operator_performance.task_names) == 2
+            result, cached = 
unimodal_optimizer.operator_performance.get_k_best_results(
+                modality, 1, self.tasks[0]
             )
+            assert len(result) == 1
+            assert len(cached) == 1
 
 
 if __name__ == "__main__":


Reply via email to