Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A Request for Collaboration and Code Optimization #3

Open
bstartek opened this issue Apr 7, 2024 · 0 comments
Open

A Request for Collaboration and Code Optimization #3

bstartek opened this issue Apr 7, 2024 · 0 comments

Comments

@bstartek
Copy link

bstartek commented Apr 7, 2024

Good morning, colleagues, please share the documentation on how to install the plugin in Slicer 5.6 and higher, as it cannot be done through the "install from file" method. I also request assistance in adapting the code for Blender, where I am using it for scientific purposes in the segmentation of bones for planning orthognathic surgeries. I have an idea to utilize your project and combine it with another for the segmentation of all teeth. Initially, I use MONAI to apply cephalometric points with the help of an AI model, which also includes the crowns of the teeth. In the next step, I want to use your model in a loop, where the ROI will be determined based on the locational point of the tooth crown (additionally expanded by 50), and thus in a loop, I will perform the segmentation of all teeth, tagging them by names. I ask for help in optimizing the code to exclude Slicer and operate only on SimpleITK, VTK, MONAI. Currently, I have managed to build such a part of the code, but the segmentation results are incorrect. For simplification, I am assuming a constant value of ROI, which will be dynamically assigned in the future.

`def brain_tooth_AI(
inputVolume,
outputSegmentation,
modelPath,
sphere_center,
sphere_radius):
"""
Run the processing algorithm.
Can be used without GUI widget.
:param inputVolume: volume to be Segmented
:param outputVolume: Segmentation result
:param inputROI - To ADD
:param showResult: show output volume in slice viewers
"""

if not inputVolume or not outputSegmentation:
    raise ValueError("Input or output volume is invalid")

if not is_installed("monai", "1.3.0"):
    subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "monai", "-y"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "monai==1.3.0"])

import time
startTime = time.time()
print('Processing started')

### ROI from Blender #########################################################################################


def load_nii_gz_file(file_path):
    return sitk.ReadImage(file_path)


def sitk_to_numpy(image):
    return sitk.GetArrayFromImage(image)

def adjust_roi_for_simpleitk(input_image, sphere_center, sphere_radius):
    img_size = input_image.GetSize()
    img_center = (img_size[0] / 2, img_size[1] / 2, img_size[2] / 2)

    transformed_center = (
        sphere_center[0] + img_center[0],
        sphere_center[1] + img_center[1],
        sphere_center[2] + img_center[2],
    )

    roi = (
        transformed_center[0] - sphere_radius,  # Początek x
        transformed_center[1] - sphere_radius,  # Początek y
        transformed_center[2] - sphere_radius,  # Początek z
        2 * sphere_radius,  # Szerokość
        2 * sphere_radius,  # Wysokość
        2 * sphere_radius  # Głębokość
    )
    return roi

def crop_image(input_image, roi):
    img_size = input_image.GetSize()
    print(f"Image size: {img_size}")

    x, y, z, width, height, depth = roi
    roi_slice = sitk.RegionOfInterestImageFilter()
    roi_slice.SetSize([int(width), int(height), int(depth)]) 
    roi_slice.SetIndex([int(x), int(y), int(z)])  



    cropped_image = roi_slice.Execute(input_image)
    return cropped_image



input_image = load_nii_gz_file(inputVolume)

#roi = adjust_roi_for_simpleitk(input_image, sphere_center, sphere_radius)
roi = (180,250,150,55,55,100) # temporary


cropped_image = crop_image(input_image, roi)


inputImageArray = sitk_to_numpy(cropped_image)
inputCrop_shape = inputImageArray.shape

print("ROI:", inputCrop_shape)

################################################################################################################

import numpy as np
import torch
from monai.inferers import SlidingWindowInferer

from monai.transforms import (
    Compose,
    EnsureChannelFirst,
    SpatialPad,
    NormalizeIntensity
)
from monai.networks.nets import UNet
from monai.networks.layers.factories import Act
from monai.networks.layers import Norm

print("CUDA count: "+str(torch.cuda.device_count()))

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = "cpu"
print("Using ", device, " for compute")

# Define U-Net model
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    act=Act.RELU,
    norm=Norm.BATCH,
    dropout=0.2).to(device)

# Load model weights
inputModelPath = modelPath
loaded_model = torch.load(inputModelPath, map_location=device)
model.load_state_dict(loaded_model,
                      strict=True)  # Strict is false since U-Net is missing some keys - batch norm related?
model.eval()

inputImageArray = torch.tensor(inputImageArray, dtype=torch.float)

# define pre-transforms
pre_transforms = Compose([
    EnsureChannelFirst(channel_dim='no_channel'),
    NormalizeIntensity(),
    SpatialPad(spatial_size=[144, 144, 144], mode="reflect"),
    EnsureChannelFirst(channel_dim='no_channel')
])

# run inference
inputProcessed = pre_transforms(inputImageArray).to(device)
inferer = SlidingWindowInferer(roi_size=[96, 96, 96])


# process prediction output
output = inferer(inputProcessed, model)
output = torch.softmax(output, axis=1).data.cpu().numpy()
output = np.argmax(output, 1).squeeze().astype(np.uint8)

# Crop the predicion back to original size
lower = [0] * 3
upper = [0] * 3
for i in range(len(inputCrop_shape)):
    dim = inputCrop_shape[i]
    padding = 144 - dim
    if padding > 0:
        lower[i] = int(np.floor(padding / 2))
        upper[i] = -int(np.ceil(padding / 2))
    else:
        lower[i] = 0
        upper[i] = dim

output_reshaped = output[lower[0]:upper[0], lower[1]:upper[1], lower[2]:upper[2]]

# # Keep largest connected component
# largest_comp_transform = KeepLargestConnectedComponent()
# val_comp = largest_comp_transform(val_outputs)

print("Inference done")

# Need to take cropped segmentation back into the space of the original image croppedVolume
data_array = numpy_to_vtk(num_array=output_reshaped.ravel(), deep=True, array_type=vtk.VTK_UNSIGNED_CHAR)

image_data = vtk.vtkImageData()
image_data.SetDimensions(output_reshaped.shape)
image_data.GetPointData().SetScalars(data_array)

contour_filter = vtk.vtkMarchingCubes()
contour_filter.SetInputData(image_data)
contour_filter.SetValue(0, 0.5)  
contour_filter.Update()

stl_writer = vtk.vtkSTLWriter()
stl_writer.SetFileName(outputSegmentation+"/test.stl")
stl_writer.SetInputData(contour_filter.GetOutput())
stl_writer.Write()

stopTime = time.time()
print(f'Processing completed in {stopTime - startTime:.2f} seconds')`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant