import numpy as np
from gpaw import GPAW, PW, FermiDirac
from ase import Atoms
from ase.constraints import FixAtoms
from ase.units import Pascal, m
from gpaw.solvation.sjm import SJM, SJMPower12Potential
from gpaw.solvation import (
    EffectivePotentialCavity,
    LinearDielectric,
    GradientSurface,
    SurfaceInteraction)
from ase.parallel import world
import sys
import os
sys.path.append(os.path.dirname(__file__))
import geoms_from_paper


def read_energies_avail(filename):
    if os.path.exists('energies.npy'):
        energies_avail = np.load('energies.npy', allow_pickle=True).item()
    else:
        energies_avail = []
    return energies_avail

def get_energy(slab, kx=4, ky=6, potential_vs_inner=0, name="output"):
    """ Set up a constant potential DFT calculation and return energy.

    Parameters:
    ----------
    slab: [ase geometry object] of the system 
    kx: [int] number of k-points in x-dimension
    ky: [int] number of k-points in y-dimension
    potential_vs_inner: [float] potential at which to compute the energy given in [V] vs. the inner potential of water
    name: name of the calculation (also used for gpaw output)

    Output:
    ------
    -energy [eV]
    """
    # Implicit solvent parameters (to SolvationGPAW).
    epsinf = 78.36  # dielectric constant of water at 298 K
    gamma = 18.4 * 1e-3 * Pascal * m
    cavity = EffectivePotentialCavity(
        effective_potential=SJMPower12Potential(H2O_layer=False),
        temperature=298.15,  # K
        surface_calculator=GradientSurface())
    dielectric = LinearDielectric(epsinf=epsinf)
    interactions = [SurfaceInteraction(surface_tension=gamma)]
    
    # set up calculation in GPAW using SJM method
    calc = SJM(
                #gpts = (36,20,88),
                setups = {'Pt':'10'}, #large core PP
                txt = "{}.out".format(name),
                mode='lcao', #iadjust! use finite difference method
                basis = 'dzp', #used only for initial guess if mode not lcao
                xc='PBE', #functional
                kpts=(kx, ky, 1), #kpoints
                cavity = cavity, #dielectric cavity
                dielectric = dielectric, #dielectric constant of solvent
                interactions = interactions, #additional solvent - solute interactions
                sj = {'target_potential': potential_vs_inner}, #commands for the solvated jellium
                parallel = dict(band = 1, #band parallelization
                    augment_grids = True, #use all cores for XC/Poisson
                    sl_auto = True,       #enable ScaLAPACK
                    use_elpa = True      #enable Elpa
                    )
            )

    # Compute energy
    slab.calc = calc
    slab.get_potential_energy()
    # Write SJM output to file
    calc.write_sjm_traces(path='sjm_{}'.format(name))
    return slab.get_potential_energy()


# Set potential here
potentials = [0]#,-0.6] #V vs. SHE
# shift of SHE vs. inner potential (use 4.2 instead of 4.4 to avoid influence of water dipole)
SHEpotential = 4.2 #V vs. inner

# Read in energies that have already been computed
energies_avail = read_energies_avail('energies.npy')

# Compute energies
energies = {}
for potential in potentials:
    for state in ['IS']:#, 'TS', 'FS']:
        name = state+"_{}VSHE".format(potential)
        # don't recompute if already known
        if name in energies_avail:
            continue
        # read geometries from file
        slab = geoms_from_paper.Volmer_top_ads[name]
        # shift slab a bit down and reduce vacuum size from 30 reduce cost
        slab.positions[:,2] -= 2.5
        slab.cell[2,2] = 20
        # compute energy
        energy = get_energy(slab, kx=1, ky=1, potential_vs_inner = potential, name=name)
        energies[name] = energy

# Read in energies again that have already been computed and info to energies
#(helps diminishing chances that this will crash when running several simulations in parallel)
energies_avail = read_energies_avail('energies.npy')
for name in energies_avail:
    if not name in energies:
        energies[name] = energies_avail[name]

# Save energies
if world.rank == 0:
    np.save('energies.npy', energies)
