IdentifiantMot de passe
Loading...
Mot de passe oublié ?Je m'inscris ! (gratuit)
Navigation

Inscrivez-vous gratuitement
pour pouvoir participer, suivre les réponses en temps réel, voter pour les messages, poser vos propres questions et recevoir la newsletter

Python Discussion :

Comment utiliser torch.save et torch.load en POO pour RL ?


Sujet :

Python

  1. #1
    Membre averti
    Femme Profil pro
    Data
    Inscrit en
    Mai 2023
    Messages
    14
    Détails du profil
    Informations personnelles :
    Sexe : Femme
    Âge : 25
    Localisation : France, Morbihan (Bretagne)

    Informations professionnelles :
    Activité : Data

    Informations forums :
    Inscription : Mai 2023
    Messages : 14
    Par défaut Comment utiliser torch.save et torch.load en POO pour RL ?
    Bonjour ! Je suis une débutante en POO et RL et j'ai besoin de petits conseils pour mon jeu PUISSANCE4
    n'hésitez pas si vous voyez quelque chose de choquant dans mon code ahah

    Mais surtout je me demande où sauvegarder mon enregistrement de données d'entraînement et où le charger pour qu'il soit efficace ? J'ai un peu de mal.
    je sais qu'il faut utiliser torch.save et torch.load...... mais dans quelle classe ?

    Merci d'avance !

    Code : Sélectionner tout - Visualiser dans une fenêtre à part
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    269
    270
    271
    272
    273
    274
    275
    276
    277
    278
    279
    280
    281
    282
    283
    284
    285
    286
    287
    288
    289
    290
    291
    292
    293
    294
    295
    296
    297
    298
    299
    300
    301
    302
    303
    304
    305
    306
    307
    308
    309
    310
    311
    312
    313
    314
    315
    316
    317
    318
    319
    320
    321
    322
    323
    324
    325
    326
    327
    328
    329
    330
    331
    332
    333
    334
    335
    336
    337
    338
    339
    340
    341
    342
    343
    344
    345
    346
    347
    348
    349
    350
    351
    352
    353
    354
    355
    356
    357
    358
    359
    360
    361
    362
    363
    364
    365
    366
    367
    368
    369
    370
    371
    372
    373
    374
    375
    376
    377
    378
    379
    380
    381
    382
    383
    384
    385
    386
    387
    388
    389
    390
    391
    392
    393
    394
    395
    396
    397
    398
    399
    400
    401
    402
    403
    404
    405
    406
    407
    408
    409
    410
    411
    412
    413
    414
    415
    416
    import numpy as np
    from colorama import Fore, Style
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from collections import deque
    import random
    import torch.nn.functional as F
    import pickle
    import os
     
    repo = os.path.dirname(os.path.abspath(__file__))
    path = os.path.join(repo, "modeleIA.pth")
     
    import matplotlib.pyplot as plt
     
    if torch.cuda.is_available():
        device = torch.device('cuda') 
    else:
        device = torch.device('cpu')
     
    class Plateau:
        """ Classe représentant le plateau de jeu """
     
        def __init__(self, rows, columns):
            """ Constructeur de la classe, initialise le plateau et ses dimensions """
            self.rows = rows
            self.columns = columns
            self.plato = np.zeros((rows, columns))
     
        def colonne_check(self, col):
            """ Vérifie si la colonne est pleine """
            for i in range(self.rows):
                if self.plato[i][col] == 0:
                    return True
            return False
     
     
        def placement_jeton(self, col, joueur):
            """ Place le jeton du joueur dans la colonne qu'il sélectionne """
            for i in np.r_[:self.rows][::-1]:
                if self.plato[i][col] == 0:
                    self.plato[i][col] = joueur
                    return True
            return False
     
        def affichage(self, joueur):
            """ Affiche le plateau de jeu à l'instant t """
            couleur = Fore.RED if joueur == 1 else Fore.YELLOW
            print(couleur + str(self.plato) + Style.RESET_ALL)
     
        def check_victoire(self):
            """ Vérifie si un joueur a gagné """
            rows, columns = self.rows, self.columns
     
            # vérification en ligne
            for r in np.r_[:rows]:
                for d in np.r_[:columns-3]:
                    f = d + 4
                    s = np.prod(self.plato[r, d:f])
                    if s == 1 or s == 16:
                        return True
     
            # vérification en colonne
            for c in np.r_[:columns]:
                for d in np.r_[:rows-3]:
                    f = d + 4
                    s = np.prod(self.plato[d:f, c])
                    if s == 1 or s == 16:
                        return True
     
            # vérification en diagonale (bas gauche vers haut droite)
            for r in np.r_[:rows-3]:
                for c in np.r_[:columns-3]:
                    f = c + 4
                    s = np.prod([self.plato[r+i, c+i] for i in range(4)])
                    if s == 1 or s == 16:
                        return True
     
            # vérification en diagonale (haut gauche vers bas droite)
            for r in np.r_[3:rows]:
                for c in np.r_[:columns-3]:
                    f = c + 4
                    s = np.prod([self.plato[r-i, c+i] for i in range(4)])
                    if s == 1 or s == 16:
                        return True
            return False
     
        def get_etat(self):
            """ Obtient l'état actuel du plateau sous forme de tableau 1D """
            return self.plato.flatten()
     
        def get_actions(self):
            """ Obtient les actions possibles à partir de l'état actuel du plateau """
            return np.where(self.plato[0] == 0)[0]
     
    class Joueur:
        """ Classe représentant un joueur humain """
     
        def __init__(self, numero, max_choix):
            """ Initialise le joueur et son numéro et le nombre de choix possibles """
            self.numero = numero
            self.max_choix = max_choix
     
        def jouer(self, state, actions):
            """ Demande au joueur de choisir une colonne """
            while True:
                try:
                    choix = int(input(f'Joueur {self.numero}, à vous de jouer (entre 1 et {self.max_choix}): ')) - 1
                    if choix in actions:
                        return choix
                    else:
                        print("Choix invalide. Essayez à nouveau.")
                except ValueError:
                    print("Ce n'est pas un nombre. Essayez encore.")
     
    class DQNAgent:
        # Dans le RL, l'agent DQN utilise une mémoire appelée "replay memory"
        # pour stocker les XP passées (état/action/rec/prochain état etc)
        # afin de les réutiliser lors de l'apprentissage
     
        # Initialisation de l'agent DQN
        def __init__(self, state_size, action_size):
            self.state_size = state_size
            self.action_size = action_size
            self.memory = deque(maxlen=10000) #pareil que sur le github
            self.gamma = 0.95  # facteur d'actualisation, équilibre recomp im et future
            self.epsilon = 1.0  # taux d'exploration initial
            self.epsilon_min = 0.01  # taux d'exploration minimum
            self.epsilon_decay = 0.995  # taux de décroissance de l'exploration
            self.lr = 0.001  # taux d'apprentissage
            self.model = self._build_model()  # Construire le modèle de réseau neuronal
            self.batch_size = 64
            self.update_every = 5
     
            # réseau principal (d'évaluation) pour choisir les actions et réseau cible pour générer les Q-values cibles
            self.dqn_network = self._build_model().to(device)
            self.target_network = self._build_model().to(device)
            #pour l'instant même dim d'entrée, cachée (64) et de sortie
     
            self.optimizer = optim.Adam(self.dqn_network.parameters(), lr=self.lr)  # Optimiseur courrament utilisé en RL
            self.t_step = 0  # Compteur pour la mise à jour du réseau cible
     
        def charger_modele(self, chemin):
                self.modele.load_state_dict(torch.load(chemin))
                self.modele.eval()
     
        def _build_model(self):
                '''modele de réseau de neurones pour l'apprentissage'''
                model = nn.Sequential(
                    nn.Linear(self.state_size, 64),
                    nn.ReLU(),
                    nn.Linear(64, 64),
                    nn.ReLU(),
                    nn.Linear(64, self.action_size)
                )
                return model
     
        # L'agent choisit l'action selon l'état et la politique epsilon-greedy
        def act(self, state, eps=0.1):
            state = torch.from_numpy(state).float().unsqueeze(0).to(device)  # On convertit l'état en tenseur
            self.dqn_network.eval()  # mode évaluation càd pas de mise à jour des poids
     
            with torch.no_grad():
                action_values = self.dqn_network(state)  # Calculer la Q-valeur pour chaque action
            self.dqn_network.train()  # Repasser le réseau en mode entrainement
     
            # politique epsilon-greedy (exploration/exploitation), à améliorer car
            #pour l'instant génère un nombre aléatoire
            if random.random() > eps:
                return np.argmax(action_values.cpu().data.numpy())  # Choix de l'action avec la plus grande Q-valeur (greedy)
            else:
                return random.choice(np.arange(self.action_size))  # Choix d'une action aléatoire
     
        # Stock l'expérience dans la mémoire de remise en état
        def remember(self, state, action, reward, next_state, done):
            self.memory.append((state, action, reward, next_state, done))
     
        # Prend une action et apprendre à partir de l'expérience
        def step(self, state, action, reward, next_state, done):
            self.remember(state, action, reward, next_state, done)  # Stocker l'expérience
            self.epsilon *= self.epsilon_decay #on multiplie par le facteur de décroissance
            self.learn()  # Apprendre de l'expérience
     
        # Apprentissage à partir de l'expérience (implémentation de l'équation de Belman)
        def learn(self):
            # mémoire de remise en état soit assez grande
            if len(self.memory) < self.batch_size:
                return
     
            # on échantillonne un batch d'expériences de taille aléatoire
            experiences = random.sample(self.memory, self.batch_size)
            # dézip les élements de l'échantillon 
            states, actions, rewards, next_states, dones = zip(*experiences)
     
            # Convertion des expériences numpy --> tenseur pytorch
            states = torch.from_numpy(np.vstack(states)).float().to(device)
            actions = torch.from_numpy(np.vstack(actions)).long().to(device)
            rewards = torch.from_numpy(np.vstack(rewards)).float().to(device)
            next_states = torch.from_numpy(np.vstack(next_states)).float().to(device)
            dones = torch.from_numpy(np.vstack(dones).astype(np.int64)).float().to(device)
     
            # Calcul les Q-valeurs cibles et attendues
     
            #on obtient les qvaleur prédites à partir du modèle cible
            Q_cible_next = self.target_network(next_states).detach().max(1)[0].unsqueeze(1)  # Q-valeurs cibles pour les prochains états
     
            #on calcule les Q cibles pour les états actuels
            Q_cible = rewards + (self.gamma * Q_cible_next * (1 - dones))  # Q-valeurs cibles pour les états actuels
     
            #on calcul les q attendus à partir du modèle
            Q_expected = self.dqn_network(states).gather(1, actions)  # Q-valeurs attendues pour les états actuels
     
            # Calcul de la perte et rétropropagation de l'erreur
            loss = F.mse_loss(Q_expected, Q_cible)  # Calcul la perte w/ MSE
            # On minimise la fonction de perte
            self.optimizer.zero_grad()  # Réinitialisation gradients
            loss.backward()  # Rétropropagation de l'erreur
            self.optimizer.step()  # Mise à jour des poids du réseau
     
            # Mise à jour du réseau cible on le copie du réseau DQN
            self.t_step = (self.t_step + 1) % self.update_every
            if self.t_step == 0:
                self.target_network.load_state_dict(self.dqn_network.state_dict())
     
        def sauvegarder_modele(self, path):
            # Sauvegarde du modèle PyTorch avec pickle
            state = {
                    'dqn_network_state_dict': self.dqn_network.state_dict(),
                    'target_network_state_dict': self.target_network.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'memory': self.memory,
                    'epsilon': self.epsilon,
                }
     
            #with open(chemin, "wb") as fichier:
               #pickle.dump(self.model, fichier)
     
        def charger_modele(self, path):
            # Chergament du modèle Pytorch avec pickle
     
            #with open(chemin, "rb") as fichier:
            #    self.model = pickle.load(fichier)
            if os.path.exists(path):
                state = torch.load(path)
                self.dqn_network.load_state_dict(state['dqn_network_state_dict'])
                self.target_network.load_state_dict(state['target_network_state_dict'])
                self.optimizer.load_state_dict(state['optimizer_state_dict'])
                self.memory = state['memory']
                self.epsilon = state['epsilon']
                print(f"Modèle chargé depuis {path}")
            else:
                print(f"Aucun modèle trouvé à {path}")
     
    class IA:
        def __init__(self, numero, max_choix, state_size, agent=None):
            self.numero = numero
            self.max_choix = max_choix
            if agent is None:
                self.agent= DQNAgent(state_size, max_choix)
            else:
                self.agent= agent
     
        def jouer(self, state, actions):
            """ Choisit une colonne en utilisant l'agent IA """
            proba_victoires = self.calculer_proba_victoires(state, actions)
            # Trie en fonction des probabilités de victoire
            tri_indice_action = np.argsort(proba_victoires)
            # actions de la plus probable à la moins probable
            for indice in reversed(tri_indice_action):
                action = actions[indice]
                # Si l'action est possible, on la retourne
                if action in actions:
                    return action
            #(ne devrait pas arriver), retourne une action aléatoire
            return np.random.choice(actions)
     
     
        def calculer_proba_victoires(self, state, actions):
            proba_victoires = np.zeros(len(actions))
            for i, action in enumerate(actions):
                next_state = state.copy()
                next_state[action] = self.numero
                proba_victoires[i] = self.agent.act(next_state)
            return proba_victoires
     
        def apprendre(self, state, action, reward, next_state, done):
            self.agent.step(state, action, reward, next_state, done)
     
    class Jeu:
        """ Classe représentant le jeu en lui-même """
     
        def __init__(self, rows, columns, joueurs):
            """ Initialise le jeu et les joueurs """
            self.plato = Plateau(rows, columns)
            self.joueurs = joueurs
     
        def play(self):
                """ Lance le jeu et vérifie si un joueur a gagné ou si la partie est nulle """
                state = self.plato.get_etat()
                while True:
                    for joueur in self.joueurs:
                        print(f"Joueur {joueur.numero}")
                        actions = self.plato.get_actions()
                        choix = joueur.jouer(state, actions)
                        self.plato.placement_jeton(choix, joueur.numero)
                        self.plato.affichage(joueur.numero)
     
                        if self.plato.check_victoire():
                            print(f"Joueur {joueur.numero} a gagné!")
                            self.plato.affichage(joueur.numero)  # Afficher le plateau final
                            return joueur.numero
                        elif np.all(self.plato.plato != 0):
                            print("Match nul!")
                            return None
     
                        state = self.plato.get_etat()
     
    class Entrainement:
     
        def __init__(self, lignes, colonnes, episodes):
            self.lignes = lignes
            self.colonnes = colonnes
            self.episodes = episodes
            self.victoires = {1: 0, 2: 0, 'Nulles': 0}
     
            ### Liste vides pour enregistrer les données
     
            self.episodes_list = []
            self.victoires_joueur1 = []
            self.victoires_joueur2 = []
            self.parties_nulles = []
     
        def commencer(self):
            print("Choisissez le mode :")
            print("1. Jouer contre l'IA 1")
            print("2. Jouer contre l'IA 2")
            print("3. IA 1 vs IA 2")
            print("4. Jouer humain contre humain")
            print("5. Entraîner les deux IA entre elles")
            choix = int(input("Votre choix : "))
     
            if choix == 1:
                joueur_humain = Joueur(1, colonnes)
                agent_IA1 = DQNAgent(colonnes * lignes, colonnes)
                joueur_IA1 = IA(2, colonnes, colonnes * lignes, agent_IA1)
                joueurs = [joueur_humain, joueur_IA1]
            elif choix == 2:
                joueur_humain = Joueur(1, colonnes)
                agent_IA2 = DQNAgent(colonnes * lignes, colonnes)
                joueur_IA2 = IA(2, colonnes, colonnes * lignes, agent_IA2)
                joueurs = [joueur_humain, joueur_IA2]
            elif choix == 3:
                agent_IA1 = DQNAgent(colonnes * lignes, colonnes)
                agent_IA2 = DQNAgent(colonnes * lignes, colonnes)
                joueur_IA1 = IA(1, colonnes, colonnes * lignes, agent_IA1)
                joueur_IA2 = IA(2, colonnes, colonnes * lignes, agent_IA2)
                joueurs = [joueur_IA1, joueur_IA2]
            elif choix == 4:
                joueur_humain1 = Joueur(1, colonnes)
                joueur_humain2 = Joueur(2, colonnes)
                joueurs = [joueur_humain1, joueur_humain2]
            elif choix == 5:
                agent_IA1 = DQNAgent(colonnes * lignes, colonnes)
                agent_IA2 = DQNAgent(colonnes * lignes, colonnes)
                joueur_IA1 = IA(1, colonnes, colonnes * lignes, agent_IA1)
                joueur_IA2 = IA(2, colonnes, colonnes * lignes, agent_IA2)
                joueurs = [joueur_IA1, joueur_IA2]
                self.entrainement_IA(joueurs)
                return
            else:
                print("Mode invalide. Veuillez choisir 1, 2, 3, 4 ou 5.")
                return
     
            for i in range(self.episodes):
                print(f"Épisode {i+1}/{self.episodes}")
                jeu = Jeu(lignes, colonnes, joueurs)
                vainqueur = jeu.play()
                if vainqueur is not None:
                    self.victoires[vainqueur] += 1
                else:
                    self.victoires['Nulles'] += 1
                print(f"Taux de victoire Joueur 1 : {self.victoires[1]/(i+1):.2f}")
                print(f"Taux de victoire Joueur 2 : {self.victoires[2]/(i+1):.2f}")
                print(f"Parties nulles : {self.victoires['Nulles']/(i+1):.2f}")
     
        def entrainement_IA(self, joueurs):
     
            for i in range(self.episodes):
                print(f"Épisode {i+1}/{self.episodes}")
                jeu = Jeu(lignes, colonnes, joueurs)
                vainqueur = jeu.play()
                if vainqueur is not None:
                    self.victoires[vainqueur] += 1
                else:
                    self.victoires['Nulles'] += 1
     
                # Sauvefarde des données
                self.episodes_list.append(i + 1)
                self.victoires_joueur1.append(self.victoires[1] / (i + 1))
                self.victoires_joueur2.append(self.victoires[2] / (i + 1))
                self.parties_nulles.append(self.victoires['Nulles'] / (i + 1))
     
                print(f"Taux de victoire Joueur 1 : {self.victoires[1]/(i+1):.2f}")
                print(f"Taux de victoire Joueur 2 : {self.victoires[2]/(i+1):.2f}")
                print(f"Parties nulles : {self.victoires['Nulles']/(i+1):.2f}")
     
        # graphique inutile à changer + tard
            plt.plot(self.episodes_list, self.victoires_joueur1, label='Taux de victoire Joueur 1')
            plt.plot(self.episodes_list, self.victoires_joueur2, label='Taux de victoire Joueur 2')
            plt.plot(self.episodes_list, self.parties_nulles, label='Parties nulles')
            plt.xlabel('Épisodes')
            plt.ylabel('Taux de victoire')
            plt.legend()
            plt.savefig('graphique_evolution.png')
            plt.show()

  2. #2
    Expert éminent
    Homme Profil pro
    Architecte technique retraité
    Inscrit en
    Juin 2008
    Messages
    21 736
    Détails du profil
    Informations personnelles :
    Sexe : Homme
    Localisation : France, Manche (Basse Normandie)

    Informations professionnelles :
    Activité : Architecte technique retraité
    Secteur : Industrie

    Informations forums :
    Inscription : Juin 2008
    Messages : 21 736
    Par défaut
    Salut,
    Citation Envoyé par Judicieusement Voir le message
    n'hésitez pas si vous voyez quelque chose de choquant dans mon code ahah
    Normalement, on ne code pas sans un travail de conception préalable. Le code doit traduire ce travail de conception et... en lisant le code on devrait pouvoir retrouver le découpage décidé dans la conception.

    Citation Envoyé par Judicieusement Voir le message
    Mais surtout je me demande où sauvegarder mon enregistrement de données d'entraînement et où le charger pour qu'il soit efficace ? J'ai un peu de mal.
    Vous avez déjà des méthodes charger_modele et sauvegarder_modele dans la classe DQNAgent qui semblent traduire ces opérations.
    Côté conception, la question est d'abord "quand" (fonction du besoin) puis le découpage (arbitraire) du code en fonctionnalités (fait à la conception) va donner le "où"... appeler ces méthodes.

    note: il y a 2 méthodes charger_modele dans cette classe!

    - W
    Architectures post-modernes.
    Python sur DVP c'est aussi des FAQs, des cours et tutoriels

Discussions similaires

  1. Réponses: 4
    Dernier message: 05/03/2014, 12h35
  2. Réponses: 2
    Dernier message: 03/07/2008, 09h10
  3. Réponses: 1
    Dernier message: 19/04/2007, 09h08
  4. Réponses: 5
    Dernier message: 11/06/2002, 15h21

Partager

Partager
  • Envoyer la discussion sur Viadeo
  • Envoyer la discussion sur Twitter
  • Envoyer la discussion sur Google
  • Envoyer la discussion sur Facebook
  • Envoyer la discussion sur Digg
  • Envoyer la discussion sur Delicious
  • Envoyer la discussion sur MySpace
  • Envoyer la discussion sur Yahoo