1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
#!usr/bin/python3.5
#-*-coding:UTF-8 -*
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import pickle
from tensorflow.examples.tutorials.mnist import input_data
from numpy import *
import tensorflow as tf
import matplotlib.pyplot as plt
FLAGS = None
def mlb(data, chemin_donnee = None):
try:
# creation des variables
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, W) + b
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
#sauvegarder
saver = tf.train.Saver()
sess = tf.InteractiveSession()
#je restaure les donnée.
saver.restore(sess, "modeles/basique/model_basique.ckpt")
print("Model restored.")
#tf.global_variables_initializer().run()
except:
print("fichier non trouver!")
if not chemin_donnee == None:
# importation des données
mnist = input_data.read_data_sets(chemin_donnee, one_hot=True)
else:
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
# creation des variables
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, W) + b
#initialisation des variables
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
#entrainement du modele
for _ in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
# verification de l'entrainement du modele
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
saver = tf.train.Saver()
#sauvegarde modele
save_path = saver.save(sess, "modeles/basique/model_basique.ckpt")
print("Model saved in file: %s" % save_path)
#test modele
result2 = sess.run(tf.argmax(y,1), feed_dict={x: [data]})
plt.matshow(tf.reshape(data,(28,28)).eval())
plt.show()
print ('resultat ', result2)
if __name__ == "__main__":
import scipy.ndimage
from PIL import Image
version = 3
if version == 3:
#tableau_img = scipy.ndimage.imread("test/0v0.bmp", flatten=True)
chiffre = Image.open("test/0v1.bmp").convert("L")
data = (255 - array(chiffre.getdata()))/255
mlb(data, chemin_donnee= "./MNIST_data")
chiffre = Image.open("test/0v2.bmp").convert("L")
data = (255 - array(chiffre.getdata()))/255
mlb(data, chemin_donnee= "./MNIST_data")
chiffre = Image.open("test/0v3.bmp").convert("L")
data = (255 - array(chiffre.getdata()))/255
mlb(data, chemin_donnee= "./MNIST_data")
chiffre = Image.open("test/0v4.bmp").convert("L")
data = (255 - array(chiffre.getdata()))/255
mlb(data, chemin_donnee= "./MNIST_data")
chiffre = Image.open("test/0v5.bmp").convert("L")
data = (255 - array(chiffre.getdata()))/255
mlb(data, chemin_donnee= "./MNIST_data")
chiffre = Image.open("test/1v1.bmp").convert("L")
data = (255 - array(chiffre.getdata()))/255
mlb(data, chemin_donnee= "./MNIST_data")
chiffre = Image.open("test/1v2.bmp").convert("L")
data = (255 - array(chiffre.getdata()))/255
mlb(data, chemin_donnee= "./MNIST_data")
chiffre = Image.open("test/1v3.bmp").convert("L")
data = (255 - array(chiffre.getdata()))/255
mlb(data, chemin_donnee= "./MNIST_data")
chiffre = Image.open("test/1v4.bmp").convert("L")
data = (255 - array(chiffre.getdata()))/255
mlb(data, chemin_donnee= "./MNIST_data")
chiffre = Image.open("test/1v5.bmp").convert("L")
data = (255 - array(chiffre.getdata()))/255
mlb(data, chemin_donnee= "./MNIST_data") |
Partager