#!/usr/bin/python3
"""
    Copyright (C) 2021 Michael Ablassmeier <abi@grinser.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""
import os
import sys
import tempfile
import logging
import argparse
import pprint
import json

from libvirtnbdbackup import argopt
from libvirtnbdbackup import __version__
from libvirtnbdbackup import nbdhelper
from libvirtnbdbackup import libvirthelper
from libvirtnbdbackup import qemuhelper
from libvirtnbdbackup import outputhelper
from libvirtnbdbackup import exceptions as baseexception
from libvirtnbdbackup import sshutil
from libvirtnbdbackup.common import common as lib
from libvirtnbdbackup.logcount import logCount
from libvirtnbdbackup.sparsestream import streamer
from libvirtnbdbackup.sparsestream import types
from libvirtnbdbackup.sparsestream import exceptions


def dump(args, stream, dataFiles):
    """Dump stream contents to json output"""
    logging.info("Dumping saveset meta information")
    entries = []
    for dataFile in dataFiles:
        if args.disk is not None and not os.path.basename(dataFile).startswith(
            args.disk
        ):
            continue
        logging.info(dataFile)

        sourceFile = dataFile
        if args.sequence:
            sourceFile = f"{args.input}/{dataFile}"
        meta = getHeader(sourceFile, stream)

        if not meta:
            continue

        entries.append(meta)

        if lib.isCompressed(meta):
            logging.info("Compressed stream found: [%s].", meta["compressionMethod"])

    print(json.dumps(entries, indent=4))

    return True


def restoreData(args, stream, dataFile, targetFile, nbdClient, connection):
    """Restore data for disk"""
    sTypes = types.SparseStreamTypes()

    try:
        reader = open(dataFile, "rb")
    except OSError as errmsg:
        logging.critical("Failed to open backup file for reading: [%s].", errmsg)
        return False

    try:
        kind, start, length = stream.readFrame(reader)
        meta = stream.loadMetadata(reader.read(length))
    except exceptions.StreamFormatException as errmsg:
        logging.fatal(errmsg)
        raise baseexception.RestoreError from errmsg

    trailer = None
    if lib.isCompressed(meta) is True:
        trailer = stream.readCompressionTrailer(reader)
        logging.info("Found compression trailer.")
        logging.debug("%s", trailer)

    if meta["dataSize"] == 0:
        logging.info("File [%s] contains no dirty blocks, skipping.", dataFile)
        return True

    logging.info(
        "Applying data from backup file [%s] to target file [%s].", dataFile, targetFile
    )
    pprint.pprint(meta)
    assert reader.read(len(sTypes.TERM)) == sTypes.TERM

    progressBar = lib.progressBar(
        meta["dataSize"], f"restoring disk [{meta['diskName']}]", args
    )
    dataSize = 0
    dataBlockCnt = 0
    while True:
        try:
            kind, start, length = stream.readFrame(reader)
        except exceptions.StreamFormatException as err:
            logging.error("Cant read stream at pos: [%s]: [%s]", reader.tell(), err)
            raise baseexception.RestoreError from err
        if kind == sTypes.ZERO:
            logging.debug("Zero segment from [%s] length: [%s]", start, length)
        elif kind == sTypes.DATA:
            logging.debug(
                "Processing data segment from [%s] length: [%s]", start, length
            )

            originalSize = length
            if trailer:
                logging.debug("Block: [%s]", dataBlockCnt)
                logging.debug("Original block size: [%s]", length)
                length = trailer[dataBlockCnt]
                logging.debug("Compressed block size: [%s]", length)

            if originalSize >= nbdClient.maxRequestSize:
                logging.debug(
                    "Chunked read/write, start: [%s], len: [%s]", start, length
                )
                try:
                    written = lib.readChunk(
                        reader,
                        start,
                        length,
                        nbdClient.maxRequestSize,
                        connection,
                        lib.isCompressed(meta),
                    )
                except Exception as e:
                    raise baseexception.RestoreError from e
                logging.debug("Wrote: [%s]", written)
            else:
                try:
                    data = reader.read(length)
                    if lib.isCompressed(meta):
                        data = lib.lz4DecompressFrame(data)
                    connection.pwrite(data, start)
                    written = len(data)
                except Exception as e:
                    raise baseexception.RestoreError from e

            assert reader.read(len(sTypes.TERM)) == sTypes.TERM
            dataSize += originalSize
            progressBar.update(written)
            dataBlockCnt += 1
        elif kind == sTypes.STOP:
            progressBar.close()
            if dataSize != meta["dataSize"]:
                logging.error(
                    "Restored data size does not match [%s] != [%s]",
                    dataSize,
                    meta["dataSize"],
                )
                raise baseexception.RestoreError
            break

    logging.info("End of stream, [%s] bytes of data processed", dataSize)
    if meta["checkpointName"] == args.until:
        logging.info("Reached checkpoint [%s], stopping", args.until)
        raise baseexception.UntilCheckpointReached

    return True


def restoreSequence(args, dataFiles, virtClient, sshClient):
    """Reconstruct image from a given set of data files"""
    stream = streamer.SparseStream(types)

    for disk in dataFiles:
        sourceFile = f"{args.input}/{disk}"

        meta = getHeader(sourceFile, stream)
        if not meta:
            return False

        diskName = meta["diskName"]
        targetFile = f"{args.output}/{diskName}"

        if not lib.exists(targetFile) and not createDiskFile(
            meta, targetFile, sshClient
        ):
            return False

        nbdClient, connection = startNbd(
            args, diskName, targetFile, virtClient, sshClient
        )

        result = writeData(args, stream, sourceFile, targetFile, nbdClient, connection)

        nbdClient.disconnect()

    return result


def writeData(args, stream, disk, targetFile, nbdClient, connection):
    """Restore the data stream to the target file"""
    diskState = False
    diskState = restoreData(args, stream, disk, targetFile, nbdClient, connection)
    # no data has been processed
    if diskState is None:
        diskState = True

    return diskState


def createDiskFile(meta, targetFile, sshClient):
    """Create target image file"""
    logging.info(
        "Create virtual disk [%s] format: [%s] size: [%s]",
        targetFile,
        meta["diskFormat"],
        meta["virtualSize"],
    )

    if lib.exists(targetFile, sshClient):
        logging.error("Target file already exists: [%s], wont overwrite.", targetFile)
        return False
    qFh = qemuhelper.qemuHelper(meta["diskName"])
    try:
        qFh.create(targetFile, meta["virtualSize"], meta["diskFormat"], sshClient)
    except qemuhelper.exceptions.ProcessError as e:
        logging.error("Cant create restore target: [%s]", e)
        return False

    return True


def getHeader(diskFile, stream):
    """Read header from data file"""
    try:
        return lib.dumpMetaData(diskFile, stream)
    except exceptions.StreamFormatException as errmsg:
        logging.error("Reading metadata from [%s] failed: [%s]", diskFile, errmsg)
        return False
    except outputhelper.exceptions.OutputException as errmsg:
        logging.error("Reading data file [%s] failed: [%s]", diskFile, errmsg)
        return False


def startNbd(args, exportName, targetFile, virtClient, sshClient):
    """Start NBD service for restore"""
    qFh = qemuhelper.qemuHelper(exportName)
    if not virtClient.remoteHost:
        socketFile = lib.getSocketFile(args.socketfile)
        logging.info("Starting NBD server on socket: [%s]", socketFile)
        err = qFh.startRestoreNbdServer(targetFile, socketFile)
        if err.err is not None:
            logging.error("Unable to start NBD server: [%s]", err)
            return False
        cType = nbdhelper.nbdConnUnix(exportName, None, socketFile)
    else:
        remoteIP = virtClient.remoteHost
        if args.nbd_ip is not None:
            remoteIP = args.nbd_ip
        err = qFh.startRemoteRestoreNbdServer(args, sshClient, targetFile)
        cType = nbdhelper.nbdConnTCP(
            exportName, None, remoteIP, args.tls, args.nbd_port
        )

    nbdClient = nbdhelper.nbdClient(cType)
    return nbdClient, nbdClient.waitForServer()


def readConfig(vmConfig):
    """Read saved virtual machine config'"""
    try:
        return outputhelper.openfile(vmConfig, "rb").read().decode()
    except:
        logging.error("Cant read config file: [%s]", vmConfig)
        raise


