注:この記事の実装は非効率的かもしれません.
この記事は何?
マルチタスク学習をニューラルネットワークに適用した研究がいくつか報告されています.
適用例として,Georgeら(2014)の研究では,タンパク質に対する化合物の活性予測にニューラルネットワークを用いたマルチタスク学習を適用しています.
また,Xiaodongら(2015)の研究では,自然言語処理のタスクとしてクエリ分類と情報検索の2つのタスクを同時に解くマルチタスク学習を提案しています.
私はDNNを実装する時はChainerを使っているのですが,Chainerでマルチタスク学習を実装した例が見当たらなかったため自分で実装してみました.
ネットワーク構造
解くタスクは2つとします.2つのタスクに対応するネットワークが存在し,それぞれのネットワーク間で一部の層を共有するようなネットワーク構造を考えます(下図参照).
学習の進め方と実装
学習は,「task: 0のデータで学習」→「task: 1のデータで学習」→「task: 0のデータで学習」→…というように交互に行っていきます.
記事冒頭でも示したように上手い実装が思いつかなかった,分からなかったので,次図のようにナイーブな実装を行いました.
この記事で紹介する実装では,各タスク毎にネットワークを定義します.
「task: 0のデータで学習」が終わると,およびの重みがそれぞれ更新されます.
その後,共有したい重みの値をもう片方のタスクにおけるネットワークにコピーします.
この操作を繰り返すことで,タスク毎に一部の重みは共有し,一部の重みは独立に学習を進めるということができます.
Chainerによる実装
まずはネットワーク構造を実装します,共有したい層をSharedNet,タスク毎に独立させたい層をSeparetedNet,それぞれの層を結合したものをCombinedNetとします.今回は簡単のためMLPで実装しましたが,他のネットワーク構造も簡単に扱えます.
import chainer import chainer.links as L import chainer.functions as F class SharedNet(chainer.Chain): def __init__(self, n_out): super(SharedNet, self).__init__( l1=L.Linear(None, n_out) ) def __call__(self, x): a = self.l1(x) z = F.sigmoid(a) return z class SeparatedNet(chainer.Chain): def __init__(self, n_out): super(SeparatedNet, self).__init__( l1=L.Linear(None, n_out) ) def __call__(self, x): a = self.l1(x) z = a return z class CombinedNet(chainer.Chain): def __init__(self, shared_net, separated_net): super(CombinedNet, self).__init__( shared_net=shared_net, separated_net=separated_net ) def __call__(self, x): a1 = self.shared_net(x) a2 = self.separated_net(a1) return a2
次に,学習用のコードを書きます.データセットには有名なirisデータセットを用いました.
irisデータセットのサンプルには3つのラベル(0, 1, 2)のどれかが振られていますが,task: 0を「ラベル0とラベル1の分類問題」,task: 1を「ラベル1とラベル2の分類問題」としました.3クラス分類問題を無理やりマルチタスクの問題にしてみました.
import numpy as np import chainer import chainer.functions as F from sklearn import datasets import mymodel N_TASK = 2 def main(): np.random.seed(0) # Model definition shared_nets = [mymodel.SharedNet(3) for i_task in range(N_TASK)] separated_nets = [mymodel.SeparatedNet(2) for i_task in range(N_TASK)] combined_nets = [mymodel.CombinedNet(shared_net, separated_net) for shared_net, separated_net in zip( shared_nets, separated_nets)] # Setup an optimizer optimizers = [chainer.optimizers.Adam() for i_task in range(N_TASK)] for i_task, optimizer in enumerate(optimizers): optimizer.use_cleargrads() optimizer.setup(combined_nets[i_task]) # Load dataset X, ys = datasets.load_iris(return_X_y=True) X = X.astype(np.float32) ys = ys.astype(np.int32) task_index_lst = [(ys == 0) | (ys == 1), (ys == 1) | (ys == 2)] X_by_task = [X[task_index] for task_index in task_index_lst] ys_by_task = [ys[task_index] for task_index in task_index_lst] """ labels should be 0 or 1 change labels of task1 1 or 2 -> 0 or 1 """ ys_by_task[1] -= 1 # First Training for i_task in range(N_TASK): combined_nets[i_task].cleargrads() optimizers[i_task].update(F.softmax_cross_entropy, combined_nets[i_task](X_by_task[i_task]), ys_by_task[i_task]) # Multi-task DNN (parameter sharing) if i_task == 0: combined_nets[1].shared_net.copyparams(combined_nets[0].shared_net) elif i_task == 1: combined_nets[0].shared_net.copyparams(combined_nets[1].shared_net) # Before training on task0 print("** Before training on task0 **") for i_task in range(N_TASK): print("task{}".format(i_task)) print("shared\n", combined_nets[i_task].shared_net.l1.W.data) print("separated\n", combined_nets[i_task].separated_net.l1.W.data) # Training on task0 optimizers[0].update(F.softmax_cross_entropy, combined_nets[i_task](X_by_task[0]), ys_by_task[0]) # After training on task0 print("\n** After training on task0 **") for i_task in range(N_TASK): print("task{}".format(i_task)) print("shared\n", combined_nets[i_task].shared_net.l1.W.data) print("separated\n", combined_nets[i_task].separated_net.l1.W.data) # Multi-task DNN (parameter sharing) combined_nets[1].shared_net.copyparams(combined_nets[0].shared_net) print("\n** After sharing parameter **") for i_task in range(N_TASK): print("task{}".format(i_task)) print("shared\n", combined_nets[i_task].shared_net.l1.W.data) print("separated\n", combined_nets[i_task].separated_net.l1.W.data) if __name__ == '__main__': main()
TrainerやUpdaterを使わず,泥臭く実装してみました.
First trainingではtask: 0およびtask: 1についてそれぞれ1回学習を行います.
その後,重みが正しく更新されているかどうかを見ていきます.
実行結果
上のコードを実行すると,以下のような結果が得られました.
** Before training on task0 ** task0 shared [[ 0.88203126 0.20008963 0.48937538 1.12046123] [ 0.93177909 -0.49063882 0.47304434 -0.07767794] [-0.05160943 0.20529924 0.0700218 0.72513676]] separated [[ 0.44038531 0.0712491 0.25726452] [ 0.19164696 0.86160696 -0.11944815]] task1 shared [[ 0.88203126 0.20008963 0.48937538 1.12046123] [ 0.93177909 -0.49063882 0.47304434 -0.07767794] [-0.05160943 0.20529924 0.0700218 0.72513676]] separated [[ 0.18174972 -0.49211243 -1.47296929] [ 0.37636688 0.49808249 -0.42948917]] ** After training on task0 ** task0 shared [[ 0.88270104 0.20075931 0.4900445 1.12112439] [ 0.93110901 -0.49130887 0.47237432 -0.07834772] [-0.05093938 0.2059693 0.06935176 0.72446674]] separated [[ 0.44105536 0.07191915 0.25793457] [ 0.1909769 0.86093688 -0.12011819]] task1 shared [[ 0.88203126 0.20008963 0.48937538 1.12046123] [ 0.93177909 -0.49063882 0.47304434 -0.07767794] [-0.05160943 0.20529924 0.0700218 0.72513676]] separated [[ 0.18174972 -0.49211243 -1.47296929] [ 0.37636688 0.49808249 -0.42948917]] ** After sharing parameter ** task0 shared [[ 0.88270104 0.20075931 0.4900445 1.12112439] [ 0.93110901 -0.49130887 0.47237432 -0.07834772] [-0.05093938 0.2059693 0.06935176 0.72446674]] separated [[ 0.44105536 0.07191915 0.25793457] [ 0.1909769 0.86093688 -0.12011819]] task1 shared [[ 0.88270104 0.20075931 0.4900445 1.12112439] [ 0.93110901 -0.49130887 0.47237432 -0.07834772] [-0.05093938 0.2059693 0.06935176 0.72446674]] separated [[ 0.18174972 -0.49211243 -1.47296929] [ 0.37636688 0.49808249 -0.42948917]]
"** Before training on task0 **"の段階では,task: 0およびtask: 1それぞれの共有部分の重みが等しくなっていることが分かります.
"** After training on task0 **"の段階では,task: 0のネットワークの重みが更新されたことが分かります.
"** After sharing parameter **"の段階では,task: 0およびtask: 1それぞれの共有部分の重みが等しくなっていることが分かります.
その他
重みを更新する度に,他方のタスクへの重みコピーが走るので効率的とは言えませんが,一応実装はできました.
ネットワーク定義自体で何とか出来るのでしょうか…ご存知の方いらっしゃったらご教授お願いします…(ヽ´ω`)
2017/3/8 追記
marugari2さんに,CombinedNetおよびoptimizerをそれぞれ単一で実装する方法を紹介していただきました.ありがとうございます!コメント欄をご参照ください.