[ML] Federated Learning: Google 輸入法的應用
本文內容主要來自: Hard, Andrew, et al. "Federated learning for mobile keyboard prediction." arXiv preprint arXiv:1811.03604 (2018). 連結參考: https://arxiv.org/abs/1811.03604 在舊有的 deep learning 架構中, 我們通常將大量資料收集到一個伺服器, 並學習網路之參數, 這樣架構的假設是: 這些資料包含了所有使用者的行為, 因此, 在給定一個好的網路架構的前提下, 我們便可藉由該網路的參數學習, 來表現所有資料的特徵. 然而, 這樣的假設, 面臨到行為的動態, 以及資料增長的問題. 舉例來說, 在此研究中的題目, 是根據現有輸入, 對下一個輸入詞彙的預測. 此問題會隨著地域與時間的變化, 而有不同結果. 因此, 如何透過使用者的回饋, 動態更新模型, 就是一個困難的問題. 因此, Federated Learning 被提出用一個簡單的方法來解決此問題, 其中的想法與資料交換大致如下圖所示: 在一開始, 每一個裝置會先下載一個共通的模型 (w_t), 在此應用中, 下一個字元的預測藉由 LSTM 網路進行, 因此, w_t 代表的是 LSTM 中的參數, 透過預先收集的大量資料學習出來, 接著, 根據每個裝置上收集的資料 (個數為n_k), 裝置自行學習並更新 LSTM 網路, 此參數值 (w^k_{t+1}) 也會回傳至中央伺服器, 執行 Federated Learning 更新, 中央伺服器收到每個裝置的更新參數後, 根據資料的大小賦與權重 (n_k/N), 並更新整體網路的共通模型 (w_{t+1}). 在 Federated Learning 的架構下, 不但能夠利用裝置有限的計算資源, 提供一個分散式的網路學習架構, 動態對應越來越大的資料叢集, 另一方面, 考慮到回傳的資料為 LSTM 網路的參數, 而非使用者行為, 此方法可以減少對使用者隱私的侵害, 並減少中間網路的資料傳輸.