利用Python装饰器来组织Tensorflow代码的结构

装饰器

定义Python装饰器

装饰器是一种设计模式, 可以使用OOP中的继承和组合实现, 而Python还直接从语法层面支持了装饰器.
装饰器可以在不改变函数定义的前提下, 在代码运行期间动态增加函数的功能, 本质上就是将原来的函数与新加的功能包装成一个新的函数wrapper, 并让原函数的名字指向wrapper.

Python中实现decorator有两种方式: 函数方式 和 类方式

函数方式

可以用一个返回函数的高阶函数来实现装饰器

简单的无参数装饰器

def log(func):
    def wrapper(*args, **kw):
        print('call %s():' % func.__name__)
        return func(*args, **kw)
    return wrapper
@log
def now():
    print('NOW')

在函数fun的定义前面放入@decorator实现的功能相当于fun=decorator(fun),
从而现在调用now()将打印前面的调用信息.

实现带参数的装饰器

只要给装饰器提供参数后,返回的object具备一个无参数装饰器的功能即可.
可以用返回无参数装饰器函数的高阶函数来实现.

def log(text):
    def decorator(func):
        def wrapper(*args, **kw):
            print('%s %s():' % (text, func.__name__))
            return func(*args, **kw)
        return wrapper
    return decorator

@log('execute')
def now():
  print("parametric NOW")

该语法糖相当于now=log('execute')(now).

如果要保存原函数的__name__属性, 使用python的functools模块中的wraps()装饰器, 只需要将@functools.wraps(func)放在def wrapper()前面即可.该装饰器实现的功能就相当于添加了wrapper.__name__ = func.__name__语句.

类方式

Python中的类和函数差别不大, 实现类的__call__ method就可以把类当成一个函数来使用了.

实现以上带参数装饰器同样功能的装饰器类的代码如下:

class log():
    def __init__(self, text):
        self.text = text
    def __call__(self,func):
        @functools.wraps(func)
        def wrapper(*args, **kw):
            print("%s %s" % (self.text, func.__name__))
            return func(*args, **kw)
        return wrapper

@log("I love Python")
def now():
    print("class decorator NOW")

使用类的好处是可以继承

使用场景

装饰器最巧妙的使用场景在Flask和Django Web框架中,它可以用来检查某人是否被授权使用Web应用的某个endpoint(假设是f函数), 下面是一个检查授权的示意性代码片段.

from functools import wraps

def require_auth(f):
  @wraps(f)
  def decorated(*args, **kw):
    auth = request.authorization
    if not auth or not check_auth(auth.username, auth.password):
      authenticate()
    return f(*args, **kw)
  return decorated

另一个常见的用处是用于日志记录

from functools import wraps

def logit(func):
    @wraps(func)
    def with_logging(*args, **kwargs):
        print(func.__name__ + " was called")
        return func(*args, **kwargs)
    return with_logging

@logit
def addition_func(x):
   """Do some math."""
   return x + x

result = addition_func(4)

是不是超级灵活呢? 虽然装饰器有点难定义, 但是一旦掌握, 它就像不可思议的魔法. Σ(*゚д゚ノ)ノ

利用装饰器改善你的Tensorflow代码结构

重头戏终于来了! 当你在写Tensorflow代码时, 定义模型的代码和动态运行的代码经常会混乱不清. 一方面, 我们希望定义compute graph的"静态"Python代码只执行一次, 而相反, 我们希望调用session来运行的代码可以运行多次取得不同状态的数据信息, 而两类代码一旦杂糅在一起, 很容易造成Graph中有冗余的nodes被定义了多次, 感觉十分不爽, 写过那种丑代码的你们都懂.

那么,如何以一种可读又可复用的方式来组织你的TF代码结构呢?

版本1

我们都希望用一个类来抽象一个模型, 这无疑是明智的. 但是如何定义类的接口呢?
我们的模型需要接受input的feature data和target value, 需要进行 training, evaluation 和 inference 操作.

class Model:

    def __init__(self, data, target):
        data_size = int(data.get_shape()[1])   # 假设data的shape为[N,D] N为Batch Size  D是输入维度
        target_size = int(target.get_shape()[1]) # 假设target的shape为[N,K] K是one-hot的label深度, 即要分类的类的数量
        weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
        bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
        incoming = tf.matmul(data, weight) + bias
        self._prediction = tf.nn.softmax(incoming)
        cross_entropy = tf.reduce_mean(-tf.reduce_sum(target * tf.log(self._prediction), reduction_indices=[1]))
        self._optimize = tf.train.RMSPropOptimizer(0.03).minimize(cross_entropy)
        mistakes = tf.not_equal(
            tf.argmax(target, 1), tf.argmax(self._prediction, 1))
        self._error = tf.reduce_mean(tf.cast(mistakes, tf.float32))

    @property
    def prediction(self):
        return self._prediction

    @property
    def optimize(self):
        return self._optimize

    @property
    def error(self):
        return self._error

这是最基本的形式, 但是它存在很多问题. 最严重的问题是整个图都被定义在init构造函数中, 这既不可读又不可复用.

版本2

直接将代码分离开来,放在多个函数中是不行的, 因为每次函数调用时都会向Graph中添加nodes, 所以我们必须确保这些Node Operations只在函数第一次调用的时候才添加到Graph中, 这有点类似于singleton模式, 或者叫做lazy-loading(使用时才创建).

class Model:

    def __init__(self, data, target):
        self.data = data
        self.target = target
        self._prediction = None
        self._optimize = None
        self._error = None

    @property
    def prediction(self):
        if not self._prediction:
            data_size = int(self.data.get_shape()[1])
            target_size = int(self.target.get_shape()[1])
            weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
            bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
            incoming = tf.matmul(self.data, weight) + bias
            self._prediction = tf.nn.softmax(incoming)
        return self._prediction

    @property
    def optimize(self):
        if not self._optimize:
             cross_entropy = tf.reduce_mean(-tf.reduce_sum(self.target * tf.log(self._prediction), reduction_indices=[1]))
            optimizer = tf.train.RMSPropOptimizer(0.03)
            self._optimize = optimizer.minimize(cross_entropy)
        return self._optimize

    @property
    def error(self):
        if not self._error:
            mistakes = tf.not_equal(
                tf.argmax(self.target, 1), tf.argmax(self.prediction, 1))
            self._error = tf.reduce_mean(tf.cast(mistakes, tf.float32))
        return self._error

这好多了, 但是每次都需要if判断还是有点太臃肿, 利用装饰器, 我们可以做的更好!

版本3

实现一个自定义装饰器lazy_property, 它的功能和property类似,但是只运行function一次, 然后将返回结果存在一个属性中, 该属性的名字是 "_cache_" + function.__name__, 后续函数调用将直接返回缓存好的属性.

import functools

def lazy_property(function):
    attribute = '_cache_' + function.__name__

    @property
    @functools.wraps(function)
    def decorator(self):
        if not hasattr(self, attribute):
            setattr(self, attribute, function(self))
        return getattr(self, attribute)

    return decorator

使用该装饰器, 优化后的代码如下:

class Model:

    def __init__(self, data, target):
        self.data = data
        self.target = target
        self.prediction
        self.optimize
        self.error

    @lazy_property
    def prediction(self):
        data_size = int(self.data.get_shape()[1])
        target_size = int(self.target.get_shape()[1])
        weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
        bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
        incoming = tf.matmul(self.data, weight) + bias
        return tf.nn.softmax(incoming)

    @lazy_property
    def optimize(self):
        cross_entropy = tf.reduce_mean(-tf.reduce_sum(self.target * tf.log(self.prediction), reduction_indices=[1]))
        optimizer = tf.train.RMSPropOptimizer(0.03)
        return optimizer.minimize(cross_entropy)

    @lazy_property
    def error(self):
        mistakes = tf.not_equal(
            tf.argmax(self.target, 1), tf.argmax(self.prediction, 1))
        return tf.reduce_mean(tf.cast(mistakes, tf.float32))

注意, 在init构造函数中调用了属性prediction,optimize和error, 这会让其第一次执行, 因此构造函数完成后Compute Graph也就构建完毕了.

有时我们使用TensorBoard来可视化Graph时, 希望将相关的Node分组到一起, 这样看起来更为清楚直观, 我们只需要修改之前的lazy_property装饰器, 在其中加上with tf.name_scope("name") 或者 with tf.variable_scope("name")即可, 修改之前的装饰器如下:

import functools

def define_scope(function):
    attribute = '_cache_' + function.__name__

    @property
    @functools.wraps(function)
    def decorator(self):
        if not hasattr(self, attribute):
            with tf.variable_scope(function.__name__):
                setattr(self, attribute, function(self))
        return getattr(self, attribute)

    return decorator

我们现在能够用一种结构化和紧凑的方式来定义TensorFlow的模型了, 这归功于Python的强大的decorator语法糖.

完整的代码在这里, 有关该代码的详细注释请参考我的博客.

References:

  1. https://danijar.com/structuring-your-tensorflow-models/
  2. https://www.liaoxuefeng.com/wiki/0014316089557264a6b348958f449949df42a6d3a2e542c000/0014318435599930270c0381a3b44db991cd6d858064ac0000
  3. https://eastlakeside.gitbooks.io/interpy-zh/content/decorators/deco_class.html
  4. https://www.liaoxuefeng.com/wiki/001374738125095c955c1e6d8bb493182103fac9270762a000/001386820062641f3bcc60a4b164f8d91df476445697b9e000
  5. https://www.tensorflow.org/get_started/mnist/beginners?hl=zh-cn#training
  6. https://mozillazg.github.io/2016/12/python-super-is-not-as-simple-as-you-thought.html
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 199,519评论 5 468
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 83,842评论 2 376
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 146,544评论 0 330
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 53,742评论 1 271
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 62,646评论 5 359
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,027评论 1 275
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,513评论 3 390
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,169评论 0 254
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,324评论 1 294
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,268评论 2 317
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,299评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 32,996评论 3 315
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,591评论 3 303
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,667评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,911评论 1 255
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 42,288评论 2 345
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 41,871评论 2 341

推荐阅读更多精彩内容

  • Python进阶框架 希望大家喜欢,点赞哦首先感谢廖雪峰老师对于该课程的讲解 一、函数式编程 1.1 函数式编程简...
    Gaolex阅读 5,484评论 6 53
  • 我还是那么想你。 当我能够轻松自如地和别人谈笑风生,当我不再辗转反侧失眠到深夜,当我平静地把你的微信拖到了黑名单,...
    安夏的花花世界阅读 247评论 0 0
  • 打开窗户,感觉世界距离我 有着遥远的路程, 而我也懒得动身前往。 ——《孤独》城子玄
    城子玄阅读 172评论 0 0
  • 有一种唠叨,最容易使我们烦躁,那就是母亲的唠叨。这种唠叨,都是源于她内心深处对我们的爱。 在生活...
    林琨皓阅读 520评论 0 2