# plot 2D fields

import datetime as dt  # Python standard library datetime  module
import numpy as np
from netCDF4 import Dataset  # http://code.google.com/p/netcdf4-python/
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap, addcyclic, shiftgrid

from matplotlib.backends.backend_pdf import PdfPages
from mpl_toolkits.axes_grid1 import make_axes_locatable


def ncdump(nc_fid, verb=False):
    '''
    ncdump outputs dimensions, variables and their attribute information.
    The information is similar to that of NCAR's ncdump utility.
    ncdump requires a valid instance of Dataset.

    Parameters
    ----------
    nc_fid : netCDF4.Dataset
        A netCDF4 dateset object
    verb : Boolean
        whether or not nc_attrs, nc_dims, and nc_vars are printed

    Returns
    -------
    nc_attrs : list
        A Python list of the NetCDF file global attributes
    nc_dims : list
        A Python list of the NetCDF file dimensions
    nc_vars : list
        A Python list of the NetCDF file variables
    '''
    def print_ncattr(key):
        """
        Prints the NetCDF file attributes for a given key

        Parameters
        ----------
        key : unicode
            a valid netCDF4.Dataset.variables key
        """
        try:
            print "\t\ttype:", repr(nc_fid.variables[key].dtype)
            for ncattr in nc_fid.variables[key].ncattrs():
                print '\t\t%s:' % ncattr,\
                      repr(nc_fid.variables[key].getncattr(ncattr))
        except KeyError:
            print "\t\tWARNING: %s does not contain variable attributes" % key

    # NetCDF global attributes
    nc_attrs = nc_fid.ncattrs()
    if verb:
        print "NetCDF Global Attributes:"
        for nc_attr in nc_attrs:
            print '\t%s:' % nc_attr, repr(nc_fid.getncattr(nc_attr))
    nc_dims = [dim for dim in nc_fid.dimensions]  # list of nc dimensions
    # Dimension shape information.
    if verb:
        print "NetCDF dimension information:"
        for dim in nc_dims:
            print "\tName:", dim 
            print "\t\tsize:", len(nc_fid.dimensions[dim])
            print_ncattr(dim)
    # Variable information.
    nc_vars = [var for var in nc_fid.variables]  # list of nc variables
    if verb:
        print "NetCDF variable information:"
        for var in nc_vars:
            if var not in nc_dims:
                print '\tName:', var
                print "\t\tdimensions:", nc_fid.variables[var].dimensions
                print "\t\tsize:", nc_fid.variables[var].size
                print_ncattr(var)
    return nc_attrs, nc_dims, nc_vars

#############################
pdf = PdfPages('out.pdf')
nc_f = './my_data.nc'  
nc_fid = Dataset(nc_f, 'r') 
nc_attrs, nc_dims, nc_vars = ncdump(nc_fid)

# Extract data from NetCDF file
if ('longitude' in nc_vars):
  lons = nc_fid.variables['longitude'][:]
else :
  lons = nc_fid.variables['lon'][:]
if ('latitude' in nc_vars):
  lats = nc_fid.variables['latitude'][:]  # extract/copy the data
else :
  lats = nc_fid.variables['lat'][:]  # extract/copy the data
if ('time' in nc_vars):
  time = nc_fid.variables['time'][:]
else :
  time = nc_fid.variables['leadtime'][:]

var_name='VAR_NAME'
air = nc_fid.variables[var_name][:]  # shape is time, lat, lon as shown above

fig = plt.figure()
fig.subplots_adjust(left=0.1, right=0.9, bottom=0.1, top=0.9)
m = Basemap(projection='cyl', llcrnrlat=-90, urcrnrlat=90,\
         llcrnrlon=0, urcrnrlon=360, resolution='c', lon_0=0)
lon2d, lat2d = np.meshgrid(lons, lats)
x, y = m(lon2d, lat2d)
time_idx = [TIME_INDEX]
for i in range(len(time_idx)):
 if ( 'pr' in nc_vars or 'evspsbl' in nc_vars ):
    air1=air[time_idx[i], :, :]*86400.0
 else:
    air1=air[time_idx[i], :, :]
 cs = m.contourf(x, y, air1, 21, cmap=plt.cm.Spectral_r)
 m.drawcoastlines()
 m.drawmapboundary()
 m.drawparallels(np.arange(-90.,120.,30.), labels=[1,0,0,0])
 m.drawmeridians(np.arange(0.,360.,60.), labels=[0,0,0,1])
 cbar = plt.colorbar(cs, orientation='horizontal', shrink=0.6)
 if ( 'pr' in nc_vars or 'evspsbl' in nc_vars ):
    cbar.set_label("%s (mm/day)" % (var_name))
 else:
    cbar.set_label("%s (%s)" % (var_name, nc_fid.variables[var_name].units))
 if ( nc_fid.variables['leadtime'].units == 'hours' ) :
   plt.title("%s on %s lead days :\n %s" % (nc_fid.variables[var_name].long_name, time[time_idx[i]]/24.0, "MODEL"))
 else:
   plt.title("%s on %s lead days :\n %s" % (nc_fid.variables[var_name].long_name, time[time_idx[i]], "MODEL"))

 pdf.savefig(fig)
 plt.clf()
    
# Close original NetCDF file.
pdf.close()
nc_fid.close()
