Commit 7a69eae5 by yahan.li

模型的保存,加载及清理

	修改:     RadomForestSAC/FeatureExtractor.cpp
	修改:     RadomForestSAC/RamdomForestCalculate.cpp
	修改:     RadomForestSAC/RandomForestClassifier.cpp
	修改:     package/E3/SigerCalculation
parent ddc583d2
......@@ -5,7 +5,6 @@
#include <cmath>
#include <iostream>
#include <complex>
#include <sys/time.h>
#include "spdlog/spdlog.h"
#include "common.h"
......@@ -51,9 +50,6 @@ SignalFeatures FeatureExtractor::extract_features(const std::vector<float>& sign
return features;
}
unsigned long elapsed_time = 0;
struct timeval start,end;
gettimeofday(&start, NULL);
std::vector<float> processed_signal = signal;
// 1. 高通滤波处理 - 使用预定义系数
......@@ -88,9 +84,6 @@ SignalFeatures FeatureExtractor::extract_features(const std::vector<float>& sign
}
}
gettimeofday(&end, NULL);
elapsed_time = (end.tv_sec - start.tv_sec) * 1000L + (end.tv_usec - start.tv_usec)/1000;
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"[extract_features] elapsed_time: {}", elapsed_time);
return features;
}
......
......@@ -36,11 +36,11 @@ void delete_excess_rf_models()
++done;
}
}
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"[delete_rf_models] to_delete={} done={}", to_delete, done);
}
closedir(dir);
if (to_delete > 0)
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"[delete_rf_models] to_delete={} done={}", to_delete, done);
}
int pickPriorityValue(double xpred, double ypred, double zpred)
......@@ -245,6 +245,9 @@ int RamdomForestCalculate::Save_SACDATA(const Dc_SacData &sacdata)
QString FilePath = QString("/home/pi/SigerTMS/stream/SAC/%1").arg(sacdata.sacFile);
//step1: 读取二进制原始文件
unsigned long elapsed_time = 0;
struct timeval start,end;
gettimeofday(&start, NULL);
VibrationDataReader reader;
if (reader.readFileRaw(FilePath)) {
//reader.printDataInfo();
......@@ -255,6 +258,9 @@ int RamdomForestCalculate::Save_SACDATA(const Dc_SacData &sacdata)
raw = reader.convertToRawFormat();
for (int i=0; i<3; ++i)
fea[i] = extractor.extract_features(raw[i]);
gettimeofday(&end, NULL);
elapsed_time = (end.tv_sec - start.tv_sec) * 1000L + (end.tv_usec - start.tv_usec)/1000;
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"[extract_features] elapsed_time: {}", elapsed_time);
//step2: 预测报警
int valid_num = 0;
......@@ -355,7 +361,8 @@ bool RamdomForestCalculate::queryRecordsByLabel(const QString &dbFileName,
QSqlQuery query(db);
if (!query.exec(sql)) {
qDebug() << "Query failed:" << query.lastError().text();
QString error = query.lastError().text();
SPDLOG_LOGGER_ERROR(spdlog::get("logger"), "Query failed: {}", error.toUtf8().constData());
db.close();
QSqlDatabase::removeDatabase(conn);
return false;
......@@ -372,7 +379,7 @@ bool RamdomForestCalculate::queryRecordsByLabel(const QString &dbFileName,
record.y = query.value("y").toString();
record.z = query.value("z").toString();
if (record.target.isEmpty() || record.startTime.isEmpty() || record.startTime.isEmpty() ||
if (record.target.isEmpty() || record.startTime.isEmpty() || record.endTime.isEmpty() ||
record.x.isEmpty() || record.y.isEmpty() || record.z.isEmpty())
continue;
......@@ -388,6 +395,8 @@ bool RamdomForestCalculate::queryRecordsByLabel(const QString &dbFileName,
int replace_index = std::rand() % max_normal_per_target;
target_pair.second[replace_index] = record;
}
//SPDLOG_LOGGER_ERROR(spdlog::get("logger"), "target={} size={}", record.target.toStdString().c_str(), target_pair.second.size());
}
}
......
......@@ -6,7 +6,7 @@
RandomForestClassifier::RandomForestClassifier(std::string target, int initial_dim, int label_num)
{
target = m_target;
m_target = target;
feature_dim = initial_dim;
label_dim = label_num;
}
......@@ -202,10 +202,10 @@ bool RandomForestClassifier::train(const std::vector<std::vector<float>>& X, con
{
std::string filename = (!m_target.empty()) ? (m_target + "_rf_model.dat") : "rf_model.dat";
serialize(filename) << scaler_mean << scaler_std << forest;
DEBUG_PRINT("save random-forest model success\n");
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"), "save {} success", filename.c_str());
}
catch(std::exception &e) {
DEBUG_PRINT("save random-forest model failed: %s\n", e.what());
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"), "load model failure: {}", e.what());
return false;
}
......@@ -279,7 +279,7 @@ void RandomForestClassifier::evaluate(const std::vector<sample_type>& X_test, co
for (int true_val = 0; true_val < label_dim; ++true_val) {
if (ret < maxlen)
ret += snprintf(&ptr[ret], maxlen-ret, "真实 %d: ", true_val);
for (int pred_val = 0; pred_val < label_dim; ++pred_val) {
for (int pred_val = 0; pred_val < label_dim; ++pred_val){
if (ret < maxlen)
ret += snprintf(&ptr[ret], maxlen-ret, "%3d ", conf_matrix[true_val][pred_val]);
}
......@@ -312,9 +312,10 @@ double RandomForestClassifier::predict(const std::vector<float>& sample)
std::string filename = (!m_target.empty()) ? (m_target + "_rf_model.dat") : "rf_model.dat";
deserialize(filename) >> scaler_mean >> scaler_std >> forest;
feature_dim = scaler_mean.size();
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"), "load {} success", filename.c_str());
}
catch (const std::exception& e) {
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"), "load model failure: {}\n", e.what());
SPDLOG_LOGGER_DEBUG(spdlog::get("logger"), "load model failure: {}", e.what());
return -1;
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment