This is an automated email from the ASF dual-hosted git repository.

shunping pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 9a15261c00d Adding Timesfm anomaly detection colab example (#35844)
9a15261c00d is described below

commit 9a15261c00d44d187460e100e1fdf566960e4235
Author: Ashok Devireddy <[email protected]>
AuthorDate: Thu Aug 14 20:34:35 2025 -0400

    Adding Timesfm anomaly detection colab example (#35844)
    
    * Created using Colab
    
    * add license
    
    * added dataset to gcloud
    
    * cleared outputs
    
    * removed commented code
    
    * Delete examples/notebooks/beam-ml/anomaly_detection 
/anomaly_detection_timesfm.ipynb
    
    * removed commented code
    
    * changed to beam 2.67.0 and edited file paths
    
    * uncomment plot data line
---
 .../anomaly_detection_timesfm.ipynb                | 2712 ++++++++++++++++++++
 1 file changed, 2712 insertions(+)

diff --git 
a/examples/notebooks/beam-ml/anomaly_detection/anomaly_detection_timesfm.ipynb 
b/examples/notebooks/beam-ml/anomaly_detection/anomaly_detection_timesfm.ipynb
new file mode 100644
index 00000000000..65a06fd5d8e
--- /dev/null
+++ 
b/examples/notebooks/beam-ml/anomaly_detection/anomaly_detection_timesfm.ipynb
@@ -0,0 +1,2712 @@
+{
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "provenance": [],
+      "gpuType": "T4"
+    },
+    "kernelspec": {
+      "name": "python3",
+      "display_name": "Python 3"
+    },
+    "language_info": {
+      "name": "python"
+    },
+    "accelerator": "GPU"
+  },
+  "cells": [
+    {
+      "cell_type": "code",
+      "source": [
+        "# @title ###### Licensed to the Apache Software Foundation (ASF), 
Version 2.0 (the \"License\")\n",
+        "\n",
+        "# Licensed to the Apache Software Foundation (ASF) under one\n",
+        "# or more contributor license agreements. See the NOTICE file\n",
+        "# distributed with this work for additional information\n",
+        "# regarding copyright ownership. The ASF licenses this file\n",
+        "# to you under the Apache License, Version 2.0 (the\n",
+        "# \"License\"); you may not use this file except in compliance\n",
+        "# with the License. You may obtain a copy of the License at\n",
+        "#\n",
+        "#   http://www.apache.org/licenses/LICENSE-2.0\n";,
+        "#\n",
+        "# Unless required by applicable law or agreed to in writing,\n",
+        "# software distributed under the License is distributed on an\n",
+        "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
+        "# KIND, either express or implied. See the License for the\n",
+        "# specific language governing permissions and limitations\n",
+        "# under the License"
+      ],
+      "metadata": {
+        "id": "eMMlVe_Gukos"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# TimesFM Anomaly Detection Pipeline Diagram\n",
+        "Time series data is a sequence of data points indexed by time, where 
each data point is recorded at a specific interval. TimesFM is a foundation 
model pretrained on a large corpus of time series data. Its architecture is a 
decoder-only transformer, similar to LLMs, which learns to predict the next 
part of a time series from previous data. We can use the follow pipeline to 
detect anomalies in time series data and periodically learn from incoming data 
to improve our timesfm predic [...]
+        "\n",
+        "![Untitled 
drawing.jpg](data:image/jpeg;base64,/9j/4AAQSkZJRgABAgAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAG0B3oDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJ
 [...]
+      ],
+      "metadata": {
+        "id": "cAgnGkn3GFVb"
+      }
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "collapsed": true,
+        "id": "oCgmuQtdrSkG"
+      },
+      "outputs": [],
+      "source": [
+        "!pip install timesfm[torch]\n",
+        "!pip install 'apache_beam[gcp, test, interactive] == 2.67.0'\n",
+        "!pip install google-generativeai"
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# Ordered Sliding 
Window![preprocessing.jpg](data:image/jpeg;base64,/9j/4AAQSkZJRgABAgAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAIPCGcDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGl
 [...]
+      ],
+      "metadata": {
+        "id": "VgyZHICtuRMz"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "import logging\n",
+        "\n",
+        "import apache_beam as beam\n",
+        "from apache_beam.coders import BooleanCoder\n",
+        "from apache_beam.coders import PickleCoder\n",
+        "from apache_beam.coders import TimestampCoder\n",
+        "from apache_beam.transforms.timeutil import TimeDomain\n",
+        "from apache_beam.transforms.userstate import OrderedListStateSpec\n",
+        "from apache_beam.transforms.userstate import 
ReadModifyWriteStateSpec\n",
+        "from apache_beam.transforms.userstate import TimerSpec\n",
+        "from apache_beam.transforms.userstate import on_timer\n",
+        "from apache_beam.utils.timestamp import MAX_TIMESTAMP\n",
+        "from apache_beam.utils.timestamp import Timestamp\n",
+        "\n",
+        "_LOGGER = logging.getLogger(__name__)\n",
+        "logging.basicConfig(level=logging.INFO)\n",
+        "_LOGGER.setLevel(logging.INFO)\n",
+        "\n",
+        "\n",
+        "class OrderedSlidingWindowFn(beam.DoFn):\n",
+        "\n",
+        "  ORDERED_BUFFER_STATE = OrderedListStateSpec('ordered_buffer', 
PickleCoder())\n",
+        "  WINDOW_TIMER = TimerSpec('window_timer', TimeDomain.WATERMARK)\n",
+        "  TIMER_STATE = ReadModifyWriteStateSpec('timer_state', 
BooleanCoder())\n",
+        "  EARLIEST_TS_STATE = ReadModifyWriteStateSpec('earliest_ts', 
TimestampCoder())\n",
+        "\n",
+        "  def __init__(self, window_size, slide_interval):\n",
+        "    self.window_size = window_size\n",
+        "    self.slide_interval = slide_interval\n",
+        "\n",
+        "  def start_bundle(self):\n",
+        "    _LOGGER.debug(\"start bundle\")\n",
+        "\n",
+        "  def finish_bundle(self):\n",
+        "    _LOGGER.debug(\"finish bundle\")\n",
+        "\n",
+        "  def process(\n",
+        "      self,\n",
+        "      element,\n",
+        "      timestamp=beam.DoFn.TimestampParam,\n",
+        "      ordered_buffer=beam.DoFn.StateParam(ORDERED_BUFFER_STATE),\n",
+        "      window_timer=beam.DoFn.TimerParam(WINDOW_TIMER),\n",
+        "      timer_state=beam.DoFn.StateParam(TIMER_STATE),\n",
+        "      earliest_ts_state=beam.DoFn.StateParam(EARLIEST_TS_STATE)):\n",
+        "\n",
+        "    _, value = element\n",
+        "    ordered_buffer.add((timestamp, value))\n",
+        "\n",
+        "    _LOGGER.debug(\"receive %s at %s\", element, timestamp)\n",
+        "    timer_started = timer_state.read()\n",
+        "\n",
+        "    earliest = earliest_ts_state.read()\n",
+        "    if not earliest or earliest > timestamp:\n",
+        "      earliest_ts_state.write(timestamp)\n",
+        "\n",
+        "    if not timer_started:\n",
+        "      earliest_ts_state.write(timestamp)\n",
+        "\n",
+        "      first_slide_start = int(\n",
+        "          timestamp.micros / 1e6 // self.slide_interval) * 
self.slide_interval\n",
+        "      first_slide_start_ts = Timestamp.of(first_slide_start)\n",
+        "\n",
+        "      first_window_end_ts = first_slide_start_ts + 
self.window_size\n",
+        "      _LOGGER.debug(\"set timer to %s\", first_window_end_ts)\n",
+        "      window_timer.set(first_window_end_ts)\n",
+        "\n",
+        "      timer_state.write(True)\n",
+        "\n",
+        "    return []\n",
+        "\n",
+        "  @on_timer(WINDOW_TIMER)\n",
+        "  def on_timer(\n",
+        "      self,\n",
+        "      key=beam.DoFn.KeyParam,\n",
+        "      fire_ts=beam.DoFn.TimestampParam,\n",
+        "      ordered_buffer=beam.DoFn.StateParam(ORDERED_BUFFER_STATE),\n",
+        "      window_timer=beam.DoFn.TimerParam(WINDOW_TIMER),\n",
+        "      timer_state=beam.DoFn.StateParam(TIMER_STATE),\n",
+        "      earliest_ts_state=beam.DoFn.StateParam(EARLIEST_TS_STATE)):\n",
+        "    _LOGGER.debug(\"timer fire at %s\", fire_ts)\n",
+        "    window_end_ts = fire_ts\n",
+        "    window_start_ts = window_end_ts - self.window_size\n",
+        "\n",
+        "    window_values = list(\n",
+        "        ordered_buffer.read_range(window_start_ts, window_end_ts))\n",
+        "\n",
+        "    _LOGGER.debug(\n",
+        "        \"window start: %s, window end: %s\", window_start_ts, 
window_end_ts)\n",
+        "    _LOGGER.debug(\"windowed data in buffer %s\", 
str(window_values))\n",
+        "    if window_values:\n",
+        "      yield (key, (window_start_ts, window_end_ts, window_values))\n",
+        "\n",
+        "    next_window_end_ts = fire_ts + self.slide_interval\n",
+        "    next_window_start_ts = window_start_ts + self.slide_interval\n",
+        "\n",
+        "    earliest_ts = earliest_ts_state.read()\n",
+        "    ordered_buffer.clear_range(earliest_ts, next_window_start_ts)\n",
+        "\n",
+        "    remaining_data = list(\n",
+        "        ordered_buffer.read_range(next_window_start_ts, 
MAX_TIMESTAMP))\n",
+        "\n",
+        "    if not remaining_data:\n",
+        "      timer_state.clear()\n",
+        "      earliest_ts_state.write(next_window_start_ts)\n",
+        "      return\n",
+        "\n",
+        "    _LOGGER.debug(\"set timer to %s\", next_window_end_ts)\n",
+        "    window_timer.set(next_window_end_ts)\n",
+        "\n",
+        "\n",
+        "class FillGapsFn(beam.DoFn):\n",
+        "  def __init__(self, expected_interval: float):\n",
+        "    \"\"\"\n",
+        "    Args:\n",
+        "      expected_interval: The expected time delta between elements, in 
seconds.\n",
+        "    \"\"\"\n",
+        "    self.expected_interval = expected_interval\n",
+        "\n",
+        "  def process(self, element):\n",
+        "    key, (window_start_ts, window_end_ts, window_elements) = 
element\n",
+        "\n",
+        "    received_data = {\n",
+        "        round(float(ts.micros / 1e6), 5): val\n",
+        "        for ts, val in window_elements\n",
+        "    }\n",
+        "\n",
+        "    start_sec = float(window_start_ts.micros / 1e6)\n",
+        "    end_sec = float(window_end_ts.micros / 1e6)\n",
+        "\n",
+        "    filled_values = []\n",
+        "    current_ts_sec = start_sec\n",
+        "\n",
+        "    while current_ts_sec < end_sec:\n",
+        "      lookup_ts = round(current_ts_sec, 5)\n",
+        "\n",
+        "      if lookup_ts in received_data:\n",
+        "        filled_values.append(float(received_data[lookup_ts]))\n",
+        "      else:\n",
+        "        filled_values.append('NaN')\n",
+        "\n",
+        "      current_ts_sec += self.expected_interval\n",
+        "\n",
+        "    yield (key, (window_start_ts, window_end_ts, filled_values))\n"
+      ],
+      "metadata": {
+        "id": "E1fHKPrkuLFW"
+      },
+      "execution_count": 2,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# Model 
Handler![detection.jpg](data:image/jpeg;base64,/9j/4AAQSkZJRgABAgAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAJlBaYDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqD
 [...]
+      ],
+      "metadata": {
+        "id": "aP8LqLobuViH"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "import apache_beam as beam\n",
+        "from apache_beam.ml.inference.base import ModelHandler\n",
+        "import timesfm\n",
+        "import logging\n",
+        "import numpy as np\n",
+        "import os\n",
+        "from google.cloud import storage\n",
+        "from apache_beam.io.gcp.gcsio import GcsIO\n",
+        "from apache_beam.utils.timestamp import Timestamp\n",
+        "\n",
+        "class LatestModelCheckpointLoader(beam.PTransform):\n",
+        "    \"\"\"A PTransform that finds the latest model checkpoint in a 
GCS path.\"\"\"\n",
+        "    def __init__(self, gcs_bucket, gcs_prefix):\n",
+        "        self.gcs_bucket = gcs_bucket\n",
+        "        self.gcs_prefix = gcs_prefix\n",
+        "\n",
+        "    def expand(self, pcoll):\n",
+        "        return pcoll | \"FindLatestModel\" >> 
beam.Map(self._find_latest_model_path)\n",
+        "\n",
+        "    def _find_latest_model_path(self, _):\n",
+        "        try:\n",
+        "            storage_client = storage.Client()\n",
+        "            blobs = storage_client.list_blobs(self.gcs_bucket, 
prefix=self.gcs_prefix)\n",
+        "            # Filter for model files and find the most recent one\n",
+        "            model_blobs = [b for b in blobs if 
b.name.endswith(\".pth\")]\n",
+        "            latest_blob = max(model_blobs, key=lambda b: 
b.time_created, default=None)\n",
+        "\n",
+        "            if latest_blob:\n",
+        "                path = 
f\"gs://{self.gcs_bucket}/{latest_blob.name}\"\n",
+        "                logging.info(f\"Found latest finetuned model at: 
{path}\")\n",
+        "                return path\n",
+        "        except Exception as e:\n",
+        "            logging.error(f\"Error finding latest model in GCS: 
{e}\")\n",
+        "\n",
+        "        # Return a path to the base model if no finetuned one exists 
or an error occurs\n",
+        "        base_model = \"google/timesfm-1.0-200m-pytorch\"\n",
+        "        logging.info(f\"No finetuned model found. Using base model: 
{base_model}\")\n",
+        "        return base_model\n",
+        "\n",
+        "class DynamicTimesFmModelHandler(ModelHandler[np.ndarray, np.ndarray, 
timesfm.TimesFm]):\n",
+        "    \"\"\"\n",
+        "    A model handler that loads a TimesFM model from a dynamic path 
(GCS or Hugging Face).\n",
+        "    The model path is provided as a side input to RunInference.\n",
+        "    \"\"\"\n",
+        "    def __init__(self, model_uri: str, hparams):\n",
+        "        self._hparams = hparams\n",
+        "        self._model = None\n",
+        "        self._model_uri = model_uri\n",
+        "        self._context_len = hparams.context_len\n",
+        "        self._horizon_len = hparams.horizon_len\n",
+        "\n",
+        "    def load_model(self) -> timesfm.TimesFm:\n",
+        "        \"\"\"Loads a model from the handler's current 
model_uri.\"\"\"\n",
+        "        logging.info(f\"Loading TimesFM model from path: 
{self._model_uri}...\")\n",
+        "\n",
+        "        checkpoint_config = {}\n",
+        "        if self._model_uri.startswith(\"gs://\"):\n",
+        "            try:\n",
+        "                gcs = GcsIO()\n",
+        "                file_name = os.path.basename(self._model_uri)\n",
+        "                local_path = f\"/tmp/{file_name}\"\n",
+        "                with gcs.open(self._model_uri, 'rb') as f_in, 
open(local_path, 'wb') as f_out:\n",
+        "                    f_out.write(f_in.read())\n",
+        "                checkpoint_config['path'] = local_path\n",
+        "                logging.info(f\"Downloaded model from GCS to 
{local_path}\")\n",
+        "            except Exception as e:\n",
+        "                logging.error(f\"Failed to download model from GCS: 
{e}. Check path and permissions.\")\n",
+        "                raise e # Re-raise the exception to fail fast if the 
model can't be loaded.\n",
+        "        else:\n",
+        "            checkpoint_config['huggingface_repo_id'] = 
self._model_uri\n",
+        "\n",
+        "        self._model = timesfm.TimesFm(\n",
+        "            hparams=self._hparams,\n",
+        "            
checkpoint=timesfm.TimesFmCheckpoint(**checkpoint_config)\n",
+        "        )\n",
+        "        logging.info(\"TimesFM model loaded successfully.\")\n",
+        "        return self._model\n",
+        "\n",
+        "    def update_model_path(self, model_path: str):\n",
+        "        \"\"\"\n",
+        "        This method is called by RunInference when a new model 
metadata is available\n",
+        "        from the side input. It updates the model URI that 
`load_model` will use.\n",
+        "        \"\"\"\n",
+        "        if not model_path:\n",
+        "            logging.info(\"Received an empty model path update. No 
action taken.\")\n",
+        "            return\n",
+        "        logging.info(f\"Received model update. New model URI: 
{model_path}\")\n",
+        "        self._model_uri = model_path\n",
+        "        self._model = self.load_model()\n",
+        "        logging.info(\"Model has been updated in the handler.\")\n",
+        "\n",
+        "    def run_inference(self, batch, model, inference_args=None):\n",
+        "        \"\"\"\n",
+        "            Runs inference on a batch of data.\n",
+        "\n",
+        "            Note: While this is a standard method for ModelHandler, 
we will call the\n",
+        "            model's `forecast` method directly in our DoFn for 
clarity.\n",
+        "            \"\"\"\n",
+        "        # print(\"Running inference on batch:\", batch)\n",
+        "        # logging.info(f\"Running inference on batch:\", batch)\n",
+        "\n",
+        "        anomalies_found = []\n",
+        "\n",
+        "        key, (window_start_ts, _, values_array) = batch[0]\n",
+        "\n",
+        "        # A window must have enough data for both context and 
horizon.\n",
+        "        # if len(values_array) < self.context_len + 
self.horizon_len:\n",
+        "        #     return\n",
+        "\n",
+        "        current_context = 
np.array(values_array[:self._context_len])\n",
+        "        actual_horizon_values = np.array(\n",
+        "            values_array[self._context_len:self._context_len + 
self._horizon_len])\n",
+        "\n",
+        "        print(\"Current context shape:\", current_context.shape)\n",
+        "        print(\"Actual horizon values shape:\", 
actual_horizon_values.shape)\n",
+        "        point_forecast, experimental_quantile_forecast = 
model.forecast(\n",
+        "            [current_context],\n",
+        "            freq=[0],\n",
+        "        )\n",
+        "\n",
+        "        current_predicted_horizon_values = point_forecast[\n",
+        "            0, :, 0] if point_forecast.ndim == 3 else 
point_forecast[0]\n",
+        "\n",
+        "        current_q20_values = experimental_quantile_forecast[0, :, 
2]\n",
+        "        current_q30_values = experimental_quantile_forecast[0, :, 
3]\n",
+        "        current_q70_values = experimental_quantile_forecast[0, :, 
7]\n",
+        "        current_q80_values = experimental_quantile_forecast[0, :, 
8]\n",
+        "\n",
+        "        for j in range(len(actual_horizon_values)):\n",
+        "            current_actual = actual_horizon_values[j]\n",
+        "\n",
+        "            point_Q1 = np.nanmean([current_q20_values[j], 
current_q30_values[j]])\n",
+        "            point_Q3 = np.nanmean([current_q70_values[j], 
current_q80_values[j]])\n",
+        "            point_IQR = point_Q3 - point_Q1\n",
+        "\n",
+        "            upper_thresh = point_Q3 + 1.5 * point_IQR\n",
+        "            lower_thresh = point_Q1 - 1.5 * point_IQR\n",
+        "\n",
+        "            if current_actual > upper_thresh or current_actual < 
lower_thresh:\n",
+        "                score = (current_actual - upper_thresh\n",
+        "                            ) / point_IQR if current_actual > 
upper_thresh else (\n",
+        "                                lower_thresh - current_actual) / 
point_IQR\n",
+        "\n",
+        "                anomaly_timestamp_seconds = (window_start_ts.micros / 
1e6) + (\n",
+        "                    self._context_len + j)\n",
+        "\n",
+        "                index_in_window = self._context_len + j\n",
+        "\n",
+        "                anomalies_found.append({\n",
+        "                    'key': key,\n",
+        "                    'timestamp': 
Timestamp(anomaly_timestamp_seconds),\n",
+        "                    'index_in_window': index_in_window,\n",
+        "                    'actual_value': current_actual,\n",
+        "                    'predicted_value': 
current_predicted_horizon_values[j],\n",
+        "                    'is_anomaly': True,\n",
+        "                    'outlier_score': score,\n",
+        "                    'lower_bound': lower_thresh,\n",
+        "                    'upper_bound': upper_thresh,\n",
+        "                })\n",
+        "        payload = {\n",
+        "            \"start_ts_micros\": window_start_ts.micros,\n",
+        "            \"predicted_values\": 
current_predicted_horizon_values.tolist(),\n",
+        "            \"q20_values\": current_q20_values.tolist(),\n",
+        "            \"q30_values\": current_q30_values.tolist(),\n",
+        "            \"q70_values\": current_q70_values.tolist(),\n",
+        "            \"q80_values\": current_q80_values.tolist(),\n",
+        "            \"anomalies\": anomalies_found,  # Your original list is 
now inside the dictionary\n",
+        "            \"actual_horizon_values\": 
actual_horizon_values.tolist()\n",
+        "        }\n",
+        "        result_with_context = (batch[0], payload)\n",
+        "\n",
+        "        return [result_with_context]"
+      ],
+      "metadata": {
+        "id": "oT9NIaWcuUgb",
+        "colab": {
+          "base_uri": "https://localhost:8080/";
+        },
+        "outputId": "a26eae71-067a-4314-b49a-3b028ee75903"
+      },
+      "execution_count": 3,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            " See 
https://github.com/google-research/timesfm/blob/master/README.md for updated 
APIs.\n",
+            "Loaded PyTorch TimesFM, likely because python version is 3.11.13 
(main, Jun  4 2025, 08:57:29) [GCC 11.4.0].\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# LLM 
Classifier![classification.jpg](data:image/jpeg;base64,/9j/4AAQSkZJRgABAgAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAOFBV4DASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dn
 [...]
+      ],
+      "metadata": {
+        "id": "CU9zuwUUu7tX"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "import apache_beam as beam\n",
+        "import google.generativeai as genai\n",
+        "import logging\n",
+        "import os\n",
+        "import re\n",
+        "import json\n",
+        "import numpy as np\n",
+        "from apache_beam.utils.timestamp import Timestamp\n",
+        "from dotenv import load_dotenv\n",
+        "from apache_beam.transforms.userstate import BagStateSpec\n",
+        "\n",
+        "import apache_beam as beam\n",
+        "import json\n",
+        "import numpy as np\n",
+        "\n",
+        "from apache_beam.coders.coders import PickleCoder\n",
+        "\n",
+        "from apache_beam.transforms.userstate import BagStateSpec, 
ReadModifyWriteStateSpec, TimerSpec, on_timer\n",
+        "\n",
+        "\n",
+        "class CustomJsonEncoderForLLM(json.JSONEncoder):\n",
+        "    \"\"\"Encodes special types like Timestamp and numpy objects into 
JSON.\"\"\"\n",
+        "    def default(self, obj):\n",
+        "        if isinstance(obj, Timestamp):\n",
+        "            # Store as a dict with a special key for easy decoding\n",
+        "            return {'__timestamp__': True, 'micros': obj.micros}\n",
+        "        if isinstance(obj, np.integer):\n",
+        "            return int(obj)\n",
+        "        if isinstance(obj, np.floating):\n",
+        "            return float(obj)\n",
+        "        if isinstance(obj, np.ndarray):\n",
+        "            return obj.tolist()\n",
+        "        return super().default(obj)\n",
+        "\n",
+        "def custom_json_decoder(dct):\n",
+        "    \"\"\"Decodes a Timestamp object from our custom dict 
format.\"\"\"\n",
+        "    if '__timestamp__' in dct:\n",
+        "        return Timestamp(micros=dct['micros'])\n",
+        "    return dct\n",
+        "\n",
+        "class JsonCoderWithNumpyAndTimestamp(beam.coders.Coder):\n",
+        "    \"\"\"A custom Beam Coder that handles JSON serialization for 
Timestamps and numpy types.\"\"\"\n",
+        "    def encode(self, value):\n",
+        "        return json.dumps(value, 
cls=CustomJsonEncoderForLLM).encode('utf-8')\n",
+        "\n",
+        "    def decode(self, encoded):\n",
+        "        return json.loads(encoded.decode('utf-8'), 
object_hook=custom_json_decoder)\n",
+        "\n",
+        "    def is_deterministic(self):\n",
+        "        return True\n",
+        "\n",
+        "\n",
+        "# It's highly recommended to manage API keys via GCP Secret 
Manager\n",
+        "# and access them as environment variables in your Dataflow job.\n",
+        "# genai.configure(api_key=os.environ[\"GEMINI_API_KEY\"])\n",
+        "\n",
+        "class LLMClassifierFn(beam.DoFn):\n",
+        "    \"\"\"\n",
+        "    Takes an anomaly, formats a detailed prompt with surrounding 
context,\n",
+        "    calls the Gemini model to classify it, and routes the original 
data\n",
+        "    based on the model's decision.\n",
+        "\n",
+        "    This DoFn is stateful, deferring anomalies that occur too close 
to\n",
+        "    the end of a window until a subsequent window provides enough 
context.\n",
+        "    \"\"\"\n",
+        "\n",
+        "    DEFERRED_ANOMALIES_STATE = BagStateSpec(\n",
+        "        'deferred_anomalies', 
coder=JsonCoderWithNumpyAndTimestamp())\n",
+        "    YIELD_BUFFER_STATE = ReadModifyWriteStateSpec('yield_buffer', 
PickleCoder())\n",
+        "\n",
+        "    # <<< CHANGE: Define a timer and a state to track if it's set\n",
+        "    EXPIRY_TIMER = TimerSpec('expiry', beam.TimeDomain.WATERMARK)\n",
+        "    # <<< CHANGE: Add state to track the last yielded timestamp\n",
+        "    LAST_YIELDED_TIMESTAMP_STATE = 
ReadModifyWriteStateSpec('last_yielded_ts', PickleCoder())\n",
+        "\n",
+        "\n",
+        "\n",
+        "\n",
+        "    def __init__(self, secret, context_points=25, slide_interval=128, 
expected_interval_secs=1):\n",
+        "        self.context_points = context_points\n",
+        "        self._model = None\n",
+        "        self.secret = secret\n",
+        "        self.slide_interval = slide_interval\n",
+        "        self.expected_interval_micros = expected_interval_secs * 
1_000_000\n",
+        "\n",
+        "        self._last_window_data = None\n",
+        "\n",
+        "\n",
+        "    def setup(self):\n",
+        "        # Configure the generative model\n",
+        "\n",
+        "        genai.configure(api_key=self.secret)\n",
+        "        logging.getLogger().setLevel(logging.INFO)\n",
+        "\n",
+        "\n",
+        "        generation_config = {\n",
+        "            \"temperature\": 0.2,\n",
+        "            \"top_p\": 1,\n",
+        "            \"top_k\": 1,\n",
+        "            \"max_output_tokens\": 256,\n",
+        "            \"response_mime_type\": \"application/json\",\n",
+        "        }\n",
+        "        # For a full list of safety settings, see the Gemini API 
documentation\n",
+        "        safety_settings = [\n",
+        "            {\"category\": \"HARM_CATEGORY_HARASSMENT\", 
\"threshold\": \"BLOCK_NONE\"},\n",
+        "            {\"category\": \"HARM_CATEGORY_HATE_SPEECH\", 
\"threshold\": \"BLOCK_NONE\"},\n",
+        "        ]\n",
+        "        self._model = genai.GenerativeModel(\n",
+        "            model_name=\"gemini-1.5-flash-latest\",\n",
+        "            generation_config=generation_config,\n",
+        "            safety_settings=safety_settings\n",
+        "        )\n",
+        "        logging.info(\"Gemini Model has been successfully 
initialized.\")\n",
+        "\n",
+        "    def _build_prompt(self, anomaly_data, context_before, 
context_after):\n",
+        "        mean_before = np.mean(context_before) if context_before.size 
> 0 else 0\n",
+        "        mean_after = np.mean(context_after) if context_after.size > 0 
else 0\n",
+        "        std_before = np.std(context_before) if context_before.size > 
0 else 0\n",
+        "        std_after = np.std(context_after) if context_after.size > 0 
else 0\n",
+        "\n",
+        "        return f\"\"\"\n",
+        "        You are an expert time-series analyst classifying an outlier 
from NYC taxi pickup data.\n",
+        "        Normal behavior includes daily and weekly cyclical 
patterns.\n",
+        "\n",
+        "        **1. Outlier Context:**\n",
+        "        * **--> The Outlier:**\n",
+        "            * **Timestamp:** 
{Timestamp(micros=anomaly_data['timestamp'].micros)}\n",
+        "            * **Actual Value:** {anomaly_data['actual_value']:.2f}\n",
+        "            * **Predicted Value:** 
{anomaly_data['predicted_value']:.2f}\n",
+        "            * **Anomaly Upper Bound:** 
{anomaly_data['upper_bound']:.2f}\n",
+        "            * **Anomaly Lower Bound:** 
{anomaly_data['lower_bound']:.2f}\n",
+        "\n",
+        "        **2. Data Surrounding the Outlier:**\n",
+        "        * **Data Before ({len(context_before)} points):** 
{np.round(context_before, 2).tolist()}\n",
+        "        * **Data After ({len(context_after)} points):** 
{np.round(context_after, 2).tolist()}\n",
+        "\n",
+        "        **3. Statistical Context:**\n",
+        "        * **Mean Before:** {mean_before:.2f}\n",
+        "        * **Mean After:** {mean_after:.2f}\n",
+        "        * **Std. Dev. Before:** {std_before:.2f}\n",
+        "        * **Std. Dev. After:** {std_after:.2f}\n",
+        "\n",
+        "        **4. Your Task:**\n",
+        "\n",
+        "        **Step 1: Analyze the Evidence.** In a few sentences, 
describe the behavior of the data *after* the outlier. Does it quickly revert 
to the \"Predicted Value\" or the \"Mean Before\"? Or does it establish a new 
level, closer to the \"Mean After\"?\n",
+        "\n",
+        "        **Step 2: Make a Decision.** Classify the outlier.\n",
+        "        * **REMOVE:** If it's a transient, one-off event. This is 
likely if the data after the outlier rapidly returns to the established 
pattern.\n",
+        "        * **KEEP:** If it signifies a sustained shift in the pattern 
that the model should learn from. This is likely if the `Mean After` has 
shifted significantly.\n",
+        "\n",
+        "        **Step 3: Provide Final Output.** Respond with a single JSON 
object. Do not add any text outside the JSON block.\n",
+        "\n",
+        "        {{\n",
+        "          \"reasoning_steps\": \"Your analysis from Step 1 goes 
here.\",\n",
+        "          \"decision\": \"KEEP or REMOVE\",\n",
+        "          \"confidence_score\": <A float between 0.0 and 1.0>\n",
+        "        }}\n",
+        "        \"\"\"\n",
+        "\n",
+        "    def process(self, element,\n",
+        "                
deferred_anomalies=beam.DoFn.StateParam(DEFERRED_ANOMALIES_STATE),\n",
+        "                
yield_buffer=beam.DoFn.StateParam(YIELD_BUFFER_STATE),\n",
+        "                expiry_timer=beam.DoFn.TimerParam(EXPIRY_TIMER)):\n",
+        "\n",
+        "        key, data = element\n",
+        "        window_start_ts = data['window_start_ts']\n",
+        "\n",
+        "        # Set a timer to fire based on the event time of the current 
element.\n",
+        "        # Each new element will push the timer forward. The timer 
will only\n",
+        "        # fire when a gap in the input stream occurs, allowing the 
buffer\n",
+        "        # to contain data from multiple consecutive sliding 
windows.\n",
+        "        # We set it far enough ahead to allow the next window's data 
to arrive.\n",
+        "        grace_period_secs = self.slide_interval * 2\n",
+        "        expiry_timer.set(window_start_ts + grace_period_secs)\n",
+        "        anomalies_in_window = data.get('anomalies', [])\n",
+        "        values_in_element = data.get('values_array', [])\n",
+        "\n",
+        "        for anomaly in anomalies_in_window:\n",
+        "             deferred_anomalies.add(anomaly)\n",
+        "\n",
+        "        buffer = yield_buffer.read() or {}\n",
+        "        for i, value in enumerate(values_in_element):\n",
+        "            point_timestamp = Timestamp(micros=window_start_ts.micros 
+ (i * self.expected_interval_micros))\n",
+        "            buffer[point_timestamp] = value\n",
+        "        yield_buffer.write(buffer)\n",
+        "\n",
+        "    @on_timer(EXPIRY_TIMER)\n",
+        "    def on_expiry_timer(\n",
+        "        self,\n",
+        "        
deferred_anomalies=beam.DoFn.StateParam(DEFERRED_ANOMALIES_STATE),\n",
+        "        yield_buffer=beam.DoFn.StateParam(YIELD_BUFFER_STATE),\n",
+        "        # <<< CHANGE: Add the new state parameter here\n",
+        "        
last_yielded_ts_state=beam.DoFn.StateParam(LAST_YIELDED_TIMESTAMP_STATE)):\n",
+        "\n",
+        "        all_anomalies_to_consider = 
list(deferred_anomalies.read())\n",
+        "        buffered_points_map = yield_buffer.read() or {}\n",
+        "\n",
+        "        if not buffered_points_map:\n",
+        "            return\n",
+        "\n",
+        "        sorted_points = sorted(buffered_points_map.items())\n",
+        "        all_timestamps = [ts for ts, val in sorted_points]\n",
+        "        all_values = [val for ts, val in sorted_points]\n",
+        "\n",
+        "        anomalies_to_process_now = []\n",
+        "        prompts_to_batch = []\n",
+        "        final_deferred = []\n",
+        "\n",
+        "        for anomaly_data in all_anomalies_to_consider:\n",
+        "            anomaly_ts = anomaly_data['timestamp']\n",
+        "            try:\n",
+        "                idx_in_full_data = 
all_timestamps.index(anomaly_ts)\n",
+        "\n",
+        "                if (idx_in_full_data + self.context_points) < 
len(all_values):\n",
+        "                    start_ctx = max(0, idx_in_full_data - 
self.context_points)\n",
+        "                    end_ctx = idx_in_full_data + self.context_points 
+ 1\n",
+        "\n",
+        "                    context_before = 
np.array(all_values[start_ctx:idx_in_full_data])\n",
+        "                    context_after = 
np.array(all_values[idx_in_full_data + 1:end_ctx])\n",
+        "\n",
+        "                    anomaly_data['index_in_window'] = 
idx_in_full_data\n",
+        "                    prompt = self._build_prompt(anomaly_data, 
context_before, context_after)\n",
+        "                    prompts_to_batch.append(prompt)\n",
+        "                    anomalies_to_process_now.append(anomaly_data)\n",
+        "                else:\n",
+        "                    final_deferred.append(anomaly_data)\n",
+        "            except ValueError:\n",
+        "                 final_deferred.append(anomaly_data)\n",
+        "\n",
+        "        if prompts_to_batch:\n",
+        "            try:\n",
+        "                logging.info(f\"Sending a batch of 
{len(prompts_to_batch)} prompts to the LLM.\")\n",
+        "                responses = 
self._model.generate_content(prompts_to_batch)\n",
+        "                for anomaly_data, response in 
zip(anomalies_to_process_now, responses):\n",
+        "                    try:\n",
+        "                        response_data = json.loads(response.text)\n",
+        "                        decision = response_data.get('decision', 
'KEEP').strip().upper()\n",
+        "                        idx = anomaly_data['index_in_window']\n",
+        "\n",
+        "                        if decision == 'REMOVE':\n",
+        "                            logging.warning(f\"LLM decided to REMOVE 
anomaly at {anomaly_data['timestamp']}. Imputing value.\")\n",
+        "                            all_values[idx] = 
anomaly_data['predicted_value']\n",
+        "                    except (json.JSONDecodeError, AttributeError) as 
e:\n",
+        "                        logging.error(f\"Error processing LLM 
response for {anomaly_data['timestamp']}: {e}. Defaulting to KEEP.\")\n",
+        "            except Exception as e:\n",
+        "                logging.error(f\"Error calling LLM with a batch: {e}. 
Defaulting to KEEP for all.\")\n",
+        "\n",
+        "        # <<< CHANGE: New logic to yield only new data\n",
+        "        last_yielded_ts = last_yielded_ts_state.read()\n",
+        "        latest_ts_in_batch = None\n",
+        "\n",
+        "        for i, (ts, original_val) in enumerate(sorted_points):\n",
+        "            # Only yield points that are newer than the last batch we 
yielded\n",
+        "            if last_yielded_ts is None or ts > last_yielded_ts:\n",
+        "                yield {\n",
+        "                    'timestamp': ts,\n",
+        "                    'value': all_values[i]\n",
+        "                }\n",
+        "                latest_ts_in_batch = ts\n",
+        "\n",
+        "        # After yielding, update the state with the latest timestamp 
from this batch\n",
+        "        if latest_ts_in_batch:\n",
+        "            last_yielded_ts_state.write(latest_ts_in_batch)\n",
+        "\n",
+        "        # Prune the buffer. We need to keep enough historical data to 
serve\n",
+        "        # as `context_before` for the anomalies that we are 
re-deferring.\n",
+        "        if latest_ts_in_batch:\n",
+        "            all_buffered_points = yield_buffer.read() or {}\n",
+        "\n",
+        "            # Find the earliest timestamp we need to keep. This will 
be\n",
+        "            # `context_points` before the last yielded point, 
ensuring\n",
+        "            # context is available for the next batch.\n",
+        "            try:\n",
+        "                last_yielded_index = 
all_timestamps.index(latest_ts_in_batch)\n",
+        "                context_start_index = max(0, last_yielded_index - 
self.context_points)\n",
+        "                context_start_ts = 
all_timestamps[context_start_index]\n",
+        "\n",
+        "                pruned_buffer = {\n",
+        "                    ts: val\n",
+        "                    for ts, val in all_buffered_points.items()\n",
+        "                    if ts >= context_start_ts\n",
+        "                }\n",
+        "                yield_buffer.write(pruned_buffer)\n",
+        "            except ValueError:\n",
+        "                # This can happen if the buffer is in an inconsistent 
state.\n",
+        "                # As a fallback, we clear it if we aren't deferring 
anything.\n",
+        "                logging.warning(\n",
+        "                    f\"Could not find last yielded timestamp \"\n",
+        "                    f\"{latest_ts_in_batch} in buffer for 
pruning.\"\n",
+        "                )\n",
+        "                if not final_deferred:\n",
+        "                    yield_buffer.clear()\n",
+        "        elif not final_deferred:\n",
+        "            # If we didn't yield anything and we're not deferring 
anything,\n",
+        "            # the buffer is fully processed and can be cleared.\n",
+        "            yield_buffer.clear()\n",
+        "\n",
+        "        # Re-add anomalies that couldn't be processed to the state so 
they can\n",
+        "        # be considered in the next firing.\n",
+        "        deferred_anomalies.clear()\n",
+        "        if final_deferred:\n",
+        "            logging.info(f\"Re-deferring {len(final_deferred)} 
anomalies due to insufficient context.\")\n",
+        "            for anomaly in final_deferred:\n",
+        "                deferred_anomalies.add(anomaly)\n"
+      ],
+      "metadata": {
+        "id": "c55ou9f5vADf"
+      },
+      "execution_count": 4,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# Finetuning Component"
+      ],
+      "metadata": {
+        "id": "WSl5lV_9ugQY"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "\"\"\"\n",
+        "TimesFM Finetuner: A flexible framework for finetuning TimesFM models 
on custom datasets.\n",
+        "\"\"\"\n",
+        "\n",
+        "import logging\n",
+        "import os\n",
+        "from abc import ABC, abstractmethod\n",
+        "from dataclasses import dataclass, field\n",
+        "from typing import Any, Callable, Dict, List, Optional\n",
+        "\n",
+        "import torch\n",
+        "import torch.distributed as dist\n",
+        "import torch.nn as nn\n",
+        "from torch.nn.parallel import DistributedDataParallel as DDP\n",
+        "from torch.utils.data import DataLoader, Dataset\n",
+        "from timesfm.pytorch_patched_decoder import create_quantiles\n",
+        "\n",
+        "import wandb\n",
+        "\n",
+        "\n",
+        "class MetricsLogger(ABC):\n",
+        "  \"\"\"Abstract base class for logging metrics during training.\n",
+        "\n",
+        "    This class defines the interface for logging metrics during model 
training.\n",
+        "    Concrete implementations can log to different backends (e.g., 
WandB, TensorBoard).\n",
+        "    \"\"\"\n",
+        "\n",
+        "  @abstractmethod\n",
+        "  def log_metrics(self,\n",
+        "                  metrics: Dict[str, Any],\n",
+        "                  step: Optional[int] = None) -> None:\n",
+        "    \"\"\"Log metrics to the specified backend.\n",
+        "\n",
+        "        Args:\n",
+        "          metrics: Dictionary containing metric names and values.\n",
+        "          step: Optional step number or epoch for the metrics.\n",
+        "        \"\"\"\n",
+        "    pass\n",
+        "\n",
+        "  @abstractmethod\n",
+        "  def close(self) -> None:\n",
+        "    \"\"\"Clean up any resources used by the logger.\"\"\"\n",
+        "    pass\n",
+        "\n",
+        "\n",
+        "class WandBLogger(MetricsLogger):\n",
+        "  \"\"\"Weights & Biases implementation of metrics logging.\n",
+        "\n",
+        "    Args:\n",
+        "      project: Name of the W&B project.\n",
+        "      config: Configuration dictionary to log.\n",
+        "      rank: Process rank in distributed training.\n",
+        "    \"\"\"\n",
+        "\n",
+        "  def __init__(self, project: str, config: Dict[str, Any], rank: int 
= 0):\n",
+        "    self.rank = rank\n",
+        "    if rank == 0:\n",
+        "      wandb.init(project=project, config=config)\n",
+        "\n",
+        "  def log_metrics(self,\n",
+        "                  metrics: Dict[str, Any],\n",
+        "                  step: Optional[int] = None) -> None:\n",
+        "    \"\"\"Log metrics to W&B if on the main process.\n",
+        "\n",
+        "        Args:\n",
+        "          metrics: Dictionary of metrics to log.\n",
+        "          step: Current training step or epoch.\n",
+        "        \"\"\"\n",
+        "    if self.rank == 0:\n",
+        "      wandb.log(metrics, step=step)\n",
+        "\n",
+        "  def close(self) -> None:\n",
+        "    \"\"\"Finish the W&B run if on the main process.\"\"\"\n",
+        "    if self.rank == 0:\n",
+        "      wandb.finish()\n",
+        "\n",
+        "\n",
+        "class DistributedManager:\n",
+        "  \"\"\"Manages distributed training setup and cleanup.\n",
+        "\n",
+        "    Args:\n",
+        "      world_size: Total number of processes.\n",
+        "      rank: Process rank.\n",
+        "      master_addr: Address of the master process.\n",
+        "      master_port: Port for distributed communication.\n",
+        "      backend: PyTorch distributed backend to use.\n",
+        "    \"\"\"\n",
+        "\n",
+        "  def __init__(\n",
+        "      self,\n",
+        "      world_size: int,\n",
+        "      rank: int,\n",
+        "      master_addr: str = \"localhost\",\n",
+        "      master_port: str = \"12358\",\n",
+        "      backend: str = \"nccl\",\n",
+        "  ):\n",
+        "    self.world_size = world_size\n",
+        "    self.rank = rank\n",
+        "    self.master_addr = master_addr\n",
+        "    self.master_port = master_port\n",
+        "    self.backend = backend\n",
+        "\n",
+        "  def setup(self) -> None:\n",
+        "    \"\"\"Initialize the distributed environment.\"\"\"\n",
+        "    os.environ[\"MASTER_ADDR\"] = self.master_addr\n",
+        "    os.environ[\"MASTER_PORT\"] = self.master_port\n",
+        "\n",
+        "    if not dist.is_initialized():\n",
+        "      dist.init_process_group(backend=self.backend,\n",
+        "                              world_size=self.world_size,\n",
+        "                              rank=self.rank)\n",
+        "\n",
+        "  def cleanup(self) -> None:\n",
+        "    \"\"\"Clean up the distributed environment.\"\"\"\n",
+        "    if dist.is_initialized():\n",
+        "      dist.destroy_process_group()\n",
+        "\n",
+        "\n",
+        "@dataclass\n",
+        "class FinetuningConfig:\n",
+        "  \"\"\"Configuration for model training.\n",
+        "\n",
+        "    Args:\n",
+        "      batch_size: Number of samples per batch.\n",
+        "      num_epochs: Number of training epochs.\n",
+        "      learning_rate: Initial learning rate.\n",
+        "      weight_decay: L2 regularization factor.\n",
+        "      freq_type: Frequency, can be [0, 1, 2].\n",
+        "      use_quantile_loss: bool = False  # Flag to enable/disable 
quantile loss\n",
+        "      quantiles: Optional[List[float]] = None\n",
+        "      device: Device to train on ('cuda' or 'cpu').\n",
+        "      distributed: Whether to use distributed training.\n",
+        "      gpu_ids: List of GPU IDs to use.\n",
+        "      master_port: Port for distributed training.\n",
+        "      master_addr: Address for distributed training.\n",
+        "      use_wandb: Whether to use Weights & Biases logging.\n",
+        "      wandb_project: W&B project name.\n",
+        "      log_every_n_steps: Log metrics every N steps (batches), this is 
inspired from Pytorch Lightning\n",
+        "      val_check_interval: How often within one training epoch to 
check val metrics. (also from Pytorch Lightning)\n",
+        "        Can be: float (0.0-1.0): fraction of epoch (e.g., 0.5 = 
validate twice per epoch)\n",
+        "                int: validate every N batches\n",
+        "    \"\"\"\n",
+        "\n",
+        "  batch_size: int = 32\n",
+        "  num_epochs: int = 20\n",
+        "  learning_rate: float = 1e-4\n",
+        "  weight_decay: float = 0.01\n",
+        "  freq_type: int = 0\n",
+        "  use_quantile_loss: bool = False\n",
+        "  quantiles: Optional[List[float]] = None\n",
+        "  device: str = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
+        "  distributed: bool = False\n",
+        "  gpu_ids: List[int] = field(default_factory=lambda: [0])\n",
+        "  master_port: str = \"12358\"\n",
+        "  master_addr: str = \"localhost\"\n",
+        "  use_wandb: bool = False\n",
+        "  wandb_project: str = \"timesfm-finetuning\"\n",
+        "  log_every_n_steps: int = 50\n",
+        "  val_check_interval: float = 0.5\n",
+        "\n",
+        "\n",
+        "class TimesFMFinetuner:\n",
+        "  \"\"\"Handles model training and validation.\n",
+        "\n",
+        "    Args:\n",
+        "      model: PyTorch model to train.\n",
+        "      config: Training configuration.\n",
+        "      rank: Process rank for distributed training.\n",
+        "      loss_fn: Loss function (defaults to MSE).\n",
+        "      logger: Optional logging.Logger instance.\n",
+        "    \"\"\"\n",
+        "\n",
+        "  def __init__(\n",
+        "      self,\n",
+        "      model: nn.Module,\n",
+        "      config: FinetuningConfig,\n",
+        "      rank: int = 0,\n",
+        "      loss_fn: Optional[Callable] = None,\n",
+        "      logger: Optional[logging.Logger] = None,\n",
+        "  ):\n",
+        "    self.model = model\n",
+        "    self.config = config\n",
+        "    self.rank = rank\n",
+        "    self.logger = logger or logging.getLogger(__name__)\n",
+        "    self.device = torch.device(\n",
+        "        f\"cuda:{rank}\" if torch.cuda.is_available() else 
\"cpu\")\n",
+        "    self.loss_fn = loss_fn or (lambda x, y: torch.mean((x - 
y.squeeze(-1))**2))\n",
+        "\n",
+        "    if config.use_wandb:\n",
+        "      self.metrics_logger = WandBLogger(config.wandb_project, 
config.__dict__,\n",
+        "                                        rank)\n",
+        "\n",
+        "    if config.distributed:\n",
+        "      self.dist_manager = DistributedManager(\n",
+        "          world_size=len(config.gpu_ids),\n",
+        "          rank=rank,\n",
+        "          master_addr=config.master_addr,\n",
+        "          master_port=config.master_port,\n",
+        "      )\n",
+        "      self.dist_manager.setup()\n",
+        "      self.model = self._setup_distributed_model()\n",
+        "\n",
+        "  def _setup_distributed_model(self) -> nn.Module:\n",
+        "    \"\"\"Configure model for distributed training.\"\"\"\n",
+        "    self.model = self.model.to(self.device)\n",
+        "    return DDP(self.model,\n",
+        "               device_ids=[self.config.gpu_ids[self.rank]],\n",
+        "               output_device=self.config.gpu_ids[self.rank])\n",
+        "\n",
+        "  def _create_dataloader(self, dataset: Dataset, is_train: bool) -> 
DataLoader:\n",
+        "    \"\"\"Create appropriate DataLoader based on training 
configuration.\n",
+        "\n",
+        "        Args:\n",
+        "          dataset: Dataset to create loader for.\n",
+        "          is_train: Whether this is for training (affects 
shuffling).\n",
+        "\n",
+        "        Returns:\n",
+        "          DataLoader instance.\n",
+        "        \"\"\"\n",
+        "    if self.config.distributed:\n",
+        "      sampler = torch.utils.data.distributed.DistributedSampler(\n",
+        "          dataset,\n",
+        "          num_replicas=len(self.config.gpu_ids),\n",
+        "          rank=dist.get_rank(),\n",
+        "          shuffle=is_train)\n",
+        "    else:\n",
+        "      sampler = None\n",
+        "\n",
+        "    return DataLoader(\n",
+        "        dataset,\n",
+        "        batch_size=self.config.batch_size,\n",
+        "        shuffle=(is_train and not self.config.distributed),\n",
+        "        sampler=sampler,\n",
+        "    )\n",
+        "\n",
+        "  def _quantile_loss(self, pred: torch.Tensor, actual: 
torch.Tensor,\n",
+        "                     quantile: float) -> torch.Tensor:\n",
+        "    \"\"\"Calculates quantile loss.\n",
+        "        Args:\n",
+        "            pred: Predicted values\n",
+        "            actual: Actual values\n",
+        "            quantile: Quantile at which loss is computed\n",
+        "        Returns:\n",
+        "            Quantile loss\n",
+        "        \"\"\"\n",
+        "    dev = actual - pred\n",
+        "    loss_first = dev * quantile\n",
+        "    loss_second = -dev * (1.0 - quantile)\n",
+        "    return 2 * torch.where(loss_first >= 0, loss_first, 
loss_second)\n",
+        "\n",
+        "  def _process_batch(self, batch: List[torch.Tensor]) -> tuple:\n",
+        "    \"\"\"Process a single batch of data.\n",
+        "\n",
+        "        Args:\n",
+        "          batch: List of input tensors.\n",
+        "\n",
+        "        Returns:\n",
+        "          Tuple of (loss, predictions).\n",
+        "        \"\"\"\n",
+        "    x_context, x_padding, freq, x_future = [\n",
+        "        t.to(self.device, non_blocking=True) for t in batch\n",
+        "    ]\n",
+        "\n",
+        "    predictions = self.model(x_context, x_padding.float(), freq)\n",
+        "    predictions_mean = predictions[..., 0]\n",
+        "    last_patch_pred = predictions_mean[:, -1, :]\n",
+        "\n",
+        "    loss = self.loss_fn(last_patch_pred, x_future.squeeze(-1))\n",
+        "    if self.config.use_quantile_loss:\n",
+        "      quantiles = self.config.quantiles or create_quantiles()\n",
+        "      for i, quantile in enumerate(quantiles):\n",
+        "        last_patch_quantile = predictions[:, -1, :, i + 1]\n",
+        "        loss += torch.mean(\n",
+        "            self._quantile_loss(last_patch_quantile, 
x_future.squeeze(-1),\n",
+        "                                quantile))\n",
+        "\n",
+        "    return loss, predictions\n",
+        "\n",
+        "  def _train_epoch(self, train_loader: DataLoader,\n",
+        "                   optimizer: torch.optim.Optimizer) -> float:\n",
+        "    \"\"\"Train for one epoch in a distributed setting.\n",
+        "\n",
+        "        Args:\n",
+        "            train_loader: DataLoader for training data.\n",
+        "            optimizer: Optimizer instance.\n",
+        "\n",
+        "        Returns:\n",
+        "            Average training loss for the epoch.\n",
+        "        \"\"\"\n",
+        "    self.model.train()\n",
+        "    total_loss = 0.0\n",
+        "    num_batches = len(train_loader)\n",
+        "\n",
+        "    for batch in train_loader:\n",
+        "      loss, _ = self._process_batch(batch)\n",
+        "\n",
+        "      optimizer.zero_grad()\n",
+        "      loss.backward()\n",
+        "      optimizer.step()\n",
+        "\n",
+        "      total_loss += loss.item()\n",
+        "\n",
+        "    avg_loss = total_loss / num_batches\n",
+        "\n",
+        "    if self.config.distributed:\n",
+        "      avg_loss_tensor = torch.tensor(avg_loss, device=self.device)\n",
+        "      dist.all_reduce(avg_loss_tensor, op=dist.ReduceOp.SUM)\n",
+        "      avg_loss = (avg_loss_tensor / dist.get_world_size()).item()\n",
+        "\n",
+        "    return avg_loss\n",
+        "\n",
+        "  def _validate(self, val_loader: DataLoader) -> float:\n",
+        "    \"\"\"Perform validation.\n",
+        "\n",
+        "        Args:\n",
+        "            val_loader: DataLoader for validation data.\n",
+        "\n",
+        "        Returns:\n",
+        "            Average validation loss.\n",
+        "        \"\"\"\n",
+        "    self.model.eval()\n",
+        "    total_loss = 0.0\n",
+        "    num_batches = len(val_loader)\n",
+        "\n",
+        "    with torch.no_grad():\n",
+        "      for batch in val_loader:\n",
+        "        loss, _ = self._process_batch(batch)\n",
+        "        total_loss += loss.item()\n",
+        "\n",
+        "    avg_loss = total_loss / num_batches\n",
+        "\n",
+        "    if self.config.distributed:\n",
+        "      avg_loss_tensor = torch.tensor(avg_loss, device=self.device)\n",
+        "      dist.all_reduce(avg_loss_tensor, op=dist.ReduceOp.SUM)\n",
+        "      avg_loss = (avg_loss_tensor / dist.get_world_size()).item()\n",
+        "\n",
+        "    return avg_loss\n",
+        "\n",
+        "  def finetune(self, train_dataset: Dataset,\n",
+        "               val_dataset: Dataset) -> Dict[str, Any]:\n",
+        "    \"\"\"Train the model.\n",
+        "\n",
+        "        Args:\n",
+        "          train_dataset: Training dataset.\n",
+        "          val_dataset: Validation dataset.\n",
+        "\n",
+        "        Returns:\n",
+        "          Dictionary containing training history.\n",
+        "        \"\"\"\n",
+        "    self.model = self.model.to(self.device)\n",
+        "    train_loader = self._create_dataloader(train_dataset, 
is_train=True)\n",
+        "    val_loader = self._create_dataloader(val_dataset, 
is_train=False)\n",
+        "\n",
+        "    optimizer = torch.optim.Adam(self.model.parameters(),\n",
+        "                                 lr=self.config.learning_rate,\n",
+        "                                 
weight_decay=self.config.weight_decay)\n",
+        "\n",
+        "    history = {\"train_loss\": [], \"val_loss\": [], 
\"learning_rate\": []}\n",
+        "\n",
+        "    self.logger.info(\n",
+        "        f\"Starting training for {self.config.num_epochs} 
epochs...\")\n",
+        "    self.logger.info(f\"Training samples: {len(train_dataset)}\")\n",
+        "    self.logger.info(f\"Validation samples: {len(val_dataset)}\")\n",
+        "\n",
+        "    try:\n",
+        "      for epoch in range(self.config.num_epochs):\n",
+        "        train_loss = self._train_epoch(train_loader, optimizer)\n",
+        "        val_loss = self._validate(val_loader)\n",
+        "        current_lr = optimizer.param_groups[0][\"lr\"]\n",
+        "\n",
+        "        metrics = {\n",
+        "            \"train_loss\": train_loss,\n",
+        "            \"val_loss\": val_loss,\n",
+        "            \"learning_rate\": current_lr,\n",
+        "            \"epoch\": epoch + 1,\n",
+        "        }\n",
+        "\n",
+        "        if self.config.use_wandb:\n",
+        "          self.metrics_logger.log_metrics(metrics)\n",
+        "\n",
+        "        history[\"train_loss\"].append(train_loss)\n",
+        "        history[\"val_loss\"].append(val_loss)\n",
+        "        history[\"learning_rate\"].append(current_lr)\n",
+        "\n",
+        "        if self.rank == 0:\n",
+        "          self.logger.info(\n",
+        "              f\"[Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Val 
Loss: {val_loss:.4f}\"\n",
+        "          )\n",
+        "\n",
+        "    except KeyboardInterrupt:\n",
+        "      self.logger.info(\"Training interrupted by user\")\n",
+        "\n",
+        "    if self.config.distributed:\n",
+        "      self.dist_manager.cleanup()\n",
+        "\n",
+        "    if self.config.use_wandb:\n",
+        "      self.metrics_logger.close()\n",
+        "\n",
+        "    return {\"history\": history}\n",
+        "\n",
+        "import apache_beam as beam\n",
+        "import logging\n",
+        "import torch\n",
+        "import numpy as np\n",
+        "import timesfm\n",
+        "from os import path\n",
+        "from timesfm import TimesFm, TimesFmCheckpoint, TimesFmHparams\n",
+        "from timesfm.pytorch_patched_decoder import 
PatchedTimeSeriesDecoder\n",
+        "from huggingface_hub import snapshot_download\n",
+        "from apache_beam.io.gcp.gcsio import GcsIO # Add this import\n",
+        "\n",
+        "from torch.utils.data import Dataset\n",
+        "from google.cloud import storage\n",
+        "from typing import Tuple\n",
+        "\n",
+        "\n",
+        "class TimeSeriesDataset(Dataset):\n",
+        "  \"\"\"Dataset for time series data compatible with 
TimesFM.\"\"\"\n",
+        "  def __init__(\n",
+        "      self,\n",
+        "      series: np.ndarray,\n",
+        "      context_length: int,\n",
+        "      horizon_length: int,\n",
+        "      freq_type: int = 0):\n",
+        "    \"\"\"\n",
+        "        Initialize dataset.\n",
+        "\n",
+        "        Args:\n",
+        "            series: Time series data\n",
+        "            context_length: Number of past timesteps to use as 
input\n",
+        "            horizon_length: Number of future timesteps to predict\n",
+        "            freq_type: Frequency type (0, 1, or 2)\n",
+        "        \"\"\"\n",
+        "    if freq_type not in [0, 1, 2]:\n",
+        "      raise ValueError(\"freq_type must be 0, 1, or 2\")\n",
+        "\n",
+        "    self.series = series\n",
+        "    self.context_length = context_length\n",
+        "    self.horizon_length = horizon_length\n",
+        "    self.freq_type = freq_type\n",
+        "    self._prepare_samples()\n",
+        "\n",
+        "  def _prepare_samples(self) -> None:\n",
+        "    \"\"\"Prepare sliding window samples from the time 
series.\"\"\"\n",
+        "    self.samples = []\n",
+        "    total_length = self.context_length + self.horizon_length\n",
+        "\n",
+        "    for start_idx in range(0, len(self.series) - total_length + 
1):\n",
+        "      end_idx = start_idx + self.context_length\n",
+        "      x_context = self.series[start_idx:end_idx]\n",
+        "      x_future = self.series[end_idx:end_idx + 
self.horizon_length]\n",
+        "      self.samples.append((x_context, x_future))\n",
+        "\n",
+        "  def __len__(self) -> int:\n",
+        "    return len(self.samples)\n",
+        "\n",
+        "  def __getitem__(\n",
+        "      self, index: int\n",
+        "  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, 
torch.Tensor]:\n",
+        "    x_context, x_future = self.samples[index]\n",
+        "\n",
+        "    x_context = torch.tensor(x_context, dtype=torch.float32)\n",
+        "    x_future = torch.tensor(x_future, dtype=torch.float32)\n",
+        "\n",
+        "    input_padding = torch.zeros_like(x_context)\n",
+        "    freq = torch.tensor([self.freq_type], dtype=torch.long)\n",
+        "\n",
+        "    return x_context, input_padding, freq, x_future\n",
+        "\n",
+        "\n",
+        "def prepare_datasets(\n",
+        "    series: np.ndarray,\n",
+        "    context_length: int,\n",
+        "    horizon_length: int,\n",
+        "    freq_type: int = 0,\n",
+        "    train_split: float = 0.8) -> Tuple[Dataset, Dataset]:\n",
+        "  \"\"\"\n",
+        "    Prepare training and validation datasets from time series 
data.\n",
+        "\n",
+        "    Args:\n",
+        "        series: Input time series data\n",
+        "        context_length: Number of past timesteps to use\n",
+        "        horizon_length: Number of future timesteps to predict\n",
+        "        freq_type: Frequency type (0, 1, or 2)\n",
+        "        train_split: Fraction of data to use for training\n",
+        "\n",
+        "    Returns:\n",
+        "        Tuple of (train_dataset, val_dataset)\n",
+        "    \"\"\"\n",
+        "  train_size = int(len(series) * train_split)\n",
+        "  train_data = series[:train_size]\n",
+        "  val_data = series[train_size:]\n",
+        "\n",
+        "  # Create datasets with specified frequency type\n",
+        "  train_dataset = TimeSeriesDataset(\n",
+        "      train_data,\n",
+        "      context_length=context_length,\n",
+        "      horizon_length=horizon_length,\n",
+        "      freq_type=freq_type)\n",
+        "\n",
+        "  val_dataset = TimeSeriesDataset(\n",
+        "      val_data,\n",
+        "      context_length=context_length,\n",
+        "      horizon_length=horizon_length,\n",
+        "      freq_type=freq_type)\n",
+        "\n",
+        "  return train_dataset, val_dataset\n",
+        "\n",
+        "\n",
+        "class BatchContinuousAndOrderedFn(beam.DoFn):\n",
+        "    \"\"\"\n",
+        "    A stateful DoFn that buffers elements, keeps them sorted, and 
emits\n",
+        "    a batch only when a full, continuous sequence of points is 
available.\n",
+        "    Includes detailed logging for debugging.\n",
+        "    \"\"\"\n",
+        "    BUFFER_STATE = ReadModifyWriteStateSpec('buffer', 
PickleCoder())\n",
+        "\n",
+        "    def __init__(self, batch_size, expected_interval_seconds=1):\n",
+        "        self.batch_size = batch_size\n",
+        "        self.interval = expected_interval_seconds\n",
+        "        # NEW LOGGING: Counter to avoid logging on every single 
element\n",
+        "        self.counter = 0\n",
+        "\n",
+        "    def process(self, element, 
buffer=beam.DoFn.StateParam(BUFFER_STATE)):\n",
+        "        key, data = element\n",
+        "        timestamp = data['timestamp']\n",
+        "        value = data['value']\n",
+        "\n",
+        "        # Increment the counter\n",
+        "        self.counter += 1\n",
+        "\n",
+        "        current_buffer = buffer.read() or []\n",
+        "        current_buffer.append((timestamp, value))\n",
+        "        current_buffer.sort(key=lambda x: x[0])\n",
+        "\n",
+        "        # NEW LOGGING: Periodically log the buffer status\n",
+        "        if self.counter % 100 == 0 and current_buffer:\n",
+        "            logging.info(\n",
+        "                f\"Batching buffer now contains {len(current_buffer)} 
points. \"\n",
+        "                f\"Timestamps range from {current_buffer[0][0]} to 
{current_buffer[-1][0]}.\"\n",
+        "            )\n",
+        "\n",
+        "        start_index = 0\n",
+        "        while start_index + self.batch_size <= 
len(current_buffer):\n",
+        "            is_continuous = True\n",
+        "            # Check for continuity in the slice of the buffer we are 
considering\n",
+        "            for i in range(start_index, start_index + self.batch_size 
- 1):\n",
+        "                ts1_seconds = current_buffer[i][0].seconds()\n",
+        "                ts2_seconds = current_buffer[i + 1][0].seconds()\n",
+        "\n",
+        "                if ts2_seconds - ts1_seconds != self.interval:\n",
+        "                    is_continuous = False\n",
+        "                    # If a gap is found, we should stop and wait for 
more data.\n",
+        "                    # We can't proceed past this point because the 
buffer is sorted.\n",
+        "                    logging.info(\n",
+        "                        f\"Gap detected at index {i}. \"\n",
+        "                        f\"Timestamp {current_buffer[i][0]} is 
followed by {current_buffer[i+1][0]}. \"\n",
+        "                        f\"Actual interval: {ts2_seconds - 
ts1_seconds}s, Expected: {self.interval}s. \"\n",
+        "                        f\"Waiting for missing data.\"\n",
+        "                    )\n",
+        "                    break\n",
+        "\n",
+        "            if not is_continuous:\n",
+        "                # Since the buffer is sorted, a gap at this point 
means we can't form any more continuous batches.\n",
+        "                break\n",
+        "\n",
+        "            # If we are here, the batch from start_index is 
continuous.\n",
+        "            logging.info(f\"Continuous sequence found! Emitting batch 
of size {self.batch_size} starting at index {start_index}.\")\n",
+        "\n",
+        "            batch_to_yield = current_buffer[start_index : start_index 
+ self.batch_size]\n",
+        "\n",
+        "            formatted_batch = [{'timestamp': ts, 'value': val} for 
ts, val in batch_to_yield]\n",
+        "            yield formatted_batch\n",
+        "\n",
+        "            # Move the start_index to the next position after the 
yielded batch\n",
+        "            start_index += self.batch_size\n",
+        "\n",
+        "        # After the loop, remove all the yielded elements from the 
buffer.\n",
+        "        if start_index > 0:\n",
+        "            current_buffer = current_buffer[start_index:]\n",
+        "\n",
+        "        buffer.write(current_buffer)\n",
+        "\n",
+        "class RunFinetuningFn(beam.DoFn):\n",
+        "  \"\"\"\n",
+        "    Takes a batch of data, loads the LATEST model, runs 
fine-tuning,\n",
+        "    and uploads the new model to GCS.\n",
+        "  \"\"\"\n",
+        "  def __init__(\n",
+        "      self,\n",
+        "      initial_model_path, # Renamed from base_model_path\n",
+        "      finetuned_model_bucket,\n",
+        "      finetuned_model_prefix,\n",
+        "      hparams,\n",
+        "      config):\n",
+        "    # This is now a fallback for the very first run\n",
+        "    self.initial_model_path = initial_model_path\n",
+        "    self.finetuned_model_bucket = finetuned_model_bucket\n",
+        "    self.finetuned_model_prefix = finetuned_model_prefix\n",
+        "    self.hparams = hparams\n",
+        "    self.config = config\n",
+        "    self._storage_client = None\n",
+        "\n",
+        "  def setup(self):\n",
+        "    self._storage_client = storage.Client()\n",
+        "\n",
+        "  def _get_latest_model_from_gcs(self):\n",
+        "    \"\"\"Directly queries GCS for the most recently created model 
checkpoint.\"\"\"\n",
+        "    try:\n",
+        "        bucket = 
self._storage_client.get_bucket(self.finetuned_model_bucket)\n",
+        "        blobs = 
list(bucket.list_blobs(prefix=self.finetuned_model_prefix))\n",
+        "\n",
+        "        # Filter for actual model files and exclude the initial model 
if present\n",
+        "        model_blobs = [b for b in blobs if b.name.endswith(\".pth\") 
and \"initial\" not in b.name]\n",
+        "\n",
+        "        if not model_blobs:\n",
+        "            return None\n",
+        "\n",
+        "        # Find the blob with the latest creation time\n",
+        "        latest_blob = max(model_blobs, key=lambda b: 
b.time_created)\n",
+        "        latest_model_path = 
f\"gs://{self.finetuned_model_bucket}/{latest_blob.name}\"\n",
+        "        return latest_model_path\n",
+        "    except Exception as e:\n",
+        "        logging.error(f\"Error querying GCS for the latest model: 
{e}\")\n",
+        "        return None\n",
+        "\n",
+        "  # Add the side input parameter to the process method\n",
+        "  def process(self, batch_of_data):\n",
+        "    logging.info(\n",
+        "        f\"Received a batch of {len(batch_of_data)} points for 
finetuning.\")\n",
+        "\n",
+        "    # If a finetuned model exists, use it. Otherwise, use the initial 
base model.\n",
+        "    latest_model_path = self._get_latest_model_from_gcs()\n",
+        "\n",
+        "    if latest_model_path:\n",
+        "        model_to_load = latest_model_path\n",
+        "        logging.info(f\"Continuously finetuning from latest model: 
{model_to_load}\")\n",
+        "    else:\n",
+        "        model_to_load = self.initial_model_path\n",
+        "        logging.info(f\"No finetuned model found. Starting from 
initial model: {model_to_load}\")\n",
+        "\n",
+        "    # batch_of_data.sort(key=lambda x: x[1]['timestamp'])\n",
+        "    time_series_values = np.array([d['value'] for d in 
batch_of_data],\n",
+        "                                  dtype=np.float32)\n",
+        "    train_dataset, val_dataset = prepare_datasets(\n",
+        "        series=time_series_values,\n",
+        "        context_length=self.hparams.context_len,\n",
+        "        horizon_length=self.hparams.horizon_len,\n",
+        "        freq_type=self.config.freq_type,\n",
+        "        train_split=0.8\n",
+        "    )\n",
+        "\n",
+        "    logging.info(f\"Training dataset size: 
{train_dataset.series.tolist()}\")\n",
+        "    logging.info(f\"Validation dataset size: 
{val_dataset.series.tolist()}\")\n",
+        "\n",
+        "    # Load the model (base or latest finetuned)\n",
+        "    # The updated get_model function can handle both GCS and Hugging 
Face paths\n",
+        "    model = get_model(\n",
+        "        model_path=model_to_load, # Use the path we just 
determined\n",
+        "        hparams=self.hparams,\n",
+        "        load_weights=True\n",
+        "    )\n",
+        "\n",
+        "    # 4. Run fine-tuning (same as before)\n",
+        "    finetuner = TimesFMFinetuner(model, self.config)\n",
+        "    finetuner.finetune(train_dataset=train_dataset, 
val_dataset=val_dataset)\n",
+        "\n",
+        "    # 5. Save and upload the new model (same as before)\n",
+        "    from datetime import datetime\n",
+        "    timestamp_str = datetime.utcnow().strftime('%Y%m%d%H%M%S')\n",
+        "    model_filename = f\"timesfm_finetuned_{timestamp_str}.pth\"\n",
+        "    local_path = f\"/tmp/{model_filename}\"\n",
+        "    torch.save(model.state_dict(), local_path)\n",
+        "    bucket = 
self._storage_client.bucket(self.finetuned_model_bucket)\n",
+        "    blob_path = 
f\"{self.finetuned_model_prefix}/{model_filename}\"\n",
+        "    blob = bucket.blob(blob_path)\n",
+        "    blob.upload_from_filename(local_path)\n",
+        "    logging.info(\n",
+        "        f\"Successfully uploaded new model to 
gs://{self.finetuned_model_bucket}/{blob_path}\"\n",
+        "    )\n",
+        "    yield blob_path\n",
+        "\n",
+        "\n",
+        "def get_model(model_path: str, hparams: TimesFmHparams, load_weights: 
bool = False):\n",
+        "    \"\"\"\n",
+        "    Loads a TimesFM model from either a Hugging Face repo ID or a GCS 
path.\n",
+        "    The `load_weights` argument is kept for signature consistency but 
is\n",
+        "    effectively always True, as TimesFm handles loading.\n",
+        "    \"\"\"\n",
+        "    checkpoint_config = {}\n",
+        "\n",
+        "    # Case 1: The model path is a GCS URI.\n",
+        "    # We download it to a local file and tell TimesFmCheckpoint to 
load from that path.\n",
+        "    if model_path.startswith(\"gs://\"):\n",
+        "        logging.info(f\"Preparing to load model from GCS path: 
{model_path}\")\n",
+        "        local_temp_path = f\"/tmp/{path.basename(model_path)}\"\n",
+        "        with GcsIO().open(model_path, 'rb') as f_in, 
open(local_temp_path, 'wb') as f_out:\n",
+        "            f_out.write(f_in.read())\n",
+        "        # The key for a local file is 'path'\n",
+        "        checkpoint_config['path'] = local_temp_path\n",
+        "\n",
+        "    # Case 2: The model path is a Hugging Face repository ID.\n",
+        "    else:\n",
+        "        logging.info(f\"Preparing to load model from Hugging Face 
repo: {model_path}\")\n",
+        "        # The key for a Hugging Face repo is 'huggingface_repo_id'\n",
+        "        checkpoint_config['huggingface_repo_id'] = model_path\n",
+        "\n",
+        "    # Initialize the TimesFm object correctly with the dynamically 
created checkpoint config.\n",
+        "    # This single call handles model configuration and weight 
loading.\n",
+        "    tfm = TimesFm(\n",
+        "        hparams=hparams,\n",
+        "        checkpoint=TimesFmCheckpoint(**checkpoint_config)\n",
+        "    )\n",
+        "\n",
+        "    logging.info(\"Model loaded successfully inside get_model.\")\n",
+        "\n",
+        "    # The `TimesFm` object holds the configured model instance.\n",
+        "    # The model returned here will be a PatchedTimeSeriesDecoder 
instance with weights loaded.\n",
+        "    return tfm._model"
+      ],
+      "metadata": {
+        "id": "IzEE_R3SuwAR"
+      },
+      "execution_count": 11,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# Load Time Series Data\n",
+        "\n",
+        "https://www.kaggle.com/datasets/julienjta/nyc-taxi-traffic/data";
+      ],
+      "metadata": {
+        "id": "Lz1OROouy9IV"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "import pandas as pd\n",
+        "from google.colab import auth\n",
+        "auth.authenticate_user()\n",
+        "\n",
+        "auth.authenticate_user()\n",
+        "\n",
+        "# Define the path to your file in the GCS bucket\n",
+        "gcs_path = 
'gs://apache-beam-samples/anomaly_detection/timesfm-dataset-example/nyc_taxi_timeseries.csv'\n",
+        "\n",
+        "# Read the CSV directly from GCS into a DataFrame\n",
+        "# All the gspread code is replaced by this single line\n",
+        "df = pd.read_csv(gcs_path)\n",
+        "\n",
+        "# --- The rest of your processing code remains the same ---\n",
+        "\n",
+        "# Convert 'value' column to a numpy array of integers\n",
+        "values_array = pd.to_numeric(df['value'], 
errors='coerce').astype(int).to_numpy()\n",
+        "\n",
+        "# Create the list of (timestamp, value) tuples\n",
+        "input_data = []\n",
+        "for i in range(len(values_array)):\n",
+        "  input_data.append((Timestamp(i + 1), values_array[i])) # Assuming 
Timestamp comes from pandas\n",
+        "\n",
+        "print(\"DataFrame loaded from GCS:\")\n",
+        "print(df.head())\n",
+        "print(\"\\nInput data created successfully (first 5 entries):\")\n",
+        "print(input_data[:5])"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/";
+        },
+        "id": "N7fAXDjXuDF4",
+        "outputId": "0ea32fa5-221d-44fb-9f83-4f524fd8f3c2"
+      },
+      "execution_count": 6,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "DataFrame loaded from GCS:\n",
+            "   Unnamed: 0           timestamp  value\n",
+            "0           0  2014-07-01 0:00:00  10844\n",
+            "1           1  2014-07-01 0:30:00   8127\n",
+            "2           2  2014-07-01 1:00:00   6210\n",
+            "3           3  2014-07-01 1:30:00   4656\n",
+            "4           4  2014-07-01 2:00:00   3820\n",
+            "\n",
+            "Input data created successfully (first 5 entries):\n",
+            "[(Timestamp(1), np.int64(10844)), (Timestamp(2), np.int64(8127)), 
(Timestamp(3), np.int64(6210)), (Timestamp(4), np.int64(4656)), (Timestamp(5), 
np.int64(3820))]\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# Beam Pipeline Setup"
+      ],
+      "metadata": {
+        "id": "LiQQF_IxquCK"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "import apache_beam as beam\n",
+        "from apache_beam.options.pipeline_options import PipelineOptions\n",
+        "from apache_beam.pvalue import AsDict, AsSingleton\n",
+        "from apache_beam.transforms.periodicsequence import 
PeriodicImpulse\n",
+        "import logging\n",
+        "import os\n",
+        "import json\n",
+        "import timesfm\n",
+        "from apache_beam.utils.timestamp import Timestamp\n",
+        "import csv\n",
+        "from apache_beam.ml.inference.base import RunInference\n",
+        "from apache_beam.ml.inference.utils import WatchFilePattern\n",
+        "import typing\n",
+        "from google.colab import userdata\n",
+        "import apache_beam.transforms.window as window\n",
+        "\n",
+        "logging.getLogger().setLevel(logging.INFO)\n",
+        "\n",
+        "# --- Pipeline Configuration ---\n",
+        "PROJECT_ID = os.environ.get(\"GCP_PROJECT\", 
\"apache-beam-testing\")\n",
+        "REGION = os.environ.get(\"GCP_REGION\", \"us-central1\")\n",
+        "TEMP_LOCATION = 
\"gs://apache-beam-testing-temp/timesfm_anomaly_detection/temp\"\n",
+        "STAGING_LOCATION = 
\"gs://apache-beam-testing-temp/timesfm_anomaly_detection/staging\"\n",
+        "FINETUNED_MODEL_BUCKET = \"apache-beam-testing-temp\"\n",
+        "FINETUNED_MODEL_PREFIX = 
\"timesfm_anomaly_detection/finetuned-models/timesfm/checkpoints\"\n",
+        "\n",
+        "# --- Model & Window Parameters ---\n",
+        "CONTEXT_LEN = 512\n",
+        "HORIZON_LEN = 128\n",
+        "WINDOW_SIZE = CONTEXT_LEN + HORIZON_LEN\n",
+        "SLIDE_INTERVAL = HORIZON_LEN\n",
+        "EXPECTED_INTERVAL = 1\n",
+        "INITIAL_MODEL = \"google/timesfm-1.0-200m-pytorch\"\n",
+        "\n",
+        "MODEL_CHECK_INTERVAL_SECONDS = 10  # Check for a new model every 5 
seconds\n",
+        "FINETUNING_BATCH_SIZE = 7680 # 9600 # make larger later. minimum is 
WINDOW_SIZE for validation and training\n",
+        "FINETUNE_CONFIG = FinetuningConfig(\n",
+        "      batch_size=128,\n",
+        "      num_epochs=5,\n",
+        "      learning_rate=1e-4,\n",
+        "      use_wandb=False,\n",
+        "      freq_type=0, # should change based on your data\n",
+        "      log_every_n_steps=10,\n",
+        "      val_check_interval=0.5,\n",
+        "      use_quantile_loss=True\n",
+        "  )\n",
+        "\n",
+        "#Change to Dataflow if needed\n",
+        "options = PipelineOptions([\n",
+        "    \"--streaming\",\n",
+        "    \"--environment_type=LOOPBACK\",\n",
+        "    \"--runner=PrismRunner\",\n",
+        "    \"--logging_level=INFO\",\n",
+        "    \"--job_server_timeout=3600\"\n",
+        "])\n",
+        "\n",
+        "\n",
+        "\n",
+        "# HParams for the model\n",
+        "hparams = timesfm.TimesFmHparams(\n",
+        "    backend=\"gpu\",\n",
+        "    per_core_batch_size=32,\n",
+        "    horizon_len=HORIZON_LEN,\n",
+        "    context_len=CONTEXT_LEN,\n",
+        ")\n",
+        "model_handler = DynamicTimesFmModelHandler(model_uri=INITIAL_MODEL, 
hparams=hparams)\n",
+        "\n",
+        "def print_and_pass_through(label):\n",
+        "    def logger(element):\n",
+        "        print(f\"--- {label} --- \\nELEMENT: %s\", element)\n",
+        "        return element\n",
+        "    return logger\n",
+        "\n",
+        "\n",
+        "class CustomJsonEncoder(json.JSONEncoder):\n",
+        "    \"\"\"A custom JSON encoder that knows how to handle Beam's 
Timestamp objects.\"\"\"\n",
+        "    def default(self, obj):\n",
+        "        if isinstance(obj, Timestamp):\n",
+        "            # Convert Timestamp to a standard, readable ISO 8601 
string format\n",
+        "            return obj.micros // 1e6\n",
+        "        # For all other types, fall back to the default behavior\n",
+        "        if isinstance(obj, np.integer):\n",
+        "            return int(obj)\n",
+        "\n",
+        "        # 3. Handle NumPy float types (this will fix your float32 
error)\n",
+        "        if isinstance(obj, np.floating):\n",
+        "            return float(obj)\n",
+        "\n",
+        "        # 4. Handle NumPy arrays\n",
+        "        if isinstance(obj, np.ndarray):\n",
+        "            return obj.tolist()\n",
+        "\n",
+        "        # For all other types, fall back to the default behavior\n",
+        "        return super().default(obj)\n",
+        "        return json.JSONEncoder.default(self, obj)\n",
+        "\n",
+        "class WritePlotDataAndPassThrough(beam.DoFn):\n",
+        "    \"\"\"\n",
+        "    A DoFn that writes plotting data to a file as a side effect\n",
+        "    and then passes the original, unmodified element downstream.\n",
+        "    \"\"\"\n",
+        "    def __init__(self, output_path):\n",
+        "        self._output_path = output_path\n",
+        "        self._file_handle = None\n",
+        "\n",
+        "    def setup(self):\n",
+        "        self._file_handle = open(self._output_path, 'a')\n",
+        "\n",
+        "    def process(self, element):\n",
+        "        _original_window, payload_dict = element\n",
+        "\n",
+        "        # ✅ FIX: Use the custom encoder to handle Timestamp 
objects\n",
+        "        json_record = json.dumps(payload_dict, 
cls=CustomJsonEncoder)\n",
+        "        self._file_handle.write(json_record + '\\n')\n",
+        "\n",
+        "        # Pass the original element through, with the Timestamp 
object intact\n",
+        "        yield element\n",
+        "\n",
+        "    def teardown(self):\n",
+        "        if self._file_handle:\n",
+        "            self._file_handle.close()\n",
+        "\n",
+        "\n",
+        "# 
=================================================================\n",
+        "# 1. Get Latest Model Path (Side Input) - WatchFilePattern is not\n",
+        "#    currently supported on Prism. Uncomment the following to run\n",
+        "#    on Dataflow\n",
+        "# 
=================================================================\n",
+        "# model_pattern = os.path.join(\n",
+        "#     f\"gs://{FINETUNED_MODEL_BUCKET}\", FINETUNED_MODEL_PREFIX, 
\"*.pth\"\n",
+        "# )\n",
+        "\n",
+        "# model_metadata_pcoll = (\n",
+        "#     \"WatchForNewModels\" >> WatchFilePattern(\n",
+        "#         file_pattern=model_pattern,\n",
+        "#         interval=MODEL_CHECK_INTERVAL_SECONDS\n",
+        "#       )\n",
+        "#     | \"PrintModelLocation\" >> 
beam.Map(print_and_pass_through(\"Model Location\"))\n",
+        "\n",
+        "# )\n",
+        "\n",
+        "# 
=================================================================\n",
+        "# Ingest and Window Raw Data\n",
+        "# 
=================================================================\n",
+        "\n",
+        "\n",
+        "windowed_data = (\n",
+        "    PeriodicImpulse(data=input_data, fire_interval=0.01)\n",
+        "    | \"AddKey\" >> beam.WithKeys(lambda x: 0)\n",
+        "    | \"ApplySlidingWindow\" >> beam.ParDo(\n",
+        "        OrderedSlidingWindowFn(window_size=WINDOW_SIZE, 
slide_interval=SLIDE_INTERVAL))\n",
+        "    | \"FillGaps\" >> 
beam.ParDo(FillGapsFn(expected_interval=EXPECTED_INTERVAL)).with_output_types(\n",
+        "        typing.Tuple[int, typing.Tuple[Timestamp, Timestamp, 
typing.List[float]]])\n",
+        "    | \"Skip NaN Values for now\" >> beam.Filter(\n",
+        "      lambda batch: 'NaN' not in batch[1][2])\n",
+        "    | \"PrintWindowedData\" >> 
beam.Map(print_and_pass_through(\"Windowed Data\"))\n",
+        "\n",
+        ")\n",
+        "\n",
+        "# 
=================================================================\n",
+        "# Detect Anomalies using the Latest Model\n",
+        "# 
=================================================================\n",
+        "\n",
+        "inference_results = (\n",
+        "    \"DetectAnomalies\" >> RunInference(\n",
+        "        model_handler=model_handler,\n",
+        "        # model_metadata_pcoll=model_metadata_pcoll\n",
+        "      )\n",
+        "    | \"PrintInference\" >> 
beam.Map(print_and_pass_through(\"Inference Results\"))\n",
+        ")\n",
+        "\n",
+        "\n",
+        "# NEW BRANCH: For plotting. It takes the payload dictionary, 
converts\n",
+        "# it to JSON, and writes it to a file.\n",
+        "plotting_data_output = (\n",
+        "    \"WritePlotDataAsSideEffect\" >> beam.ParDo(\n",
+        "          WritePlotDataAndPassThrough('plot_data_original.jsonl'))\n",
+        ")\n",
+        "\n",
+        "def format_for_llm(result_tuple):\n",
+        "    \"\"\"\n",
+        "    Takes the output of RunInference (a PredictionResult) and formats 
it\n",
+        "    into the dictionary structure needed by the LLMClassifierFn.\n",
+        "    \"\"\"\n",
+        "    original_window_data, result_dict = result_tuple\n",
+        "\n",
+        "    list_of_anomalies = result_dict['anomalies']\n",
+        "\n",
+        "    key, (window_start_ts, _, values_array) = original_window_data\n",
+        "\n",
+        "    return (key, {\n",
+        "        'key': key,\n",
+        "        'window_start_ts': window_start_ts,\n",
+        "        'values_array': values_array,\n",
+        "        'anomalies': list_of_anomalies if list_of_anomalies else 
[]\n",
+        "    })\n",
+        "\n",
+        "\n",
+        "data_for_llm = (\n",
+        "    \"FormatForLLM\" >> beam.Map(format_for_llm)\n",
+        "    | \"PrintDataForLLM\" >> beam.Map(print_and_pass_through(\"Data 
for LLM\"))\n",
+        ")\n",
+        "\n",
+        "\n",
+        "# 
=================================================================\n",
+        "# Classify with LLM and Create Clean Data for Finetuning\n",
+        "# 
=================================================================\n",
+        "api_key = \"AIzaSyCB_g6tq3eBFtB3BsshdGotLkUkTsCyApY\" 
#userdata.get('GEMINI_API_KEY')\n",
+        "\n",
+        "llm_classifier = (\n",
+        "    \"LLMClassifierAndImputer\" >> beam.ParDo(\n",
+        "        LLMClassifierFn(\n",
+        "            secret=api_key,\n",
+        "            slide_interval=SLIDE_INTERVAL,\n",
+        "            expected_interval_secs=EXPECTED_INTERVAL\n",
+        "          )\n",
+        "      )\n",
+        "    # | \"PrintLLMResults\" >> beam.Map(print_and_pass_through(\"LLM 
Results\"))\n",
+        ")\n",
+        "\n",
+        "\n",
+        "# # 
=================================================================\n",
+        "# # Batch Clean Data and Trigger Finetuning\n",
+        "# # 
=================================================================\n",
+        "finetuning_job_input = (\n",
+        "    \"KeyForBatching\" >> beam.WithKeys(lambda _: 
\"finetune_batch\")\n",
+        "    # | \"BatchAndTrigger\" >> 
beam.ParDo(BatchAndTriggerFinetuningFn(FINETUNING_BATCH_SIZE))\n",
+        "    | \"BatchAndTrigger\" >> beam.ParDo(\n",
+        "        BatchContinuousAndOrderedFn(\n",
+        "            FINETUNING_BATCH_SIZE,\n",
+        "            expected_interval_seconds=1\n",
+        "            )\n",
+        "        )\n",
+        "    | \"PrintFinetuningJobInput\" >> 
beam.Map(print_and_pass_through(\"Finetuning Job Input\"))\n",
+        ")\n",
+        "\n",
+        "# # 
=================================================================\n",
+        "# # Run Finetuning and Save New Model to GCS\n",
+        "# # 
=================================================================\n",
+        "finetuning = (\n",
+        "    \"RunFinetuning\" >> beam.ParDo(\n",
+        "        RunFinetuningFn(\n",
+        "            
initial_model_path=\"google/timesfm-1.0-200m-pytorch\",\n",
+        "            finetuned_model_bucket=FINETUNED_MODEL_BUCKET,\n",
+        "            finetuned_model_prefix=FINETUNED_MODEL_PREFIX,\n",
+        "            hparams=hparams,\n",
+        "            config=FINETUNE_CONFIG\n",
+        "        ),\n",
+        "    )\n",
+        ")\n"
+      ],
+      "metadata": {
+        "id": "Oud4wLTjqy2j",
+        "colab": {
+          "base_uri": "https://localhost:8080/";
+        },
+        "outputId": "1e0cdb8c-16e5-42ce-ef5a-b870677e954d"
+      },
+      "execution_count": 16,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stderr",
+          "text": [
+            "WARNING:apache_beam.transforms.core:('No iterator is returned by 
the process method in %s.', <class '__main__.LLMClassifierFn'>)\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# Beam Pipeline"
+      ],
+      "metadata": {
+        "id": "ZMx8KhRyvj3Q"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "with beam.Pipeline(options=options) as p:\n",
+        "  (p\n",
+        "  | windowed_data\n",
+        "  | inference_results\n",
+        "  | plotting_data_output # comment this line if you dont want to save 
plot data\n",
+        "  | data_for_llm\n",
+        "  | llm_classifier\n",
+        "  | finetuning_job_input\n",
+        "  | finetuning\n",
+        "  )\n"
+      ],
+      "metadata": {
+        "id": "mKa6Qb_1vnNX"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# Plot Data (Original)"
+      ],
+      "metadata": {
+        "id": "O7egp_5Alzz7"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "import json\n",
+        "import numpy as np\n",
+        "import matplotlib.pyplot as plt\n",
+        "\n",
+        "CONTEXT_LEN = 512\n",
+        "HORIZON_LEN = 128\n",
+        "\n",
+        "def plot_anomalies_and_forecast(\n",
+        "    values_array,\n",
+        "    all_anomalies,\n",
+        "    all_predicted_values,\n",
+        "    all_q20_values,\n",
+        "    all_q30_values,\n",
+        "    all_q70_values,\n",
+        "    all_q80_values,\n",
+        "    title_suffix=\"\",\n",
+        "    x_lims=None,\n",
+        "    min_outlier_score_for_plot=0,\n",
+        "    context_len=512,\n",
+        "    output_filename=\"plot.png\"\n",
+        "):\n",
+        "    print(len(all_anomalies))\n",
+        "    # The key from your file is 'outlier_score'\n",
+        "    filtered_anomalies = [a for a in all_anomalies if 
a['outlier_score'] >= min_outlier_score_for_plot]\n",
+        "    # The key from your file is 'timestamp'\n",
+        "    anomaly_indices = [(a['timestamp'] - HORIZON_LEN) for a in 
filtered_anomalies]\n",
+        "    anomaly_values = [a['actual_value'] for a in 
filtered_anomalies]\n",
+        "\n",
+        "    Q1 = np.nanmean([all_q20_values, all_q30_values], axis=0)\n",
+        "    Q3 = np.nanmean([all_q70_values, all_q80_values], axis=0)\n",
+        "    IQR = Q3 - Q1\n",
+        "    upper_thresh = Q3 + 1.5 * IQR\n",
+        "    lower_thresh = Q1 - 1.5 * IQR\n",
+        "\n",
+        "    plt.figure(figsize=(18, 9))\n",
+        "    # This now plots the correct original data for the horizon\n",
+        "    plt.plot(values_array[context_len:], label='Original Time 
Series', color='blue', alpha=0.7, linewidth=1.5)\n",
+        "\n",
+        "    plt.plot(all_predicted_values, label='Predicted Mean', 
color='green', linestyle='--', linewidth=1.5)\n",
+        "    plt.plot(lower_thresh, label='Lower Threshold', color='orange', 
linestyle=':', linewidth=1.2)\n",
+        "    plt.plot(upper_thresh, label='Upper Threshold', color='purple', 
linestyle=':', linewidth=1.2)\n",
+        "\n",
+        "    plt.scatter([i - context_len for i in anomaly_indices], 
anomaly_values,\n",
+        "                color='red', s=70, zorder=5,\n",
+        "                label=f'Detected Anomalies (Score >= 
{min_outlier_score_for_plot:.1f})',\n",
+        "                marker='o', edgecolors='black', linewidths=0.8)\n",
+        "\n",
+        "    plt.title(f'Time Series Anomaly Detection {title_suffix}')\n",
+        "    plt.xlabel('Time Index')\n",
+        "    plt.ylabel('Value')\n",
+        "    if x_lims:\n",
+        "        plt.xlim(x_lims[0], x_lims[1])\n",
+        "    plt.legend()\n",
+        "    plt.grid(True, linestyle='--', alpha=0.6)\n",
+        "    plt.tight_layout()\n",
+        "    # plt.savefig(output_filename) # Save the plot to a file\n",
+        "    plt.show()\n",
+        "    plt.close() # Close the figure to free memory\n",
+        "\n",
+        "# --- Main Script Logic ---\n",
+        "\n",
+        "# 1. Read and parse the data from the Beam output file\n",
+        "all_window_data = []\n",
+        "# Make sure 'plot_data.jsonl' is in the same directory as this 
script\n",
+        "try:\n",
+        "    with open('plot_data_original.jsonl', 'r') as f:\n",
+        "        for line in f:\n",
+        "            # Check for empty lines that might have been added\n",
+        "            if line.strip():\n",
+        "                all_window_data.append(json.loads(line))\n",
+        "except FileNotFoundError:\n",
+        "    print(\"Error: 'plot_data.jsonl' not found. Please make sure the 
file is in the correct directory.\")\n",
+        "    exit()\n",
+        "\n",
+        "\n",
+        "# 2. Sort data by timestamp to ensure the correct order\n",
+        "all_window_data.sort(key=lambda x: x['start_ts_micros'])\n",
+        "\n",
+        "# 3. Reconstruct the full data arrays\n",
+        "all_anomalies = []\n",
+        "all_predicted_values = []\n",
+        "all_q20_values = []\n",
+        "all_q30_values = []\n",
+        "all_q70_values = []\n",
+        "all_q80_values = []\n",
+        "all_actual_horizon_values = [] # This will hold the real \"blue 
line\" data\n",
+        "\n",
+        "for window_data in all_window_data:\n",
+        "    all_predicted_values.extend(window_data['predicted_values'])\n",
+        "    all_q20_values.extend(window_data['q20_values'])\n",
+        "    all_q30_values.extend(window_data['q30_values'])\n",
+        "    all_q70_values.extend(window_data['q70_values'])\n",
+        "    all_q80_values.extend(window_data['q80_values'])\n",
+        "    # Populate the list with the actual values from the file\n",
+        "    
all_actual_horizon_values.extend(window_data.get('actual_horizon_values', 
[]))\n",
+        "    all_anomalies.extend(window_data.get('anomalies', []))\n",
+        "\n",
+        "# 4. Convert lists to NumPy arrays\n",
+        "all_predicted_values = np.array(all_predicted_values)\n",
+        "all_q20_values = np.array(all_q20_values)\n",
+        "all_q30_values = np.array(all_q30_values)\n",
+        "all_q70_values = np.array(all_q70_values)\n",
+        "all_q80_values = np.array(all_q80_values)\n",
+        "\n",
+        "# 5. Construct the `values_array` using the REAL data from your 
file\n",
+        "context_len = 512\n",
+        "# Create a dummy context so the array has the right shape for the 
plotting function.\n",
+        "# The first real value is used to make the context visually 
seamless.\n",
+        "if all_actual_horizon_values:\n",
+        "    dummy_context = [all_actual_horizon_values[0]] * context_len\n",
+        "    values_array = np.array(dummy_context + 
all_actual_horizon_values)\n",
+        "else:\n",
+        "    # Fallback in case the file is empty or missing the 
actual_horizon_values key\n",
+        "    print(\"Warning: 'actual_horizon_values' not found. The original 
time series plot will be empty.\")\n",
+        "    total_len = context_len + len(all_predicted_values)\n",
+        "    values_array = np.zeros(total_len)\n",
+        "\n",
+        "# 6. Call the plotting functions\n",
+        "if values_array.any():\n",
+        "    # Plotting function for full graph\n",
+        "    plot_anomalies_and_forecast(\n",
+        "        values_array, all_anomalies, all_predicted_values,\n",
+        "        all_q20_values, all_q30_values, all_q70_values, 
all_q80_values,\n",
+        "        title_suffix=\"(Full Graph with Correct Data)\",\n",
+        "        min_outlier_score_for_plot=1, # Set a score threshold\n",
+        "        context_len=context_len,\n",
+        "        output_filename=\"full_graph_correct.png\"\n",
+        "    )\n",
+        "\n",
+        "    # Plotting function for zoomed-in graphs - feel free to change\n",
+        "    zoom_ranges = [(2000, 2500), (8300, 9000), (9000, 9600)]\n",
+        "    for i, (start_idx, end_idx) in enumerate(zoom_ranges):\n",
+        "        # Adjust x_lims for the fact that the plotted array is sliced 
by context_len\n",
+        "        plot_x_start = max(0, start_idx)\n",
+        "        plot_x_end = end_idx\n",
+        "\n",
+        "        plot_anomalies_and_forecast(\n",
+        "            values_array, all_anomalies, all_predicted_values,\n",
+        "            all_q20_values, all_q30_values, all_q70_values, 
all_q80_values,\n",
+        "            title_suffix=f\"(Zoomed In: {start_idx} to 
{end_idx})\",\n",
+        "            x_lims=(plot_x_start, plot_x_end),\n",
+        "            min_outlier_score_for_plot=1,\n",
+        "            context_len=context_len,\n",
+        "            output_filename=f\"zoomed_graph_correct_{i}.png\"\n",
+        "        )\n",
+        "    print(\"Plots have been generated and saved with the corrected 
original data.\")\n",
+        "else:\n",
+        "    print(\"No data found to plot.\")"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/";,
+          "height": 1000
+        },
+        "id": "_HzqIoZbl25b",
+        "outputId": "7f85032d-e2cf-4e7c-b9b1-cae9994dee37"
+      },
+      "execution_count": 18,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "948\n"
+          ]
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<Figure size 1800x900 with 1 Axes>"
+            ],
+            "image/png": 
"iVBORw0KGgoAAAANSUhEUgAABvsAAAN5CAYAAAAmV4erAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XdYFMf/B/D3cVSpIiiiVBuoiL0rdozla9eoUYk1KrEbTTQKsccaY+y9GzWx94K9K4gNEcGu2EB6uZvfH/y4eFLVgwXu/Xoennizs7Of2d37BG5uZ2RCCAEiIiIiIiIiIiIiIiIiynd0pA6AiIiIiIiIiIiIiIiIiL4MB/uIiIiIiIiIiIiIiIiI8ikO9hERERERERERERERERHlUxzsIyIiIiIiIiIiIiIiIsqnONhHRERERERERERERERElE9xsI+IiIiIiIiIiIiIiIgon+JgHxERER
 [...]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "948\n"
+          ]
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<Figure size 1800x900 with 1 Axes>"
+            ],
+            "image/png": 
"iVBORw0KGgoAAAANSUhEUgAABv0AAAN5CAYAAAArSffsAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3Xd4VMX79/HPpjeSEFpASuhNqiCiIkWaFLEAokgvKiCIgOAXpYgUpaoI2CgKKgp2RJqEDgpIL9JCUSBISQgkhOzO8wdP9seShOxi2E3w/bquXLJz5szcs3tyS7hz5liMMUYAAAAAAAAAAAAAciwvTwcAAAAAAAAAAAAA4N+h6AcAAAAAAAAAAADkcBT9AAAAAAAAAAAAgByOoh8AAAAAAAAAAACQw1H0AwAAAAAAAAAAAHI4in4AAAAAAAAAAABADkfRDwAAAAAAAAAAAMjhKPoBAAAAAA
 [...]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "948\n"
+          ]
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<Figure size 1800x900 with 1 Axes>"
+            ],
+            "image/png": 
"iVBORw0KGgoAAAANSUhEUgAABv0AAAN5CAYAAAArSffsAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XdcU9f7B/BPwgxbFBnKUhFQcS+sFjfWUfdEhTpbpW6rVuuou64O66h7a9VWbd0L97a4tyCoIC5ANiTn94c/8jUyDAoJVz/v14tXzbnnnvuc5D4JzcM9VyaEECAiIiIiIiIiIiIiIiIiyZLrOwAiIiIiIiIiIiIiIiIi+jAs+hERERERERERERERERFJHIt+RERERERERERERERERBLHoh8RERERERERERERERGRxLHoR0RERERERERERERERCRxLPoRERERERERERERERERSRyLfkRERE
 [...]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "948\n"
+          ]
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<Figure size 1800x900 with 1 Axes>"
+            ],
+            "image/png": 
"iVBORw0KGgoAAAANSUhEUgAABv4AAAN5CAYAAADAfkzvAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XdYFFfbBvB7lrp0QQRRmqKIir3GKHaNvceusb5RYjfqq7HEHmsSY6+xx5LExK4Re4kFxY4oYkNUBKTD7vn+8GNfV4qgAyPr/bsurmTPnDnznJ3hYeVhzkhCCAEiIiIiIiIiIiIiIiIiytdUSgdARERERERERERERERERB+OhT8iIiIiIiIiIiIiIiIiA8DCHxEREREREREREREREZEBYOGPiIiIiIiIiIiIiIiIyACw8EdERERERERERERERERkAFj4IyIiIiIiIiIiIiIiIjIALPwRER
 [...]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Plots have been generated and saved with the corrected original 
data.\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# Finetuned Model Predictions"
+      ],
+      "metadata": {
+        "id": "_KxBniHGqS3M"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "import apache_beam as beam\n",
+        "from apache_beam.options.pipeline_options import PipelineOptions\n",
+        "from apache_beam.pvalue import AsDict, AsSingleton\n",
+        "from apache_beam.transforms.periodicsequence import 
PeriodicImpulse\n",
+        "import logging\n",
+        "import os\n",
+        "import json\n",
+        "import timesfm\n",
+        "from apache_beam.utils.timestamp import Timestamp\n",
+        "import csv\n",
+        "from apache_beam.ml.inference.base import RunInference\n",
+        "from apache_beam.ml.inference.utils import WatchFilePattern\n",
+        "import typing\n",
+        "from google.colab import userdata\n",
+        "import apache_beam.transforms.window as window\n",
+        "\n",
+        "logging.getLogger().setLevel(logging.INFO)\n",
+        "\n",
+        "# --- Pipeline Configuration ---\n",
+        "PROJECT_ID = os.environ.get(\"GCP_PROJECT\", 
\"apache-beam-testing\")\n",
+        "REGION = os.environ.get(\"GCP_REGION\", \"us-central1\")\n",
+        "TEMP_LOCATION = 
\"gs://apache-beam-testing-temp/timesfm_anomaly_detection/temp\"\n",
+        "STAGING_LOCATION = 
\"gs://apache-beam-testing-temp/timesfm_anomaly_detection/staging\"\n",
+        "FINETUNED_MODEL_BUCKET = \"apache-beam-testing-temp\"\n",
+        "FINETUNED_MODEL_PREFIX = 
\"timesfm_anomaly_detection/finetuned-models/timesfm/checkpoints\"\n",
+        "\n",
+        "# --- Model & Window Parameters ---\n",
+        "CONTEXT_LEN = 512\n",
+        "HORIZON_LEN = 128\n",
+        "WINDOW_SIZE = CONTEXT_LEN + HORIZON_LEN\n",
+        "SLIDE_INTERVAL = HORIZON_LEN\n",
+        "EXPECTED_INTERVAL = 1\n",
+        "# go to checkpoints bucket and select the correct model path\n",
+        "INITIAL_MODEL = 
\"gs://apache-beam-testing-temp/timesfm_anomaly_detection/finetuned-models/timesfm/checkpoints/timesfm_finetuned_20250814192006.pth\"\n",
+        "\n",
+        "MODEL_CHECK_INTERVAL_SECONDS = 10  # Check for a new model every 5 
seconds\n",
+        "FINETUNING_BATCH_SIZE = 7680 # 9600 # make larger later. minimum is 
WINDOW_SIZE for validation and training\n",
+        "FINETUNE_CONFIG = FinetuningConfig(\n",
+        "      batch_size=128,\n",
+        "      num_epochs=5,\n",
+        "      learning_rate=1e-4,\n",
+        "      use_wandb=False,\n",
+        "      freq_type=0, # should change based on your data\n",
+        "      log_every_n_steps=10,\n",
+        "      val_check_interval=0.5,\n",
+        "      use_quantile_loss=True\n",
+        "  )\n",
+        "\n",
+        "\n",
+        "options = PipelineOptions([\n",
+        "    \"--streaming\",\n",
+        "    \"--environment_type=LOOPBACK\",\n",
+        "    \"--runner=PrismRunner\",\n",
+        "    \"--logging_level=INFO\",\n",
+        "    \"--job_server_timeout=3600\"\n",
+        "])\n",
+        "\n",
+        "\n",
+        "\n",
+        "# HParams for the model\n",
+        "hparams = timesfm.TimesFmHparams(\n",
+        "    backend=\"gpu\",\n",
+        "    per_core_batch_size=32,\n",
+        "    horizon_len=HORIZON_LEN,\n",
+        "    context_len=CONTEXT_LEN,\n",
+        ")\n",
+        "model_handler = DynamicTimesFmModelHandler(model_uri=INITIAL_MODEL, 
hparams=hparams)\n",
+        "\n",
+        "def print_and_pass_through(label):\n",
+        "    def logger(element):\n",
+        "        print(f\"--- {label} --- \\nELEMENT: %s\", element)\n",
+        "        return element\n",
+        "    return logger\n",
+        "\n",
+        "\n",
+        "class CustomJsonEncoder(json.JSONEncoder):\n",
+        "    \"\"\"A custom JSON encoder that knows how to handle Beam's 
Timestamp objects.\"\"\"\n",
+        "    def default(self, obj):\n",
+        "        if isinstance(obj, Timestamp):\n",
+        "            # Convert Timestamp to a standard, readable ISO 8601 
string format\n",
+        "            return obj.micros // 1e6\n",
+        "        # For all other types, fall back to the default behavior\n",
+        "        if isinstance(obj, np.integer):\n",
+        "            return int(obj)\n",
+        "\n",
+        "        # 3. Handle NumPy float types (this will fix your float32 
error)\n",
+        "        if isinstance(obj, np.floating):\n",
+        "            return float(obj)\n",
+        "\n",
+        "        # 4. Handle NumPy arrays\n",
+        "        if isinstance(obj, np.ndarray):\n",
+        "            return obj.tolist()\n",
+        "\n",
+        "        # For all other types, fall back to the default behavior\n",
+        "        return super().default(obj)\n",
+        "        return json.JSONEncoder.default(self, obj)\n",
+        "\n",
+        "class WritePlotDataAndPassThrough(beam.DoFn):\n",
+        "    \"\"\"\n",
+        "    A DoFn that writes plotting data to a file as a side effect\n",
+        "    and then passes the original, unmodified element downstream.\n",
+        "    \"\"\"\n",
+        "    def __init__(self, output_path):\n",
+        "        self._output_path = output_path\n",
+        "        self._file_handle = None\n",
+        "\n",
+        "    def setup(self):\n",
+        "        self._file_handle = open(self._output_path, 'a')\n",
+        "\n",
+        "    def process(self, element):\n",
+        "        _original_window, payload_dict = element\n",
+        "\n",
+        "        # ✅ FIX: Use the custom encoder to handle Timestamp 
objects\n",
+        "        json_record = json.dumps(payload_dict, 
cls=CustomJsonEncoder)\n",
+        "        self._file_handle.write(json_record + '\\n')\n",
+        "\n",
+        "        # Pass the original element through, with the Timestamp 
object intact\n",
+        "        yield element\n",
+        "\n",
+        "    def teardown(self):\n",
+        "        if self._file_handle:\n",
+        "            self._file_handle.close()\n",
+        "\n",
+        "\n",
+        "# 
=================================================================\n",
+        "# 1. Get Latest Model Path (Side Input) - WatchFilePattern is not\n",
+        "#    currently supported on Prism. Uncomment the following to run\n",
+        "#    on Dataflow\n",
+        "# 
=================================================================\n",
+        "# model_pattern = os.path.join(\n",
+        "#     f\"gs://{FINETUNED_MODEL_BUCKET}\", FINETUNED_MODEL_PREFIX, 
\"*.pth\"\n",
+        "# )\n",
+        "\n",
+        "# model_metadata_pcoll = (\n",
+        "#     \"WatchForNewModels\" >> WatchFilePattern(\n",
+        "#         file_pattern=model_pattern,\n",
+        "#         interval=MODEL_CHECK_INTERVAL_SECONDS\n",
+        "#       )\n",
+        "#     | \"PrintModelLocation\" >> 
beam.Map(print_and_pass_through(\"Model Location\"))\n",
+        "\n",
+        "# )\n",
+        "\n",
+        "# 
=================================================================\n",
+        "# Ingest and Window Raw Data\n",
+        "# 
=================================================================\n",
+        "\n",
+        "\n",
+        "windowed_data = (\n",
+        "    PeriodicImpulse(data=input_data, fire_interval=0.01)\n",
+        "    | \"AddKey\" >> beam.WithKeys(lambda x: 0)\n",
+        "    | \"ApplySlidingWindow\" >> beam.ParDo(\n",
+        "        OrderedSlidingWindowFn(window_size=WINDOW_SIZE, 
slide_interval=SLIDE_INTERVAL))\n",
+        "    | \"FillGaps\" >> 
beam.ParDo(FillGapsFn(expected_interval=EXPECTED_INTERVAL)).with_output_types(\n",
+        "        typing.Tuple[int, typing.Tuple[Timestamp, Timestamp, 
typing.List[float]]])\n",
+        "    | \"Skip NaN Values for now\" >> beam.Filter(\n",
+        "      lambda batch: 'NaN' not in batch[1][2])\n",
+        "    | \"PrintWindowedData\" >> 
beam.Map(print_and_pass_through(\"Windowed Data\"))\n",
+        "\n",
+        ")\n",
+        "\n",
+        "# 
=================================================================\n",
+        "# Detect Anomalies using the Latest Model\n",
+        "# 
=================================================================\n",
+        "\n",
+        "inference_results = (\n",
+        "    \"DetectAnomalies\" >> RunInference(\n",
+        "        model_handler=model_handler,\n",
+        "        # model_metadata_pcoll=model_metadata_pcoll\n",
+        "      )\n",
+        "    | \"PrintInference\" >> 
beam.Map(print_and_pass_through(\"Inference Results\"))\n",
+        ")\n",
+        "\n",
+        "\n",
+        "# NEW BRANCH: For plotting. It takes the payload dictionary, 
converts\n",
+        "# it to JSON, and writes it to a file.\n",
+        "plotting_data_output = (\n",
+        "    \"WritePlotDataAsSideEffect\" >> beam.ParDo(\n",
+        "          
WritePlotDataAndPassThrough('plot_data_finetuned.jsonl'))\n",
+        ")\n",
+        "\n"
+      ],
+      "metadata": {
+        "id": "LRRg2QhPfEZr"
+      },
+      "execution_count": 19,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "with beam.Pipeline(options=options) as p:\n",
+        "  (p\n",
+        "  | windowed_data\n",
+        "  | inference_results\n",
+        "  | plotting_data_output\n",
+        "  )\n"
+      ],
+      "metadata": {
+        "id": "qAJDAbdGqW_V"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# Plot Data (After Finetuning)"
+      ],
+      "metadata": {
+        "id": "_fT-AjbrUOl5"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "import json\n",
+        "import numpy as np\n",
+        "import matplotlib.pyplot as plt\n",
+        "\n",
+        "def plot_anomalies_and_forecast(\n",
+        "    values_array,\n",
+        "    all_anomalies,\n",
+        "    all_predicted_values,\n",
+        "    all_q20_values,\n",
+        "    all_q30_values,\n",
+        "    all_q70_values,\n",
+        "    all_q80_values,\n",
+        "    title_suffix=\"\",\n",
+        "    x_lims=None,\n",
+        "    min_outlier_score_for_plot=0,\n",
+        "    context_len=512,\n",
+        "    output_filename=\"plot.png\"\n",
+        "):\n",
+        "    # The key from your file is 'outlier_score'\n",
+        "    filtered_anomalies = [a for a in all_anomalies if 
a['outlier_score'] >= min_outlier_score_for_plot]\n",
+        "    # The key from your file is 'timestamp'\n",
+        "    anomaly_indices = [(a['timestamp'] - 128) for a in 
filtered_anomalies]\n",
+        "    anomaly_values = [a['actual_value'] for a in 
filtered_anomalies]\n",
+        "\n",
+        "    Q1 = np.nanmean([all_q20_values, all_q30_values], axis=0)\n",
+        "    Q3 = np.nanmean([all_q70_values, all_q80_values], axis=0)\n",
+        "    IQR = Q3 - Q1\n",
+        "    upper_thresh = Q3 + 1.5 * IQR\n",
+        "    lower_thresh = Q1 - 1.5 * IQR\n",
+        "\n",
+        "    plt.figure(figsize=(18, 9))\n",
+        "    # This now plots the correct original data for the horizon\n",
+        "    plt.plot(values_array[context_len:], label='Original Time 
Series', color='blue', alpha=0.7, linewidth=1.5)\n",
+        "\n",
+        "    plt.plot(all_predicted_values, label='Predicted Mean', 
color='green', linestyle='--', linewidth=1.5)\n",
+        "    plt.plot(lower_thresh, label='Lower Threshold', color='orange', 
linestyle=':', linewidth=1.2)\n",
+        "    plt.plot(upper_thresh, label='Upper Threshold', color='purple', 
linestyle=':', linewidth=1.2)\n",
+        "\n",
+        "    plt.scatter([i - context_len for i in anomaly_indices], 
anomaly_values,\n",
+        "                color='red', s=70, zorder=5,\n",
+        "                label=f'Detected Anomalies (Score >= 
{min_outlier_score_for_plot:.1f})',\n",
+        "                marker='o', edgecolors='black', linewidths=0.8)\n",
+        "\n",
+        "    plt.title(f'Time Series Anomaly Detection {title_suffix}')\n",
+        "    plt.xlabel('Time Index')\n",
+        "    plt.ylabel('Value')\n",
+        "    if x_lims:\n",
+        "        plt.xlim(x_lims[0], x_lims[1])\n",
+        "    plt.legend()\n",
+        "    plt.grid(True, linestyle='--', alpha=0.6)\n",
+        "    plt.tight_layout()\n",
+        "    # plt.savefig(output_filename) # Save the plot to a file\n",
+        "    plt.show()\n",
+        "    plt.close() # Close the figure to free memory\n",
+        "\n",
+        "# --- Main Script Logic ---\n",
+        "\n",
+        "# 1. Read and parse the data from the Beam output file\n",
+        "all_window_data = []\n",
+        "# Make sure 'plot_data.jsonl' is in the same directory as this 
script\n",
+        "try:\n",
+        "    with open('plot_data_finetuned.jsonl', 'r') as f:\n",
+        "        for line in f:\n",
+        "            # Check for empty lines that might have been added\n",
+        "            if line.strip():\n",
+        "                all_window_data.append(json.loads(line))\n",
+        "except FileNotFoundError:\n",
+        "    print(\"Error: 'plot_data_finetuned.jsonl' not found. Please make 
sure the file is in the correct directory.\")\n",
+        "    exit()\n",
+        "\n",
+        "\n",
+        "# 2. Sort data by timestamp to ensure the correct order\n",
+        "all_window_data.sort(key=lambda x: x['start_ts_micros'])\n",
+        "\n",
+        "# 3. Reconstruct the full data arrays\n",
+        "all_anomalies = []\n",
+        "all_predicted_values = []\n",
+        "all_q20_values = []\n",
+        "all_q30_values = []\n",
+        "all_q70_values = []\n",
+        "all_q80_values = []\n",
+        "all_actual_horizon_values = [] # This will hold the real \"blue 
line\" data\n",
+        "\n",
+        "for window_data in all_window_data:\n",
+        "    all_predicted_values.extend(window_data['predicted_values'])\n",
+        "    all_q20_values.extend(window_data['q20_values'])\n",
+        "    all_q30_values.extend(window_data['q30_values'])\n",
+        "    all_q70_values.extend(window_data['q70_values'])\n",
+        "    all_q80_values.extend(window_data['q80_values'])\n",
+        "    # Populate the list with the actual values from the file\n",
+        "    
all_actual_horizon_values.extend(window_data.get('actual_horizon_values', 
[]))\n",
+        "    all_anomalies.extend(window_data.get('anomalies', []))\n",
+        "\n",
+        "# 4. Convert lists to NumPy arrays\n",
+        "all_predicted_values = np.array(all_predicted_values)\n",
+        "print(len(all_predicted_values))\n",
+        "all_q20_values = np.array(all_q20_values)\n",
+        "all_q30_values = np.array(all_q30_values)\n",
+        "all_q70_values = np.array(all_q70_values)\n",
+        "all_q80_values = np.array(all_q80_values)\n",
+        "\n",
+        "# 5. Construct the `values_array` using the REAL data from your 
file\n",
+        "context_len = 512\n",
+        "# Create a dummy context so the array has the right shape for the 
plotting function.\n",
+        "# The first real value is used to make the context visually 
seamless.\n",
+        "if all_actual_horizon_values:\n",
+        "    dummy_context = [all_actual_horizon_values[0]] * context_len\n",
+        "    values_array = np.array(dummy_context + 
all_actual_horizon_values)\n",
+        "else:\n",
+        "    # Fallback in case the file is empty or missing the 
actual_horizon_values key\n",
+        "    print(\"Warning: 'actual_horizon_values' not found. The original 
time series plot will be empty.\")\n",
+        "    total_len = context_len + len(all_predicted_values)\n",
+        "    values_array = np.zeros(total_len)\n",
+        "\n",
+        "# 6. Call the plotting functions\n",
+        "if values_array.any():\n",
+        "    # Plotting function for full graph\n",
+        "    plot_anomalies_and_forecast(\n",
+        "        values_array, all_anomalies, all_predicted_values,\n",
+        "        all_q20_values, all_q30_values, all_q70_values, 
all_q80_values,\n",
+        "        title_suffix=\"(Full Graph with Correct Data)\",\n",
+        "        min_outlier_score_for_plot=5, # Set a score threshold\n",
+        "        context_len=context_len,\n",
+        "        output_filename=\"full_graph_correct.png\"\n",
+        "    )\n",
+        "\n",
+        "    # Plotting function for zoomed-in graphs - feel free to change\n",
+        "    zoom_ranges = [(2000, 2500), (8300, 9000), (9000, 9600)]\n",
+        "    for i, (start_idx, end_idx) in enumerate(zoom_ranges):\n",
+        "        # Adjust x_lims for the fact that the plotted array is sliced 
by context_len\n",
+        "        plot_x_start = max(0, start_idx)\n",
+        "        plot_x_end = end_idx\n",
+        "\n",
+        "        plot_anomalies_and_forecast(\n",
+        "            values_array, all_anomalies, all_predicted_values,\n",
+        "            all_q20_values, all_q30_values, all_q70_values, 
all_q80_values,\n",
+        "            title_suffix=f\"(Zoomed In: {start_idx} to 
{end_idx})\",\n",
+        "            x_lims=(plot_x_start, plot_x_end),\n",
+        "            min_outlier_score_for_plot=5,\n",
+        "            context_len=context_len,\n",
+        "            output_filename=f\"zoomed_graph_correct_{i}.png\"\n",
+        "        )\n",
+        "    print(\"Plots have been generated and saved with the corrected 
original data.\")\n",
+        "else:\n",
+        "    print(\"No data found to plot.\")"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/";,
+          "height": 1000
+        },
+        "id": "RZwpgtAr8nlD",
+        "outputId": "d27b08f1-1e77-4109-bfa6-c709edf4b5e8"
+      },
+      "execution_count": 21,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "9600\n"
+          ]
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<Figure size 1800x900 with 1 Axes>"
+            ],
+            "image/png": 
"iVBORw0KGgoAAAANSUhEUgAABvsAAAN5CAYAAAAmV4erAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XdcU9f7B/BPEvYWFHGwRBFsUcRVtA5cWJWvo466cX+11m3Vap2to2rVDkfddVTrqNqqdePeCqIoKoIbcSCKEEZyf3/wS75GEgSFXOB+3q8Xrzbnntz7nOTylObJOUcmCIIAIiIiIiIiIiIiIiIiIipy5GIHQERERERERERERERERETvh8U+IiIiIiIiIiIiIiIioiKKxT4iIiIiIiIiIiIiIiKiIorFPiIiIiIiIiIiIiIiIqIiisU+IiIiIiIiIiIiIiIioiKKxT4iIiIiIiIiIiIiIi
 [...]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<Figure size 1800x900 with 1 Axes>"
