tf.train.SessionRunHook 让 estimator 训练过程可以个性化定制

estimator

estimator 是 tensorflow 提供的使用非常方便的模型封装。estimator 中提供了许多内置的模型,例如 LinearClassifier、DNNLinearCombinedClassifier、LinearRegressor等。用户也可以通过 model_fn 定制模型结构。在 estimator 对象的基础上任何模型都可以直接调用 train 和 eval 函数进行训练和测试,用户无需手动地创建 session 和 run session。estimator 的具体使用方式可以参考[1]。
estimator.png

dataset

tensorflow 底层 API 中都是使用 placeholder 和 feed_dict 向模型输入数据的,这样的方式效率较低。我们可以利用 dataset 库,这里提供了高效读取数据并且输入给模型训练的方式。

可以直接用 numpy 数组创建 dataset。直接用数组创建 dataset 的一个问题是 tensorflow 会直接把 dataset 中的数据写到 graph 中,当数据量较大时会报错,因为 graph 在序列化到 pb 文件时现在最大2GB。

def input_fn():
  features, labels = (np.random.sample((100,2)), np.random.sample((100,1)))
  dataset = tf.data.Dataset.from_tensor_slices((features,labels))
  dataset = dataset.shuffle(100000).repeat().batch(batch_size)
  return dataset

...
estimator.train(input_fn)

为了在大数据量时使用 dataset,我们可以用 placeholder 创建 dataset。这时数据就不会直接写到 graph 中,graph 中只有一个 placeholder 占位符。但是,用了 placeholder 就需要我们在一开始对它进行初始化填数据,需要调用 sess.run(iter.initializer, feed_dict={ x: data })。更多关于 dataset 的使用介绍可以参考文献[2]。

def input_fn():
  x = tf.placeholder(tf.float32, shape=[None,2])
  dataset = tf.data.Dataset.from_tensor_slices(x)
  dataset = dataset.shuffle(100000).repeat().batch(batch_size)
  iter = dataset.make_initializable_iterator()
  return iter.get_next()

SessionRunHook

既然前面说到 estimator 是 tensorflow 对模型的一种封装,我们不需要也无法拿到训练和测试时创建的 session,那么我们如何在 estimator 中对上一节使用 placeholder 的 dataset 的 initializeble_iterator 调用 sess.run 进行初始化呢?这时候就要用到 SessionRunHook 了。
先从字面意思理解一下 SessionRunHook 这个类。Session 就是 tensorflow 运行模型计算时的会话,Run就是整个 session 运行过程,Hook 是挂钩的意思即把某些事情挂在这个对象上可以理解为回调。

再看一下 SessionRunHook 源码[3]中的定义:
A SessionRunHook extends session.run() calls for the MonitoredSession.
SessionRunHooks are useful to track training, report progress, request early
stopping and more. SessionRunHooks use the observer pattern and notify at the
following points:

  • when a session starts being used
  • before a call to the session.run()
  • after a call to the session.run()
  • when the session closed
