#include "FeatureExtractor.h"
#include <fftw3.h>
#include <algorithm>
#include <numeric>
#include <cmath>
#include <iostream>
#include <complex>
#include <sys/time.h>
#include "spdlog/spdlog.h"
#include "common.h"

using namespace std;

// 预定义的滤波器系数表 - 直接使用Python生成的系数
const int g_filter_coeffs_count = 2;
const struct FilterCoefficients g_filter_coeffs_tbl[]
{
    {FilterType::HighPass, 200, 0,
    {{0.43284664f, -1.73138658f, 2.59707987f, -1.73138658f, 0.43284664f},
     {1.0f, -2.36951301f, 2.31398841f, -1.05466541f, 0.18737949f}}},

    {FilterType::BandPass, 500, 900,
    {{0.04658291f, 0.0f, -0.18633163f, 0.0f, 0.27949744f, 0.0f, -0.18633163f, 0.0f, 0.04658291f},
    {1.0f, 3.47439553f, 5.54673549f, 5.71826677f, 4.3888535f, 2.46412347f, 0.92628774f, 0.22025224f, 0.03011888f}}}
};

void hexdump(std::vector<float> data, int len)
{
    for (int i=0; i<len; ++i)
    {
        printf("%.8f ", data[i]);
        if (!((i+1)%6)) printf("\n");
    }
    printf("\n");
}

FeatureExtractor::FeatureExtractor() : m_fs(2000), m_cutoff(200), m_lowcut(500), m_highcut(900), m_order(4) {}

FeatureExtractor::FeatureExtractor(int sample_freq, int cutoff, int lowcut, int highcut, int order) :
    m_fs(sample_freq), m_cutoff(cutoff), m_lowcut(lowcut), m_highcut(highcut), m_order(order) {}

FeatureExtractor::~FeatureExtractor() {}

SignalFeatures FeatureExtractor::extract_features(const std::vector<float>& signal)
{
    SignalFeatures features;
    features.clear();

    if (signal.empty()) {
        std::cerr << "Error: Input signal is empty!" << std::endl;
        return features;
    }

    unsigned long elapsed_time = 0;
    struct timeval start,end;
    gettimeofday(&start, NULL);
    std::vector<float> processed_signal = signal;

    // 1. 高通滤波处理 - 使用预定义系数
    processed_signal = butterworth_highpass_filter(processed_signal, m_cutoff, m_order);

    // 2. 带通滤波处理 - 使用预定义系数
    processed_signal = butterworth_bandpass_filter(processed_signal, m_lowcut, m_highcut, m_order);

    // 3. 希尔伯特变换提取包络
    std::vector<float> envelope = hilbert_transform(processed_signal);

    // 4. 计算包络谱
    auto envelope_spectrum_result = compute_envelope_spectrum(envelope);
    std::vector<float> envelope_freq = envelope_spectrum_result.first;
    std::vector<float> envelope_spectrum = envelope_spectrum_result.second;

    // 5. 提取时域特征
    extract_time_domain_features(envelope, features);

    // 6. 提取频域特征
    if (!envelope_freq.empty() && !envelope_spectrum.empty()) {
        float sum_fs = 0.0f, sum_s = 0.0f;
        for (size_t i = 0; i < envelope_spectrum.size(); ++i) {
            sum_fs += envelope_freq[i] * envelope_spectrum[i];
            sum_s += envelope_spectrum[i];
        }
        features.Envelope_SpectralCentroid = (sum_s > 0) ? sum_fs / sum_s : 0.0f;

        features.Envelope_TotalEnergy = 0.0f;
        for (float val : envelope_spectrum) {
            features.Envelope_TotalEnergy += val * val;
        }
    }

    gettimeofday(&end, NULL);
    elapsed_time = (end.tv_sec - start.tv_sec) * 1000L + (end.tv_usec - start.tv_usec)/1000;
    SPDLOG_LOGGER_DEBUG(spdlog::get("logger"),"[extract_features] elapsed_time: {}", elapsed_time);
    return features;
}

