This is an automated email from the ASF dual-hosted git repository.

cdionysio pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 23b3d08eb1 [SYSTEMDS-3887] Create representation optimizer
23b3d08eb1 is described below

commit 23b3d08eb18305612ce038edd0fbd50a011aba9b
Author: Christina Dionysio <diony...@tu-berlin.de>
AuthorDate: Wed May 28 10:52:52 2025 +0200

    [SYSTEMDS-3887] Create representation optimizer
    
    This patch adds an initial version of the representation optimizer for the 
Scuro library. It is a two stage optimization where in the first step the best 
unimodal representation for given raw modalities is found and in the next step 
the k-best unimodal rerpesentations are combined into multimodal 
representations and evaluated against the target downstream task. Additionally, 
this patch adds tests for each stage of the optimizer.
    
    Closes #2267
---
 src/main/python/systemds/scuro/__init__.py         |  83 ++++--
 .../python/systemds/scuro/aligner/alignment.py     |  48 ----
 .../systemds/scuro/dataloader/audio_loader.py      |  11 +-
 .../scuro/{aligner => drsearch}/__init__.py        |   0
 .../scuro/{aligner => drsearch}/dr_search.py       |   4 +-
 .../systemds/scuro/drsearch/fusion_optimizer.py    | 295 +++++++++++++++++++++
 .../scuro/drsearch/hyperparameter_tuner.py         | 106 ++++++++
 .../systemds/scuro/drsearch/operator_registry.py   | 107 ++++++++
 .../systemds/scuro/drsearch/optimization_data.py   | 164 ++++++++++++
 .../scuro/drsearch/representation_cache.py         | 127 +++++++++
 .../{aligner => drsearch}/similarity_measures.py   |   0
 .../systemds/scuro/{aligner => drsearch}/task.py   |  22 +-
 .../drsearch/unimodal_representation_optimizer.py  | 271 +++++++++++++++++++
 src/main/python/systemds/scuro/main.py             |   4 +-
 src/main/python/systemds/scuro/modality/joined.py  |   6 +-
 .../python/systemds/scuro/modality/modality.py     |   2 +-
 .../systemds/scuro/modality/modality_identifier.py |   7 -
 .../python/systemds/scuro/modality/transformed.py  |   5 +-
 .../systemds/scuro/modality/unimodal_modality.py   |   1 -
 .../systemds/scuro/representations/aggregate.py    |  30 ++-
 .../aggregated_representation.py}                  |  29 +-
 .../systemds/scuro/representations/average.py      |   5 +
 .../python/systemds/scuro/representations/bert.py  |  11 +-
 .../python/systemds/scuro/representations/bow.py   |   2 +
 .../scuro/representations/concatenation.py         |   3 +
 .../systemds/scuro/representations/context.py      |   1 -
 .../python/systemds/scuro/representations/glove.py |   4 +-
 .../python/systemds/scuro/representations/lstm.py  |   3 +
 .../python/systemds/scuro/representations/max.py   |   3 +
 .../scuro/representations/mel_spectrogram.py       |  13 +-
 .../{mel_spectrogram.py => mfcc.py}                |  38 ++-
 .../scuro/representations/multiplication.py        |   3 +
 .../systemds/scuro/representations/optical_flow.py |  79 ++++++
 .../systemds/scuro/representations/resnet.py       |  84 ++----
 .../systemds/scuro/representations/rowmax.py       |  78 ------
 .../{mel_spectrogram.py => spectrogram.py}         |  23 +-
 .../python/systemds/scuro/representations/sum.py   |   3 +
 .../representations/swin_video_transformer.py      | 111 ++++++++
 .../python/systemds/scuro/representations/tfidf.py |   2 +
 .../{mel_spectrogram.py => wav2vec.py}             |  52 ++--
 .../systemds/scuro/representations/window.py       |   6 +-
 .../systemds/scuro/representations/word2vec.py     |   6 +-
 .../scuro/representations/{resnet.py => x3d.py}    | 123 +++------
 .../python/systemds/scuro/utils/schema_helpers.py  |   1 -
 .../python/systemds/scuro/utils/torch_dataset.py   |  63 +++++
 src/main/python/tests/scuro/data_generator.py      |  12 +-
 src/main/python/tests/scuro/test_dr_search.py      |   4 +-
 .../python/tests/scuro/test_multimodal_fusion.py   | 202 ++++++++++++++
 .../python/tests/scuro/test_multimodal_join.py     |   2 -
 .../python/tests/scuro/test_operator_registry.py   |  87 ++++++
 .../python/tests/scuro/test_unimodal_optimizer.py  | 203 ++++++++++++++
 51 files changed, 2133 insertions(+), 416 deletions(-)

diff --git a/src/main/python/systemds/scuro/__init__.py 
b/src/main/python/systemds/scuro/__init__.py
index 53b68d430f..4b2185316a 100644
--- a/src/main/python/systemds/scuro/__init__.py
+++ b/src/main/python/systemds/scuro/__init__.py
@@ -24,27 +24,55 @@ from systemds.scuro.dataloader.video_loader import 
VideoLoader
 from systemds.scuro.dataloader.text_loader import TextLoader
 from systemds.scuro.dataloader.json_loader import JSONLoader
 from systemds.scuro.representations.representation import Representation
+from systemds.scuro.representations.aggregate import Aggregation
+from systemds.scuro.representations.aggregated_representation import (
+    AggregatedRepresentation,
+)
 from systemds.scuro.representations.average import Average
+from systemds.scuro.representations.bert import Bert
+from systemds.scuro.representations.bow import BoW
 from systemds.scuro.representations.concatenation import Concatenation
-from systemds.scuro.representations.sum import Sum
+from systemds.scuro.representations.context import Context
+from systemds.scuro.representations.fusion import Fusion
+from systemds.scuro.representations.glove import GloVe
+from systemds.scuro.representations.lstm import LSTM
 from systemds.scuro.representations.max import RowMax
-from systemds.scuro.representations.multiplication import Multiplication
 from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
+from systemds.scuro.representations.mfcc import MFCC
+from systemds.scuro.representations.multiplication import Multiplication
+from systemds.scuro.representations.optical_flow import OpticalFlow
+from systemds.scuro.representations.representation import Representation
+from systemds.scuro.representations.representation_dataloader import NPY
+from systemds.scuro.representations.representation_dataloader import JSON
+from systemds.scuro.representations.representation_dataloader import Pickle
 from systemds.scuro.representations.resnet import ResNet
-from systemds.scuro.representations.bert import Bert
-from systemds.scuro.representations.lstm import LSTM
-from systemds.scuro.representations.bow import BoW
-from systemds.scuro.representations.glove import GloVe
+from systemds.scuro.representations.spectrogram import Spectrogram
+from systemds.scuro.representations.sum import Sum
+from systemds.scuro.representations.swin_video_transformer import 
SwinVideoTransformer
 from systemds.scuro.representations.tfidf import TfIdf
+from systemds.scuro.representations.unimodal import UnimodalRepresentation
+from systemds.scuro.representations.wav2vec import Wav2Vec
+from systemds.scuro.representations.window import WindowAggregation
 from systemds.scuro.representations.word2vec import W2V
+from systemds.scuro.representations.x3d import X3D
 from systemds.scuro.models.model import Model
 from systemds.scuro.models.discrete_model import DiscreteModel
+from systemds.scuro.modality.joined import JoinedModality
+from systemds.scuro.modality.joined_transformed import 
JoinedTransformedModality
 from systemds.scuro.modality.modality import Modality
-from systemds.scuro.modality.unimodal_modality import UnimodalModality
+from systemds.scuro.modality.modality_identifier import ModalityIdentifier
 from systemds.scuro.modality.transformed import TransformedModality
 from systemds.scuro.modality.type import ModalityType
