!--------------------------------------------------------------------------------------------------!
!   CP2K: A general program to perform molecular dynamics simulations                              !
!   Copyright 2000-2023 CP2K developers group <https://cp2k.org>                                   !
!                                                                                                  !
!   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
!--------------------------------------------------------------------------------------------------!

! **************************************************************************************************
!> \brief Routine for the real time propagation output.
!> \author Florian Schiffmann (02.09)
! **************************************************************************************************

MODULE rt_propagation_output
   USE atomic_kind_types,               ONLY: atomic_kind_type
   USE cp_control_types,                ONLY: dft_control_type
   USE cp_dbcsr_operations,             ONLY: cp_dbcsr_sm_fm_multiply,&
                                              dbcsr_allocate_matrix_set,&
                                              dbcsr_deallocate_matrix_set
   USE cp_fm_basic_linalg,              ONLY: cp_fm_scale_and_add
   USE cp_fm_struct,                    ONLY: cp_fm_struct_create,&
                                              cp_fm_struct_double,&
                                              cp_fm_struct_release,&
                                              cp_fm_struct_type
   USE cp_fm_types,                     ONLY: cp_fm_create,&
                                              cp_fm_get_info,&
                                              cp_fm_release,&
                                              cp_fm_set_all,&
                                              cp_fm_type
   USE cp_log_handling,                 ONLY: cp_get_default_logger,&
                                              cp_logger_get_default_io_unit,&
                                              cp_logger_get_default_unit_nr,&
                                              cp_logger_type
   USE cp_output_handling,              ONLY: cp_iter_string,&
                                              cp_p_file,&
                                              cp_print_key_finished_output,&
                                              cp_print_key_should_output,&
                                              cp_print_key_unit_nr
   USE cp_realspace_grid_cube,          ONLY: cp_pw_to_cube
   USE dbcsr_api,                       ONLY: &
        dbcsr_add, dbcsr_binary_write, dbcsr_checksum, dbcsr_copy, dbcsr_create, &
        dbcsr_deallocate_matrix, dbcsr_desymmetrize, dbcsr_filter, dbcsr_get_info, &
        dbcsr_get_occupation, dbcsr_init_p, dbcsr_iterator_blocks_left, dbcsr_iterator_next_block, &
        dbcsr_iterator_start, dbcsr_iterator_stop, dbcsr_iterator_type, dbcsr_p_type, dbcsr_scale, &
        dbcsr_set, dbcsr_type
   USE efield_utils,                    ONLY: make_field
   USE input_constants,                 ONLY: ehrenfest,&
                                              real_time_propagation
   USE input_section_types,             ONLY: section_get_ivals,&
                                              section_vals_get_subs_vals,&
                                              section_vals_type
   USE kahan_sum,                       ONLY: accurate_sum
   USE kinds,                           ONLY: default_path_length,&
                                              dp
   USE machine,                         ONLY: m_flush
   USE message_passing,                 ONLY: mp_comm_type
   USE parallel_gemm_api,               ONLY: parallel_gemm
   USE particle_list_types,             ONLY: particle_list_type
   USE particle_types,                  ONLY: particle_type
   USE pw_env_types,                    ONLY: pw_env_get,&
                                              pw_env_type
   USE pw_methods,                      ONLY: pw_zero
   USE pw_pool_types,                   ONLY: pw_pool_create_pw,&
                                              pw_pool_give_back_pw,&
                                              pw_pool_type
   USE pw_types,                        ONLY: COMPLEXDATA1D,&
                                              REALDATA3D,&
                                              REALSPACE,&
                                              RECIPROCALSPACE,&
                                              pw_type
   USE qs_environment_types,            ONLY: get_qs_env,&
                                              qs_environment_type
   USE qs_kind_types,                   ONLY: get_qs_kind_set,&
                                              qs_kind_type
   USE qs_linres_current,               ONLY: calculate_jrho_resp
   USE qs_linres_types,                 ONLY: current_env_type
   USE qs_mo_io,                        ONLY: write_rt_mos_to_restart
   USE qs_rho_types,                    ONLY: qs_rho_get,&
                                              qs_rho_type
   USE qs_scf_post_gpw,                 ONLY: qs_scf_post_moments,&
                                              write_mo_dependent_results,&
                                              write_mo_free_results
   USE qs_scf_post_tb,                  ONLY: scf_post_calculation_tb
   USE qs_scf_types,                    ONLY: qs_scf_env_type
   USE qs_subsys_types,                 ONLY: qs_subsys_get,&
                                              qs_subsys_type
   USE rt_projection_mo_utils,          ONLY: compute_and_write_proj_mo
   USE rt_propagation_types,            ONLY: get_rtp,&
                                              rt_prop_type
   USE rt_propagation_utils,            ONLY: calculate_P_imaginary,&
                                              write_rtp_mo_cubes,&
                                              write_rtp_mos_to_output_unit
#include "../base/base_uses.f90"

   IMPLICIT NONE

   PRIVATE

   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'rt_propagation_output'

   PUBLIC :: rt_prop_output, &
             rt_convergence, &
             rt_convergence_density, &
             report_density_occupation

CONTAINS

! **************************************************************************************************
!> \brief ...
!> \param qs_env ...
!> \param run_type ...
!> \param delta_iter ...
!> \param used_time ...
! **************************************************************************************************
   SUBROUTINE rt_prop_output(qs_env, run_type, delta_iter, used_time)
      TYPE(qs_environment_type), POINTER                 :: qs_env
      INTEGER, INTENT(in)                                :: run_type
      REAL(dp), INTENT(in), OPTIONAL                     :: delta_iter, used_time

      INTEGER                                            :: n_electrons, n_proj, nspin, output_unit, &
                                                            spin
      REAL(dp)                                           :: orthonormality, tot_rho_r
      REAL(KIND=dp), DIMENSION(:), POINTER               :: qs_tot_rho_r
      TYPE(atomic_kind_type), DIMENSION(:), POINTER      :: atomic_kind_set
      TYPE(cp_fm_type), DIMENSION(:), POINTER            :: mos_new
      TYPE(cp_logger_type), POINTER                      :: logger
      TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: matrix_s, P_im, rho_new
      TYPE(dft_control_type), POINTER                    :: dft_control
      TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
      TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
      TYPE(qs_rho_type), POINTER                         :: rho
      TYPE(rt_prop_type), POINTER                        :: rtp
      TYPE(section_vals_type), POINTER                   :: dft_section, input, rtp_section

      NULLIFY (logger, dft_control)

      logger => cp_get_default_logger()
      CALL get_qs_env(qs_env, &
                      rtp=rtp, &
                      matrix_s=matrix_s, &
                      input=input, &
                      rho=rho, &
                      particle_set=particle_set, &
                      atomic_kind_set=atomic_kind_set, &
                      qs_kind_set=qs_kind_set, &
                      dft_control=dft_control)

      rtp_section => section_vals_get_subs_vals(input, "DFT%REAL_TIME_PROPAGATION")

      CALL get_qs_kind_set(qs_kind_set, nelectron=n_electrons)
      n_electrons = n_electrons - dft_control%charge

      CALL qs_rho_get(rho_struct=rho, tot_rho_r=qs_tot_rho_r)

      tot_rho_r = accurate_sum(qs_tot_rho_r)

      output_unit = cp_print_key_unit_nr(logger, rtp_section, "PRINT%PROGRAM_RUN_INFO", &
                                         extension=".scfLog")

      IF (output_unit > 0) THEN
         WRITE (output_unit, FMT="(/,(T3,A,T40,I5))") &
            "Information at iteration step:", rtp%iter
         WRITE (UNIT=output_unit, FMT="((T3,A,T41,2F20.10))") &
            "Total electronic density (r-space): ", &
            tot_rho_r, &
            tot_rho_r + &
            REAL(n_electrons, dp)
         WRITE (UNIT=output_unit, FMT="((T3,A,T59,F22.14))") &
            "Total energy:", rtp%energy_new
         IF (run_type == ehrenfest) &
            WRITE (UNIT=output_unit, FMT="((T3,A,T61,F20.14))") &
            "Energy difference to previous iteration step:", rtp%energy_new - rtp%energy_old
         IF (run_type == real_time_propagation) &
            WRITE (UNIT=output_unit, FMT="((T3,A,T61,F20.14))") &
            "Energy difference to initial state:", rtp%energy_new - rtp%energy_old
         IF (PRESENT(delta_iter)) &
            WRITE (UNIT=output_unit, FMT="((T3,A,T61,E20.6))") &
            "Convergence:", delta_iter
         IF (rtp%converged) THEN
            IF (run_type == real_time_propagation) &
               WRITE (UNIT=output_unit, FMT="((T3,A,T61,F12.2))") &
               "Time needed for propagation:", used_time
            WRITE (UNIT=output_unit, FMT="(/,(T3,A,3X,F16.14))") &
               "CONVERGENCE REACHED", rtp%energy_new - rtp%energy_old
         END IF
      END IF

      IF (rtp%converged) THEN
         IF (.NOT. rtp%linear_scaling) THEN
            CALL get_rtp(rtp=rtp, mos_new=mos_new)
            CALL rt_calculate_orthonormality(orthonormality, &
                                             mos_new, matrix_s(1)%matrix)
            IF (output_unit > 0) &
               WRITE (output_unit, FMT="(/,(T3,A,T60,F20.10))") &
               "Max deviation from orthonormalization:", orthonormality
         END IF
      END IF

      IF (output_unit > 0) &
         CALL m_flush(output_unit)
      CALL cp_print_key_finished_output(output_unit, logger, rtp_section, &
                                        "PRINT%PROGRAM_RUN_INFO")

      IF (rtp%converged) THEN
         dft_section => section_vals_get_subs_vals(input, "DFT")
         IF (BTEST(cp_print_key_should_output(logger%iter_info, &
                                              dft_section, "REAL_TIME_PROPAGATION%PRINT%FIELD"), cp_p_file)) &
            CALL print_field_applied(qs_env, dft_section)
         CALL make_moment(qs_env)
         IF (.NOT. dft_control%qs_control%dftb) THEN
            CALL write_available_results(qs_env=qs_env, rtp=rtp)
         END IF
         IF (rtp%linear_scaling) THEN
            CALL get_rtp(rtp=rtp, rho_new=rho_new)
            IF (BTEST(cp_print_key_should_output(logger%iter_info, &
                                                 dft_section, "REAL_TIME_PROPAGATION%PRINT%RESTART"), cp_p_file)) THEN
               CALL write_rt_p_to_restart(rho_new, .FALSE.)
            END IF
            IF (BTEST(cp_print_key_should_output(logger%iter_info, &
                                                 dft_section, "REAL_TIME_PROPAGATION%PRINT%RESTART_HISTORY"), cp_p_file)) THEN
               CALL write_rt_p_to_restart(rho_new, .TRUE.)
            END IF
            IF (.NOT. dft_control%qs_control%dftb) THEN
               !Not sure if these things could also work with dftb or not
               IF (BTEST(cp_print_key_should_output(logger%iter_info, &
                                                    dft_section, "REAL_TIME_PROPAGATION%PRINT%CURRENT"), cp_p_file)) THEN
                  DO spin = 1, SIZE(rho_new)/2
                     CALL rt_current(qs_env, rho_new(2*spin)%matrix, dft_section, spin, SIZE(rho_new)/2)
                  END DO
               END IF
            END IF
         ELSE
            CALL get_rtp(rtp=rtp, mos_new=mos_new)
            IF (.NOT. dft_control%qs_control%dftb) THEN
               IF (BTEST(cp_print_key_should_output(logger%iter_info, &
                                                    dft_section, "REAL_TIME_PROPAGATION%PRINT%CURRENT"), cp_p_file)) THEN
                  NULLIFY (P_im)
                  nspin = SIZE(mos_new)/2
                  CALL dbcsr_allocate_matrix_set(P_im, nspin)
                  DO spin = 1, nspin
                     CALL dbcsr_init_p(P_im(spin)%matrix)
                     CALL dbcsr_create(P_im(spin)%matrix, template=matrix_s(1)%matrix, matrix_type="N")
                  END DO
                  CALL calculate_P_imaginary(qs_env, rtp, P_im)
                  DO spin = 1, nspin
                     CALL rt_current(qs_env, P_im(spin)%matrix, dft_section, spin, nspin)
                  END DO
                  CALL dbcsr_deallocate_matrix_set(P_im)
               END IF
               IF (dft_control%rtp_control%is_proj_mo) THEN
                  DO n_proj = 1, SIZE(dft_control%rtp_control%proj_mo_list)
                     CALL compute_and_write_proj_mo(qs_env, mos_new, &
                                                    dft_control%rtp_control%proj_mo_list(n_proj)%proj_mo, n_proj)
                  END DO
               END IF
            END IF
            CALL write_rt_mos_to_restart(qs_env%mos, mos_new, particle_set, &
                                         dft_section, qs_kind_set)
         END IF
      END IF

      rtp%energy_old = rtp%energy_new

      IF (.NOT. rtp%converged .AND. rtp%iter >= dft_control%rtp_control%max_iter) &
         CALL cp_abort(__LOCATION__, "EMD did not converge, either increase MAX_ITER "// &
                       "or use a smaller TIMESTEP")

   END SUBROUTINE rt_prop_output

! **************************************************************************************************
!> \brief computes the effective orthonormality of a set of mos given an s-matrix
!>        orthonormality is the max deviation from unity of the C^T S C
!> \param orthonormality ...
!> \param mos_new ...
!> \param matrix_s ...
!> \author Florian Schiffmann (02.09)
! **************************************************************************************************
   SUBROUTINE rt_calculate_orthonormality(orthonormality, mos_new, matrix_s)
      REAL(KIND=dp), INTENT(out)                         :: orthonormality
      TYPE(cp_fm_type), DIMENSION(:), POINTER            :: mos_new
      TYPE(dbcsr_type), OPTIONAL, POINTER                :: matrix_s

      CHARACTER(len=*), PARAMETER :: routineN = 'rt_calculate_orthonormality'

      INTEGER                                            :: handle, i, im, ispin, j, k, n, &
                                                            ncol_local, nrow_local, nspin, re
      INTEGER, DIMENSION(:), POINTER                     :: col_indices, row_indices
      REAL(KIND=dp)                                      :: alpha, max_alpha, max_beta
      TYPE(cp_fm_struct_type), POINTER                   :: tmp_fm_struct
      TYPE(cp_fm_type)                                   :: overlap_re, svec_im, svec_re

      NULLIFY (tmp_fm_struct)

      CALL timeset(routineN, handle)

      nspin = SIZE(mos_new)/2
      max_alpha = 0.0_dp
      max_beta = 0.0_dp
      DO ispin = 1, nspin
         re = ispin*2 - 1
         im = ispin*2
         ! get S*C
         CALL cp_fm_create(svec_re, mos_new(im)%matrix_struct)
         CALL cp_fm_create(svec_im, mos_new(im)%matrix_struct)
         CALL cp_fm_get_info(mos_new(im), &
                             nrow_global=n, ncol_global=k)
         CALL cp_dbcsr_sm_fm_multiply(matrix_s, mos_new(re), &
                                      svec_re, k)
         CALL cp_dbcsr_sm_fm_multiply(matrix_s, mos_new(im), &
                                      svec_im, k)

         ! get C^T (S*C)
         CALL cp_fm_struct_create(tmp_fm_struct, nrow_global=k, ncol_global=k, &
                                  para_env=mos_new(re)%matrix_struct%para_env, &
                                  context=mos_new(re)%matrix_struct%context)
         CALL cp_fm_create(overlap_re, tmp_fm_struct)

         CALL cp_fm_struct_release(tmp_fm_struct)

         CALL parallel_gemm('T', 'N', k, k, n, 1.0_dp, mos_new(re), &
                            svec_re, 0.0_dp, overlap_re)
         CALL parallel_gemm('T', 'N', k, k, n, 1.0_dp, mos_new(im), &
                            svec_im, 1.0_dp, overlap_re)

         CALL cp_fm_release(svec_re)
         CALL cp_fm_release(svec_im)

         CALL cp_fm_get_info(overlap_re, nrow_local=nrow_local, ncol_local=ncol_local, &
                             row_indices=row_indices, col_indices=col_indices)
         DO i = 1, nrow_local
            DO j = 1, ncol_local
               alpha = overlap_re%local_data(i, j)
               IF (row_indices(i) .EQ. col_indices(j)) alpha = alpha - 1.0_dp
               max_alpha = MAX(max_alpha, ABS(alpha))
            END DO
         END DO
         CALL cp_fm_release(overlap_re)
      END DO
      CALL mos_new(1)%matrix_struct%para_env%max(max_alpha)
      CALL mos_new(1)%matrix_struct%para_env%max(max_beta)
      orthonormality = max_alpha

      CALL timestop(handle)

   END SUBROUTINE rt_calculate_orthonormality

