""" Module for defining CCCma NEMO grids, reading NEMO netCDF and handing
    over to CMOR for writing in CMIP6/CF compliant output.

    NCS, 12/2018

    TO DO:
      - Implement proper masking of NEMO data (for land points). Currently var=0.0 is masked.
      - Fix defintion of z-grid bounds
      - Figure out where to determine if positive=up or down
      - Implement addition "grid types", including "basin" and "oline" dimensions
"""
import cmor
import nmor_types
from nmor_types import nmor_var
import xarray as xr
import numpy as np
import warnings
import traceback
import json
import cftime

def nemo_to_cmor(cmortable, cmorvar, vardata, cccts, cccvar, nemo_grid, rowvar,
                 positive='down', user_f='cccma_user_input.json', cmor_logfile=None):
    """Read a CCCma timeseries, define grids, and pass to CMOR

       Inputs:
          cmortable  (str)          : The CMOR table being used (e.g. 'CMIP6_Omon.json')
          cmorvar    (str)          : The CMOR variable name (e.g. 'thetao')
          vardata    (xr.dataset)   : Dataset containing all the data in cccts
          units      (str)          : Units of the data provided in cccts/cccvar (e.g. 'K').
          cccts      (str)          : filename (or fullpath) of the nemo netcdf timeseries
                                    : Note that this is only used to determine the grid stagger
          cccvar     (str)          : The CCCma variable name, as found in the cccts file.
          rowvar     (dict)         : Contains all the columns from the spreadsheet about this variable
          nemo_grid  (xr.dataset)   : Specifies the grid ("mesh_mask") file to use
          cmor_dims  (list)         : List all the CMOR specified dimensions
          positive   (str)          : 'up' or 'down' for fluxes
          user_f     (str)          : filename (or fullpath) of the user json file to be used by cmor

      Outputs:
          CMOR output is written to stdout/stderr
          If successful, returns True and the resulting CMORized file will appear in the CMIP6 directory.
    """
    misval = 1.e20 # Set the missing value to use for masked variable. This is defined by CMOR

    # Create convenience variables from spreadsheet information
    ccchistory = str(rowvar['CCCma optional deck'])
    ccccomment = str(rowvar['CCCma NetCDF comment'])
    cmor_dims = rowvar['dimensions'].split(' ')
    units = str(rowvar['units'])

    #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    # Load & setup CMOR tables
    #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

    # CMOR tables location and setup
    d = {
        'inpath'                : 'cmip6-cmor-tables/Tables'
    ,   'netcdf_file_action'    : cmor.CMOR_REPLACE_4
    }
    # if log file defined and has a unicode string name set into cmor
    if cmor_logfile:
        if isinstance(cmor_logfile, (str, unicode)):
            d.update({'logfile' : cmor_logfile})
        else:
            print("cmor log filename must be a unicoded string...")
            print("\tsending cmor output to stdout")

    cmor.setup(**d)

    # Define user json file
    cmor.dataset_json(user_f)
    with open(user_f,'r') as f:
        user_tab = json.load(f)
    # Also load the coordinates table for some special dimensions
    with open(d['inpath']+'/CMIP6_coordinate.json','r') as f:
        coord_tab = json.load(f)

    tables = {}
    # Load all needed tables and assign shorthand ids
    tables['grids']      = cmor.load_table("CMIP6_grids.json")
    tables['cmortable']  = cmor.load_table(cmortable)

    #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    # Open NEMO file and parse grid dimensions
    #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    # Load the variable into the 'nmor_var' class which contains the xarray dataset and additional
    # grid information
    dr = nmor_var(vardata, cccts, cccvar, cmor_dims)
    dr.vardata = strip_extra_row_cols(dr.vardata, dr.has_x, dr.has_y, dr.has_z, zvar = dr.depth_var)
    # Skip applying the mask if this is a 'basin' diagnostic
    if dr.field_type != 'basin':
        dr.apply_mask(nemo_grid, misval)

    #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    # Define grid & variable using CMOR table properties
    #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    # horizontal axes
    # For the NEMO tripolar grid we need to define the index axes first, and
    # then define the actual grid with 2D lon/lat using cmor.grid.

    # Define the list containing all the cmor axes ids. This will be appended to before "CMOR writing"
    axes = []
    cmor.set_table(tables['cmortable']) # the actual CMOR table being used, e.g. Omon

    if dr.has_x and dr.has_y:
        # Take the lon/lat grid from the nemo_grid file instead of the timeseries file
        latvar = 'gphi'+dr.hor_stagger
        lonvar = 'glam'+dr.hor_stagger


        cmor.set_table(tables['grids'])
        # Ensure that longitude goes from 0-360
        lon_coords = np.mod(nemo_grid[lonvar].squeeze().values,360.)
        x = dr.vardata.x.values
        ilon = cmor.axis(table_entry= 'i_index',
                         units= '1', coord_vals=x)
        lat_coords = nemo_grid[latvar].squeeze().values
        y = dr.vardata.y.values
        ilat = cmor.axis(table_entry= 'j_index',
                         units= '1', coord_vals=y)
        # Data contains x & y dimensions and grid needs to be defined
        lat_vertices, lon_vertices = dr.calc_vertices(nemo_grid)
        lon_vertices = np.mod(lon_vertices,360)
        grid_id = cmor.grid([ilat, ilon], lat_coords, lon_coords, lat_vertices, lon_vertices)
        cmor.set_table(tables['cmortable'])
    elif dr.has_x:
        lon_id = cmor.axis(table_entry='longitude',units='degrees_east', coord_vals=dr.vardata.nav_lon.values)
    elif dr.has_y:
        lat_bnds = dr.calc_lat_bnds()
        lat_id = cmor.axis(table_entry='latitude',units='degrees_north', coord_vals=dr.vardata.nav_lat.values,
                           cell_bounds = lat_bnds)

    # Time axis
    if dr.has_t:
        # Determine the name of the time dimension requested by CMOR, note that str is needed because
        # xarray stores the dimensions as unicode and not as a python string
        cmor_time_dim = str([dim for dim in cmor_dims if 'time' in dim][0])
        itime = cmor.axis(table_entry = cmor_time_dim, units=user_tab['#time_axis_origin'])
        axes.append(itime)

    if dr.has_z:
        # Note here that the bottom level of all files is identically 0.
        depth_coords = dr.vardata[dr.depth_var].values # depths
        # For a t-grid, the cell bounds are top and bottom interface (w-points)
        if   dr.ver_stagger == 'w':
            cell_bounds = np.append(0.,nemo_grid['gdept_0'][0,:-1].copy())
        elif dr.ver_stagger == 't':
            cell_bounds = np.append(0.,nemo_grid['gdepw_0'][0,1:].copy())
        idep = cmor.axis(table_entry='depth_coord',
                         units='m', coord_vals=depth_coords,
                         cell_bounds=cell_bounds)

    if   dr.field_type == '3d':
        axes.extend((idep, grid_id))
    elif dr.field_type == '2d':
        axes.append(grid_id)

    # Deal with some special cases
    if 'basin' in cmor_dims:
        basin_names = [ str(name) for name in coord_tab['axis_entry']['basin']['requested'] ]
        basin_id = cmor.axis(table_entry='basin', coord_vals = basin_names, units ="")
        axes.append(basin_id)
        if dr.has_z:
            axes.append(idep)
        axes.append(lat_id)
    if 'oline' in cmor_dims:
        line_names = [ str(name) for name in coord_tab['axis_entry']['oline']['requested'] ]
        axes.append(cmor.axis(table_entry='oline', coord_vals = line_names, units = ""))

    if ccccomment == "None": ccccomment = None # needed incase a null comment was cast to a string
    if ccchistory == "None": ccchistory = None
    varid = cmor.variable(cmorvar, units, axes, positive=positive,
                          original_name=cccvar, history=ccchistory,
                          comment=ccccomment, missing_value=misval)
    #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    # Do the CMOR writing
    #+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
    # Arguments are slightly different if we are writing a static 'fixed' field or a time-varying one
    if dr.has_t:
        # Shift the time bounds to the correct time origin and calendar
        dr.shift_time_origin( user_tab['#time_axis_origin'], user_tab['calendar'] )
        time_bounds = check_time_bounds(cmortable, dr.shifted_time, dr.shifted_time_bnds,
                                        user_tab['#time_axis_origin'], user_tab['calendar'])
        for i in range(dr.ntime):
            vardat = dr.vardata[cccvar][i,...].values
            blist=[time_bounds[i,0], time_bounds[i,1]]
            cmor.write(varid, vardat, 1, time_vals=dr.shifted_time[i], time_bnds=blist)
    else:
        cmor.write(varid, dr.vardata[cccvar].values)

    # close CMOR and dataset
    cmor.close()
    return True

def strip_extra_row_cols( data, has_x, has_y, has_z, xvar = 'x', yvar = 'y', zvar = 'z', xidx = [0,361],
                          yidx = [291], zidx = [45] ):
    """
    Given an xarray object, strip out the leftmost and rightmost columns and bottom row of the data for
    2d and 3d files

    Inputs:
        data  (xarray dataset) : xarray dataset or variables
        has_x (bool)           : True if extra longitudes should be removed
        has_y (bool)           : True if extra latitude should be
        has_z (bool)           : True if extra depths should be removed
    Inputs (optional):
        NOTE: All values are set to default values for CanESM5 using the ORCA1 grid
        xvar  (str)            : Name of the zonal dimension variable
        yvar  (str)            : Name of the meridional dimension variable
        zvar  (str)            : Name of the vertical dimension variable
        xidx  (list)           : Indices to remove in the x-variable
        yidx  (list)           : Indices to remove in the y-variable
        zidx  (list)           : Indices to remove in the z-variable
    """

    # Create the dictionary of indexers to retain
    retain_idx = {}
    if has_x:
        retain_idx[xvar] = [ x for x in range(0,len(data[xvar])) if x not in xidx ]
    if has_y:
        retain_idx[yvar] = [ y for y in range(0,len(data[yvar])) if y not in yidx ]
    if has_z:
        retain_idx[zvar] = [ z for z in range(0,len(data[zvar])) if z not in zidx ]

    return data.isel(**retain_idx).copy()