class SessionRunHook(object):
  """Hook to extend calls to MonitoredSession.run()."""

  def begin(self):
    """Called once before using the session.
    When called, the default graph is the one that will be launched in the
    session.  The hook can modify the graph by adding new operations to it.
    After the `begin()` call the graph will be finalized and the other callbacks
    can not modify the graph anymore. Second call of `begin()` on the same
    graph, should not change the graph.
    """
    pass

  def after_create_session(self, session, coord):  # pylint: disable=unused-argument
    """Called when new TensorFlow session is created.
    This is called to signal the hooks that a new session has been created. This
    has two essential differences with the situation in which `begin` is called:
    * When this is called, the graph is finalized and ops can no longer be added
        to the graph.
    * This method will also be called as a result of recovering a wrapped
        session, not only at the beginning of the overall session.
    Args:
      session: A TensorFlow Session that has been created.
      coord: A Coordinator object which keeps track of all threads.
    """
    pass

  def before_run(self, run_context):  # pylint: disable=unused-argument
    """Called before each call to run().
    You can return from this call a `SessionRunArgs` object indicating ops or
    tensors to add to the upcoming `run()` call.  These ops/tensors will be run
    together with the ops/tensors originally passed to the original run() call.
    The run args you return can also contain feeds to be added to the run()
    call.
    The `run_context` argument is a `SessionRunContext` that provides
    information about the upcoming `run()` call: the originally requested
    op/tensors, the TensorFlow Session.
    At this point graph is finalized and you can not add ops.
    Args:
      run_context: A `SessionRunContext` object.
    Returns:
      None or a `SessionRunArgs` object.
    """
    return None

  def after_run(self,
                run_context,  # pylint: disable=unused-argument
                run_values):  # pylint: disable=unused-argument
    """Called after each call to run().
    The `run_values` argument contains results of requested ops/tensors by
    `before_run()`.
    The `run_context` argument is the same one send to `before_run` call.
    `run_context.request_stop()` can be called to stop the iteration.
    If `session.run()` raises any exceptions then `after_run()` is not called.
    Args:
      run_context: A `SessionRunContext` object.
      run_values: A SessionRunValues object.
    """
    pass

  def end(self, session):  # pylint: disable=unused-argument
    """Called at the end of session.
    The `session` argument can be used in case the hook wants to run final ops,
    such as saving a last checkpoint.
    If `session.run()` raises exception other than OutOfRangeError or
    StopIteration then `end()` is not called.
    Note the difference between `end()` and `after_run()` behavior when
    `session.run()` raises OutOfRangeError or StopIteration. In that case
    `end()` is called but `after_run()` is not called.
    Args:
      session: A TensorFlow Session that will be soon closed.
    """
    pass

我们看到 SessionRunHook 源码中为 5 中不同的事件提供了回调函数,用户只需要继承 SessionRunHook 这个类并且具体实现想要的回调函数即可,具体用法看下一节。

estimator 结合 SessionRunHook 实现 placeholder 初始化

仔细看一下 estimator 的 train 和 evaluate 函数定义可以发现它们都接收 hooks 参数,这个参数的定义是:List of tf.train.SessionRunHook subclass instances. Used for callbacks inside the training loop. 就是上一节提到的用户继承自 SessionRunHook 的类的实例对象列表。

train(
    input_fn,
    hooks=None,
    steps=None,
    max_steps=None,
    saving_listeners=None
)

我们现在想要在训练之前初始化 dataset 的 placeholder,那么我们就应该具体实现 SessionRunHook 的after_create_session 成员函数:

class IteratorInitializerHook(tf.train.SessionRunHook):
   def __init__(self):
       super(IteratorInitializerHook, self).__init__()
       self.iterator_initializer_fn = None

   def after_create_session(self, session, coord):
       del coord
       self.iterator_initializer_fn(session)

def make_input_fn():
   iterator_initializer_hook = IteratorInitializerHook()

   def input_fn():
       x = tf.placeholder(tf.float32, shape=[None,2])
       dataset = tf.data.Dataset.from_tensor_slices(x)
       dataset = dataset.shuffle(100000).repeat().batch(batch_size)
       iter = dataset.make_initializable_iterator()
       data = np.random.sample((100,2))
       iterator_initializer_hook.iterator_initializer_fn = (
           lambda sess: sess.run(iter.initializer, feed_dict={x: data})
       )
       return iter.get_next()
   return input_fn, iterator_initializer_hook

...
input_fn, iterator_initializer_hook = make_input_fn()
estimator.train(input_fn, hooks=[iterator_initializer_hook])

当然,SessionRunHook 不光能用在初始化上,还有许多应用场景,可以参考源码[3]中提供的几个内置 Hook 和文献[4]。

[1] https://github.com/tensorflow/models/tree/master/samples/core/get_started
[2] https://www.jiqizhixin.com/articles/03137
[3] https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/session_run_hook.py
[4] https://blog.csdn.net/mrr1ght/article/details/81011280

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 200,527评论 5 470
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 84,314评论 2 377
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 147,535评论 0 332
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,006评论 1 272
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 62,961评论 5 360
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,220评论 1 277
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,664评论 3 392
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,351评论 0 254
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,481评论 1 294
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,397评论 2 317
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,443评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,123评论 3 315
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,713评论 3 303
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,801评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,010评论 1 255
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,494评论 2 346
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,075评论 2 341

推荐阅读更多精彩内容