! **************************************************************************************************
!> \brief computes the convergence criterion for RTP and EMD
!> \param rtp ...
!> \param matrix_s Overlap matrix without the derivatives
!> \param delta_mos ...
!> \param delta_eps ...
!> \author Florian Schiffmann (02.09)
! **************************************************************************************************

   SUBROUTINE rt_convergence(rtp, matrix_s, delta_mos, delta_eps)
      TYPE(rt_prop_type), POINTER                        :: rtp
      TYPE(dbcsr_type), POINTER                          :: matrix_s
      TYPE(cp_fm_type), DIMENSION(:), INTENT(IN)         :: delta_mos
      REAL(dp), INTENT(out)                              :: delta_eps

      CHARACTER(len=*), PARAMETER                        :: routineN = 'rt_convergence'
      REAL(KIND=dp), PARAMETER                           :: one = 1.0_dp, zero = 0.0_dp

      INTEGER                                            :: handle, i, icol, im, ispin, j, lcol, &
                                                            lrow, nao, newdim, nmo, nspin, re
      LOGICAL                                            :: double_col, double_row
      REAL(KIND=dp)                                      :: alpha, max_alpha
      TYPE(cp_fm_struct_type), POINTER                   :: newstruct, newstruct1, tmp_fm_struct
      TYPE(cp_fm_type)                                   :: work, work1, work2
      TYPE(cp_fm_type), DIMENSION(:), POINTER            :: mos_new

      NULLIFY (tmp_fm_struct)

      CALL timeset(routineN, handle)

      CALL get_rtp(rtp=rtp, mos_new=mos_new)

      nspin = SIZE(delta_mos)/2
      max_alpha = 0.0_dp

      DO i = 1, SIZE(mos_new)
         CALL cp_fm_scale_and_add(-one, delta_mos(i), one, mos_new(i))
      END DO

      DO ispin = 1, nspin
         re = ispin*2 - 1
         im = ispin*2

         double_col = .TRUE.
         double_row = .FALSE.
         CALL cp_fm_struct_double(newstruct, &
                                  delta_mos(re)%matrix_struct, &
                                  delta_mos(re)%matrix_struct%context, &
                                  double_col, &
                                  double_row)

         CALL cp_fm_create(work, matrix_struct=newstruct)
         CALL cp_fm_create(work1, matrix_struct=newstruct)

         CALL cp_fm_get_info(delta_mos(re), ncol_local=lcol, ncol_global=nmo, &
                             nrow_global=nao)
         CALL cp_fm_get_info(work, ncol_global=newdim)

         CALL cp_fm_set_all(work, zero, zero)

         DO icol = 1, lcol
            work%local_data(:, icol) = delta_mos(re)%local_data(:, icol)
            work%local_data(:, icol + lcol) = delta_mos(im)%local_data(:, icol)
         END DO

         CALL cp_dbcsr_sm_fm_multiply(matrix_s, work, work1, ncol=newdim)

         CALL cp_fm_release(work)

         CALL cp_fm_struct_create(tmp_fm_struct, nrow_global=nmo, ncol_global=nmo, &
                                  para_env=delta_mos(re)%matrix_struct%para_env, &
                                  context=delta_mos(re)%matrix_struct%context)
         CALL cp_fm_struct_double(newstruct1, &
                                  tmp_fm_struct, &
                                  delta_mos(re)%matrix_struct%context, &
                                  double_col, &
                                  double_row)

         CALL cp_fm_create(work, matrix_struct=newstruct1)
         CALL cp_fm_create(work2, matrix_struct=newstruct1)

         CALL parallel_gemm("T", "N", nmo, newdim, nao, one, delta_mos(re), &
                            work1, zero, work)

         CALL parallel_gemm("T", "N", nmo, newdim, nao, one, delta_mos(im), &
                            work1, zero, work2)

         CALL cp_fm_get_info(work, nrow_local=lrow)
         DO i = 1, lrow
            DO j = 1, lcol
               alpha = SQRT((work%local_data(i, j) + work2%local_data(i, j + lcol))**2 + &
                            (work%local_data(i, j + lcol) - work2%local_data(i, j))**2)
               max_alpha = MAX(max_alpha, ABS(alpha))
            END DO
         END DO

         CALL cp_fm_release(work)
         CALL cp_fm_release(work1)
         CALL cp_fm_release(work2)
         CALL cp_fm_struct_release(tmp_fm_struct)
         CALL cp_fm_struct_release(newstruct)
         CALL cp_fm_struct_release(newstruct1)

      END DO

      CALL delta_mos(1)%matrix_struct%para_env%max(max_alpha)
      delta_eps = SQRT(max_alpha)

      CALL timestop(handle)

   END SUBROUTINE rt_convergence

