import time
from typing import Optional
from authlib.integrations.flask_oauth2 import AuthorizationServer, ResourceProtector
from authlib.integrations.sqla_oauth2 import (
create_save_token_func,
create_bearer_token_validator,
)
from authlib.oauth2 import OAuth2Request
from authlib.oauth2.rfc6749 import grants
from flask import Flask
from timApp.auth.oauth2.models import OAuth2Client, OAuth2Token, OAuth2AuthorizationCode
from timApp.timdb.sqa import db
from timApp.user.user import User
from tim_common.marshmallow_dataclass import class_schema
ALLOWED_CLIENTS: dict[str, OAuth2Client] = {}
[docs]class RefreshTokenGrant(grants.RefreshTokenGrant):
TOKEN_ENDPOINT_AUTH_METHODS = [
"client_secret_basic",
"client_secret_post",
]
INCLUDE_NEW_REFRESH_TOKEN = True
[docs] def authenticate_refresh_token(self, refresh_token: str) -> OAuth2Token | None:
token: OAuth2Token = OAuth2Token.query.filter_by(refresh_token=refresh_token)
if token and not token.is_revoked() and not token.is_expired():
return token
return None
[docs] def authenticate_user(self, credential: OAuth2Token) -> User:
return User.query.get(credential.user_id)
[docs] def revoke_old_credential(self, credential: OAuth2Token) -> None:
credential.refresh_token_revoked_at = int(time.time())
db.session.add(credential)
db.session.commit()
[docs]class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
TOKEN_ENDPOINT_AUTH_METHODS = [
"client_secret_basic",
"client_secret_post",
# TODO: Do we need 'none'?
]
[docs] def save_authorization_code(
self, code: str, request: OAuth2Request
) -> OAuth2AuthorizationCode:
auth_code = OAuth2AuthorizationCode(
code=code,
client_id=request.client.client_id,
redirect_uri=request.redirect_uri,
scope=request.scope,
user_id=request.user.id,
)
db.session.add(auth_code)
db.session.commit()
return auth_code
[docs] def query_authorization_code(
self, code: str, client: OAuth2Client
) -> OAuth2AuthorizationCode | None:
auth_code = OAuth2AuthorizationCode.query.filter_by(
code=code, client_id=client.client_id
).first()
if auth_code and not auth_code.is_expired():
return auth_code
return None
[docs] def delete_authorization_code(
self, authorization_code: OAuth2AuthorizationCode
) -> None:
db.session.delete(authorization_code)
db.session.commit()
[docs] def authenticate_user(self, authorization_code: OAuth2AuthorizationCode) -> User:
return User.query.get(authorization_code.user_id)
[docs]def query_client(client_id: str) -> OAuth2Client:
if client_id not in ALLOWED_CLIENTS:
raise Exception(f"OAuth2 client {client_id} is not in allowed list")
return ALLOWED_CLIENTS[client_id]
save_token = create_save_token_func(db.session, OAuth2Token)
auth_server = AuthorizationServer(query_client=query_client, save_token=save_token)
require_oauth = ResourceProtector()
"""Special decorator to request for permission scopes"""
[docs]def delete_expired_oauth2_tokens() -> None:
now_time = int(time.time())
OAuth2Token.query.filter(
(OAuth2Token.expires_in + OAuth2Token.issued_at < now_time)
| (OAuth2Token.access_token_revoked_at < now_time)
| (OAuth2Token.refresh_token_revoked_at < now_time)
).delete()
db.session.commit()
[docs]def init_oauth(app: Flask) -> None:
global ALLOWED_CLIENTS
clients = app.config.get("OAUTH2_CLIENTS", [])
schema = class_schema(OAuth2Client)()
clients_obj: list[OAuth2Client] = [schema.load(c) for c in clients]
ALLOWED_CLIENTS = {c.client_id: c for c in clients_obj}
auth_server.init_app(app)
auth_server.register_grant(AuthorizationCodeGrant)
auth_server.register_grant(RefreshTokenGrant)
# TODO: Do we need to support revocation?
from timApp.auth.oauth2.routes import oauth
app.register_blueprint(oauth)
bearer_cls = create_bearer_token_validator(db.session, OAuth2Token)
require_oauth.register_token_validator(bearer_cls())