def getDisksFromConfig(args, vmConfig, virtClient):
    """Parse disk information from latest config file
    contained in the backup directory
    """
    config = readConfig(vmConfig)
    return virtClient.getDomainDisks(args, config)


def checkBackingStore(args, disk):
    """If an virtual machine was running on an snapshot image,
    warn user, the virtual machine configuration has to be
    adjusted before starting the VM is possible"""
    if len(disk.backingstores) > 0 and not args.adjust_config:
        logging.warning(
            "Target image [%s] seems to be a snapshot image.", disk.filename
        )
        logging.warning("Target virtual machine configuration must be altered!")
        logging.warning("Configured backing store images must be changed.")


def restoreFiles(args, vmConfig, virtClient, sshClient):
    """Notice user if backed up vm had loader / nvram"""
    config = readConfig(vmConfig)
    info = virtClient.getDomainInfo(config)

    for setting, val in info.items():
        f = lib.getLatest(args.input, f"*{os.path.basename(val)}*", -1)
        if lib.exists(val, sshClient):
            logging.info(
                "File [%s]: for boot option [%s] already exists, skipping.",
                val,
                setting,
            )
            continue

        logging.info(
            "Restoring configured file [%s] for boot option [%s]", val, setting
        )
        lib.copy(f, val, sshClient)