! **************************************************************************************************
!> \brief computes the convergence criterion for RTP and EMD based on the density matrix
!> \param rtp ...
!> \param delta_P ...
!> \param delta_eps ...
!> \author Samuel Andermatt (02.14)
! **************************************************************************************************

   SUBROUTINE rt_convergence_density(rtp, delta_P, delta_eps)

      TYPE(rt_prop_type), POINTER                        :: rtp
      TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: delta_P
      REAL(dp), INTENT(out)                              :: delta_eps

      CHARACTER(len=*), PARAMETER :: routineN = 'rt_convergence_density'
      REAL(KIND=dp), PARAMETER                           :: one = 1.0_dp, zero = 0.0_dp

      INTEGER                                            :: col_atom, group_handle, handle, i, &
                                                            ispin, row_atom
      REAL(dp)                                           :: alpha, max_alpha
      REAL(dp), DIMENSION(:), POINTER                    :: block_values
      TYPE(dbcsr_iterator_type)                          :: iter
      TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: rho_new
      TYPE(dbcsr_type), POINTER                          :: tmp
      TYPE(mp_comm_type)                                 :: group

      CALL timeset(routineN, handle)

      CALL get_rtp(rtp=rtp, rho_new=rho_new)

      DO i = 1, SIZE(rho_new)
         CALL dbcsr_add(delta_P(i)%matrix, rho_new(i)%matrix, one, -one)
      END DO
      !get the maximum value of delta_P
      DO i = 1, SIZE(delta_P)
         !square all entries of both matrices
         CALL dbcsr_iterator_start(iter, delta_P(i)%matrix)
         DO WHILE (dbcsr_iterator_blocks_left(iter))
            CALL dbcsr_iterator_next_block(iter, row_atom, col_atom, block_values)
            block_values = block_values*block_values
         END DO
         CALL dbcsr_iterator_stop(iter)
      END DO
      NULLIFY (tmp)
      ALLOCATE (tmp)
      CALL dbcsr_create(tmp, template=delta_P(1)%matrix, matrix_type="N")
      DO ispin = 1, SIZE(delta_P)/2
         CALL dbcsr_desymmetrize(delta_P(2*ispin - 1)%matrix, tmp)
         CALL dbcsr_add(delta_P(2*ispin)%matrix, tmp, one, one)
      END DO
      !the absolute values are now in the even entries of delta_P
      max_alpha = zero
      DO ispin = 1, SIZE(delta_P)/2
         CALL dbcsr_iterator_start(iter, delta_P(2*ispin)%matrix)
         DO WHILE (dbcsr_iterator_blocks_left(iter))
            CALL dbcsr_iterator_next_block(iter, row_atom, col_atom, block_values)
            alpha = MAXVAL(block_values)
            IF (alpha > max_alpha) max_alpha = alpha
         END DO
         CALL dbcsr_iterator_stop(iter)
      END DO
      CALL dbcsr_get_info(delta_P(1)%matrix, group=group_handle)
      CALL group%set_handle(group_handle)
      CALL group%max(max_alpha)
      delta_eps = SQRT(max_alpha)
      CALL dbcsr_deallocate_matrix(tmp)
      CALL timestop(handle)

   END SUBROUTINE rt_convergence_density

