Pytorch implementation for Anchor & Transform: Learning Sparse Embeddings for Large Vocabularies
Correspondence to:
- Paul Liang (pliang@cs.cmu.edu)
- Manzil Zaheer (manzilzaheer@google.com)
Anchor & Transform: Learning Sparse Embeddings for Large Vocabularies
Paul Pu Liang, Manzil Zaheer, Yuan Wang, Amr Ahmed
ICLR 2021
If you find this repository useful, please cite our paper:
@inproceedings{liang2021anchor,
author = {Paul Pu Liang and
Manzil Zaheer and
Yuan Wang and
Amr Ahmed},
title = {Anchor & Transform: Learning Sparse Embeddings for Large Vocabularies},
booktitle = {9th International Conference on Learning Representations, {ICLR} 2021},
publisher = {OpenReview.net},
year = {2021},
url = {https://openreview.net/forum?id=Vd7lCMvtLqg}
}
First check that the requirements are satisfied:
Python 3.6
torch 1.2.0
numpy 1.18.1
matplotlib 3.1.2
tqdm 4.45.0
The next step is to clone the repository:
git clone https://github.com/pliang279/sparse_discrete.git
download Movielens 25m data from http://files.grouplens.org/datasets/movielens/ml-25m.zip and unzip into a folder ml-25m/
download Movielens 1m data from http://files.grouplens.org/datasets/movielens/ml-1m.zip and unzip into a folder ml-1m/
run python3 movielens_data.py
which extracts the .dat files in ml-1m/ and generates ml-1m/ml1m_ratings.csv
by now, make sure you have the files ml-25m/ratings.csv
and ml-1m/ml1m_ratings.csv
download amazon data from http://deepyeti.ucsd.edu/jianmo/amazon/categoryFilesSmall/all_csv_files.csv into a folder called amazon_data/
run python3 movielens_data.py
, which parses the .csv files in amazon_data/ and generates the file amazon_data/saved_amazon_data_filtered5.h5
MF baseline: python3 movielens.py --model_path MF --latent_dim 16 --dataset 25m
MixDim embeddings: python3 movielens.py --model_path mdMF --base_dim 16 --temperature 0.4 --k 8 --dataset 25m
ANT: python3 movielens.py --model_path sparseMF --latent_dim 16 --user_anchors 10 --item_anchors 15 --lda2s 2e-6 --lda2e 2e-6 --dataset 25m
NBANT: python3 movielens.py --model_path sparseMF --latent_dim 16 --lda1 0.1 --lda2s 2e-6 --lda2e 2e-6 --dataset 25m --dynamic
MF: python3 amazon.py --model_path MF --latent_dim 16 --dataset amazon
ANT: python3 amazon.py --model_path sparseMF --latent_dim 16 --user_anchors 8 --item_anchors 8 --lda2s 1e-7 --lda2e 1e-7 --dataset amazon