#region: Modules.
from fp.plotting.absorption import BseSpectrumPlot
from fp.plotting.gwelbands import GwelbandsPlot
from fp.plotting.phbands import PhbandsPlot
from fp.plotting.dftelbands import DftelbandsPlot
from fp.analysis.xctpol import XctPolCalc
from fp.plotting.xctph import XctphPlot
import h5py 
#endregion

#region: Variables.
#endregion

#region: Functions.
def plot_bse():
    # Plot the bsespectrum. 
    abs = BseSpectrumPlot(
        './absorption_eh.dat',
        './absorption_noeh.dat',
    )
    abs.save_plot('bse_plot.png', show=True)

def plot_dftelbands():
    dftelbands = DftelbandsPlot(
        dftelbands_xml_filename='./dftelbands.xml',
        bandpathpkl_filename='./bandpath.pkl',
        fullgridflow_filename='./fullgridflow.pkl'
    )
    dftelbands.save_plot('dftelbands_plot.png', show=True, ylim=[-2, 2])

def plot_gwelbands():
    gwelbands = GwelbandsPlot(
        inteqp_filename='./bandstructure_inteqp.dat',
        bandpathpkl_filename='./bandpath.pkl',
        fullgridflow_filename='./fullgridflow.pkl',
    )
    gwelbands.save_plot('gwelbands_plot.png', show=True, offset=1.0)

def plot_phbands():
    phbands = PhbandsPlot(
        phbands_filename='./struct.freq.gp',
        bandpathpkl_filename='./bandpath.pkl',
        fullgridflow_filename='./fullgridflow.pkl',
    )
    phbands.save_plot('phbands_plot.png', show=True)

def plot_xctph():
    xctphplot = XctphPlot(
        xctph_filename='xctph_elhole.h5',
        phbands_filename='struct.freq.gp',
        fullgridflow_filename='fullgridflow.pkl',
        bandpatkpkl_filename='bandpath.pkl',
        xctph_mult_factor=5e2,
        xct_state=0,   # 0 based index. 
        xct_Qpt_idx=0, # 0 based index. 
    )
    xctphplot.save_plot('xctph_plot.png', show=True)

def get_xctpol_results():
    xctpol_result = XctPolCalc(
        './eph_xctph.h5',
        './xctph_elhole.h5',
        './xctph_hole.h5',
        max_error=1.0e-4,
        max_steps=10,
    )


    # Stack. 
    xctpol_result.assemble_components()
    
def main():
    plot_phbands()
    # plot_dftelbands()
    # plot_gwelbands()
    # plot_bse()
    # plot_xctph()
#endregion

#region: Classes.
#endregion

#region: Main.
if __name__=='__main__':
    main()
#endregion