#include "RamdomForestCalculate.h"
#include <cstdlib>
#include <dirent.h>
#include <QCoreApplication>
#include "spdlog/spdlog.h"
#include <QFileInfo>

void delete_excess_rf_models()
{
    int to_delete = 0;
    int done = 0;

    DIR* dir = opendir(".");
    if (!dir) return;

    struct dirent* entry = NULL;
    while ((entry = readdir(dir)) != nullptr) {
        if ('.' == entry->d_name[0])
            continue;

        std::string filename = entry->d_name;
        if (filename.find("rf_model.dat") != std::string::npos)
            ++to_delete;
    }

    if (to_delete >= 1000)
    {
        rewinddir(dir);
        while ((entry = readdir(dir)) != nullptr) {
            if ('.' == entry->d_name[0])
                continue;

            std::string filename = entry->d_name;
            if (filename.find("rf_model.dat") != std::string::npos){
                if (!unlink(entry->d_name))
                    ++done;
            }
        }
    }

    closedir(dir);
    if (to_delete > 0)
        SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"[delete_rf_models] to_delete={} done={}", to_delete, done);
}

bool checkSqliteDb(const QString& dbPath)
{
    QFile file(dbPath);
    if (!file.open(QIODevice::ReadOnly)) {
        SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"[isSqliteDb] open database failed");
        return false;
    }

    QByteArray header = file.read(16);
    file.close();

    if (header.size() < 16) {
        SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"[isSqliteDb] read_bytes lt 16: {}", header.size());
        QFile::remove(dbPath);
        return false;
    }

    static const QByteArray sqliteMagic("SQLite format 3");
    if (!header.startsWith(sqliteMagic)){
        SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"[isSqliteDb] read sqlite magic error");
        QFile::remove(dbPath);
        return false;
    }

    return true;
}

int pickPriorityValue(double xpred, double ypred, double zpred)
{
    // 四舍五入
    int x = qRound(xpred);
    int y = qRound(ypred);
    int z = qRound(zpred);

    // 优先检查 3 → 2 → 1
    if (x == 3 || y == 3 || z == 3) return 3;
    if (x == 2 || y == 2 || z == 2) return 2;
    if (x == 1 || y == 1 || z == 1) return 1;

    return 0;   // 都没有
}

void CreatTable()
{
    QString dbPath = QString("/home/pi/SigerTMS/stream/SAC/%1.db")
                         .arg("RamdomForestFeatureValue");

    delete_excess_rf_models();
    checkSqliteDb(dbPath);

    // 判断数据库文件是否存在
    QFileInfo fi(dbPath);
    if (!fi.exists()) {
        qDebug() << "Database not exists. Creating...";

        // ---- 短连接：打开 -> 执行建表语句 -> 关闭 ----
        QSqlDatabase db = QSqlDatabase::addDatabase("QSQLITE", "InitConnection");
        db.setDatabaseName(dbPath);

        if (!db.open()) {
            qDebug() << "Create DB failed:" << db.lastError().text();
            return;
        }

        QString createTable = QString(
            "CREATE TABLE IF NOT EXISTS RamdomForestFeatureValue("
            "id INTEGER PRIMARY KEY AUTOINCREMENT, "
            "StartTime TEXT NOT NULL, "
            "EndTime TEXT NOT NULL, "
            "target TEXT NOT NULL, "
            "Label INTEGER NOT NULL, "
            "x TEXT NOT NULL, "
            "y TEXT NOT NULL, "
            "z TEXT NOT NULL, "
            "created_at TEXT DEFAULT CURRENT_TIMESTAMP)"
        );

        QSqlQuery query(db);
        if (!query.exec(createTable)) {
            qDebug() << "Create table error:" << query.lastError().text();
        }

        db.close();
        QSqlDatabase::removeDatabase("InitConnection");

        qDebug() << "Database and table created.";
    }
}

