Skip to content

Instantly share code, notes, and snippets.

@datinho
Created August 22, 2018 14:38
Show Gist options
  • Save datinho/d73b7d2ae22872125972b236eee126a3 to your computer and use it in GitHub Desktop.
Save datinho/d73b7d2ae22872125972b236eee126a3 to your computer and use it in GitHub Desktop.
OPT_REFERRALS for Active Directory search issue and multiple search reference issue with empty name_attribute
from __future__ import absolute_import, unicode_literals
from codecs import open
from collections import namedtuple
import logging
import os
from itertools import product
import ldap
# On CentOS 6, python-ldap does not manage SCOPE_SUBORDINATE
try:
from ldap import SCOPE_SUBORDINATE
except ImportError: # pragma: nocover
SCOPE_SUBORDINATE = None
from ldap.dn import str2dn as native_str2dn
from ldap import sasl
from .utils import decode_value, encode_value, PY2, uniq, iter_format_fields
from .utils import Settable
logger = logging.getLogger(__name__)
LDAPError = ldap.LDAPError
SCOPES = {
'base': ldap.SCOPE_BASE,
'one': ldap.SCOPE_ONELEVEL,
'sub': ldap.SCOPE_SUBTREE,
}
if SCOPE_SUBORDINATE:
SCOPES['children'] = SCOPE_SUBORDINATE
SCOPES_STR = dict((v, k) for k, v in SCOPES.items())
def parse_scope(raw):
if raw in SCOPES_STR:
return raw
try:
return SCOPES[raw]
except KeyError:
raise ValueError("Unknown scope %r" % (raw,))
def str2dn(value):
try:
if PY2: # pragma: nocover_py3
# Workaround buggy unicode managmenent in python-ldap on Python2.
# This is not necessary on Python3.
value = decode_value(native_str2dn(value.encode('utf-8')))
else: # pragma: nocover_py2
value = native_str2dn(value)
except ldap.DECODING_ERROR:
raise ValueError("Can't parse DN '%s'" % (value,))
return [
[(k.lower(), v, _) for k, v, _ in t]
for t in value
]
def expand_attributes(entry, formats):
if entry is None:
for f in formats:
yield f
return
attributes = dict()
for k in iter_format_fields(formats):
v = list(get_attribute(entry, k))
if '.' in k:
k, _, attr = k.partition('.')
v = [Settable(**dict({attr: v1})) for v1 in v]
attributes[k] = v
for format_ in formats:
fields = list(iter_format_fields([format_], split=True))
values = [attributes[k] for k in fields]
for items in product(*values):
yield format_.format(**dict(zip(fields, items)))
def get_attribute(entry, attribute):
_, attributes = entry
path = attribute.lower().split('.')
try:
values = attributes[path[0]]
except KeyError:
raise ValueError("Unknown attribute %r" % (path[0],))
path = path[1:]
for value in values:
if path:
try:
dn = str2dn(value)
except ValueError:
msg = "Can't parse DN from attribute %s=%s" % (
attribute, value)
raise ValueError(msg)
value = dict()
for (type_, name, _), in dn:
names = value.setdefault(type_.lower(), [])
names.append(name)
try:
value = value[path[0]][0]
except KeyError:
raise ValueError("Unknown attribute %s" % (path[0],))
yield value
def lower_attributes(entry):
dn, attributes = entry
return dn, dict([
(k.lower(), v)
for k, v in attributes.items()
])
class EncodedParamsCallable(object): # pragma: nocover_py3
# Wrap a callable not accepting unicode to encode all arguments.
def __init__(self, callable_):
self.callable_ = callable_
def __call__(self, *a, **kw):
a, kw = encode_value((a, kw))
return decode_value(self.callable_(*a, **kw))
class UnicodeModeLDAPObject(object): # pragma: nocover_py3
# Simulate UnicodeMode from Python3, on top of python-ldap. This is not a
# Python2 issue but rather python-ldap not managing strings. Here we do it
# for this.
def __init__(self, wrapped):
self.wrapped = wrapped
def __getattr__(self, name):
return EncodedParamsCallable(getattr(self.wrapped, name))
class LDAPLogger(object):
def __init__(self, wrapped):
self.wrapped = wrapped
self.connect_opts = ''
def __getattr__(self, name):
return getattr(self.wrapped, name)
def search_s(self, base, scope, filter, attributes):
logger.debug(
"Doing: ldapsearch%s -b %s -s %s '%s' %s",
self.connect_opts,
base, SCOPES_STR[scope], filter, ' '.join(attributes or []),
)
return self.wrapped.search_s(base, scope, filter, attributes)
def simple_bind_s(self, binddn, password):
self.connect_opts = ' -x'
if binddn:
self.connect_opts += ' -D %s' % (binddn,)
if password:
self.connect_opts += ' -W'
self.log_connect()
return self.wrapped.simple_bind_s(binddn, password)
def sasl_interactive_bind_s(self, who, auth, *a, **kw):
self.connect_opts = ' -Y %s' % (auth.mech.decode('ascii'),)
if sasl.CB_AUTHNAME in auth.cb_value_dict:
self.connect_opts += ' -U %s' % (
auth.cb_value_dict[sasl.CB_AUTHNAME],)
if sasl.CB_PASS in auth.cb_value_dict:
self.connect_opts += ' -W'
self.log_connect()
return self.wrapped.sasl_interactive_bind_s(who, auth, *a, **kw)
def log_connect(self):
logger.debug("Doing: ldapwhoami%s", self.connect_opts)
def connect(**kw):
# Sources order, see ldap.conf(3)
# variable $LDAPNOINIT, and if that is not set:
# system file /etc/ldap/ldap.conf,
# user files $HOME/ldaprc, $HOME/.ldaprc, ./ldaprc,
# system file $LDAPCONF,
# user files $HOME/$LDAPRC, $HOME/.$LDAPRC, ./$LDAPRC,
# user files <ldap2pg.yml>...
# variables $LDAP<uppercase option name>.
#
# Extra variable LDAPPASSWORD is supported.
options = gather_options(**kw)
logger.debug("Connecting to LDAP server %s.", options['URI'])
conn = ldap.initialize(options['URI'])
if PY2: # pragma: nocover_py3
conn = UnicodeModeLDAPObject(conn)
conn = LDAPLogger(conn)
logger.debug("HOTFIX https://stackoverflow.com/questions/18793040/python-ldap-not-able-to-bind-successfully")
conn.set_option(ldap.OPT_REFERRALS, 0)
if options.get('USER'):
logger.debug("Trying SASL DIGEST-MD5 auth.")
auth = sasl.sasl({
sasl.CB_AUTHNAME: options['USER'],
sasl.CB_PASS: options['PASSWORD'],
}, 'DIGEST-MD5')
conn.sasl_interactive_bind_s("", auth)
else:
logger.debug("Trying simple bind.")
conn.simple_bind_s(options['BINDDN'], options['PASSWORD'])
return conn
class Options(dict):
def set_raw(self, option, raw):
option = option.upper()
try:
parser = getattr(self, 'parse_' + option.lower())
except AttributeError:
logger.debug("Unknown option %s", option)
return None
else:
value = parser(raw)
self[option] = value
return value
def _parse_raw(self, value):
return value
parse_uri = _parse_raw
parse_host = _parse_raw
parse_port = int
parse_binddn = _parse_raw
parse_user = _parse_raw
parse_password = _parse_raw
def gather_options(environ=None, **kw):
options = Options(
URI='',
HOST='',
PORT=389,
BINDDN='',
USER=None,
PASSWORD='',
)
environ = environ or os.environ
environ = dict([
(k[4:], v.decode('utf-8') if hasattr(v, 'decode') else v)
for k, v in environ.items()
if k.startswith('LDAP') and not k.startswith('LDAP2PG')
])
if 'NOINIT' in environ:
logger.debug("LDAPNOINIT defined. Disabled ldap.conf loading.")
else:
for e in read_files(conf='/etc/ldap/ldap.conf', rc='ldaprc'):
logger.debug('Read %s from %s.', e.option, e.filename)
options.set_raw(e.option, e.value)
for e in read_files(conf=options.get('CONF'), rc=options.get('RC')):
logger.debug('Read %s from %s.', e.option, e.filename)
options.set_raw(e.option, e.value)
for option, value in environ.items():
logger.debug('Read %s from env.', option)
options.set_raw(option, value)
options.update(dict(
(k.upper(), v)
for k, v in kw.items()
if k.upper() in options and v
))
if not options['URI']:
options['URI'] = 'ldap://%(HOST)s:%(PORT)s' % options
return options
def read_files(conf, rc):
candidates = []
if conf:
candidates.append(conf)
if rc:
candidates.extend(['~/%s' % rc, '~/.%s' % rc, rc])
candidates = uniq(map(
lambda p: os.path.realpath(os.path.expanduser(p)),
candidates,
))
for candidate in candidates:
try:
with open(candidate, 'r', encoding='utf-8') as fo:
logger.debug('Found rcfile %s.', candidate)
for entry in parserc(fo):
yield entry
except (IOError, OSError) as e:
logger.debug("Ignoring: %s", e)
RCEntry = namedtuple('RCEntry', ('filename', 'lineno', 'option', 'value'))
def parserc(fo):
filename = getattr(fo, 'name', '<stdin>')
for lineno, line in enumerate(fo):
line = line.strip()
if not line:
continue
if line.startswith('#'):
continue
option, value = line.split(None, 1)
yield RCEntry(
filename=filename,
lineno=lineno+1,
option=option,
value=value,
)
from __future__ import unicode_literals
from fnmatch import fnmatch
import logging
from .ldap import LDAPError, expand_attributes, lower_attributes
from .privilege import Grant
from .privilege import Acl
from .role import (
Role,
RoleOptions,
RoleSet,
)
from .utils import UserError, decode_value, match
from .psql import expandqueries
logger = logging.getLogger(__name__)
class SyncManager(object):
def __init__(
self, ldapconn=None, psql=None, inspector=None,
privileges=None, privilege_aliases=None, blacklist=None,
):
self.ldapconn = ldapconn
self.psql = psql
self.inspector = inspector
self.privileges = privileges or {}
self.privilege_aliases = privilege_aliases or {}
self._blacklist = blacklist
def query_ldap(self, base, filter, attributes, scope):
try:
raw_entries = self.ldapconn.search_s(
base, scope, filter, attributes,
)
except LDAPError as e:
message = "Failed to query LDAP: %s." % (e,)
raise UserError(message)
logger.debug('Got %d entries from LDAP.', len(raw_entries))
entries = []
for dn, attributes in raw_entries:
try:
entry = decode_value((dn, attributes))
except UnicodeDecodeError as e:
message = "Failed to decode data from %r: %s." % (dn, e,)
raise UserError(message)
logger.debug(">> Entry [[ %s ]][[ %s ]][[ %s ]]", entry, dn, attributes)
if not dn:
continue
logger.debug("<< Entry [[ %s ]][[ %s ]][[ %s ]]", entry, dn, attributes)
entries.append(lower_attributes(entry))
return entries
def process_ldap_entry(self, entry, names, **kw):
members = [
m.lower() for m in
expand_attributes(entry, kw.get('members', []))
]
parents = [
p.lower() for p in
expand_attributes(entry, kw.get('parents', []))
]
for name in expand_attributes(entry, names):
log_source = " from " + ("YAML" if name in names else entry[0])
name = name.lower()
logger.debug("Found role %s%s.", name, log_source)
if members:
logger.debug(
"Role %s must have members %s.", name, ', '.join(members),
)
if parents:
logger.debug(
"Role %s is member of %s.", name, ', '.join(parents))
role = Role(
name=name,
members=members[:],
options=kw.get('options', {}),
parents=parents[:],
)
yield role
def apply_role_rules(self, rules, entries):
for rule in rules:
for entry in entries:
try:
for role in self.process_ldap_entry(entry=entry, **rule):
yield role
except ValueError as e:
msg = "Failed to process %.48s: %s" % (entry[0], e,)
raise UserError(msg)
def apply_grant_rules(self, grant, entries=[]):
for rule in grant:
privilege = rule.get('privilege')
databases = rule.get('databases', '__all__')
if databases == '__all__':
databases = Grant.ALL_DATABASES
schemas = rule.get('schemas', '__all__')
if schemas in (None, '__all__', '__any__'):
schemas = None
pattern = rule.get('role_match')
for entry in entries:
try:
roles = list(expand_attributes(entry, rule['roles']))
except ValueError as e:
msg = "Failed to process %.32s: %s" % (entry, e,)
raise UserError(msg)
for role in roles:
role = role.lower()
if pattern and not fnmatch(role, pattern):
logger.debug(
"Don't grant %s to %s not matching %s",
privilege, role, pattern,
)
continue
yield Grant(privilege, databases, schemas, role)
def inspect_ldap(self, syncmap):
ldaproles = {}
ldapacl = Acl()
for mapping in syncmap:
if 'ldap' in mapping:
logger.info(
"Querying LDAP %.24s... %.12s...",
mapping['ldap']['base'], mapping['ldap']['filter'])
entries = self.query_ldap(**mapping['ldap'])
log_source = 'in LDAP'
else:
entries = [None]
log_source = 'from YAML'
for role in self.apply_role_rules(mapping['roles'], entries):
if role in ldaproles:
try:
role.merge(ldaproles[role])
except ValueError as e:
msg = "Role %s redefined with different options." % (
role,)
raise UserError(msg)
ldaproles[role] = role
grant = mapping.get('grant', [])
grants = self.apply_grant_rules(grant, entries)
for grant in grants:
logger.debug("Found GRANT %s %s.", grant, log_source)
ldapacl.add(grant)
# Lazy apply of role options defaults
roleset = RoleSet()
for role in ldaproles.values():
role.options.fill_with_defaults()
roleset.add(role)
return roleset, ldapacl
def postprocess_acl(self, acl, schemas):
expanded_grants = acl.expandgrants(
aliases=self.privilege_aliases,
privileges=self.privileges,
databases=schemas,
)
acl = Acl()
try:
for grant in expanded_grants:
acl.add(grant)
except ValueError as e:
raise UserError(e)
return acl
def sync(self, syncmap):
logger.info("Inspecting roles in Postgres cluster...")
me, issuper = self.inspector.fetch_me()
if not match(me, self.inspector.roles_blacklist):
self.inspector.roles_blacklist.append(me)
if not issuper:
logger.warn("Running ldap2pg as non superuser.")
RoleOptions.filter_super_columns()
databases, pgallroles, pgmanagedroles = self.inspector.fetch_roles()
pgallroles, pgmanagedroles = self.inspector.filter_roles(
pgallroles, pgmanagedroles)
logger.debug("Postgres inspection done.")
ldaproles, ldapacl = self.inspect_ldap(syncmap)
logger.debug("LDAP inspection completed. Post processing.")
try:
ldaproles.resolve_membership()
except ValueError as e:
raise UserError(str(e))
count = 0
count += self.psql.run_queries(expandqueries(
pgmanagedroles.diff(other=ldaproles, available=pgallroles),
databases=databases))
if self.privileges:
logger.info("Inspecting GRANTs in Postgres cluster...")
if self.psql.dry and count:
logger.warn(
"In dry mode, some owners aren't created, "
"their default privileges can't be determined.")
schemas = self.inspector.fetch_schemas(databases, ldaproles)
pgacl = self.inspector.fetch_grants(schemas, pgmanagedroles)
ldapacl = self.postprocess_acl(ldapacl, schemas)
count += self.psql.run_queries(expandqueries(
pgacl.diff(ldapacl, self.privileges),
databases=schemas))
else:
logger.debug("No privileges defined. Skipping GRANT and REVOKE.")
if count:
# If log does not fit in 24 row screen, we should tell how much is
# to be done.
level = logger.debug if count < 20 else logger.info
level("Generated %d querie(s).", count)
else:
logger.info("Nothing to do.")
return count
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment