#include "RandomForestClassifier.h"
#include <algorithm>
#include <random>
#include "common.h"
#include "spdlog/spdlog.h"

RandomForestClassifier::RandomForestClassifier(std::string target, int initial_dim, int label_num)
{
    target = m_target;
    feature_dim = initial_dim;
    label_dim = label_num;
}

const std::string& RandomForestClassifier::getTarget() const {return m_target;}
void RandomForestClassifier::set_target(const std::string& target) {m_target=target;}

std::vector<sample_type> RandomForestClassifier::convert_to_dlib_format(const std::vector<std::vector<float>>& X)
{
    std::vector<sample_type> result;
    for (const auto& vec : X) {
        sample_type s(feature_dim);
        for (int i = 0; i < feature_dim; ++i) {
            s(i) = vec[i];
        }
        result.push_back(s);
    }
    return result;
}

void RandomForestClassifier::stratified_split(
    const std::vector<sample_type>& X, const std::vector<double>& y,
    std::vector<sample_type>& X_train, std::vector<double>& y_train,
    std::vector<sample_type>& X_test, std::vector<double>& y_test,
    float test_size, unsigned int seed)
{

    // 清空输出参数，防止多次调用数据累积
    X_train.clear();
    y_train.clear();
    X_test.clear();
    y_test.clear();

    // 按类别组织数据索引
    std::map<double, std::vector<size_t>> class_indices;
    for (size_t i = 0; i < y.size(); ++i) {
        class_indices[y[i]].push_back(i);
    }

    //使用更现代的随机数生成器
    std::mt19937 rng(seed);

    // 精确控制总测试集比例 (test_size 现在是 float，计算时需注意类型一致性)
    size_t total_desired_test = static_cast<size_t>(round(X.size() * static_cast<double>(test_size)));
    size_t total_actual_test = 0;

    // 第一遍：计算每个类别应该分配的测试样本数
    std::vector<size_t> each_test_count;
    for (auto& pair : class_indices) {
        size_t n = pair.second.size();
        double exact_test = n * static_cast<double>(test_size); // 转换为 double 进行乘法以避免精度损失
        size_t count = static_cast<size_t>(exact_test); // 向下取整

        // 根据累积误差决定是否进位，确保总测试数接近期望值
        double fractional_part = exact_test - count;
        if (fractional_part > 0.5 && total_actual_test < total_desired_test) {
            count++;
        }

        // 确保每个类别至少有1个测试样本，且至少保留1个训练样本
        count = std::max(size_t(1), std::min(count, n - 1));
        each_test_count.push_back(count);
        total_actual_test += count;
    }

    // 第二遍：执行分层抽样
    auto it = class_indices.begin();
    for (size_t i = 0; i < each_test_count.size(); ++i, ++it) {
        auto& indices = it->second;
        size_t n = indices.size();
        size_t test_count = each_test_count[i];

        // 使用标准库的shuffle获得更好的随机性
        std::shuffle(indices.begin(), indices.end(), rng);

        for (size_t j = 0; j < n; ++j) {
            size_t idx = indices[j];
            if (j < test_count) {
                X_test.push_back(X[idx]);
                y_test.push_back(y[idx]);
            } else {
                X_train.push_back(X[idx]);
                y_train.push_back(y[idx]);
            }
        }
    }

    DEBUG_PRINT("expect_test_size=%f actual_test_size=%f\n", test_size, static_cast<float>(X_test.size()) / X.size());
}

void RandomForestClassifier::compute_scaler_params(const std::vector<sample_type>& samples) {
    if (samples.empty())
        return;

    scaler_mean.resize(feature_dim, 0.0);
    scaler_std.resize(feature_dim, 0.0);

    for (const auto& s : samples) {
        for (int i = 0; i < feature_dim; ++i) {
            scaler_mean[i] += s(i);
        }
    }

    for (int i = 0; i < feature_dim; ++i)
        scaler_mean[i] /= samples.size();

    for (const auto& s : samples) {
        for (int i = 0; i < feature_dim; ++i) {
            double diff = s(i) - scaler_mean[i];
            scaler_std[i] += diff * diff;
        }
    }

    for (int i = 0; i < feature_dim; ++i) {
        scaler_std[i] /= (samples.size()-1);
        scaler_std[i] = sqrt(scaler_std[i]);
        if (scaler_std[i] < 1e-8)
            scaler_std[i] = 1.0;
    }
}

