import json
import random

NUM_GUMDROPS = 8


def choose(weights):
    total = sum(weights)
    roll = random.randint(1, total)
    for i, weight in enumerate(weights):
        if roll <= weight:
            return i
        roll -= weight


def check_game_end(state):
    winning_combinations = [(0, 1, 2), (3, 4, 5), (6, 7, 8),
                            (0, 3, 6), (1, 4, 7), (2, 5, 8),
                            (0, 4, 8), (2, 4, 6)]
    for a, b, c in winning_combinations:
        if state[a] == state[b] == state[c] != '-':
            return True, state[a]
    if '-' not in state:
        return True, None
    return False, None


def make_move(state, move, player):
    new_state = list(state)
    new_state[move] = player
    return ''.join(new_state)


def get_legal_moves(state):
    return [i for i, spot in enumerate(state) if spot == '-']


def print_board(state):
    print("Board positions:")
    for i in range(1, 10, 3):
        print(f"{i} {i+1} {i+2}")
    print("\nCurrent board:")
    for i in range(0, 9, 3):
        print(' '.join(state[i:i+3]))


def human_move(state):
    print_board(state)
    while True:
        try:
            move = int(input("Enter your move (1-9): ")) - 1
            if move in get_legal_moves(state):
                return move
            else:
                print("Invalid move, try again.")
        except ValueError:
            print("Please enter a number.")


def save_menace_state(menace):
    with open('MENACE/menace_state.json', 'w') as file:
        json.dump(menace, file)


def load_menace_state():
    try:
        with open('menace_state.json', 'r') as file:
            return json.load(file)
    except FileNotFoundError:
        return {}


def one_game(menace, human_player=False, self_play=False):
    state = "---------"
    history = []
    game_end, winner = check_game_end(state)
    while not game_end:
        player = 'X' if state.count('X') == state.count('O') else 'O'
        if human_player and player == 'O':
            move = human_move(state)
        else:
            if self_play or state not in menace:
                num_holes = state.count('-')
                menace[state] = [NUM_GUMDROPS for _ in range(num_holes)]
            move_weights = menace[state]
            legal_moves = get_legal_moves(state)
            chosen_index = choose(move_weights)
            move = legal_moves[chosen_index]

        history.append((state, move))
        state = make_move(state, move, player)
        game_end, winner = check_game_end(state)
    if human_player:
        print_board(state)
        if winner:
            print(f"Game ended, winner: {winner}")
        else:
            print("Game ended, draw")
    update_menace(menace, history, winner)
    return winner


def update_menace(menace, history, winner):
    reward = 10
    penalty = 1
    for state, move in reversed(history):
        if state in menace:
            index = get_legal_moves(state).index(move)
            if (winner == 'X' and state.count('X') == state.count('O')) or \
                    (winner == 'O' and state.count('X') > state.count('O')):
                menace[state][index] += reward
            else:
                menace[state][index] = max(1, menace[state][index] - penalty)


def play_again():
    choice = input("Play again? (y/n): ").lower()
    return choice == 'y'


def train_menace_self_play(menace, games):
    for i in range(games):
        one_game(menace, human_player=False, self_play=True)
        if i % 1000 == 0 or i == games - 1:
            print(
                f"Self-Play Training Progress: {i / games * 100:.1f}%", end="\r")
    print()



menace = load_menace_state()

# Training MENACE
train_menace_self_play(menace, 50000)

save_menace_state(menace)

# Play against MENACE
while True:
    one_game(menace, human_player=True)
    if play_again():
        continue
    else:
        break

save_menace_state(menace)