! **************************************************************************************************
!> \brief interface to qs_moments. Does only work for nonperiodic dipole
!> \param qs_env ...
!> \author Florian Schiffmann (02.09)
! **************************************************************************************************

   SUBROUTINE make_moment(qs_env)

      TYPE(qs_environment_type), POINTER                 :: qs_env

      CHARACTER(len=*), PARAMETER                        :: routineN = 'make_moment'

      INTEGER                                            :: handle, output_unit
      TYPE(cp_logger_type), POINTER                      :: logger
      TYPE(dft_control_type), POINTER                    :: dft_control

      CALL timeset(routineN, handle)

      NULLIFY (dft_control)

      logger => cp_get_default_logger()
      output_unit = cp_logger_get_default_io_unit(logger)
      CALL get_qs_env(qs_env, dft_control=dft_control)
      IF (dft_control%qs_control%dftb) THEN
         CALL scf_post_calculation_tb(qs_env, "DFTB", .FALSE.)
      ELSE IF (dft_control%qs_control%xtb) THEN
         CALL scf_post_calculation_tb(qs_env, "xTB", .FALSE.)
      ELSE
         CALL qs_scf_post_moments(qs_env%input, logger, qs_env, output_unit)
      END IF
      CALL timestop(handle)

   END SUBROUTINE make_moment

! **************************************************************************************************
!> \brief Reports the sparsity pattern of the complex density matrix
!> \param filter_eps ...
!> \param rho ...
!> \author Samuel Andermatt (09.14)
! **************************************************************************************************

   SUBROUTINE report_density_occupation(filter_eps, rho)

      REAL(KIND=dp)                                      :: filter_eps
      TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: rho

      CHARACTER(len=*), PARAMETER :: routineN = 'report_density_occupation'

      INTEGER                                            :: handle, i, im, ispin, re, unit_nr
      REAL(KIND=dp)                                      :: eps, occ
      TYPE(cp_logger_type), POINTER                      :: logger
      TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: tmp

      CALL timeset(routineN, handle)

      logger => cp_get_default_logger()
      unit_nr = cp_logger_get_default_io_unit(logger)
      NULLIFY (tmp)
      CALL dbcsr_allocate_matrix_set(tmp, SIZE(rho))
      DO i = 1, SIZE(rho)
         CALL dbcsr_init_p(tmp(i)%matrix)
         CALL dbcsr_create(tmp(i)%matrix, template=rho(i)%matrix)
         CALL dbcsr_copy(tmp(i)%matrix, rho(i)%matrix)
      END DO
      DO ispin = 1, SIZE(rho)/2
         re = 2*ispin - 1
         im = 2*ispin
         eps = MAX(filter_eps, 1.0E-11_dp)
         DO WHILE (eps < 1.1_dp)
            CALL dbcsr_filter(tmp(re)%matrix, eps)
            occ = dbcsr_get_occupation(tmp(re)%matrix)
            IF (unit_nr > 0) WRITE (unit_nr, FMT="((T3,A,I1,A,F15.12,A,T61,F20.10))") "Occupation of rho spin ", &
               ispin, " eps ", eps, " real: ", occ
            eps = eps*10
         END DO
         eps = MAX(filter_eps, 1.0E-11_dp)
         DO WHILE (eps < 1.1_dp)
            CALL dbcsr_filter(tmp(im)%matrix, eps)
            occ = dbcsr_get_occupation(tmp(im)%matrix)
            IF (unit_nr > 0) WRITE (unit_nr, FMT="((T3,A,I1,A,F15.12,A,T61,F20.10))") "Occupation of rho spin ", &
               ispin, " eps ", eps, " imag: ", occ
            eps = eps*10.0_dp
         END DO
      END DO
      CALL dbcsr_deallocate_matrix_set(tmp)
      CALL timestop(handle)

   END SUBROUTINE report_density_occupation

! **************************************************************************************************
!> \brief Writes the density matrix and the atomic positions to a restart file
!> \param rho_new ...
!> \param history ...
!> \author Samuel Andermatt (09.14)
! **************************************************************************************************

   SUBROUTINE write_rt_p_to_restart(rho_new, history)

      TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: rho_new
      LOGICAL                                            :: history

      CHARACTER(LEN=*), PARAMETER :: routineN = 'write_rt_p_to_restart'

      CHARACTER(LEN=default_path_length)                 :: file_name, project_name
      INTEGER                                            :: handle, im, ispin, re, unit_nr
      REAL(KIND=dp)                                      :: cs_pos
      TYPE(cp_logger_type), POINTER                      :: logger

      CALL timeset(routineN, handle)
      logger => cp_get_default_logger()
      IF (logger%para_env%is_source()) THEN
         unit_nr = cp_logger_get_default_unit_nr(logger, local=.TRUE.)
      ELSE
         unit_nr = -1
      END IF

      project_name = logger%iter_info%project_name
      DO ispin = 1, SIZE(rho_new)/2
         re = 2*ispin - 1
         im = 2*ispin
         IF (history) THEN
            WRITE (file_name, '(A,I0,A)') &
               TRIM(project_name)//"_LS_DM_SPIN_RE", ispin, "_"//TRIM(cp_iter_string(logger%iter_info))//"_RESTART.dm"
         ELSE
            WRITE (file_name, '(A,I0,A)') TRIM(project_name)//"_LS_DM_SPIN_RE", ispin, "_RESTART.dm"
         END IF
         cs_pos = dbcsr_checksum(rho_new(re)%matrix, pos=.TRUE.)
         IF (unit_nr > 0) THEN
            WRITE (unit_nr, '(T2,A,E20.8)') "Writing restart DM "//TRIM(file_name)//" with checksum: ", cs_pos
         END IF
         CALL dbcsr_binary_write(rho_new(re)%matrix, file_name)
         IF (history) THEN
            WRITE (file_name, '(A,I0,A)') &
               TRIM(project_name)//"_LS_DM_SPIN_IM", ispin, "_"//TRIM(cp_iter_string(logger%iter_info))//"_RESTART.dm"
         ELSE
            WRITE (file_name, '(A,I0,A)') TRIM(project_name)//"_LS_DM_SPIN_IM", ispin, "_RESTART.dm"
         END IF
         cs_pos = dbcsr_checksum(rho_new(im)%matrix, pos=.TRUE.)
         IF (unit_nr > 0) THEN
            WRITE (unit_nr, '(T2,A,E20.8)') "Writing restart DM "//TRIM(file_name)//" with checksum: ", cs_pos
         END IF
         CALL dbcsr_binary_write(rho_new(im)%matrix, file_name)
      END DO

      CALL timestop(handle)

   END SUBROUTINE write_rt_p_to_restart

! **************************************************************************************************
!> \brief Collocation of the current and printing of it in a cube file
!> \param qs_env ...
!> \param P_im ...
!> \param dft_section ...
!> \param spin ...
!> \param nspin ...
!> \author Samuel Andermatt (06.15)
! **************************************************************************************************
   SUBROUTINE rt_current(qs_env, P_im, dft_section, spin, nspin)
      TYPE(qs_environment_type), POINTER                 :: qs_env
      TYPE(dbcsr_type), POINTER                          :: P_im
      TYPE(section_vals_type), POINTER                   :: dft_section
      INTEGER                                            :: spin, nspin

      CHARACTER(len=*), PARAMETER                        :: routineN = 'rt_current'

      CHARACTER(len=1)                                   :: char_spin
      CHARACTER(len=14)                                  :: ext
      CHARACTER(len=2)                                   :: sdir
      INTEGER                                            :: dir, handle, print_unit
      INTEGER, DIMENSION(:), POINTER                     :: stride(:)
      LOGICAL                                            :: mpi_io
      TYPE(cp_logger_type), POINTER                      :: logger
      TYPE(current_env_type)                             :: current_env
      TYPE(dbcsr_type), POINTER                          :: tmp, zero
      TYPE(particle_list_type), POINTER                  :: particles
      TYPE(pw_env_type), POINTER                         :: pw_env
      TYPE(pw_pool_type), POINTER                        :: auxbas_pw_pool
      TYPE(pw_type)                                      :: gs, rs
      TYPE(qs_subsys_type), POINTER                      :: subsys

      CALL timeset(routineN, handle)

      logger => cp_get_default_logger()
      CALL get_qs_env(qs_env=qs_env, subsys=subsys, pw_env=pw_env)
      CALL qs_subsys_get(subsys, particles=particles)
      CALL pw_env_get(pw_env, auxbas_pw_pool=auxbas_pw_pool)

      NULLIFY (zero, tmp)
      ALLOCATE (zero, tmp)
      CALL dbcsr_create(zero, template=P_im)
      CALL dbcsr_copy(zero, P_im)
      CALL dbcsr_set(zero, 0.0_dp)
      CALL dbcsr_create(tmp, template=P_im)
      CALL dbcsr_copy(tmp, P_im)
      IF (nspin == 1) THEN
         CALL dbcsr_scale(tmp, 0.5_dp)
      END IF
      current_env%gauge = -1
      current_env%gauge_init = .FALSE.
      CALL pw_pool_create_pw(auxbas_pw_pool, rs, use_data=REALDATA3D, in_space=REALSPACE)
      CALL pw_pool_create_pw(auxbas_pw_pool, gs, use_data=COMPLEXDATA1D, in_space=RECIPROCALSPACE)

      NULLIFY (stride)
      ALLOCATE (stride(3))

      DO dir = 1, 3

         CALL pw_zero(rs)
         CALL pw_zero(gs)

         CALL calculate_jrho_resp(zero, tmp, zero, zero, dir, dir, rs, gs, qs_env, current_env, retain_rsgrid=.TRUE.)

         stride = section_get_ivals(dft_section, "REAL_TIME_PROPAGATION%PRINT%CURRENT%STRIDE")

         IF (dir == 1) THEN
            sdir = "-x"
         ELSEIF (dir == 2) THEN
            sdir = "-y"
         ELSE
            sdir = "-z"
         END IF
         WRITE (char_spin, "(I1)") spin

         ext = "-SPIN-"//char_spin//sdir//".cube"
         mpi_io = .TRUE.
         print_unit = cp_print_key_unit_nr(logger, dft_section, "REAL_TIME_PROPAGATION%PRINT%CURRENT", &
                                           extension=ext, file_status="REPLACE", file_action="WRITE", &
                                           log_filename=.FALSE., mpi_io=mpi_io)

         CALL cp_pw_to_cube(rs, print_unit, "EMD current", particles=particles, stride=stride, &
                            mpi_io=mpi_io)

         CALL cp_print_key_finished_output(print_unit, logger, dft_section, "REAL_TIME_PROPAGATION%PRINT%CURRENT", &
                                           mpi_io=mpi_io)

      END DO

      CALL pw_pool_give_back_pw(auxbas_pw_pool, rs)
      CALL pw_pool_give_back_pw(auxbas_pw_pool, gs)

      CALL dbcsr_deallocate_matrix(zero)
      CALL dbcsr_deallocate_matrix(tmp)

      DEALLOCATE (stride)

      CALL timestop(handle)

   END SUBROUTINE rt_current

! **************************************************************************************************
!> \brief Interface routine to trigger writing of results available from normal
!>        SCF. Can write MO-dependent and MO free results (needed for call from
!>        the linear scaling code)
!>        Update: trigger also some of prints for time-dependent runs
!> \param qs_env ...
!> \param rtp ...
!> \par History
!>      2022-11 Update [Guillaume Le Breton]
! **************************************************************************************************
   SUBROUTINE write_available_results(qs_env, rtp)
      TYPE(qs_environment_type), POINTER                 :: qs_env
      TYPE(rt_prop_type), POINTER                        :: rtp

      CHARACTER(len=*), PARAMETER :: routineN = 'write_available_results'

      INTEGER                                            :: handle
      TYPE(qs_scf_env_type), POINTER                     :: scf_env

      CALL timeset(routineN, handle)

      CALL get_qs_env(qs_env, scf_env=scf_env)
      IF (rtp%linear_scaling) THEN
         CALL write_mo_free_results(qs_env)
      ELSE
         CALL write_mo_free_results(qs_env)
         CALL write_mo_dependent_results(qs_env, scf_env)
         ! Time-dependent MO print
         CALL write_rtp_mos_to_output_unit(qs_env, rtp)
         CALL write_rtp_mo_cubes(qs_env, rtp)
      END IF

      CALL timestop(handle)

   END SUBROUTINE write_available_results

! **************************************************************************************************
!> \brief Print the field applied to the system. Either the electric
!>        field or the vector potential depending on the gauge used
!> \param qs_env ...
!> \param dft_section ...
!> \par History
!>      2023-01  Created [Guillaume Le Breton]
! **************************************************************************************************
   SUBROUTINE print_field_applied(qs_env, dft_section)
      TYPE(qs_environment_type), POINTER                 :: qs_env
      TYPE(section_vals_type), POINTER                   :: dft_section

      CHARACTER(LEN=3), DIMENSION(3)                     :: rlab
      CHARACTER(LEN=default_path_length)                 :: filename
      INTEGER                                            :: i, output_unit, unit_nr
      REAL(kind=dp)                                      :: field(3), to_write(3)
      TYPE(cp_logger_type), POINTER                      :: logger
      TYPE(dft_control_type), POINTER                    :: dft_control

      NULLIFY (dft_control)

      logger => cp_get_default_logger()
      output_unit = cp_logger_get_default_io_unit(logger)

      CALL get_qs_env(qs_env, dft_control=dft_control)

      unit_nr = cp_print_key_unit_nr(logger, dft_section, &
                                     "REAL_TIME_PROPAGATION%PRINT%FIELD", extension=".dat")

      IF (output_unit > 0) THEN
         IF (unit_nr /= output_unit) THEN
            INQUIRE (UNIT=unit_nr, NAME=filename)
            WRITE (UNIT=output_unit, FMT="(/,T2,A,2(/,T3,A),/)") &
               "FIELD", "The field applied is written to the file:", &
               TRIM(filename)
            WRITE (UNIT=unit_nr, FMT="(/,(T2,A,T40,I6))") &
               "Real time propagation step:", qs_env%sim_step
         ELSE
            WRITE (UNIT=output_unit, FMT="(/,T2,A)") "FIELD APPLIED"
         END IF

         rlab = [CHARACTER(LEN=3) :: "X", "Y", "Z"]

         IF (dft_control%apply_efield_field) THEN
            WRITE (unit_nr, "(T3,A)") "Electric Field (LG) in atomic units:"
            CALL make_field(dft_control, field, qs_env%sim_step, qs_env%sim_time)
            to_write = field
         ELSE IF (dft_control%apply_vector_potential) THEN
            WRITE (unit_nr, "(T3,A)") "Vector potential (VG) in atomic units:"
            to_write = dft_control%rtp_control%vec_pot
         ELSE
            WRITE (unit_nr, "(T3,A)") "No electric field applied"
            to_write = 0._dp
         END IF

         WRITE (unit_nr, "(T5,3(A,A,E16.8,1X))") &
            (TRIM(rlab(i)), "=", to_write(i), i=1, 3)
      END IF

      CALL cp_print_key_finished_output(unit_nr, logger, dft_section, &
                                        "REAL_TIME_PROPAGATION%PRINT%FIELD")

   END SUBROUTINE print_field_applied

END MODULE rt_propagation_output
