#!/usr/bin/env python3
# Copyright (c) Microsoft. All rights reserved.

# This is the utility to create and modify mssql.conf which contains
# the configurable settings to be used by SQL Server on Linux

import argparse
from argparse import SUPPRESS
import sys
import os
import os.path
import mssqlconfhelper
import mssqlsettings
import mssqlsettingsmanager
from configparser import ConfigParser
import logging
import logging.handlers
import mssqlad
import pwd
import atexit
import pyadutil
import pathlib
import functools

logger = mssqlconfhelper.logger
_ = mssqlconfhelper._

# Various constants and defaults for initializing the logger.
# While the logger is defined in mssqlconfhelper.py, we set it up
# here. It is easier to import mssqlconfhelper, but this is the main
# application entrypoint.
#
defaultLogFile = mssqlconfhelper.sqlPathLogDir + "/mssql-conf/mssql-conf.log"
defaultLogFormatter = logging.Formatter("%(asctime)s:%(levelname)s:%(name)s: %(message)s")
defaultLogLevel = logging.INFO
logFilePermissions = 0o660
logFileDirectoryPermissions = 0o770
logFileHandlers = dict()

def handleWelcome():
    """Handle welcome command
    """

    if args.welcometype == "engine":
        exit(mssqlconfhelper.printEngineWelcomeMessage())
    elif args.welcometype == "agent":
        exit(mssqlconfhelper.printAgentWelcomeMessage())
    elif args.welcometype == "fts":
        exit(mssqlconfhelper.printFTSWelcomeMessage())
    elif args.welcometype == "polybase":
        exit(mssqlconfhelper.printPolyBaseWelcomeMessage())
    elif args.welcometype == "polybase-hadoop":
        exit(mssqlconfhelper.printPolyBaseHadoopWelcomeMessage())
    elif args.welcometype == "machinelearning":
        exit(mssqlconfhelper.printMachineLearningServicesWelcomeMessage())
    else:
        exit(mssqlconfhelper.errorExitCode)

def handleList():
    """Handle list command
    """

    mssqlconfhelper.listSupportedSettings(mssqlsettingsmanager.supportedSettingsList)

def handleSetup():
    """Handle setup command
    """

    if not mssqlconfhelper.checkSudo():
        mssqlconfhelper.printError(_("Setup must be run as the superuser. This can be done using 'sudo mssql-conf setup'."))
        exit(mssqlconfhelper.errorExitCode)

    eulaAccepted = False
    eulaMlAccepted = False

    if args.setup_option != None:
        if args.setup_option == "accept-eula":
            exit(mssqlconfhelper.setupSqlServer(True, noprompt=args.noprompt))
        elif args.setup_option == "accept-eula-ml":
            if mssqlconfhelper.isMlServicesInstalled():
                exit(mssqlconfhelper.setupSqlServerMlServices(True))
            else:
                print(_("Machine Learning services has not been installed."))
        else:
            print(_("The setup option of %s is not supported.") % (args.setup_option))
            print(_("The supported setup options are accept-eula and accept-eula-ml."))
            exit(mssqlconfhelper.errorExitCode)
    else:
        if "ACCEPT_EULA" in os.environ:
            eulaAccepted = True

        if "ACCEPT_EULA_ML" in os.environ:
            eulaMlAccepted = True

        mssqlconfhelper.setupSqlServer(eulaAccepted, noprompt=args.noprompt)

        if mssqlconfhelper.isMlServicesInstalled():
            mssqlconfhelper.setupSqlServerMlServices(eulaMlAccepted)

def handleValidate():
    """Handle validate command
    """

    if not os.path.exists(mssqlconfhelper.configurationFilePath):
        exit(mssqlconfhelper.successExitCode)

    # We set allow_no_value to be True so that validate can actually parse any syntax issues as well
    #
    config = ConfigParser(allow_no_value=True)
    mssqlconfhelper.readConfigFromFile(config, mssqlconfhelper.configurationFilePath)

    if mssqlsettingsmanager.validateConfig(config):
        print(_("Validation successful."))
        exit(mssqlconfhelper.successExitCode)

    exit(mssqlconfhelper.errorExitCode)

def handleSet():
    """Handle set command
    """

    config = ConfigParser()
    mssqlconfhelper.readConfigFromFile(config, mssqlconfhelper.configurationFilePath)
    settingChosen = mssqlsettingsmanager.findSetting(args.section_name, args.setting_name)

    if settingChosen is None:
        mssqlconfhelper.printErrorUnsupportedSetting(args.section_name, args.setting_name)

    if settingChosen.sectionOnly:
        if mssqlsettingsmanager.setSectionKeyVal(config, settingChosen, args.setting_name, args.setting_value):
            mssqlconfhelper.writeConfigToFile(config, mssqlconfhelper.configurationFilePath)
    elif mssqlsettingsmanager.setSetting(config, settingChosen, args.setting_value):
            mssqlconfhelper.writeConfigToFile(config, mssqlconfhelper.configurationFilePath)
            if settingChosen.section == mssqlsettings.SectionForSetting.extensibility:
                mssqlconfhelper.printLaunchpadRestartRequiredMessage()

def handleUnset():
    """Handle unset command
    """

    settingToRemove = mssqlsettingsmanager.findSetting(args.section_name, args.setting_name)

    if settingToRemove is None:
        mssqlconfhelper.printErrorUnsupportedSetting(args.section_name, args.setting_name)

    config = ConfigParser()
    mssqlconfhelper.readConfigFromFile(config, mssqlconfhelper.configurationFilePath)

    if settingToRemove.sectionOnly:
        if mssqlsettingsmanager.unsetSectionKeyVal(config, settingToRemove, args.setting_name):
            mssqlconfhelper.writeConfigToFile(config, mssqlconfhelper.configurationFilePath)
    elif mssqlsettingsmanager.unsetSetting(config, settingToRemove):
            if settingToRemove.restart_required:
                mssqlconfhelper.printRestartRequiredMessage()
            mssqlconfhelper.writeConfigToFile(config, mssqlconfhelper.configurationFilePath)
            if settingToRemove.section == mssqlsettings.SectionForSetting.extensibility:
                mssqlconfhelper.printLaunchpadRestartRequiredMessage()

def handleSetSaPassword():
    """Handle set SA password command
    """

    ret = mssqlconfhelper.getSystemAdministratorPassword(noprompt=args.noprompt)

    if (ret == mssqlconfhelper.errorExitCode or ret == None):
        exit(mssqlconfhelper.errorExitCode)

    ret = mssqlconfhelper.configureSqlservrWithArguments("--setup --reset-sa-password", MSSQL_SA_PASSWORD=ret)

    if(ret == 0):
        print(_("The system administrator password has been changed."))
        mssqlconfhelper.printStartSqlServerMessage()
        exit(mssqlconfhelper.successExitCode)
    else:
        print(_("Unable to set the system administrator password. Please consult the ERRORLOG"))
        print(_("in %s for more information.") % (mssqlconfhelper.getErrorLogFile()))
        exit(ret)

def handleTraceflag():
    """Handle traceflag command
    """

    config = ConfigParser()
    mssqlconfhelper.readConfigFromFile(config, mssqlconfhelper.configurationFilePath)

    traceflagsList = args.traceflags_list
    if (args.tf_choice == 'on'):
        for traceflag in traceflagsList:
            mssqlsettingsmanager.addTraceFlag(config, traceflag)
    elif (args.tf_choice == 'off'):
        for traceflag in traceflagsList:
            mssqlsettingsmanager.removeTraceFlag(config, traceflag)

    mssqlconfhelper.writeConfigToFile(config, mssqlconfhelper.configurationFilePath)

def handleSetCollation():
    """Handle set collation command
    """

    collation = input('Enter the collation: ')

    if not mssqlconfhelper.validateCollation(collation):
        exit(mssqlconfhelper.errorExitCode)

    ret = mssqlconfhelper.configureSqlservrWithArguments("-q%s" % (collation))

    if(ret == 0):
        print(_("The server collation has been changed."))
        mssqlconfhelper.printStartSqlServerMessage()
        exit(mssqlconfhelper.successExitCode)
    else:
        exit(mssqlconfhelper.errorExitCode)

