"""
Provides the implementation of the Module class for managing and training spiking neural networks (SNNs).
Includes methods for STDP learning, feed-forward processing, training, testing, and model persistence.
"""
import os
from . import common, np, plt
from ..dataloader import Dataloader
from .parameters import Parameters
from .layer import Layer
from .neuron import LIFNeuron
from typing import Union
from scipy.stats import truncnorm
[docs]
def apply_truncated_noise_vectorized(input_array: np.ndarray, half_range: float) -> np.ndarray:
"""
Applies truncated normal noise to an array of values.
Args:
input_array (np.ndarray): The input array of values between 0 and 1.
half_range (float): The half-range for the noise distribution.
Returns:
np.ndarray: The array with noise applied, clipped to [0, 1].
"""
arr = np.asarray(input_array)
if half_range == 0:
return arr.copy()
sigma = half_range / 3.0
if sigma == 0:
return arr.copy()
lower_bound = -half_range / sigma
upper_bound = half_range / sigma
noise = truncnorm.rvs(
lower_bound,
upper_bound,
loc=0.0,
scale=sigma,
size=arr.shape
)
noisy_array = np.clip(arr * (1.0 + noise), 0, 1)
return noisy_array
[docs]
def apply_truncated_noise_to_value(value: float, half_range: float) -> float:
"""
Applies truncated normal noise to a single floating-point value.
Args:
value (float): The input value between 0 and 1.
half_range (float): The half-range for the noise distribution.
Returns:
float: The value with noise applied, clipped to [0, 1].
"""
if half_range == 0:
return value
sigma = half_range / 3.0
if sigma == 0:
return value
lower_bound = -half_range / sigma
upper_bound = half_range / sigma
noise = truncnorm.rvs(lower_bound, upper_bound, loc=0.0, scale=sigma)
noisy_value = np.clip(value * (1.0 + noise), 0, 1)
return noisy_value
[docs]
class Module:
"""
Represents a module for a spiking neural network.
This class serves as the main structure to initialize, train, and evaluate a spiking neural network using parameters and layer configurations.
User must set the necessary data and call methods like `initialise_layers` and `train` to fully utilize the module.
Attributes:
parameters (Parameters): Configuration parameters for the spiking neural network, including learning rates and thresholds.
t0 (float): Timestamp at the initialization of the module, used for tracking execution time.
str_t0 (str): String representation of the initialization timestamp in the format 'YYYY_MM_DD-HH_MM_SS'.
dataloader (Dataloader): Placeholder for the data loader object to manage training and testing datasets.
X_train (np.ndarray): Training input spike trains (initialized as None).
Y_train (np.ndarray): Training labels corresponding to `X_train` (initialized as None).
X_test (np.ndarray): Testing input spike trains (initialized as None).
Y_test (np.ndarray): Testing labels corresponding to `X_test` (initialized as None).
identifier (str): Unique identifier for the module, defaulting to the timestamp `str_t0` if not provided.
logger (Logger): Logger object for tracking and debugging operations within the module.
wta (bool, optional): Used to use Winner-Take-All or update all synapses while doing STDP.
synapse_update_counts (list): Tracks the number of STDP updates for each synapse to support cycle-dependent weighting.
cycle_dependent_weight_lookup (list, optional): Lookup table for weight values based on update cycle counts.
enable_cycle_dependent_weights (bool): Flag to enable weight updates based on cycle count lookup.
enable_synaptic_noise (bool): Flag to enable truncated normal noise on synapse values during read.
noise_magnitude_half_range (float): The half-range magnitude for the synaptic noise distribution.
"""
def __init__(self, parameters: Parameters, identifier: str = None) -> None:
"""
Initialize the Module object.
Args:
parameters (Parameters): Configuration parameters for the Module.
identifier (str, optional): Unique identifier for the module. Defaults to the current timestamp.
"""
self.parameters = parameters
self.t0 = common.timenow()
self.str_t0 = common.strftime("%Y_%m_%d-%H_%M_%S")
self.dataloader: Dataloader = None
self.X_train: np.ndarray = None
self.Y_train: np.ndarray = None
self.X_test: np.ndarray = None
self.Y_test: np.ndarray = None
self.identifier = identifier if identifier else self.str_t0
self.logger = common.logger
self.get_spikeplots: bool = False
self.spikeplots: np.ndarray = []
self.layerpotentials: np.ndarray = []
self.get_weight_evolution = False
self.weight_evolution = []
self.wta = True
self.allowed_levels = None
self.synapse_update_counts = []
self.cycle_dependent_weight_lookup = None
self.enable_cycle_dependent_weights = False
self.enable_synaptic_noise = False
self.noise_magnitude_half_range = 0.0
self.true_lif = False
[docs]
def STDP(self, delta_t: int) -> float:
"""
Calculate the spike-timing-dependent plasticity (STDP) adjustment for synapse weights.
Args:
delta_t (float): The time difference between post-synaptic and pre-synaptic spikes.
Returns:
float: The weight adjustment based on the STDP rule.
"""
if delta_t >= 0:
return self.parameters.A_up * (
np.exp(-float(delta_t) / self.parameters.tau_up)
)
else: # delta_t <0
return self.parameters.A_down * (
np.exp(float(delta_t) / self.parameters.tau_down)
)
[docs]
def initialise_layers(self, layer_sizes: list) -> None:
"""
Initialize the layers of the network based on the specified layer sizes.
Args:
layer_sizes (list): A list of integers specifying the number of neurons in each layer.
"""
self.layer_sizes = layer_sizes
self.layers = [
Layer(self.parameters, layer_sizes[j], layer_sizes[j + 1])
for j in range(len(layer_sizes) - 1)
]
self.synapse_update_counts = np.zeros(self.layers[0].synapses.shape)
[docs]
def get_all_synapses(self) -> np.ndarray:
"""
Retrieve all synapse weights from the network.
Returns:
np.ndarray: A list of all synapse weights for each layer in the network.
"""
return np.array([layer.synapses for layer in self.layers])
[docs]
def update_current_potentials_and_adaptive_thresholds(
self,
current_potentials: np.ndarray,
neuron: LIFNeuron,
neuron_index: int,
spike_train_at_timestep: np.ndarray,
synapses_for_neuron_index: np.ndarray,
time_step: float,
in_training: bool = False,
) -> None:
"""
Update the membrane potential and adaptive thresholds of a neuron during feed-forward processing.
Args:
current_potentials (np.ndarray): Current potentials for all neurons in the layer.
neuron (LIFNeuron): The neuron to update.
neuron_index (int): Index of the neuron in the layer.
spike_train_at_timestep (np.ndarray): Spike train data at the current time step.
synapses_for_neuron_index (np.ndarray): Synapse weights for the neuron.
time_step (float): Current time step in the simulation.
in_training (bool): Is the model in train mode (do_stdp is True).
"""
if neuron.rest_until < time_step:
if self.enable_synaptic_noise:
neuron.potential += np.dot(
apply_truncated_noise_vectorized(synapses_for_neuron_index, self.noise_magnitude_half_range), spike_train_at_timestep
)
else:
neuron.potential += np.dot(
synapses_for_neuron_index, spike_train_at_timestep
)
if neuron.potential > self.parameters.resting_potential:
if self.true_lif:
diff = (neuron.potential - self.parameters.resting_potential) / self.parameters.tau_m
else:
diff = self.parameters.spike_drop_rate
neuron.potential-=diff
if neuron.adaptive_threshold > self.parameters.spike_threshold:
if self.true_lif:
diff = (neuron.adaptive_threshold - self.parameters.spike_threshold) / self.parameters.tau_threshold
else:
diff = self.parameters.threshold_drop_rate
neuron.adaptive_threshold -= diff
current_potentials[neuron_index] = neuron.potential
if in_training:
neuron.potential_memory[time_step] = neuron.potential
[docs]
def update_synapse(self, synapse_weight: float, weight_factor: float, next_neuron_idx: int, current_neuron_idx: int) -> float:
"""
Adjust the synapse weight using the specified weight factor.
Args:
synapse_weight (float): The current weight of the synapse.
weight_factor (float): The adjustment factor based on STDP.
next_neuron_idx (int): Index of the post-synaptic neuron.
current_neuron_idx (int): Index of the pre-synaptic neuron.
Returns:
float: The updated synapse weight.
"""
diff = (
(synapse_weight - self.parameters.min_weight)
if weight_factor < 0
else (self.parameters.max_weight - synapse_weight)
)
updated_wt = synapse_weight + self.parameters.eta * weight_factor * (
np.sign(diff) * abs(diff) ** 0.9
)
if self.enable_cycle_dependent_weights:
try:
ls = self.cycle_dependent_weight_lookup[int(self.synapse_update_counts[next_neuron_idx][current_neuron_idx])]
except IndexError:
ls = self.cycle_dependent_weight_lookup[-1]
return ls[np.argmin(abs(ls-updated_wt))]
if type(self.allowed_levels) != type(None):
return self.allowed_levels[np.argmin(abs(self.allowed_levels-updated_wt))]
if not self.enable_synaptic_noise:
return updated_wt
else:
return apply_truncated_noise_to_value(updated_wt, self.noise_magnitude_half_range)
[docs]
def feed_forward(
self,
spike_train: np.ndarray,
neuron_label_map: np.ndarray,
training_duration: int,
do_stdp: bool = False,
label: int = None,
spikeplots_for_one_spike_train: list = None,
layer_potentials_for_one_spike_train: list = None,
weight_evolution_for_one_spike_train: list = None,
plotting=False,
) -> tuple[np.ndarray, int]:
"""
Perform feed-forward processing for the Module, with optional STDP learning.
Args:
spike_train (np.ndarray): Input spike train.
neuron_label_map (np.ndarray): Map of neuron indices to labels.
training_duration (int): Time duration for the feed-forward process.
do_stdp (bool, optional): If True, apply STDP learning. Defaults to False.
label (int, optional): Label of the current input. Defaults to None.
spikeplots_for_one_spike_train (list, optional): Used to store spikeplots for the current input. Defaults to None.
layer_potentials_for_one_spike_train (list, optional): Used to store potentials of all layers' neurons. Defaults to None.
plotting (bool, optional): Used to differentiate between testing and training while storing spikeplots.
Returns:
tuple: Spike counts for the output layer and the index of the neuron with the highest potential.
"""
position = 0
while position < len(self.layer_sizes) - 1:
if self.get_spikeplots and plotting:
spikeplots_for_one_spike_train.append(spike_train)
count_spikes_for_next_layer = np.zeros(self.layer_sizes[position + 1])
current_potentials_for_next_layer = np.zeros(self.layer_sizes[position + 1])
synapse_memory = np.zeros(
(self.layer_sizes[position + 1], self.layer_sizes[position])
)
synapses = self.layers[position].synapses
spike_train_for_next_layer = np.zeros(
(self.layer_sizes[position + 1], self.parameters.training_duration + 1)
)
for time_step in training_duration:
for neuron_index, neuron in enumerate(self.layers[position].layer):
self.update_current_potentials_and_adaptive_thresholds(
current_potentials_for_next_layer,
neuron,
neuron_index,
spike_train[:, time_step],
synapses[neuron_index],
time_step,
plotting,
)
highest_potential_neuron_idx = np.argmax(
current_potentials_for_next_layer
)
highest_potential_neuron = self.layers[position].layer[
highest_potential_neuron_idx
]
if (
current_potentials_for_next_layer[highest_potential_neuron_idx]
< highest_potential_neuron.adaptive_threshold
):
continue
count_spikes_for_next_layer[highest_potential_neuron_idx] += 1
for neuron_index, neuron in enumerate(self.layers[position].layer):
if (
current_potentials_for_next_layer[neuron_index]
> neuron.adaptive_threshold
):
spike_train_for_next_layer[neuron_index, time_step] = 1
neuron.state_just_after_firing(time_step)
if plotting:
neuron.potential_memory[time_step] = (
neuron.parameters.reset_potential
)
neuron.adaptive_threshold += 1
else:
neuron.inhibit(time_step)
if plotting:
neuron.potential_memory[time_step] = (
neuron.parameters.inhibitory_potential
)
if do_stdp:
for current_neuron_idx in range(self.layer_sizes[position]):
if self.wta:
self.reweigh_synapses_for_between_input_and_output_neuron(
current_neuron_idx,
spike_train,
synapses,
time_step,
highest_potential_neuron_idx,
synapse_memory,
)
else:
for next_neuron_idx in range(
self.layer_sizes[position + 1]
):
self.reweigh_synapses_for_between_input_and_output_neuron(
current_neuron_idx,
spike_train,
synapses,
time_step,
next_neuron_idx,
synapse_memory,
)
if self.get_weight_evolution and plotting:
weight_evolution_for_one_spike_train.append(self.layers[position].synapses.copy())
self.layers[position].update_neuron_potential_memories()
if plotting:
layer_potentials_for_one_spike_train.append(
self.layers[position].neuron_potential_memories
)
self.layers[position].initial()
spike_train = spike_train_for_next_layer
position += 1
if self.get_spikeplots and plotting:
spikeplots_for_one_spike_train.append(spike_train)
if do_stdp and label is not None:
neuron_label_map[highest_potential_neuron_idx] = int(label)
return count_spikes_for_next_layer, highest_potential_neuron_idx
[docs]
def train(self) -> float:
"""
Train the Module using the provided dataset.
Returns:
float: The accuracy on the test set after training.
"""
self.logger.info("Training Started")
self.t0 = common.timenow()
neuron_label_map = np.repeat(-1, self.layer_sizes[-1])
training_duration = np.arange(1, self.parameters.training_duration + 1, 1)
test_accuracy_over_epochs = []
spikeplots = []
layerpotentials = []
weight_evolution = []
for epoch in range(self.parameters.epochs):
spikeplots_for_one_epoch = []
layer_potentials_for_one_epoch = []
weight_evolution_for_one_epoch = []
print(f"Epoch {epoch + 1}/{self.parameters.epochs}")
cnt = 0
for spike_train, label in zip(self.X_train, self.Y_train):
cnt += 1
spikeplots_for_one_spike_train = []
layer_potentials_for_one_spike_train = []
weight_evolution_for_one_spike_train = []
self.feed_forward(
spike_train,
neuron_label_map,
training_duration,
do_stdp=True,
label=label,
spikeplots_for_one_spike_train=spikeplots_for_one_spike_train,
layer_potentials_for_one_spike_train=layer_potentials_for_one_spike_train,
weight_evolution_for_one_spike_train=weight_evolution_for_one_spike_train,
plotting=True,
)
spikeplots_for_one_epoch.append(spikeplots_for_one_spike_train)
layer_potentials_for_one_epoch.append(
layer_potentials_for_one_spike_train
)
weight_evolution_for_one_epoch.append(weight_evolution_for_one_spike_train)
bar_length = 40
progress = int(
bar_length * cnt / self.parameters.training_images_amount
)
bar = "=" * progress + "." * (bar_length - progress)
print(
f"\r{str(cnt).zfill(len(str(self.parameters.training_images_amount)))}/{self.parameters.training_images_amount} [{bar}] [Max cycles for one synapse: {np.max(self.synapse_update_counts)}]",
end="",
)
self.learned_synapses = self.get_all_synapses()
self.learned_neuron_label_map = neuron_label_map
print("\nTesting...")
test_accuracy = self.test()
test_accuracy_over_epochs.append(test_accuracy)
print(
f'\rTest set accuracy: {round(test_accuracy, 3) if test_accuracy_over_epochs else "None"}\nTime elapsed since training start: {round(common.timenow() - self.t0, 3)}s'
)
self.save_epoch(
epoch + 1, self.learned_synapses, neuron_label_map, test_accuracy
)
spikeplots.append(spikeplots_for_one_epoch)
layerpotentials.append(layer_potentials_for_one_epoch)
weight_evolution.append(weight_evolution_for_one_epoch)
self.spikeplots = spikeplots
self.layerpotentials = layerpotentials
self.weight_evolution = weight_evolution
print("Training complete!")
return test_accuracy
[docs]
def test(self) -> float:
"""
Test the Module on the provided test dataset.
Returns:
float: The accuracy of the Module on the test dataset.
"""
t = 0
c = 0
for spike_train, label in zip(self.X_test, self.Y_test):
prediction = self.get_prediction(spike_train)
if prediction == label:
c += 1
t += 1
return c / t
[docs]
def get_prediction(
self,
spike_train: np.ndarray,
all_synapses: np.ndarray = None,
neuron_label_map: np.ndarray = None,
) -> int:
"""
Predict the label of a given input spike train.
Args:
spike_train (np.ndarray): Input spike train.
all_synapses (np.ndarray, optional): Synapse weights. Defaults to learned synapses.
neuron_label_map (np.ndarray, optional): Neuron-to-label mapping. Defaults to learned mapping.
Returns:
int: Predicted label.
"""
if not (
isinstance(all_synapses, np.ndarray)
and isinstance(neuron_label_map, np.ndarray)
):
synapses = self.learned_synapses.copy()
neuron_label_map = self.learned_neuron_label_map.copy()
else:
synapses = all_synapses.copy()
neuron_label_map = neuron_label_map.copy()
for i in range(len(self.layers)):
self.layers[i].set_synapses(synapses[i])
training_duration = np.arange(1, self.parameters.training_duration + 1, 1)
count_spikes_for_output_layer, _ = self.feed_forward(
spike_train,
neuron_label_map,
training_duration,
do_stdp=False,
plotting=False,
)
prediction = neuron_label_map[np.argmax(count_spikes_for_output_layer)]
return prediction
[docs]
def predict(self, X: np.ndarray, model_location: str) -> int:
"""
Predict labels for a batch of inputs. User has to implement this on their own.
Args:
X (np.ndarray): Batch of input spike trains.
model_location (str): Path to the model for inference.
Returns:
int: Predicted labels for the input batch.
"""
return -1
[docs]
def save_epoch(
self,
epoch: int,
synapses: np.ndarray,
neuron_label_map: np.ndarray,
accuracy: float,
) -> None:
"""
Save the model after a training epoch.
Args:
epoch (int): Current epoch number.
synapses (np.ndarray): Synapse weights after the epoch.
neuron_label_map (np.ndarray): Neuron-to-label mapping.
accuracy (float): Accuracy on the test set after the epoch.
"""
identifier_dir = f"{self.identifier}{'_'+self.str_t0 if self.str_t0!=self.identifier else ''}"
epoch_dir = f"Epoch_{epoch}-{round(accuracy, 3)}"
full_path = os.path.join(common.cwd, "storage", identifier_dir, epoch_dir)
os.makedirs(full_path, exist_ok=True)
model_path = os.path.join(full_path, "model.red")
common.save_model(neuron_label_map, synapses, self.parameters, model_path)
[docs]
def load_model(self, model_location: str) -> None:
"""
Load a saved SNN model from the specified location.
Args:
model_location (str): Path to the saved model.
"""
synapses, neuron_label_map, parameters = common.load_model(model_location)
self.learned_synapses = synapses
self.learned_neuron_label_map = neuron_label_map
self.parameters = parameters