Source code for luminadb.models.helpers

"""Model helpers"""

# pylint: disable=invalid-name,too-few-public-methods,abstract-method,protected-access

from typing import Any, Callable, Type, TypeAlias, TypeVar, overload
from enum import StrEnum

import luminadb
from .errors import ValidationError
from ..column import BuilderColumn, text, integer, blob, boolean, real

TypeFunction: TypeAlias = Callable[[str], BuilderColumn]
Model = TypeVar("Model", bound="luminadb.BaseModel")
FuncT = TypeVar("FuncT", bound=Callable[..., bool])
BaseModel: TypeAlias = "luminadb.BaseModel"

TYPES: dict[Type[Any], TypeFunction] = (
    {  # pylint: disable=possibly-used-before-assignment
        int: integer,
        str: text,
        bytes: blob,
        bool: boolean,
        float: real
    }
)

VALID_HOOKS_NAME = (
    "before_create",
    "after_create",
    "before_update",
    "after_update",
    "before_delete",
    "after_delete",
)


[docs] class ConstraintEnum(StrEnum): """Constraints for update/delete""" RESTRICT = "restrict" SETNULL = "null" CASCADE = "cascade" NOACT = "no act" DEFAULT = "default"
RESTRICT = ConstraintEnum.RESTRICT SETNULL = ConstraintEnum.SETNULL CASCADE = ConstraintEnum.CASCADE NOACT = ConstraintEnum.NOACT DEFAULT = ConstraintEnum.DEFAULT
[docs] class Constraint: """Base constraint class for models""" def __init__(self, column: str) -> None: self._column = column @property def column(self): """Columns""" return self._column
[docs] def apply(self, type_: BuilderColumn): """Apply this constraint to an column""" raise NotImplementedError()
[docs] class Unique(Constraint): """Unique constraint"""
[docs] def apply(self, type_: BuilderColumn): type_.unique()
[docs] class Foreign(Constraint): """Foreign constraint""" def __init__(self, column: str, target: str | Type[Model]) -> None: super().__init__(column) self._target = target self.resolve() self._base = target self._on_delete = DEFAULT self._on_update = DEFAULT @property def target(self): """Target foreign constraint""" return self._target
[docs] def on_delete(self, constraint: ConstraintEnum): """On delete constraint""" self._on_delete = constraint return self
[docs] def on_update(self, constraint: ConstraintEnum): """On update constraint""" self._on_update = constraint return self
[docs] def resolve(self): """Resolve if current target is a Model""" if issubclass(self._target, luminadb.BaseModel): # type: ignore name = self._target.__table_name__ target = self._target._primary # pylint: disable=protected-access if not target: raise ValueError(f"{type(self._target)} does not have primary key") self._target = f"{name}/{target}"
[docs] def apply(self, type_: BuilderColumn): type_.foreign(self._target) # type: ignore if self._on_delete != DEFAULT: type_.on_delete(self._on_delete.value) if self._on_update != DEFAULT: type_.on_update(self._on_update.value)
[docs] class Primary(Constraint): """Primary constraint Accepts optional `auto` flag to enable auto-increment on integer primary columns when using the BuilderColumn API. """ def __init__(self, column: str, auto: bool = False) -> None: super().__init__(column) self._auto = bool(auto) @property def auto(self) -> bool: """Auto increment""" return self._auto
[docs] def apply(self, type_: BuilderColumn): """Apply this constraint as primary. If `auto` was requested, enable auto increment on the builder column as well. """ type_.primary() if self._auto: type_.auto_increment()
[docs] class Validators: """Base class to hold validators""" def __init__(self, fn: Callable[[Any], bool], if_fail: str) -> None: self._callable = fn self._reason = if_fail
[docs] def validate(self, instance: BaseModel): """Validate a value""" if not self._callable(instance): err = ValidationError(self._reason) err.add_note(str(instance)) raise err return True
@overload def hook(fn_or_name: Callable[[Model], None]) -> "staticmethod[[Callable[[Model], None]], None]": pass @overload def hook(fn_or_name: str): pass
[docs] def hook(fn_or_name): """Register a hook""" def decorator(func): fn = staticmethod(func) fn_name = func.__name__ final_name = fn_or_name if fn_name not in VALID_HOOKS_NAME else fn_name if final_name is None: raise ValueError("Hooks name is not valid. Provide with @hook(name)") fn._hooks_info = (fn_or_name # type: ignore if fn_name not in VALID_HOOKS_NAME else fn_name,) return fn return decorator(fn_or_name) if callable(fn_or_name) else decorator
@overload def validate(fn_or_column: FuncT) -> "staticmethod[[FuncT], bool]": pass @overload def validate(fn_or_column: str, reason: str | None = None): # type: ignore pass
[docs] def validate(fn_or_column, reason=None): """Register a validator""" def decorator(func: Callable): fn = staticmethod(func) name = func.__name__ inferred_col = ( name[len("validate_") :] if name.startswith("validate_") else None ) col = fn_or_column or inferred_col if callable(col): col = inferred_col if col is None: raise ValueError("Validator must have a column name.") msg = reason or func.__doc__ or f"Validation failed for '{col}'" fn._validators_info = (col, msg) # type: ignore return fn return decorator(fn_or_column) if callable(fn_or_column) else decorator
[docs] def initiate_hook(cls: Type[BaseModel], member: Callable): """Initiate hooks""" if hasattr(member, "_hooks_info"): info = member._hooks_info cls._register("hook", info[0])(member)
[docs] def initiate_validators(cls: Type[BaseModel], member: Callable): """Initiate validators""" if hasattr(member, "_validators_info"): info = member._validators_info cls._register("validator", info[0], info[1])(member)