#create class
setClass("genProgPk", slots = list(data = "function", summ = "function", prior="function", proposal="function"))

# create new env to hold the generator functions
generator <- new.env()

# ==============================================================================
# Sourcing Helper Scripts for Entropy Estimation
# ==============================================================================
#source("wtSolver.r")
#source("entropyWts.r")
#source("entEst.r")
# ==============================================================================


#' @title Initialize generator class
#' @keywords internal
#' @noRd
initABC <- function(data.func,
                    summ.func,
                    prior.func,
                    proposal.func = NULL,
                    transform.func = NULL){
  sm1<-function(theta=NULL,obs=NULL, ...){
    ## printing the summaries of the observation
    if(is.null(obs)==TRUE){
      obs=data.func(theta=theta,...)
      return(summ.func(obs))
    }else{
      return(summ.func(obs))
    }
  }

  proposal<-function(theta,...){
    ## proposal to generate the theta and transform it
    if(is.null(proposal.func)){
        proposal.func <- function(theta, i, initial.cov=NULL, initial_interval=1000){
            if(is.null(initial.cov)){
                initial.cov <- diag(dim(theta)[2])
            }
            sn <- 1/2
            eps<-10^(-7)
            d<-dim(theta)[2]
            size<-dim(theta)[1]
            I <- diag(d)
            if(i<=initial_interval){
                new_theta<-theta[size,1:d]+MASS::mvrnorm(1,rep(0,d),initial.cov)
            }else{
                cov.gen=sn*(stats::cov(theta)+eps*I)
                new_theta<-theta[size,1:d]+MASS::mvrnorm(1,rep(0,d),cov.gen)
            }
        }
    }
    if(is.null(transform.func)){
      transform.func <- function(theta, ...){
        identity(theta)
      }
    }
    new_theta <- do.call(proposal.func, append(list(theta=theta), list(...)))
    c(new_theta, do.call(transform.func, append(list(theta=new_theta), list(...))))
  }

  # set the methods in generator class
  generator$generate <- methods::new("genProgPk", data=data.func, summ=sm1, prior=prior.func, proposal=proposal)
}

append_res <- function(res, new_row){
  # add logpost,theta, differential entropy, acceptance rate
  res$logpost <- c(res$logpost,new_row$logpost)
  res$theta <- rbind(res$theta,new_row$theta)
  res$differential_entropy <- c(res$differential_entropy,new_row$differential_entropy)
  res$acceptance_rate <- c(res$acceptance_rate,new_row$acceptance_rate)

  # increase size by 1
  res$size <- res$size + 1

  #return
  res
}

pop_res <- function(res){
  #decrease the size
  res$size <- res$size - 1

  res$logpost <- res$logpost[-1]
  res$theta <- matrix(res$theta[-1,],nrow=res$size)
  res$differential_entropy <- res$differential_entropy[-1]
  res$acceptance_rate <- res$acceptance_rate[-1]

  #return
  res
}

