-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtidymodels.qmd
653 lines (495 loc) · 29.2 KB
/
tidymodels.qmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
---
title: "| ![](logo@2x.png){width=25%} \\vspace{0.25in} \nVorhersage der Kundenfluktuation mit Tidymodels"
author: "Patrik Häcki"
date: "today"
date-format: "long"
format: html
editor: visual
lang: de
---
## Prolog
Dieses Dokument beschreibt einen Prozess zur Entwicklung eines maschinellen Lernmodells für die Vorhersage der **Kundenfluktuation** bzw. **Kundenabwanderung**. Der Prozess beinhaltet eine einfache Datenexploration/-analyse, bevor zwei Modelle trainiert und getestet werden. Die konkrete Aufgabe ist die Klassifizierung, d.h. die Vorhersage, welche Kunden abwandern und welche nicht. Deshalb fiel die Auswahl auf zwei Modelle, die üblicherweise für die **Klassifikation** verwendet werden: die **logistische Regression** und **Random Forest**.
## Datenexploration/-analyse
Es ist wichtig, dass wir uns mit dem Datensatz vertraut machen und ihn verstehen, bevor wir beginnen.
```{r}
#| message: false
#| echo: false # Die Option echo: false unterdrückt die Ausgabe des Codes
# (nur das Resultat wird angezeigt).
library(tidyverse)
library(tidymodels)
library(caret)
library(corrplot)
library(DALEXtra)
library(lime)
library(ranger)
library(skimr)
library(themis)
library(vip)
```
```{r}
telco_data <- read_csv(file = "data/WA_Fn-UseC_-Telco-Customer-Churn.csv",
show_col_types = FALSE)
```
```{r}
telco_data <- telco_data %>%
mutate(Churn = as.factor(Churn))
```
```{r}
head(telco_data)
```
```{r}
str(telco_data)
```
```{r}
skim(telco_data)
```
```{r}
telco_data %>%
select(where(is.numeric)) %>%
na.omit() %>%
cor() %>%
corrplot(method = "number")
```
Aus der Datenexploration gewinnen wir verschiedene Informationen. Zunächst einmal sehen wir eine Vielzahl von Variablen des Typs Character. Nur eine von ihnen hat mehr als vier eindeutige Werte, dies ist die Spalte customerID. Das gibt uns eine Menge Faktoren, mit denen wir arbeiten können und die wir in Betracht ziehen sollten. Daneben gibt es vier numerische Variablen, von denen sich allerdings eine etwas von den anderen unterscheidet. SeniorCitizen hat nur zwei Ausprägungen, 0 und 1. Es handelt sich um eine boolesche Variable, die wir also später wie die anderen Zeichenvariablen behandeln werden.
Was die anderen numerischen Variablen betrifft, so können wir feststellen, dass sie in ihrer Grössenordnung stark variieren und dass keine von ihnen einer Normalverteilung entspricht. Wir werden uns dies in Kürze anhand einiger grösserer Diagramme genauer ansehen, aber wir müssen diese Hürden überwinden, wenn wir mit der logistischen Regression auch nur annähernd genaue Vorhersagen machen wollen.
Wir haben nicht viele fehlende Werte, aber sie sind da. Zum Glück beschränken sie sich auf die Spalte TotalCharges. Fehlende Werte sind in erster Linie eine Herausforderung für Random Forest. Deshalb müssen wir bei der Erstellung des Modells einen Weg für den Umgang mit diesem Problem finden.
Ein weiterer Punkt, den wir überprüfen müssen, ist die Verteilung der Kundenfluktuation. Wie viele von beiden Werten sind im Datensatz enthalten?
```{r}
telco_data %>%
group_by(Churn) %>%
summarise(Anzahl = n(), Prozent = round(n() / nrow(telco_data) * 100,
digits = 1))
```
Was wir sehen, ist ein erhebliches Ungleichgewicht bzw. ein ungefähres Verhältnis von 75/25 für Nein/Ja. Unausgewogene Daten (imbalanced data) sind ein Problem, da sie Vorhersagen erschweren und den Wert bestimmter Metriken zunichte machen. Die **Accuracy** (Richtigkeit) kann zum Beispiel völlig irreführend sein, da der 75er-Anteil den kleineren und manchmal wertvolleren Anteil überwiegt. Zum Glück lassen sich unausgewogene Daten mit bestimmten Methoden korrigieren. Eine erste ist die **Data Stratification** (Datenschichtung). Durch diese Technik wird erreicht, dass der kleinere Datensatz die gleiche statistische Aufschlüsselung aufweist wie der grössere. Wir können die Schichtung auf unsere Abwanderungsspalte anwenden und so sicherstellen, dass alle Stichproben die wir ziehen, in etwa die Verteilung des Hauptdatensatzes repräsentieren. Die zweite Technik ist das **Up-Sampling**, das mit dem `themis`-Paket durchgeführt wird. Bei dieser Technik werden so lange Stichproben gezogen, bis die beiden Teile des Faktors, in unserem Fall «Ja» und «Nein», in ihrer Häufigkeit gleich sind. In diesem Projekt wurden beide Techniken angewandt.
Als nächstes betrachten wir die nicht-booleschen numerischen Variablen genauer. Anhand der Analyse mit `skim()` zu urteilen, liegt keine Normalverteilung vor. Lassen Sie uns daher einen genaueren Blick darauf werfen.
```{r}
telco_data %>%
ggplot(aes(x = tenure)) +
geom_histogram(bins = 30, color = "white")
```
```{r}
telco_data %>%
ggplot(aes(x = MonthlyCharges)) +
geom_histogram(bins = 30, color = "white")
```
```{r}
#| warning: false
telco_data %>%
ggplot(aes(x = TotalCharges)) +
geom_histogram(bins = 30, color = "white")
```
Der Verdacht bestätigt sich. Die Skalen sind vollkommen unterschiedlich und nichts, was einer Normalverteilung nahe kommt. Um dieses Problem zu umgehen, werden zwei Techniken angewandt: **Log-Transformation** und **Normalisierung**. Auf diese Weise wird die Arbeit mit den Daten für das Modell der logistischen Regression wesentlich erleichtert. Für das Random Forest-Modell ist das jedoch nicht erforderlich.
Es gibt einen weiteren Punkt, den wir uns ansehen können. Wie hoch ist die Häufigkeit der Kundenfluktuation auf der Grundlage der numerischen Variablen? Gibt es zum Beispiel auf der Grundlage der monatlichen Gebühren bei einem bestimmten Wert eine höhere Häufigkeit der Kundenabwanderung?
```{r}
telco_data %>%
ggplot(aes(x = tenure, fill = Churn)) +
geom_bar(position = "dodge")
```
```{r}
telco_data %>%
filter(Churn == "Yes") %>%
ggplot(aes(x = MonthlyCharges)) +
geom_histogram(bins = 30, color = "white")
```
```{r}
telco_data %>%
filter(Churn == "Yes") %>%
ggplot(aes(x = TotalCharges)) +
geom_histogram(bins = 30, color = "white")
```
```{r}
telco_data %>%
pivot_longer(c(tenure, MonthlyCharges, TotalCharges),
names_to = "key",
values_to = "value", values_drop_na = TRUE) %>%
ggplot(aes(x = value, fill = key)) +
geom_boxplot() +
facet_wrap(~key, ncol = 1, scales = "free")
```
Daraus können wir einige wichtige Schlussfolgerungen ableiten:
- Es ist nicht weiter überraschend, dass Kunden mit kürzerer Vertragslaufzeit die abonnierten Dienste früher kündigen. Trotzdem ist es wichtig, dass sich dies in den Daten in einem solchen Ausmass widerspiegelt.
- Wir sehen auch, dass Kunden, die mehr pro Monat zahlen, den Dienst häufiger kündigen. Allenfalls gibt es einen Schwellenwert bei den monatlichen Kosten, ab denen die Kunden nicht mehr bereit sind, das Abonnement fortzusetzen.
- Anhand der vorliegenden Zahlen ist es wahrscheinlicher, dass Kunden abwandern, die insgesamt weniger für Telekommunikationsdienste bezahlen. Dieser Wert korreliert mit der Dauer der Vertragslaufzeit.
- Die drei Faktoren scheinen gute Prädiktoren für die Abwanderung von Kunden zu sein, da sie intuitiv verständlich sind.
Nachdem der grundlegende Überblick über die Daten vorliegt, können wir unser erstes Modell erstellen. Bevor wir dies tun, benötigen wir idealerweise ein Nullmodell. Es geht vielfach vergessen, dass die Erstellung von Modellen ohne einen Bezugspunkt sinnlos ist.
```{r}
# Auf Klassenungleichgewicht prüfen
table(telco_data$Churn)
prop.table(table(telco_data$Churn))
```
## Das Nullmodell
Als erstes müssen wir unsere Daten aufteilen.
```{r}
set.seed(123)
telco_split <-
initial_split(data = telco_data,
prop = 0.8,
strata = Churn)
telco_train <- training(telco_split)
telco_test <- testing(telco_split)
```
```{r}
cat(dim(telco_train), dim(telco_test))
```
Anschliessend erstellen wir ein sehr einfaches Nullmodell. Im Wesentlichen geht es von der Nullhypothese aus, dass der Kunde nicht abwandert. Ausgehend von der 75/25-Aufteilung des vollständigen Datensatzes können wir annehmen, dass dieses Modell eine Genauigkeit von etwa 75% haben wird. Das ist unsere Vergleichsbasis. Wenn unsere Modelle diesen Wert nicht übertreffen, haben wir ein Problem.
```{r}
# Nullmodell erstellen
mod_null <-
logistic_reg(mode = "classification") %>%
set_engine(engine = "glm") %>%
fit(Churn ~ 1, data = telco_train)
```
```{r}
# Berechnung der Accuracy (Richtigkeit) des Nullmodells für spätere Vergleiche
pred <- telco_train %>%
bind_cols(
predict(mod_null, new_data = telco_train, type = "class")
) %>%
rename(Churn_null = .pred_class)
```
```{r}
# Vergleich 1
accuracy(data = pred, Churn, Churn_null)
```
```{r}
# Vergleich 2
confusion_null <- pred %>%
conf_mat(truth = Churn, estimate = Churn_null)
confusion_null
```
## Logistisches Regressionsmodell
Im Falle der logistischen Regression liegt der grösste Teil der Arbeit im Rezept. Die Vorverarbeitung ist daher sehr wichtig und es ist unerlässlich, dass die Daten und Variablen vorbereitet sind.
- Zunächst haben wir ein Rezept, das unsere Formel repräsentiert und Churn als Output beinhaltet. Der Umfang gibt an, dass wir alle Prädiktorvariablen für die Vorhersage verwenden.
- Als Nächstes erhöhen wir die Stichprobe, indem wir ein Verhältnis von 1:1 für die Ja/Nein-Werte der Variable Churn festlegen. Wir werden prüfen, ob dies von Nutzen ist, da auch eine Stratifizierung verwendet wird.
- Wir machen SeniorCitizen zu einem Faktor und entfernen alle fehlenden Werte aus den Spalten.
- Numerische Variablen werden logarithmiert und normalisiert. Dies geschieht nach dem vorherigen Schritt, um sicherzustellen, dass «SeniorCitizen» nicht davon betroffen ist.
- Wir setzen alle unsere nominalen (kategorialen) Variablen als Dummy-Variablen.
- Wir entfernen Variablen, die zu stark miteinander korrelieren und alle diejenigen, welche eine Varianz von Null haben.
```{r}
# Vorverarbeitungsrezept erstellen
log_reg_recipe <-
recipe(Churn ~ ., data = telco_train) %>%
# Anzahl der Stichproben für «Ja» und «Nein» gleich setzen; Alternative: step_downsample()
step_upsample(Churn, over_ratio = 1) %>%
update_role(customerID, new_role = "id") %>%
step_mutate(SeniorCitizen = as.factor(SeniorCitizen)) %>%
step_naomit(everything(), role = TRUE) %>%
# logarithmische Transformation nicht-normaler, numerischer Variablen
step_log(tenure, MonthlyCharges, TotalCharges) %>%
# z-Standardisierung aller numerischen Variablen
step_normalize(tenure, MonthlyCharges, TotalCharges) %>%
step_dummy(all_nominal_predictors()) %>%
# Hoch korrelierte Variablen entfernen; Alternative: all_numeric()
step_corr(all_numeric_predictors(), threshold = 0.7) %>%
# Numerische Variablen mit Varianz Null entfernen; Alternative: step_nzv()
step_zv(all_numeric_predictors())
```
```{r}
prep(log_reg_recipe) %>%
bake(new_data = telco_train) # Vorverarbeitete Daten anzeigen
```
Als Nächstes richten wir unseren **Workflow** (Arbeitsablauf) ein, der ein **Modell** und das gerade erstellte **Rezept** umfasst.
```{r}
# Modell erstellen
log_reg_model <-
logistic_reg(mode = "classification") %>%
set_engine(engine = "glm")
```
```{r}
# Workflow mit Modell erstellen; Paket "workflows"
log_reg_workflow <-
workflow() %>%
add_model(spec = log_reg_model) %>%
add_recipe(recipe = log_reg_recipe)
```
Danach folgen die Re-Sampling-Methoden und die Ermittlung von Metriken für unser Modell. Zu diesem Zweck werden zwei Re-Sampling-Methoden angewandt: **Bootstrapping** und **10-Fold Cross Validation** (10-fache Kreuzvalidierung). Schauen wir uns zunächst die Bootstrapping-Methode an.
```{r}
# Auflistung von relevanten Metriken
my_metrics <- metric_set(accuracy, f_meas, roc_auc)
```
```{r}
# Bootstrap fünfmal auf Trainingsdaten anwenden
telco_bstraps <-
bootstraps(data = telco_train, times = 5, strata = Churn)
```
```{r}
# Vorhersagen für jede der 5 Stichproben
log_reg_resamples <-
fit_resamples(object = log_reg_workflow,
resamples = telco_bstraps,
metrics = my_metrics,
# Vorhersagen für Konfusionsmatrix speichern
control = control_resamples(save_pred = TRUE))
```
```{r}
# Kennzahlen der Stichproben ausgeben
log_reg_resamples %>%
collect_metrics()
```
Es gibt eine Reihe von wichtigen Messgrössen, die man auch im Hinblick auf die Nullhypothese verstehen muss.
- **Nullhypothese:** Der Kunde wechselt nicht. Wahr positiv ist die korrekte Vermutung, dass der Kunde nicht abwandert. Wahr negativ ist die korrekte Vermutung, dass der Kunde abwandert.
- **Accuracy (Richtigkeit):** Die intuitivste aller Metriken zeigt, wie oft die Vorhersagen des Modells richtig waren. Unser Modell erreichte einen Wert von ca. 0.75, was ungefähr dem Nullmodell entspricht. Das bedeutet im Wesentlichen, dass unser Modell in 3 von 4 Fällen, in denen es «ja» oder «nein» vorhersagt, die richtige Antwort gibt. Es ist wichtig, alle Probleme mit den Daten oder methodischen Entscheidungen zu berücksichtigen, die die «Accuracy» mehr oder weniger relevant machen. In unserem Fall können mögliche Konflikte zwischen Schichtung und Up-Sampling dazu führen, dass diese Kennzahl weniger zuverlässig ist. Wenn die Schichtung beispielsweise das Up-Sampling ausser Kraft setzt, überwiegen die «Nein»-Raten die «Ja»-Raten, was zu einer hohen Richtigkeit trotz schlechter Leistung bei den «Ja»-Vorhersagen führen kann.
- **Precision (Genauigkeit):** Diese Metrik zeigt die Anzahl der richtigen positiven Vorhersagen im Verhältnis zur Anzahl der falschen positiven Vorhersagen des Modells. Unser Modell hat einen recht respektablen Wert von etwa 0.90. Das bedeutet, dass das Modell in 90% der Fälle richtig lag, in denen es vorhersagte, dass ein Kunde nicht abwandert.
- **Recall (Sensitivität):** Dies ist die Rate der richtig-positiven Werte. Ein hoher Sensitivitätswert bedeutet, dass das Modell nur wenige falsch-negative Werte hatte und fast alle positiven Werte im Satz erfassen konnte. In unserem Fall bedeutet dies, dass unser Modell in der Lage ist, so viele Kunden wie möglich korrekt zu erfassen, die nicht abgewandert sind. Unser Sensitivitätswert lag bei 0.73, was bedeutet, dass das Modell etwa ein Viertel aller wirklich positiven Werte falsch klassifiziert hat. Im Wesentlichen nahm unser Modell bei einem Viertel der Kunden an, dass sie abwandern würden, obschon sie das nicht taten.
- **F_meas oder F1-Score:** Der Wert ist das harmonische Mittel aus «Precision» und «Recall». Er zeigt das Gleichgewicht zwischen den beiden Metriken im Modell. Dies ist wichtig, da diese beiden Metriken einen Kompromiss darstellen, d.h. der Versuch, eine der beiden Metriken zu maximieren, verringert die andere. Unser Wert liegt bei 0.81, was darauf hindeutet, dass die beiden Kennzahlen insgesamt relativ ausgeglichen sind.
- **Specificity (Spezifität):** Dies zeigt die Rate der echten Negativwerte, die angibt, wie gut ein Modell alle Werte erfasst, die tatsächlich im Widerspruch zu unserer Nullhypothese stehen, d.h. in unserem Fall Kunden, die abwandern. Anhand dieser Metrik können wir also sehen, wie gut das Modell alle Kunden erfasst, die den Service kündigen. Wir haben etwa 0.80 erreicht, was in Ordnung ist. Allerdings wird etwa ein Fünftel der abwandernden Kunden übersehen, was verbesserungswürdig ist.
- **roc_auc:** Das ist die Fläche unter der roc-Kurve des gegebenen Modells. Die Kurve zeigt auf der y-Achse die wahr-positive Rate (Recall) und auf der x-Achse die falsch-positive Rate (1 - Spezifität). Je näher die Fläche unter dieser Kurve bei 1 liegt, desto besser. Unser Modell erreicht hier einen Wert von 0.84, was respektabel ist.
```{r}
# Konfusionsmatrix für Wiederholungsstichproben
log_reg_resamples %>%
conf_mat_resampled()
```
Anhand der **Konfusionsmatrix** können wir die Ungenauigkeit des Modells deutlich erkennen. Das vorliegende Modell kann die «Nein»-Vorhersage gut handhaben, aber es ist im Grunde ein Münzwurf, wenn es um die «Ja»-Vorhersage geht. Eine Lösung dieses Problems würde die Gesamtleistung des Modells erheblich verbessern.
Lassen Sie uns die Berechnungen mit einer 10-fachen Kreuzvalidierung durchführen. Wir sehen ziemlich ähnliche Zahlen wie bei der verwendeten Bootstrapping-Methode.
```{r}
# Kreuzvalidierung erstellen; Paket "rsample"
telco_folds <-
vfold_cv(data = telco_train, v = 10) # Alternative: v = 3 oder v = 5, strata = Churn
```
```{r}
keep_pred <- control_resamples(save_pred = TRUE,
save_workflow = TRUE)
```
```{r}
cv_log_reg_res <-
fit_resamples(object = log_reg_workflow,
resamples = telco_folds,
metrics = my_metrics,
control = keep_pred)
```
```{r}
cv_log_reg_res %>%
collect_metrics()
```
```{r}
cv_log_reg_res %>%
conf_mat_resampled()
```
Das logistische Modell konnte nicht soweit optimiert werden, wie es notwendig wäre, um einen relevanten Unterschied zum Nullmodell zu erreichen. Es ist daher an der Zeit, ein Modell auszuprobieren, welches das Tuning eigenständig erlernen kann, wie z.B. **Random Forest**.
## Random Forest
Ein grosser Unterschied zum logistischen Regressionsmodell besteht in der Vorverarbeitung. Random Forest benötigt hierfür nur sehr wenig. Ein Punkt, den man dafür im Auge behalten muss, sind fehlende Werte. Es gibt einige Strategien, um damit umzugehen, wie z.B. das Entfernen oder die statistische Imputation. Im vorliegenden Fall wird Letzteres umgesetzt und der Mittelwert der Spalte für alle fehlenden Werte verwendet. Dies ist eine gängige Strategie für den Umgang mit fehlenden Werten beim Trainieren eines Modells. Die fehlenden Werte stammten alle aus der Spalte TotalCharges.
Für Random Forest gibt es wiederum ein Rezept, ein Modell und einen Arbeitsablauf, gleich wie beim logistischen Regressionsmodell.
```{r}
# Vorverarbeitungsrezept erstellen; Paket "recipes"
set.seed(123)
rf_recipe <-
recipe(Churn ~ ., data = telco_train) %>%
update_role(customerID, new_role = "id") %>%
# Anzahl der Stichproben für «Ja» und «Nein» gleich setzen; Alternative: step_downsample()
step_upsample(Churn, over_ratio = 1) %>%
# Verwendung der statistischen Imputation zur Behandlung fehlender Werte
step_impute_mean(TotalCharges)
# Rezept vorbereiten, dass es für andere Daten verwendet werden kann
# %>% prep()
# Weitere mögliche Einstellungen:
# step_normalize(all_numeric_predictors()) %>% # Prädiktoren normalisieren
# step_corr(all_numeric_predictors()) %>% # Hoch korrelierte Variablen entfernen
# step_nzv(all_numeric_predictors()) # Numerische Variablen mit nahezu Varianz Null entfernen
```
```{r}
# Modellspezifikation festlegen; Pakete "parsnip" & "dials"
model_spec <-
rand_forest(mode = "classification", # Alternative: %>% set_mode("classification")
mtry = tune(),
trees = tune(),
min_n = tune()
) %>%
# Die Wichtigkeit auf Permutation setzen für spätere Wichtigkeitsdiagramme; Alt.: "impurity"
set_engine(engine = "ranger", importance = "permutation")
```
```{r}
# Tune-Parameter vorgängig auskommentieren
# model_spec %>%
# fit(Churn ~ ., data = telco_train)
```
```{r}
# Raster erstellen
tree_grid <-
grid_regular(mtry() %>% range_set(c(1, 10)),
trees() %>% range_set(c(50, 300)),
min_n() %>% range_set(c(2, 10)),
levels = 3)
```
```{r}
# Workflow mit Rezept und Modell erstellen; Paket "workflows"
tuned_wf <-
workflow() %>%
add_model(spec = model_spec) %>%
add_recipe(recipe = rf_recipe)
```
In den obigen Codeabschnitten werden mehrere Schritte durchlaufen. In `model_spec` wird ein Random Forest-Modell mit drei Parametern erstellt, die alle auf `tune()` gesetzt sind. Dieser Wert ist ein Platzhalter, mit dem wir uns weiter unten in `tree_grid` beschäftigen. Grundsätzlich geht es darum, eine Reihe von Werten bereitzustellen, aus denen `grid_regular()` auswählt. Die Anzahl der auszuwählenden Werte wird durch `levels` festgelegt. Dies ermöglicht uns, auf einfache Weise eine Vielzahl möglicher Hyperparameterwerte für den von uns gewählten Modelltyp zu testen.
Im vorliegenden Beispiel erhält mtry die Werte 1, 5 und 10. Es hat 3 Werte generiert, weil wir `levels = 3` angegeben haben. Wenn wir dies später an eine 10-fache Kreuzvalidierung weitergeben, werden die Modelle mit jedem dieser Werte angepasst und wir können das beste Modell aus dem Stapel auswählen. Dies gilt für jeden der Parameterbereiche und es ist daher darauf zu achten, dass die Rechenzeit nicht aus dem Ruder läuft. Mit Hilfe einer einfachen Kombinatorik erzeugt das vorhandene Raster 3^3^ = 27 Zeilen, wenn man die `levels` auf 5 erhöht, kommt man auf 5^3^ = 125 Zeilen. Für die Kreuzvalidierung bedeutet dies eine Vervierfachung der Berechnungszeit. Seien Sie sich daher immer über den Umfang der Berechnungen im Klaren, die Sie dem Computer zumuten.
```{r}
# Modell abstimmen; Paket "tune"
tree_results <-
tune_grid(object = tuned_wf, # Workflow, definiert mit den Paketen "parsnip" & "workflows"
resamples = telco_folds, # Definiert mit den Paketen "rsample" & "recipes"
grid = tree_grid, # Gitterraum, definiert mit dem Paket "dials"; Alt.: grid = 10
# Yardstick-Paket zur Definition der Metriken für die Bewertung der Modellleistung
metrics = metric_set(accuracy, f_meas, roc_auc),
# Zusätzlich: sensitivity, precision, recall mit Paket "caret"
control = control_resamples(save_pred = TRUE))
# Alternative: control_grid(verbose = FALSE/TRUE)
```
Da wir unsere Ergebnisse durch Kreuzvalidierung überprüft haben, können wir die Metriken vergleichen. Wir wählen nun das beste Modell aus. Daher betrachten wir die Leistung basierend auf dem Wert von `mtry` für die verschiedenen gewählten Werte der Bäume.
```{r}
tree_results %>%
collect_metrics() %>%
mutate(trees = as.factor(trees)) %>%
ggplot(aes(x = mtry, y = mean, color = trees)) +
geom_line(linewidth = 1.5, alpha = 0.6) +
geom_point(size = 2) +
facet_wrap(~ .metric, nrow = 2, scales = "free") +
scale_color_viridis_d(option = "plasma", begin = 0.9, end = 0)
```
Es scheint, dass ein `mtry` von 5 ideal ist, um die verschiedenen Metriken auszugleichen. Es ist wichtig, hier die freie Skala zu beachten. Der Wert von `roc_auc` scheint zu sinken, bewegt sich aber kaum mehr als 0.01. `specificity` hingegen sinkt wirklich, je mehr `mtry` steigt. Mit einem `mtry` von 5 maximieren wir unseren `f_meas` (F1-Score) und die Richtigkeit, während wir für `roc_auc` einen schönen Mittelwert erzielen. Die einzige Möglichkeit, die `specificity` zu maximieren, besteht darin, `mtry` auf 1 zu setzen, was eine schlechte Idee zu sein scheint. `mtry` gibt an, wie viele Prädikatorspalten abgefragt werden. Bei einem Datensatz mit 19 Prädikatoren kommt die Abfrage von nur einem Schätzwert einer Abfrage irreführender Informationen gleich. Es ist bemerkenswert, dass die verschiedenen Baumwerte in ihren Metriken wenig variieren, wobei 50 der schlechteste der drei ist. 175 scheint eine solide Wahl zu sein.
Auf der Grundlage einer beliebigen Metrik kann das beste Modell ausgewählt werden. Der Entscheid fällt auf den `f_meas`, weil dieser eine solide Metrik zur Bewertung der Modell-Leistung ist.
```{r}
tree_results %>%
show_best(metric = "f_meas",
n = 10)
```
```{r}
best_tree <-
tree_results %>%
select_best(metric = "f_meas") # Alternative: "accuracy"
```
```{r}
final_wf <-
tuned_wf %>%
finalize_workflow(best_tree) # Alternative: model %>% finalize_model()
```
Bevor wir fortfahren, sollten wir uns die Metriken ansehen, die für den besten Tree (Baum) vorliegen. Dies ist nützlich für den Vergleich mit der Leistung des Test-Datensatzes. So können wir kontrollieren, ob die Daten zu gut oder zu schlecht angepasst sind.
```{r}
tree_results %>%
collect_metrics() %>%
filter(mtry == 5, trees == 175, min_n == 6)
```
Wir kommen zum letzten Schritt, dem Testen unseres Modells auf dem Test-Datensatz. `last_fit()` verwendet das ursprüngliche Split-Objekt, passt das Modell an den Trainings-Split an und evaluiert es automatisch anhand des Test-Splits. Das ist sehr praktisch und wir können alle abschliessenden Schritte anhand der Ergebnisse durchführen. Zuerst sammeln wir jedoch zwei grundlegende Metriken.
```{r}
# Das Modell wird an den vollständigen Trainingssatz angepasst
# und das endgültige Modell anhand der Testdaten bewertet.
final_fit <-
final_wf %>%
last_fit(telco_split)
# Alternative:
# workflow() %>%
# add_model(spec = final_model) %>%
# add_recipe(recipe = recipe) %>%
# last_fit(data_split)
```
```{r}
# Gesammelte Metriken für die endgültige Anpassung anzeigen
final_fit %>%
collect_metrics()
# Alternative:
# preds <-
# collect_predictions(final_wf) %>%
# head()
```
Wenn wir uns die Metriken des besten Baums ansehen, sehen wir fast identische Metriken für `accuracy` und `roc_auc`. Das ist grossartig, denn es zeigt, dass unser Modell anscheinend gut zu den Daten passt. Wenn wir grosse Unterschiede sehen, kann das ein Zeichen dafür sein, dass unser Modell zu gut oder zu schlecht an die Trainingsdaten angepasst ist. Es wäre sicherlich sinnvoll, mehr Metriken zu vergleichen, um zu prüfen, ob es drastische Unterschiede gibt. Mit den vorhandenen, begrenzten Kennzahlen scheint unser Modell gut zu passen.
Es gibt drei weitere Dinge, die wir untersuchen können: ein Diagramm unserer ROC-Kurve, eine Konfusionsmatrix und ein Diagramm der Variablenbedeutung.
```{r}
# ROC-Kurve für die endgültige Anpassung erstellen
final_fit %>%
collect_predictions() %>%
roc_curve(Churn, .pred_No) %>%
autoplot()
```
Die Kurve zeigt die `sensitivity` und `specificity` für verschiedene Schwellenwerte. Wenn die Kurve nahe der diagonalen Linie verläuft oder wild hin und her springt, sind die Vorhersagen möglicherweise zufälliger als gewünscht. Eine ideale ROC-Kurve ist so nah wie möglich oben links. Unsere Kurve liegt relativ nahe an der oberen linken Ecke und weit von der diagonalen Linie entfernt, was darauf hindeutet, dass die Vorhersagen des Modells nicht ausschliesslich auf Zufall beruhen. Es ist auch eine relativ glatte Linie. Im Idealfall würde man zwei dieser Kurven benötigen, eine aus der Trainingsmenge und eine aus der Testmenge. Ein Vergleich von `roc_auc` ist hilfreich, denn die Werte können ähnlich sein, obwohl sie sehr unterschiedliche Kurven haben. Es ist vorteilhaft, diese Unterschiede zu vergleichen.
```{r}
# Konfusionsmatrix
final_pred <-
final_fit %>%
unnest(.predictions)
confusionMatrix(final_pred$.pred_class, telco_test$Churn)
# Alternative: table(final_pred$.pred_class, telco_test$Churn)
```
```{r}
# Globales Variablen-Bedeutungsdiagramm erstellen
final_fit %>%
extract_workflow() %>%
extract_fit_parsnip() %>%
vip() +
ggtitle("Globale Bedeutung der Variablen") +
ylab("Bedeutung")
# Alternative:
# final_fit %>%
# extract_fit_parsnip() %>%
# vip() +
# ggtitle("Feature Importance") +
# theme(title = element_text(size = 20),
# axis.text.y.left = element_text(size = 20))
```
Dem Diagramm können wir entnehmen, dass - zumindest nach der Art und Weise, wie wir unser Modell aufgebaut haben - die Vertragslaufzeit, die Gesamtkosten und die monatlichen Kosten die wichtigsten Prädiktoren für die Kundenabwanderung sind. Wir haben diese Faktoren bereits in unseren Diagrammen weiter oben im Bericht gesehen. Daher ist es nicht überraschend, dass sie sich als gute Prädiktoren erweisen.
```{r}
final_pred %>%
select(.pred_Yes,
.pred_No,
.pred_class) %>%
slice(1:4) %>%
rename("P Yes" = .pred_Yes,
"P No" = .pred_No,
"FCST Klasse" = .pred_class)
```
```{r}
model <-
fit(final_wf,
telco_train)
```
```{r}
pred_test <-
predict(object = model,
new_data = telco_test %>% select(-Churn)) %>%
bind_cols(true = telco_test$Churn)
```
```{r}
pred_test %>%
summarise(accuracy = accuracy_vec(true, .pred_class),
sensitivity = sens_vec(true, .pred_class),
precision = precision_vec(true, .pred_class),
specificity = spec_vec(true, .pred_class))
```
```{r}
# Lokales, modell-agnostisches Variablen-Bedeutungsdiagramm erstellen
telco_train_num <- telco_train
telco_train_num$Churn <-
recode_factor(telco_train$Churn,
Yes = 1,
No = 0)
telco_train_num$Churn <- as.numeric(as.character(telco_train_num$Churn))
```
```{r}
# Explainer erstellen
explainer <-
explain_tidymodels(
model = model,
data = telco_train_num %>% select(-Churn),
y = telco_train_num %>% pull(Churn),
label = "rf"
)
```
```{r}
model_type.dalex_explainer <- DALEXtra::model_type.dalex_explainer
predict_model.dalex_explainer <- DALEXtra::predict_model.dalex_explainer
explanation <-
predict_surrogate(explainer = explainer,
new_observation = telco_test[1:4, ] %>%
select(-Churn),
n_features = 10,
n_permutations = 1000,
type = "lime")
```
```{r}
explanation %>%
select(model_r2, model_prediction, prediction) %>%
distinct()
```
```{r}
plot(explanation) +
ggtitle("Bedeutung der Variablen") +
xlab("Merkmale") +
ylab("Gewichtung")
```
## Quarto
Quarto ermöglicht es, Inhalte und ausführbaren Code in einem Dokument zu kombinieren. Mehr über Quarto erfahren Sie unter <https://quarto.org>.