This repository hosts a project for breast cancer classification using machine learning models with high accuracy and interpretability. The project features data preprocessing, model training and selection, interpretability analysis, deployment as an interactive web application, and integration with a PostgreSQL database to store predictions and track data drift for retraining.
- Objective
- Project Overview
- Workflow Diagram
- Repository Structure
- Methodology
- Graphical User Interface (GUI)
- Application Deployment
- How to Run the Project Locally
- Screenshots of the Application
To develop a machine learning tool to classify breast tumors as benign or malignant, targeting an F1-score greater than 0.95. The tool will incorporate interpretability techniques for actionable insights, implement a database for storing predictions, and include a monthly data drift detection mechanism to trigger model retraining when drift is detected.
Breast cancer is among the most diagnosed cancers worldwide, and early detection is critical for improving patient outcomes. This project emphasizes:
- High-performing models such as Gradient Boosting, LightGBM, XGBoost and Interpretable models like Logistic Regression.
- Interpretability analysis for gaining insights into model predictions.
- Deployment into Streamlit Cloud and Docker for accessibility.
- A PostgreSQL database hosted on Neon to store predictions made by the app.
- A monthly data drift detection mechanism that triggers model retraining if drift is detected.
The project follows a structured workflow from data exploration to model deployment, including interpretability analysis and data drift detection for model retraining. The diagram below illustrates the key steps in the project pipeline:
.github/workflows/ # GitHub Actions workflow for data drift and model retraining
.streamlit/ # Streamlit configuration
config.toml # Streamlit configuration file
data/
processed/ # Processed dataset after data exploration and feature engineering
raw/ # Original dataset
deployment/
models/ # Models for the deployed app
app.py # Streamlit app for the deployed models
Dockerfile # Docker configuration
requirements_app.txt # Dependencies for Docker
models/ # Models generated during experimentation
notebooks/ # Jupyter notebooks
- 01_data_exploration.ipynb # Data exploration and visualization
- 02_feature_engineering.ipynb # Feature engineering and selection
- 03_modeling.ipynb # Model training and evaluation
- 04_models_interpretability.ipynb # Interpretability analysis (Including SHAP and LIME)
- 05_error_analysis.ipynb # Error analysis
visuals/
app_screenshots/ # Screenshots of the deployed app
figures/ # Generated visualizations
metrics/ # Metrics visualizations
py_scripts/ # Python scripts
- data_drift_monitoring.py # Script for Data drift detection and model retraining
- query_db.py # Script to query the PostgreSQL database
full_requirements.txt # Comprehensive list of all dependencies for the project
requirements.txt # Simplified dependencies for Streamlit Cloud & GitHub Actions
The methodology for this project follows the CRISP-DM (Cross-Industry Standard Process for Data Mining) framework, which includes the following phases:
-
Business Understanding: The goal of this project is to develop a model that can accurately predict whether a breast tumor is benign or malignant. The model is intended for clinical use, where interpretability and reliability are paramount.
-
Data Understanding: Data exploration was performed to understand the characteristics of the dataset, identify any missing values, and explore the distribution of features across different classes.
-
Data Preparation: Data preprocessing was performed by cleaning the dataset, handling missing values, scaling features, and applying feature selection techniques to improve model performance.
-
Modeling: Multiple machine learning models were trained and evaluated, including Logistic Regression, Gradient Boosting, XGBoost, and LightGBM, with an emphasis on both performance and interpretability.
-
Evaluation: Model performance was evaluated based on metrics such as F1-score, and interpretability analysis performed using SHAP and LIME. Models were selected for deployment based on their performance and interpretability.
-
Deployment: The selected models were deployed using a Streamlit app on Streamlit cloud and also an Image pushed to Docker Hub. The predictions made by the app are stored in a PostgreSQL database for further monitoring and retraining. A data drift detection workflow ensures that the model remains accurate over time.
Check basic statistics and null values in the dataset along with:
-
Diagnosis Distribution: Explored the dataset distribution of benign vs. malignant cases.
-
Boxplots grouped by 'diagnosis': Visualized feature distributions for benign and malignant cases.
-
Correlation Analysis: Visualized feature relationships with heatmaps and correlation to the target variable.
-
PCA Visualization: Reduced dimensionality for visualization in 2D and 3D.
-
Applied preprocessing and feature selection techniques:
- Scaled features using standardization.
- Removed low-variance features.
- Selected important features using:
- Random Forest Feature Importance.
- Mutual Information.
- ANOVA F-Test.
- Kruskal-Wallis H Test.
-
Feature Importance Analysis: Highlighted the most impactful features.
-
Trained and evaluated multiple machine learning models:
- Classification Models: Logistic Regression, Decision Tree, Random Forest, Gradient Boosting, XGBoost, LightGBM, CatBoost, SVM, k-NN, Naive Bayes, AdaBoost, Bagging, Extra Trees, Voting, Gaussian Process.
- Neural Networks: Artificial Neural Network (ANN), and Multi-Layer Perceptron (MLP).
-
Select models achieving F1-scores above 0.96 for further analysis.
- Used SHAP and LIME to perform interpretability analysis:
- SHAP for Tree-Based Models and ANN.
- LIME for general models.
- The models for the next step were chosen based on a combination of interpretability and performance. This approach ensures that we have high-performing models along with simple interpretable models, which is critical for clinical applications where understanding the decision-making process is essential.
- Selected Models:
- Logistic Regression
- Gradient Boosting
- LightGBM
- XGBoost
- CatBoost
Conducted a thorough error analysis, inspecting misclassified cases to identify the model's weaknesses.
-
Confusion Matrices: Evaluated model predictions on the test set.
-
Performance Comparison: Evaluated key metrics across selected models.
Metric Logistic Regression Gradient Boosting LightGBM XGBoost CatBoost Accuracy 99.1% 98.2% 99.1% 99.1% 99.1% Precision 100.0% 95.2% 100.0% 97.6% 97.6% Recall 97.5% 100.0% 97.5% 100.0% 100.0% F1-Score 98.7% 97.6% 98.7% 98.8% 98.8% ROC-AUC 1.0 1.0 1.0 1.0 1.0
The models evaluated (Logistic Regression, Gradient Boosting, LightGBM, XGBoost, and CatBoost) show high accuracy and excellent predictive performance on the test set.
-
Overall Performance:
- All models perform well, with high F1-Scores, Accuracy, and ROC-AUC values.
- Differences in performance are marginal, but minor trade-offs exist between precision and recall.
-
Key Trade-Offs:
- Logistic Regression and LightGBM: Perfect precision but slightly lower recall, making them conservative predictors.
- Gradient Boosting, XGBoost, and CatBoost: Perfect recall but slightly reduced precision, better for high-recall use cases like avoiding missed Malignant cases.
-
Confusion Matrix Insights:
- Logistic Regression and LightGBM produce the fewest misclassifications, with only one false negative.
- Gradient Boosting has more false positives but no false negatives, ideal for minimizing missed critical cases.
-
Logistic Regression:
- Suitable if simplicity and interpretability are priorities.
- Effective for clinical scenarios where precision (avoiding unnecessary treatments) is important.
-
Gradient Boosting:
- Prioritize for cases where recall (detecting all Malignant cases) is essential, even if it introduces more false positives.
-
LightGBM, XGBoost, and CatBoost:
- Excellent alternatives combining strong recall with slightly reduced precision.
The Streamlit-based GUI provides an intuitive interface for users to use the breast cancer prediction app. The interface emphasizes simplicity, interpretability, and usability, ensuring users can make informed decisions with ease.
-
Interactive Input Fields:
- Users can enter patient data (tumor features) directly into the app and use buttons to increase/decrease values.
-
Upload Document:
- Users can upload a CSV file containing for predictions.
- If the user upload a CSV file the system will perform a batch processing and the application will create new columns(Prediction, Prediction label, Prediction probabilities, Model used to predict) the user can then download them or visualize them in the app.
-
Model Selection:
- The default model is Logistic Regression for its simplicity and interpretability.
- Users can choose between multiple models (Gradient Boosting, XGBoost, Light GBM, Catboost) to see predictions and compare results.
-
Prediction Visualization:
- Results are displayed classifying tumors as either Benign or Malignant.
- Probabilities are provided alongside predictions to help assess certainty.
-
Interpretability Tools:
- Global Explanations for all the models are provided.
- Local explanations for single predictions are also available.
-
Stored Predictions:
- All predictions by uploaded documents or manual input are stored in a PostgreSQL database allowing to track historical results and analyze patterns over time.
The project includes a data drift detection system that monitors for significant changes in the input data distribution every month. A Kolmogorov-Smirnov (KS) test is used to compare the cumulative distributions of feature values in the current data against the original training dataset(Baseline Data). The KS test measures the difference between two cumulative distributions and determines whether they are statistically significantly different.
If the test detects significant differences (p-value < 0.05) in more than 20% of the features, the system flags the model for retraining to adapt to the new data distribution.
- Data Monitoring: Each month, the system compares incoming data(Stored in the PostgreSQL Database Hosted on Neon) distributions against the training data using the KS test.
- Drift Detection: Data drift is detected if the KS test p-value is less than 0.05 for more than 20% of the features.
- Retraining: When data drift is detected the systems starts retraining the models.
- Deployment: The retrained models are automatically deployed to the Streamlit app.
This process is managed using GitHub Actions, which automates the monitoring and retraining process. Monthly, GitHub Actions triggers the data drift detection workflow, which runs the KS test and retrains the models if necessary.
- Drift detection is unsupervised and does not require labels. However, retraining requires access to labeled data, making this system most effective in scenarios where new data can be labeled and incorporated into the training pipeline.
For the deployment, I have integrated many models in the Streamlit app. This approach allows us to have simple models like Logistic Regression for interpretability and complex models like Gradient Boosting, LightGBM, XGBoost, and CatBoost for high performance. For both types, the application provides explanations for the predictions while giving users the opportunity to explore different models and compare the predictions.
The app is live on Streamlit Cloud and can be accessed here: Breast Cancer Prediction App
The app is fully Dockerized and can be pulled and run locally:
-
Pull the Docker image:
docker pull touradbaba/model_engineering_app:latest
-
Run the container:
docker run -p 8501:8501 touradbaba/model_engineering_app:latest
git clone https://github.com/TouradBaba/model_engineering.git
cd model_engineering
Setting up a virtual environment ensures that your project dependencies are isolated from your system Python installation, preventing conflicts with other projects.
- Create a virtual environment:
python -m venv myenv
- Activate the virtual environment:
myenv\Scripts\activate
- Create a virtual environment:
python3 -m venv myenv
- Activate the virtual environment:
source myenv/bin/activate
Install all required packages using the full_requirements.txt
file:
pip install -r full_requirements.txt
- Create a PostgreSQL database on Neon and update the connection details in your local configuration.
- Ensure the database is accessible for storing predictions.
Start the Streamlit app:
streamlit run deployment/app.py