diff --git a/cerr/registration/register.py b/cerr/registration/register.py index 046edc4..65a8d37 100644 --- a/cerr/registration/register.py +++ b/cerr/registration/register.py @@ -16,11 +16,12 @@ from cerr.utils.interp import finterp3 from cerr.radiomics import preprocess import numpy as np +import subprocess def registerScans(basePlanC, baseScanIndex, movPlanC, movScanIndex, transformSaveDir, deforAlgorithm='bsplines', registrationTool='plastimatch', - inputCmdFile=None, baseMask3M=None, movMask3M=None): + baseMask3M=None, movMask3M=None, inputCmdFile=None): """ Args: @@ -30,9 +31,9 @@ def registerScans(basePlanC, baseScanIndex, movPlanC, movScanIndex, transformSav movScanIndex (int): integer, identifies moving scan in movPlanC transformSaveDir (str): Directory to save transformation file registration_tool (str): registration software to use ('PLASTIMATCH','ELASTIX','ANTS') - inputCmdFile (str): optional, path to registration command file baseMask3M (numpy.ndarray): optional, 3D or 4D binary mask(s) in target space movMask3M (numpy.ndarray): optional, 3D or 4D binary mask(s) in moving space + inputCmdFile (str): optional, path to registration command file Returns: cerr.plan_container.PlanC: plan container object basePlanC with an element added to planC.deform attribute @@ -47,6 +48,7 @@ def registerScans(basePlanC, baseScanIndex, movPlanC, movScanIndex, transformSav fixed_img_nii = os.path.join(dirpath, 'fixed.nii.gz') moving_mask_nii = os.path.join(dirpath, 'moving_mask.nii.gz') fixed_mask_nii = os.path.join(dirpath, 'fixed_mask.nii.gz') + warped_img_nii = os.path.join(dirpath, 'warped_moving.nii.gz') basePlanC.scan[baseScanIndex].saveNii(fixed_img_nii) movPlanC.scan[movScanIndex].saveNii(moving_img_nii) if baseMask3M is not None: @@ -61,7 +63,17 @@ def registerScans(basePlanC, baseScanIndex, movPlanC, movScanIndex, transformSav del movPlanC.structure[-1] if inputCmdFile is None or not os.path.exists(inputCmdFile): - plmCmdFile = 'plastimatch_ct_ct_intra_pt.txt' + if baseMask3M is not None and movMask3M is not None: + if deforAlgorithm == 'affine': + plmCmdFile = 'plastimatch_ct_ct_intra_pt_w_masks_affine.txt' + elif deforAlgorithm == 'bsplines': + plmCmdFile = 'plastimatch_ct_ct_intra_pt_w_masks_bsplines.txt' + else: + if deforAlgorithm == 'affine': + plmCmdFile = 'plastimatch_ct_ct_intra_pt_affine.txt' + elif deforAlgorithm == 'bsplines': + plmCmdFile = 'plastimatch_ct_ct_intra_pt_bsplines.txt' + regDir = os.path.dirname(os.path.abspath(__file__)) inputCmdFile = os.path.join(regDir,'settings',plmCmdFile) #cmdFilePathDest = os.path.join(dirpath, plmCmdFile) @@ -89,9 +101,15 @@ def registerScans(basePlanC, baseScanIndex, movPlanC, movScanIndex, transformSav currDir = os.getcwd() os.chdir(dirpath) - os.system(plm_reg_cmd) + #os.system(plm_reg_cmd) + sts = subprocess.Popen(plm_reg_cmd, shell=True).wait() os.chdir(currDir) + # Add warped scan to planC + imageType = movPlanC.scan[movScanIndex].scanInfo[0].imageType + direction = '' + basePlanC = pc.loadNiiScan(warped_img_nii, imageType, direction, basePlanC) + # Copy output to the user-specified directory shutil.copyfile(bspSourcePath, bspDestPath) @@ -140,7 +158,7 @@ def warpScan(basePlanC, baseScanIndex, movPlanC, movScanIndex, deformS): plm_warp_str_cmd = "plastimatch warp --input " + moving_img_nii + \ " --output-img " + warped_img_nii + \ " --xf " + bsplines_coeff_file + \ - " --referenced-ct " + fixed_img_nii + " --fixed " + fixed_img_nii currDir = os.getcwd() os.chdir(dirpath) @@ -422,6 +440,7 @@ def getDvfVectors(deformS, planC, scanNum, outputResV=[0, 0, 0], structNum=None, xDeformV = finterp3(xSurfV,ySurfV,zSurfV,xDeformM,xFieldV,yFieldV,zFieldV) yDeformV = finterp3(xSurfV,ySurfV,zSurfV,yDeformM,xFieldV,yFieldV,zFieldV) zDeformV = finterp3(xSurfV,ySurfV,zSurfV,zDeformM,xFieldV,yFieldV,zFieldV) + # Convert xDeformV,yDeformV,zDeformV to CERR virtual coordinates onesV = np.ones_like(xDeformV) zeroV = np.zeros_like(xDeformV) @@ -438,9 +457,9 @@ def getDvfVectors(deformS, planC, scanNum, outputResV=[0, 0, 0], structNum=None, vectors = np.empty((numPts,2,3), dtype=np.float32) rcsFlag = True # an input argument? if rcsFlag: # (r,c,s) image coordinates - dx = np.abs(np.median(np.diff(xValsV))) - dy = np.abs(np.median(np.diff(yValsV))) - dz = np.abs(np.median(np.diff(zValsV))) + #dx = np.abs(np.median(np.diff(xValsV))) + #dy = np.abs(np.median(np.diff(yValsV))) + #dz = np.abs(np.median(np.diff(zValsV))) # Convert CERR virtual coords to DICOM Image coords for i in range(numPts): vectors[i,0,:] = [rSurfV[i], cSurfV[i], sSurfV[i]] diff --git a/cerr/registration/settings/plastimatch_ct_ct_intra_pt_affine.txt b/cerr/registration/settings/plastimatch_ct_ct_intra_pt_affine.txt new file mode 100644 index 0000000..e09a22d --- /dev/null +++ b/cerr/registration/settings/plastimatch_ct_ct_intra_pt_affine.txt @@ -0,0 +1,36 @@ +# Settings for inter-patient CT-CT registration + +[GLOBAL] +fixed=fixed.nii.gz +moving=moving.nii.gz +img_out=warped_moving.nii.gz +xform_out=bspline_coefficients.txt +resample_when_linear=true + +[STAGE] +xform=align_center +metric=mi + +[STAGE] +xform=translation +optim=rsg +max_its=2000 +res=4 4 2 +metric=mi + +[STAGE] +xform=affine +optim=rsg +max_its=1000 +res=2 2 1 +metric=mi + +#[STAGE] +#xform=bspline +#impl=plastimatch +#threading=openmp +#max_its=30 +#regularization_lambda=0.1 +#grid_spac=50 50 50 +#res=2 2 1 +#metric=mse diff --git a/cerr/registration/settings/plastimatch_ct_ct_intra_pt.txt b/cerr/registration/settings/plastimatch_ct_ct_intra_pt_bsplines.txt similarity index 73% rename from cerr/registration/settings/plastimatch_ct_ct_intra_pt.txt rename to cerr/registration/settings/plastimatch_ct_ct_intra_pt_bsplines.txt index f8572db..60ee9a2 100644 --- a/cerr/registration/settings/plastimatch_ct_ct_intra_pt.txt +++ b/cerr/registration/settings/plastimatch_ct_ct_intra_pt_bsplines.txt @@ -5,6 +5,7 @@ fixed=fixed.nii.gz moving=moving.nii.gz img_out=warped_moving.nii.gz xform_out=bspline_coefficients.txt +resample_when_linear=true [STAGE] xform=align_center @@ -33,3 +34,13 @@ regularization_lambda=0.1 grid_spac=50 50 50 res=2 2 1 metric=mse + +#[STAGE] +#xform=bspline +#impl=plastimatch +#threading=openmp +#max_its=30 +#regularization_lambda=0.05 +#grid_spac=30 30 30 +#res=1 1 1 +#metric=mse diff --git a/cerr/registration/settings/plastimatch_ct_ct_intra_pt_w_masks_affine.txt b/cerr/registration/settings/plastimatch_ct_ct_intra_pt_w_masks_affine.txt new file mode 100644 index 0000000..dd7ea10 --- /dev/null +++ b/cerr/registration/settings/plastimatch_ct_ct_intra_pt_w_masks_affine.txt @@ -0,0 +1,38 @@ +# Settings for inter-patient CT-CT registration + +[GLOBAL] +fixed=fixed.nii.gz +moving=moving.nii.gz +fixed_roi=fixed_mask.nii.gz +moving_roi=moving_mask.nii.gz +img_out=warped_moving.nii.gz +xform_out=bspline_coefficients.txt +resample_when_linear=true + +[STAGE] +xform=align_center +metric=mi + +[STAGE] +xform=translation +optim=rsg +max_its=2000 +res=2 2 1 +metric=mi + +[STAGE] +xform=affine +optim=rsg +max_its=1000 +res=2 2 1 +metric=mi + +#[STAGE] +#xform=bspline +#impl=plastimatch +#threading=openmp +#max_its=30 +#regularization_lambda=0.1 +#grid_spac=50 50 50 +#res=2 2 1 +#metric=mse diff --git a/cerr/registration/settings/plastimatch_ct_ct_intra_pt_w_masks_bsplines.txt b/cerr/registration/settings/plastimatch_ct_ct_intra_pt_w_masks_bsplines.txt new file mode 100644 index 0000000..3f4974a --- /dev/null +++ b/cerr/registration/settings/plastimatch_ct_ct_intra_pt_w_masks_bsplines.txt @@ -0,0 +1,48 @@ +# Settings for inter-patient CT-CT registration + +[GLOBAL] +fixed=fixed.nii.gz +moving=moving.nii.gz +fixed_roi=fixed_mask.nii.gz +moving_roi=moving_mask.nii.gz +img_out=warped_moving.nii.gz +xform_out=bspline_coefficients.txt +resample_when_linear=true + +[STAGE] +xform=align_center +metric=mi + +[STAGE] +xform=translation +optim=rsg +max_its=2000 +res=2 2 1 +metric=mi + +[STAGE] +xform=affine +optim=rsg +max_its=1000 +res=2 2 1 +metric=mi + +[STAGE] +xform=bspline +impl=plastimatch +threading=openmp +max_its=30 +regularization_lambda=0.1 +grid_spac=50 50 50 +res=2 2 1 +metric=mse + +#[STAGE] +#xform=bspline +#impl=plastimatch +#threading=openmp +#max_its=20 +#regularization_lambda=0.05 +#grid_spac=30 30 30 +#res=1 1 1 +#metric=mse diff --git a/cerr/utils/interp.py b/cerr/utils/interp.py index 4f1b33b..4bee848 100644 --- a/cerr/utils/interp.py +++ b/cerr/utils/interp.py @@ -39,51 +39,57 @@ def finterp3(xInterpV, yInterpV, zInterpV, field3M, xFieldV, yFieldV, zFieldV, O #slcs = slcs - 1 # Find indices out of bounds. - colNaN = (cols > siz[1]-1) | (cols < 0) - colLast = (cols - siz[1]-1) ** 2 < 1e-3 - yInterpColLastV = yInterpV[colLast] - zInterpColLastV = zInterpV[colLast] - - rowNaN = (rows > siz[0]-1) | (rows < 0) - rowLast = (rows - siz[0]-1) ** 2 < 1e-3 - xInterpRowLastV = xInterpV[rowLast] - zInterpRowLastV = zInterpV[rowLast] - - slcNaN = np.isnan(slcs) | (slcs < 0) | (slcs > siz[2]-1) - slcLast = (slcs - siz[2]-1) ** 2 < 1e-3 - xInterpLastV = xInterpV[slcLast] - yInterpLastV = yInterpV[slcLast] - - # Set those to a proxy 1. + colNaN = (cols > (siz[1]-1)) | (cols < 0) + # colLast = (cols - (siz[1]-1)) ** 2 < 1e-3 + # yInterpColLastV = yInterpV[colLast] + # zInterpColLastV = zInterpV[colLast] + + rowNaN = (rows > (siz[0]-1)) | (rows < 0) + # rowLast = (rows - (siz[0]-1)) ** 2 < 1e-3 + # xInterpRowLastV = xInterpV[rowLast] + # zInterpRowLastV = zInterpV[rowLast] + + slcNaN = np.isnan(slcs) | (slcs < 0) | (slcs > (siz[2]-1)) + # slcLast = (slcs - (siz[2]-1)) ** 2 < 1e-3 + # xInterpLastV = xInterpV[slcLast] + # yInterpLastV = yInterpV[slcLast] + + # Set those to a proxy 0. rows[rowNaN] = 0 cols[colNaN] = 0 slcs[slcNaN] = 0 colFloor = np.floor(cols) + colCeil = np.ceil(cols) colMod = cols - colFloor oneMinusColMod = (1 - colMod) rowFloor = np.floor(rows) + rowCeil = np.ceil(rows) rowMod = rows - rowFloor oneMinusRowMod = (1 - rowMod) slcFloor = np.floor(slcs) + slcCeil = np.ceil(slcs) slcMod = slcs - slcFloor oneMinusSlcMod = (1 - slcMod) rowFloor = np.asarray(rowFloor,dtype=int) colFloor = np.asarray(colFloor,dtype=int) slcFloor = np.asarray(slcFloor,dtype=int) + rowCeil = np.asarray(rowCeil,dtype=int) + colCeil = np.asarray(colCeil,dtype=int) + slcCeil = np.asarray(slcCeil,dtype=int) # Accumulate contribution from each voxel surrounding x,y,z point. interpV = field3M[rowFloor,colFloor,slcFloor] * oneMinusRowMod * oneMinusColMod * oneMinusSlcMod - interpV += field3M[rowFloor+1,colFloor,slcFloor] * rowMod * oneMinusColMod * oneMinusSlcMod - interpV += field3M[rowFloor,colFloor+1,slcFloor] * oneMinusRowMod * colMod * oneMinusSlcMod - interpV += field3M[rowFloor+1,colFloor+1,slcFloor] * rowMod * colMod * oneMinusSlcMod - interpV += field3M[rowFloor,colFloor,slcFloor+1] * oneMinusRowMod * oneMinusColMod * slcMod - interpV += field3M[rowFloor+1,colFloor,slcFloor+1] * rowMod * oneMinusColMod * slcMod - interpV += field3M[rowFloor,colFloor+1,slcFloor+1] * oneMinusRowMod * colMod * slcMod - interpV += field3M[rowFloor+1,colFloor+1,slcFloor+1] * rowMod * colMod * slcMod + interpV += field3M[rowCeil,colFloor,slcFloor] * rowMod * oneMinusColMod * oneMinusSlcMod + interpV += field3M[rowFloor,colCeil,slcFloor] * oneMinusRowMod * colMod * oneMinusSlcMod + interpV += field3M[rowCeil,colCeil,slcFloor] * rowMod * colMod * oneMinusSlcMod + interpV += field3M[rowFloor,colFloor,slcCeil] * oneMinusRowMod * oneMinusColMod * slcMod + interpV += field3M[rowCeil,colFloor,slcCeil] * rowMod * oneMinusColMod * slcMod + interpV += field3M[rowFloor,colCeil,slcCeil] * oneMinusRowMod * colMod * slcMod + interpV += field3M[rowCeil,colCeil,slcCeil] * rowMod * colMod * slcMod # # Linear indices of lower bound contributing points. @@ -107,17 +113,17 @@ def finterp3(xInterpV, yInterpV, zInterpV, field3M, xFieldV, yFieldV, zFieldV, O # Replace proxy 1s with out of bounds vals. interpV[rowNaN | colNaN | slcNaN] = OOBV - # 2D interpolate last slice - if any(slcLast): - interpV[slcLast] = interp2d(xFieldV, yFieldV, field3M[:, :, -1])(xInterpLastV, yInterpLastV) - - if any(colLast): - if len(zFieldV) > 1: - interpV[colLast] = interp2d(yFieldV, zFieldV, np.squeeze(field3M[:, -1, :].T))(yInterpColLastV, zInterpColLastV) - - if any(rowLast): - if len(zFieldV) > 1: - interpV[rowLast] = interp2d(xFieldV, zFieldV, np.squeeze(field3M[-1, :, :].T))(xInterpRowLastV, zInterpRowLastV) + # # 2D interpolate last slice + # if any(slcLast): + # interpV[slcLast] = interp2d(xFieldV, yFieldV, field3M[:, :, -1])(xInterpLastV, yInterpLastV) + # + # if any(colLast): + # if len(zFieldV) > 1: + # interpV[colLast] = interp2d(yFieldV, zFieldV, np.squeeze(field3M[:, -1, :].T))(yInterpColLastV, zInterpColLastV) + # + # if any(rowLast): + # if len(zFieldV) > 1: + # interpV[rowLast] = interp2d(xFieldV, zFieldV, np.squeeze(field3M[-1, :, :].T))(xInterpRowLastV, zInterpRowLastV) return interpV