Skip to content

Commit

Permalink
Merge pull request #399 from OrderN/f-proj_ELPA
Browse files Browse the repository at this point in the history
Implementing ELPA
  • Loading branch information
tsuyoshi38 authored Feb 7, 2025
2 parents 34e73a2 + 8fa1eeb commit cb68516
Show file tree
Hide file tree
Showing 14 changed files with 370 additions and 26 deletions.
24 changes: 18 additions & 6 deletions src/DiagModule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ subroutine FindEvals(electrons)
use density_module, ONLY: get_band_density
use io_module, ONLY: write_eigenvalues, write_eigenvalues_format_ase
use pao_format, ONLY: pao
use ELPA_module, ONLY: flag_use_elpa, init_ELPA, end_ELPA

implicit none

Expand Down Expand Up @@ -551,6 +552,9 @@ subroutine FindEvals(electrons)

! Initialise - start BLACS, sort out matrices, allocate memory
call initDiag
if(flag_use_elpa) then
call init_ELPA(matrix_size_padH, row_size, col_size, desca, info )
endif

scale = one / real(N_procs_in_pg(pgid), double)

Expand Down Expand Up @@ -971,6 +975,8 @@ subroutine FindEvals(electrons)
end if
! global
call endDiag
if(flag_use_elpa) call end_ELPA(info)

min_layer = min_layer + 1
return

Expand Down Expand Up @@ -4232,6 +4238,7 @@ subroutine distrib_and_diag(spin,index_kpoint,mode,flag_store_w,kpassed)
block_size_r, block_size_c, blocks_r, blocks_c, procid, pgroup,&
nkpoints_max, pgid, N_kpoints_in_pg, pg_kpoints, N_procs_in_pg, proc_groups
use GenComms, only: cq_warn
use ELPA_module, only: flag_use_elpa, ELPA_zhegv

implicit none

Expand Down Expand Up @@ -4298,13 +4305,18 @@ subroutine distrib_and_diag(spin,index_kpoint,mode,flag_store_w,kpassed)

! Call the diagonalisation routine for generalised problem
! H.psi = E.S.psi
if( flag_use_elpa ) then
call ELPA_zhegv(mode, matrix_size_padH, row_size, col_size, &
SCHmat(:,:,spin), SCSmat(:,:,spin), local_evals(:,spin), z(:,:,spin), info )
else
call pzhegvx(1, mode, 'A', 'U', matrix_size_padH, SCHmat(:,:,spin), &
1, 1, desca, SCSmat(:,:,spin), 1, 1, descb, &
vl, vu, il, iu, abstol, m, mz, local_evals(:,spin), &
orfac, z(:,:,spin), 1, 1, descz, work, lwork, &
rwork, lrwork, iwork, liwork, ifail, iclustr, &
gap, info)
endif

call pzhegvx(1, mode, 'A', 'U', matrix_size_padH, SCHmat(:,:,spin), &
1, 1, desca, SCSmat(:,:,spin), 1, 1, descb, &
vl, vu, il, iu, abstol, m, mz, local_evals(:,spin), &
orfac, z(:,:,spin), 1, 1, descz, work, lwork, &
rwork, lrwork, iwork, liwork, ifail, iclustr, &
gap, info)
if (info /= 0) then
if(info==2.OR.info==4) then ! These are safe to continue
if(.NOT.flag_info_greater_zero) then
Expand Down
161 changes: 161 additions & 0 deletions src/ELPAModule.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
module ELPA_module

use datatypes
use mpi
!!$ use omp
use elpa
use GenComms, ONLY: cq_abort, cq_warn, myid

implicit none

logical :: flag_elpa_dummy = .false. ! A marker to show ELPA in compilation
logical :: flag_use_elpa = .false. ! whether we use ELPA or not
character(len=16) :: elpa_solver = "ELPA1" ! ELPA1 or ELPA2
character(len=16) :: elpa_kernel = "GENERIC"
integer :: elpa_API = 20241105
integer :: merow, mecol

class(elpa_t), pointer :: elp

private
public :: flag_use_elpa, elpa_solver, elpa_kernel, elpa_API, flag_elpa_dummy
public :: init_ELPA, end_ELPA, ELPA_zhegv

contains

subroutine init_ELPA (matrix_size, row_size, col_size, desc, info)

implicit none

integer, intent(in) :: matrix_size, row_size, col_size
integer, intent(in) :: desc(9)
integer, intent(out) :: info

integer :: context, block_size_r, block_size_c
integer :: numrows, numcols, merow, mecol ! numrows= proc_rows, numcols=proc_cols
character(len=12) :: subname = "init_ELPA: "

context = desc(2)
block_size_r = desc(5)
block_size_c = desc(6)

