# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hmac
import time
import base64
import hashlib
from typing import Dict, Type, Optional
from hashlib import sha256
from datetime import datetime
from libcloud.utils.py3 import ET, b, httplib, urlquote, basestring, _real_unicode
from libcloud.utils.xml import findall_ignore_namespace, findtext_ignore_namespace
from libcloud.common.base import BaseDriver, XmlResponse, JsonResponse, ConnectionUserAndKey
from libcloud.common.types import InvalidCredsError, MalformedResponseError
try:
import simplejson as json
except ImportError:
import json # type: ignore
__all__ = [
"AWSBaseResponse",
"AWSGenericResponse",
"AWSTokenConnection",
"SignedAWSConnection",
"AWSRequestSignerAlgorithmV2",
"AWSRequestSignerAlgorithmV4",
"AWSDriver",
]
DEFAULT_SIGNATURE_VERSION = "2"
UNSIGNED_PAYLOAD = "UNSIGNED-PAYLOAD"
PARAMS_NOT_STRING_ERROR_MSG = """
"params" dictionary contains an attribute "%s" which value (%s, %s) is not a
string.
Parameters are sent via query parameters and not via request body and as such,
all the values need to be of a simple type (string, int, bool).
For arrays and other complex types, you should use notation similar to this
one:
params['TagSpecification.1.Tag.Value'] = 'foo'
params['TagSpecification.2.Tag.Value'] = 'bar'
See https://docs.aws.amazon.com/AWSEC2/latest/APIReference/Query-Requests.html
for details.
""".strip()
class AWSBaseResponse(XmlResponse):
namespace = None
def _parse_error_details(self, element):
"""
Parse code and message from the provided error element.
:return: ``tuple`` with two elements: (code, message)
:rtype: ``tuple``
"""
code = findtext_ignore_namespace(element=element, xpath="Code", namespace=self.namespace)
message = findtext_ignore_namespace(
element=element, xpath="Message", namespace=self.namespace
)
return code, message
class AWSGenericResponse(AWSBaseResponse):
# There are multiple error messages in AWS, but they all have an Error node
# with Code and Message child nodes. Xpath to select them
# None if the root node *is* the Error node
xpath = None
# This dict maps CodeName
to a specific
# exception class that is raised immediately.
# If a custom exception class is not defined, errors are accumulated and
# returned from the parse_error method.
exceptions = {} # type: Dict[str, Type[Exception]]
def success(self):
return self.status in [httplib.OK, httplib.CREATED, httplib.ACCEPTED]
def parse_error(self):
context = self.connection.context
status = int(self.status)
# FIXME: Probably ditch this as the forbidden message will have
# corresponding XML.
if status == httplib.FORBIDDEN:
if not self.body:
raise InvalidCredsError(str(self.status) + ": " + self.error)
else:
raise InvalidCredsError(self.body)
try:
body = ET.XML(self.body)
except Exception:
raise MalformedResponseError(
"Failed to parse XML", body=self.body, driver=self.connection.driver
)
if self.xpath:
errs = findall_ignore_namespace(
element=body, xpath=self.xpath, namespace=self.namespace
)
else:
errs = [body]
msgs = []
for err in errs:
code, message = self._parse_error_details(element=err)
exceptionCls = self.exceptions.get(code, None)
if exceptionCls is None:
msgs.append("{}: {}".format(code, message))
continue
# Custom exception class is defined, immediately throw an exception
params = {}
if hasattr(exceptionCls, "kwargs"):
for key in exceptionCls.kwargs:
if key in context:
params[key] = context[key]
raise exceptionCls(value=message, driver=self.connection.driver, **params)
return "\n".join(msgs)
class AWSTokenConnection(ConnectionUserAndKey):
def __init__(
self,
user_id,
key,
secure=True,
host=None,
port=None,
url=None,
timeout=None,
proxy_url=None,
token=None,
retry_delay=None,
backoff=None,
):
self.token = token
super().__init__(
user_id,
key,
secure=secure,
host=host,
port=port,
url=url,
timeout=timeout,
retry_delay=retry_delay,
backoff=backoff,
proxy_url=proxy_url,
)
def add_default_params(self, params):
# Even though we are adding it to the headers, we need it here too
# so that the token is added to the signature.
if self.token:
params["x-amz-security-token"] = self.token
return super().add_default_params(params)
def add_default_headers(self, headers):
if self.token:
headers["x-amz-security-token"] = self.token
return super().add_default_headers(headers)
class AWSRequestSigner:
"""
Class which handles signing the outgoing AWS requests.
"""
def __init__(self, access_key, access_secret, version, connection):
"""
:param access_key: Access key.
:type access_key: ``str``
:param access_secret: Access secret.
:type access_secret: ``str``
:param version: API version.
:type version: ``str``
:param connection: Connection instance.
:type connection: :class:`Connection`
"""
self.access_key = access_key
self.access_secret = access_secret
self.version = version
# TODO: Remove cycling dependency between connection and signer
self.connection = connection
def get_request_params(self, params, method="GET", path="/"):
return params
def get_request_headers(self, params, headers, method="GET", path="/", data=None):
return params, headers
class AWSRequestSignerAlgorithmV2(AWSRequestSigner):
def get_request_params(self, params, method="GET", path="/"):
params["SignatureVersion"] = "2"
params["SignatureMethod"] = "HmacSHA256"
params["AWSAccessKeyId"] = self.access_key
params["Version"] = self.version
params["Timestamp"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
params["Signature"] = self._get_aws_auth_param(
params=params, secret_key=self.access_secret, path=path
)
return params
def _get_aws_auth_param(self, params, secret_key, path="/"):
"""
Creates the signature required for AWS, per
http://bit.ly/aR7GaQ [docs.amazonwebservices.com]:
StringToSign = HTTPVerb + "\n" +
ValueOfHostHeaderInLowercase + "\n" +
HTTPRequestURI + "\n" +
CanonicalizedQueryString
"""
connection = self.connection
keys = list(params.keys())
keys.sort()
pairs = []
for key in keys:
value = str(params[key])
pairs.append(urlquote(key, safe="") + "=" + urlquote(value, safe="-_~"))
qs = "&".join(pairs)
hostname = connection.host
if (connection.secure and connection.port != 443) or (
not connection.secure and connection.port != 80
):
hostname += ":" + str(connection.port)
string_to_sign = "\n".join(("GET", hostname, path, qs))
b64_hmac = base64.b64encode(
hmac.new(b(secret_key), b(string_to_sign), digestmod=sha256).digest()
)
return b64_hmac.decode("utf-8")
class AWSRequestSignerAlgorithmV4(AWSRequestSigner):
def get_request_params(self, params, method="GET", path="/"):
if method == "GET":
params["Version"] = self.version
return params
def get_request_headers(self, params, headers, method="GET", path="/", data=None):
now = datetime.utcnow()
headers["X-AMZ-Date"] = now.strftime("%Y%m%dT%H%M%SZ")
headers["X-AMZ-Content-SHA256"] = self._get_payload_hash(method, data)
headers["Authorization"] = self._get_authorization_v4_header(
params=params, headers=headers, dt=now, method=method, path=path, data=data
)
return params, headers
def _get_authorization_v4_header(self, params, headers, dt, method="GET", path="/", data=None):
credentials_scope = self._get_credential_scope(dt=dt)
signed_headers = self._get_signed_headers(headers=headers)
signature = self._get_signature(
params=params, headers=headers, dt=dt, method=method, path=path, data=data
)
return (
"AWS4-HMAC-SHA256 Credential=%(u)s/%(c)s, "
"SignedHeaders=%(sh)s, Signature=%(s)s"
% {
"u": self.access_key,
"c": credentials_scope,
"sh": signed_headers,
"s": signature,
}
)
def _get_signature(self, params, headers, dt, method, path, data):
key = self._get_key_to_sign_with(dt)
string_to_sign = self._get_string_to_sign(
params=params, headers=headers, dt=dt, method=method, path=path, data=data
)
return _sign(key=key, msg=string_to_sign, hex=True)
def _get_key_to_sign_with(self, dt):
return _sign(
_sign(
_sign(
_sign(("AWS4" + self.access_secret), dt.strftime("%Y%m%d")),
self.connection.driver.region_name,
),
self.connection.service_name,
),
"aws4_request",
)
def _get_string_to_sign(self, params, headers, dt, method, path, data):
canonical_request = self._get_canonical_request(
params=params, headers=headers, method=method, path=path, data=data
)
return "\n".join(
[
"AWS4-HMAC-SHA256",
dt.strftime("%Y%m%dT%H%M%SZ"),
self._get_credential_scope(dt),
_hash(canonical_request),
]
)
def _get_credential_scope(self, dt):
return "/".join(
[
dt.strftime("%Y%m%d"),
self.connection.driver.region_name,
self.connection.service_name,
"aws4_request",
]
)
def _get_signed_headers(self, headers):
return ";".join([k.lower() for k in sorted(headers.keys(), key=str.lower)])
def _get_canonical_headers(self, headers):
return (
"\n".join(
[
":".join([k.lower(), str(v).strip()])
for k, v in sorted(headers.items(), key=lambda k: k[0].lower())
]
)
+ "\n"
)
def _get_payload_hash(self, method, data=None):
if data is UnsignedPayloadSentinel:
return UNSIGNED_PAYLOAD
if method in ("POST", "PUT"):
if data:
if hasattr(data, "next") or hasattr(data, "__next__"):
# File upload; don't try to read the entire payload
return UNSIGNED_PAYLOAD
return _hash(data)
else:
return UNSIGNED_PAYLOAD
else:
return _hash("")
def _get_request_params(self, params):
# For self.method == GET
return "&".join(
[
"{}={}".format(urlquote(k, safe=""), urlquote(str(v), safe="~"))
for k, v in sorted(params.items())
]
)
def _get_canonical_request(self, params, headers, method, path, data):
return "\n".join(
[
method,
path,
self._get_request_params(params),
self._get_canonical_headers(headers),
self._get_signed_headers(headers),
self._get_payload_hash(method, data),
]
)
class UnsignedPayloadSentinel:
pass
class SignedAWSConnection(AWSTokenConnection):
version = None # type: Optional[str]
def __init__(
self,
user_id,
key,
secure=True,
host=None,
port=None,
url=None,
timeout=None,
proxy_url=None,
token=None,
retry_delay=None,
backoff=None,
signature_version=DEFAULT_SIGNATURE_VERSION,
):
super().__init__(
user_id=user_id,
key=key,
secure=secure,
host=host,
port=port,
url=url,
timeout=timeout,
token=token,
retry_delay=retry_delay,
backoff=backoff,
proxy_url=proxy_url,
)
self.signature_version = str(signature_version)
if self.signature_version == "2":
signer_cls = AWSRequestSignerAlgorithmV2
elif self.signature_version == "4":
signer_cls = AWSRequestSignerAlgorithmV4
else:
raise ValueError("Unsupported signature_version: %s" % (signature_version))
self.signer = signer_cls(
access_key=self.user_id,
access_secret=self.key,
version=self.version,
connection=self,
)
def add_default_params(self, params):
params = self.signer.get_request_params(params=params, method=self.method, path=self.action)
# Verify that params only contain simple types and no nested
# dictionaries.
# params are sent via query params so only strings are supported
for key, value in params.items():
if not isinstance(value, (_real_unicode, basestring, int, bool)):
msg = PARAMS_NOT_STRING_ERROR_MSG % (key, value, type(value))
raise ValueError(msg)
return params
def pre_connect_hook(self, params, headers):
params, headers = self.signer.get_request_headers(
params=params,
headers=headers,
method=self.method,
path=self.action,
data=self.data,
)
return params, headers
class AWSJsonResponse(JsonResponse):
"""
Amazon ECS response class.
ECS API uses JSON unlike the s3, elb drivers
"""
def parse_error(self):
response = json.loads(self.body)
code = response["__type"]
message = response.get("Message", response["message"])
return "{}: {}".format(code, message)
def _sign(key, msg, hex=False):
if hex:
return hmac.new(b(key), b(msg), hashlib.sha256).hexdigest()
else:
return hmac.new(b(key), b(msg), hashlib.sha256).digest()
def _hash(msg):
return hashlib.sha256(b(msg)).hexdigest()
class AWSDriver(BaseDriver):
def __init__(
self,
key,
secret=None,
secure=True,
host=None,
port=None,
api_version=None,
region=None,
token=None,
**kwargs,
):
self.token = token
super().__init__(
key,
secret=secret,
secure=secure,
host=host,
port=port,
api_version=api_version,
region=region,
token=token,
**kwargs,
)
def _ex_connection_class_kwargs(self):
kwargs = super()._ex_connection_class_kwargs()
kwargs["token"] = self.token
return kwargs