junrushao1994 commented on code in PR #11961:
URL: https://github.com/apache/tvm/pull/11961#discussion_r912448190


##########
python/tvm/meta_schedule/cost_model/mlp_model.py:
##########
@@ -0,0 +1,743 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+MLP-based cost model
+"""
+
+import logging
+import math
+import os
+import random
+import tempfile
+from collections import OrderedDict
+from itertools import chain as itertools_chain
+from typing import Dict, List, NamedTuple
+from tqdm import tqdm
+
+import numpy as np
+import torch
+from torch import nn
+
+# pylint: disable=relative-beyond-top-level
+from ...contrib.tar import tar, untar
+from ...runtime import NDArray
+from ..cost_model import PyCostModel
+from ..feature_extractor import FeatureExtractor, PerStoreFeature
+from ..runner import RunnerResult
+from ..search_strategy import MeasureCandidate
+from ..tune_context import TuneContext
+from ..utils import derived_object, shash2hex
+
+logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
+
+
+# pylint: disable=no-member
+
+
+class SegmentSumMLPConfig(NamedTuple):
+    """SegmentSum MLP model configuration
+
+    Parameters
+    ----------
+    input_dim : int
+        The input dim for the model.
+    hidden_dim : int
+        The hidden dim for the model.
+    output_dim : int
+        The output dim for the model.
+    use_norm : bool
+        Whether to normalize the segment sum or not.
+    use_sigmoid : bool
+        Whether to use sigmoid on the final output or not.
+    """
+
+    input_dim: int = 172
+    hidden_dim: int = 256
+    output_dim: int = 1
+    use_norm: bool = False
+    use_sigmoid: bool = False
+
+    def to_dict(self):  # pylint: disable=missing-function-docstring
+        return {
+            "input_dim": self.input_dim,
+            "hidden_dim": self.hidden_dim,
+            "output_dim": self.output_dim,
+            "use_norm": self.use_norm,
+            "use_sigmoid": self.use_sigmoid,
+        }
+
+
+# pylint: disable=too-few-public-methods
+class FeatureGroup:
+    """Feature group
+
+    Parameters
+    ----------
+    group_hash : str
+        The hash of the group
+    features : List[np.ndarray]
+        The features
+    costs : List[float]
+        The costs
+    min_cost : float
+        The minimum cost
+    """
+
+    group_hash: str
+    features: List[np.ndarray]
+    costs: np.ndarray
+    min_cost: float
+
+    def __init__(
+        self,
+        group_hash: str,
+        features: List[np.ndarray],
+        costs: np.ndarray,
+    ) -> None:
+        self.group_hash = group_hash
+        self.features = features
+        self.costs = costs
+        self.min_cost = np.min(costs)
+
+    def append(  # pylint: disable=missing-function-docstring
+        self,
+        features: List[np.ndarray],
+        costs: np.ndarray,
+    ) -> None:
+        self.features.extend(features)
+        self.costs = np.append(self.costs, costs)
+        self.min_cost = np.min(self.costs)
+
+
+# pylint: disable=too-many-instance-attributes
+class SegmentDataLoader:
+    """Dataloader for SegmentSum MLP model.
+
+    Parameters
+    ----------
+    features : List[np.ndarray]
+        The features
+    results : np.ndarray
+        The measured results
+    batch_size : int
+        The batch size
+    shuffle : bool
+        Whether to shuffle the dataset or not
+    """
+
+    def __init__(
+        self,
+        features,
+        results,
+        batch_size=128,
+        shuffle=False,
+    ):
+        self.batch_size = batch_size
+        self.shuffle = shuffle
+        self.data_size = len(features)
+
+        # flatten features and store the starting indices
+        self.segment_sizes = torch.tensor([len(feature) for feature in 
features])

Review Comment:
   let's also specify the dtype explicitly here



##########
python/tvm/meta_schedule/cost_model/mlp_model.py:
##########
@@ -0,0 +1,743 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+MLP-based cost model
+"""
+
+import logging
+import math
+import os
+import random
+import tempfile
+from collections import OrderedDict
+from itertools import chain as itertools_chain
+from typing import Dict, List, NamedTuple
+from tqdm import tqdm
+
+import numpy as np
+import torch
+from torch import nn
+
+# pylint: disable=relative-beyond-top-level
+from ...contrib.tar import tar, untar
+from ...runtime import NDArray
+from ..cost_model import PyCostModel
+from ..feature_extractor import FeatureExtractor, PerStoreFeature
+from ..runner import RunnerResult
+from ..search_strategy import MeasureCandidate
+from ..tune_context import TuneContext
+from ..utils import derived_object, shash2hex
+
+logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
+
+
+# pylint: disable=no-member
+
+
+class SegmentSumMLPConfig(NamedTuple):
+    """SegmentSum MLP model configuration
+
+    Parameters
+    ----------
+    input_dim : int
+        The input dim for the model.
+    hidden_dim : int
+        The hidden dim for the model.
+    output_dim : int
+        The output dim for the model.
+    use_norm : bool
+        Whether to normalize the segment sum or not.
+    use_sigmoid : bool
+        Whether to use sigmoid on the final output or not.
+    """
+
+    input_dim: int = 172
+    hidden_dim: int = 256
+    output_dim: int = 1
+    use_norm: bool = False
+    use_sigmoid: bool = False
+
+    def to_dict(self):  # pylint: disable=missing-function-docstring
+        return {
+            "input_dim": self.input_dim,
+            "hidden_dim": self.hidden_dim,
+            "output_dim": self.output_dim,
+            "use_norm": self.use_norm,
+            "use_sigmoid": self.use_sigmoid,
+        }
+
+
+# pylint: disable=too-few-public-methods
+class FeatureGroup:
+    """Feature group
+
+    Parameters
+    ----------
+    group_hash : str
+        The hash of the group
+    features : List[np.ndarray]
+        The features
+    costs : List[float]
+        The costs
+    min_cost : float
+        The minimum cost
+    """
+
+    group_hash: str
+    features: List[np.ndarray]
+    costs: np.ndarray
+    min_cost: float
+
+    def __init__(
+        self,
+        group_hash: str,
+        features: List[np.ndarray],
+        costs: np.ndarray,
+    ) -> None:
+        self.group_hash = group_hash
+        self.features = features
+        self.costs = costs
+        self.min_cost = np.min(costs)
+
+    def append(  # pylint: disable=missing-function-docstring
+        self,
+        features: List[np.ndarray],
+        costs: np.ndarray,
+    ) -> None:
+        self.features.extend(features)
+        self.costs = np.append(self.costs, costs)
+        self.min_cost = np.min(self.costs)
+
+
+# pylint: disable=too-many-instance-attributes
+class SegmentDataLoader:
+    """Dataloader for SegmentSum MLP model.
+
+    Parameters
+    ----------
+    features : List[np.ndarray]
+        The features
+    results : np.ndarray
+        The measured results
+    batch_size : int
+        The batch size
+    shuffle : bool
+        Whether to shuffle the dataset or not
+    """
+
+    def __init__(
+        self,
+        features,
+        results,
+        batch_size=128,
+        shuffle=False,

Review Comment:
   quick question: do you think by default we should use `shuffle=True`?



##########
python/tvm/meta_schedule/cost_model/mlp_model.py:
##########
@@ -0,0 +1,743 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+MLP-based cost model
+"""
+
+import logging
+import math
+import os
+import random
+import tempfile
+from collections import OrderedDict
+from itertools import chain as itertools_chain
+from typing import Dict, List, NamedTuple
+from tqdm import tqdm
+
+import numpy as np
+import torch
+from torch import nn
+
+# pylint: disable=relative-beyond-top-level
+from ...contrib.tar import tar, untar
+from ...runtime import NDArray
+from ..cost_model import PyCostModel
+from ..feature_extractor import FeatureExtractor, PerStoreFeature
+from ..runner import RunnerResult
+from ..search_strategy import MeasureCandidate
+from ..tune_context import TuneContext
+from ..utils import derived_object, shash2hex
+
+logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
+
+
+# pylint: disable=no-member
+
+
+class SegmentSumMLPConfig(NamedTuple):
+    """SegmentSum MLP model configuration
+
+    Parameters
+    ----------
+    input_dim : int
+        The input dim for the model.
+    hidden_dim : int
+        The hidden dim for the model.
+    output_dim : int
+        The output dim for the model.
+    use_norm : bool
+        Whether to normalize the segment sum or not.
+    use_sigmoid : bool
+        Whether to use sigmoid on the final output or not.
+    """
+
+    input_dim: int = 172
+    hidden_dim: int = 256
+    output_dim: int = 1
+    use_norm: bool = False
+    use_sigmoid: bool = False
+
+    def to_dict(self):  # pylint: disable=missing-function-docstring
+        return {
+            "input_dim": self.input_dim,
+            "hidden_dim": self.hidden_dim,
+            "output_dim": self.output_dim,
+            "use_norm": self.use_norm,
+            "use_sigmoid": self.use_sigmoid,
+        }
+
+
+# pylint: disable=too-few-public-methods
+class FeatureGroup:
+    """Feature group
+
+    Parameters
+    ----------
+    group_hash : str
+        The hash of the group
+    features : List[np.ndarray]
+        The features
+    costs : List[float]
+        The costs
+    min_cost : float
+        The minimum cost
+    """
+
+    group_hash: str
+    features: List[np.ndarray]
+    costs: np.ndarray
+    min_cost: float
+
+    def __init__(
+        self,
+        group_hash: str,
+        features: List[np.ndarray],
+        costs: np.ndarray,
+    ) -> None:
+        self.group_hash = group_hash
+        self.features = features
+        self.costs = costs
+        self.min_cost = np.min(costs)
+
+    def append(  # pylint: disable=missing-function-docstring
+        self,
+        features: List[np.ndarray],
+        costs: np.ndarray,
+    ) -> None:
+        self.features.extend(features)
+        self.costs = np.append(self.costs, costs)
+        self.min_cost = np.min(self.costs)
+
+
+# pylint: disable=too-many-instance-attributes
+class SegmentDataLoader:
+    """Dataloader for SegmentSum MLP model.
+
+    Parameters
+    ----------
+    features : List[np.ndarray]
+        The features
+    results : np.ndarray
+        The measured results
+    batch_size : int
+        The batch size
+    shuffle : bool
+        Whether to shuffle the dataset or not
+    """
+
+    def __init__(
+        self,
+        features,
+        results,
+        batch_size=128,
+        shuffle=False,
+    ):
+        self.batch_size = batch_size
+        self.shuffle = shuffle
+        self.data_size = len(features)
+
+        # flatten features and store the starting indices
+        self.segment_sizes = torch.tensor([len(feature) for feature in 
features])
+        self.feature_offsets = (
+            torch.cumsum(self.segment_sizes, 0, dtype=torch.int32) - 
self.segment_sizes
+        )
+        features = torch.cat([torch.tensor(feature) for feature in features])
+        norm = features.max(dim=0)[0]

Review Comment:
   nit: usually we prefer this way to indicate `.max` returns two items
   
   ```suggestion
           norm, _ = features.max(dim=0)
   ```



##########
python/tvm/meta_schedule/cost_model/mlp_model.py:
##########
@@ -0,0 +1,743 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+MLP-based cost model
+"""
+
+import logging
+import math
+import os
+import random
+import tempfile
+from collections import OrderedDict
+from itertools import chain as itertools_chain
+from typing import Dict, List, NamedTuple
+from tqdm import tqdm
+
+import numpy as np
+import torch
+from torch import nn
+
+# pylint: disable=relative-beyond-top-level
+from ...contrib.tar import tar, untar
+from ...runtime import NDArray
+from ..cost_model import PyCostModel
+from ..feature_extractor import FeatureExtractor, PerStoreFeature
+from ..runner import RunnerResult
+from ..search_strategy import MeasureCandidate
+from ..tune_context import TuneContext
+from ..utils import derived_object, shash2hex
+
+logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
+
+
+# pylint: disable=no-member
+
+
+class SegmentSumMLPConfig(NamedTuple):
+    """SegmentSum MLP model configuration
+
+    Parameters
+    ----------
+    input_dim : int
+        The input dim for the model.
+    hidden_dim : int
+        The hidden dim for the model.
+    output_dim : int
+        The output dim for the model.
+    use_norm : bool
+        Whether to normalize the segment sum or not.
+    use_sigmoid : bool
+        Whether to use sigmoid on the final output or not.
+    """
+
+    input_dim: int = 172
+    hidden_dim: int = 256
+    output_dim: int = 1
+    use_norm: bool = False
+    use_sigmoid: bool = False
+
+    def to_dict(self):  # pylint: disable=missing-function-docstring
+        return {
+            "input_dim": self.input_dim,
+            "hidden_dim": self.hidden_dim,
+            "output_dim": self.output_dim,
+            "use_norm": self.use_norm,
+            "use_sigmoid": self.use_sigmoid,
+        }
+
+
+# pylint: disable=too-few-public-methods
+class FeatureGroup:
+    """Feature group
+
+    Parameters
+    ----------
+    group_hash : str
+        The hash of the group
+    features : List[np.ndarray]
+        The features
+    costs : List[float]
+        The costs
+    min_cost : float
+        The minimum cost
+    """
+
+    group_hash: str
+    features: List[np.ndarray]
+    costs: np.ndarray
+    min_cost: float
+
+    def __init__(
+        self,
+        group_hash: str,
+        features: List[np.ndarray],
+        costs: np.ndarray,
+    ) -> None:
+        self.group_hash = group_hash
+        self.features = features
+        self.costs = costs
+        self.min_cost = np.min(costs)
+
+    def append(  # pylint: disable=missing-function-docstring
+        self,
+        features: List[np.ndarray],
+        costs: np.ndarray,
+    ) -> None:
+        self.features.extend(features)
+        self.costs = np.append(self.costs, costs)
+        self.min_cost = np.min(self.costs)
+
+
+# pylint: disable=too-many-instance-attributes
+class SegmentDataLoader:
+    """Dataloader for SegmentSum MLP model.
+
+    Parameters
+    ----------
+    features : List[np.ndarray]
+        The features
+    results : np.ndarray
+        The measured results
+    batch_size : int
+        The batch size
+    shuffle : bool
+        Whether to shuffle the dataset or not
+    """
+
+    def __init__(
+        self,
+        features,
+        results,
+        batch_size=128,
+        shuffle=False,
+    ):
+        self.batch_size = batch_size
+        self.shuffle = shuffle
+        self.data_size = len(features)
+
+        # flatten features and store the starting indices
+        self.segment_sizes = torch.tensor([len(feature) for feature in 
features])
+        self.feature_offsets = (
+            torch.cumsum(self.segment_sizes, 0, dtype=torch.int32) - 
self.segment_sizes
+        )
+        features = torch.cat([torch.tensor(feature) for feature in features])
+        norm = features.max(dim=0)[0]
+        norm[norm == 0] = 1
+        self.features = features / norm
+        self.results = torch.tensor(results)
+        self.iter_order = self.pointer = None
+
+    def __len__(self):
+        return self.data_size
+
+    def __iter__(self):
+        if self.shuffle:
+            self.iter_order = torch.randperm(self.data_size)
+        else:
+            self.iter_order = torch.arange(self.data_size)
+        self.pointer = 0
+        return self
+
+    def __next__(self):
+        if self.pointer >= self.data_size:
+            raise StopIteration
+        batch_indices = self.iter_order[self.pointer : self.pointer + 
self.batch_size]
+        self.pointer += self.batch_size
+        return self._fetch_indices(batch_indices)
+
+    def _fetch_indices(self, indices):
+        segment_sizes, feature_offsets = self.segment_sizes[indices], 
self.feature_offsets[indices]
+        feature_indices = torch.empty(segment_sizes.sum(), dtype=torch.int32)
+        idx = 0
+        for offset, seg_size in zip(feature_offsets, segment_sizes):
+            feature_indices[idx : idx + seg_size] = torch.arange(offset, 
offset + seg_size)
+            idx += seg_size
+        features = self.features[feature_indices.long()]
+        results = self.results[indices.long()]
+        return segment_sizes, features, results
+
+
+class SegmentSumMLPModule(nn.Module):
+    """SegmentSum MLP model.
+
+    Parameters
+    ----------
+    input_dim : int
+        The input dim for the model.
+    hidden_dim : int
+        The hidden dim for the model.
+    output_dim : int
+        The output dim for the model.
+    use_norm : bool
+        Whether to normalize the segment sum or not.
+    use_sigmoid : bool
+        Whether to use sigmoid on the final output or not.
+    """
+
+    input_dim: int
+    hidden_dim: int
+    output_dim: int
+    use_norm: bool
+    use_sigmoid: bool
+
+    def __init__(  # pylint: disable=too-many-arguments
+        self,
+        input_dim: int = 172,
+        hidden_dim: int = 256,
+        output_dim: int = 1,
+        use_norm: bool = False,
+        use_sigmoid: bool = False,
+    ):
+        super().__init__()
+        self.encoder = nn.Sequential(
+            nn.Linear(input_dim, hidden_dim),
+            nn.ReLU(),
+            nn.Linear(hidden_dim, hidden_dim),
+            nn.ReLU(),
+        )
+        self.norm = nn.BatchNorm1d(hidden_dim) if use_norm else nn.Identity()
+        self.layer0 = nn.Sequential(
+            nn.Linear(hidden_dim, hidden_dim),
+            nn.ReLU(),
+        )
+        self.layer1 = nn.Sequential(
+            nn.Linear(hidden_dim, hidden_dim),
+            nn.ReLU(),
+        )
+        self.decoder = nn.Linear(hidden_dim, output_dim)
+        self.sigmoid = nn.Sigmoid() if use_sigmoid else nn.Identity()
+
+    def forward(
+        self,
+        segment_sizes: torch.Tensor,
+        features: torch.Tensor,
+    ):
+        """Forward the inputs with the model.
+
+        Parameters
+        ----------
+        segment_sizes : Tensor
+            The sizes of the segments.
+        features : Tensor
+            The feature vectors of the candidates.
+        """
+        n_seg = len(segment_sizes)
+        encoded_features = self.encoder(features)
+        segment_indices = torch.repeat_interleave(
+            torch.arange(n_seg, device=features.device),
+            segment_sizes.long(),
+        )
+        n_dim = encoded_features.shape[1]
+        segment_sum = torch.scatter_add(
+            input=torch.zeros((n_seg, n_dim), dtype=encoded_features.dtype, 
device=features.device),
+            dim=0,
+            index=segment_indices.view(-1, 1).expand(-1, n_dim),
+            src=encoded_features,
+        )
+        out = self.norm(segment_sum)
+        out = self.layer0(out) + out
+        out = self.layer1(out) + out
+        out = self.decoder(out).squeeze()
+        out = self.sigmoid(out)
+        return out
+
+    # pylint: 
disable=no-self-use,too-many-arguments,too-many-locals,invalid-name
+    def lambda_rank_loss(self, preds, labels, k=None, eps=1e-10, sigma=1.0):
+        """LambdaLoss: Metric-Driven Loss for Learning-to-Rank
+
+        Parameters
+        ----------
+        preds : Tensor
+            The predicted runtime for each candidate.
+        labels : Tensor
+            The measured runtime for each candidate.
+        k : int
+            Loss for top k.
+            Default is None, which means computing all scores.
+        eps : float
+            The minimum value to the denominator and argument of log if they 
reach 0.
+        sigma : float
+            The scaling factor to the input of the sigmoid function.
+        """
+        device = preds.device
+        y_pred, y_true = preds[None, :], labels[None, :]
+        y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1)
+        y_true_sorted, _ = y_true.sort(descending=True, dim=-1)
+        true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred)
+        true_diffs = true_sorted_by_preds[:, :, None] - 
true_sorted_by_preds[:, None, :]
+        padded_pairs_mask = torch.isfinite(true_diffs) & (true_diffs > 0)
+        ndcg_at_k_mask = torch.zeros(
+            (y_pred.shape[1], y_pred.shape[1]), dtype=torch.bool, device=device
+        )
+        ndcg_at_k_mask[:k, :k] = 1
+        true_sorted_by_preds.clamp_(min=0.0)
+        y_true_sorted.clamp_(min=0.0)
+        pos_idxs = torch.arange(1, y_pred.shape[1] + 1).to(device)
+        D = torch.log2(1.0 + pos_idxs.float())[None, :]
+        maxDCGs = torch.sum(((torch.pow(2, y_true_sorted) - 1) / D)[:, :k], 
dim=-1).clamp(min=eps)
+        G = (torch.pow(2, true_sorted_by_preds) - 1) / maxDCGs[:, None]
+        weights = torch.abs(
+            torch.pow(D[:, :, None], -1.0) - torch.pow(D[:, None, :], -1.0)
+        ) * torch.abs(G[:, :, None] - G[:, None, :])
+        scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, 
:]).clamp(
+            min=-1e8, max=1e8
+        )
+        scores_diffs[torch.isnan(scores_diffs)] = 0.0
+        weighted_probs = (torch.sigmoid(sigma * scores_diffs).clamp(min=eps) 
** weights).clamp(
+            min=eps
+        )
+        losses = torch.log2(weighted_probs)
+        masked_losses = losses[padded_pairs_mask & ndcg_at_k_mask]
+        loss = -torch.sum(masked_losses)
+        return loss
+
+    def topk_score(
+        self,
+        pred_results: torch.Tensor,
+        gt_results: torch.Tensor,
+        k: int,
+    ) -> float:
+        """
+        Evaluate the top-k score
+
+        Parameters
+        ----------
+        pred_results: Tensor
+            The raw prediction
+        gt_results: Tensor
+            The measured results
+        k : int
+            The k in top k score
+
+        Returns
+        -------
+        score : float
+            The top-k score
+        """
+        topk_indices = torch.topk(pred_results, k, largest=False).indices
+        score = gt_results.min() / gt_results[topk_indices].min()
+        return score.item()
+
+
+@derived_object
+class MLPModel(PyCostModel):

