首頁>技術>

在本教程中,我們將使用Flask來部署PyTorch模型,並用講解用於模型推斷的 REST API。特別是,我們將部署一個預訓練的DenseNet 121模型來檢測影象。

備註:可在https://github.com/avinassh/pytorch-flask-api上獲取本文用到的完整程式碼

這是在生產中部署PyTorch模型的系列教程中的第一篇。到目前為止,以這種方式使用Flask是開始為PyTorch模型提供服務的最簡單方法,但不適用於具有高效能要求的用例。因此:

如果您已經熟悉TorchScript,則可以直接進入我們的https://github.com/fendouai/PyTorchDocs/blob/master/EigthSection/torchScript_in_C%2B%2B.md教程。如果您首先需要複習TorchScript,請檢視我們的https://github.com/fendouai/PyTorchDocs/blob/master/EigthSection/torchScript.md教程。1.定義API

我們將首先定義API端點、請求和響應型別。我們的API端點將位於/ predict,它接受帶有包含影象的file引數的HTTP POST請求。響應將是包含預測的JSON響應:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}
2.依賴(包)

執行下面的命令來下載我們需要的依賴:

$ pip install Flask==1.0.3 torchvision-0.3.0
3.簡單的Web伺服器

以下是一個簡單的Web伺服器,摘自Flask文件

from flask import Flaskapp = Flask(__name__)@app.route('/')def hello():    return 'Hello World!'

將以上程式碼段儲存在名為app.py的檔案中,您現在可以通過輸入以下內容來執行Flask開發伺服器:

$ FLASK_ENV=development FLASK_APP=app.py flask run

當您在web瀏覽器中訪問http://localhost:5000/時,您會收到文字Hello World的問候!

我們將對以上程式碼片段進行一些更改,以使其適合我們的API定義。首先,我們將重新命名predict方法。我們將端點路徑更新為/predict。由於影象檔案將通過HTTP POST請求傳送,因此我們將對其進行更新,使其也僅接受POST請求:

@app.route('/predict', methods=['POST'])def predict():    return 'Hello World!'

我們還將更改響應型別,以使其返回包含ImageNet類的id和name的JSON響應。更新後的app.py檔案現在為:

from flask import Flask, jsonifyapp = Flask(__name__)@app.route('/predict', methods=['POST'])def predict():    return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})
4.推理

在下一部分中,我們將重點介紹編寫推理程式碼。這將涉及兩部分,第一部分是準備影象,以便可以將其饋送到DenseNet;第二部分,我們將編寫程式碼以從模型中獲取實際的預測。

4.1 準備影象

DenseNet模型要求影象為尺寸為224 x 224的 3 通道RGB影象。我們還將使用所需的均值和標準偏差值對影象張量進行歸一化。你可以點選https://pytorch.org/docs/stable/torchvision/models.html來了解更多關於它的內容。

我們將使用來自torchvision庫的transforms來建立轉換管道,該轉換管道可根據需要轉換影象。您可以在https://pytorch.org/docs/stable/torchvision/transforms.html閱讀有關轉換的更多資訊。

import ioimport torchvision.transforms as transformsfrom PIL import Imagedef transform_image(image_bytes):    my_transforms = transforms.Compose([transforms.Resize(255),                                        transforms.CenterCrop(224),                                        transforms.ToTensor(),                                        transforms.Normalize(                                            [0.485, 0.456, 0.406],                                            [0.229, 0.224, 0.225])])    image = Image.open(io.BytesIO(image_bytes))    return my_transforms(image).unsqueeze(0)

上面的方法以位元組為單位獲取影象資料,應用一系列變換並返回張量。要測試上述方法,請以位元組模式讀取影象檔案(首先將../_static/img/sample_file.jpeg替換為計算機上檔案的實際路徑),然後檢視是否獲得了張量:

