This is an automated email from the ASF dual-hosted git repository.
shunping pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 9a15261c00d Adding Timesfm anomaly detection colab example (#35844)
9a15261c00d is described below
commit 9a15261c00d44d187460e100e1fdf566960e4235
Author: Ashok Devireddy <[email protected]>
AuthorDate: Thu Aug 14 20:34:35 2025 -0400
Adding Timesfm anomaly detection colab example (#35844)
* Created using Colab
* add license
* added dataset to gcloud
* cleared outputs
* removed commented code
* Delete examples/notebooks/beam-ml/anomaly_detection
/anomaly_detection_timesfm.ipynb
* removed commented code
* changed to beam 2.67.0 and edited file paths
* uncomment plot data line
---
.../anomaly_detection_timesfm.ipynb | 2712 ++++++++++++++++++++
1 file changed, 2712 insertions(+)
diff --git
a/examples/notebooks/beam-ml/anomaly_detection/anomaly_detection_timesfm.ipynb
b/examples/notebooks/beam-ml/anomaly_detection/anomaly_detection_timesfm.ipynb
new file mode 100644
index 00000000000..65a06fd5d8e
--- /dev/null
+++
b/examples/notebooks/beam-ml/anomaly_detection/anomaly_detection_timesfm.ipynb
@@ -0,0 +1,2712 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "gpuType": "T4"
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU"
+ },
+ "cells": [
+ {
+ "cell_type": "code",
+ "source": [
+ "# @title ###### Licensed to the Apache Software Foundation (ASF),
Version 2.0 (the \"License\")\n",
+ "\n",
+ "# Licensed to the Apache Software Foundation (ASF) under one\n",
+ "# or more contributor license agreements. See the NOTICE file\n",
+ "# distributed with this work for additional information\n",
+ "# regarding copyright ownership. The ASF licenses this file\n",
+ "# to you under the Apache License, Version 2.0 (the\n",
+ "# \"License\"); you may not use this file except in compliance\n",
+ "# with the License. You may obtain a copy of the License at\n",
+ "#\n",
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing,\n",
+ "# software distributed under the License is distributed on an\n",
+ "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
+ "# KIND, either express or implied. See the License for the\n",
+ "# specific language governing permissions and limitations\n",
+ "# under the License"
+ ],
+ "metadata": {
+ "id": "eMMlVe_Gukos"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# TimesFM Anomaly Detection Pipeline Diagram\n",
+ "Time series data is a sequence of data points indexed by time, where
each data point is recorded at a specific interval. TimesFM is a foundation
model pretrained on a large corpus of time series data. Its architecture is a
decoder-only transformer, similar to LLMs, which learns to predict the next
part of a time series from previous data. We can use the follow pipeline to
detect anomalies in time series data and periodically learn from incoming data
to improve our timesfm predic [...]
+ "\n",
+ "\n",
+ "logging.basicConfig(level=logging.INFO)\n",
+ "_LOGGER.setLevel(logging.INFO)\n",
+ "\n",
+ "\n",
+ "class OrderedSlidingWindowFn(beam.DoFn):\n",
+ "\n",
+ " ORDERED_BUFFER_STATE = OrderedListStateSpec('ordered_buffer',
PickleCoder())\n",
+ " WINDOW_TIMER = TimerSpec('window_timer', TimeDomain.WATERMARK)\n",
+ " TIMER_STATE = ReadModifyWriteStateSpec('timer_state',
BooleanCoder())\n",
+ " EARLIEST_TS_STATE = ReadModifyWriteStateSpec('earliest_ts',
TimestampCoder())\n",
+ "\n",
+ " def __init__(self, window_size, slide_interval):\n",
+ " self.window_size = window_size\n",
+ " self.slide_interval = slide_interval\n",
+ "\n",
+ " def start_bundle(self):\n",
+ " _LOGGER.debug(\"start bundle\")\n",
+ "\n",
+ " def finish_bundle(self):\n",
+ " _LOGGER.debug(\"finish bundle\")\n",
+ "\n",
+ " def process(\n",
+ " self,\n",
+ " element,\n",
+ " timestamp=beam.DoFn.TimestampParam,\n",
+ " ordered_buffer=beam.DoFn.StateParam(ORDERED_BUFFER_STATE),\n",
+ " window_timer=beam.DoFn.TimerParam(WINDOW_TIMER),\n",
+ " timer_state=beam.DoFn.StateParam(TIMER_STATE),\n",
+ " earliest_ts_state=beam.DoFn.StateParam(EARLIEST_TS_STATE)):\n",
+ "\n",
+ " _, value = element\n",
+ " ordered_buffer.add((timestamp, value))\n",
+ "\n",
+ " _LOGGER.debug(\"receive %s at %s\", element, timestamp)\n",
+ " timer_started = timer_state.read()\n",
+ "\n",
+ " earliest = earliest_ts_state.read()\n",
+ " if not earliest or earliest > timestamp:\n",
+ " earliest_ts_state.write(timestamp)\n",
+ "\n",
+ " if not timer_started:\n",
+ " earliest_ts_state.write(timestamp)\n",
+ "\n",
+ " first_slide_start = int(\n",
+ " timestamp.micros / 1e6 // self.slide_interval) *
self.slide_interval\n",
+ " first_slide_start_ts = Timestamp.of(first_slide_start)\n",
+ "\n",
+ " first_window_end_ts = first_slide_start_ts +
self.window_size\n",
+ " _LOGGER.debug(\"set timer to %s\", first_window_end_ts)\n",
+ " window_timer.set(first_window_end_ts)\n",
+ "\n",
+ " timer_state.write(True)\n",
+ "\n",
+ " return []\n",
+ "\n",
+ " @on_timer(WINDOW_TIMER)\n",
+ " def on_timer(\n",
+ " self,\n",
+ " key=beam.DoFn.KeyParam,\n",
+ " fire_ts=beam.DoFn.TimestampParam,\n",
+ " ordered_buffer=beam.DoFn.StateParam(ORDERED_BUFFER_STATE),\n",
+ " window_timer=beam.DoFn.TimerParam(WINDOW_TIMER),\n",
+ " timer_state=beam.DoFn.StateParam(TIMER_STATE),\n",
+ " earliest_ts_state=beam.DoFn.StateParam(EARLIEST_TS_STATE)):\n",
+ " _LOGGER.debug(\"timer fire at %s\", fire_ts)\n",
+ " window_end_ts = fire_ts\n",
+ " window_start_ts = window_end_ts - self.window_size\n",
+ "\n",
+ " window_values = list(\n",
+ " ordered_buffer.read_range(window_start_ts, window_end_ts))\n",
+ "\n",
+ " _LOGGER.debug(\n",
+ " \"window start: %s, window end: %s\", window_start_ts,
window_end_ts)\n",
+ " _LOGGER.debug(\"windowed data in buffer %s\",
str(window_values))\n",
+ " if window_values:\n",
+ " yield (key, (window_start_ts, window_end_ts, window_values))\n",
+ "\n",
+ " next_window_end_ts = fire_ts + self.slide_interval\n",
+ " next_window_start_ts = window_start_ts + self.slide_interval\n",
+ "\n",
+ " earliest_ts = earliest_ts_state.read()\n",
+ " ordered_buffer.clear_range(earliest_ts, next_window_start_ts)\n",
+ "\n",
+ " remaining_data = list(\n",
+ " ordered_buffer.read_range(next_window_start_ts,
MAX_TIMESTAMP))\n",
+ "\n",
+ " if not remaining_data:\n",
+ " timer_state.clear()\n",
+ " earliest_ts_state.write(next_window_start_ts)\n",
+ " return\n",
+ "\n",
+ " _LOGGER.debug(\"set timer to %s\", next_window_end_ts)\n",
+ " window_timer.set(next_window_end_ts)\n",
+ "\n",
+ "\n",
+ "class FillGapsFn(beam.DoFn):\n",
+ " def __init__(self, expected_interval: float):\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " expected_interval: The expected time delta between elements, in
seconds.\n",
+ " \"\"\"\n",
+ " self.expected_interval = expected_interval\n",
+ "\n",
+ " def process(self, element):\n",
+ " key, (window_start_ts, window_end_ts, window_elements) =
element\n",
+ "\n",
+ " received_data = {\n",
+ " round(float(ts.micros / 1e6), 5): val\n",
+ " for ts, val in window_elements\n",
+ " }\n",
+ "\n",
+ " start_sec = float(window_start_ts.micros / 1e6)\n",
+ " end_sec = float(window_end_ts.micros / 1e6)\n",
+ "\n",
+ " filled_values = []\n",
+ " current_ts_sec = start_sec\n",
+ "\n",
+ " while current_ts_sec < end_sec:\n",
+ " lookup_ts = round(current_ts_sec, 5)\n",
+ "\n",
+ " if lookup_ts in received_data:\n",
+ " filled_values.append(float(received_data[lookup_ts]))\n",
+ " else:\n",
+ " filled_values.append('NaN')\n",
+ "\n",
+ " current_ts_sec += self.expected_interval\n",
+ "\n",
+ " yield (key, (window_start_ts, window_end_ts, filled_values))\n"
+ ],
+ "metadata": {
+ "id": "E1fHKPrkuLFW"
+ },
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Model
Handler:\n",
+ " \"\"\"A PTransform that finds the latest model checkpoint in a
GCS path.\"\"\"\n",
+ " def __init__(self, gcs_bucket, gcs_prefix):\n",
+ " self.gcs_bucket = gcs_bucket\n",
+ " self.gcs_prefix = gcs_prefix\n",
+ "\n",
+ " def expand(self, pcoll):\n",
+ " return pcoll | \"FindLatestModel\" >>
beam.Map(self._find_latest_model_path)\n",
+ "\n",
+ " def _find_latest_model_path(self, _):\n",
+ " try:\n",
+ " storage_client = storage.Client()\n",
+ " blobs = storage_client.list_blobs(self.gcs_bucket,
prefix=self.gcs_prefix)\n",
+ " # Filter for model files and find the most recent one\n",
+ " model_blobs = [b for b in blobs if
b.name.endswith(\".pth\")]\n",
+ " latest_blob = max(model_blobs, key=lambda b:
b.time_created, default=None)\n",
+ "\n",
+ " if latest_blob:\n",
+ " path =
f\"gs://{self.gcs_bucket}/{latest_blob.name}\"\n",
+ " logging.info(f\"Found latest finetuned model at:
{path}\")\n",
+ " return path\n",
+ " except Exception as e:\n",
+ " logging.error(f\"Error finding latest model in GCS:
{e}\")\n",
+ "\n",
+ " # Return a path to the base model if no finetuned one exists
or an error occurs\n",
+ " base_model = \"google/timesfm-1.0-200m-pytorch\"\n",
+ " logging.info(f\"No finetuned model found. Using base model:
{base_model}\")\n",
+ " return base_model\n",
+ "\n",
+ "class DynamicTimesFmModelHandler(ModelHandler[np.ndarray, np.ndarray,
timesfm.TimesFm]):\n",
+ " \"\"\"\n",
+ " A model handler that loads a TimesFM model from a dynamic path
(GCS or Hugging Face).\n",
+ " The model path is provided as a side input to RunInference.\n",
+ " \"\"\"\n",
+ " def __init__(self, model_uri: str, hparams):\n",
+ " self._hparams = hparams\n",
+ " self._model = None\n",
+ " self._model_uri = model_uri\n",
+ " self._context_len = hparams.context_len\n",
+ " self._horizon_len = hparams.horizon_len\n",
+ "\n",
+ " def load_model(self) -> timesfm.TimesFm:\n",
+ " \"\"\"Loads a model from the handler's current
model_uri.\"\"\"\n",
+ " logging.info(f\"Loading TimesFM model from path:
{self._model_uri}...\")\n",
+ "\n",
+ " checkpoint_config = {}\n",
+ " if self._model_uri.startswith(\"gs://\"):\n",
+ " try:\n",
+ " gcs = GcsIO()\n",
+ " file_name = os.path.basename(self._model_uri)\n",
+ " local_path = f\"/tmp/{file_name}\"\n",
+ " with gcs.open(self._model_uri, 'rb') as f_in,
open(local_path, 'wb') as f_out:\n",
+ " f_out.write(f_in.read())\n",
+ " checkpoint_config['path'] = local_path\n",
+ " logging.info(f\"Downloaded model from GCS to
{local_path}\")\n",
+ " except Exception as e:\n",
+ " logging.error(f\"Failed to download model from GCS:
{e}. Check path and permissions.\")\n",
+ " raise e # Re-raise the exception to fail fast if the
model can't be loaded.\n",
+ " else:\n",
+ " checkpoint_config['huggingface_repo_id'] =
self._model_uri\n",
+ "\n",
+ " self._model = timesfm.TimesFm(\n",
+ " hparams=self._hparams,\n",
+ "
checkpoint=timesfm.TimesFmCheckpoint(**checkpoint_config)\n",
+ " )\n",
+ " logging.info(\"TimesFM model loaded successfully.\")\n",
+ " return self._model\n",
+ "\n",
+ " def update_model_path(self, model_path: str):\n",
+ " \"\"\"\n",
+ " This method is called by RunInference when a new model
metadata is available\n",
+ " from the side input. It updates the model URI that
`load_model` will use.\n",
+ " \"\"\"\n",
+ " if not model_path:\n",
+ " logging.info(\"Received an empty model path update. No
action taken.\")\n",
+ " return\n",
+ " logging.info(f\"Received model update. New model URI:
{model_path}\")\n",
+ " self._model_uri = model_path\n",
+ " self._model = self.load_model()\n",
+ " logging.info(\"Model has been updated in the handler.\")\n",
+ "\n",
+ " def run_inference(self, batch, model, inference_args=None):\n",
+ " \"\"\"\n",
+ " Runs inference on a batch of data.\n",
+ "\n",
+ " Note: While this is a standard method for ModelHandler,
we will call the\n",
+ " model's `forecast` method directly in our DoFn for
clarity.\n",
+ " \"\"\"\n",
+ " # print(\"Running inference on batch:\", batch)\n",
+ " # logging.info(f\"Running inference on batch:\", batch)\n",
+ "\n",
+ " anomalies_found = []\n",
+ "\n",
+ " key, (window_start_ts, _, values_array) = batch[0]\n",
+ "\n",
+ " # A window must have enough data for both context and
horizon.\n",
+ " # if len(values_array) < self.context_len +
self.horizon_len:\n",
+ " # return\n",
+ "\n",
+ " current_context =
np.array(values_array[:self._context_len])\n",
+ " actual_horizon_values = np.array(\n",
+ " values_array[self._context_len:self._context_len +
self._horizon_len])\n",
+ "\n",
+ " print(\"Current context shape:\", current_context.shape)\n",
+ " print(\"Actual horizon values shape:\",
actual_horizon_values.shape)\n",
+ " point_forecast, experimental_quantile_forecast =
model.forecast(\n",
+ " [current_context],\n",
+ " freq=[0],\n",
+ " )\n",
+ "\n",
+ " current_predicted_horizon_values = point_forecast[\n",
+ " 0, :, 0] if point_forecast.ndim == 3 else
point_forecast[0]\n",
+ "\n",
+ " current_q20_values = experimental_quantile_forecast[0, :,
2]\n",
+ " current_q30_values = experimental_quantile_forecast[0, :,
3]\n",
+ " current_q70_values = experimental_quantile_forecast[0, :,
7]\n",
+ " current_q80_values = experimental_quantile_forecast[0, :,
8]\n",
+ "\n",
+ " for j in range(len(actual_horizon_values)):\n",
+ " current_actual = actual_horizon_values[j]\n",
+ "\n",
+ " point_Q1 = np.nanmean([current_q20_values[j],
current_q30_values[j]])\n",
+ " point_Q3 = np.nanmean([current_q70_values[j],
current_q80_values[j]])\n",
+ " point_IQR = point_Q3 - point_Q1\n",
+ "\n",
+ " upper_thresh = point_Q3 + 1.5 * point_IQR\n",
+ " lower_thresh = point_Q1 - 1.5 * point_IQR\n",
+ "\n",
+ " if current_actual > upper_thresh or current_actual <
lower_thresh:\n",
+ " score = (current_actual - upper_thresh\n",
+ " ) / point_IQR if current_actual >
upper_thresh else (\n",
+ " lower_thresh - current_actual) /
point_IQR\n",
+ "\n",
+ " anomaly_timestamp_seconds = (window_start_ts.micros /
1e6) + (\n",
+ " self._context_len + j)\n",
+ "\n",
+ " index_in_window = self._context_len + j\n",
+ "\n",
+ " anomalies_found.append({\n",
+ " 'key': key,\n",
+ " 'timestamp':
Timestamp(anomaly_timestamp_seconds),\n",
+ " 'index_in_window': index_in_window,\n",
+ " 'actual_value': current_actual,\n",
+ " 'predicted_value':
current_predicted_horizon_values[j],\n",
+ " 'is_anomaly': True,\n",
+ " 'outlier_score': score,\n",
+ " 'lower_bound': lower_thresh,\n",
+ " 'upper_bound': upper_thresh,\n",
+ " })\n",
+ " payload = {\n",
+ " \"start_ts_micros\": window_start_ts.micros,\n",
+ " \"predicted_values\":
current_predicted_horizon_values.tolist(),\n",
+ " \"q20_values\": current_q20_values.tolist(),\n",
+ " \"q30_values\": current_q30_values.tolist(),\n",
+ " \"q70_values\": current_q70_values.tolist(),\n",
+ " \"q80_values\": current_q80_values.tolist(),\n",
+ " \"anomalies\": anomalies_found, # Your original list is
now inside the dictionary\n",
+ " \"actual_horizon_values\":
actual_horizon_values.tolist()\n",
+ " }\n",
+ " result_with_context = (batch[0], payload)\n",
+ "\n",
+ " return [result_with_context]"
+ ],
+ "metadata": {
+ "id": "oT9NIaWcuUgb",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "a26eae71-067a-4314-b49a-3b028ee75903"
+ },
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ " See
https://github.com/google-research/timesfm/blob/master/README.md for updated
APIs.\n",
+ "Loaded PyTorch TimesFM, likely because python version is 3.11.13
(main, Jun 4 2025, 08:57:29) [GCC 11.4.0].\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# LLM
Classifier:\n",
+ " \"\"\"Encodes special types like Timestamp and numpy objects into
JSON.\"\"\"\n",
+ " def default(self, obj):\n",
+ " if isinstance(obj, Timestamp):\n",
+ " # Store as a dict with a special key for easy decoding\n",
+ " return {'__timestamp__': True, 'micros': obj.micros}\n",
+ " if isinstance(obj, np.integer):\n",
+ " return int(obj)\n",
+ " if isinstance(obj, np.floating):\n",
+ " return float(obj)\n",
+ " if isinstance(obj, np.ndarray):\n",
+ " return obj.tolist()\n",
+ " return super().default(obj)\n",
+ "\n",
+ "def custom_json_decoder(dct):\n",
+ " \"\"\"Decodes a Timestamp object from our custom dict
format.\"\"\"\n",
+ " if '__timestamp__' in dct:\n",
+ " return Timestamp(micros=dct['micros'])\n",
+ " return dct\n",
+ "\n",
+ "class JsonCoderWithNumpyAndTimestamp(beam.coders.Coder):\n",
+ " \"\"\"A custom Beam Coder that handles JSON serialization for
Timestamps and numpy types.\"\"\"\n",
+ " def encode(self, value):\n",
+ " return json.dumps(value,
cls=CustomJsonEncoderForLLM).encode('utf-8')\n",
+ "\n",
+ " def decode(self, encoded):\n",
+ " return json.loads(encoded.decode('utf-8'),
object_hook=custom_json_decoder)\n",
+ "\n",
+ " def is_deterministic(self):\n",
+ " return True\n",
+ "\n",
+ "\n",
+ "# It's highly recommended to manage API keys via GCP Secret
Manager\n",
+ "# and access them as environment variables in your Dataflow job.\n",
+ "# genai.configure(api_key=os.environ[\"GEMINI_API_KEY\"])\n",
+ "\n",
+ "class LLMClassifierFn(beam.DoFn):\n",
+ " \"\"\"\n",
+ " Takes an anomaly, formats a detailed prompt with surrounding
context,\n",
+ " calls the Gemini model to classify it, and routes the original
data\n",
+ " based on the model's decision.\n",
+ "\n",
+ " This DoFn is stateful, deferring anomalies that occur too close
to\n",
+ " the end of a window until a subsequent window provides enough
context.\n",
+ " \"\"\"\n",
+ "\n",
+ " DEFERRED_ANOMALIES_STATE = BagStateSpec(\n",
+ " 'deferred_anomalies',
coder=JsonCoderWithNumpyAndTimestamp())\n",
+ " YIELD_BUFFER_STATE = ReadModifyWriteStateSpec('yield_buffer',
PickleCoder())\n",
+ "\n",
+ " # <<< CHANGE: Define a timer and a state to track if it's set\n",
+ " EXPIRY_TIMER = TimerSpec('expiry', beam.TimeDomain.WATERMARK)\n",
+ " # <<< CHANGE: Add state to track the last yielded timestamp\n",
+ " LAST_YIELDED_TIMESTAMP_STATE =
ReadModifyWriteStateSpec('last_yielded_ts', PickleCoder())\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ " def __init__(self, secret, context_points=25, slide_interval=128,
expected_interval_secs=1):\n",
+ " self.context_points = context_points\n",
+ " self._model = None\n",
+ " self.secret = secret\n",
+ " self.slide_interval = slide_interval\n",
+ " self.expected_interval_micros = expected_interval_secs *
1_000_000\n",
+ "\n",
+ " self._last_window_data = None\n",
+ "\n",
+ "\n",
+ " def setup(self):\n",
+ " # Configure the generative model\n",
+ "\n",
+ " genai.configure(api_key=self.secret)\n",
+ " logging.getLogger().setLevel(logging.INFO)\n",
+ "\n",
+ "\n",
+ " generation_config = {\n",
+ " \"temperature\": 0.2,\n",
+ " \"top_p\": 1,\n",
+ " \"top_k\": 1,\n",
+ " \"max_output_tokens\": 256,\n",
+ " \"response_mime_type\": \"application/json\",\n",
+ " }\n",
+ " # For a full list of safety settings, see the Gemini API
documentation\n",
+ " safety_settings = [\n",
+ " {\"category\": \"HARM_CATEGORY_HARASSMENT\",
\"threshold\": \"BLOCK_NONE\"},\n",
+ " {\"category\": \"HARM_CATEGORY_HATE_SPEECH\",
\"threshold\": \"BLOCK_NONE\"},\n",
+ " ]\n",
+ " self._model = genai.GenerativeModel(\n",
+ " model_name=\"gemini-1.5-flash-latest\",\n",
+ " generation_config=generation_config,\n",
+ " safety_settings=safety_settings\n",
+ " )\n",
+ " logging.info(\"Gemini Model has been successfully
initialized.\")\n",
+ "\n",
+ " def _build_prompt(self, anomaly_data, context_before,
context_after):\n",
+ " mean_before = np.mean(context_before) if context_before.size
> 0 else 0\n",
+ " mean_after = np.mean(context_after) if context_after.size > 0
else 0\n",
+ " std_before = np.std(context_before) if context_before.size >
0 else 0\n",
+ " std_after = np.std(context_after) if context_after.size > 0
else 0\n",
+ "\n",
+ " return f\"\"\"\n",
+ " You are an expert time-series analyst classifying an outlier
from NYC taxi pickup data.\n",
+ " Normal behavior includes daily and weekly cyclical
patterns.\n",
+ "\n",
+ " **1. Outlier Context:**\n",
+ " * **--> The Outlier:**\n",
+ " * **Timestamp:**
{Timestamp(micros=anomaly_data['timestamp'].micros)}\n",
+ " * **Actual Value:** {anomaly_data['actual_value']:.2f}\n",
+ " * **Predicted Value:**
{anomaly_data['predicted_value']:.2f}\n",
+ " * **Anomaly Upper Bound:**
{anomaly_data['upper_bound']:.2f}\n",
+ " * **Anomaly Lower Bound:**
{anomaly_data['lower_bound']:.2f}\n",
+ "\n",
+ " **2. Data Surrounding the Outlier:**\n",
+ " * **Data Before ({len(context_before)} points):**
{np.round(context_before, 2).tolist()}\n",
+ " * **Data After ({len(context_after)} points):**
{np.round(context_after, 2).tolist()}\n",
+ "\n",
+ " **3. Statistical Context:**\n",
+ " * **Mean Before:** {mean_before:.2f}\n",
+ " * **Mean After:** {mean_after:.2f}\n",
+ " * **Std. Dev. Before:** {std_before:.2f}\n",
+ " * **Std. Dev. After:** {std_after:.2f}\n",
+ "\n",
+ " **4. Your Task:**\n",
+ "\n",
+ " **Step 1: Analyze the Evidence.** In a few sentences,
describe the behavior of the data *after* the outlier. Does it quickly revert
to the \"Predicted Value\" or the \"Mean Before\"? Or does it establish a new
level, closer to the \"Mean After\"?\n",
+ "\n",
+ " **Step 2: Make a Decision.** Classify the outlier.\n",
+ " * **REMOVE:** If it's a transient, one-off event. This is
likely if the data after the outlier rapidly returns to the established
pattern.\n",
+ " * **KEEP:** If it signifies a sustained shift in the pattern
that the model should learn from. This is likely if the `Mean After` has
shifted significantly.\n",
+ "\n",
+ " **Step 3: Provide Final Output.** Respond with a single JSON
object. Do not add any text outside the JSON block.\n",
+ "\n",
+ " {{\n",
+ " \"reasoning_steps\": \"Your analysis from Step 1 goes
here.\",\n",
+ " \"decision\": \"KEEP or REMOVE\",\n",
+ " \"confidence_score\": <A float between 0.0 and 1.0>\n",
+ " }}\n",
+ " \"\"\"\n",
+ "\n",
+ " def process(self, element,\n",
+ "
deferred_anomalies=beam.DoFn.StateParam(DEFERRED_ANOMALIES_STATE),\n",
+ "
yield_buffer=beam.DoFn.StateParam(YIELD_BUFFER_STATE),\n",
+ " expiry_timer=beam.DoFn.TimerParam(EXPIRY_TIMER)):\n",
+ "\n",
+ " key, data = element\n",
+ " window_start_ts = data['window_start_ts']\n",
+ "\n",
+ " # Set a timer to fire based on the event time of the current
element.\n",
+ " # Each new element will push the timer forward. The timer
will only\n",
+ " # fire when a gap in the input stream occurs, allowing the
buffer\n",
+ " # to contain data from multiple consecutive sliding
windows.\n",
+ " # We set it far enough ahead to allow the next window's data
to arrive.\n",
+ " grace_period_secs = self.slide_interval * 2\n",
+ " expiry_timer.set(window_start_ts + grace_period_secs)\n",
+ " anomalies_in_window = data.get('anomalies', [])\n",
+ " values_in_element = data.get('values_array', [])\n",
+ "\n",
+ " for anomaly in anomalies_in_window:\n",
+ " deferred_anomalies.add(anomaly)\n",
+ "\n",
+ " buffer = yield_buffer.read() or {}\n",
+ " for i, value in enumerate(values_in_element):\n",
+ " point_timestamp = Timestamp(micros=window_start_ts.micros
+ (i * self.expected_interval_micros))\n",
+ " buffer[point_timestamp] = value\n",
+ " yield_buffer.write(buffer)\n",
+ "\n",
+ " @on_timer(EXPIRY_TIMER)\n",
+ " def on_expiry_timer(\n",
+ " self,\n",
+ "
deferred_anomalies=beam.DoFn.StateParam(DEFERRED_ANOMALIES_STATE),\n",
+ " yield_buffer=beam.DoFn.StateParam(YIELD_BUFFER_STATE),\n",
+ " # <<< CHANGE: Add the new state parameter here\n",
+ "
last_yielded_ts_state=beam.DoFn.StateParam(LAST_YIELDED_TIMESTAMP_STATE)):\n",
+ "\n",
+ " all_anomalies_to_consider =
list(deferred_anomalies.read())\n",
+ " buffered_points_map = yield_buffer.read() or {}\n",
+ "\n",
+ " if not buffered_points_map:\n",
+ " return\n",
+ "\n",
+ " sorted_points = sorted(buffered_points_map.items())\n",
+ " all_timestamps = [ts for ts, val in sorted_points]\n",
+ " all_values = [val for ts, val in sorted_points]\n",
+ "\n",
+ " anomalies_to_process_now = []\n",
+ " prompts_to_batch = []\n",
+ " final_deferred = []\n",
+ "\n",
+ " for anomaly_data in all_anomalies_to_consider:\n",
+ " anomaly_ts = anomaly_data['timestamp']\n",
+ " try:\n",
+ " idx_in_full_data =
all_timestamps.index(anomaly_ts)\n",
+ "\n",
+ " if (idx_in_full_data + self.context_points) <
len(all_values):\n",
+ " start_ctx = max(0, idx_in_full_data -
self.context_points)\n",
+ " end_ctx = idx_in_full_data + self.context_points
+ 1\n",
+ "\n",
+ " context_before =
np.array(all_values[start_ctx:idx_in_full_data])\n",
+ " context_after =
np.array(all_values[idx_in_full_data + 1:end_ctx])\n",
+ "\n",
+ " anomaly_data['index_in_window'] =
idx_in_full_data\n",
+ " prompt = self._build_prompt(anomaly_data,
context_before, context_after)\n",
+ " prompts_to_batch.append(prompt)\n",
+ " anomalies_to_process_now.append(anomaly_data)\n",
+ " else:\n",
+ " final_deferred.append(anomaly_data)\n",
+ " except ValueError:\n",
+ " final_deferred.append(anomaly_data)\n",
+ "\n",
+ " if prompts_to_batch:\n",
+ " try:\n",
+ " logging.info(f\"Sending a batch of
{len(prompts_to_batch)} prompts to the LLM.\")\n",
+ " responses =
self._model.generate_content(prompts_to_batch)\n",
+ " for anomaly_data, response in
zip(anomalies_to_process_now, responses):\n",
+ " try:\n",
+ " response_data = json.loads(response.text)\n",
+ " decision = response_data.get('decision',
'KEEP').strip().upper()\n",
+ " idx = anomaly_data['index_in_window']\n",
+ "\n",
+ " if decision == 'REMOVE':\n",
+ " logging.warning(f\"LLM decided to REMOVE
anomaly at {anomaly_data['timestamp']}. Imputing value.\")\n",
+ " all_values[idx] =
anomaly_data['predicted_value']\n",
+ " except (json.JSONDecodeError, AttributeError) as
e:\n",
+ " logging.error(f\"Error processing LLM
response for {anomaly_data['timestamp']}: {e}. Defaulting to KEEP.\")\n",
+ " except Exception as e:\n",
+ " logging.error(f\"Error calling LLM with a batch: {e}.
Defaulting to KEEP for all.\")\n",
+ "\n",
+ " # <<< CHANGE: New logic to yield only new data\n",
+ " last_yielded_ts = last_yielded_ts_state.read()\n",
+ " latest_ts_in_batch = None\n",
+ "\n",
+ " for i, (ts, original_val) in enumerate(sorted_points):\n",
+ " # Only yield points that are newer than the last batch we
yielded\n",
+ " if last_yielded_ts is None or ts > last_yielded_ts:\n",
+ " yield {\n",
+ " 'timestamp': ts,\n",
+ " 'value': all_values[i]\n",
+ " }\n",
+ " latest_ts_in_batch = ts\n",
+ "\n",
+ " # After yielding, update the state with the latest timestamp
from this batch\n",
+ " if latest_ts_in_batch:\n",
+ " last_yielded_ts_state.write(latest_ts_in_batch)\n",
+ "\n",
+ " # Prune the buffer. We need to keep enough historical data to
serve\n",
+ " # as `context_before` for the anomalies that we are
re-deferring.\n",
+ " if latest_ts_in_batch:\n",
+ " all_buffered_points = yield_buffer.read() or {}\n",
+ "\n",
+ " # Find the earliest timestamp we need to keep. This will
be\n",
+ " # `context_points` before the last yielded point,
ensuring\n",
+ " # context is available for the next batch.\n",
+ " try:\n",
+ " last_yielded_index =
all_timestamps.index(latest_ts_in_batch)\n",
+ " context_start_index = max(0, last_yielded_index -
self.context_points)\n",
+ " context_start_ts =
all_timestamps[context_start_index]\n",
+ "\n",
+ " pruned_buffer = {\n",
+ " ts: val\n",
+ " for ts, val in all_buffered_points.items()\n",
+ " if ts >= context_start_ts\n",
+ " }\n",
+ " yield_buffer.write(pruned_buffer)\n",
+ " except ValueError:\n",
+ " # This can happen if the buffer is in an inconsistent
state.\n",
+ " # As a fallback, we clear it if we aren't deferring
anything.\n",
+ " logging.warning(\n",
+ " f\"Could not find last yielded timestamp \"\n",
+ " f\"{latest_ts_in_batch} in buffer for
pruning.\"\n",
+ " )\n",
+ " if not final_deferred:\n",
+ " yield_buffer.clear()\n",
+ " elif not final_deferred:\n",
+ " # If we didn't yield anything and we're not deferring
anything,\n",
+ " # the buffer is fully processed and can be cleared.\n",
+ " yield_buffer.clear()\n",
+ "\n",
+ " # Re-add anomalies that couldn't be processed to the state so
they can\n",
+ " # be considered in the next firing.\n",
+ " deferred_anomalies.clear()\n",
+ " if final_deferred:\n",
+ " logging.info(f\"Re-deferring {len(final_deferred)}
anomalies due to insufficient context.\")\n",
+ " for anomaly in final_deferred:\n",
+ " deferred_anomalies.add(anomaly)\n"
+ ],
+ "metadata": {
+ "id": "c55ou9f5vADf"
+ },
+ "execution_count": 4,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Finetuning Component"
+ ],
+ "metadata": {
+ "id": "WSl5lV_9ugQY"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "\"\"\"\n",
+ "TimesFM Finetuner: A flexible framework for finetuning TimesFM models
on custom datasets.\n",
+ "\"\"\"\n",
+ "\n",
+ "import logging\n",
+ "import os\n",
+ "from abc import ABC, abstractmethod\n",
+ "from dataclasses import dataclass, field\n",
+ "from typing import Any, Callable, Dict, List, Optional\n",
+ "\n",
+ "import torch\n",
+ "import torch.distributed as dist\n",
+ "import torch.nn as nn\n",
+ "from torch.nn.parallel import DistributedDataParallel as DDP\n",
+ "from torch.utils.data import DataLoader, Dataset\n",
+ "from timesfm.pytorch_patched_decoder import create_quantiles\n",
+ "\n",
+ "import wandb\n",
+ "\n",
+ "\n",
+ "class MetricsLogger(ABC):\n",
+ " \"\"\"Abstract base class for logging metrics during training.\n",
+ "\n",
+ " This class defines the interface for logging metrics during model
training.\n",
+ " Concrete implementations can log to different backends (e.g.,
WandB, TensorBoard).\n",
+ " \"\"\"\n",
+ "\n",
+ " @abstractmethod\n",
+ " def log_metrics(self,\n",
+ " metrics: Dict[str, Any],\n",
+ " step: Optional[int] = None) -> None:\n",
+ " \"\"\"Log metrics to the specified backend.\n",
+ "\n",
+ " Args:\n",
+ " metrics: Dictionary containing metric names and values.\n",
+ " step: Optional step number or epoch for the metrics.\n",
+ " \"\"\"\n",
+ " pass\n",
+ "\n",
+ " @abstractmethod\n",
+ " def close(self) -> None:\n",
+ " \"\"\"Clean up any resources used by the logger.\"\"\"\n",
+ " pass\n",
+ "\n",
+ "\n",
+ "class WandBLogger(MetricsLogger):\n",
+ " \"\"\"Weights & Biases implementation of metrics logging.\n",
+ "\n",
+ " Args:\n",
+ " project: Name of the W&B project.\n",
+ " config: Configuration dictionary to log.\n",
+ " rank: Process rank in distributed training.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, project: str, config: Dict[str, Any], rank: int
= 0):\n",
+ " self.rank = rank\n",
+ " if rank == 0:\n",
+ " wandb.init(project=project, config=config)\n",
+ "\n",
+ " def log_metrics(self,\n",
+ " metrics: Dict[str, Any],\n",
+ " step: Optional[int] = None) -> None:\n",
+ " \"\"\"Log metrics to W&B if on the main process.\n",
+ "\n",
+ " Args:\n",
+ " metrics: Dictionary of metrics to log.\n",
+ " step: Current training step or epoch.\n",
+ " \"\"\"\n",
+ " if self.rank == 0:\n",
+ " wandb.log(metrics, step=step)\n",
+ "\n",
+ " def close(self) -> None:\n",
+ " \"\"\"Finish the W&B run if on the main process.\"\"\"\n",
+ " if self.rank == 0:\n",
+ " wandb.finish()\n",
+ "\n",
+ "\n",
+ "class DistributedManager:\n",
+ " \"\"\"Manages distributed training setup and cleanup.\n",
+ "\n",
+ " Args:\n",
+ " world_size: Total number of processes.\n",
+ " rank: Process rank.\n",
+ " master_addr: Address of the master process.\n",
+ " master_port: Port for distributed communication.\n",
+ " backend: PyTorch distributed backend to use.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " world_size: int,\n",
+ " rank: int,\n",
+ " master_addr: str = \"localhost\",\n",
+ " master_port: str = \"12358\",\n",
+ " backend: str = \"nccl\",\n",
+ " ):\n",
+ " self.world_size = world_size\n",
+ " self.rank = rank\n",
+ " self.master_addr = master_addr\n",
+ " self.master_port = master_port\n",
+ " self.backend = backend\n",
+ "\n",
+ " def setup(self) -> None:\n",
+ " \"\"\"Initialize the distributed environment.\"\"\"\n",
+ " os.environ[\"MASTER_ADDR\"] = self.master_addr\n",
+ " os.environ[\"MASTER_PORT\"] = self.master_port\n",
+ "\n",
+ " if not dist.is_initialized():\n",
+ " dist.init_process_group(backend=self.backend,\n",
+ " world_size=self.world_size,\n",
+ " rank=self.rank)\n",
+ "\n",
+ " def cleanup(self) -> None:\n",
+ " \"\"\"Clean up the distributed environment.\"\"\"\n",
+ " if dist.is_initialized():\n",
+ " dist.destroy_process_group()\n",
+ "\n",
+ "\n",
+ "@dataclass\n",
+ "class FinetuningConfig:\n",
+ " \"\"\"Configuration for model training.\n",
+ "\n",
+ " Args:\n",
+ " batch_size: Number of samples per batch.\n",
+ " num_epochs: Number of training epochs.\n",
+ " learning_rate: Initial learning rate.\n",
+ " weight_decay: L2 regularization factor.\n",
+ " freq_type: Frequency, can be [0, 1, 2].\n",
+ " use_quantile_loss: bool = False # Flag to enable/disable
quantile loss\n",
+ " quantiles: Optional[List[float]] = None\n",
+ " device: Device to train on ('cuda' or 'cpu').\n",
+ " distributed: Whether to use distributed training.\n",
+ " gpu_ids: List of GPU IDs to use.\n",
+ " master_port: Port for distributed training.\n",
+ " master_addr: Address for distributed training.\n",
+ " use_wandb: Whether to use Weights & Biases logging.\n",
+ " wandb_project: W&B project name.\n",
+ " log_every_n_steps: Log metrics every N steps (batches), this is
inspired from Pytorch Lightning\n",
+ " val_check_interval: How often within one training epoch to
check val metrics. (also from Pytorch Lightning)\n",
+ " Can be: float (0.0-1.0): fraction of epoch (e.g., 0.5 =
validate twice per epoch)\n",
+ " int: validate every N batches\n",
+ " \"\"\"\n",
+ "\n",
+ " batch_size: int = 32\n",
+ " num_epochs: int = 20\n",
+ " learning_rate: float = 1e-4\n",
+ " weight_decay: float = 0.01\n",
+ " freq_type: int = 0\n",
+ " use_quantile_loss: bool = False\n",
+ " quantiles: Optional[List[float]] = None\n",
+ " device: str = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
+ " distributed: bool = False\n",
+ " gpu_ids: List[int] = field(default_factory=lambda: [0])\n",
+ " master_port: str = \"12358\"\n",
+ " master_addr: str = \"localhost\"\n",
+ " use_wandb: bool = False\n",
+ " wandb_project: str = \"timesfm-finetuning\"\n",
+ " log_every_n_steps: int = 50\n",
+ " val_check_interval: float = 0.5\n",
+ "\n",
+ "\n",
+ "class TimesFMFinetuner:\n",
+ " \"\"\"Handles model training and validation.\n",
+ "\n",
+ " Args:\n",
+ " model: PyTorch model to train.\n",
+ " config: Training configuration.\n",
+ " rank: Process rank for distributed training.\n",
+ " loss_fn: Loss function (defaults to MSE).\n",
+ " logger: Optional logging.Logger instance.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " model: nn.Module,\n",
+ " config: FinetuningConfig,\n",
+ " rank: int = 0,\n",
+ " loss_fn: Optional[Callable] = None,\n",
+ " logger: Optional[logging.Logger] = None,\n",
+ " ):\n",
+ " self.model = model\n",
+ " self.config = config\n",
+ " self.rank = rank\n",
+ " self.logger = logger or logging.getLogger(__name__)\n",
+ " self.device = torch.device(\n",
+ " f\"cuda:{rank}\" if torch.cuda.is_available() else
\"cpu\")\n",
+ " self.loss_fn = loss_fn or (lambda x, y: torch.mean((x -
y.squeeze(-1))**2))\n",
+ "\n",
+ " if config.use_wandb:\n",
+ " self.metrics_logger = WandBLogger(config.wandb_project,
config.__dict__,\n",
+ " rank)\n",
+ "\n",
+ " if config.distributed:\n",
+ " self.dist_manager = DistributedManager(\n",
+ " world_size=len(config.gpu_ids),\n",
+ " rank=rank,\n",
+ " master_addr=config.master_addr,\n",
+ " master_port=config.master_port,\n",
+ " )\n",
+ " self.dist_manager.setup()\n",
+ " self.model = self._setup_distributed_model()\n",
+ "\n",
+ " def _setup_distributed_model(self) -> nn.Module:\n",
+ " \"\"\"Configure model for distributed training.\"\"\"\n",
+ " self.model = self.model.to(self.device)\n",
+ " return DDP(self.model,\n",
+ " device_ids=[self.config.gpu_ids[self.rank]],\n",
+ " output_device=self.config.gpu_ids[self.rank])\n",
+ "\n",
+ " def _create_dataloader(self, dataset: Dataset, is_train: bool) ->
DataLoader:\n",
+ " \"\"\"Create appropriate DataLoader based on training
configuration.\n",
+ "\n",
+ " Args:\n",
+ " dataset: Dataset to create loader for.\n",
+ " is_train: Whether this is for training (affects
shuffling).\n",
+ "\n",
+ " Returns:\n",
+ " DataLoader instance.\n",
+ " \"\"\"\n",
+ " if self.config.distributed:\n",
+ " sampler = torch.utils.data.distributed.DistributedSampler(\n",
+ " dataset,\n",
+ " num_replicas=len(self.config.gpu_ids),\n",
+ " rank=dist.get_rank(),\n",
+ " shuffle=is_train)\n",
+ " else:\n",
+ " sampler = None\n",
+ "\n",
+ " return DataLoader(\n",
+ " dataset,\n",
+ " batch_size=self.config.batch_size,\n",
+ " shuffle=(is_train and not self.config.distributed),\n",
+ " sampler=sampler,\n",
+ " )\n",
+ "\n",
+ " def _quantile_loss(self, pred: torch.Tensor, actual:
torch.Tensor,\n",
+ " quantile: float) -> torch.Tensor:\n",
+ " \"\"\"Calculates quantile loss.\n",
+ " Args:\n",
+ " pred: Predicted values\n",
+ " actual: Actual values\n",
+ " quantile: Quantile at which loss is computed\n",
+ " Returns:\n",
+ " Quantile loss\n",
+ " \"\"\"\n",
+ " dev = actual - pred\n",
+ " loss_first = dev * quantile\n",
+ " loss_second = -dev * (1.0 - quantile)\n",
+ " return 2 * torch.where(loss_first >= 0, loss_first,
loss_second)\n",
+ "\n",
+ " def _process_batch(self, batch: List[torch.Tensor]) -> tuple:\n",
+ " \"\"\"Process a single batch of data.\n",
+ "\n",
+ " Args:\n",
+ " batch: List of input tensors.\n",
+ "\n",
+ " Returns:\n",
+ " Tuple of (loss, predictions).\n",
+ " \"\"\"\n",
+ " x_context, x_padding, freq, x_future = [\n",
+ " t.to(self.device, non_blocking=True) for t in batch\n",
+ " ]\n",
+ "\n",
+ " predictions = self.model(x_context, x_padding.float(), freq)\n",
+ " predictions_mean = predictions[..., 0]\n",
+ " last_patch_pred = predictions_mean[:, -1, :]\n",
+ "\n",
+ " loss = self.loss_fn(last_patch_pred, x_future.squeeze(-1))\n",
+ " if self.config.use_quantile_loss:\n",
+ " quantiles = self.config.quantiles or create_quantiles()\n",
+ " for i, quantile in enumerate(quantiles):\n",
+ " last_patch_quantile = predictions[:, -1, :, i + 1]\n",
+ " loss += torch.mean(\n",
+ " self._quantile_loss(last_patch_quantile,
x_future.squeeze(-1),\n",
+ " quantile))\n",
+ "\n",
+ " return loss, predictions\n",
+ "\n",
+ " def _train_epoch(self, train_loader: DataLoader,\n",
+ " optimizer: torch.optim.Optimizer) -> float:\n",
+ " \"\"\"Train for one epoch in a distributed setting.\n",
+ "\n",
+ " Args:\n",
+ " train_loader: DataLoader for training data.\n",
+ " optimizer: Optimizer instance.\n",
+ "\n",
+ " Returns:\n",
+ " Average training loss for the epoch.\n",
+ " \"\"\"\n",
+ " self.model.train()\n",
+ " total_loss = 0.0\n",
+ " num_batches = len(train_loader)\n",
+ "\n",
+ " for batch in train_loader:\n",
+ " loss, _ = self._process_batch(batch)\n",
+ "\n",
+ " optimizer.zero_grad()\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ "\n",
+ " total_loss += loss.item()\n",
+ "\n",
+ " avg_loss = total_loss / num_batches\n",
+ "\n",
+ " if self.config.distributed:\n",
+ " avg_loss_tensor = torch.tensor(avg_loss, device=self.device)\n",
+ " dist.all_reduce(avg_loss_tensor, op=dist.ReduceOp.SUM)\n",
+ " avg_loss = (avg_loss_tensor / dist.get_world_size()).item()\n",
+ "\n",
+ " return avg_loss\n",
+ "\n",
+ " def _validate(self, val_loader: DataLoader) -> float:\n",
+ " \"\"\"Perform validation.\n",
+ "\n",
+ " Args:\n",
+ " val_loader: DataLoader for validation data.\n",
+ "\n",
+ " Returns:\n",
+ " Average validation loss.\n",
+ " \"\"\"\n",
+ " self.model.eval()\n",
+ " total_loss = 0.0\n",
+ " num_batches = len(val_loader)\n",
+ "\n",
+ " with torch.no_grad():\n",
+ " for batch in val_loader:\n",
+ " loss, _ = self._process_batch(batch)\n",
+ " total_loss += loss.item()\n",
+ "\n",
+ " avg_loss = total_loss / num_batches\n",
+ "\n",
+ " if self.config.distributed:\n",
+ " avg_loss_tensor = torch.tensor(avg_loss, device=self.device)\n",
+ " dist.all_reduce(avg_loss_tensor, op=dist.ReduceOp.SUM)\n",
+ " avg_loss = (avg_loss_tensor / dist.get_world_size()).item()\n",
+ "\n",
+ " return avg_loss\n",
+ "\n",
+ " def finetune(self, train_dataset: Dataset,\n",
+ " val_dataset: Dataset) -> Dict[str, Any]:\n",
+ " \"\"\"Train the model.\n",
+ "\n",
+ " Args:\n",
+ " train_dataset: Training dataset.\n",
+ " val_dataset: Validation dataset.\n",
+ "\n",
+ " Returns:\n",
+ " Dictionary containing training history.\n",
+ " \"\"\"\n",
+ " self.model = self.model.to(self.device)\n",
+ " train_loader = self._create_dataloader(train_dataset,
is_train=True)\n",
+ " val_loader = self._create_dataloader(val_dataset,
is_train=False)\n",
+ "\n",
+ " optimizer = torch.optim.Adam(self.model.parameters(),\n",
+ " lr=self.config.learning_rate,\n",
+ "
weight_decay=self.config.weight_decay)\n",
+ "\n",
+ " history = {\"train_loss\": [], \"val_loss\": [],
\"learning_rate\": []}\n",
+ "\n",
+ " self.logger.info(\n",
+ " f\"Starting training for {self.config.num_epochs}
epochs...\")\n",
+ " self.logger.info(f\"Training samples: {len(train_dataset)}\")\n",
+ " self.logger.info(f\"Validation samples: {len(val_dataset)}\")\n",
+ "\n",
+ " try:\n",
+ " for epoch in range(self.config.num_epochs):\n",
+ " train_loss = self._train_epoch(train_loader, optimizer)\n",
+ " val_loss = self._validate(val_loader)\n",
+ " current_lr = optimizer.param_groups[0][\"lr\"]\n",
+ "\n",
+ " metrics = {\n",
+ " \"train_loss\": train_loss,\n",
+ " \"val_loss\": val_loss,\n",
+ " \"learning_rate\": current_lr,\n",
+ " \"epoch\": epoch + 1,\n",
+ " }\n",
+ "\n",
+ " if self.config.use_wandb:\n",
+ " self.metrics_logger.log_metrics(metrics)\n",
+ "\n",
+ " history[\"train_loss\"].append(train_loss)\n",
+ " history[\"val_loss\"].append(val_loss)\n",
+ " history[\"learning_rate\"].append(current_lr)\n",
+ "\n",
+ " if self.rank == 0:\n",
+ " self.logger.info(\n",
+ " f\"[Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Val
Loss: {val_loss:.4f}\"\n",
+ " )\n",
+ "\n",
+ " except KeyboardInterrupt:\n",
+ " self.logger.info(\"Training interrupted by user\")\n",
+ "\n",
+ " if self.config.distributed:\n",
+ " self.dist_manager.cleanup()\n",
+ "\n",
+ " if self.config.use_wandb:\n",
+ " self.metrics_logger.close()\n",
+ "\n",
+ " return {\"history\": history}\n",
+ "\n",
+ "import apache_beam as beam\n",
+ "import logging\n",
+ "import torch\n",
+ "import numpy as np\n",
+ "import timesfm\n",
+ "from os import path\n",
+ "from timesfm import TimesFm, TimesFmCheckpoint, TimesFmHparams\n",
+ "from timesfm.pytorch_patched_decoder import
PatchedTimeSeriesDecoder\n",
+ "from huggingface_hub import snapshot_download\n",
+ "from apache_beam.io.gcp.gcsio import GcsIO # Add this import\n",
+ "\n",
+ "from torch.utils.data import Dataset\n",
+ "from google.cloud import storage\n",
+ "from typing import Tuple\n",
+ "\n",
+ "\n",
+ "class TimeSeriesDataset(Dataset):\n",
+ " \"\"\"Dataset for time series data compatible with
TimesFM.\"\"\"\n",
+ " def __init__(\n",
+ " self,\n",
+ " series: np.ndarray,\n",
+ " context_length: int,\n",
+ " horizon_length: int,\n",
+ " freq_type: int = 0):\n",
+ " \"\"\"\n",
+ " Initialize dataset.\n",
+ "\n",
+ " Args:\n",
+ " series: Time series data\n",
+ " context_length: Number of past timesteps to use as
input\n",
+ " horizon_length: Number of future timesteps to predict\n",
+ " freq_type: Frequency type (0, 1, or 2)\n",
+ " \"\"\"\n",
+ " if freq_type not in [0, 1, 2]:\n",
+ " raise ValueError(\"freq_type must be 0, 1, or 2\")\n",
+ "\n",
+ " self.series = series\n",
+ " self.context_length = context_length\n",
+ " self.horizon_length = horizon_length\n",
+ " self.freq_type = freq_type\n",
+ " self._prepare_samples()\n",
+ "\n",
+ " def _prepare_samples(self) -> None:\n",
+ " \"\"\"Prepare sliding window samples from the time
series.\"\"\"\n",
+ " self.samples = []\n",
+ " total_length = self.context_length + self.horizon_length\n",
+ "\n",
+ " for start_idx in range(0, len(self.series) - total_length +
1):\n",
+ " end_idx = start_idx + self.context_length\n",
+ " x_context = self.series[start_idx:end_idx]\n",
+ " x_future = self.series[end_idx:end_idx +
self.horizon_length]\n",
+ " self.samples.append((x_context, x_future))\n",
+ "\n",
+ " def __len__(self) -> int:\n",
+ " return len(self.samples)\n",
+ "\n",
+ " def __getitem__(\n",
+ " self, index: int\n",
+ " ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor]:\n",
+ " x_context, x_future = self.samples[index]\n",
+ "\n",
+ " x_context = torch.tensor(x_context, dtype=torch.float32)\n",
+ " x_future = torch.tensor(x_future, dtype=torch.float32)\n",
+ "\n",
+ " input_padding = torch.zeros_like(x_context)\n",
+ " freq = torch.tensor([self.freq_type], dtype=torch.long)\n",
+ "\n",
+ " return x_context, input_padding, freq, x_future\n",
+ "\n",
+ "\n",
+ "def prepare_datasets(\n",
+ " series: np.ndarray,\n",
+ " context_length: int,\n",
+ " horizon_length: int,\n",
+ " freq_type: int = 0,\n",
+ " train_split: float = 0.8) -> Tuple[Dataset, Dataset]:\n",
+ " \"\"\"\n",
+ " Prepare training and validation datasets from time series
data.\n",
+ "\n",
+ " Args:\n",
+ " series: Input time series data\n",
+ " context_length: Number of past timesteps to use\n",
+ " horizon_length: Number of future timesteps to predict\n",
+ " freq_type: Frequency type (0, 1, or 2)\n",
+ " train_split: Fraction of data to use for training\n",
+ "\n",
+ " Returns:\n",
+ " Tuple of (train_dataset, val_dataset)\n",
+ " \"\"\"\n",
+ " train_size = int(len(series) * train_split)\n",
+ " train_data = series[:train_size]\n",
+ " val_data = series[train_size:]\n",
+ "\n",
+ " # Create datasets with specified frequency type\n",
+ " train_dataset = TimeSeriesDataset(\n",
+ " train_data,\n",
+ " context_length=context_length,\n",
+ " horizon_length=horizon_length,\n",
+ " freq_type=freq_type)\n",
+ "\n",
+ " val_dataset = TimeSeriesDataset(\n",
+ " val_data,\n",
+ " context_length=context_length,\n",
+ " horizon_length=horizon_length,\n",
+ " freq_type=freq_type)\n",
+ "\n",
+ " return train_dataset, val_dataset\n",
+ "\n",
+ "\n",
+ "class BatchContinuousAndOrderedFn(beam.DoFn):\n",
+ " \"\"\"\n",
+ " A stateful DoFn that buffers elements, keeps them sorted, and
emits\n",
+ " a batch only when a full, continuous sequence of points is
available.\n",
+ " Includes detailed logging for debugging.\n",
+ " \"\"\"\n",
+ " BUFFER_STATE = ReadModifyWriteStateSpec('buffer',
PickleCoder())\n",
+ "\n",
+ " def __init__(self, batch_size, expected_interval_seconds=1):\n",
+ " self.batch_size = batch_size\n",
+ " self.interval = expected_interval_seconds\n",
+ " # NEW LOGGING: Counter to avoid logging on every single
element\n",
+ " self.counter = 0\n",
+ "\n",
+ " def process(self, element,
buffer=beam.DoFn.StateParam(BUFFER_STATE)):\n",
+ " key, data = element\n",
+ " timestamp = data['timestamp']\n",
+ " value = data['value']\n",
+ "\n",
+ " # Increment the counter\n",
+ " self.counter += 1\n",
+ "\n",
+ " current_buffer = buffer.read() or []\n",
+ " current_buffer.append((timestamp, value))\n",
+ " current_buffer.sort(key=lambda x: x[0])\n",
+ "\n",
+ " # NEW LOGGING: Periodically log the buffer status\n",
+ " if self.counter % 100 == 0 and current_buffer:\n",
+ " logging.info(\n",
+ " f\"Batching buffer now contains {len(current_buffer)}
points. \"\n",
+ " f\"Timestamps range from {current_buffer[0][0]} to
{current_buffer[-1][0]}.\"\n",
+ " )\n",
+ "\n",
+ " start_index = 0\n",
+ " while start_index + self.batch_size <=
len(current_buffer):\n",
+ " is_continuous = True\n",
+ " # Check for continuity in the slice of the buffer we are
considering\n",
+ " for i in range(start_index, start_index + self.batch_size
- 1):\n",
+ " ts1_seconds = current_buffer[i][0].seconds()\n",
+ " ts2_seconds = current_buffer[i + 1][0].seconds()\n",
+ "\n",
+ " if ts2_seconds - ts1_seconds != self.interval:\n",
+ " is_continuous = False\n",
+ " # If a gap is found, we should stop and wait for
more data.\n",
+ " # We can't proceed past this point because the
buffer is sorted.\n",
+ " logging.info(\n",
+ " f\"Gap detected at index {i}. \"\n",
+ " f\"Timestamp {current_buffer[i][0]} is
followed by {current_buffer[i+1][0]}. \"\n",
+ " f\"Actual interval: {ts2_seconds -
ts1_seconds}s, Expected: {self.interval}s. \"\n",
+ " f\"Waiting for missing data.\"\n",
+ " )\n",
+ " break\n",
+ "\n",
+ " if not is_continuous:\n",
+ " # Since the buffer is sorted, a gap at this point
means we can't form any more continuous batches.\n",
+ " break\n",
+ "\n",
+ " # If we are here, the batch from start_index is
continuous.\n",
+ " logging.info(f\"Continuous sequence found! Emitting batch
of size {self.batch_size} starting at index {start_index}.\")\n",
+ "\n",
+ " batch_to_yield = current_buffer[start_index : start_index
+ self.batch_size]\n",
+ "\n",
+ " formatted_batch = [{'timestamp': ts, 'value': val} for
ts, val in batch_to_yield]\n",
+ " yield formatted_batch\n",
+ "\n",
+ " # Move the start_index to the next position after the
yielded batch\n",
+ " start_index += self.batch_size\n",
+ "\n",
+ " # After the loop, remove all the yielded elements from the
buffer.\n",
+ " if start_index > 0:\n",
+ " current_buffer = current_buffer[start_index:]\n",
+ "\n",
+ " buffer.write(current_buffer)\n",
+ "\n",
+ "class RunFinetuningFn(beam.DoFn):\n",
+ " \"\"\"\n",
+ " Takes a batch of data, loads the LATEST model, runs
fine-tuning,\n",
+ " and uploads the new model to GCS.\n",
+ " \"\"\"\n",
+ " def __init__(\n",
+ " self,\n",
+ " initial_model_path, # Renamed from base_model_path\n",
+ " finetuned_model_bucket,\n",
+ " finetuned_model_prefix,\n",
+ " hparams,\n",
+ " config):\n",
+ " # This is now a fallback for the very first run\n",
+ " self.initial_model_path = initial_model_path\n",
+ " self.finetuned_model_bucket = finetuned_model_bucket\n",
+ " self.finetuned_model_prefix = finetuned_model_prefix\n",
+ " self.hparams = hparams\n",
+ " self.config = config\n",
+ " self._storage_client = None\n",
+ "\n",
+ " def setup(self):\n",
+ " self._storage_client = storage.Client()\n",
+ "\n",
+ " def _get_latest_model_from_gcs(self):\n",
+ " \"\"\"Directly queries GCS for the most recently created model
checkpoint.\"\"\"\n",
+ " try:\n",
+ " bucket =
self._storage_client.get_bucket(self.finetuned_model_bucket)\n",
+ " blobs =
list(bucket.list_blobs(prefix=self.finetuned_model_prefix))\n",
+ "\n",
+ " # Filter for actual model files and exclude the initial model
if present\n",
+ " model_blobs = [b for b in blobs if b.name.endswith(\".pth\")
and \"initial\" not in b.name]\n",
+ "\n",
+ " if not model_blobs:\n",
+ " return None\n",
+ "\n",
+ " # Find the blob with the latest creation time\n",
+ " latest_blob = max(model_blobs, key=lambda b:
b.time_created)\n",
+ " latest_model_path =
f\"gs://{self.finetuned_model_bucket}/{latest_blob.name}\"\n",
+ " return latest_model_path\n",
+ " except Exception as e:\n",
+ " logging.error(f\"Error querying GCS for the latest model:
{e}\")\n",
+ " return None\n",
+ "\n",
+ " # Add the side input parameter to the process method\n",
+ " def process(self, batch_of_data):\n",
+ " logging.info(\n",
+ " f\"Received a batch of {len(batch_of_data)} points for
finetuning.\")\n",
+ "\n",
+ " # If a finetuned model exists, use it. Otherwise, use the initial
base model.\n",
+ " latest_model_path = self._get_latest_model_from_gcs()\n",
+ "\n",
+ " if latest_model_path:\n",
+ " model_to_load = latest_model_path\n",
+ " logging.info(f\"Continuously finetuning from latest model:
{model_to_load}\")\n",
+ " else:\n",
+ " model_to_load = self.initial_model_path\n",
+ " logging.info(f\"No finetuned model found. Starting from
initial model: {model_to_load}\")\n",
+ "\n",
+ " # batch_of_data.sort(key=lambda x: x[1]['timestamp'])\n",
+ " time_series_values = np.array([d['value'] for d in
batch_of_data],\n",
+ " dtype=np.float32)\n",
+ " train_dataset, val_dataset = prepare_datasets(\n",
+ " series=time_series_values,\n",
+ " context_length=self.hparams.context_len,\n",
+ " horizon_length=self.hparams.horizon_len,\n",
+ " freq_type=self.config.freq_type,\n",
+ " train_split=0.8\n",
+ " )\n",
+ "\n",
+ " logging.info(f\"Training dataset size:
{train_dataset.series.tolist()}\")\n",
+ " logging.info(f\"Validation dataset size:
{val_dataset.series.tolist()}\")\n",
+ "\n",
+ " # Load the model (base or latest finetuned)\n",
+ " # The updated get_model function can handle both GCS and Hugging
Face paths\n",
+ " model = get_model(\n",
+ " model_path=model_to_load, # Use the path we just
determined\n",
+ " hparams=self.hparams,\n",
+ " load_weights=True\n",
+ " )\n",
+ "\n",
+ " # 4. Run fine-tuning (same as before)\n",
+ " finetuner = TimesFMFinetuner(model, self.config)\n",
+ " finetuner.finetune(train_dataset=train_dataset,
val_dataset=val_dataset)\n",
+ "\n",
+ " # 5. Save and upload the new model (same as before)\n",
+ " from datetime import datetime\n",
+ " timestamp_str = datetime.utcnow().strftime('%Y%m%d%H%M%S')\n",
+ " model_filename = f\"timesfm_finetuned_{timestamp_str}.pth\"\n",
+ " local_path = f\"/tmp/{model_filename}\"\n",
+ " torch.save(model.state_dict(), local_path)\n",
+ " bucket =
self._storage_client.bucket(self.finetuned_model_bucket)\n",
+ " blob_path =
f\"{self.finetuned_model_prefix}/{model_filename}\"\n",
+ " blob = bucket.blob(blob_path)\n",
+ " blob.upload_from_filename(local_path)\n",
+ " logging.info(\n",
+ " f\"Successfully uploaded new model to
gs://{self.finetuned_model_bucket}/{blob_path}\"\n",
+ " )\n",
+ " yield blob_path\n",
+ "\n",
+ "\n",
+ "def get_model(model_path: str, hparams: TimesFmHparams, load_weights:
bool = False):\n",
+ " \"\"\"\n",
+ " Loads a TimesFM model from either a Hugging Face repo ID or a GCS
path.\n",
+ " The `load_weights` argument is kept for signature consistency but
is\n",
+ " effectively always True, as TimesFm handles loading.\n",
+ " \"\"\"\n",
+ " checkpoint_config = {}\n",
+ "\n",
+ " # Case 1: The model path is a GCS URI.\n",
+ " # We download it to a local file and tell TimesFmCheckpoint to
load from that path.\n",
+ " if model_path.startswith(\"gs://\"):\n",
+ " logging.info(f\"Preparing to load model from GCS path:
{model_path}\")\n",
+ " local_temp_path = f\"/tmp/{path.basename(model_path)}\"\n",
+ " with GcsIO().open(model_path, 'rb') as f_in,
open(local_temp_path, 'wb') as f_out:\n",
+ " f_out.write(f_in.read())\n",
+ " # The key for a local file is 'path'\n",
+ " checkpoint_config['path'] = local_temp_path\n",
+ "\n",
+ " # Case 2: The model path is a Hugging Face repository ID.\n",
+ " else:\n",
+ " logging.info(f\"Preparing to load model from Hugging Face
repo: {model_path}\")\n",
+ " # The key for a Hugging Face repo is 'huggingface_repo_id'\n",
+ " checkpoint_config['huggingface_repo_id'] = model_path\n",
+ "\n",
+ " # Initialize the TimesFm object correctly with the dynamically
created checkpoint config.\n",
+ " # This single call handles model configuration and weight
loading.\n",
+ " tfm = TimesFm(\n",
+ " hparams=hparams,\n",
+ " checkpoint=TimesFmCheckpoint(**checkpoint_config)\n",
+ " )\n",
+ "\n",
+ " logging.info(\"Model loaded successfully inside get_model.\")\n",
+ "\n",
+ " # The `TimesFm` object holds the configured model instance.\n",
+ " # The model returned here will be a PatchedTimeSeriesDecoder
instance with weights loaded.\n",
+ " return tfm._model"
+ ],
+ "metadata": {
+ "id": "IzEE_R3SuwAR"
+ },
+ "execution_count": 11,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Load Time Series Data\n",
+ "\n",
+ "https://www.kaggle.com/datasets/julienjta/nyc-taxi-traffic/data"
+ ],
+ "metadata": {
+ "id": "Lz1OROouy9IV"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import pandas as pd\n",
+ "from google.colab import auth\n",
+ "auth.authenticate_user()\n",
+ "\n",
+ "auth.authenticate_user()\n",
+ "\n",
+ "# Define the path to your file in the GCS bucket\n",
+ "gcs_path =
'gs://apache-beam-samples/anomaly_detection/timesfm-dataset-example/nyc_taxi_timeseries.csv'\n",
+ "\n",
+ "# Read the CSV directly from GCS into a DataFrame\n",
+ "# All the gspread code is replaced by this single line\n",
+ "df = pd.read_csv(gcs_path)\n",
+ "\n",
+ "# --- The rest of your processing code remains the same ---\n",
+ "\n",
+ "# Convert 'value' column to a numpy array of integers\n",
+ "values_array = pd.to_numeric(df['value'],
errors='coerce').astype(int).to_numpy()\n",
+ "\n",
+ "# Create the list of (timestamp, value) tuples\n",
+ "input_data = []\n",
+ "for i in range(len(values_array)):\n",
+ " input_data.append((Timestamp(i + 1), values_array[i])) # Assuming
Timestamp comes from pandas\n",
+ "\n",
+ "print(\"DataFrame loaded from GCS:\")\n",
+ "print(df.head())\n",
+ "print(\"\\nInput data created successfully (first 5 entries):\")\n",
+ "print(input_data[:5])"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "N7fAXDjXuDF4",
+ "outputId": "0ea32fa5-221d-44fb-9f83-4f524fd8f3c2"
+ },
+ "execution_count": 6,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "DataFrame loaded from GCS:\n",
+ " Unnamed: 0 timestamp value\n",
+ "0 0 2014-07-01 0:00:00 10844\n",
+ "1 1 2014-07-01 0:30:00 8127\n",
+ "2 2 2014-07-01 1:00:00 6210\n",
+ "3 3 2014-07-01 1:30:00 4656\n",
+ "4 4 2014-07-01 2:00:00 3820\n",
+ "\n",
+ "Input data created successfully (first 5 entries):\n",
+ "[(Timestamp(1), np.int64(10844)), (Timestamp(2), np.int64(8127)),
(Timestamp(3), np.int64(6210)), (Timestamp(4), np.int64(4656)), (Timestamp(5),
np.int64(3820))]\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Beam Pipeline Setup"
+ ],
+ "metadata": {
+ "id": "LiQQF_IxquCK"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import apache_beam as beam\n",
+ "from apache_beam.options.pipeline_options import PipelineOptions\n",
+ "from apache_beam.pvalue import AsDict, AsSingleton\n",
+ "from apache_beam.transforms.periodicsequence import
PeriodicImpulse\n",
+ "import logging\n",
+ "import os\n",
+ "import json\n",
+ "import timesfm\n",
+ "from apache_beam.utils.timestamp import Timestamp\n",
+ "import csv\n",
+ "from apache_beam.ml.inference.base import RunInference\n",
+ "from apache_beam.ml.inference.utils import WatchFilePattern\n",
+ "import typing\n",
+ "from google.colab import userdata\n",
+ "import apache_beam.transforms.window as window\n",
+ "\n",
+ "logging.getLogger().setLevel(logging.INFO)\n",
+ "\n",
+ "# --- Pipeline Configuration ---\n",
+ "PROJECT_ID = os.environ.get(\"GCP_PROJECT\",
\"apache-beam-testing\")\n",
+ "REGION = os.environ.get(\"GCP_REGION\", \"us-central1\")\n",
+ "TEMP_LOCATION =
\"gs://apache-beam-testing-temp/timesfm_anomaly_detection/temp\"\n",
+ "STAGING_LOCATION =
\"gs://apache-beam-testing-temp/timesfm_anomaly_detection/staging\"\n",
+ "FINETUNED_MODEL_BUCKET = \"apache-beam-testing-temp\"\n",
+ "FINETUNED_MODEL_PREFIX =
\"timesfm_anomaly_detection/finetuned-models/timesfm/checkpoints\"\n",
+ "\n",
+ "# --- Model & Window Parameters ---\n",
+ "CONTEXT_LEN = 512\n",
+ "HORIZON_LEN = 128\n",
+ "WINDOW_SIZE = CONTEXT_LEN + HORIZON_LEN\n",
+ "SLIDE_INTERVAL = HORIZON_LEN\n",
+ "EXPECTED_INTERVAL = 1\n",
+ "INITIAL_MODEL = \"google/timesfm-1.0-200m-pytorch\"\n",
+ "\n",
+ "MODEL_CHECK_INTERVAL_SECONDS = 10 # Check for a new model every 5
seconds\n",
+ "FINETUNING_BATCH_SIZE = 7680 # 9600 # make larger later. minimum is
WINDOW_SIZE for validation and training\n",
+ "FINETUNE_CONFIG = FinetuningConfig(\n",
+ " batch_size=128,\n",
+ " num_epochs=5,\n",
+ " learning_rate=1e-4,\n",
+ " use_wandb=False,\n",
+ " freq_type=0, # should change based on your data\n",
+ " log_every_n_steps=10,\n",
+ " val_check_interval=0.5,\n",
+ " use_quantile_loss=True\n",
+ " )\n",
+ "\n",
+ "#Change to Dataflow if needed\n",
+ "options = PipelineOptions([\n",
+ " \"--streaming\",\n",
+ " \"--environment_type=LOOPBACK\",\n",
+ " \"--runner=PrismRunner\",\n",
+ " \"--logging_level=INFO\",\n",
+ " \"--job_server_timeout=3600\"\n",
+ "])\n",
+ "\n",
+ "\n",
+ "\n",
+ "# HParams for the model\n",
+ "hparams = timesfm.TimesFmHparams(\n",
+ " backend=\"gpu\",\n",
+ " per_core_batch_size=32,\n",
+ " horizon_len=HORIZON_LEN,\n",
+ " context_len=CONTEXT_LEN,\n",
+ ")\n",
+ "model_handler = DynamicTimesFmModelHandler(model_uri=INITIAL_MODEL,
hparams=hparams)\n",
+ "\n",
+ "def print_and_pass_through(label):\n",
+ " def logger(element):\n",
+ " print(f\"--- {label} --- \\nELEMENT: %s\", element)\n",
+ " return element\n",
+ " return logger\n",
+ "\n",
+ "\n",
+ "class CustomJsonEncoder(json.JSONEncoder):\n",
+ " \"\"\"A custom JSON encoder that knows how to handle Beam's
Timestamp objects.\"\"\"\n",
+ " def default(self, obj):\n",
+ " if isinstance(obj, Timestamp):\n",
+ " # Convert Timestamp to a standard, readable ISO 8601
string format\n",
+ " return obj.micros // 1e6\n",
+ " # For all other types, fall back to the default behavior\n",
+ " if isinstance(obj, np.integer):\n",
+ " return int(obj)\n",
+ "\n",
+ " # 3. Handle NumPy float types (this will fix your float32
error)\n",
+ " if isinstance(obj, np.floating):\n",
+ " return float(obj)\n",
+ "\n",
+ " # 4. Handle NumPy arrays\n",
+ " if isinstance(obj, np.ndarray):\n",
+ " return obj.tolist()\n",
+ "\n",
+ " # For all other types, fall back to the default behavior\n",
+ " return super().default(obj)\n",
+ " return json.JSONEncoder.default(self, obj)\n",
+ "\n",
+ "class WritePlotDataAndPassThrough(beam.DoFn):\n",
+ " \"\"\"\n",
+ " A DoFn that writes plotting data to a file as a side effect\n",
+ " and then passes the original, unmodified element downstream.\n",
+ " \"\"\"\n",
+ " def __init__(self, output_path):\n",
+ " self._output_path = output_path\n",
+ " self._file_handle = None\n",
+ "\n",
+ " def setup(self):\n",
+ " self._file_handle = open(self._output_path, 'a')\n",
+ "\n",
+ " def process(self, element):\n",
+ " _original_window, payload_dict = element\n",
+ "\n",
+ " # ✅ FIX: Use the custom encoder to handle Timestamp
objects\n",
+ " json_record = json.dumps(payload_dict,
cls=CustomJsonEncoder)\n",
+ " self._file_handle.write(json_record + '\\n')\n",
+ "\n",
+ " # Pass the original element through, with the Timestamp
object intact\n",
+ " yield element\n",
+ "\n",
+ " def teardown(self):\n",
+ " if self._file_handle:\n",
+ " self._file_handle.close()\n",
+ "\n",
+ "\n",
+ "#
=================================================================\n",
+ "# 1. Get Latest Model Path (Side Input) - WatchFilePattern is not\n",
+ "# currently supported on Prism. Uncomment the following to run\n",
+ "# on Dataflow\n",
+ "#
=================================================================\n",
+ "# model_pattern = os.path.join(\n",
+ "# f\"gs://{FINETUNED_MODEL_BUCKET}\", FINETUNED_MODEL_PREFIX,
\"*.pth\"\n",
+ "# )\n",
+ "\n",
+ "# model_metadata_pcoll = (\n",
+ "# \"WatchForNewModels\" >> WatchFilePattern(\n",
+ "# file_pattern=model_pattern,\n",
+ "# interval=MODEL_CHECK_INTERVAL_SECONDS\n",
+ "# )\n",
+ "# | \"PrintModelLocation\" >>
beam.Map(print_and_pass_through(\"Model Location\"))\n",
+ "\n",
+ "# )\n",
+ "\n",
+ "#
=================================================================\n",
+ "# Ingest and Window Raw Data\n",
+ "#
=================================================================\n",
+ "\n",
+ "\n",
+ "windowed_data = (\n",
+ " PeriodicImpulse(data=input_data, fire_interval=0.01)\n",
+ " | \"AddKey\" >> beam.WithKeys(lambda x: 0)\n",
+ " | \"ApplySlidingWindow\" >> beam.ParDo(\n",
+ " OrderedSlidingWindowFn(window_size=WINDOW_SIZE,
slide_interval=SLIDE_INTERVAL))\n",
+ " | \"FillGaps\" >>
beam.ParDo(FillGapsFn(expected_interval=EXPECTED_INTERVAL)).with_output_types(\n",
+ " typing.Tuple[int, typing.Tuple[Timestamp, Timestamp,
typing.List[float]]])\n",
+ " | \"Skip NaN Values for now\" >> beam.Filter(\n",
+ " lambda batch: 'NaN' not in batch[1][2])\n",
+ " | \"PrintWindowedData\" >>
beam.Map(print_and_pass_through(\"Windowed Data\"))\n",
+ "\n",
+ ")\n",
+ "\n",
+ "#
=================================================================\n",
+ "# Detect Anomalies using the Latest Model\n",
+ "#
=================================================================\n",
+ "\n",
+ "inference_results = (\n",
+ " \"DetectAnomalies\" >> RunInference(\n",
+ " model_handler=model_handler,\n",
+ " # model_metadata_pcoll=model_metadata_pcoll\n",
+ " )\n",
+ " | \"PrintInference\" >>
beam.Map(print_and_pass_through(\"Inference Results\"))\n",
+ ")\n",
+ "\n",
+ "\n",
+ "# NEW BRANCH: For plotting. It takes the payload dictionary,
converts\n",
+ "# it to JSON, and writes it to a file.\n",
+ "plotting_data_output = (\n",
+ " \"WritePlotDataAsSideEffect\" >> beam.ParDo(\n",
+ " WritePlotDataAndPassThrough('plot_data_original.jsonl'))\n",
+ ")\n",
+ "\n",
+ "def format_for_llm(result_tuple):\n",
+ " \"\"\"\n",
+ " Takes the output of RunInference (a PredictionResult) and formats
it\n",
+ " into the dictionary structure needed by the LLMClassifierFn.\n",
+ " \"\"\"\n",
+ " original_window_data, result_dict = result_tuple\n",
+ "\n",
+ " list_of_anomalies = result_dict['anomalies']\n",
+ "\n",
+ " key, (window_start_ts, _, values_array) = original_window_data\n",
+ "\n",
+ " return (key, {\n",
+ " 'key': key,\n",
+ " 'window_start_ts': window_start_ts,\n",
+ " 'values_array': values_array,\n",
+ " 'anomalies': list_of_anomalies if list_of_anomalies else
[]\n",
+ " })\n",
+ "\n",
+ "\n",
+ "data_for_llm = (\n",
+ " \"FormatForLLM\" >> beam.Map(format_for_llm)\n",
+ " | \"PrintDataForLLM\" >> beam.Map(print_and_pass_through(\"Data
for LLM\"))\n",
+ ")\n",
+ "\n",
+ "\n",
+ "#
=================================================================\n",
+ "# Classify with LLM and Create Clean Data for Finetuning\n",
+ "#
=================================================================\n",
+ "api_key = \"AIzaSyCB_g6tq3eBFtB3BsshdGotLkUkTsCyApY\"
#userdata.get('GEMINI_API_KEY')\n",
+ "\n",
+ "llm_classifier = (\n",
+ " \"LLMClassifierAndImputer\" >> beam.ParDo(\n",
+ " LLMClassifierFn(\n",
+ " secret=api_key,\n",
+ " slide_interval=SLIDE_INTERVAL,\n",
+ " expected_interval_secs=EXPECTED_INTERVAL\n",
+ " )\n",
+ " )\n",
+ " # | \"PrintLLMResults\" >> beam.Map(print_and_pass_through(\"LLM
Results\"))\n",
+ ")\n",
+ "\n",
+ "\n",
+ "# #
=================================================================\n",
+ "# # Batch Clean Data and Trigger Finetuning\n",
+ "# #
=================================================================\n",
+ "finetuning_job_input = (\n",
+ " \"KeyForBatching\" >> beam.WithKeys(lambda _:
\"finetune_batch\")\n",
+ " # | \"BatchAndTrigger\" >>
beam.ParDo(BatchAndTriggerFinetuningFn(FINETUNING_BATCH_SIZE))\n",
+ " | \"BatchAndTrigger\" >> beam.ParDo(\n",
+ " BatchContinuousAndOrderedFn(\n",
+ " FINETUNING_BATCH_SIZE,\n",
+ " expected_interval_seconds=1\n",
+ " )\n",
+ " )\n",
+ " | \"PrintFinetuningJobInput\" >>
beam.Map(print_and_pass_through(\"Finetuning Job Input\"))\n",
+ ")\n",
+ "\n",
+ "# #
=================================================================\n",
+ "# # Run Finetuning and Save New Model to GCS\n",
+ "# #
=================================================================\n",
+ "finetuning = (\n",
+ " \"RunFinetuning\" >> beam.ParDo(\n",
+ " RunFinetuningFn(\n",
+ "
initial_model_path=\"google/timesfm-1.0-200m-pytorch\",\n",
+ " finetuned_model_bucket=FINETUNED_MODEL_BUCKET,\n",
+ " finetuned_model_prefix=FINETUNED_MODEL_PREFIX,\n",
+ " hparams=hparams,\n",
+ " config=FINETUNE_CONFIG\n",
+ " ),\n",
+ " )\n",
+ ")\n"
+ ],
+ "metadata": {
+ "id": "Oud4wLTjqy2j",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "1e0cdb8c-16e5-42ce-ef5a-b870677e954d"
+ },
+ "execution_count": 16,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "WARNING:apache_beam.transforms.core:('No iterator is returned by
the process method in %s.', <class '__main__.LLMClassifierFn'>)\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Beam Pipeline"
+ ],
+ "metadata": {
+ "id": "ZMx8KhRyvj3Q"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "with beam.Pipeline(options=options) as p:\n",
+ " (p\n",
+ " | windowed_data\n",
+ " | inference_results\n",
+ " | plotting_data_output # comment this line if you dont want to save
plot data\n",
+ " | data_for_llm\n",
+ " | llm_classifier\n",
+ " | finetuning_job_input\n",
+ " | finetuning\n",
+ " )\n"
+ ],
+ "metadata": {
+ "id": "mKa6Qb_1vnNX"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Plot Data (Original)"
+ ],
+ "metadata": {
+ "id": "O7egp_5Alzz7"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import json\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "CONTEXT_LEN = 512\n",
+ "HORIZON_LEN = 128\n",
+ "\n",
+ "def plot_anomalies_and_forecast(\n",
+ " values_array,\n",
+ " all_anomalies,\n",
+ " all_predicted_values,\n",
+ " all_q20_values,\n",
+ " all_q30_values,\n",
+ " all_q70_values,\n",
+ " all_q80_values,\n",
+ " title_suffix=\"\",\n",
+ " x_lims=None,\n",
+ " min_outlier_score_for_plot=0,\n",
+ " context_len=512,\n",
+ " output_filename=\"plot.png\"\n",
+ "):\n",
+ " print(len(all_anomalies))\n",
+ " # The key from your file is 'outlier_score'\n",
+ " filtered_anomalies = [a for a in all_anomalies if
a['outlier_score'] >= min_outlier_score_for_plot]\n",
+ " # The key from your file is 'timestamp'\n",
+ " anomaly_indices = [(a['timestamp'] - HORIZON_LEN) for a in
filtered_anomalies]\n",
+ " anomaly_values = [a['actual_value'] for a in
filtered_anomalies]\n",
+ "\n",
+ " Q1 = np.nanmean([all_q20_values, all_q30_values], axis=0)\n",
+ " Q3 = np.nanmean([all_q70_values, all_q80_values], axis=0)\n",
+ " IQR = Q3 - Q1\n",
+ " upper_thresh = Q3 + 1.5 * IQR\n",
+ " lower_thresh = Q1 - 1.5 * IQR\n",
+ "\n",
+ " plt.figure(figsize=(18, 9))\n",
+ " # This now plots the correct original data for the horizon\n",
+ " plt.plot(values_array[context_len:], label='Original Time
Series', color='blue', alpha=0.7, linewidth=1.5)\n",
+ "\n",
+ " plt.plot(all_predicted_values, label='Predicted Mean',
color='green', linestyle='--', linewidth=1.5)\n",
+ " plt.plot(lower_thresh, label='Lower Threshold', color='orange',
linestyle=':', linewidth=1.2)\n",
+ " plt.plot(upper_thresh, label='Upper Threshold', color='purple',
linestyle=':', linewidth=1.2)\n",
+ "\n",
+ " plt.scatter([i - context_len for i in anomaly_indices],
anomaly_values,\n",
+ " color='red', s=70, zorder=5,\n",
+ " label=f'Detected Anomalies (Score >=
{min_outlier_score_for_plot:.1f})',\n",
+ " marker='o', edgecolors='black', linewidths=0.8)\n",
+ "\n",
+ " plt.title(f'Time Series Anomaly Detection {title_suffix}')\n",
+ " plt.xlabel('Time Index')\n",
+ " plt.ylabel('Value')\n",
+ " if x_lims:\n",
+ " plt.xlim(x_lims[0], x_lims[1])\n",
+ " plt.legend()\n",
+ " plt.grid(True, linestyle='--', alpha=0.6)\n",
+ " plt.tight_layout()\n",
+ " # plt.savefig(output_filename) # Save the plot to a file\n",
+ " plt.show()\n",
+ " plt.close() # Close the figure to free memory\n",
+ "\n",
+ "# --- Main Script Logic ---\n",
+ "\n",
+ "# 1. Read and parse the data from the Beam output file\n",
+ "all_window_data = []\n",
+ "# Make sure 'plot_data.jsonl' is in the same directory as this
script\n",
+ "try:\n",
+ " with open('plot_data_original.jsonl', 'r') as f:\n",
+ " for line in f:\n",
+ " # Check for empty lines that might have been added\n",
+ " if line.strip():\n",
+ " all_window_data.append(json.loads(line))\n",
+ "except FileNotFoundError:\n",
+ " print(\"Error: 'plot_data.jsonl' not found. Please make sure the
file is in the correct directory.\")\n",
+ " exit()\n",
+ "\n",
+ "\n",
+ "# 2. Sort data by timestamp to ensure the correct order\n",
+ "all_window_data.sort(key=lambda x: x['start_ts_micros'])\n",
+ "\n",
+ "# 3. Reconstruct the full data arrays\n",
+ "all_anomalies = []\n",
+ "all_predicted_values = []\n",
+ "all_q20_values = []\n",
+ "all_q30_values = []\n",
+ "all_q70_values = []\n",
+ "all_q80_values = []\n",
+ "all_actual_horizon_values = [] # This will hold the real \"blue
line\" data\n",
+ "\n",
+ "for window_data in all_window_data:\n",
+ " all_predicted_values.extend(window_data['predicted_values'])\n",
+ " all_q20_values.extend(window_data['q20_values'])\n",
+ " all_q30_values.extend(window_data['q30_values'])\n",
+ " all_q70_values.extend(window_data['q70_values'])\n",
+ " all_q80_values.extend(window_data['q80_values'])\n",
+ " # Populate the list with the actual values from the file\n",
+ "
all_actual_horizon_values.extend(window_data.get('actual_horizon_values',
[]))\n",
+ " all_anomalies.extend(window_data.get('anomalies', []))\n",
+ "\n",
+ "# 4. Convert lists to NumPy arrays\n",
+ "all_predicted_values = np.array(all_predicted_values)\n",
+ "all_q20_values = np.array(all_q20_values)\n",
+ "all_q30_values = np.array(all_q30_values)\n",
+ "all_q70_values = np.array(all_q70_values)\n",
+ "all_q80_values = np.array(all_q80_values)\n",
+ "\n",
+ "# 5. Construct the `values_array` using the REAL data from your
file\n",
+ "context_len = 512\n",
+ "# Create a dummy context so the array has the right shape for the
plotting function.\n",
+ "# The first real value is used to make the context visually
seamless.\n",
+ "if all_actual_horizon_values:\n",
+ " dummy_context = [all_actual_horizon_values[0]] * context_len\n",
+ " values_array = np.array(dummy_context +
all_actual_horizon_values)\n",
+ "else:\n",
+ " # Fallback in case the file is empty or missing the
actual_horizon_values key\n",
+ " print(\"Warning: 'actual_horizon_values' not found. The original
time series plot will be empty.\")\n",
+ " total_len = context_len + len(all_predicted_values)\n",
+ " values_array = np.zeros(total_len)\n",
+ "\n",
+ "# 6. Call the plotting functions\n",
+ "if values_array.any():\n",
+ " # Plotting function for full graph\n",
+ " plot_anomalies_and_forecast(\n",
+ " values_array, all_anomalies, all_predicted_values,\n",
+ " all_q20_values, all_q30_values, all_q70_values,
all_q80_values,\n",
+ " title_suffix=\"(Full Graph with Correct Data)\",\n",
+ " min_outlier_score_for_plot=1, # Set a score threshold\n",
+ " context_len=context_len,\n",
+ " output_filename=\"full_graph_correct.png\"\n",
+ " )\n",
+ "\n",
+ " # Plotting function for zoomed-in graphs - feel free to change\n",
+ " zoom_ranges = [(2000, 2500), (8300, 9000), (9000, 9600)]\n",
+ " for i, (start_idx, end_idx) in enumerate(zoom_ranges):\n",
+ " # Adjust x_lims for the fact that the plotted array is sliced
by context_len\n",
+ " plot_x_start = max(0, start_idx)\n",
+ " plot_x_end = end_idx\n",
+ "\n",
+ " plot_anomalies_and_forecast(\n",
+ " values_array, all_anomalies, all_predicted_values,\n",
+ " all_q20_values, all_q30_values, all_q70_values,
all_q80_values,\n",
+ " title_suffix=f\"(Zoomed In: {start_idx} to
{end_idx})\",\n",
+ " x_lims=(plot_x_start, plot_x_end),\n",
+ " min_outlier_score_for_plot=1,\n",
+ " context_len=context_len,\n",
+ " output_filename=f\"zoomed_graph_correct_{i}.png\"\n",
+ " )\n",
+ " print(\"Plots have been generated and saved with the corrected
original data.\")\n",
+ "else:\n",
+ " print(\"No data found to plot.\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ },
+ "id": "_HzqIoZbl25b",
+ "outputId": "7f85032d-e2cf-4e7c-b9b1-cae9994dee37"
+ },
+ "execution_count": 18,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "948\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "<Figure size 1800x900 with 1 Axes>"
+ ],
+ "image/png":
"iVBORw0KGgoAAAANSUhEUgAABvsAAAN5CAYAAAAmV4erAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XdYFMf/B/D3cVSpIiiiVBuoiL0rdozla9eoUYk1KrEbTTQKsccaY+y9GzWx94K9K4gNEcGu2EB6uZvfH/y4eFLVgwXu/Xoennizs7Of2d37BG5uZ2RCCAEiIiIiIiIiIiIiIiIiynd0pA6AiIiIiIiIiIiIiIiIiL4MB/uIiIiIiIiIiIiIiIiI8ikO9hERERERERERERERERHlUxzsIyIiIiIiIiIiIiIiIsqnONhHRERERERERERERERElE9xsI+IiIiIiIiIiIiIiIgon+JgHxERER
[...]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "948\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "<Figure size 1800x900 with 1 Axes>"
+ ],
+ "image/png":
"iVBORw0KGgoAAAANSUhEUgAABv0AAAN5CAYAAAArSffsAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3Xd4VMX79/HPpjeSEFpASuhNqiCiIkWaFLEAokgvKiCIgOAXpYgUpaoI2CgKKgp2RJqEDgpIL9JCUSBISQgkhOzO8wdP9seShOxi2E3w/bquXLJz5szcs3tyS7hz5liMMUYAAAAAAAAAAAAAciwvTwcAAAAAAAAAAAAA4N+h6AcAAAAAAAAAAADkcBT9AAAAAAAAAAAAgByOoh8AAAAAAAAAAACQw1H0AwAAAAAAAAAAAHI4in4AAAAAAAAAAABADkfRDwAAAAAAAAAAAMjhKPoBAAAAAA
[...]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "948\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "<Figure size 1800x900 with 1 Axes>"
+ ],
+ "image/png":
"iVBORw0KGgoAAAANSUhEUgAABv0AAAN5CAYAAAArSffsAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XdcU9f7B/BPwgxbFBnKUhFQcS+sFjfWUfdEhTpbpW6rVuuou64O66h7a9VWbd0L97a4tyCoIC5ANiTn94c/8jUyDAoJVz/v14tXzbnnnvuc5D4JzcM9VyaEECAiIiIiIiIiIiIiIiIiyZLrOwAiIiIiIiIiIiIiIiIi+jAs+hERERERERERERERERFJHIt+RERERERERERERERERBLHoh8RERERERERERERERGRxLHoR0RERERERERERERERCRxLPoRERERERERERERERERSRyLfkRERE
[...]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "948\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "<Figure size 1800x900 with 1 Axes>"
+ ],
+ "image/png":
"iVBORw0KGgoAAAANSUhEUgAABv4AAAN5CAYAAADAfkzvAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XdYFFfbBvB7lrp0QQRRmqKIir3GKHaNvceusb5RYjfqq7HEHmsSY6+xx5LExK4Re4kFxY4oYkNUBKTD7vn+8GNfV4qgAyPr/bsurmTPnDnznJ3hYeVhzkhCCAEiIiIiIiIiIiIiIiIiytdUSgdARERERERERERERERERB+OhT8iIiIiIiIiIiIiIiIiA8DCHxEREREREREREREREZEBYOGPiIiIiIiIiIiIiIiIyACw8EdERERERERERERERERkAFj4IyIiIiIiIiIiIiIiIjIALPwRER
[...]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Plots have been generated and saved with the corrected original
data.\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Finetuned Model Predictions"
+ ],
+ "metadata": {
+ "id": "_KxBniHGqS3M"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import apache_beam as beam\n",
+ "from apache_beam.options.pipeline_options import PipelineOptions\n",
+ "from apache_beam.pvalue import AsDict, AsSingleton\n",
+ "from apache_beam.transforms.periodicsequence import
PeriodicImpulse\n",
+ "import logging\n",
+ "import os\n",
+ "import json\n",
+ "import timesfm\n",
+ "from apache_beam.utils.timestamp import Timestamp\n",
+ "import csv\n",
+ "from apache_beam.ml.inference.base import RunInference\n",
+ "from apache_beam.ml.inference.utils import WatchFilePattern\n",
+ "import typing\n",
+ "from google.colab import userdata\n",
+ "import apache_beam.transforms.window as window\n",
+ "\n",
+ "logging.getLogger().setLevel(logging.INFO)\n",
+ "\n",
+ "# --- Pipeline Configuration ---\n",
+ "PROJECT_ID = os.environ.get(\"GCP_PROJECT\",
\"apache-beam-testing\")\n",
+ "REGION = os.environ.get(\"GCP_REGION\", \"us-central1\")\n",
+ "TEMP_LOCATION =
\"gs://apache-beam-testing-temp/timesfm_anomaly_detection/temp\"\n",
+ "STAGING_LOCATION =
\"gs://apache-beam-testing-temp/timesfm_anomaly_detection/staging\"\n",
+ "FINETUNED_MODEL_BUCKET = \"apache-beam-testing-temp\"\n",
+ "FINETUNED_MODEL_PREFIX =
\"timesfm_anomaly_detection/finetuned-models/timesfm/checkpoints\"\n",
+ "\n",
+ "# --- Model & Window Parameters ---\n",
+ "CONTEXT_LEN = 512\n",
+ "HORIZON_LEN = 128\n",
+ "WINDOW_SIZE = CONTEXT_LEN + HORIZON_LEN\n",
+ "SLIDE_INTERVAL = HORIZON_LEN\n",
+ "EXPECTED_INTERVAL = 1\n",
+ "# go to checkpoints bucket and select the correct model path\n",
+ "INITIAL_MODEL =
\"gs://apache-beam-testing-temp/timesfm_anomaly_detection/finetuned-models/timesfm/checkpoints/timesfm_finetuned_20250814192006.pth\"\n",
+ "\n",
+ "MODEL_CHECK_INTERVAL_SECONDS = 10 # Check for a new model every 5
seconds\n",
+ "FINETUNING_BATCH_SIZE = 7680 # 9600 # make larger later. minimum is
WINDOW_SIZE for validation and training\n",
+ "FINETUNE_CONFIG = FinetuningConfig(\n",
+ " batch_size=128,\n",
+ " num_epochs=5,\n",
+ " learning_rate=1e-4,\n",
+ " use_wandb=False,\n",
+ " freq_type=0, # should change based on your data\n",
+ " log_every_n_steps=10,\n",
+ " val_check_interval=0.5,\n",
+ " use_quantile_loss=True\n",
+ " )\n",
+ "\n",
+ "\n",
+ "options = PipelineOptions([\n",
+ " \"--streaming\",\n",
+ " \"--environment_type=LOOPBACK\",\n",
+ " \"--runner=PrismRunner\",\n",
+ " \"--logging_level=INFO\",\n",
+ " \"--job_server_timeout=3600\"\n",
+ "])\n",
+ "\n",
+ "\n",
+ "\n",
+ "# HParams for the model\n",
+ "hparams = timesfm.TimesFmHparams(\n",
+ " backend=\"gpu\",\n",
+ " per_core_batch_size=32,\n",
+ " horizon_len=HORIZON_LEN,\n",
+ " context_len=CONTEXT_LEN,\n",
+ ")\n",
+ "model_handler = DynamicTimesFmModelHandler(model_uri=INITIAL_MODEL,
hparams=hparams)\n",
+ "\n",
+ "def print_and_pass_through(label):\n",
+ " def logger(element):\n",
+ " print(f\"--- {label} --- \\nELEMENT: %s\", element)\n",
+ " return element\n",
+ " return logger\n",
+ "\n",
+ "\n",
+ "class CustomJsonEncoder(json.JSONEncoder):\n",
+ " \"\"\"A custom JSON encoder that knows how to handle Beam's
Timestamp objects.\"\"\"\n",
+ " def default(self, obj):\n",
+ " if isinstance(obj, Timestamp):\n",
+ " # Convert Timestamp to a standard, readable ISO 8601
string format\n",
+ " return obj.micros // 1e6\n",
+ " # For all other types, fall back to the default behavior\n",
+ " if isinstance(obj, np.integer):\n",
+ " return int(obj)\n",
+ "\n",
+ " # 3. Handle NumPy float types (this will fix your float32
error)\n",
+ " if isinstance(obj, np.floating):\n",
+ " return float(obj)\n",
+ "\n",
+ " # 4. Handle NumPy arrays\n",
+ " if isinstance(obj, np.ndarray):\n",
+ " return obj.tolist()\n",
+ "\n",
+ " # For all other types, fall back to the default behavior\n",
+ " return super().default(obj)\n",
+ " return json.JSONEncoder.default(self, obj)\n",
+ "\n",
+ "class WritePlotDataAndPassThrough(beam.DoFn):\n",
+ " \"\"\"\n",
+ " A DoFn that writes plotting data to a file as a side effect\n",
+ " and then passes the original, unmodified element downstream.\n",
+ " \"\"\"\n",
+ " def __init__(self, output_path):\n",
+ " self._output_path = output_path\n",
+ " self._file_handle = None\n",
+ "\n",
+ " def setup(self):\n",
+ " self._file_handle = open(self._output_path, 'a')\n",
+ "\n",
+ " def process(self, element):\n",
+ " _original_window, payload_dict = element\n",
+ "\n",
+ " # ✅ FIX: Use the custom encoder to handle Timestamp
objects\n",
+ " json_record = json.dumps(payload_dict,
cls=CustomJsonEncoder)\n",
+ " self._file_handle.write(json_record + '\\n')\n",
+ "\n",
+ " # Pass the original element through, with the Timestamp
object intact\n",
+ " yield element\n",
+ "\n",
+ " def teardown(self):\n",
+ " if self._file_handle:\n",
+ " self._file_handle.close()\n",
+ "\n",
+ "\n",
+ "#
=================================================================\n",
+ "# 1. Get Latest Model Path (Side Input) - WatchFilePattern is not\n",
+ "# currently supported on Prism. Uncomment the following to run\n",
+ "# on Dataflow\n",
+ "#
=================================================================\n",
+ "# model_pattern = os.path.join(\n",
+ "# f\"gs://{FINETUNED_MODEL_BUCKET}\", FINETUNED_MODEL_PREFIX,
\"*.pth\"\n",
+ "# )\n",
+ "\n",
+ "# model_metadata_pcoll = (\n",
+ "# \"WatchForNewModels\" >> WatchFilePattern(\n",
+ "# file_pattern=model_pattern,\n",
+ "# interval=MODEL_CHECK_INTERVAL_SECONDS\n",
+ "# )\n",
+ "# | \"PrintModelLocation\" >>
beam.Map(print_and_pass_through(\"Model Location\"))\n",
+ "\n",
+ "# )\n",
+ "\n",
+ "#
=================================================================\n",
+ "# Ingest and Window Raw Data\n",
+ "#
=================================================================\n",
+ "\n",
+ "\n",
+ "windowed_data = (\n",
+ " PeriodicImpulse(data=input_data, fire_interval=0.01)\n",
+ " | \"AddKey\" >> beam.WithKeys(lambda x: 0)\n",
+ " | \"ApplySlidingWindow\" >> beam.ParDo(\n",
+ " OrderedSlidingWindowFn(window_size=WINDOW_SIZE,
slide_interval=SLIDE_INTERVAL))\n",
+ " | \"FillGaps\" >>
beam.ParDo(FillGapsFn(expected_interval=EXPECTED_INTERVAL)).with_output_types(\n",
+ " typing.Tuple[int, typing.Tuple[Timestamp, Timestamp,
typing.List[float]]])\n",
+ " | \"Skip NaN Values for now\" >> beam.Filter(\n",
+ " lambda batch: 'NaN' not in batch[1][2])\n",
+ " | \"PrintWindowedData\" >>
beam.Map(print_and_pass_through(\"Windowed Data\"))\n",
+ "\n",
+ ")\n",
+ "\n",
+ "#
=================================================================\n",
+ "# Detect Anomalies using the Latest Model\n",
+ "#
=================================================================\n",
+ "\n",
+ "inference_results = (\n",
+ " \"DetectAnomalies\" >> RunInference(\n",
+ " model_handler=model_handler,\n",
+ " # model_metadata_pcoll=model_metadata_pcoll\n",
+ " )\n",
+ " | \"PrintInference\" >>
beam.Map(print_and_pass_through(\"Inference Results\"))\n",
+ ")\n",
+ "\n",
+ "\n",
+ "# NEW BRANCH: For plotting. It takes the payload dictionary,
converts\n",
+ "# it to JSON, and writes it to a file.\n",
+ "plotting_data_output = (\n",
+ " \"WritePlotDataAsSideEffect\" >> beam.ParDo(\n",
+ "
WritePlotDataAndPassThrough('plot_data_finetuned.jsonl'))\n",
+ ")\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "LRRg2QhPfEZr"
+ },
+ "execution_count": 19,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "with beam.Pipeline(options=options) as p:\n",
+ " (p\n",
+ " | windowed_data\n",
+ " | inference_results\n",
+ " | plotting_data_output\n",
+ " )\n"
+ ],
+ "metadata": {
+ "id": "qAJDAbdGqW_V"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Plot Data (After Finetuning)"
+ ],
+ "metadata": {
+ "id": "_fT-AjbrUOl5"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import json\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "def plot_anomalies_and_forecast(\n",
+ " values_array,\n",
+ " all_anomalies,\n",
+ " all_predicted_values,\n",
+ " all_q20_values,\n",
+ " all_q30_values,\n",
+ " all_q70_values,\n",
+ " all_q80_values,\n",
+ " title_suffix=\"\",\n",
+ " x_lims=None,\n",
+ " min_outlier_score_for_plot=0,\n",
+ " context_len=512,\n",
+ " output_filename=\"plot.png\"\n",
+ "):\n",
+ " # The key from your file is 'outlier_score'\n",
+ " filtered_anomalies = [a for a in all_anomalies if
a['outlier_score'] >= min_outlier_score_for_plot]\n",
+ " # The key from your file is 'timestamp'\n",
+ " anomaly_indices = [(a['timestamp'] - 128) for a in
filtered_anomalies]\n",
+ " anomaly_values = [a['actual_value'] for a in
filtered_anomalies]\n",
+ "\n",
+ " Q1 = np.nanmean([all_q20_values, all_q30_values], axis=0)\n",
+ " Q3 = np.nanmean([all_q70_values, all_q80_values], axis=0)\n",
+ " IQR = Q3 - Q1\n",
+ " upper_thresh = Q3 + 1.5 * IQR\n",
+ " lower_thresh = Q1 - 1.5 * IQR\n",
+ "\n",
+ " plt.figure(figsize=(18, 9))\n",
+ " # This now plots the correct original data for the horizon\n",
+ " plt.plot(values_array[context_len:], label='Original Time
Series', color='blue', alpha=0.7, linewidth=1.5)\n",
+ "\n",
+ " plt.plot(all_predicted_values, label='Predicted Mean',
color='green', linestyle='--', linewidth=1.5)\n",
+ " plt.plot(lower_thresh, label='Lower Threshold', color='orange',
linestyle=':', linewidth=1.2)\n",
+ " plt.plot(upper_thresh, label='Upper Threshold', color='purple',
linestyle=':', linewidth=1.2)\n",
+ "\n",
+ " plt.scatter([i - context_len for i in anomaly_indices],
anomaly_values,\n",
+ " color='red', s=70, zorder=5,\n",
+ " label=f'Detected Anomalies (Score >=
{min_outlier_score_for_plot:.1f})',\n",
+ " marker='o', edgecolors='black', linewidths=0.8)\n",
+ "\n",
+ " plt.title(f'Time Series Anomaly Detection {title_suffix}')\n",
+ " plt.xlabel('Time Index')\n",
+ " plt.ylabel('Value')\n",
+ " if x_lims:\n",
+ " plt.xlim(x_lims[0], x_lims[1])\n",
+ " plt.legend()\n",
+ " plt.grid(True, linestyle='--', alpha=0.6)\n",
+ " plt.tight_layout()\n",
+ " # plt.savefig(output_filename) # Save the plot to a file\n",
+ " plt.show()\n",
+ " plt.close() # Close the figure to free memory\n",
+ "\n",
+ "# --- Main Script Logic ---\n",
+ "\n",
+ "# 1. Read and parse the data from the Beam output file\n",
+ "all_window_data = []\n",
+ "# Make sure 'plot_data.jsonl' is in the same directory as this
script\n",
+ "try:\n",
+ " with open('plot_data_finetuned.jsonl', 'r') as f:\n",
+ " for line in f:\n",
+ " # Check for empty lines that might have been added\n",
+ " if line.strip():\n",
+ " all_window_data.append(json.loads(line))\n",
+ "except FileNotFoundError:\n",
+ " print(\"Error: 'plot_data_finetuned.jsonl' not found. Please make
sure the file is in the correct directory.\")\n",
+ " exit()\n",
+ "\n",
+ "\n",
+ "# 2. Sort data by timestamp to ensure the correct order\n",
+ "all_window_data.sort(key=lambda x: x['start_ts_micros'])\n",
+ "\n",
+ "# 3. Reconstruct the full data arrays\n",
+ "all_anomalies = []\n",
+ "all_predicted_values = []\n",
+ "all_q20_values = []\n",
+ "all_q30_values = []\n",
+ "all_q70_values = []\n",
+ "all_q80_values = []\n",
+ "all_actual_horizon_values = [] # This will hold the real \"blue
line\" data\n",
+ "\n",
+ "for window_data in all_window_data:\n",
+ " all_predicted_values.extend(window_data['predicted_values'])\n",
+ " all_q20_values.extend(window_data['q20_values'])\n",
+ " all_q30_values.extend(window_data['q30_values'])\n",
+ " all_q70_values.extend(window_data['q70_values'])\n",
+ " all_q80_values.extend(window_data['q80_values'])\n",
+ " # Populate the list with the actual values from the file\n",
+ "
all_actual_horizon_values.extend(window_data.get('actual_horizon_values',
[]))\n",
+ " all_anomalies.extend(window_data.get('anomalies', []))\n",
+ "\n",
+ "# 4. Convert lists to NumPy arrays\n",
+ "all_predicted_values = np.array(all_predicted_values)\n",
+ "print(len(all_predicted_values))\n",
+ "all_q20_values = np.array(all_q20_values)\n",
+ "all_q30_values = np.array(all_q30_values)\n",
+ "all_q70_values = np.array(all_q70_values)\n",
+ "all_q80_values = np.array(all_q80_values)\n",
+ "\n",
+ "# 5. Construct the `values_array` using the REAL data from your
file\n",
+ "context_len = 512\n",
+ "# Create a dummy context so the array has the right shape for the
plotting function.\n",
+ "# The first real value is used to make the context visually
seamless.\n",
+ "if all_actual_horizon_values:\n",
+ " dummy_context = [all_actual_horizon_values[0]] * context_len\n",
+ " values_array = np.array(dummy_context +
all_actual_horizon_values)\n",
+ "else:\n",
+ " # Fallback in case the file is empty or missing the
actual_horizon_values key\n",
+ " print(\"Warning: 'actual_horizon_values' not found. The original
time series plot will be empty.\")\n",
+ " total_len = context_len + len(all_predicted_values)\n",
+ " values_array = np.zeros(total_len)\n",
+ "\n",
+ "# 6. Call the plotting functions\n",
+ "if values_array.any():\n",
+ " # Plotting function for full graph\n",
+ " plot_anomalies_and_forecast(\n",
+ " values_array, all_anomalies, all_predicted_values,\n",
+ " all_q20_values, all_q30_values, all_q70_values,
all_q80_values,\n",
+ " title_suffix=\"(Full Graph with Correct Data)\",\n",
+ " min_outlier_score_for_plot=5, # Set a score threshold\n",
+ " context_len=context_len,\n",
+ " output_filename=\"full_graph_correct.png\"\n",
+ " )\n",
+ "\n",
+ " # Plotting function for zoomed-in graphs - feel free to change\n",
+ " zoom_ranges = [(2000, 2500), (8300, 9000), (9000, 9600)]\n",
+ " for i, (start_idx, end_idx) in enumerate(zoom_ranges):\n",
+ " # Adjust x_lims for the fact that the plotted array is sliced
by context_len\n",
+ " plot_x_start = max(0, start_idx)\n",
+ " plot_x_end = end_idx\n",
+ "\n",
+ " plot_anomalies_and_forecast(\n",
+ " values_array, all_anomalies, all_predicted_values,\n",
+ " all_q20_values, all_q30_values, all_q70_values,
all_q80_values,\n",
+ " title_suffix=f\"(Zoomed In: {start_idx} to
{end_idx})\",\n",
+ " x_lims=(plot_x_start, plot_x_end),\n",
+ " min_outlier_score_for_plot=5,\n",
+ " context_len=context_len,\n",
+ " output_filename=f\"zoomed_graph_correct_{i}.png\"\n",
+ " )\n",
+ " print(\"Plots have been generated and saved with the corrected
original data.\")\n",
+ "else:\n",
+ " print(\"No data found to plot.\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 1000
+ },
+ "id": "RZwpgtAr8nlD",
+ "outputId": "d27b08f1-1e77-4109-bfa6-c709edf4b5e8"
+ },
+ "execution_count": 21,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "9600\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "<Figure size 1800x900 with 1 Axes>"
+ ],
+ "image/png":
"iVBORw0KGgoAAAANSUhEUgAABvsAAAN5CAYAAAAmV4erAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XdcU9f7B/BPEvYWFHGwRBFsUcRVtA5cWJWvo466cX+11m3Vap2to2rVDkfddVTrqNqqdePeCqIoKoIbcSCKEEZyf3/wS75GEgSFXOB+3q8Xrzbnntz7nOTylObJOUcmCIIAIiIiIiIiIiIiIiIiIipy5GIHQERERERERERERERERETvh8U+IiIiIiIiIiIiIiIioiKKxT4iIiIiIiIiIiIiIiKiIorFPiIiIiIiIiIiIiIiIqIiisU+IiIiIiIiIiIiIiIioiKKxT4iIiIiIiIiIiIiIi
[...]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "<Figure size 1800x900 with 1 Axes>"
+ ],
+ "image/png":
"iVBORw0KGgoAAAANSUhEUgAABv0AAAN5CAYAAAArSffsAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XmcjfX///HndWbMwswZxjZkGbuhQdamkiUaWaIIZV9SIXvJR1kLFaGFlEKW0p5CtpCQLNn3ncLIMmMwM5xz/f7wm/N1zAxzTuMcR4/77Ta3nPd5X9f79b7ONS+Z17zfl2GapikAAAAAAAAAAAAAPsvi7QAAAAAAAAAAAAAA/DsU/QAAAAAAAAAAAAAfR9EPAAAAAAAAAAAA8HEU/QAAAAAAAAAAAAAfR9EPAAAAAAAAAAAA8HEU/QAAAAAAAAAAAAAfR9EPAAAAAAAAAAAA8HEU/QAAAA
[...]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "<Figure size 1800x900 with 1 Axes>"
+ ],
+ "image/png":
"iVBORw0KGgoAAAANSUhEUgAABv4AAAN5CAYAAADAfkzvAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XdYFNf7NvB76HVBbKBUK6jYS0ii2LFGY48Ve1Rir1+NNbFEjZoYS9SoscSaqAn2gj12FBs2sEQQlSZK3T3vH747P1eKoOgweH+uiyvu2bMzz9mdG6IPMyMJIQSIiIiIiIiIiIiIiIiISNWMlC6AiIiIiIiIiIiIiIiIiN4dG39ERERERERERERERERE+QAbf0RERERERERERERERET5ABt/RERERERERERERERERPkAG39ERERERERERERERERE+QAbf0RERERERERERERERET5ABt/RE
[...]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "<Figure size 1800x900 with 1 Axes>"
+ ],
+ "image/png":
"iVBORw0KGgoAAAANSUhEUgAABv4AAAN5CAYAAADAfkzvAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XdYFFfbBvB76IuwIIhioVlBxd6ISbBjLLH3AvbXErtRX401tqhRE2OPJZZYU0ywG7FrbCg2oiiWKKLSRPru+f7w23ldKYLusLrev+viinvm7MxzZoaHDQ/njCSEECAiIiIiIiIiIiIiIiKi95qZsQMgIiIiIiIiIiIiIiIiorfHwh8RERERERERERERERGRCWDhj4iIiIiIiIiIiIiIiMgEsPBHREREREREREREREREZAJY+CMiIiIiIiIiIiIiIiIyASz8EREREREREREREREREZkAFv
[...]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Plots have been generated and saved with the corrected original
data.\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Compare with RMSE (Original vs Finetuned on Test Data)"
+ ],
+ "metadata": {
+ "id": "WAyaVAEIt4Ey"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import json\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "\n",
+ "def calculate_model_rmse(file_path, finetuning_cutoff_ts,
outlier_threshold):\n",
+ " \"\"\"\n",
+ " Loads model output data, reconstructs the time series, and
calculates RMSE\n",
+ " on a test set after filtering outliers.\n",
+ "\n",
+ " Args:\n",
+ " file_path (str): Path to the model's prediction data in JSONL
format.\n",
+ " finetuning_cutoff_ts (int): Timestamp to split training and
test data.\n",
+ " outlier_threshold (float): Outlier score at or above which to
exclude points.\n",
+ "\n",
+ " Returns:\n",
+ " float: The calculated Root Mean Squared Error.\n",
+ " \"\"\"\n",
+ " all_window_data = []\n",
+ " with open(file_path, 'r') as f:\n",
+ " for line in f:\n",
+ " if line.strip():\n",
+ " all_window_data.append(json.loads(line))\n",
+ "\n",
+ " all_window_data.sort(key=lambda x: x['start_ts_micros'])\n",
+ "\n",
+ " # --- Reconstruct the full time series from the windows ---\n",
+ " timestamps = []\n",
+ " all_predicted_values = []\n",
+ " all_actual_values = []\n",
+ " all_anomalies = []\n",
+ "\n",
+ " current_ts = -1\n",
+ " if all_window_data:\n",
+ " # Initialize with the first window's start time\n",
+ " current_ts = all_window_data[0]['start_ts_micros'] //
1000000\n",
+ "\n",
+ " for window_data in all_window_data:\n",
+ " # Extend the series lists\n",
+ "
all_predicted_values.extend(window_data['predicted_values'])\n",
+ "
all_actual_values.extend(window_data.get('actual_horizon_values', []))\n",
+ " all_anomalies.extend(window_data.get('anomalies', []))\n",
+ "\n",
+ " # Reconstruct the timestamps for each predicted point\n",
+ " start_ts = window_data['start_ts_micros'] // 1000000\n",
+ " for _ in window_data['predicted_values']:\n",
+ " timestamps.append(start_ts)\n",
+ " start_ts += 1\n",
+ "\n",
+ " # Create a lookup for outlier scores\n",
+ " outlier_scores_map = {item['timestamp']: item['outlier_score']
for item in all_anomalies}\n",
+ "\n",
+ " # Ensure the actual values and predicted values align\n",
+ " min_len = min(len(timestamps), len(all_predicted_values),
len(all_actual_values))\n",
+ "\n",
+ " # --- Create a DataFrame for easy filtering and calculation
---\n",
+ " df = pd.DataFrame({\n",
+ " 'timestamp': timestamps[:min_len],\n",
+ " 'actual': all_actual_values[:min_len],\n",
+ " 'predicted': all_predicted_values[:min_len]\n",
+ " })\n",
+ " df['outlier_score'] =
df['timestamp'].map(outlier_scores_map).fillna(0.0)\n",
+ "\n",
+ " # 1. Isolate the test set\n",
+ " df_test = df[df['timestamp'] > finetuning_cutoff_ts].copy()\n",
+ " print(f\"\\n--- Analyzing: {file_path} ---\")\n",
+ " print(f\"Test set size (before filtering): {len(df_test)}
points\")\n",
+ "\n",
+ " # 2. Filter out anomalies based on the threshold\n",
+ " df_filtered = df_test[df_test['outlier_score'] <
outlier_threshold]\n",
+ " num_outliers = len(df_test) - len(df_filtered)\n",
+ " print(f\"Test set size (after filtering): {len(df_filtered)}
points\")\n",
+ " print(f\"Removed {num_outliers} points with outlier_score >=
{outlier_threshold}\")\n",
+ "\n",
+ " # 3. Calculate RMSE\n",
+ " y_true = df_filtered['actual']\n",
+ " y_pred = df_filtered['predicted']\n",
+ " rmse = np.sqrt(np.mean((y_true - y_pred)**2))\n",
+ "\n",
+ " return rmse\n",
+ "\n",
+ "# --- Configuration ---\n",
+ "FINETUNING_CUTOFF_TS = 8320\n",
+ "ORIGINAL_OUTLIER_THRESHOLD = 1.0\n",
+ "FINETUNED_OUTLIER_THRESHOLD = 5.0\n",
+ "\n",
+ "# --- Execution & Comparison ---\n",
+ "original_rmse = calculate_model_rmse(\n",
+ " file_path=\"plot_data_original.jsonl\",\n",
+ " finetuning_cutoff_ts=FINETUNING_CUTOFF_TS,\n",
+ " outlier_threshold=ORIGINAL_OUTLIER_THRESHOLD\n",
+ ")\n",
+ "\n",
+ "finetuned_rmse = calculate_model_rmse(\n",
+ " file_path=\"plot_data_finetuned.jsonl\",\n",
+ " finetuning_cutoff_ts=FINETUNING_CUTOFF_TS,\n",
+ " outlier_threshold=FINETUNED_OUTLIER_THRESHOLD\n",
+ ")\n",
+ "\n",
+ "print(\"\\n--- Final Results ---\")\n",
+ "print(f\"Original Model RMSE: {original_rmse:.2f}\")\n",
+ "print(f\"Fine-tuned Model RMSE: {finetuned_rmse:.2f}\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "zrMOmc90lQhJ",
+ "outputId": "d0679b35-b375-4068-aa4b-bd22a8c6166e"
+ },
+ "execution_count": 22,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\n",
+ "--- Analyzing: plot_data_original.jsonl ---\n",
+ "Test set size (before filtering): 1407 points\n",
+ "Test set size (after filtering): 1369 points\n",
+ "Removed 38 points with outlier_score >= 1.0\n",
+ "\n",
+ "--- Analyzing: plot_data_finetuned.jsonl ---\n",
+ "Test set size (before filtering): 1407 points\n",
+ "Test set size (after filtering): 1384 points\n",
+ "Removed 23 points with outlier_score >= 5.0\n",
+ "\n",
+ "--- Final Results ---\n",
+ "Original Model RMSE: 7164.17\n",
+ "Fine-tuned Model RMSE: 3948.44\n"
+ ]
+ }
+ ]
+ }
+ ]
+}
\ No newline at end of file