集中學習(單機)
一個最簡單的例子,我們想學習人的身高和體重之間的線性關係,並且我們擁有100人的體重和身高數據,想訓練一種線性模型,該模型使用身高預測人們的體重,線性迴歸W = [a,b]如下:
我們怎麼找到w?為了求w,使用梯度下降法(GD),從一個隨機的w開始,然後通過沿誤差的相反方向在100個數據點上最小化模型的誤差。
設置A = 0和B = 2,併為每個數據點計算我們的模型,如下所示:
上面的方程肯定是不成立的,因為2 * 1.70 + 0不等於72。我們的目標是找到一個a和b使這個等式成立。所以需要計算該模型對於所有100人的數據點的誤差:
目標是找到使所有數據點的誤差為零的模型,我們假定負誤差與正誤差相等。因此將總誤差定義為所有數據點平方誤差的平均值,如下所示:
強調一下這個總誤差或者說損失函數的關鍵點是對所有數據點的平均值,也就是說每個數據點對總誤差的貢獻是相等的。損失函數是通過平均所有數據點的誤差來計算的,每個數據點對損失函數的貢獻是相等的。
為了用梯度下降法求出a和b的最優值,需要計算b在初始b點的梯度,並按如下方式更新值:
Lambda是學習率,繼續看下圖
要計算F的梯度,首先需要以完整的形式編寫F。
現在,準備計算F相對於B的梯度:
到梯度是每個數據點錯誤梯度的平均值! 使用上面定義的符號,我們可以按以下方式完成梯度下降更新規則:
通過平均每個數據點的誤差來計算損失函數的真實梯度,然後將新B替換為上一個B,直到我們的總錯誤足夠小。 這是一個迭代過程,通過多次寵物可以找到A和B的最佳價值。
隨機梯度下降(SGD)
我們通過在100個數據點的所有梯度上平均來計算F的梯度。 如果我們僅使用20個數據點進行估計,該怎麼辦?
這被為小批量的隨機梯度下降,僅利用數據子集來計算梯度。
分佈式隨機梯度下降(D-SGD)
讓我們看一下從另一個角度計算的梯度。
如果我們按照上面的公式重寫梯度並將其分為2部分求和時,每個和式都有其意義。第一部分實際上是前50個點數據的平均梯度,第二部分是數據集後50個點數據的平均梯度。
這意味著我們不需要將所有的100個數據點放在一個地方(同一臺服務器)!我們可以將數據分成兩部分然後分別計算每個部分的梯度,然後對這兩個梯度求平均值,來計算整個數據的梯度。這就是D-SGD的主要思想。
現在,我們有兩個客戶機的分佈式SGD。
如上所示,在D-SGD中兩個客戶端都從相同的b點開始,然後各自用50個數據點計算每個客戶端的梯度。然後將局部梯度發送到充當協調器的服務器上。該協調器會對兩個梯度求平均值,然後計算整個數據的梯度或叫全局梯度。服務器返回這個全局梯度給兩個客戶端,客戶端使用這個全局梯度來更新他們的b值或他們的模型。b的新值對每個客戶端都是一樣的,因為全局梯度是一樣的,計算出來的新b也應該是一樣的。這個過程如下圖所示。
從1(計算局部梯度)到4(下載全局梯度)的步驟不斷迭代,直到達到預定義的誤差水平。在這個示例中,我們只使用了兩個客戶端,但是它可以擴展到許多客戶端。
需要說明的是,我們是用局部梯度來估計全局梯度!
聯邦學習(FL)
如果我們利用每個客戶端的局部梯度來計算每個局部模型,或者在我們的例子中,b如下所示,會發生什麼?
在這個場景中,會以每個客戶端不同的b值結束,如上圖所示,我們稱之為本地模型。
如果我們這樣做,每個局部模型都會進行參數b的更新,這意味著不需要發送局部梯度。而是將局部模型的參數或者中間結果發送到服務器進行平均,然後得到全局模型。這是聯邦學習的主要思想。
FL系統通過重複以下過程來優化全局機器學習(ML)模型:
i)每個客戶端設備對其數據進行本地計算以最小化全局模型w。
ii)然後將其本地更新的模型發送到FL服務器進行聚合;
iii) FL服務器對接收到的局部模型進行聚合,生成改進的全局模型;
Iv),服務器將更新後的全局模型發送給客戶端設備,客戶端設備使用新的全局模型進行下一次的計算。
這個過程會不斷迭代,直到模型達到預定義的精度水平。這個過程如下圖所示。
聯邦學習vs分佈式SGD
在FL中使用模型權重,但在D-SGD中只使用梯度。在我們討論的例子中,在發送更新之前只進行了梯度下降的一個局部步驟。在這種情況下,FL相當於分佈式sgd。如果要進行多個步驟,需要使用FL發送模型權重。一般形式的FL的收斂分析(多個局部步驟)不同於我們所做的分佈式- sgd分析。但是原理都是差不多的。
我們在本文中描述的D-SGD算法(中心化D-SGD)和FL算法(FEDAVG)只是D-SGD和FL的眾多算法之一。
為什麼聯邦學習是有用的?
我們需要FL的主要原因是因為隱私。我們不希望將私人原始數據洩露給任何用於訓練機器學習模型的服務器。所以需要一種不需要從客戶端設備發送原始數據就可以訓練機器學習算法,這就是聯邦學習的作用。例如,谷歌利用FL來改進它的鍵盤應用程序(Gboard)。FL在不同的應用中有用還有其他原因。例如FL使系統能夠利用移動設備等本地計算,以減輕服務器的壓力。
聯邦學習的挑戰
我們可以將FL面臨的挑戰分為兩類。第一類是在運行FL流程之前的數據準備流程流程。這個的關鍵問題是,不能訪問原始數據,甚至不能訪問FL系統的設備。我們需要知道如何在不訪問設備的情況下設計模型或評估數據?
第二類的挑戰是運行FL流程時出現的問題。需要考慮到參與FL系統的客戶端資源是受限的,他們在發送或處理ML模型方面的能力有限,例如在本文的例子中,我們的參數只有b,傳輸完整的參數是可行的,但是如果模型很大,例如BERT,那麼我們不可能在客戶端和服務器之間傳輸幾個G的數據,這是不可能的。
總結
聯邦學習是一個建立在分佈式學習框架上的新興主題,它試圖解決現實應用程序中訓練ML模型的隱私問題。在本文中,我們只觸及了這些系統的表面,如果你想深入瞭解這方面的知識可以自己搜素相關的文章或者等待我們後續的相關文章。
https://avoid.overfit.cn/post/ea6d50f42f904c97b4fa299be0c389b5