bool execShortConn(QString sql, const QString &dbFile)
{
    // 生成唯一连接名
    QString conn = QString("short_%1_%2")
                       .arg((qulonglong)QThread::currentThreadId())
                       .arg(QDateTime::currentMSecsSinceEpoch());

    bool ok = false;

    {
        QSqlDatabase db = QSqlDatabase::addDatabase("QSQLITE", conn);
        db.setDatabaseName(dbFile);

        if (!db.open()) {
            qWarning() << "open failed:" << db.lastError();
            QSqlDatabase::removeDatabase(conn);
            return false;
        }

        // WAL & busy_timeout 可提升稳定性
        QSqlQuery prag(db);
        prag.exec("PRAGMA journal_mode=WAL;");
        prag.exec("PRAGMA synchronous=NORMAL;");
        prag.exec("PRAGMA busy_timeout=3000;");

        QSqlQuery query(db);
        ok = query.exec(sql);
        if (!ok) {
            qWarning() << "SQL ERROR:" << query.lastError();
        }

        db.close();
    }

    QSqlDatabase::removeDatabase(conn);
    return ok;
}

//发送给流式的消息
RamdomForestCalculate::RamdomForestCalculate(QObject *parent)
{
    QString a = "算法模块开始运行";
    SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"{}",a.toStdString());

    qRegisterMetaType<Dc_SacData>("Dc_SacData");

    m_tcpServerManager = new TcpServerManager();  // 实例化
    m_tcpServerManager->startServer(2026);

    //dbFunc =  new DBFunc();  // 实例化
    //dbFunc->startDBThread("RamdomForestFeatureValue");

    //2.创建数据库表
    CreatTable();

    //3.训练初始模型
    Retrain();

    //GLS Test
//    StreamInfo info(1,"C$14$1$1$1$101$0","1764733704270-0","0");
//    SendStreamAlarminfo(info);


    //连接信号和槽函数
    QObject::connect(m_tcpServerManager, &TcpServerManager::getdc_sacdata,
                     [=](const Dc_SacData &sac_data) {

        SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"收到 dc_sacdata 信号");
        SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"Source:{}",sac_data.source.toStdString());
        SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"SacFile:{}",sac_data.sacFile.toStdString());
        SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"Target:{}",sac_data.target.toStdString());

        Save_SACDATA(sac_data);  // 现在可以识别了
    });

    //连接TMS信号和槽函数
    QObject::connect(m_tcpServerManager, &TcpServerManager::TMSTrain,
                     [=]() {
        SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"收到 TMS 信号");
        Retrain();
        //重新训练
    });

    SPDLOG_LOGGER_DEBUG(spdlog::get("logger"), "RamdomForestCalculate init done");
}

QString RamdomForestCalculate::QueryPowerData(const StreamInfo &alarmInfo)
{
    QString dbPath = QString("/home/pi/SigerTMS/stream/slice/slice_%1.db").arg(alarmInfo.target);

    QString result = "";

    QSqlDatabase db = QSqlDatabase::addDatabase("QSQLITE");
    db.setDatabaseName(dbPath);

    if (!db.open()) {
        qDebug() << "无法打开数据库:" << db.lastError().text();
        return result;
    }

    QSqlQuery query;
    QString sql = QString("SELECT data FROM slice WHERE start_cursor = '%1'").arg(alarmInfo.startIds);

    if (query.exec(sql) && query.next()) {
        result = query.value(0).toString();
    }

    db.close();
    return result;
}

void RamdomForestCalculate::SendStreamAlarminfo(const StreamInfo &alarmInfo)
{
    QString ip("127.0.0.1");
    TcpClient tcp_client(ip, 5999);
    QJsonObject json;
    json["source"] = STREAM_MESSAGE;
    json["typealarm"] = QString::number(alarmInfo.alarmtype);
    json["target"] = alarmInfo.target;
    json["startIds"] = alarmInfo.startIds;
    json["endIds"] = alarmInfo.endIds;

    tcp_client.sendData(QString(QJsonDocument(json).toJson(QJsonDocument::Compact)).toStdString().c_str());
    SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"发送流式报警信息成功");

    // 创建短连接 TcpClient
//    TcpClient* client = new TcpClient();

