1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
|
@jit.rawkernel()
def clean_plan_cplx_device(d_plan_cplx, size_x, size_y, posX, posY, clean_radius_pix, replace_cplx_value):
index = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x
sizeXY = size_x * size_y
jj = cp.int32(index) // cp.int32(size_x)
ii = cp.int32(index) - cp.int32(jj * size_x)
if (ii < size_x and jj < size_y):
#calcul distance
distance = cp.sqrt((posX - ii)**2 + (posY - jj)**2)
cplx = cp.complex64(d_plan_cplx[ii, jj])
mod = cp.sqrt(cp.real(cplx)**2 + cp.imag(cplx)**2)
if (distance < clean_radius_pix):
d_plan_cplx[ii, jj] = 0.0+0j
else:
d_plan_cplx[ii, jj] = mod + 0j
def clean_plan_cplx(d_plan_cplx, size_x, size_y, posX, posY, clean_radius_pix, replace_value):
nthread = 1024
nBlock = math.ceil(size_x * size_y // nthread)
clean_plan_cplx_device[nBlock, nthread](d_plan_cplx, size_x, size_y, posX, posY, clean_radius_pix, replace_value) |
Partager