Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
picaultj committed Dec 13, 2024
1 parent 76ddb6f commit 3f1c3c3
Show file tree
Hide file tree
Showing 10 changed files with 276 additions and 147 deletions.
27 changes: 27 additions & 0 deletions .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: pytest on Pull Request

on:
pull_request:
branches: [main]

jobs:
build:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.11' # Adjust Python version as needed

- name: Install dependencies
run: |
python -m pip install --upgrade pip poetry
poetry install --extras tests
- name: Run pytests
run: |
coverage run -m pytest
coverage report
3 changes: 0 additions & 3 deletions bertrend/bertrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@
)
from bertrend.weak_signals.topic_modeling import preprocess_model, merge_models
from bertrend.weak_signals.weak_signals import (
initialize_new_topic,
update_existing_topic,
apply_decay_to_inactive_topics,
calculate_signal_popularity,
)

Expand Down
6 changes: 1 addition & 5 deletions bertrend/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"système électrique",
]

stopwords_fr_file = Path(__file__).parent / "stopwords-fr.json"
stopwords_fr_file = Path(__file__).parent / "resources" / "stopwords-fr.json"
with open(stopwords_fr_file, "r", encoding="utf-8") as file:
FRENCH_STOPWORDS = json.load(file)

Expand Down Expand Up @@ -122,10 +122,6 @@
DEFAULT_WINDOW_SIZE = 7 # days
MAX_WINDOW_SIZE = 365 # days

# UI Settings
PAGE_TITLE = "BERTopic Topic Detection"
LAYOUT = "wide"

# Visualization Settings
SANKEY_NODE_PAD = 15
SANKEY_NODE_THICKNESS = 20
Expand Down
File renamed without changes.
178 changes: 178 additions & 0 deletions bertrend/tests/test_topic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright (c) 2024, RTE (https://www.rte-france.com)
# See AUTHORS.txt
# SPDX-License-Identifier: MPL-2.0
# This file is part of BERTrend.
import pytest
import numpy as np

from unittest.mock import MagicMock

from bertopic import BERTopic

from bertrend.bertrend import TopicModel
from bertrend.parameters import (
DEFAULT_UMAP_N_COMPONENTS,
DEFAULT_UMAP_N_NEIGHBORS,
DEFAULT_HDBSCAN_MIN_CLUSTER_SIZE,
DEFAULT_HDBSCAN_MIN_SAMPLES,
HDBSCAN_CLUSTER_SELECTION_METHODS,
VECTORIZER_NGRAM_RANGES,
DEFAULT_MIN_DF,
DEFAULT_TOP_N_WORDS,
LANGUAGES,
)


# Mocking dependencies
@pytest.fixture
def mock_sentence_transformer():
"""Fixture for mocking the SentenceTransformer."""
return MagicMock()


@pytest.fixture
def mock_embedding():
"""Fixture for mocking the embeddings."""
return MagicMock()


@pytest.fixture
def topic_model():
"""Fixture for creating a TopicModel instance."""
return TopicModel()


def test_topic_model_initialization_default_values(topic_model):
"""Test initialization of TopicModel with default values."""
assert topic_model.umap_n_components == DEFAULT_UMAP_N_COMPONENTS
assert topic_model.umap_n_neighbors == DEFAULT_UMAP_N_NEIGHBORS
assert topic_model.hdbscan_min_cluster_size == DEFAULT_HDBSCAN_MIN_CLUSTER_SIZE
assert topic_model.hdbscan_min_samples == DEFAULT_HDBSCAN_MIN_SAMPLES
assert (
topic_model.hdbscan_cluster_selection_method
== HDBSCAN_CLUSTER_SELECTION_METHODS[0]
)
assert topic_model.vectorizer_ngram_range == VECTORIZER_NGRAM_RANGES[0]
assert topic_model.min_df == DEFAULT_MIN_DF
assert topic_model.top_n_words == DEFAULT_TOP_N_WORDS
assert topic_model.language == LANGUAGES[0]


def test_topic_model_initialization_custom_values():
"""Test initialization of TopicModel with custom values."""
custom_params = {
"umap_n_components": 15,
"umap_n_neighbors": 20,
"hdbscan_min_cluster_size": 10,
"hdbscan_min_samples": 5,
"hdbscan_cluster_selection_method": "eom",
"vectorizer_ngram_range": (1, 2),
"min_df": 2,
"top_n_words": 50,
"language": "French",
}

topic_model = TopicModel(**custom_params)

assert topic_model.umap_n_components == custom_params["umap_n_components"]
assert topic_model.umap_n_neighbors == custom_params["umap_n_neighbors"]
assert (
topic_model.hdbscan_min_cluster_size
== custom_params["hdbscan_min_cluster_size"]
)
assert topic_model.hdbscan_min_samples == custom_params["hdbscan_min_samples"]
assert (
topic_model.hdbscan_cluster_selection_method
== custom_params["hdbscan_cluster_selection_method"]
)
assert topic_model.vectorizer_ngram_range == custom_params["vectorizer_ngram_range"]
assert topic_model.min_df == custom_params["min_df"]
assert topic_model.top_n_words == custom_params["top_n_words"]
assert topic_model.language == custom_params["language"]


def test_initialize_models_called(topic_model):
"""Test that internal models are initialized properly."""
assert hasattr(topic_model, "umap_model")
assert hasattr(topic_model, "hdbscan_model")
assert hasattr(topic_model, "vectorizer_model")
assert hasattr(topic_model, "mmr_model")
assert hasattr(topic_model, "ctfidf_model")


def test_create_topic_model_with_valid_input(
topic_model, mock_sentence_transformer, mock_embedding
):
"""Test create_topic_model method with valid input."""
docs = ["Document 1", "Document 2"]
embeddings = mock_embedding
zeroshot_topic_list = ["Topic 1", "Topic 2"]
zeroshot_min_similarity = 0.7

# Mock BERTopic behavior
mock_bertopic = MagicMock(spec=BERTopic)
mock_bertopic.fit_transform.return_value = ([], [])
mock_bertopic.reduce_outliers.return_value = []

topic_model.create_topic_model = MagicMock(return_value=mock_bertopic)

result = topic_model.create_topic_model(
docs,
mock_sentence_transformer,
embeddings,
zeroshot_topic_list,
zeroshot_min_similarity,
)

topic_model.create_topic_model.assert_called_once_with(
docs,
mock_sentence_transformer,
embeddings,
zeroshot_topic_list,
zeroshot_min_similarity,
)
assert result == mock_bertopic


def test_create_topic_model_with_empty_zeroshot_topic_list(
topic_model, mock_sentence_transformer, mock_embedding
):
"""Test create_topic_model with an empty zeroshot_topic_list."""
docs = ["Document 1", "Document 2"] * 100
zeroshot_topic_list = []
zeroshot_min_similarity = 0.7

result = topic_model.create_topic_model(
docs,
None, # mock_sentence_transformer,
np.random.random((len(docs), 768)),
zeroshot_topic_list,
zeroshot_min_similarity,
)

assert result is not None
assert (
result.zeroshot_topic_list is None
) # Check that it was set to None internally


def test_create_topic_model_exception_handling(
topic_model, mock_sentence_transformer, mock_embedding
):
"""Test that create_topic_model raises an exception if an error occurs."""
docs = ["Document 1", "Document 2"]
embeddings = mock_embedding
zeroshot_topic_list = ["Topic 1"]
zeroshot_min_similarity = 0.7

# Simulate an error in the create_topic_model method
topic_model.create_topic_model = MagicMock(side_effect=Exception("Test Exception"))

with pytest.raises(Exception, match="Test Exception"):
topic_model.create_topic_model(
docs,
mock_sentence_transformer,
embeddings,
zeroshot_topic_list,
zeroshot_min_similarity,
)
57 changes: 4 additions & 53 deletions bertrend/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
# SPDX-License-Identifier: MPL-2.0
# This file is part of BERTrend.

import json
import os
from pathlib import Path
from typing import List, Tuple, Union

import numpy as np
Expand All @@ -25,6 +23,7 @@
from umap import UMAP

from bertrend import BASE_CACHE_PATH
from bertrend.parameters import STOPWORDS
from bertrend.common.openai_client import OpenAI_Client
from bertrend.common.prompts import FRENCH_TOPIC_REPRESENTATION_PROMPT
from bertrend.utils import (
Expand All @@ -34,7 +33,6 @@

# Parameters:
DEFAULT_EMBEDDING_MODEL_NAME = "paraphrase-multilingual-MiniLM-L12-v2"
DEFAULT_TOP_N_WORDS = 10
DEFAULT_NR_TOPICS = 10
DEFAULT_NGRAM_RANGE = (1, 1)
DEFAULT_MIN_DF = 2
Expand All @@ -47,55 +45,8 @@
prediction_data=True,
)

STOP_WORDS_RTE = [
"w",
"kw",
"mw",
"gw",
"tw",
"wh",
"kwh",
"mwh",
"gwh",
"twh",
"volt",
"volts",
"000",
]
COMMON_NGRAMS = [
"éléctricité",
"RTE",
"France",
"électrique",
"projet",
"année",
"transport électricité",
"réseau électrique",
"gestionnaire réseau",
"réseau transport",
"production électricité",
"milliards euros",
"euros",
"2022",
"2023",
"2024",
"électricité RTE",
"Réseau transport",
"RTE gestionnaire",
"électricité France",
"système électrique",
]

# Define the path to extended list of french stopwords JSON file
stopwords_fr_file = Path(__file__).parent / "weak_signals" / "stopwords-fr.json"

# Read the JSON data from the file and directly assign it to the list
with open(stopwords_fr_file, "r", encoding="utf-8") as file:
FRENCH_STOPWORDS = json.load(file)

DEFAULT_STOP_WORDS = FRENCH_STOPWORDS + STOP_WORDS_RTE + COMMON_NGRAMS
DEFAULT_VECTORIZER_MODEL = CountVectorizer(
stop_words=DEFAULT_STOP_WORDS,
stop_words=STOPWORDS,
ngram_range=DEFAULT_NGRAM_RANGE,
min_df=DEFAULT_MIN_DF,
)
Expand Down Expand Up @@ -273,7 +224,7 @@ def train_BERTopic(
vectorizer_model: CountVectorizer = DEFAULT_VECTORIZER_MODEL,
ctfidf_model: ClassTfidfTransformer = DEFAULT_CTFIDF_MODEL,
representation_model: List[RepresentationModelType] = DEFAULT_REPRESENTATION_MODEL,
top_n_words: int = DEFAULT_TOP_N_WORDS,
top_n_words: int = STOPWORDS,
nr_topics: Union[str, int] = DEFAULT_NR_TOPICS,
use_cache: bool = True,
cache_base_name: str = None,
Expand Down Expand Up @@ -350,7 +301,7 @@ def train_BERTopic(
stop_words = (
stopwords.words("english")
if form_parameters["countvectorizer_stop_words"] == "english"
else DEFAULT_STOP_WORDS
else STOPWORDS
)
vectorizer_model = CountVectorizer(
stop_words=stop_words,
Expand Down
Loading

0 comments on commit 3f1c3c3

Please sign in to comment.