from fabric import Connection
import os
from typing import Dict, Sequence, Union
import appdirs
import sqlite3
import uuid
from filelock import FileLock
import numpy as np


def log_directory():
    try:
        os.makedirs(appdirs.user_data_dir('canesm-ensemble'))
    except OSError:
        pass
    return appdirs.user_data_dir('canesm-ensemble')


class RemoteFile:
    """
    Copy a remote file from the server, open it for changes, and upon exiting copy it back to the server and
    remove the local copy
    """
    def __init__(self, filename: str, machine: str, user: str,
                 gateway_conn: str = 'sci-eccc-in.science.gc.ca', mode: str = 'r'):
        self.filename = filename
        self.tmp_filename = os.path.join(log_directory(), str(uuid.uuid4()) + '.tmp')
        self.conn = Connection(machine, user, gateway=Connection(gateway_conn, user))
        self.mode = mode
        self.machine = machine
        self.user = user
        self.gateway_conn = gateway_conn

    def __enter__(self):
        with Connection(self.machine, self.user, gateway=Connection(self.gateway_conn, self.user)) as c:
            c.get(self.filename, self.tmp_filename)
        self.open_file = open(self.tmp_filename, self.mode)
        return self.open_file

    def __exit__(self, *args):
        self.open_file.close()
        if self.mode != 'r':
            with Connection(self.machine, self.user, gateway=Connection(self.gateway_conn, self.user)) as c:
                c.put(self.tmp_filename, self.filename)
        os.remove(self.tmp_filename)


class RemoteDBConn:
    """
    Helper for sqlite3 connections to databases on remote machines. Given a remote connection
    copy the database file to the local machine, and open a connection. On exit close the connection
    put the file back on remote and get rid of the local copy. While active the database file is locked
    to avoid syncing issues, but this won't be perfect.

    Parameters
    ----------
    filename:
        name of database file
    machine:
        name of remote machine
    user:
        account on remote machine used for ssh connection
    gateway_conn:
        ssh gateway

    Examples
    --------

    >>> with RemoteDBConn(db_file, machine, user, gateway_conn) as db_conn:
    >>>     c = db_conn.cursor()
    >>>     c.execute('select * from table')
    >>>     result = c.fetchall()

    .. note:: sqlite3 is not designed for remote
       connections so this is a bit of a hack until a proper database is implemented
    """
    def __init__(self, filename: str, machine: str, user: str, gateway_conn: str = 'sci-eccc-in.science.gc.ca'):
        self.filename = filename
        self.tmp_filename = os.path.join(log_directory(), os.path.basename(self.filename) + '.tmp')
        self.lock = FileLock(os.path.join(self.tmp_filename + '.lock'), timeout=10)
        self.machine = machine
        self.user = user
        self.gateway_conn = gateway_conn

    def __enter__(self):
        self.lock.acquire(timeout=10)
        try:
            with Connection(self.machine, self.user, gateway=Connection(self.gateway_conn, self.user)) as c:
                c.get(self.filename, self.tmp_filename)
        except FileNotFoundError:
            pass
        self.db_conn = sqlite3.connect(self.tmp_filename)
        return self.db_conn

    def __exit__(self, *args):
        self.db_conn.close()
        with Connection(self.machine, self.user, gateway=Connection(self.gateway_conn, self.user)) as c:
            c.put(self.tmp_filename, self.filename)
        os.remove(self.tmp_filename)
        self.lock.release()


class ProcessString(str):
    """
    abstract string processing class
    """

    def process(self, settings: Dict):
        """
        Replace the parameter values in the string with the values from `settings`.

        Parameters
        ----------
        settings:
            Dictionary of parameters to be replaced

        Returns
        -------
            string with updated settings
        """
        raise NotImplementedError


class ProcessBash(ProcessString):
    """
    Process a bash file, replacing option=val
    """
    def process(self, settings):
        # TODO: allow for variables across multiple lines (file_lists)
        # TODO: there must be a cleaner way than looping over all the keys and lines
        lines = []
        found = {key: False for key in settings}
        for line in self.split('\n'):
            # skip comments and blank lines to avoid key loop
            if len(line.strip()) == 0 or line.strip()[0] == '#' or line.strip()[0] == '!':
                pass
            else:
                for key in settings:
                    if key + '=' in line.replace(' ', ''):  # remove white space so "var = value" is found
                        line = self.process_line(line, key, settings[key])
                        found[key] = True
                    elif key[0] == '$' and key in line:
                        line = self.process_variable(line, key, settings[key])
                        found[key] = True
            lines.append(line)

        for key in found:
            if not found[key]:
                raise ValueError('option: ' + key + ' was not found')

        return ProcessBash('\n'.join(lines))

    @staticmethod
    def process_variable(line, key, val):
        return line.replace(key, str(val))

    @staticmethod
    def process_line(line, key, val):
        from_equal = line.find('=') + 1
        if '$' + key in line[from_equal:]:  # if its a reference to the variable do not replace it
            pass
        elif line.strip()[0:line.strip().find('=')].strip() != key and ';' not in line:  # reject new_key=val
            pass
        else:
            if ';' in line:  # if its a line with multiple variables split it and process each one
                line = ProcessBash.process_multivarline(line, key, val)
            else:
                line = line[0:from_equal] + str(val)
        return line

    @staticmethod
    def process_multivarline(line, key, val):
        if ';' not in line:
            return ProcessBash.process_line(line, key, val)

        line_vars = line.split(';')
        for idx, var in enumerate(line_vars):
            nl = var.strip().replace(' ', '')
            if nl[0:nl.find('=')] == key:
                var = var[0:var.find('=') + 1] + str(val)
            line_vars[idx] = var
        return ';'.join(line_vars)


