elliptic_toolkit.model_wrappers module
- class elliptic_toolkit.model_wrappers.GNNBinaryClassifier(data, model, hidden_dim=64, num_layers=3, dropout=0.5, norm=None, jk='last', learning_rate_init=0.01, weight_decay=0.0005, balance_loss=True, max_iter=200, verbose=False, n_iter_no_change=10, tol=0.0001, device='auto', heads=None, **kwargs)[source]
Bases:
ClassifierMixin
,BaseEstimator
Graph Neural Network Binary Classifier with early stopping.
A scikit-learn compatible binary classifier that wraps around PyTorch Geometric GNN models. Currently supports transductive and full batch learning models (GCN, GAT).
The training loss is monitored and the model is considered converged if the loss does not improve for n_iter_no_change consecutive iterations by at least tol. This early stopping mechanism is always enabled, similar to MLPClassifier in scikit-learn.
- Parameters:
data (torch_geometric.data.Data) – Graph data object containing node features (x), edge indices (edge_index), and node labels (y).
model (torch.nn.Module) – The GNN model class to instantiate for training.
hidden_dim (int, default=64) – Number of hidden units in each layer.
num_layers (int, default=3) – Number of layers in the neural network.
dropout (float, default=0.5) – Dropout probability for regularization.
learning_rate_init (float, default=0.01) – Initial learning rate for the Adam optimizer.
weight_decay (float, default=5e-4) – L2 regularization strength.
balance_loss (bool, default=True) – Whether to balance the loss function by weighting positive samples. If True, uses positive class weighting in BCEWithLogitsLoss based on class frequencies. If False, uses unweighted loss.
max_iter (int, default=200) – Maximum number of training iterations.
verbose (bool, default=False) – Whether to print training progress.
n_iter_no_change (int, default=10) – Number of consecutive iterations with no improvement to trigger early stopping.
tol (float, default=1e-4) – Tolerance for improvement. Training stops if loss improvement is less than this value.
device (str or torch.device, default='auto') – Device to use for computation. Can be ‘cpu’, ‘cuda’, ‘auto’, or a torch.device object. If ‘auto’, will use CUDA if available, otherwise CPU.
heads (int, default=None) – Number of attention heads for GAT models. Only applicable when model=GAT. Ignored with a warning for other model types.
**kwargs (dict) – Additional keyword arguments passed to the model constructor.
Attributes
----------
loss_curve (list) – List of loss values at each training iteration.
model – The trained GNN model after calling fit.
- __init__(data, model, hidden_dim=64, num_layers=3, dropout=0.5, norm=None, jk='last', learning_rate_init=0.01, weight_decay=0.0005, balance_loss=True, max_iter=200, verbose=False, n_iter_no_change=10, tol=0.0001, device='auto', heads=None, **kwargs)[source]
- fit(X, y=None)[source]
Fit the GNN model to the training data.
Training automatically stops when the loss stops improving for n_iter_no_change consecutive iterations, similar to MLPClassifier.
- Parameters:
train_indices (array-like) – Indices of training samples in the graph.
y (array-like, default=None) – Target values (ignored, present for sklearn compatibility).
- Returns:
self – Returns self for method chaining.
- Return type:
- Warns:
UserWarning – If training stops due to max_iter being reached without convergence.
- predict(X)[source]
Predict class labels for samples in test_indices.
- Parameters:
test_indices (array-like) – Indices of test samples in the graph.
- Returns:
predictions – Predicted class labels (0 or 1).
- Return type:
ndarray of shape (n_samples,)
- Raises:
ValueError – If the classifier has not been fitted yet.
- predict_proba(X)[source]
Predict class probabilities for samples in test_indices.
- Parameters:
test_indices (array-like) – Indices of test samples in the graph.
- Returns:
probabilities – Predicted class probabilities. First column contains probabilities for class 0, second column for class 1.
- Return type:
ndarray of shape (n_samples, 2)
- Raises:
ValueError – If the classifier has not been fitted yet.
- property classes_
- set_score_request(*, sample_weight: bool | None | str = '$UNCHANGED$') GNNBinaryClassifier
Configure whether metadata should be requested to be passed to the
score
method.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True
(seesklearn.set_config()
). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed toscore
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it toscore
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- class elliptic_toolkit.model_wrappers.DropTime(drop=True)[source]
Bases:
BaseEstimator
,TransformerMixin
Transformer for dropping the ‘time’ column from a DataFrame. Useful in scikit-learn pipelines.
- class elliptic_toolkit.model_wrappers.MLPWrapper(num_layers=2, hidden_dim=16, hidden_layer_sizes=None, alpha=0.0001, learning_rate_init=0.001, batch_size='auto', max_iter=1000)[source]
Bases:
MLPClassifier
Wrapper around sklearn’s MLPClassifier to allow specifying the number of layers and hidden dimension directly. This is useful for hyperparameter tuning where hyperparameters need to be independent. Some parameters of the base MLPClassifier are fixed to ensure consistent behavior: - shuffle=False: Disable shuffling to maintain temporal order. - early_stopping=False: Disable internal test/validation split for validation loss based early stopping and use training loss based early stopping instead.
- Parameters:
num_layers (int, default=2) – Number of hidden layers in the MLP.
hidden_dim (int, default=16) – Number of units in each hidden layer.
hidden_layer_sizes (tuple or None, default=None) – If provided, this overrides num_layers and hidden_dim. Should be a tuple specifying the size of each hidden layer.
alpha (float, default=0.0001) – L2 regularization term.
learning_rate_init (float, default=0.001) – Initial learning rate.
batch_size (int or 'auto', default='auto') – Size of minibatches for stochastic optimizers.
max_iter (int, default=1000) – Maximum number of iterations.
- __init__(num_layers=2, hidden_dim=16, hidden_layer_sizes=None, alpha=0.0001, learning_rate_init=0.001, batch_size='auto', max_iter=1000)[source]
- set_fit_request(*, sample_weight: bool | None | str = '$UNCHANGED$') MLPWrapper
Configure whether metadata should be requested to be passed to the
fit
method.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True
(seesklearn.set_config()
). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed tofit
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it tofit
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- set_partial_fit_request(*, classes: bool | None | str = '$UNCHANGED$', sample_weight: bool | None | str = '$UNCHANGED$') MLPWrapper
Configure whether metadata should be requested to be passed to the
partial_fit
method.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True
(seesklearn.set_config()
). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed topartial_fit
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it topartial_fit
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- Parameters:
- Returns:
self – The updated object.
- Return type:
- set_score_request(*, sample_weight: bool | None | str = '$UNCHANGED$') MLPWrapper
Configure whether metadata should be requested to be passed to the
score
method.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True
(seesklearn.set_config()
). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed toscore
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it toscore
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- set_params(**params)[source]
Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects (such as
Pipeline
). The latter have parameters of the form<component>__<parameter>
so that it’s possible to update each component of a nested object.- Parameters:
**params (dict) – Estimator parameters.
- Returns:
self – Estimator instance.
- Return type:
estimator instance