Skip to content

Commit

Permalink
Tidy up m_kern_exx_eri and correct phi_k index in cri_eri_inner_calcu…
Browse files Browse the repository at this point in the history
…lation
  • Loading branch information
connoraird committed May 17, 2024
1 parent 3a169f4 commit bb66fa7
Showing 1 changed file with 8 additions and 106 deletions.
114 changes: 8 additions & 106 deletions src/exx_kernel_default.f90
Original file line number Diff line number Diff line change
Expand Up @@ -905,15 +905,14 @@ end subroutine get_X_matrix
!
! To ensure thread safety, variables which are altered must be passed in as parameters rather than imported.
! TODO: Change name to something more descriptive
subroutine cri_eri_inner_calculation(nsf1_array, phi_i, Ome_kj, nsf1, nsf2, kpart, dv, &
subroutine cri_eri_inner_calculation(nsf1_array, phi_i, Ome_kj, nsf1, nsf2, nsf_kg, dv, &
multiplier, ncaddr, ncbeg, ia_nsup, ewald_charge, work_out_3d, work_in_3d, &
c, backup_eris, store_eris_inner)

use exx_poisson, only: exx_v_on_grid, exx_ewald_charge

use exx_types, only: phi_j, phi_k, ewald_rho, p_gauss, w_gauss, reckernel_3d, ewald_pot, &
pulay_radius, p_ngauss, r_int, p_omega, exx_psolver, exx_pscheme, extent, eris, &
store_eris
pulay_radius, p_ngauss, r_int, p_omega, exx_psolver, exx_pscheme, extent, store_eris

use GenBlas, only: dot

Expand All @@ -922,7 +921,7 @@ subroutine cri_eri_inner_calculation(nsf1_array, phi_i, Ome_kj, nsf1, nsf2, kpar
implicit none

real(double), pointer, intent(in) :: Ome_kj(:,:,:), phi_i(:,:,:,:)
integer, intent(in) :: kpart, nsf1, nsf2 ! The indices of the loops from which this function is called
integer, intent(in) :: nsf1, nsf2, nsf_kg ! The indices of the loops from which this function is called
integer, intent(in) :: ncbeg, ia_nsup
real(double), intent(in) :: nsf1_array(:,:,:,:), dv, multiplier
real(double), intent(out) :: ewald_charge, work_out_3d(:,:,:), work_in_3d(:,:,:)
Expand Down Expand Up @@ -950,7 +949,7 @@ subroutine cri_eri_inner_calculation(nsf1_array, phi_i, Ome_kj, nsf1, nsf2, kpar
work_out_3d = work_out_3d + ewald_pot*ewald_charge
end if
!
Ome_kj = work_out_3d * phi_k(:,:,:,nsf1)
Ome_kj = work_out_3d * phi_k(:,:,:,nsf_kg)
!
ncaddr = ncbeg + ia_nsup * (nsf2 - 1)
!
Expand Down Expand Up @@ -1258,7 +1257,7 @@ subroutine m_kern_exx_cri(k_off, kpart, ib_nd_acc, ibaddr, nbnab, &
do nsf_kg = 1, kg%nsup
do nsf_ld = 1, jb%nsup
!
call cri_eri_inner_calculation(Phy_k, phi_i, Ome_kj, nsf_kg, nsf_ld, kpart, dv, 1.0d0, &
call cri_eri_inner_calculation(Phy_k, phi_i, Ome_kj, nsf_kg, nsf_ld, nsf_kg, dv, 1.0d0, &
ncaddr, ncbeg, ia%nsup, ewald_charge, work_out_3d, work_in_3d, c, &
.false.)
!
Expand Down Expand Up @@ -1393,7 +1392,6 @@ subroutine m_kern_exx_eri(k_off, kpart, ib_nd_acc, ibaddr, nbnab, &
real(double), dimension(3) :: xyz_zero = zero
!
real(double) :: dr,dv,K_val
real(double) :: exx_mat_elem
!
! We allocate pointers here to point at 1D arrays later and allow contiguous access when passing to BLAS dot later
real(double), pointer :: phi_i(:,:,:,:), Ome_kj(:,:,:), store_eris_inner(:,:)
Expand All @@ -1403,23 +1401,7 @@ subroutine m_kern_exx_eri(k_off, kpart, ib_nd_acc, ibaddr, nbnab, &
type(neigh_atomic_data) :: kg !k_gamma
type(neigh_atomic_data) :: ld !l_delta
!
integer :: nsf_kg, nsf_ld, nsf_ia, nsf_jb
integer :: r, s, t, stat
integer :: k_count, l_count, ld_count, kg_count, i_count, j_count, jb_count, count
!
! GTO
integer :: i_nx, j_nx, k_nx, l_nx
integer :: i_ny, j_ny, k_ny, l_ny
integer :: i_nz, j_nz, k_nz, l_nz
character(len=8) :: i_nt, j_nt, k_nt, l_nt
integer :: ia_gto, jb_gto, kg_gto, ld_gto
real(double) :: ai, aj, ak, al, di, dj, dk, dl
real(double) :: i_norm, j_norm, k_norm, l_norm
!real(double) :: xi, xj, xk, xl
!real(double) :: yi, yj, yk, yl
!real(double) :: zi, zj, zk, zl

real(double) :: eri_gto, eri_pao, test
integer :: nsf_kg, nsf_ld, nsf_jb, count
!
dr = grid_spacing
dv = dr**3
Expand Down Expand Up @@ -1460,12 +1442,6 @@ subroutine m_kern_exx_eri(k_off, kpart, ib_nd_acc, ibaddr, nbnab, &
!
nbbeg = nb_nd_kbeg
!
!print*, 'size jbnab2ch', size(jbnab2ch)
!print*, 'jbnab2ch', jbnab2ch
!print*
!
! The current state of count
k_count = (k - 1) * (nbnab(k_in_part) - 1)
!!$
!!$ ****[ l do loop ]****
!!$
Expand All @@ -1484,27 +1460,10 @@ subroutine m_kern_exx_eri(k_off, kpart, ib_nd_acc, ibaddr, nbnab, &
call exx_phi_on_grid(inode,ld%global_num,ld%spec,extent, &
ld%xyz,ld%nsup,phi_l,r_int,xyz_zero)
!
!xl = ld%xyz_cv(1)
!yl = ld%xyz_cv(2)
!zl = ld%xyz_cv(3)
!
! The current state of count
! l_count = (k - 1) * nbnab(k_in_part) * ld%nsup + &
! (l - 1) * ld%nsup
l_count = k_count + (l - 1)
l_count = l_count * (ld%nsup - 1)
!
ld_loop: do nsf_ld = 1, ld%nsup
!
nbaddr = nbbeg + kg%nsup * (nsf_ld - 1)
!
! The current state of count
! count = (k - 1) * nbnab(k_in_part) * ld%nsup * kg%nsup + &
! (l - 1) * ld%nsup * kg%nsup + &
! (nsf_ld - 1) * kg%nsup
ld_count = l_count + (nsf_ld - 1)
ld_count = ld_count * (kg%nsup - 1)
!
kg_loop: do nsf_kg = 1, kg%nsup
!
if ( backup_eris ) then
Expand All @@ -1513,14 +1472,6 @@ subroutine m_kern_exx_eri(k_off, kpart, ib_nd_acc, ibaddr, nbnab, &
K_val = b(nbaddr+nsf_kg-1)
end if
!
! The current state of count
! kg_count = (k - 1) * nbnab(k_in_part) * ld%nsup * kg%nsup * at%n_hnab(k_in_halo) + &
! (l - 1) * ld%nsup * kg%nsup * at%n_hnab(k_in_halo) + &
! (nsf_ld - 1) * kg%nsup * at%n_hnab(k_in_halo) + &
! (nsf_kg - 1) * at%n_hnab(k_in_halo)
kg_count = ld_count + (nsf_kg - 1)
kg_count = kg_count * (at%n_hnab(k_in_halo) - 1)
!
!!$
!!$ ****[ i loop ]****
!!$
Expand All @@ -1530,42 +1481,21 @@ subroutine m_kern_exx_eri(k_off, kpart, ib_nd_acc, ibaddr, nbnab, &
ni = bundle%iprim_seq (i_in_prim)
np = bundle%iprim_part(i_in_prim)
icad = (i_in_prim - 1) * chalo%ni_in_halo !***
!nbbeg = nb_nd_kbeg
!
call get_iprimdat(ia,kg,ni,i_in_prim,np,.true.,unit_exx_debug)
!
!print*, 'i',i, 'global_num',ia%ip,'spe',ia%spec, 'nsup', ia%nsup
!
if ( exx_alloc ) call exx_mem_alloc(extent,ia%nsup,0,'phi_i_1d_buffer','alloc')
phi_i(1:2*extent+1, 1:2*extent+1, 1:2*extent+1, 1:ia%nsup) => phi_i_1d_buffer
!
call exx_phi_on_grid(inode,ia%ip,ia%spec,extent, &
ia%xyz,ia%nsup,phi_i,r_int,xyz_zero)
!
!xi = ia%xyz_ip(1)
!yi = ia%xyz_ip(2)
!zi = ia%xyz_ip(3)

!print*, size(chalo%i_h2d), shape(chalo%i_h2d)
!
! The current state of count
! i_count = (k - 1) * nbnab(k_in_part) * ld%nsup * kg%nsup * at%n_hnab(k_in_halo) * nbnab(k_in_part) + &
! (l - 1) * ld%nsup * kg%nsup * at%n_hnab(k_in_halo) * nbnab(k_in_part) + &
! (nsf_ld - 1) * kg%nsup * at%n_hnab(k_in_halo) * nbnab(k_in_part) + &
! (nsf_kg - 1) * at%n_hnab(k_in_halo) * nbnab(k_in_part) + &
! (i - 1) * nbnab(k_in_part)
i_count = kg_count + (i - 1)
i_count = i_count * (nbnab(k_in_part) - 1)
!
!!$
!!$ ****[ j loop ]****
!!$
j_loop: do j = 1, nbnab(k_in_part)!mat(np,Xrange)%n_nab(ni)
!nbbeg = nb_nd_kbeg
j_loop: do j = 1, nbnab(k_in_part)
j_in_halo = jbnab2ch(j) !***
!
!print*, j, icad, j_in_halo
!
if ( j_in_halo /= 0 ) then
!
ncbeg = chalo%i_h2d(icad + j_in_halo) !***
Expand All @@ -1577,9 +1507,6 @@ subroutine m_kern_exx_eri(k_off, kpart, ib_nd_acc, ibaddr, nbnab, &
call get_halodat(jb,kg,jseq,chalo%i_hbeg(jpart), &
BCS_parts%lab_cell(BCS_parts%inv_lab_cover(jpart)), &
'j',.true.,unit_exx_debug)
!
!print*, 'j',j, 'global_num',jb%global_num,'spe',jb%spec,'nsup', jb%nsup

!
if ( exx_alloc ) call exx_mem_alloc(extent,jb%nsup,0,'phi_j','alloc')
!
Expand All @@ -1588,20 +1515,6 @@ subroutine m_kern_exx_eri(k_off, kpart, ib_nd_acc, ibaddr, nbnab, &
!
if ( exx_alloc ) call exx_mem_alloc(extent,0,0,'Ome_kj_1d_buffer','alloc')
!
!xj = jb%xyz_cv(1)
!yj = jb%xyz_cv(2)
!zj = jb%xyz_cv(3)
!
! The current state of count
! j_count = (k - 1) * nbnab(k_in_part) * ld%nsup * kg%nsup * at%n_hnab(k_in_halo) * nbnab(k_in_part) * jb%nsup + &
! (l - 1) * ld%nsup * kg%nsup * at%n_hnab(k_in_halo) * nbnab(k_in_part) * jb%nsup + &
! (nsf_ld - 1) * kg%nsup * at%n_hnab(k_in_halo) * nbnab(k_in_part) * jb%nsup + &
! (nsf_kg - 1) * at%n_hnab(k_in_halo) * nbnab(k_in_part) * jb%nsup + &
! (i - 1) * nbnab(k_in_part) * jb%nsup + &
! (j - 1) * jb%nsup + &
j_count = i_count + (j - 1)
j_count = j_count * (jb%nsup - 1)
!
! TODO include bounds in Ome_kj_1d_buffer and store_eris
Ome_kj(1:2*extent+1, 1:2*extent+1, 1:2*extent+1) => Ome_kj_1d_buffer
!
Expand All @@ -1611,18 +1524,7 @@ subroutine m_kern_exx_eri(k_off, kpart, ib_nd_acc, ibaddr, nbnab, &
!
jb_loop: do nsf_jb = 1, jb%nsup
!
! The current state of count
! jb_count = (k - 1) * nbnab(k_in_part) * ld%nsup * kg%nsup * at%n_hnab(k_in_halo) * nbnab(k_in_part) * jb%nsup * ia%nsup + &
! (l - 1) * ld%nsup * kg%nsup * at%n_hnab(k_in_halo) * nbnab(k_in_part) * jb%nsup * ia%nsup + &
! (nsf_ld - 1) * kg%nsup * at%n_hnab(k_in_halo) * nbnab(k_in_part) * jb%nsup * ia%nsup + &
! (nsf_kg - 1) * at%n_hnab(k_in_halo) * nbnab(k_in_part) * jb%nsup * ia%nsup + &
! (i - 1) * nbnab(k_in_part) * jb%nsup * ia%nsup + &
! (j - 1) * jb%nsup * ia%nsup + &
! (nsf_jb - 1) * ia%nsup
jb_count = j_count + (nsf_jb - 1)
jb_count = jb_count * (ia%nsup - 1)
!
call cri_eri_inner_calculation(phi_l, phi_i, Ome_kj, nsf_ld, nsf_jb, kpart, dv, K_val, &
call cri_eri_inner_calculation(phi_l, phi_i, Ome_kj, nsf_ld, nsf_jb, nsf_kg, dv, K_val, &
ncaddr, ncbeg, ia%nsup, ewald_charge, work_out_3d, work_in_3d, c, &
backup_eris, store_eris_inner)
!
Expand Down

0 comments on commit bb66fa7

Please sign in to comment.