#include "RamdomForestCalculate.h"
#include <QCoreApplication>
#include "spdlog/spdlog.h"

//发送给流式的消息
RamdomForestCalculate::RamdomForestCalculate(QObject *parent)
{
    QString a = "算法模块开始运行";
    SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"{}",a.toStdString());

    qRegisterMetaType<Dc_SacData>("Dc_SacData");

    m_tcpServerManager = new TcpServerManager();  // 实例化
    m_tcpServerManager->startServer(2026);

    dbFunc =  new DBFunc();  // 实例化
    dbFunc->startDBThread("RamdomForestFeatureValue");

    //1.初始化分类器
    Classifier = new RandomForestClassifier(10, 4);

    //2.训练初始模型
    QString dbPath = QString("/home/pi/SigerTMS/stream/SAC/%1.db").arg("RamdomForestFeatureValue");
    QList<SACRecord> record = queryRecordsByLabel(dbPath,0);

    if(record.size() != 0)
    {
        IfHaveModel = true;
        std::vector<std::vector<float>> FearValues;
        std::vector<int> TypeValues;
        DealRecord(record,FearValues,TypeValues);
        Classifier->train(FearValues, TypeValues, 0.2f, 100);
    }


    //连接信号和槽函数
    QObject::connect(m_tcpServerManager, &TcpServerManager::getdc_sacdata,
                     [=](const Dc_SacData &sac_data) {

        SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"收到 dc_sacdata 信号");
        SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"Source:{}",sac_data.source.toStdString());
        SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"SacFile:{}",sac_data.sacFile.toStdString());
        SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"Target:{}",sac_data.target.toStdString());

        Save_SACDATA(sac_data);  // 现在可以识别了
    });

    //连接TMS信号和槽函数
    QObject::connect(m_tcpServerManager, &TcpServerManager::TMSTrain,
                     [=]() {
        SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"收到 TMS 信号");
        IfHaveModel = true;
        Retrain();
        //重新训练
    });

}

void RamdomForestCalculate::SendStreamAlarminfo(const StreamInfo &alarmInfo)
{
    TcpClient* client = new TcpClient();

    // 统一的清理函数
    auto cleanup = [client]() {
        static bool cleaned = false;  // 静态变量确保只清理一次收到 TMS 信号
        if (!cleaned) {
            cleaned = true;
            //qDebug() << "清理客户端资源";
            client->disconnectFromServer();
            QTimer::singleShot(0, client, &QObject::deleteLater);  // 延迟删除确保安全
        }
    };

    // 连接所有信号到统一的清理函数
    QObject::connect(client, &TcpClient::connected, [client, alarmInfo, cleanup]() {
        qDebug() << "=== 连接成功，开始发送报警消息 ===";

        QJsonObject message;
        message["source"] = STREAM_MESSAGE;
        message["typealarm"] = alarmInfo.alarmtype;
        message["target"] = alarmInfo.target;
        message["startIds"] = alarmInfo.startIds;
        message["endIds"] = alarmInfo.endIds;

        client->sendMessage(message);
        SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"发送给流式告警信息成功");

        // 发送完成后立即开始清理
        cleanup();
    });

    QObject::connect(client, &TcpClient::disconnected, cleanup);
    QObject::connect(client, &TcpClient::messageReceived, cleanup);
    QObject::connect(client, &TcpClient::errorOccurred, cleanup);

    // 设置连接超时（5秒）
    QTimer::singleShot(5000, client, cleanup);

    // 连接到服务器
    qDebug() << "正在连接到服务器...";
    client->connectToServer("127.0.0.1", STREAM_PORT);
}

QList<SACRecord> RamdomForestCalculate::queryRecordsByLabel(const QString &dbFileName, int targetLabel)
{
    QList<SACRecord> records;

    // 建立数据库连接
    QSqlDatabase db = QSqlDatabase::addDatabase("QSQLITE", "MyConnection");
    db.setDatabaseName(dbFileName);

    if (!db.open()) {
        qDebug() << "Database open failed:" << db.lastError().text();
        return records;
    }

    // 执行查询
    QString sql = QString("SELECT * FROM RamdomForestFeatureValue WHERE Label != %1").arg(targetLabel);
    QSqlQuery query(sql, db);

    if (!query.isActive()) {
        qDebug() << "Query failed:" << query.lastError().text();
        db.close();
        return records;
    }

    if (!query.exec()) {
        qDebug() << "Query failed:" << query.lastError().text();
        db.close();
        return records;
    }

    // 遍历结果
    while (query.next()) {
        SACRecord record;
        record.startTime = query.value("StartTime").toString();
        record.endTime = query.value("EndTime").toString();
        record.target = query.value("target").toString();
        record.label = query.value("Label").toInt();
        record.x = query.value("x").toString();
        record.y = query.value("y").toString();
        record.z = query.value("z").toString();
        qDebug()<<"x: "<<record.x;
        qDebug()<<"-----------------";
        records.append(record);
    }

    db.close();
    return records;
}

