背景調查公司 Onfido 研究主管 Peter Roelants 在 Medium 上發表了一篇題為《Higher-Level APIs in TensorFlow》的文章,透過例項詳細介紹瞭如何使用 TensorFlow 中的高階 API(Estimator、Experiment 和 Dataset)訓練模型。值得一提的是 Experiment 和 Dataset 可以獨立使用。這些高階 API 已被最新發布的 TensorFlow1.3 版收錄。
TensorFlow 中有許多流行的庫,如 Keras、TFLearn 和 Sonnet,它們可以讓你輕鬆訓練模型,而無需接觸哪些低級別函式。目前,Keras API 正傾向於直接在 TensorFlow 中實現,TensorFlow 也在提供越來越多的高階構造,其中的一些已經被最新發布的 TensorFlow1.3 版收錄。
Experiment、Estimator 和 DataSet 框架和它們的相互作用(以下將對這些元件進行說明)
在本文中,我們使用 MNIST 作為資料集。它是一個易於使用的資料集,可以透過 TensorFlow 訪問。你可以在這個 gist 中找到完整的示例程式碼。使用這些框架的一個好處是我們不需要直接處理圖形和會話。
Estimator(評估器)類代表一個模型,以及這些模型被訓練和評估的方式。我們可以這樣構建一個評估器:
return tf.estimator.Estimator(
model_fn=model_fn, # First-class function
params=params, # HParams
config=run_config # RunConfig
)
為了構建一個 Estimator,我們需要傳遞一個模型函式,一個引數集合以及一些配置。
引數應該是模型超引數的集合,它可以是一個字典,但我們將在本示例中將其表示為 HParams 物件,用作 namedtuple。
該配置指定如何執行訓練和評估,以及如何存出結果。這些配置透過 RunConfig 物件表示,該物件傳達 Estimator 需要了解的關於執行模型的環境的所有內容。
模型函式是一個 Python 函式,它構建了給定輸入的模型(見後文)。
模型函式
模型函式是一個 Python 函式,它作為第一級函式傳遞給 Estimator。稍後我們就會看到,TensorFlow 也會在其他地方使用第一級函式。模型表示為函式的好處在於模型可以透過例項化函式不斷重新構建。該模型可以在訓練過程中被不同的輸入不斷建立,例如:在訓練期間執行驗證測試。
模型函式將輸入特徵作為引數,相應標籤作為張量。它還有一種模式來標記模型是否正在訓練、評估或執行推理。模型函式的最後一個引數是超引數的集合,它們與傳遞給 Estimator 的內容相同。模型函式需要返回一個 EstimatorSpec 物件——它會定義完整的模型。
EstimatorSpec 接受預測,損失,訓練和評估幾種操作,因此它定義了用於訓練,評估和推理的完整模型圖。由於 EstimatorSpec 採用常規 TensorFlow Operations,因此我們可以使用像 TF-Slim 這樣的框架來定義自己的模型。
Experiment(實驗)類是定義如何訓練模型,並將其與 Estimator 進行整合的方式。我們可以這樣建立一個實驗類:
experiment = tf.contrib.learn.Experiment(
estimator=estimator, # Estimator
train_input_fn=train_input_fn, # First-class function
eval_input_fn=eval_input_fn, # First-class function
train_steps=params.train_steps, # Minibatch steps
min_eval_frequency=params.min_eval_frequency, # Eval frequency
train_monitors=[train_input_hook], # Hooks for training
eval_hooks=[eval_input_hook], # Hooks for evaluation
eval_steps=None # Use evaluation feeder until its empty
Experiment 作為輸入:
一個 Estimator(例如上面定義的那個)。
訓練和評估鉤子(hooks)。這些鉤子可以用於監視或儲存特定內容,或在圖形和會話中進行一些操作。例如,我們將透過操作來幫助初始化資料載入器。
不同引數解釋了訓練時間和評估時間。
一旦我們定義了 experiment,我們就可以透過 learn_runner.run 執行它來訓練和評估模型:
learn_runner.run(
experiment_fn=experiment_fn, # First-class function
run_config=run_config, # RunConfig
schedule="train_and_evaluate", # What to run
hparams=params # HParams
與模型函式和資料函式一樣,函式中的學習運算子將建立 experiment 作為引數。
我們將使用 Dataset 類和相應的 Iterator 來表示我們的訓練和評估資料,並建立在訓練期間迭代資料的資料饋送器。在本示例中,我們將使用 TensorFlow 中可用的 MNIST 資料,並在其周圍構建一個 Dataset 包裝器。例如,我們把訓練的輸入資料表示為:
# Define the training inputs
def get_train_inputs(batch_size, mnist_data):
"""Return the input function to get the training data.
Args:
batch_size (int): Batch size of training iterator that is returned
by the input function.
mnist_data (Object): Object holding the loaded mnist data.
Returns:
(Input function, IteratorInitializerHook):
- Function that returns (features, labels) when called.
- Hook to initialise input iterator.
"""
iterator_initializer_hook = IteratorInitializerHook()
def train_inputs():
"""Returns training set as Operations.
(features, labels) Operations that iterate over the dataset
on every evaluation
with tf.name_scope("Training_data"):
# Get Mnist data
images = mnist_data.train.images.reshape([-1, 28, 28, 1])
labels = mnist_data.train.labels
# Define placeholders
images_placeholder = tf.placeholder(
images.dtype, images.shape)
labels_placeholder = tf.placeholder(
labels.dtype, labels.shape)
# Build dataset iterator
dataset = tf.contrib.data.Dataset.from_tensor_slices(
(images_placeholder, labels_placeholder))
dataset = dataset.repeat(None) # Infinite iterations
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()
next_example, next_label = iterator.get_next()
# Set runhook to initialize iterator
iterator_initializer_hook.iterator_initializer_func = \
lambda sess: sess.run(
iterator.initializer,
feed_dict={images_placeholder: images,
labels_placeholder: labels})
# Return batched (features, labels)
return next_example, next_label
# Return function and hook
return train_inputs, iterator_initializer_hook
呼叫這個 get_train_inputs 會返回一個一級函式,它在 TensorFlow 圖中建立資料載入操作,以及一個 Hook 初始化迭代器。
本示例中,我們使用的 MNIST 資料最初表示為 Numpy 陣列。我們建立一個佔位符張量來獲取資料,再使用佔位符來避免資料被複制。接下來,我們在 from_tensor_slices 的幫助下建立一個切片資料集。我們將確保該資料集執行無限長時間(experiment 可以考慮 epoch 的數量),讓資料得到清晰,並分成所需的尺寸。
為了迭代資料,我們需要在資料集的基礎上建立迭代器。因為我們正在使用佔位符,所以我們需要在 NumPy 資料的相關會話中初始化佔位符。我們可以透過建立一個可初始化的迭代器來實現。建立圖形時,我們將建立一個自定義的 IteratorInitializerHook 物件來初始化迭代器:
class IteratorInitializerHook(tf.train.SessionRunHook):
"""Hook to initialise data iterator after Session is created."""
def __init__(self):
super(IteratorInitializerHook, self).__init__()
self.iterator_initializer_func = None
def after_create_session(self, session, coord):
"""Initialise the iterator after the session has been created."""
self.iterator_initializer_func(session)
IteratorInitializerHook 繼承自 SessionRunHook。一旦建立了相關會話,這個鉤子就會呼叫 call after_create_session,並用正確的資料初始化佔位符。這個鉤子會透過 get_train_inputs 函式返回,並在建立時傳遞給 Experiment 物件。
train_inputs 函式返回的資料載入操作是 TensorFlow 操作,每次評估時都會返回一個新的批處理。
現在我們已經定義了所有的東西,我們可以用以下命令執行程式碼:
python mnist_estimator.py—model_dir ./mnist_training—data_dir ./mnist_data
如果你不傳遞引數,它將使用檔案頂部的預設標誌來確定儲存資料和模型的位置。訓練將在終端輸出全域性步長、損失、精度等資訊。除此之外,實驗和估算器框架將記錄 TensorBoard 可以顯示的某些統計資訊。如果我們執行:
tensorboard—logdir="./mnist_training"
我們就可以看到所有訓練統計資料,如訓練損失、評估準確性、每步時間和模型圖。
評估精度在 TensorBoard 中的視覺化
在 TensorFlow 中,有關 Estimator、Experiment 和 Dataset 框架的示例很少,這也是本文存在的原因。希望這篇文章可以向大家介紹這些架構工作的原理,它們應該採用哪些抽象方法,以及如何使用它們。如果你對它們很感興趣,以下是其他相關文件。
論文《TensorFlow Estimators: Managing Simplicity vs. Flexibility in High-Level Machine Learning Frameworks》:https://terrytangyuan.github.io/data/papers/tf-estimators-kdd-paper.pdf
Using the Dataset API for TensorFlow Input Pipelines:https://www.tensorflow.org/versions/r1.3/programmers_guide/datasets
tf.estimator.Estimator:https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator
tf.contrib.learn.RunConfig:https://www.tensorflow.org/api_docs/python/tf/contrib/learn/RunConfig
tf.estimator.DNNClassifier:https://www.tensorflow.org/api_docs/python/tf/estimator/DNNClassifier
tf.estimator.DNNRegressor:https://www.tensorflow.org/api_docs/python/tf/estimator/DNNRegressor
Creating Estimators in tf.estimator:https://www.tensorflow.org/extend/estimators
tf.contrib.learn.Head:https://www.tensorflow.org/api_docs/python/tf/contrib/learn/Head
本文用到的 Slim 框架:https://github.com/tensorflow/models/tree/master/slim
完整示例
"""Script to illustrate usage of tf.estimator.Estimator in TF v1.3"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data as mnist_data
from tensorflow.contrib import slim
from tensorflow.contrib.learn import ModeKeys
from tensorflow.contrib.learn import learn_runner
# Show debugging output
tf.logging.set_verbosity(tf.logging.DEBUG)
# Set default flags for the output directories
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string(
flag_name="model_dir", default_value="./mnist_training",
docstring="Output directory for model and training stats.")
flag_name="data_dir", default_value="./mnist_data",
docstring="Directory to download the data to.")
# Define and run experiment ###############################
def run_experiment(argv=None):
# Define model parameters
params = tf.contrib.training.HParams(
learning_rate=0.002,
n_classes=10,
train_steps=5000,
min_eval_frequency=100
# Set the run_config and the directory to save the model and stats
run_config = tf.contrib.learn.RunConfig()
run_config = run_config.replace(model_dir=FLAGS.model_dir)
def experiment_fn(run_config, params):
"""Create an experiment to train and evaluate the model.
run_config (RunConfig): Configuration for Estimator run.
params (HParam): Hyperparameters
(Experiment) Experiment for training the mnist model.
# You can change a subset of the run_config properties as
run_config = run_config.replace(
save_checkpoints_steps=params.min_eval_frequency)
# Define the mnist classifier
estimator = get_estimator(run_config, params)
# Setup data loaders
mnist = mnist_data.read_data_sets(FLAGS.data_dir, one_hot=False)
train_input_fn, train_input_hook = get_train_inputs(
batch_size=128, mnist_data=mnist)
eval_input_fn, eval_input_hook = get_test_inputs(
# Define the experiment
return experiment
# Define model ############################################
def get_estimator(run_config, params):
"""Return the model as a Tensorflow Estimator object.
params (HParams): hyperparameters.
def model_fn(features, labels, mode, params):
"""Model function used in the estimator.
features (Tensor): Input features to the model.
labels (Tensor): Labels tensor for training and evaluation.
mode (ModeKeys): Specifies if training, evaluation or prediction.
(EstimatorSpec): Model to be run by Estimator.
is_training = mode == ModeKeys.TRAIN
# Define model"s architecture
logits = architecture(features, is_training=is_training)
predictions = tf.argmax(logits, axis=-1)
loss = tf.losses.sparse_softmax_cross_entropy(
labels=tf.cast(labels, tf.int32),
logits=logits
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=get_train_op_fn(loss, params),
eval_metric_ops=get_eval_metric_ops(labels, predictions)
def get_train_op_fn(loss, params):
"""Get the training Op.
loss (Tensor): Scalar Tensor that represents the loss function.
params (HParams): Hyperparameters (needs to have `learning_rate`)
Training Op
return tf.contrib.layers.optimize_loss(
global_step=tf.contrib.framework.get_global_step(),
optimizer=tf.train.AdamOptimizer,
learning_rate=params.learning_rate
def get_eval_metric_ops(labels, predictions):
"""Return a dict of the evaluation Ops.
predictions (Tensor): Predictions Tensor.
Dict of metric results keyed by name.
return {
"Accuracy": tf.metrics.accuracy(
labels=labels,
name="accuracy")
}
def architecture(inputs, is_training, scope="MnistConvNet"):
"""Return the output operation following the network architecture.
inputs (Tensor): Input Tensor
is_training (bool): True iff in training mode
scope (str): Name of the scope of the architecture
Logits output Op for the network.
with tf.variable_scope(scope):
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_initializer=tf.contrib.layers.xavier_initializer()):
net = slim.conv2d(inputs, 20, [5, 5], padding="VALID",
scope="conv1")
net = slim.max_pool2d(net, 2, stride=2, scope="pool2")
net = slim.conv2d(net, 40, [5, 5], padding="VALID",
scope="conv3")
net = slim.max_pool2d(net, 2, stride=2, scope="pool4")
net = tf.reshape(net, [-1, 4 * 4 * 40])
net = slim.fully_connected(net, 256, scope="fn5")
net = slim.dropout(net, is_training=is_training,
scope="dropout5")
net = slim.fully_connected(net, 256, scope="fn6")
scope="dropout6")
net = slim.fully_connected(net, 10, scope="output",
activation_fn=None)
return net
# Define data loaders #####################################
def get_test_inputs(batch_size, mnist_data):
"""Return the input function to get the test data.
def test_inputs():
with tf.name_scope("Test_data"):
images = mnist_data.test.images.reshape([-1, 28, 28, 1])
labels = mnist_data.test.labels
return test_inputs, iterator_initializer_hook
# Run script ##############################################
if __name__ == "__main__":
tf.app.run(
main=run_experiment
背景調查公司 Onfido 研究主管 Peter Roelants 在 Medium 上發表了一篇題為《Higher-Level APIs in TensorFlow》的文章,透過例項詳細介紹瞭如何使用 TensorFlow 中的高階 API(Estimator、Experiment 和 Dataset)訓練模型。值得一提的是 Experiment 和 Dataset 可以獨立使用。這些高階 API 已被最新發布的 TensorFlow1.3 版收錄。
TensorFlow 中有許多流行的庫,如 Keras、TFLearn 和 Sonnet,它們可以讓你輕鬆訓練模型,而無需接觸哪些低級別函式。目前,Keras API 正傾向於直接在 TensorFlow 中實現,TensorFlow 也在提供越來越多的高階構造,其中的一些已經被最新發布的 TensorFlow1.3 版收錄。
Experiment、Estimator 和 DataSet 框架和它們的相互作用(以下將對這些元件進行說明)
在本文中,我們使用 MNIST 作為資料集。它是一個易於使用的資料集,可以透過 TensorFlow 訪問。你可以在這個 gist 中找到完整的示例程式碼。使用這些框架的一個好處是我們不需要直接處理圖形和會話。
EstimatorEstimator(評估器)類代表一個模型,以及這些模型被訓練和評估的方式。我們可以這樣構建一個評估器:
return tf.estimator.Estimator(
model_fn=model_fn, # First-class function
params=params, # HParams
config=run_config # RunConfig
)
為了構建一個 Estimator,我們需要傳遞一個模型函式,一個引數集合以及一些配置。
引數應該是模型超引數的集合,它可以是一個字典,但我們將在本示例中將其表示為 HParams 物件,用作 namedtuple。
該配置指定如何執行訓練和評估,以及如何存出結果。這些配置透過 RunConfig 物件表示,該物件傳達 Estimator 需要了解的關於執行模型的環境的所有內容。
模型函式是一個 Python 函式,它構建了給定輸入的模型(見後文)。
模型函式
模型函式是一個 Python 函式,它作為第一級函式傳遞給 Estimator。稍後我們就會看到,TensorFlow 也會在其他地方使用第一級函式。模型表示為函式的好處在於模型可以透過例項化函式不斷重新構建。該模型可以在訓練過程中被不同的輸入不斷建立,例如:在訓練期間執行驗證測試。
模型函式將輸入特徵作為引數,相應標籤作為張量。它還有一種模式來標記模型是否正在訓練、評估或執行推理。模型函式的最後一個引數是超引數的集合,它們與傳遞給 Estimator 的內容相同。模型函式需要返回一個 EstimatorSpec 物件——它會定義完整的模型。
EstimatorSpec 接受預測,損失,訓練和評估幾種操作,因此它定義了用於訓練,評估和推理的完整模型圖。由於 EstimatorSpec 採用常規 TensorFlow Operations,因此我們可以使用像 TF-Slim 這樣的框架來定義自己的模型。
ExperimentExperiment(實驗)類是定義如何訓練模型,並將其與 Estimator 進行整合的方式。我們可以這樣建立一個實驗類:
experiment = tf.contrib.learn.Experiment(
estimator=estimator, # Estimator
train_input_fn=train_input_fn, # First-class function
eval_input_fn=eval_input_fn, # First-class function
train_steps=params.train_steps, # Minibatch steps
min_eval_frequency=params.min_eval_frequency, # Eval frequency
train_monitors=[train_input_hook], # Hooks for training
eval_hooks=[eval_input_hook], # Hooks for evaluation
eval_steps=None # Use evaluation feeder until its empty
)
Experiment 作為輸入:
一個 Estimator(例如上面定義的那個)。
訓練和評估鉤子(hooks)。這些鉤子可以用於監視或儲存特定內容,或在圖形和會話中進行一些操作。例如,我們將透過操作來幫助初始化資料載入器。
不同引數解釋了訓練時間和評估時間。
一旦我們定義了 experiment,我們就可以透過 learn_runner.run 執行它來訓練和評估模型:
learn_runner.run(
experiment_fn=experiment_fn, # First-class function
run_config=run_config, # RunConfig
schedule="train_and_evaluate", # What to run
hparams=params # HParams
)
與模型函式和資料函式一樣,函式中的學習運算子將建立 experiment 作為引數。
Dataset我們將使用 Dataset 類和相應的 Iterator 來表示我們的訓練和評估資料,並建立在訓練期間迭代資料的資料饋送器。在本示例中,我們將使用 TensorFlow 中可用的 MNIST 資料,並在其周圍構建一個 Dataset 包裝器。例如,我們把訓練的輸入資料表示為:
# Define the training inputs
def get_train_inputs(batch_size, mnist_data):
"""Return the input function to get the training data.
Args:
batch_size (int): Batch size of training iterator that is returned
by the input function.
mnist_data (Object): Object holding the loaded mnist data.
Returns:
(Input function, IteratorInitializerHook):
- Function that returns (features, labels) when called.
- Hook to initialise input iterator.
"""
iterator_initializer_hook = IteratorInitializerHook()
def train_inputs():
"""Returns training set as Operations.
Returns:
(features, labels) Operations that iterate over the dataset
on every evaluation
"""
with tf.name_scope("Training_data"):
# Get Mnist data
images = mnist_data.train.images.reshape([-1, 28, 28, 1])
labels = mnist_data.train.labels
# Define placeholders
images_placeholder = tf.placeholder(
images.dtype, images.shape)
labels_placeholder = tf.placeholder(
labels.dtype, labels.shape)
# Build dataset iterator
dataset = tf.contrib.data.Dataset.from_tensor_slices(
(images_placeholder, labels_placeholder))
dataset = dataset.repeat(None) # Infinite iterations
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()
next_example, next_label = iterator.get_next()
# Set runhook to initialize iterator
iterator_initializer_hook.iterator_initializer_func = \
lambda sess: sess.run(
iterator.initializer,
feed_dict={images_placeholder: images,
labels_placeholder: labels})
# Return batched (features, labels)
return next_example, next_label
# Return function and hook
return train_inputs, iterator_initializer_hook
呼叫這個 get_train_inputs 會返回一個一級函式,它在 TensorFlow 圖中建立資料載入操作,以及一個 Hook 初始化迭代器。
本示例中,我們使用的 MNIST 資料最初表示為 Numpy 陣列。我們建立一個佔位符張量來獲取資料,再使用佔位符來避免資料被複制。接下來,我們在 from_tensor_slices 的幫助下建立一個切片資料集。我們將確保該資料集執行無限長時間(experiment 可以考慮 epoch 的數量),讓資料得到清晰,並分成所需的尺寸。
為了迭代資料,我們需要在資料集的基礎上建立迭代器。因為我們正在使用佔位符,所以我們需要在 NumPy 資料的相關會話中初始化佔位符。我們可以透過建立一個可初始化的迭代器來實現。建立圖形時,我們將建立一個自定義的 IteratorInitializerHook 物件來初始化迭代器:
class IteratorInitializerHook(tf.train.SessionRunHook):
"""Hook to initialise data iterator after Session is created."""
def __init__(self):
super(IteratorInitializerHook, self).__init__()
self.iterator_initializer_func = None
def after_create_session(self, session, coord):
"""Initialise the iterator after the session has been created."""
self.iterator_initializer_func(session)
IteratorInitializerHook 繼承自 SessionRunHook。一旦建立了相關會話,這個鉤子就會呼叫 call after_create_session,並用正確的資料初始化佔位符。這個鉤子會透過 get_train_inputs 函式返回,並在建立時傳遞給 Experiment 物件。
train_inputs 函式返回的資料載入操作是 TensorFlow 操作,每次評估時都會返回一個新的批處理。
執行程式碼現在我們已經定義了所有的東西,我們可以用以下命令執行程式碼:
python mnist_estimator.py—model_dir ./mnist_training—data_dir ./mnist_data
如果你不傳遞引數,它將使用檔案頂部的預設標誌來確定儲存資料和模型的位置。訓練將在終端輸出全域性步長、損失、精度等資訊。除此之外,實驗和估算器框架將記錄 TensorBoard 可以顯示的某些統計資訊。如果我們執行:
tensorboard—logdir="./mnist_training"
我們就可以看到所有訓練統計資料,如訓練損失、評估準確性、每步時間和模型圖。
評估精度在 TensorBoard 中的視覺化
在 TensorFlow 中,有關 Estimator、Experiment 和 Dataset 框架的示例很少,這也是本文存在的原因。希望這篇文章可以向大家介紹這些架構工作的原理,它們應該採用哪些抽象方法,以及如何使用它們。如果你對它們很感興趣,以下是其他相關文件。
關於 Estimator、Experiment 和 Dataset 的註釋論文《TensorFlow Estimators: Managing Simplicity vs. Flexibility in High-Level Machine Learning Frameworks》:https://terrytangyuan.github.io/data/papers/tf-estimators-kdd-paper.pdf
Using the Dataset API for TensorFlow Input Pipelines:https://www.tensorflow.org/versions/r1.3/programmers_guide/datasets
tf.estimator.Estimator:https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator
tf.contrib.learn.RunConfig:https://www.tensorflow.org/api_docs/python/tf/contrib/learn/RunConfig
tf.estimator.DNNClassifier:https://www.tensorflow.org/api_docs/python/tf/estimator/DNNClassifier
tf.estimator.DNNRegressor:https://www.tensorflow.org/api_docs/python/tf/estimator/DNNRegressor
Creating Estimators in tf.estimator:https://www.tensorflow.org/extend/estimators
tf.contrib.learn.Head:https://www.tensorflow.org/api_docs/python/tf/contrib/learn/Head
本文用到的 Slim 框架:https://github.com/tensorflow/models/tree/master/slim
完整示例
"""Script to illustrate usage of tf.estimator.Estimator in TF v1.3"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data as mnist_data
from tensorflow.contrib import slim
from tensorflow.contrib.learn import ModeKeys
from tensorflow.contrib.learn import learn_runner
# Show debugging output
tf.logging.set_verbosity(tf.logging.DEBUG)
# Set default flags for the output directories
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string(
flag_name="model_dir", default_value="./mnist_training",
docstring="Output directory for model and training stats.")
tf.app.flags.DEFINE_string(
flag_name="data_dir", default_value="./mnist_data",
docstring="Directory to download the data to.")
# Define and run experiment ###############################
def run_experiment(argv=None):
# Define model parameters
params = tf.contrib.training.HParams(
learning_rate=0.002,
n_classes=10,
train_steps=5000,
min_eval_frequency=100
)
# Set the run_config and the directory to save the model and stats
run_config = tf.contrib.learn.RunConfig()
run_config = run_config.replace(model_dir=FLAGS.model_dir)
learn_runner.run(
experiment_fn=experiment_fn, # First-class function
run_config=run_config, # RunConfig
schedule="train_and_evaluate", # What to run
hparams=params # HParams
)
def experiment_fn(run_config, params):
"""Create an experiment to train and evaluate the model.
Args:
run_config (RunConfig): Configuration for Estimator run.
params (HParam): Hyperparameters
Returns:
(Experiment) Experiment for training the mnist model.
"""
# You can change a subset of the run_config properties as
run_config = run_config.replace(
save_checkpoints_steps=params.min_eval_frequency)
# Define the mnist classifier
estimator = get_estimator(run_config, params)
# Setup data loaders
mnist = mnist_data.read_data_sets(FLAGS.data_dir, one_hot=False)
train_input_fn, train_input_hook = get_train_inputs(
batch_size=128, mnist_data=mnist)
eval_input_fn, eval_input_hook = get_test_inputs(
batch_size=128, mnist_data=mnist)
# Define the experiment
experiment = tf.contrib.learn.Experiment(
estimator=estimator, # Estimator
train_input_fn=train_input_fn, # First-class function
eval_input_fn=eval_input_fn, # First-class function
train_steps=params.train_steps, # Minibatch steps
min_eval_frequency=params.min_eval_frequency, # Eval frequency
train_monitors=[train_input_hook], # Hooks for training
eval_hooks=[eval_input_hook], # Hooks for evaluation
eval_steps=None # Use evaluation feeder until its empty
)
return experiment
# Define model ############################################
def get_estimator(run_config, params):
"""Return the model as a Tensorflow Estimator object.
Args:
run_config (RunConfig): Configuration for Estimator run.
params (HParams): hyperparameters.
"""
return tf.estimator.Estimator(
model_fn=model_fn, # First-class function
params=params, # HParams
config=run_config # RunConfig
)
def model_fn(features, labels, mode, params):
"""Model function used in the estimator.
Args:
features (Tensor): Input features to the model.
labels (Tensor): Labels tensor for training and evaluation.
mode (ModeKeys): Specifies if training, evaluation or prediction.
params (HParams): hyperparameters.
Returns:
(EstimatorSpec): Model to be run by Estimator.
"""
is_training = mode == ModeKeys.TRAIN
# Define model"s architecture
logits = architecture(features, is_training=is_training)
predictions = tf.argmax(logits, axis=-1)
loss = tf.losses.sparse_softmax_cross_entropy(
labels=tf.cast(labels, tf.int32),
logits=logits
)
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=predictions,
loss=loss,
train_op=get_train_op_fn(loss, params),
eval_metric_ops=get_eval_metric_ops(labels, predictions)
)
def get_train_op_fn(loss, params):
"""Get the training Op.
Args:
loss (Tensor): Scalar Tensor that represents the loss function.
params (HParams): Hyperparameters (needs to have `learning_rate`)
Returns:
Training Op
"""
return tf.contrib.layers.optimize_loss(
loss=loss,
global_step=tf.contrib.framework.get_global_step(),
optimizer=tf.train.AdamOptimizer,
learning_rate=params.learning_rate
)
def get_eval_metric_ops(labels, predictions):
"""Return a dict of the evaluation Ops.
Args:
labels (Tensor): Labels tensor for training and evaluation.
predictions (Tensor): Predictions Tensor.
Returns:
Dict of metric results keyed by name.
"""
return {
"Accuracy": tf.metrics.accuracy(
labels=labels,
predictions=predictions,
name="accuracy")
}
def architecture(inputs, is_training, scope="MnistConvNet"):
"""Return the output operation following the network architecture.
Args:
inputs (Tensor): Input Tensor
is_training (bool): True iff in training mode
scope (str): Name of the scope of the architecture
Returns:
Logits output Op for the network.
"""
with tf.variable_scope(scope):
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_initializer=tf.contrib.layers.xavier_initializer()):
net = slim.conv2d(inputs, 20, [5, 5], padding="VALID",
scope="conv1")
net = slim.max_pool2d(net, 2, stride=2, scope="pool2")
net = slim.conv2d(net, 40, [5, 5], padding="VALID",
scope="conv3")
net = slim.max_pool2d(net, 2, stride=2, scope="pool4")
net = tf.reshape(net, [-1, 4 * 4 * 40])
net = slim.fully_connected(net, 256, scope="fn5")
net = slim.dropout(net, is_training=is_training,
scope="dropout5")
net = slim.fully_connected(net, 256, scope="fn6")
net = slim.dropout(net, is_training=is_training,
scope="dropout6")
net = slim.fully_connected(net, 10, scope="output",
activation_fn=None)
return net
# Define data loaders #####################################
class IteratorInitializerHook(tf.train.SessionRunHook):
"""Hook to initialise data iterator after Session is created."""
def __init__(self):
super(IteratorInitializerHook, self).__init__()
self.iterator_initializer_func = None
def after_create_session(self, session, coord):
"""Initialise the iterator after the session has been created."""
self.iterator_initializer_func(session)
# Define the training inputs
def get_train_inputs(batch_size, mnist_data):
"""Return the input function to get the training data.
Args:
batch_size (int): Batch size of training iterator that is returned
by the input function.
mnist_data (Object): Object holding the loaded mnist data.
Returns:
(Input function, IteratorInitializerHook):
- Function that returns (features, labels) when called.
- Hook to initialise input iterator.
"""
iterator_initializer_hook = IteratorInitializerHook()
def train_inputs():
"""Returns training set as Operations.
Returns:
(features, labels) Operations that iterate over the dataset
on every evaluation
"""
with tf.name_scope("Training_data"):
# Get Mnist data
images = mnist_data.train.images.reshape([-1, 28, 28, 1])
labels = mnist_data.train.labels
# Define placeholders
images_placeholder = tf.placeholder(
images.dtype, images.shape)
labels_placeholder = tf.placeholder(
labels.dtype, labels.shape)
# Build dataset iterator
dataset = tf.contrib.data.Dataset.from_tensor_slices(
(images_placeholder, labels_placeholder))
dataset = dataset.repeat(None) # Infinite iterations
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()
next_example, next_label = iterator.get_next()
# Set runhook to initialize iterator
iterator_initializer_hook.iterator_initializer_func = \
lambda sess: sess.run(
iterator.initializer,
feed_dict={images_placeholder: images,
labels_placeholder: labels})
# Return batched (features, labels)
return next_example, next_label
# Return function and hook
return train_inputs, iterator_initializer_hook
def get_test_inputs(batch_size, mnist_data):
"""Return the input function to get the test data.
Args:
batch_size (int): Batch size of training iterator that is returned
by the input function.
mnist_data (Object): Object holding the loaded mnist data.
Returns:
(Input function, IteratorInitializerHook):
- Function that returns (features, labels) when called.
- Hook to initialise input iterator.
"""
iterator_initializer_hook = IteratorInitializerHook()
def test_inputs():
"""Returns training set as Operations.
Returns:
(features, labels) Operations that iterate over the dataset
on every evaluation
"""
with tf.name_scope("Test_data"):
# Get Mnist data
images = mnist_data.test.images.reshape([-1, 28, 28, 1])
labels = mnist_data.test.labels
# Define placeholders
images_placeholder = tf.placeholder(
images.dtype, images.shape)
labels_placeholder = tf.placeholder(
labels.dtype, labels.shape)
# Build dataset iterator
dataset = tf.contrib.data.Dataset.from_tensor_slices(
(images_placeholder, labels_placeholder))
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()
next_example, next_label = iterator.get_next()
# Set runhook to initialize iterator
iterator_initializer_hook.iterator_initializer_func = \
lambda sess: sess.run(
iterator.initializer,
feed_dict={images_placeholder: images,
labels_placeholder: labels})
return next_example, next_label
# Return function and hook
return test_inputs, iterator_initializer_hook
# Run script ##############################################
if __name__ == "__main__":
tf.app.run(
main=run_experiment
)