青ポスの部屋

旅と技術とポエムのブログ

TensorflowのLambdaレイヤーにまつわるお話

この記事はTensorFlow Advent Calendar 202018日目の記事です。

今回はTensorflowのLambdaレイヤーについていくつかTipsを書きます。

「Transform」デザインパターンとLambdaレイヤー

Transformデザインパターンとは、Machine Learning Design Patternsの中で提唱されている実装パターンの一つです。

medium.com

要は「モデルに直接データを食わせれば推論できるように、lambdaで変換するレイヤーをモデルに埋め込んでおきなさい」ということです。例えばCNNではPillowやcv2で読み込んだ画像は255で割るのが一般的です。ですが、学習のコードではちゃんと255.0で割っていても、推論のコードで書くのを忘れて「なんだこのクソモデルは」となることは往々にしてあります。

# train.py
def train(model,train_imgs,train_labels):
    imgs = imgs / 255.0
    model.fit(imgs, train_labels)


# predict.py
model.predict(imgs)

Lambdaレイヤーを含むモデルの保存

VAEとかはLambdaを含みます。そのままだとsave_modelでh5で保存することができません(checkpoint保存はできた気がする)。

その場合、本当に保存しなければならない部分だけ保存するようにします。VAEの場合だとLambdaレイヤーの前のEncoder部分とDecoder部分をそれぞれ抜き出し、Modelにしてそれぞれをsave_modelします。


今年は忙しくてあまり何か作ったりできなかったのでこの辺で。