Skip to content

Commit

Permalink
Refactor m_kern_exx_eri to call cri_eri_inner_calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
connoraird committed Apr 26, 2024
1 parent 0a9ed8a commit ddab600
Showing 1 changed file with 45 additions and 97 deletions.
142 changes: 45 additions & 97 deletions src/exx_kernel_default.f90
Original file line number Diff line number Diff line change
Expand Up @@ -897,13 +897,14 @@ end subroutine get_X_matrix
!
! To ensure thread safety, variables which are altered must be passed in as parameters rather than imported.
! TODO: Change name to something more descriptive
subroutine cri_eri_inner_calculation(phi_i, Ome_kj, nsf1, nsf2, nsf3, dv &
ncaddr, ncbeg, ia_nsup, ewald_charge, work_out_3d, work_in_3d, c)
subroutine cri_eri_inner_calculation(nsf1_array, phi_i, Ome_kj, nsf1, nsf2, nsf3, kpart, dv, ncaddr, ncbeg, &
ia_nsup, backup_eris, start_count, ewald_charge, work_out_3d, work_in_3d, c)

use exx_poisson, only: exx_v_on_grid, exx_ewald_charge

use exx_types, only: Phy_k, phi_j, phi_k, ewald_rho, p_gauss, w_gauss, reckernel_3d, ewald_pot, &
pulay_radius, p_ngauss, r_int, p_omega, exx_psolver, exx_pscheme, extent
use exx_types, only: phi_j, phi_k, ewald_rho, p_gauss, w_gauss, reckernel_3d, ewald_pot, &
pulay_radius, p_ngauss, r_int, p_omega, exx_psolver, exx_pscheme, extent, eris, &
store_eris

use GenBlas, only: dot

