from gpaw import GPAW, PW, FermiDirac
from ase import Atoms
from ase.constraints import FixAtoms
from ase.units import Pascal, m
from gpaw.solvation.sjm import SJM, SJMPower12Potential
from gpaw.solvation import (
    EffectivePotentialCavity,
    LinearDielectric,
    GradientSurface,
    SurfaceInteraction)
import sys
import os

# get paths correct. otherwise geoms_from_paper not found in parallel run on snellius
sys.path.append(os.path.dirname(__file__))
import geoms_from_paper




# define function that runs a DFT geometry optimization using the PBE functional, a single-zeta basis set and 3x3 k-points. Returns the energy
def get_energy(slab, kx=4, ky=6, potential_vs_inner=0, name="output"):
    """ sets up a DFT geometry optimization and returns the final energy.
    Input:
    - slab: ase geometry object of the system to be optimized and to compute the energy of
    Output:
    - energy [eV]
    """
    #set up jellium solvation
    sj = {'target_potential': potential_vs_inner}
    # Implicit solvent parameters (to SolvationGPAW).
    epsinf = 78.36  # dielectric constant of water at 298 K
    gamma = 18.4 * 1e-3 * Pascal * m
    cavity = EffectivePotentialCavity(
        effective_potential=SJMPower12Potential(H2O_layer=True),
        temperature=298.15,  # K
        surface_calculator=GradientSurface())
    dielectric = LinearDielectric(epsinf=epsinf)
    interactions = [SurfaceInteraction(surface_tension=gamma)]
    #set up calculation in GPAW
    calc = SJM(
                gpts = (36,20,88),
                setups = {'Pt':'10'},
                txt = "{}.out".format(name),
                mode='fd',
                basis = 'dzp',
                xc='PBE',
                kpts=(kx, ky, 1),
                sj = sj,
                cavity = cavity,
                dielectric = dielectric,
                interactions = interactions,
                parallel = dict(band = 1, #band parallelization
                    augment_grids = True, #use all cores for XC/Poisson
                    sl_auto = True,       #enable ScaLAPACK
                    use_elpa = True      #enable Elpa
                    )
            )

    slab.calc = calc
    slab.get_potential_energy()
    calc.write_sjm_traces(path='sjm_{}'.format(name))
    return slab.get_potential_energy()


#Set potential here
potentials = [0]#,-0.6] #V vs. SHE
#shift of SHE vs. inner potential (use 4.2 instead of 4.4 to avoid influence of water dipole)
SHEpotential = 4.2 #V vs. inner

energies = {}
for potential in potentials:
    for state in ['FS']:#IS, 'TS', 'FS']:
        print("in loop")
        name = state+"_{}VSHE".format(potential)
        slab = geoms_from_paper.Volmer_top_ads[name]
        #shift slab a bit down and reduce vacuum size from 30 reduce cost
        slab.positions[:,2] -= 2.5
        slab.cell[2,2] = 20
        energy = get_energy(slab, kx=1, ky=1, potential_vs_inner = potential + SHEpotential, name=name)
        energies[name] = energy

print("# -----------------------------------------")
print("# -----------------Summary-----------------")
print(energies)
print(f"{'potential':<10} {'TS-IS':>7} {'FS-IS':>7} {'FS-TS':>7}")
print("-" * 60)
for potential in potentials:
        IS = 'IS_{}VSHE'.format(potential)
        TS = 'TS_{}VSHE'.format(potential)
        FS = 'FS_{}VSHE'.format(potential)
        print(f"{potential:<10} {energies[TS]-energies[IS]:>7.3f} {energies[FS]-energies[IS]:>7.3f} {energies[FS]-energies[TS]:>7.3f}")
print("# -------------End Summary-----------------")

