This is an automated email from the ASF dual-hosted git repository. dimuthuupe pushed a commit to branch cybershuttle-staging in repository https://gitbox.apache.org/repos/asf/airavata.git
commit 512012175e20cc5ca6a2b7084ca4460278f0feb8 Author: yasith <[email protected]> AuthorDate: Fri Apr 4 09:10:24 2025 -0400 update notebooks --- .../data/cosyne/cosyne_tutorial_part_1.ipynb | 1 + .../data/cosyne/cosyne_tutorial_part_2.ipynb | 2470 ++++++++++++++++++++ .../jupyterhub/data/cosyne/cybershuttle.yml | 64 + .../jupyterhub/data/gkeyll/plotE_z.ipynb | 6 +- 4 files changed, 2538 insertions(+), 3 deletions(-) diff --git a/modules/agent-framework/deployments/jupyterhub/data/cosyne/cosyne_tutorial_part_1.ipynb b/modules/agent-framework/deployments/jupyterhub/data/cosyne/cosyne_tutorial_part_1.ipynb new file mode 100644 index 0000000000..e94e01ad2f --- /dev/null +++ b/modules/agent-framework/deployments/jupyterhub/data/cosyne/cosyne_tutorial_part_1.ipynb @@ -0,0 +1 @@ +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"gpuType":"T4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["# Foundations of Transformers in Neuroscience Tutorial\n","\n","Authors: Mehdi Azabou | Contributions: Vinam Arora, Shivashriganesh Mahato, Eva Dyer\n","\n","***\n","\n","In this notebook, we will go through an example for preparing a dataset using\n","data objects fro [...] \ No newline at end of file diff --git a/modules/agent-framework/deployments/jupyterhub/data/cosyne/cosyne_tutorial_part_2.ipynb b/modules/agent-framework/deployments/jupyterhub/data/cosyne/cosyne_tutorial_part_2.ipynb new file mode 100644 index 0000000000..1504196598 --- /dev/null +++ b/modules/agent-framework/deployments/jupyterhub/data/cosyne/cosyne_tutorial_part_2.ipynb @@ -0,0 +1,2470 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "3ImhQu3SEILJ", + "metadata": { + "id": "3ImhQu3SEILJ" + }, + "source": [ + "# Foundations of Transformers in Neuroscience Tutorial\n", + "\n", + "Authors: Shivashriganesh Mahato, Vinam Arora | Contributions: Mehdi Azabou, Sergey Shuvaev, Julie Young, Eva Dyer\n", + "\n", + "***\n", + "\n", + "The goal of this notebook is to show you how to work with datasets and dataloaders, build and train several neural decoding models (a simple MLP, a Transformer, and POYO), fine-tune a pretrained POYO model on a new session, and visualize training. This notebook is designed to be interactive and provide visual feedback. As you work through the cells, try to run them and observe the results. Detailed explanations are provided along the way.\n", + "\n", + "<center>\n", + "<img src=\"https://torch-brain.readthedocs.io/en/latest/_static/torch_brain_logo.png\" width=\"150\" height=\"150\" alt=\"torch_brain Logo\">\n", + "</center>\n", + "\n", + "\n", + "We will focus on three main topics in this notebook:\n", + "* **Part 1: DataLoaders**\n", + "* **Part 2: Training Models**\n", + "* **Part 3: Finetuning and Visualizations**\n", + "\n", + "\n", + "\\\\\n", + "General references:\n", + "\n", + "- [**torch_brain** documentation](<https://torch-brain.readthedocs.io/en/latest/index.html>)\n", + "- [pytorch tutorials](<https://pytorch.org/tutorials/beginner/basics/intro.html>)" + ] + }, + { + "cell_type": "markdown", + "id": "BycghpnsEKBg", + "metadata": { + "id": "BycghpnsEKBg" + }, + "source": [ + "***\n", + "## Setup\n", + "\n", + "First, let's install **torch_brain**, and get you set up with some simple utility functions that will be used throughout this notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "hDjkZJkEEMJR", + "metadata": { + "id": "hDjkZJkEEMJR" + }, + "outputs": [], + "source": [ + "! uv pip install pytorch_brain -q\n", + "\n", + "# We use uv for the installation here, which seems to work better with google colab.\n", + "# Although UV is awesome and we highly recommend it, you could install it with vanilla pip\n", + "# in your local environments" + ] + }, + { + "cell_type": "markdown", + "id": "aaMelfAWsl2h", + "metadata": { + "id": "aaMelfAWsl2h" + }, + "source": [ + "### Run the block below to load utility functions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "agXH4LvWEMJS", + "metadata": { + "id": "agXH4LvWEMJS" + }, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from omegaconf import OmegaConf\n", + "import warnings\n", + "import logging\n", + "from torch_brain.utils import seed_everything\n", + "\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "warnings.filterwarnings('ignore')\n", + "logging.disable(logging.WARNING)\n", + "\n", + "\n", + "def move_to_gpu(data, device):\n", + " \"\"\"\n", + " Recursively moves tensors (or collections of tensors) to the given device.\n", + " \"\"\"\n", + " if isinstance(data, torch.Tensor):\n", + " return data.to(device)\n", + " elif isinstance(data, dict):\n", + " return {k: move_to_gpu(v, device) for k, v in data.items()}\n", + " elif isinstance(data, list):\n", + " return [move_to_gpu(elem, device) for elem in data]\n", + " else:\n", + " return data\n", + "\n", + "\n", + "def bin_spikes(spikes, num_units, bin_size, right=True, num_bins=None):\n", + " \"\"\"\n", + " Bins spike timestamps into a 2D array: [num_units x num_bins].\n", + " \"\"\"\n", + " rate = 1 / bin_size # avoid precision issues\n", + " binned_spikes = np.zeros((num_units, num_bins))\n", + " bin_index = np.floor((spikes.timestamps) * rate).astype(int)\n", + " np.add.at(binned_spikes, (spikes.unit_index, bin_index), 1)\n", + " return binned_spikes\n", + "\n", + "\n", + "def r2_score(y_pred, y_true):\n", + " # Compute total sum of squares (variance of the true values)\n", + " y_true_mean = torch.mean(y_true, dim=0, keepdim=True)\n", + " ss_total = torch.sum((y_true - y_true_mean) ** 2)\n", + "\n", + " # Compute residual sum of squares\n", + " ss_res = torch.sum((y_true - y_pred) ** 2)\n", + "\n", + " # Compute R^2\n", + " r2 = 1 - ss_res / ss_total\n", + "\n", + " return r2\n", + "\n", + "\n", + "def compute_r2(dataloader, model):\n", + " # Compute R2 score over the entire dataset\n", + " total_target = []\n", + " total_pred = []\n", + " for batch in dataloader:\n", + " batch = move_to_gpu(batch, device)\n", + " pred = model(**batch[\"model_inputs\"])\n", + " target = batch[\"target_values\"]\n", + "\n", + " # Store target and pred for visualization\n", + " mask = torch.ones_like(target, dtype=torch.bool)\n", + " if \"output_mask\" in batch[\"model_inputs\"]:\n", + " mask = batch[\"model_inputs\"][\"output_mask\"]\n", + " total_target.append(target[mask])\n", + " total_pred.append(pred[mask])\n", + "\n", + " # Concatenate all batch outputs\n", + " total_target = torch.cat(total_target)\n", + " total_pred = torch.cat(total_pred)\n", + "\n", + " # Compute the R2 score\n", + " r2 = r2_score(total_pred.flatten(), total_target.flatten())\n", + "\n", + " return r2.item(), total_target, total_pred\n", + "\n", + "\n", + "def print_model(model: torch.nn.Module):\n", + " \"\"\"\n", + " Prints a summary of the model architecture and parameter count.\n", + " \"\"\"\n", + " model_str = str(model).split('\\n')\n", + " print(\"\\nModel:\")\n", + " print('\\n'.join(model_str[:5]))\n", + " print(\"...\")\n", + " print('\\n'.join(model_str[-min(5, len(model_str)):]))\n", + " num_params = sum(p.numel() for p in model.parameters())\n", + " if num_params > 1e9:\n", + " param_str = f\"{num_params/1e9:.1f}G\"\n", + " elif num_params > 1e6:\n", + " param_str = f\"{num_params/1e6:.1f}M\"\n", + " else:\n", + " param_str = f\"{num_params/1e3:.1f}K\"\n", + " print(f\"\\nNumber of parameters: {param_str}\\n\")\n", + "\n", + "\n", + "def plot_training_curves(r2_log, loss_log):\n", + " \"\"\"\n", + " Plots the training curves: training loss and validation R2 score.\n", + " \"\"\"\n", + " plt.figure(figsize=(12, 4))\n", + " plt.subplot(1, 2, 1)\n", + " plt.plot(np.linspace(0, len(loss_log), len(loss_log)), loss_log)\n", + " plt.title(\"Training Loss\")\n", + " plt.xlabel(\"Training Steps\")\n", + " plt.ylabel(\"MSE Loss\")\n", + " plt.grid()\n", + " plt.subplot(1, 2, 2)\n", + " plt.plot(r2_log)\n", + " plt.title(\"Validation R2\")\n", + " plt.xlabel(\"Epochs\")\n", + " plt.ylabel(\"R2 Score\")\n", + " plt.grid()\n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + "\n", + "def generate_sinusoidal_position_embs(num_timesteps, dim):\n", + " position = torch.arange(num_timesteps).unsqueeze(1)\n", + " div_term = torch.exp(torch.arange(0, dim, 2) * (-np.log(10000.0) / dim))\n", + " pe = torch.empty(num_timesteps, dim)\n", + " pe[:, 0:dim // 2] = torch.sin(position * div_term)\n", + " pe[:, dim//2:] = torch.cos(position * div_term)\n", + " return pe\n", + "\n", + "\n", + "def load_pretrained(ckpt_path, model):\n", + " print(\"Loading pretrained model...\")\n", + " checkpoint = torch.load(ckpt_path, map_location=\"cpu\", weights_only=False)\n", + " # poyo is pretrained using lightning, so model weights are prefixed with \"model.\"\n", + " state_dict = {k.replace(\"model.\", \"\"): v for k, v in checkpoint[\"state_dict\"].items()}\n", + " model.load_state_dict(state_dict)\n", + " print(\"Done!\")\n", + " return model\n", + "\n", + "\n", + "def reinit_vocab(emb_module, vocab):\n", + " emb_module.extend_vocab(vocab)\n", + " emb_module.subset_vocab(vocab)\n", + "\n", + "\n", + "def get_dataset_config(brainset, sessions):\n", + " brainset_norms = {\n", + " \"perich_miller_population_2018\": {\n", + " \"mean\": 0.0,\n", + " \"std\": 20.0\n", + " }\n", + " }\n", + "\n", + " config = f\"\"\"\n", + " - selection:\n", + " - brainset: {brainset}\n", + " sessions:\"\"\"\n", + " if type(sessions) is not list:\n", + " sessions = [sessions]\n", + " for session in sessions:\n", + " config += f\"\"\"\n", + " - {session}\"\"\"\n", + " config += f\"\"\"\n", + " config:\n", + " readout:\n", + " readout_id: cursor_velocity_2d\n", + " normalize_mean: {brainset_norms[brainset][\"mean\"]}\n", + " normalize_std: {brainset_norms[brainset][\"std\"]}\n", + " metrics:\n", + " - metric:\n", + " _target_: torchmetrics.R2Score\n", + " \"\"\"\n", + "\n", + " config = OmegaConf.create(config)\n", + "\n", + " return config" + ] + }, + { + "cell_type": "markdown", + "id": "0WE9qTBPy887", + "metadata": { + "id": "0WE9qTBPy887" + }, + "source": [ + "***\n", + "\n", + "## Part 1: Data Loading\n", + "\n", + "***\n", + "\n", + "### Table of contents:\n", + "* 1.1 The life of a data sample\n", + "* 1.2 Setting up a basic data pipeline\n", + "* 1.3 Downloading a session" + ] + }, + { + "cell_type": "markdown", + "id": "_5i33iBdyVrq", + "metadata": { + "id": "_5i33iBdyVrq" + }, + "source": [ + "### 1.1 The life of a data sample\n", + "\n", + "In the previous notebook, we saw how **torch_brain.data.Dataset** efficiently produces a data sample. But once a sample is drawn—what happens next?\n", + "\n", + "1.\tA data sample originates from the dataset. In **torch_brain**, this typically refers to a short time-slice of a neural recording.\n", + "2. Optionally, the sample can be transformed—for example, by applying augmentations such as dropping out neurons or brain regions.\n", + "3. Next, the sample is further processed and reshaped into a format suitable for model input. We refer to this step as **tokenization**.\n", + "4. The tokenized sample is then **collated** with other samples to form a **batch**.\n", + "5.\tFinally, the batch is passed through the model for the forward computation and loss evaluation.\n", + "\n", + "This process can be seen visually through the diagram below:\n", + "\n", + "<center>\n", + "<img src=\"https://ik.imagekit.io/7tkfmw7hc/dataloader.png?updatedAt=1743052497906\" height=420 />\n", + "</center>\n", + "\n", + "That’s a lot of work—and in plain PyTorch, you’d be doing it all by hand. **torch_brain** makes each of these steps easy and intuitive to build and customize." + ] + }, + { + "cell_type": "markdown", + "id": "U0L7aQuiE7qL", + "metadata": { + "id": "U0L7aQuiE7qL" + }, + "source": [ + "### 1.2 Setting up a basic data pipeline\n", + "\n", + "Now let's define a utility function to create the training and validation datasets, samplers, and dataloaders.\n", + "This function will be used to set up the data for training and validation of our models.\n", + "\n", + "Note: We'll handle tokenization in the next part of this notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "EEqTM1Nmy887", + "metadata": { + "id": "EEqTM1Nmy887" + }, + "outputs": [], + "source": [ + "from torch_brain.data import Dataset, collate, chain\n", + "from torch_brain.data.sampler import RandomFixedWindowSampler, SequentialFixedWindowSampler\n", + "from torch.utils.data import DataLoader\n", + "\n", + "def get_train_val_loaders(recording_id=None, cfg=None, batch_size=32, seed=0):\n", + " \"\"\"Sets up train and validation Datasets, Samplers, and DataLoaders\n", + " \"\"\"\n", + "\n", + " # -- Train --\n", + " train_dataset = Dataset(\n", + " root=\"data\", # root directory where .h5 files are found\n", + " recording_id=recording_id, # you either specify a single recording ID\n", + " config=cfg, # or a config for multi-session training / more complex configs\n", + " split=\"train\",\n", + " )\n", + " # We use a random sampler to improve generalization during training\n", + " train_sampling_intervals = train_dataset.get_sampling_intervals()\n", + " train_sampler = RandomFixedWindowSampler(\n", + " sampling_intervals=train_sampling_intervals,\n", + " window_length=1.0, # context window of samples\n", + " generator=torch.Generator().manual_seed(seed),\n", + " )\n", + " # Finally combine them in a dataloader\n", + " train_loader = DataLoader(\n", + " dataset=train_dataset, # dataset\n", + " sampler=train_sampler, # sampler\n", + " batch_size=batch_size, # num of samples per batch\n", + " collate_fn=collate, # the collator\n", + " num_workers=4, # data sample processing (slicing, transforms, tokenization) happens in parallel; this sets the amount of that parallelization\n", + " pin_memory=True,\n", + " )\n", + "\n", + " # -- Validation --\n", + " val_dataset = Dataset(\n", + " root=\"data\",\n", + " recording_id=recording_id,\n", + " config=cfg,\n", + " split=\"valid\",\n", + " )\n", + " # For validation we don't randomize samples for reproducibility\n", + " val_sampling_intervals = val_dataset.get_sampling_intervals()\n", + " val_sampler = SequentialFixedWindowSampler(\n", + " sampling_intervals=val_sampling_intervals,\n", + " window_length=1.0,\n", + " )\n", + " # Combine them in a dataloader\n", + " val_loader = DataLoader(\n", + " dataset=val_dataset,\n", + " sampler=val_sampler,\n", + " batch_size=batch_size,\n", + " collate_fn=collate,\n", + " num_workers=4,\n", + " pin_memory=True,\n", + " )\n", + "\n", + " train_dataset.disable_data_leakage_check()\n", + " val_dataset.disable_data_leakage_check()\n", + "\n", + " return train_dataset, train_loader, val_dataset, val_loader" + ] + }, + { + "cell_type": "markdown", + "id": "v9iwnzt8y888", + "metadata": { + "id": "v9iwnzt8y888" + }, + "source": [ + "Aside from the tokenizer, that's all there is to the data pipeline! Most of it is handled behind-the-scenes with **torch_brain**." + ] + }, + { + "cell_type": "markdown", + "id": "QA_B-daGG414", + "metadata": { + "id": "QA_B-daGG414" + }, + "source": [ + "### 1.3 Downloading a session\n", + "For our examples, we'll use data from a single neural recording from [1], which includes recordings of spiking activity from a monkey during a reaching task. Let's quickly download it to our environment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1qxkvGX0y888", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 7814, + "status": "ok", + "timestamp": 1743092108961, + "user": { + "displayName": "Eva Dyer", + "userId": "05212169819659068372" + }, + "user_tz": 240 + }, + "id": "1qxkvGX0y888", + "outputId": "d962adbc-1583-4ec6-ba7b-224bb47f139b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading...\n", + "From: https://drive.google.com/uc?id=1W--Sm_BcphEC2snoF4zwPdHkkYGgAaUw\n", + "To: /content/data/perich_miller_population_2018/t_20130819_center_out_reaching.h5\n", + "\r", + " 0% 0.00/9.88M [00:00<?, ?B/s]\r", + "100% 9.88M/9.88M [00:00<00:00, 227MB/s]\n" + ] + } + ], + "source": [ + "! mkdir -p data/perich_miller_population_2018\n", + "! gdown 1W--Sm_BcphEC2snoF4zwPdHkkYGgAaUw -O data/perich_miller_population_2018/t_20130819_center_out_reaching.h5" + ] + }, + { + "cell_type": "markdown", + "id": "7F-vMzr0y888", + "metadata": { + "id": "7F-vMzr0y888" + }, + "source": [ + "Now, let's jump into training some models!" + ] + }, + { + "cell_type": "markdown", + "id": "V-ebX-UMy888", + "metadata": { + "id": "V-ebX-UMy888" + }, + "source": [ + "***\n", + "\n", + "## Part 2: Training Models\n", + "\n", + "***\n", + "\n", + "In this section, we walk through training different neural decoders. We'll structure the code so there's a common training loop, then move onto implementing:\n", + "\n", + "- A simple MLP Neural Decoder\n", + "- A Transformer-based Neural Decoder [2]\n", + "- POYO! [3]" + ] + }, + { + "cell_type": "markdown", + "id": "cfRm95aFy888", + "metadata": { + "id": "cfRm95aFy888" + }, + "source": [ + "### 2.1 Warmup: Implementing an MLP based Neural Activity Decoder\n", + "\n", + "To learn how to write and train models in the `torch_brain` framework, let's kick things off with a simple MLP Neural Decoder." + ] + }, + { + "cell_type": "markdown", + "id": "X59hXx4ay888", + "metadata": { + "id": "X59hXx4ay888" + }, + "source": [ + "#### 2.1.1 Defining the Model\n", + "\n", + "In **torch_brain**, every model needs to define two important methods:\n", + "1. `model.tokenize`: the tokenization step, which processes and reshapes data to align with model input.\n", + "2. `model.forward`: the forward pass, which produces predictions.\n", + "\n", + "In case of an MLP decoder, these will look like:\n", + "\n", + "<center>\n", + "<img src=\"https://ik.imagekit.io/7tkfmw7hc/mlp.png?updatedAt=1743064332486\" width=900 />\n", + "</center>" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "Fmb6W850y888", + "metadata": { + "id": "Fmb6W850y888" + }, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "\n", + "class MLPNeuralDecoder(nn.Module):\n", + " def __init__(self, num_units, bin_size, sequence_length, output_dim, hidden_dim):\n", + " \"\"\"Initialize the neural net layers.\"\"\"\n", + " super().__init__()\n", + "\n", + " self.num_timesteps = int(sequence_length / bin_size)\n", + " self.bin_size = bin_size\n", + "\n", + " self.net = nn.Sequential(\n", + " nn.Linear(self.num_timesteps * num_units, hidden_dim),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_dim, hidden_dim),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_dim, output_dim * self.num_timesteps),\n", + " )\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Produces predictions from a binned spiketrain.\n", + " This is pure PyTorch code.\n", + "\n", + " Shape of x: (B, T, N)\n", + " \"\"\"\n", + "\n", + " x = x.flatten(1) # (B, T, N) -> (B, T*N)\n", + " x = self.net(x) # (B, T*N) -> (B, T*D_out)\n", + " x = x.reshape(-1, self.num_timesteps, 2) # (B, T*D_out) -> (B, T, D_out)\n", + " return x\n", + "\n", + " def tokenize(self, data):\n", + " \"\"\"tokenizes a data sample, which is a sliced Data object\"\"\"\n", + "\n", + " # A. Extract and bin neural activity (data.spikes)\n", + " spikes = data.spikes\n", + " x = bin_spikes(\n", + " spikes=spikes,\n", + " num_units=len(data.units),\n", + " bin_size=self.bin_size,\n", + " num_bins=self.num_timesteps\n", + " ).T\n", + " # Final shape of x here is (timestamps, num_neurons)\n", + "\n", + " # B. Extract the corresponding cursor velocity, which will act as targets\n", + " # for training the MLP.\n", + " y = data.cursor.vel\n", + " # Final shape of y is (timestamps x 2)\n", + " # Note that in this example we have choosen the bin size to match the\n", + " # sampling rate of the recorded cursor velocity.\n", + "\n", + " # Finally, we output the \"tokenized\" data in the form of a dictionary.\n", + " data_dict = {\n", + " \"model_inputs\": {\n", + " \"x\": torch.tensor(x, dtype=torch.float32),\n", + " # Models in torch_brain typically follow the convention that\n", + " # fields that are input to model.forward() are stored in\n", + " # \"model_inputs\". Although you are free to deviate from this,\n", + " # we have found that this convention generally produces cleaner\n", + " # training loops.\n", + " },\n", + " \"target_values\": torch.tensor(y, dtype=torch.float32),\n", + " }\n", + " return data_dict" + ] + }, + { + "cell_type": "markdown", + "id": "533abuoQIjb5", + "metadata": { + "id": "533abuoQIjb5" + }, + "source": [ + "**More on tokenization:**\n", + "\n", + "Recall that a data sample emitted by the Dataset is only sliced in time, but still\n", + "contains all fields present in the dataset. In the tokenizer, we:\n", + "1. Extract the fields relevant to our machine learning problem. In this case, we only care about the spiketrain and the cursor velocity.\n", + "2. Process and reshape the extracted data in a format that is convenient for our model to process. In our example, we bin the spiketrain." + ] + }, + { + "cell_type": "markdown", + "id": "TOsvWrojy888", + "metadata": { + "id": "TOsvWrojy888" + }, + "source": [ + "#### 2.1.2 Defining the Training Loop\n", + "\n", + "The following functions perform the training steps. At each epoch, we use the data loader to sample batches of data and train the model using an optimizer. We then compute the R² score on the validation set to monitor the model's performance. This `train` function trains the model for a specified number of epochs and logs the R² score on the validation set." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "lrAvNcOPy888", + "metadata": { + "id": "lrAvNcOPy888" + }, + "outputs": [], + "source": [ + "import torch.nn.functional as F\n", + "\n", + "def train(model, optimizer, train_loader, val_loader, num_epochs=50, store_embs=False):\n", + " # We'll store some intermediate outputs for visualization\n", + " train_outputs = {\n", + " 'n_epochs': num_epochs,\n", + " 'unit_emb': [],\n", + " 'session_emb': [],\n", + " 'output_pred': [],\n", + " 'output_gt': [],\n", + " }\n", + "\n", + " r2_log = []\n", + " loss_log = []\n", + "\n", + " # Training loop\n", + " for epoch in range(num_epochs):\n", + " # Compute R² score on validation set\n", + " r2, target, pred = compute_r2(val_loader, model)\n", + " r2_log.append(r2)\n", + "\n", + " # Training steps\n", + " for batch in train_loader:\n", + " batch = move_to_gpu(batch, device)\n", + " loss = training_step(batch, model, optimizer)\n", + " loss_log.append(loss.item())\n", + "\n", + " print(f\"\\rEpoch {epoch+1}/{num_epochs} | Val R2 = {r2:.3f} | Loss = {loss.item():.3f}\", end=\"\")\n", + "\n", + " # Store intermediate outputs\n", + " if store_embs:\n", + " train_outputs['unit_emb'].append(model.unit_emb.weight[1:].detach().cpu().numpy())\n", + " train_outputs['session_emb'].append(model.session_emb.weight[1:].detach().cpu().numpy())\n", + " train_outputs['output_gt'].append(target.detach().cpu().numpy())\n", + " train_outputs['output_pred'].append(pred.detach().cpu().numpy())\n", + "\n", + " # Compute final R² score\n", + " r2, _, _ = compute_r2(val_loader, model)\n", + " r2_log.append(r2)\n", + " print(f\"\\nDone! Final validation R2 = {r2:.3f}\")\n", + "\n", + " return r2_log, loss_log, train_outputs\n", + "\n", + "\n", + "def training_step(batch, model, optimizer):\n", + " optimizer.zero_grad() # Step 0. Clear old gradients\n", + " pred = model(**batch[\"model_inputs\"]) # Step 1. Do forward pass\n", + " target = batch[\"target_values\"]\n", + " loss = F.mse_loss(pred, target) # Step 2. Compute loss\n", + " loss.backward() # Step 3. Backward pass\n", + " optimizer.step() # Step 4. Update model params\n", + " return loss\n" + ] + }, + { + "cell_type": "markdown", + "id": "PGeSPwBGy888", + "metadata": { + "id": "PGeSPwBGy888" + }, + "source": [ + "#### 2.1.3 Let's train!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "RuDAItbuy889", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 462 + }, + "executionInfo": { + "elapsed": 313592, + "status": "ok", + "timestamp": 1743092422552, + "user": { + "displayName": "Eva Dyer", + "userId": "05212169819659068372" + }, + "user_tz": 240 + }, + "id": "RuDAItbuy889", + "outputId": "2891859d-553d-4853-df36-3177c4fdc169" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Num Units in Session: 55\n", + "Epoch 100/100 | Val R2 = 0.558 | Loss = 5.405\n", + "Done! Final validation R2 = 0.561\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABKUAAAGGCAYAAACqvTJ0AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAA69VJREFUeJzs3XdcU9f7B/BPEiBsEBAQxa3g1rqtq+7RbYfWtmp3rV122l9rqx121w5bW1vrqFZbv9YOrUpxV9Q6cIOo4EA2siHz/v7AhITcLEi4gJ/36+Wr5N5zzz05RJs8ec5zZIIgCCAiIiIiIiIiIqpDcqkHQERERERERERE1x8GpYiIiIiIiIiIqM4xKEVERERERERERHWOQSkiIiIiIiIiIqpzDEoREREREREREVGdY1CKiIiIiIiIiIjqHINSRERERERERERU5xiUIiIiIiIiIiKiOsegFBERERERERER [...] + "text/plain": [ + "<Figure size 1200x400 with 2 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "seed_everything(0)\n", + "\n", + "# 1. Setup datasets and dataloader\n", + "recording_id = \"perich_miller_population_2018/t_20130819_center_out_reaching\"\n", + "train_dataset, train_loader, val_dataset, val_loader = get_train_val_loaders(recording_id, batch_size=64)\n", + "num_units = len(train_dataset.get_unit_ids())\n", + "print(f\"Num Units in Session: {num_units}\")\n", + "\n", + "# 2. Initialize Model with the new MLP definition\n", + "mlp_model = MLPNeuralDecoder(\n", + " num_units=num_units, # Num. of units inputted (spiking activity)\n", + " #\n", + " bin_size=10e-3, # Duration (s) of bins\n", + " sequence_length=1.0, # Context length of the model\n", + " #\n", + " output_dim=2, # Output dimension of final readout layer\n", + " hidden_dim=32, # Hidden dimension of the model\n", + ")\n", + "mlp_model = mlp_model.to(device)\n", + "\n", + "# 3. Connect Tokenizer to Datasets\n", + "transform = mlp_model.tokenize\n", + "train_dataset.transform = transform\n", + "val_dataset.transform = transform\n", + "\n", + "# 4. Setup Optimizer\n", + "optimizer = torch.optim.AdamW(mlp_model.parameters(), lr=1e-3)\n", + "\n", + "# 5. Train!\n", + "mlp_r2_log, mlp_loss_log, mlp_train_outputs = train(mlp_model, optimizer, train_loader, val_loader, num_epochs=100)\n", + "\n", + "# Plot the training loss and validation R2\n", + "plot_training_curves(mlp_r2_log, mlp_loss_log)\n" + ] + }, + { + "cell_type": "markdown", + "id": "mZ6kUr_Sy889", + "metadata": { + "id": "mZ6kUr_Sy889" + }, + "source": [ + "You should now see a training loss curve steadily decreasing and the validation R² rising. These trends mean your model is learning effectively!" + ] + }, + { + "cell_type": "markdown", + "id": "TEt1Bbpcy889", + "metadata": { + "id": "TEt1Bbpcy889" + }, + "source": [ + "### 2.2 Training a simple Transformer for Neural Decoding\n", + "Next up: let's move on to the main course - Transformers! Let's explore how attention can be used for neural decoding by building and training a simple transformer." + ] + }, + { + "cell_type": "markdown", + "id": "SrtWWwspy889", + "metadata": { + "id": "SrtWWwspy889" + }, + "source": [ + "#### 2.2.1 Defining a Transformer model\n", + "\n", + "The philosophy of having a `model.tokenize` and `model.forward` methods remains the same as before, however our model is a bit more complex than the humble MLP.\n", + "\n", + "<center>\n", + " <img src=\"https://ik.imagekit.io/7tkfmw7hc/transformer.png?updatedAt=1743064332335\" alt=\"The Transformer model architecture.\" width=900/>\n", + "</center>\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "mfgVIeUKy889", + "metadata": { + "id": "mfgVIeUKy889" + }, + "outputs": [], + "source": [ + "from torch_brain.nn import FeedForward\n", + "\n", + "class TransformerNeuralDecoder(nn.Module):\n", + " def __init__(\n", + " self, num_units, bin_size, sequence_length, # data properties\n", + " dim_output, dim_hidden, n_layers, n_heads, # transformer properties\n", + " ):\n", + " \"\"\"Initialize the neural net components\"\"\"\n", + " super().__init__()\n", + "\n", + " self.num_timesteps = int(sequence_length / bin_size)\n", + " self.bin_size = bin_size\n", + "\n", + " # Create the read-in/out linear layers\n", + " self.readin = nn.Linear(num_units, dim_hidden)\n", + " self.readout = nn.Linear(dim_hidden, dim_output)\n", + "\n", + " # Create the position embeddings\n", + " # Note that these are kept constant in this implementation, i.e. _not_ learnable\n", + " self.position_embeddings = nn.Parameter(\n", + " data=generate_sinusoidal_position_embs(self.num_timesteps, dim_hidden),\n", + " requires_grad=False,\n", + " )\n", + "\n", + " # Create the transformer layers:\n", + " # each composed of the Attention and the feedforward (FFN) blocks\n", + " self.transformer_layers = nn.ModuleList([\n", + " nn.ModuleList([\n", + " nn.MultiheadAttention(\n", + " embed_dim=dim_hidden,\n", + " num_heads=n_heads,\n", + " batch_first=True,\n", + " ),\n", + " FeedForward(dim=dim_hidden),\n", + " ])\n", + " for _ in range(n_layers)\n", + " ])\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Produces predictions from a binned spiketrain.\n", + " This is pure PyTorch code.\n", + "\n", + " Shape of x: (B, T, N)\n", + " \"\"\"\n", + "\n", + " # Read-in: converts our input marix to transformer tokens; one token for each timestep\n", + " x = self.readin(x) # (B, T, N) -> (B, T, D)\n", + "\n", + " # Add position embeddings to the tokens\n", + " x = x + self.position_embeddings[None, ...] # -> (B, T, D)\n", + "\n", + " # Transformer\n", + " for attn, ffn in self.transformer_layers:\n", + " x = x + attn(x, x, x, need_weights=False)[0]\n", + " x = x + ffn(x)\n", + "\n", + " # Readout: converts tokens to 2d vectors; each vector signifying (v_x, v_y) at that timestep\n", + " x = self.readout(x) # (B, T, D) -> (B, T, 2)\n", + "\n", + " return x\n", + "\n", + " def tokenize(self, data):\n", + " # Same tokenizer as the MLP\n", + "\n", + " # A. Bin spikes\n", + " x = bin_spikes(\n", + " spikes=data.spikes,\n", + " num_units=len(data.units),\n", + " bin_size=self.bin_size,\n", + " num_bins=self.num_timesteps,\n", + " ).T\n", + "\n", + " # B. Extract targets\n", + " y = data.cursor.vel\n", + "\n", + " data_dict = {\n", + " \"model_inputs\": {\n", + " \"x\": torch.tensor(x, dtype=torch.float32),\n", + " },\n", + " \"target_values\": torch.tensor(y, dtype=torch.float32),\n", + " }\n", + " return data_dict" + ] + }, + { + "cell_type": "markdown", + "id": "tNLt0os-y889", + "metadata": { + "id": "tNLt0os-y889" + }, + "source": [ + "#### 2.2.2 Let's train!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "qWAph1Vjy889", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 444 + }, + "executionInfo": { + "elapsed": 314676, + "status": "ok", + "timestamp": 1743092737253, + "user": { + "displayName": "Eva Dyer", + "userId": "05212169819659068372" + }, + "user_tz": 240 + }, + "id": "qWAph1Vjy889", + "outputId": "df155464-0118-477c-809e-30c849347425" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 100/100 | Val R2 = 0.671 | Loss = 3.814\n", + "Done! Final validation R2 = 0.710\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABKUAAAGGCAYAAACqvTJ0AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAA/nFJREFUeJzs3XdYk1f7B/BvEkLYICAgijJU3KMO3KNuba2jrdql1o63ahdvl/21tnbRaa2trX2tVm1rtXbYZVXqHrjFLeJAQBkCQpgh4/n9ERISMkggIajfz3V5lTzjPOc5pPrkzn3uIxIEQQAREREREREREVEDEru6A0REREREREREdPthUIqIiIiIiIiIiBocg1JERERERERERNTgGJQiIiIiIiIiIqIGx6AUERERERERERE1OAaliIiIiIiIiIiowTEoRUREREREREREDY5BKSIiIiIiIiIianAMShERERER [...] + "text/plain": [ + "<Figure size 1200x400 with 2 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "seed_everything(0)\n", + "\n", + "# 1. Setup datasets and dataloader\n", + "recording_id = \"perich_miller_population_2018/t_20130819_center_out_reaching\"\n", + "train_dataset, train_loader, val_dataset, val_loader = get_train_val_loaders(recording_id, batch_size=64)\n", + "num_units = len(train_dataset.get_unit_ids())\n", + "\n", + "# 2. Initialize Model\n", + "tf_model = TransformerNeuralDecoder(\n", + " num_units=num_units, # Num. of units inputted (spiking activity)\n", + " #\n", + " bin_size=10e-3, # Duration (s) of bins\n", + " sequence_length=1.0, # Context length of the model\n", + " #\n", + " dim_output=2, # Output dimension of final readout layer\n", + " dim_hidden=128, # Hidden dimension of the model\n", + " n_layers=3, # Num. of transformer layers\n", + " n_heads=4, # Num. of heads in MHA blocks\n", + ").to(device)\n", + "\n", + "# 3. Connect Tokenizer to Datasets\n", + "train_dataset.transform = tf_model.tokenize\n", + "val_dataset.transform = tf_model.tokenize\n", + "\n", + "# 4. Setup Optimizer\n", + "optimizer = torch.optim.AdamW(tf_model.parameters(), lr=1e-3)\n", + "\n", + "# 5. Train!\n", + "transformer_r2_log, transformer_loss_log, transformer_train_outputs = train(tf_model, optimizer, train_loader, val_loader, num_epochs=100)\n", + "\n", + "# Plot the training loss and validation R2\n", + "plot_training_curves(transformer_r2_log, transformer_loss_log)" + ] + }, + { + "cell_type": "markdown", + "id": "rxqk7gJry88-", + "metadata": { + "id": "rxqk7gJry88-" + }, + "source": [ + "\n", + "### 2.3 Training POYO!\n", + "\n", + "Here we show how we can instantiate a POYO model and train it from scratch." + ] + }, + { + "cell_type": "markdown", + "id": "9-csi0JSZUHj", + "metadata": { + "id": "9-csi0JSZUHj" + }, + "source": [ + "From the figure below, we can see that the POYO model has a lot more going on internally in comparison to the two models we just saw. Thankfully, **torch_brain** provides implementations of standard neural decoding models that we can directly use in (almost) the same framework as we did with the previous models. We currently have implementations of POYO [3], and POYO+ [4], and are actively working on adding more models, such as NDT-2 [5], MTM [6], etc.\n", + "\n", + "<center>\n", + "<img src=\"https://ik.imagekit.io/7tkfmw7hc/poyo.png?updatedAt=1743064332701\" width=900/>\n", + "</center>\n", + "\n", + "<!-- For custom models (and when training on a single session), it suffices to specify a single recording ID when defining the Dataset and DataLoaders, especially since things like normalization can be defined within the model definition. Since POYO was built for multi-session, it's important to define a more concrete configuration that allows for training across multiple sessions, or even multiple datasets. Here we simplify the process for the session we have been using, by declari [...] + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "PC_5tFCny88-", + "metadata": { + "id": "PC_5tFCny88-" + }, + "outputs": [], + "source": [ + "seed_everything(0)\n", + "\n", + "# 1. Setup datasets and dataloader\n", + "# For a model like POYO, which was built for multi-session training, the way to\n", + "# instantiate a dataset is just slightly more involved than what we have used\n", + "# so far. We have abstracted that difference in a utility function `get_dataset_config`\n", + "train_dataset, train_loader, val_dataset, val_loader = get_train_val_loaders(\n", + " cfg=get_dataset_config(\"perich_miller_population_2018\", \"t_20130819_center_out_reaching\"),\n", + " batch_size=64,\n", + ")\n", + "\n", + "# 2. Instantiate the model. The model implementation is provided by torch_brain\n", + "from torch_brain.models import POYO\n", + "from torch_brain.registry import MODALITY_REGISTRY\n", + "poyo_model = POYO(\n", + " sequence_length=1.0, # Context length of the model\n", + " readout_spec=MODALITY_REGISTRY['cursor_velocity_2d'], # POYO allows for multiple readout modalities; this is how we choose\n", + " #\n", + " latent_step=1.0 / 8, # Timestep of the learned latent grid\n", + " num_latents_per_step=16, # Number of unique learned latents per timestep\n", + " #\n", + " dim=64, # Hidden dimension of the model\n", + " depth=6, # Number of transformer layers\n", + " #\n", + " dim_head=64, # Dimension of each attention head\n", + " cross_heads=2, # Num. of heads in cross-attention blocks\n", + " self_heads=8, # Num. of heads in self attention blocks\n", + ").to(device)\n", + "\n", + "# 2.5: Extra step: populate the Unit and Session Embedding Vocabularies\n", + "poyo_model.unit_emb.initialize_vocab(train_dataset.get_unit_ids())\n", + "poyo_model.session_emb.initialize_vocab(train_dataset.get_session_ids())\n", + "\n", + "# 3. Connect tokenizers to Datasets\n", + "train_dataset.transform = poyo_model.tokenize\n", + "val_dataset.transform = poyo_model.tokenize\n", + "\n", + "# 4. Setup Optimizer\n", + "optimizer = torch.optim.AdamW(poyo_model.parameters(), lr=1e-3)\n", + "\n", + "# 5. Train!\n", + "poyo_r2_log, poyo_loss_log, poyo_train_outputs = train(\n", + " poyo_model, optimizer, train_loader, val_loader,\n", + " num_epochs=100, store_embs=True,\n", + ")\n", + "\n", + "# Plot the training loss and validation R2\n", + "plot_training_curves(poyo_r2_log, poyo_loss_log)" + ] + }, + { + "cell_type": "markdown", + "id": "R02IoWf_y88-", + "metadata": { + "id": "R02IoWf_y88-" + }, + "source": [ + "***\n", + "\n", + "<a id=\"Finetuning\"></a>\n", + "## Part 3. Finetuning and Visualizing\n", + "\n", + "***\n", + "\n", + "In the above sections, we saw that creating and training models for a single session of neural recordings, can be made quick and easy using **torch_brain**. In addition, standard models in **torch_brain** have been pretrained across different sessions, subjects, tasks, and even datasets. We can utilize the generalization capabilities of these scaled up models, and *fine-tune* them on new data with minimal effort.\n", + "\n", + "In this section, we will demonstrate how to fine-tune a pretrained POYO-mp model on a new session. We note that the session we've\n", + "been using so far (`t_20130819_center_out_reaching`) was held out of POYO-mp pretraining.\n", + "\n", + "***\n", + "### Table of contents:\n", + "* 3.1 Loading a Pretrained Model\n", + "* 3.2 Finetuning Strategies\n", + "* 3.3 Let's fine-tune!\n", + "* 3.4 Let's compare\n", + "* 3.5 Visualizing the POYO model" + ] + }, + { + "cell_type": "markdown", + "id": "oFkTHcipy88-", + "metadata": { + "id": "oFkTHcipy88-" + }, + "source": [ + "### 3.1 Loading a Pretrained Model\n", + "First we will download the POYO-MP pretrained checkpoint." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "-6x_N3-hy88-", + "metadata": { + "id": "-6x_N3-hy88-" + }, + "outputs": [], + "source": [ + "! uv pip install boto3 -q\n", + "import boto3\n", + "from botocore import UNSIGNED\n", + "from botocore.config import Config\n", + "\n", + "# Download the pretrained model\n", + "s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))\n", + "s3.download_file(\n", + " \"torch-brain\",\n", + " \"model-zoo/poyo_mp.ckpt\",\n", + " \"poyo_mp.ckpt\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "T_yUo6l0y88-", + "metadata": { + "id": "T_yUo6l0y88-" + }, + "source": [ + "\n", + "### 3.2 Finetuning Strategies\n", + "\n", + "As a reminder, POYO learns an embedding space for units and sessions that is shared across all its training data. When fine-tuning POYO on a new session, we need to find how to project the new units and session into the previously learned embedding space. For this we keep most of the model weights unchanged and only train the new embeddings (we call this *freezing the backbone*). We can opt to unfreeze them after a certain number of epochs, or keep them frozen throughout.\n", + "\n", + "<center>\n", + " <img src=\"https://media-hosting.imagekit.io/61678065e60f442a/Screenshot 2025-03-26 at 6.59.27 PM.png?Expires=1837637971&Key-Pair-Id=K2ZIVPTIP2VGHC&Signature=aG8nDmY1CwhZ~6bOcuYws22lY3qn3qbNf7m-gWFrO3ZvBNDpfLSpxxhlbrRv5wysohL~COsjtziTinRALYxzkN7QxiQ1hhIpM8ieFamz2Po-0Ocvx12~FQeDOMRFMCfVay3QdOkXTTyP0PgFab7~e9oK2VEljN20SzhAd2mH4qVLLh-kZXHGvrnjbbSIxQiQ~OKL04tNzyRSwKglhv13l54qtpqUAparzW5lA3H50ZuiLMJJ9ue03Qg4a2X32SrHMkJUpoeFZgVgAw046Ukry-MZixsFVynLVv8vDWdl7e4ZhQIE3kJuXxpowWqdlOoSk5-oBDw [...] + "</center>" + ] + }, + { + "cell_type": "markdown", + "id": "PzWxTVE-bMOy", + "metadata": { + "id": "PzWxTVE-bMOy" + }, + "source": [ + "#### Finetuning function\n", + "\n", + "Depending on the value of `epoch_to_unfreeze`, we have different fine-tuning strategies:\n", + "- **Full fine-tuning** (with gradual unfreezing): if we set `epoch_to_unfreeze >= 0` then we will train the new embeddings for `epoch_to_unfreeze` epochs,\n", + "then unfreeze the backbone and train the entire model for the remaining epochs.\n", + "- **Unit Identification**: if we set `epoch_to_unfreeze == -1`, then we will only train the new embeddings for all epochs, i.e. the backbone\n", + "will remain frozen throughout training.\n", + "\n", + "For this we'll define a new training function that allows for different fine-tuning strategies (it's otherwise the same as the previous training function)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5k7z0v1Jy88-", + "metadata": { + "id": "5k7z0v1Jy88-" + }, + "outputs": [], + "source": [ + "from tqdm import tqdm\n", + "from torch.cuda.amp import autocast, GradScaler\n", + "\n", + "def finetune(model, optimizer, train_loader, val_loader, num_epochs=50, epoch_to_unfreeze=30):\n", + " # Freeze the backbone\n", + " backbone_params = [\n", + " p for p in model.named_parameters()\n", + " if (\n", + " 'unit_emb' not in p[0]\n", + " and 'session_emb' not in p[0]\n", + " and 'readout' not in p[0]\n", + " and p[1].requires_grad\n", + " )\n", + " ]\n", + " for _, param in backbone_params:\n", + " param.requires_grad = False\n", + "\n", + " # We'll store some intermediate outputs for visualization\n", + " train_outputs = {\n", + " 'n_epochs': num_epochs,\n", + " 'epoch_to_unfreeze': epoch_to_unfreeze,\n", + " 'unit_emb': [],\n", + " 'session_emb': [],\n", + " 'output_pred': [],\n", + " 'output_gt': [],\n", + " }\n", + "\n", + " r2_log = []\n", + " loss_log = []\n", + "\n", + " # Training loop\n", + " for epoch in range(num_epochs):\n", + " # Unfreeze the backbone after `epoch_to_unfreeze` epochs\n", + " if epoch == epoch_to_unfreeze:\n", + " for _, param in backbone_params:\n", + " param.requires_grad = True\n", + " print(\" Unfreezing entire model\")\n", + "\n", + " with torch.no_grad():\n", + " r2, target, pred = compute_r2(val_loader, model)\n", + " r2_log.append(r2)\n", + "\n", + " for batch in train_loader:\n", + " batch = move_to_gpu(batch, device)\n", + " loss = training_step(batch, model, optimizer)\n", + " loss_log.append(loss.item())\n", + "\n", + " print(f\"\\rEpoch {epoch+1}/{num_epochs} | Val R2 = {r2:.3f} | Loss = {loss.item():.3f}\", end=\"\")\n", + "\n", + " # Store intermediate outputs\n", + " train_outputs['unit_emb'].append(model.unit_emb.weight[1:].detach().cpu().numpy())\n", + " train_outputs['session_emb'].append(model.session_emb.weight[1:].detach().cpu().numpy())\n", + " train_outputs['output_gt'].append(target.detach().cpu().numpy())\n", + " train_outputs['output_pred'].append(pred.detach().cpu().numpy())\n", + "\n", + " del target, pred\n", + "\n", + " # Compute final R² score\n", + " r2, _, _ = compute_r2(val_loader, model)\n", + " r2_log.append(r2)\n", + " print(f\"\\nDone! Final validation R2 = {r2:.3f}\")\n", + "\n", + " return r2_log, loss_log, train_outputs" + ] + }, + { + "cell_type": "markdown", + "id": "2XebFpsKy88-", + "metadata": { + "id": "2XebFpsKy88-" + }, + "source": [ + "### 3.3 Let's fine-tune!\n", + "Note: the model was pretrained with the fixed vocabulary of units and sessions used during training. Since we are now fine-tuning on a new session (with new units), we extend the vocabulary to include the new units and session, then subset the vocabulary to only look at\n", + "the new ones.\n", + "\n", + "Here we'll put the full training workflow into a function as well so we can easily repeat it for different sessions or\n", + "finetuning strategies. Of course, if one wishes to manually implement a more complex training loop, they can do so as well." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "CAohIKciy88-", + "metadata": { + "id": "CAohIKciy88-" + }, + "outputs": [], + "source": [ + "seed_everything(0)\n", + "\n", + "# 1. Setup datasets and dataloader\n", + "train_dataset, train_loader, val_dataset, val_loader = get_train_val_loaders(\n", + " cfg=get_dataset_config(\"perich_miller_population_2018\", \"t_20130819_center_out_reaching\"),\n", + " batch_size=64,\n", + ")\n", + "num_units = len(train_dataset.get_unit_ids())\n", + "\n", + "# 2. Instantiate the model and load pretrained weights (with poyo-mp hparams)\n", + "from torch_brain.models import POYO\n", + "from torch_brain.registry import MODALITY_REGISTRY\n", + "poyo_ft_model = POYO(\n", + " sequence_length=1.0,\n", + " latent_step=1.0 / 8,\n", + " dim=128,\n", + " readout_spec=MODALITY_REGISTRY['cursor_velocity_2d'],\n", + " dim_head=64,\n", + " num_latents_per_step=32,\n", + " depth=24,\n", + " cross_heads=4,\n", + " self_heads=8,\n", + ")\n", + "\n", + "ckpt_path = 'poyo_mp.ckpt'\n", + "poyo_ft_model = load_pretrained(ckpt_path, poyo_ft_model)\n", + "\n", + "# 2.5. Reinitialize the vocabs for the new session\n", + "reinit_vocab(poyo_ft_model.unit_emb, train_dataset.get_unit_ids())\n", + "reinit_vocab(poyo_ft_model.session_emb, train_dataset.get_session_ids())\n", + "\n", + "poyo_ft_model.to(device)\n", + "\n", + "# 3. Connect tokenizers to Datasets\n", + "train_dataset.transform = poyo_ft_model.tokenize\n", + "val_dataset.transform = poyo_ft_model.tokenize\n", + "\n", + "# 4. Setup Optimizer\n", + "optimizer = torch.optim.AdamW(poyo_ft_model.parameters(), lr=1e-3)\n", + "\n", + "# 5. Train!\n", + "poyo_ft_r2_log, poyo_ft_loss_log, poyo_ft_train_outputs = finetune(poyo_ft_model, optimizer, train_loader, val_loader,\n", + " num_epochs=40, epoch_to_unfreeze=10)\n", + "\n", + "# Visualize the results\n", + "plot_training_curves(poyo_ft_r2_log, poyo_ft_loss_log)" + ] + }, + { + "cell_type": "markdown", + "id": "RA1mnGAOfL4Y", + "metadata": { + "id": "RA1mnGAOfL4Y" + }, + "source": [ + "### 3.4 Let's compare" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ZI41uQ2ry88_", + "metadata": { + "id": "ZI41uQ2ry88_" + }, + "outputs": [], + "source": [ + "lcls = locals().copy()\n", + "for lcl in lcls:\n", + " if not lcl.endswith(\"_r2_log\"):\n", + " continue\n", + " model = lcl.split(\"_r2_log\")[0].upper()\n", + " plt.plot(locals()[lcl], label=model)\n", + "plt.xlabel(\"Epoch\")\n", + "plt.ylabel(\"Validation $R^2$\")\n", + "plt.grid()\n", + "plt.legend()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "mHfhuXT8y88_", + "metadata": { + "id": "mHfhuXT8y88_" + }, + "source": [ + "### 3.5 Visualizing the POYO model\n", + "\n", + "POYO's unit and session embeddings are learned during training, and we can visualize them along with a snapshot of the\n", + "model's performance." + ] + }, + { + "cell_type": "markdown", + "id": "08Cr0k2lcKJV", + "metadata": { + "id": "08Cr0k2lcKJV" + }, + "source": [ + "#### Visualization utilities\n", + "\n", + "Just run this block." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "sdWySSFCy88_", + "metadata": { + "id": "sdWySSFCy88_" + }, + "outputs": [], + "source": [ + "from sklearn.decomposition import PCA\n", + "from bokeh.plotting import figure, show\n", + "from bokeh.models import ColumnDataSource, Slider, CustomJS, Div, Spacer\n", + "from bokeh.layouts import column, row\n", + "from bokeh.io import output_notebook\n", + "from bokeh.palettes import Pastel1\n", + "\n", + "# visualize_training(model, train_outputs, prev_session_emb, prev_session_emb_labels, n_epochs, epoch_to_unfreeze):\n", + "def visualize_training(model, ckpt, train_outputs):\n", + " n_epochs = train_outputs[\"n_epochs\"]\n", + " epoch_to_unfreeze = train_outputs.get(\"epoch_to_unfreeze\", 0)\n", + " # Extract info from \"finetuned\" model\n", + " cur_session_emb_labels = list(model.session_emb.vocab.keys())[1:]\n", + " # Extract info from \"pretrained\" model\n", + " if ckpt is not None and ckpt != \"\":\n", + " ckpt_data = torch.load(ckpt, weights_only=False, map_location=\"cpu\")\n", + " prev_session_emb = ckpt_data['state_dict']['model.session_emb.weight'][1:].detach().clone()\n", + " prev_session_emb_labels = [str(x) for x in ckpt_data['state_dict']['model.session_emb.vocab']][1:]\n", + " else:\n", + " prev_session_emb = np.zeros((0, model.dim))\n", + " prev_session_emb_labels = []\n", + "\n", + " n_units = train_outputs['unit_emb'][0].shape[0]\n", + " n_sessions = train_outputs['session_emb'][0].shape[0]\n", + " n_prev_sessions = prev_session_emb.shape[0]\n", + "\n", + " pca = PCA(n_components=2)\n", + " flat_unit_emb = np.concatenate(train_outputs['unit_emb'], axis=0) # (n_epochs*n_units, emb_dim)\n", + " unit_emb_pca = pca.fit_transform(flat_unit_emb) # (n_epochs*n_units, 2)\n", + " unit_emb_x = unit_emb_pca[:, 0]\n", + " unit_emb_y = unit_emb_pca[:, 1]\n", + "\n", + " total_session_emb = np.concatenate([prev_session_emb, # (n_prev_sessions+n_epochs*n_sessions, emb_dim)\n", + " *train_outputs['session_emb']], axis=0)\n", + " session_emb_pca = pca.fit_transform(total_session_emb) # (n_prev_sessions+n_epochs*n_sessions, 2)\n", + " session_emb_x = session_emb_pca[:, 0]\n", + " session_emb_y = session_emb_pca[:, 1]\n", + "\n", + " sampling_rate = 100\n", + " sample_out_start = 2000\n", + " sample_out_end = sample_out_start + 500\n", + " pred_x = np.concatenate([output[sample_out_start:sample_out_end, 0] for output in train_outputs['output_pred']])\n", + " pred_y = np.concatenate([output[sample_out_start:sample_out_end, 1] for output in train_outputs['output_pred']])\n", + " gt_x = np.concatenate([output[sample_out_start:sample_out_end, 0] for output in train_outputs['output_gt']])\n", + " gt_y = np.concatenate([output[sample_out_start:sample_out_end, 1] for output in train_outputs['output_gt']])\n", + "\n", + " _Pastel1 = Pastel1.copy()\n", + " _Pastel1 = {**_Pastel1, **{i: _Pastel1[3][:i] for i in range(1, 3)}} # for some reason only starts at 3\n", + " all_session_emb_labels = prev_session_emb_labels + cur_session_emb_labels\n", + " extract_sess_label_grp = lambda label: label.split('/')[0] + '/' + label.split('/')[1].split('_')[0]\n", + " all_sess_emb_label_grps = [extract_sess_label_grp(label) for label in all_session_emb_labels] # map to <brainset/xx>_...\n", + " unique_sess_emb_label_grps = list(set(all_sess_emb_label_grps))\n", + " unique_sess_emb_label_grp_map = {label: i for i, label in enumerate(unique_sess_emb_label_grps)} # create index map for unique labels\n", + " all_sess_emb_label_grp_idx = [unique_sess_emb_label_grp_map[label] for label in all_sess_emb_label_grps] # map labels to index\n", + "\n", + " # Generate pastel colors based on unique labels\n", + " unique_labels = np.unique(all_sess_emb_label_grp_idx)\n", + " color_map = {label: _Pastel1[len(unique_labels)][i % len(_Pastel1[len(unique_labels)])] for i, label in enumerate(unique_labels)}\n", + "\n", + " # Assign colors based on group index\n", + " colors = np.array([color_map[label] for label in all_sess_emb_label_grp_idx])\n", + "\n", + " # Define border color: Only sessions after `n_prev_sessions` should have a border\n", + " border_colors = np.array(['black' if i >= n_prev_sessions else None for i in range(n_prev_sessions+n_sessions)])\n", + "\n", + " # Enable inline visualization in Jupyter Notebook\n", + " output_notebook()\n", + "\n", + " # Data sources\n", + " unit_source = ColumnDataSource(data={'x': unit_emb_x[:n_units], 'y': unit_emb_y[:n_units]})\n", + " sess_source = ColumnDataSource(data={\n", + " 'x': np.concatenate([session_emb_x[:n_prev_sessions],\n", + " session_emb_x[n_prev_sessions:n_prev_sessions+n_sessions]]),\n", + " 'y': np.concatenate([session_emb_y[:n_prev_sessions],\n", + " session_emb_y[n_prev_sessions:n_prev_sessions+n_sessions]]),\n", + " 'color': colors,\n", + " 'border_color': border_colors,\n", + " })\n", + "\n", + " time = np.linspace(0, (sample_out_end-sample_out_start)/sampling_rate, sample_out_end-sample_out_start)\n", + " pred_source_x = ColumnDataSource(data={'x': time, 'y': pred_x[:sample_out_end-sample_out_start]})\n", + " pred_source_y = ColumnDataSource(data={'x': time, 'y': pred_y[:sample_out_end-sample_out_start]})\n", + " gt_source_x = ColumnDataSource(data={'x': time, 'y': gt_x[:sample_out_end-sample_out_start]})\n", + " gt_source_y = ColumnDataSource(data={'x': time, 'y': gt_y[:sample_out_end-sample_out_start]})\n", + "\n", + " # Define plot ranges with buffer\n", + " buffer = 0.1\n", + " unit_x_min, unit_x_max = np.min(unit_emb_x), np.max(unit_emb_x)\n", + " unit_y_min, unit_y_max = np.min(unit_emb_y), np.max(unit_emb_y)\n", + " sess_x_min, sess_x_max = np.min(session_emb_x), np.max(session_emb_x)\n", + " sess_y_min, sess_y_max = np.min(session_emb_y), np.max(session_emb_y)\n", + " unit_buffer_x, unit_buffer_y = buffer * abs(unit_x_min - unit_x_max), buffer * abs(unit_y_min - unit_y_max)\n", + " unit_buffer_x, unit_buffer_y = 0.01 if unit_buffer_x == 0 else unit_buffer_x, 0.01 if unit_buffer_y == 0 else unit_buffer_y\n", + " sess_buffer_x, sess_buffer_y = buffer * abs(sess_x_min - sess_x_max), buffer * abs(sess_y_min - sess_y_max)\n", + " sess_buffer_x, sess_buffer_y = 0.01 if sess_buffer_x == 0 else sess_buffer_x, 0.01 if sess_buffer_y == 0 else sess_buffer_y\n", + " unit_x_range = (unit_x_min - unit_buffer_x, unit_x_max + unit_buffer_x)\n", + " unit_y_range = (unit_y_min - unit_buffer_y, unit_y_max + unit_buffer_y)\n", + " sess_x_range = (sess_x_min - sess_buffer_x, sess_x_max + sess_buffer_x)\n", + " sess_y_range = (sess_y_min - sess_buffer_y, sess_y_max + sess_buffer_y)\n", + " unit_plot = figure(title=\"Unit Embeddings\", x_axis_label=\"PC1\", y_axis_label=\"PC2\", width=300, height=300, background_fill_color=\"white\",\n", + " x_range=unit_x_range, y_range=unit_y_range)\n", + "\n", + " sess_plot = figure(title=\"Session Embeddings\", x_axis_label=\"PC1\", y_axis_label=\"PC2\", width=300, height=300, background_fill_color=\"white\",\n", + " x_range=sess_x_range, y_range=sess_y_range)\n", + "\n", + " # Unit embedding plot\n", + " unit_plot.scatter('x', 'y', source=unit_source, size=10, alpha=0.7, color='lightblue', line_color='black')\n", + "\n", + " # Session embedding plot\n", + " sess_plot.scatter('x', 'y', source=sess_source, size=10, alpha=0.7, color='color', line_color='border_color')\n", + "\n", + " # Hand velocity plots\n", + " vx_min, vx_max = np.min(np.concatenate([pred_x, gt_x])), np.max(np.concatenate([pred_x, gt_x]))\n", + " vx_range = (vx_min - buffer * abs(vx_min), vx_max + buffer * abs(vx_max))\n", + " vx_plot = figure(title=\"Hand Velocity - Vx\", x_axis_label=\"Time (s)\", y_axis_label=\"Velocity\", width=300, height=150, background_fill_color=\"white\"\n", + " , y_range=vx_range)\n", + " vx_plot.line('x', 'y', source=pred_source_x, color='red')\n", + " vx_plot.line('x', 'y', source=gt_source_x, color='black')\n", + "\n", + " vy_min, vy_max = np.min(np.concatenate([pred_y, gt_y])), np.max(np.concatenate([pred_y, gt_y]))\n", + " vy_range = (vy_min - buffer * abs(vy_min), vy_max + buffer * abs(vy_max))\n", + " vy_plot = figure(title=\"Hand Velocity - Vy\", x_axis_label=\"Time (s)\", y_axis_label=\"Velocity\", width=300, height=150, background_fill_color=\"white\"\n", + " , y_range=vy_range)\n", + " vy_plot.line('x', 'y', source=pred_source_y, color='blue')\n", + " vy_plot.line('x', 'y', source=gt_source_y, color='black')\n", + "\n", + " # Model diagram\n", + " if epoch_to_unfreeze == 0:\n", + " model_diagram = Div(text='<img src=\"https://media-hosting.imagekit.io//33c2439b9dc549dc/unit-id-unfrozen.svg?Expires=1837453466&Key-Pair-Id=K2ZIVPTIP2VGHC&Signature=i4i4qMNDJhiTEfC-Uo-DhSaxCUVsrN9W7m7mL4RDWPaRHrhsCXuIXsjPzTksoIQobZr9qMGVRNwTpDD-jGIT1-Y2K42H5uXY37WXNfBhHzcb-GsyYAjx9ztVuYi9OAeaiZdtXe-Yc-xQF88RWsBdSBw0KA26Ewwj2CcuBoexlSL3rNttoWMHVzyisTDFyX2N1uYmpnRbKFarnU0Xvn9OXx2y64fgZyT8oGgbYShjHZApDqEujfRubXwrvH86etyPvuzvbzqMxc27u-BgMQv0l--qsq2tP3y66AhQu~EuwbU9E4Dxud94bekO7 [...] + " else:\n", + " model_diagram = Div(text='<img src=\"https://media-hosting.imagekit.io//7bc43a35ec174ef4/unit-id-frozen.svg?Expires=1837453466&Key-Pair-Id=K2ZIVPTIP2VGHC&Signature=CylM3BrCyxTkiRd2DaTpqU~Q2gwDYBH8LjmL0Duv5ybgWMegyrYWfDVmu4weQx~bM7Dbd2BJ3Q-83WhSVRNRTBEzq3Dkr-RjuPdhbBAKhL-0Ku7aC4eo~EetHwa5PeuYGsQvs0jFaQF-XRv3Ow04kxww8m9gAgLTuDyL6-2FYf~68EXGrD0A6GqYeozU65~nwptnzF~YVu2gYsp5ERTOdbyE5TjNIKp21QBwhQ~BkbZ5NEfvITkdvbTv9j1k2p5hmxW~jm15Llz-oxj2l-l2fM~3UZ8JC6NHze~lTqmANrrxf0GIEYBAAwvDWj6 [...] + "\n", + " # Widgets\n", + " slider = Slider(start=0, end=n_epochs-1, value=0, step=1, title=\"Epoch\")\n", + "\n", + " callback = CustomJS(args=dict(source1=unit_source, source2=sess_source, source3=pred_source_x, source4=pred_source_y,\n", + " source5=gt_source_x, source6=gt_source_y,\n", + " unit_x=unit_emb_x, unit_y=unit_emb_y, sess_x=session_emb_x, sess_y=session_emb_y, #cat=session_cat,\n", + " n_units=n_units, n_sessions=n_sessions, n_prev_sessions=n_prev_sessions,\n", + " diagram=model_diagram, unfreeze_epoch=epoch_to_unfreeze,\n", + " pred_x=pred_x, pred_y=pred_y, gt_x=gt_x, gt_y=gt_y,\n", + " samples=sample_out_end-sample_out_start,\n", + " ),#pred_x=pred_x, pred_y=pred_y),\n", + " code=\"\"\"\n", + " var step = cb_obj.value;\n", + " source1.data.x = unit_x.slice(step*n_units, (step+1)*n_units);\n", + " source1.data.y = unit_y.slice(step*n_units, (step+1)*n_units);\n", + "\n", + " let prev_sess_x = sess_x.slice(0, n_prev_sessions);\n", + " let cur_sess_x = sess_x.slice(n_prev_sessions+n_sessions*step, n_prev_sessions+n_sessions*(step+1));\n", + " source2.data.x = Array.from(prev_sess_x).concat(Array.from(cur_sess_x));\n", + " let prev_sess_y = sess_y.slice(0, n_prev_sessions);\n", + " let cur_sess_y = sess_y.slice(n_prev_sessions+n_sessions*step, n_prev_sessions+n_sessions*(step+1));\n", + " source2.data.y = Array.from(prev_sess_y).concat(Array.from(cur_sess_y));\n", + "\n", + " source3.data.y = pred_x.slice(step*samples, (step+1)*samples);\n", + " source4.data.y = pred_y.slice(step*samples, (step+1)*samples);\n", + " source5.data.y = gt_x.slice(step*samples, (step+1)*samples);\n", + " source6.data.y = gt_y.slice(step*samples, (step+1)*samples);\n", + "\n", + " source1.change.emit();\n", + " source2.change.emit();\n", + " source3.change.emit();\n", + " source4.change.emit();\n", + " source5.change.emit();\n", + " source6.change.emit();\n", + "\n", + " if (unfreeze_epoch >= 0 && step >= unfreeze_epoch) {\n", + " diagram.text = '<img src=\"https://media-hosting.imagekit.io//33c2439b9dc549dc/unit-id-unfrozen.svg?Expires=1837453466&Key-Pair-Id=K2ZIVPTIP2VGHC&Signature=i4i4qMNDJhiTEfC-Uo-DhSaxCUVsrN9W7m7mL4RDWPaRHrhsCXuIXsjPzTksoIQobZr9qMGVRNwTpDD-jGIT1-Y2K42H5uXY37WXNfBhHzcb-GsyYAjx9ztVuYi9OAeaiZdtXe-Yc-xQF88RWsBdSBw0KA26Ewwj2CcuBoexlSL3rNttoWMHVzyisTDFyX2N1uYmpnRbKFarnU0Xvn9OXx2y64fgZyT8oGgbYShjHZApDqEujfRubXwrvH86etyPvuzvbzqMxc27u-BgMQv0l--qsq2tP3y66AhQu~EuwbU9E4Dxud94bekO7ZE9a5T [...] + " } else {\n", + " diagram.text = '<img src=\"https://media-hosting.imagekit.io//7bc43a35ec174ef4/unit-id-frozen.svg?Expires=1837453466&Key-Pair-Id=K2ZIVPTIP2VGHC&Signature=CylM3BrCyxTkiRd2DaTpqU~Q2gwDYBH8LjmL0Duv5ybgWMegyrYWfDVmu4weQx~bM7Dbd2BJ3Q-83WhSVRNRTBEzq3Dkr-RjuPdhbBAKhL-0Ku7aC4eo~EetHwa5PeuYGsQvs0jFaQF-XRv3Ow04kxww8m9gAgLTuDyL6-2FYf~68EXGrD0A6GqYeozU65~nwptnzF~YVu2gYsp5ERTOdbyE5TjNIKp21QBwhQ~BkbZ5NEfvITkdvbTv9j1k2p5hmxW~jm15Llz-oxj2l-l2fM~3UZ8JC6NHze~lTqmANrrxf0GIEYBAAwvDWj6QcSmZX [...] + " }\n", + " \"\"\")\n", + "\n", + " slider.js_on_change('value', callback)\n", + "\n", + " # Improved Layout Structure\n", + " layout = column(\n", + " slider,\n", + " row(\n", + " column(unit_plot, sess_plot),\n", + " model_diagram,\n", + " column(Spacer(height=80), vx_plot, vy_plot)\n", + " ),\n", + " styles={\n", + " \"background-color\": \"white\",\n", + " \"padding\": \"20px\",\n", + " }\n", + " )\n", + "\n", + " show(layout, notebook_handle=True)" + ] + }, + { + "cell_type": "markdown", + "id": "3Hx_zKDKcTjn", + "metadata": { + "id": "3Hx_zKDKcTjn" + }, + "source": [ + "#### Visualize" + ] + }, + { + "cell_type": "markdown", + "id": "C9ayfgrSfx_E", + "metadata": { + "id": "C9ayfgrSfx_E" + }, + "source": [ + "We can visualize the training of POYO from scratch:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "rT_Fnw-sy88_", + "metadata": { + "id": "rT_Fnw-sy88_" + }, + "outputs": [], + "source": [ + "visualize_training(poyo_model, ckpt=None, train_outputs=poyo_train_outputs)" + ] + }, + { + "cell_type": "markdown", + "id": "tgCMJHHOfvjU", + "metadata": { + "id": "tgCMJHHOfvjU" + }, + "source": [ + "Or from POYO fine-tuning:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "FnDXtBOnfzJ2", + "metadata": { + "id": "FnDXtBOnfzJ2" + }, + "outputs": [], + "source": [ + "visualize_training(poyo_ft_model, ckpt=\"poyo_mp.ckpt\", train_outputs=poyo_ft_train_outputs)" + ] + }, + { + "cell_type": "markdown", + "id": "jSPsG3Jsy88_", + "metadata": { + "id": "jSPsG3Jsy88_" + }, + "source": [ + "***\n", + "\n", + "## Part 4. Your Turn!\n", + "\n", + "***\n", + "Now that you've seen how to train and fine-tune models on neural data, try experimenting on your own! Here are three exercises from us with varying levels of difficulty." + ] + }, + { + "cell_type": "markdown", + "id": "su7e6QfmeZ91", + "metadata": { + "id": "su7e6QfmeZ91" + }, + "source": [ + "### 4.1 Exercise (a): Easy\n", + "\n", + "In the transformer implementation in Part 2.2, change the positional embeddings from being fixed (i.e. non learnable) to being learnable, and train the transformer from scratch. Note that since we have limited training data, we won't observe an improvement in performance by making this switch.\n" + ] + }, + { + "cell_type": "markdown", + "id": "tV_IjhBFe-j7", + "metadata": { + "id": "tV_IjhBFe-j7" + }, + "source": [ + "#### Solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "N7HKd5qngBwV", + "metadata": { + "id": "N7HKd5qngBwV" + }, + "outputs": [], + "source": [ + "from torch_brain.nn import FeedForward\n", + "\n", + "class TransformerNeuralDecoder(nn.Module):\n", + " def __init__(\n", + " self, num_units, bin_size, sequence_length, # data properties\n", + " dim_output, dim_hidden, n_layers, n_heads, # transformer properties\n", + " ):\n", + " \"\"\"Initialize the neural net components\"\"\"\n", + " super().__init__()\n", + "\n", + " self.num_timesteps = int(sequence_length / bin_size)\n", + " self.bin_size = bin_size\n", + "\n", + " self.readin = nn.Linear(num_units, dim_hidden)\n", + " self.readout = nn.Linear(dim_hidden, dim_output)\n", + "\n", + " self.position_embeddings = nn.Parameter(\n", + " data=generate_sinusoidal_position_embs(self.num_timesteps, dim_hidden),\n", + " requires_grad=True, ##### <<<<<< SOLUTION: CHANGED TO TRUE\n", + " )\n", + "\n", + " self.transformer_layers = nn.ModuleList([\n", + " nn.ModuleList([\n", + " nn.MultiheadAttention(\n", + " embed_dim=dim_hidden,\n", + " num_heads=n_heads,\n", + " batch_first=True,\n", + " ),\n", + " FeedForward(dim=dim_hidden),\n", + " ])\n", + " for _ in range(n_layers)\n", + " ])\n", + "\n", + " def forward(self, x):\n", + " # Remains unchanged\n", + "\n", + " x = self.readin(x) # (B, T, N) -> (B, T, D)\n", + " x = x + self.position_embeddings[None, ...] # -> (B, T, D)\n", + " for attn, ffn in self.transformer_layers:\n", + " x = x + attn(x, x, x, need_weights=False)[0]\n", + " x = x + ffn(x)\n", + " x = self.readout(x) # (B, T, D) -> (B, T, 2)\n", + "\n", + " return x\n", + "\n", + " def tokenize(self, data):\n", + " # Remains unchanged\n", + "\n", + " x = bin_spikes(\n", + " spikes=data.spikes,\n", + " num_units=len(data.units),\n", + " bin_size=self.bin_size,\n", + " num_bins=self.num_timesteps,\n", + " ).T\n", + " y = data.cursor.vel\n", + " data_dict = {\n", + " \"model_inputs\": {\n", + " \"x\": torch.tensor(x, dtype=torch.float32),\n", + " },\n", + " \"target_values\": torch.tensor(y, dtype=torch.float32),\n", + " }\n", + " return data_dict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "-M92mJFHiLfp", + "metadata": { + "id": "-M92mJFHiLfp" + }, + "outputs": [], + "source": [ + "seed_everything(0)\n", + "\n", + "# 1. Setup datasets and dataloader\n", + "recording_id = \"perich_miller_population_2018/t_20130819_center_out_reaching\"\n", + "train_dataset, train_loader, val_dataset, val_loader = get_train_val_loaders(recording_id, batch_size=64)\n", + "num_units = len(train_dataset.get_unit_ids())\n", + "\n", + "# 2. Initialize Model\n", + "tf_model = TransformerNeuralDecoder(\n", + " num_units=num_units,\n", + " bin_size=10e-3,\n", + " sequence_length=1.0,\n", + " dim_output=2,\n", + " dim_hidden=128,\n", + " n_layers=3,\n", + " n_heads=4,\n", + ").to(device)\n", + "\n", + "# 3. Connect Tokenizer to Datasets\n", + "train_dataset.transform = tf_model.tokenize\n", + "val_dataset.transform = tf_model.tokenize\n", + "\n", + "# 4. Setup Optimizer\n", + "optimizer = torch.optim.AdamW(tf_model.parameters(), lr=1e-3)\n", + "\n", + "# 5. Train!\n", + "transformer_r2_log, transformer_loss_log, transformer_train_outputs = train(tf_model, optimizer, train_loader, val_loader, num_epochs=100)\n", + "\n", + "# Plot the training loss and validation R2\n", + "plot_training_curves(transformer_r2_log, transformer_loss_log)" + ] + }, + { + "cell_type": "markdown", + "id": "AeJJIf4Wj06v", + "metadata": { + "id": "AeJJIf4Wj06v" + }, + "source": [ + "You might observe that this model does not perform as well as the the version in Part 2.2 where the position embeddings were fixed. This can be attributed to overfitting, now that we can train a lot more parameters (the position embeddings!)." + ] + }, + { + "cell_type": "markdown", + "id": "gqWV0Xbry6tm", + "metadata": { + "id": "gqWV0Xbry6tm" + }, + "source": [ + "### 4.2 Exercise (b): Medium (Multi-Session Training!)\n", + "\n", + "Throughout this notebook, we have been training on a single session (`t_20130819_center_out_reaching`) which includes recordings from a single monkey during a reaching task. What if we want to train on multiple sessions? We noted in Part 2.3 that the configuration used to setup data for training POYO allows for loading multiple sessions, as well as some more complex configurations.\n", + "\n", + "In this exercise, we'll train on two separate sessions from the same monkey during the reaching task. The new session is called `t_20130821_center_out_reaching` and includes recordings from a later date than the first one. Note that we might be interested in training across different animals or tasks as well, but we are limited by the model capacity.\n", + "\n", + "First, download the new session below:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "-i-u6bA2p5vm", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "executionInfo": { + "elapsed": 7065, + "status": "ok", + "timestamp": 1743098384978, + "user": { + "displayName": "Shivashriganesh Mahato", + "userId": "06240283312510304541" + }, + "user_tz": 240 + }, + "id": "-i-u6bA2p5vm", + "outputId": "eb7f7253-7542-4dc3-90bf-4d3c9d8457d5" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading...\n", + "From: https://drive.google.com/uc?id=1EZUkOE8oiieWja9lblokf8WjF05HFuX-\n", + "To: /content/data/perich_miller_population_2018/t_20130821_center_out_reaching.h5\n", + "100% 10.4M/10.4M [00:00<00:00, 51.1MB/s]\n" + ] + } + ], + "source": [ + "!gdown 1EZUkOE8oiieWja9lblokf8WjF05HFuX- -O data/perich_miller_population_2018/t_20130821_center_out_reaching.h5" + ] + }, + { + "cell_type": "markdown", + "id": "KpEZXyt_0sr0", + "metadata": { + "id": "KpEZXyt_0sr0" + }, + "source": [ + "Now, setup train and validation DataLoaders such that training will be across both sessions, but validation will stay only on the first session (so we can directly compare). Then, train POYO on the two sessions for 50 epochs (half of that from original training).\n", + "\n", + "**Tips**:\n", + "\n", + "* `get_dataset_config` accepts both a single session as a string (as we've seen so far), or a list of sessions:\n", + "\n", + "```get_dataset_config(brainset, [session1, session2, ...])```\n", + "* You cannot use `get_train_val_loaders` directly since it uses the same config for both train and val" + ] + }, + { + "cell_type": "markdown", + "id": "PbZisWiX3wO2", + "metadata": { + "id": "PbZisWiX3wO2" + }, + "source": [ + "#### Solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "oFIaWjgTnBmf", + "metadata": { + "id": "oFIaWjgTnBmf" + }, + "outputs": [], + "source": [ + "# -- Train --\n", + "train_dataset = Dataset(\n", + " root=\"data\",\n", + " config=get_dataset_config(\"perich_miller_population_2018\", [\n", + " \"t_20130819_center_out_reaching\",\n", + " \"t_20130821_center_out_reaching\", # <<<<<< SOLUTION: ADDED\n", + " ]),\n", + " split=\"train\",\n", + ")\n", + "train_sampling_intervals = train_dataset.get_sampling_intervals()\n", + "train_sampler = RandomFixedWindowSampler(\n", + " sampling_intervals=train_sampling_intervals,\n", + " window_length=1.0,\n", + ")\n", + "train_loader = DataLoader(\n", + " dataset=train_dataset,\n", + " sampler=train_sampler,\n", + " batch_size=64,\n", + " collate_fn=collate,\n", + " num_workers=4,\n", + " pin_memory=True,\n", + ")\n", + "\n", + "# -- Validation --\n", + "val_dataset = Dataset(\n", + " root=\"data\",\n", + " config=get_dataset_config(\"perich_miller_population_2018\", \"t_20130819_center_out_reaching\"),\n", + " split=\"valid\",\n", + ")\n", + "val_sampling_intervals = val_dataset.get_sampling_intervals()\n", + "val_sampler = SequentialFixedWindowSampler(\n", + " sampling_intervals=val_sampling_intervals,\n", + " window_length=1.0,\n", + ")\n", + "val_loader = DataLoader(\n", + " dataset=val_dataset,\n", + " sampler=val_sampler,\n", + " batch_size=64,\n", + " collate_fn=collate,\n", + " num_workers=4,\n", + " pin_memory=True,\n", + ")\n", + "\n", + "train_dataset.disable_data_leakage_check()\n", + "val_dataset.disable_data_leakage_check()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ZpFEmw5TrHxW", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 443 + }, + "executionInfo": { + "elapsed": 764067, + "status": "ok", + "timestamp": 1743099192177, + "user": { + "displayName": "Shivashriganesh Mahato", + "userId": "06240283312510304541" + }, + "user_tz": 240 + }, + "id": "ZpFEmw5TrHxW", + "outputId": "c6938d53-f226-435e-af50-e72bbede892d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 100/100 | Val R2 = 0.770 | Loss = 0.004\n", + "Done! Final validation R2 = 0.781\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABKUAAAGGCAYAAACqvTJ0AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAyW9JREFUeJzs3XlclNX+B/DPzDAM+yayiiJgIipioKhlehUBtVKzUtNUMr1p3Eoq067iWpiZkWVRlrmkuV2zX2kIYlQmivu+4b6xqTCswzDz/P5AxkZAQWdhhs/79eJ1eZ7nPOc558htHr6c8z0iQRAEEBERERERERERGZDY2A0gIiIiIiIiIqKmh0EpIiIiIiIiIiIyOAaliIiIiIiIiIjI4BiUIiIiIiIiIiIig2NQioiIiIiIiIiIDI5BKSIiIiIiIiIiMjgGpYiIiIiIiIiIyOAYlCIiIiIiIiIiIoNjUIqI [...] + "text/plain": [ + "<Figure size 1200x400 with 2 Axes>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Training is same as we did before\n", + "\n", + "seed_everything(0)\n", + "\n", + "from torch_brain.models import POYO\n", + "from torch_brain.registry import MODALITY_REGISTRY\n", + "poyo_model = POYO(\n", + " sequence_length=1.0,\n", + " readout_spec=MODALITY_REGISTRY['cursor_velocity_2d'],\n", + " latent_step=1.0 / 8,\n", + " num_latents_per_step=16,\n", + " dim=64,\n", + " depth=6,\n", + " dim_head=64,\n", + " cross_heads=2,\n", + " self_heads=8,\n", + ").to(device)\n", + "\n", + "poyo_model.unit_emb.initialize_vocab(train_dataset.get_unit_ids())\n", + "poyo_model.session_emb.initialize_vocab(train_dataset.get_session_ids())\n", + "\n", + "train_dataset.transform = poyo_model.tokenize\n", + "val_dataset.transform = poyo_model.tokenize\n", + "\n", + "optimizer = torch.optim.AdamW(poyo_model.parameters(), lr=1e-3)\n", + "\n", + "poyo_r2_log, poyo_loss_log, poyo_train_outputs = train(\n", + " poyo_model, optimizer, train_loader, val_loader,\n", + " num_epochs=100, store_embs=True,\n", + ")\n", + "\n", + "plot_training_curves(poyo_r2_log, poyo_loss_log)" + ] + }, + { + "cell_type": "markdown", + "id": "h0zOuuou5Ft1", + "metadata": { + "id": "h0zOuuou5Ft1" + }, + "source": [ + "Once the training is done you can see that just by training on another session the validation performance on the original session improves slighly. You can imagine it's not too much more work to then train at the scale of POYO and POYO+ with this setup. Just keep scaling up the dataset size (and model size as well) and performance will keep on improving." + ] + }, + { + "cell_type": "markdown", + "id": "mXbkMNKhegyu", + "metadata": { + "id": "mXbkMNKhegyu" + }, + "source": [ + "### 4.3 Exercise (c): Medium\n", + "\n", + "In our MLP model in Part 2.1, we had intentionally set `bin_size` equal to the sampling period of the behavior output (10ms) to simplify the code. Now that you are more accustomed to the framework, make `bin_size` independent of the `output_sampling_period`. The signature of the MLP model should be:\n", + "\n", + "```python\n", + "class MLPNeuralDecoder(nn.Module):\n", + " def __init__(\n", + " self, num_units, bin_size, sequence_length,\n", + " output_sampling_period, # NEW ARGUMENT!\n", + " output_dim, hidden_dim,\n", + " ):\n", + " ...\n", + "```\n", + "\n", + "Once you're done modifying the model, train it with `bin_size=20e-3`, and `output_sampling_period=10e-3`." + ] + }, + { + "cell_type": "markdown", + "id": "raNpOx91gBwU", + "metadata": { + "id": "raNpOx91gBwU" + }, + "source": [ + "#### Solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "MEEvUEu-iunF", + "metadata": { + "id": "MEEvUEu-iunF" + }, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "\n", + "class MLPNeuralDecoder(nn.Module):\n", + " def __init__(self, num_units, bin_size, sequence_length, output_sampling_period, output_dim, hidden_dim):\n", + " \"\"\"Initialize the neural net layers.\"\"\"\n", + " super().__init__()\n", + "\n", + " self.num_input_timesteps = int(sequence_length / bin_size) # CHANGE\n", + " self.num_output_timesteps = int(sequence_length / output_sampling_period) # CHANGE\n", + " self.bin_size = bin_size\n", + "\n", + " self.net = nn.Sequential(\n", + " nn.Linear(self.num_input_timesteps * num_units, hidden_dim), # CHANGE\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_dim, hidden_dim),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_dim, output_dim * self.num_output_timesteps), # CHANGE\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = x.flatten(1)\n", + " x = self.net(x)\n", + " x = x.reshape(-1, self.num_output_timesteps, 2) # CHANGE\n", + " return x\n", + "\n", + " def tokenize(self, data):\n", + " spikes = data.spikes\n", + " x = bin_spikes(\n", + " spikes=spikes,\n", + " num_units=len(data.units),\n", + " bin_size=self.bin_size,\n", + " num_bins=self.num_input_timesteps # CHANGE\n", + " ).T\n", + " y = data.cursor.vel\n", + " data_dict = {\n", + " \"model_inputs\": {\n", + " \"x\": torch.tensor(x, dtype=torch.float32),\n", + " },\n", + " \"target_values\": torch.tensor(y, dtype=torch.float32),\n", + " }\n", + " return data_dict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "x1LuW5W7kF7w", + "metadata": { + "id": "x1LuW5W7kF7w" + }, + "outputs": [], + "source": [ + "seed_everything(0)\n", + "\n", + "# 1. Setup datasets and dataloader\n", + "recording_id = \"perich_miller_population_2018/t_20130819_center_out_reaching\"\n", + "train_dataset, train_loader, val_dataset, val_loader = get_train_val_loaders(recording_id, batch_size=64)\n", + "num_units = len(train_dataset.get_unit_ids())\n", + "print(f\"Num Units in Session: {num_units}\")\n", + "\n", + "# 2. Initialize Model with the new MLP definition\n", + "mlp_model = MLPNeuralDecoder(\n", + " num_units=num_units,\n", + " bin_size=20e-3, # CHANGE\n", + " sequence_length=1.0,\n", + " output_sampling_period=10e-3, # CHANGE\n", + " output_dim=2,\n", + " hidden_dim=32,\n", + ")\n", + "mlp_model = mlp_model.to(device)\n", + "\n", + "# 3. Connect Tokenizer to Datasets\n", + "transform = mlp_model.tokenize\n", + "train_dataset.transform = transform\n", + "val_dataset.transform = transform\n", + "\n", + "# 4. Setup Optimizer\n", + "optimizer = torch.optim.AdamW(mlp_model.parameters(), lr=1e-3)\n", + "\n", + "# 5. Train!\n", + "mlp_r2_log, mlp_loss_log, mlp_train_outputs = train(mlp_model, optimizer, train_loader, val_loader, num_epochs=100)\n", + "\n", + "# Plot the training loss and validation R2\n", + "plot_training_curves(mlp_r2_log, mlp_loss_log)" + ] + }, + { + "cell_type": "markdown", + "id": "ixe6TofbehAk", + "metadata": { + "id": "ixe6TofbehAk" + }, + "source": [ + "### 4.4 Exercise (d): Challenging\n", + "\n", + "So far in all our examples, we've been using the full set of recorded units for decoding. It turns out that the session includes recordings of spiking activity from neurons in both the monkey Premotor Cortex (PMd) and Primary Motor Cortex (M1). What if we restrict our model to only decode from neurons in the PMd?\n", + "\n", + "For this exercise, note that the brain regions of the units can be found directly in the unit IDs (refer to training code from any of the examples to see how to access these IDs). Given these, change the Transformer code from Part 2.2 to:\n", + "\n", + "(i) First explore the unit IDs to figure out how to extract and filter on the brain region.\n", + "\n", + "(ii) Write a new tokenizer that, given a batch of samples from both brain regions, creates tokens only from the PMd neurons.\n", + "\n", + "(iii) Instantiate a Transformer model with number of units restricted based only to PMd neurons.\n", + "\n", + "(iv) Train the Transformer model using your new tokenizer, and compare results to training on all neurons.\n", + "\n", + "**Question**: We are dropping the number of units the model is trained on, but in the process reducing the training to data from the same brain region. Given these, should we expect model performance to increase or decrease?\n", + "\n", + "**Tip**: save `r2_log` and `loss_log` as different names than before (e.g., `transformer_r2_log`) so you retain both for comparison." + ] + }, + { + "cell_type": "markdown", + "id": "JLNQmHTTgCcG", + "metadata": { + "id": "JLNQmHTTgCcG" + }, + "source": [ + "#### Solution" + ] + }, + { + "cell_type": "markdown", + "id": "p8kAOMoPtnxQ", + "metadata": { + "id": "p8kAOMoPtnxQ" + }, + "source": [ + "##### (i) Solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "lYlQWPtegCcH", + "metadata": { + "id": "lYlQWPtegCcH" + }, + "outputs": [], + "source": [ + "# Setup data loaders\n", + "seed_everything(0)\n", + "recording_id = \"perich_miller_population_2018/t_20130819_center_out_reaching\"\n", + "train_dataset, train_loader, val_dataset, val_loader = get_train_val_loaders(recording_id, batch_size=64)\n", + "\n", + "# Explore the unit ids\n", + "unit_ids = train_dataset.get_unit_ids()\n", + "for unit_id in unit_ids[:5] + unit_ids[-5:]:\n", + " print(unit_id)" + ] + }, + { + "cell_type": "markdown", + "id": "KFavck4ytsIk", + "metadata": { + "id": "KFavck4ytsIk" + }, + "source": [ + "We can see that the unit IDs directly contain the brain region (between M1 and PMd). Hence filtering is fairly straightforward:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "-xqxJyZqtqZW", + "metadata": { + "id": "-xqxJyZqtqZW" + }, + "outputs": [], + "source": [ + "# Filter on PMd neurons\n", + "pmd_units = [unit_id for unit_id in unit_ids if 'PMd' in unit_id]\n", + "for unit_id in pmd_units[:5] + pmd_units[-5:]:\n", + " print(unit_id)" + ] + }, + { + "cell_type": "markdown", + "id": "3jS_UQspt1tD", + "metadata": { + "id": "3jS_UQspt1tD" + }, + "source": [ + "##### (ii) Solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6wPD2K74t246", + "metadata": { + "id": "6wPD2K74t246" + }, + "outputs": [], + "source": [ + "# We adapt the Transformer implementation from Part 2.2\n", + "class TransformerNeuralDecoder(nn.Module):\n", + " def __init__(\n", + " self, num_units, bin_size, sequence_length,\n", + " dim_output, dim_hidden, n_layers, n_heads,\n", + " ):\n", + " # NO CHANGE TO THIS METHOD\n", + " super().__init__()\n", + "\n", + " self.num_timesteps = int(sequence_length / bin_size)\n", + " self.bin_size = bin_size\n", + "\n", + " self.readin = nn.Linear(num_units, dim_hidden)\n", + " self.readout = nn.Linear(dim_hidden, dim_output)\n", + "\n", + " self.position_embeddings = nn.Parameter(\n", + " data=generate_sinusoidal_position_embs(self.num_timesteps, dim_hidden),\n", + " requires_grad=False,\n", + " )\n", + "\n", + " self.transformer_layers = nn.ModuleList([\n", + " nn.ModuleList([\n", + " nn.MultiheadAttention(\n", + " embed_dim=dim_hidden,\n", + " num_heads=n_heads,\n", + " batch_first=True,\n", + " ),\n", + " FeedForward(dim=dim_hidden),\n", + " ])\n", + " for _ in range(n_layers)\n", + " ])\n", + "\n", + " def forward(self, x):\n", + " # NO CHANGE TO THIS METHOD\n", + " x = self.readin(x)\n", + " x = x + self.position_embeddings[None, ...]\n", + " for attn, ffn in self.transformer_layers:\n", + " x = x + attn(x, x, x, need_weights=False)[0]\n", + " x = x + ffn(x)\n", + " x = self.readout(x)\n", + " return x\n", + "\n", + " def tokenize(self, data):\n", + " unit_mask = np.array(['PMd' in unit_id for unit_id in unit_ids]) # <<<<<< SOLUTION: ADDED\n", + "\n", + " x = bin_spikes(\n", + " spikes=data.spikes,\n", + " num_units=len(data.units),\n", + " bin_size=self.bin_size,\n", + " num_bins=self.num_timesteps,\n", + " ).T\n", + " x = x[:, unit_mask] # <<<<<< SOLUTION: ADDED\n", + "\n", + " y = data.cursor.vel\n", + "\n", + " data_dict = {\n", + " \"model_inputs\": {\n", + " \"x\": torch.tensor(x, dtype=torch.float32),\n", + " },\n", + " \"target_values\": torch.tensor(y, dtype=torch.float32),\n", + " }\n", + " return data_dict" + ] + }, + { + "cell_type": "markdown", + "id": "EkAdwarVt6xf", + "metadata": { + "id": "EkAdwarVt6xf" + }, + "source": [ + "##### (iii) Solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5WiHIzjft9TJ", + "metadata": { + "id": "5WiHIzjft9TJ" + }, + "outputs": [], + "source": [ + "# Initialize model for PMd neurons only\n", + "num_units = len(pmd_units)\n", + "pmd_tf_model = TransformerNeuralDecoder(\n", + " num_units=num_units,\n", + " bin_size=10e-3,\n", + " sequence_length=1.0,\n", + " dim_output=2,\n", + " dim_hidden=128,\n", + " n_layers=3,\n", + " n_heads=4,\n", + ").to(device)" + ] + }, + { + "cell_type": "markdown", + "id": "xOfbQ9vSt-of", + "metadata": { + "id": "xOfbQ9vSt-of" + }, + "source": [ + "##### (iv) Solution" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "y8mJq4Hmt_eh", + "metadata": { + "id": "y8mJq4Hmt_eh" + }, + "outputs": [], + "source": [ + "# Connect new tokenizer to Datasets\n", + "train_dataset.transform = pmd_tf_model.tokenize\n", + "val_dataset.transform = pmd_tf_model.tokenize\n", + "\n", + "# Let's train!\n", + "optimizer = torch.optim.AdamW(pmd_tf_model.parameters(), lr=1e-3)\n", + "pmd_transformer_r2_log, pmd_transformer_loss_log, _ = train(pmd_tf_model, optimizer, train_loader, val_loader, num_epochs=100)\n", + "plot_training_curves(pmd_transformer_r2_log, pmd_transformer_loss_log)" + ] + }, + { + "cell_type": "markdown", + "id": "psz1G5Q0wDSw", + "metadata": { + "id": "psz1G5Q0wDSw" + }, + "source": [ + "### A note on composing transforms\n", + "\n", + "The current solution to Exercise (c) was possible because we were able to edit the tokenizer of the model directly. Many times, it would be infeasible or inconvenient to do so. In such situations, **torch_brain**'s ability to **compose** transforms will come in handy.\n", + "\n", + "For instance, if we wanted to train POYO on only PMd neurons, a clean approach would be:\n", + "1. Define a transform that removes all M1 neurons in a data sample and only keeps the PMd neurons.\n", + "2. Define `dataset.transform = Compose([drop_M1_neurons_transform, model.tokenize])`\n", + "\n", + "We refer curious readers to the documentation of **Compose**: [link](https://torch-brain.readthedocs.io/en/v0.1.0/package/transforms.html#torch_brain.transforms.Compose).\n", + "\n", + "This page also showcases existing transforms that are already implementated in **torch_brain**, like UnitDropout, some of which may come handy for you.\n" + ] + }, + { + "cell_type": "markdown", + "id": "XSQNyUF_y89C", + "metadata": { + "id": "XSQNyUF_y89C" + }, + "source": [ + "***\n", + "\n", + "## References\n", + "\n", + "[1] [Perich, M. G., Gallego, J. A., & Miller, L. E. (2018). A neural population mechanism for rapid learning. Neuron, 100(4), 964-976.](https://pubmed.ncbi.nlm.nih.gov/30344047/)\n", + "\n", + "[2] [Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). Attention is all you need. Advances in neural information processing systems, 30.](https://papers.nips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)\n", + "\n", + "[3] [Azabou, M., Arora, V., Ganesh, V., Mao, X., Nachimuthu, S., Mendelson, M., ... & Dyer, E. (2023). A unified, scalable framework for neural population decoding. Advances in Neural Information Processing Systems, 36, 44937-44956.](https://proceedings.neurips.cc/paper_files/paper/2023/file/8ca113d122584f12a6727341aaf58887-Paper-Conference.pdf)\n", + "\n", + "[4] [Azabou, M., Pan, K. X., Arora, V., Knight, I.J., Dyer, E. L., Richards, B. A. (2025). Multi-session, multi-task neural decoding from distinct cell-types and brain regions. International Conference on Learning Representations, 13.](https://openreview.net/pdf?id=IuU0wcO0mo)\n", + "\n", + "[5] [Ye, J., Collinger, J., Wehbe, L., & Gaunt, R. (2023). Neural data transformer 2: multi-context pretraining for neural spiking activity. Advances in Neural Information Processing Systems, 36, 80352-80374.](https://papers.neurips.cc/paper_files/paper/2023/file/fe51de4e7baf52e743b679e3bdba7905-Paper-Conference.pdf)\n", + "\n", + "[6] [Zhang, Y., Wang, Y., Jiménez-Benetó, D., Wang, Z., Azabou, M., Richards, B., ... & Hurwitz, C. (2024). Towards a\" universal translator\" for neural dynamics at single-cell, single-spike resolution. Advances in Neural Information Processing Systems, 37, 80495-80521.](https://proceedings.neurips.cc/paper_files/paper/2024/file/934eb45b99eff8f16b5cb8e4d3cb5641-Paper-Conference.pdf)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [ + "BycghpnsEKBg", + "PzWxTVE-bMOy", + "08Cr0k2lcKJV", + "tV_IjhBFe-j7", + "PbZisWiX3wO2", + "raNpOx91gBwU", + "JLNQmHTTgCcG", + "p8kAOMoPtnxQ", + "3jS_UQspt1tD", + "EkAdwarVt6xf", + "xOfbQ9vSt-of" + ], + "gpuType": "T4", + "provenance": [ + { + "file_id": "1WDLqGcFm0cNmoggIatTo2Vp2gncc9kF5", + "timestamp": 1742891256057 + } + ], + "toc_visible": true + }, + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python", + "notebook_metadata_filter": "-kernelspec,-jupytext.text_representation.jupytext_version" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/modules/agent-framework/deployments/jupyterhub/data/cosyne/cybershuttle.yml b/modules/agent-framework/deployments/jupyterhub/data/cosyne/cybershuttle.yml new file mode 100644 index 0000000000..1949e8af1d --- /dev/null +++ b/modules/agent-framework/deployments/jupyterhub/data/cosyne/cybershuttle.yml @@ -0,0 +1,64 @@ +project: + name: Foundations of Transformers in Neuroscience + version: 1.0.0 + description: This project demonstrates how to prepare datasets using temporaldata, utilize torch_brain samplers, and build neural decoding models (MLP, Transformer, POYO). You’ll also fine-tune a pretrained POYO model and visualize training, all in an interactive, hands-on notebook. + tags: + - neuroscience + - transformers + homepage: https://github.com/yasithdev/allen-v1 + +workspace: + location: /workspace + resources: + min_cpu: 4 + min_gpu: 0 + min_mem: 4096 + walltime: 60 + cluster: Anvil + group: Cerebrum + queue: shared + model_collection: + - source: cybershuttle + identifier: lgn_stimulus + mount_point: /data/lgn_stimulus + - source: cybershuttle + identifier: v1_point + mount_point: /data/v1_point + data_collection: + - source: cybershuttle + identifier: lgn_stimulus_configs + mount_point: /data/lgn_stimulus_configs + - source: cybershuttle + identifier: v1_point_configs + mount_point: /data/v1_point_configs + +additional_dependencies: + conda: + - python=3.10 + - pip + - ipywidgets + - numba + - numpy=1.23.5 + - matplotlib + - openpyxl + - pandas + - pyqtgraph + - pyyaml + - requests + - scipy + - sqlalchemy + - tqdm + - nest-simulator + - ipytree + - python-jsonpath + - pydantic=2.7 + - anndata + - parse + pip: + - allensdk + - bmtk + - pytree + - git+https://github.com/alleninstitute/abc_atlas_access + - git+https://github.com/alleninstitute/neuroanalysis + - git+https://github.com/alleninstitute/aisynphys + - git+https://github.com/lahirujayathilake/mousev1 diff --git a/modules/agent-framework/deployments/jupyterhub/data/gkeyll/plotE_z.ipynb b/modules/agent-framework/deployments/jupyterhub/data/gkeyll/plotE_z.ipynb index ff7cc617b0..6bb46a321c 100644 --- a/modules/agent-framework/deployments/jupyterhub/data/gkeyll/plotE_z.ipynb +++ b/modules/agent-framework/deployments/jupyterhub/data/gkeyll/plotE_z.ipynb @@ -20,7 +20,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install --force-reinstall airavata-jupyter-magic\n", + "# %pip install --force-reinstall airavata-python-sdk[notebook]\n", "import airavata_jupyter_magic\n", "\n", "%authenticate\n", @@ -142,7 +142,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -156,7 +156,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.10.16" } }, "nbformat": 4,
