#!/usr/bin/python

import sys
import os
import tempfile
import shutil
import datetime
import subprocess
import xml.dom.minidom
import HTMLParser
import pyxnat

def attr_value(attrs, name, default=None):
    for (n, v) in attrs:
        if n == name:
            return v
    return default

class HTMLTable:

    def __init__(self, id):
        self.id = id
        # self.cells[row][col] = data
        self.cells = {}
        self.row = None
        self.col = None
        self.rowspan = None
        self.colspan = None
        self.cell = None
        return

    def __iter__(self):
        max_row = max(self.cells)
        max_col = 0
        for r in xrange(max_row+1):
            mc = max(self.cells[r])
            if max_col < mc:
                max_col = mc
        for row in range(max_row+1):
            cols = []
            for col in range(max_col+1):
                try:
                    cols.append(self.cells[row][col])
                except KeyError:
                    cols.append(None)
            yield cols
        return

    def tr(self):
        if self.row is None:
            self.row = 0
        else:
            self.row += 1
        return

    def td(self, attrs):
        self.rowspan = int(attr_value(attrs, 'rowspan', 1))
        self.colspan = int(attr_value(attrs, 'colspan', 1))
        self.cell = ''
        if self.row not in self.cells:
            self.col = 0
        else:
            if self.col is None:
                self.col = 0
            while True:
                if self.col not in self.cells[self.row]:
                    break
                self.col += 1
        return

    def data(self, data):
        self.cell += data
        return

    def end_td(self):
        for r in range(self.row, self.row+self.rowspan):
            for c in range(self.col, self.col+self.colspan):
                self.cells.setdefault(r, {})
                self.cells[r][c] = self.cell
        self.rowspan = None
        self.colspan = None
        self.cell = None
        return

    def end_tr(self):
        self.col = None
        return

class BIRNParser(HTMLParser.HTMLParser):

    def __init__(self):
        HTMLParser.HTMLParser.__init__(self)
        self.state = None
        self.table = None
        self.data = {}
        return

    def handle_starttag(self, tag, attrs):
        # occasionally a </td> is missing, so we get <td>data<td>more data; 
        # handle that case here (simulate the missing </td>)
        if self.state == 'td' and tag == 'td':
            self.handle_endtag('td')
        if not self.state:
            if tag == 'table':
                self.table = HTMLTable(attr_value(attrs, 'id'))
                self.state = 'table'
        elif self.state == 'table':
            if tag == 'tr':
                self.table.tr()
                self.state = 'tr'
        elif self.state == 'tr':
            if tag == 'td':
                self.table.td(attrs)
                self.state = 'td'
        return

    def handle_data(self, data):
        if self.state == 'td':
            self.table.data(data)
        return

    def handle_endtag(self, tag):
        if self.state == 'table':
            if tag == 'table':
                self.read_table()
                self.table = None
                self.state = None
        elif self.state == 'tr':
            if tag == 'tr':
                self.table.end_tr()
                self.state = 'table'
        elif self.state == 'td':
            if tag == 'td':
                self.table.end_td()
                self.state = 'tr'
        return

    def read_table(self):
        if self.table.id == 'table_top_summary':
            for row in self.table:
                if row[0] == 'input':
                    self.take_summary_input_value(row)
                elif row[0] == 'masked':
                    self.take_summary_masked_value(row)
                elif row[0] == 'masked, detrended':
                    self.take_summary_md_value(row)
        elif self.table.id is None:
            pass
        elif self.table.id.startswith('table_') and \
             self.table.id.endswith('_summary'):
            self.read_summary_table(self.table.id[6:-8])
        elif self.table.id.startswith('qa_data_'):
            self.read_data_table(self.table.id[8:])
        return

    def read_summary_table(self, name):
        self.data[name] = {}
        for row in self.table:
            if row[0] == 'Mean:':
                if row[1] == '(absolute)\n':
                    self.data[name]['abs_mean'] = row[2]
                elif row[1] == '(relative)':
                    self.data[name]['rel_mean'] = row[2]
        return

    def read_data_table(self, name):
        data = []
        for row in self.table:
            if row[0] == 'VOLNUM':
                continue
            data.append(row)
        self.data[name]['data'] = data
        return

    def take_summary_input_value(self, row):
        if row[1] == '# potentially-clipped voxels':
            self.input_pcv = row[3]
        elif row[1] == '# vols. with mean intensity abs. z-score > 3':
            if row[2] == 'individual':
                self.input_nvmiaz3_ind = row[3]
            elif row[2] == 'rel. to grand mean':
                self.input_nvmiaz3_rgm = row[3]
        elif row[1] == '# vols. with mean intensity abs. z-score > 4':
            if row[2] == 'individual':
                self.input_nvmiaz4_ind = row[3]
            elif row[2] == 'rel. to grand mean':
                self.input_nvmiaz4_rgm = row[3]
        elif row[1] == '# vols. with mean volume difference > 1%':
            self.input_nvmvd1 = row[3]
        elif row[1] == '# vols. with mean volume difference > 2%':
            self.input_nvmvd2 = row[3]
        return

    def take_summary_masked_value(self, row):
        if row[1] == 'mean FWHM':
            if row[2] == 'X':
                self.masked_fwhm_x = row[3]
            elif row[2] == 'Y':
                self.masked_fwhm_y = row[3]
            elif row[2] == 'Z':
                self.masked_fwhm_z = row[3]
        return

    def take_summary_md_value(self, row):
        if row[1] == '# vols. with mean intensity abs. z-score > 3':
            if row[2] == 'individual':
                self.md_nvmiaz3_ind = row[3]
            elif row[2] == 'rel. to grand mean':
                self.md_nvmiaz3_rgm = row[3]
        elif row[1] == '# vols. with mean intensity abs. z-score > 4':
            if row[2] == 'individual':
                self.md_nvmiaz4_ind = row[3]
            elif row[2] == 'rel. to grand mean':
                self.md_nvmiaz4_rgm = row[3]
        elif row[1] == '# vols. with running difference > 1%':
            self.md_nvrd1 = row[3]
        elif row[1] == '# vols. with running difference > 2%':
            self.md_nvrd2 = row[3]
        elif row[1] == '# vols. with > 1% outlier voxels':
            self.md_nv1ov = row[3]
        elif row[1] == '# vols. with > 2% outlier voxels':
            self.md_nv2ov = row[3]
        elif row[1] == 'mean (ROI in middle slice)':
            self.md_mroims = row[3]
        elif row[1] == 'mean SNR (ROI in middle slice)':
            self.md_msnroims = row[3]
        elif row[1] == 'mean SFNR (ROI in middle slice)':
            self.md_msfnrroims = row[3]
        return