sample_type RandomForestClassifier::scale_sample(const sample_type& s)
{
    sample_type result(feature_dim);
    for (int i = 0; i < feature_dim; ++i) {
        result(i) = (s(i) - scaler_mean[i]) / scaler_std[i];
    }
    return result;
}

//TMS 训练调用的接口
bool RandomForestClassifier::train(const std::vector<std::vector<float>>& X, const std::vector<int>& y, float test_size, int n_estimators)
{
    if (X.empty()) {
        DEBUG_PRINT("X is nil\n");
        return false;
    }

    if (X.size() != y.size()) {
        DEBUG_PRINT("invalid param: size_X=%lu size_y=%lu\n", X.size(), y.size());
        return false;
    }

    if (X[0].size() != static_cast<size_t>(feature_dim)){
        DEBUG_PRINT("invalid param: X[0].size()=%lu feature_dim=%d\n", X[0].size(), feature_dim);
        return false;
    }

    std::vector<sample_type> X_train, X_test;
    std::vector<double> y_train, y_test;
    if (std::fabs(test_size) < 1e-6f)
    {
        // 转换为dlib矩阵格式
        X_train = convert_to_dlib_format(X);
        for (size_t i = 0; i < y.size(); ++i)
            y_train.push_back(static_cast<double>(y[i]));
    }
    else
    {
        // 转换为dlib矩阵格式
        std::vector<sample_type> samples = convert_to_dlib_format(X);
        std::vector<double> labels;
        for (size_t i = 0; i < y.size(); ++i)
            labels.push_back(static_cast<double>(y[i]));

        // 分层划分训练集和测试集
        stratified_split(samples, labels, X_train, y_train, X_test, y_test, test_size, 42);
        DEBUG_PRINT("total_num=%ld train_num=%ld test_num=%ld\n", X.size(), X_train.size(), X_test.size());
    }

    random_forest_regression_trainer<> trainer;
    trainer.set_num_trees(n_estimators);
    trainer.set_min_samples_per_leaf(1);
    trainer.set_feature_subsampling_fraction(1.0/3.0);
    trainer.set_seed("42");

    if (test_size > 0)
        trainer.be_verbose();

    DEBUG_PRINT("begin to train, and expect_trees=%d\n", n_estimators);
    compute_scaler_params(X_train);
    auto X_train_scaled = scale_samples(X_train);
    try
    {
        forest = trainer.train(X_train_scaled, y_train);
    }
    catch(std::exception &e) {
        DEBUG_PRINT("train failed: %s\n", e.what());
        return false;
    }

    try
    {
        std::string filename = (!m_target.empty()) ? (m_target + "_rf_model.dat") : "rf_model.dat";
        serialize(filename) << scaler_mean << scaler_std << forest;
        DEBUG_PRINT("save random-forest model success\n");
    }
    catch(std::exception &e) {
        DEBUG_PRINT("save random-forest model failed: %s\n", e.what());
        return false;
    }

    DEBUG_PRINT("decision trees: expect=%d actual=%ld\n", n_estimators, forest.get_num_trees());
    DEBUG_PRINT("train done\n");

    if (!X_test.empty())
        evaluate(X_test, y_test);
    return true;
}

