#!/usr/bin/python
#emacs: -*- mode: python-mode; py-indent-offset: 4; tab-width: 4; indent-tabs-mode: nil -*-
#ex: set sts=4 ts=4 sw=4 noet:
#------------------------- =+- Python script -+= -------------------------
"""Generate NeuroSynth atlases

 COPYRIGHT: Yaroslav Halchenko 2013

 LICENSE: MIT

  Permission is hereby granted, free of charge, to any person obtaining a copy
  of this software and associated documentation files (the "Software"), to deal
  in the Software without restriction, including without limitation the rights
  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  copies of the Software, and to permit persons to whom the Software is
  furnished to do so, subject to the following conditions:

  The above copyright notice and this permission notice shall be included in
  all copies or substantial portions of the Software.

  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  THE SOFTWARE.
"""
#-----------------\____________________________________/------------------

__author__ = 'Yaroslav Halchenko'
__copyright__ = 'Copyright (c) 2013 Yaroslav Halchenko'
__license__ = 'MIT'
__version__ = "0.1"

import os
import cPickle
from lxml import etree                  # To load/save .xml atlases

import numpy as np
import nibabel as nib                   # To load/save nifti volumes

import neurosynth
from neurosynth.base.dataset import Dataset
from neurosynth.analysis import meta

def get_header(hdr, data, colormap=None):
    header={'cal_min': 0,
            'cal_max': np.max(data)}
    if colormap:
        header['aux_file'] = colormap
    nifti_hdr = hdr.copy()
    for f,v in header.iteritems():
        nifti_hdr[f] = v
    return nifti_hdr

def make_atlas(maps,                    #  dict of resolution: map
               labels,
               fullname,
               name=None,
               shortname=None,
               imagefilename=None,
               topdir='.',
               subdir='',             # where to dump volumes
               type_="Probabilistic",
               nifti_hdrs=None,
               images_thr=[],
               colormap_thr='MGH-Subcortical',
               atlas_options={}):
    # first generate an xml file
    resolutions = sorted(maps.keys())[::-1] # my guess that we better
                                            # stay more consistent and
                                            # first list low-res one
                                            # since that is the one
                                            # which would be used for
                                            # the coordinates
    SE = etree.SubElement
    root = etree.Element('atlas', attrib=dict(version='1.0'))
    header = SE(root, 'header')
    SE(header, 'name').text = fullname
    for k, v in atlas_options.iteritems():
        SE(header, k).text = str(v)
    SE(header, 'type').text = type_
    if subdir is not None:
        subdir_full = os.path.join(topdir, subdir)
        if not os.path.exists(subdir_full):
            os.makedirs(subdir_full)
        subdir = '/%s/' % subdir

    imagefilename = imagefilename or name

    typ = {"Probabilistic": "prob",
           "Statistic": "stat"}[type_] # abbreviated version
    for res in resolutions:
        images = SE(header, "images")
        prob_file = '%(subdir)s%(imagefilename)s-%(typ)s-%(res)smm' % locals()
        SE(images, 'imagefile').text = prob_file

        # store the actual probability map
        map_ni = nib.Nifti1Image(
            maps[res], None,
            get_header(nifti_hdrs[res], maps[res]))
        map_ni.to_filename(os.path.join(topdir, prob_file.lstrip('/')+'.nii.gz'))

        # Generate maxprob entries/files
        map_argmax = np.argmax(maps[res], axis=-1)
        map_max = np.max(maps[res], axis=-1)
        for thr in images_thr:
            if False:
                # Compute/dump such a map
                map_thr_ = maps[res].copy()
                # this blows our memory limits on head1 on 1mm atlas
                # Theoretically we could avoid copy, get indexes of maxes
                # ones, and then assign that max prob in that single volume, for further
                # thresholding.  Otherwise now consumes 10G in peak
                map_thr_[map_thr_ < thr ] = 0
                map_thr = np.argmax(map_thr_, axis=-1)
                # Since 0 is not different from 0 -- offset by 1
                # but only where we had some in any volume
                map_thr[np.sum(map_thr_, axis=-1)>0] += 1
            # now we do it more efficiently without bloating memory
            map_thr = np.zeros(map_argmax.shape, dtype=int)
            over = np.where(map_max >= thr)
            # assign index for the maximum (+1) where that max passed
            # the threshold
            map_thr[over] = map_argmax[over] + 1
            """original subcortical atlas is even fancier -- there is a magical
            /usr/share/data/harvard-oxford-atlases/HarvardOxford/labels providing
            relabeling"""

            # store maxprob volume
            thresh_file = '%(subdir)s%(imagefilename)s-max%(typ)s-thr%(thr)d-%(res)smm' \
                          % locals()
            SE(images, 'summaryimagefile').text = thresh_file

            map_ni = nib.Nifti1Image(
                map_thr, None,
                get_header(nifti_hdrs[res], map_thr, colormap_thr))
            map_ni.to_filename(os.path.join(topdir, thresh_file.lstrip('/')+'.nii.gz'))


    data = SE(root, 'data')
    map_ = maps[resolutions[0]]

    # now go through all the labels
    for index, l in enumerate(labels):
        # figure out representative x,y,z
        # nz = np.asanyarray(np.nonzero(map_[..., index]))
        # let's take only the 'maximum' voxels
        max_prob = np.max(map_[..., index])
        nz = np.array(np.where(map_[..., index] >= max_prob))
        # figure out center and closest point within the ROI actually
        center = np.mean(nz, axis=1)
        dnz = nz - center[:, None]
        distances = np.sum(dnz * dnz, axis=0)
        x, y, z = nz[:, np.argmin(distances)]
        l_ = SE(data, 'label')
        # to place all the attributes in defined order we can't just pass
        # to the constructor via dict and we have no OrderedDict in 2.6
        for a in ['index', 'x', 'y', 'z']:
            l_.set(a, str(locals()[a]))
        l_.text = l

    # dump XML atlas
    filename = os.path.join(topdir, name + '.xml')
    open(filename, 'w').write(
        etree.tostring(root, pretty_print=True, encoding='ISO-8859-1'))
    return filename

def get_result_volume(data, header, volume):
    header = header.copy()
    # augment header
    # TODO: still not sure if this shouldn't just be posterior probabilities
    #       instead of zscores
    header['cal_min'] = np.min(data)
    header['cal_max'] = np.max(data)
    header['aux_file'] = 'Red-Yellow'
    # nibabel should use this one to figure out scaling/offset so int should still work
    # and int16 should allow sufficient precision while not wasting memory and fslview
    # not crashing with as with int8
    # well -- int16 makes things a bit more complicated due to
    # now absent absolute zeros (via intercept/slope), and I think
    # that is what confuses fslview's historgram viewer
    header.set_data_dtype(np.int16)
    #header.set_data_dtype(np.float32)
    data_4d = volume.unmask(data)
    return nib.Nifti1Image(data_4d, None, header=header)

if __name__ == "__main__":

    # Cmdline options
    import argparse

    parser = argparse.ArgumentParser(
        description="Generate FSL atlases from a provided NeuroSynth dataset",
        add_help=True,
        formatter_class=argparse.RawDescriptionHelpFormatter,
        version=__version__)

    parser.add_argument('-d', '--database-file',
                        default="database.txt", help="database.txt CSV")
    parser.add_argument('-f', '--features-file',
                        default="features.txt", help="features.txt CSV")
    parser.add_argument('-n', '--n-top-features', type=int,
                        default=None,
                        help="Number of top (by papers number) features/terms "
                        "to consider.  Default: None which corresponds to all "
                        "papers")
    parser.add_argument('-o', '--output-directory',
                        default='.',
                        help="Output directory")

    args = parser.parse_args()

    # Some additional hardcoded for now options
    q = .05
    min_studies = .01		# has to be present in 1% of studies
    image_types = dict([
            #('pFgA',              ('posterior',   (0, 1))),
            ('pFgA_z_FDR_%s' % q, ('posterior_z', (-20, 20))),
            ('pAgF_z_FDR_%s' % q, ('prior_z',     (-40, 40))),
            #('pA',                ('prior_article', (0, 1)))
    ])
    threshold = 0.001


    # Load the dataset
    dataset = Dataset(args.database_file)
    dataset.add_features(args.features_file)

    # Figure out top features (if not all)
    if args.n_top_features:
       #feature_counts = (dataset.feature_table.data > .001).sum(axis=1)
       feature_counts = dataset.get_feature_counts(threshold=threshold).items()
       feature_counts = sorted(feature_counts, key=lambda x:x[1], reverse=True)

       top_features = [x[0] for x in feature_counts[:args.n_top_features]]
       # and now sort them alphabetically for a sanity of the user
    else:
       top_features = dataset.get_feature_names()

    top_features = sorted(top_features)

    # Prepare result storage and run meta-analysis per each term (AKA feature
    result_shape = (dataset.volume.num_vox_in_mask, len(top_features))
    result = {i: np.zeros(result_shape) for i in image_types.keys()}
    for i, f in enumerate(top_features):
        ids = dataset.get_ids_by_features(f, threshold=threshold)
        ma = meta.MetaAnalysis(dataset, ids, q=q, min_studies=min_studies)
        for image_type in result.keys():
            result[image_type][:, i] = ma.images[image_type]

    for name, r in result.iteritems():
        l = image_types[name][1]
        assert(np.min(r) >= l[0])
        assert(np.min(r) <= l[1])


    # and tune up for pA since it is the same across all features
    if 'pA' in result:
       result['pA'] = result['pA'][:, 0] * 100.

    # Tune up and save per each interesting statistic
    # name = 'pFgA_z_FDR_%s' % q
    for name in image_types.keys():
        r = result[name]
        r_volume = get_result_volume(r,
                        header=dataset.volume.get_header(),
                volume=dataset.volume)
        rabs = np.abs(r[np.nonzero(r)])
        rminabs = np.min(rabs)
        rmaxabs = np.max(rabs)
        del rabs

        nfeatures = len(top_features)
        image_type_ = image_types[name][0]

        if nfeatures != len(dataset.get_feature_names()):
            filesuffix = "-top%d" % (nfeatures,)
            atlassuffix = " %d top terms" % (nfeatures,)
        else:
            filesuffix = ""
            atlassuffix = " all %d terms" % (nfeatures,)
 
        statmaps = r_volume.get_data()
        if '_z' in name:
            atlas_kwargs = dict(
                type_="Statistic",
                images_thr=[2.3],
                atlas_options=dict(
                    precision=2,
                    statistic="z",
                    units=""))
            features = top_features
        else:
            assert(name == 'pA')
            # This is not good anyways -- makes fslview crash, so do
            # not use for now
            raise NotImplemented("fslview pukes on this one")
            atlas_kwargs = dict(
                type_="Probabilistic",
                images_thr=[],
                atlas_options=dict())
            features = ['Articles prior']
            statmaps = statmaps[..., None] # add rudimentary 4th dimension

        atlas_kwargs['atlas_options'].update(
            dict(
                shortname="NST%s" % filesuffix,
                description =
                    "Generated using NeuroSynth %s using dataset from 2013-03-28 "
                    "with following analysis parameters: "
                    "%s FDR=%.2f threshold=%g min_studies=%.2f"
                    % (neurosynth.__version__, atlassuffix, q, threshold, min_studies),
                lower="%.2f" % rminabs,
                upper="%.2f" % rmaxabs,
                )
            )

        a = make_atlas(
            {2: statmaps},
            features,
            "NeuroSynth (http://neurosynth.org)%s (%s)" \
              % (atlassuffix, image_type_),
            name="NeuroSynth%s-%s" % (filesuffix, image_type_),
            imagefilename="NeuroSynth%s-%s" % (filesuffix, image_type_),
            topdir=args.output_directory,
            subdir='NeuroSynth',
            nifti_hdrs={2: r_volume.get_header()},
            **atlas_kwargs
            )