if( block_size_r /= block_size_c ) then ! restriction for ELPA
call cq_abort("Diag.BlockSizeR and Diag.BlockSizeC not same !", &
block_size_r, block_size_c )
end if

!To get the information of blacs grid
call blacs_gridinfo( context, numrows, numcols, merow, mecol )

if( mod(numcols,numrows) /= 0 ) then ! restriction for ELPA
call cq_warn(subname,"Diag.ProcRows is not a factor of Diag.ProcCols !",numrows,numcols)
end if

if( matrix_size <= block_size_r*numrows ) then ! restriction for ELPA
call cq_abort("Diag.BlockSizeR should be less than or equal to", (matrix_size-1)/numrows )
end if
if( matrix_size <= block_size_c*numcols ) then ! restriction for ELPA
call cq_abort("Diag.BlockSizeC should be less than or rqual to", (matrix_size-1)/numcols )
end if

info = elpa_init(elpa_API)
if( info /= ELPA_OK ) call cq_abort("ELPA_Init: ELPA API version not supported")

elp => elpa_allocate(info)

call elp%set( "na", matrix_size, info )
if( info /= ELPA_OK ) call cq_abort("ELPA_Init: Could not set parameter na")
call elp%set( "nev", matrix_size, info )
if( info /= ELPA_OK ) call cq_abort("ELPA_Init: Could not set parameter nev")
call elp%set( "local_nrows", row_size, info )
if( info /= ELPA_OK ) call cq_abort("ELPA_Init: Could not set parameter local_nrows")
call elp%set( "local_ncols", col_size, info )
if( info /= ELPA_OK ) call cq_abort("ELPA_Init: Could not set parameter local_ncols")
call elp%set( "nblk", block_size_r, info )
if( info /= ELPA_OK ) call cq_abort("ELPA_Init: Could not set parameter nblk")
call elp%set( "mpi_comm_parent", MPI_COMM_WORLD, info )
if( info /= ELPA_OK ) call cq_abort("ELPA_Init: Could not set parameter mpi_comm_parent")
call elp%set( "process_row", merow, info )
if( info /= ELPA_OK ) call cq_abort("ELPA_Init: Could not set parameter process_row")
call elp%set( "process_col", mecol, info )
if( info /= ELPA_OK ) call cq_abort("ELPA_Init: Could not set parameter process_col")
call elp%set( "blacs_context", context, info )
if( info /= ELPA_OK ) call cq_abort("ELPA_Init: Could not set parameter blacs_cotext")

select case( elpa_solver )
case("ELPA1")
call elp%set( "solver", ELPA_SOLVER_1STAGE, info ) ! ELPA1
case("ELPA2")
call elp%set( "solver", ELPA_SOLVER_2STAGE, info ) ! ELPA2

