Source code for bmtk.simulator.filternet.default_setters.cell_loaders

import numpy as np
from sympy.abc import x as symbolic_x
from sympy.abc import y as symbolic_y

from bmtk.simulator.filternet.filters import TemporalFilterCosineBump, GaussianSpatialFilter, SpatioTemporalFilter
from bmtk.simulator.filternet.cell_models import TwoSubfieldLinearCell, OnUnit, OffUnit, LGNOnOffCell
from bmtk.simulator.filternet.transfer_functions import ScalarTransferFunction, MultiTransferFunction
from bmtk.simulator.filternet.utils import get_data_metrics_for_each_subclass, get_tcross_from_temporal_kernel
from bmtk.simulator.filternet.pyfunction_cache import py_modules


[docs]def create_two_sub_cell(dom_lf, non_dom_lf, dom_spont, non_dom_spont, onoff_axis_angle, subfield_separation, dom_location): dsp = str(dom_spont) ndsp = str(non_dom_spont) two_sub_transfer_fn = MultiTransferFunction((symbolic_x, symbolic_y), 'Heaviside(x+'+dsp+')*(x+'+dsp+')+Heaviside(y+'+ndsp+')*(y+'+ndsp+')') two_sub_cell = TwoSubfieldLinearCell(dom_lf, non_dom_lf, subfield_separation=subfield_separation, onoff_axis_angle=onoff_axis_angle, dominant_subfield_location=dom_location, transfer_function=two_sub_transfer_fn) return two_sub_cell
[docs]def createOneUnitOfTwoSubunitFilter(weights, kpeaks, delays, ttp_exp): delays = np.array(delays) filt = TemporalFilterCosineBump(weights, kpeaks, delays) tcross_ind = get_tcross_from_temporal_kernel(filt.get_kernel(threshold=-1.0).kernel) filt_sum = filt.get_kernel(threshold=-1.0).kernel[:tcross_ind].sum() # Calculate delay offset needed to match response latency with data and rebuild temporal filter del_offset = ttp_exp - tcross_ind if del_offset >= 0: delays_updated = delays + del_offset filt_new = TemporalFilterCosineBump(weights, kpeaks, delays_updated) else: raise Exception('del_offset < 0') return filt_new, filt_sum
[docs]def get_tf_params(node, dynamics_params, non_dom_props=False): if not non_dom_props: weights = node.weights if node.weights is not None else dynamics_params['opt_wts'] kpeaks = node.kpeaks if node.kpeaks is not None else dynamics_params['opt_kpeaks'] delays = node.delays if node.delays is not None else dynamics_params['opt_delays'] else: dp = dynamics_params or {} weights = node.weights_non_dom if node.weights_non_dom is not None else dp.get('opt_wts', None) kpeaks = node.kpeaks_non_dom if node.kpeaks_non_dom is not None else dp.get('opt_kpeaks', None) delays = node.delays_non_dom if node.delays_non_dom is not None else dp.get('opt_delays', None) if node.predefined_jitter: jitter_fnc = np.vectorize(lambda a: np.random.uniform(a * node.jitter[0], a * node.jitter[0])) weights = jitter_fnc(weights) if weights is not None else weights kpeaks = jitter_fnc(kpeaks) if kpeaks is not None else kpeaks delays = jitter_fnc(delays) if delays is not None else delays return weights, kpeaks, delays
[docs]def default_cell_loader(node, template_name, dynamics_params): """ :param node: :param template_name: :param dynamics_params: :return: """ # Create the spatial filter origin = (0.0, 0.0) translate = (node['x'], node['y']) sigma = node['spatial_size'] / 3.0 # convert from degree to SD if np.isscalar(sigma): sigma = (sigma, sigma) if 'spatial_rotation' in node: rotation = node['spatial_rotation'] else: rotation = 0.0 spatial_filter = GaussianSpatialFilter(translate=translate, sigma=sigma, origin=origin, rotation=rotation) t_weights, t_kpeaks, t_delays = get_tf_params(node, dynamics_params) if template_name: model_name = template_name[1] else: model_name = node['pop_name'] if model_name in ['sONsOFF_001', 'sONsOFF']: # sON temporal filter t_weights_nd, t_kpeaks_nd, t_delays_nd = get_tf_params(node, node.non_dom_params, non_dom_props=True) sON_filt_new, sON_sum = createOneUnitOfTwoSubunitFilter(t_weights_nd, t_kpeaks_nd, t_delays_nd, 121.0) sOFF_filt_new, sOFF_sum = createOneUnitOfTwoSubunitFilter(t_weights, t_kpeaks, t_delays, 115.0) amp_on = 1.0 # set the non-dominant subunit amplitude to unity spont = 4.0 max_roff = 35.0 max_ron = 21.0 amp_off = -(max_roff/max_ron)*(sON_sum/sOFF_sum)*amp_on - (spont*(max_roff - max_ron))/(max_ron*sOFF_sum) # Create sON subunit: linear_filter_son = SpatioTemporalFilter(spatial_filter, sON_filt_new, amplitude=amp_on) # Create sOFF subunit: linear_filter_soff = SpatioTemporalFilter(spatial_filter, sOFF_filt_new, amplitude=amp_off) sf_sep = node.sf_sep if node.predefined_jitter: sf_sep = np.random.uniform(node.lower_jitter*sf_sep, node.upper_jitter*sf_sep) sep_ss_onoff_cell = create_two_sub_cell(linear_filter_soff, linear_filter_son, 0.5 * spont, 0.5 * spont, node.tuning_angle, sf_sep, translate) cell = sep_ss_onoff_cell elif model_name in ['sONtOFF_001', 'sONtOFF']: t_weights_nd, t_kpeaks_nd, t_delays_nd = get_tf_params(node, node.non_dom_params, non_dom_props=True) sON_filt_new, sON_sum = createOneUnitOfTwoSubunitFilter(t_weights_nd, t_kpeaks_nd, t_delays_nd, 93.5) tOFF_filt_new, tOFF_sum = createOneUnitOfTwoSubunitFilter(t_weights, t_kpeaks, t_delays, 64.8) # 64.8 amp_on = 1.0 # set the non-dominant subunit amplitude to unity spont = 5.5 max_roff = 46.0 max_ron = 31.0 amp_off = -0.7*(max_roff/max_ron)*(sON_sum/tOFF_sum)*amp_on - (spont*(max_roff - max_ron))/(max_ron*tOFF_sum) # Create sON subunit: linear_filter_son = SpatioTemporalFilter(spatial_filter, sON_filt_new, amplitude=amp_on) # Create tOFF subunit: linear_filter_toff = SpatioTemporalFilter(spatial_filter, tOFF_filt_new, amplitude=amp_off) sf_sep = node.sf_sep if node.predefined_jitter: sf_sep = np.random.uniform(node.lower_jitter*sf_sep, node.upper_jitter*sf_sep) sep_ts_onoff_cell = create_two_sub_cell(linear_filter_toff, linear_filter_son, 0.5 * spont, 0.5 * spont, node.tuning_angle, sf_sep, translate) cell = sep_ts_onoff_cell elif model_name == 'LGNOnOFFCell': wts = [node['weight_dom_0'], node['weight_dom_1']] kpeaks = [node['kpeaks_dom_0'], node['kpeaks_dom_1']] delays = [node['delay_dom_0'], node['delay_dom_1']] # transfer_function = ScalarTransferFunction('s') temporal_filter = TemporalFilterCosineBump(wts, kpeaks, delays) spatial_filter_on = GaussianSpatialFilter(sigma=node['sigma_on'], origin=origin, translate=translate) on_linear_filter = SpatioTemporalFilter(spatial_filter_on, temporal_filter, amplitude=20) spatial_filter_off = GaussianSpatialFilter(sigma=node['sigma_off'], origin=origin, translate=translate) off_linear_filter = SpatioTemporalFilter(spatial_filter_off, temporal_filter, amplitude=-20) cell = LGNOnOffCell(on_linear_filter, off_linear_filter) else: type_split = model_name.split('_') if len(type_split) == 1: cell_type = model_name tf_str = 'TF8' else: cell_type, tf_str = type_split[0], type_split[1] # Get spontaneous firing rate, either from the cell property of calculate from experimental data if 'spont_fr' in node: spont_fr = node['spont_fr'] else: exp_prs_dict = get_data_metrics_for_each_subclass(cell_type) subclass_prs_dict = exp_prs_dict[tf_str] spont_fr = subclass_prs_dict['spont_exp'][0] # Get filters transfer_function = ScalarTransferFunction('Heaviside(s+{})*(s+{})'.format(spont_fr, spont_fr)) temporal_filter = TemporalFilterCosineBump(t_weights, t_kpeaks, t_delays) if cell_type.find('OFF') >= 0: amplitude = 1.0 linear_filter = SpatioTemporalFilter(spatial_filter, temporal_filter, amplitude=amplitude) cell = OnUnit(linear_filter, transfer_function) elif cell_type.find('ON') >= 0: amplitude = -1.0 linear_filter = SpatioTemporalFilter(spatial_filter, temporal_filter, amplitude=amplitude) cell = OffUnit(linear_filter, transfer_function) return cell
py_modules.add_cell_processor('default', default_cell_loader, overwrite=False) py_modules.add_cell_processor('preset_params', default_cell_loader, overwrite=False)