void RandomForestClassifier::evaluate(const std::vector<sample_type>& X_test, const std::vector<double>& y_test)
{
    DEBUG_PRINT("Evaluate the train performance on X_test\n");

    if (X_test.empty() || y_test.empty() || X_test.size()!=y_test.size()) {
        DEBUG_PRINT("neither X_test nor y_test is nil\n");
        return;
    }

    std::map<int, std::map<int, int>> conf_matrix;
    int total_correct = 0;
    auto X_test_scaled = scale_samples(X_test);
    for (size_t i = 0; i < X_test_scaled.size(); ++i) {
        double raw_pred = forest(X_test_scaled[i]);
        int pred_val = static_cast<int>(round(raw_pred));
        int true_val = static_cast<int>(round(y_test[i]));
        if (pred_val < 0 ) pred_val = 0;
        if (pred_val > label_dim) pred_val = label_dim;
        if (true_val < 0 ) true_val = 0;
        if (true_val > label_dim) true_val = label_dim;
        conf_matrix[true_val][pred_val]++;
        if (pred_val == true_val)
            total_correct++;
    }

    char buf[512];
    int ret = 0;
    int maxlen = sizeof(buf)-1;
    buf[maxlen]= 0;
    char *ptr = buf;
    ret += snprintf(ptr, maxlen-ret, "类别准确率\n");
    for (int label=0; label<label_dim; ++label)
    {
        int total = 0;
        int correct = conf_matrix[label][label];
        for (auto& p : conf_matrix[label]) {
            total += p.second;
        }
        if (total > 0) {
            if (ret < maxlen)
                ret += snprintf(&ptr[ret], maxlen-ret, "  类别 %d: %d/%d (%.4f)\n", label, correct, total, static_cast<float>(correct) / total);
        }
    }

    if (ret < maxlen)
        ret += snprintf(&ptr[ret], maxlen-ret, "测试集总准确率: %.4f\n", static_cast<float>(total_correct) / X_test.size());
    DEBUG_PRINT("%s\n", buf);


    //dump confusion_matrix
    memset(buf, 0, sizeof(buf));
    ptr = buf;
    ret = 0;
    ret += snprintf(ptr, maxlen - ret, "混淆矩阵:\n预测  : ");
    for (int i = 0; i < label_dim; ++i) {
        ret += snprintf(&ptr[ret], maxlen-ret, "%3d ", i);
    }
    ptr[ret++] ='\n';

    for (int true_val = 0; true_val < label_dim; ++true_val) {
        if (ret < maxlen)
            ret += snprintf(&ptr[ret], maxlen-ret, "真实 %d: ", true_val);
        for (int pred_val = 0; pred_val < label_dim; ++pred_val) {
            if (ret < maxlen)
                ret += snprintf(&ptr[ret], maxlen-ret, "%3d ", conf_matrix[true_val][pred_val]);
        }
        if (ret < maxlen)
            ptr[ret++] ='\n';
    }
    DEBUG_PRINT("%s\n", buf);
}

std::vector<sample_type> RandomForestClassifier::scale_samples(const std::vector<sample_type>& samples)
{
    std::vector<sample_type> result;
    for (const auto& s : samples) {
        result.push_back(scale_sample(s));
    }
    return result;
}

//查表取出特征值 -> 预测告警类型，发告警消息类型
double RandomForestClassifier::predict(const std::vector<float>& sample)
{
    if (sample.size() != static_cast<size_t>(feature_dim))
    {
        SPDLOG_LOGGER_DEBUG(spdlog::get("logger"), "[predict] invalid param: sample_size={} feature_dim={}", sample.size(), feature_dim);
        return -1;
    }

    if (forest.get_num_trees() == 0) {
        try {
            std::string filename = (!m_target.empty()) ? (m_target + "_rf_model.dat") : "rf_model.dat";
            deserialize(filename) >> scaler_mean >> scaler_std >> forest;
            feature_dim = scaler_mean.size();
        }
        catch (const std::exception& e) {
            SPDLOG_LOGGER_DEBUG(spdlog::get("logger"), "load model failure: {}\n", e.what());
            return -1;
        }
    }

    sample_type s(feature_dim);
    for (int i = 0; i < feature_dim; ++i) {
        s(i) = static_cast<double>(sample[i]);
    }
    sample_type scaled = scale_sample(s);
    double prediction = forest(scaled);
    return prediction;
}
