// SPDX-License-Identifier: LGPL-3.0-or-later
// Author: Kristian Lytje

#include <hist/intensity_calculator/CompositeDistanceHistogramFFExplicitBase.h>
#include <hist/Histogram.h>
#include <table/ArrayDebyeTable.h>
#include <form_factor/lookup/FormFactorProduct.h>
#include <form_factor/lookup/ExvFormFactorProduct.h>
#include <settings/HistogramSettings.h>
#include <utility/MultiThreading.h>

using namespace ausaxs;
using namespace ausaxs::hist;

template<typename AA, typename AXFormFactorTableType, typename XX>
CompositeDistanceHistogramFFExplicitBase<AA, AXFormFactorTableType, XX>::CompositeDistanceHistogramFFExplicitBase() = default;

template<typename AA, typename AXFormFactorTableType, typename XX>
CompositeDistanceHistogramFFExplicitBase<AA, AXFormFactorTableType, XX>::CompositeDistanceHistogramFFExplicitBase(const CompositeDistanceHistogramFFExplicitBase&) = default;

template<typename AA, typename AX, typename XX>
CompositeDistanceHistogramFFExplicitBase<AA, AX, XX>::CompositeDistanceHistogramFFExplicitBase(CompositeDistanceHistogramFFExplicitBase&&) noexcept = default;

template<typename AA, typename AX, typename XX>
CompositeDistanceHistogramFFExplicitBase<AA, AX, XX>& CompositeDistanceHistogramFFExplicitBase<AA, AX, XX>::operator=(CompositeDistanceHistogramFFExplicitBase&&) noexcept = default;

template<typename AA, typename AX, typename XX>
CompositeDistanceHistogramFFExplicitBase<AA, AX, XX>& CompositeDistanceHistogramFFExplicitBase<AA, AX, XX>::operator=(const CompositeDistanceHistogramFFExplicitBase&) = default;

template<typename AA, typename AXFormFactorTableType, typename XX>
CompositeDistanceHistogramFFExplicitBase<AA, AXFormFactorTableType, XX>::~CompositeDistanceHistogramFFExplicitBase() = default;

template<typename AA, typename AXFormFactorTableType, typename XX>
CompositeDistanceHistogramFFExplicitBase<AA, AXFormFactorTableType, XX>::CompositeDistanceHistogramFFExplicitBase(
    hist::Distribution3D&& p_aa, 
    hist::Distribution2D&& p_aw, 
    hist::Distribution1D&& p_ww,
    hist::Distribution1D&& p_tot
) : CompositeDistanceHistogramFFAvgBase<AA>(std::move(p_aa), std::move(p_aw), std::move(p_ww), std::move(p_tot)) {}

template<typename AA, typename AXFormFactorTableType, typename XX>
CompositeDistanceHistogramFFExplicitBase<AA, AXFormFactorTableType, XX>::CompositeDistanceHistogramFFExplicitBase(
    hist::Distribution3D&& p_aa, 
    hist::Distribution2D&& p_aw, 
    hist::Distribution1D&& p_ww, 
    hist::WeightedDistribution1D&& p_tot
) : CompositeDistanceHistogramFFAvgBase<AA>(std::move(p_aa), std::move(p_aw), std::move(p_ww), std::move(p_tot)) {}

template<typename AA, typename AXFormFactorTableType, typename XX>
const AA CompositeDistanceHistogramFFExplicitBase<AA, AXFormFactorTableType, XX>::get_ffaa_table() const {
    return this->get_ff_table();
}

template<typename AA, typename AXFormFactorTableType, typename XX>
double CompositeDistanceHistogramFFExplicitBase<AA, AXFormFactorTableType, XX>::exv_factor(double q, double cx) {
    constexpr double rm2 = constants::radius::average_atomic_radius*constants::radius::average_atomic_radius/4;
    return std::pow(cx, 3)*std::exp(-rm2*(std::pow(cx, 2) - 1)*q*q);
}

template<typename AA, typename AXFormFactorTableType, typename XX>
double CompositeDistanceHistogramFFExplicitBase<AA, AXFormFactorTableType, XX>::exv_factor(double q) const {
    return exv_factor(q, this->free_params.cx);
}

