Categories
Computer Data

วิธีเทรน Segmentation โดย mmSegmentation

ในบทความนี้จะแนะนำเครื่องมือที่มีชื่อว่า mmSegmentation ที่มีหน้าที่ทำ Semantic Segmentation ซึ่งเป็นหัวข้อหนึ่งใน Computer Vision ที่แยกส่วนภาพวัตถุออกจากัน

ปกติเวลาที่เราต้องการแยกวัตถุออกจากวัตถุหนึ่งโดยตาของคน อันนี้ทำได้ไม่ยาก เพราะเราแยกออกจากกันได้ง่ายอยู่แล้ว แต่จะให้คอมพิวเตอร์แยกวัตถุแต่ละอย่างออกจากภาพได้ อันนี้เราจำเป็นต้องให้คอมพิวเตอร์เรียนรู้เสียก่อน

โดยเรื่องนี้อยู่ในหัวข้อ AI (Artificial Intelligence) อย่าง Computer Vision ที่อยู่ในเรื่องของ Semantic Segmentation

Semantic Segmentation

Semantic Segmentation คือเทคนิคหนึ่งของ Computer Vision ที่มีหน้าที่ในการแยกส่วนและจำแนกวัตถุแต่ละวัตถุในภาพแยกจากกันโดยกำหนดคลาสในแต่ละจุดพิคเซลว่าเป็นคลาสอะไร ซึ่งจุดนี้จะแตกต่างกับ

  • Object Detection ตรงที่ Object Detection จะจับวัตถุออกมาเป็น Bounding Box
  • Instance Segmentation ที่แยกส่วนร่วมกับจำแนกวัตถุออกจากกันโดยจำเพาะว่าวัตถุนี้เป็นวัตถุที่ 1,2,3

ถ้ายังนึกภาพไม่ออก ก็ดูตามด้านล่างนี้ได้ครับ

ภาพแสดงความแตกต่างระหว่าง Semantic Segmentation, Object Detection และ Instance Segmentation (จากงานวิจัยของ Li, Johnson and Yeung, 2017 เรื่อง Robot-Human-Learning for Robotic Picking Processes)

mmSegmentation

mmSegmentation

เครื่องมือ mmSegmentation เป็นไลบรารีหนึ่งที่ได้รับการสร้างขึ้นมาโดยทางทีมงาน OpenMMLab ที่เป็นเครื่องมือในการสร้าง การเทรน และการทดสอบโมเดลสำหรับการใช้งานทาง Computer Vision ในด้าน Semantic Segmentation ที่อยู่บนพื้นฐานของ PyTorch ที่เป็นไลบรารีสำหรับการใช้งานทางด้าน AI ในด้าน Deep Learning ครับ

เครื่องมือ mmSegmentation จะมีโมเดลที่มีอยู่แล้วในระบบ ผู้ใช้สามารถนำโมเดลเหล่านี้ไปหยิบใช้งานได้เลยอย่างง่ายดายเพียงพิมพ์โค้ดเพิ่มเติมลงไปก็สามารถเทรน และทดสอบโมเดลได้แล้ว โดยโมเดลที่มีมาให้อยู่แล้วได้แก่

  • FCN (CVPR’2015/TPAMI’2017)
  • ERFNet (T-ITS’2017)
  • UNet (MICCAI’2016/Nat. Methods’2019)
  • PSPNet (CVPR’2017)
  • DeepLabV3 (ArXiv’2017)
  • BiSeNetV1 (ECCV’2018)
  • PSANet (ECCV’2018)
  • DeepLabV3+ (CVPR’2018)
  • UPerNet (ECCV’2018)
  • ICNet (ECCV’2018)
  • NonLocal Net (CVPR’2018)
  • EncNet (CVPR’2018)
  • Semantic FPN (CVPR’2019)
  • DANet (CVPR’2019)
  • APCNet (CVPR’2019)
  • EMANet (ICCV’2019)
  • CCNet (ICCV’2019)
  • DMNet (ICCV’2019)
  • ANN (ICCV’2019)
  • GCNet (ICCVW’2019/TPAMI’2020)
  • FastFCN (ArXiv’2019)
  • Fast-SCNN (ArXiv’2019)
  • ISANet (ArXiv’2019/IJCV’2021)
  • OCRNet (ECCV’2020)
  • DNLNet (ECCV’2020)
  • PointRend (CVPR’2020)
  • CGNet (TIP’2020)
  • BiSeNetV2 (IJCV’2021)
  • STDC (CVPR’2021)
  • SETR (CVPR’2021)
  • DPT (ArXiv’2021)
  • Segmenter (ICCV’2021)
  • SegFormer (NeurIPS’2021)
  • และ K-Net (NeurIPS’2021)

เมื่อทราบรายการโมเดลที่รองรับแล้ว เรามาเริ่มใช้งานกันเลย อย่างแรกก็ติดตั้งไลบรารี

สำหรับผู้อ่านที่ไม่ต้องการติดตั้ง และเขียนโค้ดตามด้านล่างนี้ อีกวิธีหนึ่งที่ศึกษาไลบรารีนี้ได้คือเข้าไปทดลองใช้งาน Google Colab ของ mmSegmentation เองครับ

การติดตั้งไลบรารี

ก่อนติดตั้งไลบรารี เราก็ต้องรู้ความต้องการระบบขั้นต่ำมีรายละเอียดตามด้านล่างนี้ครับ

  • Linux/Mac
  • Python 3.6+
  • PyTorch 1.3+
  • CUDA Toolkit 9.2+
  • GCC 5+
  • MMCV โดยติดตั้งให้ตรงกับรุ่นของ CUDA และ PyTorch

เมื่อรู้ความต้องการแล้ว เราก็มาติดตั้งไลบรารีกัน โดยเราจะใช้งานบน Google Colab ครับ เนื่องมาจากเหมาะกับมือใหม่ และสามารถใช้งานได้ง่าย ไม่ยุ่งยาก แถมไม่ต้องมาติดตั้ง CUDA Toolkit เองอีก

การติดตั้งไลบรารีลงบน Google Colab ทำได้โดย

  • ติดตั้งไลบรารีของไพทอนเองได้แก่ matplotlib numpy scikit-learn scipy และ pillow
!pip3 install matplotlib numpy scikit-learn scipy pillow
  • ติดตั้ง pyTorch รุ่น 1.11.0 ติดตั้งได้โดยการพิมพ์คำสั่งตามด้านล่างนี้
!pip3 install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
  • ติดตั้ง MMCV โดยให้เลือกไลบรารี mmcv-full รุ่นล่าสุด ที่สามารถติดตั้งได้โดย
!pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.10.0/index.html

เมื่อติดตั้ง MMCV เสร็จแล้ว เราติดตั้ง mmSegmentation ได้สองวิธี วิธีแรกติดตั้งผ่านการใช้ pip โดยการพิมพ์ตามด้านล่างนี้ครับ

pip install mmsegmentation

ส่วนอีกวิธี เราสามารถคอมไพล์ตัวโค้ดได้โดยการพิมพ์ตามด้านล่างนี้ครับ ซึ่งเป็นวิธีที่แนะนำ เพราะนอกจากจะได้คอมไพล์ตัวโค้ดแล้ว ยังได้ไฟล์การตั้งค่าที่เกี่ยวข้องกับโมเดลและชุดข้อมูลที่ต้องการครับ

ในตัวอย่างนี้ เราจะใช้วิธีนี้ครับ ส่วนใครที่ใช้วิธีบน เราสามารถดาวน์โหลดไฟล์การตั้งค่าจากเว็บได้จากเว็บของ mmSegmentation เองครับ

git clone https://github.com/open-mmlab/mmsegmentation.git
cd mmsegmentation
pip install -e .  # หรือ "python setup.py develop"

เมื่อติดตั้งไลบรารีที่จำเป็นครบทุกอย่างแล้ว เราทดสอบว่าติดตั้งไลบรารีสมบูรณ์หรือไม่ได้โดยการพิมพ์โค้ดตามด้านล่างนี้

# Check Pytorch installation
import torch, torchvision
print(torch.__version__, torch.cuda.is_available())

# Check MMSegmentation installation
import mmseg
print(mmseg.__version__)

เริ่มต้นการเรียกใช้โมเดลกันเถอะ

เราสามารถเรียกใช้งานโมเดลสำหรับการเทรนเพื่อใช้ทำ Semantic Segmentation ได้โดย

  1. เขียนไฟล์ Python การตั้งค่าโมเดล
  2. เตรียมชุดข้อมูล (Dataset)
  3. ตั้งค่าที่เกี่ยวข้องกับการเทรน
  4. เริ่มต้นการเทรนข้อมูล
  5. การทดสอบโมเดล

เรามากล่าวถึงขั้นตอนแรก

1.) เขียนไฟล์ Python การตั้งค่าโมเดล

ขั้นตอนนี้เป็นการตั้งค่าให้เรียบร้อยก่อนการเทรน โดยปกติเวลาเทรนโมเดลในแต่ละครั้งจะประกอบได้เป็นการตั้งค่าโมเดล ชุดข้อมูล และตั้งการพารามิเตอร์การเทรนที่เกี่ยวข้อง

แต่ก่อนจะเริ่มการตั้งค่าโมเดล เราจำเป็นต้องโหลด Pretrained โมเดลสำหรับการทำ Transfer Learning เสียก่อน โดยดาวน์โหลดตามด้านล่างนี้ที่เป็นโมเดล PSPNet ครับ

!mkdir checkpoints
!wget https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth -P checkpoints

เมื่อดาวน์โหลดเสร็จแล้ว เราสามารถทดลองใช้งานโมเดลได้โดยการพิมพ์โค้ดตามด้านล่างนี้ครับ

from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
from mmseg.core.evaluation import get_palette

config_file = './mmsegmentation/configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py'
checkpoint_file = './checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'

model = init_segmentor(config_file, checkpoint_file, device='cuda:0')

# test a single image
img = './mmsegmentation/demo/demo.png'
result = inference_segmentor(model, img)

# show the results
show_result_pyplot(model, img, result, get_palette('cityscapes'))

เมื่อพิมพ์โค้ดตามด้านบนนี้แล้ว ระบบจะแสดงผลลัพธ์ออกมาเป็นภาพตามด้านล่างนี้ครับ

ภาพผลลัพธ์ที่ได้จากการใช้งานโมเดลเพื่อทำ Semantic Segmentation

2.) เตรียมชุดข้อมูล (Dataset)

ขั้นตอนต่อไปหลังจากการตั้งค่าโมเดล คือตั้งค่าชุดข้อมูล (Dataset)

เราสามารถเลือก Dataset ที่มีอยู่แล้วได้แก่ Cityscapes, DRIVE, COCO-Stuff 10k หรือ 164k และอื่น ๆ รวมถึงตั้งค่า Data Pipeline กับ Data Augmentation สำหรับการเทรน และทดสอบโมเดลที่เราเลือกใช้

ในที่นี้เราจะใช้งาน Dataset อย่าง Stanford Background Dataset ที่เป็น Dataset แบบที่เปิดให้คนใช้งานได้ทั่วไปที่เป็นภาพเกี่ยวกับฉาก Outdoor ทั้งหมด 715 ภาพ โดยดึงมาจาก Dataset อย่าง LabelMeMSRCPASCAL VOC and Geometric Context ที่มีความละเอียดของภาพอยู่ที่ 320×240 pixel

การทำ Annotation ของ Dataset นี้จะกำหนดคลาสในแต่ละจุด Pixel ของภาพ โดยมีทั้งหมด 8 คลาสครับ ได้แก่ sky, tree, road, grass, water, building, mountain และ foreground object.

การเริ่มต้นใช้งาน Dataset ทำได้โดย การดาวน์โหลด

# download and unzip
!wget http://dags.stanford.edu/data/iccv09Data.tar.gz -O stanford_background.tar.gz
!tar xf stanford_background.tar.gz

เมื่อดาวน์โหลดและแตกไฟล์เสร็จแล้ว เราสามารถเรียกดูรูปภาพได้โดยการพิมพ์โค้ดตามด้านล่างนี้

# Let's take a look at the dataset
import mmcv
import matplotlib.pyplot as plt

img = mmcv.imread('./iccv09Data/images/6000124.jpg')
plt.figure(figsize=(8, 6))
plt.imshow(mmcv.bgr2rgb(img))
plt.show()

กดรันตัวโค้ดแล้วจะได้ภาพตามด้านล่างนี้

ภาพจาก Stanford Background Dataset

เมื่อดูรูปภาพเรียบร้อย เราจำเป็นต้องมาปรับแต่ง Dataset ให้เหมาะสมต่อการเทรน ทำได้โดยการพิมพ์โค้ดตามด้านล่างนี้ครับ

import os.path as osp
import numpy as np
from PIL import Image

# convert dataset annotation to semantic segmentation map
data_root = 'iccv09Data'
img_dir = 'images'
ann_dir = 'labels'

# define class and plaette for better visualization
classes = ('sky', 'tree', 'road', 'grass', 'water', 'bldg', 'mntn', 'fg obj')
palette = [[128, 128, 128], [129, 127, 38], [120, 69, 125], [53, 125, 34], 
           [0, 11, 123], [118, 20, 12], [122, 81, 25], [241, 134, 51]]

for file in mmcv.scandir(osp.join(data_root, ann_dir), suffix='.regions.txt'):
  seg_map = np.loadtxt(osp.join(data_root, ann_dir, file)).astype(np.uint8)
  seg_img = Image.fromarray(seg_map).convert('P')
  seg_img.putpalette(np.array(palette, dtype=np.uint8))
  seg_img.save(osp.join(data_root, ann_dir, file.replace('.regions.txt', 
                                                         '.png')))

จากลักษณะโค้ดตามข้างบนนี้จะเป็นการกำหนดคลาสและสีในแต่ละคลาส รวมถึงสแกนไฟล์ .regions.txt เพื่อดึงข้อมูลการกำหนดสีในแต่คลาสที่กำหนดในแต่ละพิคเซลครับ

เมื่อกำหนดสีเสร็จแล้ว เราสามารถเรียกดูผลลัพธ์ได้โดยการพิมพ์โค้ดตามด้านล่างนี้ครับ

# Let's take a look at the segmentation map we got
import matplotlib.patches as mpatches

img = Image.open('./iccv09Data/labels/6000124.png')
plt.figure(figsize=(8, 6))
im = plt.imshow(np.array(img.convert('RGB')))

# create a patch (proxy artist) for every color 
patches = [mpatches.Patch(color=np.array(palette[i])/255., 
                          label=classes[i]) for i in range(8)]

# put those patched as legend-handles into the legend
plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0., 
           fontsize='large')

plt.show()

ผลลัพธ์ที่ได้จะเป็นไปตามด้านล่างนี้

ผลลัพธ์จากการใช้คำสั่ง putpalette ของ PIL

ต่อจากการเตรียม Dataset แล้ว เราจำเป็นต้องเตรียมไฟล์แสดงรายการไฟล์ที่แบ่ง Dataset ออกเป็น

  • Dataset สำหรับการ Training
  • Dataset สำหรับการทำ Validation
  • Dataset สำหรับการทำ Testing

โดยจะมีสัดส่วนที่แตกต่างกันออกไป แต่ในตัวอย่างนี้เราจะเป็นออกเป็น Dataset สำหรับการ Training และ Validation โดยแบ่งเป็น 80:20 ครับ เราเขียนโค้ดได้ตามด้านล่างนี้

# split train/val set randomly
split_dir = 'splits'
mmcv.mkdir_or_exist(osp.join(data_root, split_dir))
filename_list = [osp.splitext(filename)[0] for filename in mmcv.scandir(
    osp.join(data_root, ann_dir), suffix='.png')]

with open(osp.join(data_root, split_dir, 'train.txt'), 'w') as f:
  # select first 4/5 as train set
  train_length = int(len(filename_list)*4/5)
  f.writelines(line + '\n' for line in filename_list[:train_length])

with open(osp.join(data_root, split_dir, 'val.txt'), 'w') as f:
  # select last 1/5 as train set
  f.writelines(line + '\n' for line in filename_list[train_length:])

เมื่อเตรียม Dataset รวมถึงเตรียมไฟล์แสดงรายการไฟล์ที่ถูกแบ่งใน Dataset เพื่อการทำ Training และ Validation แล้ว เราจำเป็นต้องสร้าง Class ในไพทอนใหม่เพื่อกำหนด Dataset ที่สร้างขึ้น เนื่องมาจาก Dataset นี้ทาง mmSegmentation ไม่ได้เตรียมไว้ เราเขียนโค้ดได้ตามด้านล่างนี้

from mmseg.datasets.builder import DATASETS
from mmseg.datasets.custom import CustomDataset

@DATASETS.register_module()
class StanfordBackgroundDataset(CustomDataset):
  CLASSES = classes
  PALETTE = palette
  def __init__(self, split, **kwargs):
    super().__init__(img_suffix='.jpg', seg_map_suffix='.png', 
                     split=split, **kwargs)
    assert osp.exists(self.img_dir) and self.split is not None

ตัวโค้ดตามด้านบนนี้จะเป็นการสร้างคลาส รวมถึงเป็นการกำหนดไฟล์ภาพที่ลงท้ายด้วย .jpg รวมถึงไฟล์ Annotation ที่ลงท้ายด้วย .png ครับ

3.) ตั้งค่าที่เกี่ยวข้องกับการเทรน

เราสามารถตั้งค่าที่เกี่ยวข้องกับการเทรน โดยการตั้งค่านี้เราสามารถตั้งค่าเกี่ยวกับโมเดล, Dataset, Learning Rate และอื่น ๆ ได้ครับ อย่างไรก็ดี เราไม่จำเป็นต้องเขียนขึ้นมาเองหมด เรานำเข้าไฟล์การตั้งค่าที่มีอยู่แล้ว แล้วนำมาปรับแต่งเพิ่มได้

ขั้นตอนแรก นำเข้าไฟล์การตั้งค่า

from mmcv import Config
cfg = Config.fromfile('./mmsegmentation/configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py')

ตั้งค่าโมเดล

ต่อมาเป็นการตั้งค่าตัวโมเดล โดยกำหนดให้ใช้ Batch Normalization แทนที่ SyncBN เนื่องมาจากมีการ์ดจอเพียงใบเดียว

from mmseg.apis import set_random_seed

# Since we use only one GPU, BN is used instead of SyncBN
cfg.norm_cfg = dict(type='BN', requires_grad=True)
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg

ต่อมาเป็นการตั้งค่าจำนวนคลาส ใน Dataset นี้มี 8 คลาส เราเขียนได้ตามด้านล่างนี้

# modify num classes of the model in decode/auxiliary head
cfg.model.decode_head.num_classes = 8
cfg.model.auxiliary_head.num_classes = 8

ตั้งค่า Dataset

ขั้นตอนถัดไปจากโมเดล เป็นการตั้งค่า Dataset ที่สามารถชนิด Dataset, กำหนดจำนวน Batch size, กำหนด Data Pipeline ที่เกี่ยวข้องกับการโหลดไฟล์รวมถึงการทำ Data Augmentation และตั้งค่าที่อยู่ของไฟล์ Dataset

อย่างแรก เป็นการกำหนดชนิด Dataset โดยกำหนดให้ชื่อตรงกับคลาสที่เราสร้างขึ้น รวมถึงกำหนดโฟลเดอร์หลักของ Dataset ได้ตามด้านล่างนี้

# Modify dataset type and path
cfg.dataset_type = 'StanfordBackgroundDataset'
cfg.data_root = data_root

อย่างที่สอง เป็นการกำหนดจำนวน Batch Size ในที่นี้กำหนดให้มี Batch size เท่ากับ 8

cfg.data.samples_per_gpu = 8
cfg.data.workers_per_gpu = 8

อย่างที่สาม กำหนด Data Pipeline เรากำหนดตั้งค่าโหลดรูปภาพ ปรับขนาดรูป

cfg.img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
cfg.crop_size = (256, 256)
cfg.train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', img_scale=(320, 240), ratio_range=(0.5, 2.0)),
    dict(type='RandomCrop', crop_size=cfg.crop_size, cat_max_ratio=0.75),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='PhotoMetricDistortion'),
    dict(type='Normalize', **cfg.img_norm_cfg),
    dict(type='Pad', size=cfg.crop_size, pad_val=0, seg_pad_val=255),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]

ตามด้านบนนี้เราโหลดไฟล์ภาพ รวมถึงโหลดข้อมูล Annotation (ที่ได้ Label ไว้) แล้วนำมาทำ Data Augmentation ได้แก่

  1. ปรับขนาดรูป
  2. Crop รูป กับ Flip รูปแบบสุ่ม
  3. ทำ Photometric Distortion ที่ตามเอกสารของตัวไลบรารีได้กล่าวไว้ว่า

Apply photometric distortion to image sequentially, every transformation is applied with a probability of 0.5. The position of random contrast is in second or second to last.

1. random brightness

2. random contrast (mode 0)

3. convert color from BGR to HSV

4. random saturation

5. random hue

6. convert color from HSV to BGR

7. random contrast (mode 1)

8. randomly swap channels

นั่นก็คือเป็นการกำหนดการบิดเบือนทางแสงแบบสุ่ม โดย

  • สุ่มค่า Brightness กับ Contrast
  • ปลี่ยนค่าสีจาก BGR เป็น HSV และเปลี่ยนกลับจาก HSV gป็น BGR
  • ปรับ Saturation และ Hue แบบสุ่ม
  • สลับ Channel ของสีในภาพ

แล้วนำมาแปลงให้อยู่ในรูปแบบที่เหมาะสำหรับการเทรน

ต่อมาเป็นการตั้งค่า Data Pipeline สำหรับการทดสอบ เราเขียนโค้ดออกมา

ส่วนกรณีที่ทดสอบข้อมูล เราจะเริ่มด้วยการโหลดไฟล์ แล้วนำมาปรับขนาดให้อยู่ในขนาดที่ต้องการ จากนั้นทำ Normalize ข้อมูล แล้วแปลงให้อยู่ในรูปที่เหมาะต่อการทดสอบ นั่นคือรูป Tensor

ด้านล่างนี้จะเป็นการตั้งค่า Data Pipeline สำหรับการทดสอบชุดข้อมูล

cfg.test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(320, 240),
        # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
        flip=False,
        transforms=[
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Normalize', **cfg.img_norm_cfg),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img']),
        ])
]

ต่อจาก Data Pipeline แล้ว เราจำเป็นต้องกำหนดที่อยู่ไฟล์ Dataset นี้ เราตั้งค่าได้ตามด้านล่างนี้ครับ

cfg.data.train.type = cfg.dataset_type
cfg.data.train.data_root = cfg.data_root
cfg.data.train.img_dir = img_dir
cfg.data.train.ann_dir = ann_dir
cfg.data.train.pipeline = cfg.train_pipeline
cfg.data.train.split = 'splits/train.txt'

cfg.data.val.type = cfg.dataset_type
cfg.data.val.data_root = cfg.data_root
cfg.data.val.img_dir = img_dir
cfg.data.val.ann_dir = ann_dir
cfg.data.val.pipeline = cfg.test_pipeline
cfg.data.val.split = 'splits/val.txt'

cfg.data.test.type = cfg.dataset_type
cfg.data.test.data_root = cfg.data_root
cfg.data.test.img_dir = img_dir
cfg.data.test.ann_dir = ann_dir
cfg.data.test.pipeline = cfg.test_pipeline
cfg.data.test.split = 'splits/val.txt'

จากโค้ดด้านบนจะแบ่งออกเป็น 3 ชุด ได้แก่ Training, Validation และ Testing โดยการตั้งค่า Validation และ Testing ใช้การตั้งค่าแบบเดียวกัน

จากตัวโค้ดด้านบน เราสามารถตั้งค่า

  1. ชนิดของ Dataset
  2. ที่อยู่โฟลเดอร์ของ Dataset
  3. ที่อยู่โฟลเดอร์ของไฟล์ภาพ และไฟล์ Annotation
  4. กำหนด Data Pipeline
  5. ที่อยู่โฟลเดอร์ของไพล์ข้อความที่เก็บรายชื่อไฟล์ภาพสำหรับการทำ Training, Validation และ Testing

กำหนดพารามิเตอร์ที่เกี่ยวข้องกับการ Training และ Validation

เราสามารถกำหนดค่าได้ตามด้านล่างนี้ครับ

# We can still use the pre-trained Mask RCNN model though we do not need to
# use the mask branch
cfg.load_from = './checkpoints/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth'

# Set up working dir to save files and logs.
cfg.work_dir = './work_dirs/tutorial'

cfg.runner.max_iters = 200
cfg.log_config.interval = 10
cfg.evaluation.interval = 200
cfg.checkpoint_config.interval = 200

# Set seed to facitate reproducing the result
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)

จากตัวโค้ดด้านบนจะเป็นการกำหนด

  • Pretrained โมเดลสำหรับการทำ Transfer Learning
  • ที่อยู่โฟลเดอร์สำหรับการเก็บข้อมูลการ Training
  • กำหนดจำนวนรอบสูงสุดต่อการ Training
  • กำหนดความถี่ของการเขียน Log
  • กำหนดความถี่ของการทำ Validation และการบันทึกโมเดล (เป็น Checkpoint)
  • กำหนดการสุ่มของพารามิเตอร์ให้เหมือนกันทุกครั้งที่สุ่ม

เมื่อตั้งค่าเสร็จแล้ว เราสามารถดูการตั้งค่าทั้งหมดได้โดยการพิมพ์คำสั่ง

# Let's have a look at the final config used for training
print(f'Config:\n{cfg.pretty_text}')
print(f"img_dir = { img_dir }")
print(f"ann_dir = { ann_dir }")

เมื่อกดรันแล้ว ระบบจะแสดงผลการตั้งค่าตามด้านล่างนี้ครับ

ระบบจะแสดง Config ที่เราตั้งค่าไว้ครับ

โดยการตั้งค่าที่กำหนดมาให้แล้วคือตัวโมเดล PSPNet ที่ใช้ Loss Function อย่าง Cross-Entropy Loss

Cross-Entropy Loss เป็นชนิดของ Loss Function สำหรับการใช้งานเพื่อจำแนกข้อมูลว่าอยู่ใน Class ไหน โดยคิดตามในแต่ละ Pixel ได้ผลลัพธ์ของค่าที่อยู่ระหว่าง 0 และ 1 อย่างไรก็ตามการคำนวณลักษณะนี้เราไม่รู้ว่าขอบของพื้นที่ที่ Segment ได้อยู่บริเวณไหนครับ