def append_text_element(doc, parent, name, value):
    el = doc.createElement(name)
    el.appendChild(doc.createTextNode(value))
    parent.appendChild(el)
    return

def append_ind_rel_element(doc, parent, name, ind_val, rel_grand_mean_val):
    """append an element of type NVolsIndRelData"""
    el = doc.createElement(name)
    el.setAttribute('individual', ind_val)
    el.setAttribute('rel_grand_mean', rel_grand_mean_val)
    parent.appendChild(el)
    return

def append_data(doc, parent, birn_parser, element_name, data_name):
    el = doc.createElement(element_name)
    el.setAttribute('absolute_mean', birn_parser.data[data_name]['abs_mean'])
    el.setAttribute('relative_mean', birn_parser.data[data_name]['rel_mean'])
    parent.appendChild(el)
    return

progname = os.path.basename(sys.argv[0])

os.environ['PATH'] = os.environ['PATH'] + ':/usr/lib/afni/bin'

if len(sys.argv) != 7:
    print 'usage: %s <user> <password> <XNAT URL> <project> <experiment ID> <experiment label>' % progname
    sys.exit(0)

user = sys.argv[1]
password = sys.argv[2]
host = sys.argv[3].rstrip('/')
project_name = sys.argv[4]
experiment_id = sys.argv[5]
experiment_label = sys.argv[6]

interface = pyxnat.Interface(host, user, password)

# interface.select() will return an iterable that doesn't support indexing, so 
# we have to be explicit here
experiment = None
for e in interface.select('/experiments/%s' % experiment_id):
    if e.exists():
        experiment = e
if not experiment:
    sys.stderr.write('%s: can\'t find experiment %s\n' % (progname, 
                                                          experiment_id))
    sys.exit(1)

data_dir = '/data/archive/%s/arc001/%s' % (project_name, experiment_label)

def fp_files(data_dir, scan, resource_name):
    res_dir = '%s/SCANS/%s/%s' % (data_dir, scan.label(), resource_name)
    fp_fnames = []
    for f in files(data_dir, scan, resource_name):
        fp_fnames.append('%s/%s' % (res_dir, f))
    return fp_fnames

def files(data_dir, scan, resource_name):
    resource = scan.resource(resource_name)
    if not resource.exists():
        raise ValueError, 'scan %s has no resource %s' % (scan.label(), 
                                                          resource_name)
    catalog_fname = '%s/SCANS/%s/%s/scan_%s_catalog.xml' % (data_dir, scan.label(), resource_name, scan.label())
    print catalog_fname
    if not os.path.exists(catalog_fname):
        return []
    doc = xml.dom.minidom.parse(catalog_fname)
    fnames = []
    for el in doc.getElementsByTagName('cat:entry'):
        fname = el.getAttribute('URI')
        if fname:
            fnames.append(fname)
    return fnames

class QATemporaryDirectory:

    def __enter__(self):
        self.path = tempfile.mkdtemp()
        self.clean = True
        fo = open('%s/info' % self.path, 'w')
        fo.write('%s\n' % sys.argv[0])
        fo.write('%s\n' % str(datetime.datetime.now()))
        fo.close()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.clean:
            shutil.rmtree(self.path)
        print 'leaving %s' % self.path
        return

with QATemporaryDirectory() as co:
#    co.clean = False
    for scan in experiment.scans():
        print scan.label()
        assessment_name = 'BIRNQA-%s' % scan.label()
        assessment = experiment.assessor(assessment_name)
        if assessment.exists():
            print 'assessment %s exists; skipping' % assessment_name
            continue
        try:
            dicom_files = fp_files(data_dir, scan, 'DICOM')
        except ValueError:
            continue
        scan_tempdir = '%s/%s' % (co.path, scan.label())
        os.mkdir(scan_tempdir)
        xcede_fname = '%s/scan.xcede' % scan_tempdir
        output_dir = '%s/qa' % scan_tempdir
        args = ['/usr/local/bxh_xcede_tools/bin/dicom2bxh', '--xcede']
        args.extend(dicom_files)
        args.append(xcede_fname)
        try:
            subprocess.check_call(args)
        except subprocess.CalledProcessError, data:
            print str(data)
            continue

        args = ['/usr/local/bxh_xcede_tools/bin/fmriqa_generate.pl', 
                '--verbose', 
                xcede_fname, 
                output_dir]

