#!/usr/bin/env python3
#
# Verifies that all partitions on a Microsoft Windows Active Directory domain controller are
# replicating correctly. The query is performed via LDAP. There is no dependency on a
# Windows-specific API; the pure-Python ldap3 library is used to ensure wide-ranging platform
# support.
#
# Author: Ondřej Hošek <ondrej.hosek@tuwien.ac.at>
#
# To the extent possible under law, the person who associated CC0 with this work has waived all
# copyright and related or neighboring rights to this work.
# https://creativecommons.org/publicdomain/zero/1.0/
#
import argparse
import datetime
import struct
import sys
import traceback
from typing import ByteString, Dict, Iterable, List, NamedTuple, Optional, Set, Union
from uuid import UUID
import ldap3


STATUS_OK = 0
STATUS_WARN = 1
STATUS_CRIT = 2
STATUS_UNKN = 3
TEXT_TO_STATUS = {
    "ok": STATUS_OK,
    "warn": STATUS_WARN,
    "crit": STATUS_CRIT,
    "unkn": STATUS_UNKN,
}
STATUS_TO_SUFFIX = {
    STATUS_OK: "",
    STATUS_WARN: " (!)",
    STATUS_CRIT: " (!!!)",
    STATUS_UNKN: " (???)",
}

REPS_FROM_TO_STRUCT = struct.Struct("<LLLLQQLLLL84sL24s16s16s16sLL")
DCA_RPC_INST_OFFSETS_STRUCT = struct.Struct("<LLLLL")
WINDOWS_EPOCH = datetime.datetime(1601, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc)


def read_nul_terminated_utf16_le_string(bs: Iterable[int]) -> str:
    """
    Reads a NUL-terminated string encoded as UTF-16 little-endian from the given iterable and
    returns it as a decoded string.
    """
    iterator = iter(bs)
    chunks: List[bytes] = []

    while True:
        low, high = next(iterator), next(iterator)
        if low == 0x00 and high == 0x00:
            # terminating NUL
            break
        chunks.append(bytes((low, high)))

    return b"".join(chunks).decode("utf-16le")


class DcaRpcInst(NamedTuple):
    """
    A Python representation of the DSA_RPC_INST Active Directory data structure.

    https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-drsr/88a39619-6dbe-4ba1-8435-5966c1a490a7
    """
    server: Optional[str]
    annotation: Optional[str]
    network_address: Optional[str]
    guid: Optional[UUID]

    @staticmethod
    def from_bytes(bs: ByteString) -> "DcaRpcInst":
        bs_memview = memoryview(bs)
        offsets = DCA_RPC_INST_OFFSETS_STRUCT.unpack(bs_memview[:DCA_RPC_INST_OFFSETS_STRUCT.size])

        server_offset, annotation_offset, network_address_offset, guid_offset = offsets[1:5]
        server = read_nul_terminated_utf16_le_string(bs_memview[server_offset:]) \
            if server_offset > 0 else None
        annotation = read_nul_terminated_utf16_le_string(bs_memview[annotation_offset:]) \
            if annotation_offset > 0 else None
        network_address = read_nul_terminated_utf16_le_string(bs_memview[network_address_offset:]) \
            if network_address_offset > 0 else None
        guid = UUID(bytes_le=bs_memview[guid_offset:guid_offset+16].tobytes()) \
            if guid_offset > 0 else None

        return DcaRpcInst(server, annotation, network_address, guid)


class ReplicationStatus(NamedTuple):
    """
    A Python representation of the REPS_FROM and REPS_TO Active Directory data structures.

    https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-drsr/f8e930ea-d847-4585-8d58-993e05f55e45
    https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-drsr/b422aa87-7d07-4527-b070-c5d719696c43
    """
    version: int
    reserved_0: int
    cb: int
    consecutive_failures: int
    time_last_success: datetime.datetime
    time_last_attempt: datetime.datetime
    result_last_attempt: int
    other_dra: Union[str, DcaRpcInst]
    replica_flags: int
    schedule: bytes
    reserved_1: int
    usn_vec: bytes
    dsa_obj: UUID
    invoc_id: UUID
    transport_obj: UUID
    reserved: int
    pas_data_offset: int
    data: bytes

    @staticmethod
    def from_bytes(bs: ByteString) -> "ReplicationStatus":
        bs_memview = memoryview(bs)

        # split into known pieces and data
        pieces = REPS_FROM_TO_STRUCT.unpack(bs_memview[:REPS_FROM_TO_STRUCT.size])
        data = bs_memview[REPS_FROM_TO_STRUCT.size:]

        # we'll need this later
        version = pieces[0]

        # parse dates/times
        time_last_success = WINDOWS_EPOCH + datetime.timedelta(seconds=pieces[4])
        time_last_attempt = WINDOWS_EPOCH + datetime.timedelta(seconds=pieces[5])

        # parse "other DRA"
        other_dra_offset, other_dra_len = pieces[7], pieces[8]
        other_dra_slice = bs_memview[other_dra_offset:other_dra_offset+other_dra_len]
        if version == 1:
            # MTX_ADDR (Pascal string)
            addr_len_incl_nul = int.from_bytes(other_dra_slice[0:4], "little", signed=False)
            addr_bytes = other_dra_slice[5:5+addr_len_incl_nul-1]
            other_dra: Union[str, DcaRpcInst] = addr_bytes.tobytes().decode("utf-8")
        else:
            # assume version 2
            # DSA_RPC_INST
            other_dra = DcaRpcInst.from_bytes(other_dra_slice)

        # parse UUIDs
        dsa_obj = UUID(bytes_le=pieces[13])
        invoc_id = UUID(bytes_le=pieces[14])
        transport_obj = UUID(bytes_le=pieces[15])

        return ReplicationStatus(
            version=version,
            reserved_0=pieces[1],
            cb=pieces[2],
            consecutive_failures=pieces[3],
            time_last_success=time_last_success,
            time_last_attempt=time_last_attempt,
            result_last_attempt=pieces[6],
            other_dra=other_dra,
            replica_flags=pieces[9],
            schedule=pieces[10],
            reserved_1=pieces[11],
            usn_vec=pieces[12],
            dsa_obj=dsa_obj,
            invoc_id=invoc_id,
            transport_obj=transport_obj,
            reserved=pieces[16],
            pas_data_offset=pieces[17],
            data=data.tobytes(),
        )


class ReplicationStatusPair(NamedTuple):
    """A pair of replication statuses, one inbound (repsFrom) and one outbound (repsTo)."""
    inbound: ReplicationStatus
    outbound: ReplicationStatus


def get_base_dn(conn: ldap3.Connection) -> str:
    """Obtain the base DN of the domain."""
    # query the defaultNamingContext attribute from the RootDSE
    conn.search("", "(objectClass=*)", search_scope=ldap3.BASE, attributes=["defaultNamingContext"])
    return conn.response[0]["attributes"]["defaultNamingContext"][0]


def get_ad_partitions(conn: ldap3.Connection, base_dn: str) -> List[str]:
    """Return a list of DNs that represent the partitions in the domain."""
    # the partitions can be enumerated as children of CN=Partitions,CN=Configuration,{BASEDN}
    # and their respective location in the directory is stored in the attribute "nCName"
    conn.search(
        f"CN=Partitions,CN=Configuration,{base_dn}",
        "(&(objectClass=crossRef)(nCName=*))",
        search_scope=ldap3.LEVEL,
        attributes=["nCName"],
    )
    partitions = [
        partition["attributes"]["nCName"][0]
        for partition in conn.response
    ]
    return partitions


def get_dc_name(conn: ldap3.Connection, ntds_guid: UUID) -> str:
    """Get the DNS name of a domain controller by the GUID of its NTDS Settings object."""
    # find the object
    conn.search(
        f"<GUID={ntds_guid}>",
        "(objectClass=*)",
        search_scope=ldap3.BASE,
        attributes=["cn"],
    )
    found = conn.response[0]

    # ensure it is an "NTDS Settings" object
    if found["attributes"]["cn"][0] != "NTDS Settings":
        raise ValueError(f"GUID {ntds_guid} does not represent an NTDS Settings object")

    # obtain the parent DN
    dn_parts = ldap3.utils.dn.parse_dn(found["dn"])
    parent_dn_parts = dn_parts[1:]
    parent_dn = "".join(f"{key}={val}{sep}" for (key, val, sep) in parent_dn_parts)

    # query it
    conn.search(
        parent_dn,
        "(objectClass=*)",
        search_scope=ldap3.BASE,
        attributes=["dNSHostName"],
    )

    # return the DNS name
    return conn.response[0]["attributes"]["dNSHostName"][0]


def get_repl_status(conn: ldap3.Connection, partition_dn: str) -> Dict[str, ReplicationStatusPair]:
    """
    Obtain the replication status of a domain controller for the given partition, represented as a
    dict from partner domain controller names to replication status pairs (inbound and outbound).
    """
    conn.search(
        partition_dn,
        "(objectClass=*)",
        search_scope=ldap3.BASE,
        attributes=["repsFrom", "repsTo"],
    )
    guid_to_outbound: Dict[UUID, ReplicationStatus] = {}
    for outbound_data in conn.response[0]["attributes"]["repsTo"]:
        outbound = ReplicationStatus.from_bytes(outbound_data)
        guid_to_outbound[outbound.dsa_obj] = outbound
    guid_to_inbound: Dict[UUID, ReplicationStatus] = {}
    for inbound_data in conn.response[0]["attributes"]["repsFrom"]:
        inbound = ReplicationStatus.from_bytes(inbound_data)
        guid_to_inbound[inbound.dsa_obj] = inbound

    # find the names of the domain controllers
    all_guids: Set[UUID] = set()
    all_guids.update(guid_to_outbound.keys())
    all_guids.update(guid_to_inbound.keys())
    guid_to_hostname: Dict[UUID, str] = {
        guid: get_dc_name(conn, guid)
        for guid in all_guids
    }

    # return inbound and outbound info for each domain controller
    return {
        hostname: ReplicationStatusPair(
            guid_to_inbound[guid],
            guid_to_outbound[guid],
        )
        for (guid, hostname) in guid_to_hostname.items()
    }


def parse_minutes(minute_str: str) -> datetime.timedelta:
    return datetime.timedelta(seconds=float(minute_str)*60.0)


