# plot 3D fields

import datetime as dt  # Python standard library datetime  module
import numpy as np
from netCDF4 import Dataset  # http://code.google.com/p/netcdf4-python/
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap, addcyclic, shiftgrid

import matplotlib # use the Agg environment to generate an image rather than outputting to screen
from matplotlib.backends.backend_pdf import PdfPages
from mpl_toolkits.axes_grid1 import make_axes_locatable

def geo_idx(dd, dd_array):
   geo_idx = (np.abs(dd_array - dd)).argmin()
   return geo_idx

def ncdump(nc_mod, verb=False):
    '''
    ncdump outputs dimensions, variables and their attribute information.
    The information is similar to that of NCAR's ncdump utility.
    ncdump requires a valid instance of Dataset.

    Parameters
    ----------
    nc_mod : netCDF4.Dataset
        A netCDF4 dateset object
    verb : Boolean
        whether or not nc_attrs, nc_dims, and nc_vars are printed

    Returns
    -------
    nc_attrs : list
        A Python list of the NetCDF file global attributes
    nc_dims : list
        A Python list of the NetCDF file dimensions
    nc_vars : list
        A Python list of the NetCDF file variables
    '''
    def print_ncattr(key):
        """
        Prints the NetCDF file attributes for a given key

        Parameters
        ----------
        key : unicode
            a valid netCDF4.Dataset.variables key
        """
        try:
            print "\t\ttype:", repr(nc_mod.variables[key].dtype)
            for ncattr in nc_mod.variables[key].ncattrs():
                print '\t\t%s:' % ncattr,\
                      repr(nc_mod.variables[key].getncattr(ncattr))
        except KeyError:
            print "\t\tWARNING: %s does not contain variable attributes" % key

    # NetCDF global attributes
    nc_attrs = nc_mod.ncattrs()
    if verb:
        print "NetCDF Global Attributes:"
        for nc_attr in nc_attrs:
            print '\t%s:' % nc_attr, repr(nc_mod.getncattr(nc_attr))
    nc_dims = [dim for dim in nc_mod.dimensions]  # list of nc dimensions
    # Dimension shape information.
    if verb:
        print "NetCDF dimension information:"
        for dim in nc_dims:
            print "\tName:", dim 
            print "\t\tsize:", len(nc_mod.dimensions[dim])
            print_ncattr(dim)
    # Variable information.
    nc_vars = [var for var in nc_mod.variables]  # list of nc variables
    if verb:
        print "NetCDF variable information:"
        for var in nc_vars:
            if var not in nc_dims:
                print '\tName:', var
                print "\t\tdimensions:", nc_mod.variables[var].dimensions
                print "\t\tsize:", nc_mod.variables[var].size
                print_ncattr(var)
    return nc_attrs, nc_dims, nc_vars

class nf(float):
    def __repr__(self):
        str = '%.1f' % (self.__float__(),)
        if str[-1] == '0':
            return '%.0f' % self.__float__()
        else:
            return '%.1f' % self.__float__()
            
#############################
pdf = PdfPages('out.pdf')

nc_r = './data2.nc'
nc_mod2 = Dataset(nc_r, 'r')
nc_attrs, nc_dims, nc_vars = ncdump(nc_mod2)

nc_f = './data1.nc'
nc_mod1 = Dataset(nc_f, 'r')
nc_attrs, nc_dims, nc_vars = ncdump(nc_mod1)

# Extract data from NetCDF file
if ('longitude' in nc_vars):
  lons = nc_mod1.variables['longitude'][:]
else :
  if ('rlon' in nc_vars):
    lons = nc_mod1.variables['rlon'][:]
  else :
    lons = nc_mod1.variables['lon'][:]

if ('latitude' in nc_vars):
  lats = nc_mod1.variables['latitude'][:]  # extract/copy the data
else :
  if ('rlat' in nc_vars):
    lats = nc_mod1.variables['rlat'][:]  # extract/copy the data
  else :
    lats = nc_mod1.variables['lat'][:]  # extract/copy the data

if ('time' in nc_vars):
  time = nc_mod1.variables['time'][:]
  ntime='time'
else :
  time = nc_mod1.variables['leadtime'][:]
  ntime='leadtime'

if ('depth' in nc_vars):
     levs = nc_mod1.variables['depth'][:]  # extract/copy the data
else :
  if ('lev' in nc_vars):
     levs = nc_mod1.variables['lev'][:]  # extract/copy the data
  else :
     levs = nc_mod1.variables['level'][:]  # extract/copy the data

var_name='VAR_NAME'
data1 = nc_mod1.variables[var_name][:]  # shape is time, lat, lon as shown above
data2 = nc_mod2.variables[var_name][:]  # shape is time, lat, lon as shown above

if ('wo' in nc_vars):
  data1=data1*100000
  data2=data2*100000

if (lats[0] > 0 ):
  lats=lats[::-1]
  data1 = data1[:,0,::-1,:]
  data2 = data2[:,0,::-1,:]

###
in_slat = SLAT
in_elat = ELAT
in_slon = SLON
in_elon = ELON
in_stime = STIME
in_etime = ETIME

lat_idx_start = geo_idx(in_slat, lats)
lat_idx_end = geo_idx(in_elat, lats)
lon_idx_start = geo_idx(in_slon, lons)+1
lon_idx_end = geo_idx(in_elon, lons)
time_idx_start = geo_idx(in_stime, time)
time_idx_end = geo_idx(in_etime, time)

ilats = range(lat_idx_start,lat_idx_end + 1)
jlons = range(lon_idx_start,lon_idx_end + 1)
ktime = range(time_idx_start,time_idx_end + 1)

