diff --git a/Reinforcement_Learning/Perceptrons.py b/Reinforcement_Learning/Perceptrons.py index 66c47b54156198ce67b468745b5f549ed7af217b..c9468354efa5906523d30c9aeaff57ef0f8ed596 100644 --- a/Reinforcement_Learning/Perceptrons.py +++ b/Reinforcement_Learning/Perceptrons.py @@ -26,7 +26,7 @@ class Perceptron: self.activation = activation self.weights = rng.random(input_count + 1) - def train(self, ETA): + def train_thres(self, ETA): teach_data = Training_data.make_testset(TEACHDATA) for i in range(TEACHDATA): old_weights = np.copy(self.weights) @@ -38,7 +38,6 @@ class Perceptron: delta = ETA * \ (T - self.activation(ix.dot(self.weights))) * ix self.weights = self.weights + delta - # print(self.weights[0], self.weights[1], self.weights[2], self.weights[3], self.weights[4], self.weights[5], self.weights[6]) if np.linalg.norm(old_weights - self.weights) == 0.00: return self.weights return self.weights diff --git a/Reinforcement_Learning/PerceptronsSGD.py b/Reinforcement_Learning/PerceptronsSGD.py index 805f222798b1ba6a576a4f5d444fe26ef1f27b7b..71d1e815846e718e109cb6ee5c9767ec5a165a46 100644 --- a/Reinforcement_Learning/PerceptronsSGD.py +++ b/Reinforcement_Learning/PerceptronsSGD.py @@ -34,7 +34,7 @@ class PerceptronSGD: RI = self.activation(z) error = T - RI delta = ETA * error * self.activation_derivative(z) * ix - self.weights += delta + self.weights += delta return self.weights def test(self): diff --git a/Reinforcement_Learning/Solution_Testing_1.py b/Reinforcement_Learning/Solution_Testing_1.py index 334b49e9265795a98c969195f83585540e60c02e..a308778f62f65ec507015c6ee25fff68328a3e33 100644 --- a/Reinforcement_Learning/Solution_Testing_1.py +++ b/Reinforcement_Learning/Solution_Testing_1.py @@ -1,4 +1,5 @@ import Perceptrons +import PerceptronsSGD def test_function(ETA): @@ -6,14 +7,14 @@ def test_function(ETA): results = [] p = Perceptrons.Perceptron(input_count) - p.train(ETA) + p.train_thres(ETA) output = p.test() results.append((ETA, output)) return results for ETA in ([0.05, 0.1, 0.2, 0.4, 0.75, 1, 2, 5]): # the list of values for ETA - for i in range(5): + for i in range(1): res = test_function(ETA) print(res) # print the results list print("\n\n")