Review Comment:
   Let's do some slight refactoring moving code around to make sure it's easy 
enough to use in both pre-training and MetaSchedule integration.
   
   In this file, we will have:
   
   ```python
   class SegmentDataLoader: ...
   
   class SegmentSumMLP(torch.nn.Module):
     def __init__(...): ...
     def forward(...): ...
     def lambda_rank_loss(...): ...
     def topk_score(...): ...
   
   class SegmentSumMLPTrainer:
     # this handles data loading and training of `SegmentSumMLP`
   
   class MLPModel(PyCostModel):
     model: SegmentSumMLP
   
     def __init__(...): ...
     def load(...): ...
     def save(...): ...
     def update(...): ... # an empty method if the model is frozen
     def predict(...):  ... # calls `self.model` to make predictions
     <del>def _train(...): ... </del> # move to the Trainer; no need to 
implement this method
     <del>def _validate(...): ... </del> # move to the Trainer; no need to 
implement this method
   ```
   
   The reason that I suggested such refactoring is because I would love to make 
sure the two logics are decoupled:
   - Model pretraining
   - Integration with MetaSchedule
   
   With the `Trainer`, the logic for model pretraining doesn't need `MLPModel` 
anymore, and `MLPModel` can focus on the integration part without having to 
worry about the model training.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to