import h5py
import matplotlib.pyplot as plt
import numpy as np

def main():
    fig, ax = plot_simulation()
    plot_analytical(fig, ax)
    plt.tight_layout()

def plot_simulation():
    with h5py.File('results/results.h5', 'r') as f:
        fig, ax = plt.subplots(4,1,dpi=144, figsize=(5,12))
        if 'aggregated/kinetics/derivatives' in f:
            
            dset = f['aggregated']['kinetics']['derivatives']
            data = dset[:]
            for i in range(data.shape[1]-1):
                label = dset.attrs['column_labels'][i+1]
                if '*' not in label:
                    ax[0].plot(data[:, 0], data[:, i + 1], 'o', label=label, alpha=0.5)
            ax[0].grid(linestyle='--')
            ax[0].legend()
            ax[0].set_xlabel('Temperature [K]')
            ax[0].set_ylabel('TOF [-]')
        
        if 'aggregated/kinetics/concentrations' in f:
            
            dset = f['aggregated']['kinetics']['concentrations']
            data = dset[:]
            for i in range(data.shape[1]-1):
                label = dset.attrs['column_labels'][i+1]
                if '*' in label:
                    ax[1].plot(data[:, 0], data[:, i + 1], 'o', label=label, alpha=0.5)
            ax[1].grid(linestyle='--')
            ax[1].legend()
            ax[1].set_xlabel('Temperature [K]')
            ax[1].set_ylabel('Surface concentration [-]')
        
        if 'aggregated/sensitivity/orders' in f:
            dset = f['aggregated']['sensitivity']['orders']
            data = dset[:]
            for i in range(data.shape[1]-1):
                label = dset.attrs['column_labels'][i+1]
                if '*' not in label:
                    ax[2].plot(data[:, 0], data[:, i + 1], 'o', label=label, alpha=0.5)
            ax[2].grid(linestyle='--')
            ax[2].legend()
            ax[2].set_xlabel('Temperature [K]')
            ax[2].set_ylabel('Reaction order [-]')
            
        if 'aggregated/sensitivity/eapp' in f:
            dset = f['aggregated']['sensitivity']['eapp']
            data = dset[:]
            for i in range(data.shape[1]-1):
                label = dset.attrs['column_labels'][i+1]
                if '*' not in label:
                    ax[3].plot(data[:, 0], data[:, i + 1], 'o', label=label, alpha=0.5)
            ax[3].grid(linestyle='--')
            ax[3].legend()
            ax[3].set_xlabel('Temperature [K]')
            ax[3].set_ylabel('Apparent activation energy [J/mol]')

    return fig, ax
        
def plot_analytical(fig, ax):
    """
    Apply the rate-determining step approximation and calculate the reaction rate
    as function of temperature. Plot the result
    """
    # define constants
    R = 8.314462618
    atm2Pa = 101325.0
    pA = 1.0 * atm2Pa
    pB = 0.5 * atm2Pa
    eaf = 120e3
    edesA = 120e3
    edesB = 80e3

    # plot rates
    rates = []
    cov_A = []
    cov_B = []
    orders_A = []
    orders_B = []
    temperatures = np.linspace(350, 1200, 100)
    eact = []
    for T in temperatures:
        KA = calc_K(1e-20, 12, 1.0, 1, 1, edesA, T)
        KB = calc_K(1e-20, 28, 1.0, 1, 1, edesB, T)
        kf = 1e13 * np.exp(-eaf / (R * T))
        
        # calculate rate
        rate = kf * KA * pA * np.sqrt(KB * pB) / (1 + KA * pA + np.sqrt(KB * pB))**2
        rates.append(rate)
        
        # calculate surface coverages
        cov_A.append(KA * pA / (1 + KA * pA + np.sqrt(KB * pB)))
        cov_B.append(np.sqrt(KB * pB) / (1 + KA * pA + np.sqrt(KB * pB)))
        
        # calculate order
        na = 1 - 2 * cov_A[-1]
        nb = 1/2 - cov_B[-1]
        orders_A.append(na)
        orders_B.append(nb)       
        
        # calculate apparent activation energy
        ea = eaf + (-edesA) * (1 - 2 * cov_A[-1]) + (-edesB) * (1/2 - cov_B[-1])
        eact.append(ea)
    
    # plot analytical results
    ax[0].plot(temperatures, rates, '--', color='black', zorder=0)
    ax[1].plot(temperatures, cov_A, '--', color='black', zorder=0)
    ax[1].plot(temperatures, cov_B, '.-', color='black', zorder=0)
    ax[2].plot(temperatures, orders_A, '--', color='black', zorder=0)
    ax[2].plot(temperatures, orders_B, '.-', color='black', zorder=0)
    ax[3].plot(temperatures, eact, '--', color='black', zorder=0)

def calc_K(A, m, theta, sigma, S, Edes, T):
    kb = 1.380649e-23
    atm2Pa = 101325.0
    amu = 1.66053906660e-27
    h = 6.62607015e-34
    R = 8.314462618
    
    kf = A * S / np.sqrt(2 * np.pi * m * amu * kb * T)
    kb = kb*T**3 * A / h**3 * (2 * np.pi * m * amu * kb) / (sigma * theta) * np.exp(-Edes / (R * T))
    
    return kf / kb

if __name__ == '__main__':
    main()