-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmy_confusionmat.m
395 lines (331 loc) · 13.4 KB
/
my_confusionmat.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
function [cm,gn] = my_confusionmat(g,ghat,varargin)
% CONFUSIONMAT Confusion matrix for classification algorithms.
% CM = CONFUSIONMAT(G,GHAT) returns the confusion matrix CM determined
% by the known group labels G and the predicted group labels GHAT. G and
% GHAT are grouping variables with the same number of observations. G
% and GHAT can be categorical, numeric, or logical vectors;
% single-column cell arrays of character vectors; single-column of
% strings; or character matrices (each row representing a group label).
% G and GHAT must be of the same type. CM is a square matrix with size
% equal to the total number of distinct elements in G and GHAT. CM(I,J)
% represents the count of instances whose known group labels are group I
% and whose predicted group labels are group J. CONFUSIONMAT treats
% NaNs, empty strings or 'undefined' values in G or GHAT as missing
% values, and the corresponding observations are not counted. If inputs
% are character matrices or cell string of charater arrays, CONFUSIONMAT
% will trim inputs by strtrim().
%
% The sets of groups and the orders of group labels in rows and
% columns of CM are the same. They include all the groups appearing in
% GN, and have the same order of group labels as GN, where GN is the
% second of output of grp2idx([G;GHAT]).
%
% CM = CONFUSIONMAT(G,GHAT,'ORDER',ORDER) returns the confusion matrix
% with the order of rows (and columns) specified by ORDER. ORDER is a
% vector containing group labels and whose values can be compared to
% those in G or GHAT using the equality operator. ORDER must contain all
% the labels appearing in G or GHAT. ORDER can contain labels which do
% not appear in G and GHAT, and hence CM will have zeros in the
% corresponding rows and columns. If ORDER is a character matrix or cell
% string of character array, CONFUSIONMAT will trim it by strtrim().
%
% [CM, GORDER] = CONFUSIONMAT(G, GHAT) returns the order of group labels
% for rows and columns of CM. GORDER has the same type as G and GHAT.
%
% Example:
% % Compute the resubstitution confusion matrix for applying CLASSIFY
% % on Fisher iris data.
% load fisheriris
% x = meas;
% y = species;
% yhat = classify(x,x,y);
% [cm,order] = confusionmat(y,yhat);
%
% See also CONFUSIONCHART, CROSSTAB, GRP2IDX.
% Copyright 2008-2019 The MathWorks, Inc.
if nargin < 2
iThrowError('stats:confusionmat:NotEnoughInputs');
end
% Convert the Name 'Order' to char.
if nargin > 2
[varargin{1}] = convertStringsToChars(varargin{1});
end
% Convert ghat to g's class if necessary
if ~isempty(g) && ~isempty(ghat) && ~strcmp(class(g),class(ghat))
if iscellstr(g) && (isstring(ghat) || ischar(ghat)) && ismatrix(ghat)
ghat = cellstr(ghat);
elseif isstring(g) && (iscellstr(ghat) || ischar(ghat))
ghat = string(ghat);
elseif ischar(g) && (iscellstr(ghat) || isstring(ghat))
ghat = char(ghat);
end
end
gClass = class(g);
if ~strcmp(gClass,class(ghat))
iThrowError('stats:confusionmat:GTypeMismatch');
end
if ~isnumeric(g) && ~islogical(g) && ~isa(g,'categorical') ...
&& ~iscellstr(g) && ~ischar(g) && ~isstring(g)
iThrowError('stats:confusionmat:GTypeIncorrect');
end
if ischar(g)
if ~ismatrix(g) || ~ismatrix(ghat)
iThrowError('stats:confusionmat:BadGroup');
end
g = cellstr(g);
ghat = cellstr(ghat);
elseif ~isvector(g) || ~isvector(ghat)
iThrowError('stats:confusionmat:BadGroup');
else
g = g(:);
ghat = ghat(:);
end
if iscellstr(g)
g = strtrim(g);
ghat = strtrim(ghat);
end
if size(g,1) ~= size(ghat,1)
iThrowError('stats:confusionmat:GRowNumMismatch');
end
if isa(g,'categorical')
if isordinal(g)
if ~isordinal(ghat)
iThrowError('stats:confusionmat:GOrdinalLevelsMismatch');
elseif ~isequal(categories(g),categories(ghat))
iThrowError('stats:confusionmat:GOrdinalLevelsMismatch');
end
elseif isordinal(ghat)
iThrowError('stats:confusionmat:GOrdinalLevelsMismatch');
end
end
order = iParseOrderNameValuePair(varargin{:});
% Convert order to g's class if necessary
if ~isempty(order) && ~strcmp(class(g),class(order))
if iscellstr(g) && (isstring(order) || ischar(order)) && ismatrix(order)
order = cellstr(order);
elseif isstring(g) && (iscellstr(order) || ischar(order))
order = string(order);
elseif ischar(g) && (isstring(order) || ischar(order))
order = char(order);
end
end
if ~isempty(order)
if ischar(order)
if ~ismatrix(order)
iThrowError('stats:confusionmat:NDCharArrayORDER');
end
order = cellstr(order);
elseif ~isvector(order)
iThrowError('stats:confusionmat:NonVectorORDER');
end
if isa(g,'categorical')
if iscellstr(order)
if any(strcmp('',strtrim(order)))
iThrowError('stats:confusionmat:OrderHasEmptyString');
end
elseif iscategorical(order)
if any(isundefined(order))
iThrowError('stats:confusionmat:OrderHasUndefined');
end
else
iThrowError('stats:confusionmat:TypeMismatchOrder');
end
g = setcatsLocal(g,order);
ghat = setcatsLocal(ghat,order);
else % g is not categorical vector
if isnumeric(g)
if ~isnumeric(order)
iThrowError('stats:confusionmat:TypeMismatchOrder');
end
if any(isnan(order))
iThrowError('stats:confusionmat:OrderHasNaN');
end
elseif islogical(g)
if islogical(order)
%OK. do nothing
elseif isnumeric(order)
if any(isnan(order))
iThrowError('stats:confusionmat:OrderHasNaN');
end
order = logical(order);
else
iThrowError('stats:confusionmat:TypeMismatchOrder');
end
elseif iscellstr(g)
if ~iscellstr(order)
iThrowError('stats:confusionmat:TypeMismatchLevels');
end
if any(strcmp('',strtrim(order)))
iThrowError('stats:confusionmat:OrderHasEmptyString');
end
order = strtrim(order);
end
try
uorder = unique(order);
catch ME
iThrowError('stats:confusionmat:UniqueMethodFailedOrder');
end
if length(uorder) < length(order)
iThrowError('stats:confusionmat:DuplicatedOrder');
end
order = order(:);
end
end
% Perform calculation
[cm, gn] = iCalculateConfusion(g, ghat);
if ~isempty(order)
if ~isa(g,'categorical')
%get the map from the default order to the given order
[hasAllLabel,map] = ismember(gn,order);
if ~all(hasAllLabel)
iThrowError('stats:confusionmat:OrderInsufficientLabels');
end
orderLen = length(order);
cm2 = zeros(orderLen, orderLen);
cm2(map,map) = cm(:,:);
cm = cm2;
if nargout > 1
%convert gn to the same type as g
if strcmp(gClass,'char')
gn = char(order);
else
gn = order;
end
end
end
elseif strcmp(gClass,'char')
gn = char(gn);
end
end
function b = setcatsLocal(a,newCategories)
if iscategorical(newCategories)
if ~isa(newCategories,class(a))
iThrowError('stats:confusionmat:TypeMismatchOrder');
elseif isordinal(a)
if ~isordinal(newCategories)
iThrowError('stats:confusionmat:TypeMismatchOrder');
elseif ~isequal(categories(a),categories(newCategories))
iThrowError('stats:confusionmat:TypeMismatchOrder');
end
elseif isordinal(newCategories)
iThrowError('stats:confusionmat:TypeMismatchOrder');
end
newCategories = cellstr(newCategories);
end
existingCategories = categories(a);
if ~isempty(setdiff(existingCategories,newCategories))
iThrowError('stats:confusionmat:OrderInsufficientLabels');
end
b = addcats(a,newCategories,'after',existingCategories{end});
b = reordercats(b,newCategories);
end
function order = iParseOrderNameValuePair(varargin)
% Used for the confusionmat(g, ghat, 'Order', order) syntax. If no order
% was provided, returns empty. Otherwise, validates that the 'Order'
% parameter was provided.
switch nargin
case 0
order = [];
case 2
% Make sure the caller provided a string that fuzzy-matches 'Order'.
validatestring(varargin{1}, {'Order'});
% Assign the order.
order = varargin{2};
otherwise
% Get the "WrongNumArgs" error message, throw it as if the error
% came from confusionmat.m rather than this subfunction.
throwAsCaller(iGetErrorWithWrongNumberOfArgs());
end
end
function [cm, classLabels] = iCalculateConfusion(g, ghat)
% Performs the confusion matrix calculation. g, ghat are the
% (validated) true and predicted labels.
%
% Returns cm, the confusion matrix, and the labels of each class (the same
% type as g and ghat).
gLen = size(g,1);
% Use findgroups to obtain the group indexes, idx, and the group names,
% gLevels.
[idx,classLabels] = iFindGroups(g, ghat);
% Split the indices into true and predicted observations.
gidx = idx(1:gLen,:);
ghatidx = idx(gLen+1:gLen*2,:);
% Ignore NaN values in GIDX and GHATIDX
nanrows = isnan(gidx) | isnan(ghatidx);
if any(nanrows)
gidx(nanrows,:) = [];
ghatidx(nanrows,:) = [];
end
% Actual calculation of the confusion matrix.
cm = accumarray([gidx, ghatidx], 1, [length(classLabels), length(classLabels)]);
% If g is a cellstr or string array, we need to re-order gLevels and the
% confusion matrix so they are in order of classes first appearing (rather
% than alphabetical order, which is what findgroups returns).
if iscell(g) || isstring(g)
[cm, classLabels] = reorderMatrixAndLabelsToFirstApperanceOrder(g, ghat, cm);
end
end
function [cm, classLabels] = reorderMatrixAndLabelsToFirstApperanceOrder(g, ghat, cm)
% If labels are cellstrings or string arrays, findgroups returns them in
% alphabetical order, but we want first appearance order. We need to
% reorder after calculating the matrix so we correctly deal with <missing>.
% Strip out any missing labels when we calculate the indices.
[levels,levelIdx] = unique(rmmissing([g;ghat]),'first');
% Sort by order of appearance
[~,idx] = sort(levelIdx);
% Re-order to order of appearance.
classLabels = levels(idx);
cm = cm(idx(:), idx(:));
end
function [idx,classLabels] = iFindGroups(g, ghat)
% Find the group indices idx, and the group names classLabels, using
% findgroups(). We want to make the following modifications:
%
% 1: In the case where g and ghat are categoricals which define underlying
% categories that don't appear in the observations, we want classLabels to be
% the underlying categories (findgroups returns only categories seen in the
% observations); we deal with that here.
%
% 2: In the case where g and ghat are cellstrings, we want to return the
% categories in order of appearance (findgroups returns them in
% alphabetical order). However, this interacts with whether there are NaNs
% in the data - we deal with that after the confusion matrix
% has been calculated, and ignore it here.
[idx,classLabels] = findgroups([g;ghat]);
% If there are more underlying categories than are seen in the
% observations, get the list of underlying categories (in the order they
% appear).
% if iscategorical(g) && length(categories([g; ghat])) > length(classLabels)
% % Converting to double is equivalent to finding the indices, including
% % underlying categories that aren't observed.
% idx = double([g; ghat]);
%
% % Get the categories (which will appear in the correct order) as a
% % categorical, preserving ordinality.
% classLabels = categorical(categories([g; ghat]), categories([g; ghat]), 'Ordinal', isordinal(g));
% end
end
function iThrowError(errid,varargin)
% We want to throw an error with identifier "stats:confusionmat", even
% though the actual message catalog identifier is "mlearnlib:confusionmat".
originalMsg = message(errid,varargin{:});
oldID = string(originalMsg.Identifier);
newID = oldID.replace("stats", "mlearnlib");
% Get the error message from the new catalog but throw an error with the
% old ID.
newMsg = message(newID, originalMsg.Arguments{:});
errorText = newMsg.getString();
throwAsCaller(MException(oldID, errorText));
end
function mException = iGetErrorWithWrongNumberOfArgs()
% For the "WrongNumberArgs" error, we want to throw an error with ID
% "stats:internal:parseArgs:WrongNumberArgs", even though the error message
% is actually located in "mlearnlib:confusionmat".
oldID = 'stats:internal:parseArgs:WrongNumberArgs';
newMsg = message('mlearnlib:confusionmat:WrongNumberArgs');
errorText = newMsg.getString();
% We don't actually throw this error, but instead return the MException.
% This is because the error is being produced in a subfunction, and we want
% to throw it as caller.
mException = MException(oldID, errorText);
end