데이터과학 삼학년

Grad-CAM 본문

Explainable AI

Grad-CAM

Dan-k 2023. 3. 24. 11:47
반응형

Grad-CAM(Gradient-weighted Class Activation Mapping)

딥러닝 모델이 어떤 부분을 보고 특정 클래스를 판단했는지를 시각화하는 기술 -> 이미지에서 주로 사용 -> 텍스트에서도 사용 가능

이를 통해 모델이 어떤 부분을 주로 활용하는지를 알 수 있어 모델의 해석성(interpretability)을 높일 수 있음

 

Grad-CAM은 기존의 Class Activation Mapping(CAM)을 발전시킨 기술로, CAM은 Global Average Pooling(GAP)을 사용하여 클래스에 대한 중요도를 계산

이와 달리 Grad-CAM은 전체적인 특성 맵의 중요도가 아닌 클래스에 대한 중요도를 계산

 

Grad-CAM 계산 단계

1. 모델의 gradient를 계산

- 모델의 gradient를 계산하기 위해서는 해당 클래스의 loss에 대한 gradient가 필요한데 이를 위해 역전파(backpropagation)를 사용 -> loss에 대한 gradient는 모델의 각 층(layer)에 대해 계산되며, 각 층의 gradient는 다음 층으로 전파

2. gradient를 이용하여 특성 맵의 중요도를 계산

- 이를 위해서는 gradient의 각 채널(channel)에 대해 가중치를 계산

- 이 가중치는 gradient의 각 채널에서 클래스에 대한 중요도를 나타내며, 이를 통해 특성 맵의 각 위치에서 클래스에 대한 중요도를 계산

- 이렇게 계산된 클래스에 대한 중요도를 특성 맵에 곱하여 시각화

 

 

Grad-CAM은 딥러닝 모델의 클래스 분류와 같은 문제를 해결하는데 매우 유용

이를 통해 모델이 어떤 부분을 중점적으로 활용하는지를 파악할 수 있으며, 이를 통해 모델을 개선하는 방향으로 이어질 수 있음

https://glassboxmedicine.files.wordpress.com/2020/05/modified-figure-1-dog-cat.png

 

Grad-CAM 코드

import cv2
import torch
import torch.nn as nn
import numpy as np
from torchvision import models, transforms

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.activation = {}
        self.gradient = {}

    def save_gradient(self, name):
        def hook(grad):
            self.gradient[name] = grad
        return hook

    def get_activation(self, x):
        self.activation = {}
        for name, module in self.model._modules.items():
            x = module(x)
            if name == self.target_layer:
                x.register_hook(self.save_gradient(name))
                self.activation[name] = x
        return x

    def forward(self, x):
        return self.model(x)

    def backward(self, idx):
        self.model.zero_grad()
        gradient = self.gradient[self.target_layer]
        gradient = nn.functional.adaptive_avg_pool2d(gradient, 1)
        activation = self.activation[self.target_layer]
        b, c, h, w = activation.shape
        weights = gradient.view(b, c)
        activation = activation.view(b, c, h*w)
        saliency_map = torch.bmm(weights.unsqueeze(1), activation)
        saliency_map = saliency_map.view(b, h, w)
        saliency_map = nn.functional.relu(saliency_map)
        saliency_map = nn.functional.interpolate(saliency_map.unsqueeze(1), (224, 224), mode='bilinear', align_corners=False)
        saliency_map = saliency_map.squeeze()
        saliency_map = saliency_map - saliency_map.min()
        saliency_map = saliency_map / saliency_map.max()
        return saliency_map

model = models.resnet50(pretrained=True)
gradcam = GradCAM(model, 'layer4')

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

img_path = 'example.jpg'
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_tensor = transform(img)
img_tensor = img_tensor.unsqueeze(0)

output = gradcam.forward(img_tensor)
pred = output.argmax()

saliency_map = gradcam.backward(pred)
saliency_map = saliency_map.detach().numpy()
saliency_map = cv2.resize(saliency_map, (img.shape[1], img.shape[0]))
saliency_map = (255*saliency_map).astype(np.uint8)
heatmap = cv2.applyColorMap(saliency_map, cv2.COLORMAP_JET)

result = cv2.addWeighted(img, 0.5, heatmap, 0.5, 0)
cv2.imwrite('result.jpg', result)
728x90
반응형
LIST

'Explainable AI' 카테고리의 다른 글

SHAP 그래프 해석  (0) 2024.01.12
Counterfactual Explanations  (0) 2023.03.27
Feature Interaction  (0) 2022.12.21
Accumulated Local Effects (ALE) Plot  (0) 2022.12.03
LIME 결과 소수점 자리 핸들링  (0) 2022.07.21
Comments