Try   HackMD

前言

在「從零開始學Flask」中學到了Flask的基礎知識,而在這個教學中我們將會應用這些知識。我學習Flask最主要的原因就是希望能夠將深度學習模型與API串接在一起,來達到更多的應用。這也是現在許多公司認為應該要有的技能,因此這個教學中我們會學習將深度學習模型串接API,來建立一個用來辨識的API。

我是根據Deploying Deep Learning Model using Flask API這篇文章進行修改。由於我習慣用pytorch,但文章內是使用tensorflow,所以我對一些部分進行修改。如果想使用tensorflow可查閱上面那篇文章。

說明

建立一個ImageNet的辨識API,模型使用pytorch預訓練好的ResNet50,如下圖所示。

  • 執行app.py後,會跳出一個初始網頁
  • 上傳照片進行辨識
  • 顯示辨識結果

初始網頁

初始網頁
辨識結果
辨識結果

About the Model

我是使用PyTorch透過ImageNet預訓練好的ResNet50,這個模型的辨識類別共有1000類,詳細的說明可查閱相關文件。由於是使用PyTorch提供的模型,所以可以直接在程式中載入。

Project Structure

Project Structure

  • static資料夾中用來放辨識影像
  • templates資料用來放所有html檔
  • app.py主要API程式
  • ImageNet類別檔

接下來我們會說明app.py,再來說明home.htmlpredict.html

app. py

app.py是主要用來建立辨識API的檔案,其中包含一些我們會用到的methods:

  • read_image(): 用來讀取影像,並進行影像前處理。其中包含了torch.stack,用來將tensor大小轉為[batch, 3, 244, 244]

  • allow_file(): 用來確認影像的副檔名是否為jpg, jpeg, png。

  • homeapge(): 用來呼叫初始頁面(home.html),其url設定為/

  • predict(): 用來辨識影像的類別,並將辨識結果傳至predict.html

from flask import Flask, render_template, request import torch from torchvision import transforms from torchvision.models import resnet50 from PIL import Image import os # 固定格式 app = Flask(__name__) # 載入resnet50 model = resnet50(pretrained=True) # 由於是直接拿來辨識,所以要轉成eval()模式 model.eval() transform = transforms.Compose([ transforms.Resize([224,224]), transforms.ToTensor(), transforms.Normalize(mean = (0.5,0.5,0.5), std = (0.5,0.5,0.5))]) # 將label轉成list with open('./imagenet-classes.txt', 'r') as f: labels = [line.strip() for line in f.readlines()] # 讀取影像 def read_image(filename): img = Image.open(filename) img = transform(img) # 轉成[1,3,244,244] img_stack = torch.stack([img]) return img_stack # allow files with png, jpg, jpeg ALLOW_EXT = set(['jpg', 'jpeg', 'png']) def allow_file(filename): return '.' in filename and \ filename.rsplit('.',1)[1] in ALLOW_EXT # 最初始頁面 @app.route('/') def homeapge(): return render_template('home.html') # 進行辨識 @app.route('/predict', methods = ['GET', 'POST']) def predict(): if request.method == 'POST': file = request.files['file'] # 判斷副檔名 if file and allow_file(file.filename): filename = file.filename # 將影像儲存 file_path = os.path.join('./static/images', filename) file.save(file_path) # 讀取影像 img = read_image(file_path) # 進行辨識 output = model(img) # 取得機率值最大的index _, preds = torch.max(output.data, 1) # 根據index從label list中取得類別名稱 classes = labels[preds[0]] return render_template('predict.html', classes=classes, user_image=file_path) else: return "Unable to read the file. Please check file extension." # 固定格式 if __name__ == '__main__': app.run(debug=True)

第15-19行是用來進行影像前處理的方法,包含Resize,ToTensor,Normalize
第21-22行是將imagenet-classes.txt中的類別轉成list,以便之後讀取類別名稱。
第62行會將classesfile_path傳到predict.html,用於之後顯示於網頁上。

home.html

用來呈現我們辨識API的初始頁面,我們會接收到從user中上傳的input,也就是要辨識的影像。

<html> <head> <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.4.1/css/bootstrap.min.css"> <title>Chris's ImageNet Prediction</title> </head> <body> <div class="container"> <div class="row"> <div class="col-lg-12"> <center><h1 style="padding:30px;font-weight:600 !important;font-size: 3.2em">IMAGENET PREDICTION</h1></center> </div> <div> <div class="col-lg-12"> <center> <form class="form-inline" action="/predict" method="post" enctype="multipart/form-data"> <input type="file" name="file"/><br> <input type="submit" class="btn btn-success" value="Predict"> </form> </center> </div> </div> </div> </div> </body> </html>

第16-19行會將上傳的影像傳到app.py中的predict()進行處理及辨識。

predict.html

這是辨識API的result page。當在初始頁面按下predict按鈕時,會將影像傳到app.py中的predict()進行處理及辨識。最後將結果呈現在這個頁面。

<!DOCTYPE html> <html> <head> <!--Require meta tags--> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no"> <link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.4.1/css/bootstrap.min.css"> <title>Chris's ImageNet Prediction</title> </head> <body> <div class="container"> <div class="col-lg-12"> <center><h1 style="padding: 30px;font-weight: 600 !important;font-size: 3.2em;">ImageNet Prediction Result</h1></center> </div> <div class="col-lg-12"> <center><form class="form-inline" action="/predict" method="post" enctype="multipart/form-data"> <input type="file" name="file"/><br> <input type="submit" class="btn btn-success" value="Predict"> </form></center> </div> <div class="col-lg-5" style="padding-top: 30px"> <span class="border border-primary"> <img src="{{ user_image }}" alt="User Image" class="img-thumbnail" style="width: 250px;height: 250px;float: right"> </span> </div> <div class="col-lg-5" style="padding-top: 30px"> <h4>Class Name is <mark style="background-color:#04aa6d;color:white;border-radius:5px">{{ classes }}</mark></h4> </div> </div> </body> </html>

第16-19行同樣會將上傳的影像傳到app.py中的predict()進行處理及辨識。
第21-25行用來接收從app.py中的predict()傳來的file_path,並進行呈現。
第26-28行用來接收從app.py中的predict()傳來的classes,並進行呈現。

Output Screen

執行app.py之後,在terminal或cmd會得到這個output

result

接著複製url(http://127.0.0.1:5000),並貼上瀏覽器。接著將影像上傳至初始頁面,並按下predict按鈕。模型就會開始進行辨識,最後將結果呈現在頁面上囉!

Source Code Please Visit