#!/usr/bin/env python
'''
The heat transports in each basin are calculated online, but need to be processed and combined
'''
import argparse
import subprocess
import xarray as xr
import numpy as np

invars = ['sophtadv','sophtove']

def get_options():
    description = ''
    parser = argparse.ArgumentParser(description=
            "Expand so that the array has the correct basin dimensions, convert heat transport")

    parser.add_argument('pfx', help='Filename prefix', type = str)
    for var in invars:
        parser.add_argument(var,  type = str)
    parser.add_argument('out', help='Name of the output file', type = str)

    return parser.parse_args()

if __name__ == "__main__":
    args = get_options()
    # Access the input file
    for var in invars:
        access_cmd = 'access {inp} {pfx}_{inp}.nc'.format(inp = var, pfx = args.pfx)
        subprocess.call(access_cmd,shell=True)
    data = xr.open_mfdataset(invars,decode_times=False).squeeze()
    # Define the dimensions of the output array
    ntime, nlat = data[invars[0]].shape
    nbasin = 3 # 1: Atlantic, 2: Indo-Pacific 3: Global

    # Set all values to missing by default
    hfbasin = np.ones((ntime,nbasin,nlat))*1.e20
    hfbasin[:,2,:] = (data[invars[0]]-data[invars[1]]).values*1.e15

    # Now make the final xarray dataset
    hfbasin_var = xr.Variable(['time_counter','basin','y'],hfbasin,data[invars[0]].attrs)
    # Next create an xarray dataset and then save
    coords = { coord:data[coord] for coord in data.coords if coord in ['time_counter','lat'] }
    # Rename coordinate name from lat to latitude
    coords['basin'] = np.arange(0,3) + 1
    data_vars = { }
    # Name the variable sohtatl because of assumptions made across all optional decks that the first variable
    # passed in should be the output name and convert units
    data_vars[invars[0]] = hfbasin_var
    outdata = xr.Dataset(data_vars, coords, data.attrs)
    outdata['time_counter_bnds'] = data['time_counter_bnds']
    outdata = outdata.rename({'lat':'nav_lat'})
    outdata = outdata.isel(nav_lat=slice(0,-2))
    outdata = outdata.isel(y=slice(0,-2))
    outdata.to_netcdf('{}_{}'.format(args.pfx,args.out))
    for var in invars:
        subprocess.call('release {}'.format(var),shell=True)
