#!/cvmfs/dirac.egi.eu/dirac/v8.0.50/Linux-aarch64/bin/python
"""
Python client for the ARC Candypond service.
"""

import sys
try:
    import httplib
except ImportError:
    import http.client as httplib
import re
import time
import os
import socket
import ssl
import pwd
import argparse
import logging
import xml.etree.ElementTree as ET

# init logger
logger = logging.getLogger('ARC.Candypond.Client')
logger.setLevel(logging.WARNING)
log_handler_stderr = logging.StreamHandler()
log_handler_stderr.setFormatter(
    logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] [%(process)d] [%(message)s]'))
logger.addHandler(log_handler_stderr)


class HTTPSClientAuthConnection(httplib.HTTPSConnection):
    """ Class to make a HTTPS connection, with support for full client-based SSL Authentication"""

    def __init__(self, host, port, key_file, cert_file, ca_file=None, timeout=None):
        httplib.HTTPSConnection.__init__(self, host, port, key_file=key_file, cert_file=cert_file)
        self.key_file = key_file
        self.cert_file = cert_file
        self.ca_file = ca_file
        self.timeout = timeout

    def connect(self):
        """ Connect to a host on a given (SSL) port.
            If ca_file is pointing somewhere, use it to check Server Certificate.

            Redefined/copied and extended from httplib.py:1105 (Python 2.6.x).
            This is needed to pass cert_reqs=ssl.CERT_REQUIRED as parameter to ssl.wrap_socket(),
            which forces SSL to check server certificate against our client certificate.
        """
        sock = socket.create_connection((self.host, self.port), self.timeout)
        if self._tunnel_host:
            self.sock = sock
            self._tunnel()
        # If there's no CA File, don't force Server Certificate Check
        if self.ca_file:
            self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file,
                                        ca_certs=self.ca_file, cert_reqs=ssl.CERT_REQUIRED)
        else:
            self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file,
                                        cert_reqs=ssl.CERT_NONE)


class CacheException(Exception):
    """
    Exceptions generated by this code.
    """
    pass


def splitURL(url):
    """
    Split url into (protocol, host, port, path) and return this tuple.
    """
    match = re.match('(\w*)://([^/?#:]*):?(\d*)/?(.*)', url)
    if match is None:
        raise CacheException('URL '+url+' is malformed')
    
    port_s = match.group(3)
    if (len(port_s) > 0):
        port = int(port_s)
    else:
        port = None
        
    urltuple = (match.group(1), match.group(2), port, '/'+match.group(4))
    return urltuple


def addETElement(root, child, text):
    """
    Utility function to add a child element with text to the root element.
    Returns the child element.
    """
    sub = ET.SubElement(root, child)
    sub.text = text
    return sub


def checkSOAPFault(element):
    """
    Checks the response from a service given in element for SOAP faults, and
    raises an exception if there is one.
    """
    response_body = element[0]
    if response_body.find('{http://schemas.xmlsoap.org/soap/envelope/}Fault') is not None:
        fault = response_body.find('{http://schemas.xmlsoap.org/soap/envelope/}Fault')
        faultcode = fault.find('{http://schemas.xmlsoap.org/soap/envelope/}faultcode').text
        faultstring = fault.find('{http://schemas.xmlsoap.org/soap/envelope/}faultstring').text
        raise CacheException('SOAP error: '+faultcode+' - '+faultstring)


def cacheCheck(service, proxy, urls):
    """
    Call the cache service at service to query if the URLs given in the
    list urls exist in the cache. Returns a dictionary of each URL mapped
    to true or false.
    """
    
    (protocol, host, port, path) = splitURL(service)
    
    # create request with etree
    soap = ET.Element('soap-env:Envelope', attrib={
        'xmlns:echo': 'urn:echo',
        'xmlns:soap-enc': 'http://schemas.xmlsoap.org/soap/encoding/',
        'xmlns:soap-env': 'http://schemas.xmlsoap.org/soap/envelope/',
        'xmlns:xsd': 'http://www.w3.org/2001/XMLSchema',
        'xmlns:xsi': 'http://www.w3.org/2001/XMLSchema-instance'
    })
    
    body = ET.SubElement(soap, 'soap-env:Body')
    cachecheck = ET.SubElement(body, 'CacheCheck')
    files = ET.SubElement(cachecheck, 'TheseFilesNeedToCheck')
    for url in urls:
        addETElement(files, 'FileURL', url)
        
    request = ET.tostring(soap)
        
    conn = HTTPSClientAuthConnection(host, port, proxy, proxy)
    try:
        conn.request('POST', path, request)
        resp = conn.getresponse()
    except Exception as e:
        raise CacheException('Error connecting to service at ' + host + ': ' + str(e))
    
    # On SOAP fault 500 is returned - this is caught in checkSOAPFault
    if resp.status != 200 and resp.status != 500:
        conn.close()
        raise CacheException('Error code '+str(resp.status)+' returned: '+resp.reason)

    xmldata = resp.read()
    
    conn.close()
    response = ET.XML(xmldata)
    checkSOAPFault(response)
    
    try:
        cache_result = response.find('{http://schemas.xmlsoap.org/soap/envelope/}Body')\
            .find('CacheCheckResponse').find('CacheCheckResult')
        results = cache_result.findall('Result')
    except:
        raise CacheException('Error with XML structure received from cache service')
    
    if len(results) == 0:
        raise CacheException('No results returned')
    
    cachefiles = {}
    for result in results:
        url = result.find('FileURL').text
        incache = False
        if result.find('ExistInTheCache').text == 'true':
            incache = True
        cachefiles[url] = incache 

    return cachefiles


