텐서플로우 Checkpoint


텐서플로우 Checkpoint
텐서플로우에서 Estimator를 이용하여 개발된 모델을 저장하고 복원할 때 사용되는 Checkpoint에 대해서 설명드리도록 하겠습니다. 모델을 저장하고 복원하는 방법과 함께 Checkpoint 구성 파일과 시각화 방법 등을 함께 설명드리도록 하겠습니다.

텐서플로우 Checkpoint

텐서플로우에서는 2가지 종류의 모델 포맷을 제공하고 있습니다. 하나는 Checkpoint 으로 모델 개발에 사용된 코드에 의존되는 포맷입니다. 그리고 다른 하나의 포맷은 SavedModel으로 모델 개발에 사용된 코드에 의존하지 않는 포맷입니다.

이 글에서는 상위 레벨의 API인 Estimator 모델에서 사용되는 checkpoints 포맷에 위주로 설명드리도록 하겠습니다.

텐서플로우 Checkpoint 예제 코드

이 글에서 사용되는 예제 코드는 아래의 명령어로 다운 받으실 수 있습니다.

$ git clone https://github.com/tensorflow/models/

models/samples/core/get_started/premade_estimator.py 파일에서 소스코드를 참고하였습니다. 이 예제 파일은 Iris 데이터셋을 분류하는 모델로 보다 자세한 내용은 아래의 글을 참고해 주시기 바랍니다.

텐서플로우 Iris 예제 튜토리얼

학습 모델 저장

Estimator 모델은 학습을 하는 동안 학습하는 동안 모델 버전을 기록한 파일과, 이벤트 정보, 모델 weight 등이 모델 디렉터리에 저장됩니다. 모델이 저장되는 위치는 아래와 같이 모델을 생성할 때 model_dir 파라미터로 지정하면 됩니다.

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris')

그리고 아래와 같이 학습을 수행하는 코드를 실행하면, 자동으로 모델 파일이 생성되어 저장되게 됩니다.

classifier.train(
        input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
                steps=200)

기본 checkpoint 디렉터리

모델을 생성할 때 모델이 저장될 디렉터리를 지정하지 않게 되면, 기본값으로 /tmp 위치에서 디렉터리가 생성되어 저장하게 됩니다.

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3)

print(classifier.model_dir)

생성된 기본 디렉터리는 위의 코드로 확인 할 수 있습니다.

checkpoint 저장 주기

기본적으로 Estimator 모델은 아래의 주기로 Checkpoint를 저장하게 됩니다.

  • Writes a checkpoint every 10 minutes (600 seconds).
  • Writes a checkpoint when the train method starts (first iteration) and completes (final iteration).
  • Retains only the 5 most recent checkpoints in the directory.

위의 기본설정은 아래와 tf.estimator.RunConfig 를 정의하여 모델에 전달하면서 변경할 수 있습니다.

my_checkpointing_config = tf.estimator.RunConfig(
    save_checkpoints_secs = 20*60,  # Save checkpoints every 20 minutes.
    keep_checkpoint_max = 10,       # Retain the 10 most recent checkpoints.
)

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris',
    config=my_checkpointing_config)

Checkpoint 복원

처음으로 Estimator 모델의 train 함수를 호출하게 되면 checkpoint를 model_dir 위치에 저장합니다. 그리고 train, evaluate 그리고 predict 함수를 호출하게 되면 저장된 checkpoint 부터 weight 값 등을 초기화 합니다.

추가적인 checkpoint 복원 코드 없이 이어서 train 함수를 호출하여 학습하거나, evaluate, predict 함수 등을 호출 하여 사용 할 수 있습니다.

Checkpoint 시각화

아래의 코드는 모델이 저장된 디렉터리의 events 파일의 정보를 추출하여 시각화 하는 에제입니다. 학습 단계에서 저장된 step 과 loss 값을 그래프로 출력하게 됩니다.

import tensorflow as tf
import numpy as np
import glob
import matplotlib.pyplot as plt

paths = glob.glob("models/iris/events.out*.*")
path = paths[0]

x = []
y = []

for e in tf.train.summary_iterator(path):
  for v in e.summary.value:
      if v.tag == 'loss':
        x.append(e.step)
        y.append(v.simple_value)
print (x)
print (y)

plt.title("test")
plt.plot(x, y)

plt.show()

위의 예제를 실행한 결과는 아래와 같습니다.

텐서플로우 Checkpoint

( 본문 인용시 출처를 밝혀 주시면 감사하겠습니다.)