# plot longitude-depth fields

import matplotlib # use the Agg environment to generate an image rather than outputting to screen
import matplotlib.pyplot as plt
import numpy as np
from netCDF4 import Dataset  # http://code.google.com/p/netcdf4-python/
from matplotlib.backends.backend_pdf import PdfPages

def geo_idx(dd, dd_array):
   geo_idx = (np.abs(dd_array - dd)).argmin()
   return geo_idx

def ncdump(nc_mod, verb=False):
    '''
    ncdump outputs dimensions, variables and their attribute information.
    The information is similar to that of NCAR's ncdump utility.
    ncdump requires a valid instance of Dataset.

    Parameters
    ----------
    nc_mod : netCDF4.Dataset
        A netCDF4 dateset object
    verb : Boolean
        whether or not nc_attrs, nc_dims, and nc_vars are printed

    Returns
    -------
    nc_attrs : list
        A Python list of the NetCDF file global attributes
    nc_dims : list
        A Python list of the NetCDF file dimensions
    nc_vars : list
        A Python list of the NetCDF file variables
    '''
    def print_ncattr(key):
        """
        Prints the NetCDF file attributes for a given key

        Parameters
        ----------
        key : unicode
            a valid netCDF4.Dataset.variables key
        """
        try:
            print "\t\ttype:", repr(nc_mod.variables[key].dtype)
            for ncattr in nc_mod.variables[key].ncattrs():
                print '\t\t%s:' % ncattr,\
                      repr(nc_mod.variables[key].getncattr(ncattr))
        except KeyError:
            print "\t\tWARNING: %s does not contain variable attributes" % key

    # NetCDF global attributes
    nc_attrs = nc_mod.ncattrs()
    if verb:
        print "NetCDF Global Attributes:"
        for nc_attr in nc_attrs:
            print '\t%s:' % nc_attr, repr(nc_mod.getncattr(nc_attr))
    nc_dims = [dim for dim in nc_mod.dimensions]  # list of nc dimensions
    # Dimension shape information.
    if verb:
        print "NetCDF dimension information:"
        for dim in nc_dims:
            print "\tName:", dim 
            print "\t\tsize:", len(nc_mod.dimensions[dim])
            print_ncattr(dim)
    # Variable information.
    nc_vars = [var for var in nc_mod.variables]  # list of nc variables
    if verb:
        print "NetCDF variable information:"
        for var in nc_vars:
            if var not in nc_dims:
                print '\tName:', var
                print "\t\tdimensions:", nc_mod.variables[var].dimensions
                print "\t\tsize:", nc_mod.variables[var].size
                print_ncattr(var)
    return nc_attrs, nc_dims, nc_vars

class nf(float):
    def __repr__(self):
        str = '%.1f' % (self.__float__(),)
        if str[-1] == '0':
            return '%.0f' % self.__float__()
        else:
            return '%.1f' % self.__float__()
            
#############################
pdf = PdfPages('out.pdf')

nc_r = './data2.nc'
nc_mod2 = Dataset(nc_r, 'r')
nc_attrs, nc_dims, nc_vars = ncdump(nc_mod2)

nc_f = './data1.nc'
nc_mod1 = Dataset(nc_f, 'r')
nc_attrs, nc_dims, nc_vars = ncdump(nc_mod1)

# Extract data from NetCDF file
if ('longitude' in nc_vars):
  lons = nc_mod1.variables['longitude'][:]
else :
  if ('rlon' in nc_vars):
    lons = nc_mod1.variables['rlon'][:]
  else :
    lons = nc_mod1.variables['lon'][:]

if ('latitude' in nc_vars):
  lats = nc_mod1.variables['latitude'][:]  # extract/copy the data
else :
  if ('rlat' in nc_vars):
    lats = nc_mod1.variables['rlat'][:]  # extract/copy the data
  else :
    lats = nc_mod1.variables['lat'][:]  # extract/copy the data

if ('time' in nc_vars):
  time = nc_mod1.variables['time'][:]
  ntime='time'
else :
  time = nc_mod1.variables['leadtime'][:]
  ntime='leadtime'

if ('depth' in nc_vars):
     levs = nc_mod1.variables['depth'][:]  # extract/copy the data
