#!/usr/bin/env python
# This program is part of the UCLA Multimodal Connectivity Package (UMCP)
#
# UMCP is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
#
# Copyright 2013 Jesse Brown
import os
import sys
import struct
import numpy as np
import nibabel as nib
import core
def get_floats(track_file):
"""Read in all tracks from a .trk (TrackVis) file and store them in a list"""
track_list = []
header_dict = get_header(track_file)
size = os.path.getsize(track_file)
n_s = header_dict["n_scalars"]
n_p = header_dict["n_properties"]
f = open(track_file, 'rb') # added 'rb' for Windows reading
contents = f.read(size)
current = 1000
end = current + 4
while end < size:
length = struct.unpack('i', contents[current:end])[0]
current = end
distance = length*(12+(4*n_s)) # modify for scalars here
#distance=length*12
end = current + distance
if end > size:
break
floats = []
float_range = range(current,end,4)
for float_start in float_range:
float_end = float_start + 4
floats.append(struct.unpack('f',contents[float_start:float_end])[0])
if n_p: # if track has at least one property; currently not storing properties
properties_start = float_end
property_start = properties_start
#track_properties = []
for p in range(n_p):
property_end = property_start + 4
#track_properties.append(struct.unpack('f',contents[property_start:property_end])[0])
property_start = property_end + 4
floats = zip(*[iter(floats)] * 3)
current = end + (4 * n_p)
end = current + 4
if len(floats) > 0:
track_list.append(floats)
l = len(track_list)
f.close()
return track_list
def get_header(track_file):
"""Read in header values from a .trk (TrackVis) file and store them in a dictionary"""
header_dict={}
f=open(track_file, 'rb') # added 'rb' for Windows reading
contents = f.read()
dims=(struct.unpack('h',contents[6:8])[0],struct.unpack('h',contents[8:10])[0],struct.unpack('h',contents[10:12])[0])
header_dict["dims"]=dims
vox_size=(struct.unpack('f',contents[12:16])[0],struct.unpack('f',contents[16:20])[0],struct.unpack('f',contents[20:24])[0])
header_dict["vox_size"]=vox_size
origin=(struct.unpack('f',contents[24:28])[0],struct.unpack('f',contents[28:32])[0],struct.unpack('f',contents[32:36])[0])
header_dict["origin"]=origin
n_scalars=(struct.unpack('h',contents[36:38]))[0]
header_dict["n_scalars"]=n_scalars
n_properties=(struct.unpack('h',contents[238:240]))[0]
header_dict["n_properties"]=n_properties
vox_order=(struct.unpack('c',contents[948:949])[0],struct.unpack('c',contents[949:950])[0],struct.unpack('c',contents[950:951])[0])
header_dict["vox_order"]=vox_order
paddings=(struct.unpack('c',contents[952:953])[0],struct.unpack('c',contents[953:954])[0],struct.unpack('c',contents[954:955])[0])
header_dict["paddings"]=paddings
img_orient_patient=(struct.unpack('f',contents[956:960])[0],struct.unpack('f',contents[960:964])[0],struct.unpack('f',contents[964:968])[0],\
struct.unpack('f',contents[968:972])[0],struct.unpack('f',contents[972:976])[0],struct.unpack('f',contents[976:980])[0])
header_dict["img_orient_patient"]=img_orient_patient
inverts=(struct.unpack('B',contents[982:983])[0],struct.unpack('B',contents[983:984])[0],struct.unpack('B',contents[984:985])[0])
header_dict["inverts"]=inverts
swaps=(struct.unpack('B',contents[985:986])[0],struct.unpack('B',contents[986:987])[0],struct.unpack('B',contents[987:988])[0])
header_dict["swaps"]=swaps
num_fibers=(struct.unpack('i',contents[988:992])[0])
header_dict["num_fibers"]=num_fibers
f.close()
return header_dict
def mm_to_vox_convert(tracks,header,dsi_studio=False):
"""Convert track coordinates from mm dimensions to voxel dimensions"""
xsize,ysize,zsize=np.array(header["vox_size"])
if dsi_studio:
# hack for my 96x96x48 LPS oriented trk files created by dsi_studio
tracks_new = [[(int(x//xsize),int((240-y)//ysize),int(z//zsize)) for x,y,z in track] for track in tracks]
else:
tracks_new = [[(int(x//xsize),int(y//ysize),int(z//zsize)) for x,y,z in track] for track in tracks]
return tracks_new
def add_missing_vox(tracks):
"""Add voxels between track points that are separated by more than 1 voxel in
x, y, or z directions"""
tracks_filled = []
for track in tracks:
track_vox_set = set(track)
new_track = []
for p in range(len(track) - 1):
new_track.append(track[p])
a = np.array(track[p])
b = np.array(track[p + 1])
dif = b - a
ranges = []
if any(abs(dif) >= 2):
for count,val in enumerate(dif):
if val <= -2:
ranges.append(range(a[count],a[count] + val,-1))
elif val >= 2:
ranges.append(range(a[count],a[count] + val))
elif val < 0:
ranges.append(range(a[count],a[count] - 1,-1))
else:
ranges.append(range(a[count],a[count] + 1))
missing_vox_set = set([(x,y,z) for x in ranges[0] for y in ranges[1] for z in ranges[2]])
new_missing_vox = list(missing_vox_set - track_vox_set)
new_track.extend(new_missing_vox)
new_track.append(track[-1])
tracks_filled.append(new_track)
return tracks_filled
def mask_tracks(tracks,header,masks,nonzero_thresh=0,through=1,write_nii=0,outprefix="mask",tracks_mm=0,length_thresh=0):
"""
Creates density files for all tracks passing through a set of masks
"""
# Each volume in vox_tracks_img is the density volume for a single mask
# Leave 'through' argument as 0 to count number of tracks that originate/terminate
# within a mask, set through to 1 to count number of tracks that intersect a mask
xdim,ydim,zdim=header["dims"]
mm_dims=np.array(header["vox_size"])*np.array(header["dims"])
masks_coords_list=[]
if write_nii == 1:
vox_tracks_img=np.zeros((xdim,ydim,zdim,len(masks)))
tracknums=[[] for x in range(len(masks))]
for mask in masks:
masks_coords_list.append(set(core.get_nonzero_coords(mask,nonzero_thresh)))
for tracknum,track in enumerate(tracks):
if through == 0:
track_start_set=set([track[0]])
track_end_set=set([track[-1]])
for count,mask_coords_set in enumerate(masks_coords_list):
if track_start_set & mask_coords_set or track_end_set & mask_coords_set:
if length_thresh:
track_len = tracklength(np.array(tracks_mm[tracknum]))
if track_len > length_thresh:
tracknums[count].append(tracknum)
if write_nii==1:
for x,y,z in track:
if all(np.array([x,y,z]) length_thresh:
tracknums[count].append(tracknum)
if write_nii==1:
for x,y,z in track:
if all(np.array([x,y,z]) length_thresh:
cur_start.append(count)
else:
cur_start.append(count)
elif track_end_set & mask_coords_set:
if length_thresh:
track_len = tracklength(np.array(tracks_mm[tracknum]))
if track_len > length_thresh:
cur_end.append(count)
else:
cur_end.append(count)
for x in cur_start:
for y in cur_end:
# allow for fiber to start/end in multiple (overlapping) masks
if mask_matrix_file:
if mask_matrix_array[x,y]:
connect_mat[x,y] += 1
tracknums[(x*len(masks))+y].append(tracknum)
else:
connect_mat[x,y] += 1
tracknums[(x*len(masks))+y].append(tracknum)
elif through == 1:
cur=[]
track_set=set(track)
for count,mask_coords_set in enumerate(masks_coords_list):
if track_set & mask_coords_set:
if length_thresh:
track_len = tracklength(np.array(tracks_mm[tracknum]))
if track_len > length_thresh:
cur.append(count)
else:
cur.append(count)
for x,y in list(core.combinations(cur,2)):
if mask_matrix_file:
if mask_matrix_array[x,y]:
connect_mat[x,y] += 1
tracknums[(x*len(masks))+y].append(tracknum)
else:
connect_mat[x,y] += 1
tracknums[(x*len(masks))+y].append(tracknum)
connect_mat_sym = core.symmetrize_mat_sum(connect_mat)
tracknums_sym = core.symmetrize_tracknum_list(tracknums)
np.savetxt('%s_connectmat.txt'%outfile,connect_mat_sym)
if write_tracks:
tracknum_list = list(set([item for sublist in tracknums for item in sublist]))
tracknum_list_ordered = sorted(tracknum_list)
track_list = [tracks_mm[n] for n in tracknum_list_ordered]
make_floats(track_list,write_tracks_filename,track_file)
return connect_mat_sym,tracknums_sym
def tracklength(track):
track_len = 0
for i in range(len(track)):
a = track[i]
if i < len(track) - 1: # for length calcs
b = track[i + 1]
ab = a - b
track_len = track_len + np.sqrt(np.dot(ab,ab))
return track_len
def trackcurve(track):
track_curve = 0
for i in range(len(track)):
if i < len(track)-2: # for angle calcs
a = track[i]
b = track[i + 1]
ab = a - b
c = track[i + 2]
bc = c - b
track_curve = track_curve + \
np.arccos(np.dot(ab,bc)/(np.sqrt(np.dot(ab,ab))*np.sqrt(np.dot(bc,bc))))
track_curve = track_curve * 180/np.pi
return track_curve
def track_stats(tracknums,tracks_mm,header,vox_volume,vox_dims,tracks_vox=0,statimage=0,statimage_data=[]):
"""Given a list of track numbers and the tracks object with mm coordinates,
calculate statistics for the tracks: total volume, avg track length,
avg_track_curvature, and (optionally, requires input image) avg value from a
statistical image such as FA, MD"""
# NOTE: track curvature, length, volume calculated from tracks_mm
# stats from statsimage calculated from tracks_vox
track_vols = []
track_lens = []
track_curves = []
if statimage:
track_imagevals_list = []
if len(statimage_data) > 0: # if statimage_data is pre-loaded
pass
else:
input = nib.load(statimage)
statimage_data = input.get_data()
for tracknum in tracknums:
track = np.array(tracks_mm[tracknum])
trackvoxcount = len(track)
trackvoxcount_adjusted = trackvoxcount
track_len = tracklength(track)
track_curve = trackcurve(track)
track_vol = len(track)*vox_volume
track_imageval_cur = 0
for i in xrange(len(track)):
if statimage:
x2,y2,z2=tracks_vox[tracknum][i] # coords from tracks_vox
if x2>(vox_dims[0]-1) or y2>(vox_dims[1]-1) or z2>(vox_dims[2]-1):
# exclude tracks who go outside the dimensions of the statimage
trackvoxcount_adjusted = trackvoxcount_adjusted - 1
pass
else:
track_imageval_cur = track_imageval_cur + statimage_data[x2,y2,z2]
track_vols.append(track_vol)
track_lens.append(track_len)
track_curves.append(track_curve * 180/np.pi)
if statimage:
track_imagevals_list.append([len(track),track_imageval_cur])
total_vol = sum(track_vols)
track_curves = [z for z in track_curves if np.isnan(z) != 1] # if angle is nan
if len(track_lens)>0:
avg_distance = sum(track_lens)/len(track_lens)
# avg_distance_std = np.std(track_lens)
else:
avg_distance = 0
if len(track_curves) > 0:
avg_curve = sum(track_curves)/len(track_curves)
else:
avg_curve = 0
if statimage:
if len(track_imagevals_list)>0:
vox_counts,val_sums = zip(*track_imagevals_list)
avg_imageval = sum(val_sums) / sum(vox_counts) # weighted average
else:
avg_imageval=0
return total_vol,avg_distance,avg_curve,avg_imageval # avg_distance_std
else:
return total_vol,avg_distance,avg_curve # avg_distance_std
def track_stats_list(tracknums_list,tracks_mm,header,outprefix,tracks_vox=0,statimage=0):
"""Calculate matrices for total track bundle volume, avg track length,
avg track curvature, and optionally an average track statistic from a statistical
image like FA or MD
Requires tracknums_list output by mask_connectivity_matrix"""
xdim = len(tracknums_list)
volumelist = np.zeros((xdim))
lengthlist = np.zeros((xdim))
curvelist = np.zeros((xdim))
statlist = np.zeros((xdim))
xsize,ysize,zsize = header["vox_size"]
vox_volume = xsize * ysize * zsize
vox_dims = header["dims"]
if statimage:
input = nib.load(statimage)
statimage_data = input.get_data()
for i in range(xdim):
if statimage:
volumelist[i],lengthlist[i],curvelist[i],statlist[i] = track_stats(tracknums_list[i],tracks_mm,header,vox_volume,vox_dims,tracks_vox,statimage,statimage_data)
else:
volumelist[i],lengthlist[i],curvelist[i] = track_stats(tracknums_list[i],tracks_mm,header,vox_volume,vox_dims)
if statimage:
np.savetxt('%s_volumelist.txt'%outprefix,volumelist)
np.savetxt('%s_lengthlist.txt'%outprefix,lengthlist)
np.savetxt('%s_curvelist.txt'%outprefix,curvelist)
np.savetxt('%s_statlist.txt'%outprefix,statlist)
return volumelist,lengthlist,curvelist,statlist
else:
np.savetxt('%s_volumelist.txt'%outprefix,volumelist)
np.savetxt('%s_lengthlist.txt'%outprefix,lengthlist)
np.savetxt('%s_curvelist.txt'%outprefix,curvelist)
return volumelist,lengthlist,curvelist
def track_stats_group(tracknums_list,tracks_mm,header,outprefix,tracks_vox=0,statimage=0):
"""Calculate matrices for total track bundle volume, avg track length,
avg track curvature, and optionally an average track statistic from a statistical
image like FA or MD
Requires tracknums_list output by mask_connectivity_matrix"""
# Run the trackstats function for a list of track number lists output by
# mask_connectivity_matrix"""
xdim = np.sqrt(len(tracknums_list)).astype('int')
ydim = xdim
volumemat = np.zeros((xdim,ydim))
lengthmat = np.zeros((xdim,ydim))
curvemat = np.zeros((xdim,ydim))
statmat = np.zeros((xdim,ydim))
xsize,ysize,zsize = header["vox_size"]
vox_volume = xsize * ysize * zsize
vox_dims = header["dims"]
if statimage:
input = nib.load(statimage)
statimage_data = input.get_data()
for i in xrange(xdim-1):
for j in xrange(i+1,ydim):
index = (i*xdim)+j
if statimage:
volumemat[i,j],lengthmat[i,j],curvemat[i,j],statmat[i,j]=\
track_stats(tracknums_list[index],\
tracks_mm,\
header,\
vox_volume, vox_dims,\
tracks_vox,\
statimage,statimage_data)
else:
volumemat[i,j],lengthmat[i,j],curvemat[i,j]=\
track_stats(tracknums_list[index],\
tracks_mm,\
header,vox_volume,vox_dims)
volumemat = core.symmetrize_mat(volumemat,'top')
lengthmat = core.symmetrize_mat(lengthmat,'top')
curvemat = core.symmetrize_mat(curvemat,'top')
if statimage:
statmat=core.symmetrize_mat(statmat,'top')
np.savetxt('%s_volumemat.txt'%outprefix,volumemat)
np.savetxt('%s_lengthmat.txt'%outprefix,lengthmat)
np.savetxt('%s_curvemat.txt'%outprefix,curvemat)
np.savetxt('%s_statmat.txt'%outprefix,statmat)
return volumemat,lengthmat,curvemat,statmat
else:
np.savetxt('%s_volumemat.txt'%outprefix,volumemat)
np.savetxt('%s_lengthmat.txt'%outprefix,lengthmat)
np.savetxt('%s_curvemat.txt'%outprefix,curvemat)
return volumemat,lengthmat,curvemat
def get_tracks_dsi_studio(tracks_file,xsize=2.5,ysize=2.5,zsize=2.5):
"""
Read tracks from DSI studio tracks .txt file
xsize, ysize, zsize specify voxel size
"""
mm_convert = False
tracks = core.file_reader(tracks_file)
tracks_new = []
for track in tracks:
track_new = []
track_len = len(track) / 3
for count in range(track_len):
start = (count * 3)
if mm_convert:
track_new.append((int(track[start]//xsize), int(track[start+1]//ysize), int(track[start+2]//zsize)))
else:
track_new.append((int(96-track[start]), int(96-track[start+1]), int(track[start+2])))
tracks_new.append(track_new)
return tracks_new
def mask_connectivity_matrix_dsi(tracks,masks,outfile,nonzero_thresh=0,through=0,tracks_mm=0,length_thresh=0,header=None):
"""Calculate the (symmetric) connectivity matrix for a set of tracks from a DSI studio .txt file and a
set of masks"""
# Leave third argument as 0 to count number of tracks that originate/terminate at
# either end of a pair of masks, set through to 1 to count number of tracks that
# intersect both of the masks.
connect_mat=np.zeros((len(masks),len(masks)))
masks_coords_list=[]
tracknums=[[] for x in range(len(masks)*len(masks))]
for mask in masks:
masks_coords_list.append(set(core.get_nonzero_coords(mask,nonzero_thresh)))
for tracknum,track in enumerate(tracks):
if through == 0:
cur_start=[]
cur_end=[]
track_start_set=set([track[0]])
track_end_set=set([track[-1]])
for count,mask_coords_set in enumerate(masks_coords_list):
if track_start_set & mask_coords_set:
if length_thresh:
track_len = tracklength(np.array(tracks_mm[tracknum]))
if track_len > length_thresh:
cur_start.append(count)
else:
cur_start.append(count)
elif track_end_set & mask_coords_set:
if length_thresh:
track_len = tracklength(np.array(tracks_mm[tracknum]))
if track_len > length_thresh:
cur_end.append(count)
else:
cur_end.append(count)
for x in cur_start:
for y in cur_end:
# allow for fiber to start/end in multiple (overlapping) masks
connect_mat[x,y] += 1
tracknums[(x*len(masks))+y].append(tracknum)
elif through == 1:
cur=[]
track_set=set(track)
for count,mask_coords_set in enumerate(masks_coords_list):
if track_set & mask_coords_set:
if length_thresh:
track_len = tracklength(np.array(tracks_mm[tracknum]))
if track_len > length_thresh:
cur.append(count)
else:
cur.append(count)
for x,y in list(core.combinations(cur,2)):
connect_mat[x,y] += 1
tracknums[(x*len(masks))+y].append(tracknum)
connect_mat_sym = core.symmetrize_mat_sum(connect_mat)
tracknums_sym = core.symmetrize_tracknum_list(tracknums)
np.savetxt('%s_connectmat.txt'%outfile,connect_mat_sym)
return connect_mat_sym,tracknums_sym
def make_floats(track_list,output_filename,input_trackfile):
"""
Take a track list and generate a .trk (TrackVis) file
"""
# can copy header from input file if input file exists
# should only need to change num_fibers
# no real point in generating full file from scratch at this point
f = open(input_trackfile, 'rb') # added 'rb' for Windows reading
contents = f.read()
f.close()
header = contents[0:1000]
outfile = open(output_filename,'wb')
outfile.write(header[0:988])
num_fibers = len(track_list)
num_fibers_packed = struct.pack('i',num_fibers)
outfile.write(num_fibers_packed)
outfile.write(header[992:1000])
for track in track_list:
track_n_points = struct.pack('i',len(track))
outfile.write(track_n_points)
for point in track:
for coord in point:
cur_float = struct.pack('f',coord)
outfile.write(cur_float) # do i need to specify length or just append?
outfile.close()