{ "cells": [ { "cell_type": "markdown", "id": "7cda2fe0", "metadata": {}, "source": [ "# Temporal Rolling CV: Splits Visualization\n", "This notebook aims to show how `TemporalRollingCV` works. We start by importing the necessary packages and creating a small utility fuction to plot the splits" ] }, { "cell_type": "code", "execution_count": null, "id": "3d2f250f", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from elliptic_toolkit import TemporalRollingCV, temporal_split" ] }, { "cell_type": "code", "execution_count": null, "id": "36b479d9", "metadata": {}, "outputs": [], "source": [ "# utility function to visualize the splits\n", "def show_splits(cv, train_val_times, test_times):\n", " n_folds = cv.get_n_splits() + 1\n", " plt.figure(figsize=(10, 6))\n", "\n", " for (fold, (train_indices, val_indices)) in enumerate(cv.split(train_val_times, groups=train_val_times)):\n", " train_times = np.unique(train_val_times[train_indices])\n", " val_times = np.unique(train_val_times[val_indices])\n", " plt.scatter(train_times, [fold+1]*len(train_times), color='blue', label='Train' if fold==0 else None, marker='o', s=100)\n", " plt.scatter(val_times, [fold+1]*len(val_times), color='orange', label='Validation' if fold==0 else None, marker='s', s=100)\n", " plt.scatter(test_times, [fold+1]*len(test_times), color='red', label='Test' if fold==0 else None, marker='^', s=100)\n", "\n", " folds = range(1, n_folds)\n", " plt.xlabel('Time Step')\n", " plt.ylabel('Fold')\n", " plt.title('Train and Validation Time Steps per Fold')\n", " plt.legend()\n", " plt.yticks(folds, [f'Fold {i}' for i in folds])\n", " plt.tight_layout()" ] }, { "cell_type": "markdown", "id": "42c6d163", "metadata": {}, "source": [ "We then create an array that should mimic a time index an that has more then one sample for each time step. We then hold out 20% of the unique time steps, those would be the one we would usually use for the final testing sample" ] }, { "cell_type": "code", "execution_count": null, "id": "eb7c7595", "metadata": {}, "outputs": [], "source": [ "times = np.sort(np.random.randint(0, 10, size=30))\n", "train_val_times, test_times = temporal_split(times, test_size=0.2)" ] }, { "cell_type": "markdown", "id": "33ac2251", "metadata": {}, "source": [ "## Basic TemporalRollingCV\n", "Standard temporal cross-validation with 5 folds. Each fold uses all previous time steps for training and the next available time step for validation. The number of time steps used for training and validation is automatically computed. " ] }, { "cell_type": "code", "execution_count": null, "id": "5ed40d6a", "metadata": {}, "outputs": [], "source": [ "show_splits(TemporalRollingCV(n_splits=5), train_val_times, test_times)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "4b4a5a8c", "metadata": {}, "source": [ "## TemporalRollingCV with Gap\n", "Adds a 2-time-step gap between training and validation sets." ] }, { "cell_type": "code", "execution_count": null, "id": "bc7eaaac", "metadata": {}, "outputs": [], "source": [ "show_splits(TemporalRollingCV(n_splits=5, gap=2), train_val_times, test_times)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "4246d4a3", "metadata": {}, "source": [ "## TemporalRollingCV with Limited Training Window\n", "Limits training data to a maximum of 4 time steps, creating a sliding window approach that maintains consistent training set sizes." ] }, { "cell_type": "code", "execution_count": null, "id": "0fe229c6", "metadata": {}, "outputs": [], "source": [ "show_splits(TemporalRollingCV(n_splits=5, max_train_size=4), train_val_times, test_times)\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "5fbcca8c", "metadata": {}, "source": [ "## TemporalRollingCV with Fixed Number of Time Steps for Validation\n", "Fix the number of time steps to use for validation" ] }, { "cell_type": "code", "execution_count": null, "id": "909ca316", "metadata": {}, "outputs": [], "source": [ "show_splits(TemporalRollingCV(n_splits=5, test_size=2), train_val_times, test_times)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "d8344cce", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py313_torch_cuda216", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.5" } }, "nbformat": 4, "nbformat_minor": 5 }