std::vector<float> FeatureExtractor::filtfilt(const std::vector<float>& b, const std::vector<float>& a, const std::vector<float>& x) 
{
    if (x.empty() || b.empty() || a.empty() || x.size() < 100 || std::abs(a[0]) < 1e-10f){
        DEBUG_PRINT("invalid param\n");
        return x;
    }

    const size_t n = x.size();
    const size_t nb = b.size();
    const size_t na = a.size();
    
    // 直接使用scipy标准边界计算（无需复杂保护）
    const size_t nz = std::max(nb, na) - 1;
    const size_t edge = 3 * nz;  // scipy标准

    // 对于您的场景，n总是远大于edge，所以直接使用
    const size_t ext_size = 2 * edge + n;
    std::vector<float> ext_signal(ext_size);
    std::vector<float> y_forward(ext_size, 0.0f);

    // 信号扩展 - 简化版本（n足够大，无需边界检查）
    // 左端扩展
    for (size_t i = 0; i < edge; ++i) {
        ext_signal[i] = 2.0f * x[0] - x[edge - i];
    }
    
    // 中间原始信号
    std::copy(x.begin(), x.end(), ext_signal.begin() + edge);
    
    // 右端扩展
    for (size_t i = 0; i < edge; ++i) {
        ext_signal[edge + n + i] = 2.0f * x[n - 1] - x[n - 2 - i];
    }

    // 预计算倒数
    const float a0_inv = 1.0f / a[0];

    // 前向滤波
    for (size_t i = 0; i < ext_size; ++i) {
        float result = 0.0f;
        
        for (size_t j = 0; j < nb && j <= i; ++j) {
            result += b[j] * ext_signal[i - j];
        }
        
        for (size_t j = 1; j < na && j <= i; ++j) {
            result -= a[j] * y_forward[i - j];
        }
        
        y_forward[i] = result * a0_inv;
    }

    // 反向滤波
    std::vector<float> y_reversed(y_forward.rbegin(), y_forward.rend());
    std::vector<float> y_temp(ext_size, 0.0f);
    
    for (size_t i = 0; i < ext_size; ++i) {
        float result = 0.0f;
        
        for (size_t j = 0; j < nb && j <= i; ++j) {
            result += b[j] * y_reversed[i - j];
        }
        
        for (size_t j = 1; j < na && j <= i; ++j) {
            result -= a[j] * y_temp[i - j];
        }
        
        y_temp[i] = result * a0_inv;
    }

    // 提取结果
    std::vector<float> y_backward(y_temp.rbegin(), y_temp.rend());
    std::vector<float> result(n);
    
    for (size_t i = 0; i < n; ++i) {
        result[i] = y_backward[edge + i];
    }

    return result;
}

std::pair<std::vector<float>, std::vector<float>> FeatureExtractor::get_coeffs(FilterType type, int lowcut, int highcut, int order)
{
    order = order;
    for (int i=0; i<g_filter_coeffs_count; ++i)
    {
        if (g_filter_coeffs_tbl[i].type == type && g_filter_coeffs_tbl[i].lowcut == lowcut && g_filter_coeffs_tbl[i].highcut == highcut)
            return g_filter_coeffs_tbl[i].coeffs;
    }
	
    return { {1.0f}, {1.0f} };
}

std::vector<float> FeatureExtractor::butterworth_highpass_filter(const std::vector<float>& signal, int cutoff, int order)
{
    order = order;
    auto coeffs = get_coeffs(FilterType::HighPass, cutoff, 0, 4);
    return filtfilt(coeffs.first, coeffs.second, signal);
}

std::vector<float> FeatureExtractor::butterworth_bandpass_filter(const std::vector<float>& signal, int lowcut, int highcut, int order)
{
    order = order;
    auto coeffs = get_coeffs(FilterType::BandPass, lowcut, highcut, 4);
    return filtfilt(coeffs.first, coeffs.second, signal);
}

// 希尔伯特变换提取包络
std::vector<float> FeatureExtractor::hilbert_transform(const std::vector<float>& signal) {
    size_t N = signal.size();
    if (N == 0) return {};

    // 计算FFT
    auto freq_domain = compute_fft(signal);

    // 构造解析信号频谱（负频率置零）
    for (size_t i = N/2 + 1; i < N; ++i) {
        freq_domain[i] = std::complex<float>(0, 0);
    }
    if (N % 2 == 0) {
        freq_domain[N/2] = freq_domain[N/2] * 0.5f;
    }

    // 计算逆FFT得到解析信号
    std::vector<std::complex<float>> analytic_signal = compute_ifft(freq_domain);

    // 提取包络（模值）
    std::vector<float> envelope(N);
    for (size_t i = 0; i < N; ++i) {
        envelope[i] = std::abs(analytic_signal[i]);
    }

    return envelope;
}

// 计算FFT
std::vector<std::complex<float>> FeatureExtractor::compute_fft(const std::vector<float>& signal) {
    int N = signal.size();
    if (N == 0) return {};

    fftwf_complex* in = fftwf_alloc_complex(N);
    fftwf_complex* out = fftwf_alloc_complex(N);
    if (in == nullptr || out == nullptr) {
        if (in) fftwf_free(in);
        if (out) fftwf_free(out);
        return {};
    }

    fftwf_plan plan = fftwf_plan_dft_1d(N, in, out, FFTW_FORWARD, FFTW_ESTIMATE);

    for (int i = 0; i < N; ++i) {
        in[i][0] = signal[i];
        in[i][1] = 0.0f;
    }

    fftwf_execute(plan);

    std::vector<std::complex<float>> result(N);
    for (int i = 0; i < N; ++i) {
        result[i] = std::complex<float>(out[i][0], out[i][1]);
    }

    fftwf_destroy_plan(plan);
    fftwf_free(in);
    fftwf_free(out);

    return result;
}

