Skip to content
Snippets Groups Projects
Commit fe66ad70 authored by jaannigu's avatar jaannigu
Browse files

kill the container when the user fails to connect

parent 546b6b0e
No related branches found
No related tags found
No related merge requests found
......@@ -2,4 +2,5 @@
.env
tutorials/apps/testcontainer/aws/
tutorials/apps/testcontainer/awscliv2.zip
__pycache__
\ No newline at end of file
__pycache__
tutorials/log/app.log
\ No newline at end of file
import boto3
import time
from dotenv import load_dotenv
import os
import json
load_dotenv()
task_arn_dict = {}
def run_ecs_container_fargate(new_image):
"""
1. Describes 'atlas-task' to get its container definitions.
2. Registers a new revision with container image "public.ecr.aws/a3o6k3z1/atlas:<new_image>".
3. Runs the new revision on Fargate with 'assignPublicIp=ENABLED'.
4. Waits for the task to be RUNNING, then extracts & returns the public IPv4.
"""
ecs_client = boto3.client('ecs', region_name='eu-north-1')
username = os.getenv("USERNAME", None)
user_family = f"atlas-{username}"
cluster_name = "atlas-ecs-test1"
template_task_def = "atlas-task"
subnets = json.loads(os.getenv("SUBNETS", []))
security_groups = json.loads(os.getenv("SECURITY_GROUPS", []))
try:
try:
existing_td_resp = ecs_client.describe_task_definition(taskDefinition=user_family)
print(f"Found existing task definition family: {user_family}")
td = existing_td_resp["taskDefinition"]
except ecs_client.exceptions.ClientException as e:
print(f"Task definition {user_family} does not exist yet. ")
template_resp = ecs_client.describe_task_definition(taskDefinition=template_task_def)
td = template_resp["taskDefinition"]
td = {
**td,
"family": user_family
}
# Describe the base 'atlas-task'
#resp = ecs_client.describe_task_definition(taskDefinition=template_task_def)
#td = resp["taskDefinition"]
container_defs = td["containerDefinitions"]
container_defs[0]["name"] = "atlas-container"
container_defs[0]["image"] = f"{new_image}"
register_args = {
"family": td["family"],
"networkMode": td.get("networkMode"),
"containerDefinitions": container_defs,
"volumes": td.get("volumes", []),
"placementConstraints": td.get("placementConstraints", []),
"requiresCompatibilities": td.get("requiresCompatibilities", []),
"cpu": td.get("cpu"),
#"memory": td.get("memory")
"memory": "2048" # Temporary increase to memory
}
if td.get("taskRoleArn"):
register_args["taskRoleArn"] = td["taskRoleArn"]
if td.get("executionRoleArn"):
register_args["executionRoleArn"] = td["executionRoleArn"]
# Register new revision
new_td_resp = ecs_client.register_task_definition(**register_args)
new_td_arn = new_td_resp["taskDefinition"]["taskDefinitionArn"]
print(f"Registered new revision: {new_td_arn}")
# Run the new revision
run_resp = ecs_client.run_task(
cluster=cluster_name,
launchType="FARGATE",
taskDefinition=new_td_arn,
networkConfiguration={
"awsvpcConfiguration": {
"subnets": subnets,
"securityGroups": security_groups,
"assignPublicIp": "ENABLED"
}
}
)
tasks = run_resp.get("tasks", [])
if not tasks:
return "No tasks started"
task_arn = tasks[0]["taskArn"]
print("Task started:", task_arn)
# Wait for the task to be RUNNING
public_ip = wait_for_running_and_get_ip_fargate(ecs_client, cluster_name, task_arn, timeout=180)
if not public_ip:
return f"Task {task_arn} never reached RUNNING or no public IP found."
# Store the IP and container's task ARN
task_arn_dict[public_ip] = task_arn
return public_ip
except Exception as e:
return f"Error registering/running ECS container: {str(e)}"
def wait_for_running_and_get_ip_fargate(ecs_client, cluster, task_arn, timeout=120):
"""
Polls until the ECS task is RUNNING or until 'timeout' seconds pass.
Returns the public IP or None.
"""
start_time = time.time()
while True:
desc = ecs_client.describe_tasks(cluster=cluster, tasks=[task_arn])
tasks = desc.get("tasks", [])
if tasks:
task = tasks[0]
if task["lastStatus"] == "RUNNING":
print("Task is RUNNING now.")
# https://stackoverflow.com/questions/62690894/aws-fargate-how-to-get-the-public-ip-address-of-task-by-using-python-boto3
attachments = task.get("attachments", [])
if attachments:
details = attachments[0].get("details", [])
for item in details:
if item["name"] == "networkInterfaceId":
eni_id = item["value"]
eni = boto3.resource('ec2').NetworkInterface(eni_id)
public_ip = eni.association_attribute['PublicIp']
# print(f"Public IP: {public_ip}")
return public_ip
return None
if time.time() - start_time > timeout:
print("Timed out waiting for RUNNING.")
return None
time.sleep(5)
def stop_ecs_container_fargate(container_ip, task_arn="", cluster_name="atlas-ecs-test1", region_name="eu-north-1"):
"""
Stops the given ECS task (Fargate container) by calling ecs_client.stop_task().
:param task_arn: The full ARN of the running ECS task to stop.
:param cluster_name: The name of the ECS cluster (default: "atlas-ecs-test1").
:param region_name: The AWS region (default: "eu-north-1").
:return: A string message indicating success or error details.
"""
ecs_client = boto3.client("ecs", region_name=region_name)
try:
if task_arn == "":
task_arn = task_arn_dict[container_ip]
response = ecs_client.stop_task(
cluster=cluster_name,
task=task_arn,
reason="Container stopped due to logout or inactivity"
)
task_arn_dict.pop(container_ip)
return f"Stopped ECS task: {task_arn}"
except Exception as e:
return f"Error stopping ECS container: {str(e)}"
\ No newline at end of file
import os
from flask import Flask, redirect, url_for, session, request, jsonify, make_response
from authlib.integrations.flask_client import OAuth
from urllib.parse import urljoin
import boto3
import awscontroller as ac
import time
from dotenv import load_dotenv
import token, jwt
load_dotenv()
app = Flask(__name__)
# TODO: Use a more secure secret key
app.secret_key = os.urandom(24)
# ------------------------------------------------------------------------------
# COGNITO SETTINGS
# ------------------------------------------------------------------------------
COGNITO_POOL_ID = os.getenv("COGNITO_POOL_ID", "eu-north-1_z6hZRjOrR")
CLIENT_ID = os.getenv("COGNITO_CLIENT_ID")
CLIENT_SECRET = os.getenv("COGNITO_CLIENT_SECRET")
CALLBACK_URL = os.getenv("CALLBACK_URL")
SCOPES = "openid email phone profile"
# ------------------------------------------------------------------------------
# SETUP OAUTH WITH AUTHLIB
# ------------------------------------------------------------------------------
app.config.update({'SERVER_NAME': os.getenv("SERVER_NAME")})
oauth = OAuth(app)
# OIDC Discovery URL for user pool:
# https://cognito-idp.<region>.amazonaws.com/<user_pool_id>/.well-known/openid-configuration
server_metadata_url = f"https://cognito-idp.eu-north-1.amazonaws.com/{COGNITO_POOL_ID}/.well-known/openid-configuration"
oauth.register(
name='oidc',
client_id=CLIENT_ID,
client_secret=CLIENT_SECRET,
server_metadata_url=server_metadata_url,
client_kwargs={'scope': SCOPES},
)
# ------------------------------------------------------------------------------
# ROUTES
# ------------------------------------------------------------------------------
@app.route('/')
def index():
userinfo = session.get('user')
if not userinfo:
return """
<html>
<body>
<h1>Welcome!</h1>
<p>You are not logged in.</p>
<p><a href="/login">Login with Cognito</a></p>
</body>
</html>
"""
else:
containerImage = userinfo.get("custom:containerImage", "<not found>")
return f"""
<html>
<head>
<title>Starting Container</title>
<style>
.loader {{
border: 16px solid #f3f3f3;
border-radius: 50%;
border-top: 16px solid #3498db;
width: 60px;
height: 60px;
animation: spin 2s linear infinite;
margin: 20px auto;
}}
@keyframes spin {{
0% {{ transform: rotate(0deg); }}
100% {{ transform: rotate(360deg); }}
}}
body {{
font-family: Arial, sans-serif;
text-align: center;
margin-top: 50px;
}}
</style>
<script>
function getCookie(name) {{
const match = document.cookie.match(new RegExp('(^| )' + name + '=([^;]+)'));
if (match) return match[2];
return '';
}}
window.addEventListener('DOMContentLoaded', () => {{
const myToken = getCookie('my_token');
fetch('/launch?containerImage={containerImage}')
.then(response => response.json())
.then(data => {{
if (data.public_ip) {{
window.location.href = 'http://' + data.public_ip + ':5006/app?token='+ encodeURIComponent(myToken) + "&ip=" + encodeURIComponent(data.public_ip);
}} else {{
document.body.innerHTML =
'<h1>Error</h1><p>' + (data.error || 'No IP found') + '</p>';
}}
}})
.catch(err => {{
document.body.innerHTML = '<h1>Request Failed</h1><p>' + err + '</p>';
}});
}});
</script>
</head>
<body>
<h1>Loading...</h1>
<div class="loader"></div>
<p>Starting your container with image: <strong>{containerImage}</strong></p>
</body>
</html>
"""
@app.route('/launch')
def launch():
"""
This route does the actual container start (run_ecs_container),
waits for a public IP, and returns JSON.
The front-end JavaScript then redirects the user to that IP.
"""
userinfo = session.get('user')
if not userinfo:
return redirect(url_for('index'))
container_image = request.args.get('containerImage', '<No containerImage>')
public_ip = ac.run_ecs_container_fargate(container_image)
time.sleep(5) # Give more time for the container and app to load
if public_ip.startswith('Task '):
return jsonify({"error": public_ip})
elif public_ip.startswith("Error"):
return jsonify({"error": public_ip})
else:
return jsonify({"public_ip": public_ip})
@app.route('/login')
def login():
"""
Login route:
- This is the only place to start the Cognito login flow.
- Authlib sets the OAuth "state" in the user's session
before redirecting to Cognito's login page.
"""
redirect_uri = url_for('authorize', _external=True)
return oauth.oidc.authorize_redirect(redirect_uri)
@app.route('/authorize')
def authorize():
"""
After a successful Cognito login, user is redirected here with '?code=...'.
Authlib exchanges the code for tokens and fetches userinfo.
We store userinfo in the session. If the user tries to skip the "/login"
step, they'd lack the correct session state, resulting in a mismatch.
"""
token = oauth.oidc.authorize_access_token()
# token usually includes 'access_token', 'id_token', 'refresh_token', etc.
# Authlib automatically calls the userinfo endpoint if present, storing it in token['userinfo']
userinfo = token.get('userinfo', {})
session['user'] = userinfo
id_token = token.get('id_token')
#decoded = jwt.decode(id_token, options={"verify_signature": False})# debugging
resp = make_response(redirect(url_for('index')))
if 'sub' in userinfo:
resp.set_cookie('my_token', id_token, httponly=False)
else:
resp.set_cookie('my_token', 'No_id_token', httponly=False)
return resp
@app.route('/logout')
def logout():
"""
Stops the ECS Fargate container and clears the local session.
"""
container_ip = request.args.get('ip')
stopped = ac.stop_ecs_container_fargate(container_ip)
session.pop('user', None)
resp = make_response(redirect(url_for('index')))
resp.set_cookie('my_token', '', expires=0)
print(f"{stopped}; ip: {container_ip}")
return resp
# ------------------------------------------------------------------------------
# MAIN
# ------------------------------------------------------------------------------
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=5000)
......@@ -6,7 +6,7 @@ import json
load_dotenv()
task_arn_dict = {}
from login import log_into_file
def run_ecs_container_fargate(new_image):
def run_ecs_container_fargate(new_image, email):
"""
1. Describes 'atlas-task' to get its container definitions.
2. Registers a new revision with container image "public.ecr.aws/a3o6k3z1/atlas:<new_image>".
......@@ -14,8 +14,8 @@ def run_ecs_container_fargate(new_image):
4. Waits for the task to be RUNNING, then extracts & returns the public IPv4.
"""
ecs_client = boto3.client('ecs', region_name='eu-north-1')
username = os.getenv("USERNAME", None)
user_family = f"atlas-{username}"
user_family = f"atlas-{email}"
cluster_name = "atlas-ecs-test1"
template_task_def = "atlas-task"
subnets = json.loads(os.getenv("SUBNETS", []))
......
......@@ -6,7 +6,7 @@ import boto3
import awscontroller as ac
import time
from dotenv import load_dotenv
import token, jwt
import token, jwt, re
from datetime import datetime
load_dotenv()
app = Flask(__name__)
......@@ -47,6 +47,8 @@ def log_into_file(info):
with open("../../log/app.log", "a", encoding="utf-8") as f:
f.write(f"{datetime.now().strftime("%d-%m-%Y %H:%M:%S")} - {info}\n")
def sanitize_email(email):
return re.sub(r'[^a-zA-Z0-9_-]', '-', email)
# ------------------------------------------------------------------------------
# ROUTES
# ------------------------------------------------------------------------------
......@@ -192,7 +194,7 @@ def dashboard():
document.body.innerHTML = `
<div class="container">
<h1>Error</h1>
<p>(data.error || 'No IP found')</p>
<p>${{data.error || 'No IP found'}}</p>
</div>
`;
}}
......@@ -229,9 +231,10 @@ def launch():
userinfo = session.get('user')
if not userinfo:
return redirect(url_for('index'))
user_email = userinfo['email']
clean_email = sanitize_email(user_email)
container_image = request.args.get('containerImage', '<No containerImage>')
public_ip = ac.run_ecs_container_fargate(container_image)
public_ip = ac.run_ecs_container_fargate(container_image, clean_email)
time.sleep(5) # Give more time for the container and app to load
if public_ip.startswith('Task '):
return jsonify({"error": public_ip})
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment