#!/usr/bin/python

import sys
import os
import xml.dom.minidom
import HTMLParser

def attr_value(attrs, name, default=None):
    for (n, v) in attrs:
        if n == name:
            return v
    return default

class HTMLTable:

    def __init__(self, id):
        self.id = id
        # self.cells[row][col] = data
        self.cells = {}
        self.row = None
        self.col = None
        self.rowspan = None
        self.colspan = None
        self.cell = None
        return

    def __iter__(self):
        max_row = max(self.cells)
        max_col = 0
        for r in xrange(max_row+1):
            mc = max(self.cells[r])
            if max_col < mc:
                max_col = mc
        for row in range(max_row+1):
            cols = []
            for col in range(max_col+1):
                try:
                    cols.append(self.cells[row][col])
                except KeyError:
                    cols.append(None)
            yield cols
        return

    def tr(self):
        if self.row is None:
            self.row = 0
        else:
            self.row += 1
        return

    def td(self, attrs):
        self.rowspan = int(attr_value(attrs, 'rowspan', 1))
        self.colspan = int(attr_value(attrs, 'colspan', 1))
        self.cell = ''
        if self.row not in self.cells:
            self.col = 0
        else:
            if self.col is None:
                self.col = 0
            while True:
                if self.col not in self.cells[self.row]:
                    break
                self.col += 1
        return

    def data(self, data):
        self.cell += data
        return

    def end_td(self):
        for r in range(self.row, self.row+self.rowspan):
            for c in range(self.col, self.col+self.colspan):
                self.cells.setdefault(r, {})
                self.cells[r][c] = self.cell
        self.rowspan = None
        self.colspan = None
        self.cell = None
        return

    def end_tr(self):
        self.col = None
        return

class BIRNParser(HTMLParser.HTMLParser):

    def __init__(self):
        HTMLParser.HTMLParser.__init__(self)
        self.state = None
        self.table = None
        self.data = {}
        return

    def handle_starttag(self, tag, attrs):
        # occasionally a </td> is missing, so we get <td>data<td>more data; 
        # handle that case here (simulate the missing </td>)
        if self.state == 'td' and tag == 'td':
            self.handle_endtag('td')
        if not self.state:
            if tag == 'table':
                self.table = HTMLTable(attr_value(attrs, 'id'))
                self.state = 'table'
        elif self.state == 'table':
            if tag == 'tr':
                self.table.tr()
                self.state = 'tr'
        elif self.state == 'tr':
            if tag == 'td':
                self.table.td(attrs)
                self.state = 'td'
        return

    def handle_data(self, data):
        if self.state == 'td':
            self.table.data(data)
        return

    def handle_endtag(self, tag):
        if self.state == 'table':
            if tag == 'table':
                self.read_table()
                self.table = None
                self.state = None
        elif self.state == 'tr':
            if tag == 'tr':
                self.table.end_tr()
                self.state = 'table'
        elif self.state == 'td':
            if tag == 'td':
                self.table.end_td()
                self.state = 'tr'
        return

    def read_table(self):
        if self.table.id == 'table_top_summary':
            for row in self.table:
                if row[0] == 'input':
                    self.take_summary_input_value(row)
                elif row[0] == 'masked':
                    self.take_summary_masked_value(row)
                elif row[0] == 'masked, detrended':
                    self.take_summary_md_value(row)
        elif self.table.id is None:
            pass
        elif self.table.id.startswith('table_') and \
             self.table.id.endswith('_summary'):
            self.read_summary_table(self.table.id[6:-8])
        elif self.table.id.startswith('qa_data_'):
            self.read_data_table(self.table.id[8:])
        return

    def read_summary_table(self, name):
        self.data[name] = {}
        for row in self.table:
            if row[0] == 'Mean:':
                if row[1] == '(absolute)\n':
                    self.data[name]['abs_mean'] = row[2]
                elif row[1] == '(relative)':
                    self.data[name]['rel_mean'] = row[2]
        return

    def read_data_table(self, name):
        data = []
        for row in self.table:
            if row[0] == 'VOLNUM':
                continue
            data.append(row)
        self.data[name]['data'] = data
        return

    def take_summary_input_value(self, row):
        if row[1] == '# potentially-clipped voxels':
            self.input_pcv = row[3]
        elif row[1] == '# vols. with mean intensity abs. z-score > 3':
            if row[2] == 'individual':
                self.input_nvmiaz3_ind = row[3]
            elif row[2] == 'rel. to grand mean':
                self.input_nvmiaz3_rgm = row[3]
        elif row[1] == '# vols. with mean intensity abs. z-score > 4':
            if row[2] == 'individual':
                self.input_nvmiaz4_ind = row[3]
            elif row[2] == 'rel. to grand mean':
                self.input_nvmiaz4_rgm = row[3]
        elif row[1] == '# vols. with mean volume difference > 1%':
            self.input_nvmvd1 = row[3]
        elif row[1] == '# vols. with mean volume difference > 2%':
            self.input_nvmvd2 = row[3]
        return

    def take_summary_masked_value(self, row):
        if row[1] == 'mean FWHM':
            if row[2] == 'X':
                self.masked_fwhm_x = row[3]
            elif row[2] == 'Y':
                self.masked_fwhm_y = row[3]
            elif row[2] == 'Z':
                self.masked_fwhm_z = row[3]
        return

    def take_summary_md_value(self, row):
        if row[1] == '# vols. with mean intensity abs. z-score > 3':
            if row[2] == 'individual':
                self.md_nvmiaz3_ind = row[3]
            elif row[2] == 'rel. to grand mean':
                self.md_nvmiaz3_rgm = row[3]
        elif row[1] == '# vols. with mean intensity abs. z-score > 4':
            if row[2] == 'individual':
                self.md_nvmiaz4_ind = row[3]
            elif row[2] == 'rel. to grand mean':
                self.md_nvmiaz4_rgm = row[3]
        elif row[1] == '# vols. with running difference > 1%':
            self.md_nvrd1 = row[3]
        elif row[1] == '# vols. with running difference > 2%':
            self.md_nvrd2 = row[3]
        elif row[1] == '# vols. with > 1% outlier voxels':
            self.md_nv1ov = row[3]
        elif row[1] == '# vols. with > 2% outlier voxels':
            self.md_nv2ov = row[3]
        elif row[1] == 'mean (ROI in middle slice)':
            self.md_mroims = row[3]
        elif row[1] == 'mean SNR (ROI in middle slice)':
            self.md_msnroims = row[3]
        elif row[1] == 'mean SFNR (ROI in middle slice)':
            self.md_msfnrroims = row[3]
        return

def append_text_element(doc, parent, name, value):
    el = doc.createElement(name)
    el.appendChild(doc.createTextNode(value))
    parent.appendChild(el)
    return

def append_ind_rel_element(doc, parent, name, ind_val, rel_grand_mean_val):
    """append an element of type NVolsIndRelData"""
    el = doc.createElement(name)
    el.setAttribute('individual', ind_val)
    el.setAttribute('rel_grand_mean', rel_grand_mean_val)
    parent.appendChild(el)
    return

def append_data(doc, parent, birn_parser, element_name, data_name):
    el = doc.createElement(element_name)
    el.setAttribute('absolute_mean', birn_parser.data[data_name]['abs_mean'])
    el.setAttribute('relative_mean', birn_parser.data[data_name]['rel_mean'])
    parent.appendChild(el)
    return

progname = os.path.basename(sys.argv[0])

source_scan = '1101'
qa_dir = 'QC'
session_id = 'incf_000001'

birn_parser = BIRNParser()
birn_parser.feed(open('%s/index.html' % qa_dir).read())

doc = xml.dom.minidom.Document()
root = doc.createElement('incf:TimeSeriesQA')
doc.appendChild(root)
root.setAttribute('xmlns:incf', 'http://xnat.incf.org/xnat')
root.setAttribute('xmlns:xnat', 'http://nrg.wustl.edu/xnat')
root.setAttribute('source_scan', '1101')
root.setAttribute('ID', '')

append_text_element(doc, root, 'xnat:imageSession_ID', session_id)

input = doc.createElement('Input')
root.appendChild(input)
append_text_element(doc, 
                    input, 
                    'PotentiallyClippedVoxels', 
                    birn_parser.input_pcv)
append_ind_rel_element(doc, 
                       input, 
                       'NVolsMeanIntensityAbsZ3', 
                       birn_parser.input_nvmiaz3_ind, 
                       birn_parser.input_nvmiaz3_rgm)
append_ind_rel_element(doc, 
                       input, 
                       'NVolsMeanIntensityAbsZ4', 
                       birn_parser.input_nvmiaz4_ind, 
                       birn_parser.input_nvmiaz4_rgm)
append_text_element(doc, 
                    input, 
                    'NVolsMeanVolDifference1', 
                    birn_parser.input_nvmvd1)
append_text_element(doc, 
                    input, 
                    'NVolsMeanVolDifference2', 
                    birn_parser.input_nvmvd2)
append_data(doc, input, birn_parser, 'Means', 'volmeans')
append_data(doc, input, birn_parser, 'MeanMeanVolDiff', 'mdiffvolmeans')
append_data(doc, input, birn_parser, 'CenterOfMassX', 'cmassx')
append_data(doc, input, birn_parser, 'CenterOfMassY', 'cmassy')
append_data(doc, input, birn_parser, 'CenterOfMassZ', 'cmassz')

masked = doc.createElement('Masked')
root.appendChild(masked)
masked_fwhm = doc.createElement('MeanFWHM')
masked_fwhm.setAttribute('x', birn_parser.masked_fwhm_x)
masked_fwhm.setAttribute('y', birn_parser.masked_fwhm_y)
masked_fwhm.setAttribute('z', birn_parser.masked_fwhm_z)
masked.appendChild(masked_fwhm)
append_data(doc, masked, birn_parser, 'FWHMX', 'FWHMx-X')
append_data(doc, masked, birn_parser, 'FWHMY', 'FWHMx-Y')
append_data(doc, masked, birn_parser, 'FWHMZ', 'FWHMx-Z')

md = doc.createElement('MaskedDetrended')
root.appendChild(md)
append_ind_rel_element(doc, 
                       md, 
                       'NVolsMeanIntensityAbsZ3', 
                       birn_parser.md_nvmiaz3_ind, 
                       birn_parser.md_nvmiaz3_rgm)
append_ind_rel_element(doc, 
                       md, 
                       'NVolsMeanIntensityAbsZ4', 
                       birn_parser.md_nvmiaz4_ind, 
                       birn_parser.md_nvmiaz4_rgm)
append_text_element(doc, md, 'NVolsRunningDifference1', birn_parser.md_nvrd1)
append_text_element(doc, md, 'NVolsRunningDifference2', birn_parser.md_nvrd2)
append_text_element(doc, md, 'NVols1OutlierVoxels', birn_parser.md_nv1ov)
append_text_element(doc, md, 'NVols2OutlierVoxels', birn_parser.md_nv2ov)
append_text_element(doc, md, 'MeanROIMiddleSlice', birn_parser.md_mroims)
append_text_element(doc, md, 'MeanSNRROIMiddleSlice', birn_parser.md_msnroims)
append_text_element(doc, 
                    md, 
                    'MeanSFNRROIMiddleSlice', 
                    birn_parser.md_msfnrroims)
append_data(doc, md, birn_parser, 'Means', 'maskedvolmeans')
append_data(doc, md, birn_parser, 'RunningDifferenceMeans', 'maskedtdiffvolmeans')
append_data(doc, md, birn_parser, 'OutlierVoxelPcts', 'outliercount')
append_data(doc, md, birn_parser, 'CenterOfMassX', 'maskedcmassx')
append_data(doc, md, birn_parser, 'CenterOfMassY', 'maskedcmassy')
append_data(doc, md, birn_parser, 'CenterOfMassZ', 'maskedcmassz')
append_data(doc, md, birn_parser, 'FrequencySpectrumMeanOverMask', 'spectrummean')
append_data(doc, md, birn_parser, 'FrequencySpectrumMaxOverMask', 'spectrummax')

print doc.toxml()

sys.exit(0)

# eof
