首頁>技術>

0. 簡介

pytorch lightning透過提供LightningModule和LightningDataModule,使得在用pytorch編寫網路模型時,載入資料、分割資料集、訓練、驗證、測試、計算指標的程式碼全部都能很好地組織起來,顯得主程式呼叫時,程式碼簡潔可讀性大幅度提升。

1. pytorch lightning的安裝
pip install pytorch-lightningconda install pytorch-lightning -c conda-forge
2. 定義一個網路模型模型:LightningModule

透過繼承LightningModule,並實現幾個關鍵的函式,使得模型在訓練、驗證和測試過程中能進行模組化呼叫,具體細節完全被自定義的函式封裝,整體十分簡潔。定義一個LightningModule的基類,可以實現的函式如下:

from pytorch_lightning import LightningModuleclass MyModel(LightningModule):    """    The only required methods in the LightningModule are:    init    training_step    configure_optimizers    """    def __init__(self, *args, **kwargs): pass    def forward(self, *args, **kwargs): pass    def training_step(self, batch, batch_idx, optimizer_idx, hiddens): pass    def training_step_end(self, *args, **kwargs): pass # 接受train_step的返回值    def training_epoch_end(self, outputs): pass # 接受train_step一整個epoch的返回值的列表    def validation_step(self, batch, batch_idx, dataloader_idx): pass # model.eval() and torch.no_grad() are called automatically    def validation_step_end(self, *args, **kwargs): pass # 接受validation_step的返回值    def validation_epoch_end(self, outputs)    def test_step(self, batch, batch_idx, dataloader_idx): pass # model.eval() and torch.no_grad() are called automatically    def test_step_end(self, *args, **kwargs): pass  # 接收test_step的返回值    def test_epoch_end(self, outputs): pass    def configure_optimizers(self, ): pass    def any_extra_hook(...): pass  #  指代任意其他的可過載函式

其中,必須實現的函式只有__init__() 、training_step()、configure_optimizers()。

3. 定義一個數據模型:LightningDataModule

透過定義LightningDataModule的子類,資料集分割、載入的程式碼將整合在一起,可以實現的方法有:

class MyDataModule(LightningDataModule):    def __init__(self):        super().__init__()    def prepare_data(self):        # download, split, etc...        # only called on 1 GPU/TPU in distributed    def setup(self,stage:str):  # stage: "fit", "test", 【暫時不知道驗證步驟叫什麼名字,可以自己列印一下】        # make assignments here (val/train/test split)        # called on every process in DDP    def train_dataloader(self):        train_split = Dataset(...)        return DataLoader(train_split)    def val_dataloader(self):        val_split = Dataset(...)        return DataLoader(val_split)    def test_dataloader(self):        test_split = Dataset(...)        return DataLoader(test_split)
4. 使用pytorch lightning的API開始訓練
def main():    model = MyModule()    data_module = MyDataModule()    trainer = pytorch_lightning.Trainer(...)  # some arugments, 根據需要傳入你的引數    trainer.fit(module, datamodule=data_module)    trainer.test(module, datamodule=data_module, verbose=True) if __name__ == "__main__":    main()

具體實現都透過類封裝之後,主函式就顯得簡潔多了。

17
最新評論
  • BSA-TRITC(10mg/ml) TRITC-BSA 牛血清白蛋白改性標記羅丹明
  • node-exporter安裝