-from systemds.scuro.aligner.dr_search import DRSearch
-from systemds.scuro.aligner.task import Task
+from systemds.scuro.modality.unimodal_modality import UnimodalModality
+from systemds.scuro.drsearch.dr_search import DRSearch
+from systemds.scuro.drsearch.task import Task
+from systemds.scuro.drsearch.fusion_optimizer import FusionOptimizer
+from systemds.scuro.drsearch.operator_registry import Registry
+from systemds.scuro.drsearch.optimization_data import OptimizationData
+from systemds.scuro.drsearch.representation_cache import RepresentationCache
+from systemds.scuro.drsearch.unimodal_representation_optimizer import (
+    UnimodalRepresentationOptimizer,
+)
 
 
 __all__ = [
@@ -53,25 +81,50 @@ __all__ = [
     "VideoLoader",
     "TextLoader",
     "Representation",
+    "Aggregation",
+    "AggregatedRepresentation",
     "Average",
+    "Bert",
+    "BoW",
     "Concatenation",
-    "Sum",
+    "Context",
+    "Fusion",
+    "GloVe",
+    "LSTM",
     "RowMax",
-    "Multiplication",
     "MelSpectrogram",
+    "MFCC",
+    "Multiplication",
+    "OpticalFlow",
+    "Representation",
+    "NPY",
+    "JSON",
+    "Pickle",
     "ResNet",
-    "Bert",
-    "LSTM",
+    "Spectrogram",
+    "Sum",
     "BoW",
-    "GloVe",
+    "SwinVideoTransformer",
     "TfIdf",
+    "UnimodalRepresentation",
+    "Wav2Vec",
+    "WindowAggregation",
     "W2V",
+    "X3D",
     "Model",
     "DiscreteModel",
+    "JoinedModality",
+    "JoinedTransformedModality",
     "Modality",
-    "UnimodalModality",
+    "ModalityIdentifier",
     "TransformedModality",
     "ModalityType",
+    "UnimodalModality",
     "DRSearch",
     "Task",
+    "FusionOptimizer",
+    "Registry",
+    "OptimizationData",
+    "RepresentationCache",
+    "UnimodalRepresentationOptimizer",
 ]
diff --git a/src/main/python/systemds/scuro/aligner/alignment.py 
b/src/main/python/systemds/scuro/aligner/alignment.py
deleted file mode 100644
index 62f88a272b..0000000000
--- a/src/main/python/systemds/scuro/aligner/alignment.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# -------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-# -------------------------------------------------------------
-from aligner.alignment_strategy import AlignmentStrategy
-from modality.modality import Modality
-from modality.representation import Representation
-from aligner.similarity_measures import Measure
-
-
-class Alignment:
-    def __init__(
-        self,
-        modality_a: Modality,
-        modality_b: Modality,
-        strategy: AlignmentStrategy,
-        similarity_measure: Measure,
-    ):
-        """
-        Defines the core of the library where the alignment of two modalities 
is performed
-        :param modality_a: first modality
-        :param modality_b: second modality
-        :param strategy: the alignment strategy used in the alignment process
-        :param similarity_measure: the similarity measure used to check the 
score of the alignment
-        """
-        self.modality_a = modality_a
-        self.modality_b = modality_b
-        self.strategy = strategy
-        self.similarity_measure = similarity_measure
-
-    def align_modalities(self) -> Modality:
-        return Modality(Representation())
diff --git a/src/main/python/systemds/scuro/dataloader/audio_loader.py 
b/src/main/python/systemds/scuro/dataloader/audio_loader.py
index a6a164b4fb..a008962680 100644
--- a/src/main/python/systemds/scuro/dataloader/audio_loader.py
+++ b/src/main/python/systemds/scuro/dataloader/audio_loader.py
@@ -27,13 +27,22 @@ from systemds.scuro.modality.type import ModalityType
 
 class AudioLoader(BaseLoader):
     def __init__(
-        self, source_path: str, indices: List[str], chunk_size: Optional[int] 
= None
+        self,
+        source_path: str,
+        indices: List[str],
+        chunk_size: Optional[int] = None,
+        normalize: bool = True,
     ):
         super().__init__(source_path, indices, chunk_size, ModalityType.AUDIO)
+        self.normalize = normalize
 
     def extract(self, file: str, index: Optional[Union[str, List[str]]] = 
None):
         self.file_sanity_check(file)
         audio, sr = librosa.load(file)
+
+        if self.normalize:
+            audio = librosa.util.normalize(audio)
+
         self.metadata[file] = self.modality_type.create_audio_metadata(sr, 
audio)
 
         self.data.append(audio)
diff --git a/src/main/python/systemds/scuro/aligner/__init__.py 
b/src/main/python/systemds/scuro/drsearch/__init__.py
similarity index 100%
rename from src/main/python/systemds/scuro/aligner/__init__.py
rename to src/main/python/systemds/scuro/drsearch/__init__.py
diff --git a/src/main/python/systemds/scuro/aligner/dr_search.py 
b/src/main/python/systemds/scuro/drsearch/dr_search.py
similarity index 98%
rename from src/main/python/systemds/scuro/aligner/dr_search.py
rename to src/main/python/systemds/scuro/drsearch/dr_search.py
index b46139dff3..2000608a1d 100644
--- a/src/main/python/systemds/scuro/aligner/dr_search.py
+++ b/src/main/python/systemds/scuro/drsearch/dr_search.py
@@ -22,7 +22,7 @@ import itertools
 import random
 from typing import List
 
-from systemds.scuro.aligner.task import Task
+from systemds.scuro.drsearch.task import Task
 from systemds.scuro.modality.modality import Modality
 from systemds.scuro.representations.representation import Representation
 
@@ -111,7 +111,7 @@ class DRSearch:
         representation = random.choice(self.representations)
 
         modality = modality_combination[0].combine(
-            modality_combination[1:], representation
+            list(modality_combination[1:]), representation
         )
 
         scores = self.task.run(modality.data)
diff --git a/src/main/python/systemds/scuro/drsearch/fusion_optimizer.py 
b/src/main/python/systemds/scuro/drsearch/fusion_optimizer.py
new file mode 100644
index 0000000000..7247720f55
--- /dev/null
+++ b/src/main/python/systemds/scuro/drsearch/fusion_optimizer.py
@@ -0,0 +1,295 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import time
+import copy
+import pickle
+from systemds.scuro.drsearch.operator_registry import Registry
+from systemds.scuro.drsearch.optimization_data import (
+    OptimizationResult,
+    OptimizationStatistics,
+)
+from systemds.scuro.drsearch.representation_cache import RepresentationCache
+from systemds.scuro.drsearch.task import Task
+from systemds.scuro.representations.aggregate import Aggregation
+from systemds.scuro.representations.context import Context
+
+
+def extract_names(operator_chain):
+    result = []
+    for op in operator_chain:
+        result.append(op.name)
+
+    return result
+
+
+class FusionOptimizer:
+    def __init__(
+        self,
+        modalities,
+        task: Task,
+        unimodal_representations_candidates,
+        representation_cache: RepresentationCache,
+        num_best_candidates=4,
+        max_chain_depth=5,
+        debug=False,
+    ):
+        self.modalities = modalities
+        self.task = task
+        self.unimodal_representations_candidates = 
unimodal_representations_candidates
+        self.num_best_candidates = num_best_candidates
+        self.k_best_candidates, self.candidates_per_modality = 
self.get_k_best_results(
+            num_best_candidates
+        )
+        self.operator_registry = Registry()
+        self.max_chain_depth = max_chain_depth
+        self.debug = debug
+        self.evaluated_candidates = set()
+        self.cache = representation_cache
+        self.optimization_statistics = 
OptimizationStatistics(self.k_best_candidates)
+        self.optimization_results = []
+
+    def optimize(self):
+        """
+        This method finds different ways in how to combine modalities and 
evaluates the fused representations against
+        the given task. It can fuse different representations from the same 
modality as well as fuse representations
+        form different modalities.
+        """
+
+        # TODO: add an aligned representation for all modalities with a 
temporal dimension
+        # TODO: keep a map of operator chains so that we don't evaluate them 
multiple times in different orders (if it does not make a difference)
+
+        r = []
+
+        for candidate in self.k_best_candidates:
+            modality = self.candidates_per_modality[str(candidate)]
+            cached_representation, representation_ops, used_op_names = (
+                self.cache.load_from_cache(modality, candidate.operator_chain)
+            )
+            if cached_representation is not None:
+                modality = cached_representation
+            store = False
+            for representation in representation_ops:
+                if isinstance(representation, Context):
+                    modality = modality.context(representation)
+                elif representation.name == "RowWiseConcatenation":
+                    modality = modality.flatten(True)
+                else:
+                    modality = modality.apply_representation(representation)
+                store = True
+            if store:
+                self.cache.save_to_cache(modality, used_op_names, 
representation_ops)
+
+            remaining_candidates = [c for c in self.k_best_candidates if c != 
candidate]
+            r.append(
+                self._optimize_candidate(modality, candidate, 
remaining_candidates, 1)
+            )
+
+        if self.debug:
+            with open(
+                
f"fusion_statistics_{self.task.model.name}_{self.num_best_candidates}_{self.max_chain_depth}.pkl",
+                "wb",
+            ) as fp:
+                pickle.dump(
+                    self.optimization_statistics,
+                    fp,
+                    protocol=pickle.HIGHEST_PROTOCOL,
+                )
+
+            opt_results = copy.deepcopy(self.optimization_results)
+            for i, opt_res in enumerate(self.optimization_results):
+                op_name = []
+                for op in opt_res.operator_chain:
+                    if isinstance(op, list):
+                        for o in op:
+                            if isinstance(o, list):
+                                for j in o:
+                                    op_name.append(j.name)
+                            elif isinstance(o, str):
+                                op_name.append(o)
+                            else:
+                                op_name.append(o.name)
+                    elif isinstance(op, str):
+                        op_name.append(op)
+                    else:
+                        op_name.append(op.name)
+                opt_results[i].operator_chain = op_name
+            with open(
+                
f"fusion_results_{self.task.model.name}_{self.num_best_candidates}_{self.max_chain_depth}.pkl",
+                "wb",
+            ) as fp:
+                pickle.dump(opt_results, fp, protocol=pickle.HIGHEST_PROTOCOL)
+
+            self.optimization_statistics.print_statistics()
+
+    def get_k_best_results(self, k: int):
+        """
+        Get the k best results per modality
+        :param k: number of best results
+        """
+        best_results = []
+        candidate_for_modality = {}
+        for modality in self.modalities:
+            k_results = sorted(
+                self.unimodal_representations_candidates[modality.modality_id][
+                    self.task.model.name
+                ],
+                key=lambda x: x.test_accuracy,
+                reverse=True,
+            )[:k]
+            for k_result in k_results:
+                candidate_for_modality[str(k_result)] = modality
+            best_results.extend(k_results)
+
+        return best_results, candidate_for_modality
+
+    def _optimize_candidate(
+        self, modality, candidate, remaining_candidates, chain_depth
+    ):
+        """
+        Optimize a single candidate by fusing it with others recursively.
+
+        :param candidate: The current candidate representation.
+        :param chain_depth: The current depth of fusion chains.
+        """
+        if chain_depth > self.max_chain_depth:
+            return
+
+        for other_candidate in remaining_candidates:
+            other_modality = self.candidates_per_modality[str(other_candidate)]
+            cached_representation, representation_ops, used_op_names = (
+                self.cache.load_from_cache(
+                    other_modality, other_candidate.operator_chain
+                )
+            )
+            if cached_representation is not None:
+                other_modality = cached_representation
+            store = False
+            for representation in representation_ops:
+                if representation.name == "Aggregation":
+                    params = other_candidate.parameters[representation.name]
+                    representation = Aggregation(
+                        aggregation_function=params["aggregation"]
+                    )
+                if isinstance(representation, Context):
+                    other_modality = other_modality.context(representation)
+                elif isinstance(representation, Aggregation):
+                    other_modality = representation.execute(other_modality)
+                elif representation.name == "RowWiseConcatenation":
+                    other_modality = other_modality.flatten(True)
+                else:
+                    other_modality = 
other_modality.apply_representation(representation)
+                store = True
+            if store:
+                self.cache.save_to_cache(
+                    other_modality, used_op_names, representation_ops
+                )
+
+            fusion_results = self.operator_registry.get_fusion_operators()
+            fusion_representation = None
+            for fusion_operator in fusion_results:
+                fusion_operator = fusion_operator()
+                chain_key = self.create_identifier(
+                    candidate, fusion_operator, other_candidate
+                )
+                # print(fusion_operator.name)
+                representation_start = time.time()
+                if (
+                    isinstance(fusion_operator, Context)
+                    and fusion_representation is not None
+                ):
+                    fusion_representation.context(fusion_operator)
+                elif isinstance(fusion_operator, Context):
+                    continue
+                else:
+                    fused_representation = modality.combine(
+                        other_modality, fusion_operator
+                    )
+
+                representation_end = time.time()
+                if chain_key not in self.evaluated_candidates:
+                    # Evaluate the fused representation
+
+                    score = self.task.run(fused_representation.data)
+                    fusion_params = {fusion_operator.name: 
fusion_operator.parameters}
+                    result = OptimizationResult(
+                        operator_chain=[
+                            candidate.operator_chain,
+                            fusion_operator.name,
+                            other_candidate.operator_chain,
+                        ],
+                        parameters=[
+                            candidate.parameters,
+                            fusion_params,
+                            other_candidate.parameters,
+                        ],
+                        train_accuracy=score[0],
+                        test_accuracy=score[1],
+                        # train_min_it_acc=score[2],
+                        # test_min_it_acc=score[3],
+                        training_runtime=self.task.training_time,
+                        inference_runtime=self.task.inference_time,
+                        representation_time=representation_end - 
representation_start,
+                        output_shape=(1, 1),  # TODO
+                    )
+
+                    # Store the result
+                    self.optimization_results.append(result)
+                    self.optimization_statistics.add_entry(
+                        [
+                            candidate.operator_chain,
+                            [fusion_operator.name],
+                            other_candidate.operator_chain,
+                        ],
+                        score[1],
+                    )
+
+                    # Mark this chain as evaluated
+                    self.evaluated_candidates.add(chain_key)
+
+                    if self.debug:
+                        print(
+                            f"Evaluated chain: {candidate.operator_chain} + 
{fusion_operator.name} + {other_candidate.operator_chain} -> {score[1]}"
+                        )
+
+                    # Recursively optimize further with this fused 
representation
+                    self._optimize_candidate(
+                        fused_representation,
+                        result,
+                        [c for c in remaining_candidates if c != 
other_candidate],
+                        chain_depth + 1,
+                    )
+
+    def create_identifier(self, candidate, fusion, other_candidate):
+        identifier = "".join(flatten_and_join(candidate.operator_chain))
+        identifier += fusion.name
+        identifier += "".join(flatten_and_join(other_candidate.operator_chain))
+
+        return identifier
+
+
+def flatten_and_join(data):
+    flat_list = []
+    for item in data:
+        if isinstance(item, list):
+            flat_list.extend(flatten_and_join(item))
+        else:
+            flat_list.append(item.name if not isinstance(item, str) else item)
+    return flat_list
diff --git a/src/main/python/systemds/scuro/drsearch/hyperparameter_tuner.py 
b/src/main/python/systemds/scuro/drsearch/hyperparameter_tuner.py
new file mode 100644
index 0000000000..04a3fa4701
--- /dev/null
+++ b/src/main/python/systemds/scuro/drsearch/hyperparameter_tuner.py
@@ -0,0 +1,106 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import itertools
+import time
+
+import numpy as np
+
+from systemds.scuro.drsearch.optimization_data import OptimizationResult
+from systemds.scuro.representations.context import Context
+
+
+class HyperparameterTuner:
+    def __init__(self, task, n_trials=10, early_stopping_patience=5):
+        self.task = task
+        self.n_trials = n_trials
+        self.early_stopping_patience = early_stopping_patience
+
+    def tune_operator_chain(self, modality, operator_chain):
+        best_result = None
+        best_score = -np.inf
+
+        param_grids = {}
+
+        for operator in operator_chain:
+            param_grids[operator.name] = operator.parameters
+
+        param_combinations = self._generate_search_space(param_grids)
+
+        for params in param_combinations:
+            modified_modality = modality
+            current_chain = []
+
+            representation_start = time.time()
+            try:
+                for operator in operator_chain:
+
+                    if operator.name in params:
+                        operator.set_parameters(params[operator.name])
+
+                    if isinstance(operator, Context):
+                        modified_modality = modified_modality.context(operator)
+                    else:
+                        modified_modality = 
modified_modality.apply_representation(
+                            operator
+                        )
+
+                    current_chain.append(operator)
+
+                representation_end = time.time()
+
+                score = self.task.run(modified_modality.data)
+
+                if score[1] > best_score:
+                    best_score = score[1]
+                    best_params = params
+                    best_result = OptimizationResult(
+                        operator_chain=current_chain,
+                        parameters=params,
+                        train_accuracy=score[0],
+                        test_accuracy=score[1],
+                        training_runtime=self.task.training_time,
+                        inference_runtime=self.task.inference_time,
+                        representation_time=representation_end - 
representation_start,
+                        output_shape=(1, 1),
+                    )
+
+            except Exception as e:
+                print(f"Failed parameter combination {params}: {str(e)}")
+                continue
+
+        return best_result
+
+    def _generate_search_space(self, param_grids):
+        combinations = {}
+        for operator_name, params in param_grids.items():
+            operator_combinations = [
+                dict(zip(params.keys(), v)) for v in 
itertools.product(*params.values())
+            ]
+            combinations[operator_name] = operator_combinations
+
+        keys = list(combinations.keys())
+        values = [combinations[key] for key in keys]
+
+        parameter_grid = [
+            dict(zip(keys, combo)) for combo in itertools.product(*values)
+        ]
+
+        return parameter_grid
diff --git a/src/main/python/systemds/scuro/drsearch/operator_registry.py 
b/src/main/python/systemds/scuro/drsearch/operator_registry.py
new file mode 100644
index 0000000000..942e5bb80e
--- /dev/null
+++ b/src/main/python/systemds/scuro/drsearch/operator_registry.py
@@ -0,0 +1,107 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+from typing import Union, List
+
+from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.representations.representation import Representation
+
+
+class Registry:
+    """
+    A registry for all representations per modality.
+    The representations are stored in a dictionary where a specific modality 
type is the key.
+    Implemented as a singleton.
+    """
+
+    _instance = None
+    _representations = {}
+    _context_operators = []
+    _fusion_operators = []
+
+    def __new__(cls):
+        if not cls._instance:
+            cls._instance = super().__new__(cls)
+            for m_type in ModalityType:
+                cls._representations[m_type] = []
+        return cls._instance
+
+    def add_representation(
+        self, representation: Representation, modality: ModalityType
+    ):
+        self._representations[modality].append(representation)
+
+    def add_context_operator(self, context_operator):
+        self._context_operators.append(context_operator)
+
+    def add_fusion_operator(self, fusion_operator):
+        self._fusion_operators.append(fusion_operator)
+
+    def get_representations(self, modality: ModalityType):
+        return self._representations[modality]
+
+    def get_context_operators(self):
+        return self._context_operators
+
+    def get_fusion_operators(self):
+        return self._fusion_operators
+
+
+def register_representation(modalities: Union[ModalityType, 
List[ModalityType]]):
+    """
+    Decorator to register representation for a specific modality.
+    :param modalities: The modalities for which the representation is to be 
registered
+    """
+    if isinstance(modalities, ModalityType):
+        modalities = [modalities]
+
+    def decorator(cls):
+        for modality in modalities:
+            if modality not in ModalityType:
+                raise f"Modality {modality} not in ModalityTypes please add it 
to constants.py ModalityTypes first!"
+
+            Registry().add_representation(cls, modality)
+        return cls
+
+    return decorator
+
+
+def register_context_operator():
+    """
+    Decorator to register a context operator.
+    """
+
+    def decorator(cls):
+        Registry().add_context_operator(cls)
+        return cls
+
+    return decorator
+
+
+def register_fusion_operator():
+    """
+    Decorator to register a fusion operator.
+    """
+
+    def decorator(cls):
+        Registry().add_fusion_operator(cls)
+        return cls
+
+    return decorator
diff --git a/src/main/python/systemds/scuro/drsearch/optimization_data.py 
b/src/main/python/systemds/scuro/drsearch/optimization_data.py
new file mode 100644
index 0000000000..4ca54c10d3
--- /dev/null
+++ b/src/main/python/systemds/scuro/drsearch/optimization_data.py
@@ -0,0 +1,164 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+from dataclasses import dataclass
+from typing import List, Dict, Any, Union
+
+from systemds.scuro.drsearch.operator_registry import Registry
+from systemds.scuro.representations.representation import Representation
+
+
+@dataclass
+class OptimizationResult:
+    """
+    The OptimizationResult class stores the results of an individual 
optimization
+
+    Attributes:
+        operator_chain (List[str]): stores the name of the operators used in 
the optimization run
+        parameters (Dict[str, Any]): stores the parameters used for the 
operators in the optimization run
+        accuracy (float): stores the test accuracy of the optimization run
+        training_runtime (float): stores the training runtime of the 
optimization run
+        inference_runtime (float): stores the inference runtime of the 
optimization run
+        output_shape (tupe): stores the output shape of the data produced by 
the optimization run
+    """
+
+    operator_chain: List[Representation]
+    parameters: Union[Dict[str, Any], List[Any]]
+    train_accuracy: float
+    test_accuracy: float
+    # train_min_it_acc: float
+    # test_min_it_acc: float
+    training_runtime: float
+    inference_runtime: float
+    representation_time: float
+    output_shape: tuple
+
+    # def __str__(self):
+    #     result_string = ""
+    #     for operator in self.operator_chain:
+    #         if isinstance(operator, List):
+    #             result_string += extract_operator_names(operator)
+    #         else:
+    #             result_string += operator.name
+    #     return result_string
+
+
+@dataclass
+class OptimizationData:
+    representation_name: str
+    mean_accuracy = 0.0
+    min_accuracy = 1.0
+    max_accuracy = 0.0
+    num_times_used = 0
+
+    def add_entry(self, score):
+        self.num_times_used += 1
+        self.min_accuracy = min(score, self.min_accuracy)
+        self.max_accuracy = max(score, self.max_accuracy)
+        if self.num_times_used > 1:
+            self.mean_accuracy += (score - self.mean_accuracy) / 
self.num_times_used
+        else:
+            self.mean_accuracy = score
+
+    def __str__(self):
+        return f"Name: {self.representation_name}  mean: {self.mean_accuracy} 
max: {self.max_accuracy} min: {self.min_accuracy} num_times: 
{self.num_times_used}"
+
+
+def extract_names(operator_chain):
+    result = []
+    for op in operator_chain:
+        result.append(op.name if not isinstance(op, str) else op)
+
+    return result
+
+
+class OptimizationStatistics:
+    optimization_data: Dict[str, OptimizationData] = {}
+    fusion_names = []
+
+    def __init__(self, candidates):
+        for candidate in candidates:
+            representation_name = 
"".join(extract_names(candidate.operator_chain))
+            self.optimization_data[representation_name] = OptimizationData(
+                representation_name
+            )
+
+        for fusion_method in Registry().get_fusion_operators():
+            self.optimization_data[fusion_method.__name__] = OptimizationData(
+                fusion_method.__name__
+            )
+            self.fusion_names.append(fusion_method.__name__)
+
+    def parse_representation_name(self, name):
+        parts = []
+        current_part = ""
+
+        i = 0
+        while i < len(name):
+            found_fusion = False
+            for fusion in self.fusion_names:
+                if name[i:].startswith(fusion):
+                    if current_part:
+                        parts.append(current_part)
+                    parts.append(fusion)
+                    i += len(fusion)
+                    found_fusion = True
+                    break
+
+            if not found_fusion:
+                current_part += name[i]
+                i += 1
+            else:
+                current_part = ""
+
+        if current_part:
+            parts.append(current_part)
+
+        return parts
+
+    def add_entry(self, representations, score):
+        # names = self.parse_representation_name(representation_name)
+
+        for rep in representations:
+            if isinstance(rep[0], list):
+                for r in rep:
+                    name = "".join(extract_names(r))
+                    if self.optimization_data.get(name) is None:
+                        self.optimization_data[name] = OptimizationData(name)
+                    self.optimization_data[name].add_entry(score)
+            else:
+                name = "".join(extract_names(rep))
+                if self.optimization_data.get(name) is None:
+                    self.optimization_data[name] = OptimizationData(name)
+                self.optimization_data[name].add_entry(score)
+
+    def print_statistics(self):
+        for statistic in self.optimization_data.values():
+            print(statistic)
+
+
+def extract_operator_names(operators):
+    names = ""
+    for operator in operators:
+        if isinstance(operator, List):
+            names += extract_operator_names(operator)
+        else:
+            names += operator.name
+    return names
diff --git a/src/main/python/systemds/scuro/drsearch/representation_cache.py 
b/src/main/python/systemds/scuro/drsearch/representation_cache.py
new file mode 100644
index 0000000000..fc78167f2e
--- /dev/null
+++ b/src/main/python/systemds/scuro/drsearch/representation_cache.py
@@ -0,0 +1,127 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import copy
+import os
+import pickle
+import tempfile
+
+from systemds.scuro.modality.transformed import TransformedModality
+
+
+class RepresentationCache:
+    """ """
+
+    _instance = None
+    _cache_dir = None
+    debug = False
+
+    def __new__(cls, debug=False):
+        if not cls._instance:
+            cls.debug = debug
+            cls._instance = super().__new__(cls)
+            cls._cache_dir = tempfile.TemporaryDirectory()
+            # cls._cache_dir = "representation_cache"
+        return cls._instance
+
+    def _generate_cache_filename(self, modality_id, operators):
+        """
+        Generate a unique filename for an operator based on its name.
+
+        :param operator_name: The name of the operator.
+        :return: A full path to the cache file.
+        """
+        op_names = []
+        filename = modality_id
+        for operator in operators:
+            if isinstance(operator, str):
+                op_names.append(operator)
+                filename += operator
+            else:
+                op_names.append(operator.name)
+                filename += operator.name
+
+        return os.path.join(self._cache_dir.name, filename), op_names  # 
_cache_dir.name
+
+    def save_to_cache(self, modality, used_op_names, operators):
+        """
+        Save data to a cache file.
+
+        :param operator_name: The name of the operator.
+        :param data: The data to save.
+        """
+        filename, op_names = self._generate_cache_filename(
+            str(modality.modality_id) + used_op_names, operators
+        )
+        if not os.path.exists(filename):
+            with open(f"{filename}.pkl", "wb") as f:
+                pickle.dump(modality.data, f)
+
+            with open(f"{filename}.meta", "wb") as f:
+                pickle.dump(modality.metadata, f)
+
+            if self.debug:
+                str_names = ", ".join(op_names)
+                print(
+                    f"Saved data for operator 
{str(modality.modality_id)}{used_op_names}{str_names} to cache: {filename}"
+                )
+
+    def load_from_cache(self, modality, operators):
+        """
+        Load data from a cache file if it exists.
+
+        :param operator_name: The name of the operator.
+        :return: The cached data or None if not found.
+        """
+        ops = copy.deepcopy(operators)
+        filename, op_names = self._generate_cache_filename(
+            str(modality.modality_id), ops
+        )
+        dropped_ops = []
+        while not os.path.exists(f"{filename}.pkl"):
+            op_names.pop()
+            dropped_ops.append(ops.pop())
+            if len(ops) < 1:
+                break
+            filename, op_names = self._generate_cache_filename(
+                str(modality.modality_id), ops
+            )
+
+        dropped_ops.reverse()
+        op_names = "".join(op_names)
+
+        if os.path.exists(f"{filename}.pkl"):
+            with open(f"{filename}.meta", "rb") as f:
+                metadata = pickle.load(f)
+
+            transformed_modality = TransformedModality(
+                modality.modality_type, op_names, modality.modality_id, 
metadata
+            )
+            data = None
+            with open(f"{filename}.pkl", "rb") as f:
+                if self.debug:
+                    print(
+                        f"Loaded cached data for operator 
'{str(modality.modality_id) + op_names}' from {filename}"
+                    )
+                data = pickle.load(f)
+            transformed_modality.data = data
+            return transformed_modality, dropped_ops, op_names
+
+        return None, dropped_ops, op_names
diff --git a/src/main/python/systemds/scuro/aligner/similarity_measures.py 
b/src/main/python/systemds/scuro/drsearch/similarity_measures.py
similarity index 100%
rename from src/main/python/systemds/scuro/aligner/similarity_measures.py
rename to src/main/python/systemds/scuro/drsearch/similarity_measures.py
diff --git a/src/main/python/systemds/scuro/aligner/task.py 
b/src/main/python/systemds/scuro/drsearch/task.py
similarity index 80%
rename from src/main/python/systemds/scuro/aligner/task.py
rename to src/main/python/systemds/scuro/drsearch/task.py
index f33546ae65..7e05a489e4 100644
--- a/src/main/python/systemds/scuro/aligner/task.py
+++ b/src/main/python/systemds/scuro/drsearch/task.py
@@ -18,6 +18,7 @@
 # under the License.
 #
 # -------------------------------------------------------------
+import time
 from typing import List
 
 from systemds.scuro.models.model import Model
@@ -34,6 +35,7 @@ class Task:
         train_indices: List,
         val_indices: List,
         kfold=5,
+        measure_performance=True,
     ):
         """
         Parent class for the prediction task that is performed on top of the 
aligned representation
@@ -51,6 +53,10 @@ class Task:
         self.train_indices = train_indices
         self.val_indices = val_indices
         self.kfold = kfold
+        self.measure_performance = measure_performance
+        self.inference_time = []
+        self.training_time = []
+        self.expected_dim = 1
 
     def get_train_test_split(self, data):
         X_train = [data[i] for i in self.train_indices]
@@ -67,6 +73,8 @@ class Task:
          :param data: The aligned data used in the prediction process
          :return: the validation accuracy
         """