template<typename AA, typename AXFormFactorTableType, typename XX>
void CompositeDistanceHistogramFFExplicitBase<AA, AXFormFactorTableType, XX>::cache_refresh_sinqd() const {
    auto pool = utility::multi_threading::get_global_pool();
    const auto& sinqd_table = this->sinc_table.get_sinc_table();

    Axis debye_axis = constants::axes::q_axis.sub_axis(settings::axes::qmin, settings::axes::qmax);
    unsigned int q0 = constants::axes::q_axis.get_bin(settings::axes::qmin);

    if (exv_cache.sinqd.aa.empty()) {
        // We only need one sinqd calculation per histogram type since aa/ax/xx share the same data
        exv_cache.sinqd.aa = container::Container3D<double>(form_factor::get_count(), form_factor::get_count(), debye_axis.bins);
        exv_cache.sinqd.aw = container::Container2D<double>(form_factor::get_count(), debye_axis.bins);
        exv_cache.sinqd.ww = container::Container1D<double>(debye_axis.bins);
    }

    for (unsigned int ff1 = 0; ff1 < form_factor::get_count_without_excluded_volume(); ++ff1) {
        for (unsigned int ff2 = 0; ff2 < form_factor::get_count_without_excluded_volume(); ++ff2) {
            pool->detach_task([this, q0, bins=debye_axis.bins, ff1, ff2, sinqd_table] () {
                for (unsigned int q = q0; q < q0+bins; ++q) {
                    // Single sinqd calculation - the same histogram is used for aa, ax, and xx
                    exv_cache.sinqd.aa.index(ff1, ff2, q-q0) = std::inner_product(this->distance_profiles.aa.begin(ff1, ff2), this->distance_profiles.aa.end(ff1, ff2), sinqd_table->begin(q), 0.0);
                }
            });
        }
        pool->detach_task([this, q0, bins=debye_axis.bins, ff1, sinqd_table] () {
            for (unsigned int q = q0; q < q0+bins; ++q) {
                // Single sinqd calculation - the same histogram is used for aw and wx
                exv_cache.sinqd.aw.index(ff1, q-q0) = std::inner_product(this->distance_profiles.aw.begin(ff1), this->distance_profiles.aw.end(ff1), sinqd_table->begin(q), 0.0);
            }
        });
    }
    pool->detach_task([&] () {
        for (unsigned int q = q0; q < q0+debye_axis.bins; ++q) {
            exv_cache.sinqd.ww.index(q-q0) = std::inner_product(this->distance_profiles.ww.begin(), this->distance_profiles.ww.end(), sinqd_table->begin(q), 0.0);
        }
    });
    exv_cache.sinqd.valid = true;
    pool->wait();
}

