Try   HackMD

Week 7:Character Region Awareness for Text Detection (CRAFT)

tags: 技術研討

議程

  • 論文導讀 (一 ~ 四) (40 分)
  • 其他補充 (五) (10 分)
  • 程式碼 (六) (40 分)
  • 大家的提問 (20 分)

名詞定義

  • character:字元
  • word (instance):由字元組成的字詞

一、前言

text detection 有很多種方法:

1. 用標籤方式來區分

(1) word-level

  • 優點
    • 標註成本低
  • 缺點
    • 字詞的範圍較難定義
      e.g., machine learning (有空格)和 機器學習 (無空格) 都是一個字詞
    • 很難處理彎曲、變形、過長的文字

(2) character-level

  • 優點
    • 先偵測字元,再將字元組成字詞,故可以處理各種形狀的文字
  • 缺點
    • 標註成本高

2. 用歷史演進來區分

(1) 深度學習出現前:使用傳統 computer vision 的招式

  • MSER (Maximally Stable Extremal Regions) (最大穩定極值區域)
    • 先將圖二值化,然後不斷調整 threshold,在比較寬的灰度閾值範圍內保持形狀穩定的區域就是 MSER
    • OpenCV MSER 算法介绍
    • OpenCV MSER 程式碼
    • 存摺封面測試圖樣
      Image Not Showing Possible Reasons
      • The image file may be corrupted
      • The server hosting the image is unavailable
      • The image path is incorrect
      • The image format is not supported
      Learn More →

      Image Not Showing Possible Reasons
      • The image file may be corrupted
      • The server hosting the image is unavailable
      • The image path is incorrect
      • The image format is not supported
      Learn More →
    • 存摺內頁測試圖樣
      Image Not Showing Possible Reasons
      • The image file may be corrupted
      • The server hosting the image is unavailable
      • The image path is incorrect
      • The image format is not supported
      Learn More →

      Image Not Showing Possible Reasons
      • The image file may be corrupted
      • The server hosting the image is unavailable
      • The image path is incorrect
      • The image format is not supported
      Learn More →
  • SWT (Stroke Width Transform) (筆畫寬度變換)
    • 先找到字的輪廓,接著用相似的筆畫寬度來決定分組
    • SWT 算法介紹
      Image Not Showing Possible Reasons
      • The image file may be corrupted
      • The server hosting the image is unavailable
      • The image path is incorrect
      • The image format is not supported
      Learn More →

(2) 深度學習出現後:當然就是 CNN 類的 model,各種變形

  • SSD (Single Shot MultiBox Detector)
  • Faster R-CNN
  • FCN (Fully Convolutional Networks)

3. 用訓練方法來區分

(1) Regression-based text detectors

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

(綠框:Ground Truth、紅框:region proposal)

(2) Segmentation-based text detectors

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

  • 用 pixel-level 偵測字詞區域
  • 相關論文

(3) End-to-end text detectors

  • 將定位和辨識一起訓練,除了速度較快、模型較小外,也能根據辨識結果提升定位準確率

  • 相關論文1:FOTS: Fast Oriented Text Spotting with a Unified Network

    • 程式碼:xieyufei1993/FOTS

    • 主要架構:

      Image Not Showing Possible Reasons
      • The image file may be corrupted
      • The server hosting the image is unavailable
      • The image path is incorrect
      • The image format is not supported
      Learn More →

      Image Not Showing Possible Reasons
      • The image file may be corrupted
      • The server hosting the image is unavailable
      • The image path is incorrect
      • The image format is not supported
      Learn More →

      • ROI Rotate:將 Region of Interest (ROI) 旋轉讓 recognition model 可以訓練
    • 補充

      • RoI Pooling:將 RoI 的框做 Pooling,讓 output 可以餵給 Fully Connected Layer
      • RoI Align:座標進行縮放後可能不為整數,但直接四捨五入可能會讓圖片失真
  • 相關論文2:Character Region Attention For Text Spotting (CRAFTS)

    • character attended feature:recognition model 除了有 word label 外,也有各字元的熱點
      Image Not Showing Possible Reasons
      • The image file may be corrupted
      • The server hosting the image is unavailable
      • The image path is incorrect
      • The image format is not supported
      Learn More →
    • thin-plate spline (TPS):將彎曲的文字校正成矩形
      Image Not Showing Possible Reasons
      • The image file may be corrupted
      • The server hosting the image is unavailable
      • The image path is incorrect
      • The image format is not supported
      Learn More →
    • loss 包含定位和辨識的結果
      Image Not Showing Possible Reasons
      • The image file may be corrupted
      • The server hosting the image is unavailable
      • The image path is incorrect
      • The image format is not supported
      Learn More →

