Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create feature maps for visualization #107

Merged
merged 2 commits into from
Dec 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 122 additions & 12 deletions cerr/mri_metrics/dce_mri.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import numpy as np
from matplotlib import pyplot as plt

from scipy.signal import resample
from scipy.ndimage import gaussian_filter1d
from matplotlib import pyplot as plt
from scipy.integrate import cumtrapz

from cerr import plan_container as pc
from cerr.contour.rasterseg import getStrMask
from cerr.utils.statistics import round


EPS = np.finfo(float).eps
def loadTimeSeq(planC, structNum):
"""loadTimeSeq
Function to extract 4D DCE scan array associated with input structure from planC
Expand Down Expand Up @@ -121,7 +125,7 @@ def normalizeToBaseline(scanArr4M, mask3M, timePtsV, basePts=None, imgSmoothDict

maskedSlcSeq3M = np.ma.masked_invalid(slcSeq3M[:, :, 0:basePts]) #Prevents RuntimeWarning: Mean of empty slice
baselineM = np.mean(maskedSlcSeq3M, axis=2).filled(np.nan)
baselineM[baselineM == 0] = np.finfo(float).eps
baselineM[baselineM == 0] = EPS
normScan4M[:, :, slc, :] = scanArr4M[:, :, slc, :] / baselineM[:, :, np.newaxis]

timePtsV = timePtsV - timePtsV[basePts]
Expand Down Expand Up @@ -226,15 +230,91 @@ def smoothResample(sigM, timeV, temporalSmoothDict=None, resampFlag=False):
# Un-pad
tSkip = round(nPad * tdiff / ts)
resampSigM = resampPadSigM[:, tSkip:-tSkip]
timeOutV = np.linspace(0, (resampSigM.shape[1] - 1) * ts, num=resampSigM.shape[1] - 1)
timeOutV = np.linspace(0, (resampSigM.shape[1] - 1) * ts, num=resampSigM.shape[1])

return resampSigM, timeOutV


def computeFeatures(procSlcSigM, procTimeV):

#TBD
return 0
relEnhancementM = procSlcSigM - 1 # S(t)/S(0) - 1
nVox = relEnhancementM.shape[0]

# Peak enhancement
PEv = np.max(relEnhancementM, axis=1)
peakIdxV = np.argmax(relEnhancementM, axis=1)
TTPv = procTimeV[peakIdxV] #Time-to-peak

# Half-peak
halfMaxSig = (np.max(procSlcSigM, axis=1) - 1) / 2
SHPcolIdx = np.argmax(relEnhancementM >= halfMaxSig[:, np.newaxis], axis=1)
SHPv = procSlcSigM[np.arange(len(procSlcSigM)), SHPcolIdx] #Signal at half-peak
TTHPv = procTimeV[SHPcolIdx] #Time to half-peak

# Wash-in / wash-out slopes
WISv = PEv / (TTPv + EPS) # Wash in slope, WIS = PE / TTP
Tend = procTimeV[-1]
RSEendV = relEnhancementM[:, -1]
peakAtEndIdx = TTPv == Tend
WOSv = (PEv - RSEendV)/ (TTPv - Tend) # WOS = (PE - RSE(Tend)) / (TTP – Tend), if PE does not occur at Tend
WOSv[peakAtEndIdx] = 0

# Wash-in/out gradients
## Initial gradient estimated by linear regression of RSE between 20 % and 80 % PE
id_20v = np.argmax(relEnhancementM >= .2 * PEv[:, np.newaxis], axis=1)
id_80v = np.argmax(relEnhancementM > .8 * PEv[:, np.newaxis], axis=1)
id_80v = id_80v - 1
IGv = np.full((nVox, ), fill_value=np.nan)
igIdxV = np.full(relEnhancementM.shape, fill_value=False)
for i in range(nVox):
idxV = np.arange(id_20v[i], id_80v[i]+1)
y = relEnhancementM[i, idxV].T
x = np.hstack((np.ones((len(idxV), 1)), procTimeV[idxV].T[:,np.newaxis]))
b, __, __, __ = np.linalg.lstsq(x, y, rcond=None)
IGv[i] = b[1]
igIdxV[i, idxV] = True

## Wash-out gradient estimated by linear regression of RSE between PE and 1 min post-PE
t0 = peakIdxV
t1IdxM = procTimeV[:,None] >= (procTimeV[t0] + 1)
skipRowV = ~np.any(t1IdxM, axis=0)
t1 = np.argmax(t1IdxM, axis=0)
WOGv = np.full((nVox, ), fill_value=np.nan)
for i in range(nVox):
if ~skipRowV[i]:
x = np.hstack((np.ones((t1[i] - t0[i] + 1, 1)), procTimeV[t0[i]:t1[i]+1][:, np.newaxis]))
y = relEnhancementM[i, t0[i]: t1[i]+1].T
b, __, __, __ = np.linalg.lstsq(x, y, rcond=None)
WOGv[i] = b[1]
else:
WOGv[i] = np.nan

#Signal enhancement ratio
tse1 = np.argmax(procTimeV >= .5)
tse2 = np.argmax(procTimeV >= 2.5)
SERv = relEnhancementM[:, tse1] / relEnhancementM[:, tse2]

#IAUC
IAUCv = cumtrapz(y=relEnhancementM.T, x=procTimeV.T, axis=0, initial=0).T
IAUCtthpV = np.full((nVox,), fill_value=np.nan)
IAUCttpV = np.full((nVox,), fill_value=np.nan)
for i in range(nVox):
IAUCtthpV[i] = IAUCv[i, np.argmax(procTimeV >= TTHPv[i])]
IAUCttpV[i] = IAUCv[i, np.argmax(procTimeV >= TTPv[i])]

featureDict = {'PeakEnhancement': PEv,
'SignalAtHalfPeak': SHPv,
'TimeToPeak': TTPv,
'TimeToHalfPeak': TTHPv,
'SignalEnhancementRatio': SERv,
'WashInSlope': WISv,
'WashOutSlope': WOSv,
'InitialGradient': IGv,
'WashOutGradient': WOGv,
'AUCatPeak': IAUCttpV,
'AUCatHalfPeak': IAUCtthpV}

return featureDict


def calcSemiQuantFeatures(planC, structNum, basePts=None, temporalSmoothDict=None,
Expand All @@ -261,10 +341,40 @@ def calcSemiQuantFeatures(planC, structNum, basePts=None, temporalSmoothDict=Non
normROISlcSigM = normSlcSigM[~skipIdxV, :]
## Smoothing + resampling
procSlcSigM, procTimeV = smoothResample(normROISlcSigM, selTimePtsV,
temporalSmoothDict=temporalSmoothDict, resampFlag=resampFlag)
#
# # Compute features
# featureDict = computeFeatures(procSlcSigM, procTimeV)
# featureList.append(featureDict)
temporalSmoothDict=temporalSmoothDict, resampFlag=resampFlag)

# Compute features
featureDict = computeFeatures(procSlcSigM, procTimeV)
featureList.append(featureDict)

return featureList

def createFeatureMaps(featureList, strNum, planC, importFlag=False):

# Get mask, associated scan and grid
mask3M = getStrMask(strNum, planC)
validSlcV = np.sum(np.sum(mask3M, axis=0), axis=0) > 0
mask3M = mask3M[:, :, validSlcV]

if importFlag:
assocScan = planC.structure[strNum].getStructureAssociatedScan(planC)
xV, yV, zV = planC.scan[assocScan].getScanXYZVals()
zV = zV[validSlcV]

# Extract list of available features
feats = featureList[0].keys()
numFeats = len(feats)
numRow, numCol, numSlc = mask3M.shape

mapDict = {f"{key}": np.zeros_like(mask3M, dtype=float) for key in feats}

# Create 3D maps
for key in feats:
for s in range(numSlc):
maskSlcM = mask3M[:, :, s]
mapDict[key][...,s][maskSlcM] = featureList[s][key]
# Import as pseudo-scan array
if importFlag:
planC = pc.importScanArray(mapDict[key], xV, yV, zV, key, assocScan, planC)

return featureList
return mapDict, planC
Loading