with open("../_static/img/sample_file.jpeg", 'rb') as f:    image_bytes = f.read()    tensor = transform_image(image_bytes=image_bytes)    print(tensor)
輸出結果:
tensor([[[[ 0.4508,  0.4166,  0.3994,  ..., -1.3473, -1.3302, -1.3473],          [ 0.5364,  0.4851,  0.4508,  ..., -1.2959, -1.3130, -1.3302],          [ 0.7077,  0.6392,  0.6049,  ..., -1.2959, -1.3302, -1.3644],          ...,          [ 1.3755,  1.3927,  1.4098,  ...,  1.1700,  1.3584,  1.6667],          [ 1.8893,  1.7694,  1.4440,  ...,  1.2899,  1.4783,  1.5468],          [ 1.6324,  1.8379,  1.8379,  ...,  1.4783,  1.7352,  1.4612]],         [[ 0.5728,  0.5378,  0.5203,  ..., -1.3704, -1.3529, -1.3529],          [ 0.6604,  0.6078,  0.5728,  ..., -1.3004, -1.3179, -1.3354],          [ 0.8529,  0.7654,  0.7304,  ..., -1.3004, -1.3354, -1.3704],          ...,          [ 1.4657,  1.4657,  1.4832,  ...,  1.3256,  1.5357,  1.8508],          [ 2.0084,  1.8683,  1.5182,  ...,  1.4657,  1.6583,  1.7283],          [ 1.7458,  1.9384,  1.9209,  ...,  1.6583,  1.9209,  1.6408]],         [[ 0.7228,  0.6879,  0.6531,  ..., -1.6476, -1.6302, -1.6476],          [ 0.8099,  0.7576,  0.7228,  ..., -1.6476, -1.6476, -1.6650],          [ 1.0017,  0.9145,  0.8797,  ..., -1.6476, -1.6650, -1.6999],          ...,          [ 1.6291,  1.6291,  1.6465,  ...,  1.6291,  1.8208,  2.1346],          [ 2.1868,  2.0300,  1.6814,  ...,  1.7685,  1.9428,  2.0125],          [ 1.9254,  2.0997,  2.0823,  ...,  1.9428,  2.2043,  1.9080]]]])
4.2 預測

現在將使用預訓練的DenseNet 121模型來預測影象的類別。我們將使用torchvision庫中的一個庫,載入模型並進行推斷。在此示例中,我們將使用預訓練的模型,但您可以對自己的模型使用相同的方法。在這個https://pytorch.org/tutorials/beginner/saving_loading_models.html中了解有關載入模型的更多資訊。

from torchvision import models# 確保使用`pretrained`作為`True`來使用預訓練的權重:model = models.densenet121(pretrained=True)# 由於我們僅將模型用於推理,因此請切換到“eval”模式:model.eval()def get_prediction(image_bytes):    tensor = transform_image(image_bytes=image_bytes)    outputs = model.forward(tensor)    _, y_hat = outputs.max(1)    return y_hat

張量y_hat將包含預測的類的id的索引。但是,我們需要一個易於閱讀的類名。為此,我們需要一個類id來命名對映。將https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json下載為imagenet_class_index.json並記住它的儲存位置(或者,如果您按照本教程中的確切步驟操作,請將其儲存在tutorials/_static中)。此檔案包含ImageNet類的id到ImageNet類的name的對映。我們將載入此JSON檔案並獲取預測索引的類的name。

import jsonimagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))def get_prediction(image_bytes):    tensor = transform_image(image_bytes=image_bytes)    outputs = model.forward(tensor)    _, y_hat = outputs.max(1)    predicted_idx = str(y_hat.item())    return imagenet_class_index[predicted_idx]

在使用字典imagenet_class_index之前,首先我們將張量值轉換為字串值,因為字典imagenet_class_index中的keys是字串。我們將測試上述方法:

with open("../_static/img/sample_file.jpeg", 'rb') as f:    image_bytes = f.read()    print(get_prediction(image_bytes=image_bytes))
輸出結果:
['n02124075', 'Egyptian_cat']

你會得到這樣的一個響應:

['n02124075', 'Egyptian_cat']

陣列中的第一項是ImageNet類的id,第二項是人類可讀的name。

注意:您是否注意到模型變數不是get_prediction方法的一部分?或者為什麼模型是全域性變數?就記憶體和計算而言,載入模型可能是一項昂貴的操作。如果將模型載入到get_prediction方法中,則每次呼叫該方法時都會不必要地載入該模型。由於我們正在構建Web服務器,因此每秒可能有成千上萬的請求,因此我們不應該浪費時間為每個推斷重複載入模型。因此,我們僅將模型載入到記憶體中一次。在生產系統中,必須有效利用計算以能夠大規模處理請求,因此通常應在處理請求之前載入模型。

