Major Study./Computer Science

pytorch hub, torchvision으로 받은 모델의 forward 가져오기

sosal 2024. 3. 9. 05:01
반응형

Transfer learning을 수행해야할 때,

torch의 hub를 활용해서 기존의 pretrained model을 가져오는 경우가 많이 있다.

 

그러나, 단순히 마지막 fully connected layer만을 없애고 싶은게 아니라,

중간의 feature부터 활용하고 싶은 경우가 있는데, 이런 경우는 forward 함수를 건드리면 제일 간편하다.

 

 

예를 들어, vision transformer에서, 마지막 cls token의 값을 가져오는게 아닌

patch의 정보를 가져오고 싶을때?

단순히 모델의 architecture를 수정한다고 해결할 수 있는 문제는 아니다.

forward 함수에서, cls token만 짚어서 return하고 있기 때문이다.

 

이런 경우, python의 inspect를 활용하면 매우 간편하다.

 

import torch

import inspect
from torchvision.models import VisionTransformer, vit_b_16, ViT_B_16_Weights

model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
model.eval()

 

 

ln [4] forward_source = inspect.getsource(model.forward)
ln [5] print(forward_source)
    def forward(self, x: torch.Tensor):
        # Reshape and permute the input tensor
        x = self._process_input(x)
        n = x.shape[0]

        # Expand the class token to the full batch
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        x = self.encoder(x)

        # Classifier "token" as used by standard language architectures
        x = x[:, 0]
        x = self.heads(x)

        return x

 

 

이제 forward 함수를 직접 확인할 수 있으니,

여기서 cls token만 가져오는 x[:, 0] 을 수행하지 않고 x를 바로 return하면 patch의 feature를 가져올 수 있다.