Bonjour,
Je souhaiterais une discussion et un retour de votre part. Ceci d'une part pour savoir si je me plante et d'autre part pour améliorer mon code...
Présentation :
Durant cette période de confinement j'ai décidé de me plonger dans le domaine des réseaux de neurones appliqués à la détection de chiffres à partir du tout début.
J'ai donc codé en Java un Perceptron, un réseau de neurones (propagation et rétropropagation du gradient de l'erreur). J'ai implémenté les fonctions d'activation de type sigmoid, tanh et PreRELU (RELU pas réussi à le faire fonctionner : le réseau se bloque de suite...) et j'ai bien galéré pour appréhender les hyper paramètres.
Bref, à ce jour ça marche ! J'utilise la base d'images MNIST et j'ai obtenu un score de 95,4%. J'aurais pu avoir plus mais j'ai arrêté le prog au bout de 8h...
Je trouve l'apprentissage long : pour parcourir 30 000 images par exemple je mets 1mn ... sachant qu'un tiers du temps correspond à la propagation et le reste à la rétropropagation.
Mon réseau à 784 entrées (28*28), une couche d'entrée de 784 neurones, une couche cachée de 100 neurones et une couche de sortie de 10 neurones
Mon PC est un I7 4600U @2.10 Gz (4 core)
Trouvez vous ce temps long, trop long ou raisonnable ?
Mon deuxième point maintenant : la recherche d'amélioration du temps d'execution
J'ai voulu essayer de profiter pleinement du processeur. Seulement utilisé qu'à 25% au max, j'ai décidé de commencer à optimiser la partie "propagation". J'ai donc lancé mes calculs par des Threads.
Et là comble de l'histoire :
- le temps d'execution est plus long,
- le processeur est utilisé à 100%.
Après avoir testé les threads simplement (temps catastrophiques) j'ai ensuite utilisé les services "Executor" qui permet de n’exécuter que 4 thread à la fois (1 par proc). Ce qui a très nettement amélioré les temps mais pas assez. J'arrive à des temps équivalent que lorsque j'ai des boucles de plus de 10000 itérations pour un Perceptron. Ce qui en pratique n'arrive jamais dans une couche de 800 neurones... J'ai l'impression que la mise en oeuvre du mecanisme des Threads est trop longue et fait perdre du temps d’exécution....
D’où ma deuxième série de questions :
- existe-t-il en java une autre manière de gérer des Thread (pour des petits calculs) ?
- d'une manière plus globale, existe-t-il une manière de calculer une propagation (et une rétropropagtion) adaptée aux threads ?
Merci pour vos retours d'expériences !
Bien cordialement,
Xavier
Partager