1
2
3
4
5
6
7
8
9 """Unit tests for PyMVPA pattern handling"""
10
11 from mvpa.datasets.masked import MaskedDataset
12 from mvpa.datasets.splitters import NFoldSplitter, OddEvenSplitter, \
13 NoneSplitter, HalfSplitter, CustomSplitter
14 import unittest
15 import numpy as N
16
17
19
21 self.data = \
22 MaskedDataset(samples=N.random.normal(size=(100,10)),
23 labels=[ i%4 for i in range(100) ],
24 chunks=[ i/10 for i in range(100)])
25
26
28
29 nfs = NFoldSplitter(cvtype=1)
30
31
32 xvpat = [ (train, test) for (train,test) in nfs(self.data) ]
33
34 self.failUnless( len(xvpat) == 10 )
35
36 for i,p in enumerate(xvpat):
37 self.failUnless( len(p) == 2 )
38 self.failUnless( p[0].nsamples == 90 )
39 self.failUnless( p[1].nsamples == 10 )
40 self.failUnless( p[1].chunks[0] == i )
41
42
44 oes = OddEvenSplitter()
45
46 splits = [ (train, test) for (train, test) in oes(self.data) ]
47
48 self.failUnless(len(splits) == 2)
49
50 for i,p in enumerate(splits):
51 self.failUnless( len(p) == 2 )
52 self.failUnless( p[0].nsamples == 50 )
53 self.failUnless( p[1].nsamples == 50 )
54
55 self.failUnless((splits[0][1].uniquechunks == [1, 3, 5, 7, 9]).all())
56 self.failUnless((splits[0][0].uniquechunks == [0, 2, 4, 6, 8]).all())
57 self.failUnless((splits[1][0].uniquechunks == [1, 3, 5, 7, 9]).all())
58 self.failUnless((splits[1][1].uniquechunks == [0, 2, 4, 6, 8]).all())
59
60
61 moresplits = [ (train, test) for (train, test) in oes(splits[0][0])]
62
63 for split in moresplits:
64 self.failUnless(split[0] != None)
65 self.failUnless(split[1] != None)
66
67
69 hs = HalfSplitter()
70
71 splits = [ (train, test) for (train, test) in hs(self.data) ]
72
73 self.failUnless(len(splits) == 2)
74
75 for i,p in enumerate(splits):
76 self.failUnless( len(p) == 2 )
77 self.failUnless( p[0].nsamples == 50 )
78 self.failUnless( p[1].nsamples == 50 )
79
80 self.failUnless((splits[0][1].uniquechunks == [0, 1, 2, 3, 4]).all())
81 self.failUnless((splits[0][0].uniquechunks == [5, 6, 7, 8, 9]).all())
82 self.failUnless((splits[1][1].uniquechunks == [5, 6, 7, 8, 9]).all())
83 self.failUnless((splits[1][0].uniquechunks == [0, 1, 2, 3, 4]).all())
84
85
86 moresplits = [ (train, test) for (train, test) in hs(splits[0][0])]
87
88 for split in moresplits:
89 self.failUnless(split[0] != None)
90 self.failUnless(split[1] != None)
91
92
94
95 hs = CustomSplitter([(None,[0,1,2,3,4]),(None,[5,6,7,8,9])])
96 splits = list(hs(self.data))
97 self.failUnless(len(splits) == 2)
98
99 for i,p in enumerate(splits):
100 self.failUnless( len(p) == 2 )
101 self.failUnless( p[0].nsamples == 50 )
102 self.failUnless( p[1].nsamples == 50 )
103
104 self.failUnless((splits[0][1].uniquechunks == [0, 1, 2, 3, 4]).all())
105 self.failUnless((splits[0][0].uniquechunks == [5, 6, 7, 8, 9]).all())
106 self.failUnless((splits[1][1].uniquechunks == [5, 6, 7, 8, 9]).all())
107 self.failUnless((splits[1][0].uniquechunks == [0, 1, 2, 3, 4]).all())
108
109
110
111 cs = CustomSplitter([([0,3,4],[5,9])])
112 splits = list(cs(self.data))
113 self.failUnless(len(splits) == 1)
114
115 for i,p in enumerate(splits):
116 self.failUnless( len(p) == 2 )
117 self.failUnless( p[0].nsamples == 30 )
118 self.failUnless( p[1].nsamples == 20 )
119
120 self.failUnless((splits[0][1].uniquechunks == [5, 9]).all())
121 self.failUnless((splits[0][0].uniquechunks == [0, 3, 4]).all())
122
123
124 cs = CustomSplitter([([0,3,4],[5,9],[2])],
125 nperlabel=[3,4,1],
126 nrunspersplit=3)
127 splits = list(cs(self.data))
128 self.failUnless(len(splits) == 3)
129
130 for i,p in enumerate(splits):
131 self.failUnless( len(p) == 3 )
132 self.failUnless( p[0].nsamples == 12 )
133 self.failUnless( p[1].nsamples == 16 )
134 self.failUnless( p[2].nsamples == 4 )
135
136
137
138 cs = CustomSplitter([([0,3,4],[5,9],[2])],
139 nperlabel=[[0.3, 0.6, 1.0, 0.5],
140 0.5,
141 'all'],
142 nrunspersplit=3)
143 csall = CustomSplitter([([0,3,4],[5,9],[2])],
144 nrunspersplit=3)
145
146
147 splits = list(cs(self.data))
148 splitsall = list(csall(self.data))
149
150 self.failUnless(len(splits) == 3)
151 ul = self.data.uniquelabels
152
153 self.failUnless(((N.array(splitsall[0][0].samplesperlabel.values())
154 *[0.3, 0.6, 1.0, 0.5]).round().astype(int) ==
155 N.array(splits[0][0].samplesperlabel.values())).all())
156
157 self.failUnless(((N.array(splitsall[0][1].samplesperlabel.values())*0.5
158 ).round().astype(int) ==
159 N.array(splits[0][1].samplesperlabel.values())).all())
160
161 self.failUnless((N.array(splitsall[0][2].samplesperlabel.values()) ==
162 N.array(splits[0][2].samplesperlabel.values())).all())
163
164
166 nos = NoneSplitter()
167 splits = [ (train, test) for (train, test) in nos(self.data) ]
168 self.failUnless(len(splits) == 1)
169 self.failUnless(splits[0][0] == None)
170 self.failUnless(splits[0][1].nsamples == 100)
171
172 nos = NoneSplitter(mode='first')
173 splits = [ (train, test) for (train, test) in nos(self.data) ]
174 self.failUnless(len(splits) == 1)
175 self.failUnless(splits[0][1] == None)
176 self.failUnless(splits[0][0].nsamples == 100)
177
178
179
180
181 nos = NoneSplitter(nrunspersplit=3,
182 nperlabel=10)
183 splits = [ (train, test) for (train, test) in nos(self.data) ]
184
185 self.failUnless(len(splits) == 3)
186 for split in splits:
187 self.failUnless(split[0] == None)
188 self.failUnless(split[1].nsamples == 40)
189 self.failUnless(split[1].samplesperlabel.values() == [10,10,10,10])
190
191
192 nos = NoneSplitter(nrunspersplit=3,
193 nperlabel='equal')
194 splits = [ (train, test) for (train, test) in nos(self.data) ]
195
196 self.failUnless(len(splits) == 3)
197 for split in splits:
198 self.failUnless(split[0] == None)
199 self.failUnless(split[1].nsamples == 100)
200 self.failUnless(split[1].samplesperlabel.values() == [25,25,25,25])
201
202
204 oes = OddEvenSplitter(attr='labels')
205
206 splits = [ (first, second) for (first, second) in oes(self.data) ]
207
208 self.failUnless((splits[0][0].uniquelabels == [0,2]).all())
209 self.failUnless((splits[0][1].uniquelabels == [1,3]).all())
210 self.failUnless((splits[1][0].uniquelabels == [1,3]).all())
211 self.failUnless((splits[1][1].uniquelabels == [0,2]).all())
212
213
215
216 nchunks = len(self.data.uniquechunks)
217 for strategy in NFoldSplitter._STRATEGIES:
218 for count, target in [ (nchunks*2, nchunks),
219 (nchunks, nchunks),
220 (nchunks-1, nchunks-1),
221 (3, 3),
222 (0, 0),
223 (1, 1)
224 ]:
225 nfs = NFoldSplitter(cvtype=1, count=count, strategy=strategy)
226 splits = [ (train, test) for (train,test) in nfs(self.data) ]
227 self.failUnless(len(splits) == target)
228 chosenchunks = [int(s[1].uniquechunks) for s in splits]
229 if strategy == 'first':
230 self.failUnlessEqual(chosenchunks, range(target))
231 elif strategy == 'equidistant':
232 if target == 3:
233 self.failUnlessEqual(chosenchunks, [0, 3, 7])
234 elif strategy == 'random':
235
236 self.failUnless(len(set(chosenchunks)) == len(chosenchunks))
237 self.failUnless(target == len(chosenchunks))
238 else:
239 raise RuntimeError, "Add unittest for strategy %s" \
240 % strategy
241
242
245
246
247 if __name__ == '__main__':
248 import runner
249