+            ],
+            "image/png": 
"iVBORw0KGgoAAAANSUhEUgAABv0AAAN5CAYAAAArSffsAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XmcjfX///HndWbMwswZxjZkGbuhQdamkiUaWaIIZV9SIXvJR1kLFaGFlEKW0p5CtpCQLNn3ncLIMmMwM5xz/f7wm/N1zAxzTuMcR4/77Ta3nPd5X9f79b7ONS+Z17zfl2GapikAAAAAAAAAAAAAPsvi7QAAAAAAAAAAAAAA/DsU/QAAAAAAAAAAAAAfR9EPAAAAAAAAAAAA8HEU/QAAAAAAAAAAAAAfR9EPAAAAAAAAAAAA8HEU/QAAAAAAAAAAAAAfR9EPAAAAAAAAAAAA8HEU/QAAAA
 [...]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<Figure size 1800x900 with 1 Axes>"
+            ],
+            "image/png": 
"iVBORw0KGgoAAAANSUhEUgAABv4AAAN5CAYAAADAfkzvAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XdYFNf7NvB76HVBbKBUK6jYS0ii2LFGY48Ve1Rir1+NNbFEjZoYS9SoscSaqAn2gj12FBs2sEQQlSZK3T3vH747P1eKoOgweH+uiyvu2bMzz9mdG6IPMyMJIQSIiIiIiIiIiIiIiIiISNWMlC6AiIiIiIiIiIiIiIiIiN4dG39ERERERERERERERERE+QAbf0RERERERERERERERET5ABt/RERERERERERERERERPkAG39ERERERERERERERERE+QAbf0RERERERERERERERET5ABt/RE
 [...]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<Figure size 1800x900 with 1 Axes>"
