Package mvpa :: Package tests :: Module test_splitter
[hide private]
[frames] | no frames]

Source Code for Module mvpa.tests.test_splitter

  1  #emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  #ex: set sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Unit tests for PyMVPA pattern handling""" 
 10   
 11  from mvpa.datasets.masked import MaskedDataset 
 12  from mvpa.datasets.splitters import NFoldSplitter, OddEvenSplitter, \ 
 13                                     NoneSplitter, HalfSplitter, CustomSplitter 
 14  import unittest 
 15  import numpy as N 
 16   
 17   
18 -class SplitterTests(unittest.TestCase):
19
20 - def setUp(self):
21 self.data = \ 22 MaskedDataset(samples=N.random.normal(size=(100,10)), 23 labels=[ i%4 for i in range(100) ], 24 chunks=[ i/10 for i in range(100)])
25 26
27 - def testSimplestCVPatGen(self):
28 # create the generator 29 nfs = NFoldSplitter(cvtype=1) 30 31 # now get the xval pattern sets One-Fold CV) 32 xvpat = [ (train, test) for (train,test) in nfs(self.data) ] 33 34 self.failUnless( len(xvpat) == 10 ) 35 36 for i,p in enumerate(xvpat): 37 self.failUnless( len(p) == 2 ) 38 self.failUnless( p[0].nsamples == 90 ) 39 self.failUnless( p[1].nsamples == 10 ) 40 self.failUnless( p[1].chunks[0] == i )
41 42
43 - def testOddEvenSplit(self):
44 oes = OddEvenSplitter() 45 46 splits = [ (train, test) for (train, test) in oes(self.data) ] 47 48 self.failUnless(len(splits) == 2) 49 50 for i,p in enumerate(splits): 51 self.failUnless( len(p) == 2 ) 52 self.failUnless( p[0].nsamples == 50 ) 53 self.failUnless( p[1].nsamples == 50 ) 54 55 self.failUnless((splits[0][1].uniquechunks == [1, 3, 5, 7, 9]).all()) 56 self.failUnless((splits[0][0].uniquechunks == [0, 2, 4, 6, 8]).all()) 57 self.failUnless((splits[1][0].uniquechunks == [1, 3, 5, 7, 9]).all()) 58 self.failUnless((splits[1][1].uniquechunks == [0, 2, 4, 6, 8]).all()) 59 60 # check if it works on pure odd and even chunk ids 61 moresplits = [ (train, test) for (train, test) in oes(splits[0][0])] 62 63 for split in moresplits: 64 self.failUnless(split[0] != None) 65 self.failUnless(split[1] != None)
66 67
68 - def testHalfSplit(self):
69 hs = HalfSplitter() 70 71 splits = [ (train, test) for (train, test) in hs(self.data) ] 72 73 self.failUnless(len(splits) == 2) 74 75 for i,p in enumerate(splits): 76 self.failUnless( len(p) == 2 ) 77 self.failUnless( p[0].nsamples == 50 ) 78 self.failUnless( p[1].nsamples == 50 ) 79 80 self.failUnless((splits[0][1].uniquechunks == [0, 1, 2, 3, 4]).all()) 81 self.failUnless((splits[0][0].uniquechunks == [5, 6, 7, 8, 9]).all()) 82 self.failUnless((splits[1][1].uniquechunks == [5, 6, 7, 8, 9]).all()) 83 self.failUnless((splits[1][0].uniquechunks == [0, 1, 2, 3, 4]).all()) 84 85 # check if it works on pure odd and even chunk ids 86 moresplits = [ (train, test) for (train, test) in hs(splits[0][0])] 87 88 for split in moresplits: 89 self.failUnless(split[0] != None) 90 self.failUnless(split[1] != None)
91 92
93 - def testCustomSplit(self):
94 #simulate half splitter 95 hs = CustomSplitter([(None,[0,1,2,3,4]),(None,[5,6,7,8,9])]) 96 splits = list(hs(self.data)) 97 self.failUnless(len(splits) == 2) 98 99 for i,p in enumerate(splits): 100 self.failUnless( len(p) == 2 ) 101 self.failUnless( p[0].nsamples == 50 ) 102 self.failUnless( p[1].nsamples == 50 ) 103 104 self.failUnless((splits[0][1].uniquechunks == [0, 1, 2, 3, 4]).all()) 105 self.failUnless((splits[0][0].uniquechunks == [5, 6, 7, 8, 9]).all()) 106 self.failUnless((splits[1][1].uniquechunks == [5, 6, 7, 8, 9]).all()) 107 self.failUnless((splits[1][0].uniquechunks == [0, 1, 2, 3, 4]).all()) 108 109 110 # check fully customized split with working and validation set specified 111 cs = CustomSplitter([([0,3,4],[5,9])]) 112 splits = list(cs(self.data)) 113 self.failUnless(len(splits) == 1) 114 115 for i,p in enumerate(splits): 116 self.failUnless( len(p) == 2 ) 117 self.failUnless( p[0].nsamples == 30 ) 118 self.failUnless( p[1].nsamples == 20 ) 119 120 self.failUnless((splits[0][1].uniquechunks == [5, 9]).all()) 121 self.failUnless((splits[0][0].uniquechunks == [0, 3, 4]).all()) 122 123 # full test with additional sampling and 3 datasets per split 124 cs = CustomSplitter([([0,3,4],[5,9],[2])], 125 nperlabel=[3,4,1], 126 nrunspersplit=3) 127 splits = list(cs(self.data)) 128 self.failUnless(len(splits) == 3) 129 130 for i,p in enumerate(splits): 131 self.failUnless( len(p) == 3 ) 132 self.failUnless( p[0].nsamples == 12 ) 133 self.failUnless( p[1].nsamples == 16 ) 134 self.failUnless( p[2].nsamples == 4 ) 135 136 # lets test selection of samples by ratio and combined with 137 # other ways 138 cs = CustomSplitter([([0,3,4],[5,9],[2])], 139 nperlabel=[[0.3, 0.6, 1.0, 0.5], 140 0.5, 141 'all'], 142 nrunspersplit=3) 143 csall = CustomSplitter([([0,3,4],[5,9],[2])], 144 nrunspersplit=3) 145 # lets craft simpler dataset 146 #ds = Dataset(samples=N.arange(12), labels=[1]*6+[2]*6, chunks=1) 147 splits = list(cs(self.data)) 148 splitsall = list(csall(self.data)) 149 150 self.failUnless(len(splits) == 3) 151 ul = self.data.uniquelabels 152 153 self.failUnless(((N.array(splitsall[0][0].samplesperlabel.values()) 154 *[0.3, 0.6, 1.0, 0.5]).round().astype(int) == 155 N.array(splits[0][0].samplesperlabel.values())).all()) 156 157 self.failUnless(((N.array(splitsall[0][1].samplesperlabel.values())*0.5 158 ).round().astype(int) == 159 N.array(splits[0][1].samplesperlabel.values())).all()) 160 161 self.failUnless((N.array(splitsall[0][2].samplesperlabel.values()) == 162 N.array(splits[0][2].samplesperlabel.values())).all())
163 164
165 - def testNoneSplitter(self):
166 nos = NoneSplitter() 167 splits = [ (train, test) for (train, test) in nos(self.data) ] 168 self.failUnless(len(splits) == 1) 169 self.failUnless(splits[0][0] == None) 170 self.failUnless(splits[0][1].nsamples == 100) 171 172 nos = NoneSplitter(mode='first') 173 splits = [ (train, test) for (train, test) in nos(self.data) ] 174 self.failUnless(len(splits) == 1) 175 self.failUnless(splits[0][1] == None) 176 self.failUnless(splits[0][0].nsamples == 100) 177 178 179 # test sampling tools 180 # specified value 181 nos = NoneSplitter(nrunspersplit=3, 182 nperlabel=10) 183 splits = [ (train, test) for (train, test) in nos(self.data) ] 184 185 self.failUnless(len(splits) == 3) 186 for split in splits: 187 self.failUnless(split[0] == None) 188 self.failUnless(split[1].nsamples == 40) 189 self.failUnless(split[1].samplesperlabel.values() == [10,10,10,10]) 190 191 # auto-determined 192 nos = NoneSplitter(nrunspersplit=3, 193 nperlabel='equal') 194 splits = [ (train, test) for (train, test) in nos(self.data) ] 195 196 self.failUnless(len(splits) == 3) 197 for split in splits: 198 self.failUnless(split[0] == None) 199 self.failUnless(split[1].nsamples == 100) 200 self.failUnless(split[1].samplesperlabel.values() == [25,25,25,25])
201 202
203 - def testLabelSplitter(self):
204 oes = OddEvenSplitter(attr='labels') 205 206 splits = [ (first, second) for (first, second) in oes(self.data) ] 207 208 self.failUnless((splits[0][0].uniquelabels == [0,2]).all()) 209 self.failUnless((splits[0][1].uniquelabels == [1,3]).all()) 210 self.failUnless((splits[1][0].uniquelabels == [1,3]).all()) 211 self.failUnless((splits[1][1].uniquelabels == [0,2]).all())
212 213
214 - def testCountedSplitting(self):
215 # count > #chunks, should result in 10 splits 216 nchunks = len(self.data.uniquechunks) 217 for strategy in NFoldSplitter._STRATEGIES: 218 for count, target in [ (nchunks*2, nchunks), 219 (nchunks, nchunks), 220 (nchunks-1, nchunks-1), 221 (3, 3), 222 (0, 0), 223 (1, 1) 224 ]: 225 nfs = NFoldSplitter(cvtype=1, count=count, strategy=strategy) 226 splits = [ (train, test) for (train,test) in nfs(self.data) ] 227 self.failUnless(len(splits) == target) 228 chosenchunks = [int(s[1].uniquechunks) for s in splits] 229 if strategy == 'first': 230 self.failUnlessEqual(chosenchunks, range(target)) 231 elif strategy == 'equidistant': 232 if target == 3: 233 self.failUnlessEqual(chosenchunks, [0, 3, 7]) 234 elif strategy == 'random': 235 # none is selected twice 236 self.failUnless(len(set(chosenchunks)) == len(chosenchunks)) 237 self.failUnless(target == len(chosenchunks)) 238 else: 239 raise RuntimeError, "Add unittest for strategy %s" \ 240 % strategy
241 242
243 -def suite():
244 return unittest.makeSuite(SplitterTests)
245 246 247 if __name__ == '__main__': 248 import runner 249