#!/usr/bin/python

# See file COPYING distributed with the one_click package for the copyright
# and license.

import sys
import os
import optparse
import string
import tempfile
import shutil
import subprocess
import uuid
import StringIO
import getpass
import csv
import dicom
import httplib2

#############################################################################
# classes and functions
#

class InvalidDicomError(Exception):
    """Exception raised when read_dicom_file() tries to read a file that 
    doesn't have DICM at byte 128.
    """

class ServerConfigurationError(Exception):
    """error in server configuration"""

    def __init__(self, msg):
        self.msg = msg
        return

    def __str__(self):
        return msg

class RequestError(Exception):
    """error in an HTTP request"""

    def __init__(self, method, url, status):
        self.method = method
        self.url = url
        self.status = status
        return

    def __str__(self):
        return '%s %s returned %s' % (self.method, self.url, self.status)

class XNAT:

    def __init__(self, base_url, user_name, password):
        self.base_url = base_url
        h = httplib2.Http()
        h.add_credentials(user_name, password)
        url = '%s/data/JSESSION' % self.base_url
        (response, content) = h.request(url)
        if response['status'] != '200':
            raise ValueError
        if 'set-cookie' not in response:
            raise ValueError
        self.jsession_cookie = response['set-cookie'].split(';')[0]
        self.user_name = user_name
        self._projects = None
        self._user_projects = None
        self._subjects = {}
        self._sessions = {}
        return

    def _request(self, url, method='GET', expected_status='200'):
        full_url = '%s/data%s' % (self.base_url, url)
        h = httplib2.Http()
        headers = {'Cookie': self.jsession_cookie}
        (response, content) = h.request(full_url, method, headers=headers)
        if response['status'] != expected_status:
            raise RequestError(method, full_url, response['status'])
        return (response, content)

    def _request_table(self, url):
        (response, content) = self._request('%s?format=csv' % url)
        reader = csv.reader(StringIO.StringIO(content))
        header = reader.next()
        return [ dict(zip(header, row)) for row in reader ]

    def close(self):
        self._request('/JSESSION', 'DELETE')
        return

    def projects(self):
        if self._projects is None:
            self._projects = []
            for project in self._request_table('/projects'):
                self._projects.append(project['ID'])
        return self._projects

    def user_projects(self):
        if self._user_projects is None:
            self._user_projects = []
            for project in self.projects():
                user_url = '/projects/%s/users' % project
                for user in self._request_table(user_url):
                    if user['login'] == self.user_name:
                        self._user_projects.append(project)
                        break
        return self._user_projects

    def create_project(self, project):
        if project in self.projects():
            raise ValueError('project %s exists' % project)
        self._request('/projects/%s' % project, 'PUT', '200')
        # reset the cache
        self._projects = None
        self._user_projects = None
        return

    def subjects(self, project):
        if project not in self.projects():
            raise ValueError('project %s not found' % project)
        if project not in self._subjects:
            self._subjects[project] = []
            url = '/projects/%s/subjects' % project
            for subject in self._request_table(url):
                self._subjects[project].append(subject['label'])
            self._subjects[project].sort()
        return self._subjects[project]

    def create_subject(self, project, subject):
        if project not in self.projects():
            raise ValueError('project %s not found' % project)
        if subject in self.subjects(project):
            raise ValueError('subject %s exists in project %s' % (subject, 
                                                                  project))
        self._request('/projects/%s/subjects/%s' % (project, subject), 
                      'PUT', 
                      '201')
        # reset the cache -- this will exist due to our self.subjects(project) 
        # call above
        del self._subjects[project]
        return

    def sessions(self, project):
        if project not in self._sessions:
            self._sessions[project] = {}
        sessions = []
        for subject in self.subjects(project):
            if subject not in self._sessions[project]:
                self._sessions[project][subject] = []
                url = '/projects/%s/subjects/%s/experiments' % (project, 
                                                                subject)
                for session in self._request_table(url):
                    self._sessions[project][subject].append(session['label'])
            sessions.extend(self._sessions[project][subject])
        sessions.sort()
        return sessions

    def create_session(self, project, subject, session):
        if project not in self.projects():
            raise ValueError('project %s not found' % project)
        if subject not in self.subjects(project):
            raise ValueError('no subject %s in project %s' % (subject, project))
        if session in self.sessions(project):
            raise ValueError('session %s exists in project %s' % (session, 
                                                                  project))
        fmt = '/projects/%s/subjects/%s/experiments/%s?xsiType=mrSessionData'
        url = fmt % (project, subject, session)
        self._request(url, 'PUT', '201')
        # reset the cache -- this will exist due to the calls above
        del self._sessions[project][subject]
        return

def read_dicom_file(fname):
    """Read a DICOM file, raising an exception if the 'DICM' marker is not 
    present at byte 128.

    dicom.read_file() does this as of pydicom 0.9.5.
    """
    fo = open(fname)
    try:
        preamble = fo.read(128)
        magic = fo.read(4)
        if len(preamble) != 128 or magic != 'DICM':
            raise InvalidDicomError
        fo.seek(0)
        do = dicom.read_file(fo)
    finally:
        fo.close()
    return do

def patient_studies(dicom_files, studies, patient_id):
    """given our dicom_files and studies records and a patient ID, return a 
    list of (datetime, study instance UID) ordered by date+time"""
    ps = []
    for uid in dicom_files[patient_id]:
        datetime = '%s%s' % studies[uid]
        ps.append((datetime, uid))
    ps.sort(lambda a, b: cmp(a[0], b[0]))
    return ps

def new_uid(old_uid):
    """Create a new DICOM UID.

    Return an already-created UID if the old UID has already been replaced.
    """
    if old_uid not in uids:
        uids[old_uid] = '2.25.%d' % uuid.uuid1().int
    return uids[old_uid]

def identifier_is_valid(identifier):
    """Check if a project/subject/session identifier is valid.

    Identifiers can only contain alphanumeric characters and underscores.
    """
    if not isinstance(identifier, basestring):
        raise ValueError('identifier must be a string')
    for c in identifier:
        if c not in string.letters + string.digits + '_':
            return False
    return True

def parse_server_configuration(data):
    conf = {}
    # header -- conf version 1 is text key/value (the only thing we handle here)
    if data.split('\n')[0] != 'conf 1':
        raise ServerConfigurationError('can\'t parse server configuration')
    for line in data.split('\n')[1:]:
        if not line:
            continue
        (key, value) = line.split(' ', 1)
        if key == 'client_protocol':
            conf.setdefault(key, []).append(value)
        elif key in ('xnat_base', 'push_target'):
            conf[key] = value
        else:
            conf.setdefault(key, []).append(value)
    for key in ('client_protocol', 'xnat_base', 'push_target'):
        if key not in conf:
            raise KeyError(key)
    (push_host, push_port) = conf['push_target'].split(':')
    conf['push_host'] = push_host
    conf['push_port'] = push_port
    return conf

#############################################################################
# defaults and definitions
#

progname = os.path.basename(sys.argv[0])

try:
    subprocess.call(['storescu'], 
                    stdout=subprocess.PIPE, 
                    stderr=subprocess.PIPE)
except:
    have_storescu = False
else:
    have_storescu = True

storescu_warning = """WARNING: storescu (the program used to send data to
the server) was not found on your system.  Please install DCMTK (or adjust
your path) before proceeding."""

usage = 'usage: %prog <DICOM file or directory> [...]' 

version = '%prog 2.0'

description = """Upload DICOM data to the INCF data sharing server."""

conf_url = 'http://xnat.incf.org/conf.txt'

deident_tags = [(0x0008, 0x0050), # Accession Number
                (0x0008, 0x0080), # InstitutionName
                (0x0008, 0x0090), # Referring Physician's Name
                (0x0008, 0x0096), # Referring Physician Identification
                (0x0008, 0x1048), # Physician(s) of Record
                (0x0008, 0x1049), # Physician(s) of Record Identification
                (0x0008, 0x1050), # Performing Physicians' Name
                (0x0008, 0x1052), # Performing Physician Identification
                (0x0008, 0x1060), # Name of Physician(s) Reading Study
                (0x0008, 0x1062), # Physician(s) Reading Study Identification
                (0x0010, 0x0030), # Patient's Birth Date
                (0x0010, 0x0050), # Patient's Insurance Plan Code SQ
                (0x0010, 0x0101), # Patient's Primary Language Code SQ
                (0x0010, 0x1000), # Other Patient IDs
                (0x0010, 0x1001), # Other Patient Names
                (0x0010, 0x1002), # Other Patient IDs SQ
                (0x0010, 0x1005), # Patient's Birth Name
                (0x0010, 0x1010), # Patient's Age
                (0x0010, 0x1040), # Patient's Address
                (0x0010, 0x1060), # Patient's Mother's Birth Name
               ]

# del dicom_object[(xxxx, yyyy)] doesn't work for me, but 
# del dicom_object[tag] does, where tag = dicom.tag.Tag((xxxx, yyyy))
deident_tags = [ dicom.tag.Tag(tag) for tag in deident_tags ]

# our cache of new UIDs: uids[old UID] -> new UID
uids = {}

files_per_push = 500

# comments on performed procedure
# if present, this can mess up the project change in the prearchive
copp_tag = dicom.tag.Tag((0x0040, 0x0280))

#############################################################################
# command line processing
#

if have_storescu:
    parser = optparse.OptionParser(usage=usage, 
                                   description=description, 
                                   version=version)
else:
    parser = optparse.OptionParser(usage=usage, 
                                   description=description, 
                                   version=version, 
                                   epilog=storescu_warning)

(opts, args) = parser.parse_args()

if not args:
    sys.stderr.write('%s: no sources given\n' % progname)
    sys.stderr.write('run %s -h for help\n' % progname)
    sys.exit(1)

#############################################################################
# begin execution
#

#----------------------------------------------------------------------------
# read the server configuration and check that the server supports the way 
# we're doing things here
#

try:
    h = httplib2.Http()
    (response, content) = h.request(conf_url)
    if response['status'] != '200':
        msg = 'configuration response status was %s' % response['status']
        raise ValueError(msg)
    server_configuration = parse_server_configuration(content)
except:
    sys.stderr.write('%s: error in server configuation\n' % progname)
    sys.stderr.write('please contact xnat-admin@incf.org for assistance\n')
    sys.exit(1)

if '2' not in server_configuration['client_protocol']:
    sys.stderr.write('%s: server configuration error\n' % progname)
    sys.stderr.write('\n')
    sys.stderr.write('The INCF server does not support the protocol used in this program.  Please \n')
    sys.stderr.write('update this program (updates are available at xnat.incf.org).\n')
    sys.stderr.write('\n')
    sys.exit(1)

#----------------------------------------------------------------------------
# upload agreement
#

print
print 'You are about to upload data to the INCF neuroimaging data store.  The '
print 'INCF encourages high ethical standards in data sharing, and requires '
print 'you as the data provider to assert the following:'
print
print '    I have the right to upload and share this data and I have gained '
print '    local ethics approval for the sharing of this data.  I consent '
print '    to this data being freely available for download.  INCF is '
print '    not responsible for checking that the data meets these criteria.'
print
response = raw_input('Do you agree to these terms (enter "Yes" to agree)? ')
print
if response.strip().lower() != 'yes':
    print 'exiting'
    sys.exit(2)

#----------------------------------------------------------------------------
# scan for DICOM files
#

print 'Scanning for files...'

all_files = []
for arg in args:
    if os.path.isfile(arg):
        all_files.append(arg)
    elif os.path.isdir(arg):
        for (dirpath, dirnames, fnames) in os.walk(arg):
            for f in fnames:
                if f == 'DICOMDIR':
                    continue
                all_files.append('%s/%s' % (dirpath, f))
    else:
        sys.stderr.write('%s: %s is not a file or directory\n' % (progname, 
                                                                  arg))

print 'Reading all files...'

# dicom_files[PatientID][StudyInstanceUID] -> list of file names
dicom_files = {}

