Source code for timApp.auth.oauth2.models

from dataclasses import dataclass, field
from enum import Enum
from typing import Optional

from authlib.integrations.sqla_oauth2 import (
    OAuth2TokenMixin,
    OAuth2AuthorizationCodeMixin,
)
from authlib.oauth2.rfc6749 import ClientMixin, scope_to_list, list_to_scope

from timApp.timdb.sqa import db


[docs]class Scope(Enum): profile = "profile"
[docs]@dataclass class OAuth2Client(ClientMixin): """ An application that is allowed to authenticate as a TIM user and use OAUTH-protected REST API. """ client_id: str """Unique identifier for the client.""" client_name: str | None = None """User-friendly client name""" client_secret: str = "" """Client secret that is used to allow OAUTH2 authentication.""" redirect_urls: list[str] = field(default_factory=list) """List of valid URLs that TIM is allowed to redirect the user to upon successful authentication.""" allowed_scopes: list[Scope] = field(default_factory=list) """Resource scopes that the client can ask for. Scopes are used to limit what REST API can be used.""" token_endpoint_auth_method = "client_secret_post" """How the client authenticates itself with TIM. Allowed values: * "none": The client is a public client as defined in OAuth 2.0, and does not have a client secret. * "client_secret_post": The client uses the HTTP POST parameters as defined in OAuth 2.0 * "client_secret_basic": The client uses HTTP Basic as defined in OAuth 2.0 """ response_types: list[str] = field(default_factory=list) """What response types the client can handle. In other words, tells TIM how to send user's OAUTH2 token to the client. Allowed values: "code" and "token". """ grant_types: list[str] = field(default_factory=list) """What grant types the client can handle. Default values: * authorization_code * implicit * client_credentials * password More custom grant types are allowed. """ @property def name(self) -> str: return self.client_name or self.client_id
[docs] def get_client_id(self) -> str: return self.client_id
[docs] def get_default_redirect_uri(self) -> str | None: if self.redirect_urls: return self.redirect_urls[0] return None
[docs] def get_allowed_scope(self, scope: str) -> str | None: if not scope: return "" allowed = {s.name for s in self.allowed_scopes} scopes = scope_to_list(scope) return list_to_scope([s for s in scopes if s in allowed])
[docs] def check_redirect_uri(self, redirect_uri: str) -> bool: return redirect_uri in self.redirect_urls
[docs] def check_client_secret(self, client_secret: str) -> bool: return self.client_secret == client_secret
[docs] def check_endpoint_auth_method(self, method: str, endpoint: str) -> bool: if endpoint == "token": return self.token_endpoint_auth_method == method return True
[docs] def check_response_type(self, response_type: str) -> bool: return response_type in self.response_types
[docs] def check_grant_type(self, grant_type: str) -> bool: return grant_type in self.grant_types
[docs]class OAuth2Token(db.Model, OAuth2TokenMixin): __tablename__ = "oauth2_token" id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey("useraccount.id")) user = db.relationship("User")
[docs]class OAuth2AuthorizationCode(db.Model, OAuth2AuthorizationCodeMixin): __tablename__ = "oauth2_auth_code" id = db.Column(db.Integer, primary_key=True) user_id = db.Column(db.Integer, db.ForeignKey("useraccount.id")) user = db.relationship("User")