/*-------------------------------------------------------------------------------
 This file is part of unityForest.

 Copyright (c) [2014-2018] [Marvin N. Wright]
 Modifications and extensions by Roman Hornung

 This software may be modified and distributed under the terms of the MIT license.

 Please note that the C++ core of divfor is distributed under MIT license and the
 R package "unityForest" under GPL3 license.
 #-------------------------------------------------------------------------------*/

#include <Rcpp.h>

#include <thread>
#include <chrono>
#include <unordered_map>
#include <random>
#include <algorithm>
#include <iostream>
#include <iterator>
#include <array>

#include "TreeClassification.h"
#include "utility.h"
#include "Data.h"

// #include <chrono>

namespace unityForest
{

  // Constructor for forest construction:
  TreeClassification::TreeClassification(std::vector<double> *class_values, std::vector<uint> *response_classIDs,
                                         std::vector<std::vector<size_t>> *sampleIDs_per_class, std::vector<double> *class_weights) : class_values(class_values), response_classIDs(response_classIDs), sampleIDs_per_class(sampleIDs_per_class), class_weights(class_weights), counter(0), counter_per_class(0)
  {
  }

  // Constructor for prediction mode:
  TreeClassification::TreeClassification(std::vector<std::vector<size_t>> &child_nodeIDs,
                                         std::vector<size_t> &split_varIDs, std::vector<double> &split_values,
                                         std::vector<double> *class_values,
                                         std::vector<uint> *response_classIDs) : Tree(child_nodeIDs, split_varIDs, split_values), class_values(class_values), response_classIDs(response_classIDs), sampleIDs_per_class(0), class_weights(0), counter{}, counter_per_class{}
  {
  }

  // Constructor for repr_tree_mode:
  TreeClassification::TreeClassification(std::vector<std::vector<size_t>> &child_nodeIDs, std::vector<size_t> &split_varIDs,
                                         std::vector<double> &split_values,
                                         std::vector<double> *class_values, std::vector<double> *class_weights, std::vector<uint> *response_classIDs,
                                         std::vector<size_t> &nodeID_in_root, std::vector<size_t> &inbag_counts, std::vector<size_t> &repr_vars, const Data *data_ptr) : Tree(child_nodeIDs, split_varIDs, split_values, data_ptr), class_values(class_values), response_classIDs(response_classIDs), sampleIDs_per_class(0), class_weights(class_weights), counter(0), counter_per_class(0)
  {
    this->nodeID_in_root = nodeID_in_root;
    this->repr_vars = repr_vars;
    this->inbag_counts = inbag_counts;
  }

  std::unique_ptr<Tree> TreeClassification::clone() const
  {
    return std::make_unique<TreeClassification>(*this);
  }

  void TreeClassification::allocateMemory()
  {

    // Init counters if not in memory efficient mode
    if (!memory_saving_splitting)
    {
      size_t num_classes = class_values->size();
      size_t max_num_splits = data->getMaxNumUniqueValues();

      // Use number of random splits for extratrees
      if (splitrule == EXTRATREES && num_random_splits > max_num_splits)
      {
        max_num_splits = num_random_splits;
      }

      counter.resize(max_num_splits);
      counter_per_class.resize(num_classes * max_num_splits);
    }
  }

  double TreeClassification::estimate(size_t nodeID)
  {

    // Count classes over samples in node and return class with maximum count
    std::vector<double> class_count = std::vector<double>(class_values->size(), 0.0);

    for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
    {
      size_t sampleID = sampleIDs[pos];
      size_t value = (*response_classIDs)[sampleID];
      class_count[value] += (*class_weights)[value];
    }

    if (end_pos[nodeID] > start_pos[nodeID])
    {
      size_t result_classID = mostFrequentClass(class_count, random_number_generator);
      return ((*class_values)[result_classID]);
    }
    else
    {
      throw std::runtime_error("Error: Empty node.");
    }
  }

  // Evaluate a random candidate tree root.
  double TreeClassification::evaluateRandomTree(const std::vector<size_t> &terminal_nodes)
  {
    // Constants and cached references (resolve all indirections once)
    const size_t num_classes = class_values->size();
    const auto &class_w = *class_weights;
    const auto &class_id = *response_classIDs;

    // Per-class counters: fast path (<=64) uses std::array, fallback uses std::vector
    std::array<size_t, 64> cnt64{};
    std::vector<size_t> cnt_big;
    size_t *cnt = nullptr;

    if (num_classes <= cnt64.size())
    {
      cnt = cnt64.data();
    }
    else
    {
      cnt_big.resize(num_classes);
      cnt = cnt_big.data();
    }

    double sum_impurity = 0.0;
    size_t tot_samples = 0;

    // --- main loop over terminal nodes ----------------------------------
    for (size_t nodeID : terminal_nodes)
    {
      // 1. Zero the used part of cnt once with memset
      std::memset(cnt, 0, num_classes * sizeof(size_t));

      // 2. Count classes in the node (tight inner loop, no branches)
      size_t sBeg = start_pos_loop[nodeID];
      size_t sEnd = end_pos_loop[nodeID];
      for (size_t p = sBeg; p < sEnd; ++p)
        ++cnt[class_id[sampleIDs[p]]];

      const size_t node_n = sEnd - sBeg;
      tot_samples += node_n;

      // 3. Compute Σ w_k·c_k²  (unrolled by 4 – helps the compiler)
      double ss = 0.0;
      size_t k = 0;
      for (; k + 3 < num_classes; k += 4)
      {
        ss += class_w[k] * static_cast<double>(cnt[k] * cnt[k]);
        ss += class_w[k + 1] * static_cast<double>(cnt[k + 1] * cnt[k + 1]);
        ss += class_w[k + 2] * static_cast<double>(cnt[k + 2] * cnt[k + 2]);
        ss += class_w[k + 3] * static_cast<double>(cnt[k + 3] * cnt[k + 3]);
      }
      for (; k < num_classes; ++k) // handle the rest (0–3 iterations)
        ss += class_w[k] * static_cast<double>(cnt[k] * cnt[k]);

      sum_impurity += ss / static_cast<double>(node_n);
    }
    // ---------------------------------------------------------------------

    return tot_samples ? (sum_impurity / static_cast<double>(tot_samples))
                       : 0.0;
  }

  // Split in a tree sprout.
  bool TreeClassification::splitNodeInternal(size_t nodeID, std::vector<size_t> &possible_split_varIDs)
  {

    // Stop if maximum node size or depth reached
    size_t num_samples_node = end_pos[nodeID] - start_pos[nodeID];
    if (num_samples_node <= min_node_size || (nodeID >= last_left_nodeID && max_depth > 0 && depth >= max_depth))
    {
      split_values[nodeID] = estimate(nodeID);
      return true;
    }

    // Check if node is pure and set split_value to estimate and stop if pure
    bool pure = true;
    double pure_value = 0;
    for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
    {
      size_t sampleID = sampleIDs[pos];
      double value = data->get(sampleID, dependent_varID);
      if (pos != start_pos[nodeID] && value != pure_value)
      {
        pure = false;
        break;
      }
      pure_value = value;
    }
    if (pure)
    {
      split_values[nodeID] = pure_value;
      return true;
    }

    // Find best split, stop if no decrease of impurity
    bool stop;
    if (splitrule == EXTRATREES)
    {
      stop = findBestSplitExtraTrees(nodeID, possible_split_varIDs);
    }
    else
    {
      stop = findBestSplit(nodeID, possible_split_varIDs);
    }

    if (stop)
    {
      split_values[nodeID] = estimate(nodeID);
      return true;
    }

    return false;
  }

  // Check whether the current node in a random candidate tree root is final.
  bool TreeClassification::checkWhetherFinalRandom(size_t nodeID)
  {

    // Stop if maximum node size or depth reached
    size_t num_samples_node = end_pos_loop[nodeID] - start_pos_loop[nodeID];
    if (num_samples_node <= min_node_size_root || (nodeID >= last_left_nodeID_loop && max_depth_root > 0 && depth >= max_depth_root))
    {
      return true;
    }

    // Check if node is pure and set split_value to estimate and stop if pure
    bool pure = true;
    size_t pos = start_pos_loop[nodeID];
    size_t sampleID = sampleIDs[pos];
    double pure_value = data->get(sampleID, dependent_varID);
    ++pos;
    for (; pos < end_pos_loop[nodeID]; ++pos)
    {
      sampleID = sampleIDs[pos];
      double value = data->get(sampleID, dependent_varID);
      if (value != pure_value)
      {
        pure = false;
        break;
      }
    }

    if (pure)
    {
      return true;
    }

    return false;
  }

  // Create an empty node in a random candidate tree root.
  void TreeClassification::createEmptyNodeRandomTreeInternal()
  {
    // Empty on purpose
  }

  // Create an empty node in a tree sprout.
  void TreeClassification::createEmptyNodeFullTreeInternal()
  {
    // Empty on purpose
  }

  // Function used to clear some objects from the random candidate tree root.
  void TreeClassification::clearRandomTreeInternal()
  {
    // Empty on purpose
  }