+        self.inference_time = []
+        self.training_time = []
         skf = KFold(n_splits=self.kfold, shuffle=True, random_state=11)
         train_scores = []
         test_scores = []
@@ -76,13 +84,21 @@ class Task:
         for train, test in skf.split(X, y):
             train_X = np.array(X)[train]
             train_y = np.array(y)[train]
-
+            train_start = time.time()
             train_score = self.model.fit(train_X, train_y, X_test, y_test)
+            train_end = time.time()
+            self.training_time.append(train_end - train_start)
             train_scores.append(train_score)
-
-            test_score = self.model.test(X_test, y_test)
+            test_start = time.time()
+            test_score = self.model.test(np.array(X_test), y_test)
+            test_end = time.time()
+            self.inference_time.append(test_end - test_start)
             test_scores.append(test_score)
 
             fold += 1
 
+        if self.measure_performance:
+            self.inference_time = np.mean(self.inference_time)
+            self.training_time = np.mean(self.training_time)
+
         return [np.mean(train_scores), np.mean(test_scores)]
diff --git 
a/src/main/python/systemds/scuro/drsearch/unimodal_representation_optimizer.py 
b/src/main/python/systemds/scuro/drsearch/unimodal_representation_optimizer.py
new file mode 100644
index 0000000000..e59ddbe9be
--- /dev/null
+++ 
b/src/main/python/systemds/scuro/drsearch/unimodal_representation_optimizer.py
@@ -0,0 +1,271 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import copy
+import os
+import pickle
+import time
+from typing import List
+
+from systemds.scuro.drsearch.operator_registry import Registry
+from systemds.scuro.drsearch.optimization_data import OptimizationResult
+from systemds.scuro.drsearch.representation_cache import RepresentationCache
+from systemds.scuro.drsearch.task import Task
+from systemds.scuro.modality.modality import Modality
+from systemds.scuro.representations.aggregate import Aggregation
+from systemds.scuro.representations.context import Context
+
+
+class UnimodalRepresentationOptimizer:
+    def __init__(
+        self,
+        modalities: List[Modality],
+        tasks: List[Task],
+        max_chain_depth=5,
+        debug=False,
+        folder_name=None,
+    ):
+        self.optimization_results = {}
+        self.modalities = modalities
+        self.tasks = tasks
+        self.operator_registry = Registry()
+        self.initialize_optimization_results()
+        self.max_chain_depth = max_chain_depth
+        self.debug = debug
+        self.cache = RepresentationCache(self.debug)
+        if self.debug:
+            self.folder_name = folder_name
+            os.makedirs(self.folder_name, exist_ok=True)
+
+    def initialize_optimization_results(self):
+        for modality in self.modalities:
+            self.optimization_results[modality.modality_id] = {}
+            for task in self.tasks:
+                
self.optimization_results[modality.modality_id][task.model.name] = []
+
+    def optimize(self):
+        """
+        This method finds different unimodal representations for all given 
modalities
+        """
+
+        for modality in self.modalities:
+            self._optimize_modality(modality)
+
+            copy_results = copy.deepcopy(
+                self.optimization_results[modality.modality_id]
+            )
+            for model in copy_results:
+                for i, model_task in enumerate(copy_results[model]):
+                    ops = []
+                    for op in model_task.operator_chain:
+                        if not isinstance(op, str):
+                            ops.append(op.name)
+                    if len(ops) > 0:
+                        copy_results[model][i].operator_chain = ops
+                if self.debug:
+                    with open(
+                        
f"{self.folder_name}/results_{model}_{modality.modality_type.name}.p",
+                        "wb",
+                    ) as fp:
+                        pickle.dump(
+                            copy_results[model], fp, 
protocol=pickle.HIGHEST_PROTOCOL
+                        )
+
+    def get_k_best_results(self, modality: Modality, k: int, task: Task):
+        """
+        Get the k best results for the given modality
+        :param modality: modality to get the best results for
+        :param k: number of best results
+        """
+        results = sorted(
+            self.optimization_results[modality.modality_id][task.model.name],
+            key=lambda x: x.test_accuracy,
+            reverse=True,
+        )[:k]
+
+        return results
+
+    def _optimize_modality(self, modality: Modality):
+        """
+        Optimize a single modality by leveraging modality specific heuristics 
and incorporating context and
+        stores the resulting operation chains as optimization results.
+        :param modality: modality to optimize
+        """
+
+        representations = 
self._get_compatible_operators(modality.modality_type, [])
+
+        for rep in representations:
+            self._build_operator_chain(modality, [rep()], 1)
+
+    def _get_compatible_operators(self, modality_type, used_operators):
+        next_operators = []
+        for operator in 
self.operator_registry.get_representations(modality_type):
+            if operator.__name__ not in used_operators:
+                next_operators.append(operator)
+
+        for context_operator in self.operator_registry.get_context_operators():
+            if (
+                len(used_operators) == 0
+                or context_operator.__name__ not in used_operators[-1]
+            ):
+                next_operators.append(context_operator)
+
+        return next_operators
+
+    def _build_operator_chain(self, modality, current_operator_chain, depth):
+
+        if depth > self.max_chain_depth:
+            return
+
+        self._apply_operator_chain(modality, current_operator_chain)
+
+        current_modality_type = modality.modality_type
+
+        for operator in current_operator_chain:
+            if hasattr(operator, "output_modality_type"):
+                current_modality_type = operator.output_modality_type
+
+        next_representations = self._get_compatible_operators(
+            current_modality_type, [type(op).__name__ for op in 
current_operator_chain]
+        )
+
+        for next_rep in next_representations:
+            rep_instance = next_rep()
+            new_chain = current_operator_chain + [rep_instance]
+            self._build_operator_chain(modality, new_chain, depth + 1)
+
+    def _evaluate_with_flattened_data(
+        self, modality, operator_chain, op_params, representation_time, task
+    ):
+        from systemds.scuro.representations.aggregated_representation import (
+            AggregatedRepresentation,
+        )
+
+        results = []
+        for aggregation in Aggregation().get_aggregation_functions():
+            start = time.time()
+            agg_operator = AggregatedRepresentation(Aggregation(aggregation, 
True))
+            agg_modality = agg_operator.transform(modality)
+            end = time.time()
+
+            agg_opperator_chain = operator_chain + [agg_operator]
+            agg_params = dict(op_params)
+            agg_params.update({agg_operator.name: agg_operator.parameters})
+
+            score = task.run(agg_modality.data)
+            result = OptimizationResult(
+                operator_chain=agg_opperator_chain,
+                parameters=agg_params,
+                train_accuracy=score[0],
+                test_accuracy=score[1],
+                # train_min_it_acc=score[2],
+                # test_min_it_acc=score[3],
+                training_runtime=task.training_time,
+                inference_runtime=task.inference_time,
+                representation_time=representation_time + end - start,
+                output_shape=(1, 1),  # TODO
+            )
+            results.append(result)
+
+            if self.debug:
+                op_name = ""
+                for operator in agg_opperator_chain:
+                    op_name += str(operator.__class__.__name__)
+                print(f"{task.name} {task.model.name} {op_name}: {score[1]}")
+
+        return results
+
+    def _evaluate_operator_chain(
+        self, modality, operator_chain, op_params, representation_time
+    ):
+        for task in self.tasks:
+            if isinstance(modality.data[0], str):
+                continue
+
+            if (
+                task.expected_dim == 1
+                and not isinstance(modality.data[0], list)
+                and modality.data[0].ndim > 1
+            ):
+                r = self._evaluate_with_flattened_data(
+                    modality, operator_chain, op_params, representation_time, 
task
+                )
+                
self.optimization_results[modality.modality_id][task.model.name].extend(
+                    r
+                )
+            else:
+                score = task.run(modality.data)
+                result = OptimizationResult(
+                    operator_chain=operator_chain,
+                    parameters=op_params,
+                    train_accuracy=score[0],
+                    test_accuracy=score[1],
+                    # train_min_it_acc=score[2],
+                    # test_min_it_acc=score[3],
+                    training_runtime=task.training_time,
+                    inference_runtime=task.inference_time,
+                    representation_time=representation_time,
+                    output_shape=(1, 1),
+                )  # TODO
+                
self.optimization_results[modality.modality_id][task.model.name].append(
+                    result
+                )
+                if self.debug:
+                    op_name = ""
+                    for operator in operator_chain:
+                        op_name += str(operator.__class__.__name__)
+                    print(f"{task.name} {task.model.name} - {op_name}: 
{score[1]}")
+
+    def _apply_operator_chain(self, current_modality, operator_chain):
+        op_params = {}
+        modified_modality = current_modality
+
+        representation_start = time.time()
+        try:
+            cached_representation, representation_ops, used_op_names = (
+                self.cache.load_from_cache(
+                    modified_modality, copy.deepcopy(operator_chain)
+                )
+            )
+            if cached_representation is not None:
+                modified_modality = cached_representation
+            store = False
+            for operator in representation_ops:
+                if isinstance(operator, Context):
+                    modified_modality = modified_modality.context(operator)
+                else:
+                    modified_modality = 
modified_modality.apply_representation(operator)
+                store = True
+                op_params[operator.name] = operator.get_current_parameters()
+            if store:
+                self.cache.save_to_cache(
+                    modified_modality, used_op_names, representation_ops
+                )
+            representation_end = time.time()
+
+            self._evaluate_operator_chain(
+                modified_modality,
+                operator_chain,
+                op_params,
+                representation_end - representation_start,
+            )
+        except Exception as e:
+            print(f"Failed to evaluate chain {operator_chain}: {str(e)}")
+            return
diff --git a/src/main/python/systemds/scuro/main.py 
b/src/main/python/systemds/scuro/main.py
index 8a51e098cc..f88e211157 100644
--- a/src/main/python/systemds/scuro/main.py
+++ b/src/main/python/systemds/scuro/main.py
@@ -25,8 +25,8 @@ from systemds.scuro.representations.average import Average
 from systemds.scuro.representations.concatenation import Concatenation
 from systemds.scuro.modality.unimodal_modality import UnimodalModality
 from systemds.scuro.models.discrete_model import DiscreteModel
-from systemds.scuro.aligner.task import Task
-from systemds.scuro.aligner.dr_search import DRSearch
+from systemds.scuro.drsearch.task import Task
+from systemds.scuro.drsearch.dr_search import DRSearch
 
 from systemds.scuro.dataloader.audio_loader import AudioLoader
 from systemds.scuro.dataloader.text_loader import TextLoader
diff --git a/src/main/python/systemds/scuro/modality/joined.py 
b/src/main/python/systemds/scuro/modality/joined.py
index c1aa26abf6..1a58df9256 100644
--- a/src/main/python/systemds/scuro/modality/joined.py
+++ b/src/main/python/systemds/scuro/modality/joined.py
@@ -18,13 +18,13 @@
 # under the License.
 #
 # -------------------------------------------------------------
+import importlib
 import sys
 
 import numpy as np
 
 from systemds.scuro.modality.joined_transformed import 
JoinedTransformedModality
 from systemds.scuro.modality.modality import Modality
-from systemds.scuro.representations.aggregate import Aggregation
 from systemds.scuro.representations.utils import pad_sequences
 
 
@@ -167,7 +167,9 @@ class JoinedModality(Modality):
     def aggregate(
         self, aggregation_function, field_name
     ):  # TODO: use the filed name to extract data entries from modalities
-        self.aggregation = Aggregation(aggregation_function, field_name)
+        module = 
importlib.import_module("systemds.scuro.representations.aggregate")
+
+        self.aggregation = module.Aggregation(aggregation_function, field_name)
 
         if not self.chunked_execution and self.joined_right:
             return self.aggregation.aggregate(self.joined_right)
diff --git a/src/main/python/systemds/scuro/modality/modality.py 
b/src/main/python/systemds/scuro/modality/modality.py
index c110a24eba..c16db00172 100644
--- a/src/main/python/systemds/scuro/modality/modality.py
+++ b/src/main/python/systemds/scuro/modality/modality.py
@@ -23,7 +23,7 @@ from typing import List
 
 import numpy as np
 
-from systemds.scuro.modality.type import ModalityType, DataLayout
+from systemds.scuro.modality.type import ModalityType
 from systemds.scuro.representations import utils
 
 
diff --git a/src/main/python/systemds/scuro/modality/modality_identifier.py 
b/src/main/python/systemds/scuro/modality/modality_identifier.py
index 95668c6e58..5eeee7dc13 100644
--- a/src/main/python/systemds/scuro/modality/modality_identifier.py
+++ b/src/main/python/systemds/scuro/modality/modality_identifier.py
@@ -18,13 +18,6 @@
 # under the License.
 #
 # -------------------------------------------------------------
-import os
-import pickle
-from typing import List, Dict, Any, Union
-import tempfile
-from systemds.scuro.representations.representation import Representation
-
-
 class ModalityIdentifier:
     """ """
 
diff --git a/src/main/python/systemds/scuro/modality/transformed.py 
b/src/main/python/systemds/scuro/modality/transformed.py
index 2b4b049ef4..aba59c1efb 100644
--- a/src/main/python/systemds/scuro/modality/transformed.py
+++ b/src/main/python/systemds/scuro/modality/transformed.py
@@ -100,7 +100,10 @@ class TransformedModality(Modality):
             self.metadata,
         )
         modalities = [self]
-        modalities.extend(other)
+        if isinstance(other, list):
+            modalities.extend(other)
+        else:
+            modalities.append(other)
         fused_modality.data = fusion_method.transform(modalities)
 
         return fused_modality
diff --git a/src/main/python/systemds/scuro/modality/unimodal_modality.py 
b/src/main/python/systemds/scuro/modality/unimodal_modality.py
index 6173237e0a..714fe42c33 100644
--- a/src/main/python/systemds/scuro/modality/unimodal_modality.py
+++ b/src/main/python/systemds/scuro/modality/unimodal_modality.py
@@ -26,7 +26,6 @@ from systemds.scuro.dataloader.base_loader import BaseLoader
 from systemds.scuro.modality.modality import Modality
 from systemds.scuro.modality.joined import JoinedModality
 from systemds.scuro.modality.transformed import TransformedModality
-from systemds.scuro.modality.type import ModalityType
 from systemds.scuro.modality.modality_identifier import ModalityIdentifier
 
 
diff --git a/src/main/python/systemds/scuro/representations/aggregate.py 
b/src/main/python/systemds/scuro/representations/aggregate.py
index 4b4545ef47..756e6271ea 100644
--- a/src/main/python/systemds/scuro/representations/aggregate.py
+++ b/src/main/python/systemds/scuro/representations/aggregate.py
@@ -20,7 +20,6 @@
 # -------------------------------------------------------------
 import numpy as np
 
-from systemds.scuro.modality.modality import Modality
 from systemds.scuro.representations import utils
 
 
@@ -48,21 +47,28 @@ class Aggregation:
         "sum": _sum_agg.__func__,
     }
 
-    def __init__(self, aggregation_function="mean", pad_modality=False):
+    def __init__(self, aggregation_function="mean", pad_modality=False, 
params=None):
+        if params is not None:
+            aggregation_function = params["aggregation_function"]
+            pad_modality = params["pad_modality"]
+
         if aggregation_function not in self._aggregation_function.keys():
             raise ValueError("Invalid aggregation function")
+
         self._aggregation_func = 
self._aggregation_function[aggregation_function]
         self.name = "Aggregation"
         self.pad_modality = pad_modality
 
+        self.parameters = {
+            "aggregation_function": aggregation_function,
+            "pad_modality": pad_modality,
+        }
+
     def execute(self, modality):
-        aggregated_modality = Modality(
-            modality.modality_type, modality.modality_id, modality.metadata
-        )
-        aggregated_modality.data = []
+        data = []
         max_len = 0
         for i, instance in enumerate(modality.data):
-            aggregated_modality.data.append([])
+            data.append([])
             if isinstance(instance, np.ndarray):
                 aggregated_data = self._aggregation_func(instance)
             else:
@@ -70,22 +76,22 @@ class Aggregation:
                 for entry in instance:
                     aggregated_data.append(self._aggregation_func(entry))
             max_len = max(max_len, len(aggregated_data))
-            aggregated_modality.data[i] = aggregated_data
+            data[i] = aggregated_data
 
         if self.pad_modality:
-            for i, instance in enumerate(aggregated_modality.data):
+            for i, instance in enumerate(data):
                 if isinstance(instance, np.ndarray):
                     if len(instance) < max_len:
                         padded_data = np.zeros(max_len, dtype=instance.dtype)
                         padded_data[: len(instance)] = instance
-                        aggregated_modality.data[i] = padded_data
+                        data[i] = padded_data
                 else:
                     padded_data = []
                     for entry in instance:
                         padded_data.append(utils.pad_sequences(entry, max_len))
-                    aggregated_modality.data[i] = padded_data
+                    data[i] = padded_data
 
-        return aggregated_modality
+        return data
 
     def transform(self, modality):
         return self.execute(modality)
diff --git a/src/main/python/systemds/scuro/aligner/alignment_strategy.py 
b/src/main/python/systemds/scuro/representations/aggregated_representation.py
similarity index 59%
rename from src/main/python/systemds/scuro/aligner/alignment_strategy.py
rename to 
src/main/python/systemds/scuro/representations/aggregated_representation.py
index 698a6d0d98..46e6b8bed2 100644
--- a/src/main/python/systemds/scuro/aligner/alignment_strategy.py
+++ 
b/src/main/python/systemds/scuro/representations/aggregated_representation.py
@@ -18,23 +18,18 @@
 # under the License.
 #
 # -------------------------------------------------------------
-from aligner.similarity_measures import Measure
+from systemds.scuro.modality.transformed import TransformedModality
+from systemds.scuro.representations.representation import Representation
 
 
-class AlignmentStrategy:
-    def __init__(self):
-        pass
+class AggregatedRepresentation(Representation):
+    def __init__(self, aggregation):
+        super().__init__("AggregatedRepresentation", aggregation.parameters)
+        self.aggregation = aggregation
 
-    def align_chunk(self, chunk_a, chunk_b, similarity_measure: Measure):
-        raise "Not implemented error"
-
-
-class ChunkedCrossCorrelation(AlignmentStrategy):
-    def __init__(self):
-        super().__init__()
-
-    def align_chunk(self, chunk_a, chunk_b, similarity_measure: Measure):
-        raise "Not implemented error"
-
-
-# TODO: Add additional alignment methods
+    def transform(self, modality):
+        aggregated_modality = TransformedModality(
+            modality.modality_type, self.name, modality.modality_id, 
modality.metadata
+        )
+        aggregated_modality.data = self.aggregation.execute(modality)
+        return aggregated_modality
diff --git a/src/main/python/systemds/scuro/representations/average.py 
b/src/main/python/systemds/scuro/representations/average.py
index db44050e9e..4c6b0e1787 100644
--- a/src/main/python/systemds/scuro/representations/average.py
+++ b/src/main/python/systemds/scuro/representations/average.py
@@ -27,8 +27,10 @@ from systemds.scuro.modality.modality import Modality
 from systemds.scuro.representations.utils import pad_sequences
 
 from systemds.scuro.representations.fusion import Fusion
+from systemds.scuro.drsearch.operator_registry import register_fusion_operator
 
 
+@register_fusion_operator()
 class Average(Fusion):
     def __init__(self):
         """
@@ -37,6 +39,9 @@ class Average(Fusion):
         super().__init__("Average")
 
     def transform(self, modalities: List[Modality]):
+        for modality in modalities:
+            modality.flatten()
+
         max_emb_size = self.get_max_embedding_size(modalities)
 
         padded_modalities = []
diff --git a/src/main/python/systemds/scuro/representations/bert.py 
b/src/main/python/systemds/scuro/representations/bert.py
index 6395d0b9e6..802d7e3d0b 100644
--- a/src/main/python/systemds/scuro/representations/bert.py
+++ b/src/main/python/systemds/scuro/representations/bert.py
@@ -19,16 +19,16 @@
 #
 # -------------------------------------------------------------
 
-import numpy as np
-
 from systemds.scuro.modality.transformed import TransformedModality
 from systemds.scuro.representations.unimodal import UnimodalRepresentation
 import torch
 from transformers import BertTokenizer, BertModel
 from systemds.scuro.representations.utils import save_embeddings
 from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.drsearch.operator_registry import register_representation
 
 
+@register_representation(ModalityType.TEXT)
 class Bert(UnimodalRepresentation):
     def __init__(self, model_name="bert", output_file=None):
         parameters = {"model_name": "bert"}
@@ -49,7 +49,7 @@ class Bert(UnimodalRepresentation):
         model = BertModel.from_pretrained(model_name)
 
         embeddings = self.create_embeddings(modality.data, model, tokenizer)
-        embeddings = [embeddings[i : i + 1] for i in 
range(embeddings.shape[0])]
+
         if self.output_file is not None:
             save_embeddings(embeddings, self.output_file)
 
@@ -65,7 +65,6 @@ class Bert(UnimodalRepresentation):
                 outputs = model(**inputs)
 
                 cls_embedding = outputs.last_hidden_state[:, 0, 
:].squeeze().numpy()
-                embeddings.append(cls_embedding)
+                embeddings.append(cls_embedding.reshape(1, -1))
 
-        embeddings = np.array(embeddings)
-        return embeddings.reshape((embeddings.shape[0], embeddings.shape[-1]))
+        return embeddings
diff --git a/src/main/python/systemds/scuro/representations/bow.py 
b/src/main/python/systemds/scuro/representations/bow.py
index 52fddc7d3f..e2bc94041f 100644
--- a/src/main/python/systemds/scuro/representations/bow.py
+++ b/src/main/python/systemds/scuro/representations/bow.py
@@ -26,8 +26,10 @@ from systemds.scuro.representations.unimodal import 
UnimodalRepresentation
 from systemds.scuro.representations.utils import save_embeddings
 
 from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.drsearch.operator_registry import register_representation
 
 
+@register_representation(ModalityType.TEXT)
 class BoW(UnimodalRepresentation):
     def __init__(self, ngram_range=2, min_df=2, output_file=None):
         parameters = {"ngram_range": [ngram_range], "min_df": [min_df]}
diff --git a/src/main/python/systemds/scuro/representations/concatenation.py 
b/src/main/python/systemds/scuro/representations/concatenation.py
index fd9293d399..1265563b6c 100644
--- a/src/main/python/systemds/scuro/representations/concatenation.py
+++ b/src/main/python/systemds/scuro/representations/concatenation.py
@@ -28,7 +28,10 @@ from systemds.scuro.representations.utils import 
pad_sequences
 
 from systemds.scuro.representations.fusion import Fusion
 
+from systemds.scuro.drsearch.operator_registry import register_fusion_operator
 
+
+@register_fusion_operator()
 class Concatenation(Fusion):
     def __init__(self, padding=True):
         """