def cacheLink(service, proxy, user, jobid, urls, dostage):
    """
    Call the cache service at service to link the given dictionary of urls
    to the session directory corresponding to jobid. The url dictionary
    maps remote URLs corresponding to the original source files to local
    filenames on the session directory. If dostage is true then any files not
    in the cache will be downloaded from source. The best way to use this
    method may be to call once with dostage=False, and then if there are
    missing files, make a decision whether to call again with doStage=True.
    Alternatively cacheCheck can be called first.
    
    Returns a tuple of (return code, urls) where urls is a dictionary
    representing the state of each requested url.
    
    Possible return codes:
    all successful
    one or more cache files is locked
    permission denied on one or more cache files
    download of one or more cache files failed
    one or more cache files is not present and dostage is false
    cache service errors (failed to connect, not authorised, bad config etc)
    too many downloads already in progress according to configured limit on server
    """
        
    (protocol, host, port, path) = splitURL(service)
    
    # create request with etree
    soap = ET.Element('soap-env:Envelope', attrib={
        'xmlns:echo': 'urn:echo',
        'xmlns:soap-enc': 'http://schemas.xmlsoap.org/soap/encoding/',
        'xmlns:soap-env': 'http://schemas.xmlsoap.org/soap/envelope/',
        'xmlns:xsd': 'http://www.w3.org/2001/XMLSchema',
        'xmlns:xsi': 'http://www.w3.org/2001/XMLSchema-instance'
    })
    
    body = ET.SubElement(soap, 'soap-env:Body')
    cachelink = ET.SubElement(body, 'CacheLink')
    files = ET.SubElement(cachelink, 'TheseFilesNeedToLink')
    for url in urls:
        file_obj = ET.SubElement(files, 'File')
        addETElement(file_obj, 'FileURL', url)
        addETElement(file_obj, 'FileName', urls[url])
    addETElement(cachelink, 'Username', user)
    addETElement(cachelink, 'JobID', jobid)
    stage = 'false'
    if dostage:
        stage = 'true'
    addETElement(cachelink, 'Stage', stage)
        
    request = ET.tostring(soap)
        
    conn = HTTPSClientAuthConnection(host, port, proxy, proxy)
    try:
        conn.request('POST', path, request)
        resp = conn.getresponse()
    except Exception as e:
        raise CacheException('Error connecting to service at ' + host + ': ' + str(e))
    
    # On SOAP fault 500 is returned - this is caught in checkSOAPFault
    if resp.status != 200 and resp.status != 500:
        conn.close()
        raise CacheException('Error code '+str(resp.status)+' returned: '+resp.reason)

    xmldata = resp.read()
    
    conn.close()
    response = ET.XML(xmldata)
    checkSOAPFault(response)
    
    try:
        cache_result = response.find('{http://schemas.xmlsoap.org/soap/envelope/}Body')\
            .find('CacheLinkResponse').find('CacheLinkResult')
        results = cache_result.findall('Result')
    except:
        raise CacheException('Error with XML structure received from cache service')
    
    if len(results) == 0:
        raise CacheException('No results returned')
    
    cachefiles = {}
    stagingfiles = {}
    for result in results:
        url = result.find('FileURL').text
        link_result_code = result.find('ReturnCode').text
        link_result_text = result.find('ReturnCodeExplanation').text
        if link_result_code == '1':
            stagingfiles[url] = urls[url]
        else:
            cachefiles[url] = (link_result_code, link_result_text)
    
    if len(stagingfiles) == 0:
        return cachefiles

    # Some files required staging so poll until finished
    # So we don't overload the service, we poll for the link appearing in the
    # session dir, checking the service occasionally in case the transfer
    # failed.
    soap = ET.Element('soap-env:Envelope', attrib={
        'xmlns:echo': 'urn:echo',
        'xmlns:soap-enc': 'http://schemas.xmlsoap.org/soap/encoding/',
        'xmlns:soap-env': 'http://schemas.xmlsoap.org/soap/envelope/',
        'xmlns:xsd': 'http://www.w3.org/2001/XMLSchema',
        'xmlns:xsi': 'http://www.w3.org/2001/XMLSchema-instance'
    })
    
    body = ET.SubElement(soap, 'soap-env:Body')
    cachelinkquery = ET.SubElement(body, 'CacheLinkQuery')
    addETElement(cachelink, 'Username', user)
    addETElement(cachelinkquery, 'JobID', jobid)

    request = ET.tostring(soap)
        
    while True:  # add timeout
        time.sleep(1)
        
        conn = HTTPSClientAuthConnection(host, port, proxy, proxy)
        try:
            conn.request('POST', path, request)
            resp = conn.getresponse()
        except Exception as e:
            raise CacheException('Error connecting to service at ' + host + ': ' + str(e))
        
        # On SOAP fault 500 is returned - this is caught in checkSOAPFault
        if resp.status != 200 and resp.status != 500:
            conn.close()
            raise CacheException('Error code '+str(resp.status)+' returned: '+resp.reason)
    
        xmldata = resp.read()
        conn.close()
        response = ET.XML(xmldata)
        checkSOAPFault(response)
        
        try:
            cache_result = response.find('{http://schemas.xmlsoap.org/soap/envelope/}Body')\
                .find('CacheLinkQueryResponse').find('CacheLinkQueryResult').find('Result')
            link_result_code = cache_result.find('ReturnCode').text
            link_result_text = cache_result.find('ReturnCodeExplanation').text
        except:
            raise CacheException('Error with XML structure received from cache service')
        
        if link_result_code == '1':
            # still staging
            logger.debug("Still staging")
            
            # poll for final link appearing in local dir, for one minute, then
            # check service again. It is assumed that files will appear in the
            # current working dir.
            for i in range(60):
                time.sleep(1)
                for stagingfile in stagingfiles:
                    logger.info('Checking for %s', stagingfiles[stagingfile])
                    if not os.path.exists(stagingfiles[stagingfile]):
                        break
                else:
                    # all files exist - check service again to make sure (files
                    # could take a while to copy to session dir for example)
                    break
        else:
        
            # finished - either successfully or failed
            for url in stagingfiles:
                cachefiles[url] = (link_result_code, link_result_text)
        
            break
            
    return cachefiles


