들어가며
지난 4월 Meta에서 Segment Anything(SAM)이라는 논문이 나왔습니다.
CLIP처럼 확장성이 좋은 것 같아 곧 SAM을 활용한 논문이 우후죽순 나올 것 같아
읽어봐야지 읽어봐야지 하다가 이제서야 정리해봅니다 :)
⛳️ Image Segmentation 분야의 foundation model
최근 자연어 처리(NLP)분야에서는 매우 큰 데이터 셋으로 학습한 대용량 언어 모델 (Large Language Model)의 일반화 (generalization)성능이 우수하여 프롬프트 튜닝을 통해 zero-shot, few-shot downstrem task에서도 좋은 성능을 보여주고 있다고 합니다.
이러한 LLM은 다른 하위 태스크 수행의 기반이 되는 foundation 모델이라고도 부르는데, 본 논문은 Image segmentation 분야에서의 foundation model을 만드는 것을 목표로 합니다.
논문에서는 foundation model을 아래와 같이 정의하고 있습니다.
Models that are trained on broad data at scale and are adaptable to a wide range of downstream tasks
대충 해석해보자면 큰 규모의 데이터 셋으로 학습해 다양한 downstream task에 활용할 수 있는 모델 정도가 되겠네요
논문에서는 Image segmentation 분야의 foundation model이라고 말하고는 있지만,
활용되고 있는 방향을 보면 computer vision 분야의 foundation model이라고 해도 될 것 같은 수준이네요,,
아마 생성 부분에서는 foundation model로 활용할 정확한 방향성을 제시하지 않았기 때문에
computer vision의 foundation 모델이라고는 말하지 못한게 아닐까요??
해당 논문에서는 제안하는 foundation model을 만들기 위해
- 어떤 task를 학습해야 하는가?
- 어떤 data를 사용해야 하는가?
- 어떤 model 구조를 사용해야하는가?
라는 세 가지 관점에 대한 질문을 던지고 이에 대한 나름의 해결 방법을 제시합니다.
🤖 뭘 학습할까? - Promptable Segmentation
앞서 말했듯 자연어 처리 분야의 foundation model은 프롬프트를 잘 조절하는 방법을 통해 zero-shot, few-shot task에 활용되어왔습니다.
이에 영감을 받은 저자들은 모델이 다양한 프롬프트를 처리하는 능력이 있어야 한다 생각했고, 최종적으로 promptable segmentation task를 제안합니다.
Promptable segmentation task는 다양한 형태의 프롬프트를 받아서 segmentation을 수행하는 task입니다.
논문에서는 포인트, 박스, 마스크, 텍스트 형태의 입력을 프롬프트로 받아서 segmentation을 수행하고 있습니다.
💽 데이터 수집 - 3 step Data Engine
앞서 말한 promptable segmentation task를 학습하기 위해서는 이미지, 프롬프트, segmentation 마스크가 쌍으로 이루어진 거대한 데이터 셋이 필요합니다.
일반적인 LLM은 인터넷에서 크롤링한 자연어를 학습에 사용하는데,
SAM학습에 필요한 image, segmentation mask쌍은 인터넷에서 쉽게 찾아내기 어려운 데이터 입니다.
Meta에서는 규모가 큰 image, segmentation mask 페어 데이터를 수집하기 위해
총 3단계에 걸친 데이터 엔진을 구축했고, 이를 기반으로 SA-1B 데이터셋을 수집했다고 합니다.
Step 1. Assisted-manual stage
첫 번째 단계는 assisted-manual stage라고 부릅니다.
먼저 기존에 공개된 벤치마크 데이터 셋들로 SAM모델을 학습해두고,
SAM의 inference결과를 라벨러들이 수정하여 최종 segmentation mask를 수집합니다.
이 때 라벨러들은 stuff를 라벨링 할지, things를 라벨링할지와 같은 의미론적인 제약을 받지 않았고,
라벨링한 class이름도 작성하지 않았습니다.
단지 이미지에서 가장 눈에 띄는 (prominence)부분부터 라벨링하고, 라벨링에 30초 이상이 걸리면 다음 이미지로 넘어가도록 가이드를 주었다고 합니다.
논문의 저자들은 일정 수준의 데이터가 수집되면 SAM 모델을 재학습했는데,
Step 1에서는 SAM모델을 총 6번 재학습했고,
총 120K 장의 이미지에 해당하는 4.3M 개의 마스크를 수집했습니다.
Step 2. Semi-automatic stage
두 번째 단계는 semi-automatic stage라고 부르며, 마스크의 다양성 확보를 목표로 했습니다.
저자들은 먼저 1단계에서 수집한 마스크들을 기반으로 이미지를 넣었을 때
SAM이 segmentation할 것 같은 object 박스를 예측하는 detector를 학습했습니다.
이 detector를 기반으로 SAM의 Inference결과에서 confident하다고 보이는 mask만 남겨둡니다.
라벨러들은 이렇게 confident하지 않은 부분이 빠진 segmentation mask를 받아서
남아있는 segmentation mask들은 수정하지 않고, 추가적인 물체들만 라벨링합니다.
이 때 라벨러들은 1단계에서보다 눈에 덜 띄는 물체들에 집중해달라는 지시를 받았다고 합니다.
Step 2에서는 SAM모델을 총 5번 재학습했고,
총 180K 장의 이미지에 해당하는 5.9M 개의 마스크를 추가로 수집하여
총 300K 장의 이미지와 이에 해당하는 10.2M 개의 마스크를 수집했습니다.
Step 3. Fully automatic stage
마지막 단계인 Fully automatic stage 에서는 사람의 개입없이
이전 두 단계에서 수집한 데이터를 사용해 학습한 SAM을 활용하여 segmentation mask를 생성했다고 합니다.
이 과정에서는 총 11M 장의 이미지에 해당하는 1.1B 개의 마스크를 생성하였고,
최종적으로 공개된 SA-1B 데이터 셋은 이 단계에서 수집된 데이터만을 포함하고 있다고 합니다.
🏛️ 어떤 모델 구조를 사용할까?
논문에서는 SAM(Segment Anything Model)의 구조를 디자인 할 때 아래와 같은 사항을 염두에 두고 디자인했다고 합니다.
- 다양한 프롬프트를 입력으로 받을 수 있어야 함
- 프롬프트를 입력하는 사용자와 interactive하게 동작할 수 있도록 Amortized real-time으로 마스크를 생성할 수 있어야 함
- 입력의 모호성(ambiguity)을 해결할 수 있어야 함
이 때 저자들이 말하는 모호성이라는 개념이 새로울 수 있는데,
아래 그림의 1열과 같이 타조의 부리 근처에 초록 점 하나가 프롬프트로 주어졌을 때
사용자가 원하는 마스크는 타조 전체에 해당하는 마스크가 될 수도 있고, 타조의 몸통 부분만 포함하는 마스크, 타조의 머리 부분만 포함하는 마스크 등 여러 마스크가 될 수 있어 프롬프트는 모호성을 가지고 있고,
모델은 이를 염두에 두어 설계되어야 한다고 말합니다.
위 사항들을 염두에 두어 디자인된 최종 SAM은 비교적 간단한 모델 구조를 사용할 수 있고, 크게 아래 세가지 부분으로 나눌 수 있습니다.
- Image Encoder: 이미지에서 feature를 뽑아내는 트랜스포머 인코더
- Prompt Encoder: 입력으로 주어진 프롬프트 정보를 추출하는 인코더
- Mask Decoder: 이미지 인코더와 프롬프트 인코더에서 나온 정보를 합쳐서 segmentation mask를 최종적으로 생성하는 디코더
이 때 이미지 인코더는 real-time으로 동작하지 않지만, 인코더로 feature를 한 번 뽑아 두면
그 뒤의 prompt encoder와 mask decoder는 real-time으로 다양한 프롬프트에 해당하는 마스크를 생성할 수 있기 때문에
SAM은 amortized real-time으로 동작한다고 이야기 합니다.
Image Encoder
Masked AutoEncoder(MAE) 형태로 pretrain된 ViT 구조를 사용합니다.
MAE는 Meta에서 CVPR 2022에 발표한 또 다른 연구로,
이미지의 일부를 마스킹하고 다시 복구하는 트랜스포머를 학습하면,
이 트랜스포머의 인코더를 좋은 feature extractor로 사용할 수 있다고 합니다.
Prompt Encoder
프롬프트 인코더는 다양한 형태로 주어지는 프롬프트를 이미지 feature와 함께 잘 사용할 수 있도록 만들어 주는 역할을 수행합니다.
SAM에서는 아래와 같이 프롬프트의 형태에 따라 다른 인코딩 방식을 사용하고 있습니다.
- Mask: CNN을 통해 인코딩 해서 이미지 feature와 element-wise로 더함
- Point and Box: Positional encoding과 학습된 embedding을 더해서 디코더 입력으로 사용
- Text: Pretrain된 CLIP 모델을 통해 인코딩해서 디코더 입력으로 사용
박스 정보를 인코딩 할 때는 네모의 좌상단 좌표와 우하단 좌표를 사용했다고 합니다.
Mask Decoder
마스크 디코더는 앞에서 만든 이미지 feature와 인코딩된 프롬프트 정보를 입력으로 받아 세그멘테이션 마스크를 생성해냅니다.
이 때 저자들은 이미지 feature를 한 번 뽑아 두면 해당 feature를 사용해 다양한 프롬프트로 마스크를 생성하는 과정은 real-time으로 지원하는 amortized real-time을 목표로 했으므로, 마스크 디코더는 상당히 가볍게 디자인 되었습니다.
마스크 디코더에는 ViT의 cls 토큰과 유사한 learnable output token이 있습니다.
이를 인코딩 된 프롬프트들과 concat하여 self attention 연산을 거치고,
프롬프트 토큰이 query, 이미지가 key, value인 cross attention 연산과
반대로 이미지가 query, 프롬프트 토큰이 key, value인 cross attention 연산을 통해 마스크를 계산합니다.
마스크 디코더는 오버뷰 그림의 가위처럼 하나의 프롬프트에 대해 하나의 마스크만을 예측하는 것이 아니라 여러개의 마스크를 예측합니다.
이는 앞서 말한 ambiguity를 해결하기 위함인데요, 여러 개의 마스크와 함께 각 마스크와 GT 사이의 IoU를 예측하는 head도 학습하여 예측된 IoU가 가장 높은 마스크를 최종 예측으로 사용하여 학습합니다.
학습 시 예측하는 마스크의 갯수는 3개로도 충분했고, 보통 whole(ex) 타조), part(ex) 타조의 머리), subpart(ex) 타조의 부리)에 마스크가 생성되었다고 하네요 :)
Training Loss
학습에는 segmentation에서 흔히 사용되는 focal loss와 dice loss를 사용했고,
IoU 예측을 학습하기 위해서는 예측된 IoU와 실제 IoU의 MSE(Mean-squared-error)를 loss로 사용했다고 합니다.
🛠️ SAM은 진짜 foundation model 인가? - Zero-shot Transfer Ability
SAM이 실제로도 foundation model이라고 불릴만한 일반화 능력을 갖췄는지 확인하기 위해
저자들은 총 5개의 task를 수행해봅니다.
결론은 task들을 무리 없이 잘 수행한다는 것이고, 관심있으신 분들은 한 번 찾아보시면 좋을 것 같습니다.
여기서는 제가 관심있었던 Zero-shot Text-to-Mask만 간략히 소개드려 보겠습니다!
Zero-shot Text-to-Mask
본 task의 목표는 사용자에게 텍스트를 입력으로 받아
텍스트가 설명하고 있는 물체를 이미지에서 찾아 segmentation mask를 리턴하는 것 입니다.
가장 처음 SAM에 대한 이야기를 들었을 때 제가 가장 궁금해했고, 기대했던 부분이 바로 이 부분이었지만 성능이 아주 잘 나오는 것은 아닌 것 같아서 아쉽네요 🥹
실제 논문에서도 Proof-of-Concept 정도의 실험이라고 언급하고, 이미지 샘플도 몇 장 보여주지 않은걸로 봐서는 아직 갈 길이 머나 봅니다 🥲
저자들은 Step 1과 Step 2에서 수집한 마스크 중
픽셀이 $100 \times 100$개 이상 포함된 마스크들의 CLIP 이미지 임베딩을 추출했다고 합니다.
그리고 이 임베딩을 프롬프트로 사용하여 똑같이 segmentation을 학습했는데,
CLIP은 이미지와 텍스트 임베딩이 같은 공간에 있으므로,
학습은 이미지 임베딩으로 하더라도 inference 시에는 텍스트 임베딩으로 segmentation mask를 추론할 수 있다고 합니다.
위 방법으로 Zero-shot Text-to-Mask task를 수행하면
Segmentation 학습 시 Text label이 필요 없다는 장점이 있지만,
코드에도 공개되지 않았고, 논문에도 샘플이 많이 없는 것을 보면 성능은 별로 좋지 않은 걸로 예상됩니다.
마무리
본 포스트에서 대충 SAM에 대해 정리를 해보았는데, 내용상 다 담을 수 없지만 학습할 때 여러 종류의 프롬프트가 주어지면 마스크를 하나만 예측하는 등 다양한 테크닉들이 많이 사용되었고, 모든 과정이 본문 및 Appendix에 자세히 잘 설명되어있습니다.
물론 몰라도 활용하는 데는 문제가 없겠지만, 궁금하신 분들은 한 번쯤 찾아보시면 좋을 것 같네요 :)
잘못된 내용이 있다면 알려주시길 바랍니다! 감사합니다 😊
'개발 > 📑 논문 리뷰' 카테고리의 다른 글
[논문 리뷰] RS-DPO(Rejection Sampling DPO) (2) | 2024.11.19 |
---|