#        if scan.label() != '1201':
#            print 'skip'
#            continue
#        args = ['/bin/cp', '-rv', '/usr/share/tomcat6/qa', output_dir]

        try:
            subprocess.check_call(args)
        except subprocess.CalledProcessError, data:
            print str(data)
            continue

        birn_parser = BIRNParser()
        birn_parser.feed(open('%s/index.html' % output_dir).read())

        doc = xml.dom.minidom.Document()
        root = doc.createElement('incf:BIRNfMRIQA')
        doc.appendChild(root)
        root.setAttribute('xmlns:incf', 'http://xnat.incf.org/xnat')
        root.setAttribute('xmlns:xnat', 'http://nrg.wustl.edu/xnat')
        root.setAttribute('ID', '')

        append_text_element(doc, root, 'xnat:imageSession_ID', experiment_id)
        append_text_element(doc, root, 'source_scan', scan.label())

        input = doc.createElement('Input')
        root.appendChild(input)
        append_text_element(doc, 
                            input, 
                            'PotentiallyClippedVoxels', 
                            birn_parser.input_pcv)
        append_ind_rel_element(doc, 
                               input, 
                               'NVolsMeanIntensityAbsZ3', 
                               birn_parser.input_nvmiaz3_ind, 
                               birn_parser.input_nvmiaz3_rgm)
        append_ind_rel_element(doc, 
                               input, 
                               'NVolsMeanIntensityAbsZ4', 
                               birn_parser.input_nvmiaz4_ind, 
                               birn_parser.input_nvmiaz4_rgm)
        append_text_element(doc, 
                            input, 
                            'NVolsMeanVolDifference1', 
                            birn_parser.input_nvmvd1)
        append_text_element(doc, 
                            input, 
                            'NVolsMeanVolDifference2', 
                            birn_parser.input_nvmvd2)
        append_data(doc, input, birn_parser, 'Means', 'volmeans')
        append_data(doc, input, birn_parser, 'MeanMeanVolDiff', 'mdiffvolmeans')
        append_data(doc, input, birn_parser, 'CenterOfMassX', 'cmassx')
        append_data(doc, input, birn_parser, 'CenterOfMassY', 'cmassy')
        append_data(doc, input, birn_parser, 'CenterOfMassZ', 'cmassz')

        masked = doc.createElement('Masked')
        root.appendChild(masked)
        masked_fwhm = doc.createElement('MeanFWHM')
        masked_fwhm.setAttribute('x', birn_parser.masked_fwhm_x)
        masked_fwhm.setAttribute('y', birn_parser.masked_fwhm_y)
        masked_fwhm.setAttribute('z', birn_parser.masked_fwhm_z)
        masked.appendChild(masked_fwhm)
        append_data(doc, masked, birn_parser, 'FWHMX', 'FWHMx-X')
        append_data(doc, masked, birn_parser, 'FWHMY', 'FWHMx-Y')
        append_data(doc, masked, birn_parser, 'FWHMZ', 'FWHMx-Z')

        md = doc.createElement('MaskedDetrended')
        root.appendChild(md)
        append_ind_rel_element(doc, 
                               md, 
                               'NVolsMeanIntensityAbsZ3', 
                               birn_parser.md_nvmiaz3_ind, 
                               birn_parser.md_nvmiaz3_rgm)
        append_ind_rel_element(doc, 
                               md, 
                               'NVolsMeanIntensityAbsZ4', 
                               birn_parser.md_nvmiaz4_ind, 
                               birn_parser.md_nvmiaz4_rgm)
        append_text_element(doc, 
                            md, 
                            'NVolsRunningDifference1', 
                            birn_parser.md_nvrd1)
        append_text_element(doc, 
                            md, 
                            'NVolsRunningDifference2', 
                            birn_parser.md_nvrd2)
        append_text_element(doc, 
                            md, 
                            'NVols1OutlierVoxels', 
                            birn_parser.md_nv1ov)
        append_text_element(doc, 
                            md, 
                            'NVols2OutlierVoxels', 
                            birn_parser.md_nv2ov)
        append_text_element(doc, 
                            md, 
                            'MeanROIMiddleSlice', 
                            birn_parser.md_mroims)
        append_text_element(doc, 
                            md, 
                            'MeanSNRROIMiddleSlice', 
                            birn_parser.md_msnroims)
        append_text_element(doc, 
                            md, 
                            'MeanSFNRROIMiddleSlice', 
                            birn_parser.md_msfnrroims)
        append_data(doc, md, birn_parser, 'Means', 'maskedvolmeans')
        append_data(doc, 
                            md, 
                            birn_parser, 
                            'RunningDifferenceMeans', 
                            'maskedtdiffvolmeans')
        append_data(doc, md, birn_parser, 'OutlierVoxelPcts', 'outliercount')
        append_data(doc, md, birn_parser, 'CenterOfMassX', 'maskedcmassx')
        append_data(doc, md, birn_parser, 'CenterOfMassY', 'maskedcmassy')
        append_data(doc, md, birn_parser, 'CenterOfMassZ', 'maskedcmassz')
        append_data(doc, 
                            md, 
                            birn_parser, 
                            'FrequencySpectrumMeanOverMask', 
                            'spectrummean')
        append_data(doc, 
                            md, 
                            birn_parser, 
                            'FrequencySpectrumMaxOverMask', 
                            'spectrummax')

        assessment_xml = '%s/assessment.xml' % scan_tempdir
        print 'creating xml %s...' % assessment_xml
        fo = open(assessment_xml, 'w')
        fo.write(doc.toxml())
        fo.close()

        print 'creating assessment %s...' % assessment_name
        assessment.create(xml=assessment_xml)

        resource = assessment.out_resource('HTML')
        resource.create()
        for (path, dirs, fnames) in os.walk(output_dir):
            for fname in fnames:
                full_path = '%s/%s' % (path, fname)
                rel_path = full_path[len(output_dir)+1:]
                print 'uploading %s...' % rel_path
                file = resource.file(rel_path).put(open(full_path).read())

sys.exit(0)

# eof
