1
2
3
4
5
6
7
8
9 """Utility class to compute the transfer error of classifiers."""
10
11 __docformat__ = 'restructuredtext'
12
13 import mvpa.support.copy as copy
14
15 import numpy as N
16
17 from sets import Set
18 from StringIO import StringIO
19 from math import log10, ceil
20
21 from mvpa.base import externals
22
23 from mvpa.misc.errorfx import meanPowerFx, rootMeanPowerFx, RMSErrorFx, \
24 CorrErrorFx, CorrErrorPFx, RelativeRMSErrorFx, MeanMismatchErrorFx, \
25 AUCErrorFx
26 from mvpa.base import warning
27 from mvpa.misc.state import StateVariable, Stateful
28 from mvpa.base.dochelpers import enhancedDocString, table2string
29 from mvpa.clfs.stats import autoNullDist
30
31 if __debug__:
32 from mvpa.base import debug
33
34 if externals.exists('scipy'):
35 from scipy.stats.stats import nanmean
36 else:
37 from mvpa.clfs.stats import nanmean
38
40 """Helper to print depending on the type nicely. For some
41 reason %.2g for 100 prints exponential form which is ugly
42 """
43 if isinstance(x, int):
44 return "%d" % x
45 elif isinstance(x, float):
46 s = ("%%.%df" % prec % x).rstrip('0').rstrip('.').lstrip()
47 if s == '':
48 s = '0'
49 return s
50 else:
51 return "%s" % x
52
53
54
56 """Basic class to collect targets/predictions and report summary statistics
57
58 It takes care about collecting the sets, which are just tuples
59 (targets, predictions, values). While 'computing' the matrix, all
60 sets are considered together. Children of the class are
61 responsible for computation and display.
62 """
63
64 _STATS_DESCRIPTION = (
65 ('# of sets',
66 'number of target/prediction sets which were provided',
67 None), )
68
69
70 - def __init__(self, targets=None, predictions=None, values=None, sets=None):
71 """Initialize SummaryStatistics
72
73 targets or predictions cannot be provided alone (ie targets
74 without predictions)
75
76 :Parameters:
77 targets
78 Optional set of targets
79 predictions
80 Optional set of predictions
81 values
82 Optional set of values (which served for prediction)
83 sets
84 Optional list of sets
85 """
86 self._computed = False
87 """Flag either it was computed for a given set of data"""
88
89 self.__sets = (sets, [])[int(sets is None)]
90 """Datasets (target, prediction) to compute confusion matrix on"""
91
92 if not targets is None or not predictions is None:
93 if not targets is None and not predictions is None:
94 self.add(targets=targets, predictions=predictions,
95 values=values)
96 else:
97 raise ValueError, \
98 "Please provide none or both targets and predictions"
99
100
101 - def add(self, targets, predictions, values=None):
102 """Add new results to the set of known results"""
103 if len(targets) != len(predictions):
104 raise ValueError, \
105 "Targets[%d] and predictions[%d]" % (len(targets),
106 len(predictions)) + \
107 " have different number of samples"
108
109 if values is not None and len(targets) != len(values):
110 raise ValueError, \
111 "Targets[%d] and values[%d]" % (len(targets),
112 len(values)) + \
113 " have different number of samples"
114
115
116
117
118 nonetype = type(None)
119 for i in xrange(len(targets)):
120 t1, t2 = type(targets[i]), type(predictions[i])
121
122
123 if t1 != t2 and t2 != nonetype:
124
125
126 if isinstance(predictions, tuple):
127 predictions = list(predictions)
128 predictions[i] = t1(predictions[i])
129
130 self.__sets.append( (targets, predictions, values) )
131 self._computed = False
132
133
134 - def asstring(self, short=False, header=True, summary=True,
135 description=False):
136 """'Pretty print' the matrix
137
138 :Parameters:
139 short : bool
140 if True, ignores the rest of the parameters and provides consise
141 1 line summary
142 header : bool
143 print header of the table
144 summary : bool
145 print summary (accuracy)
146 description : bool
147 print verbose description of presented statistics
148 """
149 raise NotImplementedError
150
151
153 """String summary over the `SummaryStatistics`
154
155 It would print description of the summary statistics if 'CM'
156 debug target is active
157 """
158 if __debug__:
159 description = ('CM' in debug.active)
160 else:
161 description = False
162 return self.asstring(short=False, header=True, summary=True,
163 description=description)
164
165
167 """Add the sets from `other` s `SummaryStatistics` to current one
168 """
169
170
171
172 othersets = copy.copy(other.__sets)
173 for set in othersets:
174 self.add(*set)
175 return self
176
177
179 """Add two `SummaryStatistics`s
180 """
181 result = copy.copy(self)
182 result += other
183 return result
184
185
187 """Actually compute the confusion matrix based on all the sets"""
188 if self._computed:
189 return
190
191 self._compute()
192 self._computed = True
193
194
196 self._stats = {'# of sets' : len(self.sets)}
197
198
199 @property
201 """Return a list of separate summaries per each stored set"""
202 return [ self.__class__(sets=[x]) for x in self.sets ]
203
204
205 @property
207 raise NotImplementedError
208
209
210 @property
212 self.compute()
213 return self._stats
214
215
217 """Cleans summary -- all data/sets are wiped out
218 """
219 self.__sets = []
220 self._computed = False
221
222
223 sets = property(lambda self:self.__sets)
224
225
227 """Generic class for ROC curve computation and plotting
228 """
229
231 """
232 :Parameters:
233 labels : list
234 labels which were used (in order of values if multiclass,
235 or 1 per class for binary problems (e.g. in SMLR))
236 sets : list of tuples
237 list of sets for the analysis
238 """
239 self._labels = labels
240 self._sets = sets
241 self.__computed = False
242
243
245 """Lazy computation if needed
246 """
247 if self.__computed:
248 return
249
250 labels = self._labels
251 Nlabels = len(labels)
252 sets = self._sets
253
254
255 def _checkValues(set_):
256 """Check if values are 'acceptable'"""
257 if len(set_)<3: return False
258 x = set_[2]
259
260 if (x is None) or len(x) == 0: return False
261 for v in x:
262 try:
263 if Nlabels <= 2 and N.isscalar(v):
264 continue
265 if (isinstance(v, dict) or
266 ((Nlabels>=2) and len(v)!=Nlabels)
267 ): return False
268 except Exception, e:
269
270
271
272 if __debug__:
273 debug('ROC', "Exception %s while checking "
274 "either %s are valid labels" % (str(e), x))
275 return False
276 return True
277
278 sets_wv = filter(_checkValues, sets)
279
280 Nsets_wv = len(sets_wv)
281 if Nsets_wv > 0 and len(sets) != Nsets_wv:
282 warning("Only %d sets have values assigned from %d sets" %
283 (Nsets_wv, len(sets)))
284
285
286
287
288 for iset,s in enumerate(sets_wv):
289
290 values = s[2]
291
292 if isinstance(values, N.ndarray) and len(values.shape)==1:
293 values = list(values)
294 rangev = None
295 for i in xrange(len(values)):
296 v = values[i]
297 if N.isscalar(v):
298 if Nlabels == 2:
299 def last_el(x):
300 """Helper function. Returns x if x is scalar, and
301 last element if x is not (ie list/tuple)"""
302 if N.isscalar(x): return x
303 else: return x[-1]
304 if rangev is None:
305
306
307 values_ = [last_el(x) for x in values]
308 rangev = N.min(values_) + N.max(values_)
309 values[i] = [rangev - v, v]
310 else:
311 raise ValueError, \
312 "Cannot have a single 'value' for multiclass" \
313 " classification. Got %s" % (v)
314 elif len(v) != Nlabels:
315 raise ValueError, \
316 "Got %d values whenever there is %d labels" % \
317 (len(v), Nlabels)
318
319 sets_wv[iset] = (s[0], s[1], N.asarray(values))
320
321
322
323
324
325 ROCs, aucs = [], []
326 for i,label in enumerate(labels):
327 aucs_pl = []
328 ROCs_pl = []
329 for s in sets_wv:
330 targets_pl = (s[0] == label).astype(int)
331
332 ROC = AUCErrorFx()
333 aucs_pl += [ROC([x[i] for x in s[2]], targets_pl)]
334 ROCs_pl.append(ROC)
335 if len(aucs_pl)>0:
336 ROCs += [ROCs_pl]
337 aucs += [nanmean(aucs_pl)]
338
339
340
341 self._ROCs = ROCs
342 self._aucs = aucs
343 self.__computed = True
344
345
346 @property
348 """Compute and return set of AUC values 1 per label
349 """
350 self._compute()
351 return self._aucs
352
353
354 @property
356 self._compute()
357 return self._ROCs
358
359
360 - def plot(self, label_index=0):
361 """
362
363 TODO: make it friendly to labels given by values?
364 should we also treat labels_map?
365 """
366 externals.exists("pylab", raiseException=True)
367 import pylab as P
368
369 self._compute()
370
371 labels = self._labels
372
373 ROCs = self.ROCs[label_index]
374
375 fig = P.gcf()
376 ax = P.gca()
377
378 P.plot([0, 1], [0, 1], 'k:')
379
380 for ROC in ROCs:
381 P.plot(ROC.fp, ROC.tp, linewidth=1)
382
383 P.axis((0.0, 1.0, 0.0, 1.0))
384 P.axis('scaled')
385 P.title('Label %s. Mean AUC=%.2f' % (label_index, self.aucs[label_index]))
386
387 P.xlabel('False positive rate')
388 P.ylabel('True positive rate')
389
390
392 """Class to contain information and display confusion matrix.
393
394 Implementation of the `SummaryStatistics` in the case of
395 classification problem. Actual computation of confusion matrix is
396 delayed until all data is acquired (to figure out complete set of
397 labels). If testing data doesn't have a complete set of labels,
398 but you like to include all labels, provide them as a parameter to
399 the constructor.
400
401 Confusion matrix provides a set of performance statistics (use
402 asstring(description=True) for the description of abbreviations),
403 as well ROC curve (http://en.wikipedia.org/wiki/ROC_curve)
404 plotting and analysis (AUC) in the limited set of problems:
405 binary, multiclass 1-vs-all.
406 """
407
408 _STATS_DESCRIPTION = (
409 ('TP', 'true positive (AKA hit)', None),
410 ('TN', 'true negative (AKA correct rejection)', None),
411 ('FP', 'false positive (AKA false alarm, Type I error)', None),
412 ('FN', 'false negative (AKA miss, Type II error)', None),
413 ('TPR', 'true positive rate (AKA hit rate, recall, sensitivity)',
414 'TPR = TP / P = TP / (TP + FN)'),
415 ('FPR', 'false positive rate (AKA false alarm rate, fall-out)',
416 'FPR = FP / N = FP / (FP + TN)'),
417 ('ACC', 'accuracy', 'ACC = (TP + TN) / (P + N)'),
418 ('SPC', 'specificity', 'SPC = TN / (FP + TN) = 1 - FPR'),
419 ('PPV', 'positive predictive value (AKA precision)',
420 'PPV = TP / (TP + FP)'),
421 ('NPV', 'negative predictive value', 'NPV = TN / (TN + FN)'),
422 ('FDR', 'false discovery rate', 'FDR = FP / (FP + TP)'),
423 ('MCC', "Matthews Correlation Coefficient",
424 "MCC = (TP*TN - FP*FN)/sqrt(P N P' N')"),
425 ('AUC', "Area under (AUC) curve", None),
426 ) + SummaryStatistics._STATS_DESCRIPTION
427
428
429 - def __init__(self, labels=None, labels_map=None, **kwargs):
430 """Initialize ConfusionMatrix with optional list of `labels`
431
432 :Parameters:
433 labels : list
434 Optional set of labels to include in the matrix
435 labels_map : None or dict
436 Dictionary from original dataset to show mapping into
437 numerical labels
438 targets
439 Optional set of targets
440 predictions
441 Optional set of predictions
442 """
443
444 SummaryStatistics.__init__(self, **kwargs)
445
446 if labels == None:
447 labels = []
448 self.__labels = labels
449 """List of known labels"""
450 self.__labels_map = labels_map
451 """Mapping from original into given labels"""
452 self.__matrix = None
453 """Resultant confusion matrix"""
454
455
456
457
458 @property
464
465
467 """Actually compute the confusion matrix based on all the sets"""
468
469 super(ConfusionMatrix, self)._compute()
470
471 if __debug__:
472 if not self.__matrix is None:
473 debug("LAZY",
474 "Have to recompute %s#%s" \
475 % (self.__class__.__name__, id(self)))
476
477
478
479
480 try:
481
482 labels = \
483 list(reduce(lambda x, y: x.union(Set(y[0]).union(Set(y[1]))),
484 self.sets,
485 Set(self.__labels)))
486 except:
487 labels = self.__labels
488
489
490 labels_map = self.__labels_map
491 if labels_map is not None:
492 labels_set = Set(labels)
493 map_labels_set = Set(labels_map.values())
494
495 if not map_labels_set.issuperset(labels_set):
496 warning("Provided labels_map %s is not coherent with labels "
497 "provided to ConfusionMatrix. No reverse mapping "
498 "will be provided" % labels_map)
499 labels_map = None
500
501
502 labels_map_rev = None
503 if labels_map is not None:
504 labels_map_rev = {}
505 for k,v in labels_map.iteritems():
506 v_mapping = labels_map_rev.get(v, [])
507 v_mapping.append(k)
508 labels_map_rev[v] = v_mapping
509 self.__labels_map_rev = labels_map_rev
510
511 labels.sort()
512 self.__labels = labels
513
514 Nlabels, Nsets = len(labels), len(self.sets)
515
516 if __debug__:
517 debug("CM", "Got labels %s" % labels)
518
519
520 mat_all = N.zeros( (Nsets, Nlabels, Nlabels), dtype=int )
521
522
523
524
525 counts_all = N.zeros( (Nsets, Nlabels) )
526
527
528 rev_map = dict([ (x[1], x[0]) for x in enumerate(labels)])
529 for iset, set_ in enumerate(self.sets):
530 for t,p in zip(*set_[:2]):
531 mat_all[iset, rev_map[p], rev_map[t]] += 1
532
533
534
535
536
537 self.__matrix = N.sum(mat_all, axis=0)
538 self.__Nsamples = N.sum(self.__matrix, axis=0)
539 self.__Ncorrect = sum(N.diag(self.__matrix))
540
541 TP = N.diag(self.__matrix)
542 offdiag = self.__matrix - N.diag(TP)
543 stats = {
544 '# of labels' : Nlabels,
545 'TP' : TP,
546 'FP' : N.sum(offdiag, axis=1),
547 'FN' : N.sum(offdiag, axis=0)}
548
549 stats['CORR'] = N.sum(TP)
550 stats['TN'] = stats['CORR'] - stats['TP']
551 stats['P'] = stats['TP'] + stats['FN']
552 stats['N'] = N.sum(stats['P']) - stats['P']
553 stats["P'"] = stats['TP'] + stats['FP']
554 stats["N'"] = stats['TN'] + stats['FN']
555 stats['TPR'] = stats['TP'] / (1.0*stats['P'])
556
557
558 stats['TPR'][stats['P'] == 0] = 0
559 stats['PPV'] = stats['TP'] / (1.0*stats["P'"])
560 stats['NPV'] = stats['TN'] / (1.0*stats["N'"])
561 stats['FDR'] = stats['FP'] / (1.0*stats["P'"])
562 stats['SPC'] = (stats['TN']) / (1.0*stats['FP'] + stats['TN'])
563
564 MCC_denom = N.sqrt(1.0*stats['P']*stats['N']*stats["P'"]*stats["N'"])
565 nz = MCC_denom!=0.0
566 stats['MCC'] = N.zeros(stats['TP'].shape)
567 stats['MCC'][nz] = \
568 (stats['TP'] * stats['TN'] - stats['FP'] * stats['FN'])[nz] \
569 / MCC_denom[nz]
570
571 stats['ACC'] = N.sum(TP)/(1.0*N.sum(stats['P']))
572 stats['ACC%'] = stats['ACC'] * 100.0
573
574
575
576 ROC = ROCCurve(labels=labels, sets=self.sets)
577 aucs = ROC.aucs
578 if len(aucs)>0:
579 stats['AUC'] = aucs
580 if len(aucs) != Nlabels:
581 raise RuntimeError, \
582 "We must got a AUC per label. Got %d instead of %d" % \
583 (len(aucs), Nlabels)
584 self.ROC = ROC
585 else:
586
587 stats['AUC'] = [N.nan] * Nlabels
588 self.ROC = None
589
590
591
592 for k,v in stats.items():
593 stats['mean(%s)' % k] = N.mean(v)
594
595 self._stats.update(stats)
596
597
598 - def asstring(self, short=False, header=True, summary=True,
599 description=False):
600 """'Pretty print' the matrix
601
602 :Parameters:
603 short : bool
604 if True, ignores the rest of the parameters and provides consise
605 1 line summary
606 header : bool
607 print header of the table
608 summary : bool
609 print summary (accuracy)
610 description : bool
611 print verbose description of presented statistics
612 """
613 if len(self.sets) == 0:
614 return "Empty"
615
616 self.compute()
617
618
619 labels = self.__labels
620 labels_map_rev = self.__labels_map_rev
621 matrix = self.__matrix
622
623 labels_rev = []
624 if labels_map_rev is not None:
625 labels_rev = [','.join([str(x) for x in labels_map_rev[l]])
626 for l in labels]
627
628 out = StringIO()
629
630 Nlabels = len(labels)
631 Nsamples = self.__Nsamples.astype(int)
632
633 stats = self._stats
634 if short:
635 return "%(# of sets)d sets %(# of labels)d labels " \
636 " ACC:%(ACC).2f" \
637 % stats
638
639 Ndigitsmax = int(ceil(log10(max(Nsamples))))
640 Nlabelsmax = max( [len(str(x)) for x in labels] )
641
642
643 L = max(Ndigitsmax+2, Nlabelsmax)
644 res = ""
645
646 stats_perpredict = ["P'", "N'", 'FP', 'FN', 'PPV', 'NPV', 'TPR',
647 'SPC', 'FDR', 'MCC']
648
649 if self.ROC is not None: stats_perpredict += [ 'AUC' ]
650 stats_pertarget = ['P', 'N', 'TP', 'TN']
651 stats_summary = ['ACC', 'ACC%', '# of sets']
652
653
654
655 prefixlen = Nlabelsmax + 1
656 pref = ' '*(prefixlen)
657
658 if matrix.shape != (Nlabels, Nlabels):
659 raise ValueError, \
660 "Number of labels %d doesn't correspond the size" + \
661 " of a confusion matrix %s" % (Nlabels, matrix.shape)
662
663
664 printed = []
665 underscores = [" %s" % ("-" * L)] * Nlabels
666 if header:
667
668 printed.append(['@l----------. '] + labels_rev)
669 printed.append(['@lpredictions\\targets'] + labels)
670
671 printed.append(['@l `------'] \
672 + underscores + stats_perpredict)
673
674
675 for i, line in enumerate(matrix):
676 l = labels[i]
677 if labels_rev != []:
678 l = '@r%10s / %s' % (labels_rev[i], l)
679 printed.append(
680 [l] +
681 [ str(x) for x in line ] +
682 [ _p2(stats[x][i]) for x in stats_perpredict])
683
684 if summary:
685
686
687
688
689
690
691 printed.append(['@lPer target:'] + underscores)
692 for stat in stats_pertarget:
693 printed.append([stat] + [
694 _p2(stats[stat][i]) for i in xrange(Nlabels)])
695
696
697
698
699 mean_stats = N.mean(N.array([stats[k] for k in stats_perpredict]),
700 axis=1)
701 printed.append(['@lSummary \ Means:'] + underscores
702 + [_p2(stats['mean(%s)' % x])
703 for x in stats_perpredict])
704
705 for stat in stats_summary:
706 printed.append([stat] + [_p2(stats[stat])])
707
708 table2string(printed, out)
709
710 if description:
711 out.write("\nStatistics computed in 1-vs-rest fashion per each " \
712 "target.\n")
713 out.write("Abbreviations (for details see " \
714 "http://en.wikipedia.org/wiki/ROC_curve):\n")
715 for d, val, eq in self._STATS_DESCRIPTION:
716 out.write(" %-3s: %s\n" % (d, val))
717 if eq is not None:
718 out.write(" " + eq + "\n")
719
720
721 result = out.getvalue()
722 out.close()
723 return result
724
725
726 - def plot(self, labels=None, numbers=False, origin='upper',
727 numbers_alpha=None, xlabels_vertical=True, numbers_kwargs={},
728 **kwargs):
729 """Provide presentation of confusion matrix in image
730
731 :Parameters:
732 labels : list of int or basestring
733 Optionally provided labels guarantee the order of
734 presentation. Also value of None places empty column/row,
735 thus provides visual groupping of labels (Thanks Ingo)
736 numbers : bool
737 Place values inside of confusion matrix elements
738 numbers_alpha : None or float
739 Controls textual output of numbers. If None -- all numbers
740 are plotted in the same intensity. If some float -- it controls
741 alpha level -- higher value would give higher contrast. (good
742 value is 2)
743 origin : basestring
744 Which left corner diagonal should start
745 xlabels_vertical : bool
746 Either to plot xlabels vertical (benefitial if number of labels
747 is large)
748 numbers_kwargs : dict
749 Additional keyword parameters to be added to numbers (if numbers
750 is True)
751 **kwargs
752 Additional arguments given to imshow (\eg me cmap)
753
754 :Returns:
755 (fig, im, cb) -- figure, imshow, colorbar
756 """
757
758 externals.exists("pylab", raiseException=True)
759 import pylab as P
760
761 self.compute()
762 labels_order = labels
763
764
765 labels = self.__labels
766 labels_map = self.__labels_map
767 labels_map_rev = self.__labels_map_rev
768 matrix = self.__matrix
769
770
771 labels_indexes = dict([(x,i) for i,x in enumerate(labels)])
772
773 labels_rev = []
774 if labels_map_rev is not None:
775 labels_rev = [','.join([str(x) for x in labels_map_rev[l]])
776 for l in labels]
777 labels_map_full = dict(zip(labels_rev, labels))
778
779 if labels_order is not None:
780 labels_order_filtered = filter(lambda x:x is not None, labels_order)
781 labels_order_filtered_set = Set(labels_order_filtered)
782
783 if Set(labels) == labels_order_filtered_set:
784
785 labels_plot = labels_order
786 elif len(labels_rev) \
787 and Set(labels_rev) == labels_order_filtered_set:
788
789
790 labels_plot = []
791 for l in labels_order:
792 v = None
793 if l is not None: v = labels_map_full[l]
794 labels_plot += [v]
795 else:
796 raise ValueError, \
797 "Provided labels %s do not match set of known " \
798 "original labels (%s) or mapped labels (%s)" % \
799 (labels_order, labels, labels_rev)
800 else:
801 labels_plot = labels
802
803
804 isempty = N.array([l is None for l in labels_plot])
805 non_empty = N.where(N.logical_not(isempty))[0]
806
807 NlabelsNN = len(non_empty)
808 Nlabels = len(labels_plot)
809
810 if matrix.shape != (NlabelsNN, NlabelsNN):
811 raise ValueError, \
812 "Number of labels %d doesn't correspond the size" + \
813 " of a confusion matrix %s" % (NlabelsNN, matrix.shape)
814
815 confusionmatrix = N.zeros((Nlabels, Nlabels))
816 mask = confusionmatrix.copy()
817 ticks = []
818 tick_labels = []
819
820 reordered_indexes = [labels_indexes[i] for i in labels_plot
821 if i is not None]
822 for i, l in enumerate(labels_plot):
823 if l is not None:
824 j = labels_indexes[l]
825 confusionmatrix[i, non_empty] = matrix[j, reordered_indexes]
826 confusionmatrix[non_empty, i] = matrix[reordered_indexes, j]
827 ticks += [i + 0.5]
828 if labels_map_rev is not None:
829 tick_labels += ['/'.join(labels_map_rev[l])]
830 else:
831 tick_labels += [str(l)]
832 else:
833 mask[i, :] = mask[:, i] = 1
834
835 confusionmatrix = N.ma.MaskedArray(confusionmatrix, mask=mask)
836
837
838 if P.matplotlib.get_backend() == 'TkAgg':
839 P.ioff()
840
841 fig = P.gcf()
842 ax = P.gca()
843 ax.axis('off')
844
845
846 xticks_position, yticks, ybottom = {
847 'upper': ('top', [Nlabels-x for x in ticks], 0.1),
848 'lower': ('bottom', ticks, 0.2)
849 }[origin]
850
851
852
853 axi = fig.add_axes([0.15, ybottom, 0.7, 0.7])
854 im = axi.imshow(confusionmatrix, interpolation="nearest", origin=origin,
855 aspect='equal', **kwargs)
856
857
858 if numbers:
859 numbers_kwargs_ = {'fontsize': 10,
860 'horizontalalignment': 'center',
861 'verticalalignment': 'center'}
862 maxv = float(N.max(confusionmatrix))
863 colors = [im.to_rgba(0), im.to_rgba(maxv)]
864 for i,j in zip(*N.logical_not(mask).nonzero()):
865 v = confusionmatrix[j, i]
866
867 if numbers_alpha is None:
868 alpha = 1.0
869 else:
870
871 alpha = 1 - N.array(1 - v / maxv) ** numbers_alpha
872 y = {'lower':j, 'upper':Nlabels-j-1}[origin]
873 numbers_kwargs_['color'] = colors[int(v<maxv/2)]
874 numbers_kwargs_.update(numbers_kwargs)
875 P.text(i+0.5, y+0.5, '%d' % v, alpha=alpha, **numbers_kwargs_)
876
877 maxv = N.max(confusionmatrix)
878 boundaries = N.linspace(0, maxv, N.min(maxv, 10), True)
879
880
881 P.xlabel("targets")
882 P.ylabel("predictions")
883
884 P.setp(axi, xticks=ticks, yticks=yticks,
885 xticklabels=tick_labels, yticklabels=tick_labels)
886
887 axi.xaxis.set_ticks_position(xticks_position)
888 axi.xaxis.set_label_position(xticks_position)
889
890 if xlabels_vertical:
891 P.setp(P.getp(axi, 'xticklabels'), rotation='vertical')
892
893 axcb = fig.add_axes([0.8, ybottom, 0.02, 0.7])
894 cb = P.colorbar(im, cax=axcb, format='%d', ticks = boundaries)
895
896 if P.matplotlib.get_backend() == 'TkAgg':
897 P.ion()
898 P.draw()
899
900 self._plotted_confusionmatrix = confusionmatrix
901 return fig, im, cb
902
903
904 @property
906 self.compute()
907 return 1.0-self.__Ncorrect*1.0/sum(self.__Nsamples)
908
909
910 @property
912 self.compute()
913 return self.__labels
914
915
917 return self.__labels_map
918
919
921 if val is None or isinstance(val, dict):
922 self.__labels_map = val
923 else:
924 raise ValueError, "Cannot set labels_map to %s" % val
925
926 self.__labels_map_rev = None
927 self._computed = False
928
929
930 @property
932 self.compute()
933 return self.__matrix
934
935
936 @property
938 self.compute()
939 return 100.0*self.__Ncorrect/sum(self.__Nsamples)
940
941 labels_map = property(fget=getLabels_map, fset=setLabels_map)
942
943
945 """Class to contain information and display on regression results.
946
947 """
948
949 _STATS_DESCRIPTION = (
950 ('CCe', 'Error based on correlation coefficient',
951 '1 - corr_coef'),
952 ('CCp', 'Correlation coefficient (p-value)', None),
953 ('RMSE', 'Root mean squared error', None),
954 ('STD', 'Standard deviation', None),
955 ('RMP', 'Root mean power (compare to RMSE of results)',
956 'sqrt(mean( data**2 ))'),
957 ) + SummaryStatistics._STATS_DESCRIPTION
958
959
961 """Initialize RegressionStatistics
962
963 :Parameters:
964 targets
965 Optional set of targets
966 predictions
967 Optional set of predictions
968 """
969
970 SummaryStatistics.__init__(self, **kwargs)
971
972
974 """Actually compute the confusion matrix based on all the sets"""
975
976 super(RegressionStatistics, self)._compute()
977 sets = self.sets
978 Nsets = len(sets)
979
980 stats = {}
981
982 funcs = {
983 'RMP_t': lambda p,t:rootMeanPowerFx(t),
984 'STD_t': lambda p,t:N.std(t),
985 'RMP_p': lambda p,t:rootMeanPowerFx(p),
986 'STD_p': lambda p,t:N.std(p),
987 'CCe': CorrErrorFx(),
988 'CCp': CorrErrorPFx(),
989 'RMSE': RMSErrorFx(),
990 'RMSE/RMP_t': RelativeRMSErrorFx()
991 }
992
993 for funcname, func in funcs.iteritems():
994 funcname_all = funcname + '_all'
995 stats[funcname_all] = []
996 for i, (targets, predictions, values) in enumerate(sets):
997 stats[funcname_all] += [func(predictions, targets)]
998 stats[funcname_all] = N.array(stats[funcname_all])
999 stats[funcname] = N.mean(stats[funcname_all])
1000 stats[funcname+'_std'] = N.std(stats[funcname_all])
1001 stats[funcname+'_max'] = N.max(stats[funcname_all])
1002 stats[funcname+'_min'] = N.min(stats[funcname_all])
1003
1004
1005
1006
1007 targets, predictions = [], []
1008 for i, (targets_, predictions_, values_) in enumerate(sets):
1009 targets += list(targets_)
1010 predictions += list(predictions_)
1011
1012 for funcname, func in funcs.iteritems():
1013 funcname_all = 'Summary ' + funcname
1014 stats[funcname_all] = func(predictions, targets)
1015
1016 self._stats.update(stats)
1017
1018
1019 - def plot(self,
1020 plot=True, plot_stats=True,
1021 splot=True
1022
1023
1024
1025
1026 ):
1027 """Provide presentation of regression performance in image
1028
1029 :Parameters:
1030 plot : bool
1031 Plot regular plot of values (targets/predictions)
1032 plot_stats : bool
1033 Print basic statistics in the title
1034 splot : bool
1035 Plot scatter plot
1036
1037 :Returns:
1038 (fig, im, cb) -- figure, imshow, colorbar
1039 """
1040 externals.exists("pylab", raiseException=True)
1041 import pylab as P
1042
1043 self.compute()
1044
1045 nplots = plot + splot
1046
1047
1048 if P.matplotlib.get_backend() == 'TkAgg':
1049 P.ioff()
1050
1051 fig = P.gcf()
1052 P.clf()
1053 sps = []
1054
1055 nplot = 0
1056 if plot:
1057 nplot += 1
1058 sps.append(P.subplot(nplots, 1, nplot))
1059 xstart = 0
1060 lines = []
1061 for s in self.sets:
1062 nsamples = len(s[0])
1063 xend = xstart+nsamples
1064 xs = xrange(xstart, xend)
1065 lines += [P.plot(xs, s[0], 'b')]
1066 lines += [P.plot(xs, s[1], 'r')]
1067
1068 P.plot([xend, xend], [N.min(s[0]), N.max(s[0])], 'k--')
1069 xstart = xend
1070 if len(lines)>1:
1071 P.legend(lines[:2], ('Target', 'Prediction'))
1072 if plot_stats:
1073 P.title(self.asstring(short='very'))
1074
1075 if splot:
1076 nplot += 1
1077 sps.append(P.subplot(nplots, 1, nplot))
1078 for s in self.sets:
1079 P.plot(s[0], s[1], 'o',
1080 markeredgewidth=0.2,
1081 markersize=2)
1082 P.gca().set_aspect('equal')
1083
1084 if P.matplotlib.get_backend() == 'TkAgg':
1085 P.ion()
1086 P.draw()
1087
1088 return fig, sps
1089
1090 - def asstring(self, short=False, header=True, summary=True,
1091 description=False):
1092 """'Pretty print' the statistics"""
1093
1094 if len(self.sets) == 0:
1095 return "Empty"
1096
1097 self.compute()
1098
1099 stats = self.stats
1100
1101 if short:
1102 if short == 'very':
1103
1104 return "%(# of sets)d sets CCe=%(CCe).2f p=%(CCp).2g" \
1105 " RMSE:%(RMSE).2f" \
1106 " Summary: " \
1107 "CCe=%(Summary CCe).2f p=%(Summary CCp).2g" \
1108 % stats
1109 else:
1110 return "%(# of sets)d sets CCe=%(CCe).2f+-%(CCe_std).3f" \
1111 " RMSE=%(RMSE).2f+-%(RMSE_std).3f" \
1112 " RMSE/RMP_t=%(RMSE/RMP_t).2f+-%(RMSE/RMP_t_std).3f" \
1113 % stats
1114
1115 stats_data = ['RMP_t', 'STD_t', 'RMP_p', 'STD_p']
1116
1117 stats_ = ['CCe', 'RMSE', 'RMSE/RMP_t']
1118 stats_summary = ['# of sets']
1119
1120 out = StringIO()
1121
1122 printed = []
1123 if header:
1124
1125 printed.append(['Statistics', 'Mean', 'Std', 'Min', 'Max'])
1126
1127 printed.append(['----------', '-----', '-----', '-----', '-----'])
1128
1129 def print_stats(printed, stats_):
1130
1131 for stat in stats_:
1132 s = [stat]
1133 for suffix in ['', '_std', '_min', '_max']:
1134 s += [ _p2(stats[stat+suffix], 3) ]
1135 printed.append(s)
1136
1137 printed.append(["Data: "])
1138 print_stats(printed, stats_data)
1139 printed.append(["Results: "])
1140 print_stats(printed, stats_)
1141 printed.append(["Summary: "])
1142 printed.append(["CCe", _p2(stats['Summary CCe']), "", "p=", '%g' % stats['Summary CCp']])
1143 printed.append(["RMSE", _p2(stats['Summary RMSE'])])
1144 printed.append(["RMSE/RMP_t", _p2(stats['Summary RMSE/RMP_t'])])
1145
1146 if summary:
1147 for stat in stats_summary:
1148 printed.append([stat] + [_p2(stats[stat])])
1149
1150 table2string(printed, out)
1151
1152 if description:
1153 out.write("\nDescription of printed statistics.\n"
1154 " Suffixes: _t - targets, _p - predictions\n")
1155
1156 for d, val, eq in self._STATS_DESCRIPTION:
1157 out.write(" %-3s: %s\n" % (d, val))
1158 if eq is not None:
1159 out.write(" " + eq + "\n")
1160
1161 result = out.getvalue()
1162 out.close()
1163 return result
1164
1165
1166 @property
1170
1171
1172
1174 """Compute (or return) some error of a (trained) classifier on a dataset.
1175 """
1176
1177 confusion = StateVariable(enabled=False)
1178 """TODO Think that labels might be also symbolic thus can't directly
1179 be indicies of the array
1180 """
1181
1182 training_confusion = StateVariable(enabled=False,
1183 doc="Proxy training_confusion from underlying classifier.")
1184
1185
1186 - def __init__(self, clf, labels=None, train=True, **kwargs):
1187 """Initialization.
1188
1189 :Parameters:
1190 clf : Classifier
1191 Either trained or untrained classifier
1192 labels : list
1193 if provided, should be a set of labels to add on top of the
1194 ones present in testdata
1195 train : bool
1196 unless train=False, classifier gets trained if
1197 trainingdata provided to __call__
1198 """
1199 Stateful.__init__(self, **kwargs)
1200 self.__clf = clf
1201
1202 self._labels = labels
1203 """Labels to add on top to existing in testing data"""
1204
1205 self.__train = train
1206 """Either to train classifier if trainingdata is provided"""
1207
1208
1209 __doc__ = enhancedDocString('ClassifierError', locals(), Stateful)
1210
1211
1217
1218
1219 - def _precall(self, testdataset, trainingdataset=None):
1220 """Generic part which trains the classifier if necessary
1221 """
1222 if not trainingdataset is None:
1223 if self.__train:
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235 if self.states.isEnabled('training_confusion'):
1236 self.__clf.states._changeTemporarily(
1237 enable_states=['training_confusion'])
1238 self.__clf.train(trainingdataset)
1239 if self.states.isEnabled('training_confusion'):
1240 self.training_confusion = self.__clf.training_confusion
1241 self.__clf.states._resetEnabledTemporarily()
1242
1243 if self.__clf.states.isEnabled('trained_labels') and \
1244 not testdataset is None:
1245 newlabels = Set(testdataset.uniquelabels) \
1246 - Set(self.__clf.trained_labels)
1247 if len(newlabels)>0:
1248 warning("Classifier %s wasn't trained to classify labels %s" %
1249 (`self.__clf`, `newlabels`) +
1250 " present in testing dataset. Make sure that you have" +
1251 " not mixed order/names of the arguments anywhere")
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264 - def _call(self, testdataset, trainingdataset=None):
1265 raise NotImplementedError
1266
1267
1268 - def _postcall(self, testdataset, trainingdataset=None, error=None):
1270
1271
1272 - def __call__(self, testdataset, trainingdataset=None):
1273 """Compute the transfer error for a certain test dataset.
1274
1275 If `trainingdataset` is not `None` the classifier is trained using the
1276 provided dataset before computing the transfer error. Otherwise the
1277 classifier is used in it's current state to make the predictions on
1278 the test dataset.
1279
1280 Returns a scalar value of the transfer error.
1281 """
1282 self._precall(testdataset, trainingdataset)
1283 error = self._call(testdataset, trainingdataset)
1284 self._postcall(testdataset, trainingdataset, error)
1285 if __debug__:
1286 debug('CERR', 'Classifier error on %s: %.2f'
1287 % (testdataset, error))
1288 return error
1289
1290
1291 @property
1294
1295
1296 @property
1299
1300
1301
1303 """Compute the transfer error of a (trained) classifier on a dataset.
1304
1305 The actual error value is computed using a customizable error function.
1306 Optionally the classifier can be trained by passing an additional
1307 training dataset to the __call__() method.
1308 """
1309
1310 null_prob = StateVariable(enabled=True,
1311 doc="Stores the probability of an error result under "
1312 "the NULL hypothesis")
1313 samples_error = StateVariable(enabled=False,
1314 doc="Per sample errors computed by invoking the "
1315 "error function for each sample individually. "
1316 "Errors are available in a dictionary with each "
1317 "samples origid as key.")
1318
1321 """Initialization.
1322
1323 :Parameters:
1324 clf : Classifier
1325 Either trained or untrained classifier
1326 errorfx
1327 Functor that computes a scalar error value from the vectors of
1328 desired and predicted values (e.g. subclass of `ErrorFunction`)
1329 labels : list
1330 if provided, should be a set of labels to add on top of the
1331 ones present in testdata
1332 null_dist : instance of distribution estimator
1333 """
1334 ClassifierError.__init__(self, clf, labels, **kwargs)
1335 self.__errorfx = errorfx
1336 self.__null_dist = autoNullDist(null_dist)
1337
1338
1339 __doc__ = enhancedDocString('TransferError', locals(), ClassifierError)
1340
1341
1349
1350
1351 - def _call(self, testdataset, trainingdataset=None):
1352 """Compute the transfer error for a certain test dataset.
1353
1354 If `trainingdataset` is not `None` the classifier is trained using the
1355 provided dataset before computing the transfer error. Otherwise the
1356 classifier is used in it's current state to make the predictions on
1357 the test dataset.
1358
1359 Returns a scalar value of the transfer error.
1360 """
1361
1362 clf = self.clf
1363 if testdataset is None:
1364
1365
1366 import traceback as tb
1367 filenames = [x[0] for x in tb.extract_stack(limit=100)]
1368 rfe_matches = [f for f in filenames if f.endswith('/rfe.py')]
1369 cv_matches = [f for f in filenames if
1370 f.endswith('cvtranserror.py')]
1371 msg = ""
1372 if len(rfe_matches) > 0 and len(cv_matches):
1373 msg = " It is possible that you used RFE with stopping " \
1374 "criterion based on the TransferError and directly" \
1375 " from CrossValidatedTransferError, such approach" \
1376 " would require exposing testing dataset " \
1377 " to the classifier which might heavily bias " \
1378 " generalization performance estimate. If you are " \
1379 " sure to use it that way, create CVTE with " \
1380 " parameter expose_testdataset=True"
1381 raise ValueError, "Transfer error call obtained None " \
1382 "as a dataset for testing.%s" % msg
1383 predictions = clf.predict(testdataset.samples)
1384
1385
1386
1387
1388
1389
1390
1391 states = self.states
1392 if states.isEnabled('confusion'):
1393 confusion = clf._summaryClass(
1394
1395 targets=testdataset.labels,
1396 predictions=predictions,
1397 values=clf.states.get('values', None))
1398 try:
1399 confusion.labels_map = testdataset.labels_map
1400 except:
1401 pass
1402 states.confusion = confusion
1403
1404 if states.isEnabled('samples_error'):
1405 samples_error = []
1406 for i, p in enumerate(predictions):
1407 samples_error.append(self.__errorfx(p, testdataset.labels[i]))
1408
1409 states.samples_error = dict(zip(testdataset.origids, samples_error))
1410
1411
1412 error = self.__errorfx(predictions, testdataset.labels)
1413
1414 return error
1415
1416
1417 - def _postcall(self, vdata, wdata=None, error=None):
1418 """
1419 """
1420
1421
1422 if not self.__null_dist is None and not wdata is None:
1423
1424
1425
1426 null_terr = copy.copy(self)
1427 null_terr.__null_dist = None
1428 self.__null_dist.fit(null_terr, wdata, vdata)
1429
1430
1431
1432 if not error is None and not self.__null_dist is None:
1433 self.null_prob = self.__null_dist.p(error)
1434
1435
1436 @property
1437 - def errorfx(self): return self.__errorfx
1438
1439 @property
1440 - def null_dist(self): return self.__null_dist
1441
1442
1443
1445 """For a given classifier report an error based on internally
1446 computed error measure (given by some `ConfusionMatrix` stored in
1447 some state variable of `Classifier`).
1448
1449 This way we can perform feature selection taking as the error
1450 criterion either learning error, or transfer to splits error in
1451 the case of SplitClassifier
1452 """
1453
1454 - def __init__(self, clf, labels=None, confusion_state="training_confusion",
1455 **kwargs):
1456 """Initialization.
1457
1458 :Parameters:
1459 clf : Classifier
1460 Either trained or untrained classifier
1461 confusion_state
1462 Id of the state variable which stores `ConfusionMatrix`
1463 labels : list
1464 if provided, should be a set of labels to add on top of the
1465 ones present in testdata
1466 """
1467 ClassifierError.__init__(self, clf, labels, **kwargs)
1468
1469 self.__confusion_state = confusion_state
1470 """What state to extract from"""
1471
1472 if not clf.states.isKnown(confusion_state):
1473 raise ValueError, \
1474 "State variable %s is not defined for classifier %s" % \
1475 (confusion_state, `clf`)
1476 if not clf.states.isEnabled(confusion_state):
1477 if __debug__:
1478 debug('CERR', "Forcing state %s to be enabled for %s" %
1479 (confusion_state, `clf`))
1480 clf.states.enable(confusion_state)
1481
1482
1483 __doc__ = enhancedDocString('ConfusionBasedError', locals(),
1484 ClassifierError)
1485
1486
1487 - def _call(self, testdata, trainingdata=None):
1493