def run():
    parser = argparse.ArgumentParser(
        description=
            "Verifies that all partitions on an Active Directory server are replicating"
            " correctly.",
    )
    parser.add_argument(
        "--hostname", "-H",
        dest="hostname", metavar="HOSTNAME", required=True,
        help="Hostname or URL of the LDAP server to contact.",
    )
    parser.add_argument(
        "--port", "-p",
        dest="port", metavar="PORT", default=None,
        help=
            "Port of the LDAP server. Ignored if --hostname/-H is given as a URL. The default port"
            " is 389; this default is changed to 636 if --ssl/-S is given.",
    )
    parser.add_argument(
        "--ssl", "-S",
        dest="ssl", action="store_true",
        help=
            "Whether to connect to the LDAP server using SSL instead of a plaintext connection."
            " Ignored if --hostname/-H is given as a URL.",
    )
    parser.add_argument(
        "--starttls", "-T",
        dest="starttls", action="store_true",
        help=
            "Whether to launch an encrypted tunnel using STARTTLS.",
    )
    parser.add_argument(
        "--bind", "-D",
        dest="bind_dn", metavar="BINDDN", default=None,
        help=
            "The DN (Distinguished Name) using which to bind to the LDAP server.",
    )
    parser.add_argument(
        "--pass", "-P",
        dest="password", metavar="PASSWORD", default=None,
        help=
            "The password using which to bind to the LDAP server.",
    )
    parser.add_argument(
        "--pass-file", "-F",
        dest="password_file", metavar="PASSWORDFILE", type=argparse.FileType("r"), default=None,
        help=
            "A file containing the password used to bind to the LDAP server. A single trailing"
            " CR, LF or CR+LF is stripped from the file before it is passed to the server; if the"
            " password actually contains these characters, make sure to add another CR+LF at the"
            " end.",
    )
    parser.add_argument(
        "--inbound-age-warn", "-a",
        dest="inbound_age_warn", metavar="MINUTES", type=parse_minutes, default=None,
        help=
            "The maximum age of an inbound replication, in minutes, above which a warning is"
            " raised.",
    )
    parser.add_argument(
        "--inbound-age-crit", "-A",
        dest="inbound_age_crit", metavar="MINUTES", type=parse_minutes, default=None,
        help=
            "The maximum age of an inbound replication, in minutes, above which a critical state is"
            " raised.",
    )
    parser.add_argument(
        "--failure-state", "-f",
        dest="failure_state", choices=TEXT_TO_STATUS.keys(), default="crit",
        help=
            "State to raise if a replication reports that it has failed. A critical state is raised"
            " by default.",
    )

    try:
        args, unknown_args = parser.parse_known_args()
    except argparse.ArgumentError as exc:
        parser.print_usage(sys.stderr)
        print(f"{parser.prog}: error: {exc}", file=sys.stderr)
        sys.exit(STATUS_UNKN)

    if unknown_args:
        parser.print_usage(sys.stderr)
        print(
            f"{parser.prog}: error: unrecognized arguments: {' '.join(unknown_args)}",
            file=sys.stderr,
        )
        sys.exit(STATUS_UNKN)

    if args.ssl and args.starttls:
        print(
            "The options --ssl/-S and --starttls/-T cannot be used simultaneously.",
            file=sys.stderr,
        )
        sys.exit(STATUS_UNKN)

    if args.password is not None and args.password_file is not None:
        print(
            "The options --pass/-P and --pass-file/-F cannot be used simultaneously.",
            file=sys.stderr,
        )
        sys.exit(STATUS_UNKN)

    # sane defaults for the port
    if args.port is not None:
        port: int = args.port
    elif args.ssl:
        port = 636
    else:
        port = 389

    # obtain the password
    password: Optional[str] = None
    if args.password is not None:
        password = args.password
    if args.password_file is not None:
        with args.password_file:
            password = args.password_file.read()
        if password.endswith("\n"):
            password = password[:-1]
        if password.endswith("\r"):
            password = password[:-1]

    # connect to the server
    server = ldap3.Server(
        args.hostname,
        port,
        use_ssl=args.ssl,
        get_info=ldap3.NONE,
    )
    conn = ldap3.Connection(
        server,
        user=args.bind_dn,
        password=password,
    )

    if args.starttls:
        conn.start_tls()

    conn.bind()

    # obtain the base DN and the partitions
    base_dn = get_base_dn(conn)
    partitions = get_ad_partitions(conn, base_dn)

    # collect replication status for each partition
    partition_host_status: Dict[str, Dict[str, ReplicationStatusPair]] = {
        partition: get_repl_status(conn, partition)
        for partition in partitions
    }

    # collect the details
    inbound_failed_count, outbound_failed_count = 0, 0
    inbound_over_warn_count, inbound_over_crit_count = 0, 0
    detail_lines = []
    statuses = []
    for partition, host_status in partition_host_status.items():
        detail_lines.append(f"partition: {partition}")
        for host, status in host_status.items():
            # did the replication succeed?
            inbound_repl_status = STATUS_OK
            if status.inbound.result_last_attempt != 0x0:
                inbound_failed_count += 1
                inbound_repl_status = TEXT_TO_STATUS[args.failure_state]
            statuses.append(inbound_repl_status)
            inbound_result_suffix = STATUS_TO_SUFFIX[inbound_repl_status]

            outbound_repl_status = STATUS_OK
            if status.outbound.result_last_attempt != 0x0:
                outbound_failed_count += 1
                outbound_repl_status = TEXT_TO_STATUS[args.failure_state]
            statuses.append(outbound_repl_status)
            outbound_result_suffix = STATUS_TO_SUFFIX[outbound_repl_status]

            # how old is the inbound data?
            inbound_age_status = STATUS_OK
            age = datetime.timedelta(
                seconds=datetime.datetime.now().timestamp() - status.inbound.time_last_success.timestamp()
            )
            if args.inbound_age_warn is not None:
                if age > args.inbound_age_warn:
                    inbound_over_warn_count += 1
                    inbound_age_status = STATUS_WARN
            if args.inbound_age_crit is not None:
                if age > args.inbound_age_crit:
                    inbound_over_crit_count += 1
                    inbound_age_status = STATUS_CRIT
            statuses.append(inbound_age_status)
            inbound_age_suffix = STATUS_TO_SUFFIX[inbound_age_status]

            detail_lines.append(f"  partner DC: {host}")
            detail_lines.append("    inbound data obtained:")
            detail_lines.append(f"      last succeeded: {status.inbound.time_last_success}{inbound_age_suffix}")
            detail_lines.append(f"      last attempted: {status.inbound.time_last_attempt}")
            detail_lines.append(f"      attempt result: {status.inbound.result_last_attempt}{inbound_result_suffix}")
            detail_lines.append("    outbound triggered:")
            detail_lines.append(f"      last succeeded: {status.outbound.time_last_success}")
            detail_lines.append(f"      last attempted: {status.outbound.time_last_attempt}")
            detail_lines.append(f"      attempt result: {status.outbound.result_last_attempt}{outbound_result_suffix}")

    # summarize
    inbound_crit_count = inbound_over_crit_count
    inbound_warn_count = inbound_over_warn_count - inbound_crit_count

    output_message = []
    if inbound_failed_count > 0:
        if outbound_failed_count > 0:
            output_message.append(
                f"{inbound_failed_count} inbound and {outbound_failed_count} outbound replications failed"
            )
        else:
            output_message.append(
                f"{inbound_failed_count} inbound replications failed"
            )
    elif outbound_failed_count > 0:
        output_message.append(
            f"{outbound_failed_count} outbound replications failed"
        )

    if inbound_crit_count > 0:
        output_message.append(
            f"{inbound_crit_count} inbound replications critically old"
        )
    if inbound_warn_count > 0:
        output_message.append(
            f"{inbound_warn_count} inbound replications old"
        )

    # the summary if everything is OK
    if not output_message:
        process_count = sum(len(host_status) for host_status in partition_host_status.values())
        output_message.append(f"{process_count} replication processes OK")

    # output summary
    print(", ".join(output_message))

    # output details
    for detail_line in detail_lines:
        print(detail_line)

    # exit with worst exit code
    sys.exit(max(statuses))


def main():
    # wrap the run() call to ensure that the UNKNOWN status is returned if a call fails
    try:
        run()
    except Exception: # pylint: disable=broad-except
        traceback.print_exc()
        sys.exit(STATUS_UNKN)


if __name__ == "__main__":
    main()