  // Compute the Gini impurity reduction for a split in the unity VIM computation.
  double TreeClassification::computeSplitCriterion(std::vector<size_t> sampleIDs_left_child, std::vector<size_t> sampleIDs_right_child)
  {

    // CP();

    // Combine the sample IDs of the left and right child nodes:
    std::vector<size_t> sampleIDs_left_right_child;
    sampleIDs_left_right_child.reserve(sampleIDs_left_child.size() + sampleIDs_right_child.size());
    sampleIDs_left_right_child.insert(sampleIDs_left_right_child.end(), sampleIDs_left_child.begin(), sampleIDs_left_child.end());
    sampleIDs_left_right_child.insert(sampleIDs_left_right_child.end(), sampleIDs_right_child.begin(), sampleIDs_right_child.end());

    // CP();

    // Compute the Gini impurity of the parent node:
    double gini_parent = 1.0;
    if (sampleIDs_left_right_child.size() > 0)
    {
      gini_parent = computeGiniImpurity(sampleIDs_left_right_child);
    }
    // CP();

    // Compute the Gini impurity of the left child node:
    double gini_left = 0.0;
    if (sampleIDs_left_child.size() > 0)
    {
      gini_left = computeGiniImpurity(sampleIDs_left_child);
    }
    // CP();

    // Compute the Gini impurity of the right child node:
    double gini_right = 0.0;
    if (sampleIDs_right_child.size() > 0)
    {
      gini_right = computeGiniImpurity(sampleIDs_right_child);
    }
    // CP();

    // Compute the Gini impurity reduction:
    double gini_reduction = gini_parent - ((double)sampleIDs_left_child.size() / (double)sampleIDs_left_right_child.size()) * gini_left - ((double)sampleIDs_right_child.size() / (double)sampleIDs_left_right_child.size()) * gini_right;
    // CP();
    return gini_reduction;
  }

  // Compute the Gini impurity reduction for a split in the CRTR analysis.
  double TreeClassification::computeOOBSplitCriterionValue(size_t nodeID, std::vector<size_t> oob_sampleIDs_nodeID)
  {

    // Compute the Gini impurity of the parent node:
    double gini_parent = computeGiniImpurity(oob_sampleIDs_nodeID);

    // Determine the OOB observations that are assigned to the left and right child nodes:
    std::vector<size_t> oob_sampleIDs_left_child;
    oob_sampleIDs_left_child.reserve(oob_sampleIDs_nodeID.size());
    std::vector<size_t> oob_sampleIDs_right_child;
    oob_sampleIDs_right_child.reserve(oob_sampleIDs_nodeID.size());

    for (size_t i = 0; i < oob_sampleIDs_nodeID.size(); ++i)
    {
      size_t sampleID = oob_sampleIDs_nodeID[i];
      double value = data->get(sampleID, split_varIDs[nodeID]);
      if (value <= split_values[nodeID])
      {
        oob_sampleIDs_left_child.push_back(sampleID);
      }
      else
      {
        oob_sampleIDs_right_child.push_back(sampleID);
      }
    }

    // If oob_sampleIDs_left_child is empty, set the Gini impurity of the left child node to 0:
    double gini_left = 0;
    if (oob_sampleIDs_left_child.size() > 0)
    {
      gini_left = computeGiniImpurity(oob_sampleIDs_left_child);
    }

    // Compute the Gini impurity of the right child node:
    double gini_right = 0;
    if (oob_sampleIDs_right_child.size() > 0)
    {
      gini_right = computeGiniImpurity(oob_sampleIDs_right_child);
    }

    // Compute the Gini impurity reduction:
    double gini_reduction = gini_parent - ((double)oob_sampleIDs_left_child.size() / (double)oob_sampleIDs_nodeID.size()) * gini_left - ((double)oob_sampleIDs_right_child.size() / (double)oob_sampleIDs_nodeID.size()) * gini_right;

    return gini_reduction;
  }

  // Compute the OOB split criterion value for the node after permuting the OOB observations (unity VIM).
  double TreeClassification::computeOOBSplitCriterionValuePermuted(size_t nodeID, std::vector<size_t> oob_sampleIDs_nodeID, std::vector<size_t> permutations)
  {

    // Compute the Gini impurity of the parent node:
    double gini_parent = computeGiniImpurity(oob_sampleIDs_nodeID);

    // Determine the OOB observations that are assigned to the left and right child nodes after permuting the values of the split variable:
    std::vector<size_t> oob_sampleIDs_left_child;
    oob_sampleIDs_left_child.reserve(oob_sampleIDs_nodeID.size());
    std::vector<size_t> oob_sampleIDs_right_child;
    oob_sampleIDs_right_child.reserve(oob_sampleIDs_nodeID.size());

    for (size_t i = 0; i < oob_sampleIDs_nodeID.size(); ++i)
    {
      size_t sampleID = permutations[i];
      double value = data->get(sampleID, split_varIDs[nodeID]);
      if (value <= split_values[nodeID])
      {
        oob_sampleIDs_left_child.push_back(oob_sampleIDs_nodeID[i]);
      }
      else
      {
        oob_sampleIDs_right_child.push_back(oob_sampleIDs_nodeID[i]);
      }
    }

    // If oob_sampleIDs_left_child is empty, set the Gini impurity of the left child node to 0:
    double gini_left = 0;
    if (oob_sampleIDs_left_child.size() > 0)
    {
      gini_left = computeGiniImpurity(oob_sampleIDs_left_child);
    }

    // Compute the Gini impurity of the right child node:
    double gini_right = 0;
    if (oob_sampleIDs_right_child.size() > 0)
    {
      gini_right = computeGiniImpurity(oob_sampleIDs_right_child);
    }

    // Compute the Gini impurity reduction:
    double gini_reduction = gini_parent - ((double)oob_sampleIDs_left_child.size() / (double)oob_sampleIDs_nodeID.size()) * gini_left - ((double)oob_sampleIDs_right_child.size() / (double)oob_sampleIDs_nodeID.size()) * gini_right;

    return gini_reduction;
  }

  // Compute the Gini impurity (needed for the unity VIM computation and the CRTR analysis). 
  double TreeClassification::computeGiniImpurity(std::vector<size_t> sampleIDs_node)
  {
    // Compute the number of samples in the current node:
    size_t num_samples_node = sampleIDs_node.size();

    // Compute the number of classes:
    size_t num_classes = class_values->size();

    // Compute the class counts:
    std::vector<size_t> class_counts(num_classes, 0);
    for (size_t i = 0; i < num_samples_node; ++i)
    {
      size_t sampleID = sampleIDs_node[i];
      uint sample_classID = (*response_classIDs)[sampleID];
      ++class_counts[sample_classID];
    }

    // Compute the Gini impurity:
    double gini = 1.0;
    for (size_t i = 0; i < num_classes; ++i)
    {
      double proportion = (double)class_counts[i] / (double)num_samples_node;
      gini -= (*class_weights)[i] * proportion * proportion;
    }

    return gini;
  }

  // Find the best split for a node in the tree sprout. 
  bool TreeClassification::findBestSplit(size_t nodeID, std::vector<size_t> &possible_split_varIDs)
  {

    size_t num_samples_node = end_pos[nodeID] - start_pos[nodeID];
    size_t num_classes = class_values->size();
    double best_decrease = -1;
    size_t best_varID = 0;
    double best_value = 0;

    std::vector<size_t> class_counts(num_classes);
    // Compute overall class counts
    for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
    {
      size_t sampleID = sampleIDs[pos];
      uint sample_classID = (*response_classIDs)[sampleID];
      ++class_counts[sample_classID];
    }

    // For all possible split variables
    for (auto &varID : possible_split_varIDs)
    {
      // Find best split value, if ordered consider all values as split values, else all 2-partitions
      if (data->isOrderedVariable(varID))
      {

        // Use memory saving method if option set
        if (memory_saving_splitting)
        {
          findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID,
                                   best_decrease);
        }
        else
        {
          // Use faster method for both cases
          double q = (double)num_samples_node / (double)data->getNumUniqueDataValues(varID);
          if (q < Q_THRESHOLD)
          {
            findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID,
                                     best_decrease);
          }
          else
          {
            findBestSplitValueLargeQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID,
                                     best_decrease);
          }
        }
      }
      else
      {
        findBestSplitValueUnordered(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID,
                                    best_decrease);
      }
    }

    // Stop if no good split found
    if (best_decrease < 0)
    {
      return true;
    }

    // Save best values
    split_varIDs[nodeID] = best_varID;
    split_values[nodeID] = best_value;

    return false;
  }

  void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, size_t num_classes,
                                                    const std::vector<size_t> &class_counts, size_t num_samples_node, double &best_value, size_t &best_varID,
                                                    double &best_decrease)
  {

    // Create possible split values
    std::vector<double> possible_split_values;
    data->getAllValues(possible_split_values, sampleIDs, varID, start_pos[nodeID], end_pos[nodeID]);

    // Try next variable if all equal for this
    if (possible_split_values.size() < 2)
    {
      return;
    }

    // -1 because no split possible at largest value
    const size_t num_splits = possible_split_values.size() - 1;
    if (memory_saving_splitting)
    {
      std::vector<size_t> class_counts_right(num_splits * num_classes), n_right(num_splits);
      findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID,
                               best_decrease, possible_split_values, class_counts_right, n_right);
    }
    else
    {
      std::fill_n(counter_per_class.begin(), num_splits * num_classes, 0);
      std::fill_n(counter.begin(), num_splits, 0);
      findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID,
                               best_decrease, possible_split_values, counter_per_class, counter);
    }
  }

  void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, size_t num_classes,
                                                    const std::vector<size_t> &class_counts, size_t num_samples_node, double &best_value, size_t &best_varID,
                                                    double &best_decrease, const std::vector<double> &possible_split_values, std::vector<size_t> &class_counts_right,
                                                    std::vector<size_t> &n_right)
  {
    const size_t num_splits = possible_split_values.size() - 1;

    // Count samples in right child per class and possbile split
    for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
    {
      size_t sampleID = sampleIDs[pos];
      double value = data->get(sampleID, varID);
      uint sample_classID = (*response_classIDs)[sampleID];

      // Count samples until split_value reached
      for (size_t i = 0; i < num_splits; ++i)
      {
        if (value > possible_split_values[i])
        {
          ++n_right[i];
          ++class_counts_right[i * num_classes + sample_classID];
        }
        else
        {
          break;
        }
      }
    }

    // Compute decrease of impurity for each possible split
    for (size_t i = 0; i < num_splits; ++i)
    {

      // Stop if one child empty
      size_t n_left = num_samples_node - n_right[i];
      if (n_left == 0 || n_right[i] == 0)
      {
        continue;
      }

      // Sum of squares
      double sum_left = 0;
      double sum_right = 0;
      for (size_t j = 0; j < num_classes; ++j)
      {
        size_t class_count_right = class_counts_right[i * num_classes + j];
        size_t class_count_left = class_counts[j] - class_count_right;

        sum_right += (*class_weights)[j] * class_count_right * class_count_right;
        sum_left += (*class_weights)[j] * class_count_left * class_count_left;
      }

      // Decrease of impurity
      double decrease = sum_left / (double)n_left + sum_right / (double)n_right[i];

      // If better than before, use this
      if (decrease > best_decrease)
      {
        best_value = (possible_split_values[i] + possible_split_values[i + 1]) / 2;
        best_varID = varID;
        best_decrease = decrease;

        // Use smaller value if average is numerically the same as the larger value
        if (best_value == possible_split_values[i + 1])
        {
          best_value = possible_split_values[i];
        }
      }
    }
  }

  void TreeClassification::findBestSplitValueLargeQ(size_t nodeID, size_t varID, size_t num_classes,
                                                    const std::vector<size_t> &class_counts, size_t num_samples_node, double &best_value, size_t &best_varID,
                                                    double &best_decrease)
  {

    // Set counters to 0
    size_t num_unique = data->getNumUniqueDataValues(varID);
    std::fill_n(counter_per_class.begin(), num_unique * num_classes, 0);
    std::fill_n(counter.begin(), num_unique, 0);

    // Count values
    for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
    {
      size_t sampleID = sampleIDs[pos];
      size_t index = data->getIndex(sampleID, varID);
      size_t classID = (*response_classIDs)[sampleID];

      ++counter[index];
      ++counter_per_class[index * num_classes + classID];
    }

    size_t n_left = 0;
    std::vector<size_t> class_counts_left(num_classes);

    // Compute decrease of impurity for each split
    for (size_t i = 0; i < num_unique - 1; ++i)
    {

      // Stop if nothing here
      if (counter[i] == 0)
      {
        continue;
      }

      n_left += counter[i];

      // Stop if right child empty
      size_t n_right = num_samples_node - n_left;
      if (n_right == 0)
      {
        break;
      }

      // Sum of squares
      double sum_left = 0;
      double sum_right = 0;
      for (size_t j = 0; j < num_classes; ++j)
      {
        class_counts_left[j] += counter_per_class[i * num_classes + j];
        size_t class_count_right = class_counts[j] - class_counts_left[j];

        sum_left += (*class_weights)[j] * class_counts_left[j] * class_counts_left[j];
        sum_right += (*class_weights)[j] * class_count_right * class_count_right;
      }

      // Decrease of impurity
      double decrease = sum_right / (double)n_right + sum_left / (double)n_left;

      // If better than before, use this
      if (decrease > best_decrease)
      {
        // Find next value in this node
        size_t j = i + 1;
        while (j < num_unique && counter[j] == 0)
        {
          ++j;
        }

        // Use mid-point split
        best_value = (data->getUniqueDataValue(varID, i) + data->getUniqueDataValue(varID, j)) / 2;
        best_varID = varID;
        best_decrease = decrease;

        // Use smaller value if average is numerically the same as the larger value
        if (best_value == data->getUniqueDataValue(varID, j))
        {
          best_value = data->getUniqueDataValue(varID, i);
        }
      }
    }
  }

  void TreeClassification::findBestSplitValueUnordered(size_t nodeID, size_t varID, size_t num_classes,
                                                       const std::vector<size_t> &class_counts, size_t num_samples_node, double &best_value, size_t &best_varID,
                                                       double &best_decrease)
  {

    // Create possible split values
    std::vector<double> factor_levels;
    data->getAllValues(factor_levels, sampleIDs, varID, start_pos[nodeID], end_pos[nodeID]);

    // Try next variable if all equal for this
    if (factor_levels.size() < 2)
    {
      return;
    }

    // Number of possible splits is 2^num_levels
    size_t num_splits = (1 << factor_levels.size());

    // Compute decrease of impurity for each possible split
    // Split where all left (0) or all right (1) are excluded
    // The second half of numbers is just left/right switched the first half -> Exclude second half
    for (size_t local_splitID = 1; local_splitID < num_splits / 2; ++local_splitID)
    {

      // Compute overall splitID by shifting local factorIDs to global positions
      size_t splitID = 0;
      for (size_t j = 0; j < factor_levels.size(); ++j)
      {
        if ((local_splitID & (1 << j)))
        {
          double level = factor_levels[j];
          size_t factorID = floor(level) - 1;
          splitID = splitID | (1 << factorID);
        }
      }

      // Initialize
      std::vector<size_t> class_counts_right(num_classes);
      size_t n_right = 0;

      // Count classes in left and right child
      for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
      {
        size_t sampleID = sampleIDs[pos];
        uint sample_classID = (*response_classIDs)[sampleID];
        double value = data->get(sampleID, varID);
        size_t factorID = floor(value) - 1;

        // If in right child, count
        // In right child, if bitwise splitID at position factorID is 1
        if ((splitID & (1 << factorID)))
        {
          ++n_right;
          ++class_counts_right[sample_classID];
        }
      }
      size_t n_left = num_samples_node - n_right;

      // Sum of squares
      double sum_left = 0;
      double sum_right = 0;
      for (size_t j = 0; j < num_classes; ++j)
      {
        size_t class_count_right = class_counts_right[j];
        size_t class_count_left = class_counts[j] - class_count_right;

        sum_right += (*class_weights)[j] * class_count_right * class_count_right;
        sum_left += (*class_weights)[j] * class_count_left * class_count_left;
      }

      // Decrease of impurity
      double decrease = sum_left / (double)n_left + sum_right / (double)n_right;

      // If better than before, use this
      if (decrease > best_decrease)
      {
        best_value = splitID;
        best_varID = varID;
        best_decrease = decrease;
      }
    }
  }

  bool TreeClassification::findBestSplitExtraTrees(size_t nodeID, std::vector<size_t> &possible_split_varIDs)
  {

    size_t num_samples_node = end_pos[nodeID] - start_pos[nodeID];
    size_t num_classes = class_values->size();
    double best_decrease = -1;
    size_t best_varID = 0;
    double best_value = 0;

    std::vector<size_t> class_counts(num_classes);
    // Compute overall class counts
    for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
    {
      size_t sampleID = sampleIDs[pos];
      uint sample_classID = (*response_classIDs)[sampleID];
      ++class_counts[sample_classID];
    }

    // For all possible split variables
    for (auto &varID : possible_split_varIDs)
    {
      // Find best split value, if ordered consider all values as split values, else all 2-partitions
      if (data->isOrderedVariable(varID))
      {
        findBestSplitValueExtraTrees(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID,
                                     best_decrease);
      }
      else
      {
        findBestSplitValueExtraTreesUnordered(nodeID, varID, num_classes, class_counts, num_samples_node, best_value,
                                              best_varID, best_decrease);
      }
    }

    // Stop if no good split found
    if (best_decrease < 0)
    {
      return true;
    }

    // Save best values
    split_varIDs[nodeID] = best_varID;
    split_values[nodeID] = best_value;

    return false;
  }

  void TreeClassification::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, size_t num_classes,
                                                        const std::vector<size_t> &class_counts, size_t num_samples_node, double &best_value, size_t &best_varID,
                                                        double &best_decrease)
  {

    // Get min/max values of covariate in node
    double min;
    double max;
    data->getMinMaxValues(min, max, sampleIDs, varID, start_pos[nodeID], end_pos[nodeID]);

    // Try next variable if all equal for this
    if (min == max)
    {
      return;
    }

    // Create possible split values: Draw randomly between min and max
    std::vector<double> possible_split_values;
    std::uniform_real_distribution<double> udist(min, max);
    possible_split_values.reserve(num_random_splits);
    for (size_t i = 0; i < num_random_splits; ++i)
    {
      possible_split_values.push_back(udist(random_number_generator));
    }

    const size_t num_splits = possible_split_values.size();
    if (memory_saving_splitting)
    {
      std::vector<size_t> class_counts_right(num_splits * num_classes), n_right(num_splits);
      findBestSplitValueExtraTrees(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID,
                                   best_decrease, possible_split_values, class_counts_right, n_right);
    }
    else
    {
      std::fill_n(counter_per_class.begin(), num_splits * num_classes, 0);
      std::fill_n(counter.begin(), num_splits, 0);
      findBestSplitValueExtraTrees(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID,
                                   best_decrease, possible_split_values, counter_per_class, counter);
    }
  }

  void TreeClassification::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, size_t num_classes,
                                                        const std::vector<size_t> &class_counts, size_t num_samples_node, double &best_value, size_t &best_varID,
                                                        double &best_decrease, const std::vector<double> &possible_split_values, std::vector<size_t> &class_counts_right,
                                                        std::vector<size_t> &n_right)
  {
    const size_t num_splits = possible_split_values.size();

    // Count samples in right child per class and possbile split
    for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
    {
      size_t sampleID = sampleIDs[pos];
      double value = data->get(sampleID, varID);
      uint sample_classID = (*response_classIDs)[sampleID];

      // Count samples until split_value reached
      for (size_t i = 0; i < num_splits; ++i)
      {
        if (value > possible_split_values[i])
        {
          ++n_right[i];
          ++class_counts_right[i * num_classes + sample_classID];
        }
        else
        {
          break;
        }
      }
    }

    // Compute decrease of impurity for each possible split
    for (size_t i = 0; i < num_splits; ++i)
    {

      // Stop if one child empty
      size_t n_left = num_samples_node - n_right[i];
      if (n_left == 0 || n_right[i] == 0)
      {
        continue;
      }

      // Sum of squares
      double sum_left = 0;
      double sum_right = 0;
      for (size_t j = 0; j < num_classes; ++j)
      {
        size_t class_count_right = class_counts_right[i * num_classes + j];
        size_t class_count_left = class_counts[j] - class_count_right;

        sum_right += (*class_weights)[j] * class_count_right * class_count_right;
        sum_left += (*class_weights)[j] * class_count_left * class_count_left;
      }

      // Decrease of impurity
      double decrease = sum_left / (double)n_left + sum_right / (double)n_right[i];

      // If better than before, use this
      if (decrease > best_decrease)
      {
        best_value = possible_split_values[i];
        best_varID = varID;
        best_decrease = decrease;
      }
    }
  }

  void TreeClassification::findBestSplitValueExtraTreesUnordered(size_t nodeID, size_t varID, size_t num_classes,
                                                                 const std::vector<size_t> &class_counts, size_t num_samples_node, double &best_value, size_t &best_varID,
                                                                 double &best_decrease)
  {

    size_t num_unique_values = data->getNumUniqueDataValues(varID);

    // Get all factor indices in node
    std::vector<bool> factor_in_node(num_unique_values, false);
    for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
    {
      size_t sampleID = sampleIDs[pos];
      size_t index = data->getIndex(sampleID, varID);
      factor_in_node[index] = true;
    }

    // Vector of indices in and out of node
    std::vector<size_t> indices_in_node;
    std::vector<size_t> indices_out_node;
    indices_in_node.reserve(num_unique_values);
    indices_out_node.reserve(num_unique_values);
    for (size_t i = 0; i < num_unique_values; ++i)
    {
      if (factor_in_node[i])
      {
        indices_in_node.push_back(i);
      }
      else
      {
        indices_out_node.push_back(i);
      }
    }

    // Generate num_random_splits splits
    for (size_t i = 0; i < num_random_splits; ++i)
    {
      std::vector<size_t> split_subset;
      split_subset.reserve(num_unique_values);

      // Draw random subsets, sample all partitions with equal probability
      if (indices_in_node.size() > 1)
      {
        size_t num_partitions = (2 << (indices_in_node.size() - 1)) - 2; // 2^n-2 (don't allow full or empty)
        std::uniform_int_distribution<size_t> udist(1, num_partitions);
        size_t splitID_in_node = udist(random_number_generator);
        for (size_t j = 0; j < indices_in_node.size(); ++j)
        {
          if ((splitID_in_node & (1 << j)) > 0)
          {
            split_subset.push_back(indices_in_node[j]);
          }
        }
      }
      if (indices_out_node.size() > 1)
      {
        size_t num_partitions = (2 << (indices_out_node.size() - 1)) - 1; // 2^n-1 (allow full or empty)
        std::uniform_int_distribution<size_t> udist(0, num_partitions);
        size_t splitID_out_node = udist(random_number_generator);
        for (size_t j = 0; j < indices_out_node.size(); ++j)
        {
          if ((splitID_out_node & (1 << j)) > 0)
          {
            split_subset.push_back(indices_out_node[j]);
          }
        }
      }

      // Assign union of the two subsets to right child
      size_t splitID = 0;
      for (auto &idx : split_subset)
      {
        splitID |= 1 << idx;
      }

      // Initialize
      std::vector<size_t> class_counts_right(num_classes);
      size_t n_right = 0;

      // Count classes in left and right child
      for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos)
      {
        size_t sampleID = sampleIDs[pos];
        uint sample_classID = (*response_classIDs)[sampleID];
        double value = data->get(sampleID, varID);
        size_t factorID = floor(value) - 1;

        // If in right child, count
        // In right child, if bitwise splitID at position factorID is 1
        if ((splitID & (1 << factorID)))
        {
          ++n_right;
          ++class_counts_right[sample_classID];
        }
      }
      size_t n_left = num_samples_node - n_right;

      // Sum of squares
      double sum_left = 0;
      double sum_right = 0;
      for (size_t j = 0; j < num_classes; ++j)
      {
        size_t class_count_right = class_counts_right[j];
        size_t class_count_left = class_counts[j] - class_count_right;

        sum_right += (*class_weights)[j] * class_count_right * class_count_right;
        sum_left += (*class_weights)[j] * class_count_left * class_count_left;
      }

      // Decrease of impurity
      double decrease = sum_left / (double)n_left + sum_right / (double)n_right;

      // If better than before, use this
      if (decrease > best_decrease)
      {
        best_value = splitID;
        best_varID = varID;
        best_decrease = decrease;
      }
    }
  }

  void TreeClassification::bootstrapClassWise()
  {
    // Number of samples is sum of sample fraction * number of samples
    size_t num_samples_inbag = 0;
    double sum_sample_fraction = 0;
    for (auto &s : *sample_fraction)
    {
      num_samples_inbag += (size_t)num_samples * s;
      sum_sample_fraction += s;
    }

    // Reserve space, reserve a little more to be save)
    sampleIDs.reserve(num_samples_inbag);
    oob_sampleIDs.reserve(num_samples * (exp(-sum_sample_fraction) + 0.1));

    // Start with all samples OOB
    inbag_counts.resize(num_samples, 0);

    // Draw samples for each class
    for (size_t i = 0; i < sample_fraction->size(); ++i)
    {
      // Draw samples of class with replacement as inbag and mark as not OOB
      size_t num_samples_class = (*sampleIDs_per_class)[i].size();
      size_t num_samples_inbag_class = round(num_samples * (*sample_fraction)[i]);
      std::uniform_int_distribution<size_t> unif_dist(0, num_samples_class - 1);
      for (size_t s = 0; s < num_samples_inbag_class; ++s)
      {
        size_t draw = (*sampleIDs_per_class)[i][unif_dist(random_number_generator)];
        sampleIDs.push_back(draw);
        ++inbag_counts[draw];
      }
    }

    // Save OOB samples
    for (size_t s = 0; s < inbag_counts.size(); ++s)
    {
      if (inbag_counts[s] == 0)
      {
        oob_sampleIDs.push_back(s);
      }
    }
    num_samples_oob = oob_sampleIDs.size();

    if (!keep_inbag)
    {
      inbag_counts.clear();
      inbag_counts.shrink_to_fit();
    }
  }

  void TreeClassification::bootstrapWithoutReplacementClassWise()
  {
    // Draw samples for each class
    for (size_t i = 0; i < sample_fraction->size(); ++i)
    {
      size_t num_samples_class = (*sampleIDs_per_class)[i].size();
      size_t num_samples_inbag_class = round(num_samples * (*sample_fraction)[i]);

      shuffleAndSplitAppend(sampleIDs, oob_sampleIDs, num_samples_class, num_samples_inbag_class,
                            (*sampleIDs_per_class)[i], random_number_generator);
    }

    if (keep_inbag)
    {
      // All observation are 0 or 1 times inbag
      inbag_counts.resize(num_samples, 1);
      for (size_t i = 0; i < oob_sampleIDs.size(); i++)
      {
        inbag_counts[oob_sampleIDs[i]] = 0;
      }
    }
  }

} // namespace ranger
