Commit e4915983 by yahan.li

随机森林代码修改:按照target训练和预测,target包含刀号,进给和转速倍率

	修改:     RadomForestSAC/RamdomForestCalculate.cpp
	修改:     RadomForestSAC/RamdomForestCalculate.h
	修改:     RadomForestSAC/RandomForestClassifier.cpp
	修改:     RadomForestSAC/RandomForestClassifier.h
	修改:     RadomForestSAC/common.h
	修改:     main.cpp
	修改:     package/E3/SigerCalculation
parent 259c8b9d
......@@ -17,6 +17,9 @@
#include <QSqlError>
#include <QTimer>
#ifndef SAC_FEATURE_DIM
#define SAC_FEATURE_DIM 10
#endif
// 如果 Dc_SacData 是结构体/类,需要注册元类型
Q_DECLARE_METATYPE(Dc_SacData)
......@@ -48,7 +51,10 @@ struct StreamInfo
}
};
struct SqlFeaRecord{
std::vector<std::vector<float> > sample;
std::vector<int> label;
};
class RamdomForestCalculate : public QObject
{
......@@ -56,6 +62,7 @@ class RamdomForestCalculate : public QObject
public:
explicit RamdomForestCalculate(QObject *parent = nullptr);
std::vector<float> QstrToFeature(const QString& qstr);
//1.查询功率样本(废弃函数)
QString QueryPowerData(const StreamInfo& alarmInfo);
......@@ -64,7 +71,8 @@ public:
void SendStreamAlarminfo(const StreamInfo& alarmInfo);
//3.查询函数
QList<SACRecord> queryRecordsByLabel(const QString& dbFileName, int targetLabel);
bool queryRecordsByLabel(const QString &dbFileName,
std::map<std::string, std::pair<std::vector<SACRecord>, std::vector<SACRecord>>>& data);
//4.接受DC消息,保存振动数据
int Save_SACDATA(const Dc_SacData& sacdata);
......@@ -73,7 +81,7 @@ public:
int judge_alarm();
//6.解析查询到的振动值
void DealRecord(const QList<SACRecord> &record,std::vector<std::vector<float>>& result1,std::vector<int> &result2);
void DealRecord(const std::map<std::string, std::pair<std::vector<SACRecord>, std::vector<SACRecord>>>& data, std::map<std::string, SqlFeaRecord>& result);
//7.重新训练
void Retrain();
......@@ -86,9 +94,7 @@ private:
//QString m_name;
//DBFunc *dbFunc;
TcpServerManager *m_tcpServerManager;
RandomForestClassifier* Classifier;
bool IfHaveModel = false;
std::map<std::string, RandomForestClassifier> m_rf_map;
};
......
......@@ -4,6 +4,16 @@
#include "common.h"
#include "spdlog/spdlog.h"
RandomForestClassifier::RandomForestClassifier(std::string target, int initial_dim, int label_num)
{
target = m_target;
feature_dim = initial_dim;
label_dim = label_num;
}
const std::string& RandomForestClassifier::getTarget() const {return m_target;}
void RandomForestClassifier::set_target(const std::string& target) {m_target=target;}
std::vector<sample_type> RandomForestClassifier::convert_to_dlib_format(const std::vector<std::vector<float>>& X)
{
std::vector<sample_type> result;
......@@ -136,12 +146,12 @@ bool RandomForestClassifier::train(const std::vector<std::vector<float>>& X, con
}
if (X.size() != y.size()) {
DEBUG_PRINT("invalid param: size_X=%ld size_y=%ld\n", X.size(), y.size());
DEBUG_PRINT("invalid param: size_X=%lu size_y=%lu\n", X.size(), y.size());
return false;
}
if (X[0].size() != static_cast<size_t>(feature_dim)){
DEBUG_PRINT("invalid param: X[0].size()=%ld feature_dim=%d\n", X[0].size(), feature_dim);
DEBUG_PRINT("invalid param: X[0].size()=%lu feature_dim=%d\n", X[0].size(), feature_dim);
return false;
}
......@@ -190,9 +200,8 @@ bool RandomForestClassifier::train(const std::vector<std::vector<float>>& X, con
try
{
serialize("rf_model.dat") << forest;
serialize("scaler_mean.dat") << scaler_mean;
serialize("scaler_std.dat") << scaler_std;
std::string filename = (!m_target.empty()) ? (m_target + "_rf_model.dat") : "rf_model.dat";
serialize(filename) << scaler_mean << scaler_std << forest;
DEBUG_PRINT("save random-forest model success\n");
}
catch(std::exception &e) {
......@@ -294,25 +303,18 @@ double RandomForestClassifier::predict(const std::vector<float>& sample)
{
if (sample.size() != static_cast<size_t>(feature_dim))
{
printf("invalid param: sample_size=%ld feature_dim=%d\n", sample.size(), feature_dim);
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"invalid param: sample_size-------sample.size:{}--------feature_dim{}",sample.size(),feature_dim);
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"), "[predict] invalid param: sample_size={} feature_dim={}", sample.size(), feature_dim);
return -1;
}
if (forest.get_num_trees() == 0) {
try {
deserialize("rf_model.dat") >> forest;
deserialize("scaler_mean.dat") >> scaler_mean;
deserialize("scaler_std.dat") >> scaler_std;
std::string filename = (!m_target.empty()) ? (m_target + "_rf_model.dat") : "rf_model.dat";
deserialize(filename) >> scaler_mean >> scaler_std >> forest;
feature_dim = scaler_mean.size();
}
catch (const std::exception& e) {
DEBUG_PRINT("load model failure: %s\n", e.what());
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"load model failure:--------------");
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"), "load model failure: {}\n", e.what());
return -1;
}
}
......@@ -325,13 +327,3 @@ double RandomForestClassifier::predict(const std::vector<float>& sample)
double prediction = forest(scaled);
return prediction;
}
std::vector<double> RandomForestClassifier::predict_batch(const std::vector<std::vector<float>>& samples)
{
std::vector<double> result;
for (const auto& s : samples) {
result.push_back(predict(s));
}
return result;
}
......@@ -20,6 +20,7 @@ private:
random_forest_regression_function<> forest;
std::vector<double> scaler_mean;
std::vector<double> scaler_std;
std::string m_target;
int feature_dim;
int label_dim;
......@@ -39,13 +40,13 @@ private:
public:
RandomForestClassifier() : feature_dim(10), label_dim(4) {}
RandomForestClassifier(int initial_dim, int label_num) : feature_dim(initial_dim), label_dim(label_num) {}
RandomForestClassifier(std::string target, int initial_dim=10, int label_num=4);
bool train(const std::vector<std::vector<float>>& X, const std::vector<int>& y,
float test_size = 0.0f, int n_estimators = 100);
double predict(const std::vector<float>& sample);
std::vector<double> predict_batch(const std::vector<std::vector<float>>& samples);
const std::string& getTarget() const;
void set_target(const std::string& target);
};
#endif // RANDOM_FOREST_CLASSIFIER_H
......@@ -9,7 +9,8 @@ enum {
};
//#define DEBUG_PRINT(fmt, ...) debug_printf(" [%s] " fmt, __func__, ##__VA_ARGS__)
#define DEBUG_PRINT(fmt, ...) printf("[%s] " fmt, __func__, ##__VA_ARGS__)
//#define DEBUG_PRINT(fmt, ...) printf("[%s] " fmt, __func__, ##__VA_ARGS__)
#define DEBUG_PRINT
void debug_printf(const char *fmt, ...);
void measure_time(struct timeval *start);
bool isFeatureValid(const std::vector<float>& fea);
......
......@@ -47,7 +47,6 @@ int main(int argc, char *argv[])
MonitorLogfile m_Logfile;
RamdomForestCalculate RandomCaculate;
//qRegisterMetaType<Dc_SacData>("Dc_SacData");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment