"""Base class for probes doing direct LDAP queries."""

import logging
import time

import ldap3
from arcnagios.confargparse import UsageError
from arcnagios import nagutils, vomsutils

class LDAPNagiosPlugin(nagutils.NagiosPlugin, vomsutils.NagiosPluginVomsMixin):
    # pylint: disable=abstract-method,super-init-not-called

    def __init__(self, default_uri = None, default_basedn = None, **kwargs):
        nagutils.NagiosPlugin.__init__(self, **kwargs)
        ap = self.argparser.add_argument_group('LDAP Options')
        ap.add_argument('-H', dest = 'host',
                help = 'Host to query.  This will be used for the LDAP '
                       'connection if --ldap-uri is not specified.')
        if not 'use_port' in kwargs: # compat
            ap.add_argument('-p', dest = 'port', type = int, default = 2135,
                    help = 'The LDAP port to use if --ldap-uri was not given.  '
                           'The default is 2135.')
        ap.add_argument('--ldap-basedn', dest = 'ldap_basedn',
                default = default_basedn,
                help = 'Base DN to query if non-standard.')
        ap.add_argument('--ldap-uri', dest = 'ldap_uri',
                default = default_uri,
                help = 'LDAP URI of the infosystem to query.')
        ap.add_argument('-t', '--timeout', dest = 'timeout',
                type = int, default = 300,
                help = 'Overall timeout of the probe.')
        self._time_limit = None
        self._ldap_use_tls = False
        self._ldap_server = None
        self._ldap_conn = None

    def parse_args(self, args):
        nagutils.NagiosPlugin.parse_args(self, args)
        if self.opts.ldap_uri:
            uri_dict = ldap3.utils.uri.parse_uri(self.opts.ldap_uri)
            self.opts.host = uri_dict['host']
            self.opts.port = uri_dict['port']
            self._ldap_use_tls = uri_dict['ssl']
        else:
            if not self.opts.host:
                raise UsageError('Either --ldap-uri or -H must be specified.')
            self.opts.ldap_uri = 'ldap://%s:%d'%(self.opts.host, self.opts.port)

    @property
    def time_left(self):
        assert not self._time_limit is None
        return self._time_limit - time.time()

    def prepare_check(self):
        self.log.debug('Using LDAP URI %s.', self.opts.ldap_uri)
        try:
            self._ldap_server = ldap3.Server(
                    self.opts.host, self.opts.port, tls=self._ldap_use_tls,
                    get_info=ldap3.ALL)
            self._ldap_conn = \
                    ldap3.Connection(self._ldap_server, auto_bind=True)
            self._ldap_server.get_info_from_server(self._ldap_conn)
        except ldap3.core.exceptions.LDAPExceptionError as exn:
            raise nagutils.ServiceCritical(
                    'Failed to connect to LDAP server: %s' % exn)
        self._time_limit = time.time() + self.opts.timeout

    def ldap_search(self, basedn, filterstr, search_scope = 'SUBTREE',
            attrlist = ldap3.ALL_ATTRIBUTES):
        """Customized LDAP search with timeout and Nagios error reporting."""

        self.log.debug('Searching %s.', basedn)
        assert not self._ldap_conn is None
        if self.time_left <= 0:
            raise nagutils.ServiceCritical('Timeout before LDAP search.')
        try:
            if not self._ldap_conn.search(
                    basedn,
                    filterstr,
                    search_scope = search_scope,
                    attributes = attrlist,
                    time_limit = self.time_left):
                self.log.debug('LDAP search returned nothing.')
            return self._ldap_conn.entries
        except ldap3.core.exceptions.LDAPException as exn:
            self.log.error('LDAP details: basedn = %s, filter = %s, scope = %s',
                           basedn, filterstr, search_scope)
            self.log.error('LDAP error: %s', exn)
            raise nagutils.ServiceCritical('LDAP Search failed.')

    def debug_dump_obj(self, obj, name):
        # Dump vars(obj) if debugging is enabled.
        if self.log.getEffectiveLevel() >= logging.DEBUG:
            self.log.debug('Dump of %s:', name)
            for k, v in vars(obj).items():
                if isinstance(v, list) and len(v) > 4:
                    v = v[0:4] + ['...']
                self.log.debug('  %s: %r', k, v)
