Skip to content

Commit

Permalink
Understood nbdev, and the need for default_exp
Browse files Browse the repository at this point in the history
miTTimmiTTim committed Nov 11, 2024
1 parent 940cf74 commit c714d85
Showing 6 changed files with 252 additions and 1 deletion.
20 changes: 20 additions & 0 deletions ddopai/_modidx.py
Original file line number Diff line number Diff line change
@@ -786,6 +786,26 @@
'ddopai/envs/inventory/single_period.py'),
'ddopai.envs.inventory.single_period.NewsvendorEnvVariableSL.set_val_test_sl': ( '20_environments/21_envs_inventory/single_period_envs.html#newsvendorenvvariablesl.set_val_test_sl',
'ddopai/envs/inventory/single_period.py')},
'ddopai.envs.pricing.base': { 'ddopai.envs.pricing.base.BasePricingEnv': ( '20_environments/22_envs_pricing/base_pricing_env.html#basepricingenv',
'ddopai/envs/pricing/base.py'),
'ddopai.envs.pricing.base.BasePricingEnv.__init__': ( '20_environments/22_envs_pricing/base_pricing_env.html#basepricingenv.__init__',
'ddopai/envs/pricing/base.py'),
'ddopai.envs.pricing.base.BasePricingEnv.get_demand_response': ( '20_environments/22_envs_pricing/base_pricing_env.html#basepricingenv.get_demand_response',
'ddopai/envs/pricing/base.py'),
'ddopai.envs.pricing.base.BasePricingEnv.get_observation': ( '20_environments/22_envs_pricing/base_pricing_env.html#basepricingenv.get_observation',
'ddopai/envs/pricing/base.py'),
'ddopai.envs.pricing.base.BasePricingEnv.reset': ( '20_environments/22_envs_pricing/base_pricing_env.html#basepricingenv.reset',
'ddopai/envs/pricing/base.py'),
'ddopai.envs.pricing.base.BasePricingEnv.set_action_space': ( '20_environments/22_envs_pricing/base_pricing_env.html#basepricingenv.set_action_space',
'ddopai/envs/pricing/base.py'),
'ddopai.envs.pricing.base.BasePricingEnv.set_observation_space': ( '20_environments/22_envs_pricing/base_pricing_env.html#basepricingenv.set_observation_space',
'ddopai/envs/pricing/base.py')},
'ddopai.envs.pricing.dynamic': { 'ddopai.envs.pricing.dynamic.DynamicPricingEnv': ( '20_environments/22_envs_pricing/dynamic_pricing_env.html#dynamicpricingenv',
'ddopai/envs/pricing/dynamic.py'),
'ddopai.envs.pricing.dynamic.DynamicPricingEnv.__init__': ( '20_environments/22_envs_pricing/dynamic_pricing_env.html#dynamicpricingenv.__init__',
'ddopai/envs/pricing/dynamic.py'),
'ddopai.envs.pricing.dynamic.DynamicPricingEnv.step_': ( '20_environments/22_envs_pricing/dynamic_pricing_env.html#dynamicpricingenv.step_',
'ddopai/envs/pricing/dynamic.py')},
'ddopai.experiments.experiment_functions': { 'ddopai.experiments.experiment_functions.EarlyStoppingHandler': ( '40_experiments/experiment_functions.html#earlystoppinghandler',
'ddopai/experiments/experiment_functions.py'),
'ddopai.experiments.experiment_functions.EarlyStoppingHandler.__init__': ( '40_experiments/experiment_functions.html#earlystoppinghandler.__init__',
Empty file added ddopai/envs/pricing/__init__.py
Empty file.
140 changes: 140 additions & 0 deletions ddopai/envs/pricing/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""Base environment with some basic funcitons"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/20_environments/22_envs_pricing/10_base_pricing_env.ipynb.

# %% auto 0
__all__ = ['BasePricingEnv']

# %% ../../../nbs/20_environments/22_envs_pricing/10_base_pricing_env.ipynb 3
from abc import ABC, abstractmethod
from typing import Union, Tuple, List

from ..base import BaseEnvironment
from ...utils import Parameter, MDPInfo
from ...dataloaders.base import BaseDataLoader
from ...loss_functions import pinball_loss

import gymnasium as gym

import numpy as np
import time

# %% ../../../nbs/20_environments/22_envs_pricing/10_base_pricing_env.ipynb 4
class BasePricingEnv(BaseEnvironment):
"""
Base class for inventory management environments. This class inherits from BaseEnvironment.
"""

def __init__(self,

## Parameters for Base env:
mdp_info: MDPInfo, #
postprocessors: list[object] | None = None, # default is empty list
mode: str = "online", # additional mode for the pricing environment TODO: add online mode to training loop
return_truncation: str = True, # whether to return a truncated condition in step function
dataloader: BaseDataLoader = None, # dataloader for the environment

alpha: Union[float, np.ndarray] = 1, # market size parameter
beta: Union[float, np.ndarray] = 1, # price sensitivity parameter
horizon_train: int = 100 # horizon for the online learning TODO: check if it can be renamed to horizon
) -> None:

self.dataloader = dataloader

self.set_param("alpha", alpha, shape=(self.nun_SKUs[0],), new=True)
self.set_param("beta", beta, shape=(self.nun_SKUs[0],), new=True)

# TODO: check in the base env if train_horizon is needed
super().__init__(mdp_info=mdp_info, postprocessors = postprocessors, mode = mode, return_truncation=return_truncation, horizon_train=horizon_train)

def set_observation_space(self,
shape: tuple, # shape of the dataloader features
low: Union[np.ndarray, float] = -np.inf, # lower bound of the observation space
high: Union[np.ndarray, float] = np.inf, # upper bound of the observation space
samples_dim_included = True # whether the first dimension of the shape input is the number of samples
) -> None:

'''
Set the observation space of the environment.
This is a standard function for simple observation spaces. For more complex observation spaces,
this function should be overwritten. Note that it is assumped that the first dimension
is n_samples that is not relevant for the observation space.
'''

# To handle cases when no external information is available (e.g., parametric NV)

if shape is None:
self.observation_space = None

else:
if not isinstance(shape, tuple):
raise ValueError("Shape must be a tuple.")

if samples_dim_included:
shape = shape[1:] # assumed that the first dimension is the number of samples

self.observation_space = gym.spaces.Box(low=low, high=high, shape=shape, dtype=np.float32)

def set_action_space(self,
shape: tuple, # shape of the dataloader target
low: Union[np.ndarray, float] = -np.inf, # lower bound of the observation space
high: Union[np.ndarray, float] = np.inf, # upper bound of the observation space
samples_dim_included = True # whether the first dimension of the shape input is the number of samples
) -> None:

'''
Set the action space of the environment.
This is a standard function for simple action spaces. For more complex action spaces,
this function should be overwritten. Note that it is assumped that the first dimension
is n_samples that is not relevant for the action space.
'''

if not isinstance(shape, tuple):
raise ValueError("Shape must be a tuple.")

if samples_dim_included:
shape = shape[1:] # assumed that the first dimension is the number of samples

self.action_space = gym.spaces.Box(low=low, high=high, shape=shape, dtype=np.float32)

def get_observation(self):

"""
Return the current observation. This function is for the online learning case it will return only the state,
this function should be overwritten.
"""

X_item, = self.dataloader[self.index]

return X_item

def get_demand_response(self, action):

"""
Return the demand and the reward for the current action. This function should be overwritten.
TODO: add the tuple call to the pricing dataloader
"""
Y_item, epsilon = self.dataloader[self.index, action]
return Y_item, epsilon
def reset(self,
start_index: int | str = None, # index to start from
state: np.ndarray = None # initial state
) -> Tuple[np.ndarray, bool]:

"""
Reset function for the Newsvendor problem. It will return the first observation and demand.
For val and test modes, it will by default reset to 0, while for the train mode it depends
on the paramter "horizon_train" whether a random point in the training data is selected or 0
"""

truncated = self.reset_index(start_index)



observation, self.demand = self.get_observation()

return observation

72 changes: 72 additions & 0 deletions ddopai/envs/pricing/dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Static dynamic pricing environment where a decision only affects the next period"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../../../nbs/20_environments/22_envs_pricing/20_dynamic_pricing_env.ipynb.

# %% auto 0
__all__ = ['DynamicPricingEnv']

# %% ../../../nbs/20_environments/22_envs_pricing/20_dynamic_pricing_env.ipynb 3
from abc import ABC, abstractmethod
from typing import Union, Tuple, Literal

from ...utils import Parameter, MDPInfo
from ...dataloaders.base import BaseDataLoader
from ...loss_functions import pinball_loss, quantile_loss
from .base import BasePricingEnv

import gymnasium as gym

import numpy as np
import time

# %% ../../../nbs/20_environments/22_envs_pricing/20_dynamic_pricing_env.ipynb 4
class DynamicPricingEnv(BasePricingEnv):
"""
Class implementing the dynamic pricing and learning problem, working for the single- and multi-item case.
If alpha and beta are scalars and they are multiple SKUs, then the same parameters are used for all SKUs.
If alpha and beta are arrays, then they should have the same length as the number of SKUs.
Num_SKUs can be set as parameter or inferrred from the DataLoader.
"""
def __init__(self,
alpha: Union[np.ndarray, Parameter, int, float] = 1.0, # market size per SKUs
beta: Union[np.ndarray, Parameter, int, float] = 0.5, # price elasticity per SKUs
p_bound_low: Union[np.ndarray, Parameter, int, float] = 0.0, # lower price bound per SKUs
p_bound_high: Union[np.ndarray, Parameter, int, float] = 1.0, # upper price bound per SKUs
dataloader: BaseDataLoader = None, # dataloader TODO: replace with pricing dataloader
num_SKUs: Union[np.ndarray, Parameter, int, float] = None, # number of SKUs
gamma: float = 1, # discount factor
horizon_train: int | str = "use_all_data", # if "use_all_data" then horizon is inferred from the DataLoader
postprocessors: list[object] | None = None, # default is empty list
mode: str = "online", # TODO: add online to relevant modes
return_truncation: str = True # TODO:Why is this a string?
) -> None:

self.print=False

num_SKUs = dataloader.num_units if num_SKUs is None else num_SKUs

if not isinstance(num_SKUs, int):
raise ValueError("num_SKUs should be an integer.")

self.set_param("num_SKUs", num_SKUs, shape=(1,), new=True)

self.set_param("p_bound_low", p_bound_low, shape=(num_SKUs,), new=True)
self.set_param("p_bound_high", p_bound_high, shape=(num_SKUs,), new=True)

self.set_observation_space(dataloader.X_shape)
self.set_action_space(dataloader.Y_shape, low = self.p_bound_low, high = self.p_bound_high)

mdp_info = MDPInfo(self.observation_space, self.action_space, gamma=gamma, horizon=horizon_train)

super().__init__(mdp_info=mdp_info,
postprocessors=postprocessors,
mode=mode, return_truncation=return_truncation,
alpha=alpha,
beta=beta,
dataloader=dataloader,
horizon_train=horizon_train)

def step_(self,
action: np.ndarray # prices)
) -> Tuple[np.ndarray, float, bool, bool, dict]:
return observation, reward, terminated, truncated, info
12 changes: 11 additions & 1 deletion nbs/20_environments/22_envs_pricing/10_base_pricing_env.ipynb
Original file line number Diff line number Diff line change
@@ -11,10 +11,20 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| default_exp envs.pricing.base"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"from nbdev.showdoc import *"
]
},
Original file line number Diff line number Diff line change
@@ -9,6 +9,15 @@
"> Static dynamic pricing environment where a decision only affects the next period "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| default_exp envs.pricing.dynamic"
]
},
{
"cell_type": "code",
"execution_count": 1,

0 comments on commit c714d85

Please sign in to comment.