/*-------------------------------------------------------------------------------
 This file is part of unityForest.

 Copyright (c) [2014-2018] [Marvin N. Wright]
 Modifications and extensions by Roman Hornung

 This software may be modified and distributed under the terms of the MIT license.

 Please note that the C++ core of divfor is distributed under MIT license and the
 R package "unityForest" under GPL3 license.
 #-------------------------------------------------------------------------------*/

#ifndef TREEREGRESSION_H_
#define TREEREGRESSION_H_

#include <vector>

#include "globals.h"
#include "Tree.h"

namespace unityForest
{

    class TreeRegression : public Tree
    {
    public:
        TreeRegression() = default;

        // Create from loaded forest
        TreeRegression(std::vector<std::vector<size_t>> &child_nodeIDs, std::vector<size_t> &split_varIDs,
                       std::vector<double> &split_values);

        // Constructor for repr_tree_mode
        TreeRegression(std::vector<std::vector<size_t>> &child_nodeIDs, std::vector<size_t> &split_varIDs,
                       std::vector<double> &split_values,
                       std::vector<size_t> &nodeID_in_root,
                       std::vector<size_t> &inbag_counts, std::vector<size_t> &repr_vars, const Data *data_ptr);

        std::unique_ptr<Tree> clone() const override;

        // TreeRegression(const TreeRegression&) = delete;
        // TreeRegression& operator=(const TreeRegression&) = delete;

        virtual ~TreeRegression() override = default;

        void allocateMemory() override;

        double estimate(size_t nodeID);

        double computeSplitCriterion(std::vector<size_t> sampleIDs_left_child, std::vector<size_t> sampleIDs_right_child) override;

        double computeVariance(std::vector<size_t> sampleIDs_node);

        double computeOOBSplitCriterionValue(size_t nodeID, std::vector<size_t> oob_sampleIDs_nodeID) override;

        double computeOOBSplitCriterionValuePermuted(size_t nodeID, std::vector<size_t> oob_sampleIDs_nodeID, std::vector<size_t> permutations) override;

        double getPrediction(size_t sampleID) const
        {
            size_t terminal_nodeID = prediction_terminal_nodeIDs[sampleID];
            return (split_values[terminal_nodeID]);
        }

        size_t getPredictionTerminalNodeID(size_t sampleID) const
        {
            return prediction_terminal_nodeIDs[sampleID];
        }

    private:
        // Function to evaluate a random tree:
        double evaluateRandomTree(const std::vector<size_t> &terminal_nodes) override;
        bool splitNodeInternal(size_t nodeID, std::vector<size_t> &possible_split_varIDs) override;

        // Create an empty node in a random tree
        void createEmptyNodeRandomTreeInternal() override;

        // Create an empty node in a (full) tree
        void createEmptyNodeFullTreeInternal() override;

        // Function used to clear some objects from the random trees
        void clearRandomTreeInternal() override;

        // Check whether the current node in a random tree is final
        bool checkWhetherFinalRandom(size_t nodeID) override;

        // Called by splitNodeInternal(). Sets split_varIDs and split_values.
        bool findBestSplit(size_t nodeID, std::vector<size_t> &possible_split_varIDs);
        void findBestSplitValueSmallQ(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node,
                                      double &best_value, size_t &best_varID, double &best_decrease);
        void findBestSplitValueSmallQ(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node,
                                      double &best_value, size_t &best_varID, double &best_decrease, std::vector<double> possible_split_values,
                                      std::vector<double> &sums_right, std::vector<size_t> &n_right);
        void findBestSplitValueLargeQ(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node,
                                      double &best_value, size_t &best_varID, double &best_decrease);
        void findBestSplitValueUnordered(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node,
                                         double &best_value, size_t &best_varID, double &best_decrease);

        bool findBestSplitMaxstat(size_t nodeID, std::vector<size_t> &possible_split_varIDs);

        bool findBestSplitExtraTrees(size_t nodeID, std::vector<size_t> &possible_split_varIDs);
        void findBestSplitValueExtraTrees(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node,
                                          double &best_value, size_t &best_varID, double &best_decrease);
        void findBestSplitValueExtraTrees(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node,
                                          double &best_value, size_t &best_varID, double &best_decrease, std::vector<double> possible_split_values,
                                          std::vector<double> &sums_right, std::vector<size_t> &n_right);
        void findBestSplitValueExtraTreesUnordered(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node,
                                                   double &best_value, size_t &best_varID, double &best_decrease);

        double computePredictionMSE();

        void cleanUpInternal() override
        {
            counter.clear();
            counter.shrink_to_fit();
            sums.clear();
            sums.shrink_to_fit();
        }

        std::vector<size_t> counter;
        std::vector<double> sums;
    };

} // namespace unityForest

#endif /* TREEREGRESSION_H_ */
