#!/usr/bin/python3
# -*- coding: utf-8 -*-
#
# Copyright (C) 2017-present Linaro Limited
#
# Author: Remi Duraffort <remi.duraffort@linaro.org>
#
# SPDX-License-Identifier: GPL-2.0-or-later

import argparse
import logging
import os
import shutil
import signal
import sys
import traceback
from pathlib import Path

from setproctitle import setproctitle

from lava_common.constants import DISPATCHER_DOWNLOAD_DIR
from lava_common.exceptions import InfrastructureError, JobCanceled, LAVABug, LAVAError
from lava_common.log import YAMLListFormatter, YAMLLogger
from lava_common.yaml import yaml_safe_dump, yaml_safe_load
from lava_dispatcher.device import NewDevice
from lava_dispatcher.parser import JobParser
from lava_dispatcher.worker import get_prefix


def parser():
    """Configure the parser"""
    p_obj = argparse.ArgumentParser()

    p_obj.add_argument(
        "--job-id",
        required=True,
        metavar="ID",
        help="Job identifier. This alters process name for easier debugging",
    )
    p_obj.add_argument(
        "--output-dir",
        required=True,
        metavar="DIR",
        type=Path,
        help="Directory for temporary resources",
    )
    p_obj.add_argument(
        "--validate",
        action="store_true",
        default=False,
        help="validate the job file, do not execute any steps. "
        "The description is saved into description.yaml",
    )

    group = p_obj.add_argument_group("logging")
    group.add_argument(
        "--url", metavar="URL", default=None, help="URL of the server to send logs"
    )
    group.add_argument(
        "--token", default=None, metavar="token", help="token for server authentication"
    )
    group.add_argument(
        "--job-log-interval",
        metavar="SECONDS",
        type=int,
        default=5,
        help="Time between two job log submissions to the server",
    )

    group = p_obj.add_argument_group("configuration files")
    group.add_argument(
        "--device",
        metavar="PATH",
        type=Path,
        required=True,
        help="Device configuration",
    )
    group.add_argument(
        "--dispatcher",
        metavar="PATH",
        type=Path,
        default=None,
        help="Dispatcher configuration",
    )
    group.add_argument(
        "--env-dut",
        metavar="PATH",
        type=Path,
        default=None,
        help="DUT environment",
    )

    p_obj.add_argument(
        "--skip-sudo-warning",
        action="store_true",
        default=False,
        help="Allow to run as a non-root user",
    )
    p_obj.add_argument(
        "--debug",
        action="store_true",
        default=False,
        help="Start remote pdb right before running the job, for debugging",
    )

    p_obj.add_argument("definition", type=Path, help="Job definition")

    return p_obj


def setup_logger(options):
    # Pipeline always log as YAML so change the base logger.
    # Every calls to logging.getLogger will now return a YAMLLogger
    logging.setLoggerClass(YAMLLogger)

    # The logger can be used by the parser and the Job object in all phases.
    logger = logging.getLogger("dispatcher")
    # Save logs to a file for debugging/post-processing.
    log_file = options.output_dir / "logs.yaml"
    file_handler = logging.FileHandler(log_file, mode="w")
    file_handler.setFormatter(YAMLListFormatter())
    logger.addHandler(file_handler)

    prefix = ""
    if options.dispatcher is not None:
        prefix = get_prefix(
            yaml_safe_load(options.dispatcher.read_text(encoding="utf-8"))
        )
    # Create the symlink for downloading via http
    # SymLinks is allowed in the apache2 lava-dispatcher.conf
    job_dir = Path(DISPATCHER_DOWNLOAD_DIR) / f"{prefix}{options.job_id}"
    log_download_path = job_dir / "logs.yaml"
    try:
        if job_dir.exists():
            shutil.rmtree(job_dir)
        job_dir.mkdir(mode=0o755, parents=True, exist_ok=True)
        log_download_path.symlink_to(log_file)
    except OSError:
        raise InfrastructureError(f"Unable to create {log_download_path}")

    if options.url is not None:
        logger.addHTTPHandler(
            f"{options.url}/scheduler/internal/v1/jobs/{options.job_id}/logs/",
            options.token,
            options.job_log_interval,
            options.job_id,
        )
    else:
        logger.addHandler(logging.StreamHandler())

    return logger


def parse_job_file(logger, options):
    """
    Uses the parsed device_config instead of the old Device class
    so it can fail before the Pipeline is made.
    Avoids loading all configuration for all supported devices for every job.
    """
    # Prepare the pipeline from the file using the parser.
    device = None  # secondary connections do not need a device
    device = NewDevice(options.device)
    parser = JobParser()

    # Load the configuration files (this should *not* fail)
    env_dut = None
    if options.env_dut is not None:
        env_dut = options.env_dut.read_text(encoding="utf-8")
    dispatcher_config = None
    if options.dispatcher is not None:
        dispatcher_config = options.dispatcher.read_text(encoding="utf-8")

    # Generate the pipeline
    return parser.parse(
        options.definition.read_text(encoding="utf-8"),
        device,
        options.job_id,
        logger=logger,
        dispatcher_config=dispatcher_config,
        env_dut=env_dut,
    )


def cancelling_handler(*_):
    """
    Catches most signals and raise JobCanceled (inherit from LAVAError).
    The exception will go through all the stack frames cleaning and reporting
    the error.
    """
    signal.signal(signal.SIGHUP, terminating_handler)
    signal.signal(signal.SIGINT, terminating_handler)
    signal.signal(signal.SIGQUIT, terminating_handler)
    signal.signal(signal.SIGTERM, terminating_handler)
    signal.signal(signal.SIGUSR1, job_not_found_handler)
    signal.signal(signal.SIGUSR2, terminating_handler)
    raise JobCanceled("The job was canceled")


def job_not_found_handler(*_):
    raise JobCanceled("The job was deleted on the lava-server")


def terminating_handler(*_):
    """
    Second signal handler to notify to the user that the job was canceled twice
    """
    raise JobCanceled("The job was canceled again (too long to cancel)")


def main():
    # Parse the command line
    options = parser().parse_args()

    # Check that we are running as root
    if not options.skip_sudo_warning and os.geteuid() != 0:
        print("lava-run should be executed as root")
        return 1

    # Set process title for easier debugging
    setproctitle("lava-run [job: %s]" % options.job_id)

    # Should be an absolute directory
    options.output_dir = options.output_dir.resolve()
    # Create the output directory
    try:
        options.output_dir.mkdir(mode=0o755, parents=True, exist_ok=True)
    except OSError:
        raise InfrastructureError("Unable to create %s" % options.output_dir)

    # Setup the logger as early as possible
    logger = setup_logger(options)
    if not logger:
        print("lava-run failed to setup logging")
        return 1

    # By default, that's a failure
    success = False
    error_help = error_msg = error_type = None
    try:
        # Set the signal handler
        signal.signal(signal.SIGHUP, cancelling_handler)
        signal.signal(signal.SIGINT, cancelling_handler)
        signal.signal(signal.SIGQUIT, cancelling_handler)
        signal.signal(signal.SIGTERM, cancelling_handler)
        signal.signal(signal.SIGUSR1, job_not_found_handler)
        signal.signal(signal.SIGUSR2, cancelling_handler)

        # Parse the definition and create the job object
        job = parse_job_file(logger, options)

        # Add secrets to the environment
        for k, v in job.parameters.get("secrets", {}).items():
            os.environ[k] = v

        # Generate the description
        description = yaml_safe_dump(job.describe())
        (options.output_dir / "description.yaml").write_text(
            description, encoding="utf-8"
        )

        job.validate()
        if not options.validate:
            if options.debug:
                from remote_pdb import set_trace

                set_trace()
            job.run()

    except LAVAError as exc:
        if isinstance(exc, LAVABug):
            # Log LAVABug traceback so that it is easier to debug
            logger.exception(traceback.format_exc())

        error_help = exc.error_help
        error_msg = str(exc)
        error_type = exc.error_type
    except BaseException as exc:
        logger.exception(traceback.format_exc())
        error_help = LAVABug.error_help
        error_msg = str(exc)
        error_type = LAVABug.error_type
    else:
        success = True
    finally:
        result_dict = {"definition": "lava", "case": "job"}
        if success:
            result_dict["result"] = "pass"
            logger.info("Job finished correctly")
        else:
            result_dict["result"] = "fail"
            result_dict["error_msg"] = error_msg or "Unknown error"
            result_dict["error_type"] = error_type or "Unknown"
            logger.error(error_help)
        logger.results(result_dict)

    def logger_terminate_handler(*_):
        logger.terminate()
        sys.stderr.write(
            "Job log upload process terminated due to timeout; logs may be incomplete.\n"
        )
        sys.stderr.flush()

    # At this point, the job run/cancellation has finished. Change signal
    # handler here to prevent 'JobCanceled' from being raised again. In case of
    # job cancellation, the logger has FINISH_MAX_DURATION/2 seconds to finish
    # log uploading and close itself. After that another SIGTERM signal will be
    # sent by the worker to terminate the logger. If the logger is not
    # terminated in FINISH_MAX_DURATION seconds, the lava-run process group
    # will be killed by the worker.
    signal.signal(signal.SIGHUP, logger_terminate_handler)
    signal.signal(signal.SIGINT, logger_terminate_handler)
    signal.signal(signal.SIGQUIT, logger_terminate_handler)
    signal.signal(signal.SIGTERM, logger_terminate_handler)
    signal.signal(signal.SIGUSR1, logger_terminate_handler)
    signal.signal(signal.SIGUSR2, logger_terminate_handler)

    # Closing the socket. We are now sure that all messages were sent.
    logger.close()

    # Save the results file
    (options.output_dir / "result.yaml").write_text(
        yaml_safe_dump(result_dict), encoding="utf-8"
    )

    return 0 if success else 1


if __name__ == "__main__":
    sys.exit(main())
