import numpy as np
import matplotlib.pyplot as plt

def propagate(conc, D, dt, nt, dx):
    """Propagate concentration conc over time.

    Parameters:
    ----------
    conc: [numpy array] carrying the initial concentration
    D: [float] diffusion coefficient [arbitrary units, but must be compatible to dt and dx]
    dt: [float] time step [arbitrary units, but must be compatible with D]
    nt: [integer] number of timesteps to propagate
    dx: [float] grid size [arbitrary units, but must be compatible with D]

    Output:
    -------
    conc: [numpy array] carrying the final concentration""" 
    nx = len(conc)

    # propagate
    for it in range(1, nt):
        conc_new = conc.copy() #create array for concentration profile in next step
        for i in range(1, nx-1):
            #Fick's law; keeping conc at left and right constant!
            conc_new[i] = conc[i] # Needs correction!!!
        conc = conc_new
    return conc

def get_pos_centered(ix, dx, nx):
    """Return position for indix ix, assuming cell is centered.

    Parameters:
    ----------
    ix: [int] index
    dx: [float] grid spacing
    nx: [int] total number of grid points

    Output:
    ------
    x: [float] position
    """
    return ix*dx-int(nx/2)*dx

def analyze_spike_propagation(ax, conc, dx, time, color):
    """Plot results of spike propagation and perform some analysis

    Parameters:
    ----------
    ax: [plt.ax object] for figure to plot in 
    conc: [numpy array] concentration profile
    dx: [float] grid spacing
    color: [hex] definition of plot color
    """
    # Create x range
    nx = len(conc)
    indices = np.arange(len(conc))
    xs = get_pos_centered(indices, dx, nx)
    # Plot
    ax.plot(xs, conc, label=f"$t$={time}")
    ax.set_xlabel('$x$')
    ax.set_ylabel('concentration')
    ax.legend()
    # Analyze
    sigma = plot_sigma(ax, conc, dx, color)
    my_string = r"$\sigma = $" + f"{sigma:4.3f}"
    ax.text(sigma, max(conc)/np.exp(0.5), my_string, color = color)

def plot_sigma(ax, conc, dx, color):
    """Plots and returns standard deviation based on conc assumng Gaussian profile"""
    nx = len(conc)
    # find the position at which the concentration has decayed to 1/e^(1/2) of its maximum value
    i = np.where(conc/max(conc)>1/np.exp(0.5))[0][0] #first index 
    j = np.where(conc/max(conc)>1/np.exp(0.5))[0][-1] #last index
    # translate indices into positions
    ix = get_pos_centered(i, dx, nx)
    jx = get_pos_centered(j, dx, nx)
    # plot line connecting 
    ax.plot([ix, jx],
              [1/np.exp(0.5)*max(conc), 1/np.exp(0.5)*max(conc)], 
            ':', color = color)
    # print the width found
    return (jx - ix)/2



if __name__ == "__main__":
    # Set Parameters (dimensionless)
    D = 0.1  # Diffusion coefficient
    L = 3  # Length of the domain
    dx = 0.01 #width of a spatial bin
    nx = int(L/dx)+1  # Number of spatial points (+1 as we want a final grid point to the left and right)
    dt = 1E-4

    # Set times at which to output results
    t_outs = [0] + [0.04, 0.16, 0.36]
    # convert this to time steps
    nt_outs = [int(t/dt) for t in t_outs]

    # Set initial concentration
    conc = np.zeros(nx)
    conc[0] =   # Needs corrections!!!: Set initial concentration at the center

    # Define some plot colors
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']

    # Run and analyze
    fig = plt.figure()
    ax = fig.gca()
    for i in range(1, len(nt_outs)):
        # Propagatin time to next output
        nt = nt_outs[i] - nt_outs[i-1]
        # Propagate
        conc = propagate(conc, D, dt, nt, dx)
        # Analyze
        analyze_spike_propagation(ax, conc, dx, t_outs[i], colors[i-1])
    plt.show()
