青ポスの部屋

旅と技術とポエムのブログ

TensorFlowでLossを詳しく解析する方法 ①カスタムメトリックを自作する

はじめに

機械学習をやるとき、損失関数を例えば

loss = loss_a + loss_b + loss_c

のような形で定義することはよくあります。このとき普通にやると全部の総和になっているlossの記録は簡単に取れますが、loss_aやloss_bがそれぞれどのように変化していくのか詳細に見たいことはあります。ただし、少なくともデフォルトで簡単にこうすれば取れるというような機能はTensorFlowには(私が探した限りでは)ありません。

そこで、今回はloss_aやloss_bをmetricとして記録して後からhistoryで解析する方法を書きます。

カスタムメトリックの実装方法

仕様

modelをcompileするときに設定するmetric引数の仕様は下記の通りになっています。

List of metrics to be evaluated by the model during training and testing. Each of this can be a string (name of a built-in function), function or a keras.metrics.Metric instance. See keras.metrics. Typically you will use metrics=['accuracy']. A function is any callable with the signature result = fn(y_true, _pred).

https://www.tensorflow.org/api_docs/python/tf/keras/Model

つまりmetricはy_trueとy_predを引数とする関数として自作することができます。

設定(1) コンパイルするときに渡す

自作したmetricを設定する方法は2つあります。

まずmodel.compileするときにmetricの引数のリストを与える方法です。下記ではmymodelkeras.Modelをオーバーライドした自作モデルとします。compileするときにmetric引数にリストを与えます。こうするとmodel.metricsにメトリックを登録することができます。

mymodel.compile(optimizer="adam", loss="mse", metics=["acc", loss_a, loss_b])

このとき、当然ですがmymodelとloss_aなどは同じ名前空間に属する必要があります。もう少し言い換えるとmymodel.compileするときにちゃんと参照できる場所にloss_aなどが定義されている必要があります。

設定方法(2) metricsメソッドを自作する

一方で名前空間の分け方的にmymodelの実装の中でlossを定義しておきたいこともあります。そのような場合、modelのpropertyであるmetricsメソッドをオーバーライドしてリストを与えることもできます。

class mymodel(keras.Model):
    #
    # いろいろな実装
    #
    @property
    def metrics(self):
        return [loss_a, loss_b]

train_stepを自作している場合

なおtrain_stepを自作している場合、metricを上記の方法で追加しても学習中のプレビューには出力されずhistoryにも記録されません。

この場合、自作したtrain_stepのreturnに含めてやる必要があります。この記述はModel.metricとは無関係です。逆にhistoryやcallbackで受け取りたいときや学習中に表示させたいだけの場合は下記だけを記述することで実装できます

def train_step(self, data):
    # (実装)
    return {"loss_a": loss_a, "loss_b": loss_b}

(y_pred, y_true) 以外の引数が必要な場合

最初に述べた通り、上記でModel.metricsに追加できるメトリックは引数が (y_pred, y_true) であるものに限られます。例えば入力値 x が必要なメトリックは上記の方法では追加することができません。

ただしtrain_stepではxを与えることができるし、その値をreturnすることで値を出力させたりhistoryに記録することができます。

あと「add_metricという関数を使うことでメトリックを追加できる」という情報もあります。ただこれがmodel.metricsに登録することを指すのか単に出力するだけなのかはよくわかりません。

そもそもModel.metricsに追加した場合のメリットは何なのでしょうか。一番重要な違いはevaluateするときに出力されるかどうかです。つまりevaluateを使っていない場合は固執する必要はないと思います。もしもどうしてもevaluateでカスタムメトリックが出ないと困る場合はevaluateごと自作してオーバーライドしてやればよいと思います(未検証)。

続き

bluepost69.hatenablog.com

補足

本記事を作成するにあたってCopilotでヒントを得ました。

参考文献