Skip to content

Commit

Permalink
Fix overwrite config file
Browse files Browse the repository at this point in the history
  • Loading branch information
rspwarnaarUT committed Dec 12, 2024
1 parent 34c2738 commit 090fb7a
Showing 1 changed file with 32 additions and 7 deletions.
39 changes: 32 additions & 7 deletions resurfemg/data_connector/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,29 @@ class Config:

required_directories = ['root_data']

def __init__(self, location=None, verbose=False):
def __init__(self, location=None, verbose=False, force=False):
"""
This function initializes the configuration file. If no location is
specified it will try to load the configuration file from the default
locations:
- ./config.json
- ~/.resurfemg/config.json
- /etc/resurfemg/config.json
- PROJECT_ROOT/config.json
-----------------------------------------------------------------------
:param location: The location of the configuration file.
:type location: str
:param verbose: A boolean to print the loaded configuration.
:type verbose: bool
:param force: A boolean to overwrite the configuration file.
:type force: bool
:raises ValueError: If the configuration file is not found.
"""
self._raw = None
self._loaded = None
self.example = 'config_example_resurfemg.json'
self.repo_root = find_repo_root(self.example)
self.force = force
self.created_config = False
# In the ResurfEMG project, the test data is stored in ./test_data
test_path = os.path.join(self.repo_root, 'test_data')
Expand Down Expand Up @@ -138,6 +156,7 @@ def usage(self):
def create_config_from_example(
self,
location: str,
force=False,
):
"""
This function creates a config file from an example file.
Expand All @@ -147,10 +166,15 @@ def create_config_from_example(
:raises ValueError: If the example file is not found.
"""
config_path = location.replace(self.example, 'config.json')
with open(location, 'r') as f:
example = json.load(f)
with open(config_path, 'w') as f:
json.dump(example, f, indent=4, sort_keys=True)
if os.path.isfile(config_path) and not force:
raise ValueError(
f'Config file already exists at {config_path}.'
+ ' Use `force=True` to overwrite.')
else:
with open(location, 'r') as f:
example = json.load(f)
with open(config_path, 'w') as f:
json.dump(example, f, indent=4, sort_keys=True)

def load(self, location, verbose=False):
"""
Expand Down Expand Up @@ -180,9 +204,10 @@ def load(self, location, verbose=False):
logging.info('Failed to load %s: %s', _path, e)
else:
if (self.repo_root is not None and os.path.isfile(
os.path.join(self.repo_root, 'config.json'))):
os.path.join(self.repo_root, self.example))):
self.create_config_from_example(
os.path.join(self.repo_root, self.example))
os.path.join(self.repo_root, self.example),
force=self.force,)
root_path = os.path.join(self.repo_root, 'not_pushed')
if not os.path.isdir(root_path):
os.makedirs(root_path)
Expand Down

0 comments on commit 090fb7a

Please sign in to comment.