0. 簡介
pytorch lightning透過提供LightningModule和LightningDataModule,使得在用pytorch編寫網路模型時,載入資料、分割資料集、訓練、驗證、測試、計算指標的程式碼全部都能很好地組織起來,顯得主程式呼叫時,程式碼簡潔可讀性大幅度提升。
1. pytorch lightning的安裝pip install pytorch-lightningconda install pytorch-lightning -c conda-forge
2. 定義一個網路模型模型:LightningModule透過繼承LightningModule,並實現幾個關鍵的函式,使得模型在訓練、驗證和測試過程中能進行模組化呼叫,具體細節完全被自定義的函式封裝,整體十分簡潔。定義一個LightningModule的基類,可以實現的函式如下:
from pytorch_lightning import LightningModuleclass MyModel(LightningModule): """ The only required methods in the LightningModule are: init training_step configure_optimizers """ def __init__(self, *args, **kwargs): pass def forward(self, *args, **kwargs): pass def training_step(self, batch, batch_idx, optimizer_idx, hiddens): pass def training_step_end(self, *args, **kwargs): pass # 接受train_step的返回值 def training_epoch_end(self, outputs): pass # 接受train_step一整個epoch的返回值的列表 def validation_step(self, batch, batch_idx, dataloader_idx): pass # model.eval() and torch.no_grad() are called automatically def validation_step_end(self, *args, **kwargs): pass # 接受validation_step的返回值 def validation_epoch_end(self, outputs) def test_step(self, batch, batch_idx, dataloader_idx): pass # model.eval() and torch.no_grad() are called automatically def test_step_end(self, *args, **kwargs): pass # 接收test_step的返回值 def test_epoch_end(self, outputs): pass def configure_optimizers(self, ): pass def any_extra_hook(...): pass # 指代任意其他的可過載函式
其中,必須實現的函式只有__init__() 、training_step()、configure_optimizers()。
3. 定義一個數據模型:LightningDataModule透過定義LightningDataModule的子類,資料集分割、載入的程式碼將整合在一起,可以實現的方法有:
class MyDataModule(LightningDataModule): def __init__(self): super().__init__() def prepare_data(self): # download, split, etc... # only called on 1 GPU/TPU in distributed def setup(self,stage:str): # stage: "fit", "test", 【暫時不知道驗證步驟叫什麼名字,可以自己列印一下】 # make assignments here (val/train/test split) # called on every process in DDP def train_dataloader(self): train_split = Dataset(...) return DataLoader(train_split) def val_dataloader(self): val_split = Dataset(...) return DataLoader(val_split) def test_dataloader(self): test_split = Dataset(...) return DataLoader(test_split)
4. 使用pytorch lightning的API開始訓練def main(): model = MyModule() data_module = MyDataModule() trainer = pytorch_lightning.Trainer(...) # some arugments, 根據需要傳入你的引數 trainer.fit(module, datamodule=data_module) trainer.test(module, datamodule=data_module, verbose=True) if __name__ == "__main__": main()
具體實現都透過類封裝之後,主函式就顯得簡潔多了。
最新評論