この記事は福島高専 Advent Calendar 2020 22日目の記事です。
当方初執筆のため、気になる点などあればコメントよろしくお願いします。
はじめに
Tensorflow での学習を研究室のPCで回すことがよくあるのですが、進捗状況の確認のためにいちいち学校へ行くのがとても面倒なので、自宅でも進捗状況を確認できるシステムを作りました。
やりたいこと
- 遠隔地からの学習状況の確認
- 損失関数、評価関数のグラフの描画
実現方法
各 epoch ごとの損失関数や評価関数の値などの諸々の結果を「学習機」から送ってもらい、「監視機」に送信することで実現させます。
学習状況をリアルタイムで監視したいため、Websocketを用いて情報のやり取りを行います。
環境
- Python 3.8.6
- Node.js v15.3.0
実装
思ったよりコードが長くなってしまったので、掲載分のコードは適当なところで端折ってあります。
Githubのリポジトリにソースコードがあるので、そちらを参照してください。
サーバー側
Node.jsを利用します。Websocket サーバーを作成するため、ws
パッケージをnpm
でインストールしてあげます。
npm install ws
Websocketでは、文字列データを受信/送信することができるため、JSON文字列を送り、受信側でデコードすることで簡単に辞書データを送ることができます。
以下のような構造のJSONデータをやり取りすることにより、送られてきたデータによって処理を分岐させます。
{type:"xxx",//データの種類data:{...// 諸々のデータ}}
constws=require("ws");// データの種類を定数で列挙したファイルをrequireconstmessageType=require("./messageType").messageType//サーバオブジェクト作成constserver=newws.Server({port:8765})server.on("connection",(ws)=>{ws.on("message",(data)=>{console.log(data)// 送られてきたデータをパースconstjData=JSON.parse(data)switch(jData.type){// 接続時casemessageType.sessionStart:...break;// 学習タスクに関するメッセージ受信時casemessageType.trainInfo:...break;});});
クライアント側
学習機、監視機どちらもPythonで実装していきます。
学習機でのWebsocket通信にwebsockets
、監視機でのGUI作成にPyQt5
を使います。
pip install websockets PyQt5
学習機側
各エポックの損失関数・評価関数の値をサーバーに送信するためのコールバックを作成していきます。
Tensorflow のモデルで使用するコールバックは、tf.keras.callbacks.Callback
クラスを継承することで自作できます。
今回は各エポック終了時と学習終了時に値を送信したいので、on_epoch_end
とon_train_end
メソッドに送信部分を書いていきます。
importtensorflowastfimportwebsocketsimportjsonfrom.wsConstimportmessageTypeimportredefmakeWSData(dataType:str,data:dict)->dict:return{"type":dataType,"data":data}classwsConnector(tf.keras.callbacks.Callback):def__init__(self,URI,loop,name,details,trainerID,result_regex:dict):super().__init__()self.loop=loopself.trainerID=trainerIDself.URI=URIself.taskID=self.getTaskID(name,details)self.result_regex={k:re.compile(p)fork,pinresult_regex.items()}defgetTaskID(self,name,details):trainData=details.copy()trainData["layer"]=namepacket=makeWSData(dataType=messageType["trainInfo"],data={"id":self.trainerID,"name":name,"type":messageType["train"]["start"],"data":trainData})recvPacket=self.loop.run_until_complete(self.send_and_recv(json.dumps(packet)))recvDict=json.loads(recvPacket)returnrecvDict["data"]["taskID"]defon_epoch_end(self,epoch,logs=None):result={k:{}forkinself.result_regex.keys()}fork,vinlogs.items():forresult_class,result_regexinself.result_regex.items():ifresult_regex.search(k)isnotNone:result[result_class][k]=float(v)packet=makeWSData(dataType=messageType["trainInfo"],data={"id":self.trainerID,"type":messageType["train"]["update"],"data":{"id":self.taskID,"epoch":epoch+1,"result":result}})self.loop.run_until_complete(self.send(json.dumps(packet)))defon_train_end(self,logs=None):iflogsisNone:result=Noneelse:result={k:{}forkinself.result_regex.keys()}fork,vinlogs.items():forresult_class,result_regexinself.result_regex.items():ifresult_regex.search(k)isnotNone:result[result_class][k]=float(v)packet=makeWSData(dataType=messageType["trainInfo"],data={"id":self.trainerID,"type":messageType["train"]["end"],"data":{"id":self.taskID,"status":messageType["train"]["success"],"result":result}})self.loop.run_until_complete(self.send(json.dumps(packet)))asyncdefsend(self,message):asyncwithwebsockets.connect(self.URI)asws:awaitws.send(message)asyncdefsend_and_recv(self,message):asyncwithwebsockets.connect(self.URI)asws:awaitws.send(message)returnawaitws.recv()
コールバックを作成したら、モデルがコールバックを呼び出してくれるようmodel.fit
時に指定してあげます。
wsCallBack=wsConnector(URI="ws://localhost:8765",loop=loop,name=dataName,details=trainParams,trainerID=trainerID,result_regex={"loss":".*loss.*","accuracy":".*accuracy.*"})model.fit(trainData,epochs=epoch,callbacks=[wsCallback,],validation_data=testData,)
受信した値をコンソールに出力する Websocket テストサーバを作成し、起動した状態で学習タスクを走らせると、きちんと損失関数・評価関数の値が送信されているのが確認できます。
importwebsocketsimportjsonimportasyncioasyncdefserver(ws,path):print(json.loads(awaitws.recv()))loop=asyncio.get_event_loop()loop.run_until_complete(websockets.serve(server,"localhost",8765))loop.run_forever()
{'type': 'TRAIN_INFO', 'data': {'id': 'hoge', 'type': 'UPDATE', 'data': {'id': 'Hoge', 'epoch': 1, 'result': {'loss': {'loss': 6.121824748860965, 'val_loss': 5.398087776068485}, 'accuracy': {'sparse_categorical_accuracy_softmax': 0.08984068036079407, 'val_sparse_categorical_accuracy_softmax': 0.08866003900766373}}}}}
{'type': 'TRAIN_INFO', 'data': {'id': 'hoge', 'type': 'UPDATE', 'data': {'id': 'Hoge', 'epoch': 2, 'result': {'loss': {'loss': 5.307668804372631, 'val_loss': 5.248192136937922}, 'accuracy': {'sparse_categorical_accuracy_softmax': 0.09052243083715439, 'val_sparse_categorical_accuracy_softmax': 0.08866003900766373}}}}}
...
監視機側
PyQt5
でGUIを作っていきます。こんな感じで実装しました。
また、タスク一覧に表示されたタスクをダブルクリックすると、損失関数・評価関数のグラフが表示されるようにしました。
あとがき
12月に入るまで書く内容が決まらなかったり、「リモートデスクトップ使えば全部解決するんじゃ…」などと考えてモチベーションがだだ下がりしていたのですが、なんとか記事投稿まで持ってこれたので安心しています。
突貫作業だったこともあり、記事の焦点がブレブレなので次回執筆する際にはしっかり考えて書こうと思います。
いろいろ端折ってしまったため、分かりづらいなと感じた箇所は適宜修正・加筆していきます。
参考文献
Pythonの非同期通信(asyncioモジュール)入門を書きました