Skip to content

Commit

Permalink
add SAM
Browse files Browse the repository at this point in the history
  • Loading branch information
VVJia committed Aug 10, 2024
1 parent 584f0ba commit fd5db85
Show file tree
Hide file tree
Showing 27 changed files with 4,023 additions and 3 deletions.
37 changes: 37 additions & 0 deletions SAM/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SAMed_h

## Prerequisites
- Linux (We tested our codes on Ubuntu 18.04)
- Anaconda
- Python 3.10.11
- Pytorch 2.0.0 **(Pytorch 2+ is necessary)**

To get started, first please clone the repo
```
git clone https://github.com/hitachinsk/SAMed.git
cd SAMed_h
```
Then, please run the following commands:
```
conda create -n SAMed_h python=3.10.11
conda activate SAMed_h
pip install -r requirements.txt
```

## Quick start
All the steps are the same as [SAMed](https://github.com/hitachinsk/SAMed). But you need to prepare the [vit_h version of SAM](https://github.com/facebookresearch/segment-anything#model-checkpoints) and [our pretrained checkpoint](https://drive.google.com/file/d/1Kx_vx9bcxJaiMYWAgljNtwtHcooUsq8m/view?usp=sharing).

## Training
We adopt one A100 (80G) for training.
1. Please download the processed [training set](https://drive.google.com/file/d/1zuOQRyfo0QYgjcU_uZs0X3LdCnAC2m3G/view?usp=share_link), whose resolution is `224x224`, and put it in `<Your folder>`. Then, unzip and delete this file. We also prepare the [training set](https://drive.google.com/file/d/1F42WMa80UpH98Pw95oAzYDmxAAO2ApYg/view?usp=share_link) with resolution `512x512` for reference, the `224x224` version of training set is downsampled from the `512x512` version.
2. Run this command to train SAMed.
```bash
python train.py --root_path <Your folder> --output <Your output path> --warmup --AdamW --tf32 --compile --use_amp --lr_exp 7 --max_epochs 400 --stop_epoch 300
```
Check the results in `<Your output path>`, and the training process will consume about 70G GPU memory.

## Difference between SAMed_h and SAMed
- SAMed_h adopts the `vit_h` version of SAM as the base model.
- SAMed_h needs more training iterations. Therefore, we set the max epoch to 400 and early stop to 300 for better performance.
- Too large learning rate will cause the training instability of SAMed_h. Therefore, we increase the exponent of exponential decay from 0.9 to 7, which can greatly reduce the training instability.
- For faster training speed and less memory consumption, SAMed_h adopts auto mixed-precision, tensor-float 32 and `compile` technology in pytorch 2.0. Therefore, pytorch2+ is necessary for training this model.
119 changes: 119 additions & 0 deletions SAM/crop_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import copy
import gzip
import os
import pickle
import sys

from PIL import Image
from tqdm import tqdm
import logging
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn

from scipy.ndimage.interpolation import zoom

from pathlib import Path
import cv2
from scipy.ndimage import label

from segment_anything import sam_model_registry
from sam_lora_image_encoder import LoRA_Sam

def generate_cropped_image():
def do_crop(img, prd, crop_size):
h, w = img.shape[:2]
masked_img = img.copy()
if np.max(prd) == 0:
# 计算中心位置
center_row = h // 2
center_col = w // 2
# 计算裁剪的起始和结束位置
min_row = max(0, center_row - crop_size[0] // 2)
min_col = max(0, center_col - crop_size[1] // 2)
max_row = min(h, center_row + crop_size[0] // 2)
max_col = min(w, center_col + crop_size[1] // 2)

else:
masked_img[prd != 255] = 0

rows, cols = np.where(prd == 255)
min_row, max_row, min_col, max_col = min(rows), max(rows), min(cols), max(cols)
rect_width = max_col - min_col + 1
rect_height = max_row - min_row + 1

if rect_width < crop_size[0] or rect_height < crop_size[1]:
# 计算裁剪区域的边界
crop_min_row = max(0, min_row - max(0, (crop_size[0] - rect_height) // 2))
crop_max_row = min(prd.shape[0], crop_min_row + max(crop_size[0], rect_height))

crop_min_col = max(0, min_col - max(0, (crop_size[1] - rect_width) // 2))
crop_max_col = min(prd.shape[1], crop_min_col + max(crop_size[1], rect_width))
min_row, max_row, min_col, max_col = crop_min_row, crop_max_row, crop_min_col, crop_max_col

# Crop the corresponding region from the original image
cropped_img = Image.fromarray(masked_img[min_row:max_row, min_col:max_col])

return cropped_img

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
root = r"../datasets/dataset1/global"
classes = ['benign', 'tumor', 'normal']
phases = ['test']
source = root
target = root.replace("global", "local_seg")
input_size = 224
crop_size = (256, 256)
cudnn.benchmark = False
cudnn.deterministic = True
seed = 1234
random.seed()
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
rank = 4
lora_ckpt = r"./exp4/epoch_199.pth"
ckpt = r"./checkpoints/sam_vit_b_01ec64.pth"
sam, img_embedding_size = sam_model_registry['vit_b'](image_size=input_size,
num_classes=1,
checkpoint=ckpt,
pixel_mean=[0, 0, 0],
pixel_std=[1, 1, 1])

net = LoRA_Sam(sam, rank).cuda()
net.load_lora_parameters(lora_ckpt)

net.eval()
for phase in phases:
for cls in classes:
imgs = os.listdir(os.path.join(source, phase, cls))
for img in tqdm(imgs):
torch.cuda.empty_cache()
img_path = os.path.join(source, phase, cls, img)
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
origin_image = copy.deepcopy(image)
x, y = image.shape[0:2]
if x != input_size or y != input_size:
image = zoom(image, (input_size / x, input_size / y, 1.0), order=3)
inputs = torch.from_numpy(image.astype(np.float32) / 255.0)
inputs = inputs.permute(2, 0, 1)
inputs = inputs.unsqueeze(0).cuda()
with torch.no_grad():
outputs = net(inputs, False, input_size)
output_masks = outputs['masks']
out = torch.argmax(torch.softmax(output_masks, dim=1), dim=1).squeeze(0)
prediction = out.cpu().detach().numpy()
if x != input_size or y != input_size:
prediction = zoom(prediction, (x / input_size, y / input_size), order=0)
cropped_image = do_crop(img=origin_image.astype(np.uint8),
prd=(prediction * 255).astype(np.uint8),
crop_size=crop_size)
output_path = os.path.join(target, phase, cls, img)
if not os.path.exists(os.path.join(target, phase, cls)):
os.makedirs(os.path.join(target, phase, cls))
cropped_image.save(output_path)

if __name__ == "__main__":
generate_cropped_image()
Empty file added SAM/datasets/__init__.py
Empty file.
137 changes: 137 additions & 0 deletions SAM/datasets/dataset_cancer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import os
import random
import numpy as np
import torch
import torchvision.datasets
from scipy import ndimage
from scipy.ndimage.interpolation import zoom
from torch.utils.data import Dataset
from torchvision import transforms
from pycocotools import mask as coco_mask
import os

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

def random_rot_flip(image, label):
k = np.random.randint(0, 4)
image = np.rot90(image, k)
label = np.rot90(label, k)
axis = np.random.randint(0, 2)
image = np.flip(image, axis=axis).copy()
label = np.flip(label, axis=axis).copy()
return image, label


def random_rotate(image, label):
angle = np.random.randint(-20, 20)
image = ndimage.rotate(image, angle, order=0, reshape=False)
label = ndimage.rotate(label, angle, order=0, reshape=False)
return image, label


class RandomGenerator(object):
def __init__(self, output_size, low_res, phase):
self.output_size = output_size
self.low_res = low_res
self.phase = phase

def __call__(self, sample):
image, label = sample['image'], sample['label']

if self.phase == "train":
if random.random() > 0.5:
image, label = random_rot_flip(image, label)
elif random.random() > 0.5:
image, label = random_rotate(image, label)
x, y = image.shape[0:2]
if x != self.output_size[0] or y != self.output_size[1]:
image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y, 1.0), order=3) # why not 3?
label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
label_h, label_w = label.shape
low_res_label = zoom(label, (self.low_res[0] / label_h, self.low_res[1] / label_w), order=0)
image = torch.from_numpy(image.astype(np.float32) / 255.0)
# image = (image - torch.FloatTensor(mean)) / torch.FloatTensor(std)
image = image.permute(2, 0, 1)
label = torch.from_numpy(label.astype(np.float32))
low_res_label = torch.from_numpy(low_res_label.astype(np.float32))
sample = {'image': image, 'label': label.long(), 'low_res_label': low_res_label.long(), 'case_name': sample['case_name']}
return sample

def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
if len(mask.shape) < 3:
mask = mask[..., None]
# mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = np.any(mask, axis=2)
masks.append(mask)

merged_mask = np.zeros((height, width), dtype=np.uint8)
if masks:
for mask in masks:
merged_mask = merged_mask | mask

return merged_mask

class COCO_dataset(torchvision.datasets.CocoDetection):
def __init__(self, img_folder, ann_file, split=None, transform=None):
super(COCO_dataset, self).__init__(img_folder, ann_file)
self.split = split
self.transform = transform

def __len__(self):
return super(COCO_dataset, self).__len__()

def __getitem__(self, idx):
img, target = super(COCO_dataset, self).__getitem__(idx)

# get filename
image_info = self.coco.loadImgs(self.ids[idx])[0]
filename = image_info['file_name']

# generate masks
w, h = img.size
segmentations = [obj['segmentation'] for obj in target]
masks = convert_coco_poly_to_mask(segmentations, h, w)

label_value = target[0]['category_id'] + 1
masks[masks == 1] = label_value

img = np.array(img)

sample = {'image': img, 'label': masks}

if self.transform:
sample = self.transform(sample)

sample['case_name'] = os.path.splitext(filename)[0]

return sample

class Cancer_dataset(Dataset):
def __init__(self, data_dir, txt_dir, transform=None):
# train or val or test
phase = os.path.splitext(os.path.basename(txt_dir))[0]
file_path = os.path.join(data_dir, phase)

self.data = [os.path.join(file_path, file) for file in os.listdir(file_path)]
self.transform = transform # using transform in torch!
self.sample_list = open(txt_dir).readlines()

def __len__(self):
return len(self.sample_list)

def __getitem__(self, idx):
data_path = self.data[idx]
data_dic = np.load(data_path)
image, label = data_dic['image'], data_dic['label']
name = os.path.splitext(os.path.basename(data_path))[0]
sample = {'image': image, 'label': label, 'case_name': name}
if self.transform:
sample = self.transform(sample)

return sample

Loading

0 comments on commit fd5db85

Please sign in to comment.