template<typename AA, typename AXFormFactorTableType, typename XX>
void CompositeDistanceHistogramFFExplicitBase<AA, AXFormFactorTableType, XX>::cache_refresh_intensity_profiles(bool sinqd_changed, bool cw_changed, bool cx_changed) const {
    auto pool = utility::multi_threading::get_global_pool();
    const auto& ff_aa_table = get_ffaa_table();
    const auto& ff_ax_table = get_ffax_table();
    const auto& ff_xx_table = get_ffxx_table();
    const auto& sinqd_table = this->sinc_table.get_sinc_table();

    Axis debye_axis = constants::axes::q_axis.sub_axis(settings::axes::qmin, settings::axes::qmax);
    unsigned int q0 = constants::axes::q_axis.get_bin(settings::axes::qmin); // account for a possibly different qmin

    if (sinqd_changed) {
        this->cache.intensity_profiles.aa = std::vector<double>(debye_axis.bins, 0);
    }
    if (cw_changed) {
        this->cache.intensity_profiles.aw = std::vector<double>(debye_axis.bins, 0);
        this->cache.intensity_profiles.ww = std::vector<double>(debye_axis.bins, 0);
    }
    if (cx_changed) {
        this->cache.intensity_profiles.ax = std::vector<double>(debye_axis.bins, 0);
        this->cache.intensity_profiles.xx = std::vector<double>(debye_axis.bins, 0);
    }
    if (cw_changed || cx_changed) {
        this->cache.intensity_profiles.wx = std::vector<double>(debye_axis.bins, 0);
    }

    std::vector<double> cx(debye_axis.bins, 0);
    for (unsigned int q = q0; q < q0+debye_axis.bins; ++q) {cx[q-q0] = exv_factor(constants::axes::q_vals[q]);}

    if (sinqd_changed) {
        pool->detach_task([&] () {
            for (unsigned int ff1 = 0; ff1 < form_factor::get_count_without_excluded_volume(); ++ff1) {
                for (unsigned int ff2 = 0; ff2 < form_factor::get_count_without_excluded_volume(); ++ff2) {
                    for (unsigned int q = q0; q < q0+debye_axis.bins; ++q) {
                        this->cache.intensity_profiles.aa[q-q0] += 
                            exv_cache.sinqd.aa.index(ff1, ff2, q-q0)*ff_aa_table.index(ff1, ff2).evaluate(q);
                    }
                }
            }
        });
    }

    if (cx_changed) {
        // Use the same sinqd.aa values but with different form factor tables for ax and xx
        // For ax: subtract self-correlations at distance bin 0
        pool->detach_task([&] () {
            for (unsigned int ff1 = 0; ff1 < form_factor::get_count_without_excluded_volume(); ++ff1) {
                for (unsigned int ff2 = 0; ff2 < form_factor::get_count_without_excluded_volume(); ++ff2) {
                    // Get the self-correlation contribution at distance bin 0
                    double self_correlation = this->distance_profiles.aa.index(ff1, ff2, 0);
                    for (unsigned int q = q0; q < q0+debye_axis.bins; ++q) {
                        // Subtract self-correlation * sinqd(q, d=0) from the sinqd.aa value
                        double sinqd_ax = exv_cache.sinqd.aa.index(ff1, ff2, q-q0) - self_correlation * sinqd_table->lookup(q, 0);
                        this->cache.intensity_profiles.ax[q-q0] += 
                            2*this->free_params.crho*cx[q-q0]*sinqd_ax*ff_ax_table.index(ff1, ff2).evaluate(q);
                    }
                }
            }
        });
        pool->detach_task([&] () {
            for (unsigned int ff1 = 0; ff1 < form_factor::get_count_without_excluded_volume(); ++ff1) {
                for (unsigned int ff2 = 0; ff2 < form_factor::get_count_without_excluded_volume(); ++ff2) {
                    for (unsigned int q = q0; q < q0+debye_axis.bins; ++q) {
                        this->cache.intensity_profiles.xx[q-q0] += 
                            std::pow(cx[q-q0]*this->free_params.crho, 2)*exv_cache.sinqd.aa.index(ff1, ff2, q-q0)*ff_xx_table.index(ff1, ff2).evaluate(q);
                    }
                }
            }
        });
    }

    if (cw_changed) {
        pool->detach_task([&] () {
            for (unsigned int ff1 = 0; ff1 < form_factor::get_count_without_excluded_volume(); ++ff1) {
                for (unsigned int q = q0; q < q0+debye_axis.bins; ++q) {
                    this->cache.intensity_profiles.aw[q-q0] += 
                        2*this->free_params.cw*exv_cache.sinqd.aw.index(ff1, q-q0)
                        *ff_aa_table.index(ff1, form_factor::water_bin).evaluate(q);
                }
            }
        });
        pool->detach_task([&] () {
            for (unsigned int q = q0; q < q0+debye_axis.bins; ++q) {
                this->cache.intensity_profiles.ww[q-q0] += 
                    this->free_params.cw*this->free_params.cw*exv_cache.sinqd.ww.index(q-q0)
                    *ff_aa_table.index(form_factor::water_bin, form_factor::water_bin).evaluate(q);
            }
        });
    }

    if (cw_changed || cx_changed) {
        // Use the same sinqd.aw values but with different form factor table for wx
        pool->detach_task([&] () {
            for (unsigned int ff1 = 0; ff1 < form_factor::get_count_without_excluded_volume(); ++ff1) {
                for (unsigned int q = q0; q < q0+debye_axis.bins; ++q) {
                    this->cache.intensity_profiles.wx[q-q0] += 
                        2*this->free_params.crho*cx[q-q0]*this->free_params.cw*exv_cache.sinqd.aw.index(ff1, q-q0)
                        *ff_ax_table.index(form_factor::water_bin, ff1).evaluate(q);
                }
            }
        });
    }
    this->cache.intensity_profiles.cached_cx = this->free_params.cx;
    this->cache.intensity_profiles.cached_crho = this->free_params.crho;
    this->cache.intensity_profiles.cached_cw = this->free_params.cw;
    pool->wait();    
}

template class hist::CompositeDistanceHistogramFFExplicitBase<
    form_factor::lookup::atomic::table_t, form_factor::lookup::cross::table_t, form_factor::lookup::exv::table_t
>;