SAM (Segment Anything Model) 점, 박스, 라벨을 이용한 세그멘테이션
저번 글에서 객체에 대한 점을 입력으로 전달해 세그멘테이션을 하고 배경제거를 해봤습니다. 이번 글에서는 추가로 객체를 포함하는 박스와 점에 대한 라벨을 전달해 세그멘테이션을 해보겠습니다.
박스로 세그멘테이션 하기

from transformers import SamModel, SamProcessor
import torch
import matplotlib.pyplot as plt
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
라이브러리와 모델을 불러옵니다.
박스로 세그멘테이션을 하기 위해서는 먼저 이미지의 객체를 포함하는 박스의 좌표를 담은 리스트를 만들어야 합니다. 박스 하나는 4개의 숫자로 표현이 됩니다.
순서대로 박스 왼쪽 위 점의 x 좌표, 박스 왼쪽 위 점의 y좌표, 박스 오른쪽 아래 점의 x좌표, 박스 오른쪽 아래 점의 y좌표입니다.
객체가 3개이므로 아래와 같은 형식으로 만들었습니다.
box = [[[21,15,116,130],[165,51,226,115],[239,55,297,123]]]
저번 코드의 세그멘테이션 부분 코드를 약간 수정해서 input_boxes 매개변수에 box를 전달하겠습니다.
if img.max() > 1.0:
rescale = True
else:
rescale = False
inputs = processor(images=img, input_boxes = box, return_tensors="pt",do_rescale=rescale).to(device)
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
저번에는 세그멘테이션을 할 때 processor에 input_points를 입력으로 주었는데 이번에는 input_boxes를 입력으로 줬습니다.
저번 글에서처럼 masks를 이미지로 만들어 원래 이미지와 겹쳐서 출력하면 아래처럼 나옵니다. 점으로 세그멘테이션 했을 때보다 더 잘 되는 것 같습니다.

점에 대한 라벨 이용하기
세그멘테이션을 할 때 점과 함께 점에 대한 라벨을 전달할 수 있습니다. 라벨은 세그멘테이션을 할 객체인지 아닌지를 구분하는 데 이용할 수 있습니다. 다음과 같은 라벨의 값을 경우에 맞게 사용합니다.
- 1 : 세그멘테이션 할 객체에 대한 점
- 0 : 세그멘테이션 할 객체가 아닌 것에 대한 점
- -1 : 배경에 대한 점
위에서 이용한 박스를 그대로 사용하고 그림자에 대한 점 3개에 대해 라벨 1, 물체에 대한 점1개에 라벨 0, 배경에 대한 점 1개에 라벨 -1로 세그멘테이션 해보겠습니다.
box = [[[21,15,116,130],[165,51,226,115],[239,55,297,123]]]
points = [[[[82,69],[50,106],[103,73],[64,64],[58,51]],[[182,106],[196,101],[204,108],[186,60],[197,106]], [[286,100],[258,109],[276,112],[270,64],[243,62]]]]
labels = [[[1,1,1,0,-1],[1,1,1,0,-1], [1,1,1,0,-1]]]
이들을 numpy 배열로 변환해 shape를 출력하면 의미를 이해하기가 좀 더 쉽습니다.
n1 = np.array(box)
n2 = np.array(points)
n3 = np.array(labels)
print(n1.shape)
print(n2.shape)
print(n3.shape)
박스 : (1, 3, 4)
점 : (1, 3, 5, 2)
라벨 : (1, 3, 5)
각 차원이 나타내는 의미는 아래와 같습니다.
- 첫번째 차원 : 이미지의 갯수
- 두번째 차원 : 세그멘테이션 할 객체의 수
- 박스 세 번째 차원 : 박스는 값 4개로 표현 됨 (x1,y1,x2,y2)
- 점 세 번째 차원 : 객체마다 점 5개를 세그멘테이션하는데 이용
- 점 네 번째 차원 : 점은 값 2개로 표현 됨 (x,y)
- 라벨 세 번째 차원 : 점 5개를 이용함
processor에 점과 라벨도 전달하겠습니다.
if img.max() > 1.0:
rescale = True
else:
rescale = False
inputs = processor(images=img, input_boxes = box, input_points = points, input_labels=labels, return_tensors="pt",do_rescale=rescale).to(device)
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)

박스만 이용했을 때와 비교해보면 그림자가 추가되고 물체가 제외된 것으로 보입니다.