#!/usr/bin/python3

"""
Scan through the various HTCondor signing keys and passwords, looking
for NULL characters that potentially truncated keys in earlier versions
of HTCondor.
"""

import argparse
import os
import pathlib
import re
import sys
import tempfile

import htcondor



MINIMUM_SAFE_LENGTH = 12

g_color_output = os.isatty(sys.stdout.fileno())

def format_red_terminal_text(text):
    if g_color_output:
        return "\033[1m\033[91m%s\033[0m" % text
    return text
        

def simple_scramble(in_buf):
    """
    Undo the simple scramble of HTCondor - simply
    XOR with 0xdeadbeef
    """
    deadbeef = [0xde, 0xad, 0xbe, 0xef]
    out_buf = b''
    for idx in range(len(in_buf)):
        scramble = in_buf[idx] ^ deadbeef[idx % 4]
        out_buf += b'%c' % scramble
    return out_buf


def truncate_at_null(in_buf):
    """
    Determine whether the input buffer contains a NULL character
    """
    for idx, byteval in zip(range(len(in_buf)), in_buf):
        if byteval == 0:
            return True, in_buf[:idx]
    return False, in_buf


def atomic_write_contents(dest, contents):
    """
    Atomically write `contents` to Path `dest` and fsync the result. 
    """
    with tempfile.NamedTemporaryFile(mode="w+b", prefix=".condor." + dest.name, dir=dest.parent, delete=False) as tf:
        tf.write(contents)
        tf.close()
        os.rename(tf.name, str(dest))
        with open(str(dest), "r") as fp:
            os.fsync(fp)


def truncate_keyfile(key, contents, truncated_key):
    """
    Truncate a given key file according to the pre-8.9.12 logic
    """
    try:
        new_fname = key.with_name(".condor.backup." + key.name)
        if new_fname.exists():
            print("Backup file %s already exists; not overwriting." % str(new_fname))
            sys.exit(4)
        atomic_write_contents(new_fname, contents)
        atomic_write_contents(key, simple_scramble(truncated_key))
    except Exception as exc:
        print("Failure when truncating %s: %s" % (str(key), str(exc)))
        raise
        sys.exit(3)


def enumerate_keys():
    """
    Enumerate all possible keys that HTCondor might use.
    """
    keys = set()
    if htcondor.param.get("SEC_TOKEN_POOL_SIGNING_KEY_FILE"):
        keys.add(pathlib.Path(htcondor.param['SEC_TOKEN_POOL_SIGNING_KEY_FILE']))
    keys.add(pathlib.Path(htcondor.param['SEC_PASSWORD_FILE']))
    exclude_regexp = re.compile(htcondor.param['LOCAL_CONFIG_DIR_EXCLUDE_REGEXP'])
    password_directory = pathlib.Path(htcondor.param['SEC_PASSWORD_DIRECTORY'])
    skipped_files = 0
    for fname in os.listdir(str(password_directory)):
        if exclude_regexp.match(fname):
            print("Skipping excluded file %s." % fname)
            skipped_files += 1
            continue
        keys.add(password_directory / fname)
    if skipped_files:
        print("Skipped %d excluded files.\n" % skipped_files)

    return keys


def main():

    parser = argparse.ArgumentParser(description="Examine HTCondor key files, looking for keys that may be silently truncated in prior HTCondor versions")
    parser.add_argument("key", nargs="*", help="Specific key files to examine; if not given, then the defaults are used from the HTCondor configuration")
    parser.add_argument("--truncate", default=False, help="When a potentially insecure key is encountered, truncate it to match the behavior prior to 8.9.12", action="store_true")
    args = parser.parse_args()

    keys_to_check = [pathlib.Path(fname) for fname in args.key]
    if not keys_to_check:
        try:
            keys_to_check = enumerate_keys()
        except PermissionError as pe:
            print("Cannot enumerate default keys: %s" % str(pe))
            sys.exit(2)

    good_key_count = 0
    keys_to_check = list(keys_to_check)
    keys_to_check.sort()
    truncated_key_count = 0
    short_key_count = 0

    if not keys_to_check:
        print("There are no keys to check")
        sys.exit(0)

    print("Key tests:")
    for key in keys_to_check:
        try:
            contents = key.read_bytes()
        except OSError as oe:
            print("%s: Unable to read - %s" % (str(key), oe))
            continue
        is_truncated, truncated_key = truncate_at_null(contents)
        unscrambled = simple_scramble(truncated_key)
        is_unscrambled_truncated, truncated_key = truncate_at_null(unscrambled)
        if is_truncated or is_unscrambled_truncated:
            truncated_key_count += 1
        if len(truncated_key) < MINIMUM_SAFE_LENGTH:
            short_key_count += 1
        if is_truncated and not is_unscrambled_truncated:
            if args.truncate:
                truncate_keyfile(key, contents, truncated_key)
                print("%s %s: Scrambled key was truncated; new size is %d bytes (originally %d bytes)" % (format_red_terminal_text("TRUNCATED"), str(key), len(truncated_key), len(contents)))
            else:
                print("%s      %s: Scrambled key will be silently truncated; effective size is %d bytes (originally %d bytes)" % (format_red_terminal_text("FAIL"), str(key), len(truncated_key), len(contents)))
        elif is_unscrambled_truncated and not is_truncated:
            if args.truncate:
                truncate_keyfile(key, contents, truncated_key)
                print("%s %s: Key was truncated; new size is %d bytes (originally %d bytes)" % (format_red_terminal_text("TRUNCATED"), str(key), len(truncated_key), len(contents)))
            else:
                print("%s      %s: Key will be silently truncated; effective size is %d bytes (originally %d bytes)" % (format_red_terminal_text("FAIL"), str(key), len(truncated_key), len(contents)))
        elif is_unscrambled_truncated and is_truncated:
            if args.truncate:
                truncate_keyfile(key, contents, truncated_key)
                print("%s %s: Scrambled key and key were truncated; new size is %d bytes (originally %d bytes)" % (format_red_terminal_text("TRUNCATED"), str(key), len(truncated_key), len(contents)))
            else:
                print("%s      %s: Scrambled key and key will be silently truncated; effective size is %d bytes (originally %d bytes)" % (format_red_terminal_text("FAIL"), str(key), len(truncated_key), len(contents)))
        elif len(truncated_key) < MINIMUM_SAFE_LENGTH:
            print("%s      %s: Key is too short: %d bytes" % (format_red_terminal_text("FAIL"), str(key), len(truncated_key)))
        else:
            print("OK        %s" % str(key))
            good_key_count += 1

    print("")
    print("There were %d compatible keys and %d keys with issues." % (good_key_count, len(keys_to_check) - good_key_count))

    if good_key_count != len(keys_to_check):
        print("")
        if truncated_key_count and not args.truncate:
            print("Some keys have compatibility problems with clients or servers older than HTCondor 8.9.12.")
            print("Before HTCondor 8.9.12, these keys were silently truncated, potentially resulting in insecure setups")
            print("Rerun this tool with the `--truncate` flag to change the files on disk and restore compatibbility with older HTCondor versions")
        elif args.truncate:
            print("There were %d keys that were explicitly truncated" % truncated_key_count)
        if short_key_count:
            print("HTCondor recommends passwords be longer than %d characters." % MINIMUM_SAFE_LENGTH)

        sys.exit(1)

if __name__ == '__main__':
    main()