Bonjour,

Je suis un peu novice sur l utilisation de python et des multi process, je rencontre quelques problèmes pour les variables.

Dans mon script Python (calcul de basqsin versant), j'ai une fonction appelée initializer_worker() qui est responsable de réaliser des calculs pour obtenir deux variables issues de raster : les directions de flux (fdir) et l'accumulation de flux (acc).
Ces variables doivent être calculées une seule fois au début du processus global et ensuite utilisées dans plusieurs processus enfants pour effectuer des calculs de la délimitation de bassins versants Ã* partir de coordonnées xy des stations.

J'ai énormément de paire xy(definit une stations) donc j'ai initié :

Code : Sélectionner tout - Visualiser dans une fenêtre à part
chunks = [stations_df.iloc[i:i + 200] for i in range(0, len(stations_df), 200)]
pour ne pas tt calculé d"un seul coup

si j 'initialise la fonction dans le ProcessPoolExecutor

Code : Sélectionner tout - Visualiser dans une fenêtre à part
1
2
   with ProcessPoolExecutor(max_workers=8, initializer=initializer_worker, initargs=(elevation,)) as executor:
        futures = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
J'ai bien mes calculs mais par contre initializer_worker est lancé plusieurs fois et c'est ce qui demande le plus de temps dans le traitement

et si j'essaye de lui donner ttes le variables de cette fonction

Code : Sélectionner tout - Visualiser dans une fenêtre à part
1
2
      with ProcessPoolExecutor(max_workers=7) as executor:
        futures = {executor.submit(process_chunk, dirmap, fdir, acc, chunk): chunk for chunk in chunks}
j ai des retour d erreur 'NoneType' object has no attribute 'mask' donc les variables ne sont pas transmises.

J insérée mon code ici et toute aide, lien est la bienvenue ( désolé si il y a de choses grossière en python...)


Code : Sélectionner tout - Visualiser dans une fenêtre à part
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
 
import time
import os
import pandas as pd
import geopandas as gpd
from shapely import geometry, ops
import rasterio
from pysheds.grid import Grid
from concurrent.futures import ProcessPoolExecutor, as_completed
from export_logging import create_logger
 
 
 
 
#répertoires et du logger
dossier_script = os.path.dirname(os.path.abspath(__file__))
dossier_log = os.path.join(dossier_script, 'log')
os.makedirs(dossier_log, exist_ok=True)
log = create_logger(os.path.join(dossier_log, 'transfert_process.log'))
elevation = r'....fr_75m_wgs84.tif'
 
stations = pd.read_csv(
    r'....stations.csv',
    sep=';')
 
def lecture_x_y(listexy):
    '''lecture de coordonnees geo (xy)'''
    stations_df = listexy[['pop_id', 'x', 'y']].drop_duplicates()
    stations_df = stations_df.head(200)
 
    print(time.strftime('%H:%M:%S'), 'lecture stations')
 
    return stations_df
 
def initializer_worker(elevation):
 
    #global grid, dem, fdir, acc
    grid = Grid.from_raster(elevation)
    dem = grid.read_raster(elevation)
    pit_filled_dem = grid.fill_pits(dem)
    flooded_dem = grid.fill_depressions(pit_filled_dem)
    inflated_dem = grid.resolve_flats(flooded_dem)
    dirmap = (64, 128, 1, 2, 4, 8, 16, 32)
    fdir = grid.flowdir(inflated_dem, dirmap=dirmap)
    acc = grid.accumulation(fdir, dirmap=dirmap)
    if acc is None:
        print("Erreur : 'acc' est None après calcul")
    print("process initialisation des parametre de direction et d'accumulation des flux... ok")
    return dirmap, fdir, acc
 
 
def delineate_catchment(x, y, station_id, acc, dirmap, fdir):
 
    """ fonction de délimitation des bassins versants """
    log_file = os.path.join(dossier_log, f'log_station_{station_id}.log')
    log = create_logger(log_file)
    grid = Grid.from_raster(elevation)
    try:
        with rasterio.Env():
            # on permet au bassin versant de coller au xy
            x_snap, y_snap = grid.snap_to_mask(acc > 1200, (x, y))
 
            # on vérifie que la station est Ã* l'intérieur du raster
            if x_snap is None or y_snap is None:
                log.error(f"Station {station_id} en dehors de l'emprise du raster.")
                print(f"Station {station_id} en dehors de l'emprise du raster.")
                return None, station_id
            # print ('stations ok')
 
            # calcul du bassin
            catch = grid.catchment(x=x_snap, y=y_snap, fdir=fdir, dirmap=dirmap, xytype='coordinate')
            # on decoupe
            grid.clip_to(catch)
            # on trasforme en polygones
            shapes = grid.polygonize()
 
            # vérifiesi le bassin versant est vide
            if not shapes:
                print(f"Bassin versant vide ou trop petit pour la station {station_id}")
                log.warning(f"Bassin versant vide ou trop petit pour la station {station_id}")
                return None, station_id
            # print('shape ok')
            catchment_polygon = ops.unary_union([geometry.shape(shape) for shape, value in shapes])
 
            # Vérifie si le polygone retourné est vide
            if catchment_polygon.is_empty:
                print(f"Bassin versant trop petit pour la station {station_id}")
                log.warning(f"Bassin versant trop petit pour la station {station_id}")
                return None, station_id
 
            log.info(f"Bassin versant délimité pour la station {station_id}")
            return catchment_polygon, station_id
    except Exception as e:
        log.error(f"Erreur pour la station {station_id}: {e}")
        import traceback, sys
        tb = sys.exc_info()[2]
        print("erreur ligne {0} -  {1} \n".format(tb.tb_lineno, e))
        return None, station_id
#
def process_chunk(dirmap, fdir, acc, chunk):
 
    """Traitement par lots de stations"""
 
    results = []
    for row in chunk.itertuples():
        try:
            result = delineate_catchment(row.x, row.y, row.pop_id,acc, dirmap, fdir)
            if result[0] is not None:
                results.append(result)
        except Exception as e:
            log.error(f"Erreur lors du traitement de la station {row.pop_id}: {e}")
    return results
 
 
if __name__ == '__main__':
    start = time.time()
 
 
    log.info("Début")
    print(time.strftime('%H:%M:%S'), 'Début')
    dirmap, fdir, acc = initializer_worker(elevation)
 
    stations_df = lecture_x_y(stations)
    # diviser les stations en morceaux pour le traitement par lots
    chunks = [stations_df.iloc[i:i + 200] for i in range(0, len(stations_df), 200)]
 
    all_catchment_polygons = []
    all_station_ids = []
 
 
    print(time.strftime('%H:%M:%S'), 'envoi')
 
    with ProcessPoolExecutor(max_workers=7) as executor:
        futures = {executor.submit(process_chunk, dirmap, fdir, acc, chunk): chunk for chunk in chunks}
 
        for future in as_completed(futures):
            chunk_index = list(futures.keys()).index(future)
            try:
                result = future.result()
                if result:
                    for catchment_polygon, station_id in result:
                        all_catchment_polygons.append(catchment_polygon)
                        all_station_ids.append(station_id)
                print(f"Chunk {chunk_index + 1}/{len(chunks)} processed")
            except Exception as e:
                print(f"Error processing chunk {chunk_index + 1}: {e}")
 
 
    gdf = gpd.GeoDataFrame({'pop_id': all_station_ids, 'geometry': all_catchment_polygons}, crs='EPSG:4326')
    print(time.strftime('%H:%M:%S', time.gmtime(time.time() - start)))
 
    dossier_resultat = os.path.join(dossier_script, 'resultat')
    os.makedirs(dossier_resultat, exist_ok=True)
    output_shapefile = os.path.join(dossier_script, 'resultat/bv_75m.shp')
    gdf.to_file(output_shapefile)
    log.info('Fin')
    print(time.strftime('%H:%M:%S'), 'fin')