Source code for elliptic_toolkit.model_wrappers

import torch
import numpy as np
import warnings
from torch_geometric.nn import GAT, PairNorm

from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin
from sklearn.neural_network import MLPClassifier


def _get_norm_arg(norm_str):
    """
    Utility function to parse normalization layer and arguments from a string.

    Parameters
    ----------
    norm_str : str or None
        A string containing the normaliation layer and optional keword arguments,
        separated by underscores. For example:
        - 'batch' for BatchNorm
        - 'layer' for LayerNorm
        - 'pair' for PairNorm with default args
        - 'pair_scale=0.5' for PairNorm with scale=0.5
        - 'layer_mode=node' for LayerNorm with mode='node'
        - None for no normalization
    Returns
    -------
    norm : callable or None
        Normalization layer class or None.

    norm_kwargs : dict
        Additional keyword arguments for the normalization layer.

    Notes:
    -----
    Since `PairNorm` does not work with `BasicGNN` API, we return the already instantiated `PairNorm`
    object when `pair` is specified.
    """

    if not isinstance(norm_str, str):
        return norm_str, {}

    norm_args = norm_str.split('_')
    if len(norm_args) == 1:
        return norm_args[0], {}

    norm = norm_args[0]
    kwargs = {}
    for arg in norm_args[1:]:
        key, value = arg.split('=')
        try:
            value = int(value)
        except ValueError:
            try:
                value = float(value)
            except ValueError:
                pass
        kwargs[key] = value
    if norm == 'pair':
        return PairNorm(**(kwargs or {})), {}
    return norm, kwargs