else :
  if ('lev' in nc_vars):
     levs = nc_mod1.variables['lev'][:]  # extract/copy the data
  else :
     levs = nc_mod1.variables['level'][:]  # extract/copy the data

var_name='VAR_NAME'
data1 = nc_mod1.variables[var_name][:]  # shape is time, lat, lon as shown above
data2 = nc_mod2.variables[var_name][:]  # shape is time, lat, lon as shown above

if ('wo' in nc_vars):
  data1=data1*100000
  data2=data2*100000

if (lats[0] > 0 ):
  lats=lats[::-1]
  data1 = data1[:,:,::-1,:]
  data2 = data2[:,:,::-1,:]

###
in_slat = SLAT
in_elat = ELAT
in_slon = SLON
in_elon = ELON
in_sdepth = SDEPTH
in_edepth = EDEPTH

lat_idx_start = geo_idx(in_slat, lats)
lat_idx_end = geo_idx(in_elat, lats)
lon_idx_start = geo_idx(in_slon, lons)+1
lon_idx_end = geo_idx(in_elon, lons)
depth_idx_start = geo_idx(in_sdepth, levs)
depth_idx_end = geo_idx(in_edepth, levs)

ilats = range(lat_idx_start,lat_idx_end + 1)           
jlons = range(lon_idx_start,lon_idx_end + 1)
kdepth = range(depth_idx_start,depth_idx_end + 1)

if ( nc_mod1.variables[ntime].units == 'hours' ) :
   time[ktime]=time[ktime]/24

NumLats = len(ilats)
NumLons = len(jlons)
NumDepth = len(kdepth)

Hovdata = np.ma.zeros((NumDepth,NumLons))
Hovref = np.ma.zeros((NumDepth,NumLons))
#
#  average along the latitude
#
latr = np.deg2rad(lats[lat_idx_start:lat_idx_end+1])
weights = np.cos(latr)

if ( 'thetao' in nc_vars or 'so' in nc_vars ):
     con_levels = np.arange(-2.0, 2, 0.2)
else:
     con_levels = np.arange(-0.5, 0.5, 0.05)

#
# PLOT
#
time_idx = [TIME_INDEX]
for i in range(len(time_idx)):
     middata = np.ma.zeros((NumDepth,NumLats,NumLons))
     middata = data1[time_idx[i],depth_idx_end+1:depth_idx_start:-1,lat_idx_start:lat_idx_end+1,lon_idx_start:lon_idx_end+1]
     Hovdata = np.ma.average(middata, axis=1, weights=weights)
     middata = np.ma.zeros((NumDepth,NumLats,NumLons))
     middata = data2[time_idx[i],depth_idx_end+1:depth_idx_start:-1,lat_idx_start:lat_idx_end+1,lon_idx_start:lon_idx_end+1]
     Hovref = np.ma.average(middata, axis=1, weights=weights)
     diff=Hovdata - Hovref
     plt.close('all')
     fig=plt.figure()
     fig.subplots_adjust(left=0.1, right=0.9, bottom=0.10, top=0.75)
     if ( nc_mod1.variables[ntime].units == 'hours' ) :
       if ('wo' not in nc_vars):
        fig.text(0.5, 0.80, "MODEL %s (%s) on %s lead days :\n lat : %1.1f ~ %1.1f  lon : %1.1f ~ %1.1f " % (nc_mod1.variables[var_name].long_name, nc_mod2.variables[var_name].units, time[time_idx[i]]/24.0, lats[lat_idx_start,], lats[lat_idx_end,], lons[lon_idx_start,], lons[lon_idx_end,]), ha='center',fontsize=12)
       else :
        fig.text(0.5, 0.80, "MODEL %s (1E-5 %s) on %s lead days :\n lat : %1.1f ~ %1.1f  lon : %1.1f ~ %1.1f " % (nc_mod1.variables[var_name].long_name, nc_mod2.variables[var_name].units, time[time_idx[i]]/24.0, lats[lat_idx_start,], lats[lat_idx_end,], lons[lon_idx_start,], lons[lon_idx_end,]), ha='center',fontsize=12)
     else :
       if ('wo' not in nc_vars):
        fig.text(0.5, 0.80, "MODEL %s (%s) on %s lead days :\n lat : %1.1f ~ %1.1f  lon : %1.1f ~ %1.1f " % (nc_mod1.variables[var_name].long_name, nc_mod2.variables[var_name].units, time[time_idx[i]], lats[lat_idx_start,], lats[lat_idx_end,], lons[lon_idx_start,], lons[lon_idx_end,]), ha='center',fontsize=12)
       else :
        fig.text(0.5, 0.80, "MODEL %s (1E-5 %s) on %s lead days :\n lat : %1.1f ~ %1.1f  lon : %1.1f ~ %1.1f " % (nc_mod1.variables[var_name].long_name, nc_mod2.variables[var_name].units, time[time_idx[i]], lats[lat_idx_start,], lats[lat_idx_end,], lons[lon_idx_start,], lons[lon_idx_end,]), ha='center',fontsize=12)

     m=plt.subplot(131)
     m=plt.contourf(Hovdata,15, extend="both")
     m1=plt.contour(m,colors='k',linewidths=0.5)
     m1.levels = [nf(val) for val in m1.levels]
     plt.clabel(m1, m1.levels, colors='k', fontsize=7, inline=True, fmt='%1.1f' )
     ax = plt.gca()
     LonTickInterval = 30
     ax.set_xticks(range(0,NumLons,LonTickInterval))
     ax.set_xticklabels(np.arange(lons[jlons[0]],lons[jlons[-1]],LonTickInterval).astype(int),fontsize=7)
     plt.yticks(rotation=45)
     ax.set_yticks(range(0,NumDepth,1))
     ax.set_yticklabels(levs[depth_idx_end+1:depth_idx_start:-1].astype(int),fontsize=7)
     plt.ylabel('Depth(m)',fontsize=7)
     plt.xlabel('Longitude',fontsize=7)
     plt.title("TYPE1",fontsize=9,loc='left')
     cbar = plt.colorbar(m,orientation='horizontal',format="%1.1f")
     cbar.ax.tick_params(labelsize=7)
     r=plt.subplot(132)
     r=plt.contourf(Hovref,15, extend="both")
     r1=plt.contour(r,colors='k',linewidths=0.5)
     r1.levels = [nf(val) for val in r1.levels]
     plt.clabel(r1, r1.levels, colors='k', fontsize=7, inline=True, fmt='%1.1f' )
     ax = plt.gca()
     ax.set_xticks(range(0,NumLons,LonTickInterval))
     ax.set_xticklabels(np.arange(lons[jlons[0]],lons[jlons[-1]],LonTickInterval).astype(int),fontsize=7)
     ax.set_yticks(range(0,NumDepth,1))
     ax.set_yticklabels(['']*len(kdepth))
     plt.xlabel('Longitude',fontsize=7)
     plt.title("TYPE2",fontsize=9,loc='left')
     cbar = plt.colorbar(r,orientation='horizontal',format="%1.1f")
     cbar.ax.tick_params(labelsize=7)
     d=plt.subplot(133)
     d=plt.contourf(diff, con_levels, extend="both")
     d1=plt.contour(d,con_levels,colors='k',linewidths=0.5)
     d1.levels = [nf(val) for val in d1.levels]
     plt.clabel(d1, d1.levels, colors='k', fontsize=7, inline=True, fmt='%1.1f' )
     ax = plt.gca()
     ax.set_xticks(range(0,NumLons,LonTickInterval))
     ax.set_xticklabels(np.arange(lons[jlons[0]],lons[jlons[-1]],LonTickInterval).astype(int),fontsize=7)
     ax.set_yticks(range(0,NumDepth,1))
     ax.set_yticklabels(['']*len(kdepth))
     plt.xlabel('Longitude',fontsize=7)
     plt.title("TYPE1-TYPE2",fontsize=9,loc='left')
     cbar = plt.colorbar(d,orientation='horizontal',format="%1.1f")
     cbar.ax.tick_params(labelsize=7)
     pdf.savefig(fig)
     plt.clf()

# Close original NetCDF file.
pdf.close()
nc_mod1.close()
