Skip to content

Commit

Permalink
Updated signature of registerScans and moved inputCmdFile to end. Set…
Browse files Browse the repository at this point in the history
… up command files for affine and bsplines transforms
  • Loading branch information
adityaapte committed Jul 3, 2024
1 parent 96d40ec commit c24d022
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 42 deletions.
35 changes: 27 additions & 8 deletions cerr/registration/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
from cerr.utils.interp import finterp3
from cerr.radiomics import preprocess
import numpy as np
import subprocess


def registerScans(basePlanC, baseScanIndex, movPlanC, movScanIndex, transformSaveDir,
deforAlgorithm='bsplines', registrationTool='plastimatch',
inputCmdFile=None, baseMask3M=None, movMask3M=None):
baseMask3M=None, movMask3M=None, inputCmdFile=None):
"""
Args:
Expand All @@ -30,9 +31,9 @@ def registerScans(basePlanC, baseScanIndex, movPlanC, movScanIndex, transformSav
movScanIndex (int): integer, identifies moving scan in movPlanC
transformSaveDir (str): Directory to save transformation file
registration_tool (str): registration software to use ('PLASTIMATCH','ELASTIX','ANTS')
inputCmdFile (str): optional, path to registration command file
baseMask3M (numpy.ndarray): optional, 3D or 4D binary mask(s) in target space
movMask3M (numpy.ndarray): optional, 3D or 4D binary mask(s) in moving space
inputCmdFile (str): optional, path to registration command file
Returns:
cerr.plan_container.PlanC: plan container object basePlanC with an element added to planC.deform attribute
Expand All @@ -47,6 +48,7 @@ def registerScans(basePlanC, baseScanIndex, movPlanC, movScanIndex, transformSav
fixed_img_nii = os.path.join(dirpath, 'fixed.nii.gz')
moving_mask_nii = os.path.join(dirpath, 'moving_mask.nii.gz')
fixed_mask_nii = os.path.join(dirpath, 'fixed_mask.nii.gz')
warped_img_nii = os.path.join(dirpath, 'warped_moving.nii.gz')
basePlanC.scan[baseScanIndex].saveNii(fixed_img_nii)
movPlanC.scan[movScanIndex].saveNii(moving_img_nii)
if baseMask3M is not None:
Expand All @@ -61,7 +63,17 @@ def registerScans(basePlanC, baseScanIndex, movPlanC, movScanIndex, transformSav
del movPlanC.structure[-1]

if inputCmdFile is None or not os.path.exists(inputCmdFile):
plmCmdFile = 'plastimatch_ct_ct_intra_pt.txt'
if baseMask3M is not None and movMask3M is not None:
if deforAlgorithm == 'affine':
plmCmdFile = 'plastimatch_ct_ct_intra_pt_w_masks_affine.txt'
elif deforAlgorithm == 'bsplines':
plmCmdFile = 'plastimatch_ct_ct_intra_pt_w_masks_bsplines.txt'
else:
if deforAlgorithm == 'affine':
plmCmdFile = 'plastimatch_ct_ct_intra_pt_affine.txt'
elif deforAlgorithm == 'bsplines':
plmCmdFile = 'plastimatch_ct_ct_intra_pt_bsplines.txt'

regDir = os.path.dirname(os.path.abspath(__file__))
inputCmdFile = os.path.join(regDir,'settings',plmCmdFile)
#cmdFilePathDest = os.path.join(dirpath, plmCmdFile)
Expand Down Expand Up @@ -89,9 +101,15 @@ def registerScans(basePlanC, baseScanIndex, movPlanC, movScanIndex, transformSav

currDir = os.getcwd()
os.chdir(dirpath)
os.system(plm_reg_cmd)
#os.system(plm_reg_cmd)
sts = subprocess.Popen(plm_reg_cmd, shell=True).wait()
os.chdir(currDir)

# Add warped scan to planC
imageType = movPlanC.scan[movScanIndex].scanInfo[0].imageType
direction = ''
basePlanC = pc.loadNiiScan(warped_img_nii, imageType, direction, basePlanC)

# Copy output to the user-specified directory
shutil.copyfile(bspSourcePath, bspDestPath)

Expand Down Expand Up @@ -140,7 +158,7 @@ def warpScan(basePlanC, baseScanIndex, movPlanC, movScanIndex, deformS):
plm_warp_str_cmd = "plastimatch warp --input " + moving_img_nii + \
" --output-img " + warped_img_nii + \
" --xf " + bsplines_coeff_file + \
" --referenced-ct " + fixed_img_nii
" --fixed " + fixed_img_nii

currDir = os.getcwd()
os.chdir(dirpath)
Expand Down Expand Up @@ -422,6 +440,7 @@ def getDvfVectors(deformS, planC, scanNum, outputResV=[0, 0, 0], structNum=None,
xDeformV = finterp3(xSurfV,ySurfV,zSurfV,xDeformM,xFieldV,yFieldV,zFieldV)
yDeformV = finterp3(xSurfV,ySurfV,zSurfV,yDeformM,xFieldV,yFieldV,zFieldV)
zDeformV = finterp3(xSurfV,ySurfV,zSurfV,zDeformM,xFieldV,yFieldV,zFieldV)

# Convert xDeformV,yDeformV,zDeformV to CERR virtual coordinates
onesV = np.ones_like(xDeformV)
zeroV = np.zeros_like(xDeformV)
Expand All @@ -438,9 +457,9 @@ def getDvfVectors(deformS, planC, scanNum, outputResV=[0, 0, 0], structNum=None,
vectors = np.empty((numPts,2,3), dtype=np.float32)
rcsFlag = True # an input argument?
if rcsFlag: # (r,c,s) image coordinates
dx = np.abs(np.median(np.diff(xValsV)))
dy = np.abs(np.median(np.diff(yValsV)))
dz = np.abs(np.median(np.diff(zValsV)))
#dx = np.abs(np.median(np.diff(xValsV)))
#dy = np.abs(np.median(np.diff(yValsV)))
#dz = np.abs(np.median(np.diff(zValsV)))
# Convert CERR virtual coords to DICOM Image coords
for i in range(numPts):
vectors[i,0,:] = [rSurfV[i], cSurfV[i], sSurfV[i]]
Expand Down
36 changes: 36 additions & 0 deletions cerr/registration/settings/plastimatch_ct_ct_intra_pt_affine.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Settings for inter-patient CT-CT registration

[GLOBAL]
fixed=fixed.nii.gz
moving=moving.nii.gz
img_out=warped_moving.nii.gz
xform_out=bspline_coefficients.txt
resample_when_linear=true

[STAGE]
xform=align_center
metric=mi

[STAGE]
xform=translation
optim=rsg
max_its=2000
res=4 4 2
metric=mi

[STAGE]
xform=affine
optim=rsg
max_its=1000
res=2 2 1
metric=mi

#[STAGE]
#xform=bspline
#impl=plastimatch
#threading=openmp
#max_its=30
#regularization_lambda=0.1
#grid_spac=50 50 50
#res=2 2 1
#metric=mse
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ fixed=fixed.nii.gz
moving=moving.nii.gz
img_out=warped_moving.nii.gz
xform_out=bspline_coefficients.txt
resample_when_linear=true

[STAGE]
xform=align_center
Expand Down Expand Up @@ -33,3 +34,13 @@ regularization_lambda=0.1
grid_spac=50 50 50
res=2 2 1
metric=mse

#[STAGE]
#xform=bspline
#impl=plastimatch
#threading=openmp
#max_its=30
#regularization_lambda=0.05
#grid_spac=30 30 30
#res=1 1 1
#metric=mse
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Settings for inter-patient CT-CT registration

[GLOBAL]
fixed=fixed.nii.gz
moving=moving.nii.gz
fixed_roi=fixed_mask.nii.gz
moving_roi=moving_mask.nii.gz
img_out=warped_moving.nii.gz
xform_out=bspline_coefficients.txt
resample_when_linear=true

[STAGE]
xform=align_center
metric=mi

[STAGE]
xform=translation
optim=rsg
max_its=2000
res=2 2 1
metric=mi

[STAGE]
xform=affine
optim=rsg
max_its=1000
res=2 2 1
metric=mi

#[STAGE]
#xform=bspline
#impl=plastimatch
#threading=openmp
#max_its=30
#regularization_lambda=0.1
#grid_spac=50 50 50
#res=2 2 1
#metric=mse
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Settings for inter-patient CT-CT registration

[GLOBAL]
fixed=fixed.nii.gz
moving=moving.nii.gz
fixed_roi=fixed_mask.nii.gz
moving_roi=moving_mask.nii.gz
img_out=warped_moving.nii.gz
xform_out=bspline_coefficients.txt
resample_when_linear=true

[STAGE]
xform=align_center
metric=mi

[STAGE]
xform=translation
optim=rsg
max_its=2000
res=2 2 1
metric=mi

[STAGE]
xform=affine
optim=rsg
max_its=1000
res=2 2 1
metric=mi

[STAGE]
xform=bspline
impl=plastimatch
threading=openmp
max_its=30
regularization_lambda=0.1
grid_spac=50 50 50
res=2 2 1
metric=mse

#[STAGE]
#xform=bspline
#impl=plastimatch
#threading=openmp
#max_its=20
#regularization_lambda=0.05
#grid_spac=30 30 30
#res=1 1 1
#metric=mse
74 changes: 40 additions & 34 deletions cerr/utils/interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,51 +39,57 @@ def finterp3(xInterpV, yInterpV, zInterpV, field3M, xFieldV, yFieldV, zFieldV, O
#slcs = slcs - 1

# Find indices out of bounds.
colNaN = (cols > siz[1]-1) | (cols < 0)
colLast = (cols - siz[1]-1) ** 2 < 1e-3
yInterpColLastV = yInterpV[colLast]
zInterpColLastV = zInterpV[colLast]

rowNaN = (rows > siz[0]-1) | (rows < 0)
rowLast = (rows - siz[0]-1) ** 2 < 1e-3
xInterpRowLastV = xInterpV[rowLast]
zInterpRowLastV = zInterpV[rowLast]

slcNaN = np.isnan(slcs) | (slcs < 0) | (slcs > siz[2]-1)
slcLast = (slcs - siz[2]-1) ** 2 < 1e-3
xInterpLastV = xInterpV[slcLast]
yInterpLastV = yInterpV[slcLast]

# Set those to a proxy 1.
colNaN = (cols > (siz[1]-1)) | (cols < 0)
# colLast = (cols - (siz[1]-1)) ** 2 < 1e-3
# yInterpColLastV = yInterpV[colLast]
# zInterpColLastV = zInterpV[colLast]

rowNaN = (rows > (siz[0]-1)) | (rows < 0)
# rowLast = (rows - (siz[0]-1)) ** 2 < 1e-3
# xInterpRowLastV = xInterpV[rowLast]
# zInterpRowLastV = zInterpV[rowLast]

slcNaN = np.isnan(slcs) | (slcs < 0) | (slcs > (siz[2]-1))
# slcLast = (slcs - (siz[2]-1)) ** 2 < 1e-3
# xInterpLastV = xInterpV[slcLast]
# yInterpLastV = yInterpV[slcLast]

# Set those to a proxy 0.
rows[rowNaN] = 0
cols[colNaN] = 0
slcs[slcNaN] = 0

colFloor = np.floor(cols)
colCeil = np.ceil(cols)
colMod = cols - colFloor
oneMinusColMod = (1 - colMod)

rowFloor = np.floor(rows)
rowCeil = np.ceil(rows)
rowMod = rows - rowFloor
oneMinusRowMod = (1 - rowMod)

slcFloor = np.floor(slcs)
slcCeil = np.ceil(slcs)
slcMod = slcs - slcFloor
oneMinusSlcMod = (1 - slcMod)

rowFloor = np.asarray(rowFloor,dtype=int)
colFloor = np.asarray(colFloor,dtype=int)
slcFloor = np.asarray(slcFloor,dtype=int)
rowCeil = np.asarray(rowCeil,dtype=int)
colCeil = np.asarray(colCeil,dtype=int)
slcCeil = np.asarray(slcCeil,dtype=int)

# Accumulate contribution from each voxel surrounding x,y,z point.
interpV = field3M[rowFloor,colFloor,slcFloor] * oneMinusRowMod * oneMinusColMod * oneMinusSlcMod
interpV += field3M[rowFloor+1,colFloor,slcFloor] * rowMod * oneMinusColMod * oneMinusSlcMod
interpV += field3M[rowFloor,colFloor+1,slcFloor] * oneMinusRowMod * colMod * oneMinusSlcMod
interpV += field3M[rowFloor+1,colFloor+1,slcFloor] * rowMod * colMod * oneMinusSlcMod
interpV += field3M[rowFloor,colFloor,slcFloor+1] * oneMinusRowMod * oneMinusColMod * slcMod
interpV += field3M[rowFloor+1,colFloor,slcFloor+1] * rowMod * oneMinusColMod * slcMod
interpV += field3M[rowFloor,colFloor+1,slcFloor+1] * oneMinusRowMod * colMod * slcMod
interpV += field3M[rowFloor+1,colFloor+1,slcFloor+1] * rowMod * colMod * slcMod
interpV += field3M[rowCeil,colFloor,slcFloor] * rowMod * oneMinusColMod * oneMinusSlcMod
interpV += field3M[rowFloor,colCeil,slcFloor] * oneMinusRowMod * colMod * oneMinusSlcMod
interpV += field3M[rowCeil,colCeil,slcFloor] * rowMod * colMod * oneMinusSlcMod
interpV += field3M[rowFloor,colFloor,slcCeil] * oneMinusRowMod * oneMinusColMod * slcMod
interpV += field3M[rowCeil,colFloor,slcCeil] * rowMod * oneMinusColMod * slcMod
interpV += field3M[rowFloor,colCeil,slcCeil] * oneMinusRowMod * colMod * slcMod
interpV += field3M[rowCeil,colCeil,slcCeil] * rowMod * colMod * slcMod


# # Linear indices of lower bound contributing points.
Expand All @@ -107,17 +113,17 @@ def finterp3(xInterpV, yInterpV, zInterpV, field3M, xFieldV, yFieldV, zFieldV, O
# Replace proxy 1s with out of bounds vals.
interpV[rowNaN | colNaN | slcNaN] = OOBV

# 2D interpolate last slice
if any(slcLast):
interpV[slcLast] = interp2d(xFieldV, yFieldV, field3M[:, :, -1])(xInterpLastV, yInterpLastV)

if any(colLast):
if len(zFieldV) > 1:
interpV[colLast] = interp2d(yFieldV, zFieldV, np.squeeze(field3M[:, -1, :].T))(yInterpColLastV, zInterpColLastV)

if any(rowLast):
if len(zFieldV) > 1:
interpV[rowLast] = interp2d(xFieldV, zFieldV, np.squeeze(field3M[-1, :, :].T))(xInterpRowLastV, zInterpRowLastV)
# # 2D interpolate last slice
# if any(slcLast):
# interpV[slcLast] = interp2d(xFieldV, yFieldV, field3M[:, :, -1])(xInterpLastV, yInterpLastV)
#
# if any(colLast):
# if len(zFieldV) > 1:
# interpV[colLast] = interp2d(yFieldV, zFieldV, np.squeeze(field3M[:, -1, :].T))(yInterpColLastV, zInterpColLastV)
#
# if any(rowLast):
# if len(zFieldV) > 1:
# interpV[rowLast] = interp2d(xFieldV, zFieldV, np.squeeze(field3M[-1, :, :].T))(xInterpRowLastV, zInterpRowLastV)

return interpV

Expand Down

0 comments on commit c24d022

Please sign in to comment.