Source code for awscli_bastion.sts
from botocore.exceptions import ClientError
from datetime import timedelta
from dateutil.tz import tzutc
import boto3
import click
import datetime
import getpass
import sys
ONE_HOUR_IN_SECONDS = timedelta(hours=1).seconds
TWELVE_HOURS_IN_SECONDS = timedelta(hours=12).seconds
[docs]class STS:
""" A small class that wraps relevant boto3 sts function calls. """
def __init__(self, bastion="bastion", bastion_sts="bastion-sts", region="us-west-2",
credentials=None, cache=None):
self.bastion = bastion
self.bastion_sts = bastion_sts
self.region = region
self.credentials = credentials
self.cache = cache
[docs] def is_mfa_code_invalid(self, mfa_code):
return len(mfa_code) != 6 or not mfa_code.isdigit()
def _get_mfa_code(self, mfa_serial):
""" Prompt the user for the mfa code and return it.
:param mfa_serial: The identification number of the MFA device that is associated with the IAM user.
:type mfa_serial: str
:return: 6 digit mfa code
:rtype: str
"""
is_mfa_code_invalid = True
while is_mfa_code_invalid:
mfa_code = getpass.getpass("Enter MFA code for {}: ".format(mfa_serial))
is_mfa_code_invalid = self.is_mfa_code_invalid(mfa_code)
if is_mfa_code_invalid:
click.echo("Warning: The MFA code must be 6 digits. For example: 123456")
return mfa_code
[docs] def get_session_token(self, mfa_code=None, mfa_serial=None,
duration_seconds=TWELVE_HOURS_IN_SECONDS):
""" Get the short-lived credentials from sts.get_session_token()
if the 'mfa_code' is provided. Otherwise, try to look up sts credentials from the cache.
:param mfa_code: The value provided by the MFA device.
:type mfa_code: str
:param mfa_serial: The identification number of the MFA device that is associated with the IAM user.
:type mfa_serial: str
:param duration_seconds: The duration, in seconds, that the credentials should remain valid.
:type duration_seconds: str
:return: sts credentials
:rtype: dict
"""
if not mfa_serial:
mfa_serial = self.credentials.get_mfa_serial(bastion_sts=self.bastion_sts)
cached_sts_creds = None
if not mfa_code and not self.cache.is_expired():
cached_sts_creds = self.cache.read()
if cached_sts_creds:
sts_creds = cached_sts_creds
else:
if not mfa_code or self.is_mfa_code_invalid(mfa_code):
mfa_code = self._get_mfa_code(mfa_serial)
session = boto3.Session(profile_name=self.bastion, region_name=self.region)
sts = session.client("sts")
try:
sts_creds = sts.get_session_token(
DurationSeconds=duration_seconds,
SerialNumber=mfa_serial,
TokenCode=mfa_code
)["Credentials"]
sts_creds["Expiration"] = sts_creds["Expiration"].isoformat()
except Exception as e:
click.echo(e)
sys.exit(1)
self.cache.write(sts_creds)
return sts_creds
[docs] def assume_role(self, profile, duration_seconds=ONE_HOUR_IN_SECONDS):
"""Get the short-lived credentials from sts.assume_role().
:param profile: The profile that contains the 'role_arn' and 'source_profile' attributes.
:type profile: str
:param duration_seconds: The duration, in seconds, that the credentials should remain valid.
:type duration_seconds: str
:return: sts credentials
:rtype: dict
"""
session = boto3.Session(profile_name=self.bastion_sts, region_name=self.region)
sts = session.client("sts")
try:
role_arn = self.credentials.config[profile]["role_arn"]
except Exception:
click.echo("An error occured when getting the role_arn from '{}' profile.".format(profile))
sys.exit(1)
timestamp = datetime.datetime.now(tzutc()).strftime("%Y-%m-%d")
try:
iam = boto3.client('iam')
username = iam.get_user()["User"]["UserName"]
role_session_name = "{}-{}".format(username, timestamp)
except Exception:
role_session_name = "bastion-assume-role-{}".format(timestamp)
try:
sts_creds = sts.assume_role(
RoleArn=role_arn,
RoleSessionName=role_session_name,
DurationSeconds=duration_seconds
)["Credentials"]
except ClientError as e:
click.echo(e)
sys.exit(1)
return sts_creds