#!/usr/bin/python

import sys
from sys import argv
from commands import getoutput
from numpy import *

def usage():
    print "Usage: " + argv[0] + " <output mat> <input2standard mat> <input image> <x1> <y1> <z1> <x2> <y2> <z2>"
    print " "
    print "       First argument is the output matrix which will go from the input image to standard space,"
    print "          with the desired linear structure aligned with the y-axis in standard space"
    print "       Second argument is the FLIRT transform from the input image to standard"
    print "       Third argument is the input image (e.g., the .nii.gz file)"
    print "       The remaining arguments are two coordinates in the input image, chosen along the"
    print "          linear structure that is to be aligned with the y-axis"
    print "       All coordinates are in voxel coordinates (in the input image)"
    print "       If the input image is already in standard space then use $FSLDIR/etc/flirtsch/ident.mat"
    print "          as the second argument but still use voxel coordinates (not MNI/mm coords)"
    print "       You can use the output matrix from this script in a simple resampling call to flirt:"
    print "         e.g.  flirt -in input_image -ref standard_image -applyxfm -init outputfromhere.mat -out rotated_image"
    print "         You can also use a higher resolution standard image as the -ref in the line above if you want better resolution"
    sys.exit(1)

if len(argv) <= 9:
    usage()

# Load in the necessary info
a=loadtxt(argv[2])
alldims=getoutput("$FSLDIR/bin/fslsize "+argv[3]+" -s")
listdims=alldims.split()
dx=float(listdims[12])
dy=float(listdims[14])
dz=float(listdims[16])
#print [dx,dy,dz]
x1=matrix([[dx*float(argv[4])],[dy*float(argv[5])],[dz*float(argv[6])],[1]])
x2=matrix([[dx*float(argv[7])],[dy*float(argv[8])],[dz*float(argv[9])],[1]])

# Calculate the desired rotation
v=a*(x2-x1)
# get rid of x-component as we are not interested in this
vn=matrix([[v[1,0]],[v[2,0]]])
norm=sqrt(vn.T * vn)
vn=vn/norm
# deal with angles greater than 90 degrees (only aligning undirected lines)
if vn[0,0]<0:
    vn=vn*-1.0 
theta=arcsin(vn[1,0])

r=matrix([[1,0,0,0],[0,cos(theta),sin(theta),0],[0,-sin(theta),cos(theta),0],[0,0,0,1]])
newa=r*a

# Fix the translation (keep COV in the same place)
# The input image space COV is ...
sx=float(listdims[2])
sy=float(listdims[4])
sz=float(listdims[6])
#print [sx,sy,sz]
incov=matrix([[dx*sx/2.0],[dy*sy/2.0],[dz*sz/2.0],[1]])
# The standard image space *COV* (not origin) is ...
stdcov=mat("91;109;91;1");
trans=stdcov-newa*incov
newa[0:3,3]+=trans[0:3]

# Save out the result
savetxt(argv[1],newa,fmt='%14.10f')

