Skip to content

Commit

Permalink
add get_modified_defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
cmbant committed Aug 14, 2024
1 parent 0cc707a commit 34483d2
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 8 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
## 3.5.4

- allow classes to have both yaml and class attributes as long as no duplicate keys
- Allow classes to have both yaml and class attributes as long as no duplicate keys
- Added get_modified_defaults() class method to cobaya components to dynamically set/modify defaults

### Cosmology
- Option to return lensed scalar Cl's from CAMB (without tensors) (thanks @kimmywu})

## 3.5.3
Expand Down
18 changes: 14 additions & 4 deletions cobaya/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,7 @@ def get_bibtex(cls) -> Optional[str]:
from this class, it will return the result from an inherited class if that
provides bibtex.
"""
filename = cls.__dict__.get('bibtex_file')
if filename:
if filename := cls.__dict__.get('bibtex_file'):
bib = cls.get_text_file_content(filename)
else:
bib = cls.get_associated_file_content('.bibtex')
Expand Down Expand Up @@ -298,6 +297,17 @@ def get_defaults(cls, return_yaml=False, yaml_expand_defaults=True,
else:
return defaults

# noinspection PyUnusedLocal
@classmethod
def get_modified_defaults(cls, defaults, input_options=empty_dict):
"""
After defaults dictionary is loaded, you can dynamically modify them here
as needed,e.g. to add or remove defaults['params']. Use this when you don't
want the inheritance-recursive nature of get_defaults() or don't only
want to affect class attributes (like get_class_options() does0.
"""
return defaults

@classmethod
def get_annotations(cls) -> InfoDict:
d = {}
Expand Down Expand Up @@ -329,6 +339,7 @@ def __init__(self, info: InfoDictIn = empty_dict,
if standalone:
# TODO: would probably be more natural if defaults were always read here
default_info = self.get_defaults(input_options=info)
default_info = self.get_modified_defaults(default_info, input_options=info)
default_info.update(info)
info = default_info

Expand Down Expand Up @@ -459,8 +470,7 @@ def add_instance(self, name, component):
self[name] = component

def dump_timing(self):
timers = [component for component in self.values() if component.timer]
if timers:
if timers := [component for component in self.values() if component.timer]:
sep = "\n "
self.log.info(
"Average computation time:" + sep + sep.join(
Expand Down
6 changes: 4 additions & 2 deletions cobaya/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

# Local
from cobaya.conventions import products_path, kinds, separator_files, \
get_chi2_name, get_chi2_label, Extension, FileSuffix, \
packages_path_input
get_chi2_name, get_chi2_label, Extension, FileSuffix, packages_path_input
from cobaya.typing import InputDict, InfoDict, ModelDict, ExpandedParamsDict, LikesDict, \
empty_dict
from cobaya.tools import recursive_update, str_to_list, get_base_classes, \
Expand Down Expand Up @@ -241,6 +240,9 @@ def get_default_info(component_or_class, kind=None, return_yaml=False,
cls.get_defaults(return_yaml=return_yaml,
yaml_expand_defaults=yaml_expand_defaults,
input_options=input_options)
if not return_yaml:
default_component_info = cls.get_modified_defaults(
default_component_info, input_options=input_options)
except ComponentNotFoundError:
raise
except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion cobaya/theory.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,8 @@ def __init__(self, info_theory: TheoriesDict, packages_path=None, timing=None):
logger=self.log, component_path=info.get("python_path"))
self.add_instance(
name, theory_class(
info, packages_path=packages_path, timing=timing, name=name))
info, packages_path=packages_path, timing=timing, name=name,
standalone=False))

def __getattribute__(self, name):
if not name.startswith('_'):
Expand Down

0 comments on commit 34483d2

Please sign in to comment.