Skip to content

GAT-RWOS: Graph Attention-Guided Random Walk Oversampling for Imbalanced Data Classification. A graph attention-guided oversampling method for imbalanced data classification using GAT and random walks.


Notifications You must be signed in to change notification settings


Folders and files

Last commit message
Last commit date

Latest commit



3 Commits

Repository files navigation

GAT-RWOS: Graph Attention-Guided Random Walk Oversampling for Imbalanced Data Classification

License: MIT Python Version


GAT-RWOS is a graph-based oversampling method that combines Graph Attention Networks (GATs) with random walk-based oversampling to address the class imbalance problem. By utilizing GAT's attention mechanism to guide random walks through informative neighborhoods of minority nodes, GAT-RWOS generates synthetic samples that expand class boundaries while preserving the original data distribution.

TL;DR: We create samples around those instances deemed important by GAT.




  • Python >= 3.9 (Tested on 3.9-3.12)
  • PyTorch (follow installation instructions at


  • pandas
  • numpy
  • scikit-learn
  • optuna
  • tqdm
  • pyyaml
  • scipy
  • xgboost
  • pytorch_geometric

Note: these dependencies will be installed automatically when you install this package


# Clone the repository
git clone
cd gat-rwos

# Install in development mode
pip install -e .


GAT-RWOS can be used in three ways:

  1. Command-line interface:
gat_rwos --datasets yeast6 flare-F
  1. Python module:
python -m gat_rwos.main --datasets yeast6
  1. As a Python package:
from gat_rwos import main"yeast6")

Command Line Arguments

  • --datasets: Names of datasets to process (without extension)
  • --tune: Enable hyperparameter tuning. You can control the configurations in the config file
  • --random_state: There is also random_state parameter in the config file. You can override it with this argument.
  • --config: Path to configuration file (default: configs/config.yaml)

Data Format

You can use the datasets provided in the data/ folder or use your own datasets.
If you decide to use your own datasets, please ensure that they meet the following:

The input data should be in CSV format with the following requirements:

  • The target variable must be named 'class'
  • Binary classification only
  • No missing values
  • Features can be either numerical or categorical

Example dataset structure:


Project Structure

├── src/gat_rwos/       # Main package source code
├── configs/            # Configuration files
├── data/              # Dataset files (.csv format)
└── results/           # Generated results

The results folder will contain:

  • {dataset_name}/
    • {dataset_name}_results.csv # Results with oversampled data
    • {dataset_name}_original_results.csv# Results with original data
    • {dataset_name}_balanced.csv # Generated balanced dataset
    • {dataset_name}_original_vs_oversampled.png # Visualization


GAT-RWOS uses a YAML configuration file to control all aspects of the pipeline. Here are the key configuration parameters:

Expand to view configuration parameters

Data Processing

  scaler: "minmax"      # Data scaling method: minmax, standard, or none
  test_size: 0.1        # Proportion of test set
  val_size: 0.1         # Proportion of validation set

Graph Construction

  similarity_method: "euclidean"   # Distance metric: euclidean, cosine, manhattan
  similarity_threshold: 0.5        # Threshold for edge creation. Setting this higher will result in a sparser graph

GAT Architecture

hid: 32              # Hidden dimension size
in_head: 4           # Number of attention heads in input layer
out_head: 3          # Number of attention heads in output layer
dropout_rate: 0.3    # Dropout rate
num_hidden_layers: 3 # Number of hidden layers

Attention Aggregation

aggregation_method: "mean" # Aggregation method: mean, median, max, mul. This is how we combine the attention weights into an attention matrix
attention_threshold: 0.5  # This controls the "importance" of the connections. Setting this higher will only keep stronger connections (i.e., pairs of nodes with higher attention weights).

Random Walk Parameters

num_steps: 10        # Length of random walks
p: 0.5              # Return parameter (controls likelihood of returning to previous node)
q: 2.0              # In-out parameter (controls search behavior)

Interpolation Settings

num_interpolations: 15  # Number of interpolations per path
min_alpha: 0.1         # Minimum interpolation weight
max_alpha: 0.9         # Maximum interpolation weight
variability: 0.9      # Variability of interpolation weights

Hyperparameter Tuning

GAT-RWOS uses Optuna for hyperparameter optimization. The tuning process occurs in three hierarchical stages, (number of trials for each can be controlled in the config file):

  1. Main Model Parameters (n_trials_main): Optimizes GAT architecture parameters (hidden dimensions, number of heads, dropout rate)
  2. Attention Parameters (n_trials_attention): Tunes attention threshold and aggregation method
  3. Interpolation Parameters (n_trials_interpolation): Optimizes random walk and interpolation settings

Total number of trials = n_trials_main * n_trials_attention * n_trials_interpolation. Setting these to higher values will result in a more thorough search but at the cost of increased computation time.

The tuning ranges can be configured in the config file under the tuning.optuna.ranges section:

      similarity_methods: ["cosine", "euclidean", "manhattan"]
        min: 4
        max: 64
        min: 0.0
        max: 0.7
      # ... other parameters



Common Issues and Solutions

  1. CUDA Out of Memory

    • Try reducing the hid parameter in the config file
    • Decrease num_hidden_layers
  2. Graph Construction Failure

    • Try a different similarity_method (e.g., switch from 'euclidean' to 'cosine')
    • Lower the similarity_threshold to create more connections
    • Ensure your data is properly scaled (this makes a lot of difference)
  3. Poor Performance

    • Increase the number of trials in tuning parameters (n_trials_main, n_trials_attention, n_trials_interpolation)
    • Try different aggregation_method method and play with the attention_threshold.
    • Adjust the random walk parameters (num_steps, p, q)
  4. Pandas Error on Python 3.12

    • RecursionError: maximum recursion depth exceeded: This can be solved my upgrading both pandas and numpy to the latest versions.



GAT-RWOS: Graph Attention-Guided Random Walk Oversampling for Imbalanced Data Classification. A graph attention-guided oversampling method for imbalanced data classification using GAT and random walks.





