데이터과학 삼학년

Fine tuning 본문

Machine Learning

Fine tuning

Dan-k 2022. 1. 19. 14:31
반응형

Fine tuning

기존 학습되어져 있는 모델을 기반으로 새로운 목적(데이터에 맞게)에 맞게 모델을 변형

- 모델의 파라미터를 미세하게 조정하는 행위 : pre-trained model에 추가 데이터를 이용하여 파라미터를 업데이트

- 파인튜닝은 정교한 파라미터 튜닝

 


Keras에서 fine tuning

- Trainable = False 이냐 True이냐로 fine tuning을 할지 말지 결정

- https://keras.io/guides/transfer_learning/

 

Build a model

Now let's built a model that follows the blueprint we've explained earlier.

Note that:

  • We add a Rescaling layer to scale input values (initially in the [0, 255] range) to the [-1, 1] range.
  • We add a Dropout layer before the classification layer, for regularization.
  • We make sure to pass training=False when calling the base model, so that it runs in inference mode, so that batchnorm statistics don't get updated even after we unfreeze the base model for fine-tuning.
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 2,049
Non-trainable params: 20,861,480
_________________________________________________________________

Train the top layer

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
291/291 [==============================] - 133s 451ms/step - loss: 0.1670 - binary_accuracy: 0.9267 - val_loss: 0.0830 - val_binary_accuracy: 0.9716
Epoch 2/20
291/291 [==============================] - 135s 465ms/step - loss: 0.1208 - binary_accuracy: 0.9502 - val_loss: 0.0768 - val_binary_accuracy: 0.9716
Epoch 3/20
291/291 [==============================] - 135s 463ms/step - loss: 0.1062 - binary_accuracy: 0.9572 - val_loss: 0.0757 - val_binary_accuracy: 0.9716
Epoch 4/20
291/291 [==============================] - 137s 469ms/step - loss: 0.1024 - binary_accuracy: 0.9554 - val_loss: 0.0733 - val_binary_accuracy: 0.9725
Epoch 5/20
291/291 [==============================] - 137s 470ms/step - loss: 0.1004 - binary_accuracy: 0.9587 - val_loss: 0.0735 - val_binary_accuracy: 0.9729
Epoch 6/20
291/291 [==============================] - 136s 467ms/step - loss: 0.0979 - binary_accuracy: 0.9577 - val_loss: 0.0747 - val_binary_accuracy: 0.9708
Epoch 7/20
291/291 [==============================] - 134s 462ms/step - loss: 0.0998 - binary_accuracy: 0.9596 - val_loss: 0.0706 - val_binary_accuracy: 0.9725
Epoch 8/20
291/291 [==============================] - 133s 457ms/step - loss: 0.1029 - binary_accuracy: 0.9592 - val_loss: 0.0720 - val_binary_accuracy: 0.9733
Epoch 9/20
291/291 [==============================] - 135s 466ms/step - loss: 0.0937 - binary_accuracy: 0.9625 - val_loss: 0.0707 - val_binary_accuracy: 0.9721
Epoch 10/20
291/291 [==============================] - 137s 472ms/step - loss: 0.0967 - binary_accuracy: 0.9580 - val_loss: 0.0720 - val_binary_accuracy: 0.9712
Epoch 11/20
291/291 [==============================] - 135s 463ms/step - loss: 0.0961 - binary_accuracy: 0.9612 - val_loss: 0.0802 - val_binary_accuracy: 0.9699
Epoch 12/20
291/291 [==============================] - 134s 460ms/step - loss: 0.0963 - binary_accuracy: 0.9638 - val_loss: 0.0721 - val_binary_accuracy: 0.9716
Epoch 13/20
291/291 [==============================] - 136s 468ms/step - loss: 0.0925 - binary_accuracy: 0.9635 - val_loss: 0.0736 - val_binary_accuracy: 0.9686
Epoch 14/20
291/291 [==============================] - 138s 476ms/step - loss: 0.0909 - binary_accuracy: 0.9624 - val_loss: 0.0766 - val_binary_accuracy: 0.9703
Epoch 15/20
291/291 [==============================] - 136s 467ms/step - loss: 0.0949 - binary_accuracy: 0.9598 - val_loss: 0.0704 - val_binary_accuracy: 0.9725
Epoch 16/20
291/291 [==============================] - 133s 456ms/step - loss: 0.0969 - binary_accuracy: 0.9586 - val_loss: 0.0722 - val_binary_accuracy: 0.9708
Epoch 17/20
291/291 [==============================] - 135s 464ms/step - loss: 0.0913 - binary_accuracy: 0.9635 - val_loss: 0.0718 - val_binary_accuracy: 0.9716
Epoch 18/20
291/291 [==============================] - 137s 472ms/step - loss: 0.0915 - binary_accuracy: 0.9639 - val_loss: 0.0727 - val_binary_accuracy: 0.9725
Epoch 19/20
291/291 [==============================] - 134s 460ms/step - loss: 0.0938 - binary_accuracy: 0.9631 - val_loss: 0.0707 - val_binary_accuracy: 0.9733
Epoch 20/20
291/291 [==============================] - 134s 460ms/step - loss: 0.0971 - binary_accuracy: 0.9609 - val_loss: 0.0714 - val_binary_accuracy: 0.9716

<keras.callbacks.History at 0x7f4494e38f70>

Do a round of fine-tuning of the entire model

Finally, let's unfreeze the base model and train the entire model end-to-end with a low learning rate.

Importantly, although the base model becomes trainable, it is still running in inference mode since we passed training=False when calling it when we built the model. This means that the batch normalization layers inside won't update their batch statistics. If they did, they would wreck havoc on the representations learned by the model so far.

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 20,809,001
Non-trainable params: 54,528
_________________________________________________________________
Epoch 1/10
291/291 [==============================] - 567s 2s/step - loss: 0.0749 - binary_accuracy: 0.9689 - val_loss: 0.0605 - val_binary_accuracy: 0.9776
Epoch 2/10
291/291 [==============================] - 551s 2s/step - loss: 0.0559 - binary_accuracy: 0.9770 - val_loss: 0.0507 - val_binary_accuracy: 0.9798
Epoch 3/10
291/291 [==============================] - 545s 2s/step - loss: 0.0444 - binary_accuracy: 0.9832 - val_loss: 0.0502 - val_binary_accuracy: 0.9807
Epoch 4/10
291/291 [==============================] - 558s 2s/step - loss: 0.0365 - binary_accuracy: 0.9874 - val_loss: 0.0506 - val_binary_accuracy: 0.9807
Epoch 5/10
291/291 [==============================] - 550s 2s/step - loss: 0.0276 - binary_accuracy: 0.9890 - val_loss: 0.0477 - val_binary_accuracy: 0.9802
Epoch 6/10
291/291 [==============================] - 588s 2s/step - loss: 0.0206 - binary_accuracy: 0.9916 - val_loss: 0.0444 - val_binary_accuracy: 0.9832
Epoch 7/10
291/291 [==============================] - 542s 2s/step - loss: 0.0206 - binary_accuracy: 0.9923 - val_loss: 0.0502 - val_binary_accuracy: 0.9828
Epoch 8/10
291/291 [==============================] - 544s 2s/step - loss: 0.0153 - binary_accuracy: 0.9939 - val_loss: 0.0509 - val_binary_accuracy: 0.9819
Epoch 9/10
291/291 [==============================] - 548s 2s/step - loss: 0.0156 - binary_accuracy: 0.9934 - val_loss: 0.0610 - val_binary_accuracy: 0.9807
Epoch 10/10
291/291 [==============================] - 546s 2s/step - loss: 0.0176 - binary_accuracy: 0.9936 - val_loss: 0.0561 - val_binary_accuracy: 0.9789

<keras.callbacks.History at 0x7f4495056040>

After 10 epochs, fine-tuning gains us a nice improvement here.

 


전이학습의 전체 과정

실습을 하기 전에 전이학습의 전체 과정을 요약해본다면 다음과 같습니다.

1) 사전학습 모델 선택하기

다양하게 공개되어 있는 사전학습 모델 중에서, 내 문제를 푸는 것에 적합해보이는 모델을 선택합니다. 만약 Keras를 이용한다면, 간단하게 바로 여러가지 사전학습 모델을 사용할 수 있습니다. 대표적인 사전학습 모델로는 VGG (Simonyan & Zisserman 2014), Inception (Szegedy et al. 2015), 그리고 ResNet5 (He et al. 2015) 등이 있습니다. 여기를 보시면 Keras가 제공하는 모든 모델을 확인할 수 있습니다.

2) 내 문제가 데이터크기-유사성 그래프에서 어떤 부분에 속하는지 알아보기

[그림 3]은 데이터 크기 데이터 간의 유사성을 기반으로 네 가지 상황을 구분합니다. 이 그래프가 제시하는 네 가지 상황에 따라 내가 속한 상황은 무엇인지 알아보고, 어떤 전략을 써야 하는지를 결정할 수 있습니다. 여기서 작은 데이터셋의 크기는 경험적으로 1000개 이하의 이미지 데이터셋 정도를 말합니다. 데이터셋의 유사성에 대해서는 상식의 선에서 생각해볼 수 있습니다. 예를 들어, 고양이와 강아지를 분류해야 하는 상황이라면, ImageNet 모델은 이미 고양이와 강아지에 대해 학습했기 때문에 내 문제와 유사하다고 할 수 있지만, 암세포를 구분해야 하는 상황이라면, ImageNet은 유사한 데이터셋을 학습한 모델이라고 볼 수 없을 것입니다.

3) 내 모델을 Fine-tuning 하기

위에서 데이터 크기와 유사성을 고려해서 내가 어떤 상황에 속하는지 알아보았다면, 이제 각 상황에 따라 구체적으로 어떻게 Fine-tuning을 진행해야 하는지 결정합니다. [그림 4]는 앞서 알아본 사전학습 모델을 재정의하는 세 가지 전략을 각각 어떤 상황에 적용시켜야 하는지 보여줍니다. 다음은 각 상황에 대한 세부 설명입니다.

  • Quadrant 1 : 크기가 크고 유사성이 작은 데이터셋일 때
    이 경우에는 [전략 1]이 적합합니다. 데이터셋의 크기가 크므로, 모델을 다시 처음부터 내가 원하는 대로 완전히 다시 학습시킬 수 있습니다. 비록 유사성이 거의 없는 데이터로 새로 학습을 시켜야 한다고 해도, 사전 학습 모델의 구조와 파라미터들을 가지고 시작하는 것은 아예 처음 시작하는 것보다 유용할 것입니다.
  • Quadrant 2 : 크기가 크고 유사성도 높은 데이터셋일 때
    Here you are in la-la land. 라라랜드입니다! 어떤 옵션을 선택해도 괜찮은 상황입니다. 그 중에서도, 가장 효과적인 옵션은 [전략 2]로 생각됩니다. 데이터셋의 크기가 커서 오버피팅은 문제가 안 될 것이므로, 우리가 원하는 만큼 학습을 시켜도 됩니다. 하지만, 데이터셋이 유사하다는 이점이 있으므로 모델이 이전에 학습한 지식을 활용하지 않을 이유가 없습니다. 따라서, classifier와 convolutional base의 높은 레벨 계층(뒷단의 계층) 일부만 학습시켜도 충분할 것입니다.
  • Quadrant 3 : 크기가 작고 유사성도 작은 데이터셋일 때
    가장 좋지 않은 상황입니다. 상황을 바꿔보는 게 가장 좋은 옵션일테지만 불가능하다면, 시도해 볼 수 있는 건 [전략 2]뿐입니다. 하지만 convolutoinal base의 계층 중 몇 개의 계층을 학습시키고 몇 개를 얼려야(그대로 두어야) 하는지를 알아내는 것은 어렵습니다. 너무 많은 계층을 새로 학습시키면 작은 데이터셋에 모델이 과적합될 우려가 있고, 너무 적은 계층만을 학습시키면 모델은 제대로 학습되지 않을 것입니다. 아마, Quadrant 2의 상황에서보다는 조금 더 깊은 계층까지 새로 학습을 시킬 필요가 있을 겁니다. 또한, 작은 크기의 데이터셋을 보완하기 위해서 data augmentation(데이터를 왜곡 또는 변형시켜서 데이터셋의 크기를 증가시키는 것)과 같은 테크닉에 대해 알아보야 할 것입니다. (data augmentation 기술에 대해 설명해놓은 좋은 자료는 여기에서 확인하실 수 있습니다.)
  • Quadrant 4 : 크기가 작지만 유사성은 높은 데이터셋일 때
    이 상황에서는 [전략 3]이 최선의 선택일 것입니다. 이 경우에는 사전학습모델의 마지막 부분, classifier만 삭제하고 기존의 convolutional base는 특징 추출기로써 사용하고, 추출된 특징을 새 classifier에 넣어서 분류할 수 있도록 학습시키면 됩니다. 즉, 우리가 해야 할 일은 새 classifier만 학습시키는 것입니다! ㅤ


[그림 3] 데이터크기-유사성 그래프ㅤㅤㅤㅤㅤㅤㅤㅤㅤㅤ ㅤㅤㅤㅤ[그림 4] 각 상황에 따른 Fine-tuning 방법

 

 

https://jeinalog.tistory.com/13

728x90
반응형
LIST
Comments