Skip to content

Commit

Permalink
updated binary classification AIDS ex.
Browse files Browse the repository at this point in the history
  • Loading branch information
lshpaner committed Nov 20, 2024
1 parent 750f24d commit 3ff6a27
Show file tree
Hide file tree
Showing 11 changed files with 105 additions and 83 deletions.
Binary file modified docs/.doctrees/about.doctree
Binary file not shown.
Binary file modified docs/.doctrees/caveats.doctree
Binary file not shown.
Binary file modified docs/.doctrees/changelog.doctree
Binary file not shown.
Binary file modified docs/.doctrees/environment.pickle
Binary file not shown.
Binary file modified docs/.doctrees/getting_started.doctree
Binary file not shown.
Binary file modified docs/.doctrees/index.doctree
Binary file not shown.
Binary file modified docs/.doctrees/references.doctree
Binary file not shown.
Binary file modified docs/.doctrees/usage_guide.doctree
Binary file not shown.
63 changes: 35 additions & 28 deletions docs/_sources/usage_guide.rst.txt
Original file line number Diff line number Diff line change
Expand Up @@ -448,10 +448,18 @@ You can use this function to evaluate the model by printing the output.
# ------------------------- VALID AND TEST METRICS -----------------------------
print("Validation Metrics")
class_report_val, cm_val = model_xgb.return_metrics(X_valid, y_valid, optimal_threshold=True)
class_report_val, cm_val = model_xgb.return_metrics(
X_valid,
y_valid,
optimal_threshold=True,
)
print()
print("Test Metrics")
class_report_test, cm_test = model_xgb.return_metrics(X_test, y_test, optimal_threshold=True)
class_report_test, cm_test = model_xgb.return_metrics(
X_test,
y_test,
optimal_threshold=True,
)
.. code-block:: bash
Expand Down Expand Up @@ -521,22 +529,22 @@ Step 10: Calibrate the Model (if needed)
import matplotlib.pyplot as plt
from sklearn.calibration import calibration_curve
# Get the predicted probabilities for the validation data from the uncalibrated model
## Get the predicted probabilities for the validation data from uncalibrated model
y_prob_uncalibrated = model_xgb.predict_proba(X_test)[:, 1]
# Compute the calibration curve for the uncalibrated model
## Compute the calibration curve for the uncalibrated model
prob_true_uncalibrated, prob_pred_uncalibrated = calibration_curve(
y_test,
y_prob_uncalibrated,
n_bins=6,
n_bins=10,
)
# Calibrate the model
## Calibrate the model
if model_xgb.calibrate:
model_xgb.calibrateModel(X, y, score="roc_auc")
model_xgb.calibrateModel(X, y, score="roc_auc")
# Predict on the validation set
y_test_pred = model_xgb.predict_proba(X_test)[:,1]
## Predict on the validation set
y_test_pred = model_xgb.predict_proba(X_test)[:, 1]
.. code-block:: bash
Expand Down Expand Up @@ -568,44 +576,43 @@ Step 10: Calibrate the Model (if needed)
.. code-block:: python
# Get the predicted probabilities for the validation data from calibrated model
## Get the predicted probabilities for the validation data from calibrated model
y_prob_calibrated = model_xgb.predict_proba(X_test)[:, 1]
# Compute the calibration curve for the calibrated model
## Compute the calibration curve for the calibrated model
prob_true_calibrated, prob_pred_calibrated = calibration_curve(
y_test,
y_prob_calibrated,
n_bins=6,
y_test,
y_prob_calibrated,
n_bins=10,
)
# Plot the calibration curves
## Plot the calibration curves
plt.figure(figsize=(5, 5))
plt.plot(
prob_pred_uncalibrated,
prob_true_uncalibrated,
marker="o",
label="Uncalibrated XGBoost",
prob_pred_uncalibrated,
prob_true_uncalibrated,
marker="o",
label="Uncalibrated XGBoost",
)
plt.plot(
prob_pred_calibrated,
prob_true_calibrated,
marker="o",
label="Calibrated XGBoost",
prob_pred_calibrated,
prob_true_calibrated,
marker="o",
label="Calibrated XGBoost",
)
plt.plot(
[0, 1],
[0, 1],
linestyle="--",
label="Perfectly calibrated",
[0, 1],
[0, 1],
linestyle="--",
label="Perfectly calibrated",
)
plt.xlabel("Predicted probability")
plt.ylabel("True probability in each bin")
plt.title("Calibration plot (reliability curve)")
plt.legend()
plt.show()
.. raw:: html

<div class="no-click">
Expand Down
62 changes: 35 additions & 27 deletions docs/usage_guide.html
Original file line number Diff line number Diff line change
Expand Up @@ -512,10 +512,18 @@ <h3>Step 9: Return Metrics (Optional)<a class="headerlink" href="#step-9-return-
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># ------------------------- VALID AND TEST METRICS -----------------------------</span>

<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Validation Metrics&quot;</span><span class="p">)</span>
<span class="n">class_report_val</span><span class="p">,</span> <span class="n">cm_val</span> <span class="o">=</span> <span class="n">model_xgb</span><span class="o">.</span><span class="n">return_metrics</span><span class="p">(</span><span class="n">X_valid</span><span class="p">,</span> <span class="n">y_valid</span><span class="p">,</span> <span class="n">optimal_threshold</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">class_report_val</span><span class="p">,</span> <span class="n">cm_val</span> <span class="o">=</span> <span class="n">model_xgb</span><span class="o">.</span><span class="n">return_metrics</span><span class="p">(</span>
<span class="n">X_valid</span><span class="p">,</span>
<span class="n">y_valid</span><span class="p">,</span>
<span class="n">optimal_threshold</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="p">)</span>
<span class="nb">print</span><span class="p">()</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Test Metrics&quot;</span><span class="p">)</span>
<span class="n">class_report_test</span><span class="p">,</span> <span class="n">cm_test</span> <span class="o">=</span> <span class="n">model_xgb</span><span class="o">.</span><span class="n">return_metrics</span><span class="p">(</span><span class="n">X_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">,</span> <span class="n">optimal_threshold</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">class_report_test</span><span class="p">,</span> <span class="n">cm_test</span> <span class="o">=</span> <span class="n">model_xgb</span><span class="o">.</span><span class="n">return_metrics</span><span class="p">(</span>
<span class="n">X_test</span><span class="p">,</span>
<span class="n">y_test</span><span class="p">,</span>
<span class="n">optimal_threshold</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="p">)</span>
</pre></div>
</div>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>Validation<span class="w"> </span>Metrics
Expand Down Expand Up @@ -583,22 +591,22 @@ <h3>Step 10: Calibrate the Model (if needed)<a class="headerlink" href="#step-10
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
<span class="kn">from</span> <span class="nn">sklearn.calibration</span> <span class="kn">import</span> <span class="n">calibration_curve</span>

<span class="c1"># Get the predicted probabilities for the validation data from the uncalibrated model</span>
<span class="c1">## Get the predicted probabilities for the validation data from uncalibrated model</span>
<span class="n">y_prob_uncalibrated</span> <span class="o">=</span> <span class="n">model_xgb</span><span class="o">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">X_test</span><span class="p">)[:,</span> <span class="mi">1</span><span class="p">]</span>

<span class="c1"># Compute the calibration curve for the uncalibrated model</span>
<span class="c1">## Compute the calibration curve for the uncalibrated model</span>
<span class="n">prob_true_uncalibrated</span><span class="p">,</span> <span class="n">prob_pred_uncalibrated</span> <span class="o">=</span> <span class="n">calibration_curve</span><span class="p">(</span>
<span class="n">y_test</span><span class="p">,</span>
<span class="n">y_prob_uncalibrated</span><span class="p">,</span>
<span class="n">n_bins</span><span class="o">=</span><span class="mi">6</span><span class="p">,</span>
<span class="n">n_bins</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
<span class="p">)</span>

<span class="c1"># Calibrate the model</span>
<span class="c1">## Calibrate the model</span>
<span class="k">if</span> <span class="n">model_xgb</span><span class="o">.</span><span class="n">calibrate</span><span class="p">:</span>
<span class="n">model_xgb</span><span class="o">.</span><span class="n">calibrateModel</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">score</span><span class="o">=</span><span class="s2">&quot;roc_auc&quot;</span><span class="p">)</span>
<span class="n">model_xgb</span><span class="o">.</span><span class="n">calibrateModel</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">score</span><span class="o">=</span><span class="s2">&quot;roc_auc&quot;</span><span class="p">)</span>

<span class="c1"># Predict on the validation set</span>
<span class="n">y_test_pred</span> <span class="o">=</span> <span class="n">model_xgb</span><span class="o">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">X_test</span><span class="p">)[:,</span><span class="mi">1</span><span class="p">]</span>
<span class="c1">## Predict on the validation set</span>
<span class="n">y_test_pred</span> <span class="o">=</span> <span class="n">model_xgb</span><span class="o">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">X_test</span><span class="p">)[:,</span> <span class="mi">1</span><span class="p">]</span>
</pre></div>
</div>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>Change<span class="w"> </span>back<span class="w"> </span>to<span class="w"> </span>CPU
Expand All @@ -624,36 +632,36 @@ <h3>Step 10: Calibrate the Model (if needed)<a class="headerlink" href="#step-10
roc_auc<span class="w"> </span>after<span class="w"> </span>calibration:<span class="w"> </span><span class="m">0</span>.9280033238366572
</pre></div>
</div>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># Get the predicted probabilities for the validation data from calibrated model</span>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1">## Get the predicted probabilities for the validation data from calibrated model</span>
<span class="n">y_prob_calibrated</span> <span class="o">=</span> <span class="n">model_xgb</span><span class="o">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">X_test</span><span class="p">)[:,</span> <span class="mi">1</span><span class="p">]</span>

<span class="c1"># Compute the calibration curve for the calibrated model</span>
<span class="c1">## Compute the calibration curve for the calibrated model</span>
<span class="n">prob_true_calibrated</span><span class="p">,</span> <span class="n">prob_pred_calibrated</span> <span class="o">=</span> <span class="n">calibration_curve</span><span class="p">(</span>
<span class="n">y_test</span><span class="p">,</span>
<span class="n">y_prob_calibrated</span><span class="p">,</span>
<span class="n">n_bins</span><span class="o">=</span><span class="mi">6</span><span class="p">,</span>
<span class="n">y_test</span><span class="p">,</span>
<span class="n">y_prob_calibrated</span><span class="p">,</span>
<span class="n">n_bins</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
<span class="p">)</span>


<span class="c1"># Plot the calibration curves</span>
<span class="c1">## Plot the calibration curves</span>
<span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span>
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span>
<span class="n">prob_pred_uncalibrated</span><span class="p">,</span>
<span class="n">prob_true_uncalibrated</span><span class="p">,</span>
<span class="n">marker</span><span class="o">=</span><span class="s2">&quot;o&quot;</span><span class="p">,</span>
<span class="n">label</span><span class="o">=</span><span class="s2">&quot;Uncalibrated XGBoost&quot;</span><span class="p">,</span>
<span class="n">prob_pred_uncalibrated</span><span class="p">,</span>
<span class="n">prob_true_uncalibrated</span><span class="p">,</span>
<span class="n">marker</span><span class="o">=</span><span class="s2">&quot;o&quot;</span><span class="p">,</span>
<span class="n">label</span><span class="o">=</span><span class="s2">&quot;Uncalibrated XGBoost&quot;</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span>
<span class="n">prob_pred_calibrated</span><span class="p">,</span>
<span class="n">prob_true_calibrated</span><span class="p">,</span>
<span class="n">marker</span><span class="o">=</span><span class="s2">&quot;o&quot;</span><span class="p">,</span>
<span class="n">label</span><span class="o">=</span><span class="s2">&quot;Calibrated XGBoost&quot;</span><span class="p">,</span>
<span class="n">prob_pred_calibrated</span><span class="p">,</span>
<span class="n">prob_true_calibrated</span><span class="p">,</span>
<span class="n">marker</span><span class="o">=</span><span class="s2">&quot;o&quot;</span><span class="p">,</span>
<span class="n">label</span><span class="o">=</span><span class="s2">&quot;Calibrated XGBoost&quot;</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span>
<span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span>
<span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span>
<span class="n">linestyle</span><span class="o">=</span><span class="s2">&quot;--&quot;</span><span class="p">,</span>
<span class="n">label</span><span class="o">=</span><span class="s2">&quot;Perfectly calibrated&quot;</span><span class="p">,</span>
<span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span>
<span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span>
<span class="n">linestyle</span><span class="o">=</span><span class="s2">&quot;--&quot;</span><span class="p">,</span>
<span class="n">label</span><span class="o">=</span><span class="s2">&quot;Perfectly calibrated&quot;</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s2">&quot;Predicted probability&quot;</span><span class="p">)</span>
<span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s2">&quot;True probability in each bin&quot;</span><span class="p">)</span>
Expand Down
63 changes: 35 additions & 28 deletions source/usage_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -448,10 +448,18 @@ You can use this function to evaluate the model by printing the output.
# ------------------------- VALID AND TEST METRICS -----------------------------
print("Validation Metrics")
class_report_val, cm_val = model_xgb.return_metrics(X_valid, y_valid, optimal_threshold=True)
class_report_val, cm_val = model_xgb.return_metrics(
X_valid,
y_valid,
optimal_threshold=True,
)
print()
print("Test Metrics")
class_report_test, cm_test = model_xgb.return_metrics(X_test, y_test, optimal_threshold=True)
class_report_test, cm_test = model_xgb.return_metrics(
X_test,
y_test,
optimal_threshold=True,
)
.. code-block:: bash
Expand Down Expand Up @@ -521,22 +529,22 @@ Step 10: Calibrate the Model (if needed)
import matplotlib.pyplot as plt
from sklearn.calibration import calibration_curve
# Get the predicted probabilities for the validation data from the uncalibrated model
## Get the predicted probabilities for the validation data from uncalibrated model
y_prob_uncalibrated = model_xgb.predict_proba(X_test)[:, 1]
# Compute the calibration curve for the uncalibrated model
## Compute the calibration curve for the uncalibrated model
prob_true_uncalibrated, prob_pred_uncalibrated = calibration_curve(
y_test,
y_prob_uncalibrated,
n_bins=6,
n_bins=10,
)
# Calibrate the model
## Calibrate the model
if model_xgb.calibrate:
model_xgb.calibrateModel(X, y, score="roc_auc")
model_xgb.calibrateModel(X, y, score="roc_auc")
# Predict on the validation set
y_test_pred = model_xgb.predict_proba(X_test)[:,1]
## Predict on the validation set
y_test_pred = model_xgb.predict_proba(X_test)[:, 1]
.. code-block:: bash
Expand Down Expand Up @@ -568,44 +576,43 @@ Step 10: Calibrate the Model (if needed)
.. code-block:: python
# Get the predicted probabilities for the validation data from calibrated model
## Get the predicted probabilities for the validation data from calibrated model
y_prob_calibrated = model_xgb.predict_proba(X_test)[:, 1]
# Compute the calibration curve for the calibrated model
## Compute the calibration curve for the calibrated model
prob_true_calibrated, prob_pred_calibrated = calibration_curve(
y_test,
y_prob_calibrated,
n_bins=6,
y_test,
y_prob_calibrated,
n_bins=10,
)
# Plot the calibration curves
## Plot the calibration curves
plt.figure(figsize=(5, 5))
plt.plot(
prob_pred_uncalibrated,
prob_true_uncalibrated,
marker="o",
label="Uncalibrated XGBoost",
prob_pred_uncalibrated,
prob_true_uncalibrated,
marker="o",
label="Uncalibrated XGBoost",
)
plt.plot(
prob_pred_calibrated,
prob_true_calibrated,
marker="o",
label="Calibrated XGBoost",
prob_pred_calibrated,
prob_true_calibrated,
marker="o",
label="Calibrated XGBoost",
)
plt.plot(
[0, 1],
[0, 1],
linestyle="--",
label="Perfectly calibrated",
[0, 1],
[0, 1],
linestyle="--",
label="Perfectly calibrated",
)
plt.xlabel("Predicted probability")
plt.ylabel("True probability in each bin")
plt.title("Calibration plot (reliability curve)")
plt.legend()
plt.show()
.. raw:: html

<div class="no-click">
Expand Down

0 comments on commit 3ff6a27

Please sign in to comment.