1
2
3
4
5
6
7
8
9 """Unit tests for PyMVPA nifti dataset"""
10
11 import unittest
12 import os.path
13 import numpy as N
14
15 from mvpa import pymvpa_dataroot
16 from mvpa.datasets.nifti import *
17 from mvpa.misc.exceptions import *
18 from mvpa.misc.fsl import FslEV3
19
21
23 data = NiftiDataset(samples=os.path.join(pymvpa_dataroot,'example4d'),
24 labels=[1,2])
25 self.failUnless(data.nfeatures == 294912)
26 self.failUnless(data.nsamples == 2)
27
28 self.failUnless((data.mapper.metric.elementsize \
29 == data.niftihdr['pixdim'][3:0:-1]).all())
30
31
32 nb22=N.array([i for i in data.mapper.getNeighborIn((1,1,1), 2.2)])
33 nb20=N.array([i for i in data.mapper.getNeighborIn((1,1,1), 2.0)])
34 self.failUnless(nb22.shape[0] == 7)
35 self.failUnless(nb20.shape[0] == 5)
36
37
38
39 self.failUnless(data.dt in [2.0, 2000.0])
40 self.failUnless(data.samplingrate in [5e-4, 5e-1])
41 merged = data + data
42
43 self.failUnless(merged.nfeatures == 294912)
44 self.failUnless(merged.nsamples == 4)
45
46
47
48 for k in merged.niftihdr.keys():
49 self.failUnless(N.mean(merged.niftihdr[k] == data.niftihdr[k]) == 1)
50
51
52 del data
53 self.failUnless(merged.samples[3, 120000] == merged.samples[1, 120000])
54
55
56 mask = N.zeros((24, 96, 128), dtype='bool')
57 mask[12,20,40] = True
58 nddata = NiftiDataset(samples=os.path.join(pymvpa_dataroot,'example4d'),
59 labels=[1,2],
60 mask=mask)
61 self.failUnless(nddata.nfeatures == 1)
62 rmap = nddata.mapReverse([44])
63 self.failUnless(rmap.shape == (24, 96, 128))
64 self.failUnless(N.sum(rmap) == 44)
65 self.failUnless(rmap[12,20,40] == 44)
66
67
69 data = NiftiDataset(samples=os.path.join(pymvpa_dataroot,'example4d'),
70 labels=[1,2])
71
72
73 vol = data.map2Nifti(N.ones((294912,), dtype='int16'))
74 self.failUnless(vol.data.shape == (24,96,128))
75 self.failUnless((vol.data == 1).all())
76
77
78 vol = data.map2Nifti(data)
79 self.failUnless(vol.data.shape == (2, 24, 96, 128))
80
81
97
98
117
118
120 self.failUnlessRaises(DatasetError, ERNiftiDataset)
121
122
123 tssrc = os.path.join(pymvpa_dataroot, 'bold')
124 evsrc = os.path.join(pymvpa_dataroot, 'fslev3.txt')
125 masrc = os.path.join(pymvpa_dataroot, 'mask')
126 evs = FslEV3(evsrc).toEvents()
127
128
129
130 self.failUnlessRaises(ValueError, ERNiftiDataset,
131 samples=tssrc, events=evs)
132
133
134 for ev in evs:
135 ev['label'] = 1
136
137
138
139 ds = ERNiftiDataset(samples=tssrc, events=evs)
140
141
142 self.failUnless(ds.nfeatures == 7201)
143 self.failUnless(ds.nsamples == len(evs))
144
145
146 origsamples = getNiftiFromAnySource(tssrc).data
147 for i, ev in enumerate(evs):
148 self.failUnless((ds.samples[i][:-1] \
149 == origsamples[ev['onset']:ev['onset'] + ev['duration']].ravel()
150 ).all())
151
152
153 ds = ERNiftiDataset(samples=tssrc, events=evs, evconv=True,
154 storeoffset=True)
155 self.failUnless(ds.nsamples == len(evs))
156
157
158 self.failUnless(ds.nfeatures == 3202)
159
160
161 nim = ds.map2Nifti()
162 self.failUnless(nim.data.shape == origsamples.shape)
163
164 nim = ds.map2Nifti(ds.samples[0])
165 self.failUnless(nim.data.shape == (4, 1, 20, 40))
166
167
169 tssrc = os.path.join(pymvpa_dataroot, 'bold')
170 masrc = os.path.join(pymvpa_dataroot, 'mask')
171
172
173
174
175 self.failUnlessRaises(Exception, NiftiDataset,
176 masrc, mask=masrc, labels=1, enforce4D=False)
177
178 ds = NiftiDataset(masrc, mask=masrc, labels=1)
179
180 plain_data = NiftiImage(masrc).data
181
182 self.failUnless(N.all(plain_data == \
183 ds.map2Nifti().data.reshape(plain_data.shape)))
184
185
186
187
188 self.failUnlessRaises(ValueError, NiftiDataset, (masrc, tssrc),
189 mask=masrc, labels=1)
190
191
192 dsfull = NiftiDataset(tssrc, mask=masrc, labels=1)
193 ds_selected = dsfull['samples', [3]]
194 nifti_selected = ds_selected.map2Nifti()
195
196
197 labels = [123,2,123]
198 ds2 = NiftiDataset((masrc, masrc, nifti_selected), mask=masrc, labels=labels)
199 self.failUnless(ds2.nsamples == 3)
200 self.failUnless((ds2.samples[0] == ds2.samples[1]).all())
201 self.failUnless((ds2.samples[2] == dsfull.samples[3]).all())
202 self.failUnless((ds2.labels == labels).all())
203
204
207
208
209 if __name__ == '__main__':
210 import runner
211