class ProcessCPPDef(ProcessString):

    cpp_option_map = {True: '#define ', False: '#undef '}

    def process(self, settings):

        options = settings.keys()
        is_set = {key: False for key in options}
        lines = []
        for line in self.split('\n'):
            for option in options:
                # process only # lines
                if len(line.strip()) > 0 and line.strip()[0] == '#' and option in line:
                    # ignore conditionals
                    if line.strip().split(' ')[0] == '#undef' or line.strip().split(' ')[0] == '#define':
                        line = self.cpp_option_map[settings[option]] + option
                        is_set[option] = True
            lines.append(line)

        for option in options:
            if not is_set[option]:  # if the option was not found then set it now
                lines.append(self.cpp_option_map[settings[option]] + option)

        return ProcessCPPDef('\n'.join(lines))


def divide_list(l: Sequence, n: int):
    """
    Divide a list into chunks of size n

    Parameters
    ----------
    l :
        input list
    n :
        chunk size

    Yields
    ------
        n elements of the list
    """
    for i in range(0, len(l), n):
        yield l[i:i + n]


def year_from_time(time) -> int:
    """
    Get the year in the time field

    Parameters
    ----------
    time :
        A time in either integer or str (YYYY_mMM) format

    Returns
    -------
        year
    """

    year = int(convert_date(time).split('_m')[0])
    return year


def month_from_time(time, default_month: int = 12) -> int:
    """
    Get the month in the time field

    Parameters
    ----------
    time :
        A time in either integer or str (YYYY_mMM) format
    default_month:
        If month is not specified assume this month

    Returns
    -------
        month
    """
    month = int(convert_date(time, default_month=default_month).split('_m')[1])
    return month


def convert_date(date, default_month=12) -> str:
    """
    Convert the date to the canesm version 'YYYY_mMM'

    Raises
    ------
    ValueError
        If the string cannot be converted
    """
    if type(date) is str:
        try:
            date = f'{int(date)}_m{default_month:02}'
        except ValueError:
            for mdelim in ['_m', 'm', ':', '-']:
                if mdelim in date:
                    break
            else:
                raise ValueError('date format not recognized')
            year = int(date.split(mdelim)[0])  # check that year and month are integers
            month = int(date.split(mdelim)[1])
            date = f'{year}_m{month:02d}'
    elif type(date) is int:
        date = f'{int(date)}_m{default_month:02d}'
    elif type(date) is float:
        if abs(int(date) - date) > 0.001:
            raise ValueError('could not interpret float as date format, try YYYY_mMM')
        date = f'{int(date)}_m{default_month:02d}'
    else:
        raise ValueError('could not interpret date format, try YYYY_mMM')

    return date


def previous_month(date: str = None, year: int = None, month: int = None) -> str:

    if date is not None:
        date = convert_date(date)
        year = year_from_time(date)
        month = month_from_time(date)

    if month == 1:
        return f'{year-1}_m12'
    else:
        return f'{year}_m{month - 1:02d}'


def add_time(date: str = None, delta: Union[str, int] = None, year: int = None, month: int = None) -> str:

    if date is not None:
        date = convert_date(date)
        year = year_from_time(date)
        month = month_from_time(date)

    delta = convert_date(delta, default_month=0)
    delta_years = int(delta.split('_m')[0])
    delta_months = int(delta.split('_m')[1])

    if delta_months == 0:
        new_date = f'{year + delta_years}_m{month:02d}'
    else:
        add_years = delta_years + ((month + delta_months) // 12)
        add_month = (month + delta_months) % 12
        if add_month == 0:
            add_month = 12
            add_years -= 1
        new_date = f'{year + add_years}_m{add_month:02d}'

    return new_date


def table_path(table: str, config_file: str) -> str:
    """
    Get the path to the table file. First the cwd is searched and if the file is not found the path relative to the
    configuration file is tested.

    Parameters
    ----------
    table:

    config_file:

    Returns
    -------
        Path to the table file
    """
    if os.path.isfile(table):
        table = table
    else:
        table = os.path.join(os.path.dirname(os.path.abspath(config_file)), table)
    return table


def read_table(filename: str):

    """
    Load the text table into a dictionary for parsing by the ensemble code

    Parameters
    ----------
    filename :
        path of the run configuration table
    """
    import pandas as pd
    data = pd.read_csv(filename, delim_whitespace=True, comment='#')

    data_dict = {}
    for key in data.keys():
        if ':' in key:
            try:
                data_dict[key.split(':')[0]][key.split(':')[1]] = [x for x in data[key].values]
            except KeyError:
                data_dict[key.split(':')[0]] = {}
                data_dict[key.split(':')[0]][key.split(':')[1]] = [x for x in data[key].values]
        else:
            data_dict[key] = [x for x in data[key].values]  # convert to list from array

    return data_dict


def write_table(data: dict, filename: str):

    import pandas as pd

    table_dict = {}
    for key in data.keys():

        # flatten the dictionary
        if type(data[key]) is dict:
            for dkey in data[key].keys():
                table_dict[f'{key}:{dkey}'] = data[key][dkey]
        else:
            table_dict[key] = data[key]

    pd.DataFrame(data=table_dict).to_csv(filename, sep='\t', index=False)


def is_null(value):

    if (value is None):
        return True

    if type(value) is str:
        if value.lower() in ['none', 'null', 'nil']:
            return True

    return False