Expand All @@ -912,9 +913,10 @@ subroutine cri_eri_inner_calculation(phi_i, Ome_kj, nsf1, nsf2, nsf3, dv &
implicit none

real(double), pointer, intent(in) :: Ome_kj(:,:,:), phi_i(:,:,:,:)
integer, intent(in) :: nsf1, nsf2 ! The indices of the loops from which this function is called
integer, intent(in) :: ncbeg, ia_nsup
real(double), intent(in) :: dv
integer, intent(in) :: kpart, nsf1, nsf2 ! The indices of the loops from which this function is called
integer, intent(in) :: ncbeg, ia_nsup, start_count
logical, intent(in) :: backup_eris
real(double), intent(in) :: nsf1_array(:,:,:,:), dv
real(double), intent(out) :: ewald_charge, work_out_3d(:,:,:), work_in_3d(:,:,:)
real(double), intent(inout) :: c(:)

Expand All @@ -923,11 +925,11 @@ subroutine cri_eri_inner_calculation(phi_i, Ome_kj, nsf1, nsf2, nsf3, dv &

work_out_3d = zero
!
work_in_3d = Phy_k(:,:,:,nsf1) * phi_j(:,:,:,nsf2)
work_in_3d = nsf1_array(:,:,:,nsf1) * phi_j(:,:,:,nsf2)
!
if (exx_psolver=='fftw' .and. exx_pscheme=='ewald') then
call exx_ewald_charge(work_in_3d,extent,dv,ewald_charge)
work_in_3d = work_in_3d - ewald_rho*ewald_charge
work_in_3d = work_in_3d - ewald_rho*ewald_charge
end if
!
call exx_v_on_grid(inode,extent,work_in_3d,work_out_3d,r_int, &
Expand All @@ -944,8 +946,17 @@ subroutine cri_eri_inner_calculation(phi_i, Ome_kj, nsf1, nsf2, nsf3, dv &
!
do nsf3 = 1, ia_nsup
!
c(ncaddr + nsf3 - 1) = c(ncaddr + nsf3 - 1) &
+ dot((2*extent+1)**3, phi_i(:,:,:,nsf3), 1, Ome_kj, 1) * dv
exx_mat_elem = dot((2*extent+1)**3, phi_i(:,:,:,nsf3), 1, Ome_kj, 1) * dv
!
if ( backup_eris ) then
!
eris(kpart)%store_eris( start_count + nsf3 ) = exx_mat_elem
!
else
!
c(ncaddr + nsf3 - 1) = c(ncaddr + nsf3 - 1) + exx_mat_elem
!
end if
!
end do ! nsf3 = 1, ia%nsup
end subroutine cri_eri_inner_calculation
Expand Down Expand Up @@ -1222,19 +1233,19 @@ subroutine m_kern_exx_cri(k_off, kpart, ib_nd_acc, ibaddr, nbnab, &
! Begin the parallel region here as earlier allocations make it difficult to do before now.
! However, this should be possible in future work.
!
!$omp parallel default(none) reduction(+: c) &
!$omp shared(kg,jb,tmr_std_exx_poisson,tmr_std_exx_nsup,Phy_k,phi_j,phi_k,ncbeg,ia, &
!$omp tmr_std_exx_matmult,ewald_pot,phi_i,exx_psolver,exx_pscheme,extent,dv, &
!$omp ewald_rho,inode,pulay_radius,p_omega,p_gauss,w_gauss,reckernel_3d,r_int) &
!$omp private(nsf1,nsf2,work_out_3d,work_in_3d,ewald_charge,Ome_kj_1d_buffer,Ome_kj, &
!$omp parallel default(none) reduction(+: c) &
!$omp shared(kg,jb,tmr_std_exx_poisson,tmr_std_exx_nsup,Phy_k,phi_j,phi_k,ncbeg,ia,kpart, &
!$omp tmr_std_exx_matmult,ewald_pot,phi_i,exx_psolver,exx_pscheme,extent,dv, &
!$omp ewald_rho,inode,pulay_radius,p_omega,p_gauss,w_gauss,reckernel_3d,r_int) &
!$omp private(nsf1,nsf2,work_out_3d,work_in_3d,ewald_charge,Ome_kj_1d_buffer,Ome_kj, &
!$omp ncaddr,nsf3,exx_mat_elem,r,s,t)
Ome_kj(1:2*extent+1, 1:2*extent+1, 1:2*extent+1) => Ome_kj_1d_buffer
!$omp do schedule(runtime) collapse(2)
do nsf1 = 1, kg%nsup
do nsf2 = 1, jb%nsup
!
call cri_eri_inner_calculation(phi_i, Ome_kj, nsf1, nsf2, nsf3, ewald_charge, dv, ncaddr, &
ncbeg, ia%nsup, work_out_3d, work_in_3d, c)
call cri_eri_inner_calculation(Phy_k, phi_i, Ome_kj, nsf1, nsf2, kpart, dv, ncaddr, ncbeg, &
ia%nsup, .false., 0, ewald_charge, work_out_3d, work_in_3d, c)
!
end do ! nsf2 = 1, jb%nsup
end do ! nsf1 = 1, kg%nsup
Expand Down Expand Up @@ -1319,13 +1330,13 @@ subroutine m_kern_exx_eri(k_off, kpart, ib_nd_acc, ibaddr, nbnab, &
use exx_types, only: prim_atomic_data, neigh_atomic_data, &
tmr_std_exx_accumul, tmr_std_exx_poisson, &
tmr_std_exx_poisson, grid_spacing, r_int, extent,&
ewald_charge, ewald_rho, ewald_pot, eris, &
ewald_charge, ewald_rho, ewald_pot, &
pulay_radius, p_omega, p_ngauss, p_gauss, w_gauss, &
exx_psolver,exx_pscheme, &
unit_exx_debug, unit_eri_debug
!
use exx_types, only: phi_i, phi_j, phi_k, phi_l, eris, &
work_in_3d, work_out_3d, exx_gto, exx_gto_poisson
use exx_types, only: phi_i_1d_buffer, phi_j, phi_k, phi_l, &
Ome_kj_1d_buffer, work_in_3d, work_out_3d
use exx_types, only: exx_alloc
!
use exx_memory, only: exx_mem_alloc
Expand Down Expand Up @@ -1371,14 +1382,17 @@ subroutine m_kern_exx_eri(k_off, kpart, ib_nd_acc, ibaddr, nbnab, &
real(double) :: dr,dv,K_val
real(double) :: exx_mat_elem
!
! We allocate pointers here to point at 1D arrays later and allow contiguous access when passing to BLAS dot later
real(double), pointer :: phi_i(:,:,:,:), Ome_kj(:,:,:)
!
type(prim_atomic_data) :: ia !i_alpha
type(neigh_atomic_data) :: jb !j_beta
type(neigh_atomic_data) :: kg !k_gamma
type(neigh_atomic_data) :: ld !l_delta
!
integer :: maxsuppfuncs
integer :: nsf_kg, nsf_ld, nsf_ia, nsf_jb
integer :: r, s, t,
integer :: r, s, t
integer :: k_count, l_count, ld_count, kg_count, i_count, j_count, jb_count, count
!
! GTO
Expand Down Expand Up @@ -1510,7 +1524,8 @@ subroutine m_kern_exx_eri(k_off, kpart, ib_nd_acc, ibaddr, nbnab, &
!
!print*, 'i',i, 'global_num',ia%ip,'spe',ia%spec, 'nsup', ia%nsup
!
if ( exx_alloc ) call exx_mem_alloc(extent,ia%nsup,0,'phi_i','alloc')
if ( exx_alloc ) call exx_mem_alloc(extent,ia%nsup,0,'phi_i_1d_buffer','alloc')
phi_i(1:2*extent+1, 1:2*extent+1, 1:2*extent+1, 1:ia%nsup) => phi_i_1d_buffer
!
call exx_phi_on_grid(inode,ia%ip,ia%spec,extent, &
ia%xyz,ia%nsup,phi_i,r_int,xyz_zero)
Expand Down Expand Up @@ -1559,6 +1574,8 @@ subroutine m_kern_exx_eri(k_off, kpart, ib_nd_acc, ibaddr, nbnab, &
call exx_phi_on_grid(inode,jb%global_num,jb%spec,extent, &
jb%xyz,jb%nsup,phi_j,r_int,xyz_zero)
!
if ( exx_alloc ) call exx_mem_alloc(extent,0,0,'Ome_kj_1d_buffer','alloc')
!
!xj = jb%xyz_cv(1)
!yj = jb%xyz_cv(2)
!zj = jb%xyz_cv(3)
Expand All @@ -1573,6 +1590,7 @@ subroutine m_kern_exx_eri(k_off, kpart, ib_nd_acc, ibaddr, nbnab, &
j_count = i_count + (j - 1)
j_count = j_count * jb%nsup
!
Ome_kj(1:2*extent+1, 1:2*extent+1, 1:2*extent+1) => Ome_kj_1d_buffer
jb_loop: do nsf_jb = 1, jb%nsup
!
! The current state of count
Expand All @@ -1586,86 +1604,16 @@ subroutine m_kern_exx_eri(k_off, kpart, ib_nd_acc, ibaddr, nbnab, &
jb_count = j_count + (nsf_jb - 1)
jb_count = jb_count * ia%nsup
!
ncaddr = ncbeg + ia%nsup * (nsf_jb - 1)
!
call start_timer(tmr_std_exx_poisson)
work_out_3d = zero
work_in_3d = phi_l(:,:,:,nsf_ld)*phi_j(:,:,:,nsf_jb)
!
if (exx_psolver=='fftw' .and. exx_pscheme=='ewald') then
call exx_ewald_charge(work_in_3d,extent,dv,ewald_charge)
work_in_3d = work_in_3d - ewald_rho*ewald_charge
end if
!
call exx_v_on_grid(inode,extent,work_in_3d,work_out_3d,r_int, &
exx_psolver,exx_pscheme,pulay_radius,p_omega,p_ngauss,p_gauss,&
w_gauss,reckernel_3d)
!
if (exx_psolver=='fftw' .and. exx_pscheme=='ewald') then
work_out_3d = work_out_3d + ewald_pot*ewald_charge
end if
!
call stop_timer(tmr_std_exx_poisson,.true.)

ia_loop: do nsf_ia = 1, ia%nsup
!
! The current state of count
! count = (k - 1) * nbnab(k_in_part) * ld%nsup * kg%nsup * at%n_hnab(k_in_halo) * nbnab(k_in_part) * jb%nsup * ia%nsup + &
! (l - 1) * ld%nsup * kg%nsup * at%n_hnab(k_in_halo) * nbnab(k_in_part) * jb%nsup * ia%nsup + &
! (nsf_ld - 1) * kg%nsup * at%n_hnab(k_in_halo) * nbnab(k_in_part) * jb%nsup * ia%nsup + &
! (nsf_kg - 1) * at%n_hnab(k_in_halo) * nbnab(k_in_part) * jb%nsup * ia%nsup + &
! (i - 1) * nbnab(k_in_part) * jb%nsup * ia%nsup + &
! (j - 1) * jb%nsup * ia%nsup + &
! (nsf_jb - 1) * ia%nsup + &
! nsf_ia
count = jb_count + nsf_ia
!
exx_mat_elem = zero
!
call start_timer(tmr_std_exx_accumul)
!
do r = 1, 2*extent+1
do s = 1, 2*extent+1
do t = 1, 2*extent+1

exx_mat_elem = exx_mat_elem &
+ phi_k(t,s,r,nsf_kg) * phi_i(t,s,r,nsf_ia) * K_val &
* work_out_3d(t,s,r) * dv

end do
end do
end do
!
call stop_timer(tmr_std_exx_accumul,.true.)
!
if ( exx_debug ) then

write(unit_eri_debug,10) count, exx_mat_elem, K_val, &
'[',ia%ip, kg%global_num,'|',ld%global_num, jb%global_num,']', &
'(',nsf_ia,nsf_kg, '|',nsf_ld,nsf_jb, ')' , &
'[',ia%name,kg%name,'|',ld%name,jb%name,']' , &
ia%xyz_ip(3), kg%xyz_cv(3), ld%xyz_cv(3), jb%xyz_cv(3)

end if
!
if ( backup_eris ) then
!
eris(kpart)%store_eris( count ) = exx_mat_elem
!
else
!
c(ncaddr + nsf_ia - 1) = c(ncaddr + nsf_ia - 1) + exx_mat_elem
!
end if
!
end do ia_loop
call cri_eri_inner_calculation(phi_l, phi_i, Ome_kj, nsf_kg, nsf_jb, kpart, dv, ncaddr, ncbeg, &
ia%nsup, backup_eris, jb_count, ewald_charge, work_out_3d, work_in_3d, c)
!
end do jb_loop
!
end if
!
end if
!
if ( exx_alloc ) call exx_mem_alloc(extent,0,0,'Ome_kj_1d_buffer','dealloc')
if ( exx_alloc ) call exx_mem_alloc(extent,jb%nsup,0,'phi_j','dealloc')
!
!!$
Expand All @@ -1674,7 +1622,7 @@ subroutine m_kern_exx_eri(k_off, kpart, ib_nd_acc, ibaddr, nbnab, &
!
end do j_loop
!
if ( exx_alloc ) call exx_mem_alloc(extent,ia%nsup,0,'phi_i','dealloc')
if ( exx_alloc ) call exx_mem_alloc(extent,ia%nsup,0,'phi_i_1d_buffer','dealloc')
!
!!$
!!$ ****[ i end loop ]****
Expand Down

0 comments on commit ddab600

Please sign in to comment.