{ "cells": [ { "cell_type": "markdown", "id": "85a9c03e", "metadata": {}, "source": [ "# Elliptic Bitcoin Dataset: Graph Convolutional Network Example\n", "\n", "This notebook demonstrates how to use Graph Convolutional Networks (GCNs) with the Elliptic Bitcoin dataset. It covers:\n", "\n", "- Loading and preparing the graph data\n", "- Training a GCN model for binary classification\n", "- Evaluating model performance\n", "- Hyperparameter tuning with temporal cross-validation\n", "\n", "This example uses PyTorch Geometric for graph neural network operations." ] }, { "cell_type": "code", "execution_count": null, "id": "d5580280", "metadata": {}, "outputs": [], "source": [ "from elliptic_toolkit import download_dataset, process_dataset, temporal_split, TemporalRollingCV, GNNBinaryClassifier\n", "from torch_geometric.data import Data\n", "import torch" ] }, { "cell_type": "markdown", "id": "4d13de55", "metadata": {}, "source": [ "# Loading the Dataset\n", "\n", "First, download the Elliptic Bitcoin dataset. This will automatically save the data files in the correct location for further processing." ] }, { "cell_type": "code", "execution_count": null, "id": "9084a4ae", "metadata": {}, "outputs": [], "source": [ "download_dataset()" ] }, { "cell_type": "markdown", "id": "2fb79fe4", "metadata": {}, "source": [ "# Preparing Graph Data\n", "\n", "Process the dataset to create a PyTorch Geometric `Data` object containing:\n", "\n", "- Node features (transaction features)\n", "- Edge indices (transaction connections)\n", "- Node labels (illicit/licit classification)\n", "- Time information for temporal splitting\n", "\n", "We also create training and test splits based on temporal ordering, focusing only on labeled transactions." ] }, { "cell_type": "code", "execution_count": null, "id": "47eebebd", "metadata": {}, "outputs": [], "source": [ "nodes_df, edges_df = process_dataset()\n", "data = Data(\n", " x=torch.tensor(nodes_df.drop(columns=['time', 'class']).values, dtype=torch.float),\n", " edge_index=torch.tensor(edges_df.values.T, dtype=torch.long),\n", " y=torch.tensor(nodes_df['class'].values, dtype=torch.long),\n", " time=torch.tensor(nodes_df['time'].values, dtype=torch.long)\n", ")\n", "\n", "train_val_idx, test_idx = temporal_split(data.time)\n", "\n", "labeled_mask = data.y != -1\n", "train_val_idx = train_val_idx[labeled_mask[train_val_idx]]\n", "test_idx = test_idx[labeled_mask[test_idx]]\n", "\n" ] }, { "cell_type": "markdown", "id": "26995ccf", "metadata": {}, "source": [ "# Training a GCN Model\n", "\n", "Create and train a Graph Convolutional Network using the `GNNBinaryClassifier` wrapper. \n", "\n", "**Note:** The model uses a low number of iterations (`max_iter=50`) for demonstration purposes, which may cause convergence warnings. In practice, you would use more iterations for better convergence." ] }, { "cell_type": "code", "execution_count": null, "id": "073f86cc", "metadata": {}, "outputs": [], "source": [ "from torch_geometric.nn import GCN\n", "\n", "gcn_model = GNNBinaryClassifier(\n", " data,\n", " GCN,\n", " hidden_dim=8,\n", " num_layers=3,\n", " dropout=0.3,\n", " verbose=True,\n", " device='cpu',\n", " max_iter=50,\n", ")\n", "\n", "gcn_model.fit(train_val_idx)" ] }, { "cell_type": "markdown", "id": "74178e58", "metadata": {}, "source": [ "# Model Evaluation\n", "\n", "Evaluate the trained GCN model using a Precision-Recall curve on the test set. This provides insight into the model's ability to distinguish between illicit and licit transactions." ] }, { "cell_type": "code", "execution_count": null, "id": "03a5becc", "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import PrecisionRecallDisplay\n", "\n", "PrecisionRecallDisplay.from_estimator(\n", " gcn_model,\n", " test_idx,\n", " data.y[test_idx],\n", " name=\"GCN Model\",\n", ")\n", "from matplotlib import pyplot as plt\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "5308b11e", "metadata": {}, "source": [ "# Hyperparameter Tuning\n", "\n", "Perform grid search to find optimal hyperparameters using temporal cross-validation. This ensures the model evaluation respects the temporal nature of the data.\n", "\n", "The GCN model knows the full graph at training time and we only pass the indices over which we compute the loss. Note that we will have to pass the time steps in the `fit` method as `groups` in order to make them known to the cross validation splitter." ] }, { "cell_type": "code", "execution_count": null, "id": "d85660aa", "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import GridSearchCV\n", "\n", "gcn_model.set_params(verbose=False)\n", "\n", "grid_search = GridSearchCV(\n", " gcn_model,\n", " param_grid={\n", " 'hidden_dim': [2, 4, 8, 16],\n", " },\n", " cv=TemporalRollingCV(3),\n", " scoring='average_precision',\n", " n_jobs=-1,\n", " verbose=1,\n", ")\n", "\n", "grid_search.fit(train_val_idx, data.y[train_val_idx], groups=data.time[train_val_idx])" ] }, { "cell_type": "markdown", "id": "aa0da0ea", "metadata": {}, "source": [ "# Visualizing Results\n", "\n", "Plot the marginal effects of hyperparameters and temporal evaluation results to understand model performance and parameter sensitivity." ] }, { "cell_type": "code", "execution_count": null, "id": "0007d5d5", "metadata": {}, "outputs": [], "source": [ "from elliptic_toolkit import plot_marginals, plot_evals\n", "for fig in plot_marginals(grid_search.cv_results_):\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "876df11f", "metadata": {}, "outputs": [], "source": [ "for fig in plot_evals(grid_search, test_idx, data.y[test_idx].numpy(), data.y[train_val_idx].numpy(), time_steps_test=data.time[test_idx].numpy()):\n", " plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "py313_torch_cuda216", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.5" } }, "nbformat": 4, "nbformat_minor": 5 }