This is an automated email from the ASF dual-hosted git repository.

dimuthuupe pushed a commit to branch cybershuttle-staging
in repository https://gitbox.apache.org/repos/asf/airavata.git

commit 512012175e20cc5ca6a2b7084ca4460278f0feb8
Author: yasith <[email protected]>
AuthorDate: Fri Apr 4 09:10:24 2025 -0400

    update notebooks
---
 .../data/cosyne/cosyne_tutorial_part_1.ipynb       |    1 +
 .../data/cosyne/cosyne_tutorial_part_2.ipynb       | 2470 ++++++++++++++++++++
 .../jupyterhub/data/cosyne/cybershuttle.yml        |   64 +
 .../jupyterhub/data/gkeyll/plotE_z.ipynb           |    6 +-
 4 files changed, 2538 insertions(+), 3 deletions(-)

diff --git 
a/modules/agent-framework/deployments/jupyterhub/data/cosyne/cosyne_tutorial_part_1.ipynb
 
b/modules/agent-framework/deployments/jupyterhub/data/cosyne/cosyne_tutorial_part_1.ipynb
new file mode 100644
index 0000000000..e94e01ad2f
--- /dev/null
+++ 
b/modules/agent-framework/deployments/jupyterhub/data/cosyne/cosyne_tutorial_part_1.ipynb
@@ -0,0 +1 @@
+{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"gpuType":"T4"},"kernelspec":{"name":"python3","display_name":"Python
 
3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["#
 Foundations of Transformers in Neuroscience Tutorial\n","\n","Authors: Mehdi 
Azabou | Contributions: Vinam Arora, Shivashriganesh Mahato, Eva 
Dyer\n","\n","***\n","\n","In this notebook, we will go through an example for 
preparing a dataset using\n","data objects fro [...]
\ No newline at end of file
diff --git 
a/modules/agent-framework/deployments/jupyterhub/data/cosyne/cosyne_tutorial_part_2.ipynb
 
b/modules/agent-framework/deployments/jupyterhub/data/cosyne/cosyne_tutorial_part_2.ipynb
new file mode 100644
index 0000000000..1504196598
--- /dev/null
+++ 
b/modules/agent-framework/deployments/jupyterhub/data/cosyne/cosyne_tutorial_part_2.ipynb
@@ -0,0 +1,2470 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "3ImhQu3SEILJ",
+   "metadata": {
+    "id": "3ImhQu3SEILJ"
+   },
+   "source": [
+    "# Foundations of Transformers in Neuroscience Tutorial\n",
+    "\n",
+    "Authors: Shivashriganesh Mahato, Vinam Arora | Contributions: Mehdi 
Azabou, Sergey Shuvaev, Julie Young, Eva Dyer\n",
+    "\n",
+    "***\n",
+    "\n",
+    "The goal of this notebook is to show you how to work with datasets and 
dataloaders, build and train several neural decoding models (a simple MLP, a 
Transformer, and POYO), fine-tune a pretrained POYO model on a new session, and 
visualize training. This notebook is designed to be interactive and provide 
visual feedback. As you work through the cells, try to run them and observe the 
results. Detailed explanations are provided along the way.\n",
+    "\n",
+    "<center>\n",
+    "<img 
src=\"https://torch-brain.readthedocs.io/en/latest/_static/torch_brain_logo.png\";
 width=\"150\" height=\"150\" alt=\"torch_brain Logo\">\n",
+    "</center>\n",
+    "\n",
+    "\n",
+    "We will focus on three main topics in this notebook:\n",
+    "* **Part 1: DataLoaders**\n",
+    "* **Part 2: Training Models**\n",
+    "* **Part 3: Finetuning and Visualizations**\n",
+    "\n",
+    "\n",
+    "\\\\\n",
+    "General references:\n",
+    "\n",
+    "- [**torch_brain** 
documentation](<https://torch-brain.readthedocs.io/en/latest/index.html>)\n",
+    "- [pytorch 
tutorials](<https://pytorch.org/tutorials/beginner/basics/intro.html>)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "BycghpnsEKBg",
+   "metadata": {
+    "id": "BycghpnsEKBg"
+   },
+   "source": [
+    "***\n",
+    "## Setup\n",
+    "\n",
+    "First, let's install **torch_brain**, and get you set up with some simple 
utility functions that will be used throughout this notebook."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "hDjkZJkEEMJR",
+   "metadata": {
+    "id": "hDjkZJkEEMJR"
+   },
+   "outputs": [],
+   "source": [
+    "! uv pip install pytorch_brain -q\n",
+    "\n",
+    "# We use uv for the installation here, which seems to work better with 
google colab.\n",
+    "# Although UV is awesome and we highly recommend it, you could install it 
with vanilla pip\n",
+    "# in your local environments"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "aaMelfAWsl2h",
+   "metadata": {
+    "id": "aaMelfAWsl2h"
+   },
+   "source": [
+    "### Run the block below to load utility functions."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "agXH4LvWEMJS",
+   "metadata": {
+    "id": "agXH4LvWEMJS"
+   },
+   "outputs": [],
+   "source": [
+    "import torch\n",
+    "import numpy as np\n",
+    "import matplotlib.pyplot as plt\n",
+    "from omegaconf import OmegaConf\n",
+    "import warnings\n",
+    "import logging\n",
+    "from torch_brain.utils import seed_everything\n",
+    "\n",
+    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else 
\"cpu\")\n",
+    "\n",
+    "warnings.filterwarnings('ignore')\n",
+    "logging.disable(logging.WARNING)\n",
+    "\n",
+    "\n",
+    "def move_to_gpu(data, device):\n",
+    "    \"\"\"\n",
+    "    Recursively moves tensors (or collections of tensors) to the given 
device.\n",
+    "    \"\"\"\n",
+    "    if isinstance(data, torch.Tensor):\n",
+    "        return data.to(device)\n",
+    "    elif isinstance(data, dict):\n",
+    "        return {k: move_to_gpu(v, device) for k, v in data.items()}\n",
+    "    elif isinstance(data, list):\n",
+    "        return [move_to_gpu(elem, device) for elem in data]\n",
+    "    else:\n",
+    "        return data\n",
+    "\n",
+    "\n",
+    "def bin_spikes(spikes, num_units, bin_size, right=True, 
num_bins=None):\n",
+    "    \"\"\"\n",
+    "    Bins spike timestamps into a 2D array: [num_units x num_bins].\n",
+    "    \"\"\"\n",
+    "    rate = 1 / bin_size  # avoid precision issues\n",
+    "    binned_spikes = np.zeros((num_units, num_bins))\n",
+    "    bin_index = np.floor((spikes.timestamps) * rate).astype(int)\n",
+    "    np.add.at(binned_spikes, (spikes.unit_index, bin_index), 1)\n",
+    "    return binned_spikes\n",
+    "\n",
+    "\n",
+    "def r2_score(y_pred, y_true):\n",
+    "    # Compute total sum of squares (variance of the true values)\n",
+    "    y_true_mean = torch.mean(y_true, dim=0, keepdim=True)\n",
+    "    ss_total = torch.sum((y_true - y_true_mean) ** 2)\n",
+    "\n",
+    "    # Compute residual sum of squares\n",
+    "    ss_res = torch.sum((y_true - y_pred) ** 2)\n",
+    "\n",
+    "    # Compute R^2\n",
+    "    r2 = 1 - ss_res / ss_total\n",
+    "\n",
+    "    return r2\n",
+    "\n",
+    "\n",
+    "def compute_r2(dataloader, model):\n",
+    "    # Compute R2 score over the entire dataset\n",
+    "    total_target = []\n",
+    "    total_pred = []\n",
+    "    for batch in dataloader:\n",
+    "        batch = move_to_gpu(batch, device)\n",
+    "        pred = model(**batch[\"model_inputs\"])\n",
+    "        target = batch[\"target_values\"]\n",
+    "\n",
+    "        # Store target and pred for visualization\n",
+    "        mask = torch.ones_like(target, dtype=torch.bool)\n",
+    "        if \"output_mask\" in batch[\"model_inputs\"]:\n",
+    "            mask = batch[\"model_inputs\"][\"output_mask\"]\n",
+    "        total_target.append(target[mask])\n",
+    "        total_pred.append(pred[mask])\n",
+    "\n",
+    "    # Concatenate all batch outputs\n",
+    "    total_target = torch.cat(total_target)\n",
+    "    total_pred = torch.cat(total_pred)\n",
+    "\n",
+    "    # Compute the R2 score\n",
+    "    r2 = r2_score(total_pred.flatten(), total_target.flatten())\n",
+    "\n",
+    "    return r2.item(), total_target, total_pred\n",
+    "\n",
+    "\n",
+    "def print_model(model: torch.nn.Module):\n",
+    "    \"\"\"\n",
+    "    Prints a summary of the model architecture and parameter count.\n",
+    "    \"\"\"\n",
+    "    model_str = str(model).split('\\n')\n",
+    "    print(\"\\nModel:\")\n",
+    "    print('\\n'.join(model_str[:5]))\n",
+    "    print(\"...\")\n",
+    "    print('\\n'.join(model_str[-min(5, len(model_str)):]))\n",
+    "    num_params = sum(p.numel() for p in model.parameters())\n",
+    "    if num_params > 1e9:\n",
+    "        param_str = f\"{num_params/1e9:.1f}G\"\n",
+    "    elif num_params > 1e6:\n",
+    "        param_str = f\"{num_params/1e6:.1f}M\"\n",
+    "    else:\n",
+    "        param_str = f\"{num_params/1e3:.1f}K\"\n",
+    "    print(f\"\\nNumber of parameters: {param_str}\\n\")\n",
+    "\n",
+    "\n",
+    "def plot_training_curves(r2_log, loss_log):\n",
+    "    \"\"\"\n",
+    "    Plots the training curves: training loss and validation R2 score.\n",
+    "    \"\"\"\n",
+    "    plt.figure(figsize=(12, 4))\n",
+    "    plt.subplot(1, 2, 1)\n",
+    "    plt.plot(np.linspace(0, len(loss_log), len(loss_log)), loss_log)\n",
+    "    plt.title(\"Training Loss\")\n",
+    "    plt.xlabel(\"Training Steps\")\n",
+    "    plt.ylabel(\"MSE Loss\")\n",
+    "    plt.grid()\n",
+    "    plt.subplot(1, 2, 2)\n",
+    "    plt.plot(r2_log)\n",
+    "    plt.title(\"Validation R2\")\n",
+    "    plt.xlabel(\"Epochs\")\n",
+    "    plt.ylabel(\"R2 Score\")\n",
+    "    plt.grid()\n",
+    "    plt.tight_layout()\n",
+    "    plt.show()\n",
+    "\n",
+    "\n",
+    "def generate_sinusoidal_position_embs(num_timesteps, dim):\n",
+    "    position = torch.arange(num_timesteps).unsqueeze(1)\n",
+    "    div_term = torch.exp(torch.arange(0, dim, 2) * (-np.log(10000.0) / 
dim))\n",
+    "    pe = torch.empty(num_timesteps, dim)\n",
+    "    pe[:, 0:dim // 2] = torch.sin(position * div_term)\n",
+    "    pe[:, dim//2:] = torch.cos(position * div_term)\n",
+    "    return pe\n",
+    "\n",
+    "\n",
+    "def load_pretrained(ckpt_path, model):\n",
+    "    print(\"Loading pretrained model...\")\n",
+    "    checkpoint = torch.load(ckpt_path, map_location=\"cpu\", 
weights_only=False)\n",
+    "    # poyo is pretrained using lightning, so model weights are prefixed 
with \"model.\"\n",
+    "    state_dict = {k.replace(\"model.\", \"\"): v for k, v in 
checkpoint[\"state_dict\"].items()}\n",
+    "    model.load_state_dict(state_dict)\n",
+    "    print(\"Done!\")\n",
+    "    return model\n",
+    "\n",
+    "\n",
+    "def reinit_vocab(emb_module, vocab):\n",
+    "    emb_module.extend_vocab(vocab)\n",
+    "    emb_module.subset_vocab(vocab)\n",
+    "\n",
+    "\n",
+    "def get_dataset_config(brainset, sessions):\n",
+    "    brainset_norms = {\n",
+    "        \"perich_miller_population_2018\": {\n",
+    "            \"mean\": 0.0,\n",
+    "            \"std\": 20.0\n",
+    "        }\n",
+    "    }\n",
+    "\n",
+    "    config = f\"\"\"\n",
+    "    - selection:\n",
+    "      - brainset: {brainset}\n",
+    "        sessions:\"\"\"\n",
+    "    if type(sessions) is not list:\n",
+    "        sessions = [sessions]\n",
+    "    for session in sessions:\n",
+    "        config += f\"\"\"\n",
+    "          - {session}\"\"\"\n",
+    "    config += f\"\"\"\n",
+    "      config:\n",
+    "        readout:\n",
+    "          readout_id: cursor_velocity_2d\n",
+    "          normalize_mean: {brainset_norms[brainset][\"mean\"]}\n",
+    "          normalize_std: {brainset_norms[brainset][\"std\"]}\n",
+    "          metrics:\n",
+    "            - metric:\n",
+    "                _target_: torchmetrics.R2Score\n",
+    "    \"\"\"\n",
+    "\n",
+    "    config = OmegaConf.create(config)\n",
+    "\n",
+    "    return config"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "0WE9qTBPy887",
+   "metadata": {
+    "id": "0WE9qTBPy887"
+   },
+   "source": [
+    "***\n",
+    "\n",
+    "## Part 1: Data Loading\n",
+    "\n",
+    "***\n",
+    "\n",
+    "### Table of contents:\n",
+    "* 1.1 The life of a data sample\n",
+    "* 1.2 Setting up a basic data pipeline\n",
+    "* 1.3 Downloading a session"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "_5i33iBdyVrq",
+   "metadata": {
+    "id": "_5i33iBdyVrq"
+   },
+   "source": [
+    "### 1.1 The life of a data sample\n",
+    "\n",
+    "In the previous notebook, we saw how **torch_brain.data.Dataset** 
efficiently produces a data sample. But once a sample is drawn—what happens 
next?\n",
+    "\n",
+    "1.\tA data sample originates from the dataset. In **torch_brain**, this 
typically refers to a short time-slice of a neural recording.\n",
+    "2. Optionally, the sample can be transformed—for example, by applying 
augmentations such as dropping out neurons or brain regions.\n",
+    "3. Next, the sample is further processed and reshaped into a format 
suitable for model input. We refer to this step as **tokenization**.\n",
+    "4. The tokenized sample is then **collated** with other samples to form a 
**batch**.\n",
+    "5.\tFinally, the batch is passed through the model for the forward 
computation and loss evaluation.\n",
+    "\n",
+    "This process can be seen visually through the diagram below:\n",
+    "\n",
+    "<center>\n",
+    "<img 
src=\"https://ik.imagekit.io/7tkfmw7hc/dataloader.png?updatedAt=1743052497906\"; 
height=420 />\n",
+    "</center>\n",
+    "\n",
+    "That’s a lot of work—and in plain PyTorch, you’d be doing it all by hand. 
**torch_brain** makes each of these steps easy and intuitive to build and 
customize."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "U0L7aQuiE7qL",
+   "metadata": {
+    "id": "U0L7aQuiE7qL"
+   },
+   "source": [
+    "### 1.2 Setting up a basic data pipeline\n",
+    "\n",
+    "Now let's define a utility function to create the training and validation 
datasets, samplers, and dataloaders.\n",
+    "This function will be used to set up the data for training and validation 
of our models.\n",
+    "\n",
+    "Note: We'll handle tokenization in the next part of this notebook."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "EEqTM1Nmy887",
+   "metadata": {
+    "id": "EEqTM1Nmy887"
+   },
+   "outputs": [],
+   "source": [
+    "from torch_brain.data import Dataset, collate, chain\n",
+    "from torch_brain.data.sampler import RandomFixedWindowSampler, 
SequentialFixedWindowSampler\n",
+    "from torch.utils.data import DataLoader\n",
+    "\n",
+    "def get_train_val_loaders(recording_id=None, cfg=None, batch_size=32, 
seed=0):\n",
+    "    \"\"\"Sets up train and validation Datasets, Samplers, and 
DataLoaders\n",
+    "    \"\"\"\n",
+    "\n",
+    "    # -- Train --\n",
+    "    train_dataset = Dataset(\n",
+    "        root=\"data\",                # root directory where .h5 files 
are found\n",
+    "        recording_id=recording_id,  # you either specify a single 
recording ID\n",
+    "        config=cfg,                 # or a config for multi-session 
training / more complex configs\n",
+    "        split=\"train\",\n",
+    "    )\n",
+    "    # We use a random sampler to improve generalization during 
training\n",
+    "    train_sampling_intervals = train_dataset.get_sampling_intervals()\n",
+    "    train_sampler = RandomFixedWindowSampler(\n",
+    "        sampling_intervals=train_sampling_intervals,\n",
+    "        window_length=1.0,          # context window of samples\n",
+    "        generator=torch.Generator().manual_seed(seed),\n",
+    "    )\n",
+    "    # Finally combine them in a dataloader\n",
+    "    train_loader = DataLoader(\n",
+    "        dataset=train_dataset,      # dataset\n",
+    "        sampler=train_sampler,      # sampler\n",
+    "        batch_size=batch_size,      # num of samples per batch\n",
+    "        collate_fn=collate,         # the collator\n",
+    "        num_workers=4,              # data sample processing (slicing, 
transforms, tokenization) happens in parallel; this sets the amount of that 
parallelization\n",
+    "        pin_memory=True,\n",
+    "    )\n",
+    "\n",
+    "    # -- Validation --\n",
+    "    val_dataset = Dataset(\n",
+    "        root=\"data\",\n",
+    "        recording_id=recording_id,\n",
+    "        config=cfg,\n",
+    "        split=\"valid\",\n",
+    "    )\n",
+    "    # For validation we don't randomize samples for reproducibility\n",
+    "    val_sampling_intervals = val_dataset.get_sampling_intervals()\n",
+    "    val_sampler = SequentialFixedWindowSampler(\n",
+    "        sampling_intervals=val_sampling_intervals,\n",
+    "        window_length=1.0,\n",
+    "    )\n",
+    "    # Combine them in a dataloader\n",
+    "    val_loader = DataLoader(\n",
+    "        dataset=val_dataset,\n",
+    "        sampler=val_sampler,\n",
+    "        batch_size=batch_size,\n",
+    "        collate_fn=collate,\n",
+    "        num_workers=4,\n",
+    "        pin_memory=True,\n",
+    "    )\n",
+    "\n",
+    "    train_dataset.disable_data_leakage_check()\n",
+    "    val_dataset.disable_data_leakage_check()\n",
+    "\n",
+    "    return train_dataset, train_loader, val_dataset, val_loader"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "v9iwnzt8y888",
+   "metadata": {
+    "id": "v9iwnzt8y888"
+   },
+   "source": [
+    "Aside from the tokenizer, that's all there is to the data pipeline! Most 
of it is handled behind-the-scenes with **torch_brain**."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "QA_B-daGG414",
+   "metadata": {
+    "id": "QA_B-daGG414"
+   },
+   "source": [
+    "### 1.3 Downloading a session\n",
+    "For our examples, we'll use data from a single neural recording from [1], 
which includes recordings of spiking activity from a monkey during a reaching 
task. Let's quickly download it to our environment."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "1qxkvGX0y888",
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/";
+    },
+    "executionInfo": {
+     "elapsed": 7814,
+     "status": "ok",
+     "timestamp": 1743092108961,
+     "user": {
+      "displayName": "Eva Dyer",
+      "userId": "05212169819659068372"
+     },
+     "user_tz": 240
+    },
+    "id": "1qxkvGX0y888",
+    "outputId": "d962adbc-1583-4ec6-ba7b-224bb47f139b"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Downloading...\n",
+      "From: 
https://drive.google.com/uc?id=1W--Sm_BcphEC2snoF4zwPdHkkYGgAaUw\n";,
+      "To: 
/content/data/perich_miller_population_2018/t_20130819_center_out_reaching.h5\n",
+      "\r",
+      "  0% 0.00/9.88M [00:00<?, ?B/s]\r",
+      "100% 9.88M/9.88M [00:00<00:00, 227MB/s]\n"
+     ]
+    }
+   ],
+   "source": [
+    "! mkdir -p data/perich_miller_population_2018\n",
+    "! gdown 1W--Sm_BcphEC2snoF4zwPdHkkYGgAaUw -O 
data/perich_miller_population_2018/t_20130819_center_out_reaching.h5"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "7F-vMzr0y888",
+   "metadata": {
+    "id": "7F-vMzr0y888"
+   },
+   "source": [
+    "Now, let's jump into training some models!"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "V-ebX-UMy888",
+   "metadata": {
+    "id": "V-ebX-UMy888"
+   },
+   "source": [
+    "***\n",
+    "\n",
+    "## Part 2: Training Models\n",
+    "\n",
+    "***\n",
+    "\n",
+    "In this section, we walk through training different neural decoders. 
We'll structure the code so there's a common training loop, then move onto 
implementing:\n",
+    "\n",
+    "- A simple MLP Neural Decoder\n",
+    "- A Transformer-based Neural Decoder [2]\n",
+    "- POYO! [3]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "cfRm95aFy888",
+   "metadata": {
+    "id": "cfRm95aFy888"
+   },
+   "source": [
+    "### 2.1 Warmup: Implementing an MLP based Neural Activity Decoder\n",
+    "\n",
+    "To learn how to write and train models in the `torch_brain` framework, 
let's kick things off with a simple MLP Neural Decoder."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "X59hXx4ay888",
+   "metadata": {
+    "id": "X59hXx4ay888"
+   },
+   "source": [
+    "#### 2.1.1 Defining the Model\n",
+    "\n",
+    "In **torch_brain**, every model needs to define two important methods:\n",
+    "1. `model.tokenize`: the tokenization step, which processes and reshapes 
data to align with model input.\n",
+    "2. `model.forward`: the forward pass, which produces predictions.\n",
+    "\n",
+    "In case of an MLP decoder, these will look like:\n",
+    "\n",
+    "<center>\n",
+    "<img 
src=\"https://ik.imagekit.io/7tkfmw7hc/mlp.png?updatedAt=1743064332486\"; 
width=900 />\n",
+    "</center>"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "Fmb6W850y888",
+   "metadata": {
+    "id": "Fmb6W850y888"
+   },
+   "outputs": [],
+   "source": [
+    "import torch.nn as nn\n",
+    "\n",
+    "class MLPNeuralDecoder(nn.Module):\n",
+    "    def __init__(self, num_units, bin_size, sequence_length, output_dim, 
hidden_dim):\n",
+    "        \"\"\"Initialize the neural net layers.\"\"\"\n",
+    "        super().__init__()\n",
+    "\n",
+    "        self.num_timesteps = int(sequence_length / bin_size)\n",
+    "        self.bin_size = bin_size\n",
+    "\n",
+    "        self.net = nn.Sequential(\n",
+    "            nn.Linear(self.num_timesteps * num_units, hidden_dim),\n",
+    "            nn.ReLU(),\n",
+    "            nn.Linear(hidden_dim, hidden_dim),\n",
+    "            nn.ReLU(),\n",
+    "            nn.Linear(hidden_dim, output_dim * self.num_timesteps),\n",
+    "        )\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        \"\"\"Produces predictions from a binned spiketrain.\n",
+    "        This is pure PyTorch code.\n",
+    "\n",
+    "        Shape of x: (B, T, N)\n",
+    "        \"\"\"\n",
+    "\n",
+    "        x = x.flatten(1)                          # (B, T, N)    -> (B, 
T*N)\n",
+    "        x = self.net(x)                           # (B, T*N)     -> (B, 
T*D_out)\n",
+    "        x = x.reshape(-1, self.num_timesteps, 2)  # (B, T*D_out) -> (B, 
T, D_out)\n",
+    "        return x\n",
+    "\n",
+    "    def tokenize(self, data):\n",
+    "        \"\"\"tokenizes a data sample, which is a sliced Data 
object\"\"\"\n",
+    "\n",
+    "        # A. Extract and bin neural activity (data.spikes)\n",
+    "        spikes = data.spikes\n",
+    "        x = bin_spikes(\n",
+    "            spikes=spikes,\n",
+    "            num_units=len(data.units),\n",
+    "            bin_size=self.bin_size,\n",
+    "            num_bins=self.num_timesteps\n",
+    "        ).T\n",
+    "        # Final shape of x here is (timestamps, num_neurons)\n",
+    "\n",
+    "        # B. Extract the corresponding cursor velocity, which will act as 
targets\n",
+    "        #    for training the MLP.\n",
+    "        y = data.cursor.vel\n",
+    "        # Final shape of y is (timestamps x 2)\n",
+    "        # Note that in this example we have choosen the bin size to match 
the\n",
+    "        # sampling rate of the recorded cursor velocity.\n",
+    "\n",
+    "        # Finally, we output the \"tokenized\" data in the form of a 
dictionary.\n",
+    "        data_dict = {\n",
+    "            \"model_inputs\": {\n",
+    "                \"x\": torch.tensor(x, dtype=torch.float32),\n",
+    "                # Models in torch_brain typically follow the convention 
that\n",
+    "                # fields that are input to model.forward() are stored 
in\n",
+    "                # \"model_inputs\". Although you are free to deviate from 
this,\n",
+    "                # we have found that this convention generally produces 
cleaner\n",
+    "                # training loops.\n",
+    "            },\n",
+    "            \"target_values\": torch.tensor(y, dtype=torch.float32),\n",
+    "        }\n",
+    "        return data_dict"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "533abuoQIjb5",
+   "metadata": {
+    "id": "533abuoQIjb5"
+   },
+   "source": [
+    "**More on tokenization:**\n",
+    "\n",
+    "Recall that a data sample emitted by the Dataset is only sliced in time, 
but still\n",
+    "contains all fields present in the dataset. In the tokenizer, we:\n",
+    "1. Extract the fields relevant to our machine learning problem. In this 
case, we only care about the spiketrain and the cursor velocity.\n",
+    "2. Process and reshape the extracted data in a format that is convenient 
for our model to process. In our example, we bin the spiketrain."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "TOsvWrojy888",
+   "metadata": {
+    "id": "TOsvWrojy888"
+   },
+   "source": [
+    "#### 2.1.2 Defining the Training Loop\n",
+    "\n",
+    "The following functions perform the training steps. At each epoch, we use 
the data loader to sample batches of data and train the model using an 
optimizer. We then compute the R² score on the validation set to monitor the 
model's performance. This `train` function trains the model for a specified 
number of epochs and logs the R² score on the validation set."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "lrAvNcOPy888",
+   "metadata": {
+    "id": "lrAvNcOPy888"
+   },
+   "outputs": [],
+   "source": [
+    "import torch.nn.functional as F\n",
+    "\n",
+    "def train(model, optimizer, train_loader, val_loader, num_epochs=50, 
store_embs=False):\n",
+    "    # We'll store some intermediate outputs for visualization\n",
+    "    train_outputs = {\n",
+    "        'n_epochs': num_epochs,\n",
+    "        'unit_emb': [],\n",
+    "        'session_emb': [],\n",
+    "        'output_pred': [],\n",
+    "        'output_gt': [],\n",
+    "    }\n",
+    "\n",
+    "    r2_log = []\n",
+    "    loss_log = []\n",
+    "\n",
+    "    # Training loop\n",
+    "    for epoch in range(num_epochs):\n",
+    "        # Compute R² score on validation set\n",
+    "        r2, target, pred = compute_r2(val_loader, model)\n",
+    "        r2_log.append(r2)\n",
+    "\n",
+    "        # Training steps\n",
+    "        for batch in train_loader:\n",
+    "            batch = move_to_gpu(batch, device)\n",
+    "            loss = training_step(batch, model, optimizer)\n",
+    "            loss_log.append(loss.item())\n",
+    "\n",
+    "        print(f\"\\rEpoch {epoch+1}/{num_epochs} | Val R2 = {r2:.3f} | 
Loss = {loss.item():.3f}\", end=\"\")\n",
+    "\n",
+    "        # Store intermediate outputs\n",
+    "        if store_embs:\n",
+    "            
train_outputs['unit_emb'].append(model.unit_emb.weight[1:].detach().cpu().numpy())\n",
+    "            
train_outputs['session_emb'].append(model.session_emb.weight[1:].detach().cpu().numpy())\n",
+    "        
train_outputs['output_gt'].append(target.detach().cpu().numpy())\n",
+    "        
train_outputs['output_pred'].append(pred.detach().cpu().numpy())\n",
+    "\n",
+    "    # Compute final R² score\n",
+    "    r2, _, _ = compute_r2(val_loader, model)\n",
+    "    r2_log.append(r2)\n",
+    "    print(f\"\\nDone! Final validation R2 = {r2:.3f}\")\n",
+    "\n",
+    "    return r2_log, loss_log, train_outputs\n",
+    "\n",
+    "\n",
+    "def training_step(batch, model, optimizer):\n",
+    "    optimizer.zero_grad()                  # Step 0. Clear old 
gradients\n",
+    "    pred = model(**batch[\"model_inputs\"])  # Step 1. Do forward pass\n",
+    "    target = batch[\"target_values\"]\n",
+    "    loss = F.mse_loss(pred, target)        # Step 2. Compute loss\n",
+    "    loss.backward()                        # Step 3. Backward pass\n",
+    "    optimizer.step()                       # Step 4. Update model 
params\n",
+    "    return loss\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "PGeSPwBGy888",
+   "metadata": {
+    "id": "PGeSPwBGy888"
+   },
+   "source": [
+    "#### 2.1.3 Let's train!"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "RuDAItbuy889",
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/";,
+     "height": 462
+    },
+    "executionInfo": {
+     "elapsed": 313592,
+     "status": "ok",
+     "timestamp": 1743092422552,
+     "user": {
+      "displayName": "Eva Dyer",
+      "userId": "05212169819659068372"
+     },
+     "user_tz": 240
+    },
+    "id": "RuDAItbuy889",
+    "outputId": "2891859d-553d-4853-df36-3177c4fdc169"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Num Units in Session: 55\n",
+      "Epoch 100/100 | Val R2 = 0.558 | Loss = 5.405\n",
+      "Done! Final validation R2 = 0.561\n"
+     ]
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAABKUAAAGGCAYAAACqvTJ0AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAA69VJREFUeJzs3XdcU9f7B/BPEiBsEBAQxa3g1rqtq+7RbYfWtmp3rV122l9rqx121w5bW1vrqFZbv9YOrUpxV9Q6cIOo4EA2siHz/v7AhITcLEi4gJ/36+Wr5N5zzz05RJs8ec5zZIIgCCAiIiIiIiIiIqpDcqkHQERERERERERE1x8GpYiIiIiIiIiIqM4xKEVERERERERERHWOQSkiIiIiIiIiIqpzDEoREREREREREVGdY1CKiIiIiIiIiIjqHINSRERERERERERU5xiUIiIiIiIiIiKiOsegFBERERERERER
 [...]
+      "text/plain": [
+       "<Figure size 1200x400 with 2 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "seed_everything(0)\n",
+    "\n",
+    "# 1. Setup datasets and dataloader\n",
+    "recording_id = 
\"perich_miller_population_2018/t_20130819_center_out_reaching\"\n",
+    "train_dataset, train_loader, val_dataset, val_loader = 
get_train_val_loaders(recording_id, batch_size=64)\n",
+    "num_units = len(train_dataset.get_unit_ids())\n",
+    "print(f\"Num Units in Session: {num_units}\")\n",
+    "\n",
+    "# 2. Initialize Model with the new MLP definition\n",
+    "mlp_model = MLPNeuralDecoder(\n",
+    "    num_units=num_units,    # Num. of units inputted (spiking 
activity)\n",
+    "    #\n",
+    "    bin_size=10e-3,         # Duration (s) of bins\n",
+    "    sequence_length=1.0,    # Context length of the model\n",
+    "    #\n",
+    "    output_dim=2,           # Output dimension of final readout layer\n",
+    "    hidden_dim=32,          # Hidden dimension of the model\n",
+    ")\n",
+    "mlp_model = mlp_model.to(device)\n",
+    "\n",
+    "# 3. Connect Tokenizer to Datasets\n",
+    "transform = mlp_model.tokenize\n",
+    "train_dataset.transform = transform\n",
+    "val_dataset.transform = transform\n",
+    "\n",
+    "# 4. Setup Optimizer\n",
+    "optimizer = torch.optim.AdamW(mlp_model.parameters(), lr=1e-3)\n",
+    "\n",
+    "# 5. Train!\n",
+    "mlp_r2_log, mlp_loss_log, mlp_train_outputs = train(mlp_model, optimizer, 
train_loader, val_loader, num_epochs=100)\n",
+    "\n",
+    "# Plot the training loss and validation R2\n",
+    "plot_training_curves(mlp_r2_log, mlp_loss_log)\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "mZ6kUr_Sy889",
+   "metadata": {
+    "id": "mZ6kUr_Sy889"
+   },
+   "source": [
+    "You should now see a training loss curve steadily decreasing and the 
validation R² rising. These trends mean your model is learning effectively!"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "TEt1Bbpcy889",
+   "metadata": {
+    "id": "TEt1Bbpcy889"
+   },
+   "source": [
+    "### 2.2 Training a simple Transformer for Neural Decoding\n",
+    "Next up: let's move on to the main course - Transformers! Let's explore 
how attention can be used for neural decoding by building and training a simple 
transformer."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "SrtWWwspy889",
+   "metadata": {
+    "id": "SrtWWwspy889"
+   },
+   "source": [
+    "#### 2.2.1 Defining a Transformer model\n",
+    "\n",
+    "The philosophy of having a `model.tokenize` and `model.forward` methods 
remains the same as before, however our model is a bit more complex than the 
humble MLP.\n",
+    "\n",
+    "<center>\n",
+    "  <img 
src=\"https://ik.imagekit.io/7tkfmw7hc/transformer.png?updatedAt=1743064332335\";
 alt=\"The Transformer model architecture.\" width=900/>\n",
+    "</center>\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "mfgVIeUKy889",
+   "metadata": {
+    "id": "mfgVIeUKy889"
+   },
+   "outputs": [],
+   "source": [
+    "from torch_brain.nn import FeedForward\n",
+    "\n",
+    "class TransformerNeuralDecoder(nn.Module):\n",
+    "    def __init__(\n",
+    "        self, num_units, bin_size, sequence_length,   # data 
properties\n",
+    "        dim_output, dim_hidden, n_layers, n_heads,    # transformer 
properties\n",
+    "    ):\n",
+    "        \"\"\"Initialize the neural net components\"\"\"\n",
+    "        super().__init__()\n",
+    "\n",
+    "        self.num_timesteps = int(sequence_length / bin_size)\n",
+    "        self.bin_size = bin_size\n",
+    "\n",
+    "        # Create the read-in/out linear layers\n",
+    "        self.readin = nn.Linear(num_units, dim_hidden)\n",
+    "        self.readout = nn.Linear(dim_hidden, dim_output)\n",
+    "\n",
+    "        # Create the position embeddings\n",
+    "        # Note that these are kept constant in this implementation, i.e. 
_not_ learnable\n",
+    "        self.position_embeddings = nn.Parameter(\n",
+    "            data=generate_sinusoidal_position_embs(self.num_timesteps, 
dim_hidden),\n",
+    "            requires_grad=False,\n",
+    "        )\n",
+    "\n",
+    "        # Create the transformer layers:\n",
+    "        # each composed of the Attention and the feedforward (FFN) 
blocks\n",
+    "        self.transformer_layers = nn.ModuleList([\n",
+    "            nn.ModuleList([\n",
+    "                nn.MultiheadAttention(\n",
+    "                    embed_dim=dim_hidden,\n",
+    "                    num_heads=n_heads,\n",
+    "                    batch_first=True,\n",
+    "                ),\n",
+    "                FeedForward(dim=dim_hidden),\n",
+    "            ])\n",
+    "            for _ in range(n_layers)\n",
+    "        ])\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        \"\"\"Produces predictions from a binned spiketrain.\n",
+    "        This is pure PyTorch code.\n",
+    "\n",
+    "        Shape of x: (B, T, N)\n",
+    "        \"\"\"\n",
+    "\n",
+    "        # Read-in: converts our input marix to transformer tokens; one 
token for each timestep\n",
+    "        x = self.readin(x)  # (B, T, N) -> (B, T, D)\n",
+    "\n",
+    "        # Add position embeddings to the tokens\n",
+    "        x = x + self.position_embeddings[None, ...]  # -> (B, T, D)\n",
+    "\n",
+    "        # Transformer\n",
+    "        for attn, ffn in self.transformer_layers:\n",
+    "            x = x + attn(x, x, x, need_weights=False)[0]\n",
+    "            x = x + ffn(x)\n",
+    "\n",
+    "        # Readout: converts tokens to 2d vectors; each vector signifying 
(v_x, v_y) at that timestep\n",
+    "        x = self.readout(x)  # (B, T, D) -> (B, T, 2)\n",
+    "\n",
+    "        return x\n",
+    "\n",
+    "    def tokenize(self, data):\n",
+    "        # Same tokenizer as the MLP\n",
+    "\n",
+    "        # A. Bin spikes\n",
+    "        x = bin_spikes(\n",
+    "            spikes=data.spikes,\n",
+    "            num_units=len(data.units),\n",
+    "            bin_size=self.bin_size,\n",
+    "            num_bins=self.num_timesteps,\n",
+    "        ).T\n",
+    "\n",
+    "        # B. Extract targets\n",
+    "        y = data.cursor.vel\n",
+    "\n",
+    "        data_dict = {\n",
+    "            \"model_inputs\": {\n",
+    "                \"x\": torch.tensor(x, dtype=torch.float32),\n",
+    "            },\n",
+    "            \"target_values\": torch.tensor(y, dtype=torch.float32),\n",
+    "        }\n",
+    "        return data_dict"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "tNLt0os-y889",
+   "metadata": {
+    "id": "tNLt0os-y889"
+   },
+   "source": [
+    "#### 2.2.2 Let's train!"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "qWAph1Vjy889",
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/";,
+     "height": 444
+    },
+    "executionInfo": {
+     "elapsed": 314676,
+     "status": "ok",
+     "timestamp": 1743092737253,
+     "user": {
+      "displayName": "Eva Dyer",
+      "userId": "05212169819659068372"
+     },
+     "user_tz": 240
+    },
+    "id": "qWAph1Vjy889",
+    "outputId": "df155464-0118-477c-809e-30c849347425"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch 100/100 | Val R2 = 0.671 | Loss = 3.814\n",
+      "Done! Final validation R2 = 0.710\n"
+     ]
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAABKUAAAGGCAYAAACqvTJ0AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAA/nFJREFUeJzs3XdYk1f7B/BvEkLYICAgijJU3KMO3KNuba2jrdql1o63ahdvl/21tnbRaa2trX2tVm1rtXbYZVXqHrjFLeJAQBkCQpgh4/n9ERISMkggIajfz3V5lTzjPOc5pPrkzn3uIxIEQQAREREREREREVEDEru6A0REREREREREdPthUIqIiIiIiIiIiBocg1JERERERERERNTgGJQiIiIiIiIiIqIGx6AUERERERERERE1OAaliIiIiIiIiIiowTEoRUREREREREREDY5BKSIiIiIiIiIianAMShERERER
 [...]
+      "text/plain": [
+       "<Figure size 1200x400 with 2 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "seed_everything(0)\n",
+    "\n",
+    "# 1. Setup datasets and dataloader\n",
+    "recording_id = 
\"perich_miller_population_2018/t_20130819_center_out_reaching\"\n",
+    "train_dataset, train_loader, val_dataset, val_loader = 
get_train_val_loaders(recording_id, batch_size=64)\n",
+    "num_units = len(train_dataset.get_unit_ids())\n",
+    "\n",
+    "# 2. Initialize Model\n",
+    "tf_model = TransformerNeuralDecoder(\n",
+    "    num_units=num_units,    # Num. of units inputted (spiking 
activity)\n",
+    "    #\n",
+    "    bin_size=10e-3,         # Duration (s) of bins\n",
+    "    sequence_length=1.0,    # Context length of the model\n",
+    "    #\n",
+    "    dim_output=2,           # Output dimension of final readout layer\n",
+    "    dim_hidden=128,         # Hidden dimension of the model\n",
+    "    n_layers=3,             # Num. of transformer layers\n",
+    "    n_heads=4,              # Num. of heads in MHA blocks\n",
+    ").to(device)\n",
+    "\n",
+    "# 3. Connect Tokenizer to Datasets\n",
+    "train_dataset.transform = tf_model.tokenize\n",
+    "val_dataset.transform = tf_model.tokenize\n",
+    "\n",
+    "# 4. Setup Optimizer\n",
+    "optimizer = torch.optim.AdamW(tf_model.parameters(), lr=1e-3)\n",
+    "\n",
+    "# 5. Train!\n",
+    "transformer_r2_log, transformer_loss_log, transformer_train_outputs = 
train(tf_model, optimizer, train_loader, val_loader, num_epochs=100)\n",
+    "\n",
+    "# Plot the training loss and validation R2\n",
+    "plot_training_curves(transformer_r2_log, transformer_loss_log)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "rxqk7gJry88-",
+   "metadata": {
+    "id": "rxqk7gJry88-"
+   },
+   "source": [
+    "\n",
+    "### 2.3 Training POYO!\n",
+    "\n",
+    "Here we show how we can instantiate a POYO model and train it from 
scratch."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "9-csi0JSZUHj",
+   "metadata": {
+    "id": "9-csi0JSZUHj"
+   },
+   "source": [
+    "From the figure below, we can see that the POYO model has a lot more 
going on internally in comparison to the two models we just saw. Thankfully, 
**torch_brain** provides implementations of standard neural decoding models 
that we can directly use in (almost) the same framework as we did with the 
previous models. We currently have implementations of POYO [3], and POYO+ [4], 
and are actively working on adding more models, such as NDT-2 [5], MTM [6], 
etc.\n",
+    "\n",
+    "<center>\n",
+    "<img 
src=\"https://ik.imagekit.io/7tkfmw7hc/poyo.png?updatedAt=1743064332701\"; 
width=900/>\n",
+    "</center>\n",
+    "\n",
+    "<!-- For custom models (and when training on a single session), it 
suffices to specify a single recording ID when defining the Dataset and 
DataLoaders, especially since things like normalization can be defined within 
the model definition. Since POYO was built for multi-session, it's important to 
define a more concrete configuration that allows for training across multiple 
sessions, or even multiple datasets. Here we simplify the process for the 
session we have been using, by declari [...]
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "PC_5tFCny88-",
+   "metadata": {
+    "id": "PC_5tFCny88-"
+   },
+   "outputs": [],
+   "source": [
+    "seed_everything(0)\n",
+    "\n",
+    "# 1. Setup datasets and dataloader\n",
+    "# For a model like POYO, which was built for multi-session training, the 
way to\n",
+    "# instantiate a dataset is just slightly more involved than what we have 
used\n",
+    "# so far. We have abstracted that difference in a utility function 
`get_dataset_config`\n",
+    "train_dataset, train_loader, val_dataset, val_loader = 
get_train_val_loaders(\n",
+    "    cfg=get_dataset_config(\"perich_miller_population_2018\", 
\"t_20130819_center_out_reaching\"),\n",
+    "    batch_size=64,\n",
+    ")\n",
+    "\n",
+    "# 2. Instantiate the model. The model implementation is provided by 
torch_brain\n",
+    "from torch_brain.models import POYO\n",
+    "from torch_brain.registry import MODALITY_REGISTRY\n",
+    "poyo_model = POYO(\n",
+    "    sequence_length=1.0,                                    # Context 
length of the model\n",
+    "    readout_spec=MODALITY_REGISTRY['cursor_velocity_2d'],   # POYO allows 
for multiple readout modalities; this is how we choose\n",
+    "    #\n",
+    "    latent_step=1.0 / 8,                                    # Timestep of 
the learned latent grid\n",
+    "    num_latents_per_step=16,                                # Number of 
unique learned latents  per timestep\n",
+    "    #\n",
+    "    dim=64,                                                 # Hidden 
dimension of the model\n",
+    "    depth=6,                                                # Number of 
transformer layers\n",
+    "    #\n",
+    "    dim_head=64,                                            # Dimension 
of each attention head\n",
+    "    cross_heads=2,                                          # Num. of 
heads in cross-attention blocks\n",
+    "    self_heads=8,                                           # Num. of 
heads in self attention blocks\n",
+    ").to(device)\n",
+    "\n",
+    "# 2.5: Extra step: populate the Unit and Session Embedding 
Vocabularies\n",
+    "poyo_model.unit_emb.initialize_vocab(train_dataset.get_unit_ids())\n",
+    
"poyo_model.session_emb.initialize_vocab(train_dataset.get_session_ids())\n",
+    "\n",
+    "# 3. Connect tokenizers to Datasets\n",
+    "train_dataset.transform = poyo_model.tokenize\n",
+    "val_dataset.transform = poyo_model.tokenize\n",
+    "\n",
+    "# 4. Setup Optimizer\n",
+    "optimizer = torch.optim.AdamW(poyo_model.parameters(), lr=1e-3)\n",
+    "\n",
+    "# 5. Train!\n",
+    "poyo_r2_log, poyo_loss_log, poyo_train_outputs = train(\n",
+    "  poyo_model, optimizer, train_loader, val_loader,\n",
+    "  num_epochs=100, store_embs=True,\n",
+    ")\n",
+    "\n",
+    "# Plot the training loss and validation R2\n",
+    "plot_training_curves(poyo_r2_log, poyo_loss_log)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "R02IoWf_y88-",
+   "metadata": {
+    "id": "R02IoWf_y88-"
+   },
+   "source": [
+    "***\n",
+    "\n",
+    "<a id=\"Finetuning\"></a>\n",
+    "## Part 3. Finetuning and Visualizing\n",
+    "\n",
+    "***\n",
+    "\n",
+    "In the above sections, we saw that creating and training models for a 
single session of neural recordings, can be made quick and easy using 
**torch_brain**. In addition, standard models in **torch_brain** have been 
pretrained across different sessions, subjects, tasks, and even datasets. We 
can utilize the generalization capabilities of these scaled up models, and 
*fine-tune* them on new data with minimal effort.\n",
+    "\n",
+    "In this section, we will demonstrate how to fine-tune a pretrained 
POYO-mp model on a new session. We note that the session we've\n",
+    "been using so far (`t_20130819_center_out_reaching`) was held out of 
POYO-mp pretraining.\n",
+    "\n",
+    "***\n",
+    "### Table of contents:\n",
+    "* 3.1 Loading a Pretrained Model\n",
+    "* 3.2 Finetuning Strategies\n",
+    "* 3.3 Let's fine-tune!\n",
+    "* 3.4 Let's compare\n",
+    "* 3.5 Visualizing the POYO model"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "oFkTHcipy88-",
+   "metadata": {
+    "id": "oFkTHcipy88-"
+   },
+   "source": [
+    "### 3.1 Loading a Pretrained Model\n",
+    "First we will download the POYO-MP pretrained checkpoint."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "-6x_N3-hy88-",
+   "metadata": {
+    "id": "-6x_N3-hy88-"
+   },
+   "outputs": [],
+   "source": [
+    "! uv pip install boto3 -q\n",
+    "import boto3\n",
+    "from botocore import UNSIGNED\n",
+    "from botocore.config import Config\n",
+    "\n",
+    "# Download the pretrained model\n",
+    "s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))\n",
+    "s3.download_file(\n",
+    "    \"torch-brain\",\n",
+    "    \"model-zoo/poyo_mp.ckpt\",\n",
+    "    \"poyo_mp.ckpt\"\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "T_yUo6l0y88-",
+   "metadata": {
+    "id": "T_yUo6l0y88-"
+   },
+   "source": [
+    "\n",
+    "### 3.2 Finetuning Strategies\n",
+    "\n",
+    "As a reminder, POYO learns an embedding space for units and sessions that 
is shared across all its training data. When fine-tuning POYO on a new session, 
we need to find how to project the new units and session into the previously 
learned embedding space. For this we keep most of the model weights unchanged 
and only train the new embeddings (we call this *freezing the backbone*). We 
can opt to unfreeze them after a certain number of epochs, or keep them frozen 
throughout.\n",
+    "\n",
+    "<center>\n",
+    "  <img 
src=\"https://media-hosting.imagekit.io/61678065e60f442a/Screenshot 2025-03-26 
at 6.59.27 
PM.png?Expires=1837637971&Key-Pair-Id=K2ZIVPTIP2VGHC&Signature=aG8nDmY1CwhZ~6bOcuYws22lY3qn3qbNf7m-gWFrO3ZvBNDpfLSpxxhlbrRv5wysohL~COsjtziTinRALYxzkN7QxiQ1hhIpM8ieFamz2Po-0Ocvx12~FQeDOMRFMCfVay3QdOkXTTyP0PgFab7~e9oK2VEljN20SzhAd2mH4qVLLh-kZXHGvrnjbbSIxQiQ~OKL04tNzyRSwKglhv13l54qtpqUAparzW5lA3H50ZuiLMJJ9ue03Qg4a2X32SrHMkJUpoeFZgVgAw046Ukry-MZixsFVynLVv8vDWdl7e4ZhQIE3kJuXxpowWqdlOoSk5-oBDw
 [...]
+    "</center>"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "PzWxTVE-bMOy",
+   "metadata": {
+    "id": "PzWxTVE-bMOy"
+   },
+   "source": [
+    "#### Finetuning function\n",
+    "\n",
+    "Depending on the value of `epoch_to_unfreeze`, we have different 
fine-tuning strategies:\n",
+    "- **Full fine-tuning** (with gradual unfreezing): if we set 
`epoch_to_unfreeze >= 0` then we will train the new embeddings for 
`epoch_to_unfreeze` epochs,\n",
+    "then unfreeze the backbone and train the entire model for the remaining 
epochs.\n",
+    "- **Unit Identification**: if we set `epoch_to_unfreeze == -1`, then we 
will only train the new embeddings for all epochs, i.e. the backbone\n",
+    "will remain frozen throughout training.\n",
+    "\n",
+    "For this we'll define a new training function that allows for different 
fine-tuning strategies (it's otherwise the same as the previous training 
function)."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "5k7z0v1Jy88-",
+   "metadata": {
+    "id": "5k7z0v1Jy88-"
+   },
+   "outputs": [],
+   "source": [
+    "from tqdm import tqdm\n",
+    "from torch.cuda.amp import autocast, GradScaler\n",
+    "\n",
+    "def finetune(model, optimizer, train_loader, val_loader, num_epochs=50, 
epoch_to_unfreeze=30):\n",
+    "    # Freeze the backbone\n",
+    "    backbone_params = [\n",
+    "        p for p in model.named_parameters()\n",
+    "        if (\n",
+    "            'unit_emb' not in p[0]\n",
+    "            and 'session_emb' not in p[0]\n",
+    "            and 'readout' not in p[0]\n",
+    "            and p[1].requires_grad\n",
+    "        )\n",
+    "    ]\n",
+    "    for _, param in backbone_params:\n",
+    "        param.requires_grad = False\n",
+    "\n",
+    "    # We'll store some intermediate outputs for visualization\n",
+    "    train_outputs = {\n",
+    "        'n_epochs': num_epochs,\n",
+    "        'epoch_to_unfreeze': epoch_to_unfreeze,\n",
+    "        'unit_emb': [],\n",
+    "        'session_emb': [],\n",
+    "        'output_pred': [],\n",
+    "        'output_gt': [],\n",
+    "    }\n",
+    "\n",
+    "    r2_log = []\n",
+    "    loss_log = []\n",
+    "\n",
+    "    # Training loop\n",
+    "    for epoch in range(num_epochs):\n",
+    "        # Unfreeze the backbone after `epoch_to_unfreeze` epochs\n",
+    "        if epoch == epoch_to_unfreeze:\n",
+    "            for _, param in backbone_params:\n",
+    "                param.requires_grad = True\n",
+    "            print(\" Unfreezing entire model\")\n",
+    "\n",
+    "        with torch.no_grad():\n",
+    "            r2, target, pred = compute_r2(val_loader, model)\n",
+    "            r2_log.append(r2)\n",
+    "\n",
+    "        for batch in train_loader:\n",
+    "            batch = move_to_gpu(batch, device)\n",
+    "            loss = training_step(batch, model, optimizer)\n",
+    "            loss_log.append(loss.item())\n",
+    "\n",
+    "        print(f\"\\rEpoch {epoch+1}/{num_epochs} | Val R2 = {r2:.3f} | 
Loss = {loss.item():.3f}\", end=\"\")\n",
+    "\n",
+    "        # Store intermediate outputs\n",
+    "        
train_outputs['unit_emb'].append(model.unit_emb.weight[1:].detach().cpu().numpy())\n",
+    "        
train_outputs['session_emb'].append(model.session_emb.weight[1:].detach().cpu().numpy())\n",
+    "        
train_outputs['output_gt'].append(target.detach().cpu().numpy())\n",
+    "        
train_outputs['output_pred'].append(pred.detach().cpu().numpy())\n",
+    "\n",
+    "        del target, pred\n",
+    "\n",
+    "    # Compute final R² score\n",
+    "    r2, _, _ = compute_r2(val_loader, model)\n",
+    "    r2_log.append(r2)\n",
+    "    print(f\"\\nDone! Final validation R2 = {r2:.3f}\")\n",
+    "\n",
+    "    return r2_log, loss_log, train_outputs"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "2XebFpsKy88-",
+   "metadata": {
+    "id": "2XebFpsKy88-"
+   },
+   "source": [
+    "### 3.3 Let's fine-tune!\n",
+    "Note: the model was pretrained with the fixed vocabulary of units and 
sessions used during training. Since we are now fine-tuning on a new session 
(with new units), we extend the vocabulary to include the new units and 
session, then subset the vocabulary to only look at\n",
+    "the new ones.\n",
+    "\n",
+    "Here we'll put the full training workflow into a function as well so we 
can easily repeat it for different sessions or\n",
+    "finetuning strategies. Of course, if one wishes to manually implement a 
more complex training loop, they can do so as well."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "CAohIKciy88-",
+   "metadata": {
+    "id": "CAohIKciy88-"
+   },
+   "outputs": [],
+   "source": [
+    "seed_everything(0)\n",
+    "\n",
+    "# 1. Setup datasets and dataloader\n",
+    "train_dataset, train_loader, val_dataset, val_loader = 
get_train_val_loaders(\n",
+    "    cfg=get_dataset_config(\"perich_miller_population_2018\", 
\"t_20130819_center_out_reaching\"),\n",
+    "    batch_size=64,\n",
+    ")\n",
+    "num_units = len(train_dataset.get_unit_ids())\n",
+    "\n",
+    "# 2. Instantiate the model and load pretrained weights (with poyo-mp 
hparams)\n",
+    "from torch_brain.models import POYO\n",
+    "from torch_brain.registry import MODALITY_REGISTRY\n",
+    "poyo_ft_model = POYO(\n",
+    "    sequence_length=1.0,\n",
+    "    latent_step=1.0 / 8,\n",
+    "    dim=128,\n",
+    "    readout_spec=MODALITY_REGISTRY['cursor_velocity_2d'],\n",
+    "    dim_head=64,\n",
+    "    num_latents_per_step=32,\n",
+    "    depth=24,\n",
+    "    cross_heads=4,\n",
+    "    self_heads=8,\n",
+    ")\n",
+    "\n",
+    "ckpt_path = 'poyo_mp.ckpt'\n",
+    "poyo_ft_model = load_pretrained(ckpt_path, poyo_ft_model)\n",
+    "\n",
+    "# 2.5. Reinitialize the vocabs for the new session\n",
+    "reinit_vocab(poyo_ft_model.unit_emb, train_dataset.get_unit_ids())\n",
+    "reinit_vocab(poyo_ft_model.session_emb, 
train_dataset.get_session_ids())\n",
+    "\n",
+    "poyo_ft_model.to(device)\n",
+    "\n",
+    "# 3. Connect tokenizers to Datasets\n",
+    "train_dataset.transform = poyo_ft_model.tokenize\n",
+    "val_dataset.transform = poyo_ft_model.tokenize\n",
+    "\n",
+    "# 4. Setup Optimizer\n",
+    "optimizer = torch.optim.AdamW(poyo_ft_model.parameters(), lr=1e-3)\n",
+    "\n",
+    "# 5. Train!\n",
+    "poyo_ft_r2_log, poyo_ft_loss_log, poyo_ft_train_outputs = 
finetune(poyo_ft_model, optimizer, train_loader, val_loader,\n",
+    "                                          num_epochs=40, 
epoch_to_unfreeze=10)\n",
+    "\n",
+    "# Visualize the results\n",
+    "plot_training_curves(poyo_ft_r2_log, poyo_ft_loss_log)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "RA1mnGAOfL4Y",
+   "metadata": {
+    "id": "RA1mnGAOfL4Y"
+   },
+   "source": [
+    "### 3.4 Let's compare"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ZI41uQ2ry88_",
+   "metadata": {
+    "id": "ZI41uQ2ry88_"
+   },
+   "outputs": [],
+   "source": [
+    "lcls = locals().copy()\n",
+    "for lcl in lcls:\n",
+    "    if not lcl.endswith(\"_r2_log\"):\n",
+    "        continue\n",
+    "    model = lcl.split(\"_r2_log\")[0].upper()\n",
+    "    plt.plot(locals()[lcl], label=model)\n",
+    "plt.xlabel(\"Epoch\")\n",
+    "plt.ylabel(\"Validation $R^2$\")\n",
+    "plt.grid()\n",
+    "plt.legend()\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "mHfhuXT8y88_",
+   "metadata": {
+    "id": "mHfhuXT8y88_"
+   },
+   "source": [
+    "### 3.5 Visualizing the POYO model\n",
+    "\n",
+    "POYO's unit and session embeddings are learned during training, and we 
can visualize them along with a snapshot of the\n",
+    "model's performance."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "08Cr0k2lcKJV",
+   "metadata": {
+    "id": "08Cr0k2lcKJV"
+   },
+   "source": [
+    "#### Visualization utilities\n",
+    "\n",
+    "Just run this block."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "sdWySSFCy88_",
+   "metadata": {
+    "id": "sdWySSFCy88_"
+   },
+   "outputs": [],
+   "source": [
+    "from sklearn.decomposition import PCA\n",
+    "from bokeh.plotting import figure, show\n",
+    "from bokeh.models import ColumnDataSource, Slider, CustomJS, Div, 
Spacer\n",
+    "from bokeh.layouts import column, row\n",
+    "from bokeh.io import output_notebook\n",
+    "from bokeh.palettes import Pastel1\n",
+    "\n",
+    "# visualize_training(model, train_outputs, prev_session_emb, 
prev_session_emb_labels, n_epochs, epoch_to_unfreeze):\n",
+    "def visualize_training(model, ckpt, train_outputs):\n",
+    "    n_epochs = train_outputs[\"n_epochs\"]\n",
+    "    epoch_to_unfreeze = train_outputs.get(\"epoch_to_unfreeze\", 0)\n",
+    "    # Extract info from \"finetuned\" model\n",
+    "    cur_session_emb_labels = list(model.session_emb.vocab.keys())[1:]\n",
+    "    # Extract info from \"pretrained\" model\n",
+    "    if ckpt is not None and ckpt != \"\":\n",
+    "        ckpt_data = torch.load(ckpt, weights_only=False, 
map_location=\"cpu\")\n",
+    "        prev_session_emb = 
ckpt_data['state_dict']['model.session_emb.weight'][1:].detach().clone()\n",
+    "        prev_session_emb_labels = [str(x) for x in 
ckpt_data['state_dict']['model.session_emb.vocab']][1:]\n",
+    "    else:\n",
+    "        prev_session_emb = np.zeros((0, model.dim))\n",
+    "        prev_session_emb_labels = []\n",
+    "\n",
+    "    n_units = train_outputs['unit_emb'][0].shape[0]\n",
+    "    n_sessions = train_outputs['session_emb'][0].shape[0]\n",
+    "    n_prev_sessions = prev_session_emb.shape[0]\n",
+    "\n",
+    "    pca = PCA(n_components=2)\n",
+    "    flat_unit_emb = np.concatenate(train_outputs['unit_emb'], axis=0) # 
(n_epochs*n_units, emb_dim)\n",
+    "    unit_emb_pca = pca.fit_transform(flat_unit_emb) # (n_epochs*n_units, 
2)\n",
+    "    unit_emb_x = unit_emb_pca[:, 0]\n",
+    "    unit_emb_y = unit_emb_pca[:, 1]\n",
+    "\n",
+    "    total_session_emb = np.concatenate([prev_session_emb, # 
(n_prev_sessions+n_epochs*n_sessions, emb_dim)\n",
+    "                                        *train_outputs['session_emb']], 
axis=0)\n",
+    "    session_emb_pca = pca.fit_transform(total_session_emb) # 
(n_prev_sessions+n_epochs*n_sessions, 2)\n",
+    "    session_emb_x = session_emb_pca[:, 0]\n",
+    "    session_emb_y = session_emb_pca[:, 1]\n",
+    "\n",
+    "    sampling_rate = 100\n",
+    "    sample_out_start = 2000\n",
+    "    sample_out_end = sample_out_start + 500\n",
+    "    pred_x = np.concatenate([output[sample_out_start:sample_out_end, 0] 
for output in train_outputs['output_pred']])\n",
+    "    pred_y = np.concatenate([output[sample_out_start:sample_out_end, 1] 
for output in train_outputs['output_pred']])\n",
+    "    gt_x = np.concatenate([output[sample_out_start:sample_out_end, 0] for 
output in train_outputs['output_gt']])\n",
+    "    gt_y = np.concatenate([output[sample_out_start:sample_out_end, 1] for 
output in train_outputs['output_gt']])\n",
+    "\n",
+    "    _Pastel1 = Pastel1.copy()\n",
+    "    _Pastel1 = {**_Pastel1, **{i: _Pastel1[3][:i] for i in range(1, 3)}}  
# for some reason only starts at 3\n",
+    "    all_session_emb_labels = prev_session_emb_labels + 
cur_session_emb_labels\n",
+    "    extract_sess_label_grp = lambda label: label.split('/')[0] + '/' + 
label.split('/')[1].split('_')[0]\n",
+    "    all_sess_emb_label_grps = [extract_sess_label_grp(label) for label in 
all_session_emb_labels]  # map to <brainset/xx>_...\n",
+    "    unique_sess_emb_label_grps = list(set(all_sess_emb_label_grps))\n",
+    "    unique_sess_emb_label_grp_map = {label: i for i, label in 
enumerate(unique_sess_emb_label_grps)}  # create index map for unique labels\n",
+    "    all_sess_emb_label_grp_idx = [unique_sess_emb_label_grp_map[label] 
for label in all_sess_emb_label_grps]  # map labels to index\n",
+    "\n",
+    "    # Generate pastel colors based on unique labels\n",
+    "    unique_labels = np.unique(all_sess_emb_label_grp_idx)\n",
+    "    color_map = {label: _Pastel1[len(unique_labels)][i % 
len(_Pastel1[len(unique_labels)])] for i, label in enumerate(unique_labels)}\n",
+    "\n",
+    "    # Assign colors based on group index\n",
+    "    colors = np.array([color_map[label] for label in 
all_sess_emb_label_grp_idx])\n",
+    "\n",
+    "    # Define border color: Only sessions after `n_prev_sessions` should 
have a border\n",
+    "    border_colors = np.array(['black' if i >= n_prev_sessions else None 
for i in range(n_prev_sessions+n_sessions)])\n",
+    "\n",
+    "    # Enable inline visualization in Jupyter Notebook\n",
+    "    output_notebook()\n",
+    "\n",
+    "    # Data sources\n",
+    "    unit_source = ColumnDataSource(data={'x': unit_emb_x[:n_units], 'y': 
unit_emb_y[:n_units]})\n",
+    "    sess_source = ColumnDataSource(data={\n",
+    "        'x': np.concatenate([session_emb_x[:n_prev_sessions],\n",
+    "                            
session_emb_x[n_prev_sessions:n_prev_sessions+n_sessions]]),\n",
+    "        'y': np.concatenate([session_emb_y[:n_prev_sessions],\n",
+    "                            
session_emb_y[n_prev_sessions:n_prev_sessions+n_sessions]]),\n",
+    "        'color': colors,\n",
+    "        'border_color': border_colors,\n",
+    "    })\n",
+    "\n",
+    "    time = np.linspace(0, 
(sample_out_end-sample_out_start)/sampling_rate, 
sample_out_end-sample_out_start)\n",
+    "    pred_source_x = ColumnDataSource(data={'x': time, 'y': 
pred_x[:sample_out_end-sample_out_start]})\n",
+    "    pred_source_y = ColumnDataSource(data={'x': time, 'y': 
pred_y[:sample_out_end-sample_out_start]})\n",
+    "    gt_source_x = ColumnDataSource(data={'x': time, 'y': 
gt_x[:sample_out_end-sample_out_start]})\n",
+    "    gt_source_y = ColumnDataSource(data={'x': time, 'y': 
gt_y[:sample_out_end-sample_out_start]})\n",
+    "\n",
+    "    # Define plot ranges with buffer\n",
+    "    buffer = 0.1\n",
+    "    unit_x_min, unit_x_max = np.min(unit_emb_x), np.max(unit_emb_x)\n",
+    "    unit_y_min, unit_y_max = np.min(unit_emb_y), np.max(unit_emb_y)\n",
+    "    sess_x_min, sess_x_max = np.min(session_emb_x), 
np.max(session_emb_x)\n",
+    "    sess_y_min, sess_y_max = np.min(session_emb_y), 
np.max(session_emb_y)\n",
+    "    unit_buffer_x, unit_buffer_y = buffer * abs(unit_x_min - unit_x_max), 
buffer * abs(unit_y_min - unit_y_max)\n",
+    "    unit_buffer_x, unit_buffer_y = 0.01 if unit_buffer_x == 0 else 
unit_buffer_x, 0.01 if unit_buffer_y == 0 else unit_buffer_y\n",
+    "    sess_buffer_x, sess_buffer_y = buffer * abs(sess_x_min - sess_x_max), 
buffer * abs(sess_y_min - sess_y_max)\n",
+    "    sess_buffer_x, sess_buffer_y = 0.01 if sess_buffer_x == 0 else 
sess_buffer_x, 0.01 if sess_buffer_y == 0 else sess_buffer_y\n",
+    "    unit_x_range = (unit_x_min - unit_buffer_x, unit_x_max + 
unit_buffer_x)\n",
+    "    unit_y_range = (unit_y_min - unit_buffer_y, unit_y_max + 
unit_buffer_y)\n",
+    "    sess_x_range = (sess_x_min - sess_buffer_x, sess_x_max + 
sess_buffer_x)\n",
+    "    sess_y_range = (sess_y_min - sess_buffer_y, sess_y_max + 
sess_buffer_y)\n",
+    "    unit_plot = figure(title=\"Unit Embeddings\", x_axis_label=\"PC1\", 
y_axis_label=\"PC2\", width=300, height=300, 
background_fill_color=\"white\",\n",
+    "                    x_range=unit_x_range, y_range=unit_y_range)\n",
+    "\n",
+    "    sess_plot = figure(title=\"Session Embeddings\", 
x_axis_label=\"PC1\", y_axis_label=\"PC2\", width=300, height=300, 
background_fill_color=\"white\",\n",
+    "                    x_range=sess_x_range, y_range=sess_y_range)\n",
+    "\n",
+    "    # Unit embedding plot\n",
+    "    unit_plot.scatter('x', 'y', source=unit_source, size=10, alpha=0.7, 
color='lightblue', line_color='black')\n",
+    "\n",
+    "    # Session embedding plot\n",
+    "    sess_plot.scatter('x', 'y', source=sess_source, size=10, alpha=0.7, 
color='color', line_color='border_color')\n",
+    "\n",
+    "    # Hand velocity plots\n",
+    "    vx_min, vx_max = np.min(np.concatenate([pred_x, gt_x])), 
np.max(np.concatenate([pred_x, gt_x]))\n",
+    "    vx_range = (vx_min - buffer * abs(vx_min), vx_max + buffer * 
abs(vx_max))\n",
+    "    vx_plot = figure(title=\"Hand Velocity - Vx\", x_axis_label=\"Time 
(s)\", y_axis_label=\"Velocity\", width=300, height=150, 
background_fill_color=\"white\"\n",
+    "                        , y_range=vx_range)\n",
+    "    vx_plot.line('x', 'y', source=pred_source_x, color='red')\n",
+    "    vx_plot.line('x', 'y', source=gt_source_x, color='black')\n",
+    "\n",
+    "    vy_min, vy_max = np.min(np.concatenate([pred_y, gt_y])), 
np.max(np.concatenate([pred_y, gt_y]))\n",
+    "    vy_range = (vy_min - buffer * abs(vy_min), vy_max + buffer * 
abs(vy_max))\n",
+    "    vy_plot = figure(title=\"Hand Velocity - Vy\", x_axis_label=\"Time 
(s)\", y_axis_label=\"Velocity\", width=300, height=150, 
background_fill_color=\"white\"\n",
+    "                        , y_range=vy_range)\n",
+    "    vy_plot.line('x', 'y', source=pred_source_y, color='blue')\n",
+    "    vy_plot.line('x', 'y', source=gt_source_y, color='black')\n",
+    "\n",
+    "    # Model diagram\n",
+    "    if epoch_to_unfreeze == 0:\n",
+    "        model_diagram = Div(text='<img 
src=\"https://media-hosting.imagekit.io//33c2439b9dc549dc/unit-id-unfrozen.svg?Expires=1837453466&Key-Pair-Id=K2ZIVPTIP2VGHC&Signature=i4i4qMNDJhiTEfC-Uo-DhSaxCUVsrN9W7m7mL4RDWPaRHrhsCXuIXsjPzTksoIQobZr9qMGVRNwTpDD-jGIT1-Y2K42H5uXY37WXNfBhHzcb-GsyYAjx9ztVuYi9OAeaiZdtXe-Yc-xQF88RWsBdSBw0KA26Ewwj2CcuBoexlSL3rNttoWMHVzyisTDFyX2N1uYmpnRbKFarnU0Xvn9OXx2y64fgZyT8oGgbYShjHZApDqEujfRubXwrvH86etyPvuzvbzqMxc27u-BgMQv0l--qsq2tP3y66AhQu~EuwbU9E4Dxud94bekO7
 [...]
