If you want to design a new model, e.g, NewModel in newmodel.py
:
class NewModel(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, input_data):
left_img = input_data['left']
right_img = input_data['right']
...
return {'disp_pred': disp_pred}
def get_loss(self, model_preds, input_data):
disp_gt = input_data["disp"]
disp_pred = model_preds['disp_pred']
In your model class, at least you need to implement get_loss()
and forward()
functions.
then you need to make a trainer.py file
from stereo.modeling.trainer_template import TrainerTemplate
from .newmodel import newmodel
__all__ = {
'newmodel': newmodel,
}
class Trainer(TrainerTemplate):
def __init__(self, args, cfgs, local_rank, global_rank, logger, tb_writer):
model = __all__[cfgs.MODEL.NAME](cfgs.MODEL)
super().__init__(args, cfgs, local_rank, global_rank, logger, tb_writer, model)
After finishing the trainer file, you have several steps left to do:
Step 1: Put your newmodel.py and trainer.py under openstereo/modeling/models
.
Step 2: Resister your model trainer by importing your model in openstereo/modeling/__init__.py
.
Note that the Name of the imported class is the name you should use in the yaml file.
Step 3: Specify the model name in a yaml file:
MODEL:
NAME: newmodel
param1: ...
param2: ...
param3: ...