def setTargetFile(args, disk):
    """Based on disk information, set target file
    to write"""
    if disk.filename is not None:
        targetFile = f"{args.output}/{disk.filename}"
    else:
        targetFile = f"{args.output}/{disk.target}"

    return targetFile


def restore(args, vmConfig, virtClient, sshClient):
    """Handle restore operation"""
    stream = streamer.SparseStream(types)
    vmDisks = getDisksFromConfig(args, vmConfig, virtClient)
    vmConfig = readConfig(vmConfig)
    if not vmDisks:
        return False, False

    for disk in vmDisks:
        if args.disk not in (None, disk.target):
            logging.info("Skipping disk [%s] for restore", disk.target)
            continue

        restoreDisk = lib.getLatest(args.input, f"{disk.target}*.data")
        logging.debug("Restoring disk: [%s]", restoreDisk)
        if len(restoreDisk) < 1:
            logging.warning(
                "No backup file for disk [%s] found, assuming it has been excluded.",
                disk.target,
            )
            if args.adjust_config is True:
                restConfig = virtClient.adjustDomainConfigRemoveDisk(
                    vmConfig, disk.target
                )
            continue

        targetFile = setTargetFile(args, disk)

        if args.raw and disk.format == "raw":
            logging.info("Restoring raw image to [%s]", targetFile)
            lib.copy(restoreDisk[0], targetFile, sshClient)
            continue

        if "full" not in restoreDisk[0] and "copy" not in restoreDisk[0]:
            logging.error(
                "[%s]: Unable to locate base full or copy backup.", restoreDisk[0]
            )
            return False, False

        meta = getHeader(restoreDisk[0], stream)
        if not meta:
            logging.error("Reading metadata from [%s] failed", restoreDisk[0])
            return False, False

        if not createDiskFile(meta, targetFile, sshClient):
            return False, False

        nbdClient, connection = startNbd(
            args, meta["diskName"], targetFile, virtClient, sshClient
        )

        result = False
        for dataFile in restoreDisk:
            try:
                result = writeData(
                    args, stream, dataFile, targetFile, nbdClient, connection
                )
            except baseexception.UntilCheckpointReached:
                result = True
                break
            except baseexception.RestoreError:
                result = False
                break

        checkBackingStore(args, disk)
        if args.adjust_config is True:
            vmConfig = restConfig = virtClient.adjustDomainConfig(
                args, disk, vmConfig, targetFile
            )
        else:
            restConfig = vmConfig

    return result, restConfig


