import torch
import numpy as np
import warnings
from torch_geometric.nn import GAT, PairNorm
from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin
from sklearn.neural_network import MLPClassifier
def _get_norm_arg(norm_str):
"""
Utility function to parse normalization layer and arguments from a string.
Parameters
----------
norm_str : str or None
A string containing the normaliation layer and optional keword arguments,
separated by underscores. For example:
- 'batch' for BatchNorm
- 'layer' for LayerNorm
- 'pair' for PairNorm with default args
- 'pair_scale=0.5' for PairNorm with scale=0.5
- 'layer_mode=node' for LayerNorm with mode='node'
- None for no normalization
Returns
-------
norm : callable or None
Normalization layer class or None.
norm_kwargs : dict
Additional keyword arguments for the normalization layer.
Notes:
-----
Since `PairNorm` does not work with `BasicGNN` API, we return the already instantiated `PairNorm`
object when `pair` is specified.
"""
if not isinstance(norm_str, str):
return norm_str, {}
norm_args = norm_str.split('_')
if len(norm_args) == 1:
return norm_args[0], {}
norm = norm_args[0]
kwargs = {}
for arg in norm_args[1:]:
key, value = arg.split('=')
try:
value = int(value)
except ValueError:
try:
value = float(value)
except ValueError:
pass
kwargs[key] = value
if norm == 'pair':
return PairNorm(**(kwargs or {})), {}
return norm, kwargs
[docs]
class GNNBinaryClassifier(ClassifierMixin, BaseEstimator):
"""
Graph Neural Network Binary Classifier with early stopping.
A scikit-learn compatible binary classifier that wraps around PyTorch Geometric GNN models.
Currently supports transductive and full batch learning models (GCN, GAT).
The training loss is monitored and the model is considered converged if the loss does not improve
for `n_iter_no_change` consecutive iterations by at least `tol`. This early stopping
mechanism is always enabled, similar to `MLPClassifier` in scikit-learn.
Parameters
----------
data : torch_geometric.data.Data
Graph data object containing node features (x), edge indices (edge_index),
and node labels (y).
model : torch.nn.Module class
The GNN model class to instantiate for training.
hidden_dim : int, default=64
Number of hidden units in each layer.
num_layers : int, default=3
Number of layers in the neural network.
dropout : float, default=0.5
Dropout probability for regularization.
learning_rate_init : float, default=0.01
Initial learning rate for the Adam optimizer.
weight_decay : float, default=5e-4
L2 regularization strength.
balance_loss : bool, default=True
Whether to balance the loss function by weighting positive samples.
If True, uses positive class weighting in BCEWithLogitsLoss based on
class frequencies. If False, uses unweighted loss.
max_iter : int, default=200
Maximum number of training iterations.
verbose : bool, default=False
Whether to print training progress.
n_iter_no_change : int, default=10
Number of consecutive iterations with no improvement to trigger early stopping.
tol : float, default=1e-4
Tolerance for improvement. Training stops if loss improvement is less than this value.
device : str or torch.device, default='auto'
Device to use for computation. Can be 'cpu', 'cuda', 'auto', or a torch.device object.
If 'auto', will use CUDA if available, otherwise CPU.
heads : int, default=None
Number of attention heads for GAT models. Only applicable when model=GAT.
Ignored with a warning for other model types.
**kwargs : dict
Additional keyword arguments passed to the model constructor.
Attributes:
----------
loss_curve_ : list
List of loss values at each training iteration.
model_ : torch.nn.Module
The trained GNN model after calling `fit`.
"""
def _validate_data(self, data):
"""
Validate that the data object has required attributes.
Parameters
----------
data : object
Data object to validate.
Returns
-------
data : object
Validated data object.
Raises
------
ValueError
If data object is missing required attributes.
"""
attributes = ['x', 'edge_index', 'y']
for attr in attributes:
if not hasattr(data, attr) or getattr(data, attr) is None:
raise ValueError(
f"Data object must have '{attr}' attribute and be non-null.")
return data
def _validate_device(self, device):
"""
Validate and set the device for computation.
Parameters
----------
device : str or torch.device
Device specification.
Returns
-------
torch.device
Validated device object.
Raises
------
ValueError
If device is invalid or CUDA is requested but not available.
"""
if device == 'auto':
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
try:
if isinstance(device, torch.device):
device_obj = device
else:
device_obj = torch.device(device)
if device_obj.type == 'cuda' and not torch.cuda.is_available():
warnings.warn(
"CUDA is not available, falling back to CPU", UserWarning)
return torch.device('cpu')
return device_obj
except (RuntimeError, ValueError) as e:
raise ValueError(f"Invalid device '{device}': {e}")
[docs]
def __init__(
self,
data,
model,
hidden_dim=64,
num_layers=3,
dropout=0.5,
norm=None,
jk='last',
learning_rate_init=0.01,
weight_decay=5e-4,
balance_loss=True,
max_iter=200,
verbose=False,
n_iter_no_change=10,
tol=1e-4,
device='auto',
heads=None,
**kwargs,
):
super().__init__()
self.data = self._validate_data(data)
self.model = model
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.dropout = dropout
self.norm = norm
self.jk = jk
self.learning_rate_init = learning_rate_init
self.weight_decay = weight_decay
self.balance_loss = balance_loss
self.max_iter = max_iter
self.verbose = verbose
self.n_iter_no_change = n_iter_no_change
self.tol = tol
self.device = self._validate_device(device) # Store validated device
self.heads = heads
self.kwargs = kwargs
# Move data to device
self.data = self.data.to(self.device)
# Handle heads parameter properly
if heads is not None and model != GAT:
warnings.warn(
"'heads' parameter is only applicable for GAT model. Ignoring 'heads'.",
UserWarning)
self.heads = heads # Store for sklearn but don't use
elif model == GAT:
# For GAT, use heads if provided, otherwise default to 1
actual_heads = heads if heads is not None else 1
self.heads = actual_heads
self.kwargs['heads'] = actual_heads
else:
# For non-GAT models with heads=None
self.heads = heads
if self.verbose:
print(f"Using device: {self.device}")
def _get_pos_weight(self, indices):
"""
Calculate positive class weight for balanced loss computation.
Parameters
----------
indices : torch.Tensor or array-like
Indices of training samples.
Returns
-------
torch.Tensor
Weight for positive class to balance the loss.
"""
y = self.data.y[indices]
pos_weight = (y == 0).sum() / (y == 1).sum()
return pos_weight.to(self.device)
[docs]
def fit(self, X, y=None):
"""
Fit the GNN model to the training data.
Training automatically stops when the loss stops improving for
n_iter_no_change consecutive iterations, similar to MLPClassifier.
Parameters
----------
train_indices : array-like
Indices of training samples in the graph.
y : array-like, default=None
Target values (ignored, present for sklearn compatibility).
Returns
-------
self : GNNBinaryClassifier
Returns self for method chaining.
Warns
-----
UserWarning
If training stops due to max_iter being reached without convergence.
"""
train_indices = X
num_features = self.data.x.shape[1]
norm_layer, norm_kwargs = _get_norm_arg(self.norm)
self.model_ = self.model(
in_channels=num_features,
hidden_channels=self.hidden_dim,
out_channels=1,
num_layers=self.num_layers,
dropout=self.dropout,
norm=norm_layer,
norm_kwargs=norm_kwargs,
jk=self.jk,
**self.kwargs
).to(self.device)
optimizer = torch.optim.Adam(
self.model_.parameters(),
lr=self.learning_rate_init,
weight_decay=self.weight_decay)
# Convert indices to tensor and move to device
if not isinstance(train_indices, torch.Tensor):
train_indices = torch.tensor(
train_indices, dtype=torch.long, device=self.device)
else:
train_indices = train_indices.to(self.device)
if self.balance_loss:
criterion = torch.nn.BCEWithLogitsLoss(
pos_weight=self._get_pos_weight(train_indices))
else:
criterion = torch.nn.BCEWithLogitsLoss()
# Early stopping variables
best_loss = float('inf')
no_improvement_count = 0
self.loss_curve_ = []
self.model_.train()
converged = False
for epoch in range(1, self.max_iter + 1):
optimizer.zero_grad()
out = self.model_(self.data.x, self.data.edge_index).squeeze()
loss = criterion(
out[train_indices],
self.data.y[train_indices].float())
loss.backward()
optimizer.step()
current_loss = loss.item()
self.loss_curve_.append(current_loss)
if self.verbose:
print(f"Epoch {epoch}: Loss = {current_loss:.6f}")
# Early stopping logic (always enabled)
if current_loss < best_loss - self.tol:
best_loss = current_loss
no_improvement_count = 0
else:
no_improvement_count += 1
if no_improvement_count >= self.n_iter_no_change:
if self.verbose:
print(
f"Early stopping at epoch {epoch}. No improvement for {
self.n_iter_no_change} iterations.")
converged = True
break
# Warn if training ended without convergence
if not converged:
warnings.warn(
f"Training stopped before reaching convergence. Consider increasing "
f"max_iter (currently {self.max_iter}) or decreasing tol "
f"(currently {self.tol}) for better results.",
UserWarning
)
return self
[docs]
def predict(self, X):
"""
Predict class labels for samples in test_indices.
Parameters
----------
test_indices : array-like
Indices of test samples in the graph.
Returns
-------
predictions : ndarray of shape (n_samples,)
Predicted class labels (0 or 1).
Raises
------
ValueError
If the classifier has not been fitted yet.
"""
probs = self.predict_proba(X)
predictions = (probs[:, 1] > 0.5).astype(int)
return predictions
[docs]
def predict_proba(self, X):
"""
Predict class probabilities for samples in test_indices.
Parameters
----------
test_indices : array-like
Indices of test samples in the graph.
Returns
-------
probabilities : ndarray of shape (n_samples, 2)
Predicted class probabilities. First column contains probabilities
for class 0, second column for class 1.
Raises
------
ValueError
If the classifier has not been fitted yet.
"""
test_indices = X
if not hasattr(self, 'model_'):
raise ValueError(
"This GNNBinaryClassifier instance is not fitted yet.")
# Convert indices to tensor and move to device
if not isinstance(test_indices, torch.Tensor):
test_indices = torch.tensor(
test_indices, dtype=torch.long, device=self.device)
else:
test_indices = test_indices.to(self.device)
self.model_.eval()
with torch.no_grad():
out = self.model_(self.data.x, self.data.edge_index).squeeze()
# Debug: Check for inf/nan in raw outputs
raw_outputs = out[test_indices]
if torch.isnan(raw_outputs).any(
) or torch.isinf(raw_outputs).any():
warnings.warn(
"Model outputs contain NaN or Inf values. Using fallback predictions.")
# Fallback to neutral probabilities
proba_positive = np.full(len(test_indices), 0.5)
else:
# Clamp extreme values to prevent numerical issues
proba_positive = torch.sigmoid(raw_outputs).cpu().numpy()
proba_negative = 1 - proba_positive
return np.column_stack([proba_negative, proba_positive])
@property
def classes_(self):
return np.array([0, 1])
[docs]
class DropTime(BaseEstimator, TransformerMixin):
"""
Transformer for dropping the 'time' column from a DataFrame.
Useful in scikit-learn pipelines.
"""
[docs]
def __init__(self, drop=True):
self.drop = drop
[docs]
def fit(self, X, y=None):
return self
[docs]
class MLPWrapper(MLPClassifier):
"""
Wrapper around sklearn's MLPClassifier to allow specifying the number of layers
and hidden dimension directly. This is useful for hyperparameter tuning where
hyperparameters need to be independent.
Some parameters of the base MLPClassifier are fixed to ensure consistent behavior:
- shuffle=False: Disable shuffling to maintain temporal order.
- early_stopping=False: Disable internal test/validation split
for validation loss based early stopping and use training loss based early stopping instead.
Parameters
----------
num_layers : int, default=2
Number of hidden layers in the MLP.
hidden_dim : int, default=16
Number of units in each hidden layer.
hidden_layer_sizes : tuple or None, default=None
If provided, this overrides num_layers and hidden_dim.
Should be a tuple specifying the size of each hidden layer.
alpha : float, default=0.0001
L2 regularization term.
learning_rate_init : float, default=0.001
Initial learning rate.
batch_size : int or 'auto', default='auto'
Size of minibatches for stochastic optimizers.
max_iter : int, default=1000
Maximum number of iterations.
"""
[docs]
def __init__(
self,
num_layers=2,
hidden_dim=16,
hidden_layer_sizes=None, # Add this to make sklearn happy
alpha=0.0001,
learning_rate_init=0.001,
batch_size='auto',
max_iter=1000,
):
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.hidden_layer_sizes = hidden_layer_sizes # Store it as an attribute
self.alpha = alpha
self.learning_rate_init = learning_rate_init
self.batch_size = batch_size
self.max_iter = max_iter
# Use hidden_layer_sizes if provided, otherwise construct from
# num_layers/hidden_dim
if hidden_layer_sizes is None:
hidden_layer_sizes = tuple([hidden_dim] * num_layers)
super().__init__(
hidden_layer_sizes=hidden_layer_sizes,
alpha=alpha,
learning_rate_init=learning_rate_init,
max_iter=max_iter,
batch_size=batch_size,
shuffle=False,
early_stopping=False,
)
[docs]
def set_params(self, **params):
if 'num_layers' in params or 'hidden_dim' in params:
num_layers = params.pop('num_layers', self.num_layers)
hidden_dim = params.pop('hidden_dim', self.hidden_dim)
params['hidden_layer_sizes'] = tuple([hidden_dim] * num_layers)
self.num_layers = num_layers
self.hidden_dim = hidden_dim
return super().set_params(**params)