青ポスの部屋

旅と技術とポエムのブログ

TensorFlowでLossを詳しく解析する方法 ②カスタムメトリックを記録する

前回の続きです。

前回:

bluepost69.hatenablog.com

おさらい:簡単にメトリックを取得する方法

前回カスタムメトリックを実装する方法をいくつか示しました。単に学習中のカスタムメトリックを取得するだけなら、train_stepのreturnに取得したいメトリックの値を入れておけばよいです。

def train_step(self, data):
    # (実装)
    return {"loss_a": loss_a, "loss_b": loss_b}

今回はこのようにして記録したloss_aやloss_bをどのように取り出すかをまとめておきます。

historyを使って取り出す方法

fitの戻り値にhistoryというオブジェクトが戻ります。この中にはloss_aなどのメトリックも記録されます。

ただし、historyはfitが完了しないと実行されないので、異常終了したときとかにどうなるのかというちょっとした心配が残ります。

callbackを使って取り出す方法

callbackを使えば1エポック終わった後に実行される処理をカスタムすることができます。今回はcallbackでテキストファイルに値を追記していく形で実装します。

callbackはtensorflow.keras.callbacks.Callbackクラスをオーバーライドして実装します。このクラスはon_epoch_endon_test_beginのようなメソッドを持っており、これらをオーバーライドしてやることで各処理の開始前や終了後に行う処理を実装できます。

今回は1エポック学習が終わった後にテキストにwriteする処理を書きます。ここではcallback_writeという名前で実装します。Callbackクラスをオーバーライドして関数の引数にlogsを書いておけば、変数logsにメトリックの名前がキーのdictが渡ってきます

class callback_write(tf.keras.callbacks.Callback):
    def on_epoch_end(elf, epoch, logs=None):
        with open("loss.log", "a") as f:
            f.write(logs["loss_a"], logs["loss_b"])

※今回は適当に書いていますが、実際はformat文などでよしなに整えて出すとよいでしょう。

あとはfitのcallbacks引数に該当するcallbackを与えます。callbacks引数はcallbackのリストを引数に取るので、例えば「出したいメトリックがlossとcovariantと大別して2種類ある」というような場合は2つcallbackを作ってもよいでしょう。

model.fit(xx, yy, epochs=myepochs, batch_size=mybatch_size, callbacks=[callback_write(),])

なお引数に与えるcallbackはすでにcallされた形(つまり()をつける)で書く必要がある点に注意しましょう。

補足

本記事を作成するにあたってCopilotでヒントを得ました。

参考文献