# Author: Xuye Luo
# Date: December 11, 2025

#' @title Plot Matrix with Entries Represented 
#'   by Balloons of Varying Sizes and Colors 
#'
#' @description Creates a "balloon plot" to visualize 
#'   numeric data in a matrix or contingency table.
#' 
#' @details Each entry in the matrix is represented by 
#'   a shape, with size and color corresponding to the
#'   magnitude of value in the entry. It offers an
#'   alternative to heatmap for displaying count data.
#'
#' @param x a numeric matrix or table to be plotted.
#' @param title a character string for the main title of the plot.
#'   Defaults to \code{"Balloon plot"}.
#' @param shape.color a character string specifying the
#'   color for entries (e.g., \code{"tomato"}, \code{"blue"}).
#' @param s.min a numeric value specifying the minimum size
#'   of the shapes. Defaults to 5.
#' @param s.max a numeric value specifying the maximum size
#'   of the shapes. Defaults to 30.
#' @param x.axis a character vector for custom x-axis labels.
#'   If \code{NULL}, column names of \code{x} are used. 
#'   Set to \code{""} to hide labels.
#' @param y.axis a character vector for custom y-axis labels.
#'   If \code{NULL}, row names of \code{x} are used. 
#'   Set to \code{""} to hide labels.
#' @param x.lab a character string for the x-axis title. 
#'   Defaults to \code{""}.
#' @param y.lab a character string for the y-axis title.
#'   Defaults to \code{""}.
#' @param bg.color a character string for the background
#'   color of the tiles. Defaults to \code{"white"}.
#' @param grid.color a character string specifying color of
#'   grid lines (\code{NA} to remove).
#' @param grid.width a numeric value to specify the width
#'   of grid lines.
#' @param size.by a character string to specify how to
#'   scale the size of balloon: \code{"column"} (Default), 
#'   \code{"row"}, \code{"global"}, or \code{"none"}.
#' @param color.by a character string to specify how to
#'   scale the color of balloon: \code{"global"} (Default), 
#'   \code{"row"}, \code{"column"}, or \code{"none"}.
#' @param number.size a numeric value specifying the font
#'   size for text.
#' @param shape.by a character string to specify how to 
#'   choose the shape of balloon: \code{"column"} (Default), 
#'   \code{"row"}, or \code{""} (none).
#' @param shapes a character vector to specify shape codes.
#'
#' @return A ggplot object.
#' 
#' @importFrom reshape2 melt
#' @importFrom scales rescale
#' @importFrom grDevices col2rgb rgb colorRamp
#' @importFrom stats ave
#' @import ggplot2
#' @export
#' 
#' @examples
#' library(ggplot2)
#' mat <- matrix(c(10, 20, 30, 50, 80, 60, 40, 30), nrow = 2)
#' rownames(mat) <- c("Row1", "Row2")
#' colnames(mat) <- c("C1", "C2", "C3", "C4")
#' 
#' # Color by Row (Row 1 = red, Row 2 = blue)
#' plot_matrix(mat, color.by = "row", shape.color = c("tomato", "steelblue"))
#' 
#' # Color by Column (Rainbow colors)
#' plot_matrix(mat, color.by = "column", shape.color = c("red", "green", "blue", "orange"))
#' @keywords internal
plot_matrix <- function(x,
                        title = "Balloon plot",
                        shape.color = c("tomato"),
                        s.min = 1, 
                        s.max = 30,
                        x.axis = NULL, y.axis = NULL,
                        x.lab = "", y.lab = "",
                        bg.color = "white",
                        grid.color = "black", 
                        grid.width = 0.1,    
                        size.by = c("column", "row", "global", "none"),
                        color.by = c("column", "row", "global", "none"), 
                        number.size = 6,
                        shape.by = c("column", "row", ""),
                        shapes = c(21, 22, 23, 24)) {
  
  x <- as.matrix(x)
  Column <- Row <- Value <- FinalFill <- Shape <- SizeInput <- NormVal <- BaseColor <- NULL
  if (is.null(rownames(x))) rownames(x) <- paste0("R", 1:nrow(x))
  if (is.null(colnames(x))) colnames(x) <- paste0("C", 1:ncol(x))
  
  hide_x <- identical(x.axis, "")
  hide_y <- identical(y.axis, "")
  final_x_labels <- if (hide_x) NULL else if (is.null(x.axis)) colnames(x) else x.axis
  final_y_labels <- if (hide_y) NULL else if (is.null(y.axis)) rownames(x) else y.axis
  
  size_scope <- match.arg(size.by)
  color_scope <- match.arg(color.by)
  shape_scope <- match.arg(shape.by)
  
  if (nrow(x) == 1 || ncol(x) == 1) {
    if (size_scope != "none") size_scope <- "global"
    if (color_scope != "none" && color_scope != "global") color_scope <- "global" 
  }
  
  x_melt <- reshape2::melt(x, na.rm = FALSE)
  colnames(x_melt) <- c("Row", "Column", "Value")
  x_melt$Row <- factor(x_melt$Row, levels = rownames(x))
  x_melt$Column <- factor(x_melt$Column, levels = colnames(x))
  
  scale_func <- function(v) {
    is_zero <- v == 0
    non_zero <- v[!is_zero]
    out <- numeric(length(v))
    out[is_zero] <- -0.1
    if (length(non_zero) > 0) {
      if (length(unique(non_zero)) > 1) {
        min_v <- min(non_zero); range_v <- max(non_zero) - min_v
        out[!is_zero] <- (non_zero - min_v) / range_v
      } else { out[!is_zero] <- 1 }
    }
    return(out)
  }
  
  if (size_scope == "none") {

    x_melt$NormVal <- scale_func(x_melt$Value)
  } else {
    if (size_scope == "row") {
      x_melt$NormVal <- stats::ave(x_melt$Value, x_melt$Row, FUN = scale_func)
    } else if (size_scope == "column") {
      x_melt$NormVal <- stats::ave(x_melt$Value, x_melt$Column, FUN = scale_func)
    } else {
      x_melt$NormVal <- scale_func(x_melt$Value)
    }
  }
  
  if (shape_scope == "row") {
    map <- rep(shapes, length.out = nlevels(x_melt$Row))
    x_melt$Shape <- map[as.integer(x_melt$Row)]
  } else if (shape_scope == "column") {
    map <- rep(shapes, length.out = nlevels(x_melt$Column))
    x_melt$Shape <- map[as.integer(x_melt$Column)]
  } else {
    x_melt$Shape <- shapes[1]
  }
  

  if (color_scope == "none") {

    solid_color <- shape.color[1]
    

    x_melt$FinalFill <- ifelse(x_melt$Value == 0, bg.color, solid_color)
    
  } else {

    if (color_scope == "row") {
      color_vec <- rep(shape.color, length.out = nrow(x))
      x_melt$BaseColor <- color_vec[as.integer(x_melt$Row)]
    } else if (color_scope == "column") {
      color_vec <- rep(shape.color, length.out = ncol(x))
      x_melt$BaseColor <- color_vec[as.integer(x_melt$Column)]
    } else {
      x_melt$BaseColor <- shape.color[1]
    }
    

    get_hex_color <- function(norm_val, base_col, bg_col) {
      if (norm_val < 0) return(bg_col) 
      rgb_base <- grDevices::col2rgb(base_col)
      rgb_white <- grDevices::col2rgb("white")
      rgb_light <- round(rgb_base * 0.1 + rgb_white * 0.9)
      rgb_dark  <- pmax(rgb_base - 30, 0)
      ramp_func <- grDevices::colorRamp(c(
        grDevices::rgb(rgb_light[1], rgb_light[2], rgb_light[3], maxColorValue=255),
        grDevices::rgb(rgb_dark[1], rgb_dark[2], rgb_dark[3], maxColorValue=255)
      ))
      final_rgb <- ramp_func(norm_val) 
      grDevices::rgb(final_rgb[1], final_rgb[2], final_rgb[3], maxColorValue = 255)
    }
    
    x_melt$FinalFill <- mapply(get_hex_color, 
                               x_melt$NormVal, 
                               x_melt$BaseColor, 
                               MoreArgs = list(bg_col = bg.color))
  }
  

  p <- ggplot(x_melt, aes(x = Column, y = Row)) +
    coord_fixed() +
    labs(title = title, x = x.lab, y = y.lab) +
    scale_fill_identity() + 
    scale_x_discrete(labels = final_x_labels, drop = FALSE) +
    scale_y_discrete(labels = if(is.null(final_y_labels)) NULL else rev(final_y_labels), 
                     limits = rev, drop = FALSE) +
    theme(
      plot.title = element_text(hjust = 0.5, size = 18),
      axis.text.x = element_text(size = 10, angle = 20, hjust = 1, vjust = 1),
      legend.position = "none",
      panel.background = element_rect(fill = "white", color = NULL)
    )
  
  if (size_scope == "none") {
    # HEATMAP MODE
    p <- p + 
      geom_tile(aes(fill = FinalFill), color = grid.color, linewidth = grid.width) 
    if (max(dim(x)) <= 20) {
      p <- p + geom_text(aes(label = Value), size = sqrt(800/max(dim(x))))
    }
    
  } else {
    # BALLOON MODE
    if (is.null(s.max)) {
      max_dim <- max(nrow(x), ncol(x))
      s.max <- 120 / max_dim
    }
    
    p <- p +
      geom_tile(fill = bg.color, color = grid.color, linewidth = grid.width) +
      geom_point(aes(size = Value, fill = FinalFill, shape = as.factor(Shape)), 
                 color = "transparent", stroke = 0.5) +
      geom_text(aes(label = Value), size = number.size) +
      scale_size(range = c(s.min, s.max), limits = c(0, max(x_melt$Value, na.rm = TRUE)), guide = "none") +
      scale_shape_manual(values = unique(x_melt$Shape), guide = "none")
  }
  
  if (hide_x) p <- p + theme(axis.ticks.x = element_blank())
  if (hide_y) p <- p + theme(axis.ticks.y = element_blank())
  
  return(p)
}