#' @title Empirical Likelihood-based Approximate Bayesian Computation
#' @name abcel
#' @description Perform empirical likelihood-based posterior approximation for ABC problems.
#' @param data.obs Observed data.
#' @param theta0 Initial parameter values.
#' @param rep number of MCMC runs.
#' @param m number of generated samples statistics.
#' @param n size of the data.
#' @param burn_in size of the burn-in.
#' @param data.func This function is used for generating the data.
#' @param summ.func This function is used for generating the statistics on the data.
#' @param prior.func Prior function.
#' @param proposal.func This function is used to propose a new theta given the old thetas. This function needs to include the ellipsis argument.
#' @param transform.func The function is used to transform the theta that will be fed to the generate data function. This function also needs to include the ellipsis argument. The default argument is the identity function.
#' @param fixed.summ.num If the number of the summary statistics remains fixed from one MCMC step to another. (Defaults to TRUE).  
#' @param print_interval Fixed interval of iterations to print the results, if it's set to NULL, it will not print.
#' @param plot_interval Fixed interval of iterations to plot the results, if it's set to NULL, there will be no plots.
#' @param which_plot vector of parameters to plot.
#' @param k order of the nearest neighbor to be used for k-NN based differential entropy estimation.
#' @param ... Args for transformation and proposal functions.
#' @return A list with the following components:
#' @return size: Size of the data.
#' @return logpost: The logarithm of the posterior numerator.
#' @return theta: A matrix of sampled `theta` values. 
#' @return differential_entropy: A vector of computed differential entropies.
#' @return acceptance_rate: Stepwise acceptance rates.
#' @details By default, the sampler performs an adaptive random walk on the parameters with N(0,`initial.cov`) as its proposal density for the first `initial_interval` steps. Subsequently, the variance of the proposal Normal density will be updated based on the previously sampled parameters.
#' @details When defining a new proposal or a new transformation, be sure to insert the ellipsis argument.
#' @references Chaudhuri, S., Ghosh, S., & Pham, K. C. (2020). On an Empirical Likelihood-Based Solution to the Approximate Bayesian Computation Problem. Statistical Analysis and Data Mining, Vol 17(5):e11711.
#' @examples
#'\dontrun{
#'  data.gk1<-function(theta,n=1000){
#'    p<-runif(n,min=0,max=1)
#'    h<-(1-exp(-theta[3]*qnorm(p)))/(1+exp(-theta[3]*qnorm(p)))
#'    x<-theta[1]+theta[2]*(1+0.8*h)*(1+qnorm(p)^2)^theta[4]*qnorm(p)
#'    return(x)
#'  }
#'  summary.gk1<-function(obs){
#'    summary=c(mean(obs),quantile(obs,p=c(.25,.5,.75)))
#'    return(summary)
#'  }
#'  prior.gk1<-function(theta){
#'    pr=c(dunif(theta[1],min=0,max=10,log=TRUE),
#'         dunif(theta[2],min=0,max=10,log=TRUE),
#'         dunif(theta[3],min=0,max=10,log=TRUE),
#'         dunif(theta[4],min=0,max=10,log=TRUE)
#'    )
#'    return(pr)
#'  }
#'  # run the algorithm
#'  # parameters
#'  rep<-100
#'  m<-40
#'  n<-1000
#'  ## target theta and the observed data
#'  theta.t<-c(3,1,2,.5)
#'  pr.t <- prior.gk1(theta.t)
#'  d=length(theta.t)
#'  data.obs<-data.gk1(theta.t,n=n)
#'  initial.cov=diag(d)
#'  diag(initial.cov)=c(1,1,1,1)*10^(-7)
#'  # initialize mean and variance for the initial theta
#'  marginal.mean.A<-3.003707
#'  marginal.mean.B<-1.012046
#'  marginal.mean.g<-2.017939
#'  marginal.mean.k<-0.4894453
#'  mean0<-matrix(0,4,1)
#'  mean0[1,]<-marginal.mean.A
#'  mean0[2,]<-marginal.mean.B
#'  mean0[3,]<-marginal.mean.g
#'  mean0[4,]<-marginal.mean.k
#'  marginal.var.A<-0.0002019947
#'  marginal.var.B<-0.0009072782
#'  marginal.var.g<-0.004954367
#'  marginal.var.k<-0.001020683
#'  initial.cov<-matrix(0,4,4)
#'  initial.cov[1,1]<-marginal.var.A
#'  initial.cov[2,2]<-marginal.var.B
#'  initial.cov[3,3]<-marginal.var.g
#'  initial.cov[4,4]<-marginal.var.k
#'  theta0<-MASS::mvrnorm(1,mean0,initial.cov/100)
#'  abcel <- function(data.obs=data.obs,
#'                    theta0=theta0,
#'                    rep=rep,
#'                    m=m,
#'                    n=n,
#'                    burn_in=burn_in,
#'                    data.func=data.gk1,
#'                    summ.func=summary.gk1,
#'                    prior.func=prior.gk1,
#'                    print_interval=1000,
#'                    plot_interval=0,
#'                    which_plot=NULL)
#'}
#' @export
abcel <- function(data.obs,
                  theta0,
                  rep,
                  m,
                  n,
                  burn_in,
                  data.func,
                  summ.func,
                  prior.func,
                  proposal.func = NULL,
                  transform.func = NULL,
                  fixed.summ.num=TRUE,
                  print_interval=1000,
                  plot_interval=0,
                  which_plot=NULL, 
                  k=NULL, ...){
  # initialize the generator functions to be used
  initABC(data.func, summ.func, prior.func, proposal.func, transform.func)

  # initialize parameters
  start_time <- Sys.time()
  entropy<- 0
  loglik<- -(9999)
  num_burned <- 0

  I <- diag(length(theta0))
  d <- length(theta0)
  theta <- theta0

  ##Par for the plots 
  oldpar <- par(no.readonly = TRUE)
  on.exit(par(oldpar))
  
  if(!is.null(plot_interval)){
  graphics::par(mfrow=c(2,if(d==1){2}else{(length(which_plot)+1)}))
  }
  
  ## observation summaries
  s.obs <- generator$generate@summ(obs=data.obs)

  res <- list(size = 1, logpost = c(-9999), theta = matrix(theta0,nrow=1), differential_entropy = c(0), acceptance_rate = c(0))

  # ==============================================================================
  # OPTIMIZATION: Pre-compute entropy terms and weights
  # These values depend only on d and m, which are constant.
  # We calculate them once here to avoid redundant calculations inside the loop.
  # ==============================================================================
  if(fixed.summ.num==TRUE){
  precomputed_entropy_params <- entropyWts(d = length(s.obs), m = m, k=k, option = 3)
  }

  ## MCMC
  for(i in 2:rep){
    ## burn in
    if(res$size > burn_in && num_burned <= burn_in){
      res <- pop_res(res)
      num_burned <- num_burned + 1
    }
    ## generate data
    thetas <- do.call(generator$generate@proposal,
                      append(list(i=i,theta=res$theta),
                             list(...)))
    theta <- thetas[1:d]
    trans_theta <- thetas[(d+1):(2*d)]
    # generate prior
    prs=generator$generate@prior(theta)

    if(any(is.infinite(prs))==T){
        # append the last row if prior is not defined
        res <- append_res(res, list(logpost = res$logpost[res$size],
                                    theta = res$theta[res$size,1:d],
                                    differential_entropy = res$differential_entropy[res$size],
                                    acceptance_rate = ((i-1)*res$acceptance_rate[res$size])/i))
    }else{
        # generate summaries
        summ<-rep(0,length(s.obs))
        for(j in 1:m){
          summ=rbind(summ,generator$generate@summ(theta=trans_theta,n=n)-s.obs)
        }
        summ=summ[-1,]

        if((isTRUE(NA%in%summ))||(isTRUE(NaN%in%summ))||(isTRUE(Inf%in%summ))){
          # append the last row if summaries cannot be found
          res <- append_res(res, list(logpost = res$logpost[res$size],
                                      theta = res$theta[res$size,1:d],
                                      differential_entropy = res$differential_entropy[res$size],
                                      acceptance_rate = ((i-1)*res$acceptance_rate[res$size])/i))
        }else{
          if(is.matrix(summ)==T){
             lz=rep(0,ncol(summ))
             }else{
             lz=0
             }

          W.matrix<-tryCatch(emplik::el.test(summ,mu=lz,maxit=200)$wts, error = function(e) NULL)
          
          if(is.null(W.matrix) || abs(sum(W.matrix)-m)>10^(-5)){
            res <- append_res(res, list(logpost = res$logpost[res$size],
                                        theta = res$theta[res$size,1:d],
                                        differential_entropy = res$differential_entropy[res$size],
                                        acceptance_rate = ((i-1)*res$acceptance_rate[res$size])/i))
          }else{
            # ==============================================================================
            # OPTIMIZATION: Use the pre-computed terms and weights
            # ==============================================================================
            
            if(fixed.summ.num==FALSE){
               precomputed_entropy_params <- entropyWts(d = length(s.obs), m = m, k=k, option = 3)
  				}
            
            entropy <- entEst(t(summ), 
                              terms = precomputed_entropy_params$terms,
                              wts   = precomputed_entropy_params$wts)

            loglik=mean(log(W.matrix))
            logpos=mean(log(W.matrix))+entropy+sum(prs)
            lograt=min(0,logpos-res$logpost[res$size])
            tryCatch({
            if((log(stats::runif(1))<lograt)){ ## randomly accept the terms
              res <- append_res(res, list(logpost = logpos,
                                          theta = theta,
                                          differential_entropy = entropy,
                                          acceptance_rate = ((i-1)*res$acceptance_rate[res$size]+1)/i))
            }else{
              res <- append_res(res, list(logpost = res$logpost[res$size],
                                          theta = res$theta[res$size,1:d],
                                          differential_entropy = res$differential_entropy[res$size],
                                          acceptance_rate = ((i-1)*res$acceptance_rate[res$size])/i))
            }
            },
            error = function(e){
                #print(c(theta))
                stop(paste(e,theta))
            })
          }
        }
      }

      if(!is.null(print_interval) && i%%abs(print_interval)==0 && exists("W.matrix") && !is.null(W.matrix)){
        thetaPrint <- paste(sprintf("%.4f", round(res$theta[res$size,1:d], 4)), collapse=" ")
        cat(sprintf("iter %d theta %s; Acceptance rate : %.4f; Prior : %.4f; Time Taken %.4f\n", i, thetaPrint, res$acceptance_rate[res$size], sum(prs), difftime(Sys.time(), start_time, units = "secs")))
        start_time = Sys.time()
      }

      if(!is.null(plot_interval) & i%%abs(plot_interval)==0){
        
        if(!is.null(which_plot)){
          for(j in which_plot){
            graphics::plot(as.vector(res$theta[floor(res$size/2):res$size, j]),type="l",main=sprintf("theta %d", j),ylab="",xlab="index")
           }
	  }
	  graphics::plot(as.numeric(res$logpost[floor(res$size/2):res$size]),type="l",main="LLR",ylab="",xlab="index")

	if(!is.null(which_plot)){
          for(j in which_plot){
            graphics::hist(as.vector(res$theta[floor(res$size/2):res$size, j]),main=sprintf("theta %d", j),ylab="",xlab="index")
          }
        }
	
        graphics::plot(as.numeric(res$acceptance_rate[floor(res$size/2):res$size]),type="l",main="acc",ylab="",xlab="index")
      }
  }
  res
}
