IdentifiantMot de passe
Loading...
Mot de passe oublié ?Je m'inscris ! (gratuit)
Navigation

Inscrivez-vous gratuitement
pour pouvoir participer, suivre les réponses en temps réel, voter pour les messages, poser vos propres questions et recevoir la newsletter

Python Discussion :

les arbres de décision


Sujet :

Python

  1. #1
    Nouveau membre du Club
    Homme Profil pro
    Étudiant
    Inscrit en
    Novembre 2016
    Messages
    98
    Détails du profil
    Informations personnelles :
    Sexe : Homme
    Localisation : France

    Informations professionnelles :
    Activité : Étudiant
    Secteur : Enseignement

    Informations forums :
    Inscription : Novembre 2016
    Messages : 98
    Points : 26
    Points
    26
    Par défaut les arbres de décision
    salut a tous.

    j'ai un programme qui implémente l'arbre de décision, la construction d'arbre en forme string ça marche bien, mais y a un problème lors de l'affichage de l'arbre avec une fonction dotgraph().
    pouvez vous m'aider, et merci d'avance.

    erreur:
    Code : Sélectionner tout - Visualiser dans une fenêtre à part
    1
    2
    3
    4
    5
     
    line 276, in dotgraph
        p_node = dcParent[szSplit]
     
    KeyError: '3-PetalLength >= 4.9'
    programme:

    Code : Sélectionner tout - Visualiser dans une fenêtre à part
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    193
    194
    195
    196
    197
    198
    199
    200
    201
    202
    203
    204
    205
    206
    207
    208
    209
    210
    211
    212
    213
    214
    215
    216
    217
    218
    219
    220
    221
    222
    223
    224
    225
    226
    227
    228
    229
    230
    231
    232
    233
    234
    235
    236
    237
    238
    239
    240
    241
    242
    243
    244
    245
    246
    247
    248
    249
    250
    251
    252
    253
    254
    255
    256
    257
    258
    259
    260
    261
    262
    263
    264
    265
    266
    267
    268
    269
    270
    271
    272
    273
    274
    275
    276
    277
    278
    279
    280
    281
    282
    283
    284
    285
    286
    287
    288
    289
    290
    291
    292
    293
    294
    295
    296
    297
    298
    299
    300
    301
    302
    303
    304
    305
    306
    307
    308
    309
    310
    311
    312
    313
    314
    315
    316
    317
    318
    319
    320
    321
    322
    323
    324
    325
    326
    327
    328
    329
    330
    331
    332
    333
    334
    335
    336
    337
    338
    339
    340
    341
    342
    343
    344
    345
    346
    347
     
     
    import csv
    from collections import defaultdict
    import pydotplus
     
     
     
     
     
    class DecisionTree:
        """Binary tree implementation with true and false branch. """
        def __init__(self, col=-1, value=None, trueBranch=None, falseBranch=None, results=None, summary=None):
            self.col = col
            self.value = value
            self.trueBranch = trueBranch
            self.falseBranch = falseBranch
            self.results = results # None for nodes, not None for leaves
            self.summary = summary
     
     
    def divideSet(rows, column, value):
        splittingFunction = None
        if isinstance(value, int) or isinstance(value, float): # for int and float values
            splittingFunction = lambda row : row[column] >= value
        else: # for strings
            splittingFunction = lambda row : row[column] == value
        list1 = [row for row in rows if splittingFunction(row)]
        list2 = [row for row in rows if not splittingFunction(row)]
        return (list1, list2)
     
     
    def uniqueCounts(rows):
        results = {}
        for row in rows:
            #response variable is in the last column
            r = row[-1]
            if r not in results: results[r] = 0
            results[r] += 1
        return results
     
     
    def entropy(rows):
        from math import log
        log2 = lambda x: log(x)/log(2)
        results = uniqueCounts(rows)
     
        entr = 0.0
        for r in results:
            p = float(results[r])/len(rows)
            entr -= p*log2(p)
        return entr
     
     
    def gini(rows):
        total = len(rows)
        counts = uniqueCounts(rows)
        imp = 0.0
     
        for k1 in counts:
            p1 = float(counts[k1])/total
            for k2 in counts:
                if k1 == k2: continue
                p2 = float(counts[k2])/total
                imp += p1*p2
        return imp
     
     
    def variance(rows):
        if len(rows) == 0: return 0
        data = [float(row[len(row) - 1]) for row in rows]
        mean = sum(data) / len(data)
     
        variance = sum([(d-mean)**2 for d in data]) / len(data)
        return variance
     
     
    def growDecisionTreeFrom(rows, evaluationFunction=entropy):
        """Grows and then returns a binary decision tree.
        evaluationFunction: entropy or gini"""
     
        if len(rows) == 0: return DecisionTree()
        currentScore = evaluationFunction(rows)
     
        bestGain = 0.0
        bestAttribute = None
        bestSets = None
     
        columnCount = len(rows[0]) - 1  # last column is the result/target column
        for col in range(0, columnCount):
            columnValues = [row[col] for row in rows]
     
            #unique values
            lsUnique = list(set(columnValues))
     
            for value in lsUnique:
                (set1, set2) = divideSet(rows, col, value)
     
                # Gain -- Entropy or Gini
                p = float(len(set1)) / len(rows)
                gain = currentScore - p*evaluationFunction(set1) - (1-p)*evaluationFunction(set2)
                if gain>bestGain and len(set1)>0 and len(set2)>0:
                    bestGain = gain
                    bestAttribute = (col, value)
                    bestSets = (set1, set2)
     
        dcY = {'impurity' : '%.3f' % currentScore, 'samples' : '%d' % len(rows)}
        if bestGain > 0:
            trueBranch = growDecisionTreeFrom(bestSets[0], evaluationFunction)
            falseBranch = growDecisionTreeFrom(bestSets[1], evaluationFunction)
            return DecisionTree(col=bestAttribute[0], value=bestAttribute[1], trueBranch=trueBranch,
                                falseBranch=falseBranch, summary=dcY)
        else:
            return DecisionTree(results=uniqueCounts(rows), summary=dcY)
     
     
    def prune(tree, minGain, evaluationFunction=entropy, notify=False):
        """Prunes the obtained tree according to the minimal gain (entropy or Gini). """
        # recursive call for each branch
        if tree.trueBranch.results == None: prune(tree.trueBranch, minGain, evaluationFunction, notify)
        if tree.falseBranch.results == None: prune(tree.falseBranch, minGain, evaluationFunction, notify)
     
        # merge leaves (potentionally)
        if tree.trueBranch.results != None and tree.falseBranch.results != None:
            tb, fb = [], []
     
            for v, c in tree.trueBranch.results.items(): tb += [[v]] * c
            for v, c in tree.falseBranch.results.items(): fb += [[v]] * c
     
            p = float(len(tb)) / len(tb + fb)
            delta = evaluationFunction(tb+fb) - p*evaluationFunction(tb) - (1-p)*evaluationFunction(fb)
            if delta < minGain:
                if notify: print('A branch was pruned: gain = %f' % delta)
                tree.trueBranch, tree.falseBranch = None, None
                tree.results = uniqueCounts(tb + fb)
     
     
    def classify(observations, tree, dataMissing=False):
        """Classifies the observationss according to the tree.
        dataMissing: true or false if data are missing or not. """
     
        def classifyWithoutMissingData(observations, tree):
            if tree.results != None:  # leaf
                return tree.results
            else:
                v = observations[tree.col]
                branch = None
                if isinstance(v, int) or isinstance(v, float):
                    if v >= tree.value: branch = tree.trueBranch
                    else: branch = tree.falseBranch
                else:
                    if v == tree.value: branch = tree.trueBranch
                    else: branch = tree.falseBranch
            return classifyWithoutMissingData(observations, branch)
     
     
        def classifyWithMissingData(observations, tree):
            if tree.results != None:  # leaf
                return tree.results
            else:
                v = observations[tree.col]
                if v == None:
                    tr = classifyWithMissingData(observations, tree.trueBranch)
                    fr = classifyWithMissingData(observations, tree.falseBranch)
                    tcount = sum(tr.values())
                    fcount = sum(fr.values())
                    tw = float(tcount)/(tcount + fcount)
                    fw = float(fcount)/(tcount + fcount)
                    result = defaultdict(int) # Problem description: http://blog.ludovf.net/python-collections-defaultdict/
                    for k, v in tr.items(): result[k] += v*tw
                    for k, v in fr.items(): result[k] += v*fw
                    return dict(result)
                else:
                    branch = None
                    if isinstance(v, int) or isinstance(v, float):
                        if v >= tree.value: branch = tree.trueBranch
                        else: branch = tree.falseBranch
                    else:
                        if v == tree.value: branch = tree.trueBranch
                        else: branch = tree.falseBranch
                return classifyWithMissingData(observations, branch)
     
        # function body
        if dataMissing:
            return classifyWithMissingData(observations, tree)
        else:
            return classifyWithoutMissingData(observations, tree)
     
     
    def plot(decisionTree):
        """Plots the obtained decision tree. """
        def toString(decisionTree, indent=''):
            if decisionTree.results != None:  # leaf node
                return str(decisionTree.results)
            else:
                szCol = 'Column %s' % decisionTree.col
                if szCol in dcHeadings:
                    szCol = dcHeadings[szCol]
                if isinstance(decisionTree.value, int) or isinstance(decisionTree.value, float):
                    decision = '%s >= %s?' % (szCol, decisionTree.value)
                else:
                    decision = '%s == %s?' % (szCol, decisionTree.value)
                trueBranch = indent + 'yes -> ' + toString(decisionTree.trueBranch, indent + '\t\t')
                falseBranch = indent + 'no  -> ' + toString(decisionTree.falseBranch, indent + '\t\t')
                return (decision + '\n' + trueBranch + '\n' + falseBranch)
     
        print(toString(decisionTree))
     
     
    def dotgraph(decisionTree):
        global dcHeadings
        dcNodes = defaultdict(list)
        """Plots the obtained decision tree. """
        def toString(iSplit, decisionTree, bBranch, szParent = "null", indent=''):
            if decisionTree.results != None:  # leaf node
                lsY = []
                for szX, n in decisionTree.results.items():
                        lsY.append('%s:%d' % (szX, n))
                dcY = {"name": "%s" % ', '.join(lsY), "parent" : szParent}
                dcSummary = decisionTree.summary
                dcNodes[iSplit].append(['leaf', dcY['name'], szParent, bBranch, dcSummary['impurity'],
                                        dcSummary['samples']])
                return dcY
            else:
                szCol = 'Column %s' % decisionTree.col
                if szCol in dcHeadings:
                        szCol = dcHeadings[szCol]
                if isinstance(decisionTree.value, int) or isinstance(decisionTree.value, float):
                        decision = '%s >= %s' % (szCol, decisionTree.value)
                else:
                        decision = '%s == %s' % (szCol, decisionTree.value)
                trueBranch = toString(iSplit+1, decisionTree.trueBranch, True, decision, indent + '\t\t')
                falseBranch = toString(iSplit+1, decisionTree.falseBranch, False, decision, indent + '\t\t')
                dcSummary = decisionTree.summary
                dcNodes[iSplit].append([iSplit+1, decision, szParent, bBranch, dcSummary['impurity'],
                                        dcSummary['samples']])
                return
     
        toString(0, decisionTree, None)
        lsDot = ['digraph Tree {',
                    'node [shape=box, style="filled, rounded", color="black", fontname=helvetica] ;',
                    'edge [fontname=helvetica] ;'
        ]
        i_node = 0
        dcParent = {}
        for nSplit, lsY in dcNodes.items():
            for lsX in lsY:
                iSplit, decision, szParent, bBranch, szImpurity, szSamples =lsX
                if type(iSplit) == int:
                    szSplit = '%d-%s' % (iSplit, decision)
                    dcParent[szSplit] = i_node
                    lsDot.append('%d [label=<%s<br/>impurity %s<br/>samples %s>, fillcolor="#e5813900"] ;' % (i_node,
                                            decision.replace('>=', '&ge;').replace('?', ''),
                                            szImpurity,
                                            szSamples))
                else:
                    lsDot.append('%d [label=<impurity %s<br/>samples %s<br/>class %s>, fillcolor="#e5813900"] ;' % (i_node,
                                            szImpurity,
                                            szSamples,
                                            decision))
     
                if szParent != 'null':
                    if bBranch:
                        szAngle = '45'
                        szHeadLabel = 'True'
                    else:
                        szAngle = '-45'
                        szHeadLabel = 'False'
                    szSplit = '%d-%s' % (nSplit, szParent)
                    p_node = dcParent[szSplit]
                    if nSplit == 1:
                        lsDot.append('%d -> %d [labeldistance=2.5, labelangle=%s, headlabel="%s"] ;' % (p_node,
                                                            i_node, szAngle, szHeadLabel))
                    else:
                        lsDot.append('%d -> %d ;' % (p_node, i_node))
                i_node += 1
        lsDot.append('}')
        dot_data = '\n'.join(lsDot)
        return dot_data
     
     
    def loadCSV(file):
        """Loads a CSV file and converts all floats and ints into basic datatypes."""
        def convertTypes(s):
            s = s.strip()
            try:
                return float(s) if '.' in s else int(s)
            except ValueError:
                return s
     
        reader = csv.reader(open(file, 'rt'))
        dcHeader = {}
        if bHeader:
            lsHeader = next(reader)
            for i, szY in enumerate(lsHeader):
                    szCol = 'Column %d' % i
                    dcHeader[szCol] = str(szY)
        return dcHeader, [[convertTypes(item) for item in row] for row in reader]
     
     
    if __name__ == '__main__':
     
        # Select the example you want to classify
        example = 2
     
        # All examples do the following steps:
        #   1. Load training data
        #   2. Let the decision tree grow
        #   4. Plot the decision tree
        #   5. classify without missing data
        #   6. Classifiy with missing data
        #   (7.) Prune the decision tree according to a minimal gain level
        #   (8.) Plot the pruned tree
     
        if example == 1:
            # the smaller examples
            bHeader = False
            dcHeadings, trainingData = loadCSV('tbc.csv') # sorry for not translating the TBC and pneumonia symptoms
            decisionTree = growDecisionTreeFrom(trainingData)
            #decisionTree = growDecisionTreeFrom(trainingData, evaluationFunction=gini) # with gini
            result = plot(decisionTree)
            print(result)
            dot_data = dotgraph(decisionTree)
            graph = pydotplus.graph_from_dot_data(dot_data)
            graph.write_pdf("tbc.pdf")
            graph.write_png("tbc.png")
     
            print(classify(['ohne', 'leicht', 'Streifen', 'normal', 'normal'], decisionTree, dataMissing=False))
            print(classify([None, 'leicht', None, 'Flocken', 'fiepend'], decisionTree, dataMissing=True)) # no longer unique
     
            # Don' forget if you compare the resulting tree with the tree in my presentation: here it is a binary tree!
     
        else:
            bHeader = True
            # the bigger example
            dcHeadings, trainingData = loadCSV('fishiris.csv') # demo data from matlab
            decisionTree = growDecisionTreeFrom(trainingData, evaluationFunction=gini)
            prune(decisionTree, 0.8, notify=True) # notify, when a branch is pruned (one time in this example)
            result = plot(decisionTree)
            print(result)
            dot_data = dotgraph(decisionTree)
            graph = pydotplus.graph_from_dot_data(dot_data)
            graph.write_pdf("iris.pdf")
            graph.write_png("iris.png")
     
            print(classify([6.0, 2.2, 5.0, 1.5], decisionTree)) # dataMissing=False is the default setting
            print(classify([None, None, None, 1.5], decisionTree, dataMissing=True)) # no longer unique

    data:

    fishiris.zip

  2. #2
    Membre émérite

    Homme Profil pro
    Ingénieur calcul scientifique
    Inscrit en
    Mars 2013
    Messages
    1 229
    Détails du profil
    Informations personnelles :
    Sexe : Homme
    Localisation : France, Alpes Maritimes (Provence Alpes Côte d'Azur)

    Informations professionnelles :
    Activité : Ingénieur calcul scientifique

    Informations forums :
    Inscription : Mars 2013
    Messages : 1 229
    Points : 2 328
    Points
    2 328
    Par défaut
    Salut

    Il faut que tu construises un exemple minimal pour présenter ton problème.
    Car là personne n'ira lire tout ce code ...

    NB : Ca arrive souvent qu'en voulant construire un exemple minimal, on comprends où est l'erreur

  3. #3
    Membre averti
    Profil pro
    Inscrit en
    Octobre 2005
    Messages
    788
    Détails du profil
    Informations personnelles :
    Localisation : France

    Informations forums :
    Inscription : Octobre 2005
    Messages : 788
    Points : 446
    Points
    446
    Par défaut
    Bonjour

    Sans lire tous le code, l'erreur indique que tu cherches à atteindre une valeur dans un dictionnaire à travers une clé qui n'hésite pas

    La solution simple à ce genre de problème est de tester la présence de la clé dans le dictionnaire. (if ou méthode get en fonction de tes besoins)
    Le savoir est une arme alors soyons armés

  4. #4
    Nouveau membre du Club
    Homme Profil pro
    Étudiant
    Inscrit en
    Novembre 2016
    Messages
    98
    Détails du profil
    Informations personnelles :
    Sexe : Homme
    Localisation : France

    Informations professionnelles :
    Activité : Étudiant
    Secteur : Enseignement

    Informations forums :
    Inscription : Novembre 2016
    Messages : 98
    Points : 26
    Points
    26
    Par défaut
    salut..

    j'ai fait le teste sur la clé mais ça cause un problème, les données ne s'affiche pas complètement,

  5. #5
    Membre émérite

    Homme Profil pro
    Ingénieur calcul scientifique
    Inscrit en
    Mars 2013
    Messages
    1 229
    Détails du profil
    Informations personnelles :
    Sexe : Homme
    Localisation : France, Alpes Maritimes (Provence Alpes Côte d'Azur)

    Informations professionnelles :
    Activité : Ingénieur calcul scientifique

    Informations forums :
    Inscription : Mars 2013
    Messages : 1 229
    Points : 2 328
    Points
    2 328
    Par défaut
    Ok alors présente nous un exemple minimal avec juste le dictionnaire, et l'accès à la clé qui pose problème.

Discussions similaires

  1. [2008R2] [Datamining] Aucune donnée dans les arbres de décision
    Par mgesche dans le forum SSAS
    Réponses: 1
    Dernier message: 12/12/2012, 15h00
  2. Problèmes de pointeurs avec les arbres
    Par thierry57 dans le forum C
    Réponses: 17
    Dernier message: 22/12/2005, 23h35
  3. Tutoriel sur les arbres
    Par emidelphi77 dans le forum Langage
    Réponses: 2
    Dernier message: 09/10/2005, 23h09
  4. [LG]Les Arbres
    Par SaladinDev dans le forum Langage
    Réponses: 6
    Dernier message: 08/03/2005, 11h51
  5. Recherche documentation sur les arbres
    Par Oberown dans le forum Algorithmes et structures de données
    Réponses: 2
    Dernier message: 22/09/2004, 01h40

Partager

Partager
  • Envoyer la discussion sur Viadeo
  • Envoyer la discussion sur Twitter
  • Envoyer la discussion sur Google
  • Envoyer la discussion sur Facebook
  • Envoyer la discussion sur Digg
  • Envoyer la discussion sur Delicious
  • Envoyer la discussion sur MySpace
  • Envoyer la discussion sur Yahoo