#!/usr/bin/env python
'''
In some cases of CanESM5, the mfo diagnostic contains only 15 straits,
omitting the Strait of Gibraltar. Additionally, the ordering of the
straits is not quite correct. This deck checks the size of the straits
and if it's equal to 15, then it adds in zeros for Gibraltar and permutes
the rows as needed. If the number of straits is 16, the file is simply
written to the requested output file name.
'''

import argparse
import subprocess
import xarray as xr
import numpy as np

def get_options():
    description = ''
    parser = argparse.ArgumentParser(description=
            "Expand 15 strait files to 16 and reorder")

    parser.add_argument('pfx', help='Filename prefix', type = str)
    parser.add_argument('inp', help='Name of the input variable', type = str)
    parser.add_argument('out', help='Name of the output file', type = str)

    return parser.parse_args()

if __name__ == "__main__":
    args = get_options()
    access_cmd = 'access {inp} {pfx}_{inp}.nc'.format(inp = args.inp, pfx =args.pfx)
    print(access_cmd)
    subprocess.call(access_cmd,shell=True)
    data = xr.open_dataset(args.inp, decode_times = False)
    ntime, nlines = data['mfo'].shape
    # If all 16 lines are present, assume that all calculations were done correctly
    if nlines == 16:
        data.to_netcdf('{}_{}'.format(args.pfx,args.out))
        # Otherwise, Gibraltar is missing and the indexing needs to be reworked
    else:
        # Get all the coordinates, variables, and metadata from the original file
        data_vars = { var:data[var] for var in data.variables if var not in ['mfo','line','time_counter'] }
        coords = { coord:data[coord] for coord in data.coords if coord not in 'line' }

        # Based on the diagnostic file this is the reordering of indexes needed to get the
        # straits in the proper order
        reorder = np.array([1,2,4,3,5,6,7,8,9,10,16,11,12,13,14,15])-1
        # Create the output array
        out = np.ones( (ntime,16) )*1.e20 # This is the missing valuee for CMIP6
        # Store the original data
        out[:,0:15] = data['mfo'][:,:]
        # Permute the rows
        out[:,:] = out[:,reorder]
        mfo_var = xr.Variable(['time_counter','line'],out,data['mfo'].attrs)
        # Next create an xarray dataset and then save
        coords['line'] = np.arange(0,16) + 1
        data_vars['mfo'] = mfo_var
        outdata = xr.Dataset(data_vars, coords, data.attrs)
        outdata.to_netcdf('{}_{}'.format(args.pfx,args.out))
        subprocess.call('release {}'.format(args.inp),shell=True)

