#ifndef RANDOM_FOREST_CLASSIFIER_H
#define RANDOM_FOREST_CLASSIFIER_H

#include "dlib/random_forest.h"
#include <vector>
#include <map>

using namespace dlib;

// 使用dlib的矩阵类型
typedef matrix<double, 0, 1> sample_type;

// 使用dlib的随机森林回归函数
typedef random_forest_regression_function<> rf_model_type;
typedef random_forest_regression_trainer<> rf_trainer_type;

class RandomForestClassifier
{
private:
    random_forest_regression_function<> forest;
    std::vector<double> scaler_mean;
    std::vector<double> scaler_std;
    int feature_dim;
    int label_dim;

    // 私有成员函数的声明
    void stratified_split(
        const std::vector<sample_type>& X, const std::vector<double>& y,
        std::vector<sample_type>& X_train, std::vector<double>& y_train,
        std::vector<sample_type>& X_test, std::vector<double>& y_test,
        float test_size, unsigned int seed);

    std::vector<sample_type> convert_to_dlib_format(const std::vector<std::vector<float>>& X);

    void compute_scaler_params(const std::vector<sample_type>& samples);
    sample_type scale_sample(const sample_type& s);
    std::vector<sample_type> scale_samples(const std::vector<sample_type>& samples);
    void evaluate(const std::vector<sample_type>& X_test, const std::vector<double>& y_test);

public:
    RandomForestClassifier() : feature_dim(10), label_dim(4) {}
    RandomForestClassifier(int initial_dim, int label_num) : feature_dim(initial_dim), label_dim(label_num) {}

    bool train(const std::vector<std::vector<float>>& X, const std::vector<int>& y,
               float test_size = 0.0f, int n_estimators = 100);

    double predict(const std::vector<float>& sample);
    std::vector<double> predict_batch(const std::vector<std::vector<float>>& samples);
};

#endif // RANDOM_FOREST_CLASSIFIER_H
