#!/usr/bin/env python

import sys, os.path

if len (sys.argv) != 3:
  sys.stderr.write("usage: convert_bruker 2dseq header.mih\n")
  sys.exit (0)


#if os.path.basename (sys.argv[1]) != '2dseq':
  #print ("expected '2dseq' file as first argument")
  #sys.exit (1)

if not sys.argv[2].endswith ('.mih'):
  sys.stderr.write("expected .mih suffix as the second argument\n")
  sys.exit (1)

with open (os.path.join (os.path.dirname (sys.argv[1]), 'reco')) as f:
  lines = f.read().split ('##$')

with open (os.path.join (os.path.dirname (sys.argv[1]), '../../acqp')) as f:
  lines += f.read().split ('##$')

with open (os.path.join (os.path.dirname (sys.argv[1]), '../../method')) as f:
  lines += f.read().split ('##$')


for line in lines:
  line = line.lower()
  if line.startswith ('reco_size='):
    mat_size = line.splitlines()[1].split()
    print ('mat_size', mat_size)
  elif line.startswith ('nslices='):
    nslices = line.split('=')[1].split()[0]
    print ('nslices', nslices)
  elif line.startswith ('acq_time_points='):
    nacq = len (line.split('\n',1)[1].split())
    print ('nacq', nacq)
  elif line.startswith ('reco_wordtype='):
    wtype = line.split('=')[1].split()[0]
    print ('wtype', wtype)
  elif line.startswith ('reco_byte_order='):
    byteorder = line.split('=')[1].split()[0]
    print ('byteorder', byteorder)
  elif line.startswith ('pvm_spatresol='):
    res = line.splitlines()[1].split()
    print ('res', res)
  elif line.startswith ('pvm_spackarrslicedistance='):
    slicethick = line.splitlines()[1].split()[0]
    print ('slicethick',  slicethick)
  elif line.startswith ('pvm_dweffbval='):
    bval = line.split('\n',1)[1].split()
    print ('bval', bval)
  elif line.startswith ('pvm_dwgradvec='):
    bvec = line.split('\n',1)[1].split()
    print ('bvec', bvec)


with open (sys.argv[2], 'w') as f:
  f.write ('mrtrix image\ndim: ' + mat_size[0] + ',' + mat_size[1])
  if len(mat_size) > 2:
    f.write (',' + str(mat_size[2]))
  else:
    try:
      nslices #pylint: disable=pointless-statement
      f.write (',' + str(nslices))
    except:
      pass

  try:
    nacq #pylint: disable=pointless-statement
    f.write (',' + str(nacq))
  except:
    pass

  f.write ('\nvox: ' + str(res[0]) + ',' + str(res[1]))
  if len(res) > 2:
    f.write (',' + str(res[2]))
  else:
    try:
      slicethick #pylint: disable=pointless-statement
      f.write (',' + str(slicethick))
    except:
      pass
  try:
    nacq #pylint: disable=pointless-statement
    f.write (',')
  except:
    pass

  f.write ('\ndatatype: ')
  if wtype == '_16bit_sgn_int':
    f.write ('int16')
  elif wtype == '_32bit_sgn_int':
    f.write ('int32')

  if byteorder=='littleendian':
    f.write ('le')
  else:
    f.write ('be')

  f.write ('\nlayout: +0,+1')
  try:
    nslices #pylint: disable=pointless-statement
    f.write (',+2')
  except:
    pass
  try:
    nacq #pylint: disable=pointless-statement
    f.write (',+3')
  except:
    pass

  f.write ('\nfile: ' + sys.argv[1] + '\n')

  try:
    assert len(bvec) == 3*len(bval)
    bvec = [ bvec[n:n+3] for n in range(0,len(bval),3) ]
    for direction, value in zip(bvec, bval):
      f.write ('dw_scheme: ' + direction[0] + ',' + direction[1] + ',' + str(-float(direction[2])) + ',' + value + '\n')
  except:
    pass
