diff --git a/qdeeplandia.py b/qdeeplandia.py index 8ee99b6..dcf75d5 100644 --- a/qdeeplandia.py +++ b/qdeeplandia.py @@ -28,10 +28,10 @@ QHBoxLayout, QVBoxLayout, QMessageBox, \ QToolBar, QLabel, QCheckBox -os.environ['DEEPOSL_CONFIG']=os.path.join(os.path.dirname(__file__),'config.ini') -from .deeposlandia import postprocess -from .processing_provider.provider import QDeepLandiaProvider +os.environ['DEEPOSL_CONFIG'] = os.path.join(os.path.dirname(__file__), 'config.ini') +from deeposlandia.postprocess import get_trained_model +from .processing_provider.provider import QDeepLandiaProvider from .gui.NbLabelDialog import NbLabelDialog from .inferenceTask import InferenceTask @@ -61,8 +61,6 @@ def __init__(self, iface): self.layer = self.updateLayer() self.nb_labels = None self.model_path = None - self.datapath = None - self.dataset = None locale = QSettings().value('locale/userLocale') or 'en_USA' locale = locale[0:2] @@ -147,12 +145,10 @@ def load_trained_model(self): self.nb_labels = nbLabelDlg.param() else : return - - self.datapath = os.path.abspath(os.path.join(os.path.dirname(self.model_path), '..', '..', '..', '..')) - self.dataset = os.path.basename(os.path.abspath(os.path.join(os.path.dirname(self.model_path), '..', '..', '..'))) + self.image_size = os.path.splitext(os.path.basename(self.model_path))[0].split('-')[-1] try : - self.model = postprocess.get_trained_model(self.datapath, self.dataset, int(self.image_size), int(self.nb_labels)) + self.model = get_trained_model(self.model_path, int(self.image_size), int(self.nb_labels)) except ValueError as e: self.iface.messageBar().pushMessage(tr("Critical"), str(e), level=Qgis.Critical)