# plot 2D fields

import datetime as dt  # Python standard library datetime  module
import numpy as np
from netCDF4 import Dataset  # http://code.google.com/p/netcdf4-python/
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap, addcyclic, shiftgrid
from matplotlib.backends.backend_pdf import PdfPages
from mpl_toolkits.axes_grid1 import make_axes_locatable


def ncdump(nc_mod, verb=False):
    '''
    ncdump outputs dimensions, variables and their attribute information.
    The information is similar to that of NCAR's ncdump utility.
    ncdump requires a valid instance of Dataset.

    Parameters
    ----------
    nc_mod : netCDF4.Dataset
        A netCDF4 dateset object
    verb : Boolean
        whether or not nc_attrs, nc_dims, and nc_vars are printed

    Returns
    -------
    nc_attrs : list
        A Python list of the NetCDF file global attributes
    nc_dims : list
        A Python list of the NetCDF file dimensions
    nc_vars : list
        A Python list of the NetCDF file variables
    '''
    def print_ncattr(key):
        """
        Prints the NetCDF file attributes for a given key

        Parameters
        ----------
        key : unicode
            a valid netCDF4.Dataset.variables key
        """
        try:
            print "\t\ttype:", repr(nc_mod.variables[key].dtype)
            for ncattr in nc_mod.variables[key].ncattrs():
                print '\t\t%s:' % ncattr,\
                      repr(nc_mod.variables[key].getncattr(ncattr))
        except KeyError:
            print "\t\tWARNING: %s does not contain variable attributes" % key

    # NetCDF global attributes
    nc_attrs = nc_mod.ncattrs()
    if verb:
        print "NetCDF Global Attributes:"
        for nc_attr in nc_attrs:
            print '\t%s:' % nc_attr, repr(nc_mod.getncattr(nc_attr))
    nc_dims = [dim for dim in nc_mod.dimensions]  # list of nc dimensions
    # Dimension shape information.
    if verb:
        print "NetCDF dimension information:"
        for dim in nc_dims:
            print "\tName:", dim 
            print "\t\tsize:", len(nc_mod.dimensions[dim])
            print_ncattr(dim)
    # Variable information.
    nc_vars = [var for var in nc_mod.variables]  # list of nc variables
    if verb:
        print "NetCDF variable information:"
        for var in nc_vars:
            if var not in nc_dims:
                print '\tName:', var
                print "\t\tdimensions:", nc_mod.variables[var].dimensions
                print "\t\tsize:", nc_mod.variables[var].size
                print_ncattr(var)
    return nc_attrs, nc_dims, nc_vars

#############################
pdf = PdfPages('out.pdf')

nc_m = './mask.nc'
nc_mask = Dataset(nc_m, 'r')
omask = nc_mask.variables["sftlf"][:]  # shape is time, lat, lon as shown above
mask=omask[0, :, :]

nc_r = './my_data2.nc'
nc_mod2 = Dataset(nc_r, 'r')
nc_attrs, nc_dims, nc_vars = ncdump(nc_mod2)

nc_f = './my_data1.nc'
nc_mod1 = Dataset(nc_f, 'r')
nc_attrs, nc_dims, nc_vars = ncdump(nc_mod1)

# Extract data from NetCDF file
if ('msftmyzv' not in nc_vars and 'msftmyz' not in nc_vars ):
   if ('longitude' in nc_vars):
     lons = nc_mod1.variables['longitude'][:]
   else :
     if ('rlon' in nc_vars):
       lons = nc_mod1.variables['rlon'][:]
     else :
       lons = nc_mod1.variables['lon'][:]
   if ('latitude' in nc_vars):
     lats = nc_mod1.variables['latitude'][:]  # extract/copy the data
   else :
     if ('rlat' in nc_vars):
       lats = nc_mod1.variables['rlat'][:]  # extract/copy the data
     else :
       lats = nc_mod1.variables['lat'][:]  # extract/copy the data
   if ('time' in nc_vars):
     time = nc_mod1.variables['time'][:]
     ntime='time'
   else :
     time = nc_mod1.variables['leadtime'][:]
     ntime='leadtime'
else:
   if ('depth' in nc_vars):
        levs = nc_mod1.variables['depth'][:]  # extract/copy the data
   else :
     if ('lev' in nc_vars):
        levs = nc_mod1.variables['lev'][:]  # extract/copy the data
     else :
        levs = nc_mod1.variables['level'][:]  # extract/copy the data
   if ('latitude' in nc_vars):
     lats = nc_mod1.variables['latitude'][:]  # extract/copy the data
   else :
     if ('rlat' in nc_vars):
       lats = nc_mod1.variables['rlat'][:]  # extract/copy the data
     else :
       lats = nc_mod1.variables['lat'][:]  # extract/copy the data
   if ('time' in nc_vars):
     time = nc_mod1.variables['time'][:]
     ntime='time'
   else :
     time = nc_mod1.variables['leadtime'][:]
     ntime='leadtime'

var_name='VAR_NAME'
data1 = nc_mod1.variables[var_name][:]  # shape is time, lat, lon as shown above
data2 = nc_mod2.variables[var_name][:]  # shape is time, lat, lon as shown above

