"""
This module implements the `MNISTLoader` class, which extends the `Dataloader`
base class. It is specifically designed to handle the MNIST dataset, providing
methods for:
- Loading and filtering digit classes
- Normalizing images
- Converting images into spike trains
- Retrieving balanced or random samples
The `MNISTLoader` supports preprocessing for Spiking Neural Network (SNN)
training workflows.
"""
from ..utils import *
from .loader import Dataloader
from sklearn.datasets import fetch_openml
from typing import Union
from scipy.fft import dct, idct
from sklearn.decomposition import PCA
[docs]
class MNISTLoader(Dataloader):
"""
A dataloader class specifically designed for loading and preprocessing the MNIST dataset.
This class extends the `Dataloader` base class and provides functionality
for loading MNIST data, normalizing images, converting images into spike
trains, and retrieving balanced or random samples.
Attributes:
parameters (Parameters): Configuration parameters for the loader.
classes (list): List of digit classes to include in the dataset.
X_train (np.ndarray): Training images.
Y_train (np.ndarray): Training labels.
X_test (np.ndarray): Testing images.
Y_test (np.ndarray): Testing labels.
"""
def __init__(
self, parameters: Parameters, classes: list = [i for i in range(10)]
) -> None:
"""
Initializes the MNISTLoader with given parameters and filters classes.
Args:
parameters (Parameters): Configuration parameters.
classes (list, optional): List of digit classes to filter. Defaults to all 10 classes.
"""
super().__init__(parameters)
self.parameters = parameters
logger.info("Loading Raw Data")
mnist__ = fetch_openml('mnist_784', version=1)
X__ = np.array(mnist__.data, dtype="uint8").reshape(-1, 28, 28)
y__ = np.array(mnist__.target, dtype="int64")
(self.X_train, self.Y_train), (self.X_test, self.Y_test) = (X__[:60000], y__[:60000]), (X__[60000:], y__[60000:])
if len(classes) != 0:
train_filter = np.isin(self.Y_train, classes)
self.X_train = self.X_train[train_filter]
self.Y_train = self.Y_train[train_filter]
test_filter = np.isin(self.Y_test, classes)
self.X_test = self.X_test[test_filter]
self.Y_test = self.Y_test[test_filter]
[docs]
def normalise(self, img: np.ndarray) -> np.ndarray:
"""
Normalizes an image to the range [0, 1].
Args:
img (np.ndarray): The image to normalize.
Returns:
np.ndarray: The normalized image.
"""
return (img - np.min(img)) / (np.max(img) - np.min(img))
[docs]
def img2spiketrain(self, img: np.ndarray) -> np.ndarray:
"""
Converts a normalized image into a spike train representation.
Args:
img (np.ndarray): The normalized image.
Returns:
np.ndarray: The spike train representation of the image.
"""
sx, sy = img.shape
time_steps = self.parameters.training_duration + 1
normalized_img = self.normalise(img)
frequencies = (
normalized_img
* (self.parameters.max_frequency - self.parameters.min_frequency)
+ self.parameters.min_frequency
)
intervals = np.ceil(self.parameters.training_duration / frequencies).astype(int)
spike_trains = np.zeros((sx, sy, time_steps), dtype=int)
for t in range(1, time_steps):
spikes = (t % intervals == 0) & (img > 0) # spike condition
spike_trains[:, :, t] = spikes.astype(int)
return spike_trains.reshape(-1, time_steps)
[docs]
def get_random_image(
self, get_spike_train: bool = False
) -> tuple[np.ndarray, int, np.ndarray]:
"""
Retrieves a random image and its label from the training dataset.
Args:
get_spike_train (bool, optional): Whether to return the spike train of the image. Defaults to False.
Returns:
tuple: A tuple containing the image, label, and optionally the spike train.
"""
idx = np.random.randint(0, len(self.X_train))
X = self.X_train[idx]
Y = self.Y_train[idx]
spike_train = None
if get_spike_train:
spike_train = self.img2spiketrain(X)
return X, Y, spike_train
[docs]
def load_balanced_mnist(
self, Y: np.ndarray, num_samples: int, seed: int = None
) -> np.ndarray:
"""
Balances the dataset by selecting an equal number of samples for each label.
Args:
Y (np.ndarray): The labels of the dataset.
num_samples (int): Total number of samples to select.
seed (int, optional): Seed for reproducibility. Defaults to None.
Returns:
np.ndarray: Indices of the selected samples.
"""
rng = np.random.default_rng(seed)
unique_labels = np.unique(Y)
samples_per_label = num_samples // len(unique_labels)
selected_indices = []
for label in unique_labels:
indices = np.where(Y == label)[0]
selected_indices.extend(
rng.choice(indices, size=samples_per_label, replace=False)
)
selected_indices = rng.permutation(selected_indices)
return selected_indices
[docs]
def compress_image(self, image: np.ndarray, k: int) -> np.ndarray:
"""
Compress image by keeping top block of m x m = k coefficients.
Args:
image (np.ndarray): 2D numpy array (e.g. 28x28 for MNIST)
k (int): number of coefficients to keep.
(k must be chosen so that sqrt(k) is an integer; we keep a top-left block of size m x m, where m = sqrt(k))
Returns:
np.ndarray: The DCT coefficient matrix with only the top-left m x m block kept.
"""
image_dct = dct(dct(image.T, norm="ortho").T, norm="ortho")
m = int(np.sqrt(k))
mask = np.zeros_like(image_dct)
mask[:m, :m] = 1
compressed_dct = image_dct * mask
return compressed_dct
[docs]
def uncompress_image(
self, image: np.ndarray, k: int, threshold: float = 0.5
) -> np.ndarray:
"""
Given the (possibly sparsified) DCT coefficients, reconstruct the image using the inverse DCT.
Args:
image (np.ndarray): 2D numpy array.
k (int): number of coefficients to keep.
threshold (float): The threshold below which all pixel values will be zero after normalization. Defaults to 0.5
Returns:
np.ndarray: The uncompressed image.
"""
img = idct(idct(image.T, norm="ortho").T, norm="ortho")
uncompressed = (img - np.min(img)) / (np.max(img) - np.min(img))
uncompressed[uncompressed < threshold] = 0
uncompressed[uncompressed >= threshold] = 1
uncompressed = (uncompressed - np.min(uncompressed)) / (
np.max(uncompressed) - np.min(uncompressed)
)
m = int(np.sqrt(k))
return uncompressed[:m, :m]
[docs]
def image2dct2image(
self, image: np.ndarray, k: int, threshold: float
) -> tuple[np.ndarray, np.ndarray]:
"""
Compresses the image, then keeps top k DCT features, then uncompresses the image.
Args:
image (np.ndarray): 2D numpy array (e.g. 28x28 for MNIST)
k (int): number of coefficients to keep.
(k must be chosen so that sqrt(k) is an integer; we keep a top-left block of size m x m, where m = sqrt(k))
threshold (float): The threshold below which all pixel values will be zero after normalization. Defaults to 0.5
Returns:
np.ndarray: The DCT coefficient matrix with only the top-left m x m = k block kept.
np.ndarray: The top k DCT features uncompressed image.
"""
img = self.compress_image(image.astype(np.float32), k)
return img, self.uncompress_image(img,k, threshold)
[docs]
def dataloader(
self,
train: bool = True,
preprocess: bool = False,
random_single: bool = False,
seed: int = 42,
size: int = None,
k: int = None,
threshold: float = 0.5,
pca=False
) -> tuple[np.ndarray, Union[np.ndarray, int]]:
"""
Loads and preprocesses the MNIST dataset.
Args:
train (bool, optional): Whether to load training or testing data. Defaults to True.
preprocess (bool, optional): Whether to preprocess the data into spike trains. Defaults to False.
random_single (bool, optional): Whether to return a single random sample. Defaults to False.
seed (int, optional): Seed for reproducibility. Defaults to 42.
size (int, optional): Number of samples to load. Defaults to None.
k (int, optional): The top k features to keep after taking DCT of the image. Defaults to None means take the whole image.
threshold (float): Used only if DCT is used. The threshold below which all pixel values will be zero after normalization. Defaults to 0.5
pca (bool): If compression should happen according to PCA.
Returns:
np.ndarray or tuple: Preprocessed data and labels, or a single random sample.
"""
if pca:
pPCA = PCA(n_components=0.95, svd_solver='full')
if not random_single:
fin_X = []
fin_Y = []
logger.info(
f"Loading{' preprocessed' if preprocess else ''} {'train' if train else 'test'} data"
)
if train:
if not size:
size = self.parameters.training_images_amount
indices = self.load_balanced_mnist(self.Y_train, size, seed=seed)
X, Y = self.X_train[indices], self.Y_train[indices]
else:
if not size:
size = self.parameters.testing_images_amount
indices = self.load_balanced_mnist(self.Y_test, size, seed=seed)
X, Y = self.X_test[indices], self.Y_test[indices]
if pca==True:
n_samples, h, w = X.shape
X_flat = X.reshape((n_samples, h * w))
X = pPCA.fit_transform(X_flat)
print(X.shape)
for img, label in zip(X, Y):
if isinstance(k, int) and k > 0:
_, img = self.image2dct2image(img, k, threshold)
if preprocess:
if pca:
img = img.reshape((1,-1))
img = self.img2spiketrain(img)
fin_X.append(img)
fin_Y.append(label)
return np.array(fin_X), np.array(fin_Y)
idx = np.random.choice(range(len(self.X_train)))
img = self.X_train[idx]
if isinstance(k, int) and k > 0:
_, img = self.image2dct2image(img, k, threshold)
elif pca:
img = pPCA.fit_transform(img.reshape((1,-1)))[0]
if preprocess:
if pca:
img = img.reshape((1,-1))
return (
img,
self.img2spiketrain(img),
self.Y_train[idx],
)
return img, img, self.Y_train[idx]