diff --git a/ext/cuda/remapping_interpolate_array.jl b/ext/cuda/remapping_interpolate_array.jl index 622efc70fb..322c513cd9 100644 --- a/ext/cuda/remapping_interpolate_array.jl +++ b/ext/cuda/remapping_interpolate_array.jl @@ -46,10 +46,11 @@ function interpolate_slab_kernel!( weights::AbstractArray{Tuple{A, A}}, ) where {A} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x + + index <= length(output_array) || return nothing space = axes(field) FT = Spaces.undertype(space) - - if index <= length(output_array) + @inbounds begin I1, I2 = weights[index] Nq1, Nq2 = length(I1), length(I2) @@ -74,10 +75,11 @@ function interpolate_slab_kernel!( weights::AbstractArray{Tuple{A}}, ) where {A} index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - space = axes(field) - FT = Spaces.undertype(space) - if index <= length(output_array) + index <= length(output_array) || return nothing + @inbounds begin + space = axes(field) + FT = Spaces.undertype(space) I1, = weights[index] Nq = length(I1) @@ -130,11 +132,12 @@ function interpolate_slab_level_kernel!( (I1, I2)::Tuple{<:AbstractArray, <:AbstractArray}, ) index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - space = axes(field) - FT = Spaces.undertype(space) - Nq1, Nq2 = length(I1), length(I2) - if index <= length(vidx_ref_coordinates) + index <= length(vidx_ref_coordinates) || return nothing + @inbounds begin + space = axes(field) + FT = Spaces.undertype(space) + Nq1, Nq2 = length(I1), length(I2) v_lo, v_hi, ξ3 = vidx_ref_coordinates[index] f_lo = zero(FT) @@ -165,11 +168,13 @@ function interpolate_slab_level_kernel!( (I1,)::Tuple{<:AbstractArray}, ) index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - space = axes(field) - FT = Spaces.undertype(space) - Nq = length(I1) - if index <= length(vidx_ref_coordinates) + index <= length(vidx_ref_coordinates) || return nothing + @inbounds begin + space = axes(field) + FT = Spaces.undertype(space) + Nq = length(I1) + v_lo, v_hi, ξ3 = vidx_ref_coordinates[index] f_lo = zero(FT) diff --git a/src/Remapping/distributed_remapping.jl b/src/Remapping/distributed_remapping.jl index b6753f8f56..096825fb8d 100644 --- a/src/Remapping/distributed_remapping.jl +++ b/src/Remapping/distributed_remapping.jl @@ -685,7 +685,7 @@ function _set_interpolated_values_device!( for (field_index, field) in enumerate(fields) field_values = Fields.field_values(field) - for (out_index, h) in enumerate(local_horiz_indices) + @inbounds for (out_index, h) in enumerate(local_horiz_indices) out[out_index, field_index] = zero(FT) if hdims == 2 for j in 1:Nq, i in 1:Nq diff --git a/src/Remapping/interpolate_array.jl b/src/Remapping/interpolate_array.jl index c339a9cdd4..5751167715 100644 --- a/src/Remapping/interpolate_array.jl +++ b/src/Remapping/interpolate_array.jl @@ -29,7 +29,7 @@ function interpolate_slab!( space = axes(field) FT = Spaces.undertype(space) - for index in 1:length(output_array) + @inbounds for index in 1:length(output_array) (I1, I2) = weights[index] Nq1, Nq2 = length(I1), length(I2) @@ -56,7 +56,7 @@ function interpolate_slab!( space = axes(field) FT = Spaces.undertype(space) - for index in 1:length(output_array) + @inbounds for index in 1:length(output_array) (I1,) = weights[index] Nq = length(I1) @@ -178,7 +178,7 @@ function interpolate_slab_level!( FT = Spaces.undertype(space) Nq1, Nq2 = length(I1), length(I2) - for index in 1:length(vidx_ref_coordinates) + @inbounds for index in 1:length(vidx_ref_coordinates) v_lo, v_hi, ξ3 = vidx_ref_coordinates[index] f_lo = zero(FT) @@ -213,7 +213,7 @@ function interpolate_slab_level!( FT = Spaces.undertype(space) Nq = length(I1) - for index in 1:length(vidx_ref_coordinates) + @inbounds for index in 1:length(vidx_ref_coordinates) v_lo, v_hi, ξ3 = vidx_ref_coordinates[index] f_lo = zero(FT) @@ -274,7 +274,7 @@ function interpolate_array( vertical_indices_ref_coordinates = [vertical_indices_ref_coordinate(space, zcoord) for zcoord in zpts] - for (ix, xcoord) in enumerate(xpts) + @inbounds for (ix, xcoord) in enumerate(xpts) hcoord = xcoord helem = Meshes.containing_element(horz_mesh, hcoord) quad = Spaces.quadrature_style(space) @@ -313,7 +313,9 @@ function interpolate_array( vertical_indices_ref_coordinates = [vertical_indices_ref_coordinate(space, zcoord) for zcoord in zpts] - for (iy, ycoord) in enumerate(ypts), (ix, xcoord) in enumerate(xpts) + @inbounds for (iy, ycoord) in enumerate(ypts), + (ix, xcoord) in enumerate(xpts) + hcoord = Geometry.product_coordinates(xcoord, ycoord) helem = Meshes.containing_element(horz_mesh, hcoord) quad = Spaces.quadrature_style(space)