본문 바로가기

AI/TensorFlow 학습

24. TensorFlow 2.x - callback 함수 소개

참고: https://youtu.be/5cmPhp0Kz9s?si=3Vqiqbcw7_3Mzxoc

 

콜백 함수는 워낙 많이 쓰이는 개념이니 경력이 어느정도 있는 개발자들은 알고 있는 내용이지만,

이렇게 다들 안다고 하는 내용은 더욱 묻기도 힘들기도 하다는 것을 잘 알고 있어 소개한다.

위의 슬라이드에서 설명하듯, 매번 특정 이벤트 발생하면 어떤 함수를 실행해 해달라고 등록할 수 있는데, 이런 함수를 콜백 함수라고 한다.

 

TensorFlow에서도 역시 콜백 함수를 지원하는데, 일정 조건에서 학습을 멈춘다든가, 저장을 한다는가 하는 것이다.

이전 예제에서 나왔던 ModelCheckpoint(), EarlyStopping()등이 그 예이다. 

 

 

ReduceLROnPlateau(monitor, factor, patience, verbose): 특정 조건에서 학습율(LR)을 줄이는 콜백 함수

monitor: 관찰 대상

factor: 줄이는 비율 (<1)

patience: 관찰 기간 학습 회수

verbose: 로그 출력

위의 예제는 'val_loss'가 5회 학습동안 개선되지 않으면 현재의 학습률을 0.5배로 줄인다는 의미가 된다.

그리고, 콜백함수를 등록하는 방법은 fit()에서 callbacks=[콜백함수명]으로 인자를 넣는다.

 

ModelCheckpoint(file_path, monitor, verbose, save_best_only, mode): 특정 조건에서 학습된 모델을 저장.

file_path: 모델 저장 위치

monitor: 관찰 대상

verbose: 로그 출력

save_best_only: 학습중 최상의 모델을 저장

mode: 저장 모델을 찾는 방법

위의 예제는 학습중 'val_loss'가 개선되었을 때 자동으로 최상의 모델을 저장하게 하는 것이다.

 

EarlyStopping(monitor, patience): 특정 조건에서 학습을 멈춤

monitor: 관찰 대상

patience: 관찰 기간 학습 회수

위의 예제는, 'val_loss'가 5회 학습동안 개선이 없으면 멈추게 하는 것이다.

 

이제 실제 예제 코드를 보자.

이전 글에서 다루었던 mnist 예제를 가지고 콜백 함수를 적용해 보는 것이다.

https://firstmove.tistory.com/25

 

7. TensorFlow 2.x - Neural Network MNIST 예제

참고: https://youtu.be/AyvicBsP8tE?si=037NR0QlKNBz2ve7  영상에서는 MNIST문제를 모든 컴퓨터 언어에서 처음 배우는 "Hello world!"와 같은 수준이라고 소개한다. ㅎㅎㅎ이미지 데이터를 입력으로 사용할 때는

firstmove.tistory.com

 

받아온 mnist 이미지들은 위와 같다.

 

정규화를 하고, label을 원핫 인코딩을 진행한다.

 

모델을 위와 같이 만들었다. 여기서 원핫 인코딩을 했기 때문에 'categorical_crossentropy'를 사용하게 되었다.

 

ModelCheckpoint()와 EarlyStopping()을 이용하여 callback함수를 등록하고 학습하였다.

log를 통해서 epoch 1에서 관찰 대상인 'val_loss'가 개선되어 저장되는 것을 확인할 수 있다.

 

Epoch 6, 7, 8에서는 'val_loss'가 개선되지 않아 등록된 EarlyStopping 조건에 의해서 학습을 멈추었다.

 

위의 예제처럼 새로운 모델을 학습시킬 때  조금이라도 자원을 덜 사용하는 방법들을 제공하고 있다.

 

다음 글에서는 Kaggle 데이터를 이용하여 모델을 학습하고 적용해보자.