def handleSetEdition():
    """Handle the set-edition command
    """

    if mssqlconfhelper.checkRunningInstance():
        exit(mssqlconfhelper.errorExitCode)

    pid = mssqlconfhelper.getPid(args.noprompt)

    if pid is None:
        exit(mssqlconfhelper.errorExitCode)

    ret = mssqlconfhelper.configureSqlservrWithArguments("--setup", MSSQL_PID=pid)

    if ret == mssqlconfhelper.successExitCode:
        mssqlconfhelper.printStartSqlServerMessage()
    else:
        exit(ret)


def handleGet():
    """
    Handle the get command which gets the setting values of the entire section specified or of a particular setting
    """
    settingsDictionary = mssqlconfhelper.getSettings(mssqlconfhelper.configurationFilePath,
                                                     args.section_name, args.setting_name)
    if not len(settingsDictionary):
        print(_("No setting for the given option found in '%s'.") % (mssqlconfhelper.configurationFilePath))

    for setting_name, setting_value in list(settingsDictionary.items()):
        print("%s : %s" % (setting_name, setting_value))

def handleADCommon():
    """
    Perform common setup for validating and setting up AD. Right now this is just
    setting the location of adutil if specified by the user but might be expanded
    in the future.
    """
    if args.adutil_path != "":
        # Just set globally since this is the script top level.
        #
        pyadutil.adutilLocation = args.adutil_path

def handleValidateADConfig():
    """
    Handle the check-ad-config command
    """

    if args.adutil_path != "":
        with pyadutil.withAdutilLocation(args.adutil_path):
            config = mssqlad.adconfig(args.keytab, args.realm)
            config.validate()
    else:
        config = mssqlad.adconfig(args.keytab, args.realm)
        config.validate()

def handleSetupADKeytab():
    """
    Handle the setup-ad-keytab command
    """
    interactively = not args.noprompt

    if args.kvno is not None:
        inputKvno = args.kvno
    elif args.useNextKvno:
        inputKvno = mssqlad.USE_NEXT_KVNO
    else:
        inputKvno = mssqlad.USE_CURRENT_KVNO


    if args.adutil_path != "":
        with pyadutil.withAdutilLocation(args.adutil_path):
            mssqlad.setupADKeytab(args.keytab, args.user, interactively=interactively, inputKvno=inputKvno)
    else:
        mssqlad.setupADKeytab(args.keytab, args.user, interactively=interactively, inputKvno=inputKvno)

# Command dispatch table
#
COMMAND_TABLE = {
    "welcome": handleWelcome,
    "setup": handleSetup,
    "validate": handleValidate,
    "list": handleList,
    "set": handleSet,
    "unset": handleUnset,
    "traceflag": handleTraceflag,
    "set-sa-password": handleSetSaPassword,
    "set-collation": handleSetCollation,
    "set-edition": handleSetEdition,
    "get": handleGet,
    "validate-ad-config": handleValidateADConfig,
    "setup-ad-keytab": handleSetupADKeytab
}

def processCommands():
    """Process commands
    """
    logger.info("Executing command: [%s]", args.which)
    COMMAND_TABLE[args.which]()
    logger.info("Executed command [%s] successfully", args.which)
    exit(mssqlconfhelper.successExitCode)

@functools.lru_cache(maxsize=None)
def initializeLogging():
    """ Sets up the python logging framework to output logs to the defaultLogFile.
        These logs are rotated at 2MB with the 10 latest of these logs kept on disk.
        Any logging that is done before this function is called will likely not be captured anywhere.
    """
    # Any handlers (i.e. the file handler and the console handler) added to rootLogger will propagate to non-root loggers automatically.
    #
    rootLogger = logging.getLogger()
    rootLogger.setLevel(defaultLogLevel)

    # Initialize console logging first so that we can diagnose log file issues.
    #
    if args.verbose:
        streamHandler = logging.StreamHandler(sys.stdout)
        streamHandler.setFormatter(defaultLogFormatter)
        streamHandler.setLevel(defaultLogLevel)
        rootLogger.addHandler(streamHandler)
        logger.info("Enabled mssql-conf logging to stdout.")
    else:
        # Need to diable logging so that it doesn't go to console until logging to file setup.
        #
        logging.disable(level=logging.CRITICAL)

    if not args.quiet:
        initializeLoggingFile()

    fixLogFilePermissions()
    atexit.register(fixLogFilePermissions)
    logger.info("Added shutdown hook to ensure ownership on the log files is correct.")

    logger.info("Logging has been fully configured.")

def initializeLoggingFile(filename=defaultLogFile, logLevel=defaultLogLevel):
    """ Sets up the python logging framework to output logs to the defaultLogFile.
        These logs are rotated at 2MB with the 10 latest of these logs kept on disk.
        Any logging that is done before this function is called will likely not be captured anywhere.
        Unlike initializeLogging(), this function does not require args to be setup initialized.
        It is recommended to use this function for unit tests since it can specify an alternative file for logging.
    """
    filename = os.path.abspath(filename)
    rootLogger = logging.getLogger()
    rootLogger.setLevel(logLevel)

    # This is necessary mainly for tests when multiple test
    # modules are run in a single python instance.
    #
    if filename in logFileHandlers:
        # This will actually turn off the disable that we may have turned on earlier
        # since we pass NOTSET.
        #
        logging.disable(logging.NOTSET)
        fileHandler = logFileHandlers[filename]
        fileHandler.setLevel(logLevel)
        logger.info("Not initializing [%s] as log file because it is already initialized.", filename)
        return

    try:
        ensureLogFileExists(filename)
        fileHandler = logging.handlers.RotatingFileHandler(filename, maxBytes=2 * 1000 * 1000, backupCount=2)
        fileHandler.setFormatter(defaultLogFormatter)
        fileHandler.setLevel(logLevel)
        rootLogger.addHandler(fileHandler)
        logging.disable(logging.NOTSET)
        logFileHandlers[filename] = fileHandler
    except (OSError, LogFileAbsentException):
        logger.exception("Logging to file resulted in exception")
        print(_("Warning: could not create log file for mssql-conf at %s.") % filename)

def formatFileMode(mode):
    """ Formats the input file mode in octal padded with zeroes.
    """
    return "{0:0<3o}".format(logFilePermissions)

class LogFileAbsentException(Exception):
    """ Exception which is thrown due to the log file not existing.
    """
    def __init__(self, *args, **vargs):
        """ Pass any arguments to the super class (Exception).
        """
        super(LogFileAbsentException, self).__init__(*args, **vargs)

def ensureLogFileExists(fullFileName=defaultLogFile):
    """ Checks to see if the log file exists yet or not as well as doing the same for the mssql-conf log directory.
        However, if we cannot ensure this (either due to an exception or if /var/opt/mssql/log doesn't exist),
        LogFileAbsentException will be thrown.
    """
    try:
        # Check to see that mssql-server is installed. This is necessary because those contain the parent
        # directories for the mssql-conf log file. Don't run the checkInstall check though if testing since
        # we use a local file in that case.
        #
        if fullFileName == defaultLogFile and not mssqlconfhelper.checkInstall(runAsRoot=mssqlconfhelper.checkSudo()):
            raise LogFileAbsentException("mssql-server is not correctly installed.")
        directory = os.path.dirname(fullFileName)
        mssqlLogDir = os.path.dirname(directory)

        if not os.path.isdir(mssqlLogDir):
            raise LogFileAbsentException("The parent directory of the mssql-conf logs, {0}, does not exist.".format(mssqlLogDir))
        elif os.path.exists(directory) and not os.path.isdir(directory):
            raise LogFileAbsentException("Expected {0} to be a directory but it already exists and is NOT a directory.".format(directory))
        elif os.path.exists(fullFileName) and not os.path.isfile(fullFileName):
            raise LogFileAbsentException("Expected {0} to be a log file but it already exists and is NOT a normal file.".format(fullFileName))

        if not os.path.exists(directory):
            # When we create the directory, it is not guaranteed that permissions will be
            # set by os.mkdir. Therefore, follow it up with a chmod to get into a good state.
            # It is not worth setting file ownership at this step since we do that elsewhere
            # and that does not have the same security implications as having too lenient of
            # permissions.
            #
            os.mkdir(directory, logFileDirectoryPermissions)
            os.chmod(directory, logFileDirectoryPermissions)

        if not os.path.exists(fullFileName):
            pathObj = pathlib.Path(fullFileName)
            pathObj.touch(logFilePermissions)
    except Exception as e:
        if isinstance(e, LogFileAbsentException):
            raise
        else:
            raise LogFileAbsentException("Unexpected error occurred while ensuring log file's existence.") from e

def fixLogFilePermissions(fullFileName=defaultLogFile, mssqlUser=mssqlconfhelper.mssqlUser):
    """ Tries to set the owner and group of the log files and directory to the user mssql.
        This function does that by taking the directory of the input filename and setting permissions
        on that directory as well as files in that directory.
    """

    directory = os.path.dirname(fullFileName)

    try:
        logger.info("Trying to make owner and group of the log files at %s %s.", directory, mssqlUser)

        try:
            mssqlUID = pwd.getpwnam(mssqlUser)
        except KeyError:
            # Log as a warning and return this is normal behavior, particularly if the script is running for the first time.
            #
            logger.warning("User %s does not yet exist so file permissions cannot be given to it.", mssqlUser)
            return

        userId = mssqlUID.pw_uid
        groupId = mssqlUID.pw_gid
        logger.info("Found user ID (%s) and group ID (%s) for user %s.", userId, groupId, mssqlUser)

        # Check if directory actually exists. The only case where this should happen is if we were unable to
        # create it earlier for some reason. Therefore, we don't try to create it here since we already tried to
        # create it and failed.
        #
        if not os.path.exists(directory):
            logger.warning("Not changing permissions for log directory %s since directory does not exist.", directory)
            return

        logger.info("Trying chown the directory.")
        os.chown(directory, userId, groupId)
        logger.info("Succesfully chown'd the directory.")

        for fileNameWithoutDir in os.listdir(directory):
            fileName = os.path.join(directory, fileNameWithoutDir)
            logger.info("Trying to set ownership for file %s.", fileName)
            os.chown(fileName, userId, groupId)
            logger.info("Successfully set ownership for file.")

            # This number should be printed in octal with 3 digits (filling with 0).
            #
            logger.info("Changing the permissions of the file to %s.", formatFileMode(logFilePermissions))

            os.chmod(fileName, logFilePermissions)
            logger.info("Changed the permissions of the file.")

        logger.info("Successfully changed ownership for the directory %s and all files in that directory.", directory)
    except Exception:
        mssqlconfhelper.printException(_("Unexpected error ocurred while trying to change permissions for log files at %s.") % directory)

def initializeCommandLine():
    """Initialize command line parsing
    """

    global args

    parser = argparse.ArgumentParser(prog='mssql-conf')

    parser.add_argument("-n", "--noprompt", action="store_true", help=_("Does not prompt the user and uses environment variables or defaults."))
    parser.add_argument("-v", "--verbose", action="store_true", help=_("Enables verbose logging for mssql-conf (messages might not be localized)."))
    quietHelp = _("Completely disables the logging of mssql-conf. When this option is not selected, logs are stored in %s.") % defaultLogFile
    parser.add_argument("-q", "--quiet", action="store_true", help=quietHelp)

    sp = parser.add_subparsers(metavar='')

    sp_setup = sp.add_parser('setup', help=_('Initialize and setup Microsoft SQL Server'))
    sp_setup.add_argument("setup_option", metavar="setup_option", help=_("Specify accept-eula to accept the End-User License Agreement"), nargs='?')
    sp_setup.set_defaults(which='setup')

    sp_set = sp.add_parser('set', help=_('Set the value of a setting'))
    sp_set.add_argument('section_name', metavar='SECTION', help=_('The section of the setting'), action='store')
    sp_set.add_argument('setting_name', metavar='SETTING', help=_('The setting name to set'), action='store', nargs='?')
    sp_set.add_argument('setting_value', metavar='VALUE', help=_('Value for the setting'), action='store')
    sp_set.set_defaults(which='set')

    sp_unset = sp.add_parser('unset', help=_('Unset the value of a setting'))
    sp_unset.add_argument('section_name', metavar='SECTION', help=_('The section of the setting'), action='store')
    sp_unset.add_argument('setting_name', metavar='SETTING', help=_('The setting name to unset'), action='store',
                          nargs='?')
    sp_unset.set_defaults(which='unset')

    sp_list = sp.add_parser('list', help=_('List the supported settings'))
    sp_list.set_defaults(which='list')

    sp_get = sp.add_parser('get', help=_('Gets the value of all settings in a section or of an individual setting'))
    sp_get.add_argument('section_name', metavar='SECTION', help=_('The section of the setting'), action='store')
    sp_get.add_argument('setting_name', metavar='SETTING', help=_('Optional: The setting whose value to get'), action='store',
                        nargs='?')
    sp_get.set_defaults(which='get')

    sp_traceflag = sp.add_parser('traceflag', help=_('Enable/disable one or more traceflags'))
    sp_traceflag.add_argument('traceflags_list', metavar='TRACEFLAGS', type=int, nargs='+',
                                     help=_('List of traceflags separated by whitespace.'))
    sp_traceflag.add_argument('tf_choice', metavar='on/off',choices=['on', 'off'], help=_('Turn the traceflags on or off.'))
    sp_traceflag.set_defaults(which='traceflag')

    sp_sa_password = sp.add_parser('set-sa-password',
                                   help=_('Set the system administrator (SA) password'))
    sp_sa_password.set_defaults(which='set-sa-password')

    sp_collation = sp.add_parser('set-collation',
                                help=_('Set the collation of system databases'))
    sp_collation.set_defaults(which='set-collation')

    sp_validate = sp.add_parser('validate',
                                help=_('Validate the configuration file'))
    sp_validate.set_defaults(which='validate')

    sp_set_edition = sp.add_parser('set-edition',
                                   help= _("Set the edition of the SQL Server instance"))
    sp_set_edition.set_defaults(which='set-edition')

    sp_welcome = sp.add_parser('welcome', usage=SUPPRESS)
    sp_welcome.add_argument("welcometype", metavar="welcometype", nargs='?')
    sp_welcome.set_defaults(which='welcome')

    adutilDir = os.path.abspath(os.path.join(__file__, "../../adutil/"))
    # adutilPathHelp = _("Path for version of adutil to use. If none is specified, mssql-conf will use one appropriate for the current host located at %s") % adutilDir
    adutilPathHelp = _("Path to adutil binary. See aka.ms/install-adutil for instructions on how to install adutil.")

    sp_validateADConfig = sp.add_parser('validate-ad-config', help=_('Validate configuration for Active Directory Authentication'))
    sp_validateADConfig.set_defaults(which='validate-ad-config')
    sp_validateADConfig.add_argument('keytab', metavar="keytab", help=_("Path to keytab which SQL Server will use"), action="store", nargs="?", default="")
    sp_validateADConfig.add_argument('realm', metavar="realm", help=_("Realm which the host is joined to. Defaults to the default_realm value in /etc/krb5.conf"), action="store", nargs='?', default="")
    sp_validateADConfig.add_argument('--adutil-path', metavar="adutil_path", help=adutilPathHelp, action="store", nargs="?", default="")

    sp_setupADKeytab = sp.add_parser('setup-ad-keytab', help=_('Create a keytab for SQL Server to use to authenticate AD users. Password may be specified interactively (unless noprompt is set) or through the MSSQL_CONF_PASSWORD environment variable.'))
    sp_setupADKeytab.add_argument('keytab', metavar='keytab', help=_('Path to the keytab that will be created'), action='store')
    sp_setupADKeytab.add_argument('user', metavar='user', help=_('AD Managed Service Account (MSA) which owns the SPNs'), action='store', nargs='?', default='')
    sp_setupADKeytab.add_argument('--adutil-path', metavar='adutil_path', help=adutilPathHelp, action='store', nargs='?', default='adutil')
    sp_setupADKeytab.add_argument('--kvno', metavar='kvno', help=_('KVNO to use in the keytab. User must already exist if this option is given'), action='store', nargs='?', default=None, type=int)
    sp_setupADKeytab.add_argument('--use-next-kvno', dest='useNextKvno', action='store_true', help=_('Lookup the current KVNO and use that value plus one for keytab entries'))
    sp_setupADKeytab.set_defaults(which='setup-ad-keytab')

    if (len(sys.argv) == 1):
        parser.print_help()
        exit(mssqlconfhelper.errorExitCode)

    args = parser.parse_args()

def main():
    """Program main function
    """

    mssqlconfhelper.initialize()
    mssqlsettingsmanager.initialize()

    initializeCommandLine()
    initializeLogging()

    if not mssqlconfhelper.checkSudoOrMssql():
        print(_("This program must be run as superuser or as a user with membership in the mssql"))
        print(_("group."))
        exit(mssqlconfhelper.errorExitCode)

    processCommands()
    exit(mssqlconfhelper.successExitCode)

if __name__ == "__main__":
    main()