def check_time_bounds( tabname, time, time_bnds, units, calendar ):
    """
    Check that the time_counter is within the time_counter_bnds. If it is not, then for every 'time' cell measure
    OTHER than 'point' set the new time interval bound based on the frequency of the table (e.g. 'day', 'mon', 'yr')
    Inputs:
        tabname       (str)         : Name of the table (e.g. Omon)
        time          (real)        : Timestamps corresponding to the input file
        time_bnds     (1 x 2 array) : Timebounds of the given timestamp
        units         (str)         : The units of the reference calendar e.g. 'days since 1850-1-1 00:00:00'
        calendar      (str)         : The type of calendar used (e.g. '365_day')
    Output:
        time_bnds (ntime x2 array)    : Either the original time_bnds or the new valid ones
    """

    # Loop through and see if all of the time bounds are correct
    rework_time_bound = np.zeros(time.shape, dtype = bool)
    rework_time_bound[:] = True
    for t in range(len(time)):
        # Check to see if the timestamp is within the time bound
        if (time[t] >= time_bnds[t,0] and time[t] <= time_bnds[t,1]):
            rework_time_bound[t] = False
        else: # This will need to be reworked
            rework_time_bound[t] = True

    # If any of the timestamps are not within its time bound, then we need to figure out the
    # correct interval
    if any(rework_time_bound):
        print("WARNING: Time bounds invalid. Resetting to middle of the time period")
        valid_frequencies = ['day', 'mon', 'yr']
        freq = [ f for f in valid_frequencies if f in tabname ]
        if not freq:
            raise Exception('Reworking time bounds only works for day, mon, yr frequencies NOT the requested {}'.format(tab))

        # Now that we know what time period this variable should cover, we can set the time bounds in the following way:
        # Day: Hour 00 of current day and Hour 00 of next day
        # Month: Day 1 of current month and day 1 of next month (00 hour)
        # Year: Day 1 of current year and day 1 of next year (00 hour)
        for t in range(len(time)):
            year, month, day = cftime.num2date(time[t], units, calendar).timetuple()[0:3]
            if freq[0] == 'day':
                time_bnds[t,0] = cftime.date2num(cftime.datetime(year,month,day),units,calendar)
                time_bnds[t,1] = cftime.date2num(cftime.datetime(year,month,day+1),units,calendar)
            if freq[0] == 'mon':
                time_bnds[t,0] = cftime.date2num(cftime.datetime(year,month,1),units,calendar)
                time_bnds[t,1] = cftime.date2num(cftime.datetime(year,month+1,1),units,calendar)
            if freq[0] == 'yr':
                time_bnds[t,0] = cftime.date2num(cftime.datetime(year,1,1),units,calendar)
                time_bnds[t,1] = cftime.date2num(cftime.datetime(year+1,1,1),units,calendar)

    return time_bnds

if __name__ == '__main__':
    nemo_to_cmor('CMIP6_Omon.json', 'thetao', 'K', 'sc_rc2-hist_1980_2014_1m_grid_t_votemper.nc.001', 'votemper')
    #nemo_to_cmor('CMIP6_Oday.json', 'tos', 'degC', 'sc_rc2-hist_1980_2014_1d_grid_t_tos.nc.001', 'tos')