+            ],
+            "image/png": 
"iVBORw0KGgoAAAANSUhEUgAABv4AAAN5CAYAAADAfkzvAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XdYFFfbBvB76IuwIIhioVlBxd6ISbBjLLH3AvbXErtRX401tqhRE2OPJZZYU0ywG7FrbCg2oiiWKKLSRPru+f7w23ldKYLusLrev+viinvm7MxzZoaHDQ/njCSEECAiIiIiIiIiIiIiIiKi95qZsQMgIiIiIiIiIiIiIiIiorfHwh8RERERERERERERERGRCWDhj4iIiIiIiIiIiIiIiMgEsPBHREREREREREREREREZAJY+CMiIiIiIiIiIiIiIiIyASz8EREREREREREREREREZkAFv
 [...]
+          },
+          "metadata": {}
+        },
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "Plots have been generated and saved with the corrected original 
data.\n"
+          ]
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "source": [
+        "# Compare with RMSE (Original vs Finetuned on Test Data)"
+      ],
+      "metadata": {
+        "id": "WAyaVAEIt4Ey"
+      }
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "import json\n",
+        "import numpy as np\n",
+        "import pandas as pd\n",
+        "\n",
+        "def calculate_model_rmse(file_path, finetuning_cutoff_ts, 
outlier_threshold):\n",
+        "    \"\"\"\n",
+        "    Loads model output data, reconstructs the time series, and 
calculates RMSE\n",
+        "    on a test set after filtering outliers.\n",
+        "\n",
+        "    Args:\n",
+        "        file_path (str): Path to the model's prediction data in JSONL 
format.\n",
+        "        finetuning_cutoff_ts (int): Timestamp to split training and 
test data.\n",
+        "        outlier_threshold (float): Outlier score at or above which to 
exclude points.\n",
+        "\n",
+        "    Returns:\n",
+        "        float: The calculated Root Mean Squared Error.\n",
+        "    \"\"\"\n",
+        "    all_window_data = []\n",
+        "    with open(file_path, 'r') as f:\n",
+        "        for line in f:\n",
+        "            if line.strip():\n",
+        "                all_window_data.append(json.loads(line))\n",
+        "\n",
+        "    all_window_data.sort(key=lambda x: x['start_ts_micros'])\n",
+        "\n",
+        "    # --- Reconstruct the full time series from the windows ---\n",
+        "    timestamps = []\n",
+        "    all_predicted_values = []\n",
+        "    all_actual_values = []\n",
+        "    all_anomalies = []\n",
+        "\n",
+        "    current_ts = -1\n",
+        "    if all_window_data:\n",
+        "        # Initialize with the first window's start time\n",
+        "        current_ts = all_window_data[0]['start_ts_micros'] // 
1000000\n",
+        "\n",
+        "    for window_data in all_window_data:\n",
+        "        # Extend the series lists\n",
+        "        
all_predicted_values.extend(window_data['predicted_values'])\n",
+        "        
all_actual_values.extend(window_data.get('actual_horizon_values', []))\n",
+        "        all_anomalies.extend(window_data.get('anomalies', []))\n",
+        "\n",
+        "        # Reconstruct the timestamps for each predicted point\n",
+        "        start_ts = window_data['start_ts_micros'] // 1000000\n",
+        "        for _ in window_data['predicted_values']:\n",
+        "            timestamps.append(start_ts)\n",
+        "            start_ts += 1\n",
+        "\n",
+        "    # Create a lookup for outlier scores\n",
+        "    outlier_scores_map = {item['timestamp']: item['outlier_score'] 
for item in all_anomalies}\n",
+        "\n",
+        "    # Ensure the actual values and predicted values align\n",
+        "    min_len = min(len(timestamps), len(all_predicted_values), 
len(all_actual_values))\n",
+        "\n",
+        "    # --- Create a DataFrame for easy filtering and calculation 
---\n",
+        "    df = pd.DataFrame({\n",
+        "        'timestamp': timestamps[:min_len],\n",
+        "        'actual': all_actual_values[:min_len],\n",
+        "        'predicted': all_predicted_values[:min_len]\n",
+        "    })\n",
+        "    df['outlier_score'] = 
df['timestamp'].map(outlier_scores_map).fillna(0.0)\n",
+        "\n",
+        "    # 1. Isolate the test set\n",
+        "    df_test = df[df['timestamp'] > finetuning_cutoff_ts].copy()\n",
+        "    print(f\"\\n--- Analyzing: {file_path} ---\")\n",
+        "    print(f\"Test set size (before filtering): {len(df_test)} 
points\")\n",
+        "\n",
+        "    # 2. Filter out anomalies based on the threshold\n",
+        "    df_filtered = df_test[df_test['outlier_score'] < 
outlier_threshold]\n",
+        "    num_outliers = len(df_test) - len(df_filtered)\n",
+        "    print(f\"Test set size (after filtering): {len(df_filtered)} 
points\")\n",
+        "    print(f\"Removed {num_outliers} points with outlier_score >= 
{outlier_threshold}\")\n",
+        "\n",
+        "    # 3. Calculate RMSE\n",
+        "    y_true = df_filtered['actual']\n",
+        "    y_pred = df_filtered['predicted']\n",
+        "    rmse = np.sqrt(np.mean((y_true - y_pred)**2))\n",
+        "\n",
+        "    return rmse\n",
+        "\n",
+        "# --- Configuration ---\n",
+        "FINETUNING_CUTOFF_TS = 8320\n",
+        "ORIGINAL_OUTLIER_THRESHOLD = 1.0\n",
+        "FINETUNED_OUTLIER_THRESHOLD = 5.0\n",
+        "\n",
+        "# --- Execution & Comparison ---\n",
+        "original_rmse = calculate_model_rmse(\n",
+        "    file_path=\"plot_data_original.jsonl\",\n",
+        "    finetuning_cutoff_ts=FINETUNING_CUTOFF_TS,\n",
+        "    outlier_threshold=ORIGINAL_OUTLIER_THRESHOLD\n",
+        ")\n",
+        "\n",
+        "finetuned_rmse = calculate_model_rmse(\n",
+        "    file_path=\"plot_data_finetuned.jsonl\",\n",
+        "    finetuning_cutoff_ts=FINETUNING_CUTOFF_TS,\n",
+        "    outlier_threshold=FINETUNED_OUTLIER_THRESHOLD\n",
+        ")\n",
+        "\n",
+        "print(\"\\n--- Final Results ---\")\n",
+        "print(f\"Original Model RMSE: {original_rmse:.2f}\")\n",
+        "print(f\"Fine-tuned Model RMSE: {finetuned_rmse:.2f}\")"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/";
+        },
+        "id": "zrMOmc90lQhJ",
+        "outputId": "d0679b35-b375-4068-aa4b-bd22a8c6166e"
+      },
+      "execution_count": 22,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "\n",
+            "--- Analyzing: plot_data_original.jsonl ---\n",
+            "Test set size (before filtering): 1407 points\n",
+            "Test set size (after filtering): 1369 points\n",
+            "Removed 38 points with outlier_score >= 1.0\n",
+            "\n",
+            "--- Analyzing: plot_data_finetuned.jsonl ---\n",
+            "Test set size (before filtering): 1407 points\n",
+            "Test set size (after filtering): 1384 points\n",
+            "Removed 23 points with outlier_score >= 5.0\n",
+            "\n",
+            "--- Final Results ---\n",
+            "Original Model RMSE: 7164.17\n",
+            "Fine-tuned Model RMSE: 3948.44\n"
+          ]
+        }
+      ]
+    }
+  ]
+}
\ No newline at end of file

Reply via email to