select case( elpa_kernel )
case("GENERIC")
call elp%set( "complex_kernel", ELPA_2STAGE_COMPLEX_GENERIC, info )
case("GENERIC_SIMPLE")
call elp%set( "complex_kernel", ELPA_2STAGE_COMPLEX_GENERIC_SIMPLE, info )
case("SSE_ASSEMBLY")
call elp%set( "complex_kernel", ELPA_2STAGE_COMPLEX_SSE_ASSEMBLY, info )
case("SSE_BLOCK1")
call elp%set( "complex_kernel", ELPA_2STAGE_COMPLEX_SSE_BLOCK1, info )
case("SSE_BLOCK2")
call elp%set( "complex_kernel", ELPA_2STAGE_COMPLEX_SSE_BLOCK2, info )
case("AVX_BLOCK1")
call elp%set( "complex_kernel", ELPA_2STAGE_COMPLEX_AVX_BLOCK1, info )
case("AVX_BLOCK2")
call elp%set( "complex_kernel", ELPA_2STAGE_COMPLEX_AVX_BLOCK2, info )
case("AVX2_BLOCK1")
call elp%set( "complex_kernel", ELPA_2STAGE_COMPLEX_AVX2_BLOCK1, info )
case("AVX2_BLOCK2")
call elp%set( "complex_kernel", ELPA_2STAGE_COMPLEX_AVX2_BLOCK2, info )
case default
call cq_abort("Invalid Diag.ELPA2Kernal " // trim(elpa_kernel) )
end select
case default
call cq_abort("Invalid Diag.ELPASolver " // trim(elpa_solver) )
end select

!!$ call elp%set( "omp_threads", omp_get_max_threads(), info )

info = elp%setup()
if( info /= ELPA_OK ) call cq_abort("something wrong in ELPA !")
end subroutine init_ELPA

subroutine end_ELPA( info )

implicit none

integer, intent(out) :: info

call elpa_deallocate( elp, info )
if( info /= ELPA_OK ) call cq_abort("end_ELPA: deallocation error")
call elpa_uninit( info )
if( info /= ELPA_OK ) call cq_abort("end_ELPA: uninit error")
return
end subroutine end_ELPA

subroutine ELPA_zhegv( mode, matrix_size, row_size, col_size, &
Hmat, Smat, Wvec, Zmat, info )

implicit none

character(len=1), intent(in) :: mode
integer, intent(in) :: matrix_size, row_size, col_size
complex(double_cplx), intent(inout) :: Hmat(row_size,col_size)
complex(double_cplx), intent(inout) :: Smat(row_size,col_size)
real(double), intent(out) :: Wvec(matrix_size)
complex(double_cplx), intent(out) :: Zmat(row_size,col_size)
integer, intent(out) :: info
integer :: i, j

if( mode =='N' ) then
call elp%generalized_eigenvalues( &
Hmat, Smat, Wvec, .false., info )
end if
if( mode =='V' ) then
call elp%generalized_eigenvectors( &
Hmat, Smat, Wvec, Zmat, .false., info )
end if

end subroutine ELPA_zhegv

end module ELPA_module
67 changes: 67 additions & 0 deletions src/ELPAModuleDUMMY.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
module ELPA_module

use datatypes
use mpi
use GenComms, ONLY: cq_abort, myid
implicit none

logical :: flag_elpa_dummy = .true. ! A marker to show no ELPA in compilation
logical :: flag_use_elpa = .false. ! This should be false for ELPAModuleDummy
character(len=16) :: elpa_solver = "ELPA1" ! ELPA1 or ELPA2
character(len=16) :: elpa_kernel = "GENERIC"
integer :: elpa_API = 20241105
integer :: merow, mecol

private
public :: flag_use_elpa, elpa_solver, elpa_kernel, elpa_API, flag_elpa_dummy
public :: init_ELPA, end_ELPA, ELPA_zhegv

contains

subroutine init_ELPA (matrix_size, row_size, col_size, desc, info)

implicit none

integer, intent(in) :: matrix_size, row_size, col_size
integer, intent(in) :: desc(9)
integer, intent(out) :: info

if(flag_use_elpa) then
call cq_abort("init_ELPA: init_ELPA is called though ELPA_module is not compiled")
else
call cq_abort("init_ELPA: init_ELPA is called even though use_elpa is false")
endif

return
end subroutine init_ELPA

subroutine end_ELPA( info )

implicit none

integer, intent(out) :: info

call cq_abort("end_ELPA: end_ELPA should not be called")

return
end subroutine end_ELPA

subroutine ELPA_zhegv( mode, matrix_size, row_size, col_size, &
Hmat, Smat, Wvec, Zmat, info )

implicit none

character(len=1), intent(in) :: mode
integer, intent(in) :: matrix_size, row_size, col_size
complex(double_cplx), intent(inout) :: Hmat(row_size,col_size)
complex(double_cplx), intent(inout) :: Smat(row_size,col_size)
real(double), intent(out) :: Wvec(matrix_size)
complex(double_cplx), intent(out) :: Zmat(row_size,col_size)
integer, intent(out) :: info

call cq_abort("ELPA_zhev: CONQUEST should be compiled with ELPA")

return
end subroutine ELPA_zhegv

end module ELPA_module
1 change: 1 addition & 0 deletions src/energy.obj
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ ENERGY_OBJS = H_matrix_module.o \
PosTan_module.o \
DMMinModule.o \
DiagModule${DIAG_DUMMY}.o \
ELPAModule${ELPA_DUMMY}.o \
ScalapackFormat.o \
blip_minimisation.module.o \
blip_gradient.module.o \
Expand Down
29 changes: 29 additions & 0 deletions src/initial_read_module.f90
Original file line number Diff line number Diff line change
Expand Up @@ -3064,6 +3064,7 @@ subroutine readDiagInfo
type_dbl
use species_module, only: nsf_species
use units, only: en_conv, en_units, energy_units
use ELPA_module, only: flag_use_elpa, elpa_solver, elpa_kernel, elpa_API, flag_elpa_dummy

implicit none

Expand Down Expand Up @@ -3212,6 +3213,34 @@ subroutine readDiagInfo
write(io_lun,2) block_size_r, block_size_c
write(io_lun,3) proc_rows, proc_cols
end if

!Using ELPA or not
flag_use_elpa = fdf_boolean('Diag.UseELPA',.false.)

if( flag_use_elpa ) then
if(flag_elpa_dummy) call cq_abort("Code compiled without ELPA! Set Diag.UseELPA F")
elpa_API = fdf_integer('Diag.ELPA_API',20181113)
elpa_solver = fdf_string(16,'Diag.ELPASolver','ELPA1')
if(leqi(elpa_solver,'ELPA1')) then
elpa_kernel = "NONE"
else if(leqi(elpa_solver,'ELPA2')) then
elpa_kernel = fdf_string(16,'Diag.ELPA2Kernel','GENERIC')
! Check for a valid kernel
if(.not.(leqi(elpa_kernel,'GENERIC').OR.leqi(elpa_kernel,"GENERIC_SIMPLE") &
.OR.leqi(elpa_kernel,"SSE_ASSEMBLY").OR.leqi(elpa_kernel,"SSE_BLOCK1") &
.OR.leqi(elpa_kernel,"SSE_BLOCK2").OR.leqi(elpa_kernel,"AVX_BLOCK1") &
.OR.leqi(elpa_kernel,"AVX_BLOCK2").OR.leqi(elpa_kernel,"AVX2_BLOCK1") &
.OR.leqi(elpa_kernel,"AVX2_BLOCK2"))) then
call cq_abort("Invalid Diag.ELPA2Kernel " // elpa_kernel )
endif
else
call cq_abort("Invalid Diag.ELPASolver " // elpa_solver )
endif
else
elpa_solver = "NONE"
elpa_kernel = "NONE"
end if

! Read k-point mesh type
mp_mesh = fdf_boolean('Diag.MPMesh',.false.)
if(.NOT.mp_mesh) then
Expand Down
12 changes: 10 additions & 2 deletions src/system/system.archer2.make
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,17 @@ XC_COMPFLAGS =
FFT_LIB=-L$(FFTW_ROOT)/lib -lfftw3
FFT_OBJ=fft_fftw3.o

LIBS= $(FFT_LIB) $(XC_LIB) $(BLAS)
# Set ELPA library
#ELPA_LIB = -L/**/lib -lelpa
#ELPA_INC = -I/**/modules/
ELPA_LIB =
ELPA_INC =

LIBS= $(FFT_LIB) $(ELPA_LIB) $(XC_LIB) $(BLAS)

# Compilation flags
# NB for gcc10 you need to add -fallow-argument-mismatch
COMPFLAGS= -O3 -fallow-argument-mismatch -fopenmp $(XC_COMPFLAGS)
COMPFLAGS= -O3 -fallow-argument-mismatch -fopenmp $(XC_COMPFLAGS) $(ELPA_INC)

# Linking flags
LINKFLAGS= -fopenmp -L$(LIBSCI_BASE_DIR)/gnu/9.1/x86_64/lib -lsci_gnu_mpi -lsci_gnu
Expand All @@ -55,3 +61,5 @@ LINKFLAGS= -fopenmp -L$(LIBSCI_BASE_DIR)/gnu/9.1/x86_64/lib -lsci_gnu_mpi -lsci
MULT_KERN = ompGemm_m
# Use dummy DiagModule or not
DIAG_DUMMY =
# Use dummy ELPAModule or not
ELPA_DUMMY =DUMMY
13 changes: 10 additions & 3 deletions src/system/system.cosma.make
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ ARFLAGS=

# Compilation flags
# NB for gcc10 you need to add -fallow-argument-mismatch
COMPFLAGS= -g -fopenmp -O3 $(XC_COMPFLAGS) -I${MKLROOT}/include/intel64/lp64 -I"${MKLROOT}/include" -fno-omit-frame-pointer -xHost
COMPFLAGS= -g -fopenmp -O3 $(XC_COMPFLAGS) $(ELPA_INC) -I${MKLROOT}/include/intel64/lp64 -I"${MKLROOT}/include" -fno-omit-frame-pointer -xHost
COMPFLAGS_F77= $(COMPFLAGS)

# LibXC compatibility (LibXC below) or Conquest XC library
Expand All @@ -29,12 +29,19 @@ XC_COMPFLAGS =
FFT_LIB=-lmkl_rt
FFT_OBJ=fft_fftw3.o

# Set ELPA library
#ELPA_LIB = -L/**/lib -lelpa
#ELPA_INC = -I/**/modules/
ELPA_LIB =
ELPA_INC =

# Full library call; remove scalapack if using dummy diag module
LIBS= $(FFT_LIB) $(XC_LIB)
LIBS= $(FFT_LIB) $(ELPA_LIB) $(XC_LIB)

# Matrix multiplication kernel type
MULT_KERN = default
# Use dummy DiagModule or not
DIAG_DUMMY =


# Use dummy ELPAModule or not
ELPA_DUMMY =DUMMY
Loading

0 comments on commit cb68516

Please sign in to comment.