ต่อมา นอกจาก Loss Function แล้ว ตัวการตั้งค่าที่กำหนดมาให้ ยังกำหนด Optimizer อย่าง Stochastic Gradient Descent (SGD) ที่มีค่า Learning Rate เท่ากับ 0.01 ที่มี Momentum 0.9 และมี Weight Decay เท่ากับ 0.0005

รวมถึงการทำ Learning Policy โดยใช้ Poly ครับ

4.) เริ่มต้นการเทรนข้อมูล

เราเริ่มต้นการเทรนข้อมูลได้โดยการเขียนโค้ดตามด้านล่างนี้ครับ

from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.apis import train_segmentor

# Build the dataset
datasets = [build_dataset(cfg.data.train)]

# Build the detector
model = build_segmentor(
    cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))
# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES

# Create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
train_segmentor(model, datasets, cfg, distributed=False, validate=True, 
                meta=dict())

จากตัวโค้ดด้านบนนี้ เราจะ

  • เริ่มต้นการสร้าง Dataset โดย build_dataset
  • สร้างโมเดลโดย build_segmentor
  • กำหนดอาเรย์คลาสของโมเดล โดย model.CLASSES = datasets[0].CLASSES
  • สร้างโฟลเดอร์ที่เกี่ยวข้องกับการเทรนโมเดลโดย mmcv.mkdir_or_exist

เมื่อกำหนดทุกอย่างเสร็จแล้ว เราเริ่มต้นการเทรนได้คำสั่ง train_segmentor ระบบจะเริ่มต้นการเทรนโมเดลกับชุดข้อมูล Stanford Background Dataset

ผลลัพธ์จะแสดงหน้าจอตามด้านล่างนี้ครับ

ภาพหน้าจอการเทรน

เมื่อเทรนเสร็จแล้วระบบจะแสดงหน้าจอตามด้านล่างนี้ครับ

เราสามารถนำโมเดลนี้ไปทดสอบได้แล้วครับ

5.) การทดสอบโมเดล

เราสามารถทดสอบโมเดลได้โดยการเขียนโค้ดตามด้านล่างนี้ครับ

img = mmcv.imread('./iccv09Data/images/6000124.jpg')

model.cfg = cfg
result = inference_segmentor(model, img)
plt.figure(figsize=(8, 6))
show_result_pyplot(model, img, result, palette)

เมื่อกดปุ่มรัน ระบบจะทดสอบโมเดลบริเวณคำสั่ง inference_segmentor แล้วผลลัพธ์จะแสดงตามคำสั่ง show_result_pyplot โดยจะแสดงสีตามคลาสที่ได้กำหนดไว้ ผลลัพธ์ที่ได้แสดงตามด้านล่างนี้ครับ

ผลลัพธ์ที่ได้จากการรันโมเดล

เมื่อทดสอบโมเดลเสร็จแล้ว เราสามารถนำโมเดลไปทดสอบเพิ่ม หรือนำไป segment แยกวัตถุในภาพเพิ่มเติมได้อีกครับ

สำหรับผู้อ่านที่ต้องการอ่านข้อมูลเพิ่มเติม ศึกษาได้ในเว็บของ mmSegmentation ได้โดยตรงครับ

By Kittisak Chotikkakamthorn

อดีตนักศึกษาฝึกงานทางด้าน AI ที่ภาควิชาวิศวกรรมไฟฟ้า มหาวิทยาลัย National Chung Cheng ที่ไต้หวัน ที่กำลังหางานทางด้าน Data Engineer ที่มีความสนใจทางด้าน Data, Coding และ Blogging / ติดต่อได้ที่: contact [at] nickuntitled.com

One reply on “วิธีเทรน Segmentation โดย mmSegmentation”

Exit mobile version