int RamdomForestCalculate::Save_SACDATA(const Dc_SacData &sacdata)
{
    std::array<std::vector<float>, 3> raw;

    QString FilePath = QString("/home/pi/SigerTMS/stream/SAC/%1").arg(sacdata.sacFile);

    //step1: 读取二进制原始文件
    VibrationDataReader reader;
    if (reader.readFileRaw(FilePath)) {
        reader.printDataInfo();
    }

    raw = reader.convertToRawFormat();
    //parseSACData("斯凯孚换刀周期1.dat", raw[0], raw[1], raw[2]);
    QString x;
    QString y;
    QString z;

    FeatureExtractor extractor;
    SignalFeatures fea[3];
    for (int i=0; i<3; ++i)
        fea[i] = extractor.extract_features(raw[i]);

    for (int i=0; i<3; ++i)
    {
        fea[i].print();
        fea[i].fill(x);
        fea[i].fill(y);
        fea[i].fill(z);
    }
    //step2: 预测报警
    //std::string data;
    qDebug()<<"x "<<x;

    if(IfHaveModel)
    {
        std::vector<float> test_sample(10,0);
        //用逗号分割
        QStringList list = x.split(',');

        //转成 float 并存到 vector
        for (const QString &s : list) {
            bool ok = false;
            float f = s.toFloat(&ok);
            if (ok) {
                test_sample.push_back(f);
            } else {
                std::cerr << "转换失败: " << s.toStdString() << std::endl;
            }
        }
        double pred = Classifier->predict(test_sample);
        cout << "测试样本预测类别: " << pred << endl;

        StreamInfo(pred,sacdata.target,sacdata.startIds,sacdata.endIds);
    }

    //step3: 存入数据库
    std::string data = QString(
        "INSERT INTO RamdomForestFeatureValue (StartTime, EndTime, target, Label, x, y, z) "
        "VALUES ('%1', '%2', '%3', %4, '%5', '%6', '%7')"
    )
    .arg(sacdata.startIds) //开始时间
    .arg(sacdata.endIds) // 结束时间
    .arg(sacdata.target)  // 设备编号
    .arg(0)  // Label: 1,2,3 循环
    .arg(x)   // x 特征值
    .arg(y)   // y 特征值
    .arg(z)   // z 特征值
    .toStdString();

    dbFunc->addDataToQueue(data);

    //step4: 删除二进制文件
    QFile::remove(FilePath);

    return 0;
}

int RamdomForestCalculate::judge_alarm()
{

}

void RamdomForestCalculate::DealRecord(const QList<SACRecord> &record, std::vector<std::vector<float> > &result1, std::vector<int> &result2)
{
    std::vector<float> tmpresult;

    foreach (auto tmp, record) {
        //1.解析振动量
        QStringList list = tmp.x.split(',');
        tmpresult.clear();
        for (const QString& item : list) {
            bool ok;
            float value = item.toFloat(&ok);
            if (ok) {
                tmpresult.push_back(value);
            }
        }
        result1.push_back(tmpresult);
        result2.push_back(tmp.label);

        list.clear();
        tmpresult.clear();
        list = tmp.y.split(',');
        for (const QString& item : list) {
            bool ok;
            float value = item.toFloat(&ok);
            if (ok) {
                tmpresult.push_back(value);
            }
        }
        result1.push_back(tmpresult);
        result2.push_back(tmp.label);

        list.clear();
        tmpresult.clear();
        list = tmp.z.split(',');
        for (const QString& item : list) {
            bool ok;
            float value = item.toFloat(&ok);
            if (ok) {
                tmpresult.push_back(value);
            }
        }
        result1.push_back(tmpresult);
        //2.解析特征值
        result2.push_back(tmp.label);

    }
    return;
}

void RamdomForestCalculate::Retrain()
{
    QString dbPath = QString("/home/pi/SigerTMS/stream/SAC/%1.db").arg("RamdomForestFeatureValue");
    QList<SACRecord> record = queryRecordsByLabel(dbPath,0);

    std::vector<std::vector<float>> FearValues;
    std::vector<int> TypeValues;
    DealRecord(record,FearValues,TypeValues);
    Classifier->train(FearValues, TypeValues, 0.2f, 100);

}