// 计算逆FFT
std::vector<std::complex<float>> FeatureExtractor::compute_ifft(const std::vector<std::complex<float>>& freq_domain) {
    int N = freq_domain.size();
    if (N == 0) return {};

    fftwf_complex* in = fftwf_alloc_complex(N);
    fftwf_complex* out = fftwf_alloc_complex(N);
    if (in == nullptr || out == nullptr) {
        if (in) fftwf_free(in);
        if (out) fftwf_free(out);
        return {};
    }

    fftwf_plan plan = fftwf_plan_dft_1d(N, in, out, FFTW_BACKWARD, FFTW_ESTIMATE);

    for (int i = 0; i < N; ++i) {
        in[i][0] = freq_domain[i].real();
        in[i][1] = freq_domain[i].imag();
    }

    fftwf_execute(plan);

    std::vector<std::complex<float>> result(N);
    for (int i = 0; i < N; ++i) {
        result[i] = std::complex<float>(out[i][0], out[i][1]) / static_cast<float>(N);
    }

    fftwf_destroy_plan(plan);
    fftwf_free(in);
    fftwf_free(out);

    return result;
}

// 计算包络谱
std::pair<std::vector<float>, std::vector<float>>
FeatureExtractor::compute_envelope_spectrum(const std::vector<float>& envelope)
{
    const size_t N = envelope.size();
    if (N < 2) {
        DEBUG_PRINT("请先提取包络信号");
        return {{}, {}};
    }

    // 1. 去直流分量
    float sum = 0.0f;
    for (size_t i = 0; i < N; ++i) {
        sum += envelope[i];
    }
    float mean_val = sum / N;

    std::vector<float> centered_envelope(N);
    for (size_t i = 0; i < N; ++i) {
        centered_envelope[i] = envelope[i] - mean_val;
    }

    // FFT计算
    std::vector<std::complex<float>> yf = compute_fft(centered_envelope);

    // 生成频率轴
    std::vector<float> xf(N/2);
    for (size_t i = 0; i < N/2; ++i) {
        xf[i] = static_cast<float>(i) * m_fs / N;
    }

    // 计算功率谱（单边）
    std::vector<float> power(N/2);
    for (size_t i = 0; i < N/2; ++i) {
        power[i] = 2.0f / N * std::abs(yf[i]);
    }

    //DEBUG_PRINT("包络谱计算完成，频率范围: 0-%f Hz，共%ld个点\n", xf[N/2-1], power.size());
    return {xf, power};
}

//提取包络信号的时域特征
bool FeatureExtractor::extract_time_domain_features(const std::vector<float>& signal, SignalFeatures& features)
{
    const size_t N = signal.size();
    if (N < 100){
        DEBUG_PRINT("invalid input: signal_size=%ld", N);
        return false;
    }

    // 初始化累积变量
    float sum = 0.0f;
    float sum_squares = 0.0f;
    float sum_abs = 0.0f;
    float sum_sqrt_abs = 0.0f;
    float sum_fourth = 0.0f;
    float peak = signal[0];

    // 单次遍历，累积所有中间结果
    for (size_t i = 0; i < N; ++i) {
        float val = signal[i];
        float abs_val = std::abs(val);
        float square_val = val * val;

        sum += val;
        sum_squares += square_val;
        sum_abs += abs_val;
        sum_sqrt_abs += std::sqrt(abs_val);
        if (val > peak) peak = val;
    }

    // 基于中间结果计算最终特征值
    float mean = sum / N;
    float mean_squares = sum_squares / N;
    float mean_abs = sum_abs / N;
    float mean_sqrt_abs = sum_sqrt_abs / N;
    float variance = (sum_squares - sum * mean) / (N - 1); // 样本方差
    if (variance < 0.0f) variance = 0.0f;

    features.Envelope_Mean = mean;
    features.Envelope_RMS = std::sqrt(mean_squares);
    features.Envelope_Peak = peak;
    features.Envelope_Std = std::sqrt(variance);

    // 计算峭度需要方差不为零
    sum_fourth = 0.0f;
    for (size_t i = 0; i < N; ++i)
        sum_fourth += (signal[i]-mean)*(signal[i]-mean)*(signal[i]-mean)*(signal[i]-mean);
    if (variance > 1e-10f) {
        float mean_fourth = sum_fourth / N;
        features.Envelope_Kurtosis = (mean_fourth) / (variance * variance) - 3.0f;
    } else {
        features.Envelope_Kurtosis = 0.0f;
    }

    features.Envelope_CrestFactor = (features.Envelope_RMS > 1e-10f) ? peak / features.Envelope_RMS : 0.0f;
    features.Envelope_ImpulseFactor = (mean_abs > 1e-10f) ? peak / mean_abs : 0.0f;
    features.Envelope_ClearanceFactor = (mean_sqrt_abs > 1e-10f) ? peak / (mean_sqrt_abs * mean_sqrt_abs) : 0.0f;
    return true;
}