+    "    else:\n",
+    "        model_diagram = Div(text='<img 
src=\"https://media-hosting.imagekit.io//7bc43a35ec174ef4/unit-id-frozen.svg?Expires=1837453466&Key-Pair-Id=K2ZIVPTIP2VGHC&Signature=CylM3BrCyxTkiRd2DaTpqU~Q2gwDYBH8LjmL0Duv5ybgWMegyrYWfDVmu4weQx~bM7Dbd2BJ3Q-83WhSVRNRTBEzq3Dkr-RjuPdhbBAKhL-0Ku7aC4eo~EetHwa5PeuYGsQvs0jFaQF-XRv3Ow04kxww8m9gAgLTuDyL6-2FYf~68EXGrD0A6GqYeozU65~nwptnzF~YVu2gYsp5ERTOdbyE5TjNIKp21QBwhQ~BkbZ5NEfvITkdvbTv9j1k2p5hmxW~jm15Llz-oxj2l-l2fM~3UZ8JC6NHze~lTqmANrrxf0GIEYBAAwvDWj6
 [...]
+    "\n",
+    "    # Widgets\n",
+    "    slider = Slider(start=0, end=n_epochs-1, value=0, step=1, 
title=\"Epoch\")\n",
+    "\n",
+    "    callback = CustomJS(args=dict(source1=unit_source, 
source2=sess_source, source3=pred_source_x, source4=pred_source_y,\n",
+    "                                source5=gt_source_x, 
source6=gt_source_y,\n",
+    "                                unit_x=unit_emb_x, unit_y=unit_emb_y, 
sess_x=session_emb_x, sess_y=session_emb_y, #cat=session_cat,\n",
+    "                                n_units=n_units, n_sessions=n_sessions, 
n_prev_sessions=n_prev_sessions,\n",
+    "                                diagram=model_diagram, 
unfreeze_epoch=epoch_to_unfreeze,\n",
+    "                                pred_x=pred_x, pred_y=pred_y, gt_x=gt_x, 
gt_y=gt_y,\n",
+    "                                
samples=sample_out_end-sample_out_start,\n",
+    "                                ),#pred_x=pred_x, pred_y=pred_y),\n",
+    "                        code=\"\"\"\n",
+    "        var step = cb_obj.value;\n",
+    "        source1.data.x = unit_x.slice(step*n_units, (step+1)*n_units);\n",
+    "        source1.data.y = unit_y.slice(step*n_units, (step+1)*n_units);\n",
+    "\n",
+    "        let prev_sess_x = sess_x.slice(0, n_prev_sessions);\n",
+    "        let cur_sess_x = sess_x.slice(n_prev_sessions+n_sessions*step, 
n_prev_sessions+n_sessions*(step+1));\n",
+    "        source2.data.x = 
Array.from(prev_sess_x).concat(Array.from(cur_sess_x));\n",
+    "        let prev_sess_y = sess_y.slice(0, n_prev_sessions);\n",
+    "        let cur_sess_y = sess_y.slice(n_prev_sessions+n_sessions*step, 
n_prev_sessions+n_sessions*(step+1));\n",
+    "        source2.data.y = 
Array.from(prev_sess_y).concat(Array.from(cur_sess_y));\n",
+    "\n",
+    "        source3.data.y = pred_x.slice(step*samples, (step+1)*samples);\n",
+    "        source4.data.y = pred_y.slice(step*samples, (step+1)*samples);\n",
+    "        source5.data.y = gt_x.slice(step*samples, (step+1)*samples);\n",
+    "        source6.data.y = gt_y.slice(step*samples, (step+1)*samples);\n",
+    "\n",
+    "        source1.change.emit();\n",
+    "        source2.change.emit();\n",
+    "        source3.change.emit();\n",
+    "        source4.change.emit();\n",
+    "        source5.change.emit();\n",
+    "        source6.change.emit();\n",
+    "\n",
+    "        if (unfreeze_epoch >= 0 && step >= unfreeze_epoch) {\n",
+    "            diagram.text = '<img 
src=\"https://media-hosting.imagekit.io//33c2439b9dc549dc/unit-id-unfrozen.svg?Expires=1837453466&Key-Pair-Id=K2ZIVPTIP2VGHC&Signature=i4i4qMNDJhiTEfC-Uo-DhSaxCUVsrN9W7m7mL4RDWPaRHrhsCXuIXsjPzTksoIQobZr9qMGVRNwTpDD-jGIT1-Y2K42H5uXY37WXNfBhHzcb-GsyYAjx9ztVuYi9OAeaiZdtXe-Yc-xQF88RWsBdSBw0KA26Ewwj2CcuBoexlSL3rNttoWMHVzyisTDFyX2N1uYmpnRbKFarnU0Xvn9OXx2y64fgZyT8oGgbYShjHZApDqEujfRubXwrvH86etyPvuzvbzqMxc27u-BgMQv0l--qsq2tP3y66AhQu~EuwbU9E4Dxud94bekO7ZE9a5T
 [...]
+    "        } else {\n",
+    "            diagram.text = '<img 
src=\"https://media-hosting.imagekit.io//7bc43a35ec174ef4/unit-id-frozen.svg?Expires=1837453466&Key-Pair-Id=K2ZIVPTIP2VGHC&Signature=CylM3BrCyxTkiRd2DaTpqU~Q2gwDYBH8LjmL0Duv5ybgWMegyrYWfDVmu4weQx~bM7Dbd2BJ3Q-83WhSVRNRTBEzq3Dkr-RjuPdhbBAKhL-0Ku7aC4eo~EetHwa5PeuYGsQvs0jFaQF-XRv3Ow04kxww8m9gAgLTuDyL6-2FYf~68EXGrD0A6GqYeozU65~nwptnzF~YVu2gYsp5ERTOdbyE5TjNIKp21QBwhQ~BkbZ5NEfvITkdvbTv9j1k2p5hmxW~jm15Llz-oxj2l-l2fM~3UZ8JC6NHze~lTqmANrrxf0GIEYBAAwvDWj6QcSmZX
 [...]
+    "        }\n",
+    "    \"\"\")\n",
+    "\n",
+    "    slider.js_on_change('value', callback)\n",
+    "\n",
+    "    # Improved Layout Structure\n",
+    "    layout = column(\n",
+    "        slider,\n",
+    "        row(\n",
+    "            column(unit_plot, sess_plot),\n",
+    "            model_diagram,\n",
+    "            column(Spacer(height=80), vx_plot, vy_plot)\n",
+    "        ),\n",
+    "        styles={\n",
+    "            \"background-color\": \"white\",\n",
+    "            \"padding\": \"20px\",\n",
+    "        }\n",
+    "    )\n",
+    "\n",
+    "    show(layout, notebook_handle=True)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "3Hx_zKDKcTjn",
+   "metadata": {
+    "id": "3Hx_zKDKcTjn"
+   },
+   "source": [
+    "#### Visualize"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "C9ayfgrSfx_E",
+   "metadata": {
+    "id": "C9ayfgrSfx_E"
+   },
+   "source": [
+    "We can visualize the training of POYO from scratch:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "rT_Fnw-sy88_",
+   "metadata": {
+    "id": "rT_Fnw-sy88_"
+   },
+   "outputs": [],
+   "source": [
+    "visualize_training(poyo_model, ckpt=None, 
train_outputs=poyo_train_outputs)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "tgCMJHHOfvjU",
+   "metadata": {
+    "id": "tgCMJHHOfvjU"
+   },
+   "source": [
+    "Or from POYO fine-tuning:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "FnDXtBOnfzJ2",
+   "metadata": {
+    "id": "FnDXtBOnfzJ2"
+   },
+   "outputs": [],
+   "source": [
+    "visualize_training(poyo_ft_model, ckpt=\"poyo_mp.ckpt\", 
train_outputs=poyo_ft_train_outputs)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "jSPsG3Jsy88_",
+   "metadata": {
+    "id": "jSPsG3Jsy88_"
+   },
+   "source": [
+    "***\n",
+    "\n",
+    "## Part 4. Your Turn!\n",
+    "\n",
+    "***\n",
+    "Now that you've seen how to train and fine-tune models on neural data, 
try experimenting on your own! Here are three exercises from us with varying 
levels of difficulty."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "su7e6QfmeZ91",
+   "metadata": {
+    "id": "su7e6QfmeZ91"
+   },
+   "source": [
+    "### 4.1 Exercise (a): Easy\n",
+    "\n",
+    "In the transformer implementation in Part 2.2, change the positional 
embeddings from being fixed (i.e. non learnable) to being learnable, and train 
the transformer from scratch. Note that since we have limited training data, we 
won't observe an improvement in performance by making this switch.\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "tV_IjhBFe-j7",
+   "metadata": {
+    "id": "tV_IjhBFe-j7"
+   },
+   "source": [
+    "#### Solution"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "N7HKd5qngBwV",
+   "metadata": {
+    "id": "N7HKd5qngBwV"
+   },
+   "outputs": [],
+   "source": [
+    "from torch_brain.nn import FeedForward\n",
+    "\n",
+    "class TransformerNeuralDecoder(nn.Module):\n",
+    "    def __init__(\n",
+    "        self, num_units, bin_size, sequence_length,   # data 
properties\n",
+    "        dim_output, dim_hidden, n_layers, n_heads,    # transformer 
properties\n",
+    "    ):\n",
+    "        \"\"\"Initialize the neural net components\"\"\"\n",
+    "        super().__init__()\n",
+    "\n",
+    "        self.num_timesteps = int(sequence_length / bin_size)\n",
+    "        self.bin_size = bin_size\n",
+    "\n",
+    "        self.readin = nn.Linear(num_units, dim_hidden)\n",
+    "        self.readout = nn.Linear(dim_hidden, dim_output)\n",
+    "\n",
+    "        self.position_embeddings = nn.Parameter(\n",
+    "            data=generate_sinusoidal_position_embs(self.num_timesteps, 
dim_hidden),\n",
+    "            requires_grad=True,     ##### <<<<<< SOLUTION: CHANGED TO 
TRUE\n",
+    "        )\n",
+    "\n",
+    "        self.transformer_layers = nn.ModuleList([\n",
+    "            nn.ModuleList([\n",
+    "                nn.MultiheadAttention(\n",
+    "                    embed_dim=dim_hidden,\n",
+    "                    num_heads=n_heads,\n",
+    "                    batch_first=True,\n",
+    "                ),\n",
+    "                FeedForward(dim=dim_hidden),\n",
+    "            ])\n",
+    "            for _ in range(n_layers)\n",
+    "        ])\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        # Remains unchanged\n",
+    "\n",
+    "        x = self.readin(x)  # (B, T, N) -> (B, T, D)\n",
+    "        x = x + self.position_embeddings[None, ...]  # -> (B, T, D)\n",
+    "        for attn, ffn in self.transformer_layers:\n",
+    "            x = x + attn(x, x, x, need_weights=False)[0]\n",
+    "            x = x + ffn(x)\n",
+    "        x = self.readout(x)  # (B, T, D) -> (B, T, 2)\n",
+    "\n",
+    "        return x\n",
+    "\n",
+    "    def tokenize(self, data):\n",
+    "        # Remains unchanged\n",
+    "\n",
+    "        x = bin_spikes(\n",
+    "            spikes=data.spikes,\n",
+    "            num_units=len(data.units),\n",
+    "            bin_size=self.bin_size,\n",
+    "            num_bins=self.num_timesteps,\n",
+    "        ).T\n",
+    "        y = data.cursor.vel\n",
+    "        data_dict = {\n",
+    "            \"model_inputs\": {\n",
+    "                \"x\": torch.tensor(x, dtype=torch.float32),\n",
+    "            },\n",
+    "            \"target_values\": torch.tensor(y, dtype=torch.float32),\n",
+    "        }\n",
+    "        return data_dict"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "-M92mJFHiLfp",
+   "metadata": {
+    "id": "-M92mJFHiLfp"
+   },
+   "outputs": [],
+   "source": [
+    "seed_everything(0)\n",
+    "\n",
+    "# 1. Setup datasets and dataloader\n",
+    "recording_id = 
\"perich_miller_population_2018/t_20130819_center_out_reaching\"\n",
+    "train_dataset, train_loader, val_dataset, val_loader = 
get_train_val_loaders(recording_id, batch_size=64)\n",
+    "num_units = len(train_dataset.get_unit_ids())\n",
+    "\n",
+    "# 2. Initialize Model\n",
+    "tf_model = TransformerNeuralDecoder(\n",
+    "    num_units=num_units,\n",
+    "    bin_size=10e-3,\n",
+    "    sequence_length=1.0,\n",
+    "    dim_output=2,\n",
+    "    dim_hidden=128,\n",
+    "    n_layers=3,\n",
+    "    n_heads=4,\n",
+    ").to(device)\n",
+    "\n",
+    "# 3. Connect Tokenizer to Datasets\n",
+    "train_dataset.transform = tf_model.tokenize\n",
+    "val_dataset.transform = tf_model.tokenize\n",
+    "\n",
+    "# 4. Setup Optimizer\n",
+    "optimizer = torch.optim.AdamW(tf_model.parameters(), lr=1e-3)\n",
+    "\n",
+    "# 5. Train!\n",
+    "transformer_r2_log, transformer_loss_log, transformer_train_outputs = 
train(tf_model, optimizer, train_loader, val_loader, num_epochs=100)\n",
+    "\n",
+    "# Plot the training loss and validation R2\n",
+    "plot_training_curves(transformer_r2_log, transformer_loss_log)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "AeJJIf4Wj06v",
+   "metadata": {
+    "id": "AeJJIf4Wj06v"
+   },
+   "source": [
+    "You might observe that this model does not perform as well as the the 
version in Part 2.2 where the position embeddings were fixed. This can be 
attributed to overfitting, now that we can train a lot more parameters (the 
position embeddings!)."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "gqWV0Xbry6tm",
+   "metadata": {
+    "id": "gqWV0Xbry6tm"
+   },
+   "source": [
+    "### 4.2 Exercise (b): Medium (Multi-Session Training!)\n",
+    "\n",
+    "Throughout this notebook, we have been training on a single session 
(`t_20130819_center_out_reaching`) which includes recordings from a single 
monkey during a reaching task. What if we want to train on multiple sessions? 
We noted in Part 2.3 that the configuration used to setup data for training 
POYO allows for loading multiple sessions, as well as some more complex 
configurations.\n",
+    "\n",
+    "In this exercise, we'll train on two separate sessions from the same 
monkey during the reaching task. The new session is called 
`t_20130821_center_out_reaching` and includes recordings from a later date than 
the first one. Note that we might be interested in training across different 
animals or tasks as well, but we are limited by the model capacity.\n",
+    "\n",
+    "First, download the new session below:\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "-i-u6bA2p5vm",
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/";
+    },
+    "executionInfo": {
+     "elapsed": 7065,
+     "status": "ok",
+     "timestamp": 1743098384978,
+     "user": {
+      "displayName": "Shivashriganesh Mahato",
+      "userId": "06240283312510304541"
+     },
+     "user_tz": 240
+    },
+    "id": "-i-u6bA2p5vm",
+    "outputId": "eb7f7253-7542-4dc3-90bf-4d3c9d8457d5"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Downloading...\n",
+      "From: 
https://drive.google.com/uc?id=1EZUkOE8oiieWja9lblokf8WjF05HFuX-\n";,
+      "To: 
/content/data/perich_miller_population_2018/t_20130821_center_out_reaching.h5\n",
+      "100% 10.4M/10.4M [00:00<00:00, 51.1MB/s]\n"
+     ]
+    }
+   ],
+   "source": [
+    "!gdown 1EZUkOE8oiieWja9lblokf8WjF05HFuX- -O 
data/perich_miller_population_2018/t_20130821_center_out_reaching.h5"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "KpEZXyt_0sr0",
+   "metadata": {
+    "id": "KpEZXyt_0sr0"
+   },
+   "source": [
+    "Now, setup train and validation DataLoaders such that training will be 
across both sessions, but validation will stay only on the first session (so we 
can directly compare). Then, train POYO on the two sessions for 50 epochs (half 
of that from original training).\n",
+    "\n",
+    "**Tips**:\n",
+    "\n",
+    "* `get_dataset_config` accepts both a single session as a string (as 
we've seen so far), or a list of sessions:\n",
+    "\n",
+    "```get_dataset_config(brainset, [session1, session2, ...])```\n",
+    "*   You cannot use `get_train_val_loaders` directly since it uses the 
same config for both train and val"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "PbZisWiX3wO2",
+   "metadata": {
+    "id": "PbZisWiX3wO2"
+   },
+   "source": [
+    "#### Solution"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "oFIaWjgTnBmf",
+   "metadata": {
+    "id": "oFIaWjgTnBmf"
+   },
+   "outputs": [],
+   "source": [
+    "# -- Train --\n",
+    "train_dataset = Dataset(\n",
+    "    root=\"data\",\n",
+    "    config=get_dataset_config(\"perich_miller_population_2018\", [\n",
+    "        \"t_20130819_center_out_reaching\",\n",
+    "        \"t_20130821_center_out_reaching\",  # <<<<<< SOLUTION: ADDED\n",
+    "    ]),\n",
+    "    split=\"train\",\n",
+    ")\n",
+    "train_sampling_intervals = train_dataset.get_sampling_intervals()\n",
+    "train_sampler = RandomFixedWindowSampler(\n",
+    "    sampling_intervals=train_sampling_intervals,\n",
+    "    window_length=1.0,\n",
+    ")\n",
+    "train_loader = DataLoader(\n",
+    "    dataset=train_dataset,\n",
+    "    sampler=train_sampler,\n",
+    "    batch_size=64,\n",
+    "    collate_fn=collate,\n",
+    "    num_workers=4,\n",
+    "    pin_memory=True,\n",
+    ")\n",
+    "\n",
+    "# -- Validation --\n",
+    "val_dataset = Dataset(\n",
+    "    root=\"data\",\n",
+    "    config=get_dataset_config(\"perich_miller_population_2018\", 
\"t_20130819_center_out_reaching\"),\n",
+    "    split=\"valid\",\n",
+    ")\n",
+    "val_sampling_intervals = val_dataset.get_sampling_intervals()\n",
+    "val_sampler = SequentialFixedWindowSampler(\n",
+    "    sampling_intervals=val_sampling_intervals,\n",
+    "    window_length=1.0,\n",
+    ")\n",
+    "val_loader = DataLoader(\n",
+    "    dataset=val_dataset,\n",
+    "    sampler=val_sampler,\n",
+    "    batch_size=64,\n",
+    "    collate_fn=collate,\n",
+    "    num_workers=4,\n",
+    "    pin_memory=True,\n",
+    ")\n",
+    "\n",
+    "train_dataset.disable_data_leakage_check()\n",
+    "val_dataset.disable_data_leakage_check()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "ZpFEmw5TrHxW",
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/";,
+     "height": 443
+    },
+    "executionInfo": {
+     "elapsed": 764067,
+     "status": "ok",
+     "timestamp": 1743099192177,
+     "user": {
+      "displayName": "Shivashriganesh Mahato",
+      "userId": "06240283312510304541"
+     },
+     "user_tz": 240
+    },
+    "id": "ZpFEmw5TrHxW",
+    "outputId": "c6938d53-f226-435e-af50-e72bbede892d"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Epoch 100/100 | Val R2 = 0.770 | Loss = 0.004\n",
+      "Done! Final validation R2 = 0.781\n"
+     ]
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAABKUAAAGGCAYAAACqvTJ0AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAyW9JREFUeJzs3XlclNX+B/DPzDAM+yayiiJgIipioKhlehUBtVKzUtNUMr1p3Eoq067iWpiZkWVRlrmkuV2zX2kIYlQmivu+4b6xqTCswzDz/P5AxkZAQWdhhs/79eJ1eZ7nPOc558htHr6c8z0iQRAEEBERERERERERGZDY2A0gIiIiIiIiIqKmh0EpIiIiIiIiIiIyOAaliIiIiIiIiIjI4BiUIiIiIiIiIiIig2NQioiIiIiIiIiIDI5BKSIiIiIiIiIiMjgGpYiIiIiIiIiIyOAYlCIiIiIiIiIiIoNjUIqI
 [...]
+      "text/plain": [
+       "<Figure size 1200x400 with 2 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "# Training is same as we did before\n",
+    "\n",
+    "seed_everything(0)\n",
+    "\n",
+    "from torch_brain.models import POYO\n",
+    "from torch_brain.registry import MODALITY_REGISTRY\n",
+    "poyo_model = POYO(\n",
+    "    sequence_length=1.0,\n",
+    "    readout_spec=MODALITY_REGISTRY['cursor_velocity_2d'],\n",
+    "    latent_step=1.0 / 8,\n",
+    "    num_latents_per_step=16,\n",
+    "    dim=64,\n",
+    "    depth=6,\n",
+    "    dim_head=64,\n",
+    "    cross_heads=2,\n",
+    "    self_heads=8,\n",
+    ").to(device)\n",
+    "\n",
+    "poyo_model.unit_emb.initialize_vocab(train_dataset.get_unit_ids())\n",
+    
"poyo_model.session_emb.initialize_vocab(train_dataset.get_session_ids())\n",
+    "\n",
+    "train_dataset.transform = poyo_model.tokenize\n",
+    "val_dataset.transform = poyo_model.tokenize\n",
+    "\n",
+    "optimizer = torch.optim.AdamW(poyo_model.parameters(), lr=1e-3)\n",
+    "\n",
+    "poyo_r2_log, poyo_loss_log, poyo_train_outputs = train(\n",
+    "  poyo_model, optimizer, train_loader, val_loader,\n",
+    "  num_epochs=100, store_embs=True,\n",
+    ")\n",
+    "\n",
+    "plot_training_curves(poyo_r2_log, poyo_loss_log)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "h0zOuuou5Ft1",
+   "metadata": {
+    "id": "h0zOuuou5Ft1"
+   },
+   "source": [
+    "Once the training is done you can see that just by training on another 
session the validation performance on the original session improves slighly. 
You can imagine it's not too much more work to then train at the scale of POYO 
and POYO+ with this setup. Just keep scaling up the dataset size (and model 
size as well) and performance will keep on improving."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "mXbkMNKhegyu",
+   "metadata": {
+    "id": "mXbkMNKhegyu"
+   },
+   "source": [
+    "### 4.3 Exercise (c): Medium\n",
+    "\n",
+    "In our MLP model in Part 2.1, we had intentionally set `bin_size` equal 
to the sampling period of the behavior output (10ms) to simplify the code. Now 
that you are more accustomed to the framework, make `bin_size` independent of 
the `output_sampling_period`. The signature of the MLP model should be:\n",
+    "\n",
+    "```python\n",
+    "class MLPNeuralDecoder(nn.Module):\n",
+    "    def __init__(\n",
+    "      self, num_units, bin_size, sequence_length,\n",
+    "      output_sampling_period,   # NEW ARGUMENT!\n",
+    "      output_dim, hidden_dim,\n",
+    "    ):\n",
+    "      ...\n",
+    "```\n",
+    "\n",
+    "Once you're done modifying the model, train it with `bin_size=20e-3`, and 
`output_sampling_period=10e-3`."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "raNpOx91gBwU",
+   "metadata": {
+    "id": "raNpOx91gBwU"
+   },
+   "source": [
+    "#### Solution"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "MEEvUEu-iunF",
+   "metadata": {
+    "id": "MEEvUEu-iunF"
+   },
+   "outputs": [],
+   "source": [
+    "import torch.nn as nn\n",
+    "\n",
+    "class MLPNeuralDecoder(nn.Module):\n",
+    "    def __init__(self, num_units, bin_size, sequence_length, 
output_sampling_period, output_dim, hidden_dim):\n",
+    "        \"\"\"Initialize the neural net layers.\"\"\"\n",
+    "        super().__init__()\n",
+    "\n",
+    "        self.num_input_timesteps = int(sequence_length / bin_size)  # 
CHANGE\n",
+    "        self.num_output_timesteps = int(sequence_length / 
output_sampling_period)  # CHANGE\n",
+    "        self.bin_size = bin_size\n",
+    "\n",
+    "        self.net = nn.Sequential(\n",
+    "            nn.Linear(self.num_input_timesteps * num_units, hidden_dim),  
# CHANGE\n",
+    "            nn.ReLU(),\n",
+    "            nn.Linear(hidden_dim, hidden_dim),\n",
+    "            nn.ReLU(),\n",
+    "            nn.Linear(hidden_dim, output_dim * 
self.num_output_timesteps),  # CHANGE\n",
+    "        )\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        x = x.flatten(1)\n",
+    "        x = self.net(x)\n",
+    "        x = x.reshape(-1, self.num_output_timesteps, 2)  # CHANGE\n",
+    "        return x\n",
+    "\n",
+    "    def tokenize(self, data):\n",
+    "        spikes = data.spikes\n",
+    "        x = bin_spikes(\n",
+    "            spikes=spikes,\n",
+    "            num_units=len(data.units),\n",
+    "            bin_size=self.bin_size,\n",
+    "            num_bins=self.num_input_timesteps  # CHANGE\n",
+    "        ).T\n",
+    "        y = data.cursor.vel\n",
+    "        data_dict = {\n",
+    "            \"model_inputs\": {\n",
+    "                \"x\": torch.tensor(x, dtype=torch.float32),\n",
+    "            },\n",
+    "            \"target_values\": torch.tensor(y, dtype=torch.float32),\n",
+    "        }\n",
+    "        return data_dict"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "x1LuW5W7kF7w",
+   "metadata": {
+    "id": "x1LuW5W7kF7w"
+   },
+   "outputs": [],
+   "source": [
+    "seed_everything(0)\n",
+    "\n",
+    "# 1. Setup datasets and dataloader\n",
+    "recording_id = 
\"perich_miller_population_2018/t_20130819_center_out_reaching\"\n",
+    "train_dataset, train_loader, val_dataset, val_loader = 
get_train_val_loaders(recording_id, batch_size=64)\n",
+    "num_units = len(train_dataset.get_unit_ids())\n",
+    "print(f\"Num Units in Session: {num_units}\")\n",
+    "\n",
+    "# 2. Initialize Model with the new MLP definition\n",
+    "mlp_model = MLPNeuralDecoder(\n",
+    "    num_units=num_units,\n",
+    "    bin_size=20e-3,                   # CHANGE\n",
+    "    sequence_length=1.0,\n",
+    "    output_sampling_period=10e-3,     # CHANGE\n",
+    "    output_dim=2,\n",
+    "    hidden_dim=32,\n",
+    ")\n",
+    "mlp_model = mlp_model.to(device)\n",
+    "\n",
+    "# 3. Connect Tokenizer to Datasets\n",
+    "transform = mlp_model.tokenize\n",
+    "train_dataset.transform = transform\n",
+    "val_dataset.transform = transform\n",
+    "\n",
+    "# 4. Setup Optimizer\n",
+    "optimizer = torch.optim.AdamW(mlp_model.parameters(), lr=1e-3)\n",
+    "\n",
+    "# 5. Train!\n",
+    "mlp_r2_log, mlp_loss_log, mlp_train_outputs = train(mlp_model, optimizer, 
train_loader, val_loader, num_epochs=100)\n",
+    "\n",
+    "# Plot the training loss and validation R2\n",
+    "plot_training_curves(mlp_r2_log, mlp_loss_log)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "ixe6TofbehAk",
+   "metadata": {
+    "id": "ixe6TofbehAk"
+   },
+   "source": [
+    "### 4.4 Exercise (d): Challenging\n",
+    "\n",
+    "So far in all our examples, we've been using the full set of recorded 
units for decoding. It turns out that the session includes recordings of 
spiking activity from neurons in both the monkey Premotor Cortex (PMd) and 
Primary Motor Cortex (M1). What if we restrict our model to only decode from 
neurons in the PMd?\n",
+    "\n",
+    "For this exercise, note that the brain regions of the units can be found 
directly in the unit IDs (refer to training code from any of the examples to 
see how to access these IDs). Given these, change the Transformer code from 
Part 2.2 to:\n",
+    "\n",
+    "(i) First explore the unit IDs to figure out how to extract and filter on 
the brain region.\n",
+    "\n",
+    "(ii) Write a new tokenizer that, given a batch of samples from both brain 
regions, creates tokens only from the PMd neurons.\n",
+    "\n",
+    "(iii) Instantiate a Transformer model with number of units restricted 
based only to PMd neurons.\n",
+    "\n",
+    "(iv) Train the Transformer model using your new tokenizer, and compare 
results to training on all neurons.\n",
+    "\n",
+    "**Question**: We are dropping the number of units the model is trained 
on, but in the process reducing the training to data from the same brain 
region. Given these, should we expect model performance to increase or 
decrease?\n",
+    "\n",
+    "**Tip**: save `r2_log` and `loss_log` as different names than before 
(e.g., `transformer_r2_log`) so you retain both for comparison."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "JLNQmHTTgCcG",
+   "metadata": {
+    "id": "JLNQmHTTgCcG"
+   },
+   "source": [
+    "#### Solution"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "p8kAOMoPtnxQ",
+   "metadata": {
+    "id": "p8kAOMoPtnxQ"
+   },
+   "source": [
+    "##### (i) Solution"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "lYlQWPtegCcH",
+   "metadata": {
+    "id": "lYlQWPtegCcH"
+   },
+   "outputs": [],
+   "source": [
+    "# Setup data loaders\n",
+    "seed_everything(0)\n",
+    "recording_id = 
\"perich_miller_population_2018/t_20130819_center_out_reaching\"\n",
+    "train_dataset, train_loader, val_dataset, val_loader = 
get_train_val_loaders(recording_id, batch_size=64)\n",
+    "\n",
+    "# Explore the unit ids\n",
+    "unit_ids = train_dataset.get_unit_ids()\n",
+    "for unit_id in unit_ids[:5] + unit_ids[-5:]:\n",
+    "    print(unit_id)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "KFavck4ytsIk",
+   "metadata": {
+    "id": "KFavck4ytsIk"
+   },
+   "source": [
+    "We can see that the unit IDs directly contain the brain region (between 
M1 and PMd). Hence filtering is fairly straightforward:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "-xqxJyZqtqZW",
+   "metadata": {
+    "id": "-xqxJyZqtqZW"
+   },
+   "outputs": [],
+   "source": [
+    "# Filter on PMd neurons\n",
+    "pmd_units = [unit_id for unit_id in unit_ids if 'PMd' in unit_id]\n",
+    "for unit_id in pmd_units[:5] + pmd_units[-5:]:\n",
+    "    print(unit_id)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "3jS_UQspt1tD",
+   "metadata": {
+    "id": "3jS_UQspt1tD"
+   },
+   "source": [
+    "##### (ii) Solution"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "6wPD2K74t246",
+   "metadata": {
+    "id": "6wPD2K74t246"
+   },
+   "outputs": [],
+   "source": [
+    "# We adapt the Transformer implementation from Part 2.2\n",
+    "class TransformerNeuralDecoder(nn.Module):\n",
+    "    def __init__(\n",
+    "        self, num_units, bin_size, sequence_length,\n",
+    "        dim_output, dim_hidden, n_layers, n_heads,\n",
+    "    ):\n",
+    "        # NO CHANGE TO THIS METHOD\n",
+    "        super().__init__()\n",
+    "\n",
+    "        self.num_timesteps = int(sequence_length / bin_size)\n",
+    "        self.bin_size = bin_size\n",
+    "\n",
+    "        self.readin = nn.Linear(num_units, dim_hidden)\n",
+    "        self.readout = nn.Linear(dim_hidden, dim_output)\n",
+    "\n",
+    "        self.position_embeddings = nn.Parameter(\n",
+    "            data=generate_sinusoidal_position_embs(self.num_timesteps, 
dim_hidden),\n",
+    "            requires_grad=False,\n",
+    "        )\n",
+    "\n",
+    "        self.transformer_layers = nn.ModuleList([\n",
+    "            nn.ModuleList([\n",
+    "                nn.MultiheadAttention(\n",
+    "                    embed_dim=dim_hidden,\n",
+    "                    num_heads=n_heads,\n",
+    "                    batch_first=True,\n",
+    "                ),\n",
+    "                FeedForward(dim=dim_hidden),\n",
+    "            ])\n",
+    "            for _ in range(n_layers)\n",
+    "        ])\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        # NO CHANGE TO THIS METHOD\n",
+    "        x = self.readin(x)\n",
+    "        x = x + self.position_embeddings[None, ...]\n",
+    "        for attn, ffn in self.transformer_layers:\n",
+    "            x = x + attn(x, x, x, need_weights=False)[0]\n",
+    "            x = x + ffn(x)\n",
+    "        x = self.readout(x)\n",
+    "        return x\n",
+    "\n",
+    "    def tokenize(self, data):\n",
+    "        unit_mask = np.array(['PMd' in unit_id for unit_id in unit_ids])  
# <<<<<< SOLUTION: ADDED\n",
+    "\n",
+    "        x = bin_spikes(\n",
+    "            spikes=data.spikes,\n",
+    "            num_units=len(data.units),\n",
+    "            bin_size=self.bin_size,\n",
+    "            num_bins=self.num_timesteps,\n",
+    "        ).T\n",
+    "        x = x[:, unit_mask]   # <<<<<< SOLUTION: ADDED\n",
+    "\n",
+    "        y = data.cursor.vel\n",
+    "\n",
+    "        data_dict = {\n",
+    "            \"model_inputs\": {\n",
+    "                \"x\": torch.tensor(x, dtype=torch.float32),\n",
+    "            },\n",
+    "            \"target_values\": torch.tensor(y, dtype=torch.float32),\n",
+    "        }\n",
+    "        return data_dict"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "EkAdwarVt6xf",
+   "metadata": {
+    "id": "EkAdwarVt6xf"
+   },
+   "source": [
+    "##### (iii) Solution"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "5WiHIzjft9TJ",
+   "metadata": {
+    "id": "5WiHIzjft9TJ"
+   },
+   "outputs": [],
+   "source": [
+    "# Initialize model for PMd neurons only\n",
+    "num_units = len(pmd_units)\n",
+    "pmd_tf_model = TransformerNeuralDecoder(\n",
+    "    num_units=num_units,\n",
+    "    bin_size=10e-3,\n",
+    "    sequence_length=1.0,\n",
+    "    dim_output=2,\n",
+    "    dim_hidden=128,\n",
+    "    n_layers=3,\n",
+    "    n_heads=4,\n",
+    ").to(device)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "xOfbQ9vSt-of",
+   "metadata": {
+    "id": "xOfbQ9vSt-of"
+   },
+   "source": [
+    "##### (iv) Solution"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "y8mJq4Hmt_eh",
+   "metadata": {
+    "id": "y8mJq4Hmt_eh"
+   },
+   "outputs": [],
+   "source": [
+    "# Connect new tokenizer to Datasets\n",
+    "train_dataset.transform = pmd_tf_model.tokenize\n",
+    "val_dataset.transform = pmd_tf_model.tokenize\n",
+    "\n",
+    "# Let's train!\n",
+    "optimizer = torch.optim.AdamW(pmd_tf_model.parameters(), lr=1e-3)\n",
+    "pmd_transformer_r2_log, pmd_transformer_loss_log, _ = train(pmd_tf_model, 
optimizer, train_loader, val_loader, num_epochs=100)\n",
+    "plot_training_curves(pmd_transformer_r2_log, pmd_transformer_loss_log)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "psz1G5Q0wDSw",
+   "metadata": {
+    "id": "psz1G5Q0wDSw"
+   },
+   "source": [
+    "### A note on composing transforms\n",
+    "\n",
+    "The current solution to Exercise (c) was possible because we were able to 
edit the tokenizer of the model directly. Many times, it would be infeasible or 
inconvenient to do so. In such situations, **torch_brain**'s ability to 
**compose** transforms will come in handy.\n",
+    "\n",
+    "For instance, if we wanted to train POYO on only PMd neurons, a clean 
approach would be:\n",
+    "1. Define a transform that removes all M1 neurons in a data sample and 
only keeps the PMd neurons.\n",
+    "2. Define `dataset.transform = Compose([drop_M1_neurons_transform, 
model.tokenize])`\n",
+    "\n",
+    "We refer curious readers to the documentation of **Compose**: 
[link](https://torch-brain.readthedocs.io/en/v0.1.0/package/transforms.html#torch_brain.transforms.Compose).\n",
+    "\n",
+    "This page also showcases existing transforms that are already 
implementated in **torch_brain**, like UnitDropout, some of which may come 
handy for you.\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "XSQNyUF_y89C",
+   "metadata": {
+    "id": "XSQNyUF_y89C"
+   },
+   "source": [
+    "***\n",
+    "\n",
+    "## References\n",
+    "\n",
+    "[1] [Perich, M. G., Gallego, J. A., & Miller, L. E. (2018). A neural 
population mechanism for rapid learning. Neuron, 100(4), 
964-976.](https://pubmed.ncbi.nlm.nih.gov/30344047/)\n",
+    "\n",
+    "[2] [Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., 
Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you need. Advances 
in neural information processing systems, 
30.](https://papers.nips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)\n",
+    "\n",
+    "[3] [Azabou, M., Arora, V., Ganesh, V., Mao, X., Nachimuthu, S., 
Mendelson, M., ... & Dyer, E. (2023). A unified, scalable framework for neural 
population decoding. Advances in Neural Information Processing Systems, 36, 
44937-44956.](https://proceedings.neurips.cc/paper_files/paper/2023/file/8ca113d122584f12a6727341aaf58887-Paper-Conference.pdf)\n",
+    "\n",
+    "[4] [Azabou, M., Pan, K. X., Arora, V., Knight, I.J., Dyer, E. L., 
Richards, B. A. (2025). Multi-session, multi-task neural decoding from distinct 
cell-types and brain regions. International Conference on Learning 
Representations, 13.](https://openreview.net/pdf?id=IuU0wcO0mo)\n",
+    "\n",
+    "[5] [Ye, J., Collinger, J., Wehbe, L., & Gaunt, R. (2023). Neural data 
transformer 2: multi-context pretraining for neural spiking activity. Advances 
in Neural Information Processing Systems, 36, 
80352-80374.](https://papers.neurips.cc/paper_files/paper/2023/file/fe51de4e7baf52e743b679e3bdba7905-Paper-Conference.pdf)\n",
+    "\n",
+    "[6] [Zhang, Y., Wang, Y., Jiménez-Benetó, D., Wang, Z., Azabou, M., 
Richards, B., ... & Hurwitz, C. (2024). Towards a\" universal translator\" for 
neural dynamics at single-cell, single-spike resolution. Advances in Neural 
Information Processing Systems, 37, 
80495-80521.](https://proceedings.neurips.cc/paper_files/paper/2024/file/934eb45b99eff8f16b5cb8e4d3cb5641-Paper-Conference.pdf)"
+   ]
+  }
+ ],
+ "metadata": {
+  "accelerator": "GPU",
+  "colab": {
+   "collapsed_sections": [
+    "BycghpnsEKBg",
+    "PzWxTVE-bMOy",
+    "08Cr0k2lcKJV",
+    "tV_IjhBFe-j7",
+    "PbZisWiX3wO2",
+    "raNpOx91gBwU",
+    "JLNQmHTTgCcG",
+    "p8kAOMoPtnxQ",
+    "3jS_UQspt1tD",
+    "EkAdwarVt6xf",
+    "xOfbQ9vSt-of"
+   ],
+   "gpuType": "T4",
+   "provenance": [
+    {
+     "file_id": "1WDLqGcFm0cNmoggIatTo2Vp2gncc9kF5",
+     "timestamp": 1742891256057
+    }
+   ],
+   "toc_visible": true
+  },
+  "jupytext": {
+   "cell_metadata_filter": "-all",
+   "main_language": "python",
+   "notebook_metadata_filter": 
"-kernelspec,-jupytext.text_representation.jupytext_version"
+  },
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.11.6"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git 
a/modules/agent-framework/deployments/jupyterhub/data/cosyne/cybershuttle.yml 
b/modules/agent-framework/deployments/jupyterhub/data/cosyne/cybershuttle.yml
new file mode 100644
index 0000000000..1949e8af1d
--- /dev/null
+++ 
b/modules/agent-framework/deployments/jupyterhub/data/cosyne/cybershuttle.yml
@@ -0,0 +1,64 @@
+project:
+  name: Foundations of Transformers in Neuroscience
+  version: 1.0.0
+  description: This project demonstrates how to prepare datasets using 
temporaldata, utilize torch_brain samplers, and build neural decoding models 
(MLP, Transformer, POYO). You’ll also fine-tune a pretrained POYO model and 
visualize training, all in an interactive, hands-on notebook.
+  tags:
+    - neuroscience
+    - transformers
+  homepage: https://github.com/yasithdev/allen-v1
+
+workspace:
+  location: /workspace
+  resources:
+    min_cpu: 4
+    min_gpu: 0
+    min_mem: 4096
+    walltime: 60
+    cluster: Anvil
+    group: Cerebrum
+    queue: shared
+  model_collection:
+    - source: cybershuttle
+      identifier: lgn_stimulus
+      mount_point: /data/lgn_stimulus
+    - source: cybershuttle
+      identifier: v1_point
+      mount_point: /data/v1_point
+  data_collection:
+    - source: cybershuttle
+      identifier: lgn_stimulus_configs
+      mount_point: /data/lgn_stimulus_configs
+    - source: cybershuttle
+      identifier: v1_point_configs
+      mount_point: /data/v1_point_configs
+
+additional_dependencies:
+  conda:
+    - python=3.10
+    - pip
+    - ipywidgets
+    - numba
+    - numpy=1.23.5
+    - matplotlib
+    - openpyxl
+    - pandas
+    - pyqtgraph
+    - pyyaml
+    - requests
+    - scipy
+    - sqlalchemy
+    - tqdm
+    - nest-simulator
+    - ipytree
+    - python-jsonpath
+    - pydantic=2.7
+    - anndata
+    - parse
+  pip:
+    - allensdk
+    - bmtk
+    - pytree
+    - git+https://github.com/alleninstitute/abc_atlas_access
+    - git+https://github.com/alleninstitute/neuroanalysis
+    - git+https://github.com/alleninstitute/aisynphys
+    - git+https://github.com/lahirujayathilake/mousev1
diff --git 
a/modules/agent-framework/deployments/jupyterhub/data/gkeyll/plotE_z.ipynb 
b/modules/agent-framework/deployments/jupyterhub/data/gkeyll/plotE_z.ipynb
index ff7cc617b0..6bb46a321c 100644
--- a/modules/agent-framework/deployments/jupyterhub/data/gkeyll/plotE_z.ipynb
+++ b/modules/agent-framework/deployments/jupyterhub/data/gkeyll/plotE_z.ipynb
@@ -20,7 +20,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "%pip install --force-reinstall airavata-jupyter-magic\n",
+    "# %pip install --force-reinstall airavata-python-sdk[notebook]\n",
     "import airavata_jupyter_magic\n",
     "\n",
     "%authenticate\n",
@@ -142,7 +142,7 @@
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "Python 3 (ipykernel)",
+   "display_name": "Python 3",
    "language": "python",
    "name": "python3"
   },
@@ -156,7 +156,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.11.6"
+   "version": "3.10.16"
   }
  },
  "nbformat": 4,

Reply via email to