1
2
3
4
5
6
7
8
9 """Base class for all classifiers.
10
11 At the moment, regressions are treated just as a special case of
12 classifier (or vise verse), so the same base class `Classifier` is
13 utilized for both kinds.
14 """
15
16 __docformat__ = 'restructuredtext'
17
18 import numpy as N
19
20 from mvpa.support.copy import deepcopy
21
22 import time
23
24 from mvpa.misc.support import idhash
25 from mvpa.misc.state import StateVariable, Parametrized
26 from mvpa.misc.param import Parameter
27
28 from mvpa.clfs.transerror import ConfusionMatrix, RegressionStatistics
29
30 from mvpa.base import warning
31
32 if __debug__:
33 from mvpa.base import debug
34
35
37 """Abstract classifier class to be inherited by all classifiers
38 """
39
40
41
42 _DEV__doc__ = """
43 Required behavior:
44
45 For every classifier is has to be possible to be instantiated without
46 having to specify the training pattern.
47
48 Repeated calls to the train() method with different training data have to
49 result in a valid classifier, trained for the particular dataset.
50
51 It must be possible to specify all classifier parameters as keyword
52 arguments to the constructor.
53
54 Recommended behavior:
55
56 Derived classifiers should provide access to *values* -- i.e. that
57 information that is finally used to determine the predicted class label.
58
59 Michael: Maybe it works well if each classifier provides a 'values'
60 state member. This variable is a list as long as and in same order
61 as Dataset.uniquelabels (training data). Each item in the list
62 corresponds to the likelyhood of a sample to belong to the
63 respective class. However the semantics might differ between
64 classifiers, e.g. kNN would probably store distances to class-
65 neighbors, where PLR would store the raw function value of the
66 logistic function. So in the case of kNN low is predictive and for
67 PLR high is predictive. Don't know if there is the need to unify
68 that.
69
70 As the storage and/or computation of this information might be
71 demanding its collection should be switchable and off be default.
72
73 Nomenclature
74 * predictions : corresponds to the quantized labels if classifier spits
75 out labels by .predict()
76 * values : might be different from predictions if a classifier's predict()
77 makes a decision based on some internal value such as
78 probability or a distance.
79 """
80
81
82
83
84
85
86
87
88
89 trained_labels = StateVariable(enabled=True,
90 doc="Set of unique labels it has been trained on")
91
92 trained_dataset = StateVariable(enabled=False,
93 doc="The dataset it has been trained on")
94
95 training_confusion = StateVariable(enabled=False,
96 doc="Confusion matrix of learning performance")
97
98 predictions = StateVariable(enabled=True,
99 doc="Most recent set of predictions")
100
101 values = StateVariable(enabled=True,
102 doc="Internal classifier values the most recent " +
103 "predictions are based on")
104
105 training_time = StateVariable(enabled=True,
106 doc="Time (in seconds) which took classifier to train")
107
108 predicting_time = StateVariable(enabled=True,
109 doc="Time (in seconds) which took classifier to predict")
110
111 feature_ids = StateVariable(enabled=False,
112 doc="Feature IDS which were used for the actual training.")
113
114 _clf_internals = []
115 """Describes some specifics about the classifier -- is that it is
116 doing regression for instance...."""
117
118 regression = Parameter(False, allowedtype='bool',
119 doc="""Either to use 'regression' as regression. By default any
120 Classifier-derived class serves as a classifier, so regression
121 does binary classification.""", index=1001)
122
123
124 retrainable = Parameter(False, allowedtype='bool',
125 doc="""Either to enable retraining for 'retrainable' classifier.""",
126 index=1002)
127
128
130 """Cheap initialization.
131 """
132 Parametrized.__init__(self, **kwargs)
133
134
135 self.__trainednfeatures = None
136 """Stores number of features for which classifier was trained.
137 If None -- it wasn't trained at all"""
138
139 self._setRetrainable(self.params.retrainable, force=True)
140
141 if self.params.regression:
142 for statevar in [ "trained_labels"]:
143 if self.states.isEnabled(statevar):
144 if __debug__:
145 debug("CLF",
146 "Disabling state %s since doing regression, " %
147 statevar + "not classification")
148 self.states.disable(statevar)
149 self._summaryClass = RegressionStatistics
150 else:
151 self._summaryClass = ConfusionMatrix
152 if 'regression' in self._clf_internals:
153
154
155 self._clf_internals.append('binary')
156
157
158
159
160
161
162
164 if __debug__ and 'CLF_' in debug.active:
165 return "%s / %s" % (repr(self), super(Classifier, self).__str__())
166 else:
167 return repr(self)
168
171
172
174 """Functionality prior to training
175 """
176
177
178 params = self.params
179 if not params.retrainable:
180 self.untrain()
181 else:
182
183 self.states.reset()
184 if not self.__changedData_isset:
185 self.__resetChangedData()
186 _changedData = self._changedData
187 __idhashes = self.__idhashes
188 __invalidatedChangedData = self.__invalidatedChangedData
189
190
191
192 if __debug__:
193 debug('CLF_', "IDHashes are %s" % (__idhashes))
194
195
196 for key, data_ in (('traindata', dataset.samples),
197 ('labels', dataset.labels)):
198 _changedData[key] = self.__wasDataChanged(key, data_)
199
200
201 if __invalidatedChangedData.get(key, False):
202 if __debug__ and not _changedData[key]:
203 debug('CLF_', 'Found that idhash for %s was '
204 'invalidated by retraining' % key)
205 _changedData[key] = True
206
207
208 for col in self._paramscols:
209 changedParams = self._collections[col].whichSet()
210 if len(changedParams):
211 _changedData[col] = changedParams
212
213 self.__invalidatedChangedData = {}
214
215 if __debug__:
216 debug('CLF_', "Obtained _changedData is %s"
217 % (self._changedData))
218
219 if not params.regression and 'regression' in self._clf_internals \
220 and not self.states.isEnabled('trained_labels'):
221
222
223 if __debug__:
224 debug("CLF", "Enabling trained_labels state since it is needed")
225 self.states.enable('trained_labels')
226
227
228 - def _posttrain(self, dataset):
229 """Functionality post training
230
231 For instance -- computing confusion matrix
232 :Parameters:
233 dataset : Dataset
234 Data which was used for training
235 """
236 if self.states.isEnabled('trained_labels'):
237 self.trained_labels = dataset.uniquelabels
238
239 self.trained_dataset = dataset
240
241
242 self.__trainednfeatures = dataset.nfeatures
243
244 if __debug__ and 'CHECK_TRAINED' in debug.active:
245 self.__trainedidhash = dataset.idhash
246
247 if self.states.isEnabled('training_confusion') and \
248 not self.states.isSet('training_confusion'):
249
250
251 self.states._changeTemporarily(
252 disable_states=["predictions"])
253 if self.params.retrainable:
254
255
256
257
258
259 self.__changedData_isset = False
260 predictions = self.predict(dataset.samples)
261 self.states._resetEnabledTemporarily()
262 self.training_confusion = self._summaryClass(
263 targets=dataset.labels,
264 predictions=predictions)
265
266 try:
267 self.training_confusion.labels_map = dataset.labels_map
268 except:
269 pass
270
271 if self.states.isEnabled('feature_ids'):
272 self.feature_ids = self._getFeatureIds()
273
274
276 """Virtual method to return feature_ids used while training
277
278 Is not intended to be called anywhere but from _posttrain,
279 thus classifier is assumed to be trained at this point
280 """
281
282 return range(self.__trainednfeatures)
283
284
286 """Providing summary over the classifier"""
287
288 s = "Classifier %s" % self
289 states = self.states
290 states_enabled = states.enabled
291
292 if self.trained:
293 s += "\n trained"
294 if states.isSet('training_time'):
295 s += ' in %.3g sec' % states.training_time
296 s += ' on data with'
297 if states.isSet('trained_labels'):
298 s += ' labels:%s' % list(states.trained_labels)
299 if states.isSet('trained_dataset'):
300 td = states.trained_dataset
301 s += ' #samples:%d #chunks:%d' % (td.nsamples,
302 len(td.uniquechunks))
303 s += " #features:%d" % self.__trainednfeatures
304 if states.isSet('feature_ids'):
305 s += ", used #features:%d" % len(states.feature_ids)
306 if states.isSet('training_confusion'):
307 s += ", training error:%.3g" % states.training_confusion.error
308 else:
309 s += "\n not yet trained"
310
311 if len(states_enabled):
312 s += "\n enabled states:%s" % ', '.join([str(states[x])
313 for x in states_enabled])
314 return s
315
316
318 """Create full copy of the classifier.
319
320 It might require classifier to be untrained first due to
321 present SWIG bindings.
322
323 TODO: think about proper re-implementation, without enrollment of deepcopy
324 """
325 try:
326 return deepcopy(self)
327 except:
328 self.untrain()
329 return deepcopy(self)
330
331
333 """Function to be actually overridden in derived classes
334 """
335 raise NotImplementedError
336
337
338 - def train(self, dataset):
339 """Train classifier on a dataset
340
341 Shouldn't be overridden in subclasses unless explicitly needed
342 to do so
343 """
344 if __debug__:
345 debug("CLF", "Training classifier %(clf)s on dataset %(dataset)s",
346 msgargs={'clf':self, 'dataset':dataset})
347
348 self._pretrain(dataset)
349
350
351 t0 = time.time()
352
353 if dataset.nfeatures > 0:
354 result = self._train(dataset)
355 else:
356 warning("Trying to train on dataset with no features present")
357 if __debug__:
358 debug("CLF",
359 "No features present for training, no actual training " \
360 "is called")
361 result = None
362
363 self.training_time = time.time() - t0
364 self._posttrain(dataset)
365 return result
366
367
369 """Functionality prior prediction
370 """
371 if not ('notrain2predict' in self._clf_internals):
372
373 if not self.trained:
374 raise ValueError, \
375 "Classifier %s wasn't yet trained, therefore can't " \
376 "predict" % self
377 nfeatures = data.shape[1]
378
379
380 if nfeatures != self.__trainednfeatures:
381 raise ValueError, \
382 "Classifier %s was trained on data with %d features, " % \
383 (self, self.__trainednfeatures) + \
384 "thus can't predict for %d features" % nfeatures
385
386
387 if self.params.retrainable:
388 if not self.__changedData_isset:
389 self.__resetChangedData()
390 _changedData = self._changedData
391 _changedData['testdata'] = \
392 self.__wasDataChanged('testdata', data)
393 if __debug__:
394 debug('CLF_', "prepredict: Obtained _changedData is %s"
395 % (_changedData))
396
397
398 - def _postpredict(self, data, result):
399 """Functionality after prediction is computed
400 """
401 self.predictions = result
402 if self.params.retrainable:
403 self.__changedData_isset = False
404
406 """Actual prediction
407 """
408 raise NotImplementedError
409
410
412 """Predict classifier on data
413
414 Shouldn't be overridden in subclasses unless explicitly needed
415 to do so. Also subclasses trying to call super class's predict
416 should call _predict if within _predict instead of predict()
417 since otherwise it would loop
418 """
419 data = N.asarray(data)
420 if __debug__:
421 debug("CLF", "Predicting classifier %(clf)s on data %(data)s",
422 msgargs={'clf':self, 'data':data.shape})
423
424
425 t0 = time.time()
426
427 states = self.states
428
429
430 states.reset(['values', 'predictions'])
431
432 self._prepredict(data)
433
434 if self.__trainednfeatures > 0 \
435 or 'notrain2predict' in self._clf_internals:
436 result = self._predict(data)
437 else:
438 warning("Trying to predict using classifier trained on no features")
439 if __debug__:
440 debug("CLF",
441 "No features were present for training, prediction is " \
442 "bogus")
443 result = [None]*data.shape[0]
444
445 states.predicting_time = time.time() - t0
446
447 if 'regression' in self._clf_internals and not self.params.regression:
448
449
450
451
452
453
454
455
456 result_ = N.array(result)
457 if states.isEnabled('values'):
458
459
460 if not states.isSet('values'):
461 states.values = result_.copy()
462 else:
463
464
465
466 states.values = states.values.copy()
467
468 trained_labels = self.trained_labels
469 for i, value in enumerate(result):
470 dists = N.abs(value - trained_labels)
471 result[i] = trained_labels[N.argmin(dists)]
472
473 if __debug__:
474 debug("CLF_", "Converted regression result %(result_)s "
475 "into labels %(result)s for %(self_)s",
476 msgargs={'result_':result_, 'result':result,
477 'self_': self})
478
479 self._postpredict(data, result)
480 return result
481
482
484 """Either classifier was already trained.
485
486 MUST BE USED WITH CARE IF EVER"""
487 if dataset is None:
488
489 return not self.__trainednfeatures is None
490 else:
491 res = (self.__trainednfeatures == dataset.nfeatures)
492 if __debug__ and 'CHECK_TRAINED' in debug.active:
493 res2 = (self.__trainedidhash == dataset.idhash)
494 if res2 != res:
495 raise RuntimeError, \
496 "isTrained is weak and shouldn't be relied upon. " \
497 "Got result %b although comparing of idhash says %b" \
498 % (res, res2)
499 return res
500
501
503 """Some classifiers like BinaryClassifier can't be used for
504 regression"""
505
506 if self.params.regression:
507 raise ValueError, "Regression mode is meaningless for %s" % \
508 self.__class__.__name__ + " thus don't enable it"
509
510
511 @property
513 """Either classifier was already trained"""
514 return self.isTrained()
515
517 """Reset trained state"""
518 self.__trainednfeatures = None
519
520
521
522
523
524
525 super(Classifier, self).reset()
526
527
529 """Factory method to return an appropriate sensitivity analyzer for
530 the respective classifier."""
531 raise NotImplementedError
532
533
534
535
536
538 """Assign value of retrainable parameter
539
540 If retrainable flag is to be changed, classifier has to be
541 untrained. Also internal attributes such as _changedData,
542 __changedData_isset, and __idhashes should be initialized if
543 it becomes retrainable
544 """
545 pretrainable = self.params['retrainable']
546 if (force or value != pretrainable.value) \
547 and 'retrainable' in self._clf_internals:
548 if __debug__:
549 debug("CLF_", "Setting retrainable to %s" % value)
550 if 'meta' in self._clf_internals:
551 warning("Retrainability is not yet crafted/tested for "
552 "meta classifiers. Unpredictable behavior might occur")
553
554 if self.trained:
555 self.untrain()
556 states = self.states
557 if not value and states.isKnown('retrained'):
558 states.remove('retrained')
559 states.remove('repredicted')
560 if value:
561 if not 'retrainable' in self._clf_internals:
562 warning("Setting of flag retrainable for %s has no effect"
563 " since classifier has no such capability. It would"
564 " just lead to resources consumption and slowdown"
565 % self)
566 states.add(StateVariable(enabled=True,
567 name='retrained',
568 doc="Either retrainable classifier was retrained"))
569 states.add(StateVariable(enabled=True,
570 name='repredicted',
571 doc="Either retrainable classifier was repredicted"))
572
573 pretrainable.value = value
574
575
576 if value:
577 self.__idhashes = {'traindata': None, 'labels': None,
578 'testdata': None}
579 if __debug__ and 'CHECK_RETRAIN' in debug.active:
580
581
582
583
584 self.__trained = self.__idhashes.copy()
585 self.__resetChangedData()
586 self.__invalidatedChangedData = {}
587 elif 'retrainable' in self._clf_internals:
588
589 self.__changedData_isset = False
590 self._changedData = None
591 self.__idhashes = None
592 if __debug__ and 'CHECK_RETRAIN' in debug.active:
593 self.__trained = None
594
596 """For retrainable classifier we keep track of what was changed
597 This function resets that dictionary
598 """
599 if __debug__:
600 debug('CLF_',
601 'Retrainable: resetting flags on either data was changed')
602 keys = self.__idhashes.keys() + self._paramscols
603
604
605
606
607
608 self._changedData = dict(zip(keys, [False]*len(keys)))
609 self.__changedData_isset = False
610
611
613 """Check if given entry was changed from what known prior.
614
615 If so -- store only the ones needed for retrainable beastie
616 """
617 idhash_ = idhash(entry)
618 __idhashes = self.__idhashes
619
620 changed = __idhashes[key] != idhash_
621 if __debug__ and 'CHECK_RETRAIN' in debug.active:
622 __trained = self.__trained
623 changed2 = entry != __trained[key]
624 if isinstance(changed2, N.ndarray):
625 changed2 = changed2.any()
626 if changed != changed2 and not changed:
627 raise RuntimeError, \
628 'idhash found to be weak for %s. Though hashid %s!=%s %s, '\
629 'values %s!=%s %s' % \
630 (key, idhash_, __idhashes[key], changed,
631 entry, __trained[key], changed2)
632 if update:
633 __trained[key] = entry
634
635 if __debug__ and changed:
636 debug('CLF_', "Changed %s from %s to %s.%s"
637 % (key, __idhashes[key], idhash_,
638 ('','updated')[int(update)]))
639 if update:
640 __idhashes[key] = idhash_
641
642 return changed
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674 - def retrain(self, dataset, **kwargs):
675 """Helper to avoid check if data was changed actually changed
676
677 Useful if just some aspects of classifier were changed since
678 its previous training. For instance if dataset wasn't changed
679 but only classifier parameters, then kernel matrix does not
680 have to be computed.
681
682 Words of caution: classifier must be previously trained,
683 results always should first be compared to the results on not
684 'retrainable' classifier (without calling retrain). Some
685 additional checks are enabled if debug id 'CHECK_RETRAIN' is
686 enabled, to guard against obvious mistakes.
687
688 :Parameters:
689 kwargs
690 that is what _changedData gets updated with. So, smth like
691 ``(params=['C'], labels=True)`` if parameter C and labels
692 got changed
693 """
694
695
696 if __debug__:
697 if not self.params.retrainable:
698 raise RuntimeError, \
699 "Do not use re(train,predict) on non-retrainable %s" % \
700 self
701
702 if kwargs.has_key('params') or kwargs.has_key('kernel_params'):
703 raise ValueError, \
704 "Retraining for changed params not working yet"
705
706 self.__resetChangedData()
707
708
709 chd = self._changedData
710 ichd = self.__invalidatedChangedData
711
712 chd.update(kwargs)
713
714
715 for key, value in kwargs.iteritems():
716 if value:
717 ichd[key] = True
718 self.__changedData_isset = True
719
720
721 if __debug__ and 'CHECK_RETRAIN' in debug.active:
722 for key, data_ in (('traindata', dataset.samples),
723 ('labels', dataset.labels)):
724
725 if not chd[key] and not ichd.get(key, False):
726 if self.__wasDataChanged(key, data_, update=False):
727 raise RuntimeError, \
728 "Data %s found changed although wasn't " \
729 "labeled as such" % key
730
731
732
733
734
735
736 if __debug__ and 'CHECK_RETRAIN' in debug.active and self.trained \
737 and not self._changedData['traindata'] \
738 and self.__trained['traindata'].shape != dataset.samples.shape:
739 raise ValueError, "In retrain got dataset with %s size, " \
740 "whenever previousely was trained on %s size" \
741 % (dataset.samples.shape, self.__trained['traindata'].shape)
742 self.train(dataset)
743
744
746 """Helper to avoid check if data was changed actually changed
747
748 Useful if classifier was (re)trained but with the same data
749 (so just parameters were changed), so that it could be
750 repredicted easily (on the same data as before) without
751 recomputing for instance train/test kernel matrix. Should be
752 used with caution and always compared to the results on not
753 'retrainable' classifier. Some additional checks are enabled
754 if debug id 'CHECK_RETRAIN' is enabled, to guard against
755 obvious mistakes.
756
757 :Parameters:
758 data
759 data which is conventionally given to predict
760 kwargs
761 that is what _changedData gets updated with. So, smth like
762 ``(params=['C'], labels=True)`` if parameter C and labels
763 got changed
764 """
765 if len(kwargs)>0:
766 raise RuntimeError, \
767 "repredict for now should be used without params since " \
768 "it makes little sense to repredict if anything got changed"
769 if __debug__ and not self.params.retrainable:
770 raise RuntimeError, \
771 "Do not use retrain/repredict on non-retrainable classifiers"
772
773 self.__resetChangedData()
774 chd = self._changedData
775 chd.update(**kwargs)
776 self.__changedData_isset = True
777
778
779
780 if __debug__ and 'CHECK_RETRAIN' in debug.active:
781 for key, data_ in (('testdata', data),):
782
783
784 if self.__wasDataChanged(key, data_, update=False):
785 raise RuntimeError, \
786 "Data %s found changed although wasn't " \
787 "labeled as such" % key
788
789
790
791 if __debug__ and 'CHECK_RETRAIN' in debug.active \
792 and not self._changedData['testdata'] \
793 and self.__trained['testdata'].shape != data.shape:
794 raise ValueError, "In repredict got dataset with %s size, " \
795 "whenever previously was trained on %s size" \
796 % (data.shape, self.__trained['testdata'].shape)
797
798 return self.predict(data)
799
800
801
802
803
804