-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
27 changed files
with
4,023 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# SAMed_h | ||
|
||
## Prerequisites | ||
- Linux (We tested our codes on Ubuntu 18.04) | ||
- Anaconda | ||
- Python 3.10.11 | ||
- Pytorch 2.0.0 **(Pytorch 2+ is necessary)** | ||
|
||
To get started, first please clone the repo | ||
``` | ||
git clone https://github.com/hitachinsk/SAMed.git | ||
cd SAMed_h | ||
``` | ||
Then, please run the following commands: | ||
``` | ||
conda create -n SAMed_h python=3.10.11 | ||
conda activate SAMed_h | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Quick start | ||
All the steps are the same as [SAMed](https://github.com/hitachinsk/SAMed). But you need to prepare the [vit_h version of SAM](https://github.com/facebookresearch/segment-anything#model-checkpoints) and [our pretrained checkpoint](https://drive.google.com/file/d/1Kx_vx9bcxJaiMYWAgljNtwtHcooUsq8m/view?usp=sharing). | ||
|
||
## Training | ||
We adopt one A100 (80G) for training. | ||
1. Please download the processed [training set](https://drive.google.com/file/d/1zuOQRyfo0QYgjcU_uZs0X3LdCnAC2m3G/view?usp=share_link), whose resolution is `224x224`, and put it in `<Your folder>`. Then, unzip and delete this file. We also prepare the [training set](https://drive.google.com/file/d/1F42WMa80UpH98Pw95oAzYDmxAAO2ApYg/view?usp=share_link) with resolution `512x512` for reference, the `224x224` version of training set is downsampled from the `512x512` version. | ||
2. Run this command to train SAMed. | ||
```bash | ||
python train.py --root_path <Your folder> --output <Your output path> --warmup --AdamW --tf32 --compile --use_amp --lr_exp 7 --max_epochs 400 --stop_epoch 300 | ||
``` | ||
Check the results in `<Your output path>`, and the training process will consume about 70G GPU memory. | ||
|
||
## Difference between SAMed_h and SAMed | ||
- SAMed_h adopts the `vit_h` version of SAM as the base model. | ||
- SAMed_h needs more training iterations. Therefore, we set the max epoch to 400 and early stop to 300 for better performance. | ||
- Too large learning rate will cause the training instability of SAMed_h. Therefore, we increase the exponent of exponential decay from 0.9 to 7, which can greatly reduce the training instability. | ||
- For faster training speed and less memory consumption, SAMed_h adopts auto mixed-precision, tensor-float 32 and `compile` technology in pytorch 2.0. Therefore, pytorch2+ is necessary for training this model. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import copy | ||
import gzip | ||
import os | ||
import pickle | ||
import sys | ||
|
||
from PIL import Image | ||
from tqdm import tqdm | ||
import logging | ||
import random | ||
import numpy as np | ||
import torch | ||
import torch.backends.cudnn as cudnn | ||
|
||
from scipy.ndimage.interpolation import zoom | ||
|
||
from pathlib import Path | ||
import cv2 | ||
from scipy.ndimage import label | ||
|
||
from segment_anything import sam_model_registry | ||
from sam_lora_image_encoder import LoRA_Sam | ||
|
||
def generate_cropped_image(): | ||
def do_crop(img, prd, crop_size): | ||
h, w = img.shape[:2] | ||
masked_img = img.copy() | ||
if np.max(prd) == 0: | ||
# 计算中心位置 | ||
center_row = h // 2 | ||
center_col = w // 2 | ||
# 计算裁剪的起始和结束位置 | ||
min_row = max(0, center_row - crop_size[0] // 2) | ||
min_col = max(0, center_col - crop_size[1] // 2) | ||
max_row = min(h, center_row + crop_size[0] // 2) | ||
max_col = min(w, center_col + crop_size[1] // 2) | ||
|
||
else: | ||
masked_img[prd != 255] = 0 | ||
|
||
rows, cols = np.where(prd == 255) | ||
min_row, max_row, min_col, max_col = min(rows), max(rows), min(cols), max(cols) | ||
rect_width = max_col - min_col + 1 | ||
rect_height = max_row - min_row + 1 | ||
|
||
if rect_width < crop_size[0] or rect_height < crop_size[1]: | ||
# 计算裁剪区域的边界 | ||
crop_min_row = max(0, min_row - max(0, (crop_size[0] - rect_height) // 2)) | ||
crop_max_row = min(prd.shape[0], crop_min_row + max(crop_size[0], rect_height)) | ||
|
||
crop_min_col = max(0, min_col - max(0, (crop_size[1] - rect_width) // 2)) | ||
crop_max_col = min(prd.shape[1], crop_min_col + max(crop_size[1], rect_width)) | ||
min_row, max_row, min_col, max_col = crop_min_row, crop_max_row, crop_min_col, crop_max_col | ||
|
||
# Crop the corresponding region from the original image | ||
cropped_img = Image.fromarray(masked_img[min_row:max_row, min_col:max_col]) | ||
|
||
return cropped_img | ||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "1" | ||
root = r"../datasets/dataset1/global" | ||
classes = ['benign', 'tumor', 'normal'] | ||
phases = ['test'] | ||
source = root | ||
target = root.replace("global", "local_seg") | ||
input_size = 224 | ||
crop_size = (256, 256) | ||
cudnn.benchmark = False | ||
cudnn.deterministic = True | ||
seed = 1234 | ||
random.seed() | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
rank = 4 | ||
lora_ckpt = r"./exp4/epoch_199.pth" | ||
ckpt = r"./checkpoints/sam_vit_b_01ec64.pth" | ||
sam, img_embedding_size = sam_model_registry['vit_b'](image_size=input_size, | ||
num_classes=1, | ||
checkpoint=ckpt, | ||
pixel_mean=[0, 0, 0], | ||
pixel_std=[1, 1, 1]) | ||
|
||
net = LoRA_Sam(sam, rank).cuda() | ||
net.load_lora_parameters(lora_ckpt) | ||
|
||
net.eval() | ||
for phase in phases: | ||
for cls in classes: | ||
imgs = os.listdir(os.path.join(source, phase, cls)) | ||
for img in tqdm(imgs): | ||
torch.cuda.empty_cache() | ||
img_path = os.path.join(source, phase, cls, img) | ||
image = cv2.imread(img_path) | ||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
origin_image = copy.deepcopy(image) | ||
x, y = image.shape[0:2] | ||
if x != input_size or y != input_size: | ||
image = zoom(image, (input_size / x, input_size / y, 1.0), order=3) | ||
inputs = torch.from_numpy(image.astype(np.float32) / 255.0) | ||
inputs = inputs.permute(2, 0, 1) | ||
inputs = inputs.unsqueeze(0).cuda() | ||
with torch.no_grad(): | ||
outputs = net(inputs, False, input_size) | ||
output_masks = outputs['masks'] | ||
out = torch.argmax(torch.softmax(output_masks, dim=1), dim=1).squeeze(0) | ||
prediction = out.cpu().detach().numpy() | ||
if x != input_size or y != input_size: | ||
prediction = zoom(prediction, (x / input_size, y / input_size), order=0) | ||
cropped_image = do_crop(img=origin_image.astype(np.uint8), | ||
prd=(prediction * 255).astype(np.uint8), | ||
crop_size=crop_size) | ||
output_path = os.path.join(target, phase, cls, img) | ||
if not os.path.exists(os.path.join(target, phase, cls)): | ||
os.makedirs(os.path.join(target, phase, cls)) | ||
cropped_image.save(output_path) | ||
|
||
if __name__ == "__main__": | ||
generate_cropped_image() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import os | ||
import random | ||
import numpy as np | ||
import torch | ||
import torchvision.datasets | ||
from scipy import ndimage | ||
from scipy.ndimage.interpolation import zoom | ||
from torch.utils.data import Dataset | ||
from torchvision import transforms | ||
from pycocotools import mask as coco_mask | ||
import os | ||
|
||
mean = [0.485, 0.456, 0.406] | ||
std = [0.229, 0.224, 0.225] | ||
|
||
def random_rot_flip(image, label): | ||
k = np.random.randint(0, 4) | ||
image = np.rot90(image, k) | ||
label = np.rot90(label, k) | ||
axis = np.random.randint(0, 2) | ||
image = np.flip(image, axis=axis).copy() | ||
label = np.flip(label, axis=axis).copy() | ||
return image, label | ||
|
||
|
||
def random_rotate(image, label): | ||
angle = np.random.randint(-20, 20) | ||
image = ndimage.rotate(image, angle, order=0, reshape=False) | ||
label = ndimage.rotate(label, angle, order=0, reshape=False) | ||
return image, label | ||
|
||
|
||
class RandomGenerator(object): | ||
def __init__(self, output_size, low_res, phase): | ||
self.output_size = output_size | ||
self.low_res = low_res | ||
self.phase = phase | ||
|
||
def __call__(self, sample): | ||
image, label = sample['image'], sample['label'] | ||
|
||
if self.phase == "train": | ||
if random.random() > 0.5: | ||
image, label = random_rot_flip(image, label) | ||
elif random.random() > 0.5: | ||
image, label = random_rotate(image, label) | ||
x, y = image.shape[0:2] | ||
if x != self.output_size[0] or y != self.output_size[1]: | ||
image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y, 1.0), order=3) # why not 3? | ||
label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0) | ||
label_h, label_w = label.shape | ||
low_res_label = zoom(label, (self.low_res[0] / label_h, self.low_res[1] / label_w), order=0) | ||
image = torch.from_numpy(image.astype(np.float32) / 255.0) | ||
# image = (image - torch.FloatTensor(mean)) / torch.FloatTensor(std) | ||
image = image.permute(2, 0, 1) | ||
label = torch.from_numpy(label.astype(np.float32)) | ||
low_res_label = torch.from_numpy(low_res_label.astype(np.float32)) | ||
sample = {'image': image, 'label': label.long(), 'low_res_label': low_res_label.long(), 'case_name': sample['case_name']} | ||
return sample | ||
|
||
def convert_coco_poly_to_mask(segmentations, height, width): | ||
masks = [] | ||
for polygons in segmentations: | ||
rles = coco_mask.frPyObjects(polygons, height, width) | ||
mask = coco_mask.decode(rles) | ||
if len(mask.shape) < 3: | ||
mask = mask[..., None] | ||
# mask = torch.as_tensor(mask, dtype=torch.uint8) | ||
mask = np.any(mask, axis=2) | ||
masks.append(mask) | ||
|
||
merged_mask = np.zeros((height, width), dtype=np.uint8) | ||
if masks: | ||
for mask in masks: | ||
merged_mask = merged_mask | mask | ||
|
||
return merged_mask | ||
|
||
class COCO_dataset(torchvision.datasets.CocoDetection): | ||
def __init__(self, img_folder, ann_file, split=None, transform=None): | ||
super(COCO_dataset, self).__init__(img_folder, ann_file) | ||
self.split = split | ||
self.transform = transform | ||
|
||
def __len__(self): | ||
return super(COCO_dataset, self).__len__() | ||
|
||
def __getitem__(self, idx): | ||
img, target = super(COCO_dataset, self).__getitem__(idx) | ||
|
||
# get filename | ||
image_info = self.coco.loadImgs(self.ids[idx])[0] | ||
filename = image_info['file_name'] | ||
|
||
# generate masks | ||
w, h = img.size | ||
segmentations = [obj['segmentation'] for obj in target] | ||
masks = convert_coco_poly_to_mask(segmentations, h, w) | ||
|
||
label_value = target[0]['category_id'] + 1 | ||
masks[masks == 1] = label_value | ||
|
||
img = np.array(img) | ||
|
||
sample = {'image': img, 'label': masks} | ||
|
||
if self.transform: | ||
sample = self.transform(sample) | ||
|
||
sample['case_name'] = os.path.splitext(filename)[0] | ||
|
||
return sample | ||
|
||
class Cancer_dataset(Dataset): | ||
def __init__(self, data_dir, txt_dir, transform=None): | ||
# train or val or test | ||
phase = os.path.splitext(os.path.basename(txt_dir))[0] | ||
file_path = os.path.join(data_dir, phase) | ||
|
||
self.data = [os.path.join(file_path, file) for file in os.listdir(file_path)] | ||
self.transform = transform # using transform in torch! | ||
self.sample_list = open(txt_dir).readlines() | ||
|
||
def __len__(self): | ||
return len(self.sample_list) | ||
|
||
def __getitem__(self, idx): | ||
data_path = self.data[idx] | ||
data_dic = np.load(data_path) | ||
image, label = data_dic['image'], data_dic['label'] | ||
name = os.path.splitext(os.path.basename(data_path))[0] | ||
sample = {'image': image, 'label': label, 'case_name': name} | ||
if self.transform: | ||
sample = self.transform(sample) | ||
|
||
return sample | ||
|
Oops, something went wrong.