This repository contains source code of chess keypoint detection task, that was given as a test task in FreeFlex company. In this repo I demonstrate how to solve problem of keypoint detection using pytorch and torchvision frameworks.
- Dataset files: google disk
- Pretrained on this dataset weights: google disk
- Kaggle training script: kaggle link
- Training reports (w/ graphs of losses and images): wandb.ai
- Dataset description
- Launching training
- The structure of repository file
- Training process
- Visualization of predictions results
- Stack of technologies
- Limitations of project
- Contributions
Dataset consists from different files:
xtrain.npy
- numpy array file of shape (15137, 256, 256, 1) - 15137 gray square images of size (256, 256), pixel range: 0.-255.ytrain.npy
- numpy array file of shape (15137, 8) - 15137 keypoints (x1, y2, x2, y2 ...), range of values: 0.-1. ( should be denormalized for passing to model)xtrain.npy
- numpy array file of shape (5, 256, 256, 1) - 5 test images.
Visualization of training data (red points are keypoints of chess desk)
Steps for setting up training/testing process (assuming training locally).
- (Optionally) Register on wandb.ai for saving logs and weights of model.
- Download dataset
from google disk
and put it into
data/
folder (so, data folder should contain 3 files -xtest.npy
,xtrain.npy
andytrain.npy
) - Install requirements file for project:
python -m pip install -r requirements.txt
(if something went wrong - trypip install -r requirements-dev.txt
) - (Optionally) Change default batch size, parameters of model in
config/config.yaml
file - Run training:
python train.py
- In training script you will be asked for api key for connecting to wandb.ai - follow instructions.
config/config.yaml
- config, used for training model (lr, batch-size, etc, hyperparameters of model should be set here or from command line)data/
- data folder, used for storing data (should containxtest.npy xtrain.npy ytrain.npy
files for training/testing)models/
dataloader.py
- making the dataloader from the dataset: torchvision uses special format of inputs, so there was a need of writing custom collate functiondataset.py
- reading dataset from numpy files, wrapping it into default pytorch dataset classtrainer.py
- defining training/testing loop
jupyter-notebooks/
- visualizing and training/testing notebookstrain.py
- training script, setting image logging and running trainer class from pytorch-lightning library
Training time - 3h 40m
Model was trained on 4 epochs (the curve of errors goes to plateau, no need to train the model more): 2 epochs with learning rate 0.001, 2 epochs with learning rate 0.0001, batch size - 10 (batch size 16 did not in kaggle VRAM).
Confidence threshold for scores of prediction take 0.7 - detection the boarders of chess was a simple task, no need to spend much time of tuning parameters.
Charts of training errors,
check wandb.ai
report for more details
Red points - prediction of model, green points - labels of dataset.
Validation dataset - we could see that quite well detects
boarders
Test images - we could see that model mistakes when keypoint is hidden by hand
pytorch/pytorch-lightning
for data extraction and training looptorchvision
for keypoint detection modelhydra
for configuration model/datawandb
for logging images/losses (optional)opencv
for plotting points
Torchvision requires explicitly setting if keypoint is hidden or not - so that's why points, predicted on test images, are not well predicted if they are hidden by hand.
Contributions are welcome, please open PR and describe the implemented functionality.