def get_parser():
    parser = argparse.ArgumentParser(description='Nordugrid ARC candypond service client parser')
    parser.add_argument('-p', '--proxy', action='store',
                        help='Path to client proxy certificate (X509_USER_PROXY used if not specified)')
    parser.add_argument('-e', '--endpoint-url', action='store',
                        help='ARC Candypond endpoint URL (ARC_CANDYPOND_URL used if not specified)')
    parser.add_argument('-j', '--jobid', action='store',
                        help='Job ID to stage data for (autodetection used if not specified)')
    parser.add_argument('-u', '--user', action='store',
                        help='User to be used by Candypond (user running this command will be used if not specified)')

    parser_ops = parser.add_subparsers(title='Candypond operations', dest='operation')

    parser_get = parser_ops.add_parser('get', help='Get the data from URL delivered to session directory by Candypond')
    parser_get.add_argument('url', help='URL of the content to deliver')
    parser_get.add_argument('file', help='Put the content into specified filename')

    parser_check = parser_ops.add_parser('check', help='Check the file is in ARC cache (returns exit code)')
    parser_check.add_argument('url', help='URL of the content to deliver')
    return parser


if __name__ == '__main__':
    # process command line options
    args_parser = get_parser()
    cmd_args = args_parser.parse_args()

    # proxy certificate is required
    proxy = cmd_args.proxy
    if proxy is None:
        if 'X509_USER_PROXY' in os.environ:
            proxy = os.environ['X509_USER_PROXY']
        else:
            # use default path
            proxy = '/tmp/x509up_u' + str(os.getuid())
    if not os.path.exists(proxy):
        logger.critical('Proxy certificate file is required but cannot be accessible at %s', proxy)
        sys.exit(1)

    # candypond endpoint url
    endpoint = cmd_args.endpoint_url
    if endpoint is None:
        if 'ARC_CANDYPOND_URL' in os.environ:
            endpoint = os.environ['ARC_CANDYPOND_URL']
        else:
            logger.critical('Candypond endpoint URL is not defined.')
            sys.exit(1)

    # job ID
    jobid = cmd_args.jobid
    if jobid is None:
        if 'GRID_GLOBAL_JOBID' in os.environ:
            jobid = os.environ['GRID_GLOBAL_JOBID']
            # WS already provides non-URL id, for GridFTP - strip URL part
            jobid = jobid[jobid.rfind('/') + 1:]
        else:
            # fallback to the current working directory name
            cwd = os.getcwd()
            jobid = cwd[cwd.rfind('/') + 1:]

    # username
    username = cmd_args.user
    if username is None:
        username = pwd.getpwuid(os.getuid())[0]

    # send appropriate request to candypond
    try:
        if cmd_args.operation == 'check':
            cacheurls = cacheCheck(endpoint, proxy, [cmd_args.url])
            if cmd_args.url not in cacheurls:
                logger.error('Failed to check URL %s presence in cache', cmd_args.url)
                sys.exit(2)
            if cacheurls[cmd_args.url]:
                sys.exit(0)
            sys.exit(1)
        elif cmd_args.operation == 'get':
            cacheurls = cacheLink(endpoint, proxy, username, jobid, {cmd_args.url: cmd_args.file}, True)
            print(cacheurls)
        else:
            logger.error('Unsupported Candypond operation %s is requested.', cmd_args.operation)
            sys.exit(1)
    except CacheException as e:
        logger.error('Error in Candypond request: %s', str(e))
        sys.exit(1)