import os, sys
from gpaw import GPAW, PW, FermiDirac
from ase import Atoms
from ase.optimize import QuasiNewton, BFGS
from ase.constraints import FixAtoms
import numpy as np
from scipy.optimize import curve_fit
from ase.parallel import world
sys.path.append(os.getcwd())
from geomaker import *

# define function that runs a DFT geometry optimization using the PBE functional, a single-zeta basis set and 3x3 k-points. Returns the energy
def get_opt_energy(slab, k=3, comm):
    """ sets up a DFT geometry optimization and returns the final energy.
    Input:
    - slab: ase geometry object of the system to be optimized and to compute the energy of
    - k: number of k points in x and y
    - comm: the communicator object telling the calculations which ranks to use
    Output:
    - energy [eV]
    """
    # set up calculation in GPAW
    if len(slab)==2:
        outname = "H2"
    else:
        outname = "Pt111_{}".format(sum(slab.get_tags()==0)) #labels the output with the number of Hads
    if world.rank == 0:
        print(outname)
    calc = GPAW(communicator = comm, # this allows running different calculations on different cores
                setups = {'Pt':'10'}, # use pseudo-potential with only 10 valence e- for Pt
                txt = outname + ".out", 
                mode='fd', # use finite difference mode
                h = 0.22, # grid spacing
                xc='PBE', # define functional
                kpts=(k, k, 1), #define number of k-points
            )
    # add the calculator to your atoms object
    slab.calc = calc

    # setup the geometry relaxation
    relax = QuasiNewton(slab, trajectory = outname + ".traj")
    force_convergence_threshold = 0.07 #eV/A  --- note that this is rather large for such a simple system
    relax.run(fmax=force_convergence_threshold)

    # extract final energy
    return slab.get_potential_energy()

if __name__  == "__main__":
    # Define surface unit cell and max number of adsorbates
    Nx = 2
    Ny = 2
    l = 3
    Nmax = Nx * Ny
    # check whether some energies have already been computed and written to file
    if os.path.exists('slab_energies.npy'):
        energies = np.load('slab_energies.npy')
    else:
        energies = np.zeros(Nmax + 1)

    #perform slab calculations
    for N in range(Nmax + 1):
        if mpi.world.rank == N:
            comm = mpi.world.new_communicator(np.array([N]))
            if abs(energies[N])> 1E-5: # not equal to 0 ==> has been computed already
                continue # skip this calculation if already computed
            slab = build_slab_with_adsorbate(N, Nx, Ny, l)
            energy = get_opt_energy(slab, k=1, comm = comm)
            # send energy back to rank 0
            if not mpi.world.rank == 0:
                mpi.world.send(np.array([energy]), 0)
            else:
                energies[0]=energy
                received = np.array([0.])
                for rank in range(1,5):
                mpi.world.receive(received, rank)
                energies[rank] = received
                

    if world.rank == 0:
        np.save('slab_energies.npy', energies)

    # check wheter H2 has already been computed
    if os.path.exists('h2_energy.npy'):
        h2_energy = np.load('h2_energy')
    else:
        h2 = build_h2()
        h2_energy = get_opt_energy(h2, k=2)
        if world.rank == 0:
            np.save('h2_energy.npy', h2_energy)

    # print total energies
    if world.rank == 0:
        print("Total energies")
        print("slabs: ", energies)
        print("H2:    ",h2_energy)

    # compute and print binding energies
    if world.rank == 0:
        binding_eneries = np.zeros(Nmax)
        for N in range(1, Nmax+1):
            binding_energies[N-1] = energies[N] - energies[0] - float(N)/2*energy_h2
        # print binding energies
        print("Binding energies")
        print("All:  ", binding_energies)
        np.save('binding_energies.npy', binding_energies)