# studies[StudyInstanceUID] -> (study date, study time)
studies = {}

for (i, fname) in enumerate(all_files):
    print '%d/%d: %s' % (i+1, len(all_files), fname)
    try:
        do = read_dicom_file(fname)
        patient_files = dicom_files.setdefault(do.PatientID, {})
        patient_files.setdefault(do.StudyInstanceUID, []).append(fname)
        if do.StudyInstanceUID not in studies:
            studies[do.StudyInstanceUID] = (do.StudyDate, do.StudyTime)
    except:
        continue

#----------------------------------------------------------------------------
# prompt for information
#

# xnat_projects[PatientID] -> xnat project
xnat_projects = {}

# xnat_subjects[PatientID] -> xnat subject name
xnat_subjects = {}

# xnat_sessions[StudyInstanceUID] -> xnat session ID
xnat_sessions = {}

print

if not dicom_files:
    print 'no DICOM files found'
    sys.exit(0)

n_sessions = sum([ len(df) for df in dicom_files.itervalues() ])
if n_sessions == 1:
    print '1 subject and 1 session found'
elif len(dicom_files) == 1:
    print '1 subject and %d sessions found' % n_sessions
else:
    print '%d subjects and %d sessions found' % (len(dicom_files), n_sessions)

print
print 'Let\'s get some information before sending this data.'
print

while True:
    user_name = None
    while not user_name:
        user_name = raw_input('INCF user name: ')
    password = None
    while not password:
        password = getpass.getpass('INCF password: ')
    try:
        xnat = XNAT(server_configuration['xnat_base'], user_name, password)
        break
    except ValueError:
        print 'error connecting to INCF'

print
print 'Each session will need a project, subject, and session label on the '
print 'server.  Session identifiers must be unique within each project.'
print

subject_number = 0
for patient_id in sorted(dicom_files):
    subject_number += 1
    print 'Subject %s: %s' % (subject_number, patient_id)
    print
    print 'Projects that you can upload to are marked with an asterisk.  Please '
    print 'select one for subject %s or enter a different name to create a new '
    print 'project.  Project names can only contain alphanumeric characters '
    print 'or underscores.'
    print
    print 'Projects on the INCF server are:'
    print
    for p in xnat.projects():
        if p in xnat.user_projects():
            print '    * %s' % p
        else:
            print '      %s' % p
    print
    project = None
    while not project:
        project = raw_input('Project: ')
        if not project:
            continue
        if project in xnat.user_projects():
            break
        if project in xnat.projects():
            print '%s exists but you don\'t have permission to upload to it.' % project
            project = None
            continue
        if not identifier_is_valid(project):
            print 'Project names can only contain alphanumeric characters and underscores.'
            project = None
            continue
        xnat.create_project(project)
    xnat_projects[patient_id] = project
    print
    if not xnat.subjects(project):
        print 'Now choose an identifier for this subject on the INCF server.  There are '
        print 'no subjects in this project, so this will create a new one.  Subject '
        print 'identifiers can only contain alphanumeric characters and underscores.'
    else:
        print 'The existing subjects for this project are:'
        print
        for subject in xnat.subjects(project):
            print '    %s' % subject
        print
        print 'Now choose an identifier for this subject on the INCF server.  You can '
        print 'choose an existing subject or create a new subject by entering a new '
        print 'subject identifier.  Subject identifiers can only contain alphanumeric '
        print 'characters and underscores.'
    print
    subject = None
    while not subject:
        subject = raw_input('Subject: ')
        if not subject:
            continue
        if not identifier_is_valid(subject):
            print 'Subject identifiers can only contain alphanumeric characters and underscores.'
            subject = None
            continue
    if subject not in xnat.subjects(project):
        xnat.create_subject(project, subject)
    xnat_subjects[patient_id] = subject
    print
    session_number = 0
    ps = patient_studies(dicom_files, studies, patient_id)
    for(date_time, study_instance_uid) in ps:
        session_number += 1
        date_time_fmt = '%s-%s-%s %s:%s:%s' % (date_time[0:4], 
                                               date_time[4:6], 
                                               date_time[6:8], 
                                               date_time[8:10], 
                                               date_time[10:12], 
                                               date_time[12:14])
        print 'Session %d for this subject was acquired %s.' % (session_number, 
                                                                date_time_fmt)
        print
        if not xnat.sessions(project):
            print 'Choose a session identifier.  There are currently no sessions '
            print 'for this project.  Session identifiers can only contain alphanumeric '
            print 'characters and underscores.'
        else:
            print 'Choose a session identifier.  It must be unique for this project, '
            print 'so cannot take any of the following (existing) values:'
            print
            for s in xnat.sessions(project):
                print '    %s' % s
        print
        session = None
        while not session:
            session = raw_input('Session: ')
            if not session:
                continue
            if session in xnat.sessions(project):
                print 'Session %s exists.' % session
                session = None
                continue
            if not identifier_is_valid(session):
                print 'Session identifiers can only contain alphanumeric characters and underscores.'
                session = None
                continue
        xnat.create_session(project, subject, session)
        xnat_sessions[study_instance_uid] = session
        print