//    // 构造消息
//    QJsonObject message;
//    message["source"] = STREAM_MESSAGE;
//    message["typealarm"] = QString::number(alarmInfo.alarmtype);
//    message["target"] = alarmInfo.target;
//    message["startIds"] = alarmInfo.startIds;
//    message["endIds"] = alarmInfo.endIds;

//    // 连接成功后发送消息
//    QObject::connect(client, &TcpClient::connected, [client, message]() {
//        qDebug() << "=== 连接成功，开始发送报警消息 ===";

//        client->sendMessage(message);

//        // 延迟关闭，确保消息发送完成
//        QTimer::singleShot(200, client, [client]() {
//            //client->disconnectFromServer();
//            client->deleteLater();
//        });

//        SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),
//                            "发送流式报警信息成功");
//    });

//    // 连接错误也自动释放
//    QObject::connect(client, &TcpClient::errorOccurred, [client](const QString &err) {
//        qWarning() << "TcpClient错误:" << err;
//        client->deleteLater();
//    });

//    // 发起连接
//    client->connectToServer("127.0.0.1", STREAM_PORT);
}


//void RamdomForestCalculate::SendStreamAlarminfo(const StreamInfo &alarmInfo)
//{
//    TcpClient* client = new TcpClient();

//    // 统一的清理函数
//    auto cleanup = [client]() {
//        static bool cleaned = false;  // 静态变量确保只清理一次收到 TMS 信号
//        if (!cleaned) {
//            cleaned = true;
//            //qDebug() << "清理客户端资源";
//            client->disconnectFromServer();
//            QTimer::singleShot(0, client, &QObject::deleteLater);  // 延迟删除确保安全
//        }
//    };

//    // 连接所有信号到统一的清理函数
//    QObject::connect(client, &TcpClient::connected, [client, alarmInfo, cleanup]() {
//        qDebug() << "=== 连接成功，开始发送报警消息 ===";

//        QJsonObject message;
//        message["source"] = STREAM_MESSAGE;
//        message["typealarm"] = QString::number(alarmInfo.alarmtype);
//        message["target"] = alarmInfo.target;
//        message["startIds"] = alarmInfo.startIds;
//        message["endIds"] = alarmInfo.endIds;
//        //message["power"] = alarmInfo.endIds;

//        client->sendMessage(message);
//        SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"发送给流式告警信息成功");

//        // 发送完成后立即开始清理
//        cleanup();
//    });

//    QObject::connect(client, &TcpClient::disconnected, cleanup);
//    QObject::connect(client, &TcpClient::messageReceived, cleanup);
//    QObject::connect(client, &TcpClient::errorOccurred, cleanup);

//    // 设置连接超时（5秒）
//    QTimer::singleShot(5000, client, cleanup);

//    // 连接到服务器
//    qDebug() << "正在连接到服务器...";
//    client->connectToServer("127.0.0.1", STREAM_PORT);
//}

int RamdomForestCalculate::Save_SACDATA(const Dc_SacData &sacdata)
{
    std::array<std::vector<float>, 3> raw;

    QString FilePath = QString("/home/pi/SigerTMS/stream/SAC/%1").arg(sacdata.sacFile);

    //step1: 读取二进制原始文件
    VibrationDataReader reader;
    if (reader.readFileRaw(FilePath)) {
        reader.printDataInfo();
    }

    FeatureExtractor extractor;
    SignalFeatures fea[3];
    raw = reader.convertToRawFormat();
    for (int i=0; i<3; ++i)
        fea[i] = extractor.extract_features(raw[i]);

    //step2: 预测报警
    int valid_num = 0;
    double pred[3] = {-1.0, -1.0, -1.0};
    auto rfc_it = m_rf_map.find(sacdata.target.toStdString());
    for (int i=0; i<3; ++i)
    {
        //fea[i].print();
        if (fea[i].isValid()){
            ++valid_num;
            if (rfc_it != m_rf_map.end())
                pred[i] = rfc_it->second.predict(fea[i].toVector());
        }else{
            //无效数据存储为0
            fea[i].clear();
        }
    }

    int predRes  = pickPriorityValue(pred[0], pred[1], pred[2]);
    SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"valid={} xpred={} ypred={} zpred={} final_pred={}", valid_num, pred[0], pred[1], pred[2], predRes);
    if (predRes > 0)
    {
        StreamInfo info(predRes,sacdata.target,sacdata.startIds,sacdata.endIds);
        SendStreamAlarminfo(info);
    }

    //step3: 存入数据库
    QString x;
    QString y;
    QString z;
    fea[0].fill(x);
    fea[1].fill(y);
    fea[2].fill(z);

    std::string data = QString(
        "INSERT INTO RamdomForestFeatureValue (StartTime, EndTime, target, Label, x, y, z) "
        "VALUES ('%1', '%2', '%3', %4, '%5', '%6', '%7')"
    )
    .arg(sacdata.startIds) //开始时间
    .arg(sacdata.endIds) // 结束时间
    .arg(sacdata.target)  // 设备编号
    .arg(0)  // Label: 1,2,3 循环
    .arg(x)   // x 特征值
    .arg(y)   // y 特征值
    .arg(z)   // z 特征值
    .toStdString();

    QString dbPath = QString("/home/pi/SigerTMS/stream/SAC/%1.db")
                         .arg("RamdomForestFeatureValue");

    QString sql = QString::fromStdString(data);
    bool ok = execShortConn(sql, dbPath);

    if (!ok) {
        qWarning() << "Insert failed!";
    }

    //dbFunc->addDataToQueue(data);

    //step4: 删除二进制文件
    QFile::remove(FilePath);

    return 0;
}

int RamdomForestCalculate::judge_alarm()
{
    return 0;
}

const size_t max_normal_per_target = 200;
bool RamdomForestCalculate::queryRecordsByLabel(const QString &dbFileName,
        std::map<std::string, std::pair<std::vector<SACRecord>, std::vector<SACRecord>>>& data)
{
    // ==== 1. 建立唯一连接名 ====
    QString conn = QString("rf_query_%1_%2")
                       .arg((qulonglong)QThread::currentThreadId())
                       .arg(QDateTime::currentMSecsSinceEpoch());

    {
        QSqlDatabase db = QSqlDatabase::addDatabase("QSQLITE", conn);
        db.setDatabaseName(dbFileName);

        if (!db.open()) {
            qDebug() << "Database open failed:" << db.lastError().text();
            QSqlDatabase::removeDatabase(conn);
            return false;
        }

        // PRAGMA（可选但推荐）
        QSqlQuery prag(db);
        prag.exec("PRAGMA journal_mode=WAL;");
        prag.exec("PRAGMA synchronous=NORMAL;");
        prag.exec("PRAGMA busy_timeout=3000;");

        // ==== 2. 执行查询 ====
        QString sql = QString("SELECT * FROM RamdomForestFeatureValue");

        QSqlQuery query(db);
        if (!query.exec(sql)) {
            qDebug() << "Query failed:" << query.lastError().text();
            db.close();
            QSqlDatabase::removeDatabase(conn);
            return false;
        }

        // ==== 3. 遍历结果 ====
        while (query.next()) {
            SACRecord record;
            record.startTime = query.value("StartTime").toString();
            record.endTime = query.value("EndTime").toString();
            record.target = query.value("target").toString();
            record.label = query.value("Label").toInt();
            record.x = query.value("x").toString();
            record.y = query.value("y").toString();
            record.z = query.value("z").toString();

            if (record.target.isEmpty() || record.startTime.isEmpty() || record.startTime.isEmpty() ||
                    record.x.isEmpty() || record.y.isEmpty() || record.z.isEmpty())
                continue;

            std::pair<std::vector<SACRecord>, std::vector<SACRecord>>& target_pair = data[record.target.toStdString()];

            if (record.label != NORMAL) {  // 异常
                target_pair.first.push_back(record);
            } else {  // 正常
                if (target_pair.second.size() < static_cast<size_t>(max_normal_per_target)) {
                    target_pair.second.push_back(record);
                } else {
                    // 随机替换
                    int replace_index = std::rand() % max_normal_per_target;
                    if (replace_index < 0)
                        replace_index = -replace_index;
                    target_pair.second[replace_index] = record;
                }
            }
        }

        db.close();
    }

    // ==== 4. 删除连接（必须先关闭再删除） ====
    QSqlDatabase::removeDatabase(conn);

    return true;
}

