-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
pricing env, main TODOs: mode and dataloader.
1 parent
c065b70
commit 940cf74
Showing
2 changed files
with
284 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
124 changes: 124 additions & 0 deletions
124
nbs/20_environments/22_envs_pricing/20_dynamic_pricing_env.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Dynamic Pricing Env\n", | ||
"\n", | ||
"> Static dynamic pricing environment where a decision only affects the next period " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#| hide\n", | ||
"from nbdev.showdoc import *" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#| export\n", | ||
"from abc import ABC, abstractmethod\n", | ||
"from typing import Union, Tuple, Literal\n", | ||
"\n", | ||
"from ddopai.utils import Parameter, MDPInfo\n", | ||
"from ddopai.dataloaders.base import BaseDataLoader\n", | ||
"from ddopai.loss_functions import pinball_loss, quantile_loss\n", | ||
"from ddopai.envs.pricing.base import BasePricingEnv\n", | ||
"\n", | ||
"import gymnasium as gym\n", | ||
"\n", | ||
"import numpy as np\n", | ||
"import time" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# | export\n", | ||
"class DynamicPricingEnv(BasePricingEnv):\n", | ||
" \"\"\"\n", | ||
" Class implementing the dynamic pricing and learning problem, working for the single- and multi-item case.\n", | ||
" If alpha and beta are scalars and they are multiple SKUs, then the same parameters are used for all SKUs.\n", | ||
" If alpha and beta are arrays, then they should have the same length as the number of SKUs.\n", | ||
" Num_SKUs can be set as parameter or inferrred from the DataLoader.\n", | ||
" \"\"\"\n", | ||
" def __init__(self,\n", | ||
" alpha: Union[np.ndarray, Parameter, int, float] = 1.0, # market size per SKUs\n", | ||
" beta: Union[np.ndarray, Parameter, int, float] = 0.5, # price elasticity per SKUs\n", | ||
" p_bound_low: Union[np.ndarray, Parameter, int, float] = 0.0, # lower price bound per SKUs\n", | ||
" p_bound_high: Union[np.ndarray, Parameter, int, float] = 1.0, # upper price bound per SKUs\n", | ||
" dataloader: BaseDataLoader = None, # dataloader TODO: replace with pricing dataloader\n", | ||
" num_SKUs: Union[np.ndarray, Parameter, int, float] = None, # number of SKUs\n", | ||
" gamma: float = 1, # discount factor\n", | ||
" horizon_train: int | str = \"use_all_data\", # if \"use_all_data\" then horizon is inferred from the DataLoader\n", | ||
" postprocessors: list[object] | None = None, # default is empty list \n", | ||
" mode: str = \"online\", # TODO: add online to relevant modes\n", | ||
" return_truncation: str = True # TODO:Why is this a string?\n", | ||
" ) -> None:\n", | ||
"\n", | ||
" self.print=False\n", | ||
" \n", | ||
" num_SKUs = dataloader.num_units if num_SKUs is None else num_SKUs\n", | ||
" \n", | ||
" if not isinstance(num_SKUs, int):\n", | ||
" raise ValueError(\"num_SKUs should be an integer.\")\n", | ||
" \n", | ||
" self.set_param(\"num_SKUs\", num_SKUs, shape=(1,), new=True)\n", | ||
" \n", | ||
" self.set_param(\"p_bound_low\", p_bound_low, shape=(num_SKUs,), new=True)\n", | ||
" self.set_param(\"p_bound_high\", p_bound_high, shape=(num_SKUs,), new=True)\n", | ||
" \n", | ||
" self.set_observation_space(dataloader.X_shape)\n", | ||
" self.set_action_space(dataloader.Y_shape, low = self.p_bound_low, high = self.p_bound_high)\n", | ||
" \n", | ||
" mdp_info = MDPInfo(self.observation_space, self.action_space, gamma=gamma, horizon=horizon_train)\n", | ||
" \n", | ||
" super().__init__(mdp_info=mdp_info,\n", | ||
" postprocessors=postprocessors,\n", | ||
" mode=mode, return_truncation=return_truncation,\n", | ||
" alpha=alpha,\n", | ||
" beta=beta,\n", | ||
" dataloader=dataloader,\n", | ||
" horizon_train=horizon_train)\n", | ||
" \n", | ||
" def step_(self,\n", | ||
" action: np.ndarray # prices)\n", | ||
" ) -> Tuple[np.ndarray, float, bool, bool, dict]:\n", | ||
" return observation, reward, terminated, truncated, info" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "ddop", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.15" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |