-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Understood nbdev, and the need for default_exp
1 parent
940cf74
commit c714d85
Showing
6 changed files
with
252 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
"""Base environment with some basic funcitons""" | ||
|
||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/20_environments/22_envs_pricing/10_base_pricing_env.ipynb. | ||
|
||
# %% auto 0 | ||
__all__ = ['BasePricingEnv'] | ||
|
||
# %% ../../../nbs/20_environments/22_envs_pricing/10_base_pricing_env.ipynb 3 | ||
from abc import ABC, abstractmethod | ||
from typing import Union, Tuple, List | ||
|
||
from ..base import BaseEnvironment | ||
from ...utils import Parameter, MDPInfo | ||
from ...dataloaders.base import BaseDataLoader | ||
from ...loss_functions import pinball_loss | ||
|
||
import gymnasium as gym | ||
|
||
import numpy as np | ||
import time | ||
|
||
# %% ../../../nbs/20_environments/22_envs_pricing/10_base_pricing_env.ipynb 4 | ||
class BasePricingEnv(BaseEnvironment): | ||
""" | ||
Base class for inventory management environments. This class inherits from BaseEnvironment. | ||
""" | ||
|
||
def __init__(self, | ||
|
||
## Parameters for Base env: | ||
mdp_info: MDPInfo, # | ||
postprocessors: list[object] | None = None, # default is empty list | ||
mode: str = "online", # additional mode for the pricing environment TODO: add online mode to training loop | ||
return_truncation: str = True, # whether to return a truncated condition in step function | ||
dataloader: BaseDataLoader = None, # dataloader for the environment | ||
|
||
alpha: Union[float, np.ndarray] = 1, # market size parameter | ||
beta: Union[float, np.ndarray] = 1, # price sensitivity parameter | ||
horizon_train: int = 100 # horizon for the online learning TODO: check if it can be renamed to horizon | ||
) -> None: | ||
|
||
self.dataloader = dataloader | ||
|
||
self.set_param("alpha", alpha, shape=(self.nun_SKUs[0],), new=True) | ||
self.set_param("beta", beta, shape=(self.nun_SKUs[0],), new=True) | ||
|
||
# TODO: check in the base env if train_horizon is needed | ||
super().__init__(mdp_info=mdp_info, postprocessors = postprocessors, mode = mode, return_truncation=return_truncation, horizon_train=horizon_train) | ||
|
||
def set_observation_space(self, | ||
shape: tuple, # shape of the dataloader features | ||
low: Union[np.ndarray, float] = -np.inf, # lower bound of the observation space | ||
high: Union[np.ndarray, float] = np.inf, # upper bound of the observation space | ||
samples_dim_included = True # whether the first dimension of the shape input is the number of samples | ||
) -> None: | ||
|
||
''' | ||
Set the observation space of the environment. | ||
This is a standard function for simple observation spaces. For more complex observation spaces, | ||
this function should be overwritten. Note that it is assumped that the first dimension | ||
is n_samples that is not relevant for the observation space. | ||
''' | ||
|
||
# To handle cases when no external information is available (e.g., parametric NV) | ||
|
||
if shape is None: | ||
self.observation_space = None | ||
|
||
else: | ||
if not isinstance(shape, tuple): | ||
raise ValueError("Shape must be a tuple.") | ||
|
||
if samples_dim_included: | ||
shape = shape[1:] # assumed that the first dimension is the number of samples | ||
|
||
self.observation_space = gym.spaces.Box(low=low, high=high, shape=shape, dtype=np.float32) | ||
|
||
def set_action_space(self, | ||
shape: tuple, # shape of the dataloader target | ||
low: Union[np.ndarray, float] = -np.inf, # lower bound of the observation space | ||
high: Union[np.ndarray, float] = np.inf, # upper bound of the observation space | ||
samples_dim_included = True # whether the first dimension of the shape input is the number of samples | ||
) -> None: | ||
|
||
''' | ||
Set the action space of the environment. | ||
This is a standard function for simple action spaces. For more complex action spaces, | ||
this function should be overwritten. Note that it is assumped that the first dimension | ||
is n_samples that is not relevant for the action space. | ||
''' | ||
|
||
if not isinstance(shape, tuple): | ||
raise ValueError("Shape must be a tuple.") | ||
|
||
if samples_dim_included: | ||
shape = shape[1:] # assumed that the first dimension is the number of samples | ||
|
||
self.action_space = gym.spaces.Box(low=low, high=high, shape=shape, dtype=np.float32) | ||
|
||
def get_observation(self): | ||
|
||
""" | ||
Return the current observation. This function is for the online learning case it will return only the state, | ||
this function should be overwritten. | ||
""" | ||
|
||
X_item, = self.dataloader[self.index] | ||
|
||
return X_item | ||
|
||
def get_demand_response(self, action): | ||
|
||
""" | ||
Return the demand and the reward for the current action. This function should be overwritten. | ||
TODO: add the tuple call to the pricing dataloader | ||
""" | ||
Y_item, epsilon = self.dataloader[self.index, action] | ||
return Y_item, epsilon | ||
def reset(self, | ||
start_index: int | str = None, # index to start from | ||
state: np.ndarray = None # initial state | ||
) -> Tuple[np.ndarray, bool]: | ||
|
||
""" | ||
Reset function for the Newsvendor problem. It will return the first observation and demand. | ||
For val and test modes, it will by default reset to 0, while for the train mode it depends | ||
on the paramter "horizon_train" whether a random point in the training data is selected or 0 | ||
""" | ||
|
||
truncated = self.reset_index(start_index) | ||
|
||
|
||
|
||
observation, self.demand = self.get_observation() | ||
|
||
return observation | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
"""Static dynamic pricing environment where a decision only affects the next period""" | ||
|
||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/20_environments/22_envs_pricing/20_dynamic_pricing_env.ipynb. | ||
|
||
# %% auto 0 | ||
__all__ = ['DynamicPricingEnv'] | ||
|
||
# %% ../../../nbs/20_environments/22_envs_pricing/20_dynamic_pricing_env.ipynb 3 | ||
from abc import ABC, abstractmethod | ||
from typing import Union, Tuple, Literal | ||
|
||
from ...utils import Parameter, MDPInfo | ||
from ...dataloaders.base import BaseDataLoader | ||
from ...loss_functions import pinball_loss, quantile_loss | ||
from .base import BasePricingEnv | ||
|
||
import gymnasium as gym | ||
|
||
import numpy as np | ||
import time | ||
|
||
# %% ../../../nbs/20_environments/22_envs_pricing/20_dynamic_pricing_env.ipynb 4 | ||
class DynamicPricingEnv(BasePricingEnv): | ||
""" | ||
Class implementing the dynamic pricing and learning problem, working for the single- and multi-item case. | ||
If alpha and beta are scalars and they are multiple SKUs, then the same parameters are used for all SKUs. | ||
If alpha and beta are arrays, then they should have the same length as the number of SKUs. | ||
Num_SKUs can be set as parameter or inferrred from the DataLoader. | ||
""" | ||
def __init__(self, | ||
alpha: Union[np.ndarray, Parameter, int, float] = 1.0, # market size per SKUs | ||
beta: Union[np.ndarray, Parameter, int, float] = 0.5, # price elasticity per SKUs | ||
p_bound_low: Union[np.ndarray, Parameter, int, float] = 0.0, # lower price bound per SKUs | ||
p_bound_high: Union[np.ndarray, Parameter, int, float] = 1.0, # upper price bound per SKUs | ||
dataloader: BaseDataLoader = None, # dataloader TODO: replace with pricing dataloader | ||
num_SKUs: Union[np.ndarray, Parameter, int, float] = None, # number of SKUs | ||
gamma: float = 1, # discount factor | ||
horizon_train: int | str = "use_all_data", # if "use_all_data" then horizon is inferred from the DataLoader | ||
postprocessors: list[object] | None = None, # default is empty list | ||
mode: str = "online", # TODO: add online to relevant modes | ||
return_truncation: str = True # TODO:Why is this a string? | ||
) -> None: | ||
|
||
self.print=False | ||
|
||
num_SKUs = dataloader.num_units if num_SKUs is None else num_SKUs | ||
|
||
if not isinstance(num_SKUs, int): | ||
raise ValueError("num_SKUs should be an integer.") | ||
|
||
self.set_param("num_SKUs", num_SKUs, shape=(1,), new=True) | ||
|
||
self.set_param("p_bound_low", p_bound_low, shape=(num_SKUs,), new=True) | ||
self.set_param("p_bound_high", p_bound_high, shape=(num_SKUs,), new=True) | ||
|
||
self.set_observation_space(dataloader.X_shape) | ||
self.set_action_space(dataloader.Y_shape, low = self.p_bound_low, high = self.p_bound_high) | ||
|
||
mdp_info = MDPInfo(self.observation_space, self.action_space, gamma=gamma, horizon=horizon_train) | ||
|
||
super().__init__(mdp_info=mdp_info, | ||
postprocessors=postprocessors, | ||
mode=mode, return_truncation=return_truncation, | ||
alpha=alpha, | ||
beta=beta, | ||
dataloader=dataloader, | ||
horizon_train=horizon_train) | ||
|
||
def step_(self, | ||
action: np.ndarray # prices) | ||
) -> Tuple[np.ndarray, float, bool, bool, dict]: | ||
return observation, reward, terminated, truncated, info |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters