TensorFlow 1.8
텐서플로우 1.8 버전의 세부적인 기능과 업데이트 내용을 설명드립니다. Gradient Boosted Trees의 Classifier, Regressor Estimator 모델이 추가되었고, GPU 기반 환경에서 성능이 높아졌습니다. 그리고 generic proto 파싱과 RPC 통신 기능이 추가되었습니다.
주요 내용 외에 몇가지 버그들이 수정되었고 tf.GradientTape 등이 contrib 위치에서 이동었습니다. 그리고 CSV 파일을 dataset으로 바인딩 하기 위해서 tf.contrib.data.make_csv_dataset가 추가되는 tf.data 모듈도 다소 수정이 되었습니다. 신규 기능인 CSV 파일을 Dataset으로 바인딩 하는 함수를 예제를 통해 설명드리도록 하겠습니다.
설치
아래의 명령어로 텐서플로우 1.8 버전으로 설치 또는 업그레이드 가능합니다.
# CPU버전
$ sudo -H pip install --upgrade tensorflow
# GPU 버전일 경우
$ sudo -H pip install --upgrade tensorflow-gpu
직접 소스코드를 빌드할 경우 아래의 글을 참고해 주시기 바랍니다. 현재 소스코드에서 태깅된 버전은 v1.8.0 입니다.
주요 내용
- 여러개의 GPU를 갖는 하나의 시스템에서 Estimator 모델을 실행시키기 위해서 tf.contrib.distribute.MirroredStrategy()에서 tf.estimator.RunConfig() 으로 전달할 수 있게 되었습니다.
- Dataset elements를 GPU 메모리로 prepatch 할 수 있는 tf.contrib.data.prefetch_to_device()가 추가되었습니다.
- Gradient Boosted Trees의 Classifier와 Regressor가 각각 BoostedTreesClassifier, BoostedTreesRegressor 라는 이름으로 Estimator 모델이 추가되었습니다.
- Cloud TPU를 위한 3세대 파이프라인 설정이 사용하기 편리하게 되었고 성능이 높아졌습니다.
- tf.contrib.bayesflow 가 이동되었고, tf.contrib.{proto,rpc}가 generic proto 파싱과 RPC 통신을 위해서 추가되었습니다.
보다 자세한 내용은 아래 페이지 내용을 참고해 주시기 바랍니다.
https://github.com/tensorflow/tensorflow/releases
버그 및 기능 추가
multi-image Estimator 의 실행 요약 내용이 올바르게 출력될 수 있도록 버그 등이 수정되었고 tf.GradientTape 등이 contrib 위치에서 이동었습니다. 그리고 CSV 파일을 dataset으로 바인딩 하기 위해서 tf.contrib.data.make_csv_dataset가 추가되는 tf.data 모듈도 다소 수정이 되었습니다.
신규 기능 테스트
텐서플로우 1.8 에서 tf.contrib.data.make_csv_dataset 가 추가되었는데 CSV 파일로 부터 dataset으로 바인딩 하는 예제를 학습해 보도록 하겠습니다. 아래의 내용을 data.csv 파일을 생성해 주시기 바랍니다.
a,b,c,d
1,2,3,4
5,6,7,8
그리고 아래의 내용으로 python 스크립트를 생성합니다.
import tensorflow as tf
dataset = tf.contrib.data.make_csv_dataset("data.csv", 1)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
print(sess.run(next_element))
print(sess.run(next_element))
위의 내용을 실행하면 아래와 같이 CSV 파일로부터 데이터를 읽은 뒤에 내용이 출력되는 것을 알 수 있습니다.
{'a': array([5], dtype=int32), 'c': array([7], dtype=int32), 'b': array([6], dtype=int32), 'd': array([8], dtype=int32)}
{'a': array([1], dtype=int32), 'c': array([3], dtype=int32), 'b': array([2], dtype=int32), 'd': array([4], dtype=int32)}
( 본문 인용시 출처를 밝혀 주시면 감사하겠습니다.)