Commit e4915983 by yahan.li

随机森林代码修改:按照target训练和预测,target包含刀号,进给和转速倍率

	修改:     RadomForestSAC/RamdomForestCalculate.cpp
	修改:     RadomForestSAC/RamdomForestCalculate.h
	修改:     RadomForestSAC/RandomForestClassifier.cpp
	修改:     RadomForestSAC/RandomForestClassifier.h
	修改:     RadomForestSAC/common.h
	修改:     main.cpp
	修改:     package/E3/SigerCalculation
parent 259c8b9d
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
#include <QSqlError> #include <QSqlError>
#include <QTimer> #include <QTimer>
#ifndef SAC_FEATURE_DIM
#define SAC_FEATURE_DIM 10
#endif
// 如果 Dc_SacData 是结构体/类,需要注册元类型 // 如果 Dc_SacData 是结构体/类,需要注册元类型
Q_DECLARE_METATYPE(Dc_SacData) Q_DECLARE_METATYPE(Dc_SacData)
...@@ -48,7 +51,10 @@ struct StreamInfo ...@@ -48,7 +51,10 @@ struct StreamInfo
} }
}; };
struct SqlFeaRecord{
std::vector<std::vector<float> > sample;
std::vector<int> label;
};
class RamdomForestCalculate : public QObject class RamdomForestCalculate : public QObject
{ {
...@@ -56,6 +62,7 @@ class RamdomForestCalculate : public QObject ...@@ -56,6 +62,7 @@ class RamdomForestCalculate : public QObject
public: public:
explicit RamdomForestCalculate(QObject *parent = nullptr); explicit RamdomForestCalculate(QObject *parent = nullptr);
std::vector<float> QstrToFeature(const QString& qstr);
//1.查询功率样本(废弃函数) //1.查询功率样本(废弃函数)
QString QueryPowerData(const StreamInfo& alarmInfo); QString QueryPowerData(const StreamInfo& alarmInfo);
...@@ -64,7 +71,8 @@ public: ...@@ -64,7 +71,8 @@ public:
void SendStreamAlarminfo(const StreamInfo& alarmInfo); void SendStreamAlarminfo(const StreamInfo& alarmInfo);
//3.查询函数 //3.查询函数
QList<SACRecord> queryRecordsByLabel(const QString& dbFileName, int targetLabel); bool queryRecordsByLabel(const QString &dbFileName,
std::map<std::string, std::pair<std::vector<SACRecord>, std::vector<SACRecord>>>& data);
//4.接受DC消息,保存振动数据 //4.接受DC消息,保存振动数据
int Save_SACDATA(const Dc_SacData& sacdata); int Save_SACDATA(const Dc_SacData& sacdata);
...@@ -73,7 +81,7 @@ public: ...@@ -73,7 +81,7 @@ public:
int judge_alarm(); int judge_alarm();
//6.解析查询到的振动值 //6.解析查询到的振动值
void DealRecord(const QList<SACRecord> &record,std::vector<std::vector<float>>& result1,std::vector<int> &result2); void DealRecord(const std::map<std::string, std::pair<std::vector<SACRecord>, std::vector<SACRecord>>>& data, std::map<std::string, SqlFeaRecord>& result);
//7.重新训练 //7.重新训练
void Retrain(); void Retrain();
...@@ -86,9 +94,7 @@ private: ...@@ -86,9 +94,7 @@ private:
//QString m_name; //QString m_name;
//DBFunc *dbFunc; //DBFunc *dbFunc;
TcpServerManager *m_tcpServerManager; TcpServerManager *m_tcpServerManager;
std::map<std::string, RandomForestClassifier> m_rf_map;
RandomForestClassifier* Classifier;
bool IfHaveModel = false;
}; };
......
...@@ -4,6 +4,16 @@ ...@@ -4,6 +4,16 @@
#include "common.h" #include "common.h"
#include "spdlog/spdlog.h" #include "spdlog/spdlog.h"
RandomForestClassifier::RandomForestClassifier(std::string target, int initial_dim, int label_num)
{
target = m_target;
feature_dim = initial_dim;
label_dim = label_num;
}
const std::string& RandomForestClassifier::getTarget() const {return m_target;}
void RandomForestClassifier::set_target(const std::string& target) {m_target=target;}
std::vector<sample_type> RandomForestClassifier::convert_to_dlib_format(const std::vector<std::vector<float>>& X) std::vector<sample_type> RandomForestClassifier::convert_to_dlib_format(const std::vector<std::vector<float>>& X)
{ {
std::vector<sample_type> result; std::vector<sample_type> result;
...@@ -136,12 +146,12 @@ bool RandomForestClassifier::train(const std::vector<std::vector<float>>& X, con ...@@ -136,12 +146,12 @@ bool RandomForestClassifier::train(const std::vector<std::vector<float>>& X, con
} }
if (X.size() != y.size()) { if (X.size() != y.size()) {
DEBUG_PRINT("invalid param: size_X=%ld size_y=%ld\n", X.size(), y.size()); DEBUG_PRINT("invalid param: size_X=%lu size_y=%lu\n", X.size(), y.size());
return false; return false;
} }
if (X[0].size() != static_cast<size_t>(feature_dim)){ if (X[0].size() != static_cast<size_t>(feature_dim)){
DEBUG_PRINT("invalid param: X[0].size()=%ld feature_dim=%d\n", X[0].size(), feature_dim); DEBUG_PRINT("invalid param: X[0].size()=%lu feature_dim=%d\n", X[0].size(), feature_dim);
return false; return false;
} }
...@@ -190,9 +200,8 @@ bool RandomForestClassifier::train(const std::vector<std::vector<float>>& X, con ...@@ -190,9 +200,8 @@ bool RandomForestClassifier::train(const std::vector<std::vector<float>>& X, con
try try
{ {
serialize("rf_model.dat") << forest; std::string filename = (!m_target.empty()) ? (m_target + "_rf_model.dat") : "rf_model.dat";
serialize("scaler_mean.dat") << scaler_mean; serialize(filename) << scaler_mean << scaler_std << forest;
serialize("scaler_std.dat") << scaler_std;
DEBUG_PRINT("save random-forest model success\n"); DEBUG_PRINT("save random-forest model success\n");
} }
catch(std::exception &e) { catch(std::exception &e) {
...@@ -294,25 +303,18 @@ double RandomForestClassifier::predict(const std::vector<float>& sample) ...@@ -294,25 +303,18 @@ double RandomForestClassifier::predict(const std::vector<float>& sample)
{ {
if (sample.size() != static_cast<size_t>(feature_dim)) if (sample.size() != static_cast<size_t>(feature_dim))
{ {
printf("invalid param: sample_size=%ld feature_dim=%d\n", sample.size(), feature_dim); SPDLOG_LOGGER_DEBUG(spdlog::get("logger"), "[predict] invalid param: sample_size={} feature_dim={}", sample.size(), feature_dim);
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"invalid param: sample_size-------sample.size:{}--------feature_dim{}",sample.size(),feature_dim);
return -1; return -1;
} }
if (forest.get_num_trees() == 0) { if (forest.get_num_trees() == 0) {
try { try {
deserialize("rf_model.dat") >> forest; std::string filename = (!m_target.empty()) ? (m_target + "_rf_model.dat") : "rf_model.dat";
deserialize("scaler_mean.dat") >> scaler_mean; deserialize(filename) >> scaler_mean >> scaler_std >> forest;
deserialize("scaler_std.dat") >> scaler_std;
feature_dim = scaler_mean.size(); feature_dim = scaler_mean.size();
} }
catch (const std::exception& e) { catch (const std::exception& e) {
DEBUG_PRINT("load model failure: %s\n", e.what()); SPDLOG_LOGGER_DEBUG(spdlog::get("logger"), "load model failure: {}\n", e.what());
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"load model failure:--------------");
return -1; return -1;
} }
} }
...@@ -325,13 +327,3 @@ double RandomForestClassifier::predict(const std::vector<float>& sample) ...@@ -325,13 +327,3 @@ double RandomForestClassifier::predict(const std::vector<float>& sample)
double prediction = forest(scaled); double prediction = forest(scaled);
return prediction; return prediction;
} }
std::vector<double> RandomForestClassifier::predict_batch(const std::vector<std::vector<float>>& samples)
{
std::vector<double> result;
for (const auto& s : samples) {
result.push_back(predict(s));
}
return result;
}
...@@ -20,6 +20,7 @@ private: ...@@ -20,6 +20,7 @@ private:
random_forest_regression_function<> forest; random_forest_regression_function<> forest;
std::vector<double> scaler_mean; std::vector<double> scaler_mean;
std::vector<double> scaler_std; std::vector<double> scaler_std;
std::string m_target;
int feature_dim; int feature_dim;
int label_dim; int label_dim;
...@@ -39,13 +40,13 @@ private: ...@@ -39,13 +40,13 @@ private:
public: public:
RandomForestClassifier() : feature_dim(10), label_dim(4) {} RandomForestClassifier() : feature_dim(10), label_dim(4) {}
RandomForestClassifier(int initial_dim, int label_num) : feature_dim(initial_dim), label_dim(label_num) {} RandomForestClassifier(std::string target, int initial_dim=10, int label_num=4);
bool train(const std::vector<std::vector<float>>& X, const std::vector<int>& y, bool train(const std::vector<std::vector<float>>& X, const std::vector<int>& y,
float test_size = 0.0f, int n_estimators = 100); float test_size = 0.0f, int n_estimators = 100);
double predict(const std::vector<float>& sample); double predict(const std::vector<float>& sample);
std::vector<double> predict_batch(const std::vector<std::vector<float>>& samples); const std::string& getTarget() const;
void set_target(const std::string& target);
}; };
#endif // RANDOM_FOREST_CLASSIFIER_H #endif // RANDOM_FOREST_CLASSIFIER_H
...@@ -9,7 +9,8 @@ enum { ...@@ -9,7 +9,8 @@ enum {
}; };
//#define DEBUG_PRINT(fmt, ...) debug_printf(" [%s] " fmt, __func__, ##__VA_ARGS__) //#define DEBUG_PRINT(fmt, ...) debug_printf(" [%s] " fmt, __func__, ##__VA_ARGS__)
#define DEBUG_PRINT(fmt, ...) printf("[%s] " fmt, __func__, ##__VA_ARGS__) //#define DEBUG_PRINT(fmt, ...) printf("[%s] " fmt, __func__, ##__VA_ARGS__)
#define DEBUG_PRINT
void debug_printf(const char *fmt, ...); void debug_printf(const char *fmt, ...);
void measure_time(struct timeval *start); void measure_time(struct timeval *start);
bool isFeatureValid(const std::vector<float>& fea); bool isFeatureValid(const std::vector<float>& fea);
......
...@@ -47,7 +47,6 @@ int main(int argc, char *argv[]) ...@@ -47,7 +47,6 @@ int main(int argc, char *argv[])
MonitorLogfile m_Logfile; MonitorLogfile m_Logfile;
RamdomForestCalculate RandomCaculate; RamdomForestCalculate RandomCaculate;
//qRegisterMetaType<Dc_SacData>("Dc_SacData"); //qRegisterMetaType<Dc_SacData>("Dc_SacData");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment