# plot 2D fields

import matplotlib # use the Agg environment to generate an image rather than outputting to screen
import matplotlib.pyplot as plt
import numpy as np
from netCDF4 import Dataset  # http://code.google.com/p/netcdf4-python/
from matplotlib.backends.backend_pdf import PdfPages

def geo_idx(dd, dd_array):
   geo_idx = (np.abs(dd_array - dd)).argmin()
   return geo_idx

def ncdump(nc_mod, verb=False):
    '''
    ncdump outputs dimensions, variables and their attribute information.
    The information is similar to that of NCAR's ncdump utility.
    ncdump requires a valid instance of Dataset.

    Parameters
    ----------
    nc_mod : netCDF4.Dataset
        A netCDF4 dateset object
    verb : Boolean
        whether or not nc_attrs, nc_dims, and nc_vars are printed

    Returns
    -------
    nc_attrs : list
        A Python list of the NetCDF file global attributes
    nc_dims : list
        A Python list of the NetCDF file dimensions
    nc_vars : list
        A Python list of the NetCDF file variables
    '''
    def print_ncattr(key):
        """
        Prints the NetCDF file attributes for a given key

        Parameters
        ----------
        key : unicode
            a valid netCDF4.Dataset.variables key
        """
        try:
            print "\t\ttype:", repr(nc_mod.variables[key].dtype)
            for ncattr in nc_mod.variables[key].ncattrs():
                print '\t\t%s:' % ncattr,\
                      repr(nc_mod.variables[key].getncattr(ncattr))
        except KeyError:
            print "\t\tWARNING: %s does not contain variable attributes" % key

    # NetCDF global attributes
    nc_attrs = nc_mod.ncattrs()
    if verb:
        print "NetCDF Global Attributes:"
        for nc_attr in nc_attrs:
            print '\t%s:' % nc_attr, repr(nc_mod.getncattr(nc_attr))
    nc_dims = [dim for dim in nc_mod.dimensions]  # list of nc dimensions
    # Dimension shape information.
    if verb:
        print "NetCDF dimension information:"
        for dim in nc_dims:
            print "\tName:", dim 
            print "\t\tsize:", len(nc_mod.dimensions[dim])
            print_ncattr(dim)
    # Variable information.
    nc_vars = [var for var in nc_mod.variables]  # list of nc variables
    if verb:
        print "NetCDF variable information:"
        for var in nc_vars:
            if var not in nc_dims:
                print '\tName:', var
                print "\t\tdimensions:", nc_mod.variables[var].dimensions
                print "\t\tsize:", nc_mod.variables[var].size
                print_ncattr(var)
    return nc_attrs, nc_dims, nc_vars

class nf(float):
    def __repr__(self):
        str = '%.1f' % (self.__float__(),)
        if str[-1] == '0':
            return '%.0f' % self.__float__()
        else:
            return '%.1f' % self.__float__()
            
#############################
pdf = PdfPages('out.pdf')

nc_r = './ref_data.nc'  
nc_ref = Dataset(nc_r, 'r') 
nc_attrs, nc_dims, nc_vars = ncdump(nc_ref)
var_name='VAR_NAME'
ref = nc_ref.variables[var_name][:]  # shape is time, lat, lon as shown above

nc_f = './my_data.nc'  
nc_mod = Dataset(nc_f, 'r') 
nc_attrs, nc_dims, nc_vars = ncdump(nc_mod)

# Extract data from NetCDF file
if ('longitude' in nc_vars):
    lons = nc_mod.variables['longitude'][:]
else :
    lons = nc_mod.variables['lon'][:]

if ('latitude' in nc_vars):
    lats = nc_mod.variables['latitude'][:]  # extract/copy the data
else :
    lats = nc_mod.variables['lat'][:]  # extract/copy the data

if ('time' in nc_vars):
    time = nc_mod.variables['time'][:]
else :
    time = nc_mod.variables['leadtime'][:]

data = nc_mod.variables[var_name][:]  # shape is time, lat, lon as shown above

if (lats[0] > 0 ):
  lats=lats[::-1]
  data = data[:,::-1,:]
  ref = ref[:,::-1,:]

###
in_slat = SLAT
in_elat = ELAT
in_slon = SLON
in_elon = ELON
in_stime = STIME
in_etime = ETIME

lat_idx_start = geo_idx(in_slat, lats)
lat_idx_end = geo_idx(in_elat, lats)
lon_idx_start = geo_idx(in_slon, lons)+1
lon_idx_end = geo_idx(in_elon, lons)
time_idx_start = in_stime
time_idx_end = in_etime
#time_idx_start = geo_idx(in_stime, time)
#time_idx_end = geo_idx(in_etime, time)

ilats = range(lat_idx_start,lat_idx_end + 1)
jlons = range(lon_idx_start,lon_idx_end + 1)
ktime = range(time_idx_start,time_idx_end + 1)

NumLats = len(ilats)
NumLons = len(jlons)
NumTimes = len(ktime)

Hovdata = np.ma.zeros((NumTimes,NumLons))
Hovref = np.ma.zeros((NumTimes,NumLons))
#
#  average along the latitude
#
latr = np.deg2rad(lats[lat_idx_start:lat_idx_end+1])
weights = np.cos(latr)

middata = np.ma.zeros((NumTimes,NumLats,NumLons))
middata = data[time_idx_start:time_idx_end+1,lat_idx_start:lat_idx_end+1,lon_idx_start:lon_idx_end+1]
Hovdata = np.ma.average(middata, axis=1, weights=weights)
middata = np.ma.zeros((NumTimes,NumLats,NumLons))
middata = ref[time_idx_start:time_idx_end+1,lat_idx_start:lat_idx_end+1,lon_idx_start:lon_idx_end+1]
Hovref = np.ma.average(middata, axis=1, weights=weights)

diff=Hovdata - Hovref
if ( 'tp' in nc_vars ):
   con_levels = np.arange(-4.0, 4, 0.4)
else:
    if ( 'hfls' in nc_vars or 'hfss' in nc_vars or 'rlds' in nc_vars or 'rsds' in nc_vars ):
       con_levels = np.arange(-100.0, 100, 10)
    else :
       if ( 'msl' in nc_vars ):
          con_levels = np.arange(-1000.0, 1000, 100)
       else :
          if ( 'tcc' in nc_vars ):
             con_levels = np.arange(-50.0, 50, 5)
          else :
              if ( 'sst' in nc_vars ):
                 con_levels = np.arange(-2.0, 2.0, 0.2)
              else :
                 con_levels = np.arange(-2.0, 2, 0.5)

#
#  PLOT
#
plt.close('all')
fig=plt.figure()
fig.subplots_adjust(left=0.1, right=0.9, bottom=0.00, top=0.85)

if ( 'tp' in nc_vars ):
    fig.text(0.5, 0.90, "%s (mm/day) :\n lat : %1.1f ~ %1.1f  lon : %1.1f ~ %1.1f " % (nc_mod.variables[var_name].long_name, lats[lat_idx_start,], lats[lat_idx_end,], lons[lon_idx_start,], lons[lon_idx_end,]), ha='center',fontsize=12)
else:
   fig.text(0.5, 0.90, "%s (%s) :\n lat : %1.1f ~ %1.1f  lon : %1.1f ~ %1.1f " % (nc_mod.variables[var_name].long_name,nc_mod.variables[var_name].units,lats[lat_idx_start,], lats[lat_idx_end,], lons[lon_idx_start,],lons[lon_idx_end,]), ha='center',fontsize=12)

m=plt.subplot(131)
m=plt.contourf(Hovdata,15, extend="both")
levels = m.levels
m1=plt.contour(m,colors='k',linewidths=0.5)
m1.levels = [nf(val) for val in m1.levels]
plt.clabel(m1, m1.levels, colors='k', fontsize=7, inline=True, fmt='%1.1f' )
ax = plt.gca()
LonTickInterval = 30
ax.set_xticks(range(0,NumLons,LonTickInterval))
ax.set_xticklabels(np.arange(lons[jlons[0]],lons[jlons[-1]],LonTickInterval).astype(int),fontsize=7)
plt.yticks(rotation=45)
ax.set_yticks(range(0,NumTimes,1))

ax.set_yticklabels(ktime[:],fontsize=5)
plt.ylabel('Lead Time(days)',fontsize=7)
plt.xlabel('Longitude',fontsize=7)
plt.title("MODEL",fontsize=9,loc='left')
cbar = plt.colorbar(m,orientation='horizontal',format="%1.1f")
cbar.ax.tick_params(labelsize=7)

r=plt.subplot(132)
r=plt.contourf(Hovref,levels, extend="both")
r1=plt.contour(r,colors='k',linewidths=0.5)
r1.levels = [nf(val) for val in r1.levels]
plt.clabel(r1, r1.levels, colors='k', fontsize=7, inline=True, fmt='%1.1f' )
ax = plt.gca()
ax.set_xticks(range(0,NumLons,LonTickInterval))
ax.set_xticklabels(np.arange(lons[jlons[0]],lons[jlons[-1]],LonTickInterval).astype(int),fontsize=7)
ax.set_yticks(range(0,NumTimes,1))
ax.set_yticklabels(['']*len(ktime))
plt.xlabel('Longitude',fontsize=7)
plt.title("REF",fontsize=9,loc='left')
cbar = plt.colorbar(r,orientation='horizontal',format="%1.1f")
cbar.ax.tick_params(labelsize=7)

d=plt.subplot(133)
d=plt.contourf(diff, con_levels, extend="both")
d1=plt.contour(d,con_levels,colors='k',linewidths=0.5)
d1.levels = [nf(val) for val in d1.levels]
plt.clabel(d1, d1.levels, colors='k', fontsize=7, inline=True, fmt='%1.1f' )
ax = plt.gca()
ax.set_xticks(range(0,NumLons,LonTickInterval))
ax.set_xticklabels(np.arange(lons[jlons[0]],lons[jlons[-1]],LonTickInterval).astype(int),fontsize=7)
ax.set_yticks(range(0,NumTimes,1))
ax.set_yticklabels(['']*len(ktime))
plt.xlabel('Longitude',fontsize=7)
plt.title("MODEL-REF",fontsize=9,loc='left')
cbar = plt.colorbar(d,orientation='horizontal',format="%1.1f")
cbar.ax.tick_params(labelsize=7)


pdf.savefig(fig)
plt.clf()

# Close original NetCDF file.
pdf.close()
nc_mod.close()
