#' Estimate the mKBO decomposition
#'
#' This is the main function, computing the multi-group Kitagawa-Blinder-Oaxaca decomposition
#'
#' @param formula A regression formula (as a string) specifying the outcome and explanatory variables.
#' @param group A string naming the grouping variable. This variable should be a factor, and the decomposition will be performed for each level of this factor.
#' @param w A string naming the variable in \code{data} that contains observation weights. If \code{NULL}, equal weights are used.
#' @param data A \code{data.frame} or \code{tibble} containing the microdata. The data must not contain missing values in any of the variables used in the decomposition.
#' @param group_fixed Logical. If \code{TRUE} (default), group fixed effects are included in the pooled model used to estimate the sample-level coefficients.
#' @param viewpoint Character. Either \code{"group"}or \code{"sample"}. Specifies the decomposition perspective:
  #'   \itemize{
  #'     \item \code{"group"}: How would group outcomes change if they had the endowments/coefficient structure of the full sample?
  #'     \item \code{"sample"}: How do group characteristics differ from the sample, and how much does this explain outcome differences?
  #'   }
#' @return An object of class \code{"mkbo"}, which is a list containing:
  #' \describe{
  #'   \item{\code{RECI}}{A tibble summarizing the mean outcome per group (M), mean difference from the reference (D), and contributions from endowments (E), coefficients (C), and interactions (I).}
  #'   \item{\code{E_var}}{A data frame detailing variable-level contributions to the endowments (E) component.}
  #'   \item{\code{C_var}}{A data frame detailing variable-level contributions to the coefficients (C) component.}
  #'   \item{\code{I_var}}{A data frame detailing variable-level contributions to the interaction (I) component.}
  #' }
  #'
#'@details
#' The function performs group-wise regressions and compares them to a pooled regression model. It decomposes the differences in group means of the dependent variable into parts due to differences in observed characteristics (endowments), differences in how those characteristics translate into outcomes (coefficients), and the interaction of both.
#'
#' The choice of \code{viewpoint} changes whether the decomposition is anchored on the sample or group averages, and this influences the interpretation of each component.
#'
#' Group-specific coefficients are augmented with treatment contrasts to match the pooled model structure.
#'
#' @importFrom stats as.formula coef contr.treatment filter lm model.matrix weighted.mean
#' @importFrom dplyr %>% summarise summarise_all across group_by select mutate arrange pull
#' @importFrom tidyselect all_of
#' @importFrom tidyr nest unnest pivot_wider
#' @importFrom purrr map
#' @importFrom tibble tibble rownames_to_column
#' @importFrom rlang .data
#' @importFrom broom tidy
#' @importFrom utils globalVariables
#' @examples
#' mkbo_output <- mkbo("PERNP ~ BACHELOR", group = "RACE", data=pums_subset)
#' @export
#'



