-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathshapley.go
135 lines (116 loc) · 4.2 KB
/
shapley.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
package attribution
import (
"log"
"math/big"
"sort"
)
// GetTotalValue returns the summed value over all contributions.
func GetTotalValue(contributions []ContributionSet) big.Float {
value := new(big.Float)
for _, contribution := range contributions {
value.Add(value, &contribution.Value)
}
return *value
}
// GetAllTouchpoints returns a list (without repetition) all touchpoints encountered in contributions.
func GetAllTouchpoints(contributions []ContributionSet) Touchpoints {
seen := make(map[Touchpoint]struct{})
var touchpoints Touchpoints
for _, contribution := range contributions {
for touchpoint, _ := range contribution.Touchpoints {
if _, found := seen[touchpoint]; !found {
seen[touchpoint] = struct{}{}
touchpoints = append(touchpoints, touchpoint)
}
}
}
sort.Sort(touchpoints)
return touchpoints
}
// GetCoalitionValue returns the total value a given coalition achieved over a list of contributions.
func GetCoalitionValue(coalition map[Touchpoint]struct{}, allContributions []ContributionSet) big.Float {
coalitionValue := new(big.Float)
var coalitionContributed bool
for _, contribution := range allContributions {
coalitionContributed = true
for touchpoint, _ := range contribution.Touchpoints {
if _, ok := coalition[touchpoint]; !ok {
coalitionContributed = false
break
}
}
if coalitionContributed {
coalitionValue.Add(coalitionValue, &contribution.Value)
}
}
return *coalitionValue
}
// findTouchpoint attempts to find a given touchpoint in a slice of touchpoints.
// If the search is successful, return the first indice where the touchpoint occured in the first coordinate and true in the second coordinate.
// Otherwise, return (-1, false)
func findTouchpoint(touchpoint Touchpoint, slice []Touchpoint) (int, bool) {
for index, element := range slice {
if touchpoint == element {
return index, true
}
}
return -1, false
}
// GetShapleyValue returns the (unordered) Shapley value of a given touchpoint over all provided contributions.
// For a concise introduction to Shapley values, see https://christophm.github.io/interpretable-ml-book/shapley.html
func GetShapleyValue(touchpoint Touchpoint, allContributions []ContributionSet) big.Float {
shapleyValue := new(big.Float)
allTouchpoints := GetAllTouchpoints(allContributions)
touchpointIndex, found := findTouchpoint(touchpoint, allTouchpoints)
if !found {
log.Fatal("Illegal touchpoint!")
}
allTouchpoints[touchpointIndex] = allTouchpoints[len(allTouchpoints)-1]
allTouchpoints[len(allTouchpoints)-1] = touchpoint
powerset := getPowerSetIndices(uint(len(allTouchpoints) - 1))
for _, subset := range powerset {
coalition := make(map[Touchpoint]struct{}, len(subset))
coalitionSize := int64(0)
for _, index := range subset {
coalition[allTouchpoints[index]] = struct{}{}
coalitionSize++
}
if _, ok := coalition[touchpoint]; ok {
log.Fatal("This should never happen!")
}
coalitionValue := GetCoalitionValue(coalition, allContributions)
coalition[touchpoint] = struct{}{}
addedCoalitionValue := GetCoalitionValue(coalition, allContributions)
addedCoalitionValue.Sub(&addedCoalitionValue, &coalitionValue)
nominator := new(big.Int).MulRange(1, coalitionSize)
nominator.Mul(nominator, new(big.Int).MulRange(1, int64(len(allTouchpoints))-coalitionSize-1))
denominator := new(big.Int).MulRange(1, int64(len(allTouchpoints)))
scalingFactor := new(big.Float)
scalingFactor.Quo(new(big.Float).SetInt(nominator), new(big.Float).SetInt(denominator))
addedShapleyValue := new(big.Float)
addedShapleyValue.Mul(scalingFactor, &addedCoalitionValue)
shapleyValue.Add(shapleyValue, addedShapleyValue)
}
return *shapleyValue
}
// getPowerSetIndices provides the powerset of {0, 1, .., size - 1}.
// This can be used to iterate over arbitary powersets by using this result as an index.
func getPowerSetIndices(size uint) [][]uint {
if size < 1 {
return [][]uint{[]uint{}}
}
powerSetSize := 2 << (size - 1)
powerset := make([][]uint, 0, powerSetSize)
index := 0
for index < powerSetSize {
var subSet []uint
for i := uint(0); i < size; i++ {
if index&(1<<i) > 0 {
subSet = append(subSet, i)
}
}
powerset = append(powerset, subSet)
index++
}
return powerset
}