Package mvpa :: Package tests :: Module test_transerror
[hide private]
[frames] | no frames]

Source Code for Module mvpa.tests.test_transerror

  1  #emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  #ex: set sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Unit tests for PyMVPA classifier cross-validation""" 
 10   
 11  import unittest 
 12  from mvpa.support.copy import copy 
 13   
 14  from mvpa.base import externals 
 15  from mvpa.datasets import Dataset 
 16  from mvpa.datasets.splitters import OddEvenSplitter 
 17   
 18  from mvpa.clfs.meta import MulticlassClassifier 
 19  from mvpa.clfs.transerror import \ 
 20       TransferError, ConfusionMatrix, ConfusionBasedError 
 21  from mvpa.algorithms.cvtranserror import CrossValidatedTransferError 
 22   
 23  from mvpa.clfs.stats import MCNullDist 
 24   
 25  from mvpa.misc.exceptions import UnknownStateError 
 26   
 27  from tests_warehouse import datasets, sweepargs 
 28  from tests_warehouse_clfs import * 
 29   
30 -class ErrorsTests(unittest.TestCase):
31
32 - def testConfusionMatrix(self):
33 data = N.array([1,2,1,2,2,2,3,2,1], ndmin=2).T 34 reg = [1,1,1,2,2,2,3,3,3] 35 regl = [1,2,1,2,2,2,3,2,1] 36 correct_cm = [[2,0,1],[1,3,1],[0,0,1]] 37 # Check if we are ok with any input type - either list, or N.array, or tuple 38 for t in [reg, tuple(reg), list(reg), N.array(reg)]: 39 for p in [regl, tuple(regl), list(regl), N.array(regl)]: 40 cm = ConfusionMatrix(targets=t, predictions=p) 41 # check table content 42 self.failUnless((cm.matrix == correct_cm).all()) 43 44 45 # Do a bit more thorough checking 46 cm = ConfusionMatrix() 47 self.failUnlessRaises(ZeroDivisionError, lambda x:x.percentCorrect, cm) 48 """No samples -- raise exception""" 49 50 cm.add(reg, regl) 51 52 self.failUnlessEqual(len(cm.sets), 1, 53 msg="Should have a single set so far") 54 self.failUnlessEqual(cm.matrix.shape, (3,3), 55 msg="should be square matrix (len(reglabels) x len(reglabels)") 56 57 self.failUnlessRaises(ValueError, cm.add, reg, N.array([1])) 58 """ConfusionMatrix must complaint if number of samples different""" 59 60 # check table content 61 self.failUnless((cm.matrix == correct_cm).all()) 62 63 # lets add with new labels (not yet known) 64 cm.add(reg, N.array([1,4,1,2,2,2,4,2,1])) 65 66 self.failUnlessEqual(cm.labels, [1,2,3,4], 67 msg="We should have gotten 4th label") 68 69 matrices = cm.matrices # separate CM per each given set 70 self.failUnlessEqual(len(matrices), 2, 71 msg="Have gotten two splits") 72 73 self.failUnless((matrices[0].matrix + matrices[1].matrix == cm.matrix).all(), 74 msg="Total votes should match the sum across split CMs") 75 76 # check pretty print 77 # just a silly test to make sure that printing works 78 self.failUnless(len(cm.asstring( 79 header=True, summary=True, 80 description=True))>100) 81 self.failUnless(len(str(cm))>100) 82 # and that it knows some parameters for printing 83 self.failUnless(len(cm.asstring(summary=True, 84 header=False))>100) 85 86 # lets check iadd -- just itself to itself 87 cm += cm 88 self.failUnlessEqual(len(cm.matrices), 4, msg="Must be 4 sets now") 89 90 # lets check add -- just itself to itself 91 cm2 = cm + cm 92 self.failUnlessEqual(len(cm2.matrices), 8, msg="Must be 8 sets now") 93 self.failUnlessEqual(cm2.percentCorrect, cm.percentCorrect, 94 msg="Percent of corrrect should remain the same ;-)") 95 96 self.failUnlessEqual(cm2.error, 1.0-cm.percentCorrect/100.0, 97 msg="Test if we get proper error value")
98 99
100 - def testConfusionMatrixACC(self):
101 reg = [0,0,1,1] 102 regl = [1,0,1,0] 103 cm = ConfusionMatrix(targets=reg, predictions=regl) 104 self.failUnless('ACC% 50' in str(cm))
105 106
108 data = N.array([1,2,1,2,2,2,3,2,1], ndmin=2).T 109 reg = [1,1,1,2,2,2,3,3,3] 110 regl = [1,2,1,2,2,2,3,2,1] 111 correct_cm = [[2,0,1], [1,3,1], [0,0,1]] 112 lm = {'apple':1, 'orange':2, 'shitty apple':1, 'candy':3} 113 cm = ConfusionMatrix(targets=reg, predictions=regl, 114 labels_map=lm) 115 # check table content 116 self.failUnless((cm.matrix == correct_cm).all()) 117 # assure that all labels are somewhere listed ;-) 118 s = str(cm) 119 for l in lm.keys(): 120 self.failUnless(l in s)
121 122 123 124 @sweepargs(l_clf=clfswh['linear', 'svm'])
125 - def testConfusionBasedError(self, l_clf):
126 train = datasets['uni2medium_train'] 127 # to check if we fail to classify for 3 labels 128 test3 = datasets['uni3medium_train'] 129 err = ConfusionBasedError(clf=l_clf) 130 terr = TransferError(clf=l_clf) 131 132 self.failUnlessRaises(UnknownStateError, err, None) 133 """Shouldn't be able to access the state yet""" 134 135 l_clf.train(train) 136 self.failUnlessEqual(err(None), terr(train), 137 msg="ConfusionBasedError should be equal to TransferError on" + 138 " traindataset") 139 140 # this will print nasty WARNING but it is ok -- it is just checking code 141 # NB warnings are not printed while doing whole testing 142 self.failIf(terr(test3) is None) 143 144 # try copying the beast 145 terr_copy = copy(terr)
146 147 148 @sweepargs(l_clf=clfswh['linear', 'svm'])
149 - def testNullDistProb(self, l_clf):
150 train = datasets['uni2medium'] 151 152 # define class to estimate NULL distribution of errors 153 # use left tail of the distribution since we use MeanMatchFx as error 154 # function and lower is better 155 terr = TransferError(clf=l_clf, 156 null_dist=MCNullDist(permutations=10, 157 tail='left')) 158 159 # check reasonable error range 160 err = terr(train, train) 161 self.failUnless(err < 0.4) 162 163 # check that the result is highly significant since we know that the 164 # data has signal 165 null_prob = terr.null_prob 166 self.failUnless(null_prob < 0.01, 167 msg="Failed to check that the result is highly significant " 168 "(got %f) since we know that the data has signal" 169 % null_prob)
170 171 172 @sweepargs(l_clf=clfswh['linear', 'svm'])
173 - def testPerSampleError(self, l_clf):
174 train = datasets['uni2medium'] 175 terr = TransferError(clf=l_clf, enable_states=['samples_error']) 176 err = terr(train, train) 177 se = terr.samples_error 178 179 # one error per sample 180 self.failUnless(len(se) == train.nsamples) 181 # for this simple test it can only be correct or misclassified 182 # (boolean) 183 self.failUnless( 184 N.sum(N.array(se.values(), dtype='float') \ 185 - N.array(se.values(), dtype='b')) == 0)
186 187 188 @sweepargs(clf=clfswh['multiclass'])
189 - def testAUC(self, clf):
190 """Test AUC computation 191 """ 192 if isinstance(clf, MulticlassClassifier): 193 # TODO: handle those values correctly 194 return 195 clf.states._changeTemporarily(enable_states = ['values']) 196 # uni2 dataset with reordered labels 197 ds2 = datasets['uni2small'].copy() 198 ds2.labels = 1 - ds2.labels # revert labels 199 # same with uni3 200 ds3 = datasets['uni3small'].copy() 201 ul = ds3.uniquelabels 202 nl = ds3.labels.copy() 203 for l in xrange(3): 204 nl[ds3.labels == ul[l]] = ul[(l+1)%3] 205 ds3.labels = nl 206 for ds in [datasets['uni2small'], ds2, 207 datasets['uni3small'], ds3]: 208 cv = CrossValidatedTransferError( 209 TransferError(clf), 210 OddEvenSplitter(), 211 enable_states=['confusion', 'training_confusion']) 212 cverror = cv(ds) 213 stats = cv.confusion.stats 214 Nlabels = len(ds.uniquelabels) 215 # so we at least do slightly above chance 216 self.failUnless(stats['ACC'] > 1.2 / Nlabels) 217 auc = stats['AUC'] 218 if (Nlabels == 2) or (Nlabels > 2 and auc[0] is not N.nan): 219 mauc = N.min(stats['AUC']) 220 if cfg.getboolean('tests', 'labile', default='yes'): 221 self.failUnless(mauc > 0.55, 222 msg='All AUCs must be above chance. Got minimal ' 223 'AUC=%.2g among %s' % (mauc, stats['AUC'])) 224 clf.states._resetEnabledTemporarily()
225 226 227 228
229 - def testConfusionPlot(self):
230 """Based on existing cell dataset results. 231 232 Let in for possible future testing, but is not a part of the 233 unittests suite 234 """ 235 #from matplotlib import rc as rcmpl 236 #rcmpl('font',**{'family':'sans-serif','sans-serif':['DejaVu Sans']}) 237 ##rcmpl('text', usetex=True) 238 ##rcmpl('font', family='sans', style='normal', variant='normal', 239 ## weight='bold', stretch='normal', size='large') 240 #import numpy as N 241 #from mvpa.clfs.transerror import \ 242 # TransferError, ConfusionMatrix, ConfusionBasedError 243 244 array = N.array 245 uint8 = N.uint8 246 sets = [ 247 (array([47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 248 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 249 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 250 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 251 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 252 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 253 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 254 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 255 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 256 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 257 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 258 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 259 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44], dtype=uint8), 260 array([40, 39, 47, 43, 45, 41, 44, 41, 46, 42, 47, 39, 38, 43, 45, 41, 44, 261 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 46, 262 45, 38, 44, 39, 46, 38, 39, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 263 40, 47, 43, 45, 41, 44, 40, 46, 42, 38, 39, 40, 43, 45, 41, 44, 39, 264 46, 42, 47, 38, 38, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 265 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 38, 266 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 47, 43, 45, 41, 44, 40, 46, 267 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 43, 45, 41, 268 44, 47, 46, 42, 47, 38, 39, 43, 45, 40, 44, 40, 46, 42, 47, 39, 40, 269 43, 45, 41, 44, 38, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 41, 270 47, 39, 38, 46, 45, 41, 44, 40, 46, 42, 40, 38, 38, 43, 45, 41, 44, 271 40, 45, 42, 47, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 38, 42, 43, 272 45, 41, 44, 39, 46, 42, 39, 39, 39, 47, 45, 41, 44], dtype=uint8)), 273 (array([40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 274 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 275 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 276 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 277 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 278 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 279 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 280 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 281 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 282 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 283 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 284 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 285 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43], dtype=uint8), 286 array([40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 47, 46, 42, 47, 39, 40, 43, 287 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 288 39, 38, 43, 45, 41, 44, 39, 46, 42, 47, 47, 47, 43, 45, 41, 44, 40, 289 46, 42, 43, 39, 38, 43, 45, 41, 44, 38, 38, 42, 38, 39, 38, 43, 45, 290 41, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 40, 40, 42, 47, 40, 291 40, 43, 45, 41, 44, 38, 38, 42, 47, 38, 38, 47, 45, 41, 44, 40, 46, 292 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 47, 39, 43, 45, 41, 293 44, 40, 46, 42, 39, 39, 42, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 294 43, 45, 41, 44, 47, 46, 42, 40, 39, 39, 43, 45, 41, 44, 40, 46, 42, 295 47, 39, 38, 43, 45, 40, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 296 38, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 46, 47, 38, 39, 43, 297 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 39, 298 39, 38, 47, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43], dtype=uint8)), 299 (array([45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 300 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 301 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 302 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 303 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 304 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 305 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 306 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 307 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 308 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 309 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 310 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 311 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47], dtype=uint8), 312 array([45, 41, 44, 40, 46, 42, 47, 39, 46, 43, 45, 41, 44, 40, 46, 42, 47, 313 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 314 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 43, 43, 45, 315 40, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 316 40, 43, 45, 41, 44, 40, 47, 42, 38, 47, 38, 43, 45, 41, 44, 40, 40, 317 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 318 44, 38, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 40, 38, 319 43, 45, 41, 44, 40, 46, 38, 38, 39, 38, 43, 45, 41, 44, 39, 46, 42, 320 47, 40, 39, 43, 45, 38, 44, 38, 46, 42, 47, 47, 40, 43, 45, 41, 44, 321 40, 40, 42, 47, 40, 38, 43, 39, 41, 44, 41, 46, 42, 39, 39, 38, 38, 322 45, 41, 44, 38, 46, 40, 46, 46, 46, 43, 45, 38, 44, 40, 46, 42, 39, 323 39, 45, 43, 45, 41, 44, 38, 46, 42, 38, 39, 39, 43, 45, 41, 38, 40, 324 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 40], dtype=uint8)), 325 (array([39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 326 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 327 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 328 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 329 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 330 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 331 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 332 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 333 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 334 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 335 39, 38, 43, 45, 41, 44, 40, 46, 42, 39, 38, 43, 45, 41, 44, 40, 46, 336 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 337 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40], dtype=uint8), 338 array([39, 38, 43, 45, 41, 44, 40, 46, 38, 47, 39, 38, 43, 45, 41, 44, 40, 339 46, 42, 47, 39, 38, 43, 45, 41, 44, 41, 46, 42, 47, 39, 38, 43, 45, 340 41, 44, 40, 38, 43, 47, 38, 38, 43, 45, 41, 44, 39, 46, 42, 39, 39, 341 38, 43, 45, 41, 44, 43, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 342 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 39, 38, 38, 43, 45, 40, 343 44, 47, 46, 38, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 344 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 43, 45, 47, 44, 45, 46, 42, 345 38, 39, 41, 43, 45, 41, 44, 38, 38, 42, 39, 40, 40, 43, 45, 41, 39, 346 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 47, 42, 47, 38, 38, 43, 347 45, 41, 44, 47, 46, 42, 47, 40, 47, 43, 45, 41, 44, 40, 46, 42, 47, 348 38, 39, 43, 45, 41, 44, 40, 46, 42, 39, 38, 43, 45, 46, 44, 38, 46, 349 42, 47, 38, 44, 43, 45, 42, 44, 41, 46, 42, 47, 47, 38, 43, 45, 41, 350 44, 38, 46, 42, 39, 39, 38, 43, 45, 41, 44, 40], dtype=uint8)), 351 (array([46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 352 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 353 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 354 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 355 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 356 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 357 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 358 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 359 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 360 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 361 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 362 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 363 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45], dtype=uint8), 364 array([46, 42, 39, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 42, 43, 45, 365 42, 44, 40, 46, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 366 40, 43, 45, 41, 44, 41, 46, 42, 38, 39, 38, 43, 45, 41, 44, 38, 46, 367 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 46, 38, 38, 43, 45, 41, 368 44, 39, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 369 43, 45, 41, 44, 40, 47, 42, 47, 38, 39, 43, 45, 41, 44, 39, 46, 42, 370 47, 39, 46, 43, 45, 41, 44, 39, 46, 42, 39, 39, 38, 43, 45, 41, 44, 371 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 43, 372 45, 41, 44, 40, 38, 42, 46, 39, 38, 43, 45, 41, 44, 38, 46, 42, 46, 373 46, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 38, 38, 45, 41, 44, 38, 374 38, 42, 43, 39, 40, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 47, 45, 375 46, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 40, 46, 42, 47, 40, 376 38, 43, 45, 41, 44, 38, 46, 42, 38, 39, 38, 47, 45], dtype=uint8)), 377 (array([41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 378 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 379 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 380 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 381 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 382 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 383 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 384 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 385 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 386 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 387 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 388 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 389 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39], dtype=uint8), 390 array([41, 44, 38, 46, 42, 47, 39, 47, 40, 45, 41, 44, 40, 46, 42, 38, 40, 391 38, 43, 45, 41, 44, 40, 46, 42, 38, 38, 38, 43, 45, 41, 44, 46, 38, 392 42, 40, 38, 39, 43, 45, 41, 44, 41, 46, 42, 47, 47, 38, 43, 45, 41, 393 44, 40, 46, 42, 38, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 43, 39, 394 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42, 395 40, 39, 38, 43, 45, 41, 44, 38, 46, 42, 39, 39, 39, 43, 45, 41, 44, 396 40, 46, 42, 39, 38, 47, 43, 45, 38, 44, 40, 38, 42, 47, 38, 38, 43, 397 45, 41, 44, 40, 38, 46, 47, 38, 38, 43, 45, 41, 44, 41, 46, 42, 40, 398 38, 38, 40, 45, 41, 44, 40, 40, 42, 43, 38, 40, 43, 39, 41, 44, 40, 399 40, 42, 47, 38, 46, 43, 45, 41, 44, 47, 41, 42, 43, 40, 47, 43, 45, 400 41, 44, 41, 38, 42, 40, 39, 40, 43, 45, 41, 44, 39, 43, 42, 47, 39, 401 40, 43, 45, 41, 44, 42, 46, 42, 47, 40, 46, 43, 45, 41, 44, 38, 46, 402 42, 47, 47, 38, 43, 45, 41, 44, 40, 38, 39, 47, 38], dtype=uint8)), 403 (array([38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 404 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 405 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 406 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 407 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 408 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 409 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 410 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 411 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 412 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 413 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 414 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 415 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46], dtype=uint8), 416 array([39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 41, 46, 417 42, 47, 47, 39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41, 418 44, 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 45, 38, 419 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 420 39, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 421 40, 46, 42, 47, 40, 39, 43, 45, 41, 44, 40, 39, 42, 40, 39, 38, 43, 422 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 38, 46, 42, 39, 423 39, 47, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 424 46, 42, 46, 47, 39, 47, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 425 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 38, 46, 42, 47, 39, 426 38, 43, 45, 42, 44, 39, 47, 42, 39, 39, 47, 43, 47, 40, 44, 40, 46, 427 42, 39, 39, 38, 39, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 428 44, 46, 38, 42, 47, 39, 43, 43, 45, 41, 44, 40, 46], dtype=uint8)), 429 (array([42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 430 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 431 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 432 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 433 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 434 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 435 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 436 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 437 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 438 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 439 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 440 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 441 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45], dtype=uint8), 442 array([42, 38, 38, 40, 43, 45, 41, 44, 39, 46, 42, 47, 39, 38, 43, 45, 41, 443 44, 39, 38, 42, 47, 41, 40, 43, 45, 41, 44, 40, 41, 42, 47, 38, 46, 444 43, 45, 41, 44, 41, 41, 42, 40, 39, 39, 43, 45, 41, 44, 46, 45, 42, 445 39, 39, 40, 43, 45, 41, 44, 40, 46, 42, 40, 44, 38, 43, 41, 41, 44, 446 39, 46, 42, 39, 39, 39, 43, 45, 41, 44, 40, 43, 42, 47, 39, 39, 43, 447 45, 41, 44, 40, 47, 42, 38, 46, 39, 47, 45, 41, 44, 39, 46, 42, 47, 448 41, 38, 43, 45, 41, 44, 42, 46, 42, 46, 39, 38, 43, 45, 41, 44, 41, 449 46, 42, 46, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 38, 38, 43, 45, 450 41, 44, 38, 46, 42, 39, 40, 43, 43, 45, 41, 44, 39, 38, 40, 40, 38, 451 38, 43, 45, 41, 44, 41, 40, 42, 39, 39, 39, 43, 45, 41, 44, 40, 46, 452 42, 47, 40, 40, 43, 45, 41, 44, 40, 46, 42, 41, 39, 39, 43, 45, 41, 453 44, 40, 38, 42, 40, 39, 46, 43, 45, 41, 44, 47, 46, 42, 47, 39, 38, 454 43, 45, 41, 44, 41, 46, 42, 43, 39, 39, 43, 45], dtype=uint8))] 455 labels_map = {'12kHz': 40, 456 '20kHz': 41, 457 '30kHz': 42, 458 '3kHz': 38, 459 '7kHz': 39, 460 'song1': 43, 461 'song2': 44, 462 'song3': 45, 463 'song4': 46, 464 'song5': 47} 465 try: 466 cm = ConfusionMatrix(sets=sets, labels_map=labels_map) 467 except: 468 self.fail() 469 self.failUnless('3kHz / 38' in cm.asstring()) 470 471 if externals.exists("pylab plottable"): 472 import pylab as P 473 P.figure() 474 labels_order = ("3kHz", "7kHz", "12kHz", "20kHz","30kHz", None, 475 "song1","song2","song3","song4","song5") 476 #print cm 477 #fig, im, cb = cm.plot(origin='lower', labels=labels_order) 478 fig, im, cb = cm.plot(labels=labels_order[1:2] + labels_order[:1] 479 + labels_order[2:], numbers=True) 480 self.failUnless(cm._plotted_confusionmatrix[0,0] == cm.matrix[1,1]) 481 self.failUnless(cm._plotted_confusionmatrix[0,1] == cm.matrix[1,0]) 482 self.failUnless(cm._plotted_confusionmatrix[1,1] == cm.matrix[0,0]) 483 self.failUnless(cm._plotted_confusionmatrix[1,0] == cm.matrix[0,1]) 484 P.close(fig) 485 fig, im, cb = cm.plot(labels=labels_order, numbers=True) 486 P.close(fig)
487 # P.show() 488 489
490 -def suite():
491 return unittest.makeSuite(ErrorsTests)
492 493 494 if __name__ == '__main__': 495 import runner 496