Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to mask processing #79

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 43 additions & 10 deletions cerr/utils/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def getSurfacePoints(mask3M, sampleTrans=1, sampleAxis=1):
return r,c,s


def createStructuringElement(sizeCm, resolutionCmV, dimensions=3):
def createStructuringElement(sizeCm, resolutionCmV, dimensions=3, shape='flat'):
"""
Function to create structuring element for morphological operations given
desired dimensions in cm.
Expand All @@ -110,6 +110,8 @@ def createStructuringElement(sizeCm, resolutionCmV, dimensions=3):
sizeCm (np.float): Size of structuring element in cm.
resolutionCmV (np.array): Image resolution in cm [dx, dy, dz].
dimensions (int): [optional, default=3] Specify 3 for 3D or 2 for 2D.
shape (string): [optional, default='flat'] Geometric neighborhood shape. Supported
values: 'flat', 'sphere', 'disk'.

Returns:
structuringElement (np.ndarray): Structuring element.
Expand All @@ -120,7 +122,25 @@ def createStructuringElement(sizeCm, resolutionCmV, dimensions=3):
evenIdxV = sizePixels % 2 == 0
if any(evenIdxV):
sizePixels[evenIdxV] += 1 # Ensure odd size for symmetric structuring element
structuringElement = np.ones(tuple(sizePixels.astype(int)), dtype=np.uint8)

if shape == 'flat':
structuringElement = np.ones(tuple(sizePixels.astype(int)), dtype=np.uint8)
elif shape == 'sphere':
x, y, z = np.meshgrid(np.arange(-sizePixels[0], sizePixels[0] + 1),
np.arange(-sizePixels[1], sizePixels[1] + 1),
np.arange(-sizePixels[2], sizePixels[2] + 1))
structuringElement = ((x / sizePixels[0]) ** 2 +
(y / sizePixels[1]) ** 2 +
(z / sizePixels[2]) ** 2) <= 1
elif shape == 'disk':
x, y = np.meshgrid(np.arange(-sizePixels[0], sizePixels[0] + 1),
np.arange(-sizePixels[1], sizePixels[1] + 1))

structuringElement = ((x / sizePixels[0]) ** 2 +
(y / sizePixels[1]) ** 2) <= sizePixels[0]**2

else:
raise ValueError('Structuring element type %s is not supported.' %(shape))

return structuringElement

Expand Down Expand Up @@ -154,28 +174,34 @@ def morphologicalClosing(binaryMask, structuringElement):
return closedMask


def gaussianBlurring(binaryMask, sigmaVox):
def blurring(binaryMask, sigmaVox, filtType='gaussian'):
"""
Function for Gaussian blurring of input binary mask

Args:
binaryMask (numpy.array): Binary mask to blur.
sigmaVox (float): Sigma for Gaussian in units of voxels.
filtType (string): [optional, default:'gaussian'] 'gaussian' or 'box' smoothing filter.

Returns:
numpy.ndarray(dtype=bool): Blurred mask using Gaussian blur with input sigma.
"""

gaussian = sitk.SmoothingRecursiveGaussianImageFilter()
gaussian.SetSigma(sigmaVox)
if filtType == 'gaussian':
filter = sitk.SmoothingRecursiveGaussianImageFilter()
filter.SetSigma(sigmaVox)
elif filtType == 'box':
filter = sitk.BoxMeanImageFilter()
filter.SetRadius(sigmaVox)

dim = binaryMask.shape
blurredMask3M = np.empty_like(binaryMask, dtype=float)
for slc in range(dim[2]):
if not np.any(binaryMask[:,:,slc]):
blurredMask3M[:,:,slc] = binaryMask[:,:,slc]
continue
img = sitk.GetImageFromArray(binaryMask[:,:,slc].astype(float))
blurImage = gaussian.Execute(img)
blurImage = filter.Execute(img)
blurredMask3M[:,:,slc] = sitk.GetArrayFromImage(blurImage)
return blurredMask3M

Expand Down Expand Up @@ -282,27 +308,34 @@ def closeMask(mask3M, inputResV, structuringElementSizeCm):
return filledMask3M


def largestConnComps(mask3M, numConnComponents):
def largestConnComps(mask3M, numConnComponents, minSize=0, dim=3):
"""
Function to retain 'N' largest connected components in input binary mask

Args:
mask3M (np.ndarray(dtype=bool)): 3D binary segmentation mask
(OR) 3D binary mask.
numConnComponents (int): number of largest components to retain.
minSize (int): [optional, default=0] Min. size of connected component to retain.
dim (int): [optional, default=3. Includes 26 neighbours in 3D ] 2 (2D) or 3 (3D).

Returns:
maskOut3M (np.ndarray(dtype=bool)): 3D mask with labels corresponding to components.

"""

if dim == 2:
structure = np.ones((3, 3))
elif dim == 3:
structure = np.ones((3, 3, 3))

if np.sum(mask3M) > 1:
#Extract connected components
labeledArray, numFeatures = label(mask3M, structure=np.ones((3, 3, 3)))
labeledArray, numFeatures = label(mask3M, structure)

# Sort by size
ccSiz = [len(labeledArray[labeledArray == i]) for i in range(1, numFeatures + 1)]
ccSiz = np.array([len(labeledArray[labeledArray == i]) for i in range(1, numFeatures + 1)])
# Filter min acceptable
ccSiz[ccSiz < minSize] = 0
rankV = np.argsort(ccSiz)[::-1]
if len(rankV) > numConnComponents:
selV = rankV[:numConnComponents]
Expand Down
Loading