diff --git a/src/main/python/systemds/scuro/representations/context.py 
b/src/main/python/systemds/scuro/representations/context.py
index 4cbcf54f8e..54f22633cc 100644
--- a/src/main/python/systemds/scuro/representations/context.py
+++ b/src/main/python/systemds/scuro/representations/context.py
@@ -19,7 +19,6 @@
 #
 # -------------------------------------------------------------
 import abc
-from typing import List
 
 from systemds.scuro.modality.modality import Modality
 from systemds.scuro.representations.representation import Representation
diff --git a/src/main/python/systemds/scuro/representations/glove.py 
b/src/main/python/systemds/scuro/representations/glove.py
index 7bb586dc99..66a6847a94 100644
--- a/src/main/python/systemds/scuro/representations/glove.py
+++ b/src/main/python/systemds/scuro/representations/glove.py
@@ -23,8 +23,9 @@ from gensim.utils import tokenize
 
 
 from systemds.scuro.representations.unimodal import UnimodalRepresentation
-from systemds.scuro.representations.utils import read_data_from_file, 
save_embeddings
+from systemds.scuro.representations.utils import save_embeddings
 from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.drsearch.operator_registry import register_representation
 
 
 def load_glove_embeddings(file_path):
@@ -38,6 +39,7 @@ def load_glove_embeddings(file_path):
     return embeddings
 
 
+# @register_representation(ModalityType.TEXT)
 class GloVe(UnimodalRepresentation):
     def __init__(self, glove_path, output_file=None):
         super().__init__("GloVe", ModalityType.TEXT)
diff --git a/src/main/python/systemds/scuro/representations/lstm.py 
b/src/main/python/systemds/scuro/representations/lstm.py
index 6f06e762a5..a82a1e2500 100644
--- a/src/main/python/systemds/scuro/representations/lstm.py
+++ b/src/main/python/systemds/scuro/representations/lstm.py
@@ -28,7 +28,10 @@ import numpy as np
 from systemds.scuro.modality.modality import Modality
 from systemds.scuro.representations.fusion import Fusion
 
+from systemds.scuro.drsearch.operator_registry import register_fusion_operator
 
+
+@register_fusion_operator()
 class LSTM(Fusion):
     def __init__(self, width=128, depth=1, dropout_rate=0.1):
         """
diff --git a/src/main/python/systemds/scuro/representations/max.py 
b/src/main/python/systemds/scuro/representations/max.py
index 194b20801e..5a787dcf0c 100644
--- a/src/main/python/systemds/scuro/representations/max.py
+++ b/src/main/python/systemds/scuro/representations/max.py
@@ -28,7 +28,10 @@ from systemds.scuro.representations.utils import 
pad_sequences
 
 from systemds.scuro.representations.fusion import Fusion
 
+from systemds.scuro.drsearch.operator_registry import register_fusion_operator
 
+
+@register_fusion_operator()
 class RowMax(Fusion):
     def __init__(self, split=4):
         """
diff --git a/src/main/python/systemds/scuro/representations/mel_spectrogram.py 
b/src/main/python/systemds/scuro/representations/mel_spectrogram.py
index dfff4f3b7e..4095ceead0 100644
--- a/src/main/python/systemds/scuro/representations/mel_spectrogram.py
+++ b/src/main/python/systemds/scuro/representations/mel_spectrogram.py
@@ -25,8 +25,10 @@ from systemds.scuro.modality.type import ModalityType
 from systemds.scuro.modality.transformed import TransformedModality
 
 from systemds.scuro.representations.unimodal import UnimodalRepresentation
+from systemds.scuro.drsearch.operator_registry import register_representation
 
 
+@register_representation(ModalityType.AUDIO)
 class MelSpectrogram(UnimodalRepresentation):
     def __init__(self, n_mels=128, hop_length=512, n_fft=2048):
         parameters = {
@@ -45,8 +47,15 @@ class MelSpectrogram(UnimodalRepresentation):
         )
         result = []
         max_length = 0
-        for sample in modality.data:
-            S = librosa.feature.melspectrogram(y=sample, sr=22050)
+        for i, sample in enumerate(modality.data):
+            sr = list(modality.metadata.values())[i]["frequency"]
+            S = librosa.feature.melspectrogram(
+                y=sample,
+                sr=sr,
+                n_mels=self.n_mels,
+                hop_length=self.hop_length,
+                n_fft=self.n_fft,
+            )
             S_dB = librosa.power_to_db(S, ref=np.max)
             if S_dB.shape[-1] > max_length:
                 max_length = S_dB.shape[-1]
diff --git a/src/main/python/systemds/scuro/representations/mel_spectrogram.py 
b/src/main/python/systemds/scuro/representations/mfcc.py
similarity index 60%
copy from src/main/python/systemds/scuro/representations/mel_spectrogram.py
copy to src/main/python/systemds/scuro/representations/mfcc.py
index dfff4f3b7e..75cc00d62d 100644
--- a/src/main/python/systemds/scuro/representations/mel_spectrogram.py
+++ b/src/main/python/systemds/scuro/representations/mfcc.py
@@ -25,19 +25,23 @@ from systemds.scuro.modality.type import ModalityType
 from systemds.scuro.modality.transformed import TransformedModality
 
 from systemds.scuro.representations.unimodal import UnimodalRepresentation
+from systemds.scuro.drsearch.operator_registry import register_representation
 
 
-class MelSpectrogram(UnimodalRepresentation):
-    def __init__(self, n_mels=128, hop_length=512, n_fft=2048):
+@register_representation(ModalityType.AUDIO)
+class MFCC(UnimodalRepresentation):
+    def __init__(self, n_mfcc=12, dct_type=2, n_mels=128, hop_length=512):
         parameters = {
-            "n_mels": [20, 32, 64, 128],
+            "n_mfcc": [x for x in range(10, 26)],
+            "dct_type": [1, 2, 3],
             "hop_length": [256, 512, 1024, 2048],
-            "n_fft": [1024, 2048, 4096],
-        }
-        super().__init__("MelSpectrogram", ModalityType.TIMESERIES, parameters)
+            "n_mels": [20, 32, 64, 128],
+        }  # TODO
+        super().__init__("MFCC", ModalityType.TIMESERIES, parameters)
+        self.n_mfcc = n_mfcc
+        self.dct_type = dct_type
         self.n_mels = n_mels
         self.hop_length = hop_length
-        self.n_fft = n_fft
 
     def transform(self, modality):
         transformed_modality = TransformedModality(
@@ -45,12 +49,20 @@ class MelSpectrogram(UnimodalRepresentation):
         )
         result = []
         max_length = 0
-        for sample in modality.data:
-            S = librosa.feature.melspectrogram(y=sample, sr=22050)
-            S_dB = librosa.power_to_db(S, ref=np.max)
-            if S_dB.shape[-1] > max_length:
-                max_length = S_dB.shape[-1]
-            result.append(S_dB.T)
+        for i, sample in enumerate(modality.data):
+            sr = list(modality.metadata.values())[i]["frequency"]
+            mfcc = librosa.feature.mfcc(
+                y=sample,
+                sr=sr,
+                n_mfcc=self.n_mfcc,
+                dct_type=self.dct_type,
+                hop_length=self.hop_length,
+                n_mels=self.n_mels,
+            )
+            mfcc = (mfcc - np.mean(mfcc)) / np.std(mfcc)
+            if mfcc.shape[-1] > max_length:  # TODO: check if this needs to be 
done
+                max_length = mfcc.shape[-1]
+            result.append(mfcc.T)
 
         transformed_modality.data = result
         return transformed_modality
diff --git a/src/main/python/systemds/scuro/representations/multiplication.py 
b/src/main/python/systemds/scuro/representations/multiplication.py
index 2934fe5b3c..8d1e7f8c90 100644
--- a/src/main/python/systemds/scuro/representations/multiplication.py
+++ b/src/main/python/systemds/scuro/representations/multiplication.py
@@ -28,7 +28,10 @@ from systemds.scuro.representations.utils import 
pad_sequences
 
 from systemds.scuro.representations.fusion import Fusion
 
+from systemds.scuro.drsearch.operator_registry import register_fusion_operator
 
