由TensorFlow團隊的Goldie Gadde和Nikita Namjoshi釋出
TF 2.4在這裡!隨著對分散式培訓和混合精度的更多支援,新的NumPy前端以及用於監視和診斷瓶頸的工具,此版本的全部內容都涉及新功能以及效能和擴充套件方面的增強。
tf.distribute中的新功能引數伺服器策略
在2.4中,該tf.distribute模組引入了對使用ParameterServerStrategy和自定義訓練迴圈對模型進行非同步訓練的實驗支援。像一樣MultiWorkerMirroredStrategy,ParameterServerStrategy是一種多員工資料並行化策略;但是,漸變更新是非同步的。
引數伺服器培訓叢集由工作伺服器和引數伺服器組成。變數在引數伺服器上建立,然後在每個步驟中由工作人員讀取和更新。變數的讀取和更新在整個工作程序中獨立發生,而沒有任何同步。由於工作程序彼此不依賴,因此該策略具有工作程式容錯的優勢,並且在使用可搶佔式VM時很有用。
要開始使用此策略,請檢視Parameter Server Training教程。本教程向您展示如何設定ParameterServerStrategy並定義培訓步驟,並說明如何使用ClusterCoordinator類將培訓步驟的執行分派給遠端工作人員。
多工映象策略
MultiWorkerMirroredStrategy已經退出實驗階段,現在已成為穩定API的一部分。像它的一個工人對口MirroredStrategy,MultiWorkerMirroredStrategy機具分佈與同步資料並行訓練。但是,顧名思義,MultiWorkerMirroredStrategy您可以跨多臺機器進行訓練,每臺機器可能都具有多個GPU。
在同步訓練中,每個工作人員在輸入資料的不同片段上計算前進和後退透過次數,並在更新模型之前彙總梯度。對於這種聚合,稱為all-reduce,MultiWorkerMirroredStrategy使用CollectiveOps使變數保持同步。集合運算是TensorFlow圖中的單個運算,可以根據硬體,網路拓撲和張量大小在TensorFlow執行時中自動選擇全約演算法。
要開始使用MultiWorkerMirroredStrategy,請檢視“使用Keras進行多工作人員培訓”教程,該教程已更新,其中包含有關資料集分片,儲存/載入透過分配策略訓練的模型以及透過BackupAndRestore回撥進行故障恢復的詳細資訊。
如果您不熟悉分散式培訓,並且想學習入門,或者對GCP分散式培訓感興趣,請參閱此部落格文章,以獲取有關關鍵概念和步驟的介紹。
Keras更新混合精度
在TensorFlow 2.4中,Keras混合精度API 已脫離實驗性,現已成為穩定的API。大多數TensorFlow模型使用float32 dtype; 但是,有些低精度型別(例如float16)使用的記憶體更少。混合精度是在同一模型中使用16位和32位浮點型別以進行更快的訓練。該API可以將模型效能在GPU上提高3倍,在TPU上提高60%。要使用混合精度API,必須使用Keras層和最佳化器,但是沒有必要使用其他Keras類,例如模型或損失。如果您想學習如何利用此API以獲得更好的效能,請檢視《混合精度》教程。
最佳化器
此版本包括重構tf.keras.optimizers.Optimizer 類,使model.fit或定製培訓迴圈的使用者能夠編寫可與任何最佳化程式一起使用的培訓程式碼。現在, 所有內建tf.keras.optimizer.Optimizer子類都接受gradient_transformers和gradient_aggregator引數,從而使您可以輕鬆定義自定義漸變變換。
透過重構,您現在可以Optimizer.minimize在編寫自定義訓練迴圈時直接將損失張量傳遞給:
tape = tf.GradientTape()with tape: y_pred = model(x, training=True) loss = loss_fn(y_pred, y_true)# You can pass in the `tf.GradientTape` when using a loss `Tensor` as shown below.optimizer.minimize(loss, model.trainable_variables, tape=tape)
這些更改旨在使Model.fit定製培訓迴圈和定製培訓迴圈均與最佳化器細節無關,從而使您無需修改即可編寫可與任何最佳化器一起使用的培訓程式碼。
功能性API模型構建的內部改進
最後,TensorFlow 2.4包括Keras Functional API內部的主要重構,從而改善了功能模型構建的記憶體消耗並簡化了觸發邏輯。這種重構還可以確保TensorFlowOpLayers行為可預測並可以使用CompositeTensor型別簽名。
介紹tf.experimental.numpyTensorFlow 2.4引入了對NumPy API子集的實驗支援,可透過以下方式獲得tf.experimental.numpy。該模組使您可以執行由TensorFlow加速的NumPy程式碼。由於該API基於TensorFlow構建,因此可與TensorFlow無縫互操作,從而允許訪問所有TensorFlow API,並透過編譯和自動向量化提供最佳化的執行。例如,TensorFlow ND陣列可以與NumPy函式互操作,並且類似地,TensorFlow NumPy函式可以接受包括tf.Tensor和在內的不同型別的輸入np.ndarray。
import tensorflow.experimental.numpy as tnp# Use NumPy code in input pipelinesdataset = tf.data.Dataset.from_tensor_slices( tnp.random.randn(1000, 1024)).map( lambda z: z.clip(-1,1)).batch(100)# Compute gradients through NumPy codedef grad(x, wt): with tf.GradientTape() as tape: tape.watch(wt) output = tnp.dot(x, wt) output = tf.sigmoid(output) return tape.gradient(tnp.sum(output), wt)
您可以在TensorFlow上的NumPy API指南中瞭解有關如何使用此API的更多資訊。
新的探查器工具TensorFlow Profiler中的MultiWorker支援
該TensorFlow探查是你可以用它來衡量你的TensorFlow模型的訓練表現和資源消耗的工具套件。TensorFlow Profiler可幫助您瞭解模型中操作的硬體資源消耗,診斷瓶頸並最終更快地進行訓練。
之前,TensorFlow Profiler支援監視多GPU,單主機培訓作業。在2.4中,您現在可以分析MultiWorkerMirroredStrategy培訓工作。例如,您可以使用取樣模式API執行按需配置,並連線到MultiWorkerMirroredStrategy工作人員正在使用的同一伺服器:埠:
# Start a profiler server before your model runs.tf.profiler.experimental.server.start(6009)# Model code goes here.... # E.g. your worker IP addresses are 10.0.0.2, 10.0.0.3, 10.0.0.4, and you# would like to profile for a duration of 2 seconds. The profiling data will# be saved to the Google Cloud Storage path “your_tb_logdir”. tf.profiler.experimental.client.trace( 'grpc://10.0.0.2:6009,grpc://10.0.0.3:6009,grpc://10.0.0.4:6009', 'gs://your_tb_logdir', 2000)
另外,您可以透過向捕獲配置檔案工具提供工作人員地址來使用TensorBoard配置檔案外掛。
分析後,您可以使用新的Pod Viewer工具選擇培訓步驟,並檢視所有工作人員的培訓時間類別細分。
有關如何使用TensorFlow Profiler的更多資訊,請檢視新發布的GPU效能指南。本指南顯示了在配置模型訓練工作時可能遇到的常見情況,並提供了除錯工作流以幫助您獲得更好的效能,無論您是使用一個GPU,多個GPU還是多個機器進行訓練。
TFLite Profiler
TFLite Profiler支援在Android中跟蹤TFLite內部,以識別效能瓶頸。《TFLite效能評估指南》向您展示瞭如何使用Android Studio CPU Profiler和“系統跟蹤”應用程式新增跟蹤事件,啟用TFLite跟蹤以及捕獲跟蹤。
使用Android系統跟蹤應用程式的示例跟蹤
GPU支援的新功能TensorFlow 2.4與CUDA 11和cuDNN 8一起執行,從而支援最新可用的NVIDIA Ampere GPU架構。要了解有關CUDA 11功能的更多資訊,請檢視此NVIDIA開發人員部落格。