Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overlap communication with computation in multiply_module #290

Open
wants to merge 14 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 74 additions & 2 deletions src/comms_module.f90
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ end subroutine Mquest_start_send
!! SOURCE
!!
subroutine Mquest_get( mx_nponn, ilen2,ilen3,nc_part,send_node,sent_part,myid,&
bind_rem,b_rem,lenb_rem,bind,b,istart,mx_babs,mx_part,tag)
bind_rem,b_rem,lenb_rem,bind,istart,mx_babs,mx_part,tag)

! Module usage
use mpi
Expand All @@ -194,7 +194,6 @@ subroutine Mquest_get( mx_nponn, ilen2,ilen3,nc_part,send_node,sent_part,myid,&
integer :: bind(:)
integer :: lenb_rem
real(double) :: b_rem(lenb_rem)
real(double) :: b(:)
! Miscellaneous data
integer :: nc_part,send_node,sent_part,myid,ilen2,ilen3,size,istart,offset
integer :: nrstat(MPI_STATUS_SIZE)
Expand Down Expand Up @@ -222,6 +221,79 @@ subroutine Mquest_get( mx_nponn, ilen2,ilen3,nc_part,send_node,sent_part,myid,&
end subroutine Mquest_get
!!***

!!***

! ---------------------------------------------------------------------
! subroutine Mquest_get_nonb
! ---------------------------------------------------------------------

!!****f* comms_module/Mquest_get_nonb *
!!
!! NAME
!! Mquest_get_nonb
!! USAGE
!!
!! PURPOSE
!! Calls non-blocking receives to get data for matrix multiplication
!! INPUTS
!!
!!
!! USES
!!
!! AUTHOR
!! I. Christidi
!! CREATION DATE
!! 2023/10/06
!!
!! MODIFICATION HISTORY
!!
!! SOURCE
!!
subroutine Mquest_get_nonb( mx_nponn, ilen2,ilen3,nc_part,send_node,sent_part,myid,&
bind_rem,b_rem,lenb_rem,bind,istart,mx_babs,mx_part,tag,request)

! Module usage
use mpi
use datatypes
use matrix_comms_module, ONLY: mx_msg_per_part
use GenComms, ONLY: cq_abort

implicit none

! Maxima
integer :: mx_nponn,mx_babs,mx_part
! Arrays for receiving data
integer :: bind_rem(:)
integer :: bind(:)
integer :: lenb_rem
real(double) :: b_rem(lenb_rem)
! Miscellaneous data
integer :: nc_part,send_node,sent_part,myid,ilen2,ilen3,size,istart,offset
integer :: tag
integer :: request(2)

! Local variables
integer :: ilen1,ierr,lenbind_rem

!lenb_rem = size(b_rem)
!lenbind_rem = size(bind_rem)
Comment on lines +276 to +277
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
!lenb_rem = size(b_rem)
!lenbind_rem = size(bind_rem)

ierr = 0
ilen1 = nc_part
!if(3*ilen1+5*ilen2>lenbind_rem) call cq_abort('Get error ',3*ilen1+5*ilen2,lenbind_rem)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
!if(3*ilen1+5*ilen2>lenbind_rem) call cq_abort('Get error ',3*ilen1+5*ilen2,lenbind_rem)

if(ilen3>lenb_rem) call cq_abort('Get error 2 ',ilen3,lenb_rem)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this check should happen inside the if(ilen3.gt.0) clause

call MPI_Irecv(bind_rem,3*ilen1+5*ilen2,MPI_INTEGER, &
send_node-1,tag+1,MPI_COMM_WORLD,request(1),ierr)
if(ierr/=0) call cq_abort('Error receiving indices !',ierr)
if(ilen3.gt.0)then ! Get xyz, sequence list and elements
ierr = 0
call MPI_Irecv(b_rem,ilen3, MPI_DOUBLE_PRECISION,send_node-1,&
tag+2,MPI_COMM_WORLD,request(2),ierr)
if(ierr/=0) call cq_abort('Error receiving data !',ierr)
endif
return
end subroutine Mquest_get_nonb
!!***

! ---------------------------------------------------------------------
! subroutine send_trans_data
! ---------------------------------------------------------------------
Expand Down
13 changes: 5 additions & 8 deletions src/exx_kernel_default.f90
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ subroutine get_X_matrix( exxspin, level )
integer :: exxspin
integer :: lab_const
integer :: invdir,ierr,kpart,ind_part,ncover_yz,n_which,ipart,nnode
integer :: icall,n_cont,kpart_next,ind_partN,k_off
integer :: n_cont,kpart_next,ind_partN,k_off
integer :: icall2,stat,ilen2,lenb_rem
! Remote variables to be allocated
integer(integ),allocatable :: ibpart_rem(:)
Expand Down Expand Up @@ -356,7 +356,7 @@ subroutine get_X_matrix( exxspin, level )
! !$omp parallel default(none) &
! !$omp shared(a, b, c, a_b_c, myid, lena, lenc, tmr_std_allocation, &
! !$omp ncover_yz, ibpart_rem, atrans, usegemm) &
! !$omp private(kpart, icall, ind_part, ipart, nnode, b_rem, &
! !$omp private(kpart, ind_part, ipart, nnode, b_rem, &
! !$omp lenb_rem, n_cont, part_array, ilen2, offset, &
! !$omp nbnab_rem, ibind_rem, ib_nd_acc_rem, ibseq_rem, &
! !$omp npxyz_rem, ibndimj_rem, k_off, icall2)
Expand All @@ -375,7 +375,6 @@ subroutine get_X_matrix( exxspin, level )
xyz_ghost = zero
r_ghost = zero
do kpart = 1,mult(S_X_SX)%ahalo%np_in_halo ! Main loop
icall=1
ind_part = mult(S_X_SX)%ahalo%lab_hcell(kpart)
!
!print*, 'inode', inode,'kpart', kpart, ind_part
Expand All @@ -385,9 +384,7 @@ subroutine get_X_matrix( exxspin, level )
end if
!
if(kpart>1) then ! Is it a periodic image of the previous partition ?
if(ind_part.eq.mult(S_X_SX)%ahalo%lab_hcell(kpart-1)) then
icall=0
else ! Get the data
if(ind_part.ne.mult(S_X_SX)%ahalo%lab_hcell(kpart-1)) then ! Get the data
ipart = mult(S_X_SX)%parts%i_cc2seq(ind_part)
nnode = mult(S_X_SX)%comms%neigh_node_list(kpart)
recv_part(nnode) = recv_part(nnode)+1
Expand All @@ -401,7 +398,7 @@ subroutine get_X_matrix( exxspin, level )
!
allocate(b_rem(lenb_rem))
!
call prefetch(kpart,mult(S_X_SX)%ahalo,mult(S_X_SX)%comms,mult(S_X_SX)%bmat,icall, &
call prefetch(kpart,mult(S_X_SX)%ahalo,mult(S_X_SX)%comms,mult(S_X_SX)%bmat, &
n_cont,part_array,mult(S_X_SX)%bindex,b_rem,lenb_rem,mat_p(matK( exxspin ))%matrix, &
myid,ilen2,mx_msg_per_part,mult(S_X_SX)%parts,mult(S_X_SX)%prim,mult(S_X_SX)%gcs,&
(recv_part(nnode)-1)*2)
Expand Down Expand Up @@ -447,7 +444,7 @@ subroutine get_X_matrix( exxspin, level )
call stop_timer(tmr_std_exx_allocat,.true.)
!
!
call prefetch(kpart,mult(S_X_SX)%ahalo,mult(S_X_SX)%comms,mult(S_X_SX)%bmat,icall, &
call prefetch(kpart,mult(S_X_SX)%ahalo,mult(S_X_SX)%comms,mult(S_X_SX)%bmat, &
n_cont,part_array,mult(S_X_SX)%bindex,b_rem,lenb_rem,mat_p(matK( exxspin ))%matrix, &
myid,ilen2,mx_msg_per_part,mult(S_X_SX)%parts,mult(S_X_SX)%prim,mult(S_X_SX)%gcs,&
(recv_part(nnode)-1)*2)
Expand Down
Loading