NumLats = len(ilats)
NumLons = len(jlons)
NumTimes = len(ktime)

Hovdata1 = np.ma.zeros((NumTimes,NumLons))
Hovdata2 = np.ma.zeros((NumTimes,NumLons))
#
#  average along the latitude
#
latr = np.deg2rad(lats[lat_idx_start:lat_idx_end+1])
weights = np.cos(latr)

middata = np.ma.zeros((NumTimes,NumLats,NumLons))
middata = data1[time_idx_start:time_idx_end+1,0,lat_idx_start:lat_idx_end+1,lon_idx_start:lon_idx_end+1]
Hovdata1 = np.ma.average(middata, axis=1, weights=weights)
middata = np.ma.zeros((NumTimes,NumLats,NumLons))
middata = data2[time_idx_start:time_idx_end+1,0,lat_idx_start:lat_idx_end+1,lon_idx_start:lon_idx_end+1]
Hovdata2 = np.ma.average(middata, axis=1, weights=weights)

diff=Hovdata1 - Hovdata2
if ( 'thetao' in nc_vars or 'so' in nc_vars ):
    con_levels = np.arange(-2.0, 2.0, 0.2)
else:
    con_levels = np.arange(-0.5, 0.5, 0.05)

#
#  PLOT
#
plt.close('all')
fig=plt.figure()
fig.subplots_adjust(left=0.1, right=0.9, bottom=0.00, top=0.85)
if ('wo' not in nc_vars):
  fig.text(0.5, 0.90, "MODEL %s at %s m (%s) :\n lat : %1.1f ~ %1.1f  lon : %1.1f ~ %1.1f " % (nc_mod1.variables[var_name].long_name,levs[0],nc_mod1.variables[var_name].units,lats[lat_idx_start,], lats[lat_idx_end,], lons[lon_idx_start,],lons[lon_idx_end,]), ha='center',fontsize=12)
else :
  fig.text(0.5, 0.90, "MODEL %s at %s m (1E-5 %s) :\n lat : %1.1f ~ %1.1f  lon : %1.1f ~ %1.1f " % (nc_mod1.variables[var_name].long_name,levs[0],nc_mod1.variables[var_name].units,lats[lat_idx_start,], lats[lat_idx_end,], lons[lon_idx_start,],lons[lon_idx_end,]), ha='center',fontsize=12)
m=plt.subplot(131)
m=plt.contourf(Hovdata1,15, extend="both")
levels = m.levels
m1=plt.contour(m,colors='k',linewidths=0.5)
m1.levels = [nf(val) for val in m1.levels]
plt.clabel(m1, m1.levels, colors='k', fontsize=7, inline=True, fmt='%1.1f' )
ax = plt.gca()
LonTickInterval = 30
ax.set_xticks(range(0,NumLons,LonTickInterval))
ax.set_xticklabels(np.arange(lons[jlons[0]],lons[jlons[-1]],LonTickInterval).astype(int),fontsize=7)
plt.yticks(rotation=45)
TimeTickInterval = 6
ax.set_yticks(range(0,NumTimes,TimeTickInterval))
ax.set_yticklabels(time[range(0,NumTimes,TimeTickInterval)].astype(int),fontsize=7)
if ( nc_mod1.variables[ntime].units == 'hours' ) :
   time[ktime]=time[ktime]/24

plt.ylabel('Lead Time(days)',fontsize=7)
plt.xlabel('Longitude',fontsize=7)
plt.title("TYPE1",fontsize=9,loc='left')
cbar = plt.colorbar(m,orientation='horizontal',format="%1.1f")
cbar.ax.tick_params(labelsize=7)

r=plt.subplot(132)
r=plt.contourf(Hovdata2,levels, extend="both")
r1=plt.contour(r,colors='k',linewidths=0.5)
r1.levels = [nf(val) for val in r1.levels]
plt.clabel(r1, r1.levels, colors='k', fontsize=7, inline=True, fmt='%1.1f' )
ax = plt.gca()
ax.set_xticks(range(0,NumLons,LonTickInterval))
ax.set_xticklabels(np.arange(lons[jlons[0]],lons[jlons[-1]],LonTickInterval).astype(int),fontsize=7)
ax.set_yticks(range(0,NumTimes,TimeTickInterval))
ax.set_yticklabels(['']*len(range(0,NumTimes,TimeTickInterval)))
plt.xlabel('Longitude',fontsize=7)
plt.title("TYPE2",fontsize=9,loc='left')
cbar = plt.colorbar(r,orientation='horizontal',format="%1.1f")
cbar.ax.tick_params(labelsize=7)

d=plt.subplot(133)
d=plt.contourf(diff, con_levels, extend="both")
d1=plt.contour(d,con_levels,colors='k',linewidths=0.5)
d1.levels = [nf(val) for val in d1.levels]
plt.clabel(d1, d1.levels, colors='k', fontsize=7, inline=True, fmt='%1.1f' )
ax = plt.gca()
ax.set_xticks(range(0,NumLons,LonTickInterval))
ax.set_xticklabels(np.arange(lons[jlons[0]],lons[jlons[-1]],LonTickInterval).astype(int),fontsize=7)
ax.set_yticks(range(0,NumTimes,TimeTickInterval))
ax.set_yticklabels(['']*len(range(0,NumTimes,TimeTickInterval)))
plt.xlabel('Longitude',fontsize=7)
plt.title("TYPE1-TYPE2",fontsize=9,loc='left')
cbar = plt.colorbar(d,orientation='horizontal',format="%1.1f")
cbar.ax.tick_params(labelsize=7)

pdf.savefig(fig)
plt.clf()

# Close original NetCDF file.
pdf.close()
nc_mod1.close()
nc_mod2.close()
