#------------------------------------------------------------------
# fitExp: Exponential fit to histogram contents
#
# CC 31 May 2018
#
# Changed the calling sequence for truth, added labels to plot
# CC 18 Feb 2019
#-------------------------------------------------------------------
import numpy as np
import matplotlib.pyplot as plt

def fitExp(contents, bins, guess, truth=None, firstBin=0, lastBin=-1,
           nIter=4, plotContour=True, verbosity=0):
    """
    Performs an exponential chisq fit to an array (histogram).
    Fits to par[0]*exp(-par[1]*x)
    Takes sigma=sqrt(Nobserved) instead of sqrt(Npredicted).
    Inputs:
      contents    : array of dimension N to be fit
      bins        : bin edges, dimension N+1
      guess       : 1st guess to parameters
      truth       : true value of par.  
      firstBin    : first bin to be fit
      lastBin     : last bin to be fit (-1 if last bin of histogram)
      nIter       : number of iterations
      plotContour : output contour plot if True
      verbosity   : >0 for some debug output
    Returns:
      par      : fitted parameters
      chisq    : chisquared at minimum
      ndof     : number of degrees of freedom
      cov      : covariance matrix
    """
      
    # bin centers 
    X = 0.5 * (bins[1:] + bins[:-1])

    # Bin values, square of errors
    Y = contents.copy()
    S = contents.copy()
    
    # get rid of bins that we are not fitting
    if lastBin == -1:
        X = X[firstBin:]
        Y = Y[firstBin:]
        S = S[firstBin:]
    else:
        X = X[firstBin:lastBin]
        Y = Y[firstBin:lastBin]
        S = S[firstBin:lastBin]

    # careful about bins with zero content
    if len(S) > np.count_nonzero(S):   # there are some zeroes
        print("Warning: found some bins with zero error.  Will set those errors to 1")
        blah  = (S == 0)    # array of booleans depending on whether corresponding element of S is zero
        blah2 = blah.astype(int)   # Turn it into array of 0 and 1
        S     = S + blah2          # the zeroes are now ones
                
    # parameters
    par = guess.copy()

    # inverse covariance matrix for the bins...uses the actual counts
    W = np.diag(1./S)   # NxN matrix

    # the truth....set it to empty array if not given
    if truth is None:
        truth = np.array([])

    # Start the iteration
    for i in range(0, nIter):

        # chisquared
        Yfit = par[0]*np.exp(-par[1]*X)
        chisq = ((Yfit-Y)**2/S).sum()
        if verbosity>0: print ("before iteration ", i, "chisq =", chisq)
    
        # matrices...A and Atrans are the derivatives of the predictions wrt the fitted parameters
        Atrans = np.array([np.exp(-par[1]*X) , -par[0]*X*np.exp(-par[1]*X)])  # 2xN matrix
        A = (Atrans.T).copy()  # Nx2 matrix
        dy = (np.array( [(Y-Yfit),] )).T # Nx1 column vector

        # find the matrix to be inverted, and invert it
        temp  = np.matmul(Atrans, W)   # 2xN * NxN = 2xN
        temp2 = np.matmul(temp, A)     # 2xN * Nx2 = 2x2
        temp3 = np.linalg.inv(temp2)   # 2x2 ... this is the covariance matrix

        # multiply again
        temp4 = np.matmul(temp3, Atrans) # 2x2 * 2xN = 2xN
        temp5 = np.matmul(temp4, W)      # 2xN * NxN = 2xN
        dpar  = np.matmul(temp5, dy)     # 2xN * Nx1 = 2x1 column vector

        # the new values of the parameters
        par[0] = par[0] + dpar[0][0]
        par[1] = par[1] + dpar[1][0]

    # The fit is now done...calculate a few things
    Yfit = par[0]*np.exp(-par[1]*X)
    chisq = ((Yfit-Y)**2/S).sum()  # chisquared
    ndof = len(X) - 2              # number of degrees of freedom
    if verbosity>0: print ("At the end chisq =", chisq)

    # chisquare contour plot
    if plotContour:
        fig2, ax2 = plt.subplots(1,1)

        # a is the constant and tau is the lifetime
        amin = par[0] - 4*np.sqrt(temp3[0][0])
        amax = par[0] + 4*np.sqrt(temp3[0][0])
        taumax = 1./par[1] + 1./par[1]**2 * 4*np.sqrt(temp3[1][1])
        taumin = 1./par[1] - 1./par[1]**2 * 4*np.sqrt(temp3[1][1])
        nscan = 100
        ax2.set_xlim(taumin, taumax)
        ax2.set_ylim(amin, amax)
        a = np.linspace(amin,amax,100)
        tau = np.linspace(taumin,taumax,100) 

        # I should really use a meshgrid but I cant really
        # get it to work with the "X", "Y", and "S" arrays (?)
        # So we go brute force, sigh
        z = np.zeros(shape=(nscan,nscan))
        for iA in range(0, len(a)):
            this_a = a[iA]
            for itau in range(0, len(tau)):
                this_tau = tau[itau]
                Yfit = this_a*np.exp(-X/this_tau)
                z[iA][itau] = ((Yfit-Y)**2/S).sum() - chisq

        # Correspond to the 68% 95% 99% coverage for 2 parameters
        # Particle Data Group (PDG) Table 38.2
        # http://pdg.lbl.gov/2015/reviews/rpp2015-rev-statistics.pdf
        CS = ax2.contour(tau, a, z, [2.30, 5.99, 9.21])
        fmt = {}
        strs = [ '68%', '95%', '99%' ]
        for l,s in zip( CS.levels, strs ):
            fmt[l] = s
        ax2.clabel(CS, inline=True, fmt=fmt, fontsize=10)

        # put a point at the best fit
        ax2.plot(1./par[1], par[0], 'ko')

        # label the axes and set the title and make a gris
        ax2.set_title("Fit to p[0]*exp(-p[1]*x)")
        ax2.set_xlabel("1/p[1]")
        ax2.set_ylabel("p[0]")
        ax2.grid(True, which='both')
        
        # put a point at the truth, if given
        if len(truth) == 2: ax2.plot(1./truth[1], truth[0], 'ro')

        #show the figure
        fig2.show()
        # input("Press <Enter> to continue") 

    # we are done
    return par, chisq, ndof, temp3
    

