SAM (Segment Anything Model) 세그멘테이션 후 배경제거 (누끼따기)
세그멘테이션 (segmentation)은 이미지나 영상의 특정 객체에 대한 픽셀들의 좌표를 구해 다른 물체나 배경과 구별하는 기술입니다.
Meta에서 만든 SAM은 세그멘테이션을 위한 모델로 물체에 대한 점이나 물체를 포함하는 사각형을 입력으로 줘서 세그멘테이션을 수행할 수 있습니다. 그리고 GroundingDINO와 결합하면 GroundingDINO에 물체명을 전달해 사각형을 얻어온 후 SAM에 전달해 물체명을 전달 해 세그멘테이션을 수행할 수 있습니다.
이 글에서는 물체 위의 점의 좌표를 전달 해 세그멘테이션을 수행하고 배경을 투명화 한 이미지를 얻어 보겠습니다.
목차
라이브러리 설치
transformers, torch, matplotlib, numpy를 설치해줍니다.
pip install transformers matplotlib numpy
torch는 https://pytorch.kr/get-started/locally/ 이 링크에서 환경에 맞는 명령어로 설치해줍니다.
코랩은 위의 라이브러리들이 설치되어 있습니다.
세그멘테이션 하기
필요한 라이브러리를 불러오고 cuda 사용 여부에 따라 device를 설정합니다.
from transformers import SamModel, SamProcessor
import torch
import matplotlib.pyplot as plt
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Meta에서 만든 SAM 모델은 그냥 SAM과 SAM2가 있는데 여기서는 SAM을 이용하고 그 중에서도 가장 용량이 작은 base 모델을 이용하겠습니다. base 모델 외에는 large와 huge가 있습니다. 아래코드에서 base를 large 또는 huge로 바꿔주면 사용할 수 있습니다.
허깅페이스에서 모델을 불러옵니다.
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
세그멘테이션 할 이미지를 matplotlib으로 불러오고 확인해보겠습니다. 2번째 줄의 코드는 RGBA 형식의 투명도 채널을 가진 png 이미지에서 알파채널을 없애기 위해 넣었습니다.
img = plt.imread('star.png')
img = img[:, :, :3]
plt.imshow(img)
plt.show()

이미지를 PIL.Image.open으로 이미지를 불러오지 않고 matplotlib으로 불러온 이유는 점의 좌표를 확인하기 위함이였고 PIL.Image.open으로 불러와도 됩니다. 점의 좌표는 그림판으로 이미지를 열어서도 볼 수 있습니다. PIL.Image를 사용한다면 코드가 약간 수정되어야 합니다.
객체 하나에 대한 점 하나로 세그멘테이션
맨 왼쪽 벌처 위의 점 하나를 이용해 세그멘테이션을 해보겠습니다. x=50, y= 80을 고르면 다음과 같이 점에 대한 리스트를 만들 수 있습니다.
points = [[[50,80]]] #점 여러개로 세그멘테이션을 하는 예시 [[[50,80],[60,70]]]
이제 세그멘테이션을 해서 물체에 대한 마스크를 구하겠습니다. 이미지 픽셀의 값은 0~255 이거나 0~1인데 이를 구분해서 rescale 여부를 결정합니다.
if img.max() > 1.0:
rescale = True
else:
rescale = False
inputs = processor(images=img, input_points=points, return_tensors="pt",do_rescale=rescale).to(device)
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
우선 outputs.iou_scores를 확인해보면 세그멘테이션에 대한 점수들이 나옵니다.
print(outputs.iou_scores) # 출력 : tensor([[[0.7525, 0.7871, 0.9555]]])
세그멘테이션 3개에 대해 점수가 나왔습니다. 각각의 세그멘테이션에 대한 마스크는 다음과 같이 접근할 수 있습니다.
print(masks[0][0][0])
print(masks[0][0][1])
print(masks[0][0][2])
# 가장 점수가 높은 mask 가져오기 : masks[0][0][outputs.iou_scores.argmax()]
첫번째 인덱스는 이미지에 대한 인덱스로 이미지를 하나만 전달했기 때문에 0번 밖에 없습니다.
두번째 인덱스는 객체에 대한 인덱스로 위에서 전달한 points는 객체가 하나인 경우에 해당하기 때문에 0번 밖에 없습니다.
세번째 인덱스는 세그멘테이션에 대한 인덱스입니다.
마스크는 각 픽셀에 대해 True 또는 False의 값을 가지고 있고 세그멘테이션 된 픽셀은 True, 되지 않은 픽셀은 False의 값을 가집니다.
마스크와 원본 이미지를 겹쳐서 시각화 해보겠습니다
def mask_to_rgb(*masks):
mg = np.zeros(masks[0].shape + (4, ), dtype=np.uint8)
for mask in masks:
mg[mask == 1] = [0, 255, 0, 127]
return mg
mr0 = mask_to_rgb(masks[0][0][0])
mr1 = mask_to_rgb(masks[0][0][1])
mr2 = mask_to_rgb(masks[0][0][2])
mask를 참조해 반투명 이미지를 만듭니다. 그리고 원래이미지와 겹쳐서 보이게합니다.
fig, axes = plt.subplots(1, 3)
axes[0].imshow(img)
axes[0].imshow(mr0)
axes[1].imshow(img)
axes[1].imshow(mr1)
axes[2].imshow(img)
axes[2].imshow(mr2)
for ax in axes:
ax.axis('off')
plt.tight_layout()
plt.show()