+
+@register_fusion_operator()
 class Multiplication(Fusion):
     def __init__(self):
         """
diff --git a/src/main/python/systemds/scuro/representations/optical_flow.py 
b/src/main/python/systemds/scuro/representations/optical_flow.py
new file mode 100644
index 0000000000..1fb922d7a3
--- /dev/null
+++ b/src/main/python/systemds/scuro/representations/optical_flow.py
@@ -0,0 +1,79 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import cv2
+
+from systemds.scuro.modality.transformed import TransformedModality
+from systemds.scuro.representations.unimodal import UnimodalRepresentation
+from typing import Callable, Dict, Tuple, Any
+import torch.utils.data
+import torch
+import torchvision.models as models
+import numpy as np
+from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.drsearch.operator_registry import register_representation
+
+from systemds.scuro.utils.torch_dataset import CustomDataset
+
+if torch.backends.mps.is_available():
+    DEVICE = torch.device("mps")
+# elif torch.cuda.is_available():
+#     DEVICE = torch.device("cuda")
+else:
+    DEVICE = torch.device("cpu")
+
+
+# @register_representation([ModalityType.VIDEO])
+class OpticalFlow(UnimodalRepresentation):
+    def __init__(self):
+        parameters = {}
+        super().__init__("OpticalFlow", ModalityType.TIMESERIES, parameters)
+
+    def transform(self, modality):
+        transformed_modality = TransformedModality(
+            self.output_modality_type,
+            "opticalFlow",
+            modality.modality_id,
+            modality.metadata,
+        )
+
+        for video_id, instance in enumerate(modality.data):
+            transformed_modality.data.append([])
+
+            previous_gray = cv2.cvtColor(instance[0], cv2.COLOR_BGR2GRAY)
+            for frame_id in range(1, len(instance)):
+                gray = cv2.cvtColor(instance[frame_id], cv2.COLOR_BGR2GRAY)
+
+                flow = cv2.calcOpticalFlowFarneback(
+                    previous_gray,
+                    gray,
+                    None,
+                    pyr_scale=0.5,
+                    levels=3,
+                    winsize=15,
+                    iterations=3,
+                    poly_n=5,
+                    poly_sigma=1.1,
+                    flags=0,
+                )
+
+                transformed_modality.data[video_id].append(flow)
+        transformed_modality.update_metadata()
+        return transformed_modality
diff --git a/src/main/python/systemds/scuro/representations/resnet.py 
b/src/main/python/systemds/scuro/representations/resnet.py
index 60eed9ea12..68771eccdd 100644
--- a/src/main/python/systemds/scuro/representations/resnet.py
+++ b/src/main/python/systemds/scuro/representations/resnet.py
@@ -18,14 +18,14 @@
 # under the License.
 #
 # -------------------------------------------------------------
-
+from systemds.scuro.utils.torch_dataset import CustomDataset
 from systemds.scuro.modality.transformed import TransformedModality
 from systemds.scuro.representations.unimodal import UnimodalRepresentation
 from typing import Callable, Dict, Tuple, Any
+from systemds.scuro.drsearch.operator_registry import register_representation
 import torch.utils.data
 import torch
 import torchvision.models as models
-import torchvision.transforms as transforms
 import numpy as np
 from systemds.scuro.modality.type import ModalityType
 
@@ -37,17 +37,19 @@ else:
     DEVICE = torch.device("cpu")
 
 
+@register_representation(
+    [ModalityType.IMAGE, ModalityType.VIDEO, ModalityType.TIMESERIES]
+)
 class ResNet(UnimodalRepresentation):
     def __init__(self, layer="avgpool", model_name="ResNet18", 
output_file=None):
         self.model_name = model_name
         parameters = self._get_parameters()
         super().__init__(
             "ResNet", ModalityType.TIMESERIES, parameters
-        )  # TODO: TIMESERIES only for videos - images would be handled as 
EMBEDDIGN
+        )  # TODO: TIMESERIES only for videos - images would be handled as 
EMBEDDING
 
         self.output_file = output_file
         self.layer_name = layer
-        self.model = model_name
         self.model.eval()
         for param in self.model.parameters():
             param.requires_grad = False
@@ -59,29 +61,30 @@ class ResNet(UnimodalRepresentation):
         self.model.fc = Identity()
 
     @property
-    def model(self):
-        return self._model
-
-    @model.setter
-    def model(self, model):
-        if model == "ResNet18":
-            self._model = 
models.resnet18(weights=models.ResNet18_Weights.DEFAULT).to(
+    def model_name(self):
+        return self._model_name
+
+    @model_name.setter
+    def model_name(self, model_name):
+        self._model_name = model_name
+        if model_name == "ResNet18":
+            self.model = 
models.resnet18(weights=models.ResNet18_Weights.DEFAULT).to(
                 DEVICE
             )
-        elif model == "ResNet34":
-            self._model = 
models.resnet34(weights=models.ResNet34_Weights.DEFAULT).to(
+        elif model_name == "ResNet34":
+            self.model = 
models.resnet34(weights=models.ResNet34_Weights.DEFAULT).to(
                 DEVICE
             )
-        elif model == "ResNet50":
-            self._model = 
models.resnet50(weights=models.ResNet50_Weights.DEFAULT).to(
+        elif model_name == "ResNet50":
+            self.model = 
models.resnet50(weights=models.ResNet50_Weights.DEFAULT).to(
                 DEVICE
             )
-        elif model == "ResNet101":
-            self._model = 
models.resnet101(weights=models.ResNet101_Weights.DEFAULT).to(
+        elif model_name == "ResNet101":
+            self.model = 
models.resnet101(weights=models.ResNet101_Weights.DEFAULT).to(
                 DEVICE
             )
-        elif model == "ResNet152":
-            self._model = 
models.resnet152(weights=models.ResNet152_Weights.DEFAULT).to(
+        elif model_name == "ResNet152":
+            self.model = 
models.resnet152(weights=models.ResNet152_Weights.DEFAULT).to(
                 DEVICE
             )
         else:
@@ -107,20 +110,7 @@ class ResNet(UnimodalRepresentation):
         return parameters
 
     def transform(self, modality):
-
-        t = transforms.Compose(
-            [
-                transforms.ToPILImage(),
-                transforms.Resize(256),
-                transforms.CenterCrop(224),
-                transforms.ToTensor(),
-                transforms.Normalize(
-                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
-                ),
-            ]
-        )
-
-        dataset = ResNetDataset(modality.data, t)
+        dataset = CustomDataset(modality.data)
         embeddings = {}
 
         res5c_output = None
@@ -168,31 +158,3 @@ class ResNet(UnimodalRepresentation):
         transformed_modality.data = list(embeddings.values())
 
         return transformed_modality
-
-
-class ResNetDataset(torch.utils.data.Dataset):
-    def __init__(self, data: str, tf: Callable = None):
-        self.data = data
-        self.tf = tf
-
-    def __getitem__(self, index) -> Dict[str, object]:
-        data = self.data[index]
-        if type(data) is np.ndarray:
-            output = torch.empty((1, 3, 224, 224))
-            d = torch.tensor(data)
-            d = d.repeat(3, 1, 1)
-            output[0] = self.tf(d)
-        else:
-            output = torch.empty((len(data), 3, 224, 224))
-
-            for i, d in enumerate(data):
-                if data[0].ndim < 3:
-                    d = torch.tensor(d)
-                    d = d.repeat(3, 1, 1)
-
-                output[i] = self.tf(d)
-
-        return {"id": index, "data": output}
-
-    def __len__(self) -> int:
-        return len(self.data)
diff --git a/src/main/python/systemds/scuro/representations/rowmax.py 
b/src/main/python/systemds/scuro/representations/rowmax.py
deleted file mode 100644
index 3152782026..0000000000
--- a/src/main/python/systemds/scuro/representations/rowmax.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# -------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-# -------------------------------------------------------------
-import itertools
-from typing import List
-
-import numpy as np
-
-from systemds.scuro.modality.modality import Modality
-from systemds.scuro.representations.utils import pad_sequences
-
-from systemds.scuro.representations.fusion import Fusion
-
-
-class RowMax(Fusion):
-    def __init__(self, split=1):
-        """
-        Combines modalities by computing the outer product of a modality 
combination and
-        taking the row max
-        """
-        super().__init__("RowMax")
-        self.split = split
-
-    def transform(self, modalities: List[Modality]):
-        if len(modalities) < 2:
-            return np.array(modalities)
-
-        max_emb_size = self.get_max_embedding_size(modalities)
-
-        padded_modalities = []
-        for modality in modalities:
-            d = pad_sequences(modality.data, maxlen=max_emb_size, 
dtype="float32")
-            padded_modalities.append(d)
-
-        split_rows = int(len(modalities[0].data) / self.split)
-
-        data = []
-
-        for combination in itertools.combinations(padded_modalities, 2):
-            combined = None
-            for i in range(0, self.split):
-                start = split_rows * i
-                end = (
-                    split_rows * (i + 1)
-                    if i < (self.split - 1)
-                    else len(modalities[0].data)
-                )
-                m = np.einsum(
-                    "bi,bo->bio", combination[0][start:end], 
combination[1][start:end]
-                )
-                m = m.max(axis=2)
-                if combined is None:
-                    combined = m
-                else:
-                    combined = np.concatenate((combined, m), axis=0)
-            data.append(combined)
-
-        data = np.stack(data)
-        data = data.max(axis=0)
-
-        return np.array(data)
diff --git a/src/main/python/systemds/scuro/representations/mel_spectrogram.py 
b/src/main/python/systemds/scuro/representations/spectrogram.py
similarity index 72%
copy from src/main/python/systemds/scuro/representations/mel_spectrogram.py
copy to src/main/python/systemds/scuro/representations/spectrogram.py
index dfff4f3b7e..b5558b1b26 100644
--- a/src/main/python/systemds/scuro/representations/mel_spectrogram.py
+++ b/src/main/python/systemds/scuro/representations/spectrogram.py
@@ -25,17 +25,14 @@ from systemds.scuro.modality.type import ModalityType
 from systemds.scuro.modality.transformed import TransformedModality
 
 from systemds.scuro.representations.unimodal import UnimodalRepresentation
+from systemds.scuro.drsearch.operator_registry import register_representation
 
 
-class MelSpectrogram(UnimodalRepresentation):
-    def __init__(self, n_mels=128, hop_length=512, n_fft=2048):
-        parameters = {
-            "n_mels": [20, 32, 64, 128],
-            "hop_length": [256, 512, 1024, 2048],
-            "n_fft": [1024, 2048, 4096],
-        }
-        super().__init__("MelSpectrogram", ModalityType.TIMESERIES, parameters)
-        self.n_mels = n_mels
+@register_representation(ModalityType.AUDIO)
+class Spectrogram(UnimodalRepresentation):
+    def __init__(self, hop_length=512, n_fft=2048):
+        parameters = {"hop_length": [256, 512, 1024, 2048], "n_fft": [1024, 
2048, 4096]}
+        super().__init__("Spectrogram", ModalityType.TIMESERIES, parameters)
         self.hop_length = hop_length
         self.n_fft = n_fft
 
@@ -45,9 +42,11 @@ class MelSpectrogram(UnimodalRepresentation):
         )
         result = []
         max_length = 0
-        for sample in modality.data:
-            S = librosa.feature.melspectrogram(y=sample, sr=22050)
-            S_dB = librosa.power_to_db(S, ref=np.max)
+        for i, sample in enumerate(modality.data):
+            spectrogram = librosa.stft(
+                y=sample, hop_length=self.hop_length, n_fft=self.n_fft
+            )
+            S_dB = librosa.amplitude_to_db(np.abs(spectrogram))
             if S_dB.shape[-1] > max_length:
                 max_length = S_dB.shape[-1]
             result.append(S_dB.T)
diff --git a/src/main/python/systemds/scuro/representations/sum.py 
b/src/main/python/systemds/scuro/representations/sum.py
index 0608338a0f..46d93f2eda 100644
--- a/src/main/python/systemds/scuro/representations/sum.py
+++ b/src/main/python/systemds/scuro/representations/sum.py
@@ -27,7 +27,10 @@ from systemds.scuro.representations.utils import 
pad_sequences
 
 from systemds.scuro.representations.fusion import Fusion
 
+from systemds.scuro.drsearch.operator_registry import register_fusion_operator
 
+
+@register_fusion_operator()
 class Sum(Fusion):
     def __init__(self):
         """
diff --git 
a/src/main/python/systemds/scuro/representations/swin_video_transformer.py 
b/src/main/python/systemds/scuro/representations/swin_video_transformer.py
new file mode 100644
index 0000000000..19b2fd05c4
--- /dev/null
+++ b/src/main/python/systemds/scuro/representations/swin_video_transformer.py
@@ -0,0 +1,111 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+# from torchvision.models.video.swin_transformer import swin3d_t
+
+from systemds.scuro.modality.transformed import TransformedModality
+from systemds.scuro.representations.unimodal import UnimodalRepresentation
+from typing import Callable, Dict, Tuple, Any
+import torch.utils.data
+import torch
+import torchvision.models as models
+import numpy as np
+from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.drsearch.operator_registry import register_representation
+
+from systemds.scuro.utils.torch_dataset import CustomDataset
+
+if torch.backends.mps.is_available():
+    DEVICE = torch.device("mps")
+# elif torch.cuda.is_available():
+#     DEVICE = torch.device("cuda")
+else:
+    DEVICE = torch.device("cpu")
+
+
+# @register_representation([ModalityType.VIDEO])
+class SwinVideoTransformer(UnimodalRepresentation):
+    def __init__(self, layer_name="avgpool"):
+        parameters = {
+            "layer_name": [
+                "features",
+                "features.1",
+                "features.2",
+                "features.3",
+                "features.4",
+                "features.5",
+                "features.6",
+                "avgpool",
+            ],
+        }
+        super().__init__("SwinVideoTransformer", ModalityType.TIMESERIES, 
parameters)
+        self.layer_name = layer_name
+        # self.model = 
swin3d_t(weights=models.video.Swin3D_T_Weights).to(DEVICE)
+        self.model.eval()
+        for param in self.model.parameters():
+            param.requires_grad = False
+
+    def transform(self, modality):
+        # model = swin3d_t(weights=models.video.Swin3D_T_Weights)
+
+        embeddings = {}
+        swin_output = None
+
+        def get_features(name_):
+            def hook(
+                _module: torch.nn.Module, input_: Tuple[torch.Tensor], output: 
Any
+            ):
+                nonlocal swin_output
+                swin_output = output
+
+            return hook
+
+        if self.layer_name:
+            for name, layer in self.model.named_modules():
+                if name == self.layer_name:
+                    layer.register_forward_hook(get_features(name))
+                    break
+        dataset = CustomDataset(modality.data)
+
+        for instance in dataset:
+            video_id = instance["id"]
+            frames = instance["data"].to(DEVICE)
+            embeddings[video_id] = []
+
+            frames = frames.unsqueeze(0).permute(0, 2, 1, 3, 4)
+
+            _ = self.model(frames)
+            values = swin_output
+            pooled = torch.nn.functional.adaptive_avg_pool2d(values, (1, 1))
+
+            embeddings[video_id].extend(torch.flatten(pooled, 
1).detach().cpu().numpy())
+
+            embeddings[video_id] = np.array(embeddings[video_id])
+
+        transformed_modality = TransformedModality(
+            self.output_modality_type,
+            "swinVideoTransformer",
+            modality.modality_id,
+            modality.metadata,
+        )
+
+        transformed_modality.data = list(embeddings.values())
+
+        return transformed_modality
diff --git a/src/main/python/systemds/scuro/representations/tfidf.py 
b/src/main/python/systemds/scuro/representations/tfidf.py
index 30a6655150..c17527b476 100644
--- a/src/main/python/systemds/scuro/representations/tfidf.py
+++ b/src/main/python/systemds/scuro/representations/tfidf.py
@@ -26,8 +26,10 @@ from systemds.scuro.representations.unimodal import 
UnimodalRepresentation
 from systemds.scuro.representations.utils import save_embeddings
 
 from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.drsearch.operator_registry import register_representation
 
 
+@register_representation(ModalityType.TEXT)
 class TfIdf(UnimodalRepresentation):
     def __init__(self, min_df=2, output_file=None):
         parameters = {"min_df": [min_df]}
diff --git a/src/main/python/systemds/scuro/representations/mel_spectrogram.py 
b/src/main/python/systemds/scuro/representations/wav2vec.py
similarity index 50%
copy from src/main/python/systemds/scuro/representations/mel_spectrogram.py
copy to src/main/python/systemds/scuro/representations/wav2vec.py
index dfff4f3b7e..bf251b101c 100644
--- a/src/main/python/systemds/scuro/representations/mel_spectrogram.py
+++ b/src/main/python/systemds/scuro/representations/wav2vec.py
@@ -18,39 +18,51 @@
 # under the License.
 #
 # -------------------------------------------------------------
-import librosa
 import numpy as np
-
+from transformers import Wav2Vec2Processor, Wav2Vec2Model
+import librosa
+import torch
 from systemds.scuro.modality.type import ModalityType
 from systemds.scuro.modality.transformed import TransformedModality
 
 from systemds.scuro.representations.unimodal import UnimodalRepresentation
+from systemds.scuro.drsearch.operator_registry import register_representation
+
+import warnings
 
+warnings.filterwarnings("ignore", message="Some weights of")
 
-class MelSpectrogram(UnimodalRepresentation):
-    def __init__(self, n_mels=128, hop_length=512, n_fft=2048):
-        parameters = {
-            "n_mels": [20, 32, 64, 128],
-            "hop_length": [256, 512, 1024, 2048],
-            "n_fft": [1024, 2048, 4096],
-        }
-        super().__init__("MelSpectrogram", ModalityType.TIMESERIES, parameters)
-        self.n_mels = n_mels
-        self.hop_length = hop_length
-        self.n_fft = n_fft
+
+@register_representation(ModalityType.AUDIO)
+class Wav2Vec(UnimodalRepresentation):
+    def __init__(self):
+        super().__init__("Wav2Vec", ModalityType.TIMESERIES, {})
+        self.processor = Wav2Vec2Processor.from_pretrained(
+            "facebook/wav2vec2-base-960h"
+        )
+        self.model = Wav2Vec2Model.from_pretrained(
+            "facebook/wav2vec2-base-960h"
+        ).float()
 
     def transform(self, modality):
         transformed_modality = TransformedModality(
             self.output_modality_type, self, modality.modality_id, 
modality.metadata
         )
+
         result = []
-        max_length = 0
-        for sample in modality.data:
-            S = librosa.feature.melspectrogram(y=sample, sr=22050)
-            S_dB = librosa.power_to_db(S, ref=np.max)
-            if S_dB.shape[-1] > max_length:
-                max_length = S_dB.shape[-1]
-            result.append(S_dB.T)
+        for i, sample in enumerate(modality.data):
+            sr = list(modality.metadata.values())[i]["frequency"]
+            audio_resampled = librosa.resample(sample, orig_sr=sr, 
target_sr=16000)
+            input = self.processor(
+                audio_resampled, sampling_rate=16000, return_tensors="pt", 
padding=True
+            )
+            input.input_values = input.input_values.float()
+            input.data["input_values"] = input.data["input_values"].float()
+            with torch.no_grad():
+                outputs = self.model(**input)
+                features = outputs.extract_features
+                # TODO: check how to get intermediate representations
+            result.append(torch.flatten(features.mean(dim=1), 
1).detach().cpu().numpy())
 
         transformed_modality.data = result
         return transformed_modality