print 'That\'s all for now.'
print
print 'Uploading...'
print

#############################################################################
# prepare and push DICOM files
#

session_number = 0
for patient_id in sorted(dicom_files):
    ps = patient_studies(dicom_files, studies, patient_id)
    for(date_time, study_instance_uid) in ps:
        session_number += 1
        fmt = 'Uploading session %s of %s (%s %s %s)...'
        print fmt % (session_number, 
                     n_sessions, 
                     xnat_projects[patient_id], 
                     xnat_subjects[patient_id], 
                     xnat_sessions[study_instance_uid])
        print
        tempdir = tempfile.mkdtemp()
        staged_files = []
        try:
            fnames = dicom_files[patient_id][study_instance_uid]
            for (i, fname) in enumerate(fnames):
                print '    preparing file %d of %d' % (i+1, len(fnames))
                do = read_dicom_file(fname)
                study_comments = 'incf 2\n'
                study_comments += 'upload_agreement signed\n'
                study_comments += 'user %s\n' % user_name
                study_comments += 'project %s' % xnat_projects[patient_id]
                do.StudyComments = study_comments
                do.PatientsName = xnat_subjects[patient_id]
                do.PatientID = xnat_sessions[study_instance_uid]
                for name in ('SOPInstanceUID', 
                             'StudyInstanceUID', 
                             'SeriesInstanceUID', 
                             'FrameofReferenceUID'):
                    try:
                        el = do.data_element(name)
                    except KeyError:
                        pass
                    else:
                        if el is not None:
                            do.data_element(name).value = new_uid(el.value)
                for tag in deident_tags:
                    if tag in do:
                        del do[tag]
                if copp_tag in do:
                    del do[copp_tag]
                out_fname = '%s/%08d' % (tempdir, i)
                dicom.write_file(out_fname, do)
                staged_files.append(out_fname)
            print
            # calculate the number of pushes (split into batches of some 
            # manageable number of files each)
            n_pushes = len(staged_files) / files_per_push
            if len(staged_files) % files_per_push:
                n_pushes += 1
            i = 0
            while staged_files:
                i += 1
                print '    push %d of %d...' % (i, n_pushes)
                args = ['storescu', 
                        '-aec', 
                        'XNAT', 
                        server_configuration['push_host'], 
                        server_configuration['push_port']]
                args.extend(staged_files[:files_per_push])
                rv = subprocess.call(args)
                if rv != 0:
                    sys.stderr.write('\n')
                    sys.stderr.write('%s: error in push; bailing out\n' % progname)
                    sys.stderr.write('It\'s likely that your data upload is partial and incomplete.  \n')
                    sys.stderr.write('Please make a note of your subject and session names and contact \n')
                    sys.stderr.write('xnat-admin@incf.org for assistance.\n')
                    sys.stderr.write('\n')
                    sys.exit(1)
                staged_files = staged_files[files_per_push:]
            print
        finally:
            shutil.rmtree(tempdir)

#############################################################################
# clean up and exit
#

print 'done!'
print

xnat.close()

sys.exit(0)

# eof
