# Import packages

from __future__ import print_function, division
import numpy as np
import os
import astropy.constants as cst
from astropy import units as u
import glob
import matplotlib.pyplot as plt                                                                 # type: ignore
import xarray as xr
import time
import h5py
'''
To crop and store the grid of spectra
'''
# Imports and utilities 
NAME_OF_THE_GRID = 'R200K_1000-1300K_cloudy_fsed_2025'
c = cst.c.to(u.cm / u.s).value # Speed of light in cm/s, using astropy for more precision and tracability
Nyquist = True # If you want to use the Nyquist resolution

def decoupe(second):
    """
    Re-arranged a number of seconds in the hours-minutes-seconds format.
    Args:
        second (float): number of second
    Returns:
        - float     : hours
        - float     : minutes
        - float     : seconds

    Author: Simon Petrus, Alice Radcliffe
    """ 
    hour = second / 3600
    second %= 3600
    minute = second / 60
    second %= 60

    return hour, minute, second

# Parameters tabs
par1_tab = []
par2_tab = []
par3_tab = []
par4_tab = []
par5_tab = []

# Path to the grid files
# Paths to the grid folders you want to include
PATH_INITIAL_MODELS_LIST = [
    '/path/to/folder/1000-1100K/',
    '/path/to/folder/1100-1200K/',
    '/path/to/folder/1200-1300K/',
]

# Collect all model file paths from both folders
all_model_files = []
for path in PATH_INITIAL_MODELS_LIST:
    all_model_files.extend(glob.glob(os.path.join(path, 'spect_*')))

print(f"Found {len(all_model_files)} model files across all folders.")

#DEFINE THE REGION IN THE GRID THAT YOU NEED:
wl_min, wl_max = 1.3, 2.0
teff_min, teff_max = 1000, 1300
logg_min, logg_max = 3.0, 5.0
mh_min, mh_max = 0.32, 100  # not in dex yet (log10 form)
co_min, co_max = 0.2, 0.8
fsed_min, fsed_max = 0.3, 9
# - - - - - - - - - - - - - - - - - - 

# Iterate on all spectra
for indmod, mod in enumerate(all_model_files):
    # Extract info table
    par_tab=mod.split('/')
    par_tab=par_tab[-1]
    par_tab=par_tab.split('spect_')
    par_tab=par_tab[-1]
    par_tab=par_tab.split('.h5')
    par_tab=par_tab[0]
    par_tab=par_tab.split('_')
    par1 = par_tab[0][:-1] 
    par1_tab.append(float(par1)) # Teff
    par2 = par_tab[1][4:]
    par2_tab.append(float(par2)) # log(g)
    par3 = par_tab[2][3:]
    par3_tab.append(float(par3)) # M/H
    par4 = par_tab[3][2:]
    par4_tab.append(float(par4)) # C/O
    par5 = par_tab[4][4:]        # 'fsed1' → '1'
    par5_tab.append(float(par5))  # fsed

    # Reference table for the wavelength and resolution
    if indmod==0:
        with h5py.File("/path/to/wavenumber/wavenumber.h5", "r") as f:
            wav = f["wavenumber"][:]   
        wav_master_full = np.flip(np.asarray([float(1e4 / i) for i in wav]))  # Wavenumber (cm-1) → wavelength (µm)

        # Only keep wavelengths between 1.3 and 2.0 µm
        wav_mask = (wav_master_full >= wl_min) & (wav_master_full <= wl_max)
        wav_master = wav_master_full[wav_mask]
       # Check if you want to use Nyquist resolution or not
        if Nyquist == True:
            print('Grid uses super-Niquist sampling resolution')
            NAME_OF_THE_GRID = NAME_OF_THE_GRID 
            dl = np.abs(wav_master - np.roll(wav_master, 1))
            dl[0] = dl[1]
            R = wav_master / (3 * dl)
        else:
            dl = np.abs(wav_master - np.roll(wav_master, 1))
            dl[0] = dl[1]
            R = wav_master / dl        


# Get a tab of unique values of each parameter
par1 = np.asarray(par1_tab)
par1_uni = np.unique(par1)
par2 = np.asarray(par2_tab)
par2_uni = np.unique(par2)
par3 = np.asarray(par3_tab)
par3_uni = np.unique(par3)
par4 = np.asarray(par4_tab)
par4_uni = np.unique(par4)
par5 = np.asarray(par5_tab)
par5_uni = np.unique(par5)

#restrict parameter space:

par1_uni = par1_uni[(par1_uni >= teff_min) & (par1_uni <= teff_max)]
par2_uni = par2_uni[(par2_uni >= logg_min) & (par2_uni <= logg_max)]
par3_uni = par3_uni[(par3_uni >= mh_min) & (par3_uni <= mh_max)]
par4_uni = par4_uni[(par4_uni >= co_min) & (par4_uni <= co_max)]
par5_uni = par5_uni[(par5_uni >= fsed_min) & (par5_uni <= fsed_max)]


print()
print('Grid parameters :')
print(par1_uni)
print(par2_uni)
print(np.round(np.log10(par3_uni),1))  #in dex
print(par4_uni)
print(par5_uni)

# - - - - - - - - - - - - - - - - - - 

# Creating xarray and saving it
print('Creating xarray...')

# Define the data array (a matrix with dimensions = len(wavelengths) * len(par1) * len(par2) * len(par3) * len(par4) * len(par5), full of nan values)
DA = np.full((len(wav_master), len(par1_uni), len(par2_uni), len(par3_uni), len(par4_uni), len(par5_uni)), np.nan, dtype = np.float32)

i_tot = 1
tot_par = len(par1_uni) * len(par2_uni) * len(par3_uni) * len(par4_uni) * len(par5_uni)

# Iterate on each parameter
for p1, par1 in enumerate(par1_uni):
    for p2, par2 in enumerate(par2_uni):
        for p3, par3 in enumerate(par3_uni):
            for p4, par4 in enumerate(par4_uni):
                for p5, par5 in enumerate(par5_uni):

                    # Handle fsed formatting: drop ".0" if integer
                    if float(par5).is_integer():
                        fsed_str = f"fsed{int(par5)}"
                    else:
                        fsed_str = f"fsed{par5:.1f}"

                    # Rebuild the name of the file containing the models
                    name_to_open = (
                        "spect_"
                        + "{:.0f}".format(round(par1, 0)) + "K_logg"
                        + "{:.1f}".format(round(par2, 1)) + "_met"
                        + "{:.2f}".format(round(par3, 2)) + "_CO"
                        + "{:.2f}".format(round(par4, 2)) + "_"
                        + fsed_str
                        + ".h5"
                    )
                    
                    # --- Look for the file in all folders ---
                    file_found = False
                    for path in PATH_INITIAL_MODELS_LIST:
                        full_path = os.path.join(path, name_to_open)
                        if os.path.isfile(full_path):
                            mod = full_path
                            file_found = True
                            break

                    if file_found:
                        with h5py.File(mod, "r") as f:
                            flx_full = f["flux"][:]

                        # Flip flux to match wavelength order
                        flx_full = flx_full[::-1]

                        # Apply the wavelength mask (1.3–2.0 µm)
                        flx = flx_full[wav_mask]

                        # Convert flux to desired units
                        flx = flx * 1e4 / wav_master**2

                        # Replace the nan values by the model in the data array  
                        DA[:, p1, p2, p3, p4, p5] = flx

                        # Progress printing
                        line_up = '\033[1A'
                        line_clear = '\x1b[2K'
                        print(line_up, end=line_clear)
                        i_tot += 1


print(f'num of spec found:{i_tot}')

# Now you need to define the attribute that is read by ForMoSA to identify which parameter corresponds to which dimension of the grid
attribute = {}
attribute['key'] = ['par1', 'par2', 'par3', 'par4', 'par5']
attribute['par'] = ['teff', 'logg', 'mh', 'co', 'fsed']
attribute['title'] = ['Teff', 'log(g)', '[M/H]', 'C/O', 'fsed']
attribute['unit'] = ['(K)', '(dex)', '', '', '']
attribute['res'] = R

print("NaN count in DA:", np.isnan(DA).sum())


# Transform the data array en xarray (data set)
ds_new = xr.Dataset(data_vars=dict(grid=(["wavelength", "par1", "par2", "par3", "par4", "par5"], DA)),
                    coords={"wavelength": wav_master,
                            "par1": par1_uni,
                            "par2": par2_uni,
                            "par3": [round(np.log10(x),1) for x in par3_uni],
                            "par4": par4_uni,
                            "par5": par5_uni},
                    attrs=attribute)

# Store the data set
PATH_TO_STORE_THE_XARRAY_GRID = '/path/to/store/xarray/'


ds_new.to_netcdf(PATH_TO_STORE_THE_XARRAY_GRID + NAME_OF_THE_GRID +'.nc',
                 format='NETCDF4',
                 engine='netcdf4',
                 mode='w')




#===================A FEW USEFUL LINES=============================================
#%%

ds = xr.open_dataset('/path/to/store/xarray/R200K_cloudy_fsed_2025.nc', decode_cf=False, engine="netcdf4")
print(ds['grid'])
print(ds.attrs['res'])
print(ds.coords['wavelength'])

# %%

grid = ds['grid']
wave = grid['wavelength'].values
flux = grid.sel(par1=1000,par2=4.5, par3 = 0.0, par4 = 0.55, par5 = 7, method="nearest")
## Check
plt.plot(wave, flux)

# %%
#to slice the grid even further if needbe:
ds_new_subset = ds_new.sel(
    wavelength=slice(1.0, 1.5),  # µm
    par1=slice(1000, 1200),
    par2=slice(3.0, 4.0, 5.0),
    par3=slice(-0.5, 0),
)
