Skip to content

Latest commit



229 lines (177 loc) · 7.25 KB

File metadata and controls

229 lines (177 loc) · 7.25 KB

semantic_segmentation module

The semantic segmentation module contains the BisenetLearner class, which inherit from the abstract class Learner.

Class BisenetLearner

Bases: engine.learners.Learner

The BisenetLearner class is a wrapper of the BiseNet model [1] found on [BiseNet] ( It is used to train Semantic Segmentation models on RGB images and run inference. The BisenetLearner class has the following public methods:

BisenetLearner constructor

BisenetLearner(self, lr, iters, batch_size, optimizer, temp_path, checkpoint_after_iter, checkpoint_load_iter, device, val_after, weight_decay, momentum, drop_last, pin_memory, num_workers, num_classes, crop_height, crop_width, context_path)

Constructor parameters:

  • lr: float, default=0.01
    Learning rate during optimization.
  • iters: int, default=1
    Number of epochs to train for.
  • batch_size: int, default=1
    Dataloader batch size. Defaults to 1.
  • optimizer: str, default="sgd"
    Name of optimizer to use ("sgd" ,"rmsprop", or "adam").
  • temp_path: str, default=''
    Path in which to store temporary files.
  • checkpoint_after_iter: int, default=0
    Save chackpoint after specific epochs.
  • checkpoint_load_iter: int, default=0
    Unused parameter.
  • device: str, default="cpu"
    Name of computational device ("cpu" or "cuda").
  • val_after: int, default=1
    Perform validation after specific epochs.
  • weight_decay: [type], default=5e-4
    Weight decay used for optimization.
  • momentum: float, default=0.9
    Momentum used for optimization.
  • drop_last: bool, default=True
    Drop last data point if a batch cannot be filled.
  • pin_memory: bool, default=False
    Pin memory in dataloader.
  • num_workers: int, default=4
    Number of workers in dataloader.
  • num_classes: int, default=12
    Number of classes to predict among.
  • crop_height: int, default=720
    Input image height.
  • crop_width: int, default=960
    Input image width.
  • context_path: str, default='resnet18'
    Context path for the bisenet model., dataset, val_dataset, silent, verbose)

This method is used for training the algorithm on a train dataset and validating on a val dataset.


  • dataset: Dataset
    Training dataset.
  • val_dataset: Dataset, default=None
    Validation dataset. If none is given, validation steps are skipped.
  • silent: bool, default=False
    If set to True, disables all printing of training progress reports and other information to STDOUT.
  • verbose: bool, default=True
    If set to True, enables the maximum logging verbosity.


BisenetLearner.eval(self, dataset, silent, verbose)

This method is used to evaluate a trained model on an evaluation dataset. Returns a dictionary containing stats regarding evaluation.


  • dataset: Dataset
    Dataset on which to evaluate model.
  • silent: bool, default=False
    If set to True, disables all printing of training progress reports and other information to STDOUT.
  • verbose: bool, default=True
    If set to True, enables the maximum logging verbosity.


BisenetLearner.infer(self, img)

This method is used to perform segmentation on an image. Returns a object.


  • img: Image
    Image to predict a heatmap., path, mode, verbose, url)

Download pretrained models and testing images to path.


  • path: str, default=None
    Path to metadata file in json format or to weights path.
  • mode: {'pretrained', 'testingImage'}, default='pretrained'
    If 'pretrained', downloads a pretrained segmentation model. If 'testingImage', downloads an image to perform inference on.
  • verbose: bool, default=True
    If True, enables maximum verbosity.
  • url: str, default=OpenDR FTP URL
    URL of the FTP server., path, verbose)

Save model weights and metadata to path.


  • path: str
    Directory in which to save model weights and metadata.
  • verbose: bool, default=True
    If set to True, enables the maximum logging verbosity.


BisenetLearner.load(self, path)

This method is used to load a previously saved model from its saved folder.


  • path: str
    Local path to save the files.


  • Training example on CamVid train set.

    import os
    from opendr.perception.semantic_segmentation import BisenetLearner
    from opendr.perception.semantic_segmentation import CamVidDataset
    if __name__ == '__main__':
        learner = BisenetLearner()
        # Download CamVid dataset
        if not os.path.exists('./datasets/'):
        datatrain = CamVidDataset('./datasets/CamVid/', mode='train')"./bisenet_saved_model")
  • Evaluation example on CamVid test set.

    import os
    from opendr.perception.semantic_segmentation import BisenetLearner
    from opendr.perception.semantic_segmentation import CamVidDataset
    if __name__ == '__main__':
        learner = BisenetLearner()
        # Download CamVid dataset
        if not os.path.exists('./datasets/'):
        datatest = CamVidDataset('./datasets/CamVid/', mode='test')
        # Download the pretrained model'./bisenet_camvid', mode='pretrained')
        results = learner.eval(dataset=datatest)
        print("Evaluation results = ", results)
  • Inference example on a single test image using a pretrained model.

    import cv2
    from opendr.perception.semantic_segmentation import BisenetLearner
    from import Image
    from matplotlib import cm
    import numpy as np
    if __name__ == '__main__':
        learner = BisenetLearner()
        # Dowload the pretrained model'./bisenet_camvid', mode='pretrained')
        # Download testing image'./', mode='testingImage')
        img ="./test1.png")
        # Perform inference
        heatmap = learner.infer(img)
        # Create a color map and translate colors
        segmentation_mask =
        colormap = cm.get_cmap('viridis', 12).colors
        segmentation_img = np.uint8(255*colormap[segmentation_mask][:, :, :3])
        # Blend original image and the segmentation mask
        blended_img = np.uint8(0.4*img.opencv() + 0.6*segmentation_img)
        cv2.imshow('Heatmap', blended_img)


[1] BiSeNet: Bilateral Segmentation Network for Real-time Semantic Segmentation, arXiv.