Package mvpa :: Package mappers :: Module wavelet
[hide private]
[frames] | no frames]

Source Code for Module mvpa.mappers.wavelet

  1  #emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  #ex: set sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Wavelet mappers""" 
 10   
 11  from mvpa.base import externals 
 12  externals.exists('pywt', raiseException=True) 
 13   
 14  import pywt 
 15  import numpy as N 
 16   
 17  from mvpa.base import warning 
 18  from mvpa.mappers.base import Mapper 
 19  from mvpa.base.dochelpers import enhancedDocString 
 20   
 21  if __debug__: 
 22      from mvpa.base import debug 
 23   
 24  # WaveletPacket and WaveletTransformation mappers share lots of common 
 25  # functionality at the moment 
 26   
27 -class _WaveletMapper(Mapper):
28 """Generic class for Wavelet mappers (decomposition and packet) 29 """ 30
31 - def __init__(self, dim=1, wavelet='sym4', mode='per', maxlevel=None):
32 """Initialize _WaveletMapper mapper 33 34 :Parameters: 35 dim : int or tuple of int 36 dimensions to work across (for now just scalar value, ie 1D 37 transformation) is supported 38 wavelet : basestring 39 one from the families available withing pywt package 40 mode : basestring 41 periodization mode 42 maxlevel : int or None 43 number of levels to use. If None - automatically selected by pywt 44 """ 45 Mapper.__init__(self) 46 47 self._dim = dim 48 """Dimension to work along""" 49 50 self._maxlevel = maxlevel 51 """Maximal level of decomposition. None for automatic""" 52 53 if not wavelet in pywt.wavelist(): 54 raise ValueError, \ 55 "Unknown family of wavelets '%s'. Please use one " \ 56 "available from the list %s" % (wavelet, pywt.wavelist()) 57 self._wavelet = wavelet 58 """Wavelet family to use""" 59 60 if not mode in pywt.MODES.modes: 61 raise ValueError, \ 62 "Unknown periodization mode '%s'. Please use one " \ 63 "available from the list %s" % (mode, pywt.MODES.modes) 64 self._mode = mode 65 """Periodization mode"""
66 67
68 - def forward(self, data):
69 data = N.asanyarray(data) 70 self._inshape = data.shape 71 self._intimepoints = data.shape[self._dim] 72 res = self._forward(data) 73 self._outshape = res.shape 74 return res
75 76
77 - def reverse(self, data):
78 data = N.asanyarray(data) 79 return self._reverse(data)
80 81
82 - def _forward(self, *args):
83 raise NotImplementedError
84 85
86 - def _reverse(self, *args):
87 raise NotImplementedError
88 89
90 - def getInSize(self):
91 """Returns the number of original features.""" 92 return self._inshape[1:]
93 94
95 - def getOutSize(self):
96 """Returns the number of wavelet components.""" 97 return self._outshape[1:]
98 99
100 - def selectOut(self, outIds):
101 """Choose a subset of components... 102 103 just use MaskMapper on top?""" 104 raise NotImplementedError, "Please use in conjunction with MaskMapper"
105 106 107 __doc__ = enhancedDocString('_WaveletMapper', locals(), Mapper)
108 109
110 -def _getIndexes(shape, dim):
111 """Generator for coordinate tuples providing slice for all in `dim` 112 113 XXX Somewhat sloppy implementation... but works... 114 """ 115 if len(shape) < dim: 116 raise ValueError, "Dimension %d is incorrect for a shape %s" % \ 117 (dim, shape) 118 n = len(shape) 119 curindexes = [0] * n 120 curindexes[dim] = Ellipsis#slice(None) # all elements for dimension dim 121 while True: 122 yield tuple(curindexes) 123 for i in xrange(n): 124 if i == dim and dim == n-1: 125 return # we reached it -- thus time to go 126 if curindexes[i] == shape[i] - 1: 127 if i == n-1: 128 return 129 curindexes[i] = 0 130 else: 131 if i != dim: 132 curindexes[i] += 1 133 break
134 135
136 -class WaveletPacketMapper(_WaveletMapper):
137 """Convert signal into an overcomplete representaion using Wavelet packet 138 """ 139
140 - def __init__(self, level=None, **kwargs):
141 """Initialize WaveletPacketMapper mapper 142 143 :Parameters: 144 level : int or None 145 What level to decompose at. If 'None' data for all levels 146 is provided, but due to different sizes, they are placed 147 in 1D row. 148 """ 149 150 _WaveletMapper.__init__(self,**kwargs) 151 152 self.__level = level
153 154 155 # XXX too much of duplications between such methods -- it begs 156 # refactoring
157 - def __forwardSingleLevel(self, data):
158 if __debug__: 159 debug('MAP', "Converting signal using DWP (single level)") 160 161 wp = None 162 163 level = self.__level 164 wavelet = self._wavelet 165 mode = self._mode 166 dim = self._dim 167 168 level_paths = None 169 for indexes in _getIndexes(data.shape, self._dim): 170 if __debug__: 171 debug('MAP_', " %s" % (indexes,), lf=False, cr=True) 172 WP = pywt.WaveletPacket( 173 data[indexes], wavelet=wavelet, 174 mode=mode, maxlevel=level) 175 176 level_nodes = WP.get_level(level) 177 if level_paths is None: 178 # Needed for reconstruction 179 self.__level_paths = N.array([node.path for node in level_nodes]) 180 level_datas = N.array([node.data for node in level_nodes]) 181 182 if wp is None: 183 newdim = data.shape 184 newdim = newdim[:dim] + level_datas.shape + newdim[dim+1:] 185 if __debug__: 186 debug('MAP_', "Initializing storage of size %s for single " 187 "level (%d) mapping of data of size %s" % (newdim, level, data.shape)) 188 wp = N.empty( tuple(newdim) ) 189 190 wp[indexes] = level_datas 191 192 return wp
193 194
195 - def __forwardMultipleLevels(self, data):
196 wp = None 197 levels_length = None # total length at each level 198 levels_lengths = None # list of lengths per each level 199 for indexes in _getIndexes(data.shape, self._dim): 200 if __debug__: 201 debug('MAP_', " %s" % (indexes,), lf=False, cr=True) 202 WP = pywt.WaveletPacket( 203 data[indexes], 204 wavelet=self._wavelet, 205 mode=self._mode, maxlevel=self._maxlevel) 206 207 if levels_length is None: 208 levels_length = [None] * WP.maxlevel 209 levels_lengths = [None] * WP.maxlevel 210 211 levels_datas = [] 212 for level in xrange(WP.maxlevel): 213 level_nodes = WP.get_level(level+1) 214 level_datas = [node.data for node in level_nodes] 215 216 level_lengths = [len(x) for x in level_datas] 217 level_length = N.sum(level_lengths) 218 219 if levels_lengths[level] is None: 220 levels_lengths[level] = level_lengths 221 elif levels_lengths[level] != level_lengths: 222 raise RuntimeError, \ 223 "ADs of same level of different samples should have same number of elements." \ 224 " Got %s, was %s" % (level_lengths, levels_lengths[level]) 225 226 if levels_length[level] is None: 227 levels_length[level] = level_length 228 elif levels_length[level] != level_length: 229 raise RuntimeError, \ 230 "Levels of different samples should have same number of elements." \ 231 " Got %d, was %d" % (level_length, levels_length[level]) 232 233 level_data = N.hstack(level_datas) 234 levels_datas.append(level_data) 235 236 # assert(len(data) == levels_length) 237 # assert(len(data) >= Ntimepoints) 238 if wp is None: 239 newdim = list(data.shape) 240 newdim[self._dim] = N.sum(levels_length) 241 wp = N.empty( tuple(newdim) ) 242 wp[indexes] = N.hstack(levels_datas) 243 244 self.levels_lengths, self.levels_length = levels_lengths, levels_length 245 if __debug__: 246 debug('MAP_', "") 247 debug('MAP', "Done convertion into wp. Total size %s" % str(wp.shape)) 248 return wp
249 250
251 - def _forward(self, data):
252 if __debug__: 253 debug('MAP', "Converting signal using DWP") 254 255 if self.__level is None: 256 return self.__forwardMultipleLevels(data) 257 else: 258 return self.__forwardSingleLevel(data)
259 260 # 261 # Reverse mapping 262 #
263 - def __reverseSingleLevel(self, wp):
264 265 # local bindings 266 level_paths = self.__level_paths 267 268 # define wavelet packet to use 269 WP = pywt.WaveletPacket( 270 data=None, wavelet=self._wavelet, 271 mode=self._mode, maxlevel=self.__level) 272 273 # prepare storage 274 signal_shape = wp.shape[:1] + self.getInSize() 275 signal = N.zeros(signal_shape) 276 Ntime_points = self._intimepoints 277 for indexes in _getIndexes(signal_shape, 278 self._dim): 279 if __debug__: 280 debug('MAP_', " %s" % (indexes,), lf=False, cr=True) 281 282 for path, level_data in zip(level_paths, wp[indexes]): 283 WP[path] = level_data 284 285 signal[indexes] = WP.reconstruct(True)[:Ntime_points] 286 287 return signal
288 289
290 - def _reverse(self, data):
291 if __debug__: 292 debug('MAP', "Converting signal back using DWP") 293 294 if self.__level is None: 295 raise NotImplementedError 296 else: 297 if not externals.exists('pywt wp reconstruct'): 298 raise NotImplementedError, \ 299 "Reconstruction for a single level for versions of " \ 300 "pywt < 0.1.7 (revision 103) is not supported" 301 if not externals.exists('pywt wp reconstruct fixed'): 302 warning("Reconstruction using available version of pywt might " 303 "result in incorrect data in the tails of the signal") 304 return self.__reverseSingleLevel(data)
305 306 307 308 309
310 -class WaveletTransformationMapper(_WaveletMapper):
311 """Convert signal into wavelet representaion 312 """ 313
314 - def _forward(self, data):
315 """Decompose signal into wavelets's coefficients via dwt 316 """ 317 if __debug__: 318 debug('MAP', "Converting signal using DWT") 319 wd = None 320 coeff_lengths = None 321 for indexes in _getIndexes(data.shape, self._dim): 322 if __debug__: 323 debug('MAP_', " %s" % (indexes,), lf=False, cr=True) 324 coeffs = pywt.wavedec( 325 data[indexes], 326 wavelet=self._wavelet, 327 mode=self._mode, 328 level=self._maxlevel) 329 # Silly Yarik embedds extraction of statistics right in place 330 #stats = [] 331 #for coeff in coeffs: 332 # stats_ = [N.std(coeff), 333 # N.sqrt(N.dot(coeff, coeff)), 334 # ]# + list(N.histogram(coeff, normed=True)[0])) 335 # stats__ = list(coeff) + stats_[:] 336 # stats__ += list(N.log(stats_)) 337 # stats__ += list(N.sqrt(stats_)) 338 # stats__ += list(N.array(stats_)**2) 339 # stats__ += [ N.median(coeff), N.mean(coeff), scipy.stats.kurtosis(coeff) ] 340 # stats.append(stats__) 341 #coeffs = stats 342 coeff_lengths_ = N.array([len(x) for x in coeffs]) 343 if coeff_lengths is None: 344 coeff_lengths = coeff_lengths_ 345 assert((coeff_lengths == coeff_lengths_).all()) 346 if wd is None: 347 newdim = list(data.shape) 348 newdim[self._dim] = N.sum(coeff_lengths) 349 wd = N.empty( tuple(newdim) ) 350 coeff = N.hstack(coeffs) 351 wd[indexes] = coeff 352 if __debug__: 353 debug('MAP_', "") 354 debug('MAP', "Done DWT. Total size %s" % str(wd.shape)) 355 self.lengths = coeff_lengths 356 return wd
357 358
359 - def _reverse(self, wd):
360 if __debug__: 361 debug('MAP', "Performing iDWT") 362 signal = None 363 wd_offsets = [0] + list(N.cumsum(self.lengths)) 364 Nlevels = len(self.lengths) 365 Ntime_points = self._intimepoints #len(time_points) 366 # unfortunately sometimes due to padding iDWT would return longer 367 # sequences, thus we just limit to the right ones 368 369 for indexes in _getIndexes(wd.shape, self._dim): 370 if __debug__: 371 debug('MAP_', " %s" % (indexes,), lf=False, cr=True) 372 wd_sample = wd[indexes] 373 wd_coeffs = [wd_sample[wd_offsets[i]:wd_offsets[i+1]] for i in xrange(Nlevels)] 374 # need to compose original list 375 time_points = pywt.waverec( 376 wd_coeffs, wavelet=self._wavelet, mode=self._mode) 377 if signal is None: 378 newdim = list(wd.shape) 379 newdim[self._dim] = Ntime_points 380 signal = N.empty(newdim) 381 signal[indexes] = time_points[:Ntime_points] 382 if __debug__: 383 debug('MAP_', "") 384 debug('MAP', "Done iDWT. Total size %s" % (signal.shape, )) 385 return signal
386