def restoreConfig(args, vmConfig, adjustedConfig, sshClient):
    """Restore either original or adjusted vm configuration
    to new directory"""
    targetFile = f"{args.output}/{os.path.basename(vmConfig)}"
    if args.adjust_config is True:
        if sshClient:
            with tempfile.NamedTemporaryFile(delete=True) as fh:
                fh.write(adjustedConfig)
                lib.copy(fh.name, targetFile, sshClient)
        else:
            with outputhelper.openfile(targetFile, "wb") as cnf:
                cnf.write(adjustedConfig)
            logging.info("Adjusted config placed in: [%s]", targetFile)
        if args.define is False:
            logging.info("Use 'virsh define %s' to define VM", targetFile)
    else:
        lib.copy(vmConfig, targetFile, sshClient)
        logging.info("Copied original vm config to [%s]", targetFile)
        logging.info("Note: virtual machine config must be adjusted manually.")


def main():
    """main function"""
    parser = argparse.ArgumentParser(
        description="Restore virtual machine disks",
        epilog=(
            "Examples:\n"
            "   # Dump backup metadata:\n"
            "\t%(prog)s -i /backup/ -o dump\n"
            "   # Complete restore with all disks:\n"
            "\t%(prog)s -i /backup/ -o /target\n"
            "   # Complete restore, adjust config and redefine vm after restore:\n"
            "\t%(prog)s -cD -i /backup/ -o /target\n"
            "   # Complete restore, adjust config and redefine vm with name 'foo':\n"
            "\t%(prog)s -cD --name foo -i /backup/ -o /target\n"
            "   # Restore only disk 'vda':\n"
            "\t%(prog)s -i /backup/ -o /target -d vda\n"
            "   # Point in time restore:\n"
            "\t%(prog)s -i /backup/ -o /target --until virtnbdbackup.2\n"
            "   # Restore and process specific file sequence:\n"
            "\t%(prog)s -i /backup/ -o /target "
            "--sequence vdb.full.data,vdb.inc.virtnbdbackup.1.data\n"
            "   # Restore to remote system:\n"
            "\t%(prog)s -U qemu+ssh://root@remotehost/system"
            " --ssh-user root -i /backup/ -o /remote_target"
        ),
        formatter_class=argparse.RawTextHelpFormatter,
    )
    opt = parser.add_argument_group("General options")
    opt.add_argument(
        "-a",
        "--action",
        required=False,
        type=str,
        choices=["dump", "restore"],
        default="restore",
        help="Action to perform: (default: %(default)s)",
    )
    opt.add_argument(
        "-i",
        "--input",
        required=True,
        type=str,
        help="Directory including a backup set",
    )
    opt.add_argument(
        "-o", "--output", required=True, type=str, help="Restore target directory"
    )
    opt.add_argument(
        "-u",
        "--until",
        required=False,
        type=str,
        help="Restore only until checkpoint, point in time restore.",
    )
    opt.add_argument(
        "-s",
        "--sequence",
        required=False,
        type=str,
        default=None,
        help="Restore image based on specified backup files.",
    )
    opt.add_argument(
        "-d",
        "--disk",
        required=False,
        type=str,
        default=None,
        help="Process only disk matching target dev name. (default: %(default)s)",
    )
    opt.add_argument(
        "-n",
        "--noprogress",
        required=False,
        action="store_true",
        default=False,
        help="Disable progress bar",
    )
    opt.add_argument(
        "-f",
        "--socketfile",
        default=None,
        type=str,
        help="Use specified file for NBD Server socket instead of random file",
    )
    opt.add_argument(
        "-r",
        "--raw",
        default=False,
        action="store_true",
        help="Copy raw images as is during restore. (default: %(default)s)",
    )
    opt.add_argument(
        "-c",
        "--adjust-config",
        default=False,
        action="store_true",
        help="Adjust vm configuration during restore. (default: %(default)s)",
    )
    opt.add_argument(
        "-D",
        "--define",
        default=False,
        action="store_true",
        help="Register/define VM after restore. (default: %(default)s)",
    )
    opt.add_argument(
        "-N",
        "--name",
        default=None,
        type=str,
        help="Define restored domain with specified name",
    )
    remopt = parser.add_argument_group("Remote Restore options")
    argopt.addRemoteArgs(remopt)
    logopt = parser.add_argument_group("Logging options")
    argopt.addLogArgs(logopt, parser.prog)
    debopt = parser.add_argument_group("Debug options")
    argopt.addDebugArgs(debopt)

    args = lib.argparse(parser)

    # default values for common usage of lib.getDomainDisks
    args.exclude = None
    args.include = args.disk
    stream = streamer.SparseStream(types)
    fileLog = lib.getLogFile(args.logfile) or sys.exit(1)
    counter = logCount()

    lib.configLogger(args, fileLog, counter)
    lib.printVersion(__version__)

    if not lib.exists(args.input):
        logging.error("Backup source [%s] does not exist.", args.input)
        sys.exit(1)

    if args.sequence is not None:
        logging.info("Using manual specified sequence of files.")
        logging.info("Disabling redefine and config adjust options.")
        args.define = False
        args.adjust_config = False
        dataFiles = args.sequence.split(",")

        if "full" not in dataFiles[0] and "copy" not in dataFiles[0]:
            logging.error("Sequence must start with full or copy backup.")
            sys.exit(1)
    else:
        dataFiles = lib.getLatest(args.input, "*.data") or (
            logging.error("No data files found in directory: [%s]", args.input),
            sys.exit(1),
        )

    if args.action == "dump" or args.output == "dump":
        dump(args, stream, dataFiles)
        sys.exit(0)

    if args.action == "restore":
        virtClient = libvirthelper.client(args)
        sshClient = None
        if virtClient.remoteHost:
            try:
                sshClient = sshutil.Client(
                    virtClient.remoteHost, args.ssh_user, sshutil.Mode.UPLOAD
                )
            except sshutil.exceptions.sshutilError as e:
                logging.error("Unable to setup ssh connection: [%s]", e)
                sys.exit(1)

        vmConfig = lib.getLatest(args.input, "vmconfig*.xml", -1) or (
            logging.error("No domain config file found"),
            sys.exit(1),
        )
        logging.info("Using latest config file: [%s]", vmConfig)

        if not virtClient.remoteHost:
            outputhelper.outputHelper.Directory(args.output)
        else:
            sshClient.sftp.mkdir(args.output)

        ret = False
        if args.sequence is not None:
            restConfig = None
            ret = restoreSequence(args, dataFiles, virtClient, sshClient)
        else:
            ret, restConfig = restore(args, vmConfig, virtClient, sshClient)

        if ret is False:
            logging.fatal("Disk restore failed.")
            sys.exit(1)

        restoreFiles(args, vmConfig, virtClient, sshClient)
        if args.define is True and args.adjust_config is False:
            logging.warning(
                "Redefine domain needs adjusted config, please "
                "use --adjust-config option. Skipping.."
            )
        restoreConfig(args, vmConfig, restConfig, sshClient)
        virtClient.refreshPool(args.output)
        if args.define is True and args.adjust_config is True:
            if not virtClient.defineDomain(restConfig):
                sys.exit(1)


if __name__ == "__main__":
    main()
