Skip to content

Commit

Permalink
fix documentation type hint issues
Browse files Browse the repository at this point in the history
  • Loading branch information
TheEimer committed Sep 25, 2024
1 parent 73c4081 commit f756430
Show file tree
Hide file tree
Showing 15 changed files with 69 additions and 119 deletions.
14 changes: 5 additions & 9 deletions dacbench/abstract_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,8 @@ def get_config(self):
"""Return current configuration.
Returns:
-------
dict
Current config
--------
dict: Current config
"""
return self.config

Expand Down Expand Up @@ -127,7 +125,7 @@ def jsonify_wrappers(self):
"""Write wrapper description to list.
Returns:
-------
--------
list
"""
Expand Down Expand Up @@ -421,10 +419,8 @@ def get_environment(self):
"""Make benchmark environment.
Returns:
-------
env : gym.Env
Benchmark environment
--------
gym.Env: Benchmark environment
"""
raise NotImplementedError

Expand Down
30 changes: 11 additions & 19 deletions dacbench/abstract_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,8 @@ def step_(self):
"""Pre-step function for step count and cutoff.
Returns:
-------
bool
End of episode
--------
bool: End of episode
"""
truncated = False
Expand Down Expand Up @@ -234,7 +233,7 @@ def step(self, action):
Action to take
Returns:
-------
--------
state
Environment state
reward
Expand All @@ -259,7 +258,7 @@ def reset(self, seed: int | None = None):
Seed for the environment
Returns:
-------
--------
state
Environment state
info: dict
Expand All @@ -272,32 +271,26 @@ def get_inst_id(self):
"""Return instance ID.
Returns:
-------
int
ID of current instance
--------
int: ID of current instance
"""
return self.inst_id

def get_instance_set(self):
"""Return instance set.
Returns:
-------
list
List of instances
--------
list: List of instances
"""
return self.instance_set

def get_instance(self):
"""Return current instance.
Returns:
-------
type flexible
Currently used instance
--------
type flexible: Currently used instance
"""
return self.instance

Expand Down Expand Up @@ -489,9 +482,8 @@ def last(self):
"""Get current step data.
Returns:
-------
--------
np.array, float, bool, bool, dict
"""
return (
self.observation,
Expand Down
22 changes: 10 additions & 12 deletions dacbench/benchmarks/luby_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self, config_path=None, config=None):
"""Initialize Luby Benchmark.
Parameters
-------
----------
config_path : str
Path to config file (optional)
"""
Expand All @@ -80,9 +80,8 @@ def get_environment(self):
"""Return Luby env with current configuration.
Returns:
-------
LubyEnv
Luby environment
--------
LubyEnv: Luby environment
"""
if "instance_set" not in self.config:
self.read_instance_set()
Expand All @@ -101,8 +100,8 @@ def set_cutoff(self, steps):
"""Set cutoff and adapt dependencies.
Parameters
-------
int
----------
int:
Maximum number of steps
"""
self.config.cutoff = steps
Expand All @@ -122,8 +121,8 @@ def set_history_length(self, length):
"""Set history length and adapt dependencies.
Parameters
-------
int
----------
int:
History length
"""
self.config.hist_length = length
Expand Down Expand Up @@ -177,7 +176,7 @@ def get_benchmark(self, min_l=8, fuzziness=1.5, seed=0):
"""Get Benchmark from DAC paper.
Parameters
-------
----------
min_l : int
Minimum sequence lenght, was 8, 16 or 32 in the paper
fuzziness : float
Expand All @@ -186,9 +185,8 @@ def get_benchmark(self, min_l=8, fuzziness=1.5, seed=0):
Environment seed
Returns:
-------
env : LubyEnv
Luby environment
--------
LubyEnv: Luby Environment
"""
self.config = objdict(LUBY_DEFAULTS.copy())
self.config.min_steps = min_l
Expand Down
5 changes: 2 additions & 3 deletions dacbench/benchmarks/toysgd_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,8 @@ def get_environment(self):
"""Return SGDEnv env with current configuration.
Returns:
-------
SGDEnv
SGD environment
--------
ToySGDEnv: ToySGD environment
"""
if "instance_set" not in self.config:
self.read_instance_set()
Expand Down
20 changes: 9 additions & 11 deletions dacbench/envs/luby.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, config) -> None:
"""Initialize Luby Env.
Parameters
-------
----------
config : objdict
Environment configuration
"""
Expand Down Expand Up @@ -68,9 +68,8 @@ def step(self, action: int):
action to execute
Returns:
-------
np.array, float, bool, bool, dict
state, reward, terminated, truncated, info
--------
np.array, float, bool, bool, dict: state, reward, terminated, truncated, info
"""
self.done = super().step_()
self.prev_state = self._state.copy()
Expand Down Expand Up @@ -98,9 +97,8 @@ def reset(self, seed=None, options=None) -> list[int]:
"""Resets env.
Returns:
-------
numpy.array
Environment state
--------
numpy.array: Environment state
"""
if options is None:
options = {}
Expand Down Expand Up @@ -142,7 +140,8 @@ def get_default_state(self, _):
_ (_type_): Empty parameter, which can be used when overriding
Returns:
dict: The current state
--------
dict: The current state
"""
if self.c_step == 0:
self._state = [-1 for _ in range(self._hist_len + 1)]
Expand All @@ -159,9 +158,8 @@ def close(self) -> bool:
"""Close Env.
Returns:
-------
bool
Closing confirmation
--------
bool: Closing confirmation
"""
return True

Expand Down
23 changes: 8 additions & 15 deletions dacbench/envs/theory.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,20 +406,17 @@ def get_obs_domain_from_name(var_name): # noqa: ARG004
The observation space will then be created
Returns:
-------
Two int values, e.g., 1, np.inf
--------
Two int values, e.g., 1, np.inf
"""
return 0, np.inf

def reset(self, seed=None, options=None):
"""Resets env.
Returns:
-------
numpy.array
Environment state
--------
numpy.array: Environment state
"""
if options is None:
options = {}
Expand Down Expand Up @@ -475,10 +472,8 @@ def step(self, action):
action to execute
Returns:
-------
state, reward, terminated, truncated, info
np.array, float, bool, bool, dict
--------
np.array, float, bool, bool, dict: state, reward, terminated, truncated, info
"""
truncated = super().step_()

Expand Down Expand Up @@ -592,10 +587,8 @@ def close(self) -> bool:
No additional cleanup necessary
Returns:
-------
bool
Closing confirmation
--------
bool: Closing confirmation
"""
return True

Expand Down
Loading

0 comments on commit f756430

Please sign in to comment.