Skip to content

Commit

Permalink
feat : csv 변환 오류 해결 Fixes #9
Browse files Browse the repository at this point in the history
  • Loading branch information
Batwan01 committed Nov 13, 2024
1 parent 3f6ea4c commit 19a5acf
Showing 1 changed file with 28 additions and 25 deletions.
53 changes: 28 additions & 25 deletions yolo/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def decode_rle_to_mask(rle, height, width):
return img.reshape(height, width)

# 모델 로드
model = YOLO("/data/ephemeral/home/jiwan/Gyeonggi-Autonomous-Driving-Center-Data-Utilization-Competition/yolo/runs/segment/train2/weights/best.pt")
model = YOLO("/data/ephemeral/home/jiwan/level2-cv-semanticsegmentation-cv-15-lv3/yolo/runs/segment/train/weights/best.pt")

# 예측할 이미지 폴더 경로 설정
image_folder = "/data/ephemeral/home/dataset_yolo/test"

# CSV 파일 생성 및 헤더 작성
csv_file_path = "/data/ephemeral/home/jiwan/yolo/result/predictions.csv"
csv_file_path = "/data/ephemeral/home/jiwan/level2-cv-semanticsegmentation-cv-15-lv3/yolo/result/predictions.csv"
with open(csv_file_path, mode='w', newline='') as csv_file:
writer = csv.writer(csv_file)
writer.writerow(['image_name', 'class', 'rle'])
Expand All @@ -63,31 +63,34 @@ def decode_rle_to_mask(rle, height, width):
# 각 이미지에 대해 예측 수행 후 저장
for image_path in image_files:
# 예측 수행
results = model(image_path)
results = model(image_path, imgsz = 2048)

# 각 결과에 대해 처리
class_rle_mapping = {class_name: '' for class_name in class_names}
for result in results:
# CSV 파일에 결과 저장
with open(csv_file_path, mode='a', newline='') as csv_file:
writer = csv.writer(csv_file)

# 마스크 결과 가져오기
if result.masks is not None:
# 박스와 마스크 데이터를 클래스 인덱스 순서대로 정렬
sorted_results = sorted(zip(result.boxes, result.masks.data), key=lambda x: int(x[0].cls))

for box, mask in sorted_results:
# 클래스 가져오기
class_idx = int(box.cls)
class_name = class_names[class_idx]

# 마스크를 numpy 배열로 변환하고 RLE로 인코딩
mask_np = mask.cpu().numpy().astype(np.uint8) # dtype을 uint8로 변환
# 각 픽셀 값이 0 또는 1인지 확인 (이진화)
mask_np = (mask_np > 0).astype(np.uint8)
rle_str = encode_mask_to_rle(mask_np) # 커스텀 RLE 인코딩 사용

# CSV 파일에 기록
writer.writerow([os.path.basename(image_path), class_name, rle_str])
# 마스크 결과 가져오기
if result.masks is not None:
# 박스와 마스크 데이터를 클래스 인덱스 순서대로 정렬
sorted_results = sorted(zip(result.boxes, result.masks.data), key=lambda x: int(x[0].cls))

for box, mask in sorted_results:
# 클래스 가져오기
class_idx = int(box.cls)
class_name = class_names[class_idx]

# 마스크를 numpy 배열로 변환하고 RLE로 인코딩
mask_np = mask.cpu().numpy().astype(np.uint8) # dtype을 uint8로 변환
# 각 픽셀 값이 0 또는 1인지 확인 (이진화)
mask_np = (mask_np > 0).astype(np.uint8)
rle_str = encode_mask_to_rle(mask_np) # 커스텀 RLE 인코딩 사용

# 클래스에 해당하는 RLE 문자열 업데이트
class_rle_mapping[class_name] = rle_str

# 모든 클래스에 대해 결과 저장 (예측되지 않은 클래스는 빈 RLE 문자열 유지)
with open(csv_file_path, mode='a', newline='') as csv_file:
writer = csv.writer(csv_file)
for class_name, rle_str in class_rle_mapping.items():
writer.writerow([os.path.basename(image_path), class_name, rle_str])

print(f"Predictions for {image_path} logged in {csv_file_path}")

0 comments on commit 19a5acf

Please sign in to comment.