--- tags: Python --- # 使用YOLO辨識K線圖 * [安裝GIT](https://git-scm.com/) * [安裝Pytorch](https://pytorch.org/get-started/locally/) * [安裝CUDA(選擇性)](https://developer.nvidia.com/zh-cn/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=10&target_type=exelocal) * pip install [mplfinance](https://github.com/matplotlib/mplfinance) * `git clone https://github.com/ultralytics/yolov5` * `cd yolov5` * `pip install -r requirements.txt` * 新增`yolov5/train.yaml` ```yaml= train: train/ val: train/ # number of classes nc: 2 # class names names: ['sell','buy'] ``` ## 自動產生資料集 ### 引入套件 ```python= import os import sqlite3 import numpy as np import pandas as pd import matplotlib.pyplot as plt import matplotlib.patches as patches import matplotlib.dates as mpl_dates from mplfinance.original_flavor import candlestick_ohlc ``` ### 建立資料夾 ```python= os.makedirs('yolov5/train/images', exist_ok=True) os.makedirs('yolov5/train/labels', exist_ok=True) ``` ### 讀取資料庫 取得資料庫請參考[分析系統建置/步驟一](https://hackmd.io/dBWiFw-qQIiwGqx5QGrVyQ?view) ```python= symbol = 2330 with sqlite3.connect('flaskr.db') as con: sql = f'SELECT Date,Open,High,Low,Close FROM TW{symbol};' df = pd.read_sql(sql, con) df['Date'] = pd.to_datetime(df['Date']).apply(mpl_dates.date2num) ``` ### 技術指標與交叉點 ```python= ROLL = 1 # 框框半徑 # 技術指標 DATE,O,H,L,C = df['Date'],df['Open'],df['High'],df['Low'],df['Close'] FAST = C.rolling(12).mean().values SLOW = C.rolling(26).mean().values # 交叉點 rects = [] pos = FAST[ROLL]>SLOW[ROLL] for i in range(ROLL,len(SLOW)): curPos = FAST[i]>SLOW[i] if curPos != pos: arange = range(i-ROLL,i+ROLL) x = DATE[i-ROLL] y = min(*FAST[arange],*SLOW[arange]) w = 2*ROLL h = max(*FAST[arange],*SLOW[arange])-y label = 1 if curPos else 0 rects.append((i,label,x,y,w,h)) pos = curPos ``` ### 繪圖 ```python= STRIDE = 30 # 時窗 DOHLC = df[['Date', 'Open', 'High', 'Low', 'Close']].values for j in range(0,len(df),STRIDE)[1:]: range_rects = [r for r in rects if j<r[0] and r[0]<j+STRIDE] # 跳過沒有標記的 if not range_rects: continue # 畫蠟燭線、技術折線,圖片尺寸=fig.get_size_inches()*fig.dpi fig, ax = plt.subplots() candlestick_ohlc(ax, DOHLC[j:j+STRIDE,:], width=0.6, colorup='green', colordown='red', alpha=1) ax.plot(DATE[j:j+STRIDE],FAST[j:j+STRIDE]) ax.plot(DATE[j:j+STRIDE],SLOW[j:j+STRIDE]) # 畫方框 for i,label,x,y,w,h in range_rects: patch = patches.Rectangle((x,y), w, h, linewidth=1, edgecolor='r', facecolor='none') ax.add_patch(patch) # 標題刻度 ax.xaxis.set_major_formatter(mpl_dates.DateFormatter('%Y-%m-%d')) fig.autofmt_xdate() plt.axis('off') fig.tight_layout() # 儲存圖片、標籤 date_str = mpl_dates.num2date(x).strftime('%Y-%m-%SLOW') plt.savefig(f'yolov5/train/images/{symbol}-{date_str}.jpg') ``` ### YOLO圖片格式 ![](https://user-images.githubusercontent.com/26833433/91506361-c7965000-e886-11ea-8291-c72b98c25eec.jpg) ```python= # 將左下至右上的數值,轉換成由左上至右下的圖片的pixel絕對座標,同時正規化成[0,1] def yolo_formatter(x,y,w,h,fig,ax): width, height = fig.canvas.get_width_height() left, bottom = ax.transData.transform([x,y]) right, top = ax.transData.transform([x+w,y+h]) w = round((right-left)/width,2) h = round((top-bottom)/height,2) cx = round((right+left)/2/width,2) cy = round(1-(top+bottom)/2/height,2) return cx,cy,w,h ``` ### 產生標籤檔 ```python= # 寫入標籤文字檔 with open(f'yolov5/train/labels/{symbol}-{date_str}.txt','w') as f: text = '' for i,label,x,y,w,h in range_rects: cx,cy,w,h = yolo_formatter(x,y,w,h,fig,ax) text += f'{label} {cx} {cy} {w} {h}\n' f.write(text) # 最後記得清空繪圖記憶體 plt.close() ``` 檔案`2330-2018-10-01.txt`內容: ``` 1 0.24 0.27 0.04 0.03 0 0.36 0.33 0.04 0.06 1 0.84 0.5 0.04 0.01 ``` ## 訓練、偵測 ### 模型選擇 ![](https://user-images.githubusercontent.com/26833433/103595982-ab986000-4eb1-11eb-8c57-4726261b0a88.png) ### 超參數設定 ```yaml= lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3) lrf: 0.2 # final OneCycleLR learning rate (lr0 * lrf) momentum: 0.937 # SGD momentum/Adam beta1 weight_decay: 0.0005 # optimizer weight decay 5e-4 warmup_epochs: 3.0 # warmup epochs (fractions ok) warmup_momentum: 0.8 # warmup initial momentum warmup_bias_lr: 0.1 # warmup initial bias lr box: 0.05 # box loss gain cls: 0.5 # cls loss gain cls_pw: 1.0 # cls BCELoss positive_weight obj: 1.0 # obj loss gain (scale with pixels) obj_pw: 1.0 # obj BCELoss positive_weight iou_t: 0.20 # IoU training threshold anchor_t: 4.0 # anchor-multiple threshold # anchors: 3 # anchors per output layer (0 to ignore) fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5) hsv_h: 0.015 # image HSV-Hue augmentation (fraction) hsv_s: 0.7 # image HSV-Saturation augmentation (fraction) hsv_v: 0.4 # image HSV-Value augmentation (fraction) degrees: 0.0 # image rotation (+/- deg) translate: 0.1 # image translation (+/- fraction) scale: 0.5 # image scale (+/- gain) shear: 0.0 # image shear (+/- deg) perspective: 0.0 # image perspective (+/- fraction), range 0-0.001 flipud: 0.0 # image flip up-down (probability) fliplr: 0.5 # image flip left-right (probability) mosaic: 1.0 # image mosaic (probability) mixup: 0.0 # image mixup (probability) ``` ### 訓練 * **資料數量**:每個類別≥1.5k圖像 * **每類實例**:每類總計≥10,000個實例(帶標籤的對象) * **圖像種類**:必須代表所部署的環境。對於現實世界中的例,建議使用一天中不同時間,不同季節,不同天氣,不同照明,不同角度,不同來源(網路圖片、本地收集、不同相機)等圖像。 * **標籤一致性**:所有圖像中所有類的所有實例都必須標記,部分標籤將不起作用。 * **標籤精度**:標籤必須緊密包圍每個對象,對象與其邊界框之間不應存在任何空間,任何物體都不應缺少標籤。 * **背景圖像**:背景圖像是沒有對象的圖像,這些圖像被添加到數據集中以減少誤報(FP),建議使用大約0-10%的背景圖片來幫助減少FP(COCO提供1000張背景圖片作為參考,佔總數的1%)。 * **Epochs**:起手300,若未過度擬和,可訓練至600,1200以上 * **Image size**:預設`--img 640` * **Batch size**:應避免使用較小batch,否則會產生不準確的統計數據 * 使用[Google Cloud訓練](https://github.com/ultralytics/yolov5/wiki/GCP-Quickstart) * 使用[AWS訓練](https://github.com/ultralytics/yolov5/wiki/AWS-Quickstart) 官方指令 `python train.py --img 640 --batch 16 --epochs 5 --data coco128.yaml --weights yolov5s.pt` 個人指令(workers核心、evolve視覺化超參數調校) `python train.py --batch 16 --epochs 5 --data train.yaml --workers 0 --evolve` 繪製訓練結果 ```python= import sys sys.path.append('.') from yolov5.utils.plots import plot_results plot_results(save_dir='yolov5/runs/train/exp') ``` * `hyp.yaml` * `opt.yaml` * `results.png` * `test_batch*.jpg` ![](https://i.imgur.com/zihccmr.jpg) ### 其他偵測指令 使用攝像頭(0代表webcam) `python detect.py --source 0 --weights yolov5s.pt --conf 0.25` 使用glob批次偵測 `python detect.py --source path/*.jpg`