[ML] Federated Learning: Google 輸入法的應用
本文內容主要來自: Hard, Andrew, et al. "Federated learning for mobile keyboard prediction." arXiv preprint arXiv:1811.03604 (2018).
我們通常將大量資料收集到一個伺服器, 並學習網路之參數,
這樣架構的假設是: 這些資料包含了所有使用者的行為,
因此, 在給定一個好的網路架構的前提下,
我們便可藉由該網路的參數學習, 來表現所有資料的特徵.
然而, 這樣的假設, 面臨到行為的動態, 以及資料增長的問題.
舉例來說, 在此研究中的題目, 是根據現有輸入, 對下一個輸入詞彙的預測.
此問題會隨著地域與時間的變化, 而有不同結果.
因此, 如何透過使用者的回饋, 動態更新模型, 就是一個困難的問題.
因此, Federated Learning 被提出用一個簡單的方法來解決此問題,
其中的想法與資料交換大致如下圖所示:
在一開始, 每一個裝置會先下載一個共通的模型 (w_t),
在此應用中, 下一個字元的預測藉由 LSTM 網路進行,
因此, w_t 代表的是 LSTM 中的參數, 透過預先收集的大量資料學習出來,
接著, 根據每個裝置上收集的資料 (個數為n_k), 裝置自行學習並更新 LSTM 網路,
此參數值 (w^k_{t+1}) 也會回傳至中央伺服器, 執行 Federated Learning 更新,
中央伺服器收到每個裝置的更新參數後, 根據資料的大小賦與權重 (n_k/N),
並更新整體網路的共通模型 (w_{t+1}).
在 Federated Learning 的架構下, 不但能夠利用裝置有限的計算資源,
提供一個分散式的網路學習架構, 動態對應越來越大的資料叢集,
另一方面, 考慮到回傳的資料為 LSTM 網路的參數, 而非使用者行為,
此方法可以減少對使用者隱私的侵害, 並減少中間網路的資料傳輸.
留言
張貼留言