A PyTorch implementation of a Language Model using LSTM (Long Short-Term Memory) networks to generate political news headlines. The model is trained on the Hugging Face heegyu/news-category-dataset
dataset, specifically focusing on headlines from the POLITICS category.
- LSTM-based language model with embedding layer
- Two training implementations:
- Standard training
- Truncated Backpropagation Through Time (TBPTT)
- Two text generation strategies:
- Random sampling with top-k
- Greedy (argmax) sampling
- Performance visualization with loss and perplexity plots
torch
datasets
matplotlib
numpy
The project uses the heegyu/news-category-dataset
from Hugging Face, filtering for political headlines. The data processing pipeline includes:
- Lowercase conversion
- Basic tokenization
- Addition of
<EOS>
tokens - Creation of word-to-index and index-to-word mappings
- Padding for batch processing
The LSTM model includes:
- Embedding layer (configurable dimension)
- LSTM layer(s) (configurable number of layers and hidden size)
- Fully connected output layer
- Dropout for regularization
- Clone the repository:
git clone https://github.com/steq28/lstm-news-generator
cd lstm-news-generator
- Install dependencies:
pip install -r requirements.txt
- Run the training:
python lstm-model.py
Default hyperparameters:
- Hidden size: 1024 (standard) / 2048 (truncated)
- Embedding dimension: 150
- Number of LSTM layers: 1
- Learning rate: 0.001
- Batch size: 32
- Training epochs: 6
- Gradient clipping: 1.0
The model achieves:
- Loss < 1.5 by epoch 6 in standard training
- Loss < 0.9 by epoch 6 in truncated training
- Generates coherent political headlines with both sampling strategies
Example outputs:
Random sampling:
- "the president wants a letter to foreign and justice <EOS>"
- "the president wants to help the other <EOS>"
- "the president wants a money advantage in american politics <EOS>"
Greedy sampling:
- "the president wants to help the koch brothers <EOS>"
The training process generates four plots:
- Training loss (standard training)
- Perplexity (standard training)
- Training loss (truncated training)
- Perplexity (truncated training)
Dataset
class: Handles data preprocessing and batchingModel
class: Implements the LSTM architecture- Training functions:
train()
: Standard training implementationtrain_truncated()
: TBPTT implementation
- Generation functions:
random_sample_next()
: Top-k random samplingsample_argmax()
: Greedy sampling
Unlike Word2Vec, this model uses contextual embeddings, meaning the vector representations depend on the surrounding context and don't maintain static arithmetic properties (e.g., King - Man + Woman ≠ Queen).
Stefano Quaggio (stefano.quaggio@usi.ch)