여러 객체에 여러 개의 점으로 세그멘테이션
이번에는 입력했던 points를 바꿔서 그림자에 대한 점을 추가하고 다른 객체에 대한 세그멘테이션도 얻어보겠습니다.
points = [[[[50,80],[50,110]],[[180,80],[177,107]], [[276,85],[279,115]]]]
괄호의 위치를 잘 봐야하는데 객체는 3개이고 각 객체마다 2개의 점을 골랐습니다. 이 리스트로 위의 과정을 수행하면 다음과 같은 형태로 iou_scores가 나옵니다.
print(outputs.iou_scores)
# 출력 : tensor([[[0.9138, 0.8014, 0.8848],
[0.8522, 0.9179, 0.8829],
[0.8364, 0.9401, 0.6040]]])
위에서 사용했던 함수를 이용해 마스크들을 참조해서 반투명 이미지를 만듭니다. 두번째 인덱스로 객체를 구분하고 세번째 인덱스로 세그멘테이션을 선택합니다. 편의상 객체들의 첫번째 세그멘테이션끼리 모으고 두번째끼리, 세번째끼리 모아 이미지를 만들겠습니다.
mr0 = mask_to_rgb(masks[0][0][0],masks[0][1][0],masks[0][2][0])
mr1 = mask_to_rgb(masks[0][0][1],masks[0][1][1],masks[0][2][1])
mr2 = mask_to_rgb(masks[0][0][2],masks[0][1][2],masks[0][2][2])

그림자가 세그멘테이션에 잘 포함되지 않은 것 같은데 세그멘테이션에 포함시키고 싶은 곳에 대한 점을 더 추가하면 결과가 좋아질 수 있습니다.
배경 투명화 하기
세그멘테이션을 통해 얻어진 마스크를 이용해 배경을 투명화 할 수 있습니다. 이미지에 알파채널을 추가하고 입력된 마스크들을 더해 얻어진 최종 마스크의 값에 따라 투명도를 설정합니다.
def transparent(img, *masks):
if img.max() > 1.0:
rescale = 255
else:
rescale = 1
height, width, channels = img.shape
alpha_channel = np.ones((height, width), dtype=np.uint8) * rescale
image_with_alpha = np.dstack((img, alpha_channel))
mask = sum(masks)
image_with_alpha[mask == 0] = [0, 0, 0, 0]
return image_with_alpha
ri = transparent(img,masks[0][0][0],masks[0][1][0],masks[0][2][1])
plt.imsave('generated_image.png', ri)

세그멘테이션을 이용해 배경투명화를 해봤습니다. 세그멘테이션으로 나온 마스크를 이용해 배경제거가 아니라 다른 색으로 배경을 채워넣을 수도 있고 인페인팅으로 배경 또는 물체를 새롭게 그릴 수도 있습니다.
다음 글에서는 사각형을 이용해 세그멘테이션을 해보고 세그멘테이션 점에 대한 라벨을 사용해보겠습니다.
이 글에 대한 주피터노트북 파일입니다.