こんにちは、NRI システムデザインコンサルティング部 高橋です。
本投稿ではFederated Learningについてご紹介します。
目次
1. 背景と課題
近年のAI活用の流れではDeep Learningモデルをはじめとした多くのモデルが活用されています。
特に画像・動画・自然言語・音声といった非構造化データが活用できるようになり、大きなブレークスルーとなっていることも大きく影響していると思います。
一方で画像をはじめとした大容量かつ個人情報を含むようなデータを扱う際、データ通信負荷やデータ秘匿性への懸念もあります。
例えばモバイル端末で収集したデータや工場内や車載機で取得されるデータなど、分散したところからデータが生成される場合、それらを活用する際にデータを1ヶ所に集約しモデルを作成するケースがあげられます。
また、医療分野のデータでは特にプライバシーに配慮したデータの利活用が求められています。
これらの生成されるデータを活用しモデルを作成する際、大容量データの転送負荷やデータ秘匿性に課題が発生します。
2. Federated Learningの概要・技術要素
課題として挙げたデータ通信量の削減とデータ秘匿性について対応するための技術としてFederated Learning(連合学習)があります。
Federated Learningはデータを集約せず各端末に分散した状態で機械学習を行う手法であり、Googleにより2017年に提唱されました。
Applied Federated Learning: Improving Google Keyboard Query Suggestionsの論文ではGoogle Keyboard(Gboard:モバイルデバイス用の文字入力システム)の検索キーワード予測機能が取り上げられており、実際にGoogleのキーボード予測変換にはこの技術が活用されています。
詳細な紹介は下記が詳しいです。
Federated Learning:モバイルデバイスを用いた分散学習技術
一部引用すると下記のような内容となります。
検索テキストに対する関連キーワードの選択で機械学習モデルを使用しているが、ユーザーの入力テキストをサーバーに送信するのはプライバシー観点から問題となり得る。
従ってデータはユーザーのデバイス上に配置しておきデバイス上でモデルを更新し、更新したパラメータの差分データのみサーバーへ送信する。
Federated Learningでは、モデルは各端末上で作成しモデルのパラメータのみ中央サーバーに送るため、学習データは中央サーバーに送信する必要が無く、データの通信量の削減につながります。画像などの大容量データの場合は特に効果を発揮すると思います。
また、データは各デバイス上に分散して存在しており、既存の中央集権的なモデルと比べデータの移動が最小限となるため、プライバシー保護に優れています。
懸念点として、モデルによってはGPUが必要となり各デバイスにGPU環境を用意する場合はその分コストがかかるため、適用する際は自社のデータ通信環境やデータ転送業務コストとの兼ね合いになります。
また、初期モデルの構築以降でモデルのメンテナンス等をする際、誤差によるモデル改善示唆などが得られにくい点も懸念として挙げられます。
分散デバイス上でモデルを学習する際に誤差が大きいサンプルを記録しておき、それらのデータのみ中央サーバーへ送るなどにより対応可能かもしれません。
続いて、Federated Learningの具体的な活用フローを下記に示します。
①クライアント (Collaborator)はサーバー(Aggregator)からに最新の共有モデル(Global Model)をダウンロードする
②クライアントではクライアント内のローカルデータを用いてモデルを作成し、共有モデルのパラメータとの差分を計算する
③差分情報のみをクライアントからサーバーに送信する
④サーバーは複数のクライアントから連携された差分情報を取り込み(※モデルの重みを対象として、データサイズによる加重平均など)、共有モデルをアップデートする
⑤上記①~④を1つのラウンドとして指定のラウンドまで繰り返す
このようなアーキテクチャーとすることで下記のメリットが得られます。
- データ通信量の削減
- 中央サーバーのモデル重みデータの保管コスト削減
- 中央サーバーの電力消費量の削減
- プライバシーの保護
このようにFederated LearningではAIモデルの構築において従来のようにデータを1か所に集約することなく、分散された複数のクライアント環境によって並列的に学習を行うことが可能となります。
各クライアント環境にあるデータを用いてモデルの学習を行い、そこで取得されるモデルの重みを中央サーバーに集約・統合することで、それぞれのクライアント環境のデータ量が少量・不均衡であっても全てのデータを使用して学習を行う場合と同等のモデルを構築することが可能となります。
< 実装例 >
ここでは実装例を紹介します。
CIFAR10の画像データを使用した分類タスク(クラス:10)を想定します。
ここではflowerというライブラリを使用し、ローカルPC環境内でサーバーとクライアントを立てて完結する検証とします。
flowerの詳細は下記を参照ください。
下記が準備するファイル群です。
=========================================================
centralized.py:
Global Modelのクラス、学習・テスト用の関数を記述。
server.py:
サーバーとしてクライアントからの要求に応答。
複数クライアントからの結果を加重平均した評価関数を記述。
client_1.py:
Global Modelのロード、サーバーからのパラメータ取得、
サーバーへのパラメータ送信、ローカルデータを使用したモデル学習を記述。
client_2.py:
client_1.pyと同様。
=========================================================
centralized.py
import torch import torch.nn as nn import torch.nn.functional as F import torchvision import torchvision.transforms as transforms from torchvision.transforms import Compose, ToTensor, Normalize from torch.utils.data import DataLoader, random_split from torchvision.datasets import CIFAR10 DEVICE = "cpu" class Net(nn.Module): def __init__(self) -> None: <ネットワークの要素を記述> def forward(self, x: torch.Tensor) -> torch.Tensor: <ネットワークのアーキテクチャーを記述> def train(net, trainloader, epochs: int, verbose=False): """Train the network on the training set.""" criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(net.parameters(), lr=0.001) net.train() for epoch in range(epochs): <epochごとにNetモデルを学習するコードを記述> def test(net, testloader): """Evaluate the network on the entire test set.""" criterion = torch.nn.CrossEntropyLoss() correct, total, loss = 0, 0, 0.0 net.eval() with torch.no_grad(): <分類クラスを予測するコードを記述> return loss, accuracy def load_datasets(): < CIFAR-10をダウウンロードしてデータローダーとして返却するコードを記述> return trainloader, testloader def load_model(): return Net().to(DEVICE)
server.py
import flwr as fl from flwr.common import Metrics from typing import List, Tuple # 評価関数の定義:加重平均 def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] examples = [num_examples for num_examples, _ in metrics] return {"weighted average accuracy": sum(accuracies) / sum(examples)} fl.server.start_server( server_address="0.0.0.0:8080", # サーバーのアドレスを指定 config=fl.server.ServerConfig(num_rounds=4), # ラウンド数を指定 strategy=fl.server.strategy.FedAvg( evaluate_metrics_aggregation_fn=weighted_average ) )
flowerは「import flwr as fl」のように記述してインポートします。
※事前に「pip install flwr」でインストールが必要です。
fl.server.start_server()でクライアントからの接続に応答します。
※サーバーのアドレスは、すべてのインターフェースでLISTENする0.0.0.0としています。
num_rounds=4としているので、サーバーとクライアント間で4往復して学習が実施されます。
client_1.py
(client_2.pyも同様)
import torch import flwr as fl import numpy as np from typing import List, Tuple from collections import OrderedDict from centralized import load_datasets, load_model, train, test # サーバーから取得したパラメータでローカルモデルを更新 def set_parameters(net, parameters: List[np.ndarray]): params_dict = zip(net.state_dict().keys(), parameters) state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) net.load_state_dict(state_dict, strict=True) # クライアントのローカルモデルで学習したパラメータを取得 def get_parameters(net) -> List[np.ndarray]: return [val.cpu().numpy() for _, val in net.state_dict().items()] net = load_model() trainloader, testloader = load_datasets() class FlowerClient(fl.client.NumPyClient): def __init__(self, net, trainloader, testloader): self.net = net self.trainloader = trainloader self.testloader = testloader def get_parameters(self, config): return get_parameters(self.net) def fit(self, parameters, config): set_parameters(self.net, parameters) train(self.net, self.trainloader, epochs=2, verbose=True) return get_parameters(self.net), len(self.trainloader.dataset), {} def evaluate(self, parameters, config): set_parameters(self.net, parameters) loss, accuracy = test(self.net, self.testloader) return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)} fl.client.start_numpy_client( server_address="127.0.0.1:8080", # サーバーのアドレス client=FlowerClient(net, trainloader, testloader) )
サーバーと通信するためのサーバーアドレスは、
fl.client.start_numpy_client(server_address=***)の箇所に記述していますが、
ここではサーバーはローカルホストにあるためserver_address="127.0.0.1:8080"としています。
実行手順
ターミナル(またはコマンドプロンプト)を3つ用意し下記の順に実行します。
①ターミナル1で> python server.pyを実行
②ターミナル2で> python client_1.pyを実行
③ターミナル3で> python client_2.pyを実行
下記のログのように、ラウンドごとに精度が向上することが確認できます。
ここでは4ラウンド実行しておりAccuracyが0.4711 → 0.632 → 0.6519 → 0.6659となっていることがわかります。
INFO flwr 2023-08-15 16:51:52,241 | app.py:220 | app_fit: metrics_distributed {'weighted average accuracy': [(1, 0.4711), (2, 0.632), (3, 0.6519), (4, 0.6659)]}
3. ビジネスへの活用と事例
期待される活用領域としては、これまでの業界横断的なデータ活用が難しかった医療業界のデータや、不正検知や与信審査など業界共通でモデルを活用できる金融業界などがあげられます。
例えば病状を判定するモデルや疾患箇所を特定するモデルなどを作成する場合、病院ごとに患者のプライバシー情報を保護する必要がありますが、病状などの情報は病院に限らないため複数の病院の患者情報を用いてモデルを作成したほうが良いはずです。
Federated Learningを用いる場合、病院ごとに患者情報を用いてモデルを作成するため、データの外部持ち出しはなくプライバシーを確保することができます。
そして、各病院から算出されるパラメータ更新情報を中央サーバで集計することで複数の病院の患者情報を考慮したモデルを作成することができます。
下記ではFederated Learningを活用した事例を紹介します。
- インテルとペンシルバニア大学の取り組み
この取り組みでは71 の国際的な医療/研究機関から6,314 人の患者から取得した370 万枚のMRI画像によって脳腫瘍判定モデルが学習されました。
これはこれまでで最大の脳腫瘍データセットとなり、悪性脳腫瘍の検出を 33% 向上させる能力を実証しています。
Intel and Penn Medicine Announce Results of Largest Medical Federated Learning Study
- Owkinの取り組み
医療系AIスタートアップのOwkinは、連合学習をコア技術に医療データのプラットフォームを運営しています。
Owkinは、フランスの4つの主要病院内に保管されているデータを使用して、ネオアジュバント化学療法に対するトリプルネガティブ乳がん(TNBC)患者の将来の反応を正確に予測できるAIモデルを構築しています。
この研究は、複数の病院の病理組織データを使って、病院からデータを出さずに機械学習モデルを学習させた初めての例とされており、医学研究におけるFederatedLearningの画期的な実証となります。
また、下記のように国内での取り組みの動きもあるので今後期待される分野となります。
総務省委託研究開発「安全なデータ連携による最適化AI推進コンソーシアム」
4. NRIの取り組み
NRIではFederated Learningを適用している事例はありませんが、適用が検討できるケースを例として紹介いたします。
NRIではカーボンニュートラルに向け、トランザクションベースでCO2排出量の算定を行う製品(NRI-CTS)を開発していますが、現状多くの企業では自社の製品を作る際に排出するCO2排出量は統計的に算定された外部のデータベースの原単位をもとに活動量を掛け合わせて算出しています。
一方で、これからの業界動向として実測値に基づく算定が求められてきますが、すべてを実測値算定に置き換えることが現実的に難しい企業もあるかと思います。
その際に活動量や削減努力などの項目をもとに排出量を予測するモデルを作成することで、一律な原単位によるCO2排出量の算定と比較して確度が高い算定が期待できるでしょう。
5. まとめ
本投稿ではFederated Learningの概要と実装デモ、事例についてを記載しました。
まだ事例は少ないFederated Learningですが、ビジネスや業界標準化の取り組みと合わせて活用の検討が進めば、より社会へのインパクトがある分野となります。