diff --git a/src/main/python/systemds/scuro/representations/window.py 
b/src/main/python/systemds/scuro/representations/window.py
index 264d40ca42..bff63729c7 100644
--- a/src/main/python/systemds/scuro/representations/window.py
+++ b/src/main/python/systemds/scuro/representations/window.py
@@ -23,12 +23,12 @@ import math
 
 from systemds.scuro.modality.type import DataLayout
 
-# from systemds.scuro.drsearch.operator_registry import 
register_context_operator
+from systemds.scuro.drsearch.operator_registry import register_context_operator
 from systemds.scuro.representations.aggregate import Aggregation
 from systemds.scuro.representations.context import Context
 
 
-# @register_context_operator()
+@register_context_operator()
 class WindowAggregation(Context):
     def __init__(self, window_size=10, aggregation_function="mean"):
         parameters = {
@@ -65,6 +65,8 @@ class WindowAggregation(Context):
         return windowed_data
 
     def window_aggregate_single_level(self, instance, new_length):
+        if isinstance(instance, str):
+            return instance
         num_cols = instance.shape[1] if instance.ndim > 1 else 1
         result = np.empty((new_length, num_cols))
         for i in range(0, new_length):
diff --git a/src/main/python/systemds/scuro/representations/word2vec.py 
b/src/main/python/systemds/scuro/representations/word2vec.py
index 929dbd4415..e1d1669d9b 100644
--- a/src/main/python/systemds/scuro/representations/word2vec.py
+++ b/src/main/python/systemds/scuro/representations/word2vec.py
@@ -26,10 +26,9 @@ from gensim.models import Word2Vec
 from gensim.utils import tokenize
 
 from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.drsearch.operator_registry import register_representation
 import nltk
 
-nltk.download("punkt_tab")
-
 
 def get_embedding(sentence, model):
     vectors = []
@@ -40,6 +39,7 @@ def get_embedding(sentence, model):
     return np.mean(vectors, axis=0) if vectors else np.zeros(model.vector_size)
 
 
+@register_representation(ModalityType.TEXT)
 class W2V(UnimodalRepresentation):
     def __init__(self, vector_size=3, min_count=2, window=2, output_file=None):
         parameters = {
@@ -71,5 +71,5 @@ class W2V(UnimodalRepresentation):
 
         if self.output_file is not None:
             save_embeddings(np.array(embeddings), self.output_file)
-        transformed_modality.data = np.array(embeddings)
+        transformed_modality.data = embeddings
         return transformed_modality
diff --git a/src/main/python/systemds/scuro/representations/resnet.py 
b/src/main/python/systemds/scuro/representations/x3d.py
similarity index 50%
copy from src/main/python/systemds/scuro/representations/resnet.py
copy to src/main/python/systemds/scuro/representations/x3d.py
index 60eed9ea12..bb5d1ec5ed 100644
--- a/src/main/python/systemds/scuro/representations/resnet.py
+++ b/src/main/python/systemds/scuro/representations/x3d.py
@@ -18,36 +18,36 @@
 # under the License.
 #
 # -------------------------------------------------------------
-
+from systemds.scuro.utils.torch_dataset import CustomDataset
 from systemds.scuro.modality.transformed import TransformedModality
 from systemds.scuro.representations.unimodal import UnimodalRepresentation
 from typing import Callable, Dict, Tuple, Any
 import torch.utils.data
 import torch
+from torchvision.models.video import r3d_18, s3d
 import torchvision.models as models
 import torchvision.transforms as transforms
 import numpy as np
 from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.drsearch.operator_registry import register_representation
 
 if torch.backends.mps.is_available():
     DEVICE = torch.device("mps")
-elif torch.cuda.is_available():
-    DEVICE = torch.device("cuda")
+# elif torch.cuda.is_available():
+#     DEVICE = torch.device("cuda")
 else:
     DEVICE = torch.device("cpu")
 
 
-class ResNet(UnimodalRepresentation):
-    def __init__(self, layer="avgpool", model_name="ResNet18", 
output_file=None):
+# @register_representation([ModalityType.VIDEO])
+class X3D(UnimodalRepresentation):
+    def __init__(self, layer="avgpool", model_name="r3d", output_file=None):
         self.model_name = model_name
         parameters = self._get_parameters()
-        super().__init__(
-            "ResNet", ModalityType.TIMESERIES, parameters
-        )  # TODO: TIMESERIES only for videos - images would be handled as 
EMBEDDIGN
+        super().__init__("X3D", ModalityType.TIMESERIES, parameters)
 
         self.output_file = output_file
         self.layer_name = layer
-        self.model = model_name
         self.model.eval()
         for param in self.model.parameters():
             param.requires_grad = False
@@ -59,37 +59,22 @@ class ResNet(UnimodalRepresentation):
         self.model.fc = Identity()
 
     @property
-    def model(self):
-        return self._model
-
-    @model.setter
-    def model(self, model):
-        if model == "ResNet18":
-            self._model = 
models.resnet18(weights=models.ResNet18_Weights.DEFAULT).to(
-                DEVICE
-            )
-        elif model == "ResNet34":
-            self._model = 
models.resnet34(weights=models.ResNet34_Weights.DEFAULT).to(
-                DEVICE
-            )
-        elif model == "ResNet50":
-            self._model = 
models.resnet50(weights=models.ResNet50_Weights.DEFAULT).to(
-                DEVICE
-            )
-        elif model == "ResNet101":
-            self._model = 
models.resnet101(weights=models.ResNet101_Weights.DEFAULT).to(
-                DEVICE
-            )
-        elif model == "ResNet152":
-            self._model = 
models.resnet152(weights=models.ResNet152_Weights.DEFAULT).to(
-                DEVICE
-            )
+    def model_name(self):
+        return self._model_name
+
+    @model_name.setter
+    def model_name(self, model_name):
+        self._model_name = model_name
+        if model_name == "r3d":
+            self.model = r3d_18(pretrained=True).to(DEVICE)
+        elif model_name == "s3d":
+            self.model = 
s3d(weights=models.video.S3D_Weights.DEFAULT).to(DEVICE)
         else:
             raise NotImplementedError
 
     def _get_parameters(self, high_level=True):
         parameters = {"model_name": [], "layer_name": []}
-        for m in ["ResNet18", "ResNet34", "ResNet50", "ResNet101", 
"ResNet152"]:
+        for m in ["r3d", "s3d"]:
             parameters["model_name"].append(m)
 
         if high_level:
@@ -107,20 +92,7 @@ class ResNet(UnimodalRepresentation):
         return parameters
 
     def transform(self, modality):
-
-        t = transforms.Compose(
-            [
-                transforms.ToPILImage(),
-                transforms.Resize(256),
-                transforms.CenterCrop(224),
-                transforms.ToTensor(),
-                transforms.Normalize(
-                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
-                ),
-            ]
-        )
-
-        dataset = ResNetDataset(modality.data, t)
+        dataset = CustomDataset(modality.data)
         embeddings = {}
 
         res5c_output = None
@@ -140,59 +112,24 @@ class ResNet(UnimodalRepresentation):
                     layer.register_forward_hook(get_features(name))
                     break
 
-        for instance in torch.utils.data.DataLoader(dataset):
-            video_id = instance["id"][0]
-            frames = instance["data"][0].to(DEVICE)
+        for instance in dataset:
+            video_id = instance["id"]
+            frames = instance["data"].to(DEVICE)
             embeddings[video_id] = []
-            batch_size = 64
-
-            for start_index in range(0, len(frames), batch_size):
-                end_index = min(start_index + batch_size, len(frames))
-                frame_ids_range = range(start_index, end_index)
-                frame_batch = frames[frame_ids_range]
 
-                _ = self.model(frame_batch)
-                values = res5c_output
-                pooled = torch.nn.functional.adaptive_avg_pool2d(values, (1, 
1))
+            frames = frames.unsqueeze(0).permute(0, 2, 1, 3, 4)
+            _ = self.model(frames)
+            values = res5c_output
+            pooled = torch.nn.functional.adaptive_avg_pool2d(values, (1, 1))
 
-                embeddings[video_id].extend(
-                    torch.flatten(pooled, 1).detach().cpu().numpy()
-                )
+            embeddings[video_id].extend(torch.flatten(pooled, 
1).detach().cpu().numpy())
 
             embeddings[video_id] = np.array(embeddings[video_id])
 
         transformed_modality = TransformedModality(
-            self.output_modality_type, "resnet", modality.modality_id, 
modality.metadata
+            self.output_modality_type, "x3d", modality.modality_id, 
modality.metadata
         )
 
         transformed_modality.data = list(embeddings.values())
 
         return transformed_modality
-
-
-class ResNetDataset(torch.utils.data.Dataset):
-    def __init__(self, data: str, tf: Callable = None):
-        self.data = data
-        self.tf = tf
-
-    def __getitem__(self, index) -> Dict[str, object]:
-        data = self.data[index]
-        if type(data) is np.ndarray:
-            output = torch.empty((1, 3, 224, 224))
-            d = torch.tensor(data)
-            d = d.repeat(3, 1, 1)
-            output[0] = self.tf(d)
-        else:
-            output = torch.empty((len(data), 3, 224, 224))
-
-            for i, d in enumerate(data):
-                if data[0].ndim < 3:
-                    d = torch.tensor(d)
-                    d = d.repeat(3, 1, 1)
-
-                output[i] = self.tf(d)
-
-        return {"id": index, "data": output}
-
-    def __len__(self) -> int:
-        return len(self.data)
diff --git a/src/main/python/systemds/scuro/utils/schema_helpers.py 
b/src/main/python/systemds/scuro/utils/schema_helpers.py
index a88e81f716..28af476cca 100644
--- a/src/main/python/systemds/scuro/utils/schema_helpers.py
+++ b/src/main/python/systemds/scuro/utils/schema_helpers.py
@@ -18,7 +18,6 @@
 # under the License.
 #
 # -------------------------------------------------------------
-import math
 import numpy as np
 
 
diff --git a/src/main/python/systemds/scuro/utils/torch_dataset.py 
b/src/main/python/systemds/scuro/utils/torch_dataset.py
new file mode 100644
index 0000000000..a0f3d88b6a
--- /dev/null
+++ b/src/main/python/systemds/scuro/utils/torch_dataset.py
@@ -0,0 +1,63 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+from typing import Dict
+
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+
+
+class CustomDataset(torch.utils.data.Dataset):
+    def __init__(self, data):
+        self.data = data
+        self.tf = transforms.Compose(
+            [
+                transforms.ToPILImage(),
+                transforms.Resize(256),
+                transforms.CenterCrop(224),
+                transforms.ToTensor(),
+                transforms.Normalize(
+                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+                ),
+            ]
+        )
+
+    def __getitem__(self, index) -> Dict[str, object]:
+        data = self.data[index]
+        if type(data) is np.ndarray:
+            output = torch.empty((1, 3, 224, 224))
+            d = torch.tensor(data)
+            d = d.repeat(3, 1, 1)
+            output[0] = self.tf(d)
+        else:
+            output = torch.empty((len(data), 3, 224, 224))
+
+            for i, d in enumerate(data):
+                if data[0].ndim < 3:
+                    d = torch.tensor(d)
+                    d = d.repeat(3, 1, 1)
+
+                output[i] = self.tf(d)
+
+        return {"id": index, "data": output}
+
+    def __len__(self) -> int:
+        return len(self.data)
diff --git a/src/main/python/tests/scuro/data_generator.py 
b/src/main/python/tests/scuro/data_generator.py
index 48ff208e43..e31887ff83 100644
--- a/src/main/python/tests/scuro/data_generator.py
+++ b/src/main/python/tests/scuro/data_generator.py
@@ -26,13 +26,11 @@ from scipy.io.wavfile import write
 import random
 import os
 
-from systemds.scuro import (
-    VideoLoader,
-    AudioLoader,
-    TextLoader,
-    UnimodalModality,
-    TransformedModality,
-)
+from systemds.scuro.dataloader.video_loader import VideoLoader
+from systemds.scuro.dataloader.audio_loader import AudioLoader
+from systemds.scuro.dataloader.text_loader import TextLoader
+from systemds.scuro.modality.unimodal_modality import UnimodalModality
+from systemds.scuro.modality.transformed import TransformedModality
 from systemds.scuro.modality.type import ModalityType
 
 
diff --git a/src/main/python/tests/scuro/test_dr_search.py 
b/src/main/python/tests/scuro/test_dr_search.py
index 0959c246e0..521ff3f468 100644
--- a/src/main/python/tests/scuro/test_dr_search.py
+++ b/src/main/python/tests/scuro/test_dr_search.py
@@ -29,8 +29,8 @@ from sklearn.model_selection import train_test_split
 from sklearn.preprocessing import MinMaxScaler
 
 from systemds.scuro.modality.type import ModalityType
-from systemds.scuro.aligner.dr_search import DRSearch
-from systemds.scuro.aligner.task import Task
+from systemds.scuro.drsearch.dr_search import DRSearch
+from systemds.scuro.drsearch.task import Task
 from systemds.scuro.models.model import Model
 from systemds.scuro.representations.average import Average
 from systemds.scuro.representations.bert import Bert
diff --git a/src/main/python/tests/scuro/test_multimodal_fusion.py 
b/src/main/python/tests/scuro/test_multimodal_fusion.py
new file mode 100644
index 0000000000..8456279c3d
--- /dev/null
+++ b/src/main/python/tests/scuro/test_multimodal_fusion.py
@@ -0,0 +1,202 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+
+import shutil
+import unittest
+
+import numpy as np
+from sklearn import svm
+from sklearn.metrics import classification_report
+from sklearn.model_selection import train_test_split
+
+from systemds.scuro.representations.concatenation import Concatenation
+from systemds.scuro.representations.average import Average
+from systemds.scuro.drsearch.fusion_optimizer import FusionOptimizer
+from systemds.scuro.drsearch.operator_registry import Registry
+from systemds.scuro.models.model import Model
+from systemds.scuro.drsearch.task import Task
+from systemds.scuro.drsearch.unimodal_representation_optimizer import (
+    UnimodalRepresentationOptimizer,
+)
+
+from systemds.scuro.representations.spectrogram import Spectrogram
+from systemds.scuro.representations.word2vec import W2V
+from systemds.scuro.modality.unimodal_modality import UnimodalModality
+from systemds.scuro.representations.resnet import ResNet
+from tests.scuro.data_generator import setup_data
+
+from systemds.scuro.dataloader.audio_loader import AudioLoader
+from systemds.scuro.dataloader.video_loader import VideoLoader
+from systemds.scuro.dataloader.text_loader import TextLoader
+from systemds.scuro.modality.type import ModalityType
+
+from unittest.mock import patch
+
+
+class TestSVM(Model):
+    def __init__(self):
+        super().__init__("TestSVM")
+
+    def fit(self, X, y, X_test, y_test):
+        if X.ndim > 2:
+            X = X.reshape(X.shape[0], -1)
+        self.clf = svm.SVC(C=1, gamma="scale", kernel="rbf", verbose=False)
+        self.clf = self.clf.fit(X, np.array(y))
+        y_pred = self.clf.predict(X)
+
+        return classification_report(
+            y, y_pred, output_dict=True, digits=3, zero_division=1
+        )["accuracy"]
+
+    def test(self, test_X: np.ndarray, test_y: np.ndarray):
+        if test_X.ndim > 2:
+            test_X = test_X.reshape(test_X.shape[0], -1)
+        y_pred = self.clf.predict(np.array(test_X))  # noqa
+
+        return classification_report(
+            np.array(test_y), y_pred, output_dict=True, digits=3, 
zero_division=1
+        )["accuracy"]
+
+
+class TestCNN(Model):
+    def __init__(self):
+        super().__init__("TestCNN")
+
+    def fit(self, X, y, X_test, y_test):
+        if X.ndim > 2:
+            X = X.reshape(X.shape[0], -1)
+        self.clf = svm.SVC(C=1, gamma="scale", kernel="rbf", verbose=False)
+        self.clf = self.clf.fit(X, np.array(y))
+        y_pred = self.clf.predict(X)
+
+        return classification_report(
+            y, y_pred, output_dict=True, digits=3, zero_division=1
+        )["accuracy"]
+
+    def test(self, test_X: np.ndarray, test_y: np.ndarray):
+        if test_X.ndim > 2:
+            test_X = test_X.reshape(test_X.shape[0], -1)
+        y_pred = self.clf.predict(np.array(test_X))  # noqa
+
+        return classification_report(
+            np.array(test_y), y_pred, output_dict=True, digits=3, 
zero_division=1
+        )["accuracy"]
+
+
+class TestMultimodalRepresentationOptimizer(unittest.TestCase):
+    test_file_path = None
+    data_generator = None
+    num_instances = 0
+
+    @classmethod
+    def setUpClass(cls):
+        cls.test_file_path = "fusion_optimizer_test_data"
+
+        cls.num_instances = 10
+        cls.mods = [ModalityType.VIDEO, ModalityType.AUDIO, ModalityType.TEXT]
+
+        cls.data_generator = setup_data(cls.mods, cls.num_instances, 
cls.test_file_path)
+        split = train_test_split(
+            cls.data_generator.indices,
+            cls.data_generator.labels,
+            test_size=0.2,
+            random_state=42,
+        )
+        cls.train_indizes, cls.val_indizes = [int(i) for i in split[0]], [
+            int(i) for i in split[1]
+        ]
+
+        cls.tasks = [
+            Task(
+                "UnimodalRepresentationTask1",
+                TestSVM(),
+                cls.data_generator.labels,
+                cls.train_indizes,
+                cls.val_indizes,
+            ),
+            Task(
+                "UnimodalRepresentationTask2",
+                TestCNN(),
+                cls.data_generator.labels,
+                cls.train_indizes,
+                cls.val_indizes,
+            ),
+        ]
+
+    @classmethod
+    def tearDownClass(cls):
+        shutil.rmtree(cls.test_file_path)
+
+    def test_multimodal_fusion(self):
+        task = Task(
+            "UnimodalRepresentationTask1",
+            TestSVM(),
+            self.data_generator.labels,
+            self.train_indizes,
+            self.val_indizes,
+        )
+        audio_data_loader = AudioLoader(
+            self.data_generator.get_modality_path(ModalityType.AUDIO),
+            self.data_generator.indices,
+        )
+        audio = UnimodalModality(audio_data_loader)
+
+        text_data_loader = TextLoader(
+            self.data_generator.get_modality_path(ModalityType.TEXT),
+            self.data_generator.indices,
+        )
+        text = UnimodalModality(text_data_loader)
+
+        video_data_loader = VideoLoader(
+            self.data_generator.get_modality_path(ModalityType.VIDEO),
+            self.data_generator.indices,
+        )
+        video = UnimodalModality(video_data_loader)
+
+        with patch.object(
+            Registry,
+            "_representations",
+            {
+                ModalityType.TEXT: [W2V],
+                ModalityType.AUDIO: [Spectrogram],
+                ModalityType.TIMESERIES: [ResNet],
+                ModalityType.VIDEO: [ResNet],
+                ModalityType.EMBEDDING: [],
+            },
+        ):
+            registry = Registry()
+            registry._fusion_operators = [Average, Concatenation]
+            unimodal_optimizer = UnimodalRepresentationOptimizer(
+                [text, audio, video], [task], max_chain_depth=2
+            )
+            unimodal_optimizer.optimize()
+
+            multimodal_optimizer = FusionOptimizer(
+                [audio, text, video],
+                task,
+                unimodal_optimizer.optimization_results,
+                unimodal_optimizer.cache,
+                2,
+                2,
+                debug=False,
+            )
+            multimodal_optimizer.optimize()
diff --git a/src/main/python/tests/scuro/test_multimodal_join.py 
b/src/main/python/tests/scuro/test_multimodal_join.py
index 8388829f30..a5e3a7caf9 100644
--- a/src/main/python/tests/scuro/test_multimodal_join.py
+++ b/src/main/python/tests/scuro/test_multimodal_join.py
@@ -24,8 +24,6 @@ import shutil
 import unittest
 
 from systemds.scuro.modality.joined import JoinCondition
-from systemds.scuro.representations.aggregate import Aggregation
-from systemds.scuro.representations.window import WindowAggregation
 from systemds.scuro.modality.unimodal_modality import UnimodalModality
 from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
 from systemds.scuro.representations.resnet import ResNet
diff --git a/src/main/python/tests/scuro/test_operator_registry.py 
b/src/main/python/tests/scuro/test_operator_registry.py
new file mode 100644
index 0000000000..aaecde2991
--- /dev/null
+++ b/src/main/python/tests/scuro/test_operator_registry.py
@@ -0,0 +1,87 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+import unittest
+
+from systemds.scuro.representations.mfcc import MFCC
+from systemds.scuro.representations.wav2vec import Wav2Vec
+from systemds.scuro.representations.window import WindowAggregation
+from systemds.scuro.representations.bow import BoW
+from systemds.scuro.representations.word2vec import W2V
+from systemds.scuro.representations.tfidf import TfIdf
+from systemds.scuro.drsearch.operator_registry import Registry
+from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.representations.average import Average
+from systemds.scuro.representations.bert import Bert
+from systemds.scuro.representations.concatenation import Concatenation
+from systemds.scuro.representations.lstm import LSTM
+from systemds.scuro.representations.max import RowMax
+from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
+from systemds.scuro.representations.spectrogram import Spectrogram
+from systemds.scuro.representations.multiplication import Multiplication
+from systemds.scuro.representations.resnet import ResNet
+from systemds.scuro.representations.sum import Sum
+
+
+class TestOperatorRegistry(unittest.TestCase):
+    def test_audio_representations_in_registry(self):
+        registry = Registry()
+        for representation in [Spectrogram, MelSpectrogram, Wav2Vec, MFCC]:
+            assert representation in registry.get_representations(
+                ModalityType.AUDIO
+            ), f"{representation} not in registry"
+
+    def test_video_representations_in_registry(self):
+        registry = Registry()
+        assert registry.get_representations(ModalityType.VIDEO) == [ResNet]
+
+    def test_timeseries_representations_in_registry(self):
+        registry = Registry()
+        assert registry.get_representations(ModalityType.TIMESERIES) == 
[ResNet]
+
+    def test_text_representations_in_registry(self):
+        registry = Registry()
+        for representation in [BoW, TfIdf, W2V, Bert]:
+            assert representation in registry.get_representations(
+                ModalityType.TEXT
+            ), f"{representation} not in registry"
+
+    def test_context_operator_in_registry(self):
+        registry = Registry()
+        assert registry.get_context_operators() == [WindowAggregation]
+
+    # def test_fusion_operator_in_registry(self):
+    #     registry = Registry()
+    #     for fusion_operator in [
+    #         # RowMax,
+    #         Sum,
+    #         Average,
+    #         Concatenation,
+    #         LSTM,
+    #         Multiplication,
+    #     ]:
+    #         assert (
+    #             fusion_operator in registry.get_fusion_operators()
+    #         ), f"{fusion_operator} not in registry"
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/src/main/python/tests/scuro/test_unimodal_optimizer.py 
b/src/main/python/tests/scuro/test_unimodal_optimizer.py
new file mode 100644
index 0000000000..bfc52f0103
--- /dev/null
+++ b/src/main/python/tests/scuro/test_unimodal_optimizer.py
@@ -0,0 +1,203 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+
+import shutil
+import unittest
+
+import numpy as np
+from sklearn import svm
+from sklearn.metrics import classification_report
+from sklearn.model_selection import train_test_split
+
+from systemds.scuro.drsearch.operator_registry import Registry
+from systemds.scuro.models.model import Model
+from systemds.scuro.drsearch.task import Task
+from systemds.scuro.drsearch.unimodal_representation_optimizer import (
+    UnimodalRepresentationOptimizer,
+)
+
+from systemds.scuro.representations.spectrogram import Spectrogram
+from systemds.scuro.representations.word2vec import W2V
+from systemds.scuro.modality.unimodal_modality import UnimodalModality
+from systemds.scuro.representations.resnet import ResNet
+from tests.scuro.data_generator import setup_data
+
+from systemds.scuro.dataloader.audio_loader import AudioLoader
+from systemds.scuro.dataloader.video_loader import VideoLoader
+from systemds.scuro.dataloader.text_loader import TextLoader
+from systemds.scuro.modality.type import ModalityType
+
+
+class TestSVM(Model):
+    def __init__(self):
+        super().__init__("TestSVM")
+
+    def fit(self, X, y, X_test, y_test):
+        if X.ndim > 2:
+            X = X.reshape(X.shape[0], -1)
+        self.clf = svm.SVC(C=1, gamma="scale", kernel="rbf", verbose=False)
+        self.clf = self.clf.fit(X, np.array(y))
+        y_pred = self.clf.predict(X)
+
+        return classification_report(
+            y, y_pred, output_dict=True, digits=3, zero_division=1
+        )["accuracy"]
+
+    def test(self, test_X: np.ndarray, test_y: np.ndarray):
+        if test_X.ndim > 2:
+            test_X = test_X.reshape(test_X.shape[0], -1)
+        y_pred = self.clf.predict(np.array(test_X))  # noqa
+
+        return classification_report(
+            np.array(test_y), y_pred, output_dict=True, digits=3, 
zero_division=1
+        )["accuracy"]
+
+
+class TestCNN(Model):
+    def __init__(self):
+        super().__init__("TestCNN")
+
+    def fit(self, X, y, X_test, y_test):
+        if X.ndim > 2:
+            X = X.reshape(X.shape[0], -1)
+        self.clf = svm.SVC(C=1, gamma="scale", kernel="rbf", verbose=False)
+        self.clf = self.clf.fit(X, np.array(y))
+        y_pred = self.clf.predict(X)
+
+        return classification_report(
+            y, y_pred, output_dict=True, digits=3, zero_division=1
+        )["accuracy"]
+
+    def test(self, test_X: np.ndarray, test_y: np.ndarray):
+        if test_X.ndim > 2:
+            test_X = test_X.reshape(test_X.shape[0], -1)
+        y_pred = self.clf.predict(np.array(test_X))  # noqa
+
+        return classification_report(
+            np.array(test_y), y_pred, output_dict=True, digits=3, 
zero_division=1
+        )["accuracy"]
+
+
+from unittest.mock import patch
+
+
+class TestUnimodalRepresentationOptimizer(unittest.TestCase):
+    test_file_path = None
+    data_generator = None
+    num_instances = 0
+
+    @classmethod
+    def setUpClass(cls):
+        cls.test_file_path = "unimodal_optimizer_test_data"
+
+        cls.num_instances = 10
+        cls.mods = [ModalityType.VIDEO, ModalityType.AUDIO, ModalityType.TEXT]
+
+        cls.data_generator = setup_data(cls.mods, cls.num_instances, 
cls.test_file_path)
+        split = train_test_split(
+            cls.data_generator.indices,
+            cls.data_generator.labels,
+            test_size=0.2,
+            random_state=42,
+        )
+        cls.train_indizes, cls.val_indizes = [int(i) for i in split[0]], [
+            int(i) for i in split[1]
+        ]
+
+        cls.tasks = [
+            Task(
+                "UnimodalRepresentationTask1",
+                TestSVM(),
+                cls.data_generator.labels,
+                cls.train_indizes,
+                cls.val_indizes,
+            ),
+            Task(
+                "UnimodalRepresentationTask2",
+                TestCNN(),
+                cls.data_generator.labels,
+                cls.train_indizes,
+                cls.val_indizes,
+            ),
+        ]
+
+    @classmethod
+    def tearDownClass(cls):
+        shutil.rmtree(cls.test_file_path)
+
+    def test_unimodal_optimizer_for_audio_modality(self):
+        audio_data_loader = AudioLoader(
+            self.data_generator.get_modality_path(ModalityType.AUDIO),
+            self.data_generator.indices,
+        )
+        audio = UnimodalModality(audio_data_loader)
+
+        self.optimize_unimodal_representation_for_modality(audio)
+
+    def test_unimodal_optimizer_for_text_modality(self):
+        text_data_loader = TextLoader(
+            self.data_generator.get_modality_path(ModalityType.TEXT),
+            self.data_generator.indices,
+        )
+        text = UnimodalModality(text_data_loader)
+        self.optimize_unimodal_representation_for_modality(text)
+
+    def test_unimodal_optimizer_for_video_modality(self):
+        video_data_loader = VideoLoader(
+            self.data_generator.get_modality_path(ModalityType.VIDEO),
+            self.data_generator.indices,
+        )
+        video = UnimodalModality(video_data_loader)
+        self.optimize_unimodal_representation_for_modality(video)
+
+    def optimize_unimodal_representation_for_modality(self, modality):
+        with patch.object(
+            Registry,
+            "_representations",
+            {
+                ModalityType.TEXT: [W2V],
+                ModalityType.AUDIO: [Spectrogram],
+                ModalityType.TIMESERIES: [ResNet],
+                ModalityType.VIDEO: [ResNet],
+                ModalityType.EMBEDDING: [],
+            },
+        ):
+            registry = Registry()
+
+            unimodal_optimizer = UnimodalRepresentationOptimizer(
+                [modality], self.tasks, max_chain_depth=2
+            )
+            unimodal_optimizer.optimize()
+
+            assert (
+                list(unimodal_optimizer.optimization_results.keys())[0]
+                == modality.modality_id
+            )
+            assert 
len(list(unimodal_optimizer.optimization_results.values())[0]) == 2
+            assert (
+                len(
+                    unimodal_optimizer.get_k_best_results(modality, 1, 
self.tasks[0])[
+                        0
+                    ].operator_chain
+                )
+                >= 1
+            )


Reply via email to