Commit e4915983 by yahan.li

随机森林代码修改:按照target训练和预测,target包含刀号,进给和转速倍率

	修改:     RadomForestSAC/RamdomForestCalculate.cpp
	修改:     RadomForestSAC/RamdomForestCalculate.h
	修改:     RadomForestSAC/RandomForestClassifier.cpp
	修改:     RadomForestSAC/RandomForestClassifier.h
	修改:     RadomForestSAC/common.h
	修改:     main.cpp
	修改:     package/E3/SigerCalculation
parent 259c8b9d
#include "RamdomForestCalculate.h" #include "RamdomForestCalculate.h"
#include <cstdlib>
#include <dirent.h>
#include <QCoreApplication> #include <QCoreApplication>
#include "spdlog/spdlog.h" #include "spdlog/spdlog.h"
#include <QFileInfo> #include <QFileInfo>
void delete_excess_rf_models()
{
int to_delete = 0;
int done = 0;
DIR* dir = opendir(".");
if (!dir) return;
struct dirent* entry = NULL;
while ((entry = readdir(dir)) != nullptr) {
if ('.' == entry->d_name[0])
continue;
std::string filename = entry->d_name;
if (filename.find("rf_model.dat") != std::string::npos)
++to_delete;
}
if (to_delete >= 1000)
{
rewinddir(dir);
while ((entry = readdir(dir)) != nullptr) {
if ('.' == entry->d_name[0])
continue;
std::string filename = entry->d_name;
if (filename.find("rf_model.dat") != std::string::npos){
if (!unlink(entry->d_name))
++done;
}
}
}
closedir(dir);
if (to_delete > 0)
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"[delete_rf_models] to_delete={} done={}", to_delete, done);
}
bool checkSqliteDb(const QString& dbPath)
{
QFile file(dbPath);
if (!file.open(QIODevice::ReadOnly)) {
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"[isSqliteDb] open database failed");
return false;
}
QByteArray header = file.read(16);
file.close();
if (header.size() < 16) {
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"[isSqliteDb] read_bytes lt 16: {}", header.size());
QFile::remove(dbPath);
return false;
}
static const QByteArray sqliteMagic("SQLite format 3");
if (!header.startsWith(sqliteMagic)){
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"[isSqliteDb] read sqlite magic error");
QFile::remove(dbPath);
return false;
}
return true;
}
int pickPriorityValue(double xpred, double ypred, double zpred) int pickPriorityValue(double xpred, double ypred, double zpred)
{ {
// 四舍五入 // 四舍五入
...@@ -23,6 +90,9 @@ void CreatTable() ...@@ -23,6 +90,9 @@ void CreatTable()
QString dbPath = QString("/home/pi/SigerTMS/stream/SAC/%1.db") QString dbPath = QString("/home/pi/SigerTMS/stream/SAC/%1.db")
.arg("RamdomForestFeatureValue"); .arg("RamdomForestFeatureValue");
delete_excess_rf_models();
checkSqliteDb(dbPath);
// 判断数据库文件是否存在 // 判断数据库文件是否存在
QFileInfo fi(dbPath); QFileInfo fi(dbPath);
if (!fi.exists()) { if (!fi.exists()) {
...@@ -114,26 +184,11 @@ RamdomForestCalculate::RamdomForestCalculate(QObject *parent) ...@@ -114,26 +184,11 @@ RamdomForestCalculate::RamdomForestCalculate(QObject *parent)
//dbFunc = new DBFunc(); // 实例化 //dbFunc = new DBFunc(); // 实例化
//dbFunc->startDBThread("RamdomForestFeatureValue"); //dbFunc->startDBThread("RamdomForestFeatureValue");
//1.初始化分类器
Classifier = new RandomForestClassifier(10, 4);
//2.创建数据库表 //2.创建数据库表
CreatTable(); CreatTable();
//3.训练初始模型 //3.训练初始模型
QString dbPath = QString("/home/pi/SigerTMS/stream/SAC/%1.db").arg("RamdomForestFeatureValue"); Retrain();
QList<SACRecord> record = queryRecordsByLabel(dbPath,0);
if(record.size() != 0)
{
IfHaveModel = true;
std::vector<std::vector<float>> FearValues;
std::vector<int> TypeValues;
DealRecord(record,FearValues,TypeValues);
Classifier->train(FearValues, TypeValues, 0.2f, 100);
}
//GLS Test //GLS Test
// StreamInfo info(1,"C$14$1$1$1$101$0","1764733704270-0","0"); // StreamInfo info(1,"C$14$1$1$1$101$0","1764733704270-0","0");
...@@ -156,11 +211,11 @@ RamdomForestCalculate::RamdomForestCalculate(QObject *parent) ...@@ -156,11 +211,11 @@ RamdomForestCalculate::RamdomForestCalculate(QObject *parent)
QObject::connect(m_tcpServerManager, &TcpServerManager::TMSTrain, QObject::connect(m_tcpServerManager, &TcpServerManager::TMSTrain,
[=]() { [=]() {
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"收到 TMS 信号"); SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"收到 TMS 信号");
IfHaveModel = true;
Retrain(); Retrain();
//重新训练 //重新训练
}); });
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"), "RamdomForestCalculate init done");
} }
QString RamdomForestCalculate::QueryPowerData(const StreamInfo &alarmInfo) QString RamdomForestCalculate::QueryPowerData(const StreamInfo &alarmInfo)
...@@ -286,67 +341,6 @@ void RamdomForestCalculate::SendStreamAlarminfo(const StreamInfo &alarmInfo) ...@@ -286,67 +341,6 @@ void RamdomForestCalculate::SendStreamAlarminfo(const StreamInfo &alarmInfo)
// client->connectToServer("127.0.0.1", STREAM_PORT); // client->connectToServer("127.0.0.1", STREAM_PORT);
//} //}
QList<SACRecord> RamdomForestCalculate::queryRecordsByLabel(const QString &dbFileName, int targetLabel)
{
QList<SACRecord> records;
// ==== 1. 建立唯一连接名 ====
QString conn = QString("rf_query_%1_%2")
.arg((qulonglong)QThread::currentThreadId())
.arg(QDateTime::currentMSecsSinceEpoch());
{
QSqlDatabase db = QSqlDatabase::addDatabase("QSQLITE", conn);
db.setDatabaseName(dbFileName);
if (!db.open()) {
qDebug() << "Database open failed:" << db.lastError().text();
QSqlDatabase::removeDatabase(conn);
return records;
}
// PRAGMA(可选但推荐)
QSqlQuery prag(db);
prag.exec("PRAGMA journal_mode=WAL;");
prag.exec("PRAGMA synchronous=NORMAL;");
prag.exec("PRAGMA busy_timeout=3000;");
// ==== 2. 执行查询 ====
QString sql = QString("SELECT * FROM RamdomForestFeatureValue WHERE Label != %1")
.arg(targetLabel);
QSqlQuery query(db);
if (!query.exec(sql)) {
qDebug() << "Query failed:" << query.lastError().text();
db.close();
QSqlDatabase::removeDatabase(conn);
return records;
}
// ==== 3. 遍历结果 ====
while (query.next()) {
SACRecord record;
record.startTime = query.value("StartTime").toString();
record.endTime = query.value("EndTime").toString();
record.target = query.value("target").toString();
record.label = query.value("Label").toInt();
record.x = query.value("x").toString();
record.y = query.value("y").toString();
record.z = query.value("z").toString();
records.append(record);
}
db.close();
}
// ==== 4. 删除连接(必须先关闭再删除) ====
QSqlDatabase::removeDatabase(conn);
return records;
}
int RamdomForestCalculate::Save_SACDATA(const Dc_SacData &sacdata) int RamdomForestCalculate::Save_SACDATA(const Dc_SacData &sacdata)
{ {
std::array<std::vector<float>, 3> raw; std::array<std::vector<float>, 3> raw;
...@@ -359,108 +353,45 @@ int RamdomForestCalculate::Save_SACDATA(const Dc_SacData &sacdata) ...@@ -359,108 +353,45 @@ int RamdomForestCalculate::Save_SACDATA(const Dc_SacData &sacdata)
reader.printDataInfo(); reader.printDataInfo();
} }
raw = reader.convertToRawFormat();
//parseSACData("斯凯孚换刀周期1.dat", raw[0], raw[1], raw[2]);
QString x;
QString y;
QString z;
FeatureExtractor extractor; FeatureExtractor extractor;
SignalFeatures fea[3]; SignalFeatures fea[3];
raw = reader.convertToRawFormat();
for (int i=0; i<3; ++i) for (int i=0; i<3; ++i)
fea[i] = extractor.extract_features(raw[i]); fea[i] = extractor.extract_features(raw[i]);
//step2: 预测报警
int valid_num = 0;
double pred[3] = {-1.0, -1.0, -1.0};
auto rfc_it = m_rf_map.find(sacdata.target.toStdString());
for (int i=0; i<3; ++i) for (int i=0; i<3; ++i)
{ {
fea[i].print(); //fea[i].print();
if(i == 0) if (fea[i].isValid()){
{ ++valid_num;
fea[i].fill(x); if (rfc_it != m_rf_map.end())
} pred[i] = rfc_it->second.predict(fea[i].toVector());
else if(i == 1) }else{
{ //无效数据存储为0
fea[i].fill(y); fea[i].clear();
}
else if(i == 2)
{
fea[i].fill(z);
} }
} }
//step2: 预测报警
//std::string data;
qDebug()<<"x "<<x;
if(IfHaveModel) int predRes = pickPriorityValue(pred[0], pred[1], pred[2]);
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"valid={} xpred={} ypred={} zpred={} final_pred={}", valid_num, pred[0], pred[1], pred[2], predRes);
if (predRes > 0)
{ {
//X轴预测
std::vector<float> xtest_sample;
//用逗号分割
QStringList xlist = x.split(',');
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"xlist: {}",xlist.size());
//转成 float 并存到 vector
for (const QString &s : xlist) {
bool ok = false;
float f = s.toFloat(&ok);
if (ok) {
xtest_sample.push_back(f);
} else {
std::cerr << "转换失败: " << s.toStdString() << std::endl;
}
}
double xpred = Classifier->predict(xtest_sample);
//cout << "测试样本预测类别: " << pred << endl;
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"x轴测试样本预测类别: {}",xpred);
//Y轴预测
std::vector<float> ytest_sample;
//用逗号分割
QStringList ylist = y.split(',');
//转成 float 并存到 vector
for (const QString &s : ylist) {
bool ok = false;
float f = s.toFloat(&ok);
if (ok) {
ytest_sample.push_back(f);
} else {
std::cerr << "转换失败: " << s.toStdString() << std::endl;
}
}
double ypred = Classifier->predict(ytest_sample);
//cout << "测试样本预测类别: " << pred << endl;
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"y轴测试样本预测类别: {}",ypred);
//Z轴预测
std::vector<float> ztest_sample;
//用逗号分割
QStringList zlist = z.split(',');
//转成 float 并存到 vector
for (const QString &s : zlist) {
bool ok = false;
float f = s.toFloat(&ok);
if (ok) {
ztest_sample.push_back(f);
} else {
std::cerr << "转换失败: " << s.toStdString() << std::endl;
}
}
double zpred = Classifier->predict(ztest_sample);
//cout << "测试样本预测类别: " << pred << endl;
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"z轴测试样本预测类别: {}",zpred);
int predRes = pickPriorityValue(xpred,ypred,zpred);
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"报警类型: {}",predRes);
StreamInfo info(predRes,sacdata.target,sacdata.startIds,sacdata.endIds); StreamInfo info(predRes,sacdata.target,sacdata.startIds,sacdata.endIds);
SendStreamAlarminfo(info); SendStreamAlarminfo(info);
} }
//step3: 存入数据库 //step3: 存入数据库
QString x;
QString y;
QString z;
fea[0].fill(x);
fea[1].fill(y);
fea[2].fill(z);
std::string data = QString( std::string data = QString(
"INSERT INTO RamdomForestFeatureValue (StartTime, EndTime, target, Label, x, y, z) " "INSERT INTO RamdomForestFeatureValue (StartTime, EndTime, target, Label, x, y, z) "
"VALUES ('%1', '%2', '%3', %4, '%5', '%6', '%7')" "VALUES ('%1', '%2', '%3', %4, '%5', '%6', '%7')"
...@@ -494,67 +425,198 @@ int RamdomForestCalculate::Save_SACDATA(const Dc_SacData &sacdata) ...@@ -494,67 +425,198 @@ int RamdomForestCalculate::Save_SACDATA(const Dc_SacData &sacdata)
int RamdomForestCalculate::judge_alarm() int RamdomForestCalculate::judge_alarm()
{ {
return 0;
} }
void RamdomForestCalculate::DealRecord(const QList<SACRecord> &record, std::vector<std::vector<float> > &result1, std::vector<int> &result2) const size_t max_normal_per_target = 200;
bool RamdomForestCalculate::queryRecordsByLabel(const QString &dbFileName,
std::map<std::string, std::pair<std::vector<SACRecord>, std::vector<SACRecord>>>& data)
{ {
std::vector<float> tmpresult; // ==== 1. 建立唯一连接名 ====
QString conn = QString("rf_query_%1_%2")
foreach (auto tmp, record) { .arg((qulonglong)QThread::currentThreadId())
//1.解析振动量 .arg(QDateTime::currentMSecsSinceEpoch());
QStringList list = tmp.x.split(',');
tmpresult.clear(); {
for (const QString& item : list) { QSqlDatabase db = QSqlDatabase::addDatabase("QSQLITE", conn);
bool ok; db.setDatabaseName(dbFileName);
float value = item.toFloat(&ok);
if (ok) { if (!db.open()) {
tmpresult.push_back(value); qDebug() << "Database open failed:" << db.lastError().text();
} QSqlDatabase::removeDatabase(conn);
return false;
} }
result1.push_back(tmpresult);
result2.push_back(tmp.label); // PRAGMA(可选但推荐)
QSqlQuery prag(db);
list.clear(); prag.exec("PRAGMA journal_mode=WAL;");
tmpresult.clear(); prag.exec("PRAGMA synchronous=NORMAL;");
list = tmp.y.split(','); prag.exec("PRAGMA busy_timeout=3000;");
for (const QString& item : list) {
bool ok; // ==== 2. 执行查询 ====
float value = item.toFloat(&ok); QString sql = QString("SELECT * FROM RamdomForestFeatureValue");
if (ok) {
tmpresult.push_back(value); QSqlQuery query(db);
} if (!query.exec(sql)) {
qDebug() << "Query failed:" << query.lastError().text();
db.close();
QSqlDatabase::removeDatabase(conn);
return false;
} }
result1.push_back(tmpresult);
result2.push_back(tmp.label); // ==== 3. 遍历结果 ====
while (query.next()) {
list.clear(); SACRecord record;
tmpresult.clear(); record.startTime = query.value("StartTime").toString();
list = tmp.z.split(','); record.endTime = query.value("EndTime").toString();
for (const QString& item : list) { record.target = query.value("target").toString();
bool ok; record.label = query.value("Label").toInt();
float value = item.toFloat(&ok); record.x = query.value("x").toString();
if (ok) { record.y = query.value("y").toString();
tmpresult.push_back(value); record.z = query.value("z").toString();
if (record.target.isEmpty() || record.startTime.isEmpty() || record.startTime.isEmpty() ||
record.x.isEmpty() || record.y.isEmpty() || record.z.isEmpty())
continue;
std::pair<std::vector<SACRecord>, std::vector<SACRecord>>& target_pair = data[record.target.toStdString()];
if (record.label != NORMAL) { // 异常
target_pair.first.push_back(record);
} else { // 正常
if (target_pair.second.size() < static_cast<size_t>(max_normal_per_target)) {
target_pair.second.push_back(record);
} else {
// 随机替换
int replace_index = std::rand() % max_normal_per_target;
if (replace_index < 0)
replace_index = -replace_index;
target_pair.second[replace_index] = record;
}
} }
} }
result1.push_back(tmpresult);
//2.解析特征值
result2.push_back(tmp.label);
db.close();
}
// ==== 4. 删除连接(必须先关闭再删除) ====
QSqlDatabase::removeDatabase(conn);
return true;
}
std::vector<float> RamdomForestCalculate::QstrToFeature(const QString& qstr)
{
std::vector<float> v;
std::string token;
std::stringstream ss;
ss.str(qstr.toStdString());
while (std::getline(ss, token, ',')) {
if (token.empty() || token == "NA") continue;
try {
float f = std::stof(token);
v.push_back(f);
} catch(const std::exception& e) {
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"), "ERROR: standard exception: {}", e.what());
}
} }
return;
if (v.size() != SAC_FEATURE_DIM){
return std::vector<float>();
}
bool isAllZero = true;
for (auto e:v)
if (e!=0.0f){
isAllZero = false;
break;
}
if (isAllZero)
return std::vector<float>();
return v;
}
void RamdomForestCalculate::DealRecord(const std::map<std::string, std::pair<std::vector<SACRecord>,
std::vector<SACRecord>>>& data, std::map<std::string, SqlFeaRecord>& result)
{
result.clear();
std::map<std::string, std::pair<std::vector<SACRecord>, std::vector<SACRecord>>>::const_iterator it;
for (it = data.begin(); it != data.end(); ++it)
{
const std::string& target = it->first;
const std::pair<std::vector<SACRecord>, std::vector<SACRecord>>& target_data = it->second;
const std::vector<SACRecord>* records[2] = {&target_data.first, &target_data.second};
// 临时存储
std::vector<std::vector<float>> temp_samples;
std::vector<int> temp_labels;
bool has_valid_data = false;
for (int idx=0; idx<2; ++idx)
{
if (!records[idx] || records[idx]->empty())
continue;
for (const auto& tmp : *records[idx])
{
std::vector<float> v[3];
v[0] = QstrToFeature(tmp.x);
v[1] = QstrToFeature(tmp.y);
v[2] = QstrToFeature(tmp.z);
for (int i=0; i<3; ++i)
{
if (v[i].empty())
continue;
temp_samples.push_back(std::move(v[i]));
temp_labels.push_back(tmp.label);
has_valid_data = true;
}
} //foreach
} //for idx
// 只有有有效数据才创建
if (has_valid_data)
{
result[target] = SqlFeaRecord{std::move(temp_samples), std::move(temp_labels)};
}
} //for it
} }
void RamdomForestCalculate::Retrain() void RamdomForestCalculate::Retrain()
{ {
std::map<std::string, std::pair<std::vector<SACRecord>, std::vector<SACRecord>>> record;
std::map<std::string, SqlFeaRecord> result;
QString dbPath = QString("/home/pi/SigerTMS/stream/SAC/%1.db").arg("RamdomForestFeatureValue"); QString dbPath = QString("/home/pi/SigerTMS/stream/SAC/%1.db").arg("RamdomForestFeatureValue");
QList<SACRecord> record = queryRecordsByLabel(dbPath,0); struct timeval start,end;
gettimeofday(&start, NULL);
RamdomForestCalculate::queryRecordsByLabel(dbPath, record);
DealRecord(record, result);
m_rf_map.clear();
std::vector<std::vector<float>> FearValues; for (const auto& item : result)
std::vector<int> TypeValues; {
DealRecord(record,FearValues,TypeValues); const std::string& target = item.first;
Classifier->train(FearValues, TypeValues, 0.2f, 100); const SqlFeaRecord& data = item.second;
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"Begin Train RamdomForestModel"); if (target.empty() || data.sample.empty() || data.sample.size()!=data.label.size())
continue;
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"train start: target={} sample={}", target.c_str(), data.sample.size());
auto it = m_rf_map.find(target);
if (it != m_rf_map.end()){
it->second.train(data.sample, data.label);
}else{
m_rf_map[target] = RandomForestClassifier(target, 10, 4);
m_rf_map[target].train(data.sample, data.label);
}
}
gettimeofday(&end, NULL);
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"), "train done, use time: {} ms", (end.tv_sec - start.tv_sec) * 1000L + (end.tv_usec - start.tv_usec)/1000);
} }
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include <QSqlError> #include <QSqlError>
#include <QTimer> #include <QTimer>
#ifndef SAC_FEATURE_DIM
#define SAC_FEATURE_DIM 10
#endif
// 如果 Dc_SacData 是结构体/类,需要注册元类型 // 如果 Dc_SacData 是结构体/类,需要注册元类型
Q_DECLARE_METATYPE(Dc_SacData) Q_DECLARE_METATYPE(Dc_SacData)
...@@ -48,7 +51,10 @@ struct StreamInfo ...@@ -48,7 +51,10 @@ struct StreamInfo
} }
}; };
struct SqlFeaRecord{
std::vector<std::vector<float> > sample;
std::vector<int> label;
};
class RamdomForestCalculate : public QObject class RamdomForestCalculate : public QObject
{ {
...@@ -56,6 +62,7 @@ class RamdomForestCalculate : public QObject ...@@ -56,6 +62,7 @@ class RamdomForestCalculate : public QObject
public: public:
explicit RamdomForestCalculate(QObject *parent = nullptr); explicit RamdomForestCalculate(QObject *parent = nullptr);
std::vector<float> QstrToFeature(const QString& qstr);
//1.查询功率样本(废弃函数) //1.查询功率样本(废弃函数)
QString QueryPowerData(const StreamInfo& alarmInfo); QString QueryPowerData(const StreamInfo& alarmInfo);
...@@ -64,7 +71,8 @@ public: ...@@ -64,7 +71,8 @@ public:
void SendStreamAlarminfo(const StreamInfo& alarmInfo); void SendStreamAlarminfo(const StreamInfo& alarmInfo);
//3.查询函数 //3.查询函数
QList<SACRecord> queryRecordsByLabel(const QString& dbFileName, int targetLabel); bool queryRecordsByLabel(const QString &dbFileName,
std::map<std::string, std::pair<std::vector<SACRecord>, std::vector<SACRecord>>>& data);
//4.接受DC消息,保存振动数据 //4.接受DC消息,保存振动数据
int Save_SACDATA(const Dc_SacData& sacdata); int Save_SACDATA(const Dc_SacData& sacdata);
...@@ -73,7 +81,7 @@ public: ...@@ -73,7 +81,7 @@ public:
int judge_alarm(); int judge_alarm();
//6.解析查询到的振动值 //6.解析查询到的振动值
void DealRecord(const QList<SACRecord> &record,std::vector<std::vector<float>>& result1,std::vector<int> &result2); void DealRecord(const std::map<std::string, std::pair<std::vector<SACRecord>, std::vector<SACRecord>>>& data, std::map<std::string, SqlFeaRecord>& result);
//7.重新训练 //7.重新训练
void Retrain(); void Retrain();
...@@ -86,9 +94,7 @@ private: ...@@ -86,9 +94,7 @@ private:
//QString m_name; //QString m_name;
//DBFunc *dbFunc; //DBFunc *dbFunc;
TcpServerManager *m_tcpServerManager; TcpServerManager *m_tcpServerManager;
std::map<std::string, RandomForestClassifier> m_rf_map;
RandomForestClassifier* Classifier;
bool IfHaveModel = false;
}; };
......
...@@ -4,6 +4,16 @@ ...@@ -4,6 +4,16 @@
#include "common.h" #include "common.h"
#include "spdlog/spdlog.h" #include "spdlog/spdlog.h"
RandomForestClassifier::RandomForestClassifier(std::string target, int initial_dim, int label_num)
{
target = m_target;
feature_dim = initial_dim;
label_dim = label_num;
}
const std::string& RandomForestClassifier::getTarget() const {return m_target;}
void RandomForestClassifier::set_target(const std::string& target) {m_target=target;}
std::vector<sample_type> RandomForestClassifier::convert_to_dlib_format(const std::vector<std::vector<float>>& X) std::vector<sample_type> RandomForestClassifier::convert_to_dlib_format(const std::vector<std::vector<float>>& X)
{ {
std::vector<sample_type> result; std::vector<sample_type> result;
...@@ -136,12 +146,12 @@ bool RandomForestClassifier::train(const std::vector<std::vector<float>>& X, con ...@@ -136,12 +146,12 @@ bool RandomForestClassifier::train(const std::vector<std::vector<float>>& X, con
} }
if (X.size() != y.size()) { if (X.size() != y.size()) {
DEBUG_PRINT("invalid param: size_X=%ld size_y=%ld\n", X.size(), y.size()); DEBUG_PRINT("invalid param: size_X=%lu size_y=%lu\n", X.size(), y.size());
return false; return false;
} }
if (X[0].size() != static_cast<size_t>(feature_dim)){ if (X[0].size() != static_cast<size_t>(feature_dim)){
DEBUG_PRINT("invalid param: X[0].size()=%ld feature_dim=%d\n", X[0].size(), feature_dim); DEBUG_PRINT("invalid param: X[0].size()=%lu feature_dim=%d\n", X[0].size(), feature_dim);
return false; return false;
} }
...@@ -190,9 +200,8 @@ bool RandomForestClassifier::train(const std::vector<std::vector<float>>& X, con ...@@ -190,9 +200,8 @@ bool RandomForestClassifier::train(const std::vector<std::vector<float>>& X, con
try try
{ {
serialize("rf_model.dat") << forest; std::string filename = (!m_target.empty()) ? (m_target + "_rf_model.dat") : "rf_model.dat";
serialize("scaler_mean.dat") << scaler_mean; serialize(filename) << scaler_mean << scaler_std << forest;
serialize("scaler_std.dat") << scaler_std;
DEBUG_PRINT("save random-forest model success\n"); DEBUG_PRINT("save random-forest model success\n");
} }
catch(std::exception &e) { catch(std::exception &e) {
...@@ -294,25 +303,18 @@ double RandomForestClassifier::predict(const std::vector<float>& sample) ...@@ -294,25 +303,18 @@ double RandomForestClassifier::predict(const std::vector<float>& sample)
{ {
if (sample.size() != static_cast<size_t>(feature_dim)) if (sample.size() != static_cast<size_t>(feature_dim))
{ {
printf("invalid param: sample_size=%ld feature_dim=%d\n", sample.size(), feature_dim); SPDLOG_LOGGER_DEBUG(spdlog::get("logger"), "[predict] invalid param: sample_size={} feature_dim={}", sample.size(), feature_dim);
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"invalid param: sample_size-------sample.size:{}--------feature_dim{}",sample.size(),feature_dim);
return -1; return -1;
} }
if (forest.get_num_trees() == 0) { if (forest.get_num_trees() == 0) {
try { try {
deserialize("rf_model.dat") >> forest; std::string filename = (!m_target.empty()) ? (m_target + "_rf_model.dat") : "rf_model.dat";
deserialize("scaler_mean.dat") >> scaler_mean; deserialize(filename) >> scaler_mean >> scaler_std >> forest;
deserialize("scaler_std.dat") >> scaler_std;
feature_dim = scaler_mean.size(); feature_dim = scaler_mean.size();
} }
catch (const std::exception& e) { catch (const std::exception& e) {
DEBUG_PRINT("load model failure: %s\n", e.what()); SPDLOG_LOGGER_DEBUG(spdlog::get("logger"), "load model failure: {}\n", e.what());
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"load model failure:--------------");
return -1; return -1;
} }
} }
...@@ -325,13 +327,3 @@ double RandomForestClassifier::predict(const std::vector<float>& sample) ...@@ -325,13 +327,3 @@ double RandomForestClassifier::predict(const std::vector<float>& sample)
double prediction = forest(scaled); double prediction = forest(scaled);
return prediction; return prediction;
} }
std::vector<double> RandomForestClassifier::predict_batch(const std::vector<std::vector<float>>& samples)
{
std::vector<double> result;
for (const auto& s : samples) {
result.push_back(predict(s));
}
return result;
}
...@@ -20,6 +20,7 @@ private: ...@@ -20,6 +20,7 @@ private:
random_forest_regression_function<> forest; random_forest_regression_function<> forest;
std::vector<double> scaler_mean; std::vector<double> scaler_mean;
std::vector<double> scaler_std; std::vector<double> scaler_std;
std::string m_target;
int feature_dim; int feature_dim;
int label_dim; int label_dim;
...@@ -39,13 +40,13 @@ private: ...@@ -39,13 +40,13 @@ private:
public: public:
RandomForestClassifier() : feature_dim(10), label_dim(4) {} RandomForestClassifier() : feature_dim(10), label_dim(4) {}
RandomForestClassifier(int initial_dim, int label_num) : feature_dim(initial_dim), label_dim(label_num) {} RandomForestClassifier(std::string target, int initial_dim=10, int label_num=4);
bool train(const std::vector<std::vector<float>>& X, const std::vector<int>& y, bool train(const std::vector<std::vector<float>>& X, const std::vector<int>& y,
float test_size = 0.0f, int n_estimators = 100); float test_size = 0.0f, int n_estimators = 100);
double predict(const std::vector<float>& sample); double predict(const std::vector<float>& sample);
std::vector<double> predict_batch(const std::vector<std::vector<float>>& samples); const std::string& getTarget() const;
void set_target(const std::string& target);
}; };
#endif // RANDOM_FOREST_CLASSIFIER_H #endif // RANDOM_FOREST_CLASSIFIER_H
...@@ -9,7 +9,8 @@ enum { ...@@ -9,7 +9,8 @@ enum {
}; };
//#define DEBUG_PRINT(fmt, ...) debug_printf(" [%s] " fmt, __func__, ##__VA_ARGS__) //#define DEBUG_PRINT(fmt, ...) debug_printf(" [%s] " fmt, __func__, ##__VA_ARGS__)
#define DEBUG_PRINT(fmt, ...) printf("[%s] " fmt, __func__, ##__VA_ARGS__) //#define DEBUG_PRINT(fmt, ...) printf("[%s] " fmt, __func__, ##__VA_ARGS__)
#define DEBUG_PRINT
void debug_printf(const char *fmt, ...); void debug_printf(const char *fmt, ...);
void measure_time(struct timeval *start); void measure_time(struct timeval *start);
bool isFeatureValid(const std::vector<float>& fea); bool isFeatureValid(const std::vector<float>& fea);
......
...@@ -47,7 +47,6 @@ int main(int argc, char *argv[]) ...@@ -47,7 +47,6 @@ int main(int argc, char *argv[])
MonitorLogfile m_Logfile; MonitorLogfile m_Logfile;
RamdomForestCalculate RandomCaculate; RamdomForestCalculate RandomCaculate;
//qRegisterMetaType<Dc_SacData>("Dc_SacData"); //qRegisterMetaType<Dc_SacData>("Dc_SacData");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment