import pandas as pd
from bokeh.io import curdoc
from bokeh.layouts import layout
from bokeh.models import ColumnDataSource, Label, Legend
from bokeh.models.widgets import DataTable, TableColumn
from bokeh.plotting import figure
from bokeh.transform import factor_cmap
from bokeh.themes import Theme
from io import StringIO
from canesm.scripts.cli_functions import setup_ensemble_class
from canesm.util import RemoteDBConn, convert_date, year_from_time, month_from_time, log_directory
from fabric import Connection
import os
import sys
import sqlite3


TOOLTIPS = [
    ("job", "@job_name"),
    ("status", "@status"),
    ("start time", "@start_time"),
    ("stop time", "@stop_time"),
    ("error_file", "@error_file"),
    ("completion", "@completion{(0,0)}%"),
    ("most recent event", "@recent_event")
]

# test_file = r'/HOME/rlr/PycharmProjects/canesm-ensemble/monitor_test/pamip11.yaml'
# test_file = r'/HOME/rlr/PycharmProjects/canesm-ensemble/canesm-ensemble/tests/setup_files/share_executable.yaml'
test_file = r'/HOME/rlr/PycharmProjects/canesm-ensemble/volmip-volc_pinatubo_full.yaml'
# test_file = sys.argv[1]
ensemble = setup_ensemble_class(test_file)
ensemble_size = ensemble.ensemble_size
if type(ensemble.runid) is list:
    runid = ''.join(ensemble.runid[0].split('-')[:-1])
else:
    runid = ensemble.runid
machine_map = {'hare': 'ppp1', 'brooks': 'ppp2'}


def run_command(ens, command: str,
                setup_env: bool = True,
                machine: str = None,
                jobname: str = None,
                conn: Connection = None):
    """
    runs a command on the remote maching from the :py:attr:`run_directory`

    Parameters
    ----------
    ens:
        ensemble used to determine machine and jobname parameters
    command :
        command that will be ran on the remote machine
    setup_env :
        if True commands are run inside the job environment
    machine :
        remote machine to run the command on
    jobname :
        name of ensemble member
    conn :
        Connection to remote machine. If not set a new connection will be created to machine.
    """

    if machine is None:
        machine = ens.machine

    if jobname is None:
        run_dir = ens.run_directory[0]
    else:
        for job in ens.jobs:
            if job.runid == jobname:
                run_dir = job.run_directory
                break
        else:
            print('could not find jobname in list of runids - this may be an error')
            run_dir = ens.jobs[0].run_directory

    txt = StringIO('')
    if conn is None:
        with Connection(machine, user=ens.user, gateway=Connection(ens.gateway_conn, user=ens.user)) as new_conn:
            with new_conn.cd(run_dir):
                if setup_env:
                    output = new_conn.run('. env_setup_file && ' + command, out_stream=txt)
                else:
                    output = new_conn.run(command, out_stream=txt)
    else:
        with conn.cd(run_dir):
            if setup_env:
                output = conn.run('. env_setup_file && ' + command, out_stream=txt)
            else:
                output = conn.run(command, out_stream=txt)

    return output


# TODO: Find a more accurate way to determine job completion
def get_most_recent_year(ens, jobnum, conn=None):
    # runpath = r'/space/hall2/work/eccc/crd/ccrn/users/ccc103/pamip11-003/data'
    runpath = '$RUNPATH'
    output = run_command(ens, 'ls ' + runpath + ' | grep ab',
                         machine=ens.machine, jobname=ens.job_runid(jobnum), conn=conn)
    files = output.stdout.strip().split('\n')
    year = 0
    for file in files:
        if '_m' in file:
            fyear = int(file.split('_m')[0][-4:])
            if fyear > year:
                year = fyear
    return year


def get_error_files(ens, conn=None):
    output = run_command(ens, 'ls ~/.queue',
                         machine=ens.machine, jobname=ens.job_runid(0), conn=conn)
    files = output.stdout.strip().split('\n')
    return files


def check_for_errors(ens, jobnum, conn=None, error_files=None):

    if error_files is None:
        output = run_command(ens, 'ls ~/.queue',
                             machine=ens.machine, jobname=ens.job_runid(jobnum), conn=conn)
        files = output.stdout.strip().split('\n')
    else:
        files = error_files

    error = False
    error_file = ''
    for file in files:
        if ens.job_runid(jobnum) in file:
            error = True
            error_file = file
            break
    return error, error_file


def get_qstat(ens, all_users=False, conn=None):

    if all_users:
        cmd = 'qstat -aw'
    else:
        cmd = 'qstat -aw -u ' + ens.user
    output = run_command(ens, cmd, setup_env=False, conn=conn)
    output = output.stdout
    output = '\n'.join(output.split('\n')[5:])
    qstat_columns = ['Job ID', 'Username', 'Queue', 'Jobname', 'SessID',
                     'NDS', 'TSK', "Req'd Memory", "Req'd Time", 'Status', 'Time Used']

    try:
        qstat_df = pd.read_csv(StringIO(output), delim_whitespace=True, names=qstat_columns)
        for idx, jobid in enumerate(qstat_df['Job ID'].values):
            jobnum = jobid.split('.')[0]
            output = run_command(ens, 'qstat -f ' + jobnum, setup_env=False, conn=conn)
            jobname = output.stdout.split('\n')[1].split('=')[1].strip()
            qstat_df.loc[idx, 'Jobname'] = jobname
    except pd.errors.EmptyDataError:
        qstat_df = pd.DataFrame({column: '--' for column in qstat_columns}, index=0)
    return qstat_df


def get_job_info(ens, jobnum, qstat=None, conn=None, error_files=None):

    info = {}
    job = ens.jobs[jobnum]

    filename = os.path.join(job.ccrnsrc, '..', job.runid + '-log.db')
    if conn is None:
        data = job.events
    else:
        tmp_filename = os.path.join(log_directory(), os.path.basename(filename) + '.tmp')

        conn.get(filename, tmp_filename)

        db = sqlite3.connect(tmp_filename)
        data = pd.read_sql_query('SELECT * FROM events', db)
        db.close()
        os.remove(tmp_filename)

    model_runs = [to_Timestamp(s.split('-')[-1]) for s in data[data.apply(lambda x: 'model-run' in x['Event'], axis=1)].Event.values]
    if len(model_runs) == 0:
        model_runs = [to_Timestamp(job.start_time), to_Timestamp(job.start_time)]
    elif len(model_runs) == 1:
        model_runs = [to_Timestamp(job.start_time), model_runs[1]]
    run_length = model_runs[1] - model_runs[0]

    info['jobname'] = job.runid
    info['start_time'] = pd.Timestamp(f'{job.start_year}-{job.start_month}')
    info['stop_time'] = pd.Timestamp(f'{job.stop_year}-{job.stop_month}')
    # info['current_year'] = get_most_recent_year(ens, jobnum, conn=conn)
    info['current_year'] = model_runs[-1]
    info['error'], info['error_file'] = check_for_errors(ens, jobnum, conn=conn, error_files=error_files)
    info['status'] = 'Not Submitted'
    info['most_recent_event'] = data.iloc[-1].Event
    if qstat is None:
        qstat = get_qstat(ens, conn=conn)

    if info['error']:
        info['status'] = 'Error'

    for i in range(len(qstat)):
        if info['jobname'] in qstat['Jobname'].iloc[i]:
            s = qstat['Status'].iloc[i]

            if s.lower() == 'r':
                info['status'] = 'Running'
            else:
                info['status'] = 'In Queue'
            break

    info['job_length'] = (info['stop_time'] - info['start_time'])
    if info['current_year'] + run_length >= to_Timestamp(job.stop_time):
        info['status'] = 'Finished'

    info['completion'] = (info['current_year'] - info['start_time']) / info['job_length'] * 100
    return info


# def get_text(attr, old, new):
#     proc = subprocess.Popen(new, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
#     qstat_txt.text = str(proc.stdout.read().decode('utf-8'))

def to_Timestamp(date):

    return pd.Timestamp(f'{year_from_time(date)}-{month_from_time(date)}')


def update_data():

    with Connection(ensemble.machine, user=ensemble.user,
                    gateway=Connection(ensemble.gateway_conn, user=ensemble.user)) as c:
        error_files = get_error_files(ensemble, conn=c)
        jobinfo = get_ensemble_info(ensemble, conn=c, error_files=error_files)
        new_df = get_qstat(ensemble, all_users=False, conn=c)

    total = jobinfo['total_years']
    finished = jobinfo['finished_years']
    source.data['status'] = jobinfo['status']
    source.data['completion'] = jobinfo['completion']
    source.data['error_file'] = jobinfo['error_file']

    label.text = f'Finished {int(finished)} of {int(total)} years'
    progress_data.data['finished'] = [finished / total * 100]
    progress_data.data['queued'] = [(total - finished) / total * 100]
    data_table.source = ColumnDataSource(new_df)
    # if text_input.value == 'qstat -u user':


def get_ensemble_info(ens, conn=None, error_files=None):

    info = []
    qstat = get_qstat(ens, conn=conn)
    if error_files is None:
        error_files = get_error_files(ens, conn)
    for job in range(ens.ensemble_size):
        info.append(get_job_info(ens, job, qstat, conn=conn, error_files=error_files))

    total_years = pd.Timedelta(0, 'D')
    finished_years = pd.Timedelta(0, 'D')
    for job in info:
        total_years += job['job_length']
        finished_years += job['job_length'] * (job['completion'] / 100)
    total_years /= pd.Timedelta(1, 'Y')
    finished_years /= pd.Timedelta(1, 'Y')

    info_dict = {}
    for key in info[0].keys():
        info_dict[key] = [i[key] for i in info]

    info_dict['total_years'] = total_years
    info_dict['finished_years'] = finished_years

    return info_dict


curdoc().theme = Theme(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'custom_theme.yaml'))

status_map = {0: 'Finished', 1: 'Running', 2: 'In Queue', 3: 'Error', 4: 'Not Submitted'}
status_color = ['#78c679', '#addd8e', '#d9f0a3', '#C9411B', '#B6B6B6']

with Connection(ensemble.machine, user=ensemble.user,
                gateway=Connection(ensemble.gateway_conn, user=ensemble.user)) as c:
    jobinfo = get_ensemble_info(ensemble, conn=c)

job_idx = [str(i + 1) for i in range(ensemble_size)]
source = ColumnDataSource(data=dict(job_idx=job_idx, completion=jobinfo['completion'],
                                    status=jobinfo['status'], job_name=jobinfo['jobname'],
                                    start_time=jobinfo['start_time'], stop_time=jobinfo['stop_time'],
                                    error_file=jobinfo['error_file'], recent_event=jobinfo['most_recent_event']))
job_status = figure(x_range=job_idx, plot_height=330, plot_width=600, sizing_mode='stretch_width',
                    toolbar_location='right', tooltips=TOOLTIPS, name='job_status')

# dummy entries for legend
items = []
for idx, (color, name) in enumerate(zip(status_color, list(status_map.values()))):
    items += [(name, [job_status.rect(-100, idx, width=1, height=1, color=color)])]

legend = Legend(items=items, orientation='horizontal', location=(50, 10), label_text_font_size='12pt')

job_status.vbar(x='job_idx', top='completion', bottom=-10, width=0.9, source=source, line_color='white',
                fill_color=factor_cmap('status', palette=status_color, factors=list(status_map.values())))


job_status.yaxis.axis_label = "Completion [%]"
job_status.xgrid.grid_line_color = None
job_status.xaxis.major_tick_line_color = None  # turn off x-axis major ticks
job_status.xaxis.minor_tick_line_color = None  # turn off x-axis minor ticks
job_status.xaxis.major_label_text_font_size = '0pt'  # preferred method for removing tick labels
job_status.y_range.start = -10
job_status.y_range.end = 100
job_status.yaxis.axis_label_text_font_size = '14pt'
job_status.yaxis.axis_label_text_font_style = 'bold'
job_status.yaxis.major_label_text_font_size = '12pt'
job_status.add_layout(legend, 'below')

idx = [0]
years = ["finished", "queued"]
colors = ["#31a354", "#B6B6B6"]
total = jobinfo['total_years']
finished = jobinfo['finished_years']
progress_data = ColumnDataSource({'idx': idx,
                                  'finished': [finished / total * 100],
                                  'queued': [(total - finished) / total * 100]})

progress = figure(plot_height=75, plot_width=600, sizing_mode='stretch_width', toolbar_location=None, name='progress')
progress.hbar_stack(years, y='idx', height=1.2, color=colors, source=progress_data)
progress.x_range.start = 0
progress.x_range.end = 100
progress.axis.visible = False
label = Label(y=0, x=50,
              text=f'Finished {int(finished)} of {int(total)} years',
              text_align='center', text_baseline='middle', text_color='#FFFFFF',
              text_font_size='22pt', text_font_style='normal')
progress.add_layout(label)

df = get_qstat(ensemble, all_users=False)
table_source = ColumnDataSource(df)
columns = [TableColumn(field=key, title=key) for key in df.keys()]
data_table = DataTable(source=table_source, columns=columns, sizing_mode='stretch_width',
                       index_position=None, name='data_table')

curdoc().add_periodic_callback(update_data, 30000)
curdoc().add_root(progress)
curdoc().add_root(job_status)
curdoc().add_root(data_table)
# curdoc().add_root(layout([[progress], [p], [data_table]]))
# show(layout([[progress], [p], [data_table]]))

curdoc().title = "Bokeh Dashboard"
curdoc().template_variables['stats_names'] = ['users', 'machine', 'sim_years', 'ens_size', 'config']
curdoc().template_variables['stats'] = {'users': {'icon': 'user',
                                                  'value': ensemble.user,
                                                  'label': 'User Account'},
                                        'machine': {'icon': 'server',
                                                    'value': ensemble.machine,
                                                    'label': 'Remote Machine'},
                                        'sim_years': {'icon': 'clock-o',
                                                      'value': jobinfo['total_years'],
                                                      'label': 'Simulation Years'},
                                        'ens_size': {'icon': 'ellipsis-v',
                                                     'value': ensemble_size,
                                                     'label': 'Ensemble Size'},
                                        'config': {'icon': 'globe-americas',
                                                   'value': ensemble.config,
                                                   'label': 'Configuration'}}
