Skip to content

Commit

Permalink
Fix raster and trace (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomdele authored May 27, 2020
1 parent d3b4b53 commit 018426e
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions bluepysnap/_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def spikes_firing_rate_histogram(filtered_report, time_binsize=None, ax=None):
return ax


def spike_raster(filtered_report, y_axis="node_id", ax=None): # pragma: no cover
def spike_raster(filtered_report, y_axis=None, ax=None): # pragma: no cover
"""Spike raster plot.
Shows a global overview of the circuit's firing nodes. The y axis can project either the
Expand Down Expand Up @@ -122,7 +122,7 @@ def spike_raster(filtered_report, y_axis="node_id", ax=None): # pragma: no cove
}

def _update_raster_properties():
if y_axis == "node_id":
if y_axis is None:
props["node_id_offset"] += spikes.nodes.size
props["pop_separators"].append(props["node_id_offset"])
elif pd.api.types.is_categorical_dtype(spikes.nodes.property_dtypes[y_axis]):
Expand All @@ -141,7 +141,7 @@ def _update_raster_properties():
for population in population_names:
spikes = spike_report[population]
mask = report["population"] == population
if y_axis == "node_id":
if y_axis is None:
data.loc[mask] = report.loc[mask, "ids"] + props["node_id_offset"]
else:
ids = report.loc[mask, "ids"].to_numpy()
Expand All @@ -160,7 +160,7 @@ def _update_raster_properties():
ax.set_xlabel("Time [ms]")
ax.tick_params(axis='y', which='both', length=0)
ax.set_xlim(spike_report.time_start, spike_report.time_stop)
if y_axis == "node_id":
if y_axis is None:
ax.set_ylim(0, props["node_id_offset"])
ax.set_ylabel("nodes")
else:
Expand Down Expand Up @@ -352,7 +352,7 @@ def frame_trace(filtered_report, plot_type='mean', ax=None): # pragma: no cover
elif plot_type == "all":
ax.set_ylabel('Voltage [{}]'.format(data_units))
ax.set_xlabel("Time [{}]".format(filtered_report.frame_report.time_units))
ax.set_xlim([filtered_report.t_start, filtered_report.t_stop])
ax.set_xlim([filtered_report.report.index.min(), filtered_report.report.index.max()])

if plot_type == "mean":
ax.plot(filtered_report.report.T.mean())
Expand Down

0 comments on commit 018426e

Please sign in to comment.