Skip to content

Commit

Permalink
new example notebooks!!
Browse files Browse the repository at this point in the history
  • Loading branch information
gAldeia committed Jun 6, 2024
1 parent 6fd411d commit 3112a18
Show file tree
Hide file tree
Showing 5 changed files with 869 additions and 184 deletions.
365 changes: 365 additions & 0 deletions docs/guide/archive.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,365 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# The archive\n",
"\n",
"When you fit a brush estimator, two new attributes are created: `best_estimator_` and `archive_`.\n",
"\n",
"If you set `use_arch` to `True` when instantiating the estimator, then it will store the pareto front as a list in `archive_`. This pareto front is always created with individuals from the final population that are not dominated in objectives **error** and **complexity**.\n",
"\n",
"In case you need more flexibility, the archive will contain the entire final population if `use_arch` is `False`, and you can iterate through this list to select individuals with different criteria. It is also good to remind that Brush supports different optimization objectives using the argument `objectives`.\n",
"\n",
"Each element from the archive is a serialized individual (JSON object)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from pybrush import BrushClassifier\n",
"\n",
"# load data\n",
"df = pd.read_csv('../examples/datasets/d_analcatdata_aids.csv')\n",
"X = df.drop(columns='target')\n",
"y = df['target']"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Completed 100% [====================]\n",
"score: 0.7\n"
]
}
],
"source": [
"est = BrushClassifier(\n",
" functions=['SplitBest','Add','Mul','Sin','Cos','Exp','Logabs'],\n",
" use_arch=True,\n",
" max_gens=100,\n",
" verbosity=1\n",
")\n",
"\n",
"est.fit(X,y)\n",
"y_pred = est.predict(X)\n",
"print('score:', est.score(X,y))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can see individuals from archive using the index:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5\n"
]
},
{
"data": {
"text/plain": [
"{'fitness': {'complexity': 80,\n",
" 'crowding_dist': 0.0,\n",
" 'dcounter': 0,\n",
" 'depth': 3,\n",
" 'dominated': [],\n",
" 'loss': 0.5091069936752319,\n",
" 'loss_v': 0.5091069936752319,\n",
" 'rank': 1,\n",
" 'size': 12,\n",
" 'values': [0.5091069936752319, 12.0],\n",
" 'weights': [-1.0, -1.0],\n",
" 'wvalues': [-0.5091069936752319, -12.0]},\n",
" 'id': 10060,\n",
" 'objectives': ['error', 'size'],\n",
" 'parent_id': [9628],\n",
" 'program': {'Tree': [{'W': 15890.5,\n",
" 'arg_types': ['ArrayF', 'ArrayF'],\n",
" 'center_op': True,\n",
" 'feature': 'AIDS',\n",
" 'fixed': False,\n",
" 'is_weighted': False,\n",
" 'name': 'SplitBest',\n",
" 'node_type': 'SplitBest',\n",
" 'prob_change': 1.0,\n",
" 'ret_type': 'ArrayF',\n",
" 'sig_dual_hash': 9996486434638833164,\n",
" 'sig_hash': 10001460114883919497},\n",
" {'W': 1.0,\n",
" 'arg_types': ['ArrayF'],\n",
" 'center_op': True,\n",
" 'feature': '',\n",
" 'fixed': False,\n",
" 'is_weighted': False,\n",
" 'name': 'Logabs',\n",
" 'node_type': 'Logabs',\n",
" 'prob_change': 1.0,\n",
" 'ret_type': 'ArrayF',\n",
" 'sig_dual_hash': 10617925524997611780,\n",
" 'sig_hash': 13326223354425868050},\n",
" {'W': 2.7182815074920654,\n",
" 'arg_types': [],\n",
" 'center_op': True,\n",
" 'feature': 'Cf',\n",
" 'fixed': False,\n",
" 'is_weighted': False,\n",
" 'name': 'Constant',\n",
" 'node_type': 'Constant',\n",
" 'prob_change': 1.0,\n",
" 'ret_type': 'ArrayF',\n",
" 'sig_dual_hash': 509529941281334733,\n",
" 'sig_hash': 17717457037689164349},\n",
" {'W': 1572255.5,\n",
" 'arg_types': ['ArrayF', 'ArrayF'],\n",
" 'center_op': True,\n",
" 'feature': 'Total',\n",
" 'fixed': False,\n",
" 'is_weighted': False,\n",
" 'name': 'SplitBest',\n",
" 'node_type': 'SplitBest',\n",
" 'prob_change': 1.0,\n",
" 'ret_type': 'ArrayF',\n",
" 'sig_dual_hash': 9996486434638833164,\n",
" 'sig_hash': 10001460114883919497},\n",
" {'W': 0.2222222238779068,\n",
" 'arg_types': [],\n",
" 'center_op': True,\n",
" 'feature': 'MeanLabel',\n",
" 'fixed': False,\n",
" 'is_weighted': True,\n",
" 'name': 'MeanLabel',\n",
" 'node_type': 'MeanLabel',\n",
" 'prob_change': 1.0,\n",
" 'ret_type': 'ArrayF',\n",
" 'sig_dual_hash': 509529941281334733,\n",
" 'sig_hash': 17717457037689164349},\n",
" {'W': 0.5217871069908142,\n",
" 'arg_types': [],\n",
" 'center_op': True,\n",
" 'feature': 'Cf',\n",
" 'fixed': False,\n",
" 'is_weighted': False,\n",
" 'name': 'Constant',\n",
" 'node_type': 'Constant',\n",
" 'prob_change': 1.0,\n",
" 'ret_type': 'ArrayF',\n",
" 'sig_dual_hash': 509529941281334733,\n",
" 'sig_hash': 17717457037689164349}],\n",
" 'is_fitted_': True}}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(len(est.archive_[0]))\n",
"\n",
"est.archive_[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And you can call `predict` (or `predict_proba`, if your `est` is an instance of `BrushClassifier`) with the entire archive:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'id': 10060,\n",
" 'y_pred': array([False, True, True, True, True, False, True, True, True,\n",
" False, True, True, True, True, False, True, True, True,\n",
" True, True, True, True, True, True, True, False, False,\n",
" False, False, False, False, False, False, False, False, False,\n",
" False, False, True, False, True, True, True, True, True,\n",
" True, True, True, True, True])},\n",
" {'id': 9789,\n",
" 'y_pred': array([False, True, True, True, True, False, True, True, True,\n",
" False, True, True, True, True, False, True, True, True,\n",
" True, True, True, True, True, True, True, False, False,\n",
" False, False, False, False, False, False, False, False, False,\n",
" False, False, True, False, True, True, True, True, True,\n",
" True, True, True, True, True])},\n",
" {'id': 10049,\n",
" 'y_pred': array([False, True, True, True, True, False, True, True, True,\n",
" False, False, True, True, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False])},\n",
" {'id': 4384,\n",
" 'y_pred': array([False, True, True, True, True, False, True, True, True,\n",
" False, False, True, True, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False])},\n",
" {'id': 9692,\n",
" 'y_pred': array([ True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True, True, True, True, True,\n",
" True, True, True, True, True])},\n",
" {'id': 9552,\n",
" 'y_pred': array([False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False, False, False, False, False,\n",
" False, False, False, False, False])}]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"est.predict_archive(X)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'id': 10060,\n",
" 'y_pred': array([0.22222222, 0.9999999 , 0.9999999 , 0.9999999 , 0.9999999 ,\n",
" 0.22222222, 0.9999999 , 0.9999999 , 0.9999999 , 0.22222222,\n",
" 0.5217871 , 0.9999999 , 0.9999999 , 0.5217871 , 0.22222222,\n",
" 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 ,\n",
" 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 ,\n",
" 0.22222222, 0.22222222, 0.22222222, 0.22222222, 0.22222222,\n",
" 0.22222222, 0.22222222, 0.22222222, 0.22222222, 0.22222222,\n",
" 0.22222222, 0.22222222, 0.22222222, 0.5217871 , 0.22222222,\n",
" 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 ,\n",
" 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 ],\n",
" dtype=float32)},\n",
" {'id': 9789,\n",
" 'y_pred': array([0.22222222, 0.99994993, 0.99994993, 0.99994993, 0.99994993,\n",
" 0.22222222, 0.99994993, 0.99994993, 0.99994993, 0.22222222,\n",
" 0.5217871 , 0.99994993, 0.99994993, 0.5217871 , 0.22222222,\n",
" 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 ,\n",
" 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 ,\n",
" 0.22222222, 0.22222222, 0.22222222, 0.22222222, 0.22222222,\n",
" 0.22222222, 0.22222222, 0.22222222, 0.22222222, 0.22222222,\n",
" 0.22222222, 0.22222222, 0.22222222, 0.5217871 , 0.22222222,\n",
" 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 ,\n",
" 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 , 0.5217871 ],\n",
" dtype=float32)},\n",
" {'id': 10049,\n",
" 'y_pred': array([0.39024392, 0.9999999 , 0.9999999 , 0.9999999 , 0.9999999 ,\n",
" 0.39024392, 0.9999999 , 0.9999999 , 0.9999999 , 0.39024392,\n",
" 0.39024392, 0.9999999 , 0.9999999 , 0.39024392, 0.39024392,\n",
" 0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,\n",
" 0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,\n",
" 0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,\n",
" 0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,\n",
" 0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,\n",
" 0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,\n",
" 0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392],\n",
" dtype=float32)},\n",
" {'id': 4384,\n",
" 'y_pred': array([0.39024392, 0.9999522 , 0.9999522 , 0.9999522 , 0.9999522 ,\n",
" 0.39024392, 0.9999522 , 0.9999522 , 0.9999522 , 0.39024392,\n",
" 0.39024392, 0.9999522 , 0.9999522 , 0.39024392, 0.39024392,\n",
" 0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,\n",
" 0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,\n",
" 0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,\n",
" 0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,\n",
" 0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,\n",
" 0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392,\n",
" 0.39024392, 0.39024392, 0.39024392, 0.39024392, 0.39024392],\n",
" dtype=float32)},\n",
" {'id': 9692,\n",
" 'y_pred': array([0.5317098 , 0.93985564, 0.9835824 , 0.8686745 , 0.68970597,\n",
" 0.53089285, 0.8455727 , 0.9291562 , 0.7663612 , 0.6237519 ,\n",
" 0.5169323 , 0.7368382 , 0.794476 , 0.63628834, 0.5578266 ,\n",
" 0.50047225, 0.50908357, 0.51443684, 0.506959 , 0.50320625,\n",
" 0.5003231 , 0.50484663, 0.5051821 , 0.50173986, 0.5005965 ,\n",
" 0.5060892 , 0.5592239 , 0.56642807, 0.5267187 , 0.5222307 ,\n",
" 0.5185086 , 0.64804167, 0.68591666, 0.5714386 , 0.5314499 ,\n",
" 0.50612646, 0.5576549 , 0.5636914 , 0.5241404 , 0.5113072 ,\n",
" 0.50007457, 0.5010315 , 0.5013173 , 0.50085753, 0.50068355,\n",
" 0.5000373 , 0.50096935, 0.50095695, 0.5003852 , 0.500174 ],\n",
" dtype=float32)},\n",
" {'id': 9552,\n",
" 'y_pred': array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,\n",
" 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,\n",
" 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,\n",
" 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],\n",
" dtype=float32)}]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"est.predict_proba_archive(X)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "brush",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 2 additions & 0 deletions docs/guide/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@ data
search_space
working_with_programs
json
saving_loading_populations
archive
deap
```
Loading

0 comments on commit 3112a18

Please sign in to comment.