import canesm
import logging
import time
import yaml
import os

from canesm.canesm_setup import CanESMsetup
from canesm.canesm_database import CanESMensembleDB
from canesm.job_submitter import CanESMsubmitter
from canesm.util import divide_list, log_directory
from canesm.util import convert_date, add_time
from canesm.util import read_table, table_path
from canesm.util import RemoteFile
from canesm.exceptions import RemoteError
from fabric import Connection
from typing import List, Union
from threading import Thread


class CanESMensemble:
    """
    A class that generates :py:class:`canesm.CanESMsetup` instances for each the ensemble members.

    Attributes
    ----------
    ver
        git hash or git branch that will used to pull the code
    config
        `AMIP` or `ESM`
    runid
        name of the run
    user
        User name on the machine where the job will be ran
    run_directory
        Directory name where the code will be stored
    machine
        Name of the machine where the job is run, either `hare` or `brooks`
    ensemble_size : int
        Number of ensemble members that will be generated
    share_member_code : bool
        If True each new member of the ensemble will link to the code of the first
        member instead of cloning from the git repository.
    share_executables : bool
        If True each new member of the ensemble will use the executable from the
        first ensemble member. Not currently supported!
    start_time : int, List[int]
        Year at which the simulation is began. If a list is provided it should have
        a length of :py:attr:`ensemble_size`
    stop_time : int, List[int]
        Year at which the simulation is ended. If a list is provided it should have
        a length of :py:attr:`ensemble_size`
    tapeload : bool, List[bool]
        If the restart files are stored on tape this option should be set to True
    restart_dates : int, str, List[int, str]
        Date used to load the restart files. If int, the 12th month is assumed.
    restart_files : str, List[str]
        Name of the run used for the restarts

    Examples
    --------

    >>> esm = CanESMensemble(ver='develop-canesm', config='AMIP', runid='testrun',
    ...                      user='raa000', run_directory='test_folder', machine='hare')
    >>> esm.ensemble_size = 3
    >>> esm.restart_files = 'vsa_v4_01'
    >>> esm.restart_dates = 1990
    >>> esm.start_time = 2000
    >>> esm.stop_time = 2100
    >>> esm.tapeload = True
    >>> esm.pp_rdm_num_pert = [2, 4, 6]
    >>> esm.setup_ensemble()
    """

    def __init__(self, ver: str, config: str, runid: str,
                 user: str, run_directory: Union[str, List[str]], machine: str,
                 gateway_conn: str = 'sci-eccc-in.science.gc.ca'):

        # ensemble options
        self.ensemble_size = 1
        self.restart_files = None
        self.restart_dates = None
        self.pp_rdm_num_pert = None

        # canesm run setup
        self.config = config
        self.ver = ver
        self.runid = runid
        self.start_time = None
        self.stop_time = None
        self.tapeload = False
        self.first_job_number = 0
        self.setup_flags = None

        # machine and user specifications
        self.machine = machine  # 'hare', 'brooks'
        self.run_directory = run_directory  # directory where the code is cloned
        self.user = user  # user name used for ssh connections
        self.gateway_conn = gateway_conn

        self.phys_parm_from_local = False
        self.inline_from_local = False

        # dictionaries containing file-specific setup changes
        self.canesm_cfg = {}
        self.phys_parm = {}
        self.basefile = {}
        self.inline_diag_nl = {}
        self.cpp_defs = {}
        self.restart_options = {}

        self.job_delimiter = '-'
        self.job_str_zeropad = 3

        # loading from a base job
        self.share_member_code = True
        self.share_executables = True

        self.submit_ensemble = False

        self._db_is_setup = False
        self.db = None
        self.submitter = None
        self._jobs = None

        self.config_file = None

        self.max_threads = 1
        self.log_file = os.path.join(log_directory(), self.job_runid(0) + '.log')

        self.logger = logging.getLogger('canesm-ensemble')
        fh = logging.FileHandler(self.log_file)
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        formatter.datefmt = '%Y-%m-%d %H:%M:%S'
        fh.setFormatter(formatter)
        self.logger.addHandler(fh)
        self.logger.setLevel(logging.INFO)

    def __repr__(self):
        ret = f'Ensemble object containing {self.ensemble_size} members:\n\t'
        ret += '\n\t'.join(self.runids)
        ret += f'\nuser: {self.user}\nrunid: {self.runid}\nmachine: {self.machine}'
        return ret

    @classmethod
    def from_config_file(cls, config_file):
        """
        Setup a CanESMensemble class using a YAML or JSON configuration file.

        Parameters
        ----------
        config_file
            path to file that will be used for setup

        Returns
        -------
            The ensemble class
        """
        opt = yaml.load(open(config_file, 'r'), Loader=yaml.SafeLoader)

        # setup options from table if available
        if 'config_table' in opt.keys():
            runs = read_table(table_path(opt['config_table'], config_file))
            opt['ensemble_size'] = len(runs['runid'])
            for key in runs.keys():
                if key in opt.keys():
                    if type(opt[key]) is dict:
                        if set(opt[key]) & set(runs[key]):
                            logging.warning(key + ' values in yaml file will be overwritten by table values')
                        opt[key] = {**opt[key], **runs[key]}
                    else:
                        logging.warning(key + ' values in yaml file will be overwritten by table values')
                        opt[key] = runs[key]
                else:
                    opt[key] = runs[key]

        ens = cls(ver=opt['ver'], config=opt['config'], runid=opt['runid'],
                  user=opt['user'], run_directory=opt['run_directory'], machine=opt['machine'])

        for var in opt.keys():
            setattr(ens, var, opt[var])

        # check if restart_every_x_years option is used and set dates accordingly
        if 'restart_dates' in opt.keys():
            if 'restart_every_x_years' in opt.keys():
                ens.restart_dates = [add_time(convert_date(opt['restart_dates']), i * opt['restart_every_x_years'])
                                     for i in range(opt['ensemble_size'])]
            else:
                ens.restart_dates = opt['restart_dates']

        if 'pp_rdm_num_pert' in opt.keys():
            if opt['pp_rdm_num_pert'] == 'from job number':
                ens.pp_rdm_num_pert = [i * 10 for i in range(0, ens.ensemble_size)]

        ens.broadcast_variables()
        ens.setup_database()
        ens.config_file = config_file
        return ens

    @property
    def jobs(self) -> List[CanESMsetup]:
        """
        get job information for all ensemble members without submitting the jobs
        """
        if self._jobs is None:
            self._jobs = [self._setup_job(jobidx, setup_on_remote=False) for jobidx in range(self.ensemble_size)]
        return self._jobs

    @property
    def runids(self) -> List[str]:
        return [self.job_runid(idx) for idx in range(self.ensemble_size)]

    @property
    def executable_folder(self):
        try:
            folder = os.path.abspath(os.path.join(self.jobs[0].runpath, '..', self.runids[0] + '_executable_folder'))
        except RemoteError:
            folder = os.path.join('$RUNPATH', '..', self.runids[0] + '_executable_folder')

        return folder

    def job_runid(self, job_num: int) -> str:
        """
        Get the runid of the job based on the job index.

        Parameters
        ----------
        job_num:
            Integer between 0 and ensemble_size

        Return
        ------
            Name of the job
        """
        if type(self.runid) is list:
            rid = self.runid[job_num]
        else:
            rid = self.runid
            if rid[-1] == self.job_delimiter:
                rid = rid[:-1]
            rid = f'{rid}{self.job_delimiter}{job_num + self.first_job_number:0{self.job_str_zeropad}}'
        return rid

    def broadcast_variables(self):
        """
        Broadcast any variables to be the same length as the ensemble size
        """

        # broadcast the class attributes
        for var in ['run_directory', 'tapeload', 'setup_flags', 'start_time', 'stop_time',
                    'restart_files', 'restart_dates', 'inline_from_local', 'phys_parm_from_local', 'pp_rdm_num_pert']:
            if type(self.__getattribute__(var)) is list:
                if len(self.__getattribute__(var)) != self.ensemble_size:
                    raise ValueError(f'{var} is not the same length as ensemble size')
            else:
                self.__setattr__(var, [self.__getattribute__(var)] * self.ensemble_size)

        # broadcast the dictionary entries
        for var in [self.canesm_cfg, self.basefile, self.phys_parm, self.inline_diag_nl, self.restart_options]:
            for key in var.keys():
                if type(var[key]) is list:
                    if len(var[key]) != self.ensemble_size:
                        raise ValueError(key + ' is not the same length as ensemble size')
                else:
                    var[key] = [var[key]] * self.ensemble_size

    def _setup_job(self, job_num: int, setup_on_remote: bool = True) -> CanESMsetup:
        """

        Parameters
        ----------
        job_num:
            index of the job that will be setup (from 0 to self.ensemble_size)
        setup_on_remote: optional
            If True sets up the job on remote. If False, all job options are set, but no
            changes are made on remote. (Default is True)

        Returns
        -------
            CanESMsetup
        """
        runid = self.job_runid(job_num)
        job_folder = os.path.join(self.run_directory[job_num], runid)

        job = CanESMsetup(ver=self.ver, config=self.config, runid=runid,
                          user=self.user, run_directory=job_folder, machine=self.machine)

        job.gateway_conn = self.gateway_conn
        job.tapeload = self.tapeload[job_num]
        job.setup_flags = self.setup_flags[job_num]
        job.phys_parm_from_local = self.phys_parm_from_local[job_num]
        job.inline_from_local = self.inline_from_local[job_num]
        job.start_time = self.start_time[job_num]
        job.stop_time = self.stop_time[job_num]

        job.phys_parm = {key: self.phys_parm[key][job_num] for key in self.phys_parm.keys()}
        job.basefile = {key: self.basefile[key][job_num] for key in self.basefile.keys()}
        job.canesm_cfg = {key: self.canesm_cfg[key][job_num] for key in self.canesm_cfg.keys()}
        job.inline_diag_nl = {key: self.inline_diag_nl[key][job_num] for key in self.inline_diag_nl.keys()}
        job.restart_files = {key: self.restart_options[key][job_num] for key in self.restart_options}
        job.cpp_defs = self.cpp_defs

        job.canesm_cfg['parent_runid'] = self.restart_files[job_num]
        job.canesm_cfg['parent_branch_time'] = convert_date(self.restart_dates[job_num])

        if self.pp_rdm_num_pert[job_num] is not None:
            job.phys_parm['pp_rdm_num_pert'] = self.pp_rdm_num_pert[job_num]

        if setup_on_remote:
            base_directory = None
            share_exec = False
            if self.share_member_code and (job_num > 0):
                base_directory = os.path.join(self.run_directory[0], self.job_runid(0))
                if self.share_executables:
                    share_exec = True
            job.setup_job(base_directory=base_directory, share_executables=share_exec,
                          executable_directory=self.executable_folder)
            if job_num == 0 and self.share_executables:
                self._save_executables(job)

        return job

    def _save_executables(self, job: CanESMsetup):
        """
        store the executables so that they can be reused for later runs

        Parameters
        ----------
        job:
            ensemble member where the executables will be copied from
        """

        base_runpath = job.run_command(f'cd {job.run_directory} && . env_setup_file && echo $RUNPATH',
                                       setup_env=False, run_directory='~').stdout.strip()
        base_files = job.run_command(f'ls {base_runpath}',
                                     setup_env=False, run_directory='~').stdout.strip().split('\n')

        job.run_command('mkdir -p ' + self.executable_folder, run_directory='~', setup_env=False)
        for file in base_files:
            if any(ex in file for ex in ['_ab.', '_cpl.exe', '_nemo.exe']):
                src = os.path.join(base_runpath, file)
                dest = os.path.join(self.executable_folder, file)
                job.run_command(f'cp {src} {dest}', run_directory='~', setup_env=False)

    def verify_setup(self):
        """
        Check that the ensemble has a valid setup

        Returns
        -------
            True if something differs between the ensemble members, False otherwise
        """

        for var in [self.restart_dates, self.restart_files, self.start_time, self.stop_time]:
            if var is None:
                raise ValueError('restarts_dates, restart_files, start_time and stop_time must be set')

        if self.share_executables and not self.share_member_code:
            self.logger.warning('If executables are going to be shared the `share_member_code` option should be `True`')

        if 'pp_rdm_num_pert' in self.phys_parm.keys():
            raise ValueError('pp_rdm_num_pert is a bit special and should not be set as part of phys_parm. '
                             'Use pp_rdm_num_per: value instead.')

        if len(set(self.restart_options.keys()) & {'runid_in_a', 'runid_in_c', 'runid_in'}) > 0:
            raise ValueError('runid_in is a bit special and should not be set as part of restart_options. '
                             'Use restart_files: value instead.')

        if len(set(self.restart_options.keys()) & {'date_in_a', 'date_in_c', 'date_in'}) > 0:
            raise ValueError('date_in is a bit special and should not be set as part of restart_options. '
                             'Use restart_dates: value instead.')

        if self.ensemble_size == 1:
            return True

        for var in [self.pp_rdm_num_pert, self.restart_dates, self.restart_files, self.start_time, self.stop_time]:
            if type(var) is list:
                if len(set(var)) > 1:
                    return True

        for var in [self.canesm_cfg, self.basefile, self.phys_parm, self.inline_diag_nl]:
            for key in var.keys():
                if len(set(var[key])) > 1:
                    return True

        raise ValueError('All ensemble members appear to be identical')

    def setup_database(self):
        """
        Setup the database name and class
        """
        if self._db_is_setup:
            return

        try:
            db_file = os.path.join(self.run_directory, self.job_runid(0) + '.db')
        except TypeError:
            db_file = os.path.join(self.run_directory[0], self.job_runid(0) + '.db')
        self.db = CanESMensembleDB(db_file, self.machine, self.user, self.gateway_conn)
        self.submitter = CanESMsubmitter(self.db, delay=0)

    def setup_ensemble(self):
        """
        setup the ensemble members on the remote machine
        """

        self.broadcast_variables()
        self.verify_setup()
        self.setup_database()

        # make the directory on remote and setup the database
        with Connection(self.machine, self.user, gateway=Connection(self.gateway_conn, self.user)) as c:
            for directory in set(self.run_directory):
                c.run('mkdir -p ' + directory)
        self.db.setup(self.jobs)

        # determine what jobs need to be setup
        issetup = self.db.get(column='setup', keys=self.runids)
        if type(issetup) is int:
            issetup = [issetup]

        self.logger.info('setting up the ensemble')
        for idx in [i for i in range(0, self.ensemble_size) if issetup[i]]:
            self.logger.info(f'{self.runids[idx]} is already setup, setup for this job will be skipped')

        # do the first job to get the code setup for linking
        if not issetup[0]:
            self._setup_job(0)
            self.db.set(column='setup', keys=self.runids[0], values=1)
            if self.submit_ensemble:
                self.submitter.submit(self.runids[0])

        # break the job setup into chunks to avoid overloading the ssh connection
        # TODO: it would be better to use a Queue here so all the jobs in the chunk didn't have to wait for the slowest
        for job_nums in divide_list([i for i in range(1, self.ensemble_size) if not issetup[i]], self.max_threads):
            threads = [Thread(target=self._setup_job, args=(job_num,)) for job_num in job_nums]

            for thread in threads:
                thread.start()
                time.sleep(5)  # avoid making too many ssh connection at once

            for thread in threads:
                thread.join()

            # sqlite database insertion is not threadsafe so keep this out of _setup_job
            for job_num in job_nums:
                runid = self.runids[job_num]
                self.db.set(column='setup', keys=runid, values=1)
                if self.submit_ensemble:
                    self.submitter.submit(runid)

        if self.config_file:
            self.copy_config_to_remote()

    def extend_ensemble(self, years: Union[int, str]):
        """
        Extend each ensemble member

        Parameters
        ----------
        years :
            Extend the ensemble runs by 'YYYY_mMM' time. Every member will be extended
        """

        for job in self.jobs:
            job.extend_run(years)
            self.db.set(column='submitted', keys=job.runid, values=0)
            self.db.set(column='jobstring', keys=job.runid, values=job.job_str)
            if self.submit_ensemble:
                self.submitter.submit(job.runid)

    def delete_ensemble(self):
        """
        Delete the ensemble from the remote machines.
        Note this will only delete files on the backend, ie. hare or brooks and not ppp1 or ppp2
        """
        for job in self.jobs[::-1]:  # go from last to first to avoid deleting "env_setup_file" until the end
            job.delete_job()

        with Connection(self.machine, self.user, gateway=Connection(self.gateway_conn, self.user)) as c:
            for directory in set(self.run_directory):
                c.run('rm -rf ' + directory)

    def copy_config_to_remote(self):
        """
        Make a copy of the yaml configuration and table files on the remote machine
        """
        # move copies of the setup and log files over to the remote machine for safe keeping
        opt = yaml.load(open(self.config_file, 'r'), Loader=yaml.SafeLoader)
        with Connection(self.machine, user=self.user, gateway=Connection(self.gateway_conn, user=self.user)) as c:
            for idx, directory in enumerate(set(self.run_directory)):
                if idx == 0:
                    remote_yaml = os.path.join(directory,
                                               str(os.path.basename(self.config_file).split('.')[0]) + '-copy.yaml')
                    c.put(os.path.realpath(self.config_file), remote_yaml)
                    with RemoteFile(remote_yaml, self.machine, self.user, self.gateway_conn, mode='a+') as f:
                        f.write(f'\n# setup by canesm-ensemble {canesm.__version__}')
                    c.run('chmod u=rw,g=r,o=r ' + remote_yaml)
                    if 'config_table' in opt.keys():
                        table = table_path(opt['config_table'], self.config_file)
                        remote_table = os.path.join(directory, str(os.path.basename(table).split('.')[0]) + '-copy.txt')
                        c.put(table, remote_table)
                        c.run('chmod u=rw,g=r,o=r ' + remote_table)
                else:
                    c.run(f'ln {remote_yaml} {os.path.join(directory, os.path.basename(remote_yaml))}')
                    if 'config_table' in opt.keys():
                        c.run(f'ln {remote_table} {os.path.join(directory, os.path.basename(remote_table))}')