std::vector<float> RamdomForestCalculate::QstrToFeature(const QString& qstr)
{
    std::vector<float> v;
    std::string token;
    std::stringstream ss;
    ss.str(qstr.toStdString());

    while (std::getline(ss, token, ',')) {
        if (token.empty() || token == "NA") continue;
        try {
            float f = std::stof(token);
            v.push_back(f);
        } catch(const std::exception& e) {
            SPDLOG_LOGGER_DEBUG(spdlog::get("logger"), "ERROR: standard exception: {}", e.what());	  
	}
    }

    if (v.size() != SAC_FEATURE_DIM){
        return std::vector<float>();
    }

    bool  isAllZero = true;
    for (auto e:v)
        if (e!=0.0f){
            isAllZero = false;
            break;
        }

    if (isAllZero)
        return std::vector<float>();

    return v;
}

void RamdomForestCalculate::DealRecord(const std::map<std::string, std::pair<std::vector<SACRecord>,
                                       std::vector<SACRecord>>>& data, std::map<std::string, SqlFeaRecord>& result)
{
    result.clear();
    std::map<std::string, std::pair<std::vector<SACRecord>, std::vector<SACRecord>>>::const_iterator it;

    for (it = data.begin(); it != data.end(); ++it)
    {
        const std::string& target = it->first;
        const std::pair<std::vector<SACRecord>, std::vector<SACRecord>>& target_data = it->second;
        const std::vector<SACRecord>* records[2] = {&target_data.first, &target_data.second};

        // 临时存储
        std::vector<std::vector<float>> temp_samples;
        std::vector<int> temp_labels;
        bool has_valid_data = false;

        for (int idx=0; idx<2; ++idx)
        {
            if (!records[idx] || records[idx]->empty())
                continue;

            for (const auto& tmp : *records[idx])
            {
                std::vector<float> v[3];
                v[0] = QstrToFeature(tmp.x);
                v[1] = QstrToFeature(tmp.y);
                v[2] = QstrToFeature(tmp.z);

                for (int i=0; i<3; ++i)
                {
                    if (v[i].empty())
                        continue;

                    temp_samples.push_back(std::move(v[i]));
                    temp_labels.push_back(tmp.label);
                    has_valid_data = true;
                }
            } //foreach
        } //for idx

        // 只有有有效数据才创建
        if (has_valid_data)
        {
            result[target] = SqlFeaRecord{std::move(temp_samples), std::move(temp_labels)};
        }
    } //for it
}

void RamdomForestCalculate::Retrain()
{
    std::map<std::string, std::pair<std::vector<SACRecord>, std::vector<SACRecord>>> record;
    std::map<std::string, SqlFeaRecord> result;
    QString dbPath = QString("/home/pi/SigerTMS/stream/SAC/%1.db").arg("RamdomForestFeatureValue");
    struct timeval start,end;

    gettimeofday(&start, NULL);
    RamdomForestCalculate::queryRecordsByLabel(dbPath, record);
    DealRecord(record, result);
    m_rf_map.clear();

    for (const auto& item : result)
    {
        const std::string& target = item.first;
        const SqlFeaRecord& data = item.second;
        if (target.empty() || data.sample.empty() || data.sample.size()!=data.label.size())
            continue;

        SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"train start: target={} sample={}", target.c_str(), data.sample.size());
        auto it = m_rf_map.find(target);
        if (it != m_rf_map.end()){
            it->second.train(data.sample, data.label);
        }else{
            m_rf_map[target] = RandomForestClassifier(target, 10, 4);
            m_rf_map[target].train(data.sample, data.label);
        }
    }

    gettimeofday(&end, NULL);
    SPDLOG_LOGGER_DEBUG(spdlog::get("logger"), "train done, use time: {} ms", (end.tv_sec - start.tv_sec) * 1000L + (end.tv_usec - start.tv_usec)/1000);
}
