Skip to content
Snippets Groups Projects
Commit 6d774326 authored by Michael Mutote's avatar Michael Mutote
Browse files

22202956 - optimised training supervised training

parent 77bfb9d5
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import Training_data import Training_data
rng = np.random.default_rng(123) rng = np.random.default_rng(123)
TEACHDATA = 100000 TEACHDATA = 10000
TESTDATA = 1000 TESTDATA = 1000
# ETA = 0.5 # ETA = 0.5
T_NUMBER = 1 # Number to be detected 0-6 T_NUMBER = 1 # Number to be detected 0-6
......
import numpy as np import numpy as np
import Perceptrons import Perceptrons
import PerceptronsSGD
def test_function(ETA, p): def test_function(ETA, p):
...@@ -18,13 +17,13 @@ for ETA in ([0.05, 0.1, 0.2, 0.4, 0.75, 1, 2, 5]): # the list of values for ETA ...@@ -18,13 +17,13 @@ for ETA in ([0.05, 0.1, 0.2, 0.4, 0.75, 1, 2, 5]): # the list of values for ETA
x.train(ETA/200) x.train(ETA/200)
y = Perceptrons.SGDPerceptron(20) y = Perceptrons.SGDPerceptron(20)
y.train(ETA) y.train(ETA)
for i in range(10): for i in range(1):
res = test_function(ETA, w) res = test_function(ETA, w)
print("Thres", res) # print the results list print("Thres", res) # print the results list
for i in range(10): for i in range(1):
res = test_function(ETA/200, x) res = test_function(ETA/200, x)
print("Lin", res) # print the results list print("Lin", res) # print the results list
for i in range(10): for i in range(1):
res = test_function(ETA, y) res = test_function(ETA, y)
print("sgd", res) # print the results list print("sgd", res) # print the results list
print("\n\n") print("\n\n")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment