Skip to content

Commit

Permalink
demand_normalization oos fixed (#16)
Browse files Browse the repository at this point in the history
* fixed error in selection of time_sku features

* fixed error in force save agent

* fixed parameter name for dataset

* accelerated getitem method

* minor fix in predict function

* minor fix in predict function

* enables out-of-sample experiments

* print steps while testing

* fixed error in normalization for oos products

* fixed error in normalizing oos demand

* fixed error in normalization of oos SKUs

* fixed error in normalization of oos SKUs
  • Loading branch information
majoma7 authored Sep 4, 2024
1 parent 22d59ba commit de443eb
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 71 deletions.
14 changes: 6 additions & 8 deletions ddopnew/dataloaders/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,12 +864,13 @@ def normalize_demand_and_features_out_of_sample(self,
# Normalize demand targets
if self.demand_normalization != 'no_normalization':
# Normalizing per SKU on time dimension

self.scaler_out_of_sample_test_demand.fit(self.demand_out_of_sample_test[:self.train_index_end+1])
transformed_demand = self.scaler_out_of_sample_test_demand.transform(self.demand_lag_out_of_sample_test)
transformed_demand = self.scaler_out_of_sample_test_demand.transform(self.demand_out_of_sample_test)
self.demand_out_of_sample_test.iloc[:,:] = transformed_demand

self.scaler_out_of_sample_val_demand.fit(self.demand_out_of_sample_val[:self.train_index_end+1])
transformed_demand = self.scaler_out_of_sample_val_demand.transform(self.demand_lag_out_of_sample_val)
transformed_demand = self.scaler_out_of_sample_val_demand.transform(self.demand_out_of_sample_val)
self.demand_out_of_sample_val.iloc[:,:] = transformed_demand

# Set unit size for demand targets
Expand All @@ -881,11 +882,14 @@ def normalize_demand_and_features_out_of_sample(self,
if self.lag_demand_normalization != self.demand_normalization:
if self.lag_demand_normalization != 'no_normalization':

self.demand_lag_out_of_sample_test = self.demand_out_of_sample_test.copy()

self.demand_lag_out_of_sample_test = self.demand_out_of_sample_test.copy()
self.scaler_out_of_sample_test_demand_lag.fit(self.demand_lag_out_of_sample_test[:self.train_index_end+1])
transformed_demand_lag = self.scaler_out_of_sample_test_demand_lag.transform(self.demand_lag_out_of_sample_test)
self.demand_lag_out_of_sample_test.iloc[:,:] = transformed_demand_lag

self.demand_lag_out_of_sample_val = self.demand_out_of_sample_val.copy()
self.demand_lag_out_of_sample_val = self.demand_out_of_sample_val.copy()
self.scaler_out_of_sample_val_demand_lag.fit(self.demand_lag_out_of_sample_val[:self.train_index_end+1])
transformed_demand_lag = self.scaler_out_of_sample_val_demand_lag.transform(self.demand_lag_out_of_sample_val)
Expand Down Expand Up @@ -932,12 +936,6 @@ def normalize_demand_and_features_out_of_sample(self,

self.normalized_out_of_sample_SKUs = True

print(self.demand_out_of_sample_test)
print(self.demand_out_of_sample_val)

print(self.demand_lag_out_of_sample_test)
print(self.demand_lag_out_of_sample_val)

else:
raise NotImplementedError('Training data can only normalized during initialization - later normlization not implemented yet')

Expand Down
124 changes: 61 additions & 63 deletions nbs/10_dataloaders/12_tabular_dataloaders.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/dataloaders/tabular.py#L21){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/dataloaders/tabular.py#L23){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## XYDataLoader\n",
"\n",
Expand All @@ -341,7 +341,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/dataloaders/tabular.py#L21){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/dataloaders/tabular.py#L23){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## XYDataLoader\n",
"\n",
Expand Down Expand Up @@ -384,7 +384,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/dataloaders/tabular.py#L112){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/dataloaders/tabular.py#L114){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### XYDataLoader.prep_lag_features\n",
"\n",
Expand All @@ -406,7 +406,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/dataloaders/tabular.py#L112){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/dataloaders/tabular.py#L114){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### XYDataLoader.prep_lag_features\n",
"\n",
Expand Down Expand Up @@ -445,7 +445,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/dataloaders/tabular.py#L174){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/dataloaders/tabular.py#L176){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### XYDataLoader.__getitem__\n",
"\n",
Expand All @@ -456,7 +456,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/dataloaders/tabular.py#L174){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/dataloaders/tabular.py#L176){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### XYDataLoader.__getitem__\n",
"\n",
Expand Down Expand Up @@ -484,7 +484,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/dataloaders/tabular.py#L227){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/dataloaders/tabular.py#L229){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### XYDataLoader.get_all_X\n",
"\n",
Expand All @@ -500,7 +500,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/dataloaders/tabular.py#L227){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/dataloaders/tabular.py#L229){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### XYDataLoader.get_all_X\n",
"\n",
Expand Down Expand Up @@ -533,7 +533,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/dataloaders/tabular.py#L247){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/dataloaders/tabular.py#L249){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### XYDataLoader.get_all_Y\n",
"\n",
Expand All @@ -549,7 +549,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/dataloaders/tabular.py#L247){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/dataloaders/tabular.py#L249){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### XYDataLoader.get_all_Y\n",
"\n",
Expand Down Expand Up @@ -588,7 +588,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"sample: [ 0.48934242 -1.5249851 ] [-4.20898833]\n",
"sample: [1.82524154 0.11465852] [4.19037555]\n",
"sample shape Y: (1,)\n",
"length: 100\n"
]
Expand Down Expand Up @@ -627,28 +627,28 @@
"length train: 6 length val: 2 length test: 2\n",
"\n",
"### Data from train set ###\n",
"idx: 0 data: [-0.03919358 -1.81524105] [-5.45285373]\n",
"idx: 1 data: [-0.33238759 -0.25239502] [-2.28257421]\n",
"idx: 2 data: [-0.28088596 1.04119928] [2.99245932]\n",
"idx: 3 data: [-1.41942608 -0.14294466] [-3.4857788]\n",
"idx: 4 data: [-0.94715851 -0.90926992] [-3.88762596]\n",
"idx: 5 data: [ 0.83067309 -1.60634229] [-2.1943678]\n",
"idx: 0 data: [-0.53411714 1.14021542] [3.42541064]\n",
"idx: 1 data: [ 1.11334918 -0.7480284 ] [0.30100931]\n",
"idx: 2 data: [-0.04082374 1.24133612] [3.2659267]\n",
"idx: 3 data: [-1.2857809 -0.12486876] [-1.45752356]\n",
"idx: 4 data: [ 1.56494873 -1.44628854] [-1.45068881]\n",
"idx: 5 data: [-1.11396607 1.20500501] [1.09222255]\n",
"\n",
"### Data from val set ###\n",
"idx: 0 data: [0.31728831 0.1915907 ] [0.61714501]\n",
"idx: 1 data: [0.19605744 0.15569143] [-0.31029128]\n",
"idx: 0 data: [-0.64874237 -0.1925078 ] [-2.86408299]\n",
"idx: 1 data: [-0.95652838 -0.78083716] [-3.82414401]\n",
"\n",
"### Data from test set ###\n",
"idx: 0 data: [ 0.60736842 -1.62492312] [-1.52729382]\n",
"idx: 1 data: [-0.33198421 0.91780232] [2.45557787]\n",
"idx: 0 data: [ 3.34309074 -0.3067389 ] [6.60807345]\n",
"idx: 1 data: [-0.84503599 0.78098089] [0.31357303]\n",
"\n",
"### Data from train set again ###\n",
"idx: 0 data: [-0.03919358 -1.81524105] [-5.45285373]\n",
"idx: 1 data: [-0.33238759 -0.25239502] [-2.28257421]\n",
"idx: 2 data: [-0.28088596 1.04119928] [2.99245932]\n",
"idx: 3 data: [-1.41942608 -0.14294466] [-3.4857788]\n",
"idx: 4 data: [-0.94715851 -0.90926992] [-3.88762596]\n",
"idx: 5 data: [ 0.83067309 -1.60634229] [-2.1943678]\n"
"idx: 0 data: [-0.53411714 1.14021542] [3.42541064]\n",
"idx: 1 data: [ 1.11334918 -0.7480284 ] [0.30100931]\n",
"idx: 2 data: [-0.04082374 1.24133612] [3.2659267]\n",
"idx: 3 data: [-1.2857809 -0.12486876] [-1.45752356]\n",
"idx: 4 data: [ 1.56494873 -1.44628854] [-1.45068881]\n",
"idx: 5 data: [-1.11396607 1.20500501] [1.09222255]\n"
]
}
],
Expand Down Expand Up @@ -702,8 +702,8 @@
{
"data": {
"text/plain": [
"array([[ 0.60736842, -1.62492312],\n",
" [-0.33198421, 0.91780232]])"
"array([[ 3.34309074, -0.3067389 ],\n",
" [-0.84503599, 0.78098089]])"
]
},
"execution_count": null,
Expand All @@ -727,8 +727,8 @@
{
"data": {
"text/plain": [
"array([[-1.52729382],\n",
" [ 2.45557787]])"
"array([[6.60807345],\n",
" [0.31357303]])"
]
},
"execution_count": null,
Expand Down Expand Up @@ -764,36 +764,36 @@
"length train: 4 length val: 2 length test: 2\n",
"\n",
"### Data from train set ###\n",
"idx: 0 data: [[ 1.31636985 0.7897913 -0.75019741]\n",
" [-1.0363052 -0.33563099 5.30310363]] [-1.73725544]\n",
"idx: 1 data: [[-1.0363052 -0.33563099 5.30310363]\n",
" [ 0.61748301 0.12942096 -1.73725544]] [0.75991095]\n",
"idx: 2 data: [[ 0.61748301 0.12942096 -1.73725544]\n",
" [ 0.97507757 0.60012032 0.75991095]] [3.60080094]\n",
"idx: 3 data: [[0.97507757 0.60012032 0.75991095]\n",
" [0.05424665 0.05414227 3.60080094]] [1.06903634]\n",
"idx: 0 data: [[ 1.11387748e+00 4.58463692e-01 5.18417892e+00]\n",
" [ 2.36921073e+00 -5.79605424e-04 2.04994528e+00]] [4.2408043]\n",
"idx: 1 data: [[ 2.36921073e+00 -5.79605424e-04 2.04994528e+00]\n",
" [ 1.79162404e-01 8.47107363e-02 4.24080430e+00]] [1.9189011]\n",
"idx: 2 data: [[ 0.1791624 0.08471074 4.2408043 ]\n",
" [ 0.51586155 -0.03276417 1.9189011 ]] [0.62122014]\n",
"idx: 3 data: [[ 0.51586155 -0.03276417 1.9189011 ]\n",
" [ 1.25322782 -0.25253519 0.62122014]] [1.59908837]\n",
"\n",
"### Data from val set ###\n",
"idx: 0 data: [[ 0.05424665 0.05414227 3.60080094]\n",
" [-0.19670518 2.25039121 1.06903634]] [6.53359576]\n",
"idx: 1 data: [[-0.19670518 2.25039121 1.06903634]\n",
" [-1.84005742 -0.24281547 6.53359576]] [-3.96123686]\n",
"idx: 0 data: [[ 1.25322782 -0.25253519 0.62122014]\n",
" [-0.08575841 0.91153206 1.59908837]] [2.94348434]\n",
"idx: 1 data: [[-0.08575841 0.91153206 1.59908837]\n",
" [ 0.07725105 0.76471112 2.94348434]] [2.01289203]\n",
"\n",
"### Data from test set ###\n",
"idx: 0 data: [[-1.84005742 -0.24281547 6.53359576]\n",
" [ 0.53974671 1.48055778 -3.96123686]] [5.10164607]\n",
"idx: 1 data: [[ 0.53974671 1.48055778 -3.96123686]\n",
" [ 0.0885949 1.45853039 5.10164607]] [5.47333133]\n",
"idx: 0 data: [[ 0.07725105 0.76471112 2.94348434]\n",
" [-0.37725899 0.85497758 2.01289203]] [2.79373294]\n",
"idx: 1 data: [[-0.37725899 0.85497758 2.01289203]\n",
" [-0.66416681 0.79194957 2.79373294]] [1.14615794]\n",
"\n",
"### Data from train set again ###\n",
"idx: 0 data: [[ 1.31636985 0.7897913 -0.75019741]\n",
" [-1.0363052 -0.33563099 5.30310363]] [-1.73725544]\n",
"idx: 1 data: [[-1.0363052 -0.33563099 5.30310363]\n",
" [ 0.61748301 0.12942096 -1.73725544]] [0.75991095]\n",
"idx: 2 data: [[ 0.61748301 0.12942096 -1.73725544]\n",
" [ 0.97507757 0.60012032 0.75991095]] [3.60080094]\n",
"idx: 3 data: [[0.97507757 0.60012032 0.75991095]\n",
" [0.05424665 0.05414227 3.60080094]] [1.06903634]\n"
"idx: 0 data: [[ 1.11387748e+00 4.58463692e-01 5.18417892e+00]\n",
" [ 2.36921073e+00 -5.79605424e-04 2.04994528e+00]] [4.2408043]\n",
"idx: 1 data: [[ 2.36921073e+00 -5.79605424e-04 2.04994528e+00]\n",
" [ 1.79162404e-01 8.47107363e-02 4.24080430e+00]] [1.9189011]\n",
"idx: 2 data: [[ 0.1791624 0.08471074 4.2408043 ]\n",
" [ 0.51586155 -0.03276417 1.9189011 ]] [0.62122014]\n",
"idx: 3 data: [[ 0.51586155 -0.03276417 1.9189011 ]\n",
" [ 1.25322782 -0.25253519 0.62122014]] [1.59908837]\n"
]
}
],
Expand Down Expand Up @@ -1444,12 +1444,13 @@
" # Normalize demand targets\n",
" if self.demand_normalization != 'no_normalization':\n",
" # Normalizing per SKU on time dimension\n",
"\n",
" self.scaler_out_of_sample_test_demand.fit(self.demand_out_of_sample_test[:self.train_index_end+1])\n",
" transformed_demand = self.scaler_out_of_sample_test_demand.transform(self.demand_lag_out_of_sample_test)\n",
" transformed_demand = self.scaler_out_of_sample_test_demand.transform(self.demand_out_of_sample_test)\n",
" self.demand_out_of_sample_test.iloc[:,:] = transformed_demand\n",
"\n",
" self.scaler_out_of_sample_val_demand.fit(self.demand_out_of_sample_val[:self.train_index_end+1])\n",
" transformed_demand = self.scaler_out_of_sample_val_demand.transform(self.demand_lag_out_of_sample_val)\n",
" transformed_demand = self.scaler_out_of_sample_val_demand.transform(self.demand_out_of_sample_val)\n",
" self.demand_out_of_sample_val.iloc[:,:] = transformed_demand\n",
" \n",
" # Set unit size for demand targets\n",
Expand All @@ -1462,11 +1463,14 @@
" if self.lag_demand_normalization != 'no_normalization':\n",
" \n",
" self.demand_lag_out_of_sample_test = self.demand_out_of_sample_test.copy()\n",
" \n",
" self.demand_lag_out_of_sample_test = self.demand_out_of_sample_test.copy()\n",
" self.scaler_out_of_sample_test_demand_lag.fit(self.demand_lag_out_of_sample_test[:self.train_index_end+1])\n",
" transformed_demand_lag = self.scaler_out_of_sample_test_demand_lag.transform(self.demand_lag_out_of_sample_test)\n",
" self.demand_lag_out_of_sample_test.iloc[:,:] = transformed_demand_lag\n",
" \n",
" self.demand_lag_out_of_sample_val = self.demand_out_of_sample_val.copy()\n",
" self.demand_lag_out_of_sample_val = self.demand_out_of_sample_val.copy()\n",
" self.scaler_out_of_sample_val_demand_lag.fit(self.demand_lag_out_of_sample_val[:self.train_index_end+1])\n",
" transformed_demand_lag = self.scaler_out_of_sample_val_demand_lag.transform(self.demand_lag_out_of_sample_val)\n",
" self.demand_lag_out_of_sample_val.iloc[:,:] = transformed_demand_lag\n",
Expand Down Expand Up @@ -1512,12 +1516,6 @@
" \n",
" self.normalized_out_of_sample_SKUs = True\n",
"\n",
" print(self.demand_out_of_sample_test)\n",
" print(self.demand_out_of_sample_val)\n",
"\n",
" print(self.demand_lag_out_of_sample_test)\n",
" print(self.demand_lag_out_of_sample_val)\n",
"\n",
" else:\n",
" raise NotImplementedError('Training data can only normalized during initialization - later normlization not implemented yet')\n",
"\n",
Expand Down

0 comments on commit de443eb

Please sign in to comment.