mkbo <- function(formula, group, w = NULL, data, group_fixed=TRUE, viewpoint="group") {

  # Ensure that grouping variable is a factor
  data[, group] <- factor(data[, group])
  # Remove all missing values?
  # Clean the dataframe

  ifelse(is.null(w), data$mkbo_weight <- rep(1, nrow(data)), data$mkbo_weight <- data[, w])

  # Preparing to change the reference category
  # Default (following Jann) = The whole Sample is the reference: How would outcome change in Group, if they had E + C + I as in Sample/Population?
  # Reverse: How much do the groups differ from the sample, and how much of this can be attributed to E, C and I?

  if (viewpoint == "group") {reverse_mod = 1}
  if (viewpoint == "sample") {reverse_mod = -1}


  ### Setting up internal data structures
  n_groups <- length(table(data[, group]))
  # Number of observations (by group)
  observations_by_group <- table(data[, group])

  # Create a pooled dataset from a model.matrix, in order to 'expand' all factor variables (including the group dummies)
  if(group_fixed==TRUE) {model_pooled <- lm(paste(formula, group, sep="+"), data=data, weights=mkbo_weight)}
  if(group_fixed==FALSE) {model_pooled <- lm(paste(formula, sep="+"), data=data, weights=mkbo_weight)}

  data_pooled <- as.data.frame(model.matrix(model_pooled))

  # Calculate the mean score of each X variable (including the group-dummies)
  means_pooled <- apply(data_pooled, 2, function(x) weighted.mean(x, w=data$mkbo_weight))
  # Add a group indicator and weights to the pooled data
  data_pooled$group <- data[, group]
  data_pooled$mkbo_weight <- data$mkbo_weight

  # Calculate the mean differences in dependent variable by group
  dep_var <- all.vars(as.formula(formula))[1]
  dep_mean_pooled <- data %>% summarise(across(all_of(dep_var), ~ weighted.mean(.x, w=.data$mkbo_weight)))
  dep_mean_group  <- data %>% group_by(across(all_of(group))) %>% summarise(across(all_of(dep_var), ~ weighted.mean(.x, w=.data$mkbo_weight)))

  dep_mean_diff <- dep_mean_pooled - as.vector(dep_mean_group[,2])
  dep_mean_diff <- dep_mean_diff[,1]

  ###
  # Group-specific calculations
  ####

  group_names <- names(table(data_pooled$group))

  means_by_group <- data_pooled %>%
    group_by(group) %>%
    summarise(across(
      .cols = -c(mkbo_weight),
      .fns = ~ weighted.mean(.x, mkbo_weight))) %>%
    select(-all_of("group"))

  coefficients_by_group <- data %>%
    nest(.by=group) %>%
    mutate(models = lapply(data, function(df) lm(formula, data=df, weights=mkbo_weight)),
           tidied = map(.data$models, tidy)) %>%
    unnest(.data$tidied) %>%
    select(all_of(group), .data$term, .data$estimate) %>%
    pivot_wider(names_from = .data$term, values_from = .data$estimate) %>%
    arrange(across(1)) %>%
    select(-all_of(group))

  # Add contrasts for the group variable
  coefficients_by_group <- cbind(coefficients_by_group, contr.treatment(n_groups)[, (n_groups-1):1])

  ###
  # Pooled calculations
  ###
  coefficients_pooled  <- coef(model_pooled)

  ###
  # Calculate group-specific deviations from pooled means and coefficients
  ###

  means_diff <- means_pooled - t(means_by_group)

  coef_diff <-  coefficients_pooled - t(coefficients_by_group)

  ###
  # The real deal: the actual Kitagawa-Oaxaca-Blinder calculations
  ###

  E_var <- t(means_diff) * coefficients_by_group * reverse_mod
  E_var <- t(E_var)
  C_var <- t(means_by_group) * coef_diff * reverse_mod
  I_var <- means_diff * coef_diff * reverse_mod

  E <- colSums(E_var)
  C <- colSums(C_var)
  I <- colSums(I_var)

  D = dep_mean_diff * reverse_mod

  R = E + C + I

  ###
  # Tidy up the output
  ###

  RECI <- tibble(
    group = group_names,
    M = pull(dep_mean_group, 2),
    D=D,
    R=R,
    E=E,
    C=C,
    I=I)

  colnames(E_var) <- t(group_names)
  E_var <- as.data.frame(E_var)
  E_var <- rownames_to_column(E_var, "parameter")

  colnames(C_var) <- t(group_names)
  C_var <- as.data.frame(C_var)
  C_var <- rownames_to_column(C_var, "parameter")

  colnames(I_var) <- t(group_names)
  I_var <- as.data.frame(I_var)
  I_var <- rownames_to_column(I_var, "parameter")

  ##
  # Correct the parameter names of E_var (might be fixed upstream, but this works)
  E_var[,1] <- C_var[,1]
  ###
  # Return output
  ###
  output <- list(
    RECI = RECI,
    E_var = E_var,
    C_var = C_var,
    I_var = I_var)

  class(output) <- "mkbo"
  return(output)
}
