diff --git a/.gitignore b/.gitignore index 1d3b5479..ac7f184e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,9 @@ __pycache__ +.idea build +data +logs dist venv wilds.egg-info +.DS_Store diff --git a/README.md b/README.md index 33415393..1b125252 100644 --- a/README.md +++ b/README.md @@ -14,9 +14,10 @@ The WILDS package contains: 1. Data loaders that automatically handle data downloading, processing, and splitting, and 2. Dataset evaluators that standardize model evaluation for each dataset. -In addition, the example scripts contain default models, allowing new algorithms to be easily added and run on all of the WILDS datasets. +In addition, the example scripts contain default models, optimizers, schedulers, and training/evaluation code. +New algorithms can be easily added and run on all of the WILDS datasets. -For more information, please read [our paper](https://arxiv.org/abs/2012.07421) or visit [our website](https://wilds.stanford.edu). +For more information, please visit [our website](https://wilds.stanford.edu) or read the main WILDS paper ([1](https://arxiv.org/abs/2012.07421)) and its follow-up integrating unlabeled data ([2](https://arxiv.org/abs/2112.05090)). For questions and feedback, please post on the [discussion board](https://github.com/p-lambda/wilds/discussions). ## Installation @@ -29,7 +30,7 @@ pip install wilds If you have already installed it, please check that you have the latest version: ```bash python -c "import wilds; print(wilds.__version__)" -# This should print "1.2.2". If it doesn't, update by running: +# This should print "2.0.0". If it doesn't, update by running: pip install -U wilds ``` @@ -40,7 +41,12 @@ cd wilds pip install -e . ``` +In `examples/`, we provide a set of scripts that can be used to train models on the WILDS datasets. These scripts were also used to benchmark baselines in our papers [[1](https://arxiv.org/abs/2012.07421), [2](https://arxiv.org/abs/2112.05090)]. +These scripts are not part of the installed WILDS package. To use them, you should install from source, as described above. + ### Requirements +The WILDS package depends on the following requirements: + - numpy>=1.19.1 - ogb>=1.2.6 - outdated>=0.2.0 @@ -49,91 +55,42 @@ pip install -e . - pytz>=2020.4 - torch>=1.7.0 - torch-scatter>=2.0.5 -- torch-geometric>=1.6.1 +- torch-geometric>=2.0.1 - torchvision>=0.8.2 - tqdm>=4.53.0 - scikit-learn>=0.20.0 - scipy>=1.5.4 Running `pip install wilds` or `pip install -e .` will automatically check for and install all of these requirements -except for the `torch-scatter` and `torch-geometric` packages, which require a [quick manual install](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html#installation-via-binaries). -We recommend torch<1.9.0 because of data loader warnings described [here](https://github.com/pytorch/pytorch/issues/57273). +**except for the `torch-scatter` and `torch-geometric` packages**, which require a +[quick manual install](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html#installation-via-binaries). -### Default models -After installing the WILDS package, you can use the scripts in `examples/` to train default models on the WILDS datasets. -These scripts are not part of the installed WILDS package. To use them, you should clone the repo (assuming you did not install from source): -```bash -git clone git@github.com:p-lambda/wilds.git -``` - -To run these scripts, you will also need to install this additional dependency: +### Example script requirements +To run the example scripts, you will also need to install these additional dependencies: - transformers>=3.5.0 +- SwAV requires [Apex](https://github.com/NVIDIA/apex). + To install Apex, please follow the [README from the official SwAV repository](https://github.com/facebookresearch/swav#requirements). +- Our code supports the optional use of [Weights & Biases](https://wandb.ai/site) to track and monitor experiments. + To install the Weights and Biases Python package, run `pip install wandb`. All baseline experiments in the paper were run on Python 3.8.5 and CUDA 10.1. - -## Using the example scripts - -In the `examples/` folder, we provide a set of scripts that can be used to download WILDS datasets and train models on them. -These scripts are configured with the default models and hyperparameters that we used for all of the baselines described in our paper. All baseline results in the paper can be easily replicated with commands like: - -```bash -python examples/run_expt.py --dataset iwildcam --algorithm ERM --root_dir data -python examples/run_expt.py --dataset civilcomments --algorithm groupDRO --root_dir data -``` - -The scripts are set up to facilitate general-purpose algorithm development: new algorithms can be added to `examples/algorithms` and then run on all of the WILDS datasets using the default models. - -### Downloading and training on the WILDS datasets -The first time you run these scripts, you might need to download the datasets. You can do so with the `--download` argument, for example: -``` -python examples/run_expt.py --dataset civilcomments --algorithm groupDRO --root_dir data --download -``` - -Alternatively, you can use the standalone `wilds/download_datasets.py` script to download the datasets, for example: - -```bash -python wilds/download_datasets.py --root_dir data -``` - -This will download all datasets to the specified `data` folder. You can also use the `--datasets` argument to download particular datasets. - -These are the sizes of each of our datasets, as well as their approximate time taken to train and evaluate the default model for a single ERM run using a NVIDIA V100 GPU. - -| Dataset command | Modality | Download size (GB) | Size on disk (GB) | Train+eval time (Hours) | -|-----------------|----------|--------------------|-------------------|-------------------------| -| iwildcam | Image | 11 | 25 | 7 | -| camelyon17 | Image | 10 | 15 | 2 | -| rxrx1 | Image | 7 | 7 | 11 | -| ogb-molpcba | Graph | 0.04 | 2 | 15 | -| globalwheat | Image | 10 | 10 | 2 | -| civilcomments | Text | 0.1 | 0.3 | 4.5 | -| fmow | Image | 50 | 55 | 6 | -| poverty | Image | 12 | 14 | 5 | -| amazon | Text | 7 | 7 | 5 | -| py150 | Text | 0.1 | 0.8 | 9.5 | - -While the `camelyon17` dataset is small and fast to train on, we advise against using it as the only dataset to prototype methods on, as the test performance of models trained on this dataset tend to exhibit a large degree of variability over random seeds. - -The image datasets (`iwildcam`, `camelyon17`, `rxrx1`, `globalwheat`, `fmow`, and `poverty`) tend to have high disk I/O usage. If training time is much slower for you than the approximate times listed above, consider checking if I/O is a bottleneck (e.g., by moving to a local disk if you are using a network drive, or by increasing the number of data loader workers). To speed up training, you could also disable evaluation at each epoch or for all splits by toggling `--evaluate_all_splits` and related arguments. - -### Evaluating trained models -We also provide an evaluation script that aggregates prediction CSV files for different replicates and reports on their combined evaluation. To use this, run: - -```bash -python examples/evaluate.py --root-dir -``` - -where `` is the path to your predictions directory, `` is where the results JSON will be writte, and `` is the dataset root directory. -The predictions directory should have a subdirectory for each dataset -(e.g. `iwildcam`) containing prediction CSV files to evaluate; see our [submission guidelines](https://wilds.stanford.edu/submit/) for the format. -The evaluation script will skip over any datasets that has missing prediction files. -Any dataset not in `` will be downloaded to ``. - -### Reproducibility -We have an [executable version](https://wilds.stanford.edu/codalab) of our paper on CodaLab that contains the exact commands, code, and data for the experiments reported in our paper, which rely on these scripts. Trained model weights for all datasets can also be found there. -All configurations and hyperparameters can also be found in the `examples/configs` folder of this repo, and dataset-specific parameters are in `examples/configs/datasets.py`. +## Datasets +WILDS currently includes 10 datasets, which we've briefly listed below. For full dataset descriptions, please see our papers ([1](https://arxiv.org/abs/2012.07421), [2](https://arxiv.org/abs/2112.05090)). + +| Dataset | Modality | Labeled splits | Unlabeled splits | +| ------------- | -------- | --------------------------------- | --------------------------------------------------------------- | +| iwildcam | Image | train, val, test, id_val, id_test | extra_unlabeled | +| camelyon17 | Image | train, val, test, id_val | train_unlabeled, val_unlabeled, test_unlabeled | +| rxrx1 | Image | train, val, test, id_test | - | +| ogb-molpcba | Graph | train, val, test | train_unlabeled, val_unlabeled, test_unlabeled | +| globalwheat | Image | train, val, test, id_val, id_test | train_unlabeled, val_unlabeled, test_unlabeled, extra_unlabeled | +| civilcomments | Text | train, val, test | extra_unlabeled | +| fmow | Image | train, val, test, id_val, id_test | train_unlabeled, val_unlabeled, test_unlabeled | +| poverty | Image | train, val, test, id_val, id_test | train_unlabeled, val_unlabeled, test_unlabeled | +| amazon | Text | train, val, test, id_val, id_test | val_unlabeled, test_unlabeled, extra_unlabeled | +| py150 | Text | train, val, test, id_val, id_test | - | ## Using the WILDS package ### Data @@ -143,24 +100,39 @@ This short Python snippet covers all of the steps of getting started with a WILD We discuss data loading in more detail in [#Data loading](#data-loading). ```py ->>> from wilds import get_dataset ->>> from wilds.common.data_loaders import get_train_loader ->>> import torchvision.transforms as transforms +from wilds import get_dataset +from wilds.common.data_loaders import get_train_loader +import torchvision.transforms as transforms # Load the full dataset, and download it if necessary ->>> dataset = get_dataset(dataset='iwildcam', download=True) +dataset = get_dataset(dataset="iwildcam", download=True) # Get the training set ->>> train_data = dataset.get_subset('train', -... transform=transforms.Compose([transforms.Resize((448,448)), -... transforms.ToTensor()])) +train_data = dataset.get_subset( + "train", + transform=transforms.Compose( + [transforms.Resize((448, 448)), transforms.ToTensor()] + ), +) # Prepare the standard data loader ->>> train_loader = get_train_loader('standard', train_data, batch_size=16) +train_loader = get_train_loader("standard", train_data, batch_size=16) + +# (Optional) Load unlabeled data +dataset = get_dataset(dataset="iwildcam", download=True, unlabeled=True) +unlabeled_data = dataset.get_subset( + "test_unlabeled", + transform=transforms.Compose( + [transforms.Resize((448, 448)), transforms.ToTensor()] + ), +) +unlabeled_loader = get_train_loader("standard", unlabeled_data, batch_size=16) # Train loop ->>> for x, y_true, metadata in train_loader: -... ... +for labeled_batch, unlabeled_batch in zip(train_loader, unlabeled_loader): + x, y, metadata = labeled_batch + unlabeled_x, unlabeled_metadata = unlabeled_batch + ... ``` The `metadata` contains information like the domain identity, e.g., which camera a photo was taken from, or which hospital the patient's data came from, etc., as well as other metadata. @@ -172,16 +144,16 @@ They are used to initialize group-aware data loaders (as discussed in [#Data loa In the following code snippet, we initialize and use a `Grouper` that extracts the domain annotations on the iWildCam dataset, where the domain is location. ```py ->>> from wilds.common.grouper import CombinatorialGrouper +from wilds.common.grouper import CombinatorialGrouper # Initialize grouper, which extracts domain information # In this example, we form domains based on location ->>> grouper = CombinatorialGrouper(dataset, ['location']) +grouper = CombinatorialGrouper(dataset, ['location']) # Train loop ->>> for x, y_true, metadata in train_loader: -... z = grouper.metadata_to_group(metadata) -... ... +for x, y_true, metadata in train_loader: + z = grouper.metadata_to_group(metadata) + ... ``` ### Data loading @@ -189,10 +161,10 @@ In the following code snippet, we initialize and use a `Grouper` that extracts t For training, the WILDS package provides two types of data loaders. The standard data loader shuffles examples in the training set, and is used for the standard approach of empirical risk minimization (ERM), where we minimize the average loss. ```py ->>> from wilds.common.data_loaders import get_train_loader +from wilds.common.data_loaders import get_train_loader # Prepare the standard data loader ->>> train_loader = get_train_loader('standard', train_data, batch_size=16) +train_loader = get_train_loader('standard', train_data, batch_size=16) ``` To support other algorithms that rely on specific data loading schemes, we also provide the group data loader. @@ -202,24 +174,28 @@ We initialize group loaders as follows, using `Grouper` that specifies the group ```py # Prepare a group data loader that samples from user-specified groups ->>> train_loader = get_train_loader('group', train_data, -... grouper=grouper, -... n_groups_per_batch=2, -... batch_size=16) +train_loader = get_train_loader( + "group", train_data, grouper=grouper, n_groups_per_batch=2, batch_size=16 +) + ``` Lastly, we also provide a data loader for evaluation, which loads examples without shuffling (unlike the training loaders). ```py ->>> from wilds.common.data_loaders import get_eval_loader +from wilds.common.data_loaders import get_eval_loader # Get the test set ->>> test_data = dataset.get_subset('test', -... transform=transforms.Compose([transforms.Resize((224,224)), -... transforms.ToTensor()])) +test_data = dataset.get_subset( + "test", + transform=transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), +) # Prepare the evaluation data loader ->>> test_loader = get_eval_loader('standard', test_data, batch_size=16) +test_loader = get_eval_loader("standard", test_data, batch_size=16) + ``` ### Evaluators @@ -228,36 +204,149 @@ The WILDS package standardizes and automates evaluation for each dataset. Invoking the `eval` method of each dataset yields all metrics reported in the paper and on the leaderboard. ```py ->>> from wilds.common.data_loaders import get_eval_loader +from wilds.common.data_loaders import get_eval_loader # Get the test set ->>> test_data = dataset.get_subset('test', -... transform=transforms.Compose([transforms.Resize((224,224)), -... transforms.ToTensor()])) +test_data = dataset.get_subset( + "test", + transform=transforms.Compose( + [transforms.Resize((224, 224)), transforms.ToTensor()] + ), +) # Prepare the data loader ->>> test_loader = get_eval_loader('standard', test_data, batch_size=16) +test_loader = get_eval_loader("standard", test_data, batch_size=16) # Get predictions for the full test set ->>> for x, y_true, metadata in test_loader: -... y_pred = model(x) -... [accumulate y_true, y_pred, metadata] +for x, y_true, metadata in test_loader: + y_pred = model(x) + # Accumulate y_true, y_pred, metadata # Evaluate ->>> dataset.eval(all_y_pred, all_y_true, all_metadata) -{'recall_macro_all': 0.66, ...} +dataset.eval(all_y_pred, all_y_true, all_metadata) +# {'recall_macro_all': 0.66, ...} ``` Most `eval` methods take in predicted labels for `all_y_pred` by default, but the default inputs vary across datasets and are documented in the `eval` docstrings of the corresponding dataset class. +## Using the example scripts +In `examples/`, we provide a set of scripts that can be used to train models on the WILDS datasets. + +```bash +python examples/run_expt.py --dataset iwildcam --algorithm ERM --root_dir data +python examples/run_expt.py --dataset civilcomments --algorithm groupDRO --root_dir data +python examples/run_expt.py --dataset fmow --algorithm DANN --unlabeled_split test_unlabeled --root_dir data +``` + +The scripts are configured to use the default models and reasonable hyperparameters. For exact hyperparameter settings used in our papers, please see [our CodaLab executable paper](https://wilds.stanford.edu/codalab). + +### Downloading and training on the WILDS datasets +The first time you run these scripts, you might need to download the datasets. You can do so with the `--download` argument, for example: +```python +# downloads (labeled) dataset +python examples/run_expt.py --dataset globalwheat --algorithm groupDRO --root_dir data --download + +# additionally downloads all unlabeled data +python examples/run_expt.py --dataset globalwheat --algorithm groupDRO --root_dir data --download --unlabeled_split [...] +``` +Note that downloading the large amount of unlabeled data is optional; unlabeled data will only be downloaded if some `--unlabeled_split` is set. (It does not matter which `--unlabeled_split` is set; all unlabeled data will be downloaded together.) + +Alternatively, you can use the standalone `wilds/download_datasets.py` script to download the datasets, for example: + +```bash +# downloads (labeled) data +python wilds/download_datasets.py --root_dir data + +# downloads (unlabeled) data +python wilds/download_datasets.py --root_dir data --unlabeled +``` + +This will download all datasets to the specified `data` folder. You can also use the `--datasets` argument to download particular datasets. + +These are the sizes of each of our datasets, as well as their approximate time taken to train and evaluate the default model for a single ERM run using a NVIDIA V100 GPU. + +| Dataset command | Modality | Download size (GB) | Size on disk (GB) | Train+eval time (Hours) | +| --------------- | -------- | ------------------ | ----------------- | ----------------------- | +| iwildcam | Image | 11 | 25 | 7 | +| camelyon17 | Image | 10 | 15 | 2 | +| rxrx1 | Image | 7 | 7 | 11 | +| ogb-molpcba | Graph | 0.04 | 2 | 15 | +| globalwheat | Image | 10 | 10 | 2 | +| civilcomments | Text | 0.1 | 0.3 | 4.5 | +| fmow | Image | 50 | 55 | 6 | +| poverty | Image | 12 | 14 | 5 | +| amazon | Text | 7 | 7 | 5 | +| py150 | Text | 0.1 | 0.8 | 9.5 | + +The following are the sizes of the unlabeled data bundles: + +| Dataset command | Modality | Download size (GB) | Size on disk (GB) | +| --------------- | -------- | ------------------ | ----------------- | +| iwildcam | Image | 41 | 41 | +| camelyon17 | Image | 69.4 | 96 | +| ogb-molpcba | Graph | 1.2 | 21 | +| globalwheat | Image | 103 | 108 | +| civilcomments | Text | 0.3 | 0.6 | +| fmow\* | Image | 50 | 55 | +| poverty | Image | 172 | 184 | +| amazon\* | Text | 7 | 7 | + + \* These unlabeled datasets are downloaded simultaneously with the labeled data and do not need to be downloaded separately. + +While the `camelyon17` dataset is small and fast to train on, we advise against using it as the only dataset to prototype methods on, as the test performance of models trained on this dataset tend to exhibit a large degree of variability over random seeds. + +The image datasets (`iwildcam`, `camelyon17`, `rxrx1`, `globalwheat`, `fmow`, and `poverty`) tend to have high disk I/O usage. If training time is much slower for you than the approximate times listed above, consider checking if I/O is a bottleneck (e.g., by moving to a local disk if you are using a network drive, or by increasing the number of data loader workers). To speed up training, you could also disable evaluation at each epoch or for all splits by toggling `--evaluate_all_splits` and related arguments. + +### Algorithms +In the `examples/algorithms` folder, we provide implementations of the adaptation algorithms benchmarked in our papers ([1](https://arxiv.org/abs/2012.07421), [2](https://arxiv.org/abs/2112.05090)). +All algorithms train on labeled data from a WILDS dataset's `train` split. +Some algorithms are designed to also leverage unlabeled data. To load unlabeled data, specify an `--unlabeled_split` when running. + +In addition to shared hyperparameters such as `lr`, `weight_decay`, `batch_size`, and `unlabeled_batch_size`, the scripts also take in command line arguments for algorithm-specific hyperparameters. + +| Algorithm command | Hyperparameters | Notes | See WILDS paper | +| ------------------------------------------------ | ------------------------------------------------------------------------------------------ | --------------------------------- | ------------------------------------------------------------------------------ | +| ERM | - | Only uses labeled data | ([1](https://arxiv.org/abs/2012.07421), [2](https://arxiv.org/abs/2112.05090)) | +| [groupDRO](https://arxiv.org/abs/1911.08731) | `group_dro_step_size` | Only uses labeled data | ([1](https://arxiv.org/abs/2012.07421)) | +| [deepCORAL](https://arxiv.org/abs/1511.05547) | `coral_penalty_weight` | Can optionally use unlabeled data | ([1](https://arxiv.org/abs/2012.07421), [2](https://arxiv.org/abs/2112.05090)) | +| [IRM](https://arxiv.org/abs/1907.02893) | `irm_lambda`, `irm_penalty_anneal_iters` | Only uses labeled data | ([1](https://arxiv.org/abs/2012.07421)) | +| [DANN](https://arxiv.org/abs/1505.07818) | `dann_penalty_weight`, `dann_classifier_lr`, `dann_featurizer_lr`, `dann_discriminator_lr` | Can use unlabeled data | ([2](https://arxiv.org/abs/2112.05090)) | +| [AFN](https://arxiv.org/abs/1811.07456) | `afn_penalty_weight`, `safn_delta_r`, `hafn_r` | Designed to use unlabeled data | ([2](https://arxiv.org/abs/2112.05090)) | +| [FixMatch](https://arxiv.org/abs/2001.07685) | `self_training_lambda`, `self_training_threshold` | Designed to use unlabeled data | ([2](https://arxiv.org/abs/2112.05090)) | +| PseudoLabel | `self_training_lambda`, `self_training_threshold`, `pseudolabel_T2` | Designed to use unlabeled data | ([2](https://arxiv.org/abs/2112.05090)) | +| [NoisyStudent](https://arxiv.org/abs/1911.04252) | `soft_pseudolabels`, `noisystudent_dropout_rate` | Designed to use unlabeled data | ([2](https://arxiv.org/abs/2112.05090)) | + +The repository is set up to facilitate general-purpose algorithm development: new algorithms can be added to `examples/algorithms` and then run on all of the WILDS datasets using the default models. + +### Evaluating trained models +We also provide an evaluation script that aggregates prediction CSV files for different replicates and reports on their combined evaluation. To use this, run: + +```bash +python examples/evaluate.py --root-dir +``` + +where `` is the path to your predictions directory, `` is where the results JSON will be writte, and `` is the dataset root directory. +The predictions directory should have a subdirectory for each dataset +(e.g. `iwildcam`) containing prediction CSV files to evaluate; see our [submission guidelines](https://wilds.stanford.edu/submit/) for the format. +The evaluation script will skip over any datasets that has missing prediction files. +Any dataset not in `` will be downloaded to ``. + +### Reproducibility +We have an [executable version](https://wilds.stanford.edu/codalab) of our paper on CodaLab that contains the exact commands, code, and data for the experiments reported in our paper, which rely on these scripts. Trained model weights for all datasets can also be found there. +All configurations and hyperparameters can also be found in the `examples/configs` folder of this repo, and dataset-specific parameters are in `examples/configs/datasets.py`. + ## Leaderboard If you are developing new training algorithms and/or models on WILDS, please consider submitting them to our [public leaderboard](https://wilds.stanford.edu/leaderboard/). -## Citing WILDS -If you use WILDS datasets in your work, please cite [our paper](https://arxiv.org/abs/2012.07421) ([Bibtex](https://wilds.stanford.edu/assets/files/wilds_bib.txt)): +## Citing WILDS ([Bibtex](https://wilds.stanford.edu/assets/files/wilds_bib.txt)) +If you use WILDS datasets in your work, please cite our paper: + +1. [**WILDS: A Benchmark of in-the-Wild Distribution Shifts.**](https://arxiv.org/abs/2012.07421) Pang Wei Koh*, Shiori Sagawa*, Henrik Marklund, Sang Michael Xie, Marvin Zhang, Akshay Balsubramani, Weihua Hu, Michihiro Yasunaga, Richard Lanas Phillips, Irena Gao, Tony Lee, Etienne David, Ian Stavness, Wei Guo, Berton A. Earnshaw, Imran S. Haque, Sara Beery, Jure Leskovec, Anshul Kundaje, Emma Pierson, Sergey Levine, Chelsea Finn, and Percy Liang. ICML 2021. + +If you use unlabeled data from the WILDS datasets, please also cite: -- **WILDS: A Benchmark of in-the-Wild Distribution Shifts.** Pang Wei Koh*, Shiori Sagawa*, Henrik Marklund, Sang Michael Xie, Marvin Zhang, Akshay Balsubramani, Weihua Hu, Michihiro Yasunaga, Richard Lanas Phillips, Irena Gao, Tony Lee, Etienne David, Ian Stavness, Wei Guo, Berton A. Earnshaw, Imran S. Haque, Sara Beery, Jure Leskovec, Anshul Kundaje, Emma Pierson, Sergey Levine, Chelsea Finn, and Percy Liang. ICML 2021. +2. [**Extending the WILDS Benchmark for Unsupervised Adaptation.**](https://arxiv.org/abs/2112.05090) Shiori Sagawa*, Pang Wei Koh*, Tony Lee*, Irena Gao*, Sang Michael Xie, Kendrick Shen, Ananya Kumar, Weihua Hu, Michihiro Yasunaga, Henrik Marklund, Sara Beery, Etienne David, Ian Stavness, Wei Guo, Jure Leskovec, Kate Saenko, Tatsunori Hashimoto, Sergey Levine, Chelsea Finn, and Percy Liang. NeurIPS 2021 Workshop on Distribution Shifts. -Please also cite the original papers that introduce the datasets, as listed on the [datasets page](https://wilds.stanford.edu/datasets/). +In addition, please cite the original papers that introduced the datasets, as listed on the [datasets page](https://wilds.stanford.edu/datasets/). ## Acknowledgements The design of the WILDS benchmark was inspired by the [Open Graph Benchmark](https://ogb.stanford.edu/), and we are grateful to the Open Graph Benchmark team for their advice and help in setting up WILDS. diff --git a/dataset_preprocessing/amazon_yelp/create_unlabeled_amazon.py b/dataset_preprocessing/amazon_yelp/create_unlabeled_amazon.py new file mode 100644 index 00000000..7d97f19c --- /dev/null +++ b/dataset_preprocessing/amazon_yelp/create_unlabeled_amazon.py @@ -0,0 +1,168 @@ +import argparse +import csv +import os + +import numpy as np +import pandas as pd + +# Fix the seed for reproducibility +np.random.seed(0) + +""" +Create unlabeled splits for Amazon. + +Usage: + python dataset_preprocessing/amazon_yelp/create_unlabeled_amazon.py +""" + +NOT_IN_DATASET = -1 + +# Splits +# 'train': 0, 'val': 1, 'id_val': 2, 'test': 3, 'id_test': 4, +# 'val_unlabeled': 11, 'test_unlabeled': 12, 'extra_unlabeled': 13 +( + TRAIN, + OOD_VAL, + ID_VAL, + OOD_TEST, + ID_TEST, +) = range(5) +VAL_UNLABELED, TEST_UNLABELED, EXTRA_UNLABELED = range(11, 14) + + +def main(dataset_path): + def output_split_sizes(): + print("-" * 50) + print(f'Train size: {len(split_df[split_df["split"] == TRAIN])}') + print(f'Val size: {len(split_df[split_df["split"] == OOD_VAL])}') + print(f'ID Val size: {len(split_df[split_df["split"] == ID_VAL])}') + print(f'Test size: {len(split_df[split_df["split"] == OOD_TEST])}') + print(f'ID Test size: {len(split_df[split_df["split"] == ID_TEST])}') + print( + f'OOD Val Unlabeled size: {len(split_df[split_df["split"] == VAL_UNLABELED])}' + ) + print( + f'OOD Test Unlabeled size: {len(split_df[split_df["split"] == TEST_UNLABELED])}' + ) + print( + f'Extra Unlabeled size: {len(split_df[split_df["split"] == EXTRA_UNLABELED])}' + ) + print( + f'Number of examples not included: {len(split_df[split_df["split"] == NOT_IN_DATASET])}' + ) + print(f'Number of unclean reviews: {len(split_df[~split_df["clean"]])}') + print("-" * 50) + print("\n") + + def set_unlabeled_split(split, reviewers): + # Get unused reviews written by users from `reviewers` + split_df.loc[ + (split_df["split"] == NOT_IN_DATASET) + & split_df["clean"] + & data_df["reviewerID"].isin(reviewers), + "split", + ] = split + + def validate_split(split, expected_reviewers_count): + # Sanity check: + # Ensure the number of reviewers equals the number of reviewers in its unlabeled counterpart + # and each reviewer has at least 75 reviews. + actual_reviewers_counts = ( + data_df[(split_df["split"] == split)]["reviewerID"].unique().size + ) + assert ( + actual_reviewers_counts == expected_reviewers_count + ), "The number of reviewers ({}) did not equal {}".format( + actual_reviewers_counts, expected_reviewers_count + ) + min_reviewers_count = ( + data_df[(split_df["split"] == split)]["reviewerID"].value_counts().min() + ) + assert ( + min_reviewers_count >= 75 + ), "Each reviewer should have at least 75 reviews, but got a minimum of {} reviews.".format( + min_reviewers_count + ) + + data_df = pd.read_csv( + os.path.join(dataset_path, "reviews.csv"), + dtype={ + "reviewerID": str, + "asin": str, + "reviewTime": str, + "unixReviewTime": int, + "reviewText": str, + "summary": str, + "verified": bool, + "category": str, + "reviewYear": int, + }, + keep_default_na=False, + na_values=[], + quoting=csv.QUOTE_NONNUMERIC, + ) + user_csv_path = os.path.join(dataset_path, "splits", "user.csv") + split_df = pd.read_csv(user_csv_path) + assert split_df.shape[0] == data_df.shape[0] + output_split_sizes() + + ood_val_reviewers_ids = data_df[ + split_df["split"] == OOD_VAL + ].reviewerID.unique() # 1334 users + set_unlabeled_split(VAL_UNLABELED, ood_val_reviewers_ids) + + ood_test_reviewers_ids = data_df[ + split_df["split"] == OOD_TEST + ].reviewerID.unique() # 1334 users + set_unlabeled_split(TEST_UNLABELED, ood_test_reviewers_ids) + + # For EXTRA_UNLABELED, use any users not in any of the other splits + existing_reviewer_ids = np.concatenate( + [ + ood_test_reviewers_ids, + ood_val_reviewers_ids, + data_df[split_df["split"] == TRAIN].reviewerID.unique(), + data_df[split_df["split"] == ID_VAL].reviewerID.unique(), + data_df[split_df["split"] == ID_TEST].reviewerID.unique(), + ] + ) + # There are 151,736 extra reviewers + extra_reviewers_ids = data_df[ + ~data_df.reviewerID.isin(existing_reviewer_ids) + ].reviewerID.unique() + set_unlabeled_split(EXTRA_UNLABELED, extra_reviewers_ids) + + # Exclude reviewers with less than 75 reviews. + review_counts = data_df[(split_df["split"] == EXTRA_UNLABELED)][ + "reviewerID" + ].value_counts() + reviewers_to_filter_out = review_counts[review_counts < 75].keys() + split_df.loc[ + (split_df["split"] == EXTRA_UNLABELED) + & data_df["reviewerID"].isin(reviewers_to_filter_out), + "split", + ] = NOT_IN_DATASET + + # We are done splitting, output stats. + output_split_sizes() + + # Sanity checks + validate_split(VAL_UNLABELED, ood_val_reviewers_ids.size) + validate_split(TEST_UNLABELED, ood_test_reviewers_ids.size) + # After filtering out unclean reviews and ensuring >= 75 reviews per reviewer, we are left with 21,694 reviewers. + validate_split(EXTRA_UNLABELED, 21694) + + # Write out the new unlabeled split to user.csv + split_df.to_csv(user_csv_path, index=False) + print("Done.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Create unlabeled splits for Amazon.") + parser.add_argument( + "path", + type=str, + help="Path to the Amazon dataset", + ) + args = parser.parse_args() + main(args.path) diff --git a/dataset_preprocessing/camelyon17/unlabeled/README.md b/dataset_preprocessing/camelyon17/unlabeled/README.md new file mode 100644 index 00000000..ebfd19f3 --- /dev/null +++ b/dataset_preprocessing/camelyon17/unlabeled/README.md @@ -0,0 +1,24 @@ +## Unlabeled Camelyon17-WILDS patch processing + +#### Requirements + +- openslide-python>=1.1.2 +- opencv-python>=4.4.0 + +openslide-python relies on first installing OpenSlide; +see [installation instructions](https://github.com/openslide/openslide-python). + +#### Instructions + +1. Download the [CAMELYON17 training data](https://drive.google.com/drive/folders/0BzsdkU4jWx9BSEI2X1VOLUpYZ3c?resourcekey=0-41XIPJNyEAo598wHxVAP9w) + into `SLIDE_ROOT`. + +2. Run `python generate_all_patch_coords.py --slide_root SLIDE_ROOT --output_root OUTPUT_ROOT` to generate a .csv of all + potential patches as well as the tissue masks for each WSI. `OUTPUT_ROOT` is wherever you would like the + patches to eventually be written. + +3. Then run `python generate_final_metadata.py --slide_root SLIDE_ROOT --output_root OUTPUT_ROOT` + to generate the metadata.csv file for unlabeled Camelyon. + +4. Finally, run `python extract_final_patches_to_disk.py --slide_root SLIDE_ROOT --output_root OUTPUT_ROOT` to + extract the chosen patches from the WSIs and write them to disk. \ No newline at end of file diff --git a/dataset_preprocessing/camelyon17/unlabeled/extract_final_patches_to_disk.py b/dataset_preprocessing/camelyon17/unlabeled/extract_final_patches_to_disk.py new file mode 100644 index 00000000..95d50f09 --- /dev/null +++ b/dataset_preprocessing/camelyon17/unlabeled/extract_final_patches_to_disk.py @@ -0,0 +1,79 @@ +import argparse +import os +import pdb +from tqdm import tqdm + +import openslide +import pandas as pd + +from generate_all_patch_coords import PATCH_LEVEL, CENTER_SIZE + + +def write_patch_images_from_df(slide_root, output_root): + print("Writing patch images to disk...") + read_df = pd.read_csv( + os.path.join(output_root, "metadata.csv"), index_col=0, dtype={"patient": "str"} + ) + + patch_level = PATCH_LEVEL + center_size = CENTER_SIZE + patch_size = center_size * 3 + + for idx in tqdm(read_df.index): + orig_x = read_df.loc[idx, "x_coord"] + orig_y = read_df.loc[idx, "y_coord"] + center = read_df.loc[idx, "center"] + patient = read_df.loc[idx, "patient"] + node = read_df.loc[idx, "node"] + + patch_folder = os.path.join( + output_root, "patches", f"patient_{patient}_node_{node}" + ) + patch_path = os.path.join( + patch_folder, + f"patch_patient_{patient}_node_{node}_x_{orig_x}_y_{orig_y}.png", + ) + + os.makedirs(patch_folder, exist_ok=True) + if os.path.isfile(patch_path): + continue + + slide_path = os.path.join( + slide_root, + f"center_{center}", + f"patient_{patient}", + f"patient_{patient}_node_{node}.tif", + ) + slide = openslide.OpenSlide(slide_path) + + # Coords are at patch_level + # First shift coords to top left corner of the entire patch + x = orig_x - center_size + y = orig_y - center_size + # Then match to level 0 coords so we can use read_region + x = int( + round( + x + * slide.level_dimensions[0][0] + / slide.level_dimensions[patch_level][0] + ) + ) + y = int( + round( + y + * slide.level_dimensions[0][1] + / slide.level_dimensions[patch_level][1] + ) + ) + + patch = slide.read_region((x, y), 2, (patch_size, patch_size)) + patch.save(patch_path) + print("Done.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--slide_root", required=True) + parser.add_argument("--output_root", required=True) + args = parser.parse_args() + write_patch_images_from_df(slide_root=args.slide_root, output_root=args.output_root) diff --git a/dataset_preprocessing/camelyon17/unlabeled/generate_all_patch_coords.py b/dataset_preprocessing/camelyon17/unlabeled/generate_all_patch_coords.py new file mode 100644 index 00000000..3e32c5f4 --- /dev/null +++ b/dataset_preprocessing/camelyon17/unlabeled/generate_all_patch_coords.py @@ -0,0 +1,200 @@ +# Code adapted from https://github.com/liucong3/camelyon17 +# and https://github.com/cv-lee/Camelyon17 + +import argparse +import os +import pdb + +import openslide +import cv2 +import numpy as np +import pandas as pd + +CENTER_SIZE = 32 +MASK_LEVEL = 4 +PATCH_LEVEL = 2 + +NUM_OF_HOSPITALS = 5 + + +def _make_masks(slide_path, mask_level, **args): + """ + Return a slide with annotated tumor, normal, and tissue masks using an Otsu threshold + """ + print("_make_masks(%s)" % slide_path) + + # Load slide + slide = openslide.OpenSlide(slide_path) + slide_map = np.array(slide.get_thumbnail(slide.level_dimensions[mask_level])) + + # draw tissue mask + slide_lv = slide.read_region((0, 0), mask_level, slide.level_dimensions[mask_level]) + slide_lv = cv2.cvtColor(np.array(slide_lv), cv2.COLOR_RGBA2RGB) + slide_lv = cv2.cvtColor(slide_lv, cv2.COLOR_BGR2HSV) + slide_lv = slide_lv[:, :, 1] + _, tissue_mask = cv2.threshold( + slide_lv, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU + ) + return slide, slide_map, tissue_mask + + +def _write_masks(mask_folder_path, slide_map, tissue_mask, **args): + """ + Write masks out to disk; used for sanity checking and visualization. + """ + print("_write_masks") + os.makedirs(mask_folder_path, exist_ok=True) + map_path = os.path.join(mask_folder_path, "map.png") + cv2.imwrite(map_path, slide_map) + tissue_mask_path = os.path.join(mask_folder_path, "tissue_mask.png") + cv2.imwrite(tissue_mask_path, np.array(tissue_mask)) + + +def _record_patches( + center_size, + slide, + slide_map, + patch_level, + mask_level, + tissue_mask, + normal_threshold, + **args, +): + """ + Extract all patches using the tissue masks. + """ + width, height = np.array(slide.level_dimensions[patch_level]) // center_size + print("_record_patches(w=%d,h=%d)" % (width, height)) + margin = 5 + mask_max = 255 + assert mask_level >= patch_level + width_mask_step = ( + center_size + * slide.level_dimensions[mask_level][0] + / slide.level_dimensions[patch_level][0] + ) + height_mask_step = ( + center_size + * slide.level_dimensions[mask_level][1] + / slide.level_dimensions[patch_level][1] + ) + + patch_list = [] + + # These mark the coordinates of the central region of the patch + for i in range(margin, width - margin): + for j in range(margin, height - margin): + # We no longer have access to the tumor and normal masks. Just use the tissue mask + mask_i_start = round(width_mask_step * i) + mask_i_end = round(width_mask_step * (i + 1)) + mask_j_start = round(height_mask_step * j) + mask_j_end = round(height_mask_step * (j + 1)) + + # Compute mask only over central region + tissue_mask_avg = tissue_mask[ + mask_j_start:mask_j_end, mask_i_start:mask_i_end + ].mean() + tissue_area_ratio = tissue_mask_avg / mask_max + + # Tissue is the union of normal and tumor, so check the tissue area ratio is above the normal threshold + if tissue_area_ratio > normal_threshold: + # Set the label to be -1 to indicate it's unlabeled data + patch_list.append((center_size * i, center_size * j, -1)) + cv2.rectangle( + slide_map, + (mask_i_start, mask_j_start), + (mask_i_end, mask_j_end), + (100, 149, 237), # cornflower blue for debugging + thickness=1, + ) + + print(f"Added {len(patch_list)} patches...") + df = pd.DataFrame(patch_list, columns=["x_coord", "y_coord", "tumor"]) + return df + + +def generate_file(patient, node, slide_path, folder_path): + args = { + "slide_path": slide_path, + "patch_level": PATCH_LEVEL, + "mask_level": MASK_LEVEL, + "center_size": CENTER_SIZE, + "mask_folder_path": folder_path, + "normal_threshold": 0.2, + } + args["slide"], args["slide_map"], args["tissue_mask"] = _make_masks(**args) + df = _record_patches(**args) + df["patient"] = patient + df["node"] = node + _write_masks(**args) + return df + + +def generate_files(slide_root, output_root, center): + aggregate_df = pd.DataFrame( + columns=["patient", "node", "x_coord", "y_coord", "tumor"] + ) + + # Assume files are organized in the following way: + # center_/patient_/patient__node_.tif + if center is None: + print( + "A value for --center was not specified. Generating patches for all centers..." + ) + centers = range(NUM_OF_HOSPITALS) + else: + centers = [center] + + for center in centers: + print(f"Generating patches for center {center}...") + center_dir = os.path.join(slide_root, f"center_{center}") + patient_dirs = os.listdir(center_dir) + + for patient_dir in patient_dirs: + patient_dir = os.path.join(center_dir, patient_dir) + if not os.path.isdir(patient_dir): + continue + + for slide_file in os.listdir(patient_dir): + if not slide_file.endswith(".tif"): + continue + + slide_path = os.path.join(patient_dir, slide_file) + prefix = slide_file.split(".tif")[0] + try: + assert len(prefix.split("_")) == 4 + + # The XML files have labels so it's not needed for Unlabeled Camelyon + df = generate_file( + patient=prefix.split("_")[1], + node=prefix.split("_")[3], + slide_path=slide_path, + folder_path=os.path.join(output_root, "masks", prefix), + ) + aggregate_df = pd.concat([aggregate_df, df]) + except openslide.OpenSlideError as err: + print(err) + continue + + # Coordinates of all potential patches + aggregate_df = aggregate_df.reset_index(drop=True) + aggregate_df.to_csv(os.path.join(output_root, "all_unlabeled_patch_coords.csv")) + return aggregate_df + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--slide_root", required=True) + parser.add_argument("--output_root", required=True) + parser.add_argument( + "--center", + type=int, + help="Which specific center to extract patches for. If a center is not specified, " + "patches will be extracted for all five centers.", + ) + args = parser.parse_args() + + generate_files( + slide_root=args.slide_root, output_root=args.output_root, center=args.center + ) diff --git a/dataset_preprocessing/camelyon17/unlabeled/generate_final_metadata.py b/dataset_preprocessing/camelyon17/unlabeled/generate_final_metadata.py new file mode 100644 index 00000000..497ee24f --- /dev/null +++ b/dataset_preprocessing/camelyon17/unlabeled/generate_final_metadata.py @@ -0,0 +1,113 @@ +import argparse +import os +import pdb + +import numpy as np +import pandas as pd +from matplotlib import pyplot as plt + +# Fix seed for reproducibility +np.random.seed(0) + +_NUM_CENTERS = 5 +_NUM_PATCHES_TO_SUBSAMPLE = 6667 +_NUM_PATIENTS_PER_HOSPITAL = 20 + + +def generate_final_metadata(slide_root, output_root): + def print_stats(patches_df): + print(f"\nStatistics:\nTotal # of patches: {patches_df.shape[0]}") + for center in range(_NUM_CENTERS): + print( + f"Center {center}: {np.sum(patches_df['center'] == center):6d} patches" + ) + print() + + patches_path = os.path.join(output_root, "all_unlabeled_patch_coords.csv") + print(f"Importing patches from {patches_path}...") + df = pd.read_csv( + patches_path, + index_col=0, + dtype={"patient": "str", "tumor": "int"}, + ) + + # Assign slide numbers to patients + nodes + patient_node_list = list( + set(df[["patient", "node"]].itertuples(index=False, name=None)) + ) + patient_node_list.sort() + patient_node_to_slide_map = {} + for idx, (patient, node) in enumerate(patient_node_list): + patient_node_to_slide_map[(patient, node)] = idx + + for (patient, node), slide_idx in patient_node_to_slide_map.items(): + mask = (df["patient"] == patient) & (df["node"] == node) + df.loc[mask, "slide"] = slide_idx + df["slide"] = df["slide"].astype("int") + + # The raw data has the following assignments: + # Center 0: patients 0 to 19 + # Center 1: patients 20 to 39 + # Center 2: patients 40 to 59 + # Center 3: patients 60 to 79 + # Center 4: patients 80 to 99 + df["center"] = df["patient"].astype("int") // _NUM_PATIENTS_PER_HOSPITAL + print_stats(df) + for center, slide in set( + df[["center", "slide"]].itertuples(index=False, name=None) + ): + assert center == slide // 100, "Expected 100 slides per center." + + # Remove patches from the original metadata.csv before subsampling. + # There are 50 XML files in the lesion_annotation folder, so 50 patient-node pairs were + # already used in the original WILDS Camelyon dataset. + print( + "Removing patches from slides that were used in the original Camelyon-WILDS dataset..." + ) + for file in os.listdir(os.path.join(slide_root, "lesion_annotations")): + if file.endswith(".xml") and not file.startswith("._"): + prefix = file.split(".xml")[0] + patient = prefix.split("_")[1] + node = prefix.split("_")[3] + + patient_mask = df["patient"] == patient + node_mask = df["node"] == int(node) + df = df[~(patient_mask & node_mask)] + print_stats(df) + + # The labeled Camelyon-WILDS dataset has approximately 300,000 patches. We want about 10x unlabeled data, + # which corresponds to ~3 million patches. Since each hospital of the original Camelyon17 training set + # has a 100 slides, we subsample 6,667 patches from each slide, resulting in 600,030 patches total from each + # hospital except for Center 0. Slide 38 of Center 0 only has 5,824 patches, so we instead subsample a total of + # 599,187 patches for Center 0. Therefore, there is a total of 2,999,307 unlabeled patches across the hospitals. + print(f"Subsampling {_NUM_PATCHES_TO_SUBSAMPLE} patches from each slide...") + indices_to_keep = [] + for slide in set(df["slide"]): + slide_mask = df["slide"] == slide + slide_indices = list(df.index[slide_mask]) + print( + f"slide={slide}, choosing {_NUM_PATCHES_TO_SUBSAMPLE} patches from {len(slide_indices)} patches" + ) + if _NUM_PATCHES_TO_SUBSAMPLE < len(slide_indices): + indices_to_keep += list( + np.random.choice( + slide_indices, size=_NUM_PATCHES_TO_SUBSAMPLE, replace=False + ) + ) + else: + print("Adding all slides...") + indices_to_keep += slide_indices + df_to_keep = df.loc[indices_to_keep, :].copy().reset_index(drop=True) + + print_stats(df_to_keep) + df_to_keep.to_csv(os.path.join(output_root, "metadata.csv")) + print("Done.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--slide_root", required=True) + parser.add_argument("--output_root", required=True) + args = parser.parse_args() + + generate_final_metadata(slide_root=args.slide_root, output_root=args.output_root) diff --git a/dataset_preprocessing/camelyon17/unlabeled/validate.py b/dataset_preprocessing/camelyon17/unlabeled/validate.py new file mode 100644 index 00000000..2fcebf69 --- /dev/null +++ b/dataset_preprocessing/camelyon17/unlabeled/validate.py @@ -0,0 +1,79 @@ +import argparse +import os +import pdb + +""" +Validate the content of the unlabeled Camelyon17 dataset after +preprocessing and uploading to CodaLab. + +Statistics: +Total # of patches: 2,999,307 +Center 0: 599,187 patches +Center 1: 600,030 patches +Center 2: 600,030 patches +Center 3: 600,030 patches +Center 4: 600,030 patches + +Usage: + + python dataset_preprocessing/camelyon17/unlabeled/validate.py +""" + +_EXPECTED_SLIDES_COUNT = 450 + + +def validate_unlabeled_dataset(root_dir: str): + def get_patients_center(patient_id: str): + patient_no = int(patient_id) + if 0 <= patient_no < 20: + return 0 + elif 20 <= patient_no < 40: + return 1 + elif 40 <= patient_no < 60: + return 2 + elif 60 <= patient_no < 80: + return 3 + elif 80 <= patient_no < 100: + return 4 + else: + raise ValueError(f"Can't get center for patient {patient_id}.") + + dataset_dir = os.path.join(root_dir, "camelyon17_unlabeled_v1.0") + content = os.listdir(dataset_dir) + assert "patches" in content + assert "RELEASE_v1.0.txt" in content + assert "metadata.csv" in content + + slides_dir = os.path.join(dataset_dir, "patches") + slides = os.listdir(slides_dir) + + slide_count = 0 + patch_counts = [0 for _ in range(5)] + for slide in slides: + patches_dir = os.path.join(slides_dir, slide) + if not os.path.isdir(patches_dir): + continue + slide_count += 1 + + slide_split = slide.split("_") + assert len(slide_split) == 4 + patient_id = slide_split[1] + center = get_patients_center(patient_id) + for patch in os.listdir(patches_dir): + if patch.endswith(".png"): + patch_counts[center] += 1 + + assert ( + slide_count == _EXPECTED_SLIDES_COUNT + ), f"Got incorrect number of slides. Expected: {_EXPECTED_SLIDES_COUNT}, Actual: {len(slides)}" + print(f"Patch counts: {patch_counts}") + assert patch_counts == [599187, 600030, 600030, 600030, 600030] + assert sum(patch_counts) == 2999307 + print("\nVerified.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("root_dir", help="Path to the datasets directory.") + args = parser.parse_args() + validate_unlabeled_dataset(args.root_dir) diff --git a/dataset_preprocessing/civilcomments/README.md b/dataset_preprocessing/civilcomments/README.md index dbd96b83..63f63b21 100644 --- a/dataset_preprocessing/civilcomments/README.md +++ b/dataset_preprocessing/civilcomments/README.md @@ -4,4 +4,6 @@ 1. Download `all_data.csv` from https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification/data. -2. Run `python augment_identities_and_split.py --root ROOT`, where `ROOT` is where you downloaded `ROOT`. This will create `all_data_with_identities.csv` in the same folder, which is what we use in WILDS. +2. Run `python process_labeled.py --root ROOT`, where `ROOT` is where you downloaded `ROOT`. This will create `all_data_with_identities.csv` in the same folder, which is the labeled data that we use in WILDS. + +3. After the above step, run `python process_unlabeled.py --root ROOT`, where `ROOT` is where you downloaded `ROOT`. This will create `unlabeled_data_with_identities.csv` in the same folder, which is the unlabeled data that we optionally use in WILDS. diff --git a/dataset_preprocessing/civilcomments/augment_identities_and_split.py b/dataset_preprocessing/civilcomments/process_labeled.py similarity index 97% rename from dataset_preprocessing/civilcomments/augment_identities_and_split.py rename to dataset_preprocessing/civilcomments/process_labeled.py index 89101442..76638f25 100644 --- a/dataset_preprocessing/civilcomments/augment_identities_and_split.py +++ b/dataset_preprocessing/civilcomments/process_labeled.py @@ -1,147 +1,147 @@ -import pandas as pd -from matplotlib import pyplot as plt -import os,sys -import numpy as np -from tqdm import tqdm -import argparse - -from attr_definitions import GROUP_ATTRS, AGGREGATE_ATTRS, ORIG_ATTRS - -def load_df(root): - """ - Loads the data and removes all examples where we don't have identity annotations. - """ - df = pd.read_csv(os.path.join(root, 'all_data.csv')) - df = df.loc[(df['identity_annotator_count'] > 0), :] - df = df.reset_index(drop=True) - return df - -def augment_df(df): - """ - Augment the dataframe with auxiliary attributes. - First, we create aggregate attributes, like `LGBTQ` or `other_religions`. - These are aggregated because there would otherwise not be enough examples to accurately - estimate their accuracy. - - Next, for each category of demographics (e.g., race, gender), we construct an auxiliary - attribute (e.g., `na_race`, `na_gender`) that is 1 if the comment has no identities related to - that demographic, and is 0 otherwise. - Note that we can't just create a single multi-valued attribute like `gender` because there's - substantial overlap: for example, 4.6% of comments mention both male and female identities. - """ - df = df.copy() - for aggregate_attr in AGGREGATE_ATTRS: - aggregate_mask = pd.Series([False] * len(df)) - for attr in AGGREGATE_ATTRS[aggregate_attr]: - attr_mask = (df[attr] >= 0.5) - aggregate_mask = aggregate_mask | attr_mask - df[aggregate_attr] = 0 - df.loc[aggregate_mask, aggregate_attr] = 1 - - attr_count = np.zeros(len(df)) - for attr in ORIG_ATTRS: - attr_mask = (df[attr] >= 0.5) - attr_count += attr_mask - df['num_identities'] = attr_count - df['more_than_one_identity'] = (attr_count > 1) - - for group in GROUP_ATTRS: - print(f'## {group}') - counts = {} - na_mask = np.ones(len(df)) - for attr in GROUP_ATTRS[group]: - attr_mask = (df[attr] >= 0.5) - na_mask = na_mask & ~attr_mask - counts[attr] = np.mean(attr_mask) - counts['n/a'] = np.mean(na_mask) - - col_name = f'na_{group}' - df[col_name] = 0 - df.loc[na_mask, col_name] = 1 - - for k, v in counts.items(): - print(f'{k:40s}: {v:.4f}') - print() - return df - -def construct_splits(df): - """ - Construct splits. - The original data already has a train vs. test split. - We triple the size of the test set so that we can better estimate accuracy on the small groups, - and construct a validation set by randomly sampling articles. - """ - - df = df.copy() - train_df = df.loc[df['split'] == 'train'] - test_df = df.loc[df['split'] == 'test'] - train_articles = set(train_df['article_id'].values) - test_articles = set(test_df['article_id'].values) - # Assert no overlap between train and test articles - assert len(train_articles.intersection(test_articles)) == 0 - - n_train = len(train_df) - n_test = len(test_df) - n_train_articles = len(train_articles) - n_test_articles = len(test_articles) - - ## Set params - n_val_articles = n_test_articles - n_new_test_articles = 2 * n_test_articles - - np.random.seed(0) - - # Sample val articles - val_articles = np.random.choice( - list(train_articles), - size=n_val_articles, - replace=False) - df.loc[df['article_id'].isin(val_articles), 'split'] = 'val' - - # Sample new test articles - train_articles = train_articles - set(val_articles) - new_test_articles = np.random.choice( - list(train_articles), - size=n_new_test_articles, - replace=False) - df.loc[df['article_id'].isin(new_test_articles), 'split'] = 'test' - - train_df = df.loc[df['split'] == 'train'] - val_df = df.loc[df['split'] == 'val'] - test_df = df.loc[df['split'] == 'test'] - - train_articles = set(train_df['article_id'].values) - val_articles = set(val_df['article_id'].values) - test_articles = set(test_df['article_id'].values) - - # Sanity checks - assert len(df) == len(train_df) + len(val_df) + len(test_df) - assert n_train == len(train_df) + len(val_df) + np.sum(df['article_id'].isin(new_test_articles)) - assert n_test == len(test_df) - np.sum(df['article_id'].isin(new_test_articles)) - assert n_train_articles == len(train_articles) + len(val_articles) + len(new_test_articles) - assert n_val_articles == len(val_articles) - assert n_test_articles == len(test_articles) - n_new_test_articles - assert len(train_articles.intersection(val_articles)) == 0 - assert len(train_articles.intersection(test_articles)) == 0 - assert len(val_articles.intersection(test_articles)) == 0 - - print('% of examples') - for split in ['train', 'val', 'test']: - print(split, np.mean(df['split'] == split), np.sum(df['split'] == split)) - print('') - - print('class balance') - for split in ['train', 'val', 'test']: - split_df = df.loc[df['split'] == split] - print('pos', np.mean(split_df['toxicity'] > 0.5)) - return df - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--root', required=True) - args = parser.parse_args() - - df = load_df(args.root) - df = augment_df(df) - df = construct_splits(df) - df.to_csv(os.path.join(args.root, f'all_data_with_identities.csv')) +import pandas as pd +from matplotlib import pyplot as plt +import os,sys +import numpy as np +from tqdm import tqdm +import argparse + +from attr_definitions import GROUP_ATTRS, AGGREGATE_ATTRS, ORIG_ATTRS + +def load_df(root): + """ + Loads the data and removes all examples where we don't have identity annotations. + """ + df = pd.read_csv(os.path.join(root, 'all_data.csv')) + df = df.loc[(df['identity_annotator_count'] > 0), :] + df = df.reset_index(drop=True) + return df + +def augment_df(df): + """ + Augment the dataframe with auxiliary attributes. + First, we create aggregate attributes, like `LGBTQ` or `other_religions`. + These are aggregated because there would otherwise not be enough examples to accurately + estimate their accuracy. + + Next, for each category of demographics (e.g., race, gender), we construct an auxiliary + attribute (e.g., `na_race`, `na_gender`) that is 1 if the comment has no identities related to + that demographic, and is 0 otherwise. + Note that we can't just create a single multi-valued attribute like `gender` because there's + substantial overlap: for example, 4.6% of comments mention both male and female identities. + """ + df = df.copy() + for aggregate_attr in AGGREGATE_ATTRS: + aggregate_mask = pd.Series([False] * len(df)) + for attr in AGGREGATE_ATTRS[aggregate_attr]: + attr_mask = (df[attr] >= 0.5) + aggregate_mask = aggregate_mask | attr_mask + df[aggregate_attr] = 0 + df.loc[aggregate_mask, aggregate_attr] = 1 + + attr_count = np.zeros(len(df)) + for attr in ORIG_ATTRS: + attr_mask = (df[attr] >= 0.5) + attr_count += attr_mask + df['num_identities'] = attr_count + df['more_than_one_identity'] = (attr_count > 1) + + for group in GROUP_ATTRS: + print(f'## {group}') + counts = {} + na_mask = np.ones(len(df)) + for attr in GROUP_ATTRS[group]: + attr_mask = (df[attr] >= 0.5) + na_mask = na_mask & ~attr_mask + counts[attr] = np.mean(attr_mask) + counts['n/a'] = np.mean(na_mask) + + col_name = f'na_{group}' + df[col_name] = 0 + df.loc[na_mask, col_name] = 1 + + for k, v in counts.items(): + print(f'{k:40s}: {v:.4f}') + print() + return df + +def construct_splits(df): + """ + Construct splits. + The original data already has a train vs. test split. + We triple the size of the test set so that we can better estimate accuracy on the small groups, + and construct a validation set by randomly sampling articles. + """ + + df = df.copy() + train_df = df.loc[df['split'] == 'train'] + test_df = df.loc[df['split'] == 'test'] + train_articles = set(train_df['article_id'].values) + test_articles = set(test_df['article_id'].values) + # Assert no overlap between train and test articles + assert len(train_articles.intersection(test_articles)) == 0 + + n_train = len(train_df) + n_test = len(test_df) + n_train_articles = len(train_articles) + n_test_articles = len(test_articles) + + ## Set params + n_val_articles = n_test_articles + n_new_test_articles = 2 * n_test_articles + + np.random.seed(0) + + # Sample val articles + val_articles = np.random.choice( + list(train_articles), + size=n_val_articles, + replace=False) + df.loc[df['article_id'].isin(val_articles), 'split'] = 'val' + + # Sample new test articles + train_articles = train_articles - set(val_articles) + new_test_articles = np.random.choice( + list(train_articles), + size=n_new_test_articles, + replace=False) + df.loc[df['article_id'].isin(new_test_articles), 'split'] = 'test' + + train_df = df.loc[df['split'] == 'train'] + val_df = df.loc[df['split'] == 'val'] + test_df = df.loc[df['split'] == 'test'] + + train_articles = set(train_df['article_id'].values) + val_articles = set(val_df['article_id'].values) + test_articles = set(test_df['article_id'].values) + + # Sanity checks + assert len(df) == len(train_df) + len(val_df) + len(test_df) + assert n_train == len(train_df) + len(val_df) + np.sum(df['article_id'].isin(new_test_articles)) + assert n_test == len(test_df) - np.sum(df['article_id'].isin(new_test_articles)) + assert n_train_articles == len(train_articles) + len(val_articles) + len(new_test_articles) + assert n_val_articles == len(val_articles) + assert n_test_articles == len(test_articles) - n_new_test_articles + assert len(train_articles.intersection(val_articles)) == 0 + assert len(train_articles.intersection(test_articles)) == 0 + assert len(val_articles.intersection(test_articles)) == 0 + + print('% of examples') + for split in ['train', 'val', 'test']: + print(split, np.mean(df['split'] == split), np.sum(df['split'] == split)) + print('') + + print('class balance') + for split in ['train', 'val', 'test']: + split_df = df.loc[df['split'] == split] + print('pos', np.mean(split_df['toxicity'] > 0.5)) + return df + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--root', required=True) + args = parser.parse_args() + + df = load_df(args.root) + df = augment_df(df) + df = construct_splits(df) + df.to_csv(os.path.join(args.root, f'all_data_with_identities.csv')) diff --git a/dataset_preprocessing/civilcomments/process_unlabeled.py b/dataset_preprocessing/civilcomments/process_unlabeled.py new file mode 100644 index 00000000..8ea8c571 --- /dev/null +++ b/dataset_preprocessing/civilcomments/process_unlabeled.py @@ -0,0 +1,89 @@ +import argparse +import csv +import os +import pdb + +import numpy as np +import pandas as pd + +# Fix the seed for reproducibility +np.random.seed(0) + +""" +Process unlabeled data in CivilComments. +Script is intended to be run after process_labeled.py + +Note that there is substantial overlap between the articles that unlabeled +comments are from and the articles that the labeled comments are from. +Specifically, 92% (1427849 out of 1551516) unlabeled comments are from +articles that also have comments in the labeled set. +""" + +TRAIN, VAL, TEST, UNLABELED = ('train', 'val', 'test', 'extra_unlabeled') + +def load_unlabeled_df(root): + """ + Loads the raw data where we don't have identity annotations. + """ + df = pd.read_csv(os.path.join(root, 'all_data.csv')) + df = df.loc[(df['identity_annotator_count'] == 0), :] + df = df.dropna(axis=0, how='any', subset=['id', 'comment_text', 'article_id']) # make sure data is clean + df = df.reset_index(drop=True) + return df + +def load_labeled_df(root): + """ + Loads the processed data for which we do have identity annotations. + """ + df = pd.read_csv(os.path.join(root, 'all_data_with_identities.csv'), index_col=0) + return df + +def merge_dfs(unlabeled, labeled): + """ + Drops columns that are in unlabeled but not labeled + Adds columns that are in labeled but not unlabeled and sets values to NaN + """ + common_cols = unlabeled.columns & labeled.columns + unlabeled = unlabeled[common_cols] + joint = labeled.append(unlabeled, ignore_index = True) + return joint + +def main(args): + unlabeled = load_unlabeled_df(args.root) + labeled = load_labeled_df(args.root) + + # set all unlabeled examples to the same split + unlabeled['split'] = UNLABELED + + # merge unlabeled, labeled dfs + joint = merge_dfs(unlabeled, labeled) + assert (joint.columns == labeled.columns).all() + + def output_split_sizes(df): + print("-" * 50) + print(f'Train size: {len(df[df["split"] == TRAIN])}') + print(f'Val size: {len(df[df["split"] == VAL])}') + print(f'Test size: {len(df[df["split"] == TEST])}') + print( + f'Unlabeled size: {len(df[df["split"] == UNLABELED])}' + ) + print("-" * 50) + print("\n") + + output_split_sizes(joint) + + # Write out the new unlabeled split to user.csv + joint.to_csv(f'{args.root}/all_data_with_identities_and_unlabeled.csv', index=True) + joint[joint['split'] == UNLABELED].to_csv(f'{args.root}/unlabeled_data_with_identities.csv', index=True) + print("Done.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Create unlabeled splits for CivilComments.") + parser.add_argument( + "--root", + type=str, + help="Path to the dir containing the CivilComments processed labeled csv and full csv.", + ) + args = parser.parse_args() + main(args) diff --git a/dataset_preprocessing/domainnet/generate_metadata.py b/dataset_preprocessing/domainnet/generate_metadata.py new file mode 100644 index 00000000..b0fac190 --- /dev/null +++ b/dataset_preprocessing/domainnet/generate_metadata.py @@ -0,0 +1,151 @@ +import argparse +import os +import pdb + +import pandas as pd +import numpy as np + +# Fix the seed for reproducibility +np.random.seed(0) + +""" +Generate a CSV with the metadata for DomainNet: + + @inproceedings{peng2019moment, + title={Moment matching for multi-source domain adaptation}, + author={Peng, Xingchao and Bai, Qinxun and Xia, Xide and Huang, Zijun and Saenko, Kate and Wang, Bo}, + booktitle={Proceedings of the IEEE International Conference on Computer Vision}, + pages={1406--1415}, + year={2019} + } + +The dataset can be downloaded from http://ai.bu.edu/M3SDA. + +There are 586,576 images in 345 categories (airplane, ball, cup, etc.) across 6 domains (clipart, infograph, +painting, quickdraw, real and sketch). Images are either PNG or JPG files. + +The metadata CSV file has the following fields: + +1. image_path: Path to the image file. The path has the following format: //. +2. domain: One of the 6 possible domains. +3. split: One of "train", "val" or "test". +4. category: One of the 345 possible categories. +5. y: The index corresponding to the category (e.g. 2 if the image is of a ball). + +Example usage: + + python dataset_preprocessing/domainnet/generate_metadata.py . + +""" + +DOMAINS = ["clipart", "infograph", "painting", "quickdraw", "real", "sketch"] +METADATA_COLUMNS = ["image_path", "domain", "split", "category", "y"] +NUM_OF_CATEGORIES = 345 + + +def main(dataset_path, split_val=False): + print("Generating metadata.csv for DomainNet...") + + # Build mapping of image to split ("train", "val" or "test) and label + categories = [""] * NUM_OF_CATEGORIES + image_info = dict() + original_splits_path = os.path.join(dataset_path, "original_splits") + for domain in DOMAINS: + train_count = 0 + with open(os.path.join(original_splits_path, f"{domain}_train.txt")) as f: + for line in f.readlines(): + image_path, label = line.strip().split(" ") + categories[int(label)] = image_path.split(os.path.sep)[1] + image_info[image_path] = ["train", label] + train_count += 1 + + test_count = 0 + bucketed_test_set = [[] for _ in range(NUM_OF_CATEGORIES)] + with open(os.path.join(original_splits_path, f"{domain}_test.txt")) as f: + for line in f.readlines(): + image_path, label = line.strip().split(" ") + label = int(label) + image_info[image_path] = ["test", label] + test_count += 1 + bucketed_test_set[label].append(image_path) + + total_count = train_count + test_count + train_percentage = np.round(float(train_count) / total_count * 100.0, 2) + test_percentage = np.round(float(test_count) / total_count * 100.0, 2) + print( + f"Domain {domain} originally had {train_count} ({train_percentage}%) training examples " + f"and {test_count} ({test_percentage}%) test examples with a total of {total_count} examples." + ) + + val_count = 0 + if split_val: + # Go from 70-30 train-test split to 70-15-15 train-val-test split + print("Creating a validation set from the existing test set...") + for category_images in bucketed_test_set: + new_val_images = np.random.choice( + category_images, len(category_images) // 2, replace=False + ) + for image_path in new_val_images: + image_info[image_path][0] = "val" + val_count += 1 + + val_percentage = np.round(float(val_count) / total_count * 100.0, 2) + test_count -= val_count + test_percentage = np.round(float(test_count) / total_count * 100.0, 2) + print( + f"Domain {domain} now has {train_count} ({train_percentage}%) training examples, " + f"{val_count} ({val_percentage}%) validation examples and {test_count} ({test_percentage}%) test " + f"examples with a total of {total_count} examples.\n" + ) + + # For debugging + print(f"Categories in order: {categories}") + + # Build metadata + metadata_dict = {column: [] for column in METADATA_COLUMNS} + for domain in DOMAINS: + domain_path = os.path.join(dataset_path, domain) + + for category in os.listdir(domain_path): + category_path = os.path.join(domain_path, category) + if not os.path.isdir(category_path): + continue + + for image in os.listdir(category_path): + image_path = os.path.join(domain, category, image) + if ( + image.endswith(".jpg") or image.endswith(".png") + ) and image_path in image_info: + split, y = image_info[image_path] + metadata_dict["image_path"].append(image_path) + metadata_dict["domain"].append(domain) + metadata_dict["split"].append(split) + metadata_dict["category"].append(category) + metadata_dict["y"].append(y) + + # Write metadata out as a CSV file + metadata_df = pd.DataFrame(metadata_dict) + metadata_path = os.path.join(dataset_path, "metadata.csv") + print(f"Writing metadata out to {metadata_path}...") + metadata_df.to_csv(metadata_path, index=False) + print("Done.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate a CSV with the metadata for DomainNet." + ) + parser.add_argument( + "path", + type=str, + help="Path to the DomainNet dataset downloaded from http://ai.bu.edu/M3SDA", + ) + parser.add_argument( + "--split-val", + action="store_true", + help="Whether to create a separate validation by splitting the existing test split " + "in half (defaults to false).", + ) + + args = parser.parse_args() + main(args.path, args.split_val) diff --git a/dataset_preprocessing/domainnet/generate_sentry_metadata.py b/dataset_preprocessing/domainnet/generate_sentry_metadata.py new file mode 100644 index 00000000..b182359c --- /dev/null +++ b/dataset_preprocessing/domainnet/generate_sentry_metadata.py @@ -0,0 +1,118 @@ +import argparse +import os +import pdb + +import pandas as pd +import numpy as np + +# Fix the seed for reproducibility +np.random.seed(0) + +""" +Generate a CSV with the metadata for DomainNet (SENTRY version): + + @inproceedings{peng2019moment, + title={Moment matching for multi-source domain adaptation}, + author={Peng, Xingchao and Bai, Qinxun and Xia, Xide and Huang, Zijun and Saenko, Kate and Wang, Bo}, + booktitle={Proceedings of the IEEE International Conference on Computer Vision}, + pages={1406--1415}, + year={2019} + } + + @article{prabhu2020sentry + author = {Prabhu, Viraj and Khare, Shivam and Kartik, Deeksha and Hoffman, Judy}, + title = {SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised Domain Adaptation}, + year = {2020}, + journal = {arXiv preprint: 2012.11460}, + } + +The dataset can be downloaded from http://ai.bu.edu/M3SDA. +The SENTRY splits can be found https://github.com/virajprabhu/SENTRY/tree/main/data/DomainNet/txt. + +There are 586,576 images in 345 categories (airplane, ball, cup, etc.) across 6 domains (clipart, infograph, +painting, quickdraw, real and sketch) in the original DomainNet dataset. Images are either PNG or JPG files. + +The SENTRY version of the dataset has 40 categories across 4 domains: +"Due to labeling noise prevalent in the full version of DomainNet, we instead use the subset proposed in +Tan et al. [42], which uses 40-commonly seen classes from four domains: Real (R), Clipart (C), Painting (P), +and Sketch (S)." + +The metadata CSV file has the following fields: + +1. image_path: Path to the image file. The path has the following format: //. +2. domain: One of the 4 possible domains. +3. split: One of "train" or "test". +4. category: One of the 40 possible categories. +5. y: Given to us by the SENTRY split + +Example usage: + + python dataset_preprocessing/domainnet/generate_sentry_metadata.py . + +""" + +DOMAINS = ["clipart", "painting", "real", "sketch"] +METADATA_COLUMNS = ["image_path", "domain", "split", "category", "y"] +NUM_OF_CATEGORIES = 40 +TEST_SPLIT = "test" +TRAIN_SPLIT = "train" + + +def main(sentry_splits_path): + def process_split(split, split_path): + count = 0 + categories = set() + with open(split_path) as f: + for line in f.readlines(): + image_path, label = line.strip().split(" ") + metadata_values = image_path.split(os.path.sep) + metadata_dict["image_path"].append(image_path) + metadata_dict["domain"].append(metadata_values[0]) + metadata_dict["split"].append(split) + metadata_dict["category"].append(metadata_values[1]) + categories.add(metadata_values[1]) + metadata_dict["y"].append(int(label)) + count += 1 + assert len(categories) == NUM_OF_CATEGORIES + return count + + print("Generating sentry_metadata.csv for DomainNet (SENTRY version)...") + + metadata_dict = {column: [] for column in METADATA_COLUMNS} + for domain in DOMAINS: + train_count = process_split( + TRAIN_SPLIT, + os.path.join(sentry_splits_path, f"{domain}_{TRAIN_SPLIT}_mini.txt"), + ) + test_count = process_split( + TEST_SPLIT, + os.path.join(sentry_splits_path, f"{domain}_{TEST_SPLIT}_mini.txt"), + ) + total_count = train_count + test_count + train_percentage = np.round(float(train_count) / total_count * 100.0, 2) + test_percentage = np.round(float(test_count) / total_count * 100.0, 2) + print( + f"Domain {domain} had {train_count} ({train_percentage}%) training examples " + f"and {test_count} ({test_percentage}%) test examples with a total of {total_count} examples." + ) + + # Write metadata out as a CSV file + metadata_df = pd.DataFrame(metadata_dict) + metadata_path = os.path.join(sentry_splits_path, "sentry_metadata.csv") + print(f"Writing metadata out to {metadata_path}...") + metadata_df.to_csv(metadata_path, index=False) + print("Done.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate a CSV with the metadata for DomainNet (SENTRY version)." + ) + parser.add_argument( + "path", + type=str, + help="Path to the DomainNet dataset downloaded from http://ai.bu.edu/M3SDA", + ) + + args = parser.parse_args() + main(args.path) diff --git a/dataset_preprocessing/molpcba_unlabeled/process.py b/dataset_preprocessing/molpcba_unlabeled/process.py new file mode 100644 index 00000000..7bac7405 --- /dev/null +++ b/dataset_preprocessing/molpcba_unlabeled/process.py @@ -0,0 +1,78 @@ +import numpy as np +from wilds import get_dataset +from rdkit.Chem import AllChem +from rdkit import Chem +from tqdm import tqdm +import pandas as pd +import os +import torch + +def compute_pcba_fingerprint(): + ''' + Compute the fingerprint features for molpcba molecules. + ''' + os.makedirs('processed_fp', exist_ok = True) + + pcba_dataset = get_dataset(dataset = 'ogb-molpcba') + smiles_list = pd.read_csv('data/ogbg_molpcba/mapping/mol.csv.gz')['smiles'].tolist() + x_list = [] + for smiles in tqdm(smiles_list): + mol = Chem.MolFromSmiles(smiles) + x = np.array(list(AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024)), dtype=np.int8) + x_list.append(x) + + x = np.stack(x_list) + + np.save('processed_fp/molpcba.npy', x) + + +def jaccard_similarity(vec, mat): + AND = vec * mat + OR = (vec + mat) > 0 + denom = np.sum(OR, axis = 1) + nom = np.sum(AND, axis = 1) + + denom[denom==0] = 1 + return nom / denom + + +def assign_to_group(): + ''' + Assign unlabeled pubchem molecules to scaffold groups of molpcba. + ''' + smiles_list = pd.read_csv('molpcba_unlabeled/mapping/unlabeled_smiles.csv', header = None)[0].tolist() + + x_pcba = np.load('processed_fp/molpcba.npy') + print(x_pcba.shape) + print((x_pcba > 1).sum()) + scaffold_group = np.load('data/ogbg_molpcba/raw/scaffold_group.npy') + + # ground-truth assignment + group_assignment = np.load('molpcba_unlabeled/processed/group_assignment.npy') + + for i, smiles in tqdm(enumerate(smiles_list), total = len(smiles_list)): + mol = Chem.MolFromSmiles(smiles) + x = np.array(list(AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024)), dtype=np.int8) + sim = jaccard_similarity(x, x_pcba) + + max_idx = np.argmax(sim) + a = scaffold_group[max_idx] + b = group_assignment[i] + + print(a, b) + assert a == b # make sure they coincide each other + + +def test_jaccard(): + vec = np.random.randn(1024) > 0 + mat = np.random.randn(1000, 1024) + mat[0] = vec + + sim = jaccard_similarity(vec, mat) + print(sim) + + +if __name__ == '__main__': + compute_pcba_fingerprint() + assign_to_group() + diff --git a/dataset_preprocessing/poverty/split_npys_unlabeled.py b/dataset_preprocessing/poverty/split_npys_unlabeled.py new file mode 100644 index 00000000..445d60dd --- /dev/null +++ b/dataset_preprocessing/poverty/split_npys_unlabeled.py @@ -0,0 +1,32 @@ +import os, sys +import argparse +import numpy as np +from PIL import Image +from pathlib import Path +from tqdm import tqdm + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--root_dir', required=True, + help='The poverty data directory.') + parser.add_argument('--out_dir_root', required=True, + help='The directory where output dir should be made.') + args = parser.parse_args() + + data_dir = Path(args.root_dir) + indiv_dir = Path(args.out_dir_root) / 'poverty_unlabeled_v1.0_indiv_npz' / 'images' + indiv_dir.mkdir(exist_ok=True, parents=True) + + counter = 0 + for i in range(27): + path = data_dir / f'unlabeled_landsat_poverty_imgs_{i}.npy' + arr = np.load(path, mmap_mode='r') + arr = arr.transpose((0, 3, 1, 2)) + for j in tqdm(range(len(arr))): + x = arr[j] + np.savez_compressed(indiv_dir / f'landsat_poverty_img_{counter}.npz', x=x) + counter += 1 + + +if __name__=='__main__': + main() diff --git a/examples/algorithms/AFN.py b/examples/algorithms/AFN.py new file mode 100644 index 00000000..7082ae2b --- /dev/null +++ b/examples/algorithms/AFN.py @@ -0,0 +1,131 @@ +import torch + +from algorithms.single_model_algorithm import SingleModelAlgorithm +from models.initializer import initialize_model + +class AFN(SingleModelAlgorithm): + """ + Adaptive Feature Norm (AFN) + + Original paper: + @InProceedings{Xu_2019_ICCV, + author = {Xu, Ruijia and Li, Guanbin and Yang, Jihan and Lin, Liang}, + title = {Larger Norm More Transferable: An Adaptive Feature Norm Approach for + Unsupervised Domain Adaptation}, + booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, + month = {October}, + year = {2019} + } + """ + + def __init__( + self, + config, + d_out, + grouper, + loss, + metric, + n_train_steps, + ): + # Initialize model + featurizer, classifier = initialize_model(config, d_out=d_out, is_featurizer=True) + model = torch.nn.Sequential(featurizer, classifier) + + # Initialize module + super().__init__( + config=config, + model=model, + grouper=grouper, + loss=loss, + metric=metric, + n_train_steps=n_train_steps, + ) + + # Model components + self.featurizer = featurizer + self.classifier = classifier + + # Algorithm hyperparameters + self.penalty_weight = config.afn_penalty_weight + self.delta_r = config.safn_delta_r + self.r = config.hafn_r + self.afn_loss = self.hafn_loss if config.use_hafn else self.safn_loss + + # Additional logging + self.logged_fields.append("classification_loss") + self.logged_fields.append("feature_norm_penalty") + + def safn_loss(self, features): + """ + Adapted from https://github.com/jihanyang/AFN + """ + radius = features.norm(p=2, dim=1).detach() + assert not radius.requires_grad + radius = radius + self.delta_r + loss = ((features.norm(p=2, dim=1) - radius) ** 2).mean() + return loss + + def hafn_loss(self, features): + """ + Adapted from https://github.com/jihanyang/AFN + """ + loss = (features.norm(p=2, dim=1).mean() - self.r) ** 2 + return loss + + def process_batch(self, batch, unlabeled_batch=None): + """ + Overrides single_model_algorithm.process_batch(). + Args: + - batch (tuple of Tensors): a batch of data yielded by data loaders + - unlabeled_batch (tuple of Tensors or None): a batch of data yielded by unlabeled data loader + Output: + - results (dictionary): information about the batch + - y_true (Tensor): ground truth labels for batch + - g (Tensor): groups for batch + - metadata (Tensor): metadata for batch + - features (Tensor): featurizer output for batch + - y_pred (Tensor): full model output for batch + - unlabeled_features (Tensor): featurizer outputs for unlabeled_batch + """ + # Forward pass + x, y_true, metadata = batch + x = x.to(self.device) + y_true = y_true.to(self.device) + g = self.grouper.metadata_to_group(metadata).to(self.device) + features = self.featurizer(x) + y_pred = self.classifier(features) + + results = { + "g": g, + "metadata": metadata, + "y_true": y_true, + "y_pred": y_pred, + "features": features, + } + + if unlabeled_batch is not None: + unlabeled_x, _ = unlabeled_batch + unlabeled_x = unlabeled_x.to(self.device) + results['unlabeled_features'] = self.featurizer(unlabeled_x) + return results + + def objective(self, results): + classification_loss = self.loss.compute( + results["y_pred"], results["y_true"], return_dict=False + ) + + if self.is_training: + f_source = results.pop("features") + f_target = results.pop("unlabeled_features") + feature_norm_penalty = self.afn_loss(f_source) + self.afn_loss(f_target) + else: + feature_norm_penalty = 0.0 + + # Add to results for additional logging + self.save_metric_for_logging( + results, "classification_loss", classification_loss + ) + self.save_metric_for_logging( + results, "feature_norm_penalty", feature_norm_penalty + ) + return classification_loss + self.penalty_weight * feature_norm_penalty \ No newline at end of file diff --git a/examples/algorithms/DANN.py b/examples/algorithms/DANN.py new file mode 100644 index 00000000..7c075669 --- /dev/null +++ b/examples/algorithms/DANN.py @@ -0,0 +1,138 @@ +from typing import Dict, List + +import torch + +from algorithms.single_model_algorithm import SingleModelAlgorithm +from models.domain_adversarial_network import DomainAdversarialNetwork +from models.initializer import initialize_model +from optimizer import initialize_optimizer_with_model_params +from losses import initialize_loss +from utils import concat_input + +class DANN(SingleModelAlgorithm): + """ + Domain-adversarial training of neural networks. + + Original paper: + @inproceedings{dann, + title={Domain-Adversarial Training of Neural Networks}, + author={Ganin, Ustinova, Ajakan, Germain, Larochelle, Laviolette, Marchand and Lempitsky}, + booktitle={Journal of Machine Learning Research 17}, + year={2016} + } + """ + + def __init__( + self, + config, + d_out, + grouper, + loss, + metric, + n_train_steps, + n_domains, + group_ids_to_domains, + ): + # Initialize model + featurizer, classifier = initialize_model( + config, d_out=d_out, is_featurizer=True + ) + model = DomainAdversarialNetwork(featurizer, classifier, n_domains) + parameters_to_optimize: List[Dict] = model.get_parameters_with_lr( + featurizer_lr=config.dann_featurizer_lr, + classifier_lr=config.dann_classifier_lr, + discriminator_lr=config.dann_discriminator_lr, + ) + self.optimizer = initialize_optimizer_with_model_params(config, parameters_to_optimize) + self.domain_loss = initialize_loss('cross_entropy', config) + + # Initialize module + super().__init__( + config=config, + model=model, + grouper=grouper, + loss=loss, + metric=metric, + n_train_steps=n_train_steps, + ) + self.group_ids_to_domains = group_ids_to_domains + + # Algorithm hyperparameters + self.penalty_weight = config.dann_penalty_weight + + # Additional logging + self.logged_fields.append("classification_loss") + self.logged_fields.append("domain_classification_loss") + + def process_batch(self, batch, unlabeled_batch=None): + """ + Overrides single_model_algorithm.process_batch(). + Args: + - batch (tuple of Tensors): a batch of data yielded by data loaders + - unlabeled_batch (tuple of Tensors or None): a batch of data yielded by unlabeled data loader + Output: + - results (dictionary): information about the batch + - y_true (Tensor): ground truth labels for batch + - g (Tensor): groups for batch + - metadata (Tensor): metadata for batch + - y_pred (Tensor): model output for batch + - domains_true (Tensor): true domains for batch and unlabeled batch + - domains_pred (Tensor): predicted domains for batch and unlabeled batch + - unlabeled_features (Tensor): featurizer outputs for unlabeled_batch + """ + # Forward pass + x, y_true, metadata = batch + g = self.grouper.metadata_to_group(metadata).to(self.device) + domains_true = self.group_ids_to_domains[g] + + if unlabeled_batch is not None: + unlabeled_x, unlabeled_metadata = unlabeled_batch + unlabeled_domains_true = self.group_ids_to_domains[ + self.grouper.metadata_to_group(unlabeled_metadata) + ] + + # Concatenate examples and true domains + x_cat = concat_input(x, unlabeled_x) + domains_true = torch.cat([domains_true, unlabeled_domains_true]) + else: + x_cat = x + + x_cat = x_cat.to(self.device) + y_true = y_true.to(self.device) + domains_true = domains_true.to(self.device) + y_pred, domains_pred = self.model(x_cat) + + # Ignore the predicted labels for the unlabeled data + y_pred = y_pred[: len(y_true)] + + return { + "g": g, + "metadata": metadata, + "y_true": y_true, + "y_pred": y_pred, + "domains_true": domains_true, + "domains_pred": domains_pred, + } + + def objective(self, results): + classification_loss = self.loss.compute( + results["y_pred"], results["y_true"], return_dict=False + ) + + if self.is_training: + domain_classification_loss = self.domain_loss.compute( + results.pop("domains_pred"), + results.pop("domains_true"), + return_dict=False, + ) + else: + domain_classification_loss = 0.0 + + # Add to results for additional logging + self.save_metric_for_logging( + results, "classification_loss", classification_loss + ) + self.save_metric_for_logging( + results, "domain_classification_loss", domain_classification_loss + ) + return classification_loss + domain_classification_loss * self.penalty_weight diff --git a/examples/algorithms/ERM.py b/examples/algorithms/ERM.py index 5b4b0492..35f21c88 100644 --- a/examples/algorithms/ERM.py +++ b/examples/algorithms/ERM.py @@ -1,11 +1,12 @@ import torch from algorithms.single_model_algorithm import SingleModelAlgorithm from models.initializer import initialize_model +from utils import move_to class ERM(SingleModelAlgorithm): def __init__(self, config, d_out, grouper, loss, metric, n_train_steps): - model = initialize_model(config, d_out).to(config.device) + model = initialize_model(config, d_out) # initialize module super().__init__( config=config, @@ -15,6 +16,63 @@ def __init__(self, config, d_out, grouper, loss, metric=metric, n_train_steps=n_train_steps, ) + self.use_unlabeled_y = config.use_unlabeled_y # Expect x,y,m from unlabeled loaders and train on the unlabeled y + + def process_batch(self, batch, unlabeled_batch=None): + """ + Overrides single_model_algorithm.process_batch(). + ERM defines its own process_batch to handle if self.use_unlabeled_y is true. + Args: + - batch (tuple of Tensors): a batch of data yielded by data loaders + - unlabeled_batch (tuple of Tensors or None): a batch of data yielded by unlabeled data loader + Output: + - results (dictionary): information about the batch + - y_true (Tensor): ground truth labels for batch + - g (Tensor): groups for batch + - metadata (Tensor): metadata for batch + - y_pred (Tensor): model output for batch + - unlabeled_g (Tensor): groups for unlabeled batch + - unlabeled_metadata (Tensor): metadata for unlabeled batch + - unlabeled_y_pred (Tensor): predictions for unlabeled batch for fully-supervised ERM experiments + - unlabeled_y_true (Tensor): true labels for unlabeled batch for fully-supervised ERM experiments + """ + x, y_true, metadata = batch + x = move_to(x, self.device) + y_true = move_to(y_true, self.device) + g = move_to(self.grouper.metadata_to_group(metadata), self.device) + + outputs = self.get_model_output(x, y_true) + + results = { + 'g': g, + 'y_true': y_true, + 'y_pred': outputs, + 'metadata': metadata, + } + if unlabeled_batch is not None: + if self.use_unlabeled_y: # expect loaders to return x,y,m + x, y, metadata = unlabeled_batch + y = move_to(y, self.device) + else: + x, metadata = unlabeled_batch + x = move_to(x, self.device) + results['unlabeled_metadata'] = metadata + if self.use_unlabeled_y: + results['unlabeled_y_pred'] = self.get_model_output(x, y) + results['unlabeled_y_true'] = y + results['unlabeled_g'] = self.grouper.metadata_to_group(metadata).to(self.device) + return results def objective(self, results): - return self.loss.compute(results['y_pred'], results['y_true'], return_dict=False) + labeled_loss = self.loss.compute(results['y_pred'], results['y_true'], return_dict=False) + if self.use_unlabeled_y and 'unlabeled_y_true' in results: + unlabeled_loss = self.loss.compute( + results['unlabeled_y_pred'], + results['unlabeled_y_true'], + return_dict=False + ) + lab_size = len(results['y_pred']) + unl_size = len(results['unlabeled_y_pred']) + return (lab_size * labeled_loss + unl_size * unlabeled_loss) / (lab_size + unl_size) + else: + return labeled_loss \ No newline at end of file diff --git a/examples/algorithms/IRM.py b/examples/algorithms/IRM.py index 4b90d21a..c2e7ecd3 100644 --- a/examples/algorithms/IRM.py +++ b/examples/algorithms/IRM.py @@ -87,18 +87,13 @@ def objective(self, results): else: penalty_weight = 1.0 - # Package the results - if isinstance(penalty, torch.Tensor): - results['penalty'] = penalty.item() - else: - results['penalty'] = penalty - + self.save_metric_for_logging(results, 'penalty', penalty) return avg_loss + penalty * penalty_weight - def _update(self, results): + def _update(self, results, should_step=True): if self.update_count == self.irm_penalty_anneal_iters: print('Hit IRM penalty anneal iters') # Reset optimizer to deal with the changing penalty weight self.optimizer = initialize_optimizer(self.config, self.model) - super()._update(results) + super()._update(results, should_step=should_step) self.update_count += 1 diff --git a/examples/algorithms/algorithm.py b/examples/algorithms/algorithm.py index 136d21d6..a77bd744 100644 --- a/examples/algorithms/algorithm.py +++ b/examples/algorithms/algorithm.py @@ -1,7 +1,7 @@ -import torch import torch.nn as nn from utils import move_to, detach_and_clone + class Algorithm(nn.Module): def __init__(self, device): super().__init__() @@ -101,4 +101,4 @@ def sanitize_dict(self, in_dict, to_out_device=True): out_dict = detach_and_clone(in_dict) if to_out_device: out_dict = move_to(out_dict, self.out_device) - return out_dict + return out_dict \ No newline at end of file diff --git a/examples/algorithms/deepCORAL.py b/examples/algorithms/deepCORAL.py index e82981d4..0c3e52cf 100644 --- a/examples/algorithms/deepCORAL.py +++ b/examples/algorithms/deepCORAL.py @@ -2,6 +2,7 @@ from models.initializer import initialize_model from algorithms.single_model_algorithm import SingleModelAlgorithm from wilds.common.utils import split_into_groups +from utils import concat_input class DeepCORAL(SingleModelAlgorithm): """ @@ -18,6 +19,9 @@ class DeepCORAL(SingleModelAlgorithm): organization={Springer} } + The original CORAL loss is the distance between second-order statistics (covariances) + of the source and target features. + The CORAL penalty function below is adapted from DomainBed's implementation: https://github.com/facebookresearch/DomainBed/blob/1a61f7ff44b02776619803a1dd12f952528ca531/domainbed/algorithms.py#L539 """ @@ -30,7 +34,7 @@ def __init__(self, config, d_out, grouper, loss, metric, n_train_steps): featurizer, classifier = initialize_model(config, d_out=d_out, is_featurizer=True) featurizer = featurizer.to(config.device) classifier = classifier.to(config.device) - model = torch.nn.Sequential(featurizer, classifier).to(config.device) + model = torch.nn.Sequential(featurizer, classifier) # initialize module super().__init__( config=config, @@ -49,7 +53,7 @@ def __init__(self, config, d_out, grouper, loss, metric, n_train_steps): self.classifier = classifier def coral_penalty(self, x, y): - if x.dim() > 2: + if x.dim() > 2: # featurizers output Tensors of size (batch_size, ..., feature dimensionality). # we flatten to Tensors of size (*, feature dimensionality) x = x.view(-1, x.size(-1)) @@ -65,55 +69,68 @@ def coral_penalty(self, x, y): mean_diff = (mean_x - mean_y).pow(2).mean() cova_diff = (cova_x - cova_y).pow(2).mean() - return mean_diff+cova_diff + return mean_diff + cova_diff - def process_batch(self, batch): + def process_batch(self, batch, unlabeled_batch=None): """ - Override + Overrides single_model_algorithm.process_batch(). + Args: + - batch (tuple of Tensors): a batch of data yielded by data loaders + - unlabeled_batch (tuple of Tensors or None): a batch of data yielded by unlabeled data loader + Output: + - results (dictionary): information about the batch + - y_true (Tensor): ground truth labels for batch + - g (Tensor): groups for batch + - metadata (Tensor): metadata for batch + - unlabeled_g (Tensor): groups for unlabeled batch + - features (Tensor): featurizer output for batch and unlabeled batch + - y_pred (Tensor): full model output for batch and unlabeled batch """ # forward pass x, y_true, metadata = batch - x = x.to(self.device) y_true = y_true.to(self.device) g = self.grouper.metadata_to_group(metadata).to(self.device) - features = self.featurizer(x) - outputs = self.classifier(features) - # package the results results = { 'g': g, 'y_true': y_true, - 'y_pred': outputs, 'metadata': metadata, - 'features': features, - } + } + + if unlabeled_batch is not None: + unlabeled_x, unlabeled_metadata = unlabeled_batch + x = concat_input(x, unlabeled_x) + unlabeled_g = self.grouper.metadata_to_group(unlabeled_metadata).to(self.device) + results['unlabeled_g'] = unlabeled_g + + x = x.to(self.device) + features = self.featurizer(x) + outputs = self.classifier(features) + y_pred = outputs[: len(y_true)] + + results['features'] = features + results['y_pred'] = y_pred return results def objective(self, results): - # extract features - features = results.pop('features') - if self.is_training: - # split into groups - unique_groups, group_indices, _ = split_into_groups(results['g']) - # compute penalty + features = results.pop('features') + + # Split into groups + groups = concat_input(results['g'], results['unlabeled_g']) if 'unlabeled_g' in results else results['g'] + unique_groups, group_indices, _ = split_into_groups(groups) n_groups_per_batch = unique_groups.numel() + + # Compute penalty - perform pairwise comparisons between features of all the groups penalty = torch.zeros(1, device=self.device) for i_group in range(n_groups_per_batch): for j_group in range(i_group+1, n_groups_per_batch): penalty += self.coral_penalty(features[group_indices[i_group]], features[group_indices[j_group]]) if n_groups_per_batch > 1: penalty /= (n_groups_per_batch * (n_groups_per_batch-1) / 2) # get the mean penalty - # save penalty else: penalty = 0. - if isinstance(penalty, torch.Tensor): - results['penalty'] = penalty.item() - else: - results['penalty'] = penalty - - + self.save_metric_for_logging(results, 'penalty', penalty) avg_loss = self.loss.compute(results['y_pred'], results['y_true'], return_dict=False) - return avg_loss + penalty * self.penalty_weight diff --git a/examples/algorithms/fixmatch.py b/examples/algorithms/fixmatch.py new file mode 100644 index 00000000..3a4438cc --- /dev/null +++ b/examples/algorithms/fixmatch.py @@ -0,0 +1,143 @@ +import torch +import torch.nn.functional as F + +from models.initializer import initialize_model +from algorithms.single_model_algorithm import SingleModelAlgorithm +from configs.supported import process_pseudolabels_functions +from utils import detach_and_clone + + +class FixMatch(SingleModelAlgorithm): + """ + FixMatch. + This algorithm was originally proposed as a semi-supervised learning algorithm. + + Loss is of the form + \ell_s + \lambda * \ell_u + where + \ell_s = cross-entropy with true labels using weakly augmented labeled examples + \ell_u = cross-entropy with pseudolabel generated using weak augmentation and prediction + using strong augmentation + + Original paper: + @article{sohn2020fixmatch, + title={Fixmatch: Simplifying semi-supervised learning with consistency and confidence}, + author={Sohn, Kihyuk and Berthelot, David and Li, Chun-Liang and Zhang, Zizhao and Carlini, Nicholas and Cubuk, Ekin D and Kurakin, Alex and Zhang, Han and Raffel, Colin}, + journal={arXiv preprint arXiv:2001.07685}, + year={2020} + } + """ + def __init__(self, config, d_out, grouper, loss, metric, n_train_steps): + featurizer, classifier = initialize_model( + config, d_out=d_out, is_featurizer=True + ) + model = torch.nn.Sequential(featurizer, classifier) + + # initialize module + super().__init__( + config=config, + model=model, + grouper=grouper, + loss=loss, + metric=metric, + n_train_steps=n_train_steps, + ) + # algorithm hyperparameters + self.fixmatch_lambda = config.self_training_lambda + self.confidence_threshold = config.self_training_threshold + self.process_pseudolabels_function = process_pseudolabels_functions[config.process_pseudolabels_function] + + # Additional logging + self.logged_fields.append("pseudolabels_kept_frac") + self.logged_fields.append("classification_loss") + self.logged_fields.append("consistency_loss") + + def process_batch(self, batch, unlabeled_batch=None): + """ + Overrides single_model_algorithm.process_batch(). + Args: + - batch (x, y, m): a batch of data yielded by data loaders + - unlabeled_batch: examples ((x_weak, x_strong), m) where x_weak is weakly augmented but x_strong is strongly augmented + Output: + - results (dictionary): information about the batch + - y_true (Tensor): ground truth labels for batch + - g (Tensor): groups for batch + - metadata (Tensor): metadata for batch + - y_pred (Tensor): model output for batch + - unlabeled_g (Tensor): groups for unlabeled batch + - unlabeled_metadata (Tensor): metadata for unlabeled batch + - unlabeled_weak_y_pseudo (Tensor): pseudolabels on x_weak of the unlabeled batch, already thresholded + - unlabeled_strong_y_pred (Tensor): model output on x_strong of the unlabeled batch, already thresholded + """ + # Labeled examples + x, y_true, metadata = batch + x = x.to(self.device) + y_true = y_true.to(self.device) + g = self.grouper.metadata_to_group(metadata).to(self.device) + # package the results + results = { + 'g': g, + 'y_true': y_true, + 'metadata': metadata + } + pseudolabels_kept_frac = 0 + + # Unlabeled examples + if unlabeled_batch is not None: + (x_weak, x_strong), metadata = unlabeled_batch + x_weak = x_weak.to(self.device) + x_strong = x_strong.to(self.device) + + g = self.grouper.metadata_to_group(metadata).to(self.device) + results['unlabeled_metadata'] = metadata + results['unlabeled_g'] = g + + with torch.no_grad(): + outputs = self.model(x_weak) + _, pseudolabels, pseudolabels_kept_frac, mask = self.process_pseudolabels_function( + outputs, + self.confidence_threshold, + ) + results['unlabeled_weak_y_pseudo'] = detach_and_clone(pseudolabels) + + self.save_metric_for_logging( + results, "pseudolabels_kept_frac", pseudolabels_kept_frac + ) + + # Concat and call forward + n_lab = x.shape[0] + if unlabeled_batch is not None: + x_concat = torch.cat((x, x_strong), dim=0) + else: + x_concat = x + + outputs = self.model(x_concat) + results['y_pred'] = outputs[:n_lab] + if unlabeled_batch is not None: + results['unlabeled_strong_y_pred'] = outputs[n_lab:] if mask is None else outputs[n_lab:][mask] + return results + + def objective(self, results): + # Labeled loss + classification_loss = self.loss.compute(results['y_pred'], results['y_true'], return_dict=False) + + # Pseudolabeled loss + if 'unlabeled_weak_y_pseudo' in results: + loss_output = self.loss.compute( + results['unlabeled_strong_y_pred'], + results['unlabeled_weak_y_pseudo'], + return_dict=False, + ) + consistency_loss = loss_output * results['pseudolabels_kept_frac'] + else: + consistency_loss = 0 + + # Add to results for additional logging + self.save_metric_for_logging( + results, "classification_loss", classification_loss + ) + self.save_metric_for_logging( + results, "consistency_loss", consistency_loss + ) + + return classification_loss + self.fixmatch_lambda * consistency_loss diff --git a/examples/algorithms/groupDRO.py b/examples/algorithms/groupDRO.py index f710c4fe..bbac862b 100644 --- a/examples/algorithms/groupDRO.py +++ b/examples/algorithms/groupDRO.py @@ -18,7 +18,7 @@ def __init__(self, config, d_out, grouper, loss, metric, n_train_steps, is_group # check config assert config.uniform_over_groups # initialize model - model = initialize_model(config, d_out).to(config.device) + model = initialize_model(config, d_out) # initialize module super().__init__( config=config, @@ -38,20 +38,7 @@ def __init__(self, config, d_out, grouper, loss, metric, n_train_steps, is_group self.group_weights = self.group_weights/self.group_weights.sum() self.group_weights = self.group_weights.to(self.device) - def process_batch(self, batch): - """ - A helper function for update() and evaluate() that processes the batch - Args: - - batch (tuple of Tensors): a batch of data yielded by data loaders - Output: - - results (dictionary): information about the batch - - g (Tensor) - - y_true (Tensor) - - metadata (Tensor) - - loss (Tensor) - - metrics (Tensor) - all Tensors are of size (batch_size,) - """ + def process_batch(self, batch, unlabeled_batch=None): results = super().process_batch(batch) results['group_weight'] = self.group_weights return results @@ -74,7 +61,7 @@ def objective(self, results): return_dict=False) return group_losses @ self.group_weights - def _update(self, results): + def _update(self, results, should_step=True): """ Process the batch, update the log, and update the model, group weights, and scheduler. Args: @@ -101,4 +88,4 @@ def _update(self, results): # save updated group weights results['group_weight'] = self.group_weights # update model - super()._update(results) + super()._update(results, should_step=should_step) diff --git a/examples/algorithms/initializer.py b/examples/algorithms/initializer.py index a25afc8f..a3bf75e2 100644 --- a/examples/algorithms/initializer.py +++ b/examples/algorithms/initializer.py @@ -1,40 +1,27 @@ +from types import SimpleNamespace +import torch +import math from wilds.common.utils import get_counts from algorithms.ERM import ERM +from algorithms.AFN import AFN +from algorithms.DANN import DANN from algorithms.groupDRO import GroupDRO from algorithms.deepCORAL import DeepCORAL from algorithms.IRM import IRM -from configs.supported import algo_log_metrics +from algorithms.fixmatch import FixMatch +from algorithms.pseudolabel import PseudoLabel +from algorithms.noisy_student import NoisyStudent +from configs.supported import algo_log_metrics, losses from losses import initialize_loss -def initialize_algorithm(config, datasets, train_grouper): +def initialize_algorithm(config, datasets, train_grouper, unlabeled_dataset=None): train_dataset = datasets['train']['dataset'] train_loader = datasets['train']['loader'] - - # Configure the final layer of the networks used - # The code below are defaults. Edit this if you need special config for your model. - if train_dataset.is_classification: - if train_dataset.y_size == 1: - # For single-task classification, we have one output per class - d_out = train_dataset.n_classes - elif train_dataset.y_size is None: - d_out = train_dataset.n_classes - elif (train_dataset.y_size > 1) and (train_dataset.n_classes == 2): - # For multi-task binary classification (each output is the logit for each binary class) - d_out = train_dataset.y_size - else: - raise RuntimeError('d_out not defined.') - elif train_dataset.is_detection: - # For detection, d_out is the number of classes - d_out = train_dataset.n_classes - if config.algorithm in ['deepCORAL', 'IRM']: - raise ValueError(f'{config.algorithm} is not currently supported for detection datasets.') - else: - # For regression, we have one output per target dimension - d_out = train_dataset.y_size + d_out = infer_d_out(train_dataset, config) # Other config - n_train_steps = len(train_loader) * config.n_epochs - loss = initialize_loss(config, d_out) + n_train_steps = math.ceil(len(train_loader)/config.gradient_accumulation_steps) * config.n_epochs + loss = initialize_loss(config.loss_function, config) metric = algo_log_metrics[config.algo_log_metric] if config.algorithm == 'ERM': @@ -72,7 +59,97 @@ def initialize_algorithm(config, datasets, train_grouper): loss=loss, metric=metric, n_train_steps=n_train_steps) + elif config.algorithm == 'DANN': + if unlabeled_dataset is not None: + unlabeled_dataset = unlabeled_dataset['dataset'] + metadata_array = torch.cat( + [train_dataset.metadata_array, unlabeled_dataset.metadata_array] + ) + else: + metadata_array = train_dataset.metadata_array + + groups = train_grouper.metadata_to_group(metadata_array) + group_counts = get_counts(groups, train_grouper.n_groups) + group_ids_to_domains = group_counts.tolist() + domain_idx = 0 + for i, count in enumerate(group_ids_to_domains): + if count > 0: + group_ids_to_domains[i] = domain_idx + domain_idx += 1 + group_ids_to_domains = torch.tensor(group_ids_to_domains, dtype=torch.long) + algorithm = DANN( + config=config, + d_out=d_out, + grouper=train_grouper, + loss=loss, + metric=metric, + n_train_steps=n_train_steps, + n_domains = domain_idx, + group_ids_to_domains=group_ids_to_domains, + ) + elif config.algorithm == 'AFN': + algorithm = AFN( + config=config, + d_out=d_out, + grouper=train_grouper, + loss=loss, + metric=metric, + n_train_steps=n_train_steps + ) + elif config.algorithm == 'FixMatch': + algorithm = FixMatch( + config=config, + d_out=d_out, + grouper=train_grouper, + loss=loss, + metric=metric, + n_train_steps=n_train_steps) + elif config.algorithm == 'PseudoLabel': + algorithm = PseudoLabel( + config=config, + d_out=d_out, + grouper=train_grouper, + loss=loss, + metric=metric, + n_train_steps=n_train_steps) + elif config.algorithm == 'NoisyStudent': + if config.soft_pseudolabels: + unlabeled_loss = initialize_loss("cross_entropy_logits", config) + else: + unlabeled_loss = loss + algorithm = NoisyStudent( + config=config, + d_out=d_out, + grouper=train_grouper, + loss=loss, + unlabeled_loss=unlabeled_loss, + metric=metric, + n_train_steps=n_train_steps) else: raise ValueError(f"Algorithm {config.algorithm} not recognized") return algorithm + +def infer_d_out(train_dataset, config): + # Configure the final layer of the networks used + # The code below are defaults. Edit this if you need special config for your model. + if train_dataset.is_classification: + if train_dataset.y_size == 1: + # For single-task classification, we have one output per class + d_out = train_dataset.n_classes + elif train_dataset.y_size is None: + d_out = train_dataset.n_classes + elif (train_dataset.y_size > 1) and (train_dataset.n_classes == 2): + # For multi-task binary classification (each output is the logit for each binary class) + d_out = train_dataset.y_size + else: + raise RuntimeError('d_out not defined.') + elif train_dataset.is_detection: + # For detection, d_out is the number of classes + d_out = train_dataset.n_classes + if config.algorithm in ['deepCORAL', 'IRM']: + raise ValueError(f'{config.algorithm} is not currently supported for detection datasets.') + else: + # For regression, we have one output per target dimension + d_out = train_dataset.y_size + return d_out diff --git a/examples/algorithms/noisy_student.py b/examples/algorithms/noisy_student.py new file mode 100644 index 00000000..d53e3b13 --- /dev/null +++ b/examples/algorithms/noisy_student.py @@ -0,0 +1,158 @@ +import torch +import torch.nn as nn + +from configs.supported import process_pseudolabels_functions +from models.initializer import initialize_model +from algorithms.single_model_algorithm import SingleModelAlgorithm +from utils import move_to, collate_list, concat_input + + +class DropoutModel(nn.Module): + def __init__(self, featurizer, classifier, dropout_rate): + super().__init__() + self.featurizer = featurizer + self.dropout = nn.Dropout(p=dropout_rate) + self.classifier = classifier + self.needs_y = featurizer.needs_y + + def forward(self, x): + features = self.featurizer(x) + features_sparse = self.dropout(features) + return self.classifier(features_sparse) + + +class NoisyStudent(SingleModelAlgorithm): + """ + Noisy Student. + This algorithm was originally proposed as a semi-supervised learning algorithm. + + One run of this codebase gives us one iteration (load a teacher, train student). To run another iteration, + re-run the previous command, pointing config.teacher_model_path to the trained student weights. + + To warm start the student model, point config.pretrained_model_path to config.teacher_model_path + + Based on the original paper, loss is of the form + \ell_s + \ell_u + where + \ell_s = cross-entropy with true labels; student predicts with noise + \ell_u = cross-entropy with pseudolabel generated without noise; student predicts with noise + The student is noised using: + - Input images are augmented using RandAugment + - Single dropout layer before final classifier (fc) layer + We do not use stochastic depth. + + Pseudolabels are generated in run_expt.py on unlabeled images that have only been randomly cropped and flipped ("weak" transform). + By default, we use hard pseudolabels; use the --soft_pseudolabels flag to add soft pseudolabels. + + This code only supports a teacher that is the same class as the student (e.g. both densenet121s) + + Original paper: + @inproceedings{xie2020self, + title={Self-training with noisy student improves imagenet classification}, + author={Xie, Qizhe and Luong, Minh-Thang and Hovy, Eduard and Le, Quoc V}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={10687--10698}, + year={2020} + } + """ + + def __init__( + self, config, d_out, grouper, loss, unlabeled_loss, metric, n_train_steps + ): + # initialize student model with dropout before last layer + if config.noisystudent_add_dropout: + featurizer, classifier = initialize_model( + config, d_out=d_out, is_featurizer=True + ) + student_model = DropoutModel( + featurizer, classifier, config.noisystudent_dropout_rate + ).to(config.device) + else: + student_model = initialize_model(config, d_out=d_out, is_featurizer=False) + self.process_pseudolabels_function = process_pseudolabels_functions[ + config.process_pseudolabels_function + ] + + # initialize module + super().__init__( + config=config, + model=student_model, + grouper=grouper, + loss=loss, + metric=metric, + n_train_steps=n_train_steps, + ) + self.unlabeled_loss = unlabeled_loss + # additional logging + self.logged_fields.append("classification_loss") + self.logged_fields.append("consistency_loss") + + def process_batch(self, batch, unlabeled_batch=None): + """ + Overrides single_model_algorithm.process_batch(). + Args: + - batch (x, y, m): a batch of data yielded by data loaders + - unlabeled_batch: examples (x, y_pseudo, m) where y_pseudo is an already-computed teacher pseudolabel + Output: + - results (dictionary): information about the batch + - y_true (Tensor): ground truth labels for batch + - g (Tensor): groups for batch + - metadata (Tensor): metadata for batch + - y_pred (Tensor): model output for batch + - unlabeled_g (Tensor): groups for unlabeled batch + - unlabeled_metadata (Tensor): metadata for unlabeled batch + - unlabeled_y_pseudo (Tensor): pseudolabels for unlabeled batch (from loader) + - unlabeled_y_pred (Tensor): model output on unlabeled batch + """ + # Labeled examples + x, y_true, metadata = batch + n_lab = len(metadata) + x = move_to(x, self.device) + y_true = move_to(y_true, self.device) + g = move_to(self.grouper.metadata_to_group(metadata), self.device) + # package the results + results = {"g": g, "y_true": y_true, "metadata": metadata} + + # Unlabeled examples with pseudolabels + if unlabeled_batch is not None: + x_unlab, y_pseudo, metadata_unlab = unlabeled_batch + x_unlab = move_to(x_unlab, self.device) + g_unlab = move_to(self.grouper.metadata_to_group(metadata_unlab), self.device) + y_pseudo = move_to(y_pseudo, self.device) + results["unlabeled_metadata"] = metadata_unlab + results["unlabeled_y_pseudo"] = y_pseudo + results["unlabeled_g"] = g_unlab + + x_cat = concat_input(x, x_unlab) + y_cat = collate_list([y_true, y_pseudo]) if self.model.needs_y else None + outputs = self.get_model_output(x_cat, y_cat) + results["y_pred"] = outputs[:n_lab] + results["unlabeled_y_pred"] = outputs[n_lab:] + else: + results["y_pred"] = self.get_model_output(x, y_true) + + return results + + def objective(self, results): + # Labeled loss + classification_loss = self.loss.compute( + results["y_pred"], results["y_true"], return_dict=False + ) + + # Pseudolabel loss + if "unlabeled_y_pseudo" in results: + consistency_loss = self.unlabeled_loss.compute( + results["unlabeled_y_pred"], + results["unlabeled_y_pseudo"], + return_dict=False, + ) + else: + consistency_loss = 0 + + # Add to results for additional logging + self.save_metric_for_logging( + results, "classification_loss", classification_loss + ) + self.save_metric_for_logging(results, "consistency_loss", consistency_loss) + + return classification_loss + consistency_loss diff --git a/examples/algorithms/pseudolabel.py b/examples/algorithms/pseudolabel.py new file mode 100644 index 00000000..785cbec7 --- /dev/null +++ b/examples/algorithms/pseudolabel.py @@ -0,0 +1,163 @@ +import torch +import torch.nn.functional as F +from models.initializer import initialize_model +from algorithms.ERM import ERM +from algorithms.single_model_algorithm import SingleModelAlgorithm +from scheduler import LinearScheduleWithWarmupAndThreshold +from wilds.common.utils import split_into_groups, numel +from configs.supported import process_pseudolabels_functions +import copy +from utils import load, move_to, detach_and_clone, collate_list, concat_input + + +class PseudoLabel(SingleModelAlgorithm): + """ + PseudoLabel. + This is a vanilla pseudolabeling algorithm which updates the model per batch and incorporates a confidence threshold. + + Original paper: + @inproceedings{lee2013pseudo, + title={Pseudo-label: The simple and efficient semi-supervised learning method for deep neural networks}, + author={Lee, Dong-Hyun and others}, + booktitle={Workshop on challenges in representation learning, ICML}, + volume={3}, + number={2}, + pages={896}, + year={2013} + } + """ + def __init__(self, config, d_out, grouper, loss, metric, n_train_steps): + model = initialize_model(config, d_out=d_out) + model = model.to(config.device) + # initialize module + super().__init__( + config=config, + model=model, + grouper=grouper, + loss=loss, + metric=metric, + n_train_steps=n_train_steps, + ) + # algorithm hyperparameters + self.lambda_scheduler = LinearScheduleWithWarmupAndThreshold( + max_value=config.self_training_lambda, + step_every_batch=True, # step per batch + last_warmup_step=0, + threshold_step=config.pseudolabel_T2 * n_train_steps + ) + self.schedulers.append(self.lambda_scheduler) + self.scheduler_metric_names.append(None) + self.confidence_threshold = config.self_training_threshold + if config.process_pseudolabels_function is not None: + self.process_pseudolabels_function = process_pseudolabels_functions[config.process_pseudolabels_function] + # Additional logging + self.logged_fields.append("pseudolabels_kept_frac") + self.logged_fields.append("classification_loss") + self.logged_fields.append("consistency_loss") + + def process_batch(self, batch, unlabeled_batch=None): + """ + Overrides single_model_algorithm.process_batch(). + Args: + - batch (tuple of Tensors): a batch of data yielded by data loaders + - unlabeled_batch (tuple of Tensors or None): a batch of data yielded by unlabeled data loader + Output: + - results (dictionary): information about the batch + - y_true (Tensor): ground truth labels for batch + - g (Tensor): groups for batch + - metadata (Tensor): metadata for batch + - y_pred (Tensor): model output for batch + - unlabeled_g (Tensor): groups for unlabeled batch + - unlabeled_metadata (Tensor): metadata for unlabeled batch + - unlabeled_y_pseudo (Tensor): pseudolabels on the unlabeled batch, already thresholded + - unlabeled_y_pred (Tensor): model output on the unlabeled batch, already thresholded + """ + # Labeled examples + x, y_true, metadata = batch + n_lab = len(metadata) + x = move_to(x, self.device) + y_true = move_to(y_true, self.device) + g = move_to(self.grouper.metadata_to_group(metadata), self.device) + + # package the results + results = { + 'g': g, + 'y_true': y_true, + 'metadata': metadata + } + + if unlabeled_batch is not None: + x_unlab, metadata_unlab = unlabeled_batch + x_unlab = move_to(x_unlab, self.device) + g_unlab = move_to(self.grouper.metadata_to_group(metadata_unlab), self.device) + results['unlabeled_metadata'] = metadata_unlab + results['unlabeled_g'] = g_unlab + + # Special case for models where we need to pass in y: + # we handle these in two separate forward passes + # and turn off training to avoid errors when y is None + # Note: we have to specifically turn training in the model off + # instead of using self.train, which would reset the log + if self.model.needs_y: + self.model.train(mode=False) + unlabeled_output = self.get_model_output(x_unlab, None) + + _, unlabeled_y_pseudo, pseudolabels_kept_frac, mask = self.process_pseudolabels_function( + unlabeled_output, + self.confidence_threshold + ) + x_unlab = x_unlab[mask] + + self.model.train(mode=True) + outputs = self.get_model_output( + torch.cat((x, x_unlab), dim=0), + collate_list([y_true, unlabeled_y_pseudo]), + ) + unlabeled_y_pred = outputs[n_lab:] + else: + x_cat = concat_input(x, x_unlab) + outputs = self.get_model_output(x_cat, None) + unlabeled_output = outputs[n_lab:] + unlabeled_y_pred, unlabeled_y_pseudo, pseudolabels_kept_frac, _ = self.process_pseudolabels_function( + unlabeled_output, + self.confidence_threshold + ) + + results['y_pred'] = outputs[:n_lab] + results['unlabeled_y_pred'] = unlabeled_y_pred + results['unlabeled_y_pseudo'] = detach_and_clone(unlabeled_y_pseudo) + else: + results['y_pred'] = self.get_model_output(x, y_true) + pseudolabels_kept_frac = 0 + + self.save_metric_for_logging( + results, "pseudolabels_kept_frac", pseudolabels_kept_frac + ) + return results + + def objective(self, results): + # Labeled loss + classification_loss = self.loss.compute( + results['y_pred'], + results['y_true'], + return_dict=False) + # Pseudolabeled loss + if 'unlabeled_y_pseudo' in results: + loss_output = self.loss.compute( + results['unlabeled_y_pred'], + results['unlabeled_y_pseudo'], + return_dict=False, + ) + consistency_loss = loss_output * results['pseudolabels_kept_frac'] + else: + consistency_loss = 0 + + # Add to results for additional logging + self.save_metric_for_logging( + results, "classification_loss", classification_loss + ) + self.save_metric_for_logging( + results, "consistency_loss", consistency_loss + ) + + return classification_loss + self.lambda_scheduler.value * consistency_loss diff --git a/examples/algorithms/single_model_algorithm.py b/examples/algorithms/single_model_algorithm.py index f01c21bb..552ff489 100644 --- a/examples/algorithms/single_model_algorithm.py +++ b/examples/algorithms/single_model_algorithm.py @@ -1,7 +1,9 @@ import torch + from algorithms.group_algorithm import GroupAlgorithm from scheduler import initialize_scheduler from optimizer import initialize_optimizer +from torch.nn import DataParallel from torch.nn.utils import clip_grad_norm_ from utils import move_to @@ -18,10 +20,20 @@ def __init__(self, config, model, grouper, loss, metric, n_train_steps): logged_metrics.append(self.metric) else: self.metric = None + # initialize models, optimizers, and schedulers - self.optimizer = initialize_optimizer(config, model) + if not hasattr(self, 'optimizer') or self.optimizer is None: + self.optimizer = initialize_optimizer(config, model) self.max_grad_norm = config.max_grad_norm scheduler = initialize_scheduler(config, self.optimizer, n_train_steps) + + if config.use_data_parallel: + model = DataParallel(model) + model.to(config.device) + + self.batch_idx = 0 + self.gradient_accumulation_steps = config.gradient_accumulation_steps + # initialize the module super().__init__( device=config.device, @@ -34,39 +46,51 @@ def __init__(self, config, model, grouper, loss, metric, n_train_steps): ) self.model = model - def process_batch(self, batch): + def get_model_output(self, x, y_true): + if self.model.needs_y: + if self.training: + outputs = self.model(x, y_true) + else: + outputs = self.model(x, None) + else: + outputs = self.model(x) + return outputs + + def process_batch(self, batch, unlabeled_batch=None): """ A helper function for update() and evaluate() that processes the batch Args: - batch (tuple of Tensors): a batch of data yielded by data loaders + - unlabeled_batch (tuple of Tensors or None): a batch of data yielded by unlabeled data loader Output: - results (dictionary): information about the batch - - y_true (Tensor) - - g (Tensor) - - metadata (Tensor) - - output (Tensor) - - y_true + - y_true (Tensor): ground truth labels for batch + - g (Tensor): groups for batch + - metadata (Tensor): metadata for batch + - y_pred (Tensor): model output for batch + - unlabeled_g (Tensor): groups for unlabeled batch + - unlabeled_metadata (Tensor): metadata for unlabeled batch + - unlabeled_features (Tensor): features for unlabeled batch """ x, y_true, metadata = batch x = move_to(x, self.device) y_true = move_to(y_true, self.device) g = move_to(self.grouper.metadata_to_group(metadata), self.device) + outputs = self.get_model_output(x, y_true) - if self.model.needs_y: - if self.training: - outputs = self.model(x, y_true) - else: - outputs = self.model(x, None) - else: - outputs = self.model(x) - results = { 'g': g, 'y_true': y_true, 'y_pred': outputs, 'metadata': metadata, - } + } + if unlabeled_batch is not None: + x, metadata = unlabeled_batch + x = x.to(self.device) + results['unlabeled_metadata'] = metadata + results['unlabeled_features'] = self.featurizer(x) + results['unlabeled_g'] = self.grouper.metadata_to_group(metadata).to(self.device) return results def objective(self, results): @@ -92,11 +116,14 @@ def evaluate(self, batch): self.update_log(results) return self.sanitize_dict(results) - def update(self, batch): + def update(self, batch, unlabeled_batch=None, is_epoch_end=False): """ Process the batch, update the log, and update the model Args: - batch (tuple of Tensors): a batch of data yielded by data loaders + - unlabeled_batch (tuple of Tensors or None): a batch of data yielded by unlabeled data loader + - is_epoch_end: whether this batch is the last batch of the epoch. if so, force optimizer to step, + regardless of whether this batch idx divides self.gradient_accumulation_steps evenly Output: - results (dictionary): information about the batch, such as: - g (Tensor) @@ -107,14 +134,27 @@ def update(self, batch): - objective (float) """ assert self.is_training - # process batch - results = self.process_batch(batch) - self._update(results) - # log results + + # process this batch + results = self.process_batch(batch, unlabeled_batch) + + # update running statistics and update model if we've reached end of effective batch + self._update( + results, + should_step=(((self.batch_idx + 1) % self.gradient_accumulation_steps == 0) or (is_epoch_end)) + ) self.update_log(results) + + # iterate batch index + if is_epoch_end: + self.batch_idx = 0 + else: + self.batch_idx += 1 + + # return only this batch's results return self.sanitize_dict(results) - def _update(self, results): + def _update(self, results, should_step=False): """ Computes the objective and updates the model. Also updates the results dictionary yielded by process_batch(). @@ -123,13 +163,26 @@ def _update(self, results): # compute objective objective = self.objective(results) results['objective'] = objective.item() - # update - self.model.zero_grad() objective.backward() - if self.max_grad_norm: - clip_grad_norm_(self.model.parameters(), self.max_grad_norm) - self.optimizer.step() - self.step_schedulers( - is_epoch=False, - metrics=results, - log_access=False) + + # update model and logs based on effective batch + if should_step: + if self.max_grad_norm: + clip_grad_norm_(self.model.parameters(), self.max_grad_norm) + self.optimizer.step() + self.step_schedulers( + is_epoch=False, + metrics=self.log_dict, + log_access=False) + self.model.zero_grad() + + def save_metric_for_logging(self, results, metric, value): + if isinstance(value, torch.Tensor): + if value.numel() == 1: + results[metric] = value.item() + else: + raise ValueError( + f"Metric value can only be a number or single-element tensor. value={value}" + ) + else: + results[metric] = value diff --git a/examples/configs/algorithm.py b/examples/configs/algorithm.py index 61787f07..72fdd33d 100644 --- a/examples/configs/algorithm.py +++ b/examples/configs/algorithm.py @@ -3,6 +3,7 @@ 'train_loader': 'standard', 'uniform_over_groups': False, 'eval_loader': 'standard', + 'randaugment_n': 2, # When running ERM + data augmentation }, 'groupDRO': { 'train_loader': 'standard', @@ -17,6 +18,8 @@ 'distinct_groups': True, 'eval_loader': 'standard', 'coral_penalty_weight': 1., + 'randaugment_n': 2, + 'additional_train_transform': 'randaugment', # Apply strong augmentation to labeled & unlabeled examples }, 'IRM': { 'train_loader': 'group', @@ -25,5 +28,55 @@ 'eval_loader': 'standard', 'irm_lambda': 100., 'irm_penalty_anneal_iters': 500, + }, + 'DANN': { + 'train_loader': 'group', + 'uniform_over_groups': True, + 'distinct_groups': True, + 'eval_loader': 'standard', + 'randaugment_n': 2, + 'additional_train_transform': 'randaugment', # Apply strong augmentation to labeled & unlabeled examples + }, + 'AFN': { + 'train_loader': 'standard', + 'uniform_over_groups': False, + 'eval_loader': 'standard', + 'use_hafn': False, + 'afn_penalty_weight': 0.01, + 'safn_delta_r': 1.0, + 'hafn_r': 1.0, + 'additional_train_transform': 'randaugment', # Apply strong augmentation to labeled & unlabeled examples + 'randaugment_n': 2, + }, + 'FixMatch': { + 'train_loader': 'standard', + 'uniform_over_groups': False, + 'eval_loader': 'standard', + 'self_training_lambda': 1, + 'self_training_threshold': 0.7, + 'scheduler': 'FixMatchLR', + 'randaugment_n': 2, + 'additional_train_transform': 'randaugment', # Apply strong augmentation to labeled examples + }, + 'PseudoLabel': { + 'train_loader': 'standard', + 'uniform_over_groups': False, + 'eval_loader': 'standard', + 'self_training_lambda': 1, + 'self_training_threshold': 0.7, + 'pseudolabel_T2': 0.4, + 'scheduler': 'FixMatchLR', + 'randaugment_n': 2, + 'additional_train_transform': 'randaugment', # Apply strong augmentation to labeled & unlabeled examples + }, + 'NoisyStudent': { + 'train_loader': 'standard', + 'uniform_over_groups': False, + 'eval_loader': 'standard', + 'noisystudent_add_dropout': True, + 'noisystudent_dropout_rate': 0.5, + 'scheduler': 'FixMatchLR', + 'randaugment_n': 2, + 'additional_train_transform': 'randaugment', # Apply strong augmentation to labeled & unlabeled examples } } diff --git a/examples/configs/data_loader.py b/examples/configs/data_loader.py index 8ddaa267..eab058f5 100644 --- a/examples/configs/data_loader.py +++ b/examples/configs/data_loader.py @@ -3,5 +3,9 @@ 'num_workers': 4, 'pin_memory': True, }, + 'unlabeled_loader_kwargs': { + 'num_workers': 8, + 'pin_memory': True, + }, 'n_groups_per_batch': 4, } diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index e8bbe0d2..3f34070a 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -7,17 +7,28 @@ 'loss_function': 'cross_entropy', 'algo_log_metric': 'accuracy', 'batch_size': 8, + 'unlabeled_batch_size': 8, 'lr': 1e-5, 'weight_decay': 0.01, 'n_epochs': 3, 'n_groups_per_batch': 2, + 'unlabeled_n_groups_per_batch': 2, 'irm_lambda': 1.0, 'coral_penalty_weight': 1.0, + 'dann_penalty_weight': 1.0, + 'dann_featurizer_lr': 1e-6, + 'dann_classifier_lr': 1e-5, + 'dann_discriminator_lr': 1e-5, 'loader_kwargs': { 'num_workers': 1, 'pin_memory': True, }, + 'unlabeled_loader_kwargs': { + 'num_workers': 1, + 'pin_memory': True, + }, 'process_outputs_function': 'multiclass_logits_to_pred', + 'process_pseudolabels_function': 'pseudolabel_multiclass_logits', }, 'bdd100k': { 'split_scheme': 'official', @@ -50,14 +61,21 @@ 'optimizer_kwargs': {'momentum': 0.9}, 'scheduler': None, 'batch_size': 32, + 'unlabeled_batch_size': 32, 'lr': 0.001, 'weight_decay': 0.01, - 'n_epochs': 5, + 'n_epochs': 10, 'n_groups_per_batch': 2, + 'unlabeled_n_groups_per_batch': 2, 'irm_lambda': 1.0, 'coral_penalty_weight': 0.1, + 'dann_penalty_weight': 0.1, + 'dann_featurizer_lr': 0.0001, + 'dann_classifier_lr': 0.001, + 'dann_discriminator_lr': 0.001, 'algo_log_metric': 'accuracy', 'process_outputs_function': 'multiclass_logits_to_pred', + 'process_pseudolabels_function': 'pseudolabel_multiclass_logits', }, 'celebA': { 'split_scheme': 'official', @@ -87,26 +105,80 @@ 'val_metric': 'acc_wg', 'val_metric_decreasing': False, 'batch_size': 16, + 'unlabeled_batch_size': 16, 'lr': 1e-5, 'weight_decay': 0.01, 'n_epochs': 5, + 'n_groups_per_batch': 1, + 'unlabeled_n_groups_per_batch': 1, 'algo_log_metric': 'accuracy', 'max_token_length': 300, 'irm_lambda': 1.0, 'coral_penalty_weight': 10.0, + 'dann_penalty_weight': 1.0, + 'dann_featurizer_lr': 1e-6, + 'dann_classifier_lr': 1e-5, + 'dann_discriminator_lr': 1e-5, 'loader_kwargs': { 'num_workers': 1, 'pin_memory': True, }, + 'unlabeled_loader_kwargs': { + 'num_workers': 1, + 'pin_memory': True, + }, 'process_outputs_function': 'multiclass_logits_to_pred', + 'process_pseudolabels_function': 'pseudolabel_multiclass_logits', + }, + "domainnet": { + "split_scheme": "official", + "dataset_kwargs": { + "source_domain": "real", + "target_domain": "sketch", + "use_sentry": False, + }, + "model": "resnet50", + "model_kwargs": {"pretrained": True}, + "transform": "image_resize", + "resize_scale": 256.0 / 224.0, + "target_resolution": (224, 224), + "loss_function": "cross_entropy", + "groupby_fields": [ + "category", + ], + "val_metric": "acc_avg", + "val_metric_decreasing": False, + "batch_size": 96, + "unlabeled_batch_size": 224, + "optimizer": "SGD", + "optimizer_kwargs": { + "momentum": 0.9, + }, + "lr": 0.0007035737028722148, + "weight_decay": 1e-4, + "n_epochs": 25, + "n_groups_per_batch": 4, + "unlabeled_n_groups_per_batch": 4, + "irm_lambda": 1.0, + "coral_penalty_weight": 1.0, + "dann_penalty_weight": 1.0, + "dann_featurizer_lr": 0.001, + "dann_classifier_lr": 0.01, + "dann_discriminator_lr": 0.01, + "algo_log_metric": "accuracy", + "process_outputs_function": "multiclass_logits_to_pred", + "process_pseudolabels_function": "pseudolabel_multiclass_logits", + "loader_kwargs": { + "num_workers": 2, + "pin_memory": True, + }, }, 'encode': { 'split_scheme': 'official', 'model': 'unet-seq', 'model_kwargs': {'n_channels_in': 5}, 'loader_kwargs': {'num_workers': 1}, # pybigwig seems to have trouble with multiprocessing - 'train_transform': None, - 'eval_transform': None, + 'transform': None, 'loss_function': 'multitask_bce', 'groupby_fields': ['celltype'], 'val_metric': 'avgprec-macro_all', @@ -125,7 +197,7 @@ }, 'fmow': { 'split_scheme': 'official', - 'dataset_kwargs': { + 'dataset_kwargs': { 'seed': 111, 'use_ood_val': True }, @@ -139,15 +211,22 @@ 'optimizer': 'Adam', 'scheduler': 'StepLR', 'scheduler_kwargs': {'gamma': 0.96}, - 'batch_size': 64, + 'batch_size': 32, + 'unlabeled_batch_size': 32, 'lr': 0.0001, 'weight_decay': 0.0, - 'n_epochs': 50, + 'n_epochs': 60, 'n_groups_per_batch': 8, + 'unlabeled_n_groups_per_batch': 8, 'irm_lambda': 1.0, 'coral_penalty_weight': 0.1, + 'dann_penalty_weight': 1.0, + 'dann_featurizer_lr': 0.00001, + 'dann_classifier_lr': 0.0001, + 'dann_discriminator_lr': 0.0001, 'algo_log_metric': 'accuracy', 'process_outputs_function': 'multiclass_logits_to_pred', + 'process_pseudolabels_function': 'pseudolabel_multiclass_logits', }, 'iwildcam': { 'loss_function': 'cross_entropy', @@ -161,16 +240,23 @@ 'lr': 3e-5, 'weight_decay': 0.0, 'batch_size': 16, + 'unlabeled_batch_size': 16, 'n_epochs': 12, 'optimizer': 'Adam', 'split_scheme': 'official', 'scheduler': None, 'groupby_fields': ['location',], 'n_groups_per_batch': 2, + 'unlabeled_n_groups_per_batch': 2, 'irm_lambda': 1., 'coral_penalty_weight': 10., + 'dann_penalty_weight': 0.1, + 'dann_featurizer_lr': 3e-6, + 'dann_classifier_lr': 3e-5, + 'dann_discriminator_lr': 3e-5, 'no_group_logging': True, - 'process_outputs_function': 'multiclass_logits_to_pred' + 'process_outputs_function': 'multiclass_logits_to_pred', + 'process_pseudolabels_function': 'pseudolabel_multiclass_logits', }, 'ogb-molpcba': { 'split_scheme': 'official', @@ -182,15 +268,27 @@ 'val_metric_decreasing': False, 'optimizer': 'Adam', 'batch_size': 32, - 'lr': 1e-03, + 'unlabeled_batch_size': 32, + 'lr': 1e-3, 'weight_decay': 0., 'n_epochs': 100, 'n_groups_per_batch': 4, + 'unlabeled_n_groups_per_batch': 4, 'irm_lambda': 1., 'coral_penalty_weight': 0.1, + 'dann_penalty_weight': 0.1, + 'dann_featurizer_lr': 1e-3, + 'dann_classifier_lr': 1e-2, + 'dann_discriminator_lr': 1e-2, + 'noisystudent_add_dropout': False, 'no_group_logging': True, - 'process_outputs_function': None, 'algo_log_metric': 'multitask_binary_accuracy', + 'process_outputs_function': None, + 'process_pseudolabels_function': 'pseudolabel_binary_logits', + 'loader_kwargs': { + 'num_workers': 1, + 'pin_memory': True, + }, }, 'py150': { 'split_scheme': 'official', @@ -229,15 +327,22 @@ 'algo_log_metric': 'mse', 'optimizer': 'Adam', 'scheduler': 'StepLR', - 'scheduler_kwargs': {'gamma':0.96}, + 'scheduler_kwargs': {'gamma': 0.96}, 'batch_size': 64, + 'unlabeled_batch_size': 64, 'lr': 0.001, 'weight_decay': 0.0, 'n_epochs': 200, 'n_groups_per_batch': 8, + 'unlabeled_n_groups_per_batch': 4, 'irm_lambda': 1.0, 'coral_penalty_weight': 0.1, + 'dann_penalty_weight': 0.1, + 'dann_featurizer_lr': 0.0001, + 'dann_classifier_lr': 0.001, + 'dann_discriminator_lr': 0.001, 'process_outputs_function': None, + 'process_pseudolabels_function': 'pseudolabel_identity', }, 'waterbirds': { 'split_scheme': 'official', @@ -333,14 +438,18 @@ 'optimizer_kwargs': {}, 'scheduler': None, 'batch_size': 4, + 'unlabeled_batch_size': 4, 'lr': 1e-5, 'weight_decay': 1e-3, - 'n_epochs': 10, + 'n_epochs': 12, + 'noisystudent_add_dropout': False, + 'self_training_threshold': 0.5, 'loader_kwargs': { 'num_workers': 1, 'pin_memory': True, }, 'process_outputs_function': None, + 'process_pseudolabels_function': 'pseudolabel_detection_discard_empty', } } diff --git a/examples/configs/model.py b/examples/configs/model.py index dee0ac01..3e4e8d48 100644 --- a/examples/configs/model.py +++ b/examples/configs/model.py @@ -40,7 +40,13 @@ }, 'resnet50': { 'model_kwargs': { - 'pretrained':True, + 'pretrained': True, + }, + 'target_resolution': (224, 224), + }, + 'resnet101': { + 'model_kwargs': { + 'pretrained': True, }, 'target_resolution': (224, 224), }, diff --git a/examples/configs/scheduler.py b/examples/configs/scheduler.py index 2da8d712..9834c194 100644 --- a/examples/configs/scheduler.py +++ b/examples/configs/scheduler.py @@ -17,6 +17,9 @@ 'step_size': 1, } }, + 'FixMatchLR': { + 'scheduler_kwargs': {}, + }, 'MultiStepLR': { 'scheduler_kwargs':{ 'gamma': 0.1, diff --git a/examples/configs/supported.py b/examples/configs/supported.py index d42a6cc7..7011c793 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -1,5 +1,16 @@ -# metrics -from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MSE, multiclass_logits_to_pred, binary_logits_to_pred, MultiTaskAveragePrecision +from wilds.common.metrics.all_metrics import ( + Accuracy, + MultiTaskAccuracy, + MSE, + multiclass_logits_to_pred, + binary_logits_to_pred, + pseudolabel_binary_logits, + pseudolabel_multiclass_logits, + pseudolabel_identity, + pseudolabel_detection, + pseudolabel_detection_discard_empty, + MultiTaskAveragePrecision +) algo_log_metrics = { 'accuracy': Accuracy(prediction_fn=multiclass_logits_to_pred), @@ -16,23 +27,33 @@ None: None, } +process_pseudolabels_functions = { + 'pseudolabel_binary_logits': pseudolabel_binary_logits, + 'pseudolabel_multiclass_logits': pseudolabel_multiclass_logits, + 'pseudolabel_identity': pseudolabel_identity, + 'pseudolabel_detection': pseudolabel_detection, + 'pseudolabel_detection_discard_empty': pseudolabel_detection_discard_empty, +} + +# see initialize_*() functions for correspondence= +# See algorithms/initializer.py +algorithms = ['ERM', 'groupDRO', 'deepCORAL', 'IRM', 'DANN', 'AFN', 'FixMatch', 'PseudoLabel', 'NoisyStudent'] + # See transforms.py -transforms = ['bert', 'image_base', 'image_resize_and_center_crop', 'poverty', 'rxrx1'] +transforms = ['bert', 'image_base', 'image_resize', 'image_resize_and_center_crop', 'poverty', 'rxrx1'] +additional_transforms = ['randaugment', 'weak'] # See models/initializer.py -models = ['resnet18_ms', 'resnet50', 'resnet34', 'resnet18', 'wideresnet50', +models = ['resnet18_ms', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'wideresnet50', 'densenet121', 'bert-base-uncased', 'distilbert-base-uncased', 'gin-virtual', 'logistic_regression', 'code-gpt-py', 'fasterrcnn', 'unet-seq'] -# See algorithms/initializer.py -algorithms = ['ERM', 'groupDRO', 'deepCORAL', 'IRM'] - # See optimizer.py optimizers = ['SGD', 'Adam', 'AdamW'] # See scheduler.py -schedulers = ['linear_schedule_with_warmup', 'cosine_schedule_with_warmup', 'ReduceLROnPlateau', 'StepLR', 'MultiStepLR'] +schedulers = ['linear_schedule_with_warmup', 'cosine_schedule_with_warmup', 'ReduceLROnPlateau', 'StepLR', 'FixMatchLR', 'MultiStepLR'] # See losses.py -losses = ['cross_entropy', 'lm_cross_entropy', 'MSE', 'multitask_bce', 'fasterrcnn_criterion'] +losses = ['cross_entropy', 'lm_cross_entropy', 'MSE', 'multitask_bce', 'fasterrcnn_criterion', 'cross_entropy_logits'] diff --git a/examples/configs/utils.py b/examples/configs/utils.py index 020bb43f..724e161f 100644 --- a/examples/configs/utils.py +++ b/examples/configs/utils.py @@ -13,6 +13,66 @@ def populate_defaults(config): assert config.dataset is not None, 'dataset must be specified' assert config.algorithm is not None, 'algorithm must be specified' + # Run oracle using ERM with unlabeled split + if config.use_unlabeled_y: + assert config.algorithm == 'ERM', 'Only ERM is currently supported for training on the true labels of unlabeled data.' + assert config.unlabeled_split is not None, 'Specify an unlabeled split' + assert config.dataset in ['amazon', 'civilcomments', 'fmow', 'iwildcam'], 'The unlabeled data in this dataset are truly unlabeled, and we do not have true labels for them.' + + # Validations + if config.groupby_fields == ['from_source_domain']: + if config.n_groups_per_batch is None: + config.n_groups_per_batch = 1 + elif config.n_groups_per_batch != 1: + raise ValueError( + f"from_source_domain was specified for groupby_fields, but n_groups_per_batch " + f"was {config.n_groups_per_batch}, when it should be 1." + ) + + if config.unlabeled_n_groups_per_batch is None: + config.unlabeled_n_groups_per_batch = 1 + elif config.unlabeled_n_groups_per_batch != 1: + raise ValueError( + f"from_source_domain was specified for groupby_fields, but unlabeled_n_groups_per_batch " + f"was {config.unlabeled_n_groups_per_batch}, when it should be 1." + ) + + if config.algorithm == 'DANN' and config.lr is not None: + raise ValueError( + "Cannot pass in a value for lr. For DANN, only dann_classifier_lr, dann_featurizer_lr " + "and dann_discriminator_lr are valid learning rate parameters." + ) + + if config.additional_train_transform is not None: + if config.algorithm == "NoisyStudent": + raise ValueError( + "Cannot pass in a value for additional_train_transform, NoisyStudent " + "already has a default transformation for the training data." + ) + + if config.load_featurizer_only: + if config.pretrained_model_path is None: + raise ValueError( + "load_featurizer_only cannot be set when there is no pretrained_model_path " + "specified." + ) + + if config.dataset == 'globalwheat': + if config.additional_train_transform is not None: + raise ValueError( + f"Augmentations not supported for detection dataset: {config.dataset}." + ) + config.additional_train_transform = '' + + if config.algorithm == "NoisyStudent": + if config.process_pseudolabels_function is None: + config.process_pseudolabels_function = 'pseudolabel_detection' + elif config.process_pseudolabels_function == 'pseudolabel_detection_discard_empty': + raise ValueError( + f"Filtering out empty images when generating pseudo-labels for {config.algorithm} " + f"is not supported for detection." + ) + # implied defaults from choice of dataset config = populate_config( config, @@ -74,6 +134,7 @@ def populate_defaults(config): return config + def populate_config(config, template: dict, force_compatibility=False): """Populates missing (key, val) pairs in config with (key, val) in template. Example usage: populate config with defaults diff --git a/examples/data_augmentation/__init__.py b/examples/data_augmentation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/data_augmentation/randaugment.py b/examples/data_augmentation/randaugment.py new file mode 100644 index 00000000..fd3f127c --- /dev/null +++ b/examples/data_augmentation/randaugment.py @@ -0,0 +1,148 @@ +# Adapted from https://github.com/YBZh/Bridging_UDA_SSL + +import torch +from PIL import Image, ImageOps, ImageEnhance, ImageDraw + + +def AutoContrast(img, _): + return ImageOps.autocontrast(img) + + +def Brightness(img, v): + assert v >= 0.0 + return ImageEnhance.Brightness(img).enhance(v) + + +def Color(img, v): + assert v >= 0.0 + return ImageEnhance.Color(img).enhance(v) + + +def Contrast(img, v): + assert v >= 0.0 + return ImageEnhance.Contrast(img).enhance(v) + + +def Equalize(img, _): + return ImageOps.equalize(img) + + +def Invert(img, _): + return ImageOps.invert(img) + + +def Identity(img, v): + return img + + +def Posterize(img, v): # [4, 8] + v = int(v) + v = max(1, v) + return ImageOps.posterize(img, v) + + +def Rotate(img, v): # [-30, 30] + return img.rotate(v) + + +def Sharpness(img, v): # [0.1,1.9] + assert v >= 0.0 + return ImageEnhance.Sharpness(img).enhance(v) + + +def ShearX(img, v): # [-0.3, 0.3] + return img.transform(img.size, Image.AFFINE, (1, v, 0, 0, 1, 0)) + + +def ShearY(img, v): # [-0.3, 0.3] + return img.transform(img.size, Image.AFFINE, (1, 0, 0, v, 1, 0)) + + +def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + v = v * img.size[0] + return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0)) + + +def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + return img.transform(img.size, Image.AFFINE, (1, 0, v, 0, 1, 0)) + + +def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + v = v * img.size[1] + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v)) + + +def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, v)) + + +def Solarize(img, v): # [0, 256] + assert 0 <= v <= 256 + return ImageOps.solarize(img, v) + + +def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] => change to [0, 0.5] + assert 0.0 <= v <= 0.5 + + v = v * img.size[0] + return CutoutAbs(img, v) + + +def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] + if v < 0: + return img + w, h = img.size + x_center = _sample_uniform(0, w) + y_center = _sample_uniform(0, h) + + x0 = int(max(0, x_center - v / 2.0)) + y0 = int(max(0, y_center - v / 2.0)) + x1 = min(w, x0 + v) + y1 = min(h, y0 + v) + + xy = (x0, y0, x1, y1) + color = (125, 123, 114) + img = img.copy() + ImageDraw.Draw(img).rectangle(xy, color) + return img + + +FIX_MATCH_AUGMENTATION_POOL = [ + (AutoContrast, 0, 1), + (Brightness, 0.05, 0.95), + (Color, 0.05, 0.95), + (Contrast, 0.05, 0.95), + (Equalize, 0, 1), + (Identity, 0, 1), + (Posterize, 4, 8), + (Rotate, -30, 30), + (Sharpness, 0.05, 0.95), + (ShearX, -0.3, 0.3), + (ShearY, -0.3, 0.3), + (Solarize, 0, 256), + (TranslateX, -0.3, 0.3), + (TranslateY, -0.3, 0.3), +] + + +def _sample_uniform(a, b): + return torch.empty(1).uniform_(a, b).item() + + +class RandAugment: + def __init__(self, n, augmentation_pool): + assert n >= 1, "RandAugment N has to be a value greater than or equal to 1." + self.n = n + self.augmentation_pool = augmentation_pool + + def __call__(self, img): + ops = [ + self.augmentation_pool[torch.randint(len(self.augmentation_pool), (1,))] + for _ in range(self.n) + ] + for op, min_val, max_val in ops: + val = min_val + float(max_val - min_val) * _sample_uniform(0, 1) + img = op(img, val) + cutout_val = _sample_uniform(0, 1) * 0.5 + img = Cutout(img, cutout_val) + return img diff --git a/examples/losses.py b/examples/losses.py index a7c2b242..c42fb98c 100644 --- a/examples/losses.py +++ b/examples/losses.py @@ -1,23 +1,27 @@ import torch.nn as nn from wilds.common.metrics.loss import ElementwiseLoss, Loss, MultiTaskLoss from wilds.common.metrics.all_metrics import MSE +from utils import cross_entropy_with_logits_loss -def initialize_loss(config, d_out): - if config.loss_function == 'cross_entropy': - return ElementwiseLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')) +def initialize_loss(loss, config): + if loss == 'cross_entropy': + return ElementwiseLoss(loss_fn=nn.CrossEntropyLoss(reduction='none', ignore_index=-100)) - elif config.loss_function == 'lm_cross_entropy': - return MultiTaskLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')) + elif loss == 'lm_cross_entropy': + return MultiTaskLoss(loss_fn=nn.CrossEntropyLoss(reduction='none', ignore_index=-100)) - elif config.loss_function == 'mse': + elif loss == 'mse': return MSE(name='loss') - elif config.loss_function == 'multitask_bce': + elif loss == 'multitask_bce': return MultiTaskLoss(loss_fn=nn.BCEWithLogitsLoss(reduction='none')) - elif config.loss_function == 'fasterrcnn_criterion': + elif loss == 'fasterrcnn_criterion': from models.detection.fasterrcnn import FasterRCNNLoss return ElementwiseLoss(loss_fn=FasterRCNNLoss(config.device)) + elif loss == 'cross_entropy_logits': + return ElementwiseLoss(loss_fn=cross_entropy_with_logits_loss) + else: - raise ValueError(f'config.loss_function {config.loss_function} not recognized') + raise ValueError(f'loss {loss} not recognized') diff --git a/examples/models/detection/fasterrcnn.py b/examples/models/detection/fasterrcnn.py index 1eb4cd4a..92d2fc5a 100644 --- a/examples/models/detection/fasterrcnn.py +++ b/examples/models/detection/fasterrcnn.py @@ -420,7 +420,7 @@ def __init__(self, backbone, num_classes=None, super(FastWILDS, self).__init__(backbone, rpn, roi_heads, transform) # Set your own forward pass - def forward(self, images, targets=None): + def forward(self, images, targets=None): if self.training: if targets is None: raise ValueError("In training mode, targets should be passed") diff --git a/examples/models/domain_adversarial_network.py b/examples/models/domain_adversarial_network.py new file mode 100644 index 00000000..381c572a --- /dev/null +++ b/examples/models/domain_adversarial_network.py @@ -0,0 +1,115 @@ +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from torch.autograd import Function + + +class DomainDiscriminator(nn.Sequential): + """ + Adapted from https://github.com/thuml/Transfer-Learning-Library + + Domain discriminator model from + `"Domain-Adversarial Training of Neural Networks" `_ + In the original paper and implementation, we distinguish whether the input features come + from the source domain or the target domain. + + We extended this to work with multiple domains, which is controlled by the n_domains + argument. + + Args: + in_feature (int): dimension of the input feature + n_domains (int): number of domains to discriminate + hidden_size (int): dimension of the hidden features + batch_norm (bool): whether use :class:`~torch.nn.BatchNorm1d`. + Use :class:`~torch.nn.Dropout` if ``batch_norm`` is False. Default: True. + Shape: + - Inputs: (minibatch, `in_feature`) + - Outputs: :math:`(minibatch, n_domains)` + """ + + def __init__( + self, in_feature: int, n_domains, hidden_size: int = 1024, batch_norm=True + ): + if batch_norm: + super(DomainDiscriminator, self).__init__( + nn.Linear(in_feature, hidden_size), + nn.BatchNorm1d(hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, hidden_size), + nn.BatchNorm1d(hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, n_domains), + ) + else: + super(DomainDiscriminator, self).__init__( + nn.Linear(in_feature, hidden_size), + nn.ReLU(inplace=True), + nn.Dropout(0.5), + nn.Linear(hidden_size, hidden_size), + nn.ReLU(inplace=True), + nn.Dropout(0.5), + nn.Linear(hidden_size, n_domains), + ) + + def get_parameters_with_lr(self, lr) -> List[Dict]: + return [{"params": self.parameters(), "lr": lr}] + +class GradientReverseFunction(Function): + """ + Credit: https://github.com/thuml/Transfer-Learning-Library + """ + @staticmethod + def forward( + ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.0 + ) -> torch.Tensor: + ctx.coeff = coeff + output = input * 1.0 + return output + + @staticmethod + def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]: + return grad_output.neg() * ctx.coeff, None + + +class GradientReverseLayer(nn.Module): + """ + Credit: https://github.com/thuml/Transfer-Learning-Library + """ + def __init__(self): + super(GradientReverseLayer, self).__init__() + + def forward(self, *input): + return GradientReverseFunction.apply(*input) + + +class DomainAdversarialNetwork(nn.Module): + def __init__(self, featurizer, classifier, n_domains): + super().__init__() + self.featurizer = featurizer + self.classifier = classifier + self.domain_classifier = DomainDiscriminator(featurizer.d_out, n_domains) + self.gradient_reverse_layer = GradientReverseLayer() + + def forward(self, input): + features = self.featurizer(input) + y_pred = self.classifier(features) + features = self.gradient_reverse_layer(features) + domains_pred = self.domain_classifier(features) + return y_pred, domains_pred + + def get_parameters_with_lr(self, featurizer_lr, classifier_lr, discriminator_lr) -> List[Dict]: + """ + Adapted from https://github.com/thuml/Transfer-Learning-Library + + A parameter list which decides optimization hyper-parameters, + such as the relative learning rate of each layer + """ + # In TLL's implementation, the learning rate of this classifier is set 10 times to that of the + # feature extractor for better accuracy by default. For our implementation, we allow the learning + # rates to be passed in separately for featurizer and classifier. + params = [ + {"params": self.featurizer.parameters(), "lr": featurizer_lr}, + {"params": self.classifier.parameters(), "lr": classifier_lr}, + ] + return params + self.domain_classifier.get_parameters_with_lr(discriminator_lr) diff --git a/examples/models/initializer.py b/examples/models/initializer.py index f43c1be0..86601fe1 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -1,7 +1,10 @@ import torch import torch.nn as nn +import os +import traceback from models.layers import Identity +from utils import load def initialize_model(config, d_out, is_featurizer=False): """ @@ -17,9 +20,18 @@ def initialize_model(config, d_out, is_featurizer=False): If is_featurizer=False: - model: a model that is equivalent to nn.Sequential(featurizer, classifier) + + Pretrained weights are loaded according to config.pretrained_model_path using either transformers.from_pretrained (for bert-based models) + or our own utils.load function (for torchvision models, resnet18-ms, and gin-virtual). + There is currently no support for loading pretrained weights from disk for other models. """ - if config.model in ('resnet50', 'resnet34', 'resnet18', 'wideresnet50', 'densenet121'): - if is_featurizer: + # If load_featurizer_only is True, + # then split into (featurizer, classifier) for the purposes of loading only the featurizer, + # before recombining them at the end + featurize = is_featurizer or config.load_featurizer_only + + if config.model in ('resnet18', 'resnet34', 'resnet50', 'resnet101', 'wideresnet50', 'densenet121'): + if featurize: featurizer = initialize_torchvision_model( name=config.model, d_out=None, @@ -33,8 +45,8 @@ def initialize_model(config, d_out, is_featurizer=False): **config.model_kwargs) elif 'bert' in config.model: - if is_featurizer: - featurizer = initialize_bert_based_model(config, d_out, is_featurizer) + if featurize: + featurizer = initialize_bert_based_model(config, d_out, featurize) classifier = nn.Linear(featurizer.d_out, d_out) model = (featurizer, classifier) else: @@ -42,7 +54,7 @@ def initialize_model(config, d_out, is_featurizer=False): elif config.model == 'resnet18_ms': # multispectral resnet 18 from models.resnet_multispectral import ResNet18 - if is_featurizer: + if featurize: featurizer = ResNet18(num_classes=None, **config.model_kwargs) classifier = nn.Linear(featurizer.d_out, d_out) model = (featurizer, classifier) @@ -51,7 +63,7 @@ def initialize_model(config, d_out, is_featurizer=False): elif config.model == 'gin-virtual': from models.gnn import GINVirtual - if is_featurizer: + if featurize: featurizer = GINVirtual(num_tasks=None, **config.model_kwargs) classifier = nn.Linear(featurizer.d_out, d_out) model = (featurizer, classifier) @@ -63,7 +75,7 @@ def initialize_model(config, d_out, is_featurizer=False): from transformers import GPT2Tokenizer name = 'microsoft/CodeGPT-small-py' tokenizer = GPT2Tokenizer.from_pretrained(name) - if is_featurizer: + if featurize: model = GPT2FeaturizerLMHeadLogit.from_pretrained(name) model.resize_token_embeddings(len(tokenizer)) featurizer = model.transformer @@ -74,11 +86,11 @@ def initialize_model(config, d_out, is_featurizer=False): model.resize_token_embeddings(len(tokenizer)) elif config.model == 'logistic_regression': - assert not is_featurizer, "Featurizer not supported for logistic regression" + assert not featurize, "Featurizer not supported for logistic regression" model = nn.Linear(out_features=d_out, **config.model_kwargs) elif config.model == 'unet-seq': from models.CNN_genome import UNet - if is_featurizer: + if featurize: featurizer = UNet(num_tasks=None, **config.model_kwargs) classifier = nn.Linear(featurizer.d_out, d_out) model = (featurizer, classifier) @@ -86,7 +98,7 @@ def initialize_model(config, d_out, is_featurizer=False): model = UNet(num_tasks=d_out, **config.model_kwargs) elif config.model == 'fasterrcnn': - if is_featurizer: # TODO + if featurize: raise NotImplementedError('Featurizer not implemented for detection yet') else: model = initialize_fasterrcnn_model(config, d_out) @@ -95,6 +107,41 @@ def initialize_model(config, d_out, is_featurizer=False): else: raise ValueError(f'Model: {config.model} not recognized.') + # Load pretrained weights from disk using our utils.load function + if config.pretrained_model_path is not None: + if config.model in ('code-gpt-py', 'logistic_regression', 'unet-seq'): + # This has only been tested on some models (mostly vision), so run this code iff we're sure it works + raise NotImplementedError(f"Model loading not yet tested for {config.model}.") + + if 'bert' not in config.model: # We've already loaded pretrained weights for bert-based models using the transformers library + try: + if featurize: + if config.load_featurizer_only: + model_to_load = model[0] + else: + model_to_load = nn.Sequential(*model) + else: + model_to_load = model + + prev_epoch, best_val_metric = load( + model_to_load, + config.pretrained_model_path, + device=config.device) + + print( + (f'Initialized model with pretrained weights from {config.pretrained_model_path} ') + + (f'previously trained for {prev_epoch} epochs ' if prev_epoch else '') + + (f'with previous val metric {best_val_metric} ' if best_val_metric else '') + ) + except Exception as e: + print('Something went wrong loading the pretrained model:') + traceback.print_exc() + raise + + # Recombine model if we originally split it up just for loading + if featurize and not is_featurizer: + model = nn.Sequential(*model) + # The `needs_y` attribute specifies whether the model's forward function # needs to take in both (x, y). # If False, Algorithm.process_batch will call model(x). @@ -102,7 +149,7 @@ def initialize_model(config, d_out, is_featurizer=False): # and model(x, None) during eval. if not hasattr(model, 'needs_y'): # Sometimes model is a tuple of (featurizer, classifier) - if isinstance(model, tuple): + if is_featurizer: for submodel in model: submodel.needs_y = False else: @@ -111,12 +158,16 @@ def initialize_model(config, d_out, is_featurizer=False): return model -def initialize_bert_based_model(config, d_out, is_featurizer=False): +def initialize_bert_based_model(config, d_out, featurize=False): from models.bert.bert import BertClassifier, BertFeaturizer from models.bert.distilbert import DistilBertClassifier, DistilBertFeaturizer + if config.pretrained_model_path: + print(f'Initialized model with pretrained weights from {config.pretrained_model_path}') + config.model_kwargs['state_dict'] = torch.load(config.pretrained_model_path, map_location=config.device) + if config.model == 'bert-base-uncased': - if is_featurizer: + if featurize: model = BertFeaturizer.from_pretrained(config.model, **config.model_kwargs) else: model = BertClassifier.from_pretrained( @@ -124,7 +175,7 @@ def initialize_bert_based_model(config, d_out, is_featurizer=False): num_labels=d_out, **config.model_kwargs) elif config.model == 'distilbert-base-uncased': - if is_featurizer: + if featurize: model = DistilBertFeaturizer.from_pretrained(config.model, **config.model_kwargs) else: model = DistilBertClassifier.from_pretrained( @@ -145,7 +196,7 @@ def initialize_torchvision_model(name, d_out, **kwargs): elif name == 'densenet121': constructor_name = name last_layer_name = 'classifier' - elif name in ('resnet50', 'resnet34', 'resnet18'): + elif name in ('resnet18', 'resnet34', 'resnet50', 'resnet101'): constructor_name = name last_layer_name = 'fc' else: @@ -162,14 +213,13 @@ def initialize_torchvision_model(name, d_out, **kwargs): last_layer = nn.Linear(d_features, d_out) model.d_out = d_out setattr(model, last_layer_name, last_layer) - return model + return model def initialize_fasterrcnn_model(config, d_out): - from models.detection.fasterrcnn import fasterrcnn_resnet50_fpn - # load a model pre-trained pre-trained on COCO + # load a model pre-trained on COCO model = fasterrcnn_resnet50_fpn( pretrained=config.model_kwargs["pretrained_model"], pretrained_backbone=config.model_kwargs["pretrained_backbone"], diff --git a/examples/noisy_student_wrapper.py b/examples/noisy_student_wrapper.py new file mode 100644 index 00000000..3e0d6d79 --- /dev/null +++ b/examples/noisy_student_wrapper.py @@ -0,0 +1,105 @@ +""" +Helper code to run multiple iterations of Noisy Student, using the same hyperparameters between iterations. The initial teacher's weights must be provided by the command line. + +Normally, to run 2 warm-started iterations with some initial teacher weights, one would run a sequence of commands: + python examples/run_expt.py --root_dir $HOME --log_dir ./student1 --dataset DATASET --algorithm NoisyStudent --unlabeled_split test_unlabeled --teacher_model_path teacher_weights.pth --pretrained_model_path teacher_weights.pth + python examples/run_expt.py --root_dir $HOME --log_dir ./student2 --dataset DATASET --algorithm NoisyStudent --unlabeled_split test_unlabeled --teacher_model_path ./student1/model.pth --pretrained_model_path ./student1/model.pth + +With this script, to run 2 warm-started iterations with some initial teacher weights: + python examples/noisy_student_wrapper.py 2 teacher_weights.pth --root_dir $HOME --log_dir . --dataset DATASET --unlabeled_split test_unlabeled + +i.e. usage: + python examples/noisy_student_wrapper.py [NUM_ITERS] [INITIAL_TEACHER_WEIGHTS] [REST OF RUN_EXPT COMMAND STRING] + +Notes: + - Students are all warm-started with the current teacher's weights. + - This command will use the FIRST occurrence of --log_dir (instead of the last). +""" +import argparse +import os +import pathlib +import pdb +import subprocess + +SUCCESS_RETURN_CODE = 0 + +parser = argparse.ArgumentParser() +parser.add_argument("num_iters", type=int) +parser.add_argument("initial_teacher_path", type=str) # required +parser.add_argument("cmd", nargs=argparse.REMAINDER) +args = parser.parse_args() + +assert args.initial_teacher_path.endswith(".pth") +assert os.path.exists( + args.initial_teacher_path +), f"Model weights did not exist at {args.initial_teacher_path}" +prefix = pathlib.Path(__file__).parent.resolve() + + +def remove_arg(args, arg_to_remove): + idx = args.cmd.index(f"--{arg_to_remove}") + value = args.cmd[idx + 1] + args.cmd = args.cmd[:idx] + args.cmd[idx + 2 :] + return value + + +# Parse out a few args that we need +try: + idx = args.cmd.index("--log_dir") + log_dir = args.cmd[idx + 1] + args.cmd = ( + args.cmd[:idx] + args.cmd[idx + 2 :] + ) # will need to modify this between iters, so remove from args.cmd +except: + log_dir = "./logs" # default in run_expt.py + +idx = args.cmd.index("--dataset") +dataset = args.cmd[idx + 1] + +try: + idx = args.cmd.index("--seed") + seed = args.cmd[idx + 1] +except: + seed = 0 # default in run_expt.py + +try: + idx = args.cmd.index("--dataset_kwargs") + fold = args.cmd[idx + 1] + assert fold.startswith("fold=") + fold = fold.replace("fold=", "") +except: + fold = "A" + +# Train the teacher model without unlabeled data and default values for gradient_accumulation_steps and n_epochs +unlabeled_split = remove_arg(args, "unlabeled_split") +gradient_accumulation_steps = remove_arg(args, "gradient_accumulation_steps") +n_epochs = remove_arg(args, "n_epochs") + +# Run student iterations +for i in range(1, args.num_iters + 1): + if i == 1: + teacher_weights = args.initial_teacher_path + else: + if dataset == "poverty": + teacher_weights = ( + f"{log_dir}/student{i - 1}/{dataset}_fold:{fold}_epoch:best_model.pth" + ) + else: + teacher_weights = ( + f"{log_dir}/student{i-1}/{dataset}_seed:{seed}_epoch:best_model.pth" + ) + cmd = ( + f"python {prefix}/run_expt.py --algorithm NoisyStudent {' '.join(args.cmd)}" + + f" --unlabeled_split {unlabeled_split} --gradient_accumulation_steps {gradient_accumulation_steps}" + + f" --n_epochs {n_epochs} --log_dir {log_dir}/student{i}" + + f" --teacher_model_path {teacher_weights}" + + f" --pretrained_model_path {teacher_weights}" # warm starting + ) + print(f">>> Running {cmd}") + return_code = subprocess.Popen(cmd, shell=True).wait() + if return_code != SUCCESS_RETURN_CODE: + raise RuntimeError( + f"FAILED: Iteration {i} failed with return code: {return_code}" + ) + +print(">>> Done!") diff --git a/examples/optimizer.py b/examples/optimizer.py index bc390394..8e5ef065 100644 --- a/examples/optimizer.py +++ b/examples/optimizer.py @@ -35,3 +35,30 @@ def initialize_optimizer(config, model): raise ValueError(f'Optimizer {config.optimizer} not recognized.') return optimizer + +def initialize_optimizer_with_model_params(config, params): + if config.optimizer=='SGD': + optimizer = SGD( + params, + lr=config.lr, + weight_decay=config.weight_decay, + **config.optimizer_kwargs + ) + elif config.optimizer=='AdamW': + optimizer = AdamW( + params, + lr=config.lr, + weight_decay=config.weight_decay, + **config.optimizer_kwargs + ) + elif config.optimizer == 'Adam': + optimizer = Adam( + params, + lr=config.lr, + weight_decay=config.weight_decay, + **config.optimizer_kwargs + ) + else: + raise ValueError(f'Optimizer {config.optimizer} not supported.') + + return optimizer diff --git a/examples/pretraining/mlm/README.md b/examples/pretraining/mlm/README.md new file mode 100644 index 00000000..140e2343 --- /dev/null +++ b/examples/pretraining/mlm/README.md @@ -0,0 +1,20 @@ +# Masked LM Pre-training + +## Dependencies +- datasets==1.11.0 +- transformers==4.9.1 + +## Usage +1. Format the unlabeled text data in the hugging-face format +``` +python3 examples/pretraining/mlm/get_data.py +``` + +2. Run the commands in `examples/pretraining/mlm/run_pretrain.sh` to start masked LM pre-training + +3. Use the pre-trained model in WILDS fine-tuning, e.g., +``` +python3 examples/run_expt.py --dataset civilcomments --algorithm ERM --root_dir data \ + --model distilbert-base-uncased \ + --pretrained_model_path examples/pretraining/mlm/data/_run__distilbert-base-uncased__civilcomments__b32a256_lr1e-4/checkpoint-1500/pytorch_model.bin +``` diff --git a/examples/pretraining/mlm/get_data.py b/examples/pretraining/mlm/get_data.py new file mode 100644 index 00000000..1836e55e --- /dev/null +++ b/examples/pretraining/mlm/get_data.py @@ -0,0 +1,79 @@ +import os +import json +import numpy as np +import pandas as pd +from tqdm import tqdm +from collections import defaultdict +import csv + +os.system('mkdir -p examples/pretraining/mlm/data') + + +######################## CivilComments ######################## +CCU_metadata_df = pd.read_csv('data/civilcomments_unlabeled_v1.0/unlabeled_data_with_identities.csv', index_col=0) +CCU_text_array = list(CCU_metadata_df['comment_text']) #1_551_515 + +with open('examples/pretraining/mlm/data/civilcomments_train.json', 'w') as outf: + for text in tqdm(CCU_text_array): + print (json.dumps({'text': text}), file=outf) + + +CC_metadata_df = pd.read_csv('data/civilcomments_v1.0/all_data_with_identities.csv', index_col=0) +CC_text_array_val = list(CC_metadata_df[CC_metadata_df['split'] == 'val']['comment_text']) #45_180 + +with open('examples/pretraining/mlm/data/civilcomments_val.json', 'w') as outf: + for text in tqdm(CC_text_array_val): + print (json.dumps({'text': text}), file=outf) + + + +######################## Amazon ######################## +amazon_data_df: pd.DataFrame = pd.read_csv( + 'data/amazon_v2.1/reviews.csv', + dtype={ + "reviewerID": str, + "asin": str, + "reviewTime": str, + "unixReviewTime": int, + "reviewText": str, + "summary": str, + "verified": bool, + "category": str, + "reviewYear": int, + }, + keep_default_na=False, + na_values=[], + quoting=csv.QUOTE_NONNUMERIC, +) #10_116_947 + +amazon_split_df: pd.DataFrame = pd.read_csv('data/amazon_v2.1/splits/user.csv') #10_116_947 +is_in_dataset: bool = (amazon_split_df["split"] != -1) + +amazon_split_df = amazon_split_df[is_in_dataset] #4_002_170 +amazon_data_df = amazon_data_df[is_in_dataset] #4_002_170 + +# "val_unlabeled": 11, "test_unlabeled": 12, "extra_unlabeled": 13, "val": 1 +_text_array_11 = list(amazon_data_df[amazon_split_df['split']==11]['reviewText']) #266_066 +_text_array_12 = list(amazon_data_df[amazon_split_df['split']==12]['reviewText']) #268_761 +_text_array_13 = list(amazon_data_df[amazon_split_df['split']==13]['reviewText']) #2_927_841 +_text_array_val = list(amazon_data_df[amazon_split_df['split']==1]['reviewText']) #100_050 + +with open('examples/pretraining/mlm/data/amazon_train_11.json', 'w') as outf: + for text in tqdm(_text_array_11): + print (json.dumps({'text': text}), file=outf) + +with open('examples/pretraining/mlm/data/amazon_train_12.json', 'w') as outf: + for text in tqdm(_text_array_12): + print (json.dumps({'text': text}), file=outf) + +with open('examples/pretraining/mlm/data/amazon_train_13.json', 'w') as outf: + for text in tqdm(_text_array_13): + print (json.dumps({'text': text}), file=outf) + +with open('examples/pretraining/mlm/data/amazon_train_11_12_13.json', 'w') as outf: + for text in tqdm(_text_array_11 + _text_array_12 + _text_array_13): + print (json.dumps({'text': text}), file=outf) + +with open('examples/pretraining/mlm/data/amazon_val.json', 'w') as outf: + for text in tqdm(_text_array_val): + print (json.dumps({'text': text}), file=outf) diff --git a/examples/pretraining/mlm/run_mlm.py b/examples/pretraining/mlm/run_mlm.py new file mode 100644 index 00000000..cf0e9d8f --- /dev/null +++ b/examples/pretraining/mlm/run_mlm.py @@ -0,0 +1,555 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2020 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) on a text file or a dataset. +Here is the full list of checkpoints on the hub that can be fine-tuned by this script: +https://huggingface.co/models?filter=masked-lm +""" + +""" +NOTE: from https://github.com/huggingface/transformers/blob/40de2d5a4f25362ae9fde0aed07f1a7a0cf926a4/examples/pytorch/language-modeling/run_mlm.py +""" + +# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments. + +import logging +import math +import os +import sys +from dataclasses import dataclass, field +from typing import Optional + +import datasets +from datasets import load_dataset + +import transformers +from transformers import ( + CONFIG_MAPPING, + MODEL_FOR_MASKED_LM_MAPPING, + AutoConfig, + AutoModelForMaskedLM, + AutoTokenizer, + DataCollatorForLanguageModeling, + HfArgumentParser, + Trainer, + TrainingArguments, + set_seed, +) +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import check_min_version +from transformers.utils.versions import require_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +# check_min_version("4.10.0.dev0") + +require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") + +logger = logging.getLogger(__name__) +MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) +MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "The model checkpoint for weights initialization." + "Don't set if you want to train a model from scratch." + }, + ) + model_type: Optional[str] = field( + default=None, + metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, + ) + config_overrides: Optional[str] = field( + default=None, + metadata={ + "help": "Override some existing default config settings when a model is trained from scratch. Example: " + "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" + }, + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + + def __post_init__(self): + if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): + raise ValueError( + "--config_overrides can't be used in combination with --config_name or --model_name_or_path" + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: Optional[str] = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) + validation_file: Optional[str] = field( + default=None, + metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} + ) + validation_split_percentage: Optional[int] = field( + default=5, + metadata={ + "help": "The percentage of the train set used as validation set in case there's no validation split" + }, + ) + max_seq_length: Optional[int] = field( + default=None, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated." + }, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + mlm_probability: float = field( + default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} + ) + line_by_line: bool = field( + default=False, + metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."}, + ) + pad_to_max_length: bool = field( + default=False, + metadata={ + "help": "Whether to pad all samples to `max_seq_length`. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + + def __post_init__(self): + if self.dataset_name is None and self.train_file is None and self.validation_file is None: + raise ValueError("Need either a dataset name or a training/validation file.") + else: + if self.train_file is not None: + extension = self.train_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." + if self.validation_file is not None: + extension = self.validation_file.split(".")[-1] + assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + # Set the verbosity to info of the Transformers logger (on main process only): + logger.info(f"Training/evaluation parameters {training_args}") + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) + # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ + # (the dataset will be downloaded automatically from the datasets Hub + # + # For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this + # behavior (see below) + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + raw_datasets = load_dataset( + data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir + ) + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + ) + raw_datasets["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=f"train[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + ) + else: + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + extension = data_args.train_file.split(".")[-1] + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.validation_file.split(".")[-1] + if extension == "txt": + extension = "text" + raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) + + # If no validation data is there, validation_split_percentage will be used to divide the dataset. + if "validation" not in raw_datasets.keys(): + raw_datasets["validation"] = load_dataset( + extension, + data_files=data_files, + split=f"train[:{data_args.validation_split_percentage}%]", + cache_dir=model_args.cache_dir, + ) + raw_datasets["train"] = load_dataset( + extension, + data_files=data_files, + split=f"train[{data_args.validation_split_percentage}%:]", + cache_dir=model_args.cache_dir, + ) + + # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Load pretrained model and tokenizer + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config_kwargs = { + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "use_auth_token": True if model_args.use_auth_token else None, + } + if model_args.config_name: + config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) + elif model_args.model_name_or_path: + config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) + else: + config = CONFIG_MAPPING[model_args.model_type]() + logger.warning("You are instantiating a new config instance from scratch.") + if model_args.config_overrides is not None: + logger.info(f"Overriding config: {model_args.config_overrides}") + config.update_from_string(model_args.config_overrides) + + tokenizer_kwargs = { + "cache_dir": model_args.cache_dir, + "use_fast": model_args.use_fast_tokenizer, + "revision": model_args.model_revision, + "use_auth_token": True if model_args.use_auth_token else None, + } + if model_args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) + elif model_args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs) + else: + raise ValueError( + "You are instantiating a new tokenizer from scratch. This is not supported by this script." + "You can do it from another script, save it, and load it from here, using --tokenizer_name." + ) + + if model_args.model_name_or_path: + model = AutoModelForMaskedLM.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + else: + logger.info("Training new model from scratch") + model = AutoModelForMaskedLM.from_config(config) + + model.resize_token_embeddings(len(tokenizer)) + + # Preprocessing the datasets. + # First we tokenize all the texts. + if training_args.do_train: + column_names = raw_datasets["train"].column_names + else: + column_names = raw_datasets["validation"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + + if data_args.max_seq_length is None: + max_seq_length = tokenizer.model_max_length + if max_seq_length > 1024: + logger.warning( + f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " + "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx." + ) + max_seq_length = 1024 + else: + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + if data_args.line_by_line: + # When using line_by_line, we just tokenize each nonempty line. + padding = "max_length" if data_args.pad_to_max_length else False + print ('data_args.line_by_line', data_args.line_by_line, 'padding', padding) + + def tokenize_function(examples): + # Remove empty lines + examples[text_column_name] = [ + line for line in examples[text_column_name] if len(line) > 0 and not line.isspace() + ] + return tokenizer( + examples[text_column_name], + padding=padding, + truncation=True, + max_length=max_seq_length, + # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it + # receives the `special_tokens_mask`. + return_special_tokens_mask=True, + ) + + with training_args.main_process_first(desc="dataset map tokenization"): + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=[text_column_name], + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on dataset line_by_line", + ) + else: + # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. + # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more + # efficient when it receives the `special_tokens_mask`. + def tokenize_function(examples): + return tokenizer(examples[text_column_name], return_special_tokens_mask=True) + + with training_args.main_process_first(desc="dataset map tokenization"): + tokenized_datasets = raw_datasets.map( + tokenize_function, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on every text in dataset", + ) + + # Main data processing function that will concatenate all texts from our dataset and generate chunks of + # max_seq_length. + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can + # customize this part to your needs. + if total_length >= max_seq_length: + total_length = (total_length // max_seq_length) * max_seq_length + # Split by chunks of max_len. + result = { + k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)] + for k, t in concatenated_examples.items() + } + return result + + # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a + # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value + # might be slower to preprocess. + # + # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map + + with training_args.main_process_first(desc="grouping texts together"): + tokenized_datasets = tokenized_datasets.map( + group_texts, + batched=True, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=not data_args.overwrite_cache, + desc=f"Grouping texts in chunks of {max_seq_length}", + ) + + if training_args.do_train: + if "train" not in tokenized_datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = tokenized_datasets["train"] + if data_args.max_train_samples is not None: + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + + if training_args.do_eval: + if "validation" not in tokenized_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = tokenized_datasets["validation"] + if data_args.max_eval_samples is not None: + eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) + + # Data collator + # This one will take care of randomly masking the tokens. + pad_to_multiple_of_8 = data_args.line_by_line and training_args.fp16 and not data_args.pad_to_max_length + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm_probability=data_args.mlm_probability, + pad_to_multiple_of=8 if pad_to_multiple_of_8 else None, + ) + + # Initialize our Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + tokenizer=tokenizer, + data_collator=data_collator, + ) + + # Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() # Saves the tokenizer too for easy upload + metrics = train_result.metrics + + max_train_samples = ( + data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) + ) + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluation + if training_args.do_eval: + logger.info("*** Evaluate ***") + + metrics = trainer.evaluate() + + max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + try: + perplexity = math.exp(metrics["eval_loss"]) + except OverflowError: + perplexity = float("inf") + metrics["perplexity"] = perplexity + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + if training_args.push_to_hub: + kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "fill-mask"} + if data_args.dataset_name is not None: + kwargs["dataset_tags"] = data_args.dataset_name + if data_args.dataset_config_name is not None: + kwargs["dataset_args"] = data_args.dataset_config_name + kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" + else: + kwargs["dataset"] = data_args.dataset_name + + trainer.push_to_hub(**kwargs) + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/examples/pretraining/mlm/run_pretrain.sh b/examples/pretraining/mlm/run_pretrain.sh new file mode 100644 index 00000000..d83e71d2 --- /dev/null +++ b/examples/pretraining/mlm/run_pretrain.sh @@ -0,0 +1,39 @@ +######################## CivilComments ######################## +dt=`date '+%Y%m%d_%H%M%S'` +data_dir="mlm_pretrain/data" +TRAIN_FILE="${data_dir}/civilcomments_train.json" +VAL_FILE="${data_dir}/civilcomments_val.json" +model="distilbert-base-uncased" +outdir="${data_dir}/_run__${model}__civilcomments__b32a256_lr1e-4__${dt}" +mkdir -p $outdir + +CUDA_VISIBLE_DEVICES=1 python3.7 -u mlm_pretrain/src/run_mlm.py \ + --model_name_or_path $model \ + --train_file $TRAIN_FILE --validation_file $VAL_FILE \ + --do_train --do_eval --output_dir $outdir --overwrite_output_dir \ + --line_by_line --max_seq_length 300 --fp16 --preprocessing_num_workers 10 --learning_rate 1e-4 \ + --max_steps 1000 --logging_first_step --logging_steps 10 --save_steps 100 \ + --evaluation_strategy steps --eval_steps 100 \ + --per_device_train_batch_size 32 --per_device_eval_batch_size 64 --gradient_accumulation_steps 256 \ + |& tee $outdir/log.txt + + + +######################## Amazon ######################## +dt=`date '+%Y%m%d_%H%M%S'` +data_dir="mlm_pretrain/data" +TRAIN_FILE="${data_dir}/amazon_train_12.json" +VAL_FILE="${data_dir}/amazon_val.json" +model="distilbert-base-uncased" +outdir="${data_dir}/_run__${model}__amazon_12__b16a512_lr1e-4__${dt}" +mkdir -p $outdir + +CUDA_VISIBLE_DEVICES=9 python3.7 -u mlm_pretrain/src/run_mlm.py \ + --model_name_or_path $model \ + --train_file $TRAIN_FILE --validation_file $VAL_FILE \ + --do_train --do_eval --output_dir $outdir --overwrite_output_dir \ + --line_by_line --max_seq_length 512 --fp16 --preprocessing_num_workers 10 --learning_rate 1e-4 \ + --max_steps 1000 --logging_first_step --logging_steps 10 --save_steps 100 \ + --evaluation_strategy steps --eval_steps 100 \ + --per_device_train_batch_size 16 --per_device_eval_batch_size 32 --gradient_accumulation_steps 512 \ + |& tee $outdir/log.txt diff --git a/examples/pretraining/swav/LICENSE b/examples/pretraining/swav/LICENSE new file mode 100644 index 00000000..6b28d560 --- /dev/null +++ b/examples/pretraining/swav/LICENSE @@ -0,0 +1,399 @@ +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/examples/pretraining/swav/README.md b/examples/pretraining/swav/README.md new file mode 100644 index 00000000..f80dc9d4 --- /dev/null +++ b/examples/pretraining/swav/README.md @@ -0,0 +1,42 @@ +# SwAV pre-training + +This folder is contains a lightly modified version of the SwAV code from https://github.com/facebookresearch/swav, licensed under CC BY-NC 4.0. + +If you use this algorithm, please cite the original source: +``` +@article{caron2020unsupervised, + title={Unsupervised Learning of Visual Features by Contrasting Cluster Assignments}, + author={Caron, Mathilde and Misra, Ishan and Mairal, Julien and Goyal, Priya and Bojanowski, Piotr and Joulin, Armand}, + booktitle={Proceedings of Advances in Neural Information Processing Systems (NeurIPS)}, + year={2020} +} +``` + +SwAV requires installation of the NVIDIA Apex library for mixed-precision training. Each Apex installation has a specific CUDA extension--more information can be found in the "requirements" section of the original SwAV repository's README: ([link](https://github.com/facebookresearch/swav)). + +## Changes +We made the following changes to the SwAV repository to interface with the WILDS code. + +### `multicropdataset.py` +- Added a new dataset class, CustomSplitMultiCropDataset, to accommodate WILDS data loaders, allowing SwAV to train on multiple datasets at once. +### Model building code +- Pulled the changes from standard ResNets to SwAV-compatible ResNets into a new file (`model.py`), allowing to incorporate WILDS-Unlabeled architectures, including ResNets and DenseNets. +### `main_swav.py` +- Edited data loading and model building code to be compatible with the 2 changes noted above. + +## Pre-training on WILDS + +To run SwAV pre-training on the WILDS datasets with the default hyperparameters used in the [paper](https://arxiv.org/abs/2112.05090), +simply run: + +```buildoutcfg +python -m torch.distributed.launch --nproc_per_node= main_swav.py --dataset --root_dir +``` + +We support SwAV pre-training on the following datasets: + +- `camelyon17` +- `iwildcam` +- `fmow` +- `poverty` +- `domainnet` \ No newline at end of file diff --git a/examples/pretraining/swav/main_swav.py b/examples/pretraining/swav/main_swav.py new file mode 100644 index 00000000..6bb791e5 --- /dev/null +++ b/examples/pretraining/swav/main_swav.py @@ -0,0 +1,436 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This file has been modified from the original repository's version in the following ways: +# 1. The model loading logic uses a SwAVModel class that acts as a wrapper around WILDS-Unlabeled +# models. +# 2. The data loading logic uses a CustomSplitMultiCropDataset class that is compatible with all +# WILDS-Unlabeled datasets. +# More information about both of these classes can be found in the src/ directory. +# + +import argparse +import math +import os +import pdb +import shutil +import sys +import time +from logging import getLogger + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim +try: + import apex + from apex.parallel.LARC import LARC +except ImportError as e: + print("Apex not found. Proceeding without it...") + +try: + import wandb +except Exception as e: + print("wandb not found. Proceeding without it...") + + +import wilds +from src.utils import ( + bool_flag, + initialize_exp, + restart_from_checkpoint, + fix_random_seeds, + AverageMeter, + init_distributed_mode, + ParseKwargs, + plot_experiment, + populate_defaults_for_swav +) +from src.multicropdataset import CustomSplitMultiCropDataset +from src.model import SwAVModel + +from examples.models.initializer import initialize_model +from examples.utils import initialize_wandb + +logger = getLogger() +parser = argparse.ArgumentParser(description="Implementation of SwAV") + +######################### +##### dataset params #### +######################### +parser.add_argument('-d', '--dataset', required=True, choices=wilds.unlabeled_datasets) +parser.add_argument('--root_dir', required=True, + help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).') +parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={}) +parser.add_argument('--loader_kwargs', nargs='*', action=ParseKwargs, default={}) +parser.add_argument('--splits', nargs='+') + +######################### +#### data aug params #### +######################### +parser.add_argument("--nmb_crops", type=int, nargs="+", help="list of number of crops") +parser.add_argument("--size_crops", type=int, nargs="+", help="crops resolutions") +parser.add_argument("--min_scale_crops", type=float, nargs="+", help="argument in RandomResizedCrop") +parser.add_argument("--max_scale_crops", type=float, nargs="+", help="argument in RandomResizedCrop") + +######################### +## swav specific params # +######################### +parser.add_argument("--crops_for_assign", type=int, nargs="+", default=[0, 1], + help="list of crops id used for computing assignments (default: [0, 1])") +parser.add_argument("--temperature", default=0.1, type=float, + help="temperature parameter in training loss (default: 0.1)") +parser.add_argument("--epsilon", default=0.03, type=float, + help="regularization parameter for Sinkhorn-Knopp algorithm (default: 0.03)") +parser.add_argument("--sinkhorn_iterations", default=3, type=int, + help="number of iterations in Sinkhorn-Knopp algorithm") +parser.add_argument("--feat_dim", default=128, type=int, + help="feature dimension") +parser.add_argument("--nmb_prototypes", type=int, help="number of prototypes") +parser.add_argument("--queue_length", type=int, default=0, + help="length of the queue (0 for no queue)") +parser.add_argument("--epoch_queue_starts", type=int, default=500, + help="from this epoch, we start using a queue") + +######################### +#### optim parameters ### +######################### +parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={}) +parser.add_argument("--n_epochs", default=400, type=int, + help="number of total epochs to run") +parser.add_argument("--warmup_epochs", default=0, type=int, help="number of warmup epochs (default: 0)") +parser.add_argument("--batch_size", type=int, + help="batch size per gpu, i.e. how many unique instances per gpu") +parser.add_argument("--lr", type=float, help="base learning rate") +parser.add_argument("--final_lr", type=float, help="final learning rate") +parser.add_argument("--freeze_prototypes_niters", default=5005, type=int, + help="freeze the prototypes during this many iterations from the start (default: 5005).") +parser.add_argument("--weight_decay", default=1e-6, type=float, help="weight decay") +parser.add_argument("--start_warmup", default=0, type=float, + help="initial warmup learning rate") + +######################### +#### dist parameters ### +######################### +parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up distributed + training; see https://pytorch.org/docs/stable/distributed.html""") +parser.add_argument("--world_size", default=-1, type=int, help=""" + number of processes: it is set automatically and + should not be passed as argument""") +parser.add_argument("--rank", default=0, type=int, help="""rank of this process: + it is set automatically and should not be passed as argument""") +parser.add_argument("--local_rank", default=0, type=int, + help="this argument is not used and should be ignored") + +######################### +#### other parameters ### +######################### +parser.add_argument("--model", type=str, help="convnet architecture. If not set, uses default model specified in WILDS.") +parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={}, + help='keyword arguments for model initialization passed as key1=value1 key2=value2') +parser.add_argument("--hidden_mlp", default=2048, type=int, + help="hidden layer dimension in projection head") +parser.add_argument("--checkpoint_freq", type=int, default=50, + help="Save the model periodically") +parser.add_argument("--use_fp16", type=bool_flag, default=True, + help="whether to train with mixed precision or not") +parser.add_argument("--sync_bn", type=str, default="pytorch", help="synchronize bn") +parser.add_argument("--syncbn_process_group_size", type=int, default=8, help=""" see + https://github.com/NVIDIA/apex/blob/master/apex/parallel/__init__.py#L58-L67""") +parser.add_argument("--log_dir", type=str, default=".", + help="experiment dump path for checkpoints and log") +parser.add_argument("--seed", type=int, default=0, help="seed") +parser.add_argument("--is_not_slurm_job", type=bool_flag, default=True, help="Set to true if not running in Slurm.") +parser.add_argument("--cpu_only", type=bool_flag, default=False, + help="Set to true to run experiment on CPUs instead of GPUs (for debugging).") +parser.add_argument('--pretrained_model_path', default=None, type=str) + +# Weights & Biases +parser.add_argument('--use_wandb', type=bool_flag, nargs='?', default=False) +parser.add_argument('--wandb_api_key_path', type=str, + help="Path to Weights & Biases API Key. If use_wandb is set to True and this argument is not specified, user will be prompted to authenticate.") +parser.add_argument('--wandb_kwargs', nargs='*', action=ParseKwargs, default={}, + help="Will be passed directly into wandb.init().") + +def main(): + global args + args = parser.parse_args() + args = populate_defaults_for_swav(args) + init_distributed_mode(args) + fix_random_seeds(args.seed) + if not os.path.exists(args.log_dir): + os.makedirs(args.log_dir) + logger, training_stats = initialize_exp(args, "epoch", "loss") + logger.info(f"Initialized distributed mode and applied WILDS default...\n{args}") + + if args.use_wandb: + initialize_wandb(args) + + train_dataset = CustomSplitMultiCropDataset( + args.dataset, + args.root_dir, + args.size_crops, + args.nmb_crops, + args.min_scale_crops, + args.max_scale_crops, + args, + ) + + sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + train_loader = torch.utils.data.DataLoader( + train_dataset, + sampler=sampler, + batch_size=args.batch_size, + **args.loader_kwargs, + ) + logger.info("Building data done with {} images loaded.".format(len(train_dataset))) + + d_out = 1 # this can be arbitrary; final layer is discarded for SwAVModel + base_model, _ = initialize_model(args, d_out, is_featurizer=True) # discard classifier + model = SwAVModel( + base_model, normalize=True, output_dim=args.feat_dim, + hidden_mlp=args.hidden_mlp, nmb_prototypes=args.nmb_prototypes + ) + + # synchronize batch norm layers + if args.sync_bn == "pytorch": + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + elif args.sync_bn == "apex": + # with apex syncbn we sync bn per group because it speeds up computation + # compared to global syncbn + process_group = apex.parallel.create_syncbn_process_group(args.syncbn_process_group_size) + model = apex.parallel.convert_syncbn_model(model, process_group=process_group) + # copy model to GPU + model = model.cuda() + if args.rank == 0: + logger.info(model) + logger.info("Building model done.") + + # build optimizer + optimizer = torch.optim.SGD( + model.parameters(), + lr=args.lr, + momentum=0.9, + weight_decay=args.weight_decay, + ) + optimizer = LARC(optimizer=optimizer, trust_coefficient=0.001, clip=False) + warmup_lr_schedule = np.linspace(args.start_warmup, args.lr, len(train_loader) * args.warmup_epochs) + iters = np.arange(len(train_loader) * (args.n_epochs - args.warmup_epochs)) + cosine_lr_schedule = np.array([args.final_lr + 0.5 * (args.lr - args.final_lr) * (1 + \ + math.cos(math.pi * t / (len(train_loader) * (args.n_epochs - args.warmup_epochs)))) for t in iters]) + lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) + logger.info("Building optimizer done.") + + # init mixed precision + if args.use_fp16: + model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1") + logger.info("Initializing mixed precision done.") + + # wrap model + model = nn.parallel.DistributedDataParallel( + model, + device_ids=[args.gpu_to_work_on] + ) + + # optionally resume from a checkpoint + to_restore = {"epoch": 0} + restart_from_checkpoint( + os.path.join(args.log_dir, "checkpoint.pth.tar"), + run_variables=to_restore, + state_dict=model, + optimizer=optimizer, + amp=apex.amp, + ) + start_epoch = to_restore["epoch"] + + # build the queue + queue = None + queue_path = os.path.join(args.log_dir, "queue" + str(args.rank) + ".pth") + if os.path.isfile(queue_path): + queue = torch.load(queue_path)["queue"] + # the queue needs to be divisible by the batch size + args.queue_length -= args.queue_length % (args.batch_size * args.world_size) + + cudnn.benchmark = True + + for epoch in range(start_epoch, args.n_epochs): + # train the network for one epoch + logger.info("============ Starting epoch %i ... ============" % epoch) + + # set sampler + train_loader.sampler.set_epoch(epoch) + + # optionally starts a queue + if args.queue_length > 0 and epoch >= args.epoch_queue_starts and queue is None: + queue = torch.zeros( + len(args.crops_for_assign), + args.queue_length // args.world_size, + args.feat_dim, + ).cuda() + + # train the network + scores, queue = train(train_loader, model, optimizer, epoch, lr_schedule, queue) + training_stats.update(scores) + + # save checkpoints + if args.rank == 0: + save_dict = { + "epoch": epoch + 1, + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + } + if args.use_fp16: + save_dict["amp"] = apex.amp.state_dict() + torch.save( + save_dict, + os.path.join(args.log_dir, "checkpoint.pth.tar"), + ) + if epoch % args.checkpoint_freq == 0 or epoch == args.n_epochs - 1: + shutil.copyfile( + os.path.join(args.log_dir, "checkpoint.pth.tar"), + os.path.join(args.dump_checkpoints, "ckp-" + str(epoch) + ".pth"), + ) + if queue is not None: + torch.save({"queue": queue}, queue_path) + + if args.rank == 0: + plot_experiment(args.log_dir) + + +def train(train_loader, model, optimizer, epoch, lr_schedule, queue): + batch_time = AverageMeter() + data_time = AverageMeter() + losses = AverageMeter() + + model.train() + use_the_queue = False + + end = time.time() + for it, inputs in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + # update learning rate + iteration = epoch * len(train_loader) + it + for param_group in optimizer.param_groups: + param_group["lr"] = lr_schedule[iteration] + + # normalize the prototypes + with torch.no_grad(): + w = model.module.prototypes.weight.data.clone() + w = nn.functional.normalize(w, dim=1, p=2) + model.module.prototypes.weight.copy_(w) + + # ============ multi-res forward passes ... ============ + embedding, output = model(inputs) + embedding = embedding.detach() + bs = inputs[0].size(0) + + # ============ swav loss ... ============ + loss = 0 + for i, crop_id in enumerate(args.crops_for_assign): + with torch.no_grad(): + out = output[bs * crop_id: bs * (crop_id + 1)].detach() + + # time to use the queue + if queue is not None: + if use_the_queue or not torch.all(queue[i, -1, :] == 0): + use_the_queue = True + out = torch.cat((torch.mm( + queue[i], + model.module.prototypes.weight.t() + ), out)) + # fill the queue + queue[i, bs:] = queue[i, :-bs].clone() + queue[i, :bs] = embedding[crop_id * bs: (crop_id + 1) * bs] + + # get assignments + q = distributed_sinkhorn(out)[-bs:] + + # cluster assignment prediction + subloss = 0 + for v in np.delete(np.arange(np.sum(args.nmb_crops)), crop_id): + x = output[bs * v: bs * (v + 1)] / args.temperature + subloss -= torch.mean(torch.sum(q * F.log_softmax(x, dim=1), dim=1)) + loss += subloss / (np.sum(args.nmb_crops) - 1) + loss /= len(args.crops_for_assign) + + # ============ backward and optim step ... ============ + optimizer.zero_grad() + if args.use_fp16: + with apex.amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + # cancel gradients for the prototypes + if iteration < args.freeze_prototypes_niters: + for name, p in model.named_parameters(): + if "prototypes" in name: + p.grad = None + optimizer.step() + + # ============ misc ... ============ + losses.update(loss.item(), inputs[0].size(0)) + batch_time.update(time.time() - end) + end = time.time() + if args.rank ==0 and it % 50 == 0: + logger.info( + "Epoch: [{0}][{1}]\t" + "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" + "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" + "Loss {loss.val:.4f} ({loss.avg:.4f})\t" + "Lr: {lr:.4f}".format( + epoch, + it, + batch_time=batch_time, + data_time=data_time, + loss=losses, + lr=optimizer.optim.param_groups[0]["lr"], + ) + ) + if args.use_wandb: + wandb.log( + { + "epoch": epoch, + "loss": losses.val, + "loss_avg": losses.avg, + } + ) + return (epoch, losses.avg), queue + + +@torch.no_grad() +def distributed_sinkhorn(out): + Q = torch.exp(out / args.epsilon).t() # Q is K-by-B for consistency with notations from our paper + B = Q.shape[1] * args.world_size # number of samples to assign + K = Q.shape[0] # how many prototypes + + # make the matrix sums to 1 + sum_Q = torch.sum(Q) + dist.all_reduce(sum_Q) + Q /= sum_Q + + for it in range(args.sinkhorn_iterations): + # normalize each row: total weight per prototype must be 1/K + sum_of_rows = torch.sum(Q, dim=1, keepdim=True) + dist.all_reduce(sum_of_rows) + Q /= sum_of_rows + Q /= K + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= B + + Q *= B # the colomns must sum to 1 so that Q is an assignment + return Q.t() + + +if __name__ == "__main__": + main() diff --git a/examples/pretraining/swav/src/config.py b/examples/pretraining/swav/src/config.py new file mode 100644 index 00000000..9e92b24d --- /dev/null +++ b/examples/pretraining/swav/src/config.py @@ -0,0 +1,208 @@ +############################################# +# SwAV-specific defaults for WILDS Datasets # +############################################# + +# Run SwAV on 4 GPUs +NUM_GPUS = 4 + +# Maximum batch size that fits on a 12GB GPU +MAX_BATCH_SIZE_PER_GPU = { + "camelyon17": 168, + "iwildcam": 24, + "fmow": 72, + "poverty": 120, + "domainnet": 96, +} + + +def get_base_lr(dataset, gpus=NUM_GPUS): + # base_lr= DEFAULT_LR / (DEFAULT_BATCH_SIZE / $effective_batch_size), + # where DEFAULT_LR=4.8, DEFAULT_BATCH_SIZE=4096 and effective_batch_size=batch size per gpu * $NUM_GPUS. + # base_lr= 4.8 / (4096 / $effective_batch_size). + batch_size_per_gpu = MAX_BATCH_SIZE_PER_GPU[dataset] + effective_batch_size = batch_size_per_gpu * gpus + if effective_batch_size == 256: + return 0.6 + return 4.8 / (4096.0 / effective_batch_size) + + +def get_queue_length(dataset, gpus=NUM_GPUS): + batch_size_per_gpu = MAX_BATCH_SIZE_PER_GPU[dataset] + effective_batch_size = batch_size_per_gpu * gpus + return 4096 - effective_batch_size + + +# All the defaults are configured to run on 4 GPUs. +DATASET_DEFAULTS = { + "camelyon17": { + "splits": ["test_unlabeled"], + "split_scheme": "official", + "dataset_kwargs": {}, + "model": "densenet121", + "model_kwargs": {"pretrained": False}, + "train_transform": "image_base", + "eval_transform": "image_base", + "target_resolution": (96, 96), + "nmb_crops": [6], + "size_crops": [96], + "min_scale_crops": [0.14], + "max_scale_crops": [1], + "loss_function": "cross_entropy", + "optimizer": "SGD", + "optimizer_kwargs": {"momentum": 0.9}, + "scheduler": None, + "batch_size": MAX_BATCH_SIZE_PER_GPU["camelyon17"], + "lr": get_base_lr("camelyon17"), + "final_lr": get_base_lr("camelyon17") / 1000.0, + "epsilon": 0.03, + "nmb_prototypes": 20, + "queue_length": get_queue_length("camelyon17"), + "epoch_queue_starts": 500, + "warmup_epochs": 0, + "n_epochs": 400, + "algo_log_metric": "accuracy", + "process_outputs_function": "multiclass_logits_to_pred", + "loader_kwargs": { + "num_workers": 4, + "pin_memory": True, + "drop_last": True, + }, + }, + "domainnet": { + "splits": ["test_unlabeled"], + "split_scheme": "official", + "dataset_kwargs": { + "source_domain": "real", + "target_domain": "sketch", + "use_sentry": False, + }, + "model": "resnet50", + "model_kwargs": {"pretrained": True}, + "train_transform": "image_resize_and_center_crop", + "eval_transform": "image_resize_and_center_crop", + "resize_scale": 256.0 / 224.0, + "target_resolution": (224, 224), + "nmb_crops": [2, 6], + "size_crops": [224, 96], + "min_scale_crops": [0.14, 0.05], + "max_scale_crops": [1, 0.14], + "loss_function": "cross_entropy", + "batch_size": MAX_BATCH_SIZE_PER_GPU["domainnet"], + "optimizer": "SGD", + "lr": get_base_lr("domainnet"), + "final_lr": get_base_lr("domainnet") / 1000.0, + "epsilon": 0.03, + "nmb_prototypes": 3450, + "queue_length": get_queue_length("domainnet"), + "epoch_queue_starts": 500, + "warmup_epochs": 0, + "n_epochs": 400, + "algo_log_metric": "accuracy", + "process_outputs_function": "multiclass_logits_to_pred", + "loader_kwargs": { + "num_workers": 4, + "pin_memory": True, + "drop_last": True, + }, + }, + "fmow": { + "splits": ["test_unlabeled"], + "split_scheme": "official", + "dataset_kwargs": {"seed": 111, "use_ood_val": True}, + "model": "densenet121", + "model_kwargs": {"pretrained": True}, + "target_resolution": (224, 224), + "nmb_crops": [2, 6], + "size_crops": [224, 96], + "min_scale_crops": [0.14, 0.05], + "max_scale_crops": [1, 0.14], + "train_transform": "image_base", + "eval_transform": "image_base", + "loss_function": "cross_entropy", + "optimizer": "Adam", + "scheduler": "StepLR", + "batch_size": MAX_BATCH_SIZE_PER_GPU["fmow"], + "lr": get_base_lr("fmow"), + "final_lr": get_base_lr("fmow") / 1000.0, + "warmup_epochs": 0, + "epsilon": 0.03, + "nmb_prototypes": 620, + "queue_length": get_queue_length("fmow"), + "epoch_queue_starts": 500, + "n_epochs": 400, + "algo_log_metric": "accuracy", + "process_outputs_function": "multiclass_logits_to_pred", + "loader_kwargs": { + "num_workers": 4, + "pin_memory": True, + "drop_last": True, + }, + }, + "iwildcam": { + "splits": ["extra_unlabeled"], + "dataset_kwargs": {}, + "loss_function": "cross_entropy", + "val_metric": "F1-macro_all", + "train_transform": "image_base", + "eval_transform": "image_base", + "target_resolution": (448, 448), + "nmb_crops": [2, 2], + "size_crops": [448, 96], + "min_scale_crops": [0.14, 0.05], + "max_scale_crops": [1, 0.14], + "model": "resnet50", + "model_kwargs": {"pretrained": True}, + "lr": get_base_lr("iwildcam"), + "final_lr": get_base_lr("iwildcam") / 1000.0, + "batch_size": MAX_BATCH_SIZE_PER_GPU["iwildcam"], + "warmup_epochs": 0, + "epsilon": 0.03, + "nmb_prototypes": 1860, + "queue_length": get_queue_length("iwildcam"), + "epoch_queue_starts": 500, + "n_epochs": 400, + "optimizer": "Adam", + "split_scheme": "official", + "scheduler": None, + "groupby_fields": ["location"], + "no_group_logging": True, + "process_outputs_function": "multiclass_logits_to_pred", + "loader_kwargs": { + "num_workers": 4, + "pin_memory": True, + "drop_last": True, + }, + }, + "poverty": { + "splits": ["test_unlabeled"], + "split_scheme": "official", + "dataset_kwargs": {"no_nl": False, "fold": "A", "use_ood_val": True}, + "model": "resnet18_ms", + "model_kwargs": {"num_channels": 8}, + "train_transform": "poverty_train", + "eval_transform": None, + "target_resolution": (224, 224), + "nmb_crops": [2, 6], + "size_crops": [224, 96], + "min_scale_crops": [0.14, 0.05], + "max_scale_crops": [1, 0.14], + "loss_function": "mse", + "optimizer": "Adam", + "scheduler": "StepLR", + "batch_size": MAX_BATCH_SIZE_PER_GPU["poverty"], + "lr": get_base_lr("poverty"), + "final_lr": get_base_lr("poverty") / 1000.0, + "epsilon": 0.03, + "nmb_prototypes": 1000, + "queue_length": get_queue_length("poverty"), + "epoch_queue_starts": 500, + "warmup_epochs": 0, + "n_epochs": 400, + "process_outputs_function": None, + "loader_kwargs": { + "num_workers": 4, + "pin_memory": True, + "drop_last": True, + }, + }, +} diff --git a/examples/pretraining/swav/src/logger.py b/examples/pretraining/swav/src/logger.py new file mode 100644 index 00000000..f73ff47a --- /dev/null +++ b/examples/pretraining/swav/src/logger.py @@ -0,0 +1,94 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os +import logging +import time +from datetime import timedelta +import pandas as pd + + +class LogFormatter: + def __init__(self): + self.start_time = time.time() + + def format(self, record): + elapsed_seconds = round(record.created - self.start_time) + + prefix = "%s - %s - %s" % ( + record.levelname, + time.strftime("%x %X"), + timedelta(seconds=elapsed_seconds), + ) + message = record.getMessage() + message = message.replace("\n", "\n" + " " * (len(prefix) + 3)) + return "%s - %s" % (prefix, message) if message else "" + + +def create_logger(filepath, rank): + """ + Create a logger. + Use a different log file for each process. + """ + # create log formatter + log_formatter = LogFormatter() + + # create file handler and set level to debug + if filepath is not None: + if rank > 0: + filepath = "%s-%i" % (filepath, rank) + file_handler = logging.FileHandler(filepath, "a") + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(log_formatter) + + # create console handler and set level to info + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + console_handler.setFormatter(log_formatter) + + # create logger and set level to debug + logger = logging.getLogger() + logger.handlers = [] + logger.setLevel(logging.DEBUG) + logger.propagate = False + if filepath is not None: + logger.addHandler(file_handler) + logger.addHandler(console_handler) + + # reset logger elapsed time + def reset_time(): + log_formatter.start_time = time.time() + + logger.reset_time = reset_time + + return logger + + +class PD_Stats(object): + """ + Log stuff with pandas library + """ + + def __init__(self, path, columns): + self.path = path + + # reload path stats + if os.path.isfile(self.path): + self.stats = pd.read_pickle(self.path) + + # check that columns are the same + assert list(self.stats.columns) == list(columns) + + else: + self.stats = pd.DataFrame(columns=columns) + + def update(self, row, save=True): + self.stats.loc[len(self.stats.index)] = row + + # save the statistics + if save: + self.stats.to_pickle(self.path) diff --git a/examples/pretraining/swav/src/model.py b/examples/pretraining/swav/src/model.py new file mode 100644 index 00000000..7ce67c61 --- /dev/null +++ b/examples/pretraining/swav/src/model.py @@ -0,0 +1,94 @@ +# +# This file defines the SwAVModel class, a wrapper around WILDS-Unlabeled architectures +# that implements the changes necessary to make the networks compatible with SwAV +# training (e.g. prototypes, projection head, etc.). Currently, the supported architectures +# are ResNets and DenseNets. +# + +import os +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models + +sys.path.insert(1, os.path.join(sys.path[0], '../../..')) +import examples.models.resnet_multispectral as resnet_ms + +class SwAVModel(nn.Module): + def __init__( + self, + base_model, + normalize=False, + output_dim=0, + hidden_mlp=0, + nmb_prototypes=0, + ): + super(SwAVModel, self).__init__() + + self.base_model = base_model # base CNN architecture + self.l2norm = normalize # whether to normalize output features + + # projection head + last_dim = base_model.d_out # output dimensionality of final featurizer layer + if output_dim == 0: + self.projection_head = None + elif hidden_mlp == 0: + self.projection_head = nn.Linear(last_dim, output_dim) + else: + self.projection_head = nn.Sequential( + nn.Linear(last_dim, hidden_mlp), + nn.BatchNorm1d(hidden_mlp), + nn.ReLU(inplace=True), + nn.Linear(hidden_mlp, output_dim), + ) + + # prototype layer + self.prototypes = None + if isinstance(nmb_prototypes, list): + self.prototypes = MultiPrototypes(output_dim, nmb_prototypes) + elif nmb_prototypes > 0: + self.prototypes = nn.Linear(output_dim, nmb_prototypes, bias=False) + + def forward_head(self, x): + if self.projection_head is not None: + x = self.projection_head(x) + + if self.l2norm: + x = F.normalize(x, dim=1, p=2) + + if self.prototypes is not None: + return x, self.prototypes(x) + return x + + def forward(self, inputs): + if not isinstance(inputs, list): + inputs = [inputs] + idx_crops = torch.cumsum(torch.unique_consecutive( + torch.tensor([inp.shape[-1] for inp in inputs]), + return_counts=True, + )[1], 0) + start_idx = 0 + for end_idx in idx_crops: + _out = self.base_model( + torch.cat(inputs[start_idx: end_idx]).cuda(non_blocking=True)) + if start_idx == 0: + output = _out + else: + output = torch.cat((output, _out)) + start_idx = end_idx + return self.forward_head(output) + +class MultiPrototypes(nn.Module): + def __init__(self, output_dim, nmb_prototypes): + super(MultiPrototypes, self).__init__() + self.nmb_heads = len(nmb_prototypes) + for i, k in enumerate(nmb_prototypes): + self.add_module("prototypes" + str(i), nn.Linear(output_dim, k, bias=False)) + + def forward(self, x): + out = [] + for i in range(self.nmb_heads): + out.append(getattr(self, "prototypes" + str(i))(x)) + return out diff --git a/examples/pretraining/swav/src/multicropdataset.py b/examples/pretraining/swav/src/multicropdataset.py new file mode 100644 index 00000000..63b66fb4 --- /dev/null +++ b/examples/pretraining/swav/src/multicropdataset.py @@ -0,0 +1,168 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This file has been modified from the original to include the CustomSplitDataset, +# which acts as a wrapper for WILDS-Unlabeled datasets and allows for simultaneous SwAV +# training of multiple datasets. +# + +import random +from logging import getLogger + +from PIL import ImageFilter +import numpy as np +import torchvision.transforms as transforms +from torch.utils.data import Dataset + +from wilds import get_dataset +from examples.transforms import poverty_rgb_color_transform + +logger = getLogger() + + +class CustomSplitDataset(Dataset): + def __init__( + self, + dataset_name, + root_dir, + config, + ): + super().__init__() + + self.datasets = [] + dataset = get_dataset( + dataset=dataset_name, + root_dir=root_dir, + unlabeled=True, + download=True, + **config.dataset_kwargs + ) + for split in config.splits: + subset = dataset.get_subset(split, transform=None) + self.datasets.append(subset) + self.dataset_lengths = [len(d) for d in self.datasets] + + def __len__(self): + return sum(self.dataset_lengths) + + def __getitem__(self, index): + # determine from which dataset to take this + ds_idx = 0 + while ds_idx < len(self.dataset_lengths): + if index < self.dataset_lengths[ds_idx]: + break + index -= self.dataset_lengths[ds_idx] + ds_idx += 1 + # ds_idx now stores the correct dataset, and index stores + # the correct position within that dataset + x, _ = self.datasets[ds_idx][index] # discard metadata + return x + + +class CustomSplitMultiCropDataset(Dataset): + def __init__( + self, + dataset_name, + root_dir, + size_crops, + nmb_crops, + min_scale_crops, + max_scale_crops, + config, + return_index=False + ): + super().__init__() + + assert len(size_crops) == len(nmb_crops) + assert len(min_scale_crops) == len(nmb_crops) + assert len(max_scale_crops) == len(nmb_crops) + self.return_index = return_index + + self.ds = CustomSplitDataset(dataset_name, root_dir, config) + color_distortion = get_color_distortion() + color_transform = [color_distortion, PILRandomGaussianBlur()] + trans = [] + means = [0.485, 0.456, 0.406] + stds = [0.229, 0.224, 0.225] + for i in range(len(size_crops)): + random_resized_crop = transforms.RandomResizedCrop( + size_crops[i], + scale=(min_scale_crops[i], max_scale_crops[i]), + ) + if dataset_name == "poverty": + # The Poverty-WILDS dataset is made up of 8 x 224 x 224 multispectral, normalized images. + # Apply spatial-level transformations first, then apply pixel-level transformations + # on RGB channels only. + # We use PyTorch's GaussianBlur because we want to blur all channels; + # the PIL implementation will only blur the RGB channels. + # The PyTorch and PIL GaussianBlur APIs differ. + # Here, we follow SimCLR defaults for the kernel size. + trans.extend([transforms.Compose([ + random_resized_crop, + transforms.RandomHorizontalFlip(p=0.5), + transforms.Lambda(lambda ms_img: poverty_rgb_color_transform( + ms_img, + color_distortion)), + transforms.RandomApply( + [transforms.GaussianBlur( + kernel_size=23, # nearest odd number to image size (224) / 10 + sigma=(0.1,2))], + p=0.5) + ]) + ] * nmb_crops[i]) + else: + trans.extend([transforms.Compose([ + random_resized_crop, + transforms.RandomHorizontalFlip(p=0.5), + transforms.Compose(color_transform), + transforms.ToTensor(), + transforms.Normalize(mean=means, std=stds)]) + ] * nmb_crops[i]) + + self.trans = trans + + def __len__(self): + return len(self.ds) + + def __getitem__(self, index): + image = self.ds[index] + multi_crops = list(map(lambda trans: trans(image), self.trans)) + if self.return_index: + return index, multi_crops + return multi_crops + + +class PILRandomGaussianBlur(object): + """ + Apply Gaussian Blur to the PIL image. Take the radius and probability of + application as the parameter. + This transform was used in SimCLR - https://arxiv.org/abs/2002.05709 + """ + + def __init__(self, p=0.5, radius_min=0.1, radius_max=2.): + self.prob = p + self.radius_min = radius_min + self.radius_max = radius_max + + def __call__(self, img): + do_it = np.random.rand() <= self.prob + if not do_it: + return img + + return img.filter( + ImageFilter.GaussianBlur( + radius=random.uniform(self.radius_min, self.radius_max) + ) + ) + + +def get_color_distortion(s=1.0): + # s is the strength of color distortion. + color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s) + rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) + rnd_gray = transforms.RandomGrayscale(p=0.2) + color_distort = transforms.Compose([rnd_color_jitter, rnd_gray]) + return color_distort diff --git a/examples/pretraining/swav/src/utils.py b/examples/pretraining/swav/src/utils.py new file mode 100644 index 00000000..8461245a --- /dev/null +++ b/examples/pretraining/swav/src/utils.py @@ -0,0 +1,308 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import argparse +from logging import getLogger +import pickle +import os +import pathlib +import random + +import numpy as np +import torch + +import matplotlib +import matplotlib.pyplot as plt + +from .logger import create_logger, PD_Stats + +import torch.distributed as dist + +import sys +sys.path.insert(1, os.path.join(sys.path[0], '../../..')) +sys.path.insert(1, os.path.join(sys.path[0], '../..')) +from examples.pretraining.swav.src.config import DATASET_DEFAULTS +from examples.configs.utils import populate_config + + +matplotlib.use('Agg') + +FALSY_STRINGS = {"off", "false", "0"} +TRUTHY_STRINGS = {"on", "true", "1"} + + +logger = getLogger() + + +def bool_flag(s): + """ + Parse boolean arguments from the command line. + """ + if s.lower() in FALSY_STRINGS: + return False + elif s.lower() in TRUTHY_STRINGS: + return True + else: + raise argparse.ArgumentTypeError("invalid value for a boolean flag") + + +def init_distributed_mode(args): + """ + Initialize the following variables: + - world_size + - rank + """ + if args.cpu_only: + return + + if args.is_not_slurm_job: + args.is_slurm_job = False + else: + args.is_slurm_job = "SLURM_JOB_ID" in os.environ + + if args.is_slurm_job: + args.rank = int(os.environ["SLURM_PROCID"]) + args.world_size = int(os.environ["SLURM_NNODES"]) * int( + os.environ["SLURM_TASKS_PER_NODE"][0] + ) + else: + # multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch + # read environment variables + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + + logger.info("\n" + "=" * 50) + logger.info(f"rank={args.rank}, world_size={args.world_size}") + logger.info("=" * 50) + + # prepare distributed + dist.init_process_group( + backend="nccl", + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + + # set cuda device + args.gpu_to_work_on = args.rank % torch.cuda.device_count() + torch.cuda.set_device(args.gpu_to_work_on) + return + + +def initialize_exp(params, *args, dump_params=True): + """ + Initialize the experience: + - dump parameters + - create checkpoint repo + - create a logger + - create a panda object to keep track of the training statistics + """ + + # dump parameters + if dump_params: + pickle.dump(params, open(os.path.join(params.log_dir, "params.pkl"), "wb")) + + # create repo to store checkpoints + params.dump_checkpoints = os.path.join(params.log_dir, "checkpoints") + if not params.rank and not os.path.isdir(params.dump_checkpoints): + os.mkdir(params.dump_checkpoints) + + # create a panda object to log loss and acc + training_stats = PD_Stats( + os.path.join(params.log_dir, "stats" + str(params.rank) + ".pkl"), args + ) + + # create a logger + logger = create_logger( + os.path.join(params.log_dir, "train.log"), rank=params.rank + ) + logger.info("============ Initialized logger ============") + logger.info( + "\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(params)).items())) + ) + logger.info("The experiment will be stored in %s\n" % params.log_dir) + logger.info("") + return logger, training_stats + + +def restart_from_checkpoint(ckp_paths, run_variables=None, **kwargs): + """ + Re-start from checkpoint + """ + # look for a checkpoint in exp repository + if isinstance(ckp_paths, list): + for ckp_path in ckp_paths: + if os.path.isfile(ckp_path): + break + else: + ckp_path = ckp_paths + + if not os.path.isfile(ckp_path): + return + + logger.info("Found checkpoint at {}".format(ckp_path)) + + # open checkpoint file + checkpoint = torch.load( + ckp_path, map_location="cuda:" + str(torch.distributed.get_rank() % torch.cuda.device_count()) + ) + + # key is what to look for in the checkpoint file + # value is the object to load + # example: {'state_dict': model} + for key, value in kwargs.items(): + if key in checkpoint and value is not None: + try: + msg = value.load_state_dict(checkpoint[key], strict=False) + print(msg) + except TypeError: + msg = value.load_state_dict(checkpoint[key]) + logger.info("=> loaded {} from checkpoint '{}'".format(key, ckp_path)) + else: + logger.warning( + "=> failed to load {} from checkpoint '{}'".format(key, ckp_path) + ) + + # re load variable important for the run + if run_variables is not None: + for var_name in run_variables: + if var_name in checkpoint: + run_variables[var_name] = checkpoint[var_name] + + +def fix_random_seeds(seed=31): + """ + Fix random seeds. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +class AverageMeter(object): + """computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +# Taken from https://sumit-ghosh.com/articles/parsing-dictionary-key-value-pairs-kwargs-argparse-python/ +class ParseKwargs(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, dict()) + for value in values: + key, value_str = value.split('=') + if value_str.replace('-', '').isnumeric(): + processed_val = int(value_str) + elif value_str.replace('-', '').replace('.', '').isnumeric(): + processed_val = float(value_str) + elif value_str in ['True', 'true']: + processed_val = True + elif value_str in ['False', 'false']: + processed_val = False + else: + processed_val = value_str + getattr(namespace, self.dest)[key] = processed_val + + +def save_plot(df, key_name, plot_name, save_folder): + ''' + Saves a plot of a particular statistic provided by the dataframe. + + Parameters + ---------- + df : pandas.DataFrame + DataFrame whose columns are statistics from an experiment and whose + rows are epochs. + key_name : str + Column name in the DataFrame for the desired statistic to plot. + plot_name : str + Name to use for plot title, y-axis, and filename. + save_folder : Union[str, pathlib.Path] + Directory to save plots. + ''' + if key_name in df: + ax = df[key_name].plot() + ax.set_title(plot_name) + ax.set_xlabel('Epoch') + ax.set_ylabel(plot_name) + filename = f'{plot_name}.png' + ax.get_figure().savefig(pathlib.Path(save_folder) / filename) + + +def plot_experiment(log_dir): + ''' + Plots some statistics from the specified experiment and saves the plots. + + Parameters + ---------- + log_dir : Union[str, pathlib.Path] + Path containing the results of the experiment. Should have files of the + form stats*.pkl. + ''' + log_dir = pathlib.Path(log_dir) + df_list = [] + for filepath in log_dir.iterdir(): + filename = str(filepath.name) + if filename.startswith('stats') and filename.endswith('.pkl'): + with open(filepath, 'rb') as open_file: + df_list.append(pickle.load(open_file)) + avg_df = sum(df_list) / len(df_list) + + STAT_NAMES = [ + ('loss', 'Training Loss'), + ('prec1', 'Training Accuracy'), + ('prec1_val', 'Source Validation Accuracy'), + ('prec1_tgt', 'Target Accuracy') + ] + for stat in STAT_NAMES: + save_plot(avg_df, stat[0], stat[1], log_dir) + plt.close() + + +def populate_defaults_for_swav(config): + """ + Populate defaults for SwAV pretraining. + """ + assert config.dataset is not None, 'dataset must be specified' + config = populate_config(config, DATASET_DEFAULTS[config.dataset]) + + # Sanity checks + assert config.warmup_epochs < config.n_epochs, \ + f'The number of warmup_epochs ({config.warmup_epochs}) cannot be greater than n_epochs ({config.n_epochs}).' + + return config \ No newline at end of file diff --git a/examples/run_expt.py b/examples/run_expt.py index c533c35c..4d287b46 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -1,4 +1,4 @@ -import os, csv +import os import time import argparse import pandas as pd @@ -8,22 +8,33 @@ import sys from collections import defaultdict +try: + import wandb +except Exception as e: + pass + import wilds from wilds.common.data_loaders import get_train_loader, get_eval_loader from wilds.common.grouper import CombinatorialGrouper +from wilds.datasets.unlabeled.wilds_unlabeled_dataset import WILDSPseudolabeledSubset -from utils import set_seed, Logger, BatchLogger, log_config, ParseKwargs, load, initialize_wandb, log_group_data, parse_bool, get_model_prefix -from train import train, evaluate -from algorithms.initializer import initialize_algorithm +from utils import set_seed, Logger, BatchLogger, log_config, ParseKwargs, load, initialize_wandb, log_group_data, parse_bool, get_model_prefix, move_to +from train import train, evaluate, infer_predictions +from algorithms.initializer import initialize_algorithm, infer_d_out from transforms import initialize_transform +from models.initializer import initialize_model from configs.utils import populate_defaults import configs.supported as supported import torch.multiprocessing -def main(): +# Necessary for large images of GlobalWheat +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True - ''' to see default hyperparams for each dataset/model, look at configs/ ''' +def main(): + + ''' Arg defaults are filled in according to examples/configs/ ''' parser = argparse.ArgumentParser() # Required arguments @@ -34,45 +45,78 @@ def main(): # Dataset parser.add_argument('--split_scheme', help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.') - parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={}) + parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={}, + help='keyword arguments for dataset initialization passed as key1=value1 key2=value2') parser.add_argument('--download', default=False, type=parse_bool, const=True, nargs='?', - help='If true, tries to downloads the dataset if it does not exist in root_dir.') + help='If true, tries to download the dataset if it does not exist in root_dir.') parser.add_argument('--frac', type=float, default=1.0, help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes. Note that this also scales the test set down, so the reported numbers are not comparable with the full test set.') - parser.add_argument('--version', default=None, type=str) + parser.add_argument('--version', default=None, type=str, help='WILDS labeled dataset version number.') + + # Unlabeled Dataset + parser.add_argument('--unlabeled_split', default=None, type=str, choices=wilds.unlabeled_splits, help='Unlabeled split to use. Some datasets only have some splits available.') + parser.add_argument('--unlabeled_version', default=None, type=str, help='WILDS unlabeled dataset version number.') + parser.add_argument('--use_unlabeled_y', default=False, type=parse_bool, const=True, nargs='?', + help='If true, unlabeled loaders will also the true labels for the unlabeled data. This is only available for some datasets. Used for "fully-labeled ERM experiments" in the paper. Correct functionality relies on CrossEntropyLoss using ignore_index=-100.') # Loaders parser.add_argument('--loader_kwargs', nargs='*', action=ParseKwargs, default={}) + parser.add_argument('--unlabeled_loader_kwargs', nargs='*', action=ParseKwargs, default={}) parser.add_argument('--train_loader', choices=['standard', 'group']) - parser.add_argument('--uniform_over_groups', type=parse_bool, const=True, nargs='?') - parser.add_argument('--distinct_groups', type=parse_bool, const=True, nargs='?') + parser.add_argument('--uniform_over_groups', type=parse_bool, const=True, nargs='?', help='If true, sample examples such that batches are uniform over groups.') + parser.add_argument('--distinct_groups', type=parse_bool, const=True, nargs='?', help='If true, enforce groups sampled per batch are distinct.') parser.add_argument('--n_groups_per_batch', type=int) + parser.add_argument('--unlabeled_n_groups_per_batch', type=int) parser.add_argument('--batch_size', type=int) + parser.add_argument('--unlabeled_batch_size', type=int) parser.add_argument('--eval_loader', choices=['standard'], default='standard') + parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help='Number of batches to process before stepping optimizer and schedulers. If > 1, we simulate having a larger effective batch size (though batchnorm behaves differently).') # Model parser.add_argument('--model', choices=supported.models) parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={}, - help='keyword arguments for model initialization passed as key1=value1 key2=value2') + help='keyword arguments for model initialization passed as key1=value1 key2=value2') + parser.add_argument('--noisystudent_add_dropout', type=parse_bool, const=True, nargs='?', help='If true, adds a dropout layer to the student model of NoisyStudent.') + parser.add_argument('--noisystudent_dropout_rate', type=float) + parser.add_argument('--pretrained_model_path', default=None, type=str, help='Specify a path to pretrained model weights') + parser.add_argument('--load_featurizer_only', default=False, type=parse_bool, const=True, nargs='?', help='If true, only loads the featurizer weights and not the classifier weights.') + + # NoisyStudent-specific loading + parser.add_argument('--teacher_model_path', type=str, help='Path to NoisyStudent teacher model weights. If this is defined, pseudolabels will first be computed for unlabeled data before anything else runs.') # Transforms parser.add_argument('--transform', choices=supported.transforms) + parser.add_argument('--additional_train_transform', choices=supported.additional_transforms, help='Optional data augmentations to layer on top of the default transforms.') parser.add_argument('--target_resolution', nargs='+', type=int, help='The input resolution that images will be resized to before being passed into the model. For example, use --target_resolution 224 224 for a standard ResNet.') parser.add_argument('--resize_scale', type=float) parser.add_argument('--max_token_length', type=int) + parser.add_argument('--randaugment_n', type=int, help='Number of RandAugment transformations to apply.') # Objective - parser.add_argument('--loss_function', choices = supported.losses) + parser.add_argument('--loss_function', choices=supported.losses) parser.add_argument('--loss_kwargs', nargs='*', action=ParseKwargs, default={}, - help='keyword arguments for loss initialization passed as key1=value1 key2=value2') + help='keyword arguments for loss initialization passed as key1=value1 key2=value2') # Algorithm parser.add_argument('--groupby_fields', nargs='+') parser.add_argument('--group_dro_step_size', type=float) parser.add_argument('--coral_penalty_weight', type=float) + parser.add_argument('--dann_penalty_weight', type=float) + parser.add_argument('--dann_classifier_lr', type=float) + parser.add_argument('--dann_featurizer_lr', type=float) + parser.add_argument('--dann_discriminator_lr', type=float) + parser.add_argument('--afn_penalty_weight', type=float) + parser.add_argument('--safn_delta_r', type=float) + parser.add_argument('--hafn_r', type=float) + parser.add_argument('--use_hafn', default=False, type=parse_bool, const=True, nargs='?') parser.add_argument('--irm_lambda', type=float) parser.add_argument('--irm_penalty_anneal_iters', type=int) + parser.add_argument('--self_training_lambda', type=float) + parser.add_argument('--self_training_threshold', type=float) + parser.add_argument('--pseudolabel_T2', type=float, help='Percentage of total iterations at which to end linear scheduling and hold lambda at the max value') + parser.add_argument('--soft_pseudolabels', default=False, type=parse_bool, const=True, nargs='?') parser.add_argument('--algo_log_metric') + parser.add_argument('--process_pseudolabels_function', choices=supported.process_pseudolabels_functions) # Model selection parser.add_argument('--val_metric') @@ -84,11 +128,13 @@ def main(): parser.add_argument('--lr', type=float) parser.add_argument('--weight_decay', type=float) parser.add_argument('--max_grad_norm', type=float) - parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={}) + parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={}, + help='keyword arguments for optimizer initialization passed as key1=value1 key2=value2') # Scheduler parser.add_argument('--scheduler', choices=supported.schedulers) - parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={}) + parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={}, + help='keyword arguments for scheduler initialization passed as key1=value1 key2=value2') parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val') parser.add_argument('--scheduler_metric_name') @@ -100,7 +146,7 @@ def main(): parser.add_argument('--eval_epoch', default=None, type=int, help='If eval_only is set, then eval_epoch allows you to specify evaluating at a particular epoch. By default, it evaluates the best epoch by validation performance.') # Misc - parser.add_argument('--device', type=int, default=0) + parser.add_argument('--device', type=int, nargs='+', default=[0]) parser.add_argument('--seed', type=int, default=0) parser.add_argument('--log_dir', default='./logs') parser.add_argument('--log_every', default=50, type=int) @@ -109,9 +155,15 @@ def main(): parser.add_argument('--save_last', type=parse_bool, const=True, nargs='?', default=True) parser.add_argument('--save_pred', type=parse_bool, const=True, nargs='?', default=True) parser.add_argument('--no_group_logging', type=parse_bool, const=True, nargs='?') - parser.add_argument('--use_wandb', type=parse_bool, const=True, nargs='?', default=False) parser.add_argument('--progress_bar', type=parse_bool, const=True, nargs='?', default=False) - parser.add_argument('--resume', type=parse_bool, const=True, nargs='?', default=False) + parser.add_argument('--resume', type=parse_bool, const=True, nargs='?', default=False, help='Whether to resume from the most recent saved model in the current log_dir.') + + # Weights & Biases + parser.add_argument('--use_wandb', type=parse_bool, const=True, nargs='?', default=False) + parser.add_argument('--wandb_api_key_path', type=str, + help="Path to Weights & Biases API Key. If use_wandb is set to True and this argument is not specified, user will be prompted to authenticate.") + parser.add_argument('--wandb_kwargs', nargs='*', action=ParseKwargs, default={}, + help='keyword arguments for wandb.init() passed as key1=value1 key2=value2') config = parser.parse_args() config = populate_defaults(config) @@ -123,7 +175,18 @@ def main(): torch.multiprocessing.set_sharing_strategy('file_system') # Set device - config.device = torch.device("cuda:" + str(config.device)) if torch.cuda.is_available() else torch.device("cpu") + if torch.cuda.is_available(): + device_count = torch.cuda.device_count() + if len(config.device) > device_count: + raise ValueError(f"Specified {len(config.device)} devices, but only {device_count} devices found.") + + config.use_data_parallel = len(config.device) > 1 + device_str = ",".join(map(str, config.device)) + os.environ["CUDA_VISIBLE_DEVICES"] = device_str + config.device = torch.device("cuda") + else: + config.use_data_parallel = False + config.device = torch.device("cpu") # Initialize logs if os.path.exists(config.log_dir) and config.resume: @@ -155,6 +218,7 @@ def main(): split_scheme=config.split_scheme, **config.dataset_kwargs) + # Transforms & data augmentations for labeled dataset # To modify data augmentation, modify the following code block. # If you want to use transforms that modify both `x` and `y`, # set `do_transform_y` to True when initializing the `WILDSSubset` below. @@ -162,6 +226,7 @@ def main(): transform_name=config.transform, config=config, dataset=full_dataset, + additional_transform_name=config.additional_train_transform, is_training=True) eval_transform = initialize_transform( transform_name=config.transform, @@ -169,10 +234,104 @@ def main(): dataset=full_dataset, is_training=False) - train_grouper = CombinatorialGrouper( - dataset=full_dataset, - groupby_fields=config.groupby_fields) + # Configure unlabeled datasets + unlabeled_dataset = None + if config.unlabeled_split is not None: + split = config.unlabeled_split + full_unlabeled_dataset = wilds.get_dataset( + dataset=config.dataset, + version=config.unlabeled_version, + root_dir=config.root_dir, + download=config.download, + unlabeled=True, + **config.dataset_kwargs + ) + train_grouper = CombinatorialGrouper( + dataset=[full_dataset, full_unlabeled_dataset], + groupby_fields=config.groupby_fields + ) + + # Transforms & data augmentations for unlabeled dataset + if config.algorithm == "FixMatch": + # For FixMatch, we need our loader to return batches in the form ((x_weak, x_strong), m) + # We do this by initializing a special transform function + unlabeled_train_transform = initialize_transform( + config.transform, config, full_dataset, is_training=True, additional_transform_name="fixmatch" + ) + else: + # Otherwise, use the same data augmentations as the labeled data. + unlabeled_train_transform = train_transform + + if config.algorithm == "NoisyStudent": + # For Noisy Student, we need to first generate pseudolabels using the teacher + # and then prep the unlabeled dataset to return these pseudolabels in __getitem__ + print("Inferring teacher pseudolabels for Noisy Student") + assert config.teacher_model_path is not None + if not config.teacher_model_path.endswith(".pth"): + # Use the best model + config.teacher_model_path = os.path.join( + config.teacher_model_path, f"{config.dataset}_seed:{config.seed}_epoch:best_model.pth" + ) + + d_out = infer_d_out(full_dataset, config) + teacher_model = initialize_model(config, d_out).to(config.device) + load(teacher_model, config.teacher_model_path, device=config.device) + # Infer teacher outputs on weakly augmented unlabeled examples in sequential order + weak_transform = initialize_transform( + transform_name=config.transform, + config=config, + dataset=full_dataset, + is_training=True, + additional_transform_name="weak" + ) + unlabeled_split_dataset = full_unlabeled_dataset.get_subset(split, transform=weak_transform, frac=config.frac) + sequential_loader = get_eval_loader( + loader=config.eval_loader, + dataset=unlabeled_split_dataset, + grouper=train_grouper, + batch_size=config.unlabeled_batch_size, + **config.unlabeled_loader_kwargs + ) + teacher_outputs = infer_predictions(teacher_model, sequential_loader, config) + teacher_outputs = move_to(teacher_outputs, torch.device("cpu")) + unlabeled_split_dataset = WILDSPseudolabeledSubset( + reference_subset=unlabeled_split_dataset, + pseudolabels=teacher_outputs, + transform=unlabeled_train_transform, + collate=full_dataset.collate, + ) + teacher_model = teacher_model.to(torch.device("cpu")) + del teacher_model + else: + unlabeled_split_dataset = full_unlabeled_dataset.get_subset( + split, + transform=unlabeled_train_transform, + frac=config.frac, + load_y=config.use_unlabeled_y + ) + + unlabeled_dataset = { + 'split': split, + 'name': full_unlabeled_dataset.split_names[split], + 'dataset': unlabeled_split_dataset + } + unlabeled_dataset['loader'] = get_train_loader( + loader=config.train_loader, + dataset=unlabeled_dataset['dataset'], + batch_size=config.unlabeled_batch_size, + uniform_over_groups=config.uniform_over_groups, + grouper=train_grouper, + distinct_groups=config.distinct_groups, + n_groups_per_batch=config.unlabeled_n_groups_per_batch, + **config.unlabeled_loader_kwargs + ) + else: + train_grouper = CombinatorialGrouper( + dataset=full_dataset, + groupby_fields=config.groupby_fields + ) + # Configure labeled torch datasets (WILDS dataset splits) datasets = defaultdict(dict) for split in full_dataset.split_dict.keys(): if split=='train': @@ -215,12 +374,14 @@ def main(): # Loggers datasets[split]['eval_logger'] = BatchLogger( - os.path.join(config.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=(config.use_wandb and verbose)) + os.path.join(config.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=config.use_wandb + ) datasets[split]['algo_logger'] = BatchLogger( - os.path.join(config.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=(config.use_wandb and verbose)) + os.path.join(config.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=config.use_wandb + ) - if config.use_wandb: - initialize_wandb(config) + if config.use_wandb: + initialize_wandb(config) # Logging dataset info # Show class breakdown if feasible @@ -233,16 +394,20 @@ def main(): else: log_grouper = train_grouper log_group_data(datasets, log_grouper, logger) + if unlabeled_dataset is not None: + log_group_data({"unlabeled": unlabeled_dataset}, log_grouper, logger) - ## Initialize algorithm + # Initialize algorithm & load pretrained weights if provided algorithm = initialize_algorithm( config=config, datasets=datasets, - train_grouper=train_grouper) + train_grouper=train_grouper, + unlabeled_dataset=unlabeled_dataset, + ) model_prefix = get_model_prefix(datasets['train'], config) if not config.eval_only: - ## Load saved results if resuming + # Resume from most recent model in log_dir resume_success = False if resume: save_path = model_prefix + 'epoch:last_model.pth' @@ -254,30 +419,41 @@ def main(): latest_epoch = max(epochs) save_path = model_prefix + f'epoch:{latest_epoch}_model.pth' try: - prev_epoch, best_val_metric = load(algorithm, save_path) + prev_epoch, best_val_metric = load(algorithm, save_path, device=config.device) epoch_offset = prev_epoch + 1 logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}') resume_success = True except FileNotFoundError: pass - if resume_success == False: epoch_offset=0 best_val_metric=None + # Log effective batch size + if config.gradient_accumulation_steps > 1: + logger.write( + (f'\nUsing gradient_accumulation_steps {config.gradient_accumulation_steps} means that') + + (f' the effective labeled batch size is {config.batch_size * config.gradient_accumulation_steps}') + + (f' and the effective unlabeled batch size is {config.unlabeled_batch_size * config.gradient_accumulation_steps}' + if unlabeled_dataset and config.unlabeled_batch_size else '') + + ('. Updates behave as if torch loaders have drop_last=False\n') + ) + train( algorithm=algorithm, datasets=datasets, general_logger=logger, config=config, epoch_offset=epoch_offset, - best_val_metric=best_val_metric) + best_val_metric=best_val_metric, + unlabeled_dataset=unlabeled_dataset, + ) else: if config.eval_epoch is None: eval_model_path = model_prefix + 'epoch:best_model.pth' else: eval_model_path = model_prefix + f'epoch:{config.eval_epoch}_model.pth' - best_epoch, best_val_metric = load(algorithm, eval_model_path) + best_epoch, best_val_metric = load(algorithm, eval_model_path, device=config.device) if config.eval_epoch is None: epoch = best_epoch else: @@ -292,6 +468,8 @@ def main(): config=config, is_best=is_best) + if config.use_wandb: + wandb.finish() logger.close() for split in datasets: datasets[split]['eval_logger'].close() diff --git a/examples/scheduler.py b/examples/scheduler.py index 9b927d9c..4cb4aea1 100644 --- a/examples/scheduler.py +++ b/examples/scheduler.py @@ -1,12 +1,11 @@ -from transformers import (get_linear_schedule_with_warmup, - get_cosine_schedule_with_warmup) -from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR, MultiStepLR +from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau, StepLR, CosineAnnealingLR, MultiStepLR def initialize_scheduler(config, optimizer, n_train_steps): # construct schedulers if config.scheduler is None: return None - elif config.scheduler=='linear_schedule_with_warmup': + elif config.scheduler == 'linear_schedule_with_warmup': + from transformers import get_linear_schedule_with_warmup scheduler = get_linear_schedule_with_warmup( optimizer, num_training_steps=n_train_steps, @@ -14,6 +13,7 @@ def initialize_scheduler(config, optimizer, n_train_steps): step_every_batch = True use_metric = False elif config.scheduler == 'cosine_schedule_with_warmup': + from transformers import get_cosine_schedule_with_warmup scheduler = get_cosine_schedule_with_warmup( optimizer, num_training_steps=n_train_steps, @@ -31,13 +31,21 @@ def initialize_scheduler(config, optimizer, n_train_steps): scheduler = StepLR(optimizer, **config.scheduler_kwargs) step_every_batch = False use_metric = False + elif config.scheduler == 'FixMatchLR': + scheduler = LambdaLR( + optimizer, + lambda x: (1.0 + 10 * float(x) / n_train_steps) ** -0.75 + ) + step_every_batch = True + use_metric = False elif config.scheduler == 'MultiStepLR': scheduler = MultiStepLR(optimizer, **config.scheduler_kwargs) step_every_batch = False use_metric = False else: - raise ValueError('Scheduler not recognized.') - # add a step_every_batch field + raise ValueError(f'Scheduler: {config.scheduler} not supported.') + + # add an step_every_batch field scheduler.step_every_batch = step_every_batch scheduler.use_metric = use_metric return scheduler @@ -48,3 +56,38 @@ def step_scheduler(scheduler, metric=None): scheduler.step(metric) else: scheduler.step() + +class LinearScheduleWithWarmupAndThreshold(): + """ + Linear scheduler with warmup and threshold for non lr parameters. + Parameters is held at 0 until some T1, linearly increased until T2, and then held + at some max value after T2. + Designed to be called by step_scheduler() above and used within Algorithm class. + Args: + - last_warmup_step: aka T1. for steps [0, T1) keep param = 0 + - threshold_step: aka T2. step over period [T1, T2) to reach param = max value + - max value: end value of the param + """ + def __init__(self, max_value, last_warmup_step=0, threshold_step=1, step_every_batch=False): + self.max_value = max_value + self.T1 = last_warmup_step + self.T2 = threshold_step + assert (0 <= self.T1) and (self.T1 < self.T2) + + # internal tracker of which step we're on + self.current_step = 0 + self.value = 0 + + # required fields called in Algorithm when stepping schedulers + self.step_every_batch = step_every_batch + self.use_metric = False + + def step(self): + """This function is first called AFTER step 0, so increment first to set value for next step""" + self.current_step += 1 + if self.current_step < self.T1: + self.value = 0 + elif self.current_step < self.T2: + self.value = (self.current_step - self.T1) / (self.T2 - self.T1) * self.max_value + else: + self.value = self.max_value diff --git a/examples/train.py b/examples/train.py index c1caa3fd..c0a0fdb2 100644 --- a/examples/train.py +++ b/examples/train.py @@ -1,11 +1,12 @@ -import os -from tqdm import tqdm +import copy import torch -from utils import save_model, save_pred, get_pred_prefix, get_model_prefix, detach_and_clone, collate_list -from configs.supported import process_outputs_functions +from tqdm import tqdm +import math -def run_epoch(algorithm, dataset, general_logger, epoch, config, train): +from configs.supported import process_outputs_functions, process_pseudolabels_functions +from utils import save_model, save_pred, get_pred_prefix, get_model_prefix, collate_list, detach_and_clone, InfiniteDataIterator +def run_epoch(algorithm, dataset, general_logger, epoch, config, train, unlabeled_dataset=None): if dataset['verbose']: general_logger.write(f"\n{dataset['name']}:\n") @@ -23,16 +24,31 @@ def run_epoch(algorithm, dataset, general_logger, epoch, config, train): epoch_y_pred = [] epoch_metadata = [] + # Assert that data loaders are defined for the datasets + assert 'loader' in dataset, "A data loader must be defined for the dataset." + if unlabeled_dataset: + assert 'loader' in unlabeled_dataset, "A data loader must be defined for the dataset." + + batches = dataset['loader'] + if config.progress_bar: + batches = tqdm(batches) + last_batch_idx = len(batches)-1 + + if unlabeled_dataset: + unlabeled_data_iterator = InfiniteDataIterator(unlabeled_dataset['loader']) + # Using enumerate(iterator) can sometimes leak memory in some environments (!) # so we manually increment batch_idx batch_idx = 0 - iterator = tqdm(dataset['loader']) if config.progress_bar else dataset['loader'] - - for batch in iterator: + for labeled_batch in batches: if train: - batch_results = algorithm.update(batch) + if unlabeled_dataset: + unlabeled_batch = next(unlabeled_data_iterator) + batch_results = algorithm.update(labeled_batch, unlabeled_batch, is_epoch_end=(batch_idx==last_batch_idx)) + else: + batch_results = algorithm.update(labeled_batch, is_epoch_end=(batch_idx==last_batch_idx)) else: - batch_results = algorithm.evaluate(batch) + batch_results = algorithm.evaluate(labeled_batch) # These tensors are already detached, but we need to clone them again # Otherwise they don't get garbage collected properly in some versions @@ -45,8 +61,13 @@ def run_epoch(algorithm, dataset, general_logger, epoch, config, train): epoch_y_pred.append(y_pred) epoch_metadata.append(detach_and_clone(batch_results['metadata'])) - if train and (batch_idx+1) % config.log_every==0: - log_results(algorithm, dataset, general_logger, epoch, batch_idx) + if train: + effective_batch_idx = (batch_idx + 1) / config.gradient_accumulation_steps + else: + effective_batch_idx = batch_idx + 1 + + if train and effective_batch_idx % config.log_every==0: + log_results(algorithm, dataset, general_logger, epoch, math.ceil(effective_batch_idx)) batch_idx += 1 @@ -66,7 +87,7 @@ def run_epoch(algorithm, dataset, general_logger, epoch, config, train): log_access=(not train)) # log after updating the scheduler in case it needs to access the internal logs - log_results(algorithm, dataset, general_logger, epoch, batch_idx) + log_results(algorithm, dataset, general_logger, epoch, math.ceil(effective_batch_idx)) results['epoch'] = epoch dataset['eval_logger'].log(results) @@ -77,12 +98,20 @@ def run_epoch(algorithm, dataset, general_logger, epoch, config, train): return results, epoch_y_pred -def train(algorithm, datasets, general_logger, config, epoch_offset, best_val_metric): +def train(algorithm, datasets, general_logger, config, epoch_offset, best_val_metric, unlabeled_dataset=None): + """ + Train loop that, each epoch: + - Steps an algorithm on the datasets['train'] split and the unlabeled split + - Evaluates the algorithm on the datasets['val'] split + - Saves models / preds with frequency according to the configs + - Evaluates on any other specified splits in the configs + Assumes that the datasets dict contains labeled data. + """ for epoch in range(epoch_offset, config.n_epochs): general_logger.write('\nEpoch [%d]:\n' % epoch) # First run training - run_epoch(algorithm, datasets['train'], general_logger, epoch, config, train=True) + run_epoch(algorithm, datasets['train'], general_logger, epoch, config, train=True, unlabeled_dataset=unlabeled_dataset) # Then run val val_results, y_pred = run_epoch(algorithm, datasets['val'], general_logger, epoch, config, train=False) @@ -151,12 +180,38 @@ def evaluate(algorithm, datasets, epoch, general_logger, config, is_best): if split != 'train': save_pred_if_needed(epoch_y_pred, dataset, epoch, config, is_best, force_save=True) +def infer_predictions(model, loader, config): + """ + Simple inference loop that performs inference using a model (not algorithm) and returns model outputs. + Compatible with both labeled and unlabeled WILDS datasets. + """ + model.eval() + y_pred = [] + iterator = tqdm(loader) if config.progress_bar else loader + for batch in iterator: + x = batch[0] + x = x.to(config.device) + with torch.no_grad(): + output = model(x) + if not config.soft_pseudolabels and config.process_pseudolabels_function is not None: + _, output, _, _ = process_pseudolabels_functions[config.process_pseudolabels_function]( + output, + confidence_threshold=config.self_training_threshold if config.dataset == 'globalwheat' else 0 + ) + elif config.soft_pseudolabels: + output = torch.nn.functional.softmax(output, dim=1) + if isinstance(output, list): + y_pred.extend(detach_and_clone(output)) + else: + y_pred.append(detach_and_clone(output)) + + return torch.cat(y_pred, 0) if torch.is_tensor(y_pred[0]) else y_pred -def log_results(algorithm, dataset, general_logger, epoch, batch_idx): +def log_results(algorithm, dataset, general_logger, epoch, effective_batch_idx): if algorithm.has_log: log = algorithm.get_log() log['epoch'] = epoch - log['batch'] = batch_idx + log['batch'] = effective_batch_idx dataset['algo_logger'].log(log) if dataset['verbose']: general_logger.write(algorithm.get_pretty_log_str()) diff --git a/examples/transforms.py b/examples/transforms.py index a997110d..c1ca1a34 100644 --- a/examples/transforms.py +++ b/examples/transforms.py @@ -1,120 +1,128 @@ -import random +import copy +from typing import List +import numpy as np +import torch import torchvision.transforms as transforms import torchvision.transforms.functional as TF from transformers import BertTokenizerFast, DistilBertTokenizerFast -import torch -def initialize_transform(transform_name, config, dataset, is_training): +from data_augmentation.randaugment import FIX_MATCH_AUGMENTATION_POOL, RandAugment + + +_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406] +_DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225] + + +def initialize_transform( + transform_name, config, dataset, is_training, additional_transform_name=None +): """ By default, transforms should take in `x` and return `transformed_x`. For transforms that take in `(x, y)` and return `(transformed_x, transformed_y)`, - set `do_transform_y` to True when initializing the WILDSSubset. + set `do_transform_y` to True when initializing the WILDSSubset. """ if transform_name is None: return None - elif transform_name=='bert': + elif transform_name == "bert": return initialize_bert_transform(config) - elif transform_name=='image_base': - return initialize_image_base_transform(config, dataset) - elif transform_name=='image_resize_and_center_crop': - return initialize_image_resize_and_center_crop_transform(config, dataset) - elif transform_name=='poverty': - return initialize_poverty_transform(is_training) - elif transform_name=='rxrx1': + elif transform_name == 'rxrx1': return initialize_rxrx1_transform(is_training) + + # For images + normalize = True + if transform_name == "image_base": + transform_steps = get_image_base_transform_steps(config, dataset) + elif transform_name == "image_resize": + transform_steps = get_image_resize_transform_steps( + config, dataset + ) + elif transform_name == "image_resize_and_center_crop": + transform_steps = get_image_resize_and_center_crop_transform_steps( + config, dataset + ) + elif transform_name == "poverty": + if not is_training: + return None + transform_steps = [] + normalize = False else: raise ValueError(f"{transform_name} not recognized") + default_normalization = transforms.Normalize( + _DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, + _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD, + ) + if additional_transform_name == "fixmatch": + if transform_name == 'poverty': + transformations = add_poverty_fixmatch_transform(config, dataset, transform_steps) + else: + transformations = add_fixmatch_transform( + config, dataset, transform_steps, default_normalization + ) + transform = MultipleTransforms(transformations) + elif additional_transform_name == "randaugment": + if transform_name == 'poverty': + transform = add_poverty_rand_augment_transform( + config, dataset, transform_steps + ) + else: + transform = add_rand_augment_transform( + config, dataset, transform_steps, default_normalization + ) + elif additional_transform_name == "weak": + transform = add_weak_transform( + config, dataset, transform_steps, normalize, default_normalization + ) + else: + if transform_name != "poverty": + # The poverty data is already a tensor at this point + transform_steps.append(transforms.ToTensor()) + if normalize: + transform_steps.append(default_normalization) + transform = transforms.Compose(transform_steps) + + return transform + + def initialize_bert_transform(config): - assert 'bert' in config.model + def get_bert_tokenizer(model): + if model == "bert-base-uncased": + return BertTokenizerFast.from_pretrained(model) + elif model == "distilbert-base-uncased": + return DistilBertTokenizerFast.from_pretrained(model) + else: + raise ValueError(f"Model: {model} not recognized.") + + assert "bert" in config.model assert config.max_token_length is not None - tokenizer = getBertTokenizer(config.model) + tokenizer = get_bert_tokenizer(config.model) + def transform(text): tokens = tokenizer( text, - padding='max_length', + padding="max_length", truncation=True, max_length=config.max_token_length, - return_tensors='pt') - if config.model == 'bert-base-uncased': - x = torch.stack( - (tokens['input_ids'], - tokens['attention_mask'], - tokens['token_type_ids']), - dim=2) - elif config.model == 'distilbert-base-uncased': + return_tensors="pt", + ) + if config.model == "bert-base-uncased": x = torch.stack( - (tokens['input_ids'], - tokens['attention_mask']), - dim=2) - x = torch.squeeze(x, dim=0) # First shape dim is always 1 + ( + tokens["input_ids"], + tokens["attention_mask"], + tokens["token_type_ids"], + ), + dim=2, + ) + elif config.model == "distilbert-base-uncased": + x = torch.stack((tokens["input_ids"], tokens["attention_mask"]), dim=2) + x = torch.squeeze(x, dim=0) # First shape dim is always 1 return x - return transform - -def getBertTokenizer(model): - if model == 'bert-base-uncased': - tokenizer = BertTokenizerFast.from_pretrained(model) - elif model == 'distilbert-base-uncased': - tokenizer = DistilBertTokenizerFast.from_pretrained(model) - else: - raise ValueError(f'Model: {model} not recognized.') - return tokenizer - -def initialize_image_base_transform(config, dataset): - transform_steps = [] - if dataset.original_resolution is not None and min(dataset.original_resolution)!=max(dataset.original_resolution): - crop_size = min(dataset.original_resolution) - transform_steps.append(transforms.CenterCrop(crop_size)) - if config.target_resolution is not None and config.dataset!='fmow': - transform_steps.append(transforms.Resize(config.target_resolution)) - transform_steps += [ - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ] - transform = transforms.Compose(transform_steps) - return transform - -def initialize_image_resize_and_center_crop_transform(config, dataset): - """ - Resizes the image to a slightly larger square then crops the center. - """ - assert dataset.original_resolution is not None - assert config.resize_scale is not None - scaled_resolution = tuple(int(res*config.resize_scale) for res in dataset.original_resolution) - if config.target_resolution is not None: - target_resolution = config.target_resolution - else: - target_resolution = dataset.original_resolution - transform = transforms.Compose([ - transforms.Resize(scaled_resolution), - transforms.CenterCrop(target_resolution), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ]) return transform -def initialize_poverty_transform(is_training): - if is_training: - transforms_ls = [ - transforms.ToPILImage(), - transforms.RandomHorizontalFlip(), - transforms.RandomVerticalFlip(), - transforms.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.1), - transforms.ToTensor()] - rgb_transform = transforms.Compose(transforms_ls) - - def transform_rgb(img): - # bgr to rgb and back to bgr - img[:3] = rgb_transform(img[:3][[2,1,0]])[[2,1,0]] - return img - transform = transforms.Lambda(lambda x: transform_rgb(x)) - return transform - else: - return None - def initialize_rxrx1_transform(is_training): def standardize(x: torch.Tensor) -> torch.Tensor: mean = x.mean(dim=(1, 2)) @@ -145,3 +153,171 @@ def random_rotation(x: torch.Tensor) -> torch.Tensor: ] transform = transforms.Compose(transforms_ls) return transform + +def get_image_base_transform_steps(config, dataset) -> List: + transform_steps = [] + + if dataset.original_resolution is not None and min( + dataset.original_resolution + ) != max(dataset.original_resolution): + crop_size = min(dataset.original_resolution) + transform_steps.append(transforms.CenterCrop(crop_size)) + + if config.target_resolution is not None: + transform_steps.append(transforms.Resize(config.target_resolution)) + + return transform_steps + + +def get_image_resize_and_center_crop_transform_steps(config, dataset) -> List: + """ + Resizes the image to a slightly larger square then crops the center. + """ + transform_steps = get_image_resize_transform_steps(config, dataset) + target_resolution = _get_target_resolution(config, dataset) + transform_steps.append( + transforms.CenterCrop(target_resolution), + ) + return transform_steps + + +def get_image_resize_transform_steps(config, dataset) -> List: + """ + Resizes the image to a slightly larger square. + """ + assert dataset.original_resolution is not None + assert config.resize_scale is not None + scaled_resolution = tuple( + int(res * config.resize_scale) for res in dataset.original_resolution + ) + return [ + transforms.Resize(scaled_resolution) + ] + +def add_fixmatch_transform(config, dataset, base_transform_steps, normalization): + return ( + add_weak_transform(config, dataset, base_transform_steps, True, normalization), + add_rand_augment_transform(config, dataset, base_transform_steps, normalization) + ) + +def add_poverty_fixmatch_transform(config, dataset, base_transform_steps): + return ( + add_weak_transform(config, dataset, base_transform_steps, False, None), + add_poverty_rand_augment_transform(config, dataset, base_transform_steps) + ) + +def add_weak_transform(config, dataset, base_transform_steps, should_normalize, normalization): + # Adapted from https://github.com/YBZh/Bridging_UDA_SSL + target_resolution = _get_target_resolution(config, dataset) + weak_transform_steps = copy.deepcopy(base_transform_steps) + weak_transform_steps.extend( + [ + transforms.RandomHorizontalFlip(), + transforms.RandomCrop( + size=target_resolution, + ), + ] + ) + if should_normalize: + weak_transform_steps.append(transforms.ToTensor()) + weak_transform_steps.append(normalization) + return transforms.Compose(weak_transform_steps) + +def add_rand_augment_transform(config, dataset, base_transform_steps, normalization): + # Adapted from https://github.com/YBZh/Bridging_UDA_SSL + target_resolution = _get_target_resolution(config, dataset) + strong_transform_steps = copy.deepcopy(base_transform_steps) + strong_transform_steps.extend( + [ + transforms.RandomHorizontalFlip(), + transforms.RandomCrop( + size=target_resolution + ), + RandAugment( + n=config.randaugment_n, + augmentation_pool=FIX_MATCH_AUGMENTATION_POOL, + ), + transforms.ToTensor(), + normalization, + ] + ) + return transforms.Compose(strong_transform_steps) + +def poverty_rgb_color_transform(ms_img, transform): + from wilds.datasets.poverty_dataset import _MEANS_2009_17, _STD_DEVS_2009_17 + poverty_rgb_means = np.array([_MEANS_2009_17[c] for c in ['RED', 'GREEN', 'BLUE']]).reshape((-1, 1, 1)) + poverty_rgb_stds = np.array([_STD_DEVS_2009_17[c] for c in ['RED', 'GREEN', 'BLUE']]).reshape((-1, 1, 1)) + + def unnormalize_rgb_in_poverty_ms_img(ms_img): + result = ms_img.detach().clone() + result[:3] = (result[:3] * poverty_rgb_stds) + poverty_rgb_means + return result + + def normalize_rgb_in_poverty_ms_img(ms_img): + result = ms_img.detach().clone() + result[:3] = (result[:3] - poverty_rgb_means) / poverty_rgb_stds + return ms_img + + color_transform = transforms.Compose([ + transforms.Lambda(lambda ms_img: unnormalize_rgb_in_poverty_ms_img(ms_img)), + transform, + transforms.Lambda(lambda ms_img: normalize_rgb_in_poverty_ms_img(ms_img)), + ]) + # The first three channels of the Poverty MS images are BGR + # So we shuffle them to the standard RGB to do the ColorJitter + # Before shuffling them back + ms_img[:3] = color_transform(ms_img[[2,1,0]])[[2,1,0]] # bgr to rgb to bgr + return ms_img + +def add_poverty_rand_augment_transform(config, dataset, base_transform_steps): + def poverty_color_jitter(ms_img): + return poverty_rgb_color_transform( + ms_img, + transforms.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.1)) + + def ms_cutout(ms_img): + def _sample_uniform(a, b): + return torch.empty(1).uniform_(a, b).item() + + assert ms_img.shape[1] == ms_img.shape[2] + img_width = ms_img.shape[1] + cutout_width = _sample_uniform(0, img_width/2) + cutout_center_x = _sample_uniform(0, img_width) + cutout_center_y = _sample_uniform(0, img_width) + x0 = int(max(0, cutout_center_x - cutout_width/2)) + y0 = int(max(0, cutout_center_y - cutout_width/2)) + x1 = int(min(img_width, cutout_center_x + cutout_width/2)) + y1 = int(min(img_width, cutout_center_y + cutout_width/2)) + + # Fill with 0 because the data is already normalized to mean zero + ms_img[:, x0:x1, y0:y1] = 0 + return ms_img + + target_resolution = _get_target_resolution(config, dataset) + strong_transform_steps = copy.deepcopy(base_transform_steps) + strong_transform_steps.extend([ + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), shear=0.1, scale=(0.9, 1.1)), + transforms.Lambda(lambda ms_img: poverty_color_jitter(ms_img)), + transforms.Lambda(lambda ms_img: ms_cutout(ms_img)), + # transforms.Lambda(lambda ms_img: viz(ms_img)), + ]) + + return transforms.Compose(strong_transform_steps) + +def _get_target_resolution(config, dataset): + if config.target_resolution is not None: + return config.target_resolution + else: + return dataset.original_resolution + + +class MultipleTransforms(object): + """When multiple transformations of the same data need to be returned.""" + + def __init__(self, transformations): + self.transformations = transformations + + def __call__(self, x): + return tuple(transform(x) for transform in self.transformations) diff --git a/examples/utils.py b/examples/utils.py index 73fa1b12..85106bb7 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -7,12 +7,39 @@ import numpy as np import torch import pandas as pd +import re + +from torch.utils.data import DataLoader try: import wandb -except Exception as e: +except ImportError as e: + pass + +try: + from torch_geometric.data import Batch +except ImportError: pass + +def cross_entropy_with_logits_loss(input, soft_target): + """ + Implementation of CrossEntropy loss using a soft target. Extension of BCEWithLogitsLoss to MCE. + Normally, cross entropy loss is + \sum_j 1{j == y} -log \frac{e^{s_j}}{\sum_k e^{s_k}} = -log \frac{e^{s_y}}{\sum_k e^{s_k}} + Here we use + \sum_j P_j *-log \frac{e^{s_j}}{\sum_k e^{s_k}} + where 0 <= P_j <= 1 + Does not support fancy nn.CrossEntropy options (e.g. weight, size_average, ignore_index, reductions, etc.) + + Args: + - input (N, k): logits + - soft_target (N, k): targets for softmax(input); likely want to use class probabilities + Returns: + - losses (N, 1) + """ + return torch.sum(- soft_target * torch.nn.functional.log_softmax(input, 1), 1) + def update_average(prev_avg, prev_counts, curr_avg, curr_counts): denom = prev_counts + curr_counts if isinstance(curr_counts, torch.Tensor): @@ -59,10 +86,86 @@ def save_model(algorithm, epoch, best_val_metric, path): state['best_val_metric'] = best_val_metric torch.save(state, path) -def load(algorithm, path): - state = torch.load(path) - algorithm.load_state_dict(state['algorithm']) - return state['epoch'], state['best_val_metric'] +def load(module, path, device=None, tries=2): + """ + Handles loading weights saved from this repo/model into an algorithm/model. + Attempts to handle key mismatches between this module's state_dict and the loaded state_dict. + Args: + - module (torch module): module to load parameters for + - path (str): path to .pth file + - device: device to load tensors on + - tries: number of times to run the match_keys() function + """ + if device is not None: + state = torch.load(path, map_location=device) + else: + state = torch.load(path) + + # Loading from a saved WILDS Algorithm object + if 'algorithm' in state: + prev_epoch = state['epoch'] + best_val_metric = state['best_val_metric'] + state = state['algorithm'] + # Loading from a pretrained SwAV model + elif 'state_dict' in state: + state = state['state_dict'] + prev_epoch, best_val_metric = None, None + else: + prev_epoch, best_val_metric = None, None + + # If keys match perfectly, load_state_dict() will work + try: module.load_state_dict(state) + except: + # Otherwise, attempt to reconcile mismatched keys and load with strict=False + module_keys = module.state_dict().keys() + for _ in range(tries): + state = match_keys(state, list(module_keys)) + module.load_state_dict(state, strict=False) + leftover_state = {k:v for k,v in state.items() if k in list(state.keys()-module_keys)} + leftover_module_keys = module_keys - state.keys() + if len(leftover_state) == 0 or len(leftover_module_keys) == 0: break + state, module_keys = leftover_state, leftover_module_keys + if len(module_keys-state.keys()) > 0: print(f"Some module parameters could not be found in the loaded state: {module_keys-state.keys()}") + return prev_epoch, best_val_metric + +def match_keys(d, ref): + """ + Matches the format of keys between d (a dict) and ref (a list of keys). + + Helper function for situations where two algorithms share the same model, and we'd like to warm-start one + algorithm with the model of another. Some algorithms (e.g. FixMatch) save the featurizer, classifier within a sequential, + and thus the featurizer keys may look like 'model.module.0._' 'model.0._' or 'model.module.model.0._', + and the classifier keys may look like 'model.module.1._' 'model.1._' or 'model.module.model.1._' + while simple algorithms (e.g. ERM) use no sequential 'model._' + """ + # hard-coded exceptions + d = {re.sub('model.1.', 'model.classifier.', k): v for k,v in d.items()} + d = {k: v for k,v in d.items() if 'pre_classifier' not in k} # this causes errors + + # probe the proper transformation from d.keys() -> reference + # do this by splitting d's first key on '.' until we get a string that is a strict substring of something in ref + success = False + probe = list(d.keys())[0].split('.') + for i in range(len(probe)): + probe_str = '.'.join(probe[i:]) + matches = list(filter(lambda ref_k: len(ref_k) >= len(probe_str) and probe_str == ref_k[-len(probe_str):], ref)) + matches = list(filter(lambda ref_k: not 'layer' in ref_k, matches)) # handle resnet probe being too simple, e.g. 'weight' + if len(matches) == 0: continue + else: + success = True + append = [m[:-len(probe_str)] for m in matches] + remove = '.'.join(probe[:i]) + '.' + break + if not success: raise Exception("These dictionaries have irreconcilable keys") + + return_d = {} + for a in append: + for k,v in d.items(): return_d[re.sub(remove, a, k)] = v + + # hard-coded exceptions + if 'model.classifier.weight' in return_d: + return_d['model.1.weight'], return_d['model.1.bias'] = return_d['model.classifier.weight'], return_d['model.classifier.bias'] + return return_d def log_group_data(datasets, grouper, logger): for k, dataset in datasets.items(): @@ -171,9 +274,11 @@ def log_config(config, logger): logger.write('\n') def initialize_wandb(config): - name = config.dataset + '_' + config.algorithm + '_' + config.log_dir - wandb.init(name=name, - project=f"wilds") + if config.wandb_api_key_path is not None: + with open(config.wandb_api_key_path, "r") as f: + os.environ["WANDB_API_KEY"] = f.read().strip() + + wandb.init(**config.wandb_kwargs) wandb.config.update(config) def save_pred(y_pred, path_prefix): @@ -266,3 +371,35 @@ def remove(d): raise TypeError("remove_key must take in a dict") return {k: v for (k,v) in d.items() if k != key} return remove + +def concat_input(labeled_x, unlabeled_x): + if isinstance(labeled_x, torch.Tensor): + x_cat = torch.cat((labeled_x, unlabeled_x), dim=0) + elif isinstance(labeled_x, Batch): + labeled_x.y = None + x_cat = Batch.from_data_list([labeled_x, unlabeled_x]) + else: + raise TypeError("x must be Tensor or Batch") + return x_cat + +class InfiniteDataIterator: + """ + Adapted from https://github.com/thuml/Transfer-Learning-Library + + A data iterator that will never stop producing data + """ + def __init__(self, data_loader: DataLoader): + self.data_loader = data_loader + self.iter = iter(self.data_loader) + + def __next__(self): + try: + data = next(self.iter) + except StopIteration: + print("Reached the end, resetting data loader...") + self.iter = iter(self.data_loader) + data = next(self.iter) + return data + + def __len__(self): + return len(self.data_loader) diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index de21122c..00000000 --- a/mypy.ini +++ /dev/null @@ -1,8 +0,0 @@ -# Mypy is a static type checker for Python 3 and Python 2.7. If you sprinkle your code with type annotations, -# mypy can type check your code and find common bugs. As mypy is a static analyzer, or a lint-like tool, the type -# annotations are just hints for mypy and don’t interfere when running your program. You run your program with a -# standard Python interpreter, and the annotations are treated effectively as comments. -# See https://mypy.readthedocs.io/en/stable/index.html for more information. - -[mypy] -ignore_missing_imports = True diff --git a/requirements.dev.txt b/requirements.dev.txt deleted file mode 100644 index 683e5b5a..00000000 --- a/requirements.dev.txt +++ /dev/null @@ -1,2 +0,0 @@ -black==20.8b1 # Python code formatter -mypy==0.782 # Python static type checker diff --git a/setup.py b/setup.py index 53003e7a..8f7ef226 100644 --- a/setup.py +++ b/setup.py @@ -26,6 +26,7 @@ 'outdated>=0.2.0', 'pandas>=1.1.0', 'pillow>=7.2.0', + 'ogb>=1.2.6', 'pytz>=2020.4', 'torch>=1.7.0', 'torchvision>=0.8.2', @@ -34,7 +35,7 @@ 'scipy>=1.5.4' ], license='MIT', - packages=setuptools.find_packages(exclude=['dataset_preprocessing', 'examples', 'examples.models', 'examples.models.bert']), + packages=setuptools.find_packages(exclude=['dataset_preprocessing', 'examples', 'examples.models', 'examples.models.bert', 'examples.data_augmentation']), classifiers=[ 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Intended Audience :: Science/Research', diff --git a/wilds/__init__.py b/wilds/__init__.py index 639a9f4a..0d0adec2 100644 --- a/wilds/__init__.py +++ b/wilds/__init__.py @@ -16,6 +16,7 @@ additional_datasets = [ 'celebA', + 'domainnet', 'waterbirds', 'yelp', 'bdd100k', @@ -24,3 +25,22 @@ ] supported_datasets = benchmark_datasets + additional_datasets + +unlabeled_datasets = [ + 'amazon', + 'camelyon17', + 'domainnet', + 'civilcomments', + 'iwildcam', + 'ogb-molpcba', + 'poverty', + 'fmow', + 'globalwheat', +] + +unlabeled_splits = [ + 'train_unlabeled', + 'val_unlabeled', + 'test_unlabeled', + 'extra_unlabeled' +] \ No newline at end of file diff --git a/wilds/common/data_loaders.py b/wilds/common/data_loaders.py index d051878e..79215f19 100644 --- a/wilds/common/data_loaders.py +++ b/wilds/common/data_loaders.py @@ -1,5 +1,4 @@ import numpy as np -import torch from torch.utils.data import DataLoader from torch.utils.data.sampler import WeightedRandomSampler, SubsetRandomSampler from wilds.common.utils import get_counts, split_into_groups diff --git a/wilds/common/grouper.py b/wilds/common/grouper.py index 07dc92a3..aaa68df4 100644 --- a/wilds/common/grouper.py +++ b/wilds/common/grouper.py @@ -1,7 +1,11 @@ +import copy +import pdb +from typing import Dict, List, Union + import numpy as np import torch from wilds.common.utils import get_counts -from wilds.datasets.wilds_dataset import WILDSSubset +from wilds.datasets.wilds_dataset import WILDSDataset, WILDSSubset import warnings class Grouper: @@ -72,41 +76,75 @@ def __init__(self, dataset, groupby_fields): If groupby_fields is None, then all data points are assigned to group 0. Args: - - dataset (WILDSDataset) + - dataset (WILDSDataset or list of WILDSDataset) - groupby_fields (list of str) """ - if isinstance(dataset, WILDSSubset): - raise ValueError("Grouper should be defined for the full dataset, not a subset") - self.groupby_fields = groupby_fields + if isinstance(dataset, list): + if len(dataset) == 0: + raise ValueError("At least one dataset must be defined for Grouper.") + datasets: List[WILDSDataset] = dataset + else: + datasets: List[WILDSDataset] = [dataset] + + metadata_fields: List[str] = datasets[0].metadata_fields + # Build the largest metadata_map to see to check if all the metadata_maps are subsets of each other + largest_metadata_map: Dict[str, Union[List, np.ndarray]] = copy.deepcopy(datasets[0].metadata_map) + for i, dataset in enumerate(datasets): + if isinstance(dataset, WILDSSubset): + raise ValueError("Grouper should be defined with full dataset(s) and not subset(s).") + + # The first dataset was used to get the metadata_fields and initial metadata_map + if i == 0: + continue + if dataset.metadata_fields != metadata_fields: + raise ValueError( + f"The datasets passed in have different metadata_fields: {dataset.metadata_fields}. " + f"Expected: {metadata_fields}" + ) + + if dataset.metadata_map is None: continue + for field, values in dataset.metadata_map.items(): + n_overlap = min(len(values), len(largest_metadata_map[field])) + if not (np.asarray(values[:n_overlap]) == np.asarray(largest_metadata_map[field][:n_overlap])).all(): + raise ValueError("The metadata_maps of the datasets need to be ordered subsets of each other.") + + if len(values) > len(largest_metadata_map[field]): + largest_metadata_map[field] = values + + self.groupby_fields = groupby_fields if groupby_fields is None: self._n_groups = 1 else: - # We assume that the metadata fields are integers, - # so we can measure the cardinality of each field by taking its max + 1. - # Note that this might result in some empty groups. - self.groupby_field_indices = [i for (i, field) in enumerate(dataset.metadata_fields) if field in groupby_fields] + self.groupby_field_indices = [i for (i, field) in enumerate(metadata_fields) if field in groupby_fields] if len(self.groupby_field_indices) != len(self.groupby_fields): raise ValueError('At least one group field not found in dataset.metadata_fields') - grouped_metadata = dataset.metadata_array[:, self.groupby_field_indices] + + metadata_array = torch.cat([dataset.metadata_array for dataset in datasets]) + grouped_metadata = metadata_array[:, self.groupby_field_indices] if not isinstance(grouped_metadata, torch.LongTensor): grouped_metadata_long = grouped_metadata.long() if not torch.all(grouped_metadata == grouped_metadata_long): warnings.warn(f'CombinatorialGrouper: converting metadata with fields [{", ".join(groupby_fields)}] into long') grouped_metadata = grouped_metadata_long + for idx, field in enumerate(self.groupby_fields): min_value = grouped_metadata[:,idx].min() if min_value < 0: raise ValueError(f"Metadata for CombinatorialGrouper cannot have values less than 0: {field}, {min_value}") if min_value > 0: warnings.warn(f"Minimum metadata value for CombinatorialGrouper is not 0 ({field}, {min_value}). This will result in empty groups") - self.cardinality = 1 + torch.max( - grouped_metadata, dim=0)[0] + + # We assume that the metadata fields are integers, + # so we can measure the cardinality of each field by taking its max + 1. + # Note that this might result in some empty groups. + assert grouped_metadata.min() >= 0, "Group numbers cannot be negative." + self.cardinality = 1 + torch.max(grouped_metadata, dim=0)[0] cumprod = torch.cumprod(self.cardinality, dim=0) self._n_groups = cumprod[-1].item() self.factors_np = np.concatenate(([1], cumprod[:-1])) self.factors = torch.from_numpy(self.factors_np) - self.metadata_map = dataset.metadata_map + self.metadata_map = largest_metadata_map def metadata_to_group(self, metadata, return_counts=False): if self.groupby_fields is None: diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 5c93eda7..a4c97c75 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -1,6 +1,8 @@ +import copy + import numpy as np import torch -import torch.nn as nn + import torch.nn.functional as F from torchvision.ops.boxes import box_iou from torchvision.models.detection._utils import Matcher @@ -31,6 +33,133 @@ def multiclass_logits_to_pred(logits): def binary_logits_to_pred(logits): return (logits>0).long() +def pseudolabel_binary_logits(logits, confidence_threshold): + """ + Input: + logits (Tensor): Binary logits of size (batch_size, n_tasks). + If an entry is >0, it means the prediction for taht + (example, task) is positive. + confidence_threshold (float): In [0,1] + + Output: + unlabeled_y_pred (Tensor): Filtered version of logits, discarding any rows (examples) that + have no predictions with confidence above confidence_threshold. + unlabeled_y_pseudo (Tensor): Corresponding hard-pseudo-labeled version of logits. All + entries with confidence below confidence_threshold are set to + nan. All rows with no confident entries are discarded. + pseudolabels_kept_frac (float): Fraction of (examples, tasks) not set to nan or discarded. + mask (Tensor): Mask used to discard predictions with confidence under the confidence threshold. + """ + if len(logits.shape) != 2: + raise ValueError('Logits must be 2-dimensional.') + probs = 1 / (1 + torch.exp(-logits)) + mask = (torch.max(probs, 1-probs) >= confidence_threshold) + unlabeled_y_pseudo = (logits > 0).float() + unlabeled_y_pseudo[~mask] = float('nan') + pseudolabels_kept_frac = mask.sum() / mask.numel() # mask is bool, so no .mean() + example_mask = torch.any(~torch.isnan(unlabeled_y_pseudo), dim=1) + unlabeled_y_pseudo = unlabeled_y_pseudo[example_mask] + unlabeled_y_pred = logits[example_mask] + return unlabeled_y_pred, unlabeled_y_pseudo, pseudolabels_kept_frac, example_mask + +def pseudolabel_multiclass_logits(logits, confidence_threshold): + """ + Input: + logits (Tensor): Multi-class logits of size (batch_size, ..., n_classes). + confidence_threshold (float): In [0,1] + + Output: + unlabeled_y_pred (Tensor): Filtered version of logits, discarding any rows (examples) that + have no predictions with confidence above confidence_threshold. + unlabeled_y_pseudo (Tensor): Corresponding hard-pseudo-labeled version of logits. All + examples with confidence below confidence_threshold are discarded. + pseudolabels_kept_frac (float): Fraction of examples not discarded. + mask (Tensor): Mask used to discard predictions with confidence under the confidence threshold. + """ + mask = torch.max(F.softmax(logits, -1), -1)[0] >= confidence_threshold + unlabeled_y_pseudo = multiclass_logits_to_pred(logits) + unlabeled_y_pseudo = unlabeled_y_pseudo[mask] + unlabeled_y_pred = logits[mask] + pseudolabels_kept_frac = mask.sum() / mask.numel() # mask is bool, so no .mean() + return unlabeled_y_pred, unlabeled_y_pseudo, pseudolabels_kept_frac, mask + +def pseudolabel_identity(logits, confidence_threshold): + return logits, logits, 1, None + +def pseudolabel_detection(preds, confidence_threshold): + """ + Input: + preds (List): List of len batch_size. Each entry is a dict containing + the keys 'boxes', 'labels', 'scores', and 'losses' + ('losses' is empty) + confidence_threshold (float): In [0,1] + """ + preds, pseudolabels_kept_frac = _mask_pseudolabels_detection(preds, confidence_threshold) + unlabeled_y_pred = [ + { + 'boxes': pred['boxes'], + 'labels': pred['labels'], + 'scores': pred['scores'], + 'losses': pred['losses'], + } for pred in preds + ] + unlabeled_y_pseudo = [ + { + 'boxes': pred['boxes'], + 'labels': pred['labels'], + } for pred in preds + ] + + # Keep all examples even if they don't have any confident-enough predictions + # They will be treated as empty images + example_mask = torch.ones(len(preds), dtype=torch.bool) + return unlabeled_y_pred, unlabeled_y_pseudo, pseudolabels_kept_frac, example_mask + + +def pseudolabel_detection_discard_empty(preds, confidence_threshold): + """ + Input: + preds (List): List of len batch_size. Each entry is a dict containing + the keys 'boxes', 'labels', 'scores', and 'losses' + ('losses' is empty) + confidence_threshold (float): In [0,1] + """ + preds, pseudolabels_kept_frac = _mask_pseudolabels_detection(preds, confidence_threshold) + unlabeled_y_pred = [ + { + 'boxes': pred['boxes'], + 'labels': pred['labels'], + 'scores': pred['scores'], + 'losses': pred['losses'], + } for pred in preds if len(pred['labels']) > 0 + ] + unlabeled_y_pseudo = [ + { + 'boxes': pred['boxes'], + 'labels': pred['labels'], + } for pred in preds if len(pred['labels']) > 0 + ] + example_mask = torch.tensor([len(pred['labels']) > 0 for pred in preds]) + return unlabeled_y_pred, unlabeled_y_pseudo, pseudolabels_kept_frac, example_mask + + +def _mask_pseudolabels_detection(preds, confidence_threshold): + total_boxes = 0.0 + kept_boxes = 0.0 + + preds = copy.deepcopy(preds) + for pred in preds: + mask = (pred['scores'] >= confidence_threshold) + pred['boxes'] = pred['boxes'][mask] + pred['labels'] = pred['labels'][mask] + pred['scores'] = pred['scores'][mask] + total_boxes += len(mask) + kept_boxes += mask.sum() + + pseudolabels_kept_frac = kept_boxes / total_boxes + return preds, pseudolabels_kept_frac + + class Accuracy(ElementwiseMetric): def __init__(self, prediction_fn=None, name=None): self.prediction_fn = prediction_fn @@ -101,9 +230,6 @@ def _compute_group_wise(self, y_pred, y_true, g, n_groups): return group_metrics, group_counts, worst_group_metric - # def _compute(self, y_pred, y_true): - # return self._compute_flattened(y_pred, y_true) - def worst(self, metrics): return minimum(metrics) @@ -219,6 +345,7 @@ class DetectionAccuracy(ElementwiseMetric): Given a specific Intersection over union threshold, determine the accuracy achieved for a one-class detector """ + def __init__(self, iou_threshold=0.5, score_threshold=0.5, name=None): self.iou_threshold = iou_threshold self.score_threshold = score_threshold diff --git a/wilds/common/metrics/metric.py b/wilds/common/metrics/metric.py index 89582577..0fc25c1d 100644 --- a/wilds/common/metrics/metric.py +++ b/wilds/common/metrics/metric.py @@ -83,7 +83,10 @@ def compute(self, y_pred, y_true, return_dict=True): - results (dict): Dictionary of results, mapping metric.agg_metric_field to avg_metric """ if numel(y_true) == 0: - agg_metric = torch.tensor(0., device=y_true.device) + if hasattr(y_true, 'device'): + agg_metric = torch.tensor(0., device=y_true.device) + else: + agg_metric = torch.tensor(0.) else: agg_metric = self._compute(y_pred, y_true) if return_dict: diff --git a/wilds/datasets/amazon_dataset.py b/wilds/datasets/amazon_dataset.py index 81e633b7..11cc1e38 100644 --- a/wilds/datasets/amazon_dataset.py +++ b/wilds/datasets/amazon_dataset.py @@ -1,13 +1,16 @@ -import os, csv +import csv +import os +from typing import Any, Dict, List, Optional, Tuple, Union + import torch import pandas as pd import numpy as np + from wilds.datasets.wilds_dataset import WILDSDataset from wilds.common.utils import map_to_id_array from wilds.common.metrics.all_metrics import Accuracy from wilds.common.grouper import CombinatorialGrouper -NOT_IN_DATASET = -1 class AmazonDataset(WILDSDataset): """ @@ -39,66 +42,106 @@ class AmazonDataset(WILDSDataset): https://nijianmo.github.io/amazon/index.html Original publication: - @inproceedings{ni2019justifying, - author = {J. Ni and J. Li and J. McAuley}, - booktitle = {Empirical Methods in Natural Language Processing (EMNLP)}, - pages = {188--197}, - title = {Justifying recommendations using distantly-labeled reviews and fine-grained aspects}, - year = {2019}, - } + @inproceedings{ni2019justifying, + author = {J. Ni and J. Li and J. McAuley}, + booktitle = {Empirical Methods in Natural Language Processing (EMNLP)}, + pages = {188--197}, + title = {Justifying recommendations using distantly-labeled reviews and fine-grained aspects}, + year = {2019}, + } License: None. However, the original authors request that the data be used for research purposes only. """ - _dataset_name = 'amazon' - _versions_dict = { - '1.0': { - 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x60237058e01749cda7b0701c2bd01420/contents/blob/', - 'compressed_size': 4_066_541_568 + + _NOT_IN_DATASET: int = -1 + + _dataset_name: str = "amazon" + _versions_dict: Dict[str, Dict[str, Union[str, int]]] = { + "1.0": { + "download_url": "https://worksheets.codalab.org/rest/bundles/0x60237058e01749cda7b0701c2bd01420/contents/blob/", + "compressed_size": 4_066_541_568, + }, + "2.0": { + "download_url": "https://worksheets.codalab.org/rest/bundles/0xadbf6198d3a64bdc96fb64d6966b5e79/contents/blob/", + "compressed_size": 1_987_523_759, }, - '2.0': { - 'download_url': 'https://worksheets.codalab.org/rest/bundles/0xadbf6198d3a64bdc96fb64d6966b5e79/contents/blob/', - 'compressed_size': 1_987_523_759 + "2.1": { + "download_url": "https://worksheets.codalab.org/rest/bundles/0xe3ed909786d34ee79d430d065582aa29/contents/blob/", + "compressed_size": 1_989_805_589, }, } - def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): - self._version = version - # the official split is the user split - if split_scheme=='official': - split_scheme = 'user' - self._split_scheme = split_scheme - self._y_type = 'long' - self._y_size = 1 - self._n_classes = 5 - # path - self._data_dir = self.initialize_data_dir(root_dir, download) + def __init__( + self, + version: str = None, + root_dir: str = "data", + download: bool = False, + split_scheme: str = "official", + ): + # Dataset information + self._version: Optional[str] = version + # The official split is to split by users + self._split_scheme: str = "user" if split_scheme == "official" else split_scheme + self._y_type: str = "long" + self._y_size: int = 1 + self._n_classes: int = 5 # One for each star rating + # Path of the dataset + self._data_dir: str = self.initialize_data_dir(root_dir, download) + # Load data - data_df = pd.read_csv(os.path.join(self.data_dir, 'reviews.csv'), - dtype={'reviewerID':str, 'asin':str, 'reviewTime':str,'unixReviewTime':int, - 'reviewText':str,'summary':str,'verified':bool,'category':str, 'reviewYear':int}, - keep_default_na=False, na_values=[], quoting=csv.QUOTE_NONNUMERIC) - split_df = pd.read_csv(os.path.join(self.data_dir, 'splits', f'{self.split_scheme}.csv')) - is_in_dataset = split_df['split']!=NOT_IN_DATASET + data_df: pd.DataFrame = pd.read_csv( + os.path.join(self.data_dir, "reviews.csv"), + dtype={ + "reviewerID": str, + "asin": str, + "reviewTime": str, + "unixReviewTime": int, + "reviewText": str, + "summary": str, + "verified": bool, + "category": str, + "reviewYear": int, + }, + keep_default_na=False, + na_values=[], + quoting=csv.QUOTE_NONNUMERIC, + ) + split_df: pd.DataFrame = pd.read_csv( + os.path.join(self.data_dir, "splits", f"{self.split_scheme}.csv") + ) + is_in_dataset: bool = split_df["split"] != AmazonDataset._NOT_IN_DATASET split_df = split_df[is_in_dataset] data_df = data_df[is_in_dataset] # Get arrays - self._split_array = split_df['split'].values - self._input_array = list(data_df['reviewText']) + self._split_array: List[str] = split_df["split"].values + self._input_array: List[str] = list(data_df["reviewText"]) # Get metadata - self._metadata_fields, self._metadata_array, self._metadata_map = self.load_metadata(data_df, self.split_array) + ( + self._metadata_fields, + self._metadata_array, + self._metadata_map, + ) = self.load_metadata(data_df, self.split_array) # Get y from metadata - self._y_array = getattr(self.metadata_array[:,self.metadata_fields.index('y')], self._y_type)() + self._y_array = getattr( + self.metadata_array[:, self.metadata_fields.index("y")], self._y_type + )() # Set split info self.initialize_split_dicts() # eval self.initialize_eval_grouper() - super().__init__(root_dir, download, split_scheme) + super().__init__(root_dir, download, self._split_scheme) - def get_input(self, idx): + def get_input(self, idx) -> str: return self._input_array[idx] - def eval(self, y_pred, y_true, metadata, prediction_fn=None): + def eval( + self, + y_pred: torch.Tensor, + y_true: torch.LongTensor, + metadata: torch.Tensor, + prediction_fn=None, + ) -> Tuple[Dict[str, Any], str]: """ Computes all evaluation metrics. Args: @@ -107,31 +150,36 @@ def eval(self, y_pred, y_true, metadata, prediction_fn=None): are predicted labels. - y_true (LongTensor): Ground-truth labels - metadata (Tensor): Metadata - - prediction_fn (function): A function that turns y_pred into predicted labels + - prediction_fn (function): A function that turns y_pred into predicted labels Output: - results (dictionary): Dictionary of evaluation metrics - results_str (str): String summarizing the evaluation metrics """ - metric = Accuracy(prediction_fn=prediction_fn) - if self.split_scheme=='user': + metric: Accuracy = Accuracy(prediction_fn=prediction_fn) + + if self.split_scheme == "user": # first compute groupwise accuracies - g = self._eval_grouper.metadata_to_group(metadata) - results = { + g: torch.Tensor= self._eval_grouper.metadata_to_group(metadata) + results: Dict[str, Any] = { **metric.compute(y_pred, y_true), - **metric.compute_group_wise(y_pred, y_true, g, self._eval_grouper.n_groups) + **metric.compute_group_wise( + y_pred, y_true, g, self._eval_grouper.n_groups + ), } - accs = [] + + accs: List[float] = [] for group_idx in range(self._eval_grouper.n_groups): - group_str = self._eval_grouper.group_field_str(group_idx) - group_metric = results.pop(metric.group_metric_field(group_idx)) - group_counts = results.pop(metric.group_count_field(group_idx)) - results[f'{metric.name}_{group_str}'] = group_metric - results[f'count_{group_str}'] = group_counts - if group_counts>0: + group_str: str = self._eval_grouper.group_field_str(group_idx) + group_metric: float = results.pop(metric.group_metric_field(group_idx)) + group_counts: int = results.pop(metric.group_count_field(group_idx)) + results[f"{metric.name}_{group_str}"] = group_metric + results[f"count_{group_str}"] = group_counts + if group_counts > 0: accs.append(group_metric) + accs = np.array(accs) - results['10th_percentile_acc'] = np.percentile(accs, 10) - results[f'{metric.worst_group_metric_field}'] = metric.worst(accs) + results["10th_percentile_acc"] = np.percentile(accs, 10) + results[f"{metric.worst_group_metric_field}"] = metric.worst(accs) results_str = ( f"Average {metric.name}: {results[metric.agg_metric_field]:.3f}\n" f"10th percentile {metric.name}: {results['10th_percentile_acc']:.3f}\n" @@ -140,55 +188,78 @@ def eval(self, y_pred, y_true, metadata, prediction_fn=None): return results, results_str else: return self.standard_group_eval( - metric, - self._eval_grouper, - y_pred, y_true, metadata) + metric, self._eval_grouper, y_pred, y_true, metadata + ) def initialize_split_dicts(self): - if self.split_scheme in ('user', 'time') or self.split_scheme.endswith('_generalization'): #category generalization - self._split_dict = {'train': 0, 'val': 1, 'id_val': 2, 'test': 3, 'id_test': 4} - self._split_names = {'train': 'Train', 'val': 'Validation (OOD)', 'id_val': 'Validation (ID)', 'test':'Test (OOD)', 'id_test': 'Test (ID)'} - elif self.split_scheme in ('category_subpopulation', ): - # use defaults - pass - elif self.split_scheme.endswith('_baseline'): - # use defaults + if self.split_scheme in ("user", "time") or self.split_scheme.endswith( + "_generalization" + ): + # Category generalization + self._split_dict: Dict[str, int] = { + "train": 0, + "val": 1, + "id_val": 2, + "test": 3, + "id_test": 4, + } + self._split_names: Dict[str, str] = { + "train": "Train", + "val": "Validation (OOD)", + "id_val": "Validation (ID)", + "test": "Test (OOD)", + "id_test": "Test (ID)", + } + self._source_domain_splits = [0, 2, 4] + elif ( + self.split_scheme == "category_subpopulation" + or self.split_scheme.endswith("_baseline") + ): + # Use defaults pass else: - raise ValueError(f'Split scheme {self.split_scheme} not recognized') + raise ValueError(f"Split scheme {self.split_scheme} is not recognized.") - def load_metadata(self, data_df, split_array): + def load_metadata( + self, data_df, split_array + ) -> Tuple[List[str], torch.Tensor, Dict]: # Get metadata - columns = ['reviewerID','asin','category','reviewYear', 'overall'] - metadata_fields = ['user', 'product', 'category', 'year','y'] - metadata_df = data_df[columns].copy() + columns: List[str] = ["reviewerID", "asin", "category", "reviewYear", "overall"] + metadata_fields: List[str] = ["user", "product", "category", "year", "y"] + metadata_df: pd.DataFrame = data_df[columns].copy() metadata_df.columns = metadata_fields + sort_idx = np.argsort(split_array) ordered_maps = {} - for field in ['user', 'product', 'category']: + for field in ["user", "product", "category"]: # map to IDs in the order of split values ordered_maps[field] = pd.unique(metadata_df.iloc[sort_idx][field]) - ordered_maps['y'] = range(1,6) - ordered_maps['year'] = range(metadata_df['year'].min(), metadata_df['year'].max()+1) + ordered_maps["y"] = range(1, 6) + ordered_maps["year"] = range( + metadata_df["year"].min(), metadata_df["year"].max() + 1 + ) metadata_map, metadata = map_to_id_array(metadata_df, ordered_maps) - return metadata_fields, torch.from_numpy(metadata.astype('long')), metadata_map + return metadata_fields, torch.from_numpy(metadata.astype("long")), metadata_map def initialize_eval_grouper(self): - if self.split_scheme=='user': + if self.split_scheme == "user": self._eval_grouper = CombinatorialGrouper( - dataset=self, - groupby_fields=['user']) - elif self.split_scheme.endswith('generalization') or self.split_scheme=='category_subpopulation': + dataset=self, groupby_fields=["user"] + ) + elif ( + self.split_scheme.endswith("generalization") + or self.split_scheme == "category_subpopulation" + ): self._eval_grouper = CombinatorialGrouper( - dataset=self, - groupby_fields=['category']) - elif self.split_scheme in ('time', 'time_baseline'): + dataset=self, groupby_fields=["category"] + ) + elif self.split_scheme in ("time", "time_baseline"): self._eval_grouper = CombinatorialGrouper( - dataset=self, - groupby_fields=['year']) - elif self.split_scheme.endswith('_baseline'): # user baselines + dataset=self, groupby_fields=["year"] + ) + elif self.split_scheme.endswith("_baseline"): # user baselines self._eval_grouper = CombinatorialGrouper( - dataset=self, - groupby_fields=['user']) + dataset=self, groupby_fields=["user"] + ) else: - raise ValueError(f'Split scheme {self.split_scheme} not recognized') + raise ValueError(f"Split scheme {self.split_scheme} not recognized.") diff --git a/wilds/datasets/camelyon17_dataset.py b/wilds/datasets/camelyon17_dataset.py index b84dd6d2..e406cf8c 100644 --- a/wilds/datasets/camelyon17_dataset.py +++ b/wilds/datasets/camelyon17_dataset.py @@ -7,6 +7,13 @@ from wilds.common.grouper import CombinatorialGrouper from wilds.common.metrics.all_metrics import Accuracy + +# Note that the hospital numbering here is different from what's in the paper, +# where to avoid confusing readers we used a 1-indexed scheme and just labeled the test hospital as 5. +# Here, the numbers are 0-indexed. +TEST_CENTER = 2 +VAL_CENTER = 1 + class Camelyon17Dataset(WILDSDataset): """ The CAMELYON17-WILDS histopathology dataset. @@ -74,13 +81,6 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' for patient, node, x, y in self._metadata_df.loc[:, ['patient', 'node', 'x_coord', 'y_coord']].itertuples(index=False, name=None)] - # Extract splits - # Note that the hospital numbering here is different from what's in the paper, - # where to avoid confusing readers we used a 1-indexed scheme and just labeled the test hospital as 5. - # Here, the numbers are 0-indexed. - test_center = 2 - val_center = 1 - self._split_dict = { 'train': 0, 'id_val': 1, @@ -93,10 +93,12 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' 'test': 'Test', 'val': 'Validation (OOD)', } + + # Extract splits centers = self._metadata_df['center'].values.astype('long') num_centers = int(np.max(centers)) + 1 - val_center_mask = (self._metadata_df['center'] == val_center) - test_center_mask = (self._metadata_df['center'] == test_center) + val_center_mask = (self._metadata_df['center'] == VAL_CENTER) + test_center_mask = (self._metadata_df['center'] == TEST_CENTER) self._metadata_df.loc[val_center_mask, 'split'] = self.split_dict['val'] self._metadata_df.loc[test_center_mask, 'split'] = self.split_dict['test'] diff --git a/wilds/datasets/civilcomments_dataset.py b/wilds/datasets/civilcomments_dataset.py index c4d6bb8b..a40f9473 100644 --- a/wilds/datasets/civilcomments_dataset.py +++ b/wilds/datasets/civilcomments_dataset.py @@ -59,7 +59,9 @@ class CivilCommentsDataset(WILDSDataset): _versions_dict = { '1.0': { 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x8cd3de0634154aeaad2ee6eb96723c6e/contents/blob/', - 'compressed_size': 90_644_480}} + 'compressed_size': 90_644_480 + } + } def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): self._version = version diff --git a/wilds/datasets/domainnet_dataset.py b/wilds/datasets/domainnet_dataset.py new file mode 100644 index 00000000..beba1e5c --- /dev/null +++ b/wilds/datasets/domainnet_dataset.py @@ -0,0 +1,567 @@ +import csv +import os +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import pandas as pd +from PIL import Image + +from wilds.common.utils import map_to_id_array +from wilds.common.metrics.all_metrics import Accuracy +from wilds.common.grouper import CombinatorialGrouper +from wilds.datasets.wilds_dataset import WILDSDataset + +DOMAIN_NET_CATEGORIES = [ + "aircraft_carrier", + "airplane", + "alarm_clock", + "ambulance", + "angel", + "animal_migration", + "ant", + "anvil", + "apple", + "arm", + "asparagus", + "axe", + "backpack", + "banana", + "bandage", + "barn", + "baseball", + "baseball_bat", + "basket", + "basketball", + "bat", + "bathtub", + "beach", + "bear", + "beard", + "bed", + "bee", + "belt", + "bench", + "bicycle", + "binoculars", + "bird", + "birthday_cake", + "blackberry", + "blueberry", + "book", + "boomerang", + "bottlecap", + "bowtie", + "bracelet", + "brain", + "bread", + "bridge", + "broccoli", + "broom", + "bucket", + "bulldozer", + "bus", + "bush", + "butterfly", + "cactus", + "cake", + "calculator", + "calendar", + "camel", + "camera", + "camouflage", + "campfire", + "candle", + "cannon", + "canoe", + "car", + "carrot", + "castle", + "cat", + "ceiling_fan", + "cello", + "cell_phone", + "chair", + "chandelier", + "church", + "circle", + "clarinet", + "clock", + "cloud", + "coffee_cup", + "compass", + "computer", + "cookie", + "cooler", + "couch", + "cow", + "crab", + "crayon", + "crocodile", + "crown", + "cruise_ship", + "cup", + "diamond", + "dishwasher", + "diving_board", + "dog", + "dolphin", + "donut", + "door", + "dragon", + "dresser", + "drill", + "drums", + "duck", + "dumbbell", + "ear", + "elbow", + "elephant", + "envelope", + "eraser", + "eye", + "eyeglasses", + "face", + "fan", + "feather", + "fence", + "finger", + "fire_hydrant", + "fireplace", + "firetruck", + "fish", + "flamingo", + "flashlight", + "flip_flops", + "floor_lamp", + "flower", + "flying_saucer", + "foot", + "fork", + "frog", + "frying_pan", + "garden", + "garden_hose", + "giraffe", + "goatee", + "golf_club", + "grapes", + "grass", + "guitar", + "hamburger", + "hammer", + "hand", + "harp", + "hat", + "headphones", + "hedgehog", + "helicopter", + "helmet", + "hexagon", + "hockey_puck", + "hockey_stick", + "horse", + "hospital", + "hot_air_balloon", + "hot_dog", + "hot_tub", + "hourglass", + "house", + "house_plant", + "hurricane", + "ice_cream", + "jacket", + "jail", + "kangaroo", + "key", + "keyboard", + "knee", + "knife", + "ladder", + "lantern", + "laptop", + "leaf", + "leg", + "light_bulb", + "lighter", + "lighthouse", + "lightning", + "line", + "lion", + "lipstick", + "lobster", + "lollipop", + "mailbox", + "map", + "marker", + "matches", + "megaphone", + "mermaid", + "microphone", + "microwave", + "monkey", + "moon", + "mosquito", + "motorbike", + "mountain", + "mouse", + "moustache", + "mouth", + "mug", + "mushroom", + "nail", + "necklace", + "nose", + "ocean", + "octagon", + "octopus", + "onion", + "oven", + "owl", + "paintbrush", + "paint_can", + "palm_tree", + "panda", + "pants", + "paper_clip", + "parachute", + "parrot", + "passport", + "peanut", + "pear", + "peas", + "pencil", + "penguin", + "piano", + "pickup_truck", + "picture_frame", + "pig", + "pillow", + "pineapple", + "pizza", + "pliers", + "police_car", + "pond", + "pool", + "popsicle", + "postcard", + "potato", + "power_outlet", + "purse", + "rabbit", + "raccoon", + "radio", + "rain", + "rainbow", + "rake", + "remote_control", + "rhinoceros", + "rifle", + "river", + "roller_coaster", + "rollerskates", + "sailboat", + "sandwich", + "saw", + "saxophone", + "school_bus", + "scissors", + "scorpion", + "screwdriver", + "sea_turtle", + "see_saw", + "shark", + "sheep", + "shoe", + "shorts", + "shovel", + "sink", + "skateboard", + "skull", + "skyscraper", + "sleeping_bag", + "smiley_face", + "snail", + "snake", + "snorkel", + "snowflake", + "snowman", + "soccer_ball", + "sock", + "speedboat", + "spider", + "spoon", + "spreadsheet", + "square", + "squiggle", + "squirrel", + "stairs", + "star", + "steak", + "stereo", + "stethoscope", + "stitches", + "stop_sign", + "stove", + "strawberry", + "streetlight", + "string_bean", + "submarine", + "suitcase", + "sun", + "swan", + "sweater", + "swing_set", + "sword", + "syringe", + "table", + "teapot", + "teddy-bear", + "telephone", + "television", + "tennis_racquet", + "tent", + "The_Eiffel_Tower", + "The_Great_Wall_of_China", + "The_Mona_Lisa", + "tiger", + "toaster", + "toe", + "toilet", + "tooth", + "toothbrush", + "toothpaste", + "tornado", + "tractor", + "traffic_light", + "train", + "tree", + "triangle", + "trombone", + "truck", + "trumpet", + "t-shirt", + "umbrella", + "underwear", + "van", + "vase", + "violin", + "washing_machine", + "watermelon", + "waterslide", + "whale", + "wheel", + "windmill", + "wine_bottle", + "wine_glass", + "wristwatch", + "yoga", + "zebra", + "zigzag", +] +DOMAIN_NET_DOMAINS = ["clipart", "infograph", "painting", "quickdraw", "real", "sketch"] +SENTRY_DOMAINS = ["clipart", "painting", "real", "sketch"] + + +class DomainNetDataset(WILDSDataset): + """ + DomainNet dataset. + 586,576 images in 345 categories (airplane, ball, cup, etc.) across 6 domains (clipart, infograph, painting, + quickdraw, real and sketch). + + Supported `split_scheme`: + 'official': use the official split from DomainNet + + Input (x): + 224 x 224 x 3 RGB image. + + Label (y): + y is one of the 345 categories in DomainNet. + + Metadata: + None + + Website: + http://ai.bu.edu/M3SDA + + Original publication: + + @inproceedings{peng2019moment, + title={Moment matching for multi-source domain adaptation}, + author={Peng, Xingchao and Bai, Qinxun and Xia, Xide and Huang, Zijun and Saenko, Kate and Wang, Bo}, + booktitle={Proceedings of the IEEE International Conference on Computer Vision}, + pages={1406--1415}, + year={2019} + } + + SENTRY publication: + + @article{prabhu2020sentry + author = {Prabhu, Viraj and Khare, Shivam and Kartik, Deeksha and Hoffman, Judy}, + title = {SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised Domain Adaptation}, + year = {2020}, + journal = {arXiv preprint: 2012.11460}, + } + + Fair Use Notice: + "This dataset contains some copyrighted material whose use has not been specifically authorized by the copyright owners. + In an effort to advance scientific research, we make this material available for academic research. We believe this + constitutes a fair use of any such copyrighted material as provided for in section 107 of the US Copyright Law. + In accordance with Title 17 U.S.C. Section 107, the material on this site is distributed without profit for + non-commercial research and educational purposes. For more information on fair use please click here. If you wish + to use copyrighted material on this site or in our dataset for purposes of your own that go beyond non-commercial + research and academic purposes, you must obtain permission directly from the copyright owner." + """ + + _dataset_name: str = "domainnet" + _versions_dict: Dict[str, Dict[str, Union[str, int]]] = { + "1.0": { + "download_url": "https://worksheets.codalab.org/rest/bundles/0x0b8ca76eef384b98b879d0c8c4681a32/contents/blob/", + "compressed_size": 19_255_770_459, + }, + } + + def __init__( + self, + version: str = None, + root_dir: str = "data", + download: bool = False, + split_scheme: str = "official", + source_domain: str = "sketch", + target_domain: str = "real", + use_sentry: bool = False, + ): + # Dataset information + self._version: Optional[str] = version + self._split_scheme: str = split_scheme + self._original_resolution = (224, 224) + self._y_type: str = "long" + self._y_size: int = 1 + # Path of the dataset + self._data_dir: str = self.initialize_data_dir(root_dir, download) + + # The original dataset contains 345 categories. The SENTRY version contains 40 categories. + if use_sentry: + assert source_domain in SENTRY_DOMAINS + assert target_domain in SENTRY_DOMAINS + print("Using the SENTRY version of DomainNet...") + metadata_filename = "sentry_metadata.csv" + self._n_classes = 40 + else: + metadata_filename = "metadata.csv" + self._n_classes = 345 + + metadata_df: pd.DataFrame = pd.read_csv( + os.path.join(self.data_dir, metadata_filename), + dtype={ + "image_path": str, + "domain": str, + "split": str, + "category": str, + "y": int, + }, + keep_default_na=False, + na_values=[], + quoting=csv.QUOTE_NONNUMERIC, + ) + source_metadata_df = metadata_df.loc[metadata_df["domain"] == source_domain] + target_metadata_df = metadata_df.loc[metadata_df["domain"] == target_domain] + metadata_df = pd.concat([source_metadata_df, target_metadata_df]) + + self._input_image_paths = metadata_df["image_path"].values + self._y_array = torch.from_numpy(metadata_df["y"].values).type(torch.LongTensor) + self.initialize_split_dicts() + self.initialize_split_array(metadata_df, source_domain, target_domain) + + # Populate metadata fields + self._metadata_fields = ["domain", "category", "y"] + metadata_df = metadata_df[self._metadata_fields] + possible_metadata_values = { + "domain": DOMAIN_NET_DOMAINS, + "category": DOMAIN_NET_CATEGORIES, + "y": range(self._n_classes), + } + self._metadata_map, metadata = map_to_id_array( + metadata_df, possible_metadata_values + ) + self._metadata_array = torch.from_numpy(metadata.astype("long")) + + # Eval + self.initialize_eval_grouper() + super().__init__(root_dir, download, self._split_scheme) + + def get_input(self, idx) -> str: + img_path = os.path.join(self.data_dir, self._input_image_paths[idx]) + img = Image.open(img_path).convert("RGB") + return img + + def eval( + self, + y_pred: torch.Tensor, + y_true: torch.LongTensor, + metadata: torch.Tensor, + prediction_fn=None, + ) -> Tuple[Dict[str, Any], str]: + """ + Computes all evaluation metrics. + Args: + - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor). + But they can also be other model outputs such that prediction_fn(y_pred) + are predicted labels. + - y_true (LongTensor): Ground-truth labels + - metadata (Tensor): Metadata + - prediction_fn (function): A function that turns y_pred into predicted labels + Output: + - results (dictionary): Dictionary of evaluation metrics + - results_str (str): String summarizing the evaluation metrics + """ + metric: Accuracy = Accuracy(prediction_fn=prediction_fn) + return self.standard_group_eval( + metric, self._eval_grouper, y_pred, y_true, metadata + ) + + def initialize_split_dicts(self): + if self.split_scheme == "official": + self._split_dict: Dict[str, int] = { + "train": 0, + "val": 1, + "test": 2, + "id_test": 3, + } + self._split_names: Dict[str, str] = { + "train": "Train", + "val": "Validation (OOD)", + "test": "Test (OOD)", + "id_test": "Test (ID)", + } + self._source_domain_splits = [0, 3] + else: + raise ValueError(f"Split scheme {self.split_scheme} is not recognized.") + + def initialize_split_array(self, metadata_df, source_domain, target_domain): + def get_split(row): + if row["domain"] == source_domain: + if row["split"] == "train": + return 0 + elif row["split"] == "test": + return 3 + elif row["domain"] == target_domain: + if row["split"] == "train": + return 1 + elif row["split"] == "test": + return 2 + else: + raise ValueError( + f"Domain should be one of {source_domain}, {target_domain}" + ) + + self._split_array = metadata_df.apply( + lambda row: get_split(row), axis=1 + ).to_numpy() + + def initialize_eval_grouper(self): + if self.split_scheme == "official": + self._eval_grouper = CombinatorialGrouper( + dataset=self, groupby_fields=["category"] + ) + else: + raise ValueError(f"Split scheme {self.split_scheme} not recognized.") diff --git a/wilds/datasets/fmow_dataset.py b/wilds/datasets/fmow_dataset.py index 21b9e099..b96c849d 100644 --- a/wilds/datasets/fmow_dataset.py +++ b/wilds/datasets/fmow_dataset.py @@ -70,6 +70,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._split_dict = {'train': 0, 'id_val': 1, 'id_test': 2, 'val': 3, 'test': 4} self._split_names = {'train': 'Train', 'id_val': 'ID Val', 'id_test': 'ID Test', 'val': 'OOD Val', 'test': 'OOD Test'} + self._source_domain_splits = [0, 1, 2] self.oracle_training_set = False if split_scheme == 'official': diff --git a/wilds/datasets/globalwheat_dataset.py b/wilds/datasets/globalwheat_dataset.py old mode 100644 new mode 100755 index 7e3eefa6..6f29e084 --- a/wilds/datasets/globalwheat_dataset.py +++ b/wilds/datasets/globalwheat_dataset.py @@ -94,14 +94,14 @@ 'VLB', 'VSC', 'Wad Medani', + 'Eschikon' ] STAGES = [ 'Filling', - 'Filling - Ripening', + 'Filling-Ripening', 'multiple', 'Post-flowering', - 'Post-Flowering', 'Ripening', ] @@ -109,7 +109,6 @@ class GlobalWheatDataset(WILDSDataset): """ The GlobalWheat-WILDS wheat head localization dataset. This is a modified version of the original Global Wheat Head Dataset 2021. - Supported `split_scheme`: - 'official' - 'official_with_subsampled_test' @@ -152,7 +151,10 @@ class GlobalWheatDataset(WILDSDataset): _versions_dict = { '1.0': { 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x443fbcb18eeb4f80b5ea4a9f77795168/contents/blob/', - 'compressed_size': 10_286_120_960} + 'compressed_size': 10_286_120_960}, + '1.1': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x36e16907b7254571b708b725f8beae52/contents/blob/', + 'compressed_size': 10_284_949_504}, } def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): diff --git a/wilds/datasets/poverty_dataset.py b/wilds/datasets/poverty_dataset.py index 8fcb5245..a2c16dd3 100644 --- a/wilds/datasets/poverty_dataset.py +++ b/wilds/datasets/poverty_dataset.py @@ -155,6 +155,7 @@ def __init__(self, version=None, root_dir='data', download=False, cache_size=100): self._version = version self._data_dir = self.initialize_data_dir(root_dir, download) + self._original_resolution = (224, 224) self._split_dict = {'train': 0, 'id_val': 1, 'id_test': 2, 'val': 3, 'test': 4} self._split_names = {'train': 'Train', 'id_val': 'ID Val', 'id_test': 'ID Test', 'val': 'OOD Val', 'test': 'OOD Test'} diff --git a/wilds/datasets/py150_dataset.py b/wilds/datasets/py150_dataset.py index 1aade110..98deb0d5 100644 --- a/wilds/datasets/py150_dataset.py +++ b/wilds/datasets/py150_dataset.py @@ -124,7 +124,7 @@ def eval(self, y_pred, y_true, metadata, prediction_fn=None): #y_pred: [n_samples, seqlen-1] #y_true: [n_samples, seqlen-1] - tok_type = metadata[:, 1:] #[n_samples, seqlen-1] + tok_type = metadata[:, 1:-1] #[n_samples, seqlen-1]. must splice off coarse domain info at end. results = {} results_str = "" diff --git a/wilds/datasets/unlabeled/__init__.py b/wilds/datasets/unlabeled/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/wilds/datasets/unlabeled/amazon_unlabeled_dataset.py b/wilds/datasets/unlabeled/amazon_unlabeled_dataset.py new file mode 100644 index 00000000..8e2c3417 --- /dev/null +++ b/wilds/datasets/unlabeled/amazon_unlabeled_dataset.py @@ -0,0 +1,159 @@ +import csv +import os +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import pandas as pd +import numpy as np + +from wilds.datasets.unlabeled.wilds_unlabeled_dataset import WILDSUnlabeledDataset +from wilds.common.utils import map_to_id_array + + +class AmazonUnlabeledDataset(WILDSUnlabeledDataset): + """ + Unlabeled Amazon-WILDS dataset. + This is a modified version of the 2018 Amazon Reviews dataset. + + Supported `split_scheme`: + 'official': official split, which is equivalent to 'user' + 'user': shifts to unseen reviewers + 'time': shifts from reviews written before 2013 to reviews written after 2013 + 'category_subpopulation': the training distribution is a random subset following the natural distribution, and the + evaluation splits include each category uniformly (to the extent it is possible) + '*_generalization': domain generalization setting where the domains are categories. train categories vary. + '*_baseline': oracle baseline splits for user or time shifts + + Input (x): + Review text of maximum token length of 512. + + Metadata: + reviewer: reviewer ID + year: year in which the review was written + category: product category + product: product ID + + Website: + https://nijianmo.github.io/amazon/index.html + + Original publication: + @inproceedings{ni2019justifying, + author = {J. Ni and J. Li and J. McAuley}, + booktitle = {Empirical Methods in Natural Language Processing (EMNLP)}, + pages = {188--197}, + title = {Justifying recommendations using distantly-labeled reviews and fine-grained aspects}, + year = {2019}, + } + + License: + None. However, the original authors request that the data be used for research purposes only. + """ + + _NOT_IN_DATASET: int = -1 + + _dataset_name: str = "amazon_unlabeled" + _versions_dict: Dict[str, Dict[str, Union[str, int]]] = { + "1.0": { + "download_url": "https://worksheets.codalab.org/rest/bundles/0xe3ed909786d34ee79d430d065582aa29/contents/blob/", + "compressed_size": 1_989_805_589, + "equivalent_dataset": "amazon_v2.1", + }, + } + + def __init__( + self, + version: str = None, + root_dir: str = "data", + download: bool = False, + split_scheme: str = "official", + ): + # Dataset information + self._version: Optional[str] = version + # The official split is to split by users + self._split_scheme: str = "user" if split_scheme == "official" else split_scheme + # Path of the dataset + self._data_dir: str = self.initialize_data_dir(root_dir, download) + + # Load data + data_df: pd.DataFrame = pd.read_csv( + os.path.join(self.data_dir, "reviews.csv"), + dtype={ + "reviewerID": str, + "asin": str, + "reviewTime": str, + "unixReviewTime": int, + "reviewText": str, + "summary": str, + "verified": bool, + "category": str, + "reviewYear": int, + }, + keep_default_na=False, + na_values=[], + quoting=csv.QUOTE_NONNUMERIC, + ) + split_df: pd.DataFrame = pd.read_csv( + os.path.join(self.data_dir, "splits", f"{self.split_scheme}.csv") + ) + is_in_dataset: bool = ( + split_df["split"] != AmazonUnlabeledDataset._NOT_IN_DATASET + ) + split_df = split_df[is_in_dataset] + data_df = data_df[is_in_dataset] + # Get arrays + self._split_array: List[str] = split_df["split"].values + self._input_array: List[str] = list(data_df["reviewText"]) + # Get metadata + ( + self._metadata_fields, + self._metadata_array, + self._metadata_map, + ) = self.load_metadata(data_df, self.split_array) + # Get y from metadata + self._y_type: str = "long" + self._y_array = getattr( + self.metadata_array[:, self.metadata_fields.index("y")], self._y_type + )() + # Set split info + self.initialize_split_dicts() + + super().__init__(root_dir, download, self._split_scheme) + + def get_input(self, idx) -> str: + return self._input_array[idx] + + def initialize_split_dicts(self): + if self.split_scheme == "user": + self._split_dict = { + "val_unlabeled": 11, + "test_unlabeled": 12, + "extra_unlabeled": 13, + } + self._split_names = { + "val_unlabeled": "Unlabeled Validation", + "test_unlabeled": "Unlabeled Test", + "extra_unlabeled": "Unlabeled Extra", + } + else: + raise ValueError(f"Split scheme {self.split_scheme} is not recognized.") + + def load_metadata( + self, data_df, split_array + ) -> Tuple[List[str], torch.Tensor, Dict]: + # Get metadata + columns: List[str] = ["reviewerID", "asin", "category", "reviewYear", "overall"] + metadata_fields: List[str] = ["user", "product", "category", "year", "y"] + metadata_df: pd.DataFrame = data_df[columns].copy() + metadata_df.columns = metadata_fields + + sort_idx = np.argsort(split_array) + ordered_maps = {} + for field in ["user", "product", "category"]: + # map to IDs in the order of split values + ordered_maps[field] = pd.unique(metadata_df.iloc[sort_idx][field]) + ordered_maps["y"] = range(1, 6) + ordered_maps["year"] = range( + metadata_df["year"].min(), metadata_df["year"].max() + 1 + ) + metadata_map, metadata = map_to_id_array(metadata_df, ordered_maps) + return metadata_fields, torch.from_numpy(metadata.astype("long")), metadata_map diff --git a/wilds/datasets/unlabeled/camelyon17_unlabeled_dataset.py b/wilds/datasets/unlabeled/camelyon17_unlabeled_dataset.py new file mode 100644 index 00000000..dd30fe87 --- /dev/null +++ b/wilds/datasets/unlabeled/camelyon17_unlabeled_dataset.py @@ -0,0 +1,139 @@ +import os + +import numpy as np +import pandas as pd +import torch +from PIL import Image + +from wilds.datasets.camelyon17_dataset import TEST_CENTER, VAL_CENTER +from wilds.datasets.unlabeled.wilds_unlabeled_dataset import WILDSUnlabeledDataset +from wilds.common.grouper import CombinatorialGrouper + + +class Camelyon17UnlabeledDataset(WILDSUnlabeledDataset): + """ + Unlabeled Camelyon17-WILDS dataset. + This dataset contains patches from all of the slides in the original CAMELYON17 training data, + except for the slides that were labeled with lesion annotations and therefore used in the + labeled Camelyon17Dataset. + + Supported `split_scheme`: + 'official' + + Input (x): + 96x96 image patches extracted from histopathology slides. + + Metadata: + Each patch is annotated with the ID of the hospital it came from (integer from 0 to 4) + and the slide it came from (integer from 0 to 49). + + Website: + https://camelyon17.grand-challenge.org/ + + Original publication: + @article{bandi2018detection, + title={From detection of individual metastases to classification of lymph node status at the patient level: the camelyon17 challenge}, + author={Bandi, Peter and Geessink, Oscar and Manson, Quirine and Van Dijk, Marcory and Balkenhol, Maschenka and Hermsen, Meyke and Bejnordi, Babak Ehteshami and Lee, Byungjae and Paeng, Kyunghyun and Zhong, Aoxiao and others}, + journal={IEEE transactions on medical imaging}, + volume={38}, + number={2}, + pages={550--560}, + year={2018}, + publisher={IEEE} + } + + License: + This dataset is in the public domain and is distributed under CC0. + https://creativecommons.org/publicdomain/zero/1.0/ + """ + + _dataset_name = "camelyon17_unlabeled" + _versions_dict = { + "1.0": { + "download_url": "https://worksheets.codalab.org/rest/bundles/0xa78be8a88a00487a92006936514967d2/contents/blob/", + "compressed_size": 69_442_379_933, + } + } + + def __init__( + self, version=None, root_dir="data", download=False, split_scheme="official" + ): + self._version = version + self._data_dir = self.initialize_data_dir(root_dir, download) + self._original_resolution = (96, 96) + + # Read in metadata + self._metadata_df = pd.read_csv( + os.path.join(self._data_dir, "metadata.csv"), + index_col=0, + dtype={"patient": "str"}, + ) + + # Get filenames + self._input_array = [ + f"patches/patient_{patient}_node_{node}/patch_patient_{patient}_node_{node}_x_{x}_y_{y}.png" + for patient, node, x, y in self._metadata_df.loc[ + :, ["patient", "node", "x_coord", "y_coord"] + ].itertuples(index=False, name=None) + ] + + self._split_scheme = split_scheme + if self._split_scheme == "official": + self._split_dict = { + "train_unlabeled": 10, + "val_unlabeled": 11, + "test_unlabeled": 12, + } + self._split_names = { + "train_unlabeled": "Unlabeled Train", + "val_unlabeled": "Unlabeled Validation", + "test_unlabeled": "Unlabeled Test", + } + else: + raise ValueError(f"Split scheme {self._split_scheme} not recognized") + + # Extract splits + centers = self._metadata_df["center"].values.astype("long") + num_centers = int(np.max(centers)) + 1 + self._metadata_df["split"] = self.split_dict["train_unlabeled"] + val_center_mask = self._metadata_df["center"] == VAL_CENTER + test_center_mask = self._metadata_df["center"] == TEST_CENTER + self._metadata_df.loc[val_center_mask, "split"] = self.split_dict[ + "val_unlabeled" + ] + self._metadata_df.loc[test_center_mask, "split"] = self.split_dict[ + "test_unlabeled" + ] + # Centers 1 and 2 have 600,030 unlabeled examples each. + # The rest of the unlabeled data is used for the train_unlabeled split (1,799,247 total). + assert self._metadata_df.loc[val_center_mask].shape[0] == 600_030 + assert self._metadata_df.loc[test_center_mask].shape[0] == 600_030 + train_center_mask = ~self._metadata_df["center"].isin([VAL_CENTER, TEST_CENTER]) + assert self._metadata_df.loc[train_center_mask].shape[0] == 1_799_247 + + self._split_array = self._metadata_df["split"].values + + self._y_array = 100 * torch.LongTensor(self._metadata_df["tumor"].values) # in metadata.csv, these are all -1 + self._metadata_array = torch.stack( + ( + torch.LongTensor(centers), + torch.LongTensor(self._metadata_df["slide"].values), + self._y_array, + ), + dim=1, + ) + self._metadata_fields = ["hospital", "slide", "y"] + + self._eval_grouper = CombinatorialGrouper( + dataset=self, groupby_fields=["slide"] + ) + + super().__init__(root_dir, download, split_scheme) + + def get_input(self, idx): + """ + Returns x for a given idx. + """ + img_filename = os.path.join(self.data_dir, self._input_array[idx]) + x = Image.open(img_filename).convert("RGB") + return x diff --git a/wilds/datasets/unlabeled/civilcomments_unlabeled_dataset.py b/wilds/datasets/unlabeled/civilcomments_unlabeled_dataset.py new file mode 100644 index 00000000..5c2ece4d --- /dev/null +++ b/wilds/datasets/unlabeled/civilcomments_unlabeled_dataset.py @@ -0,0 +1,115 @@ +import csv +import os +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import pandas as pd +import numpy as np + +from wilds.datasets.unlabeled.wilds_unlabeled_dataset import WILDSUnlabeledDataset +from wilds.common.utils import map_to_id_array + + +class CivilCommentsUnlabeledDataset(WILDSUnlabeledDataset): + """ + Unlabeled CivilComments-WILDS toxicity classification dataset. + This is a modified version of the original CivilComments dataset. + + Supported `split_scheme`: + 'official' + + Input (x): + A comment on an online article, comprising one or more sentences of text. + + Website: + https://www.kaggle.com/c/jigsaw-unintended-bias-in-toxicity-classification + + Original publication: + @inproceedings{borkan2019nuanced, + title={Nuanced metrics for measuring unintended bias with real data for text classification}, + author={Borkan, Daniel and Dixon, Lucas and Sorensen, Jeffrey and Thain, Nithum and Vasserman, Lucy}, + booktitle={Companion Proceedings of The 2019 World Wide Web Conference}, + pages={491--500}, + year={2019} + } + + License: + This dataset is in the public domain and is distributed under CC0. + https://creativecommons.org/publicdomain/zero/1.0/ + """ + + _NOT_IN_DATASET: int = -1 + + _dataset_name: str = "civilcomments_unlabeled" + _versions_dict: Dict[str, Dict[str, Union[str, int]]] = { + "1.0": { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x1c471f23448e4518b000fe47aa7724e0/contents/blob/', + 'compressed_size': 254_142_009 + }, + } + + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): + self._version = version + self._data_dir = self.initialize_data_dir(root_dir, download) + + # Read in metadata + self._metadata_df = pd.read_csv( + os.path.join(self._data_dir, 'unlabeled_data_with_identities.csv'), + index_col=0) + + # Extract text + self._text_array = list(self._metadata_df['comment_text']) + + # Extract splits + self._split_scheme = split_scheme + if self._split_scheme != 'official': + raise ValueError(f'Split scheme {self._split_scheme} not recognized') + + # metadata_df contains split names in strings, so convert them to ints + self._split_dict = { "extra_unlabeled": 13 } + self._split_names = { "extra_unlabeled": "Unlabeled Extra" } + self._metadata_df['split'] = self.split_dict["extra_unlabeled"] + self._split_array = self._metadata_df['split'].values + + # Metadata (Not Available) + # We want grouper to assign all values to their own group, so fill + # all metadata fields with '2'. The normal dataset has binary metadata, + # so this will not overlap. + self._identity_vars = [ + 'male', + 'female', + 'LGBTQ', + 'christian', + 'muslim', + 'other_religions', + 'black', + 'white' + ] + self._auxiliary_vars = [ + 'identity_any', + 'severe_toxicity', + 'obscene', + 'threat', + 'insult', + 'identity_attack', + 'sexual_explicit' + ] + + self._y_array = torch.LongTensor(self._metadata_df['toxicity'].values >= 0.5) + self._metadata_array = torch.cat( + ( + torch.ones( + len(self._metadata_df), + len(self._identity_vars) + len(self._auxiliary_vars) + ) * 2, + self._y_array.unsqueeze(dim=-1) + ), + axis=1 + ) + self._metadata_fields = self._identity_vars + self._auxiliary_vars + ['y'] + + super().__init__(root_dir, download, split_scheme) + + def get_input(self, idx): + return self._text_array[idx] + diff --git a/wilds/datasets/unlabeled/domainnet_unlabeled_dataset.py b/wilds/datasets/unlabeled/domainnet_unlabeled_dataset.py new file mode 100644 index 00000000..4ea5efe5 --- /dev/null +++ b/wilds/datasets/unlabeled/domainnet_unlabeled_dataset.py @@ -0,0 +1,176 @@ +import csv +import os +from typing import Dict, List, Optional, Tuple, Union + +import torch +import pandas as pd +from PIL import Image + +from wilds.common.utils import map_to_id_array +from wilds.datasets.domainnet_dataset import DOMAIN_NET_CATEGORIES, DOMAIN_NET_DOMAINS, SENTRY_DOMAINS +from wilds.datasets.unlabeled.wilds_unlabeled_dataset import WILDSUnlabeledDataset + + +class DomainNetUnlabeledDataset(WILDSUnlabeledDataset): + """ + Unlabeled DomainNet dataset. + + Supported `split_scheme`: + 'official': use the official split from DomainNet + + Input (x): + 224 x 224 x 3 RGB image. + + Label (y): + y is one of the 345 categories in DomainNet. + + Metadata: + None + + Website: + http://ai.bu.edu/M3SDA + + Original publication: + + @inproceedings{peng2019moment, + title={Moment matching for multi-source domain adaptation}, + author={Peng, Xingchao and Bai, Qinxun and Xia, Xide and Huang, Zijun and Saenko, Kate and Wang, Bo}, + booktitle={Proceedings of the IEEE International Conference on Computer Vision}, + pages={1406--1415}, + year={2019} + } + + SENTRY publication: + + @article{prabhu2020sentry + author = {Prabhu, Viraj and Khare, Shivam and Kartik, Deeksha and Hoffman, Judy}, + title = {SENTRY: Selective Entropy Optimization via Committee Consistency for Unsupervised Domain Adaptation}, + year = {2020}, + journal = {arXiv preprint: 2012.11460}, + } + + Fair Use Notice: + "This dataset contains some copyrighted material whose use has not been specifically authorized by the copyright owners. + In an effort to advance scientific research, we make this material available for academic research. We believe this + constitutes a fair use of any such copyrighted material as provided for in section 107 of the US Copyright Law. + In accordance with Title 17 U.S.C. Section 107, the material on this site is distributed without profit for + non-commercial research and educational purposes. For more information on fair use please click here. If you wish + to use copyrighted material on this site or in our dataset for purposes of your own that go beyond non-commercial + research and academic purposes, you must obtain permission directly from the copyright owner." + """ + + _dataset_name: str = "domainnet_unlabeled" + _versions_dict: Dict[str, Dict[str, Union[str, int]]] = { + "1.0": { + "download_url": "https://worksheets.codalab.org/rest/bundles/0x0b8ca76eef384b98b879d0c8c4681a32/contents/blob/", + "compressed_size": 19_255_770_459, + "equivalent_dataset": "domainnet_v1.0", + }, + } + + def __init__( + self, + version: str = None, + root_dir: str = "data", + download: bool = False, + split_scheme: str = "official", + source_domain: str = "sketch", + target_domain: str = "real", + extra_domain: str = "clipart", + use_sentry: bool = False, + ): + # Dataset information + self._version: Optional[str] = version + self._split_scheme: str = split_scheme + self._original_resolution = (224, 224) + self._y_type: str = "long" + self._y_size: int = 1 + # Path of the dataset + self._data_dir: str = self.initialize_data_dir(root_dir, download) + + if use_sentry: + for domain in [source_domain, target_domain, extra_domain]: + assert domain in SENTRY_DOMAINS + print("Using the SENTRY version of DomainNet (unlabeled)...") + metadata_filename = "sentry_metadata.csv" + self._n_classes = 40 + else: + metadata_filename = "metadata.csv" + self._n_classes = 345 + + # Load data + metadata_df: pd.DataFrame = pd.read_csv( + os.path.join(self.data_dir, metadata_filename), + dtype={ + "image_path": str, + "domain": str, + "split": str, + "category": str, + "y": int, + }, + keep_default_na=False, + na_values=[], + quoting=csv.QUOTE_NONNUMERIC, + ) + target_metadata_df = metadata_df.loc[metadata_df["domain"] == target_domain] + extra_metadata_df = metadata_df.loc[metadata_df["domain"] == extra_domain] + metadata_df = pd.concat([target_metadata_df, extra_metadata_df]) + + self._input_image_paths = metadata_df["image_path"].values + self._y_array = torch.from_numpy(metadata_df["y"].values).type(torch.LongTensor) + self.initialize_split_dicts() + self.initialize_split_array(metadata_df, target_domain, extra_domain) + + # Populate metadata fields + self._metadata_fields = ["domain", "category", "y"] + metadata_df = metadata_df[self._metadata_fields] + possible_metadata_values = { + "domain": DOMAIN_NET_DOMAINS, + "category": DOMAIN_NET_CATEGORIES, + "y": range(self._n_classes), + } + self._metadata_map, metadata = map_to_id_array( + metadata_df, possible_metadata_values + ) + self._metadata_array = torch.from_numpy(metadata.astype("long")) + + super().__init__(root_dir, download, self._split_scheme) + + def get_input(self, idx) -> str: + img_path = os.path.join(self.data_dir, self._input_image_paths[idx]) + img = Image.open(img_path).convert("RGB") + return img + + def initialize_split_dicts(self): + if self.split_scheme == "official": + self._split_dict = { + "test_unlabeled": 12, + "extra_unlabeled": 13, + } + self._split_names = { + "test_unlabeled": "Unlabeled Test", + "extra_unlabeled": "Unlabeled Extra", + } + else: + raise ValueError(f"Split scheme {self.split_scheme} is not recognized.") + + def initialize_split_array(self, metadata_df, target_domain, extra_domain): + def get_split(row): + if row["domain"] == target_domain: + if row["split"] == "train": + return 12 + else: + return -1 + elif row["domain"] == extra_domain: + if row["split"] == "train": + return 13 + else: + return -1 + else: + raise ValueError( + f"Domain should be one of {target_domain}, {extra_domain}" + ) + + self._split_array = metadata_df.apply( + lambda row: get_split(row), axis=1 + ).to_numpy() diff --git a/wilds/datasets/unlabeled/fmow_unlabeled_dataset.py b/wilds/datasets/unlabeled/fmow_unlabeled_dataset.py new file mode 100644 index 00000000..63b95e2d --- /dev/null +++ b/wilds/datasets/unlabeled/fmow_unlabeled_dataset.py @@ -0,0 +1,161 @@ +from pathlib import Path +import shutil +import pandas as pd +import torch +from torch.utils.data import Dataset +import pickle +import numpy as np +import torchvision.transforms.functional as F +from torchvision import transforms +import tarfile +import datetime +import pytz +from PIL import Image +from tqdm import tqdm +from wilds.datasets.unlabeled.wilds_unlabeled_dataset import WILDSUnlabeledDataset +from wilds.datasets.fmow_dataset import categories + +Image.MAX_IMAGE_PIXELS = 10000000000 + +class FMoWUnlabeledDataset(WILDSUnlabeledDataset): + """ + The FMoW-WILDS land use / building classification dataset. + This is a processed version of the Functional Map of the World dataset originally sourced from https://github.com/fMoW/dataset. + + Support `split_scheme` + 'official': official split, which is equivalent to 'time_after_2016' + `time_after_{YEAR}` for YEAR between 2002--2018 + + Input (x): + 224 x 224 x 3 RGB satellite image. + + Label (y): + y is one of 62 land use / building classes + + Metadata: + each image is annotated with a location coordinate, timestamp, country code. This dataset computes region as a derivative of country code. + + Website: https://github.com/fMoW/dataset + + Original publication: + @inproceedings{fmow2018, + title={Functional Map of the World}, + author={Christie, Gordon and Fendley, Neil and Wilson, James and Mukherjee, Ryan}, + booktitle={CVPR}, + year={2018} + } + + License: + Distributed under the FMoW Challenge Public License. + https://github.com/fMoW/dataset/blob/master/LICENSE + + """ + _dataset_name = 'fmow_unlabeled' + _versions_dict = { + '1.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0xaec91eb7c9d548ebb15e1b5e60f966ab/contents/blob/', + 'compressed_size': 53_893_324_800, + "equivalent_dataset": "fmow_v1.1",} + } + + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', seed=111, use_ood_val=True): + self._version = version + self._data_dir = self.initialize_data_dir(root_dir, download) + + if split_scheme=='official': + split_scheme='time_after_2016' + self._split_scheme = split_scheme + + self.root = Path(self._data_dir) + self.seed = int(seed) + self._original_resolution = (224, 224) + + self.metadata = pd.read_csv(self.root / 'rgb_metadata.csv') + country_codes_df = pd.read_csv(self.root / 'country_code_mapping.csv') + countrycode_to_region = {k: v for k, v in zip(country_codes_df['alpha-3'], country_codes_df['region'])} + regions = [countrycode_to_region.get(code, 'Other') for code in self.metadata['country_code'].to_list()] + self.metadata['region'] = regions + all_countries = self.metadata['country_code'] + + if self._split_scheme.startswith('time_after'): + year = int(self._split_scheme.split('_')[2]) + year_dt = datetime.datetime(year, 1, 1, tzinfo=pytz.UTC) + self.test_ood_mask = np.asarray(pd.to_datetime(self.metadata['timestamp']) >= year_dt) + # use 3 years of the training set as validation + year_minus_3_dt = datetime.datetime(year-3, 1, 1, tzinfo=pytz.UTC) + self.val_ood_mask = np.asarray(pd.to_datetime(self.metadata['timestamp']) >= year_minus_3_dt) & ~self.test_ood_mask + self.ood_mask = self.test_ood_mask | self.val_ood_mask + else: + raise ValueError(f"Not supported: self._split_scheme = {self._split_scheme}") + + if self.split_scheme.startswith('time_after'): + self._split_dict = { + "train_unlabeled": 10, + "val_unlabeled": 11, + "test_unlabeled": 12, + } + self._split_names = { + "train_unlabeled": "Unlabeled Train", + "val_unlabeled": "Unlabeled Validation", + "test_unlabeled": "Unlabeled Test", + } + else: + raise ValueError(f"Split scheme {self.split_scheme} is not recognized.") + + test_mask = np.asarray(self.metadata['split'] == 'test') + val_mask = np.asarray(self.metadata['split'] == 'val') + seq_mask = np.asarray(self.metadata['split'] == 'seq') + self._split_array = -1 * np.ones(len(self.metadata)) + for split in self._split_dict.keys(): + # unused data from labeled FMoW + if split == 'test_unlabeled': + test_unlabeled_mask = self.test_ood_mask & ~test_mask & ~val_mask + idxs = np.arange(len(self.metadata))[test_unlabeled_mask] + + elif split == 'val_unlabeled': + val_unlabeled_mask = self.val_ood_mask & ~test_mask & ~val_mask + idxs = np.arange(len(self.metadata))[val_unlabeled_mask] + + elif split == 'train_unlabeled': + train_unlabeled_mask = seq_mask & ~self.ood_mask + idxs = np.arange(len(self.metadata))[train_unlabeled_mask] + + self._split_array[idxs] = self._split_dict[split] + unlabeled_mask = (self._split_array != -1) + self.full_idxs = np.arange(len(self.metadata))[unlabeled_mask] + self._split_array = self._split_array[unlabeled_mask] + + # convert region to idxs + all_regions = list(self.metadata['region'].unique()) + region_to_region_idx = {region: i for i, region in enumerate(all_regions)} + self._metadata_map = {'region': all_regions} + region_idxs = [region_to_region_idx[region] for region in self.metadata['region'].tolist()] + self.metadata['region'] = region_idxs + + # make a year column in metadata + year_array = -1 * np.ones(len(self.metadata)) + ts = pd.to_datetime(self.metadata['timestamp']) + for year in range(2002, 2018): + year_mask = np.asarray(ts >= datetime.datetime(year, 1, 1, tzinfo=pytz.UTC)) \ + & np.asarray(ts < datetime.datetime(year+1, 1, 1, tzinfo=pytz.UTC)) + year_array[year_mask] = year - 2002 + self.metadata['year'] = year_array + self._metadata_map['year'] = list(range(2002, 2018)) + + # hidden labels + self.category_to_idx = {cat: i for i, cat in enumerate(categories)} + self.metadata['y'] = np.asarray([self.category_to_idx[y] for y in list(self.metadata['category'])]) + self._y_array = torch.LongTensor(self.metadata['y'].values)[unlabeled_mask] + + self._metadata_fields = ['region', 'year', 'y'] + self._metadata_array = torch.from_numpy(self.metadata[self._metadata_fields].astype(int).to_numpy()).long()[unlabeled_mask] + + super().__init__(root_dir, download, split_scheme) + + def get_input(self, idx): + """ + Returns x for a given idx. + """ + idx = self.full_idxs[idx] + img = Image.open(self.root / 'images' / f'rgb_img_{idx}.png').convert('RGB') + return img diff --git a/wilds/datasets/unlabeled/globalwheat_unlabeled_dataset.py b/wilds/datasets/unlabeled/globalwheat_unlabeled_dataset.py new file mode 100755 index 00000000..6744db65 --- /dev/null +++ b/wilds/datasets/unlabeled/globalwheat_unlabeled_dataset.py @@ -0,0 +1,320 @@ +import numpy as np +import pandas as pd +import torch +from pathlib import Path +from PIL import Image +from wilds.datasets.unlabeled.wilds_unlabeled_dataset import WILDSUnlabeledDataset +from wilds.common.grouper import CombinatorialGrouper +from wilds.datasets.globalwheat_dataset import GlobalWheatDataset + +SESSIONS = [ + 'Arvalis_1', + 'Arvalis_2', + 'Arvalis_3', + 'Arvalis_4', + 'Arvalis_5', + 'Arvalis_6', + 'Arvalis_7', + 'Arvalis_8', + 'Arvalis_9', + 'Arvalis_10', + 'Arvalis_11', + 'Arvalis_12', + 'ETHZ_1', + 'Inrae_1', + 'NMBU_1', + 'NMBU_2', + 'Rres_1', + 'ULiège-GxABT_1', + 'Utokyo_1', + 'Utokyo_2', + 'Utokyo_3', + 'Ukyoto_1', + 'NAU_1', + 'NAU_2', + 'NAU_3', + 'ARC_1', + 'UQ_1', + 'UQ_2', + 'UQ_3', + 'UQ_4', + 'UQ_5', + 'UQ_6', + 'UQ_7', + 'UQ_8', + 'UQ_9', + 'UQ_10', + 'UQ_11', + 'Terraref_1', + 'Terraref_2', + 'KSU_1', + 'KSU_2', + 'KSU_3', + 'KSU_4', + 'CIMMYT_1', + 'CIMMYT_2', + 'CIMMYT_3', + 'Usask_1', + "unlabeled_CIMMYT_PHY_9SAT_HEAT_RGB_20200415", + "unlabeled_CIMMYT_PHY_9SAT_HEAT_RGB_20200422", + "unlabeled_CIMMYT_PHY_BestPT_HEAT_RGB_20200422", + "unlabeled_CIMMYT_PHY_BestPT_HEAT_RGB_280420", + "unlabeled_CIMMYT_PHY_BestPT_RGO_RGB_20200319", + "unlabeled_CIMMYT_PHY_HIBAPIII_HEAT_RGB_20200422", + "unlabeled_CIMMYT_PHY_HIBAPIII_RGO_RGB_20200311", + "unlabeled_CIMMYT_PHY_HIBAPIII_RGO_RGB_20200312", + "unlabeled_CIMMYT_PHY_HIBAPIII_RGO_RGB_20200313", + "unlabeled_CIMMYT_PHY_HIBAPIII_RGO_RGB_20200503", + "unlabeled_CIMMYT_PHY_ROOTS-ANAT_HEAT_RGB_20200422", + "unlabeled_CIMMYT_PHY_ROOTS_ANAT_RGO_RGB_20200313", + "unlabeled_CIMMYT_PHY_ROOTS_ANAT_SQ_RGB_20200313", + "unlabeled_KSU_17ASH_LakinFuller_KEY_PCTHEAD_20170423", + "unlabeled_KSU__16ASH_AM-PANEL_KEY_PCTHEAD_20160504", + "unlabeled_RRes-Dataset1", + "unlabeled_RRes-Dataset2", + "unlabeled_RRes-Dataset3", + "unlabeled_RRes-Dataset4", + "unlabeled_arvalis_greoux_Session_2020-05-25_12-13-52", + "unlabeled_arvalis_greoux_Session_2020-06-22_07-46-39", + "unlabeled_arvalis_VLB_Session 2020-06-02 09-33-05_VBProb", + "unlabeled_arvalis_BIGNAN_Session 2021-06-17 13-49-26", + "unlabeled_arvalis_ENCRAMBADE_Session 2021-05-18 07-59-37", + "unlabeled_arvalis_ENCRAMBADE_Session 2021-06-11 07-48-21", + "unlabeled_arvalis_OLM_Session 2021-06-02 07-47-38", + "unlabeled_arvalis_OLM_Session 2021-06-17 11-55-36", + "unlabeled_arvalis_VLB_Session 2021-05-28 07-08-58", + "unlabeled_arvalis_VLB_Session 2021-06-14 08-51-33", + "unlabeled_arvalis_mons_2018", + "unlabeled_hokkaido_SK_SK_2021_7_20", + "unlabeled_hokkaido_SK_SK_2021_7_28", + "unlabeled_hokkaido_SM_SM_2021_7_10", + "unlabeled_hokkaido_Tsujino_2021_6_16", + "unlabeled_hokkaido_Tsujino_2021_6_23", + "unlabeled_hokkaido_Tsujino_2021_6_7", + "unlabeled_hokkaido_Tsujino_2021_6_9", + "unlabeled_hokkaido_Tsujino_2021_7_10", + "unlabeled_hokkaido_Tsujino_2021_7_11", + "unlabeled_hokkaido_Tsujino_2021_7_20", + "unlabeled_hokkaido_Tsujino_2021_7_3", + "unlabeled_inrae_clermont", + "unlabeled_uliege_6_11", + "unlabeled_uliege_6_15", + "unlabeled_uliege_6_16", + "unlabeled_uliege_6_18", + "unlabeled_uliege_6_23", + "unlabeled_uliege_6_26", + "unlabeled_uliege_7_13", + "unlabeled_uliege_7_7", + "unlabeled_usask_2019_08_06_sampled", + "unlabeled_usask_2019_08_12_sampled", + "unlabeled_ETHZ_2" +] + +COUNTRIES = [ + 'Switzerland', + 'UK', + 'Belgium', + 'Norway', + 'France', + 'Canada', + 'US', + 'Mexico', + 'Japan', + 'China', + 'Australia', + 'Sudan', +] + +LOCATIONS = [ + 'Baima', + 'Brookstead', + 'Ciudad Obregon', + 'Gatton', + 'Gembloux', + 'Gréoux', + 'KSU', + 'Kyoto', + 'Maricopa, AZ', + 'McAllister', + 'Mons', + 'NARO-Hokkaido', + 'NARO-Tsukuba', + 'NMBU', + 'Rothamsted', + 'Saskatchewan', + 'Toulouse', + 'Usask', + 'VLB', + 'VSC', + 'Wad Medani', + 'Eschikon', + 'Bignan', + 'Clermont', + 'Encrambade', + 'NMBU', + 'OLM', +] + +STAGES = [ + 'Filling', + 'Filling-Ripening', + 'multiple', + 'Post-flowering', + 'Ripening', + 'Emergence', +] + +class GlobalWheatUnlabeledDataset(WILDSUnlabeledDataset): + """ + The GlobalWheat-WILDS wheat head localization dataset. + This is a modified version of the original Global Wheat Head Dataset 2021. + + Supported `split_scheme`: + - 'official' + - 'official_with_subsampled_test' + - 'fixed-test' + - 'mixed-train' + Input (x): + 1024 x 1024 RGB images of wheat field canopy starting from anthesis (flowering) to ripening. + Metadata: + Each image is annotated with the ID of the domain (session) it came from (integer from 0 to 46). + Website: + http://www.global-wheat.com/ + Original publication: + @article{david_global_2020, + title = {Global {Wheat} {Head} {Detection} ({GWHD}) {Dataset}: {A} {Large} and {Diverse} {Dataset} of {High}-{Resolution} {RGB}-{Labelled} {Images} to {Develop} and {Benchmark} {Wheat} {Head} {Detection} {Methods}}, + volume = {2020}, + url = {https://doi.org/10.34133/2020/3521852}, + doi = {10.34133/2020/3521852}, + journal = {Plant Phenomics}, + author = {David, Etienne and Madec, Simon and Sadeghi-Tehran, Pouria and Aasen, Helge and Zheng, Bangyou and Liu, Shouyang and Kirchgessner, Norbert and Ishikawa, Goro and Nagasawa, Koichi and Badhon, Minhajul A. and Pozniak, Curtis and de Solan, Benoit and Hund, Andreas and Chapman, Scott C. and Baret, Frédéric and Stavness, Ian and Guo, Wei}, + month = Aug, + year = {2020}, + note = {Publisher: AAAS}, + pages = {3521852}, + } + @misc{david2021global, + title={Global Wheat Head Dataset 2021: more diversity to improve the benchmarking of wheat head localization methods}, + author={Etienne David and Mario Serouart and Daniel Smith and Simon Madec and Kaaviya Velumani and Shouyang Liu and Xu Wang and Francisco Pinto Espinosa and Shahameh Shafiee and Izzat S. A. Tahir and Hisashi Tsujimoto and Shuhei Nasuda and Bangyou Zheng and Norbert Kichgessner and Helge Aasen and Andreas Hund and Pouria Sadhegi-Tehran and Koichi Nagasawa and Goro Ishikawa and Sébastien Dandrifosse and Alexis Carlier and Benoit Mercatoris and Ken Kuroki and Haozhou Wang and Masanori Ishii and Minhajul A. Badhon and Curtis Pozniak and David Shaner LeBauer and Morten Lilimo and Jesse Poland and Scott Chapman and Benoit de Solan and Frédéric Baret and Ian Stavness and Wei Guo}, + year={2021}, + eprint={2105.07660}, + archivePrefix={arXiv}, + primaryClass={cs.CV} + } + License: + This dataset is distributed under the MIT license. + """ + + _dataset_name = "globalwheat_unlabeled" + _versions_dict = { + "1.0": { + "download_url": "https://worksheets.codalab.org/rest/bundles/0x59d4c1b8b7ad45cc83e080d11d1eaf94/contents/blob/", + "compressed_size": 103_766_940_000, + } + } + + def __init__( + self, version=None, root_dir="data", download=False, split_scheme="official" + ): + + self._version = version + self._data_dir = self.initialize_data_dir(root_dir, download) + self._original_resolution = (1024, 1024) + self.root = Path(self.data_dir) + + self._n_classes = 1 + self._split_scheme = split_scheme + + data_dfs = {} + + if self._split_scheme == "official": + self._split_dict = { + "train_unlabeled": 10, + "val_unlabeled": 11, + "test_unlabeled": 12, + "extra_unlabeled": 13, + } + self._split_names = { + "train_unlabeled": "Unlabeled Train", + "val_unlabeled": "Unlabeled Validation", + "test_unlabeled": "Unlabeled Test", + "extra_unlabeled": "Unlabeled Extra", + } + + data_dfs["train_unlabeled"] = pd.read_csv( + self.root / f"official_train_unlabeled.csv" + ) + + data_dfs["val_unlabeled"] = pd.read_csv( + self.root / f"official_val_unlabeled.csv" + ) + + data_dfs["test_unlabeled"] = pd.read_csv( + self.root / f"official_test_unlabeled.csv" + ) + + data_dfs["extra_unlabeled"] = pd.read_csv( + self.root / f"official_extra_unlabeled.csv" + ) + + else: + raise ValueError(f"Split scheme {self._split_scheme} not recognized") + + self._image_array = [] + self._split_array = [] + self._metadata_array = [] + + # Extract splits + + for split_name, split_idx in self._split_dict.items(): + df = data_dfs[split_name] + + self._image_array.extend(list(df["image_name"].values)) + self._split_array.extend([split_idx] * len(df)) + + self._metadata_array.extend([int(item) for item in df["domain"].values]) + + self._split_array = np.array(self._split_array) + self._metadata_array = torch.tensor( + self._metadata_array, dtype=torch.long + ).unsqueeze(1) + self._metadata_array = torch.cat( + ( + self._metadata_array, + torch.zeros((len(self._metadata_array), 3), dtype=torch.long), + ), + dim=1, + ) + + domain_df = pd.read_csv(self.root / "metadata_domain_unlabeled.csv", sep=";") + + for session_idx, session_name in enumerate(SESSIONS): + idx = pd.Index(domain_df["name"]).get_loc(session_name) + + country = domain_df.loc[idx, "country"] + location = domain_df.loc[idx, "location"] + stage = domain_df.loc[idx, "development_stage"] + + session_mask = self._metadata_array[:, 0] == session_idx + + self._metadata_array[session_mask, 1] = COUNTRIES.index(country) + self._metadata_array[session_mask, 2] = LOCATIONS.index(location) + self._metadata_array[session_mask, 3] = STAGES.index(stage) + + self._metadata_fields = ["session", "country", "location", "stage"] + self._metadata_map = { + "session": SESSIONS, + "country": COUNTRIES, + "location": LOCATIONS, + "stage": STAGES, + } + + super().__init__(root_dir, download, split_scheme) + + def get_input(self, idx): + """ + Returns x for a given idx. + """ + img_filename = self.root / "images" / self._image_array[idx] + x = Image.open(img_filename) + return x diff --git a/wilds/datasets/unlabeled/iwildcam_unlabeled_dataset.py b/wilds/datasets/unlabeled/iwildcam_unlabeled_dataset.py new file mode 100644 index 00000000..a973a4b9 --- /dev/null +++ b/wilds/datasets/unlabeled/iwildcam_unlabeled_dataset.py @@ -0,0 +1,153 @@ +from datetime import datetime +from pathlib import Path +import os + +from PIL import Image +import pandas as pd +import numpy as np +import torch +import json + +from wilds.datasets.unlabeled.wilds_unlabeled_dataset import WILDSUnlabeledDataset +from wilds.common.grouper import CombinatorialGrouper +from wilds.common.metrics.all_metrics import Accuracy, Recall, F1 + + +class IWildCamUnlabeledDataset(WILDSUnlabeledDataset): + """ + The unlabeled iWildCam2020-WILDS dataset. + This is a modified version of the original iWildCam2020 competition dataset. + Input (x): + RGB images from camera traps + Metadata: + Each image is annotated with the ID of the location (camera trap) it came from. + Website: + http://lila.science/datasets/wcscameratraps + https://library.wcs.org/ScienceData/Camera-Trap-Data-Summary.aspx + Original publication: + @misc{wcsdataset, + title = {Wildlife Conservation Society Camera Traps Dataset}, + howpublished = {\\url{http://lila.science/datasets/wcscameratraps}}, + } + License: + This dataset is distributed under Community Data License Agreement – Permissive – Version 1.0 + https://cdla.io/permissive-1-0/ + """ + + _dataset_name = "iwildcam_unlabeled" + _versions_dict = { + "1.0": { + "download_url": "https://worksheets.codalab.org/rest/bundles/0xff56ea50fbf64aabbc4d09b2e8d50e18/contents/blob/", + "compressed_size": 41_016_937_676, + } + } + + def __init__( + self, version=None, root_dir="data", download=False, split_scheme="official" + ): + + self._version = version + self._split_scheme = split_scheme + if self._split_scheme != "official": + raise ValueError(f"Split scheme {self._split_scheme} not recognized") + + # path + self._data_dir = Path(self.initialize_data_dir(root_dir, download)) + + # Load splits + df = pd.read_csv(self._data_dir / "metadata.csv") + + # Splits + self._split_dict = {"extra_unlabeled": 0} + self._split_names = {"extra_unlabeled": "Extra Unlabeled"} + df["split_id"] = 0 + self._split_array = df["split_id"].values + + # Filenames + df["filename"] = df["uid"].apply(lambda x: x + ".jpg") + self._input_array = df["filename"].values + + # Location/group info + n_groups = df["location_remapped"].nunique() + self._n_groups = n_groups + + def get_date(x): + if isinstance(x, str): + return datetime.strptime(x, "%Y-%m-%d %H:%M:%S.%f") + else: + return -1 + + ## Extract datetime subcomponents and include in metadata + df["datetime_obj"] = df["datetime"].apply(get_date) + df["year"] = df["datetime_obj"].apply( + lambda x: int(x.year) if isinstance(x, datetime) else -1 + ) + df["month"] = df["datetime_obj"].apply( + lambda x: int(x.month) if isinstance(x, datetime) else -1 + ) + df["day"] = df["datetime_obj"].apply( + lambda x: int(x.day) if isinstance(x, datetime) else -1 + ) + df["hour"] = df["datetime_obj"].apply( + lambda x: int(x.hour) if isinstance(x, datetime) else -1 + ) + df["minute"] = df["datetime_obj"].apply( + lambda x: int(x.minute) if isinstance(x, datetime) else -1 + ) + df["second"] = df["datetime_obj"].apply( + lambda x: int(x.second) if isinstance(x, datetime) else -1 + ) + + df["y"] = df["y"].apply( # filter out "bad" labels (-1 means the category was not in iwildcam_v2.0; 99999 means the category was unknown). map all to -100. + lambda x: x if ((x != -1) and (x != 99999)) else -100 + ) + self._y_array = torch.LongTensor(df['y'].values) + + self._metadata_array = torch.tensor( + np.stack( + [ + df["location_remapped"].values, + df["sequence_remapped"].values, + df["year"].values, + df["month"].values, + df["day"].values, + df["hour"].values, + df["minute"].values, + df["second"].values, + df["y"], + ], + axis=1, + ) + ) + self._metadata_fields = [ + "location", + "sequence", + "year", + "month", + "day", + "hour", + "minute", + "second", + "y", + ] + + # eval grouper + self._eval_grouper = CombinatorialGrouper( + dataset=self, groupby_fields=(["location"]) + ) + + super().__init__(root_dir, download, split_scheme) + + def get_input(self, idx): + """ + Args: + - idx (int): Index of a data point + Output: + - x (Tensor): Input features of the idx-th data point + """ + + # All images are in the train folder + img_path = self.data_dir / "images" / self._input_array[idx] + img = Image.open(img_path) + + return img diff --git a/wilds/datasets/unlabeled/ogbmolpcba_unlabeled_dataset.py b/wilds/datasets/unlabeled/ogbmolpcba_unlabeled_dataset.py new file mode 100644 index 00000000..b030cc04 --- /dev/null +++ b/wilds/datasets/unlabeled/ogbmolpcba_unlabeled_dataset.py @@ -0,0 +1,101 @@ +import os +import torch +import numpy as np + +from wilds.datasets.unlabeled.wilds_unlabeled_dataset import WILDSUnlabeledDataset + +from ogb.graphproppred import PygGraphPropPredDataset +from ogb.utils.url import download_url +import torch_geometric +if torch_geometric.__version__ >= '2.0.0': + from torch_geometric.loader.dataloader import Collater as PyGCollater +else: + from torch_geometric.data.dataloader import Collater as PyGCollater + +class OGBPCBAUnlabeledDataset(WILDSUnlabeledDataset): + """ + Unlabeled dataset for OGB-molpcba. There are 5 million unlabeled molecules randomly sampled from the entire PubChem database. + + Input (x): + Molecular graphs represented as Pytorch Geometric data objects + + Metadata: + - scaffold + Each molecule is annotated with the scaffold ID that the molecule is assigned to. + + Website: + https://ogb.stanford.edu/docs/graphprop/#ogbg-mol + + Original publication: + @article{hu2020ogb, + title={Open Graph Benchmark: Datasets for Machine Learning on Graphs}, + author={W. {Hu}, M. {Fey}, M. {Zitnik}, Y. {Dong}, H. {Ren}, B. {Liu}, M. {Catasta}, J. {Leskovec}}, + journal={arXiv preprint arXiv:2005.00687}, + year={2020} + } + + License: + This dataset is distributed under the MIT license. + https://github.com/snap-stanford/ogb/blob/master/LICENSE + """ + + _dataset_name = 'ogb-molpcba_unlabeled' + _versions_dict = { + '1.0': { + 'download_url': None, + 'compressed_size': None}} + + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): + self._version = version + if version is not None: + raise ValueError('Versioning for Unlabeled MolPCBA is handled through the OGB package. Please set version=none.') + # internally call ogb package + + ### Setting up meta-information for the dataset + meta_dict = {} + meta_dict['dir_path'] = os.path.join(root_dir, 'molpcba_unlabeled') + meta_dict['url'] = 'http://snap.stanford.edu/ogb/data/wilds/molpcba_unlabeled.zip' + meta_dict['num tasks'] = 0 + meta_dict['eval metric'] = None + meta_dict['download_name'] = 'molpcba_unlabeled' + meta_dict['version'] = 1 + meta_dict['add_inverse_edge'] = 'False' + meta_dict['data type'] = 'mol' + meta_dict['has_node_attr'] = 'True' + meta_dict['has_edge_attr'] = 'True' + meta_dict['task type'] = 'classification' + meta_dict['num classes'] = -1 + meta_dict['split'] = 'scaffold' + meta_dict['additional node files'] = 'None' + meta_dict['additional edge files'] = 'None' + meta_dict['binary'] = 'True' + + self.ogb_dataset = PygGraphPropPredDataset(name = 'molpcba_unlabeled', root = root_dir, meta_dict = meta_dict) + self.ogb_dataset.data.y = None + + # set variables + self._data_dir = self.ogb_dataset.root + if split_scheme=='official': + split_scheme = 'scaffold' + self._split_scheme = split_scheme + + self._split_array = torch.zeros(len(self.ogb_dataset)).long() + split_idx = self.ogb_dataset.get_idx_split() + self._split_array[split_idx['train']] = 10 + self._split_array[split_idx['valid']] = 11 + self._split_array[split_idx['test']] = 12 + + self._metadata_fields = ['scaffold'] + + metadata_file_path = os.path.join(self.ogb_dataset.root, 'processed', 'group_assignment.npy') + self._metadata_array = torch.from_numpy(np.load(metadata_file_path)).reshape(-1,1).long() + + if torch_geometric.__version__ >= '1.7.0': + self._collate = PyGCollater(follow_batch=[], exclude_keys=[]) + else: + self._collate = PyGCollater(follow_batch=[]) + + super().__init__(root_dir, download, split_scheme) + + def get_input(self, idx): + return self.ogb_dataset[int(idx)] diff --git a/wilds/datasets/unlabeled/poverty_unlabeled_dataset.py b/wilds/datasets/unlabeled/poverty_unlabeled_dataset.py new file mode 100644 index 00000000..3d15109d --- /dev/null +++ b/wilds/datasets/unlabeled/poverty_unlabeled_dataset.py @@ -0,0 +1,149 @@ +from pathlib import Path +import shutil +import pandas as pd +import torch +from torch.utils.data import Dataset +import pickle +import numpy as np +import torchvision.transforms.functional as F +from torchvision import transforms +import tarfile +import datetime +import pytz +from PIL import Image +from tqdm import tqdm +from wilds.datasets.unlabeled.wilds_unlabeled_dataset import WILDSUnlabeledDataset + +Image.MAX_IMAGE_PIXELS = 10000000000 + +from wilds.datasets.poverty_dataset import ( + DATASET, + BAND_ORDER, + DHS_COUNTRIES, + SURVEY_NAMES, + _MEANS_2009_17, + _STD_DEVS_2009_17, + split_by_countries + ) + + +class PovertyMapUnlabeledDataset(WILDSUnlabeledDataset): + """ + The unlabeled PovertyMap-WILDS poverty measure prediction dataset. + This is a processed version of LandSat 5/7/8 satellite imagery originally from Google Earth Engine under the names `LANDSAT/LC08/C01/T1_SR`,`LANDSAT/LE07/C01/T1_SR`,`LANDSAT/LT05/C01/T1_SR`, + nighttime light imagery from the DMSP and VIIRS satellites (Google Earth Engine names `NOAA/DMSP-OLS/CALIBRATED_LIGHTS_V4` and `NOAA/VIIRS/DNB/MONTHLY_V1/VCMSLCFG`) + and processed DHS survey metadata obtained from https://github.com/sustainlab-group/africa_poverty and originally from `https://dhsprogram.com/data/available-datasets.cfm`. + Unlabeled data are sampled from around DHS survey locations. + + Supported `split_scheme`: + 'official' and `countries`, which are equivalent + + Input (x): + 224 x 224 x 8 satellite image, with 7 channels from LandSat and 1 nighttime light channel from DMSP/VIIRS. Already mean/std normalized. + + Output (y): + y is a real-valued asset wealth index. Higher index corresponds to more asset wealth. + + Metadata: + each image is annotated with location coordinates (noised for anonymity), survey year, urban/rural classification, country, nighttime light mean, nighttime light median. + + Website: https://github.com/sustainlab-group/africa_poverty + + Original publication: + @article{yeh2020using, + author = {Yeh, Christopher and Perez, Anthony and Driscoll, Anne and Azzari, George and Tang, Zhongyi and Lobell, David and Ermon, Stefano and Burke, Marshall}, + day = {22}, + doi = {10.1038/s41467-020-16185-w}, + issn = {2041-1723}, + journal = {Nature Communications}, + month = {5}, + number = {1}, + title = {{Using publicly available satellite imagery and deep learning to understand economic well-being in Africa}}, + url = {https://www.nature.com/articles/s41467-020-16185-w}, + volume = {11}, + year = {2020} + } + + License: + LandSat/DMSP/VIIRS data is U.S. Public Domain. + + """ + _dataset_name = 'poverty_unlabeled' + _versions_dict = { + '1.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0xdfcf71b4f6164cc1a7edb0cbb7444c8c/contents/blob/', + 'compressed_size': 172_742_430_134, + } + } + + def __init__(self, version=None, root_dir='data', download=False, + split_scheme='official', + no_nl=False, fold='A', + use_ood_val=True, + cache_size=100): + self._version = version + self._data_dir = self.initialize_data_dir(root_dir, download) + self._original_resolution = (224, 224) + + if split_scheme=='official': + split_scheme = 'countries' + + self._split_scheme = split_scheme + if self._split_scheme == 'countries': + self._split_dict = { + "train_unlabeled": 10, + "val_unlabeled": 11, + "test_unlabeled": 12, + } + self._split_names = { + "train_unlabeled": "Unlabeled Train", + "val_unlabeled": "Unlabeled Validation", + "test_unlabeled": "Unlabeled Test", + } + else: + raise ValueError("Split scheme not recognized") + + self.no_nl = no_nl + if fold not in {'A', 'B', 'C', 'D', 'E'}: + raise ValueError("Fold must be A, B, C, D, or E") + + self.root = Path(self._data_dir) + self.metadata = pd.read_csv(self.root / 'unlabeled_metadata.csv') + country_folds = SURVEY_NAMES[f'2009-17{fold}'] + + self._split_array = -1 * np.ones(len(self.metadata)) + + incountry_folds_split = np.arange(len(self.metadata)) + # take the test countries to be ood + idxs_id, idxs_ood_test = split_by_countries(incountry_folds_split, country_folds['test'], self.metadata) + # also create a validation OOD set + idxs_id, idxs_ood_val = split_by_countries(idxs_id, country_folds['val'], self.metadata) + + self._split_array[idxs_id] = self._split_dict['train_unlabeled'] + self._split_array[idxs_ood_val] = self._split_dict['val_unlabeled'] + self._split_array[idxs_ood_test] = self._split_dict['test_unlabeled'] + + # no labels + self.metadata['y'] = (-100 * np.ones(len(self.metadata))) + # no urban/rural classification + self.metadata['urban'] = (-100 * np.ones(len(self.metadata))) + + # add country group field + country_to_idx = {country: i for i, country in enumerate(DHS_COUNTRIES)} + self.metadata['country'] = [country_to_idx[country] for country in self.metadata['country'].tolist()] + self._metadata_map = {'country': DHS_COUNTRIES} + # rename wealthpooled to y + self._metadata_fields = ['urban', 'y', 'country'] + self._metadata_array = torch.from_numpy(self.metadata[self._metadata_fields].astype(float).to_numpy()) + super().__init__(root_dir, download, split_scheme) + + def get_input(self, idx): + """ + Returns x for a given idx. + """ + img = np.load(self.root / 'images' / f'landsat_poverty_img_{idx}.npz')['x'] + if self.no_nl: + img[-1] = 0 + img = torch.from_numpy(img).float() + + return img diff --git a/wilds/datasets/unlabeled/wilds_unlabeled_dataset.py b/wilds/datasets/unlabeled/wilds_unlabeled_dataset.py new file mode 100644 index 00000000..ef111fb7 --- /dev/null +++ b/wilds/datasets/unlabeled/wilds_unlabeled_dataset.py @@ -0,0 +1,244 @@ +import os + +import torch +import numpy as np + +from wilds.datasets.wilds_dataset import WILDSDataset + + +class WILDSUnlabeledDataset(WILDSDataset): + """ + Shared dataset class for all unlabeled WILDS datasets. + Each data point in the dataset is an (x, metadata) tuple, where: + - x is the input features + - metadata is a vector of relevant information, e.g., domain. + """ + + # The corresponding indices for the unlabeled splits should not overlap with + # the indices of their labeled counterparts (indices start from 0). + # So, for unlabeled splits, the indices should start from 10. + DEFAULT_SPLITS = { + "train_unlabeled": 10, + "val_unlabeled": 11, + "test_unlabeled": 12, + "extra_unlabeled": 13, + } + DEFAULT_SPLIT_NAMES = { + "train_unlabeled": "Unlabeled Train", + "val_unlabeled": "Unlabeled Validation", + "test_unlabeled": "Unlabeled Test", + "extra_unlabeled": "Unlabeled Extra", + } + DEFAULT_SOURCE_DOMAIN_SPLITS = [10] + + _UNSUPPORTED_FUNCTIONALITY_ERROR = "Not supported - no labels available." + + def __len__(self): + return len(self.metadata_array) + + def __getitem__(self, idx): + # Any transformations are handled by the WILDSSubset + # since different subsets (e.g., train vs test) might have different transforms + x = self.get_input(idx) + metadata = self.metadata_array[idx] + return x, metadata + + def get_subset(self, split, frac=1.0, transform=None, load_y=False): + """ + Args: + - split (str): Split identifier, e.g., 'train', 'val', 'test'. + Must be in self.split_dict. + - frac (float): What fraction of the split to randomly sample. + Used for fast development on a small dataset. + - transform (function): Any data transformations to be applied to the input x. + Output: + - subset (WILDSSubset): A (potentially subsampled) subset of the WILDSDataset. + """ + if split not in self.split_dict: + raise ValueError(f"Split {split} not found in dataset's split_dict.") + split_mask = self.split_array == self.split_dict[split] + split_idx = np.where(split_mask)[0] + + if frac < 1.0: + num_to_retain = int(np.round(float(len(split_idx)) * frac)) + split_idx = np.sort(np.random.permutation(split_idx)[:num_to_retain]) + + return WILDSUnlabeledSubset(self, split_idx, transform, load_y=load_y) + + def check_init(self): + """ + Convenience function to check that the WILDSDataset is properly configured. + """ + required_attrs = [ + "_dataset_name", + "_data_dir", + "_split_scheme", + "_split_array", + "_metadata_fields", + "_metadata_array", + ] + for attr_name in required_attrs: + assert hasattr( + self, attr_name + ), f"WILDSUnlabeledDataset is missing {attr_name}." + + # Check that data directory exists + if not os.path.exists(self.data_dir): + raise ValueError( + f"{self.data_dir} does not exist yet. Please generate the dataset first." + ) + + # Check splits + assert self.split_dict.keys() == self.split_names.keys() + + # Check that required arrays are Tensors + assert isinstance( + self.metadata_array, torch.Tensor + ), "metadata_array must be a torch.Tensor" + + # Check that dimensions match + assert len(self.split_array) == len(self.metadata_array) + + # Check metadata + assert len(self.metadata_array.shape) == 2 + assert len(self.metadata_fields) == self.metadata_array.shape[1] + + def initialize_data_dir(self, root_dir, download): + if "equivalent_dataset" in self.versions_dict[self.version]: + self.check_version() + os.makedirs(root_dir, exist_ok=True) + + # If the dataset has an equivalent dataset, check if the equivalent dataset already exists + # at the root directory. If it does, don't download and just return the equivalent dataset path. + data_dir = os.path.join( + root_dir, self.versions_dict[self.version]["equivalent_dataset"] + ) + if not os.path.exists(data_dir): + # Proceed with downloading the equivalent dataset. + self.download_dataset(data_dir, download) + return data_dir + else: + return super().initialize_data_dir(root_dir, download) + + def eval(self, y_pred, y_true, metadata): + raise AttributeError(WILDSUnlabeledDataset._UNSUPPORTED_FUNCTIONALITY_ERROR) + + @property + def y_array(self): + raise AttributeError(WILDSUnlabeledDataset._UNSUPPORTED_FUNCTIONALITY_ERROR) + + @property + def y_size(self): + raise AttributeError(WILDSUnlabeledDataset._UNSUPPORTED_FUNCTIONALITY_ERROR) + + @property + def split_dict(self): + """ + A dictionary mapping splits to integer identifiers (used in split_array), + Keys should match up with split_names. + """ + return getattr(self, "_split_dict", WILDSUnlabeledDataset.DEFAULT_SPLITS) + + @property + def split_names(self): + """ + A dictionary mapping splits to their pretty names, + Keys should match up with split_dict. + """ + return getattr(self, "_split_names", WILDSUnlabeledDataset.DEFAULT_SPLIT_NAMES) + + @property + def source_domain_splits(self): + """ + List of split IDs that are from the source domain. + """ + return getattr( + self, + "_source_domain_splits", + WILDSUnlabeledDataset.DEFAULT_SOURCE_DOMAIN_SPLITS, + ) + + +class WILDSUnlabeledSubset(WILDSUnlabeledDataset): + def __init__(self, dataset, indices, transform, load_y=False): + self.dataset = dataset + self.indices = indices + inherited_attrs = [ + "_dataset_name", + "_data_dir", + "_collate", + "_split_scheme", + "_split_dict", + "_split_names", + "_metadata_fields", + "_metadata_map", + "_y_array", + ] + for attr_name in inherited_attrs: + if hasattr(dataset, attr_name): + setattr(self, attr_name, getattr(dataset, attr_name)) + self.transform = transform + self.load_y = load_y + + def __getitem__(self, idx): + x, metadata = self.dataset[self.indices[idx]] + if self.transform is not None: + x = self.transform(x) + if self.load_y: + y = self._y_array[self.indices[idx]] + return x, y, metadata + else: + return x, metadata + + def __len__(self): + return len(self.indices) + + @property + def split_array(self): + return self.dataset._split_array[self.indices] + + @property + def metadata_array(self): + return self.dataset.metadata_array[self.indices] + +class WILDSPseudolabeledSubset(WILDSUnlabeledDataset): + """Pseudolabeled subset initialized from an unlabeled subset""" + def __init__(self, reference_subset, pseudolabels, transform, collate=None): + assert len(reference_subset) == len(pseudolabels) + self.pseudolabels = pseudolabels + copied_attrs = [ + "dataset", + "indices", + "_dataset_name", + "_data_dir", + "_collate", + "_split_scheme", + "_split_dict", + "_split_names", + "_metadata_fields", + "_metadata_map", + ] + for attr_name in copied_attrs: + if hasattr(reference_subset, attr_name): + setattr(self, attr_name, getattr(reference_subset, attr_name, None)) + self.transform = transform + if collate: + self._collate = collate + + def __getitem__(self, idx): + x, metadata = self.dataset[self.indices[idx]] + y_pseudo = self.pseudolabels[idx] + if self.transform is not None: + x = self.transform(x) + return x, y_pseudo, metadata + + def __len__(self): + return len(self.indices) + + @property + def split_array(self): + return self.dataset._split_array[self.indices] + + @property + def metadata_array(self): + return self.dataset.metadata_array[self.indices] \ No newline at end of file diff --git a/wilds/datasets/wilds_dataset.py b/wilds/datasets/wilds_dataset.py index 094c1a28..d4f301e3 100644 --- a/wilds/datasets/wilds_dataset.py +++ b/wilds/datasets/wilds_dataset.py @@ -15,10 +15,12 @@ class WILDSDataset: """ DEFAULT_SPLITS = {'train': 0, 'val': 1, 'test': 2} DEFAULT_SPLIT_NAMES = {'train': 'Train', 'val': 'Validation', 'test': 'Test'} + DEFAULT_SOURCE_DOMAIN_SPLITS = [0] def __init__(self, root_dir, download, split_scheme): if len(self._metadata_array.shape) == 1: self._metadata_array = self._metadata_array.unsqueeze(1) + self._add_coarse_domain_metadata() self.check_init() def __len__(self): @@ -66,13 +68,32 @@ def get_subset(self, split, frac=1.0, transform=None): """ if split not in self.split_dict: raise ValueError(f"Split {split} not found in dataset's split_dict.") + split_mask = self.split_array == self.split_dict[split] split_idx = np.where(split_mask)[0] + if frac < 1.0: + # Randomly sample a fraction of the split num_to_retain = int(np.round(float(len(split_idx)) * frac)) split_idx = np.sort(np.random.permutation(split_idx)[:num_to_retain]) - subset = WILDSSubset(self, split_idx, transform) - return subset + + return WILDSSubset(self, split_idx, transform) + + def _add_coarse_domain_metadata(self): + """ + Update metadata fields, map and values with coarse-grained domain information. + """ + if hasattr(self, '_metadata_map'): + self._metadata_map['from_source_domain'] = [False, True] + self._metadata_fields.append('from_source_domain') + from_source_domain = torch.as_tensor( + [1 if split in self.source_domain_splits else 0 for split in self.split_array], + dtype=torch.int64 + ).unsqueeze(dim=1) + self._metadata_array = torch.cat( + [self._metadata_array, from_source_domain], + dim=1 + ) def check_init(self): """ @@ -205,6 +226,13 @@ def split_names(self): """ return getattr(self, '_split_names', WILDSDataset.DEFAULT_SPLIT_NAMES) + @property + def source_domain_splits(self): + """ + List of split IDs that are from the source domain. + """ + return getattr(self, '_source_domain_splits', WILDSDataset.DEFAULT_SOURCE_DOMAIN_SPLITS) + @property def split_array(self): """ @@ -302,64 +330,51 @@ def initialize_data_dir(self, root_dir, download): Datasets for which we don't control the download, like Yelp, might not handle versions similarly. """ - if self.version not in self.versions_dict: - raise ValueError(f'Version {self.version} not supported. Must be in {self.versions_dict.keys()}.') - - download_url = self.versions_dict[self.version]['download_url'] - compressed_size = self.versions_dict[self.version]['compressed_size'] + self.check_version() os.makedirs(root_dir, exist_ok=True) - data_dir = os.path.join(root_dir, f'{self.dataset_name}_v{self.version}') version_file = os.path.join(data_dir, f'RELEASE_v{self.version}.txt') - current_major_version, current_minor_version = tuple(map(int, self.version.split('.'))) - # Check if we specified the latest version. Otherwise, print a warning. - latest_major_version, latest_minor_version = tuple(map(int, self.latest_version.split('.'))) - if latest_major_version > current_major_version: - print( - f'*****************************\n' - f'{self.dataset_name} has been updated to version {self.latest_version}.\n' - f'You are currently using version {self.version}.\n' - f'We highly recommend updating the dataset by not specifying the older version in the command-line argument or dataset constructor.\n' - f'See https://wilds.stanford.edu/changelog for changes.\n' - f'*****************************\n') - elif latest_minor_version > current_minor_version: - print( - f'*****************************\n' - f'{self.dataset_name} has been updated to version {self.latest_version}.\n' - f'You are currently using version {self.version}.\n' - f'Please consider updating the dataset.\n' - f'See https://wilds.stanford.edu/changelog for changes.\n' - f'*****************************\n') + # If the dataset exists at root_dir, then don't download. + if not self.dataset_exists_locally(data_dir, version_file): + self.download_dataset(data_dir, download) + return data_dir - # If the data_dir exists and contains the right RELEASE file, - # we assume the dataset is correctly set up - if os.path.exists(data_dir) and os.path.exists(version_file): - return data_dir - - # If the data_dir exists and does not contain the right RELEASE file, but it is not empty and the download_url is not set, - # we assume the dataset is correctly set up - if ((os.path.exists(data_dir)) and - (len(os.listdir(data_dir)) > 0) and - (download_url is None)): - return data_dir - - # Otherwise, we assume the dataset needs to be downloaded. - # If download == False, then return an error. - if download == False: - if download_url is None: - raise FileNotFoundError(f'The {self.dataset_name} dataset could not be found in {data_dir}. {self.dataset_name} cannot be automatically downloaded. Please download it manually.') - else: - raise FileNotFoundError(f'The {self.dataset_name} dataset could not be found in {data_dir}. Initialize the dataset with download=True to download the dataset. If you are using the example script, run with --download. This might take some time for large datasets.') + def dataset_exists_locally(self, data_dir, version_file): + download_url = self.versions_dict[self.version]['download_url'] + # There are two ways to download a dataset: + # 1. Automatically through the WILDS package + # 2. From a third party (e.g. OGB-MolPCBA is downloaded through the OGB package) + # Datasets downloaded from a third party need not have a download_url and RELEASE text file. + return ( + os.path.exists(data_dir) and ( + os.path.exists(version_file) or + (len(os.listdir(data_dir)) > 0 and download_url is None) + ) + ) - # Otherwise, proceed with downloading. + def download_dataset(self, data_dir, download_flag): + version_dict = self.versions_dict[self.version] + download_url = version_dict['download_url'] + compressed_size = version_dict['compressed_size'] + + # Check that download_url exists. if download_url is None: - raise ValueError(f'Sorry, {self.dataset_name} cannot be automatically downloaded. Please download it manually.') + raise ValueError(f'{self.dataset_name} cannot be automatically downloaded. Please download it manually.') + + # Check that the download_flag is set to true. + if not download_flag: + raise FileNotFoundError( + f'The {self.dataset_name} dataset could not be found in {data_dir}. Initialize the dataset with ' + f'download=True to download the dataset. If you are using the example script, run with --download. ' + f'This might take some time for large datasets.' + ) from wilds.datasets.download_utils import download_and_extract_archive print(f'Downloading dataset to {data_dir}...') print(f'You can also download the dataset manually at https://wilds.stanford.edu/downloads.') + try: start_time = time.time() download_and_extract_archive( @@ -368,14 +383,37 @@ def initialize_data_dir(self, root_dir, download): filename='archive.tar.gz', remove_finished=True, size=compressed_size) - download_time_in_minutes = (time.time() - start_time) / 60 - print(f"It took {round(download_time_in_minutes, 2)} minutes to download and uncompress the dataset.") + print(f"\nIt took {round(download_time_in_minutes, 2)} minutes to download and uncompress the dataset.\n") except Exception as e: print(f"\n{os.path.join(data_dir, 'archive.tar.gz')} may be corrupted. Please try deleting it and rerunning this command.\n") print(f"Exception: ", e) - return data_dir + def check_version(self): + # Check that the version is valid. + if self.version not in self.versions_dict: + raise ValueError(f'Version {self.version} not supported. Must be in {self.versions_dict.keys()}.') + + # Check that the specified version is the latest version. Otherwise, warn. + current_major_version, current_minor_version = tuple(map(int, self.version.split('.'))) + latest_major_version, latest_minor_version = tuple(map(int, self.latest_version.split('.'))) + if latest_major_version > current_major_version: + print( + f'*****************************\n' + f'{self.dataset_name} has been updated to version {self.latest_version}.\n' + f'You are currently using version {self.version}.\n' + f'We highly recommend updating the dataset by not specifying the older version in the ' + f'command-line argument or dataset constructor.\n' + f'See https://wilds.stanford.edu/changelog for changes.\n' + f'*****************************\n') + elif latest_minor_version > current_minor_version: + print( + f'*****************************\n' + f'{self.dataset_name} has been updated to version {self.latest_version}.\n' + f'You are currently using version {self.version}.\n' + f'Please consider updating the dataset.\n' + f'See https://wilds.stanford.edu/changelog for changes.\n' + f'*****************************\n') @staticmethod def standard_eval(metric, y_pred, y_true): diff --git a/wilds/download_datasets.py b/wilds/download_datasets.py index bf085739..d4dcd79e 100644 --- a/wilds/download_datasets.py +++ b/wilds/download_datasets.py @@ -12,6 +12,8 @@ def main(): help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).') parser.add_argument('--datasets', nargs='*', default=None, help=f'Specify a space-separated list of dataset names to download. If left unspecified, the script will download all of the official benchmark datasets. Available choices are {wilds.supported_datasets}.') + parser.add_argument('--unlabeled', default=False, type=bool, + help=f'If this flag is set, the unlabeled dataset will be downloaded instead of the labeled.') config = parser.parse_args() if config.datasets is None: @@ -27,6 +29,7 @@ def main(): wilds.get_dataset( dataset=dataset, root_dir=config.root_dir, + unlabeled=config.unlabeled, download=True) diff --git a/wilds/get_dataset.py b/wilds/get_dataset.py index a57e2f02..74f50356 100644 --- a/wilds/get_dataset.py +++ b/wilds/get_dataset.py @@ -1,12 +1,15 @@ +from typing import Optional + import wilds -def get_dataset(dataset, version=None, **dataset_kwargs): +def get_dataset(dataset: str, version: Optional[str] = None, unlabeled: bool = False, **dataset_kwargs): """ Returns the appropriate WILDS dataset class. Input: dataset (str): Name of the dataset - version (str): Dataset version number, e.g., '1.0'. - Defaults to the latest version. + version (Union[str, None]): Dataset version number, e.g., '1.0'. + Defaults to the latest version. + unlabeled (bool): If true, use the unlabeled version of the dataset. dataset_kwargs: Other keyword arguments to pass to the dataset constructors. Output: The specified WILDSDataset class. @@ -17,28 +20,55 @@ def get_dataset(dataset, version=None, **dataset_kwargs): if dataset not in wilds.supported_datasets: raise ValueError(f'The dataset {dataset} is not recognized. Must be one of {wilds.supported_datasets}.') + if unlabeled and dataset not in wilds.unlabeled_datasets: + raise ValueError(f'Unlabeled data is not available for {dataset}. Must be one of {wilds.unlabeled_datasets}.') + if dataset == 'amazon': - from wilds.datasets.amazon_dataset import AmazonDataset - return AmazonDataset(version=version, **dataset_kwargs) + if unlabeled: + from wilds.datasets.unlabeled.amazon_unlabeled_dataset import AmazonUnlabeledDataset + return AmazonUnlabeledDataset(version=version, **dataset_kwargs) + else: + from wilds.datasets.amazon_dataset import AmazonDataset + return AmazonDataset(version=version, **dataset_kwargs) elif dataset == 'camelyon17': - from wilds.datasets.camelyon17_dataset import Camelyon17Dataset - return Camelyon17Dataset(version=version, **dataset_kwargs) + if unlabeled: + from wilds.datasets.unlabeled.camelyon17_unlabeled_dataset import Camelyon17UnlabeledDataset + return Camelyon17UnlabeledDataset(version=version, **dataset_kwargs) + else: + from wilds.datasets.camelyon17_dataset import Camelyon17Dataset + return Camelyon17Dataset(version=version, **dataset_kwargs) elif dataset == 'celebA': from wilds.datasets.celebA_dataset import CelebADataset return CelebADataset(version=version, **dataset_kwargs) elif dataset == 'civilcomments': - from wilds.datasets.civilcomments_dataset import CivilCommentsDataset - return CivilCommentsDataset(version=version, **dataset_kwargs) + if unlabeled: + from wilds.datasets.unlabeled.civilcomments_unlabeled_dataset import CivilCommentsUnlabeledDataset + return CivilCommentsUnlabeledDataset(version=version, **dataset_kwargs) + else: + from wilds.datasets.civilcomments_dataset import CivilCommentsDataset + return CivilCommentsDataset(version=version, **dataset_kwargs) + + elif dataset == 'domainnet': + if unlabeled: + from wilds.datasets.unlabeled.domainnet_unlabeled_dataset import DomainNetUnlabeledDataset + return DomainNetUnlabeledDataset(version=version, **dataset_kwargs) + else: + from wilds.datasets.domainnet_dataset import DomainNetDataset + return DomainNetDataset(version=version, **dataset_kwargs) elif dataset == 'iwildcam': - if version == '1.0': - from wilds.datasets.archive.iwildcam_v1_0_dataset import IWildCamDataset + if unlabeled: + from wilds.datasets.unlabeled.iwildcam_unlabeled_dataset import IWildCamUnlabeledDataset + return IWildCamUnlabeledDataset(version=version, **dataset_kwargs) else: - from wilds.datasets.iwildcam_dataset import IWildCamDataset - return IWildCamDataset(version=version, **dataset_kwargs) + if version == '1.0': + from wilds.datasets.archive.iwildcam_v1_0_dataset import IWildCamDataset + else: + from wilds.datasets.iwildcam_dataset import IWildCamDataset # type:ignore + return IWildCamDataset(version=version, **dataset_kwargs) elif dataset == 'waterbirds': from wilds.datasets.waterbirds_dataset import WaterbirdsDataset @@ -49,22 +79,34 @@ def get_dataset(dataset, version=None, **dataset_kwargs): return YelpDataset(version=version, **dataset_kwargs) elif dataset == 'ogb-molpcba': - from wilds.datasets.ogbmolpcba_dataset import OGBPCBADataset - return OGBPCBADataset(version=version, **dataset_kwargs) + if unlabeled: + from wilds.datasets.unlabeled.ogbmolpcba_unlabeled_dataset import OGBPCBAUnlabeledDataset + return OGBPCBAUnlabeledDataset(version=version, **dataset_kwargs) + else: + from wilds.datasets.ogbmolpcba_dataset import OGBPCBADataset + return OGBPCBADataset(version=version, **dataset_kwargs) elif dataset == 'poverty': - if version == '1.0': - from wilds.datasets.archive.poverty_v1_0_dataset import PovertyMapDataset + if unlabeled: + from wilds.datasets.unlabeled.poverty_unlabeled_dataset import PovertyMapUnlabeledDataset + return PovertyMapUnlabeledDataset(version=version, **dataset_kwargs) else: - from wilds.datasets.poverty_dataset import PovertyMapDataset - return PovertyMapDataset(version=version, **dataset_kwargs) + if version == '1.0': + from wilds.datasets.archive.poverty_v1_0_dataset import PovertyMapDataset + else: + from wilds.datasets.poverty_dataset import PovertyMapDataset # type:ignore + return PovertyMapDataset(version=version, **dataset_kwargs) elif dataset == 'fmow': - if version == '1.0': - from wilds.datasets.archive.fmow_v1_0_dataset import FMoWDataset + if unlabeled: + from wilds.datasets.unlabeled.fmow_unlabeled_dataset import FMoWUnlabeledDataset + return FMoWUnlabeledDataset(version=version, **dataset_kwargs) else: - from wilds.datasets.fmow_dataset import FMoWDataset - return FMoWDataset(version=version, **dataset_kwargs) + if version == '1.0': + from wilds.datasets.archive.fmow_v1_0_dataset import FMoWDataset + else: + from wilds.datasets.fmow_dataset import FMoWDataset # type:ignore + return FMoWDataset(version=version, **dataset_kwargs) elif dataset == 'bdd100k': from wilds.datasets.bdd100k_dataset import BDD100KDataset @@ -77,6 +119,14 @@ def get_dataset(dataset, version=None, **dataset_kwargs): elif dataset == 'sqf': from wilds.datasets.sqf_dataset import SQFDataset return SQFDataset(version=version, **dataset_kwargs) + + elif dataset == 'globalwheat': + if unlabeled: + from wilds.datasets.unlabeled.globalwheat_unlabeled_dataset import GlobalWheatUnlabeledDataset + return GlobalWheatUnlabeledDataset(version=version, **dataset_kwargs) + else: + from wilds.datasets.globalwheat_dataset import GlobalWheatDataset # type:ignore + return GlobalWheatDataset(version=version, **dataset_kwargs) elif dataset == 'encode': from wilds.datasets.encode_dataset import EncodeDataset diff --git a/wilds/version.py b/wilds/version.py index 531433b0..42e216c5 100644 --- a/wilds/version.py +++ b/wilds/version.py @@ -4,7 +4,7 @@ import logging from threading import Thread -__version__ = '1.2.2' +__version__ = '2.0.0' try: os.environ['OUTDATED_IGNORE'] = '1'