(4) Character-level text detectors

  • 先偵測字元,再將字元組成字詞
  • 相關論文:WordSup: Exploiting Word Annotations for Character based Text Detection
    • 弱監督學習 (weekly-supervised learning)
      • 先用合成圖片訓練一版 character model
      • 取得真實圖片和 word label
      • 用 character model 根據 word label 產出 character mask
    • 缺點
      • 預測出來的字元框是矩形,沒有 perspective transform 成不規則四邊形

二、訓練 (信賢、沛筠)

由於沒有 Character-level 的資料集,因此會使用合成的 Character-level 的資料集。本篇論文是提出由合成的資料集進行訓練,透過預測的方式來顯示真實世界的每個字熱力圖,搭配 weakly-supervised learning 的方式訓練出可以適用於真實世界的字元偵測器。
下圖為 Character-level 的圖片範例

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

對於文字過度變形,或是文字過大這種的狀況,因為字的框可能並非方正的,需要做變形。使用 Anchor 去實作的 Bounding box 方法還是沒辦法處理得很好這些 case,為了改善這種情況,本篇論文提出 CRAFT 框架 - Character Region Awareness For Text detection,採用了 Character-level awareness。藉由測量每個字元(character),再預測每個字元是否是同一個文字:

  • Region score - 該位置是否是字元的機率
  • Affinity score - 該位置的字元是否需要進行串接成文字的機率

1. Ground Truth Label Generation

不像binary segmentation map,我們對於每個字元畫出 Gaussaian map,這可以很好標註沒有嚴格包圍的邊界區域

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More →

(1) 透視轉換(Perspective Transform)

對於 character region score 的標籤生成,由於對box中的每個像素計算高斯分佈值比較耗時,因此本篇論文採用透視轉換,以近似估計的方法來生成標籤,其步驟為:

  • 準備一個二維高斯圖
  • 計算高斯圖區域和每個文字框的透視轉換
  • 將透視轉換後的高斯圖,根據字符框貼到原圖的座標上
    Image Not Showing Possible Reasons
    • The image file may be corrupted
    • The server hosting the image is unavailable
    • The image path is incorrect
    • The image format is not supported
    Learn More →

(2) Region Score Generation

合成資料集有每個字元的bounding-box,我們透過透視轉換,得每個字符的高斯圖。

(3) Affinity Box Generation

判斷字元是不是屬於同一個字,我們就需要知道哪些字和字之間是怎麼關聯的。

    1. 先將 Character Box 點取出
    1. 利用綠點繪製藍色三角形
    1. 在藍色三角形分別取出中心點
    1. 相連四個三角形中心點(藍點),即可將相鄰的字符繪製出 Affinity Box
    1. 同樣的會再將這框線再進行透視轉換,產生 Gaussian 熱力圖。

(4) Summary

2. Weakly-Supervised Learning

(1) 架構

  • region score
    • 目標:找到單一字元所在的位置
    • 意涵:該 pixel 是字元中心點的機率
  • affinity score
    • 目標:將各字元組成一個字詞
    • 意涵:該 pixel 是相鄰字元間空白中心點的機率
  • gaussian heatmap
    • 可以將 region score 和 affinity score 作視覺化的呈現
  • 整體訓練過程如下:

  • Crop 真實圖片而得的文字區域:
    1. 進入模型,得到 region score 的 Gaussian heatmap。
    2. 依 heatmap 結果分割字元,接著計算訓練時用作 learning weight 的
      sconf
  • 真實圖片
  • 合成圖片
  • loss function:
    L=pSc(p)·(Sr(p)Sr(p)22+Sa(p)Sa(p)22)

(2) 圖示說明

  • 字元分割過程可分為:
  1. 裁減含 word-level 標註的圖片,得到 word box。
  2. 放進模型,預測 region score。
  3. 使用 watershed labeling 分割字元 (形狀不規則)。
  4. 找出能包裹各字元的最小 bounding box。
  5. 將 bounding boxes 的座標轉換回原始圖片,即完成字元分割。
  • 如果模型使用不準確的 region scores 訓練,output 將可能在 character regions 模糊。

(3) 算式說明

說明 符號
w
具有 word-level 標註的例子
R(w)
w
的 bounding box region
l(w)
w
的真實文字長度
lc(w)
w
的估計文字長度
sconf
confidence score
Sc
任一圖片的 pixel-wise confidence map
p
R(w)
中的 pixel
Sr
pseudo-ground truth region score
Sa
pseudo-ground truth affinity map
Sr
預測得到的 region score
Sa
預測得到的 affinity score
L
目標函數

sconf(c)=l(w)min(l(w),|l(w)lc(w)|)l(w)
Sc={sconfpR(w),1otherwise,

  • 以 Figure 4 右上方的 COURSE 為例,它的
    l(w)
    = 7,
    lc(w)
    = 5,因此
    sconf(c)
    = (7 - min(7, 2))/7 = 5/7。
  • 當用 synthetic data 訓練時,由於已知實際的 ground truth,所以設
    Sc
    為 1 並假設各字元寬度固定,以
    R(w)/l(w)
    簡單計算 character-level 的預測。
  • 若是
    sconf<0.5
    ,會不利模型訓練;因此,令
    sconf=0.5
    去學習沒看見的部分文字。
    L=pSc(p)·(Sr(p)Sr(p)22+Sa(p)Sa(p)22)

(4) 程式碼

def fake_char_boxes(self, src, word_box, word_length): img, src_points, crop_points = crop_image(src, word_box, dst_height=64.) h, w = img.shape[:2] if min(h, w) == 0: confidence = 0.5 # 依 word_length 切分字元 region_boxes = divide_region(word_box, word_length) # 重新排序四邊形頂點為 (top-left, top-right, bottom right, bottom left) region_boxes = [reorder_points(region_box) for region_box in region_boxes] return region_boxes, confidence # 計算真實圖片各字元的 bounding box 與 confidence img = img_normalize(img) region_score, affinity_score = self.test_net(self.craft, img, self.cuda) heat_map = region_score * 255. heat_map = heat_map.astype(np.uint8) marker_map = watershed(heat_map) # 從不規則的 watershed 結果尋找可包覆該字元的最小 bounding box region_boxes = find_box(marker_map) confidence = cal_confidence(region_boxes, word_length) if confidence <= 0.5: confidence = 0.5 region_boxes = divide_region(word_box, word_length) region_boxes = [reorder_points(region_box) for region_box in region_boxes] else: region_boxes = divide_region(word_box, word_length) region_boxes = [reorder_points(region_box) for region_box in region_boxes] return region_boxes, confidence
  • 其餘部分隨後一併討論

3. Gaussian

    1. gaussian_2d的方法產製了 1000x1000的 gaussian 2D圖
      然後將這個方法丟給 GaussaianGenerator
    1. GaussaianGenerator 這邊實作了一個靜態方法perspective_transform

      python staticmethod 返回函數的靜態方法。
      該方法不強制要求傳遞參數,如下聲明一個靜態方法:
      C類(對象):
      @staticmethod
      def f (arg1 ,arg2 ,):
      以上實例聲明了靜態方法f,從而可以實現實例化使用C().f(),當然也可以不實例化調用該方法Cf()。

    1. perspective_transform主要是將圖片進行透視轉換。
      • 原理:是將圖像投影到一個新平面
      • 程式碼範例:
        src:代表扭轉前的原始座標
        dst:代表想要扭轉的目標座標
        通過下列的透視變換函數,輸入兩個數組,並返回M矩陣-扭轉矩陣
      ​​​​​M = cv2.getPerspectiveTransform(src, dst)
      
      也可以通過反透視變換函數,恢復原來的圖像,只需要對調函數中的數組的順序,返回Minv矩陣
      ​​​​​Minv = cv2.getPerspectiveTransform(dst, src)
      
      將扭轉矩陣M輸入,進行的原始圖片的變換,其中img代表的是原圖像,M代表的是扭轉矩陣,img_size代表的是轉換之後的尺寸(可以設置為相同尺寸)。
      ​​​​​warped = cv2.warpPerspective(img, M, img_size, borderMode=cv2.BORDER_CONSTANT, borderValue=0)
      
      當使用warpPerspective將圖像縮小到較小時,周圍會出現黑色區域。可以調整 borderValue
    1. 最後由 gen 回傳 score map
  • 完整程式碼參考

#產生一個高斯熱力圖
def gaussian_2d():
    """
    Create a 2-dimensional isotropic Gaussian map.
    :return: a 2D Gaussian map. 1000x1000.
    """
    mean = 0
    radius = 2.5
    # a = 1 / (2 * np.pi * (radius ** 2))
    a = 1.
    x0, x1 = np.meshgrid(np.arange(-5, 5, 0.01), np.arange(-5, 5, 0.01))
    x = np.append([x0.reshape(-1)], [x1.reshape(-1)], axis=0).T

    m0 = (x[:, 0] - mean) ** 2
    m1 = (x[:, 1] - mean) ** 2
    gaussian_map = a * np.exp(-0.5 * (m0 + m1) / (radius ** 2))
    gaussian_map = gaussian_map.reshape(len(x0), len(x1))

    max_prob = np.max(gaussian_map)
    min_prob = np.min(gaussian_map)
    gaussian_map = (gaussian_map - min_prob) / (max_prob - min_prob)
    gaussian_map = np.clip(gaussian_map, 0., 1.)
    return gaussian_map

class GaussianGenerator:
    def __init__(self):
        self.gaussian_img = gaussian_2d()

    @staticmethod
    def perspective_transform(src, dst_shape, dst_points):
        """
        Perspective Transform
        :param src: Image to transform.
        :param dst_shape: Tuple of 2 intergers(rows and columns).
        :param dst_points: [[x1, y1], [x2, y2], [x3, y3], [x4, y4]].
        :return: Image after perspective transform.
        """
        img = src.copy()
        h, w = img.shape[:2]

        src_points = np.float32([[0, 0], [w, 0], [w, h], [0, h]])
        dst_points = np.float32(dst_points)
        perspective_mat = cv2.getPerspectiveTransform(src=src_points, dst=dst_points)
        dst = cv2.warpPerspective(img, perspective_mat, (dst_shape[1], dst_shape[0]),
                                  borderValue=0, borderMode=cv2.BORDER_CONSTANT)
        return dst

    def gen(self, score_shape, points_list):
        score_map = np.zeros(score_shape, dtype=np.float32)
        for points in points_list:
            tmp_score_map = self.perspective_transform(self.gaussian_img, score_shape, points)
            score_map = np.where(tmp_score_map > score_map, tmp_score_map, score_map)
        score_map = np.clip(score_map, 0, 1.)
        return score_map

三、預測

  • 在 ICDAR 的資料集,是用 word-level bounding box 的 iou 來衡量成效

  • 產出 bounding boxes 的步驟

    • 步驟 1:產出 binray map (

      M(p)):由 0 和 1 所組成,1 代表附近可能有字詞

      • p 是 pixel 的數字
      • τa
        τr
        分別是 affinity score 與 region score 的 threshold
        M(p)={1   ,if Sr(p)>τrorSa(p)>τa0   ,otherwise
    • 步驟 2:對

      M 執行 Connected Component Labeling (CCL)

      • Connected Component Labeling 就是一種劃分聯通區域的演算法
      • CRAFT-pytorch 是用 cv2.connectedComponentsWithStats (以下是小範例)
    • 步驟 3:產出 QuadBox (word-level bounding boxes)

      • 從 CCL 中的每個聯通區域找出涵蓋區域的最小旋轉矩形 (cv2.minAreaRect)
  • 產出 polygon 的步驟

    • 藍線:通過各字元 local maximum 所產生的垂直線
    • 黃線:將各字元 local maximum 的點相連所產生的線
    • 紅線:通過各字元 local maximum 且與黃線互為垂直的線 (與藍線的角度為
      θ
      )
    • 控制點:
      • 兩邊字元:紅線平行向外延伸至 character region 的最外圍,並取頂點
      • 中間字元:取紅線的頂點
    • polygon:將各字元的控制點相連

四、實驗 (昊中)

資料集

  • Quadrilateral-type datasets (四邊形)
    • ICDAR2013(IC13)、ICDAR2015(IC15)、ICDAR2017(IC17)、MSRA-TD500
#left, top, right, bottom, 'transcription' [1, 3, 200, 88, 'Photo'], [22, 249, 113, 286, 'The']
  • Polygon-type datasets (多邊形)
    • TotalText、CTW-500
    ​​​​'california' ​​​​Polygon-Shpaed Ground Truth: ​​​​ x:[306, 335, 379, 424, 463, 481, 460] ​​​​ y:[26, 28, 52, 85, 104, 76, 50] ​​​​Orientation:Curved

訓練策略

  • 訓練過程包含兩個步驟
    • 使用SynthText dataset訓練模型(迭代50次),再將訓練完成的模型進行fine-tune。
    • Fine-tune過程中會使用1:5的SynthText dataset資料確保字元區域是被分開的,而On-line Hard Example Mining採用1:3的資料來確保在自然場景中過濾出類似紋理的文字。
      • 找到 hard sample 做加強訓練
  • Augmentation如裁切、旋轉、顏色變化等會進行採用。

Weakly-supervised training

  • Weakly-supervised training會需要兩種類型的資料:
    • 用於裁切單詞圖像的四邊形註釋
    • 用於計算單詞長度的轉錄
    • 只有IC13、IC15、IC17 dataset符合
  • CRAFT只有使用ICDAR datasets做訓練,其他資料集則用來做測試,實驗共訓練兩個模型:
    • Train on IC15, evaluate IC15 only.
    • Train on IC13 and IC17, evaluate the other five datasets.

五、補充 (LILI、昱睿)

以下結果因為一些意外無法得到結果,請大家自行測試 (future work)

實際運用的時候

  • 使用 50 張身分證影像檔測試
狀態\模型 CRAFT YOLOv4
讀取 model weight (GPU memory) 1337 MiB 1483 MiB
inference 階段 (GPU memory) 2935 MiB 3581 MiB
運算時間 (平均) 0.2106 (秒) 0.1144 (秒)
運算時間 (標準差) 0.0207 (秒) 0.02 (秒)

CRAFT 的網路架構與 YOLO 有什麼差異,造成運算資源比較大

其他可以補足 CRAFT 功能的 code ?

  • 本週作業: JN

六、程式碼

六 (1) - inference 的程式碼說明

  • code: clovaai/CRAFT-pytorch
    • test.py
      • 參數用白話文講給你聽

        官方 code 參數 官方 paper 參數 參數說明
        text_threshold 決定各聯通區域是否被認定為字詞
        low_text region threshold 決定各 pixel 是否含有字元
        link_threshold affinity threshold 決定各 pixel 是否為字元間的相鄰區域
      • 平行化處理 (參考連結)

        • 模型平行化:無
        • 資料平行化:torch.nn.DataParallel()
          • 用單一程序多線程的方式將資料分割成子集,並分配到運算單元中執行
    • craft.py
    • craft_utils.py
    • imgproc.py
      • 如果要測記憶體用量的最大值,可以微調 resize_aspect_ratio 函式第 59 行
        • 微調前
        ​​​​​​​​​​​​resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32)
        
        • 微調後 (讓影像長寬固定)
        ​​​​​​​​​​​​resized = np.zeros((square_size, square_size, channel), dtype=np.float32)
        

六 (2) - 訓練階段的程式碼說明

六 (2) - 1: brooklyn1900/CRAFT_pytorch

  1. Train pre-trained model
    1-1. main training function
    ​​​​def train(net, epochs, batch_size, test_batch_size, lr, ​​​​ test_interval, max_iter, model_save_path, save_weight=True): ​​​​ train_data = SynthDataset(image_transform=image_transform, ​​​​ label_transform=label_transform, ​​​​ file_path=args.gt_path, ​​​​ image_dir=args.synth_dir) ​​​​ train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True) ​​​​ criterion = nn.MSELoss(reduction='none') ​​​​ optimizer = optim.Adam(net.parameters(), lr=lr) ​​​​ for epoch in range(epochs): ​​​​ #前向传播 ​​​​ y, _ = net(images) ​​​​ score_text = y[:, :, :, 0] ​​​​ score_link = y[:, :, :, 1] ​​​​ #联合损失 ohem loss ​​​​ loss = cal_synthText_loss(criterion, score_text, score_link, labels_region, labels_affinity, device) ​​​​ #反向传播 ​​​​ optimizer.zero_grad() #梯度清零 ​​​​ loss.backward() #计算梯度 ​​​​ optimizer.step() #更新权重
    1-2. synthDataset
    ​​​​class SynthDataset(torch.utils.data.Dataset): ​​​​ #获取图片的高斯热力图 ​​​​ def get_region_scores(self, heat_map_size, char_boxes_list): ​​​​ # 高斯热力图 ​​​​ gaussian_generator = GaussianGenerator() ​​​​ region_scores = gaussian_generator.gen(heat_map_size, char_boxes_list) ​​​​ return region_scores
    1-3. loss function
    ​​​​def get_ohem_num(labels_region, labels_affinity, device): ​​​​ """ ​​​​ labels_region: 训练标签region score ​​​​ labels_affinity: 训练标签affinity score ​​​​ return: 各像素标签的数量 ​​​​ """ ​​​​ numPos_region = torch.sum(torch.gt(labels_region, 0.1)).to(device) ​​​​ numNeg_region = torch.sum(torch.le(labels_region, 0.1)).to(device) ​​​​ numPos_affinity = torch.sum(torch.gt(labels_affinity, 0.1)).to(device) ​​​​ numNeg_affinity = torch.sum(torch.le(labels_affinity, 0.1)).to(device) ​​​​ #pos-neg ratio is 1:3 ​​​​ if numPos_region * 3 < numNeg_region: ​​​​ numNeg_region = numPos_region * 3 ​​​​ if numPos_affinity * 3 < numNeg_affinity: ​​​​ numNeg_affinity = numPos_affinity * 3 ​​​​return numPos_region, numNeg_region, numPos_affinity, numNeg_affinity ​​​​def cal_synthText_loss(criterion, score_text, ​​​​ score_link, labels_region, labels_affinity, device): ​​​​ """ ​​​​ 计算synthText强数据集的loss ​​​​ criterion: 损失函数 ​​​​ score_text: 网络输出的region score ​​​​ score_link: 网络输出的affinity score ​​​​ labels_region: 训练标签region score ​​​​ labels_affinity: 训练标签affinity score ​​​​ return: loss ​​​​ """ ​​​​ numPos_region, numNeg_region, numPos_affinity, numNeg_affinity = get_ohem_num(labels_region, labels_affinity, device) ​​​​ #联合损失 ohem loss ​​​​ #取全部的postive pixels的loss ​​​​ loss1_fg = criterion(score_text[np.where(labels_region > 0.1)], labels_region[np.where(labels_region > 0.1)]) ​​​​ loss1_fg = torch.sum(loss1_fg) / numPos_region.to(torch.float32) ​​​​ loss1_bg = criterion(score_text[np.where(labels_region <= 0.1)], labels_region[np.where(labels_region <= 0.1)]) ​​​​ #selects the pixel with high loss in the negative pixels ​​​​ loss1_bg, _ = loss1_bg.sort(descending=True) ​​​​ loss1_bg = torch.sum(loss1_bg[:numNeg_region]) / numNeg_region.to(torch.float32) ​​​​ #loss2 = labels_affinity... ​​​​ #loss2與loss1一樣,省略 ​​​​ #联合loss ​​​​ loss = loss1_fg + loss1_bg + loss2_fg + loss2_bg ​​​​return loss
  2. Fine-tune pre-trained model (weakly-supervised training)
    • train_finetune.py (line 64~69)
    ​​​​# ic13_length:synth_data = 1:5 ​​​​synth_data = torch.utils.data.Subset(synth_data, range(5*ic13_length)) ​​​​# 使用合併後的 data 進行 fine-tune ​​​​fine_tune_data = torch.utils.data.ConcatDataset([synth_data, ic13_data]) ​​​​train_data, val_data = torch.utils.data.random_split(fine_tune_data, [5*ic13_length, ic13_length])

六 (2) - 2: backtime92/CRAFT-Reimplementation

六 (2) - 2.1: 關於 ground truth

  • 下載 Sythdata:在這個連結可以下載,不過檔案有 40G 這麼大喔!
  • ground truth 是用 .mat 檔儲存,是 matlab 的檔案格式
  • 要用 scipy 的語法讀取
import scipy.io as scio synthtext_folder = os.path.join(current_dir, 'SynthText', 'SynthText') gt = scio.loadmat(os.path.join(synthtext_folder, 'gt.mat'))
  • 讀取出來 gt 這個變數是 dictionary,共有下面幾個 keys
    • __header__
    • __version__
    • __globals__
    • charBB: 字元的 bounding boxes
      • shpae: ((x, y), number of points, number of characters)
      • shape: (2, 4, 54) -> 意思是 54 個字元,每個字元 4 個座標,每個座標用 (x, y) 來描述
    • wordBB: 字詞的 bounding boxes
    • shape: (2, 4, 15) -> 意思是 15 個字詞,每個字元 4 個座標,每個座標用 (x, y) 來描述
    • imnames: 圖片的路徑檔名
    • txt: 是 word level 的字詞,不過需要用 '\n' 分開才可以對照

六 (2) - 2.2: 關於 Synth80k

  • 繼承 craft_base_dataset

    • craft_base_dataset 必須繼承 torch.utils.data.Dataset 這個 class,這樣之後才可以彙整到 torch.utils.data.DataLoader 變成一個新的物件
  • 主要會使用 __getitem__ 這個函數

    • 他會 return self.pull_item 的函數值
  • pull_item

def pull_item(self, index: int): """ 根據序號取得訓練時的答案,主要就是 region score, affinity score Args: index: 圖檔 list 的序號 Returns: image: 圖檔 (torch.Tensor) shape: (batch_size, n_channels, h, w) region_scores_torch 這個圖檔的 region score (torch.Tensor) shape: (batch_size, h/2, w/2) affinity_scores_torch 這個圖檔的 affinity score (torch.Tensor) shape: (batch_size, h/2, w/2) confidence_mask_torch 這個圖檔的 confidence (torch.Tensor) shape: (batch_size, h/2, w/2) confidences: 其實都是 1 (torch.Tensor) shape: (batch_size, ) """ return image, region_scores_torch, affinity_scores_torch, confidence_mask_torch, confidences

六 (2) - 2.3: 關於 train_loader

  • 是 torch.utils.data.DataLoader
    • 用 enumerate,他就會回圈式地 return self.__getitem__ 的值

    • gaussian.py: 主要是看 GaussianTransformer 這個 class

      1. self._gen_gaussian_heatmap 產生出 2 維度的單位常態分佈
      2. self.standardGaussianHeat 會是根據 image 圖片的 heatmap
      3. self.standardGaussianHeat 可以產生 affnity score 以及 region score
      4. 接著上面的這些 attribute 就可以根據 bounding boxes 被 apply 到每個 character 裡面 (在 self.generate_region)
      5. 在 data_loader 就會引用 GaussianTransformer

七、參考資料

觀念

https://blog.csdn.net/xz1308579340/article/details/106432966
https://blog.csdn.net/qq_18315295/article/details/104392379
https://xiaosean.github.io/deep learning/computer vision/2019-07-21-Text-Detection-CRAFT/
https://zhuanlan.zhihu.com/p/68855938

程式碼

問題區

  • 已關閉 (沒有人)
姓名 問題 解答