if ('msftmyzv' not in nc_vars and 'msftmyz' not in nc_vars ):
   fig = plt.figure()
   fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)
   m = Basemap(projection='cyl', llcrnrlat=-90, urcrnrlat=90,\
            llcrnrlon=0, urcrnrlon=360, resolution='c', lon_0=0)
   if (lons[1] < 0 ):
      m = Basemap(projection='cyl', llcrnrlat=-90, urcrnrlat=90,\
            llcrnrlon=-280, urcrnrlon=80, resolution='c', lon_0=-279.5)
   lon2d, lat2d = np.meshgrid(lons, lats)
   x, y = m(lon2d, lat2d)
   time_idx = [TIME_INDEX]
   for i in range(len(time_idx)):
    data = np.subtract(data1[time_idx[i], :, :], data2[time_idx[i], :, :])
    if ( 'msftbarot' in nc_vars ):
      data = data/10000000000.0

    mask1=np.where(mask==100.0,np.nan,1)
    data=data*mask1
    if ( 'tos' in nc_vars or 'sos' in nc_vars or 'hc300' in nc_vars ):
        con_levels = np.arange(-2.0, 2, 0.2)
    else:
       if ( 't20d' in nc_vars ):
          con_levels = np.arange(-100.0, 100, 10)
       else:
          if ( 'mlotst' in nc_vars ):
             con_levels = np.arange(-100.0, 100, 10)
          else:
             if ( 'msftbarot' in nc_vars ):
                con_levels = np.arange(-10.0, 10, 1)
             else:
                con_levels = np.arange(-1.0, 1, 0.1)

    cs = plt.subplot(111,axisbg='black')
    cs = m.contourf(x, y, data, con_levels, cmap=plt.cm.Spectral_r, extend='both')
    m.drawparallels(np.arange(-90.,120.,30.), labels=[1,0,0,0])
    m.drawmeridians(np.arange(0.,360.,60.), labels=[0,0,0,1])
    cbar = plt.colorbar(cs, orientation='horizontal', shrink=0.6)
    cbar.ax.tick_params(labelsize=6)
    cbar.set_label("%s (%s)" % (var_name, nc_mod1.variables[var_name].units))
    if ( nc_mod1.variables[ntime].units == 'hours' ) :
      plt.title("%s on %s lead days :\n %s" % (nc_mod1.variables[var_name].long_name, time[time_idx[i]]/24.0, "MODEL TYPE1-TYPE2"))
    else:
      plt.title("%s on %s lead days :\n %s" % (nc_mod1.variables[var_name].long_name, time[time_idx[i]], "MODEL TYPE1-TYPE2"))

    plt.contour(cs,con_levels,colors='k',linewidths=0.5)
    pdf.savefig(fig)
    plt.clf()
else :
    lat2d, lev2d = np.meshgrid(lats, -1*levs)
    x, y = (lat2d, lev2d)
    time_idx = [TIME_INDEX]
    fig = plt.figure()
    fig.subplots_adjust(left=0.2, right=0.9, bottom=0.3, top=0.85)
    for i in range(len(time_idx)):
      con_levels = np.arange(-50.0, 50, 5)
      z = np.subtract(data1[time_idx[i], :, :], data2[time_idx[i], :, :])
      if ( 'msftmyz' in nc_vars ):
        z = z/1000000000.0
      m = plt.subplot(111)
      cf = m.contourf(x[0:24,:], y[0:24,:], z[0:24,:],con_levels,extend='both')
      plt.contour(cf,con_levels,colors='k',linewidths=1)
      plt.ylabel('Depth (m)',fontsize=8)
      plt.tick_params(labelsize=7)
      if (  nc_mod1.variables[ntime].units == 'hours' ) :
        plt.title("%s on %s lead days :\n %s" % (nc_mod1.variables[var_name].long_name, time[time_idx[i]]/24.0, "MODEL TYPE1-TYPE2"))
      else:
        plt.title("%s on %s lead days :\n %s" % (nc_mod1.variables[var_name].long_name, time[time_idx[i]], "MODEL TYPE1-TYPE2"))
      cbar = plt.colorbar(cf, orientation='vertical', shrink=0.7)
      cbar.ax.tick_params(labelsize=7)
      cbar.set_label("%s (%s)" % (var_name, nc_mod1.variables[var_name].units),fontsize=8)
      divider = make_axes_locatable(m)
      axShallow = divider.append_axes("bottom", size="100%", pad=0.1, sharex=m)
      axShallow.contourf(x[25:,:], y[25:,:], z[25:,:],con_levels,extend='both')
      axShallow.contour(x[25:,:], y[25:,:], z[25:,:],con_levels,colors='k')
      axShallow.tick_params(labelsize=7)
      plt.ylabel('Depth (m)',fontsize=8)
      plt.xlabel('Latitude',fontsize=8)
      pdf.savefig(fig)
      plt.clf()
    
# Close original NetCDF file.
pdf.close()
nc_mod1.close()
