Package mvpa :: Package clfs :: Module base
[hide private]
[frames] | no frames]

Source Code for Module mvpa.clfs.base

  1  #emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  #ex: set sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Base class for all classifiers. 
 10   
 11  At the moment, regressions are treated just as a special case of 
 12  classifier (or vise verse), so the same base class `Classifier` is 
 13  utilized for both kinds. 
 14  """ 
 15   
 16  __docformat__ = 'restructuredtext' 
 17   
 18  import numpy as N 
 19   
 20  from mvpa.support.copy import deepcopy 
 21   
 22  import time 
 23   
 24  from mvpa.misc.support import idhash 
 25  from mvpa.misc.state import StateVariable, Parametrized 
 26  from mvpa.misc.param import Parameter 
 27   
 28  from mvpa.clfs.transerror import ConfusionMatrix, RegressionStatistics 
 29   
 30  from mvpa.base import warning 
 31   
 32  if __debug__: 
 33      from mvpa.base import debug 
 34   
 35   
36 -class Classifier(Parametrized):
37 """Abstract classifier class to be inherited by all classifiers 38 """ 39 40 # Kept separate from doc to don't pollute help(clf), especially if 41 # we including help for the parent class 42 _DEV__doc__ = """ 43 Required behavior: 44 45 For every classifier is has to be possible to be instantiated without 46 having to specify the training pattern. 47 48 Repeated calls to the train() method with different training data have to 49 result in a valid classifier, trained for the particular dataset. 50 51 It must be possible to specify all classifier parameters as keyword 52 arguments to the constructor. 53 54 Recommended behavior: 55 56 Derived classifiers should provide access to *values* -- i.e. that 57 information that is finally used to determine the predicted class label. 58 59 Michael: Maybe it works well if each classifier provides a 'values' 60 state member. This variable is a list as long as and in same order 61 as Dataset.uniquelabels (training data). Each item in the list 62 corresponds to the likelyhood of a sample to belong to the 63 respective class. However the semantics might differ between 64 classifiers, e.g. kNN would probably store distances to class- 65 neighbors, where PLR would store the raw function value of the 66 logistic function. So in the case of kNN low is predictive and for 67 PLR high is predictive. Don't know if there is the need to unify 68 that. 69 70 As the storage and/or computation of this information might be 71 demanding its collection should be switchable and off be default. 72 73 Nomenclature 74 * predictions : corresponds to the quantized labels if classifier spits 75 out labels by .predict() 76 * values : might be different from predictions if a classifier's predict() 77 makes a decision based on some internal value such as 78 probability or a distance. 79 """ 80 # Dict that contains the parameters of a classifier. 81 # This shall provide an interface to plug generic parameter optimizer 82 # on all classifiers (e.g. grid- or line-search optimizer) 83 # A dictionary is used because Michael thinks that access by name is nicer. 84 # Additionally Michael thinks ATM that additional information might be 85 # necessary in some situations (e.g. reasonably predefined parameter range, 86 # minimal iteration stepsize, ...), therefore the value to each key should 87 # also be a dict or we should use mvpa.misc.param.Parameter'... 88 89 trained_labels = StateVariable(enabled=True, 90 doc="Set of unique labels it has been trained on") 91 92 trained_dataset = StateVariable(enabled=False, 93 doc="The dataset it has been trained on") 94 95 training_confusion = StateVariable(enabled=False, 96 doc="Confusion matrix of learning performance") 97 98 predictions = StateVariable(enabled=True, 99 doc="Most recent set of predictions") 100 101 values = StateVariable(enabled=True, 102 doc="Internal classifier values the most recent " + 103 "predictions are based on") 104 105 training_time = StateVariable(enabled=True, 106 doc="Time (in seconds) which took classifier to train") 107 108 predicting_time = StateVariable(enabled=True, 109 doc="Time (in seconds) which took classifier to predict") 110 111 feature_ids = StateVariable(enabled=False, 112 doc="Feature IDS which were used for the actual training.") 113 114 _clf_internals = [] 115 """Describes some specifics about the classifier -- is that it is 116 doing regression for instance....""" 117 118 regression = Parameter(False, allowedtype='bool', 119 doc="""Either to use 'regression' as regression. By default any 120 Classifier-derived class serves as a classifier, so regression 121 does binary classification.""", index=1001) 122 123 # TODO: make it available only for actually retrainable classifiers 124 retrainable = Parameter(False, allowedtype='bool', 125 doc="""Either to enable retraining for 'retrainable' classifier.""", 126 index=1002) 127 128
129 - def __init__(self, **kwargs):
130 """Cheap initialization. 131 """ 132 Parametrized.__init__(self, **kwargs) 133 134 135 self.__trainednfeatures = None 136 """Stores number of features for which classifier was trained. 137 If None -- it wasn't trained at all""" 138 139 self._setRetrainable(self.params.retrainable, force=True) 140 141 if self.params.regression: 142 for statevar in [ "trained_labels"]: #, "training_confusion" ]: 143 if self.states.isEnabled(statevar): 144 if __debug__: 145 debug("CLF", 146 "Disabling state %s since doing regression, " % 147 statevar + "not classification") 148 self.states.disable(statevar) 149 self._summaryClass = RegressionStatistics 150 else: 151 self._summaryClass = ConfusionMatrix 152 if 'regression' in self._clf_internals: 153 # regressions are used as binary classifiers if not 154 # asked to perform regression explicitely 155 self._clf_internals.append('binary')
156 157 # deprecate 158 #self.__trainedidhash = None 159 #"""Stores id of the dataset on which it was trained to signal 160 #in trained() if it was trained already on the same dataset""" 161 162
163 - def __str__(self):
164 if __debug__ and 'CLF_' in debug.active: 165 return "%s / %s" % (repr(self), super(Classifier, self).__str__()) 166 else: 167 return repr(self)
168
169 - def __repr__(self, prefixes=[]):
170 return super(Classifier, self).__repr__(prefixes=prefixes)
171 172
173 - def _pretrain(self, dataset):
174 """Functionality prior to training 175 """ 176 # So we reset all state variables and may be free up some memory 177 # explicitly 178 params = self.params 179 if not params.retrainable: 180 self.untrain() 181 else: 182 # just reset the states, do not untrain 183 self.states.reset() 184 if not self.__changedData_isset: 185 self.__resetChangedData() 186 _changedData = self._changedData 187 __idhashes = self.__idhashes 188 __invalidatedChangedData = self.__invalidatedChangedData 189 190 # if we don't know what was changed we need to figure 191 # them out 192 if __debug__: 193 debug('CLF_', "IDHashes are %s" % (__idhashes)) 194 195 # Look at the data if any was changed 196 for key, data_ in (('traindata', dataset.samples), 197 ('labels', dataset.labels)): 198 _changedData[key] = self.__wasDataChanged(key, data_) 199 # if those idhashes were invalidated by retraining 200 # we need to adjust _changedData accordingly 201 if __invalidatedChangedData.get(key, False): 202 if __debug__ and not _changedData[key]: 203 debug('CLF_', 'Found that idhash for %s was ' 204 'invalidated by retraining' % key) 205 _changedData[key] = True 206 207 # Look at the parameters 208 for col in self._paramscols: 209 changedParams = self._collections[col].whichSet() 210 if len(changedParams): 211 _changedData[col] = changedParams 212 213 self.__invalidatedChangedData = {} # reset it on training 214 215 if __debug__: 216 debug('CLF_', "Obtained _changedData is %s" 217 % (self._changedData)) 218 219 if not params.regression and 'regression' in self._clf_internals \ 220 and not self.states.isEnabled('trained_labels'): 221 # if classifier internally does regression we need to have 222 # labels it was trained on 223 if __debug__: 224 debug("CLF", "Enabling trained_labels state since it is needed") 225 self.states.enable('trained_labels')
226 227
228 - def _posttrain(self, dataset):
229 """Functionality post training 230 231 For instance -- computing confusion matrix 232 :Parameters: 233 dataset : Dataset 234 Data which was used for training 235 """ 236 if self.states.isEnabled('trained_labels'): 237 self.trained_labels = dataset.uniquelabels 238 239 self.trained_dataset = dataset 240 241 # needs to be assigned first since below we use predict 242 self.__trainednfeatures = dataset.nfeatures 243 244 if __debug__ and 'CHECK_TRAINED' in debug.active: 245 self.__trainedidhash = dataset.idhash 246 247 if self.states.isEnabled('training_confusion') and \ 248 not self.states.isSet('training_confusion'): 249 # we should not store predictions for training data, 250 # it is confusing imho (yoh) 251 self.states._changeTemporarily( 252 disable_states=["predictions"]) 253 if self.params.retrainable: 254 # we would need to recheck if data is the same, 255 # XXX think if there is a way to make this all 256 # efficient. For now, probably, retrainable 257 # classifiers have no chance but not to use 258 # training_confusion... sad 259 self.__changedData_isset = False 260 predictions = self.predict(dataset.samples) 261 self.states._resetEnabledTemporarily() 262 self.training_confusion = self._summaryClass( 263 targets=dataset.labels, 264 predictions=predictions) 265 266 try: 267 self.training_confusion.labels_map = dataset.labels_map 268 except: 269 pass 270 271 if self.states.isEnabled('feature_ids'): 272 self.feature_ids = self._getFeatureIds()
273 274
275 - def _getFeatureIds(self):
276 """Virtual method to return feature_ids used while training 277 278 Is not intended to be called anywhere but from _posttrain, 279 thus classifier is assumed to be trained at this point 280 """ 281 # By default all features are used 282 return range(self.__trainednfeatures)
283 284
285 - def summary(self):
286 """Providing summary over the classifier""" 287 288 s = "Classifier %s" % self 289 states = self.states 290 states_enabled = states.enabled 291 292 if self.trained: 293 s += "\n trained" 294 if states.isSet('training_time'): 295 s += ' in %.3g sec' % states.training_time 296 s += ' on data with' 297 if states.isSet('trained_labels'): 298 s += ' labels:%s' % list(states.trained_labels) 299 if states.isSet('trained_dataset'): 300 td = states.trained_dataset 301 s += ' #samples:%d #chunks:%d' % (td.nsamples, 302 len(td.uniquechunks)) 303 s += " #features:%d" % self.__trainednfeatures 304 if states.isSet('feature_ids'): 305 s += ", used #features:%d" % len(states.feature_ids) 306 if states.isSet('training_confusion'): 307 s += ", training error:%.3g" % states.training_confusion.error 308 else: 309 s += "\n not yet trained" 310 311 if len(states_enabled): 312 s += "\n enabled states:%s" % ', '.join([str(states[x]) 313 for x in states_enabled]) 314 return s
315 316
317 - def clone(self):
318 """Create full copy of the classifier. 319 320 It might require classifier to be untrained first due to 321 present SWIG bindings. 322 323 TODO: think about proper re-implementation, without enrollment of deepcopy 324 """ 325 try: 326 return deepcopy(self) 327 except: 328 self.untrain() 329 return deepcopy(self)
330 331
332 - def _train(self, dataset):
333 """Function to be actually overridden in derived classes 334 """ 335 raise NotImplementedError
336 337
338 - def train(self, dataset):
339 """Train classifier on a dataset 340 341 Shouldn't be overridden in subclasses unless explicitly needed 342 to do so 343 """ 344 if __debug__: 345 debug("CLF", "Training classifier %(clf)s on dataset %(dataset)s", 346 msgargs={'clf':self, 'dataset':dataset}) 347 348 self._pretrain(dataset) 349 350 # remember the time when started training 351 t0 = time.time() 352 353 if dataset.nfeatures > 0: 354 result = self._train(dataset) 355 else: 356 warning("Trying to train on dataset with no features present") 357 if __debug__: 358 debug("CLF", 359 "No features present for training, no actual training " \ 360 "is called") 361 result = None 362 363 self.training_time = time.time() - t0 364 self._posttrain(dataset) 365 return result
366 367
368 - def _prepredict(self, data):
369 """Functionality prior prediction 370 """ 371 if not ('notrain2predict' in self._clf_internals): 372 # check if classifier was trained if that is needed 373 if not self.trained: 374 raise ValueError, \ 375 "Classifier %s wasn't yet trained, therefore can't " \ 376 "predict" % self 377 nfeatures = data.shape[1] 378 # check if number of features is the same as in the data 379 # it was trained on 380 if nfeatures != self.__trainednfeatures: 381 raise ValueError, \ 382 "Classifier %s was trained on data with %d features, " % \ 383 (self, self.__trainednfeatures) + \ 384 "thus can't predict for %d features" % nfeatures 385 386 387 if self.params.retrainable: 388 if not self.__changedData_isset: 389 self.__resetChangedData() 390 _changedData = self._changedData 391 _changedData['testdata'] = \ 392 self.__wasDataChanged('testdata', data) 393 if __debug__: 394 debug('CLF_', "prepredict: Obtained _changedData is %s" 395 % (_changedData))
396 397
398 - def _postpredict(self, data, result):
399 """Functionality after prediction is computed 400 """ 401 self.predictions = result 402 if self.params.retrainable: 403 self.__changedData_isset = False
404
405 - def _predict(self, data):
406 """Actual prediction 407 """ 408 raise NotImplementedError
409 410
411 - def predict(self, data):
412 """Predict classifier on data 413 414 Shouldn't be overridden in subclasses unless explicitly needed 415 to do so. Also subclasses trying to call super class's predict 416 should call _predict if within _predict instead of predict() 417 since otherwise it would loop 418 """ 419 data = N.asarray(data) 420 if __debug__: 421 debug("CLF", "Predicting classifier %(clf)s on data %(data)s", 422 msgargs={'clf':self, 'data':data.shape}) 423 424 # remember the time when started computing predictions 425 t0 = time.time() 426 427 states = self.states 428 # to assure that those are reset (could be set due to testing 429 # post-training) 430 states.reset(['values', 'predictions']) 431 432 self._prepredict(data) 433 434 if self.__trainednfeatures > 0 \ 435 or 'notrain2predict' in self._clf_internals: 436 result = self._predict(data) 437 else: 438 warning("Trying to predict using classifier trained on no features") 439 if __debug__: 440 debug("CLF", 441 "No features were present for training, prediction is " \ 442 "bogus") 443 result = [None]*data.shape[0] 444 445 states.predicting_time = time.time() - t0 446 447 if 'regression' in self._clf_internals and not self.params.regression: 448 # We need to convert regression values into labels 449 # XXX unify may be labels -> internal_labels conversion. 450 #if len(self.trained_labels) != 2: 451 # raise RuntimeError, "Ask developer to implement for " \ 452 # "multiclass mapping from regression into classification" 453 454 # must be N.array so we copy it to assign labels directly 455 # into labels, or should we just recreate "result"??? 456 result_ = N.array(result) 457 if states.isEnabled('values'): 458 # values could be set by now so assigning 'result' would 459 # be misleading 460 if not states.isSet('values'): 461 states.values = result_.copy() 462 else: 463 # it might be the values are pointing to result at 464 # the moment, so lets assure this silly way that 465 # they do not overlap 466 states.values = states.values.copy() 467 468 trained_labels = self.trained_labels 469 for i, value in enumerate(result): 470 dists = N.abs(value - trained_labels) 471 result[i] = trained_labels[N.argmin(dists)] 472 473 if __debug__: 474 debug("CLF_", "Converted regression result %(result_)s " 475 "into labels %(result)s for %(self_)s", 476 msgargs={'result_':result_, 'result':result, 477 'self_': self}) 478 479 self._postpredict(data, result) 480 return result
481 482 # deprecate ???
483 - def isTrained(self, dataset=None):
484 """Either classifier was already trained. 485 486 MUST BE USED WITH CARE IF EVER""" 487 if dataset is None: 488 # simply return if it was trained on anything 489 return not self.__trainednfeatures is None 490 else: 491 res = (self.__trainednfeatures == dataset.nfeatures) 492 if __debug__ and 'CHECK_TRAINED' in debug.active: 493 res2 = (self.__trainedidhash == dataset.idhash) 494 if res2 != res: 495 raise RuntimeError, \ 496 "isTrained is weak and shouldn't be relied upon. " \ 497 "Got result %b although comparing of idhash says %b" \ 498 % (res, res2) 499 return res
500 501
502 - def _regressionIsBogus(self):
503 """Some classifiers like BinaryClassifier can't be used for 504 regression""" 505 506 if self.params.regression: 507 raise ValueError, "Regression mode is meaningless for %s" % \ 508 self.__class__.__name__ + " thus don't enable it"
509 510 511 @property
512 - def trained(self):
513 """Either classifier was already trained""" 514 return self.isTrained()
515
516 - def untrain(self):
517 """Reset trained state""" 518 self.__trainednfeatures = None 519 # probably not needed... retrainable shouldn't be fully untrained 520 # or should be??? 521 #if self.params.retrainable: 522 # # ??? don't duplicate the code ;-) 523 # self.__idhashes = {'traindata': None, 'labels': None, 524 # 'testdata': None, 'testtraindata': None} 525 super(Classifier, self).reset()
526 527
528 - def getSensitivityAnalyzer(self, **kwargs):
529 """Factory method to return an appropriate sensitivity analyzer for 530 the respective classifier.""" 531 raise NotImplementedError
532 533 534 # 535 # Methods which are needed for retrainable classifiers 536 #
537 - def _setRetrainable(self, value, force=False):
538 """Assign value of retrainable parameter 539 540 If retrainable flag is to be changed, classifier has to be 541 untrained. Also internal attributes such as _changedData, 542 __changedData_isset, and __idhashes should be initialized if 543 it becomes retrainable 544 """ 545 pretrainable = self.params['retrainable'] 546 if (force or value != pretrainable.value) \ 547 and 'retrainable' in self._clf_internals: 548 if __debug__: 549 debug("CLF_", "Setting retrainable to %s" % value) 550 if 'meta' in self._clf_internals: 551 warning("Retrainability is not yet crafted/tested for " 552 "meta classifiers. Unpredictable behavior might occur") 553 # assure that we don't drag anything behind 554 if self.trained: 555 self.untrain() 556 states = self.states 557 if not value and states.isKnown('retrained'): 558 states.remove('retrained') 559 states.remove('repredicted') 560 if value: 561 if not 'retrainable' in self._clf_internals: 562 warning("Setting of flag retrainable for %s has no effect" 563 " since classifier has no such capability. It would" 564 " just lead to resources consumption and slowdown" 565 % self) 566 states.add(StateVariable(enabled=True, 567 name='retrained', 568 doc="Either retrainable classifier was retrained")) 569 states.add(StateVariable(enabled=True, 570 name='repredicted', 571 doc="Either retrainable classifier was repredicted")) 572 573 pretrainable.value = value 574 575 # if retrainable we need to keep track of things 576 if value: 577 self.__idhashes = {'traindata': None, 'labels': None, 578 'testdata': None} #, 'testtraindata': None} 579 if __debug__ and 'CHECK_RETRAIN' in debug.active: 580 # ??? it is not clear though if idhash is faster than 581 # simple comparison of (dataset != __traineddataset).any(), 582 # but if we like to get rid of __traineddataset then we 583 # should use idhash anyways 584 self.__trained = self.__idhashes.copy() # just same Nones 585 self.__resetChangedData() 586 self.__invalidatedChangedData = {} 587 elif 'retrainable' in self._clf_internals: 588 #self.__resetChangedData() 589 self.__changedData_isset = False 590 self._changedData = None 591 self.__idhashes = None 592 if __debug__ and 'CHECK_RETRAIN' in debug.active: 593 self.__trained = None
594
595 - def __resetChangedData(self):
596 """For retrainable classifier we keep track of what was changed 597 This function resets that dictionary 598 """ 599 if __debug__: 600 debug('CLF_', 601 'Retrainable: resetting flags on either data was changed') 602 keys = self.__idhashes.keys() + self._paramscols 603 # we might like to just reinit values to False??? 604 #_changedData = self._changedData 605 #if isinstance(_changedData, dict): 606 # for key in _changedData.keys(): 607 # _changedData[key] = False 608 self._changedData = dict(zip(keys, [False]*len(keys))) 609 self.__changedData_isset = False
610 611
612 - def __wasDataChanged(self, key, entry, update=True):
613 """Check if given entry was changed from what known prior. 614 615 If so -- store only the ones needed for retrainable beastie 616 """ 617 idhash_ = idhash(entry) 618 __idhashes = self.__idhashes 619 620 changed = __idhashes[key] != idhash_ 621 if __debug__ and 'CHECK_RETRAIN' in debug.active: 622 __trained = self.__trained 623 changed2 = entry != __trained[key] 624 if isinstance(changed2, N.ndarray): 625 changed2 = changed2.any() 626 if changed != changed2 and not changed: 627 raise RuntimeError, \ 628 'idhash found to be weak for %s. Though hashid %s!=%s %s, '\ 629 'values %s!=%s %s' % \ 630 (key, idhash_, __idhashes[key], changed, 631 entry, __trained[key], changed2) 632 if update: 633 __trained[key] = entry 634 635 if __debug__ and changed: 636 debug('CLF_', "Changed %s from %s to %s.%s" 637 % (key, __idhashes[key], idhash_, 638 ('','updated')[int(update)])) 639 if update: 640 __idhashes[key] = idhash_ 641 642 return changed
643 644 645 # def __updateHashIds(self, key, data): 646 # """Is twofold operation: updates hashid if was said that it changed. 647 # 648 # or if it wasn't said that data changed, but CHECK_RETRAIN and it found 649 # to be changed -- raise Exception 650 # """ 651 # 652 # check_retrain = __debug__ and 'CHECK_RETRAIN' in debug.active 653 # chd = self._changedData 654 # 655 # # we need to updated idhashes 656 # if chd[key] or check_retrain: 657 # keychanged = self.__wasDataChanged(key, data) 658 # if check_retrain and keychanged and not chd[key]: 659 # raise RuntimeError, \ 660 # "Data %s found changed although wasn't " \ 661 # "labeled as such" % key 662 663 664 # 665 # Additional API which is specific only for retrainable classifiers. 666 # For now it would just puke if asked from not retrainable one. 667 # 668 # Might come useful and efficient for statistics testing, so if just 669 # labels of dataset changed, then 670 # self.retrain(dataset, labels=True) 671 # would cause efficient retraining (no kernels recomputed etc) 672 # and subsequent self.repredict(data) should be also quite fase ;-) 673
674 - def retrain(self, dataset, **kwargs):
675 """Helper to avoid check if data was changed actually changed 676 677 Useful if just some aspects of classifier were changed since 678 its previous training. For instance if dataset wasn't changed 679 but only classifier parameters, then kernel matrix does not 680 have to be computed. 681 682 Words of caution: classifier must be previously trained, 683 results always should first be compared to the results on not 684 'retrainable' classifier (without calling retrain). Some 685 additional checks are enabled if debug id 'CHECK_RETRAIN' is 686 enabled, to guard against obvious mistakes. 687 688 :Parameters: 689 kwargs 690 that is what _changedData gets updated with. So, smth like 691 ``(params=['C'], labels=True)`` if parameter C and labels 692 got changed 693 """ 694 # Note that it also demolishes anything for repredicting, 695 # which should be ok in most of the cases 696 if __debug__: 697 if not self.params.retrainable: 698 raise RuntimeError, \ 699 "Do not use re(train,predict) on non-retrainable %s" % \ 700 self 701 702 if kwargs.has_key('params') or kwargs.has_key('kernel_params'): 703 raise ValueError, \ 704 "Retraining for changed params not working yet" 705 706 self.__resetChangedData() 707 708 # local bindings 709 chd = self._changedData 710 ichd = self.__invalidatedChangedData 711 712 chd.update(kwargs) 713 # mark for future 'train()' items which are explicitely 714 # mentioned as changed 715 for key, value in kwargs.iteritems(): 716 if value: 717 ichd[key] = True 718 self.__changedData_isset = True 719 720 # To check if we are not fooled 721 if __debug__ and 'CHECK_RETRAIN' in debug.active: 722 for key, data_ in (('traindata', dataset.samples), 723 ('labels', dataset.labels)): 724 # so it wasn't told to be invalid 725 if not chd[key] and not ichd.get(key, False): 726 if self.__wasDataChanged(key, data_, update=False): 727 raise RuntimeError, \ 728 "Data %s found changed although wasn't " \ 729 "labeled as such" % key 730 731 # TODO: parameters of classifiers... for now there is explicit 732 # 'forbidance' above 733 734 # Below check should be superseeded by check above, thus never occur. 735 # remove later on ??? 736 if __debug__ and 'CHECK_RETRAIN' in debug.active and self.trained \ 737 and not self._changedData['traindata'] \ 738 and self.__trained['traindata'].shape != dataset.samples.shape: 739 raise ValueError, "In retrain got dataset with %s size, " \ 740 "whenever previousely was trained on %s size" \ 741 % (dataset.samples.shape, self.__trained['traindata'].shape) 742 self.train(dataset)
743 744
745 - def repredict(self, data, **kwargs):
746 """Helper to avoid check if data was changed actually changed 747 748 Useful if classifier was (re)trained but with the same data 749 (so just parameters were changed), so that it could be 750 repredicted easily (on the same data as before) without 751 recomputing for instance train/test kernel matrix. Should be 752 used with caution and always compared to the results on not 753 'retrainable' classifier. Some additional checks are enabled 754 if debug id 'CHECK_RETRAIN' is enabled, to guard against 755 obvious mistakes. 756 757 :Parameters: 758 data 759 data which is conventionally given to predict 760 kwargs 761 that is what _changedData gets updated with. So, smth like 762 ``(params=['C'], labels=True)`` if parameter C and labels 763 got changed 764 """ 765 if len(kwargs)>0: 766 raise RuntimeError, \ 767 "repredict for now should be used without params since " \ 768 "it makes little sense to repredict if anything got changed" 769 if __debug__ and not self.params.retrainable: 770 raise RuntimeError, \ 771 "Do not use retrain/repredict on non-retrainable classifiers" 772 773 self.__resetChangedData() 774 chd = self._changedData 775 chd.update(**kwargs) 776 self.__changedData_isset = True 777 778 779 # check if we are attempted to perform on the same data 780 if __debug__ and 'CHECK_RETRAIN' in debug.active: 781 for key, data_ in (('testdata', data),): 782 # so it wasn't told to be invalid 783 #if not chd[key]:# and not ichd.get(key, False): 784 if self.__wasDataChanged(key, data_, update=False): 785 raise RuntimeError, \ 786 "Data %s found changed although wasn't " \ 787 "labeled as such" % key 788 789 # Should be superseded by above 790 # remove in future??? 791 if __debug__ and 'CHECK_RETRAIN' in debug.active \ 792 and not self._changedData['testdata'] \ 793 and self.__trained['testdata'].shape != data.shape: 794 raise ValueError, "In repredict got dataset with %s size, " \ 795 "whenever previously was trained on %s size" \ 796 % (data.shape, self.__trained['testdata'].shape) 797 798 return self.predict(data)
799 800 801 # TODO: callback into retrainable parameter 802 #retrainable = property(fget=_getRetrainable, fset=_setRetrainable, 803 # doc="Specifies either classifier should be retrainable") 804