Skip to content
This repository has been archived by the owner on Jun 3, 2020. It is now read-only.

Commit

Permalink
load_model: simplify the model loading with respect to deeposlandia 0…
Browse files Browse the repository at this point in the history
….6.2
  • Loading branch information
delhomer committed May 28, 2020
1 parent 107f22d commit 0e5a51a
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions qdeeplandia.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
QHBoxLayout, QVBoxLayout, QMessageBox, \
QToolBar, QLabel, QCheckBox

os.environ['DEEPOSL_CONFIG']=os.path.join(os.path.dirname(__file__),'config.ini')
from .deeposlandia import postprocess
from .processing_provider.provider import QDeepLandiaProvider
os.environ['DEEPOSL_CONFIG'] = os.path.join(os.path.dirname(__file__), 'config.ini')
from deeposlandia.postprocess import get_trained_model

from .processing_provider.provider import QDeepLandiaProvider
from .gui.NbLabelDialog import NbLabelDialog
from .inferenceTask import InferenceTask

Expand Down Expand Up @@ -61,8 +61,6 @@ def __init__(self, iface):
self.layer = self.updateLayer()
self.nb_labels = None
self.model_path = None
self.datapath = None
self.dataset = None

locale = QSettings().value('locale/userLocale') or 'en_USA'
locale = locale[0:2]
Expand Down Expand Up @@ -147,12 +145,10 @@ def load_trained_model(self):
self.nb_labels = nbLabelDlg.param()
else :
return

self.datapath = os.path.abspath(os.path.join(os.path.dirname(self.model_path), '..', '..', '..', '..'))
self.dataset = os.path.basename(os.path.abspath(os.path.join(os.path.dirname(self.model_path), '..', '..', '..')))

self.image_size = os.path.splitext(os.path.basename(self.model_path))[0].split('-')[-1]
try :
self.model = postprocess.get_trained_model(self.datapath, self.dataset, int(self.image_size), int(self.nb_labels))
self.model = get_trained_model(self.model_path, int(self.image_size), int(self.nb_labels))
except ValueError as e:
self.iface.messageBar().pushMessage(tr("Critical"),
str(e), level=Qgis.Critical)
Expand Down

0 comments on commit 0e5a51a

Please sign in to comment.