[docs] class GNNBinaryClassifier(ClassifierMixin, BaseEstimator): """ Graph Neural Network Binary Classifier with early stopping. A scikit-learn compatible binary classifier that wraps around PyTorch Geometric GNN models. Currently supports transductive and full batch learning models (GCN, GAT). The training loss is monitored and the model is considered converged if the loss does not improve for `n_iter_no_change` consecutive iterations by at least `tol`. This early stopping mechanism is always enabled, similar to `MLPClassifier` in scikit-learn. Parameters ---------- data : torch_geometric.data.Data Graph data object containing node features (x), edge indices (edge_index), and node labels (y). model : torch.nn.Module class The GNN model class to instantiate for training. hidden_dim : int, default=64 Number of hidden units in each layer. num_layers : int, default=3 Number of layers in the neural network. dropout : float, default=0.5 Dropout probability for regularization. learning_rate_init : float, default=0.01 Initial learning rate for the Adam optimizer. weight_decay : float, default=5e-4 L2 regularization strength. balance_loss : bool, default=True Whether to balance the loss function by weighting positive samples. If True, uses positive class weighting in BCEWithLogitsLoss based on class frequencies. If False, uses unweighted loss. max_iter : int, default=200 Maximum number of training iterations. verbose : bool, default=False Whether to print training progress. n_iter_no_change : int, default=10 Number of consecutive iterations with no improvement to trigger early stopping. tol : float, default=1e-4 Tolerance for improvement. Training stops if loss improvement is less than this value. device : str or torch.device, default='auto' Device to use for computation. Can be 'cpu', 'cuda', 'auto', or a torch.device object. If 'auto', will use CUDA if available, otherwise CPU. heads : int, default=None Number of attention heads for GAT models. Only applicable when model=GAT. Ignored with a warning for other model types. **kwargs : dict Additional keyword arguments passed to the model constructor. Attributes: ---------- loss_curve_ : list List of loss values at each training iteration. model_ : torch.nn.Module The trained GNN model after calling `fit`. """ def _validate_data(self, data): """ Validate that the data object has required attributes. Parameters ---------- data : object Data object to validate. Returns ------- data : object Validated data object. Raises ------ ValueError If data object is missing required attributes. """ attributes = ['x', 'edge_index', 'y'] for attr in attributes: if not hasattr(data, attr) or getattr(data, attr) is None: raise ValueError( f"Data object must have '{attr}' attribute and be non-null.") return data def _validate_device(self, device): """ Validate and set the device for computation. Parameters ---------- device : str or torch.device Device specification. Returns ------- torch.device Validated device object. Raises ------ ValueError If device is invalid or CUDA is requested but not available. """ if device == 'auto': return torch.device('cuda' if torch.cuda.is_available() else 'cpu') try: if isinstance(device, torch.device): device_obj = device else: device_obj = torch.device(device) if device_obj.type == 'cuda' and not torch.cuda.is_available(): warnings.warn( "CUDA is not available, falling back to CPU", UserWarning) return torch.device('cpu') return device_obj except (RuntimeError, ValueError) as e: raise ValueError(f"Invalid device '{device}': {e}")
[docs] def __init__( self, data, model, hidden_dim=64, num_layers=3, dropout=0.5, norm=None, jk='last', learning_rate_init=0.01, weight_decay=5e-4, balance_loss=True, max_iter=200, verbose=False, n_iter_no_change=10, tol=1e-4, device='auto', heads=None, **kwargs, ): super().__init__() self.data = self._validate_data(data) self.model = model self.hidden_dim = hidden_dim self.num_layers = num_layers self.dropout = dropout self.norm = norm self.jk = jk self.learning_rate_init = learning_rate_init self.weight_decay = weight_decay self.balance_loss = balance_loss self.max_iter = max_iter self.verbose = verbose self.n_iter_no_change = n_iter_no_change self.tol = tol self.device = self._validate_device(device) # Store validated device self.heads = heads self.kwargs = kwargs # Move data to device self.data = self.data.to(self.device) # Handle heads parameter properly if heads is not None and model != GAT: warnings.warn( "'heads' parameter is only applicable for GAT model. Ignoring 'heads'.", UserWarning) self.heads = heads # Store for sklearn but don't use elif model == GAT: # For GAT, use heads if provided, otherwise default to 1 actual_heads = heads if heads is not None else 1 self.heads = actual_heads self.kwargs['heads'] = actual_heads else: # For non-GAT models with heads=None self.heads = heads if self.verbose: print(f"Using device: {self.device}")
def _get_pos_weight(self, indices): """ Calculate positive class weight for balanced loss computation. Parameters ---------- indices : torch.Tensor or array-like Indices of training samples. Returns ------- torch.Tensor Weight for positive class to balance the loss. """ y = self.data.y[indices] pos_weight = (y == 0).sum() / (y == 1).sum() return pos_weight.to(self.device)
[docs] def fit(self, X, y=None): """ Fit the GNN model to the training data. Training automatically stops when the loss stops improving for n_iter_no_change consecutive iterations, similar to MLPClassifier. Parameters ---------- train_indices : array-like Indices of training samples in the graph. y : array-like, default=None Target values (ignored, present for sklearn compatibility). Returns ------- self : GNNBinaryClassifier Returns self for method chaining. Warns ----- UserWarning If training stops due to max_iter being reached without convergence. """ train_indices = X num_features = self.data.x.shape[1] norm_layer, norm_kwargs = _get_norm_arg(self.norm) self.model_ = self.model( in_channels=num_features, hidden_channels=self.hidden_dim, out_channels=1, num_layers=self.num_layers, dropout=self.dropout, norm=norm_layer, norm_kwargs=norm_kwargs, jk=self.jk, **self.kwargs ).to(self.device) optimizer = torch.optim.Adam( self.model_.parameters(), lr=self.learning_rate_init, weight_decay=self.weight_decay) # Convert indices to tensor and move to device if not isinstance(train_indices, torch.Tensor): train_indices = torch.tensor( train_indices, dtype=torch.long, device=self.device) else: train_indices = train_indices.to(self.device) if self.balance_loss: criterion = torch.nn.BCEWithLogitsLoss( pos_weight=self._get_pos_weight(train_indices)) else: criterion = torch.nn.BCEWithLogitsLoss() # Early stopping variables best_loss = float('inf') no_improvement_count = 0 self.loss_curve_ = [] self.model_.train() converged = False for epoch in range(1, self.max_iter + 1): optimizer.zero_grad() out = self.model_(self.data.x, self.data.edge_index).squeeze() loss = criterion( out[train_indices], self.data.y[train_indices].float()) loss.backward() optimizer.step() current_loss = loss.item() self.loss_curve_.append(current_loss) if self.verbose: print(f"Epoch {epoch}: Loss = {current_loss:.6f}") # Early stopping logic (always enabled) if current_loss < best_loss - self.tol: best_loss = current_loss no_improvement_count = 0 else: no_improvement_count += 1 if no_improvement_count >= self.n_iter_no_change: if self.verbose: print( f"Early stopping at epoch {epoch}. No improvement for { self.n_iter_no_change} iterations.") converged = True break # Warn if training ended without convergence if not converged: warnings.warn( f"Training stopped before reaching convergence. Consider increasing " f"max_iter (currently {self.max_iter}) or decreasing tol " f"(currently {self.tol}) for better results.", UserWarning ) return self
[docs] def predict(self, X): """ Predict class labels for samples in test_indices. Parameters ---------- test_indices : array-like Indices of test samples in the graph. Returns ------- predictions : ndarray of shape (n_samples,) Predicted class labels (0 or 1). Raises ------ ValueError If the classifier has not been fitted yet. """ probs = self.predict_proba(X) predictions = (probs[:, 1] > 0.5).astype(int) return predictions
[docs] def predict_proba(self, X): """ Predict class probabilities for samples in test_indices. Parameters ---------- test_indices : array-like Indices of test samples in the graph. Returns ------- probabilities : ndarray of shape (n_samples, 2) Predicted class probabilities. First column contains probabilities for class 0, second column for class 1. Raises ------ ValueError If the classifier has not been fitted yet. """ test_indices = X if not hasattr(self, 'model_'): raise ValueError( "This GNNBinaryClassifier instance is not fitted yet.") # Convert indices to tensor and move to device if not isinstance(test_indices, torch.Tensor): test_indices = torch.tensor( test_indices, dtype=torch.long, device=self.device) else: test_indices = test_indices.to(self.device) self.model_.eval() with torch.no_grad(): out = self.model_(self.data.x, self.data.edge_index).squeeze() # Debug: Check for inf/nan in raw outputs raw_outputs = out[test_indices] if torch.isnan(raw_outputs).any( ) or torch.isinf(raw_outputs).any(): warnings.warn( "Model outputs contain NaN or Inf values. Using fallback predictions.") # Fallback to neutral probabilities proba_positive = np.full(len(test_indices), 0.5) else: # Clamp extreme values to prevent numerical issues proba_positive = torch.sigmoid(raw_outputs).cpu().numpy() proba_negative = 1 - proba_positive return np.column_stack([proba_negative, proba_positive])
@property def classes_(self): return np.array([0, 1])
[docs] class DropTime(BaseEstimator, TransformerMixin): """ Transformer for dropping the 'time' column from a DataFrame. Useful in scikit-learn pipelines. """
[docs] def __init__(self, drop=True): self.drop = drop
[docs] def fit(self, X, y=None): return self
[docs] def transform(self, X): if self.drop: return X.drop(columns=["time"]) return X
[docs] class MLPWrapper(MLPClassifier): """ Wrapper around sklearn's MLPClassifier to allow specifying the number of layers and hidden dimension directly. This is useful for hyperparameter tuning where hyperparameters need to be independent. Some parameters of the base MLPClassifier are fixed to ensure consistent behavior: - shuffle=False: Disable shuffling to maintain temporal order. - early_stopping=False: Disable internal test/validation split for validation loss based early stopping and use training loss based early stopping instead. Parameters ---------- num_layers : int, default=2 Number of hidden layers in the MLP. hidden_dim : int, default=16 Number of units in each hidden layer. hidden_layer_sizes : tuple or None, default=None If provided, this overrides num_layers and hidden_dim. Should be a tuple specifying the size of each hidden layer. alpha : float, default=0.0001 L2 regularization term. learning_rate_init : float, default=0.001 Initial learning rate. batch_size : int or 'auto', default='auto' Size of minibatches for stochastic optimizers. max_iter : int, default=1000 Maximum number of iterations. """
[docs] def __init__( self, num_layers=2, hidden_dim=16, hidden_layer_sizes=None, # Add this to make sklearn happy alpha=0.0001, learning_rate_init=0.001, batch_size='auto', max_iter=1000, ): self.num_layers = num_layers self.hidden_dim = hidden_dim self.hidden_layer_sizes = hidden_layer_sizes # Store it as an attribute self.alpha = alpha self.learning_rate_init = learning_rate_init self.batch_size = batch_size self.max_iter = max_iter # Use hidden_layer_sizes if provided, otherwise construct from # num_layers/hidden_dim if hidden_layer_sizes is None: hidden_layer_sizes = tuple([hidden_dim] * num_layers) super().__init__( hidden_layer_sizes=hidden_layer_sizes, alpha=alpha, learning_rate_init=learning_rate_init, max_iter=max_iter, batch_size=batch_size, shuffle=False, early_stopping=False, )
[docs] def set_params(self, **params): if 'num_layers' in params or 'hidden_dim' in params: num_layers = params.pop('num_layers', self.num_layers) hidden_dim = params.pop('hidden_dim', self.hidden_dim) params['hidden_layer_sizes'] = tuple([hidden_dim] * num_layers) self.num_layers = num_layers self.hidden_dim = hidden_dim return super().set_params(**params)