5.將模型整合到我們的API伺服器中

在最後一部分中,我們將模型新增到Flask API伺服器中。由於我們的API伺服器應該獲取影象檔案,因此我們將更新predict方法以從請求中讀取檔案:

from flask import [email protected]('/predict', methods=['POST'])def predict():    if request.method == 'POST':        # 從請求中獲得檔案        file = request.files['file']        # 轉化為位元組        bytes = file.read()        class_id, class_name = get_prediction(image_bytes=img_bytes)        return jsonify({'class_id': class_id, 'class_name': class_name})

app.py檔案現已完成。以下是完整版本;將路徑替換為儲存檔案的路徑,它的執行應是如下:

import ioimport jsonfrom torchvision import modelsimport torchvision.transforms as transformsfrom PIL import Imagefrom flask import Flask, jsonify, requestapp = Flask(__name__)imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))model = models.densenet121(pretrained=True)model.eval()def transform_image(image_bytes):    my_transforms = transforms.Compose([transforms.Resize(255),                                        transforms.CenterCrop(224),                                        transforms.ToTensor(),                                        transforms.Normalize(                                            [0.485, 0.456, 0.406],                                            [0.229, 0.224, 0.225])])    image = Image.open(io.BytesIO(image_bytes))    return my_transforms(image).unsqueeze(0)def get_prediction(image_bytes):    tensor = transform_image(image_bytes=image_bytes)    outputs = model.forward(tensor)    _, y_hat = outputs.max(1)    predicted_idx = str(y_hat.item())    return imagenet_class_index[predicted_idx]@app.route('/predict', methods=['POST'])def predict():    if request.method == 'POST':        file = request.files['file']        bytes = file.read()        class_id, class_name = get_prediction(image_bytes=img_bytes)        return jsonify({'class_id': class_id, 'class_name': class_name})if __name__ == '__main__':    app.run()

讓我們測試一下我們的web伺服器,執行:

$ FLASK_ENV=development FLASK_APP=app.py flask run

我們可以使用https://pypi.org/project/requests/庫來發送一個POST請求到我們的app:

import requestsresp = requests.post("http://localhost:5000/predict",                     files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})

列印resp.json()會顯示下面的結果:

{"class_id": "n02124075", "class_name": "Egyptian_cat"}
6.下一步工作

我們編寫的伺服器非常瑣碎,可能無法完成生產應用程式所需的一切。因此,您可以採取一些措施來改善它:

端點/predict假定請求中總會有一個影象檔案。這可能不適用於所有請求。我們的使用者可能傳送帶有其他引數的影象,或者根本不傳送任何影象。使用者也可以傳送非影象型別的檔案。由於我們沒有處理錯誤,因此這將破壞我們的伺服器。新增顯式的錯誤處理路徑來引發異常,這將使我們能夠更好地處理錯誤的輸入即使模型可以識別大量類別的影象,也可能無法識別所有影象。增強實現以處理模型無法識別影象中的任何情況的情況。我們在開發模式下執行Flask伺服器,該伺服器不適合在生產中進行部署。您可以檢視https://flask.palletsprojects.com/en/1.1.x/tutorial/deploy/以在生產環境中部署Flask伺服器。您還可以通過建立一個帶有表單的頁面來新增UI,該表單可以拍攝影象並顯示預測。檢視類似https://pytorch-imagenet.herokuapp.com/的演示及其https://github.com/avinassh/pytorch-flask-api-heroku。在本教程中,我們僅展示了如何構建可以一次返回單個影象預測的服務。我們可以修改服務以能夠一次返回多個影象的預測。此外,https://github.com/ShannonAI/service-streamer庫自動將對服務的請求排隊,並將它們取樣到可用於模型的min-batches中。您可以檢視https://github.com/ShannonAI/service-streamer/wiki/Vision-Recognition-Service-with-Flask-and-service-streamer。最後,我們鼓勵您在頁面頂部檢視連結到的有關部署PyTorch模型的其他教程。

最新評論
  • BSA-TRITC(10mg/ml) TRITC-BSA 牛血清白蛋白改性標記羅丹明
  • 2020必學必備語言Python,來了解下他的優缺點