この記事はランク学習(Learning to Rank) Advent Calendar 2018 - Adventarの13本目の記事です
この記事は何?
ニューラルネットワークを用いたランク学習の手法として、ListNet*1が提案されています。
以前下の記事で、同じくニューラルネットワークを用いたランク学習の手法であるRankNetを紹介しましたが、ListNetはRankNetと異なり、Listwise手法に分類されます。
www.szdrblog.info
この記事では、PyTorchを用いたListNetの実装を紹介します。
ListNet
非常にざっくりとListNetの解説をします。詳細解説は以下文献が詳しいので、そちらをご確認ください。
著者スライド
Learning to Rank: from Pairwise Approach to Listwise Approach
日本語記事(Chainer実装付き)
qiita.com
訓練データに個のクエリに関するデータ点が含まれており、クエリに文書群が紐づいているとします。
それぞれの文書には関連度評価値(例えば、Excellent(4)・Perfect(3)・Good(2)・Fair(1)・Bad(0))が与えられています。
ランキングモデルをとし、文書の特徴量を用いて、ランクスコアが得られます。
…ここまで準備できたところで、論文中で"top one probability"と呼んでいる確率を定義します(ちょっとステップ飛ばしてますが)。
…なんか複雑っぽいですが、それぞれ文書関連度・ランクスコアに対するsoftmaxを計算しているだけです。
上の"top one probability"を使って、ListNetにおける損失関数(論文中では交差エントロピーを使用)は以下のように定義できます。
…というわけで、損失関数が定義できたので、あとはこれを最適化するだけです。
PyTorchを用いたListNetの実装
それでは、本題のPyTorchを用いたListNetの実装を紹介します。
下の記事で紹介したRankNetの実装と重複しているコードも多いですが。。。
www.szdrblog.info
まずはネットワークの定義です。
今回は単純なfeed-forwardニューラルネットワークを使います。
class Net(nn.Module): def __init__(self, D): super(Net, self).__init__() self.l1 = nn.Linear(D, 10) self.l2 = nn.Linear(10, 1) def forward(self, x): x = torch.sigmoid(self.l1(x)) x = self.l2(x) return x
次に、本実装のメインであるListNetの損失関数を実装します。
def listnet_loss(y_i, z_i): """ y_i: (n_i, 1) z_i: (n_i, 1) """ P_y_i = F.softmax(y_i, dim=0) P_z_i = F.softmax(z_i, dim=0) return - torch.sum(P_y_i * torch.log(P_z_i))
…めちゃくちゃあっさりしてますね…は文書関連度のベクトル、は予測スコアのベクトルを表しています。
PyTorchのCrossEntropyLoss使うともっとあっさり書けるんですかね?クエリによって文書数が異なるケース(CrossEntropyLossにおけるクラス数)でもうまく動くか分からなかったので、明示的に書いてみました。
上の実装ですが、一度に複数クエリに関するデータを受け取れない実装になっているので注意してください。その辺りはお好みで拡張を…
では実際に動かしてみます、精度評価はswapped-pairsとNDCGを使います。
def ndcg(ys_true, ys_pred): def dcg(ys_true, ys_pred): _, argsort = torch.sort(ys_pred, descending=True, dim=0) ys_true_sorted = ys_true[argsort] ret = 0 for i, l in enumerate(ys_true_sorted, 1): ret += (2 ** l - 1) / np.log2(1 + i) return ret ideal_dcg = dcg(ys_true, ys_true) pred_dcg = dcg(ys_true, ys_pred) return pred_dcg / ideal_dcg
学習・精度評価を合わせたソースコードの全体像を公開しておきます。
上のコードを実行すると、epoch毎にvalidationにおけるswapped-pairsとndcgを出力します。
epoch: 1 valid swapped pairs: 1095/4950 ndcg: 0.8722 epoch: 2 valid swapped pairs: 787/4950 ndcg: 0.9366 epoch: 3 valid swapped pairs: 548/4950 ndcg: 0.9701 epoch: 4 valid swapped pairs: 385/4950 ndcg: 0.9841 epoch: 5 valid swapped pairs: 275/4950 ndcg: 0.9908 epoch: 6 valid swapped pairs: 224/4950 ndcg: 0.9937 epoch: 7 valid swapped pairs: 182/4950 ndcg: 0.9952 epoch: 8 valid swapped pairs: 146/4950 ndcg: 0.9966 epoch: 9 valid swapped pairs: 139/4950 ndcg: 0.9965 epoch: 10 valid swapped pairs: 113/4950 ndcg: 0.9972
学習を進めていくと、ちゃんとswapped-pairsの数が小さくなり、ndcgが向上していくことが分かります!
まとめ
この記事ではPyTorchを用いたListNetの実装を紹介しました。
ListNetはRankNetよりも効率的に学習でき、NDCGやMAPといった評価指標についても精度で勝つなど、かなり強力な手法だと思います。
*1:"Learning to Rank: From Pairwise Approach to Listwise Approach", Z. Cao, 2007.