Source code for kadi.lib.oauth.models

# Copyright 2020 Karlsruhe Institute of Technology
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import secrets
from time import time

from authlib.integrations.sqla_oauth2 import OAuth2AuthorizationCodeMixin
from authlib.integrations.sqla_oauth2 import OAuth2ClientMixin
from authlib.integrations.sqla_oauth2 import OAuth2TokenMixin
from flask import current_app
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm.attributes import flag_modified
from sqlalchemy_utils.types.encrypted.encrypted_type import StringEncryptedType

import kadi.lib.constants as const
from kadi.ext.db import db
from kadi.ext.db import KadiAesEngine
from kadi.lib.api.models import AccessTokenMixin
from kadi.lib.db import SimpleTimestampMixin
from kadi.lib.db import unique_constraint
from kadi.lib.db import UTCDateTime
from kadi.lib.security import hash_value
from kadi.lib.security import random_bytes
from kadi.lib.utils import SimpleReprMixin
from kadi.lib.utils import utcnow


[docs]class OAuth2ClientToken(SimpleReprMixin, db.Model): """Model to represent OAuth2 client tokens. Note that this model uses encrypted fields and can potentially raise a :class:`.KadiDecryptionKeyError` when a value cannot be decrypted. """
[docs] class Meta: """Container to store meta class attributes.""" representation = ["id", "user_id", "name"] """See :class:`.SimpleReprMixin`."""
__tablename__ = "oauth2_client_token" __table_args__ = (unique_constraint(__tablename__, "user_id", "name"),) id = db.Column(db.Integer, primary_key=True) """The ID of the client token, auto incremented.""" user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) """The ID of the user the client token belongs to.""" name = db.Column(db.Text, nullable=False) """The name of the client token. Currently always refers to the name of a specific OAuth2 provider. """ access_token = db.Column( StringEncryptedType( type_in=db.Text, engine=KadiAesEngine, key=KadiAesEngine.get_secret_key ), nullable=False, ) """The actual access token value, stored encrypted.""" refresh_token = db.Column( StringEncryptedType( type_in=db.Text, engine=KadiAesEngine, key=KadiAesEngine.get_secret_key ), nullable=True, ) """The optional refresh token value, stored encrypted.""" expires_at = db.Column(UTCDateTime, nullable=True) """The optional expiration date and time of the access token.""" user = db.relationship("User", back_populates="oauth2_client_tokens") @property def is_expired(self): """Check if the access token is expired.""" if self.expires_at is not None: return self.expires_at < utcnow() return False
[docs] @classmethod def create(cls, *, user, name, access_token, refresh_token=None, expires_at=None): """Create a new OAuth2 client token and add it to the database session. :param user: The user the client token should belong to. :param name: The name of the client token. :param access_token: The actual access token value. :param refresh_token: (optional) The refresh token value. :param expires_at: (optional) The expiration date and time of the access token. :return: The new :class:`.OAuth2ClientToken` object. """ oauth2_client_token = cls( user=user, name=name, access_token=access_token, refresh_token=refresh_token, expires_at=expires_at, ) db.session.add(oauth2_client_token) return oauth2_client_token
[docs] def to_authlib_token(self): """Convert the client token to a format usable by an Authlib client. :return: A dictionary representation of the client token. """ expires_at = None if self.expires_at is not None: expires_at = int(self.expires_at.timestamp()) return { "access_token": self.access_token, "refresh_token": self.refresh_token, "expires_at": expires_at, "token_type": const.OAUTH_TOKEN_TYPE, }
[docs]class OAuth2ServerClient( SimpleReprMixin, SimpleTimestampMixin, OAuth2ClientMixin, db.Model ): """Model to represent registered OAuth2 clients/applications."""
[docs] class Meta: """Container to store meta class attributes.""" representation = ["id", "user_id", "client_id", "client_name"] """See :class:`.SimpleReprMixin`."""
__tablename__ = "oauth2_server_client" id = db.Column(db.Integer, primary_key=True) """The ID of the client, auto incremented.""" user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) """The ID of the user who created the client.""" client_id = db.Column(db.Text, unique=True, nullable=False) """The OAuth2 client ID.""" client_secret = db.Column(db.Text, nullable=False) """The OAuth2 client secret. Note that only a hash of the actual client secret value is stored. """ _client_metadata = db.Column("client_metadata", JSONB, nullable=False) """Additional metadata of the client.""" user = db.relationship("User", back_populates="oauth2_server_clients") oauth2_server_tokens = db.relationship( "OAuth2ServerToken", lazy="dynamic", back_populates="client", cascade="all, delete-orphan", ) oauth2_server_auth_codes = db.relationship( "OAuth2ServerAuthCode", lazy="dynamic", back_populates="client", cascade="all, delete-orphan", ) @property def client_metadata(self): """Get the additional metadata of this client.""" return self._client_metadata @property def scope(self): """Get the scope of this client.""" return self.client_metadata.get("scope", None)
[docs] @staticmethod def new_client_secret(): """Create a new random client secret. :return: The generated client secret. """ return random_bytes(num_bytes=32)
[docs] @staticmethod def hash_client_secret(client_secret): """Create a secure hash of a client secret. :param client_secret: The client secret to hash. :return: The calculated hash as a hexadecimal value. """ return hash_value(client_secret)
[docs] @classmethod def create( cls, *, user, client_name, client_uri, redirect_uris, scope=None, client_secret=None, ): """Create a new OAuth2 client and add it to the database session. :param user: The user the client should belong to. :param client_name: The name of the client. Will be stored as part of the client metadata. :param client_uri: The website of the client. Will be stored as part of the client metadata. :param redirect_uris: A list of allowed redirect URIs. Will be stored as part of the client metadata. :param scope: (optional) The scope of the client as a single string defining a list of space-delimited scope values. Will be stored as part of the client metadata. :param client_secret: (optional) The client secret, which will be hashed before persisting. Defaults to a client secret created by :meth:`new_client_secret`. :return: The new :class:`.OAuth2ServerClient` object. """ client_secret = ( client_secret if client_secret is not None else cls.new_client_secret() ) oauth2_client = cls( user=user, client_id=random_bytes(num_bytes=16), client_secret=cls.hash_client_secret(client_secret), client_id_issued_at=int(time()), ) db.session.add(oauth2_client) client_metadata = { "client_name": client_name, "client_uri": client_uri, "redirect_uris": redirect_uris, "scope": scope, "token_endpoint_auth_method": const.OAUTH_TOKEN_ENDPOINT_AUTH_METHOD, "grant_types": [ const.OAUTH_GRANT_AUTH_CODE, const.OAUTH_GRANT_REFRESH_TOKEN, ], "response_types": const.OAUTH_RESPONSE_TYPE, } oauth2_client.set_client_metadata(client_metadata) return oauth2_client
[docs] def set_client_metadata(self, value): """Set the additional metadata of this client. :param value: The metadata as a JSON serializable dictionary. """ self._client_metadata = value
[docs] def update_client_metadata(self, **kwargs): r"""Update the additional metadata of this client. :param \**kwargs: JSON serializable keyword arguments to update the metadata with. """ for key, value in kwargs.items(): if key not in self._client_metadata or self._client_metadata[key] != value: self._client_metadata.update(kwargs) # Make sure to alert SQLAlchemy of the change, as it won't be persisted # otherwise, but only if something actually changed. flag_modified(self, "_client_metadata") return
[docs] def check_client_secret(self, client_secret): """Compare the client secret of this client with a given client secret. :param client_secret: The client secret to compare with, which will be hashed before comparing. """ return secrets.compare_digest( self.client_secret, self.hash_client_secret(client_secret) )
[docs]class OAuth2ServerToken(SimpleReprMixin, AccessTokenMixin, OAuth2TokenMixin, db.Model): """Model to represent OAuth2 server tokens."""
[docs] class Meta: """Container to store meta class attributes.""" representation = ["id", "user_id", "client_id"] """See :class:`.SimpleReprMixin`."""
__tablename__ = "oauth2_server_token" id = db.Column(db.Integer, primary_key=True) """The ID of the server token, auto incremented.""" client_id = db.Column( db.Text, db.ForeignKey("oauth2_server_client.client_id"), nullable=False ) """The client ID of the client the server token belongs to.""" access_token = db.Column(db.Text, index=True, nullable=False) """The actual access token value. Note that only a hash of the actual access token value is stored. """ refresh_token = db.Column(db.Text, index=True, nullable=False) """The actual refresh token value. Note that only a hash of the actual refresh token value is stored. """ client = db.relationship( "OAuth2ServerClient", back_populates="oauth2_server_tokens" ) @property def is_expired(self): """Check if the access token is expired.""" # We simply delegate to the mixin but use a property to stay consistent with the # "AccessTokenMixin". return OAuth2TokenMixin.is_expired(self)
[docs] @staticmethod def new_access_token(include_prefix=True): """Create a new random access token value. :param include_prefix: (optional) Whether to include a prefix before the actual access token value to distinguish it with other types of access tokens. :return: The generated access token value. """ from .utils import new_oauth2_access_token return new_oauth2_access_token(include_prefix=include_prefix)
[docs] @staticmethod def new_refresh_token(): """Create a new random refresh token value. :return: The generated refresh token value. """ from .utils import new_oauth2_refresh_token return new_oauth2_refresh_token()
[docs] @staticmethod def hash_token(token): """Create a secure hash of an access or refresh token value. :param token: The token value to hash. :return: The calculated hash as a hexadecimal value. """ prefix = const.ACCESS_TOKEN_PREFIX_OAUTH # Exclude the prefix before hashing the token, if applicable. if token.startswith(prefix): token = token[len(prefix) :] return hash_value(token)
[docs] @classmethod def get_by_access_token(cls, token): """Get a server token using an access token value. :param token: The access token value to search for. :return: The server token or ``None``. """ token_hash = cls.hash_token(token) return cls.query.filter_by(access_token=token_hash).first()
[docs] @classmethod def get_by_refresh_token(cls, token): """Get a server token using a refresh token value. :param token: The refresh token value to search for. :return: The server token or ``None``. """ token_hash = cls.hash_token(token) return cls.query.filter_by(refresh_token=token_hash).first()
[docs] @classmethod def create( cls, *, user, client, expires_in, access_token=None, refresh_token=None, scope=None, ): """Create a new OAuth2 server token and add it to the database session. :param user: The user the server token should belong to. :param client: The client the server token should belong to. :param expires_in: The expiration time of the access token in seconds. :param access_token: (optional) The actual access token value, which will be hashed before persisting. Defaults to an access token value created by :meth:`new_access_token`. :param refresh_token: (optional) The actual refresh token value, which will be hashed before persisting. Defaults to a refresh token value created by :meth:`new_refresh_token`. :param scope: (optional) The scope of the server token. :return: The new :class:`.OAuth2ServerToken` object. """ access_token = ( access_token if access_token is not None else cls.new_access_token() ) refresh_token = ( refresh_token if refresh_token is not None else cls.new_refresh_token() ) oauth2_server_token = cls( user=user, client=client, expires_in=expires_in, access_token=cls.hash_token(access_token), refresh_token=cls.hash_token(refresh_token), scope=scope, token_type=const.OAUTH_TOKEN_TYPE, ) db.session.add(oauth2_server_token) return oauth2_server_token
[docs]class OAuth2ServerAuthCode(SimpleReprMixin, OAuth2AuthorizationCodeMixin, db.Model): """Model to represent OAuth2 authorization codes. Required for the implementation of the OAuth2 authorization code grant. """
[docs] class Meta: """Container to store meta class attributes.""" representation = ["id", "user_id", "client_id"] """See :class:`.SimpleReprMixin`."""
__tablename__ = "oauth2_server_auth_code" id = db.Column(db.Integer, primary_key=True) """The ID of the authorization code, auto incremented.""" user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False) """The ID of the user the authorization code belongs to.""" client_id = db.Column( db.Text, db.ForeignKey("oauth2_server_client.client_id"), nullable=False ) """The client ID of the client the authorization code belongs to.""" scope = db.Column(db.Text, nullable=True) """The optional scope of the authorization code.""" user = db.relationship("User", back_populates="oauth2_server_auth_codes") client = db.relationship( "OAuth2ServerClient", back_populates="oauth2_server_auth_codes" )
[docs] @classmethod def create( cls, *, user, client, code, redirect_uri, scope=None, code_challenge=None, code_challenge_method=None, ): """Create a new OAuth2 authorization code and add it to the database session. :param user: The user the authorization code should belong to. :param client: The client the authorization code should belong to. :param code: The actual authorization code value. :param redirect_uri: The allowed redirect URI of the authorization code. :param scope: (optional) The scope of the authorization code. :param code_challenge: (optional) The code challenge of the authorization code used for PKCE. :param code_challenge_method: (optional) The code challenge method of the authorization code used for PKCE. :return: The new :class:`.OAuth2ServerAuthCode` object. """ oauth2_auth_code = cls( user=user, client=client, code=code, redirect_uri=redirect_uri, scope=scope, code_challenge=code_challenge, code_challenge_method=code_challenge_method, response_type=const.OAUTH_RESPONSE_TYPE, ) db.session.add(oauth2_auth_code) return oauth2_auth_code
[docs] def is_expired(self): """Check if the authorization code is expired. :return: ``True`` if the authorization code is expired, ``False`` otherwise. """ expiration_time = ( self.auth_time + current_app.config["OAUTH_AUTH_CODE_EXPIRES_IN"] ) return expiration_time < int(time())