Skip to content
Snippets Groups Projects
Training_data.py 1.82 KiB
import numpy as np
import copy
ideal = dict([])

ideal[0] = np.array([[0, 1, 1, 0],
                     [1, 0, 0, 1],
                     [1, 0, 0, 1],
                     [1, 0, 0, 1],
                     [0, 1, 1, 0]])

ideal[1] = np.array([[0, 0, 1, 0],
                     [0, 1, 1, 0],
                     [0, 0, 1, 0],
                     [0, 0, 1, 0],
                     [0, 0, 1, 0]])

ideal[2] = np.array([[0, 1, 1, 0],
                     [1, 0, 0, 1],
                     [0, 0, 1, 0],
                     [0, 1, 0, 0],
                     [1, 1, 1, 1]])

ideal[3] = np.array([[0, 1, 1, 0],
                     [1, 0, 0, 1],
                     [0, 0, 1, 0],
                     [1, 0, 0, 1],
                     [0, 1, 1, 0]])

ideal[4] = np.array([[0, 0, 1, 0],
                     [0, 1, 1, 0],
                     [1, 0, 1, 0],
                     [1, 1, 1, 1],
                     [0, 0, 1, 0]])

ideal[5] = np.array([[1, 1, 1, 1],
                     [1, 0, 0, 0],
                     [1, 1, 1, 0],
                     [0, 0, 0, 1],
                     [1, 1, 1, 0]])

ideal[6] = np.array([[0, 1, 1, 0],
                     [1, 0, 0, 0],
                     [1, 1, 1, 0],
                     [1, 0, 0, 1],
                     [0, 1, 1, 0]])


def make_testset(set_size, NOISE):
    data = [[] for _ in range(len(ideal))]
    rng = np.random.default_rng(123)
    for number, value in ideal.items():
        if set_size > 1:
            data[number].append(value)
        for _ in range(set_size):
            # scale is the standard deviation affecting the "spread" plays a role int eh results
            # new_digit = ideal[number] + rng.normal(loc=0, scale=0.3, size=(5, 4))
            new_digit = ideal[number] + rng.normal(loc=0, scale=NOISE, size=(5, 4))
            data[number].append(new_digit)
    return data