Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add query key population_type #229

Merged
merged 9 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ Improvements
- ``properties`` is now a keyword argument in ``EdgePopulation.get``
- Added ``EdgePopulation.stats`` with two methods: ``divergence``, ``convergence``
- Added new notebooks covering node sets as well as node and edge queries
- Added the possibility to query Edge IDs and Node IDs based on edge/node population type using query key ``population_type``

- the types conform to `node types <https://sonata-extension.readthedocs.io/en/latest/sonata_config.html#populations>`_ and `edge types <https://sonata-extension.readthedocs.io/en/latest/sonata_config.html#id4>`_ defined in the sonata specification


Version v3.0.1
Expand Down
2 changes: 1 addition & 1 deletion bluepysnap/edges/edge_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _edge_ids_by_filter(self, queries, raise_missing_prop):
chunk_size = int(1e8)
for chunk in np.array_split(ids, 1 + len(ids) // chunk_size):
data = self.get(chunk, properties - unknown_props)
res.extend(chunk[query.resolve_ids(data, self.name, queries)])
res.extend(chunk[query.resolve_ids(data, self.name, self.type, queries)])
return np.array(res, dtype=IDS_DTYPE)

def ids(self, group=None, limit=None, sample=None, raise_missing_property=True):
Expand Down
4 changes: 3 additions & 1 deletion bluepysnap/nodes/node_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
>>> nodes.ids(group={ Node.LAYER: 2}) # returns list of IDs matching layer==2
>>> nodes.ids(group={ Node.LAYER: [2, 3]}) # returns list of IDs with layer in [2,3]
>>> nodes.ids(group={ Node.X: (0, 1)}) # returns list of IDs with 0 < x < 1
>>> # returns list of IDs of biophysical node populations
>>> nodes.ids(group={ "population_type": "biophysical"})
>>> # returns list of IDs matching one of the queries inside the 'or' list
>>> nodes.ids(group={'$or': [{ Node.LAYER: [2, 3]},
>>> { Node.X: (0, 1), Node.MTYPE: 'L1_SLAC' }]})
Expand Down Expand Up @@ -371,7 +373,7 @@ def _node_ids_by_filter(self, queries, raise_missing_prop):
self._check_properties(properties)
# load all the properties needed to execute the query, excluding the unknown properties
data = self._get_data(properties & self.property_names)
idx = query.resolve_ids(data, self.name, queries)
idx = query.resolve_ids(data, self.name, self.type, queries)
return idx.nonzero()[0]

def ids(self, group=None, limit=None, sample=None, raise_missing_property=True):
Expand Down
29 changes: 22 additions & 7 deletions bluepysnap/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,21 @@
NODE_ID_KEY = "node_id"
EDGE_ID_KEY = "edge_id"
POPULATION_KEY = "population"
POPULATION_TYPE_KEY = "population_type"
OR_KEY = "$or"
AND_KEY = "$and"
REGEX_KEY = "$regex"
NODE_SET_KEY = "$node_set"
VALUE_KEYS = {REGEX_KEY}
ALL_KEYS = {NODE_ID_KEY, EDGE_ID_KEY, POPULATION_KEY, OR_KEY, AND_KEY, NODE_SET_KEY} | VALUE_KEYS
ALL_KEYS = {
NODE_ID_KEY,
EDGE_ID_KEY,
POPULATION_KEY,
POPULATION_TYPE_KEY,
OR_KEY,
AND_KEY,
NODE_SET_KEY,
} | VALUE_KEYS


def _logical_and(masks):
Expand Down Expand Up @@ -92,29 +101,34 @@ def _positional_mask(data, ids):
return True
if isinstance(ids, int):
ids = [ids]
elif len(ids) == 0:
return False
mask = np.full(len(data), fill_value=False)
indices = data.index.get_indexer(ids)
mask[indices[indices > -1]] = True
return mask


def _circuit_mask(data, population_name, queries):
def _circuit_mask(data, population_name, population_type, queries):
"""Handle the population, node ID queries."""
populations = queries.pop(POPULATION_KEY, None)
if populations is not None and population_name not in set(utils.ensure_list(populations)):
types = queries.pop(POPULATION_TYPE_KEY, None)
if populations is not None and population_name not in utils.ensure_list(populations):
ids = []
elif types is not None and population_type not in utils.ensure_list(types):
ids = []
else:
ids = queries.pop(NODE_ID_KEY, queries.pop(EDGE_ID_KEY, None))
return queries, _positional_mask(data, ids)


def _properties_mask(data, population_name, queries):
def _properties_mask(data, population_name, population_type, queries):
"""Return mask of IDs matching `props` dict."""
unknown_props = set(queries) - set(data.columns) - ALL_KEYS
if unknown_props:
return False

queries, mask = _circuit_mask(data, population_name, queries)
queries, mask = _circuit_mask(data, population_name, population_type, queries)
if mask is False or isinstance(mask, np.ndarray) and not mask.any():
# Avoid fail and/or processing time if wrong population or no nodes
return False
Expand Down Expand Up @@ -202,12 +216,13 @@ def _resolve(queries, queries_key):
return resolved_queries


def resolve_ids(data, population_name, queries):
def resolve_ids(data, population_name, population_type, queries):
"""Returns an index mask of `data` for given `queries`.

Args:
data (pd.DataFrame): data
population_name (str): population name of `data`
population_type (str): population type
queries (dict): queries

Returns:
Expand All @@ -229,7 +244,7 @@ def _collect(queries, queries_key):
queries[queries_key] = _logical_and(children_mask)
else:
queries[queries_key] = _properties_mask(
data, population_name, {queries_key: queries[queries_key]}
data, population_name, population_type, {queries_key: queries[queries_key]}
)

queries = deepcopy(queries)
Expand Down
159 changes: 150 additions & 9 deletions doc/source/notebooks/09_node_queries.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"On top of these, a query can also be based on `node_id` or the `population_type`.\n",
"\n",
"When the query is a `dict` and there is a `list` in the query, it is (usually) considered as an \"OR\" condition, and the keys of the query are considered as an \"AND\" condition. E.g.,\n",
"```python\n",
"circuit.nodes.ids({ # give me the ids of nodes that\n",
Expand Down Expand Up @@ -398,6 +400,145 @@
"pd.concat([df for _,df in result])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Querying with population type\n",
"We can query nodes (or edges) based on their population type as specified in the [SONATA circuit configuration file](https://sonata-extension.readthedocs.io/en/latest/sonata_config.html).\n",
"\n",
"Let's find all the source nodes of projections (i.e., `virtual` nodes):"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th></th>\n",
" <th>model_template</th>\n",
" <th>model_type</th>\n",
" </tr>\n",
" <tr>\n",
" <th>population</th>\n",
" <th>node_ids</th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th rowspan=\"5\" valign=\"top\">CorticoThalamic_projections</th>\n",
" <th>0</th>\n",
" <td></td>\n",
" <td>virtual</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td></td>\n",
" <td>virtual</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td></td>\n",
" <td>virtual</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td></td>\n",
" <td>virtual</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td></td>\n",
" <td>virtual</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th rowspan=\"5\" valign=\"top\">MedialLemniscus_projections</th>\n",
" <th>5018</th>\n",
" <td></td>\n",
" <td>virtual</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5019</th>\n",
" <td></td>\n",
" <td>virtual</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5020</th>\n",
" <td></td>\n",
" <td>virtual</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5021</th>\n",
" <td></td>\n",
" <td>virtual</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5022</th>\n",
" <td></td>\n",
" <td>virtual</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>88443 rows × 2 columns</p>\n",
"</div>"
],
"text/plain": [
" model_template model_type\n",
"population node_ids \n",
"CorticoThalamic_projections 0 virtual\n",
" 1 virtual\n",
" 2 virtual\n",
" 3 virtual\n",
" 4 virtual\n",
"... ... ...\n",
"MedialLemniscus_projections 5018 virtual\n",
" 5019 virtual\n",
" 5020 virtual\n",
" 5021 virtual\n",
" 5022 virtual\n",
"\n",
"[88443 rows x 2 columns]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result = circuit.nodes.get({'population_type': ['virtual']})\n",
"pd.concat([df for _,df in result])"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -409,7 +550,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -508,7 +649,7 @@
"[65198 rows x 1 columns]"
]
},
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -532,7 +673,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -562,7 +703,7 @@
"dtype: object"
]
},
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -583,7 +724,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -684,7 +825,7 @@
"28381 135.235031 517.312378 428.180695"
]
},
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -710,7 +851,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -740,7 +881,7 @@
" names=['population', 'node_ids'], length=35567)"
]
},
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -769,7 +910,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"metadata": {},
"outputs": [
{
Expand Down
Loading
Loading