# Copyright 2017. Allen Institute. All rights reserved
#
# Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
# following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following
# disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote
# products derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
# INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import os
import csv
import h5py
from six import string_types
import pandas as pd
import numpy as np
import warnings
import matplotlib.pyplot as plt
import matplotlib.cm as cmx
import matplotlib.colors as colors
import matplotlib.gridspec as gridspec
import bmtk.simulator.utils.config as config
from bmtk.utils.reports.spike_trains.plotting import plot_raster, plot_rates # , plot_raster_cmp
from mpl_toolkits.axes_grid1 import make_axes_locatable
def _create_node_table(node_file, node_type_file, group_key=None, exclude=[]):
"""Creates a merged nodes.csv and node_types.csv dataframe with excluded items removed. Returns a dataframe."""
node_types_df = pd.read_csv(node_type_file, sep=' ', index_col='node_type_id')
nodes_h5 = h5py.File(node_file)
# TODO: Use utils.spikesReader
node_pop_name = list(nodes_h5['/nodes'])[0]
nodes_grp = nodes_h5['/nodes'][node_pop_name]
# TODO: Need to be able to handle gid or node_id
nodes_df = pd.DataFrame({'node_id': nodes_grp['node_id'], 'node_type_id': nodes_grp['node_type_id']})
#nodes_df = pd.DataFrame({'node_id': nodes_h5['/nodes/node_gid'], 'node_type_id': nodes_h5['/nodes/node_type_id']})
nodes_df.set_index('node_id', inplace=True)
# nodes_df = pd.read_csv(node_file, sep=' ', index_col='node_id')
full_df = pd.merge(left=nodes_df, right=node_types_df, how='left', left_on='node_type_id', right_index=True)
if group_key is not None and len(exclude) > 0:
# Make sure sure we group-key exists as column
if group_key not in full_df:
raise Exception('Could not find column {}'.format(group_key))
group_keys = set(nodes_df[group_key].unique()) - set(exclude)
groupings = nodes_df.groupby(group_key)
# remove any rows with matching column value
for cond in exclude:
full_df = full_df[full_df[group_key] != cond]
nodes_h5.close()
return full_df
def _count_spikes(spikes_file, max_gid, interval=None):
def parse_line(line):
ts, gid = line.strip().split(' ')
return float(ts), int(gid)
if interval is None:
t_max = t_bounds_low = -1.0
t_min = t_bounds_high = 1e16
elif hasattr(interval, "__getitem__") and len(interval) == 2:
t_min = t_bounds_low = interval[0]
t_max = t_bounds_high = interval[1]
elif isinstance(interval, float):
t_max = t_min = t_bounds_low = interval[0]
t_bounds_high = 1e16
else:
raise Exception("Unable to determine interval.")
max_gid = int(max_gid) # strange bug where max_gid was being returned as a float.
spikes = [[] for _ in range(max_gid+1)]
spike_sums = np.zeros(max_gid+1)
# TODO: Use utils.spikesReader
spikes_h5 = h5py.File(spikes_file, 'r')
#print spikes_h5['/spikes'].keys()
gid_ds = spikes_h5['/spikes/gids']
ts_ds = spikes_h5['/spikes/timestamps']
for i in range(len(gid_ds)):
ts = ts_ds[i]
gid = gid_ds[i]
if gid <= max_gid and t_bounds_low <= ts <= t_bounds_high:
spikes[gid].append(ts)
spike_sums[gid] += 1
t_min = ts if ts < t_min else t_min
t_max = ts if ts > t_max else t_max
"""
with open(spikes_file, 'r') as fspikes:
for line in fspikes:
ts, gid = parse_line(line)
if gid <= max_gid and t_bounds_low <= ts <= t_bounds_high:
spikes[gid].append(ts)
spike_sums[gid] += 1
t_min = ts if ts < t_min else t_min
t_max = ts if ts > t_max else t_max
"""
spikes_h5.close()
return spikes, spike_sums/(float(t_max-t_min)*1e-3)
[docs]def plot_spikes_config(configure, group_key=None, exclude=[], save_as=None, show_plot=True):
warnings.warn('Deprecated: Please use bmtk.analyzer.spike_trains.plot_raster instead.', DeprecationWarning)
if isinstance(configure, string_types):
conf = config.from_json(configure)
elif isinstance(configure, dict):
conf = configure
else:
raise Exception("configure variable must be either a json dictionary or json file name.")
cells_file_name = conf['internal']['nodes']
cell_models_file_name = conf['internal']['node_types']
spikes_file = conf['output']['spikes_ascii']
plot_spikes(cells_file_name, cell_models_file_name, spikes_file, group_key, exclude, save_as, show_plot)
[docs]def plot_spikes(cells_file, cell_models_file, spikes_file, population=None, group_key=None, exclude=[], save_as=None,
show=True, title=None, legend=True, font_size=None):
warnings.warn('Deprecated: Please use bmtk.analyzer.spike_trains.plot_raster instead.', DeprecationWarning)
# check if can be shown and/or saved
#if save_as is not None:
# if os.path.exists(save_as):
# raise Exception('file {} already exists. Cannot save.'.format(save_as))
cm_df = pd.read_csv(cell_models_file, sep=' ')
cm_df.set_index('node_type_id', inplace=True)
cells_h5 = h5py.File(cells_file, 'r')
# TODO: Use sonata api
if population is None:
if len(cells_h5['/nodes']) > 1:
raise Exception('Multiple populations in nodes file. Please specify one to plot using population param')
else:
population = list(cells_h5['/nodes'])[0]
nodes_grp = cells_h5['/nodes'][population]
c_df = pd.DataFrame({'node_id': nodes_grp['node_id'], 'node_type_id': nodes_grp['node_type_id']})
# c_df = pd.read_csv(cells_file, sep=' ')
c_df.set_index('node_id', inplace=True)
nodes_df = pd.merge(left=c_df,
right=cm_df,
how='left',
left_on='node_type_id',
right_index=True) # use 'model_id' key to merge, for right table the "model_id" is an index
cells_h5.close()
# TODO: Uses utils.SpikesReader to open
spikes_h5 = h5py.File(spikes_file, 'r')
try:
spike_ids = np.array(spikes_h5['/spikes/gids'], dtype=np.uint)
spike_times = np.array(spikes_h5['/spikes/timestamps'], dtype=np.float)
except:
populations = spikes_h5['/spikes/']
if (len(populations)>1):
raise Exception('TODO: case where there is more than one population in a spike file (they will have overlapping node_ids, so must be shiftered to be plotted!)')
for pop in populations:
spike_ids = np.array(spikes_h5['/spikes/%s/node_ids'%pop], dtype=np.uint)
spike_times = np.array(spikes_h5['/spikes/%s/timestamps'%pop], dtype=np.float)
# spike_times, spike_gids = np.loadtxt(spikes_file, dtype='float32,int', unpack=True)
# spike_gids, spike_times = np.loadtxt(spikes_file, dtype='int,float32', unpack=True)
spikes_h5.close()
spike_times = spike_times * 1.0e-3
if group_key is not None:
if group_key not in nodes_df:
raise Exception('Could not find column {}'.format(group_key))
groupings = nodes_df.groupby(group_key)
n_colors = nodes_df[group_key].nunique()
color_norm = colors.Normalize(vmin=0, vmax=(n_colors-1))
scalar_map = cmx.ScalarMappable(norm=color_norm, cmap='hsv')
color_map = [scalar_map.to_rgba(i) for i in range(0, n_colors)]
else:
groupings = [(None, nodes_df)]
color_map = ['blue']
#marker = '.' if len(nodes_df) > 1000 else 'o'
marker = 'o'
# Create plot
gs = gridspec.GridSpec(2, 1, height_ratios=[7, 1])
import matplotlib
if font_size is not None:
matplotlib.rcParams.update({'font.size': font_size})
plt.xlabel('xlabel', fontsize=font_size)
plt.ylabel('ylabel', fontsize=font_size)
ax1 = plt.subplot(gs[0])
gid_min = 10**10
gid_max = -1
for color, (group_name, group_df) in zip(color_map, groupings):
if group_name in exclude:
continue
group_min_gid = min(group_df.index.tolist())
group_max_gid = max(group_df.index.tolist())
gid_min = group_min_gid if group_min_gid <= gid_min else gid_min
gid_max = group_max_gid if group_max_gid > gid_max else gid_max
gids_group = group_df.index
indexes = np.in1d(spike_ids, gids_group)
ax1.scatter(spike_times[indexes], spike_ids[indexes], marker=marker, facecolors=color, label=group_name, lw=0, s=5)
#ax1.set_xlabel('time (s)')
ax1.axes.get_xaxis().set_visible(False)
ax1.set_ylabel('Cell ID')
ax1.set_xlim([0, max(spike_times)])
ax1.set_ylim([gid_min, gid_max])
if legend:
plt.legend(markerscale=2, scatterpoints=1)
ax2 = plt.subplot(gs[1])
plt.hist(spike_times, 100)
ax2.set_xlabel('Time (s)')
ax2.set_xlim([0, max(spike_times)])
#ax2.axes.get_yaxis().set_visible(False)
ax2.set_ylabel('Firing rate (AU)')
if title is not None:
ax1.set_title(title)
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
if save_as is not None:
plt.savefig(save_as)
if show:
plt.show()
[docs]def plot_ratess(cells_file, cell_models_file, spikes_file, group_key='pop_name', exclude=['LIF_inh', 'LIF_exc'], save_as=None, show_plot=True):
#if save_as is not None:
# if os.path.exists(save_as):
# raise Exception('file {} already exists. Cannot save.'.format(save_as))
cm_df = pd.read_csv(cell_models_file, sep=' ')
cm_df.set_index('node_type_id', inplace=True)
c_df = pd.read_csv(cells_file, sep=' ')
c_df.set_index('node_id', inplace=True)
nodes_df = pd.merge(left=c_df,
right=cm_df,
how='left',
left_on='node_type_id',
right_index=True) # use 'model_id' key to merge, for right table the "model_id" is an index
for cond in exclude:
nodes_df = nodes_df[nodes_df[group_key] != cond]
groupings = nodes_df.groupby(group_key)
n_colors = nodes_df[group_key].nunique()
color_norm = colors.Normalize(vmin=0, vmax=(n_colors - 1))
scalar_map = cmx.ScalarMappable(norm=color_norm, cmap='hsv')
color_map = [scalar_map.to_rgba(i) for i in range(0, n_colors)]
spike_times, spike_gids = np.loadtxt(spikes_file, dtype='float32,int', unpack=True)
rates = np.zeros(max(spike_gids) + 1)
for ts, gid in zip(spike_times, spike_gids):
if ts < 500.0:
continue
rates[gid] += 1
for color, (group_name, group_df) in zip(color_map, groupings):
print(group_name)
print(group_df.index)
print(rates[group_df.index])
plt.plot(group_df.index, rates[group_df.index], '.', color=color)
plt.show()
print(n_colors)
exit()
group_keys = set(nodes_df[group_key].unique()) - set(exclude)
groupings = nodes_df.groupby(group_key)
n_colors = len(group_keys)
color_norm = colors.Normalize(vmin=0, vmax=(n_colors - 1))
scalar_map = cmx.ScalarMappable(norm=color_norm, cmap='hsv')
color_map = [scalar_map.to_rgba(i) for i in range(0, n_colors)]
for color, (group_name, group_df) in zip(color_map, groupings):
print(group_name)
print(group_df.index)
exit()
"""
print color_map
exit()
n_colors = nodes_df[group_key].nunique()
color_norm = colors.Normalize(vmin=0, vmax=(n_colors - 1))
scalar_map = cmx.ScalarMappable(norm=color_norm, cmap='hsv')
color_map = [scalar_map.to_rgba(i) for i in range(0, n_colors)]
"""
spike_times, spike_gids = np.loadtxt(spikes_file, dtype='float32,int', unpack=True)
rates = np.zeros(max(spike_gids)+1)
for ts, gid in zip(spike_times, spike_gids):
if ts < 500.0:
continue
rates[gid] += 1
rates = rates / 3.0
plt.plot(range(max(spike_gids)+1), rates, '.')
plt.show()
[docs]def plot_rates_old(cells_file, cell_models_file, spikes_file, group_key=None, exclude=[], interval=None, show=True,
title=None, save_as=None, smoothed=False):
def smooth(data, window=100):
h = int(window/2)
x_max = len(data)
return [np.mean(data[max(0, x-h):min(x_max, x+h)]) for x in range(0, x_max)]
nodes_df = _create_node_table(cells_file, cell_models_file, group_key, exclude)
_, spike_rates = _count_spikes(spikes_file, max(nodes_df.index), interval)
if group_key is not None:
groupings = nodes_df.groupby(group_key)
group_order = {k: i for i, k in enumerate(nodes_df[group_key].unique())}
n_colors = len(group_order)
color_norm = colors.Normalize(vmin=0, vmax=(n_colors-1))
scalar_map = cmx.ScalarMappable(norm=color_norm, cmap='hsv')
color_map = [scalar_map.to_rgba(i) for i in range(0, n_colors)]
ordered_groupings = [(group_order[name], c, name, df) for c, (name, df) in zip(color_map, groupings)]
else:
ordered_groupings = [(0, 'blue', None, nodes_df)]
keys = ['' for _ in range(len(group_order))]
means = [0 for _ in range(len(group_order))]
stds = [0 for _ in range(len(group_order))]
fig = plt.figure()
ax1 = fig.add_subplot(111)
for indx, color, group_name, group_df in ordered_groupings:
keys[indx] = group_name
means[indx] = np.mean(spike_rates[group_df.index])
stds[indx] = np.std(spike_rates[group_df.index])
y = smooth(spike_rates[group_df.index]) if smoothed else spike_rates[group_df.index]
ax1.plot(group_df.index, y, '.', color=color, label=group_name)
max_rate = np.max(spike_rates)
ax1.set_ylim(0, 50)#max_rate*1.3)
ax1.set_ylabel('Hz')
ax1.set_xlabel('gid')
ax1.legend(fontsize='x-small')
if title is not None:
ax1.set_title(title)
if save_as is not None:
plt.savefig(save_as)
plt.figure()
plt.errorbar(range(len(means)), means, stds, linestyle='None', marker='o')
plt.xlim(-0.5, len(color_map)-0.5) # len(color_map) == last_index + 1
plt.ylim(0, 50.0)# max_rate*1.3)
plt.xticks(range(len(means)), keys)
if title is not None:
plt.title(title)
if save_as is not None:
if save_as.endswith('.jpg'):
base = save_as[0:-4]
elif save_as.endswith('.jpeg'):
base = save_as[0:-5]
else:
base = save_as
plt.savefig('{}.summary.jpg'.format(base))
with open('{}.summary.csv'.format(base), 'w') as f:
f.write('population mean stddev\n')
for i, key in enumerate(keys):
f.write('{} {} {}\n'.format(key, means[i], stds[i]))
if show:
plt.show()
[docs]def plot_rates_popnet(cell_models_file, rates_file, model_keys=None, save_as=None, show_plot=True):
"""Initial method for plotting popnet output
:param cell_models_file:
:param rates_file:
:param model_keys:
:param save_as:
:param show_plot:
:return:
"""
pops_df = pd.read_csv(cell_models_file, sep=' ')
lookup_col = model_keys if model_keys is not None else 'node_type_id'
pop_keys = {str(r['node_type_id']): r[lookup_col] for _, r in pops_df.iterrows()}
# organize the rates file by population
# rates = {pop_name: ([], []) for pop_name in pop_keys.keys()}
rates_df = pd.read_csv(rates_file, sep=' ', names=['id', 'times', 'rates'])
for grp_key, grp_df in rates_df.groupby('id'):
grp_label = pop_keys[str(grp_key)]
plt.plot(grp_df['times'], grp_df['rates'], label=grp_label)
plt.legend(fontsize='x-small')
plt.xlabel('time (s)')
plt.ylabel('firing rates (Hz)')
if save_as is not None:
plt.savefig(save_as)
if show_plot:
plt.show()
[docs]def plot_avg_rates(cell_models_file, rates_file, model_keys=None, save_as=None, show_plot=True):
pops_df = pd.read_csv(cell_models_file, sep=' ')
lookup_col = model_keys if model_keys is not None else 'node_type_id'
pop_keys = {str(r['node_type_id']): r[lookup_col] for _, r in pops_df.iterrows()}
# organize the rates file by population
rates = {pop_name: [] for pop_name in pop_keys.keys()}
with open(rates_file, 'r') as f:
reader = csv.reader(f, delimiter=' ')
for row in reader:
if row[0] in rates:
#rates[row[0]][0].append(row[1])
rates[row[0]].append(float(row[2]))
labels = []
means = []
stds = []
#print rates
for pop_name in pops_df['node_type_id'].unique():
r = rates[str(pop_name)]
if len(r) == 0:
continue
labels.append(pop_keys.get(str(pop_name), str(pop_name)))
means.append(np.mean(r))
stds.append(np.std(r))
plt.figure()
plt.errorbar(range(len(means)), means, stds, linestyle='None', marker='o')
plt.xlim(-0.5, len(means) - 0.5)
plt.xticks(range(len(means)), labels)
plt.ylabel('firing rates (Hz)')
if save_as is not None:
plt.savefig(save_as)
if show_plot:
plt.show()
[docs]def plot_tuning(sg_analysis, node, band, Freq=0, show=True, save_as=None):
def index_for_node(node, band):
if node == 's4':
mask = sg_analysis.node_table.node == node
else:
mask = (sg_analysis.node_table.node == node) & (sg_analysis.node_table.band == band)
return str(sg_analysis.node_table[mask].index[0])
index = index_for_node(node, band)
key = index + '/sg/tuning'
analysis_file = sg_analysis.get_tunings_file()
tuning_matrix = analysis_file[key].value[:, :, :, Freq]
n_or, n_sf, n_ph = tuning_matrix.shape
vmax = np.max(tuning_matrix[:, :, :])
vmin = np.min(tuning_matrix[:, :, :])
#fig, ax = plt.subplots(1, n_ph, figsize=(12, 16), sharex=True, sharey=True)
fig, ax = plt.subplots(1, n_ph, figsize=(13.9, 4.3), sharex=False, sharey=True)
print(sg_analysis.orientations)
for phase in range(n_ph):
tuning_to_plot = tuning_matrix[:, :, phase]
im = ax[phase].imshow(tuning_to_plot, interpolation='nearest', vmax=vmax, vmin=vmin)
ax[phase].set_xticklabels([0] + list(sg_analysis.spatial_frequencies))
ax[phase].set_yticklabels([0] + list(sg_analysis.orientations))
ax[phase].set_title('phase = {}'.format(sg_analysis.phases[phase]))
ax[phase].set_xlabel('spatial_frequency')
if phase == 0:
ax[phase].set_ylabel('orientation')
fig.subplots_adjust(right=0.90)
cbar_ax = fig.add_axes([0.92, 0.10, 0.02, 0.75])
cbar = fig.colorbar(im, cax=cbar_ax, ticks=[vmin, 0.0, vmax])
if save_as is not None:
plt.savefig(save_as)
if show:
plt.show()