Source code for luminadb.models

"""Models"""

# pylint: disable=unused-import,unused-argument,cyclic-import,protected-access

from contextlib import contextmanager
from typing import Any, Callable, Self, Type, TypeVar
from dataclasses import asdict, dataclass, fields, is_dataclass, MISSING

from luminadb.models.type_checkers import typecheck

from .helpers import (
    Constraint,
    Unique,
    Primary,
    Foreign,
    TYPES,
    Validators,
    CASCADE,
    DEFAULT,
    NOACT,
    RESTRICT,
    SETNULL,
)

from .helpers import (
    VALID_HOOKS_NAME,
    hook,
    validate,
    initiate_hook,
    initiate_validators,
)
from .query_builder import QueryBuilder
from .errors import ConstraintError
from ..errors import DatabaseExistsError
from ..database import Database, Table
from ..column import text, BuilderColumn
from ..operators import in_

NULL = object()
T = TypeVar("T", bound="BaseModel")

## Model functions


@staticmethod
def noop_autoid():
    """Default no-op function for BaseModel __auto_id__"""
    return None


[docs] class BaseModel: # pylint: disable=too-few-public-methods,too-many-public-methods """Base class for all Models using Model API""" __table_name__ = "" __schema__: tuple[Constraint, ...] = () __validators__: dict[str, list[Validators]] = {} __hooks__: "dict[str, list[Callable[[Self], None] | str]]" = {} __hidden__: tuple[str, ...] = () __auto_id__: Callable[[], Any] = noop_autoid _tbl: Table _primary: str | None
[docs] @classmethod def create_table(cls, db: Database): """Create table according to annotations and schema from `__schema__`""" if not is_dataclass(cls): raise TypeError(f"{cls.__name__} must be a dataclass") columns: list[BuilderColumn] = [] constraints: dict[str, list[Constraint]] = { col: [] for col in cls.__annotations__ } # pylint: disable=no-member cls.__table_name__ = cls.__table_name__ or cls.__name__.lower() _primary = None _primary_auto = False # Extract constraints from __schema__ for constraint in cls.__schema__: # pylint: disable=no-member target_col = constraint.column constraints[target_col].append(constraint) if isinstance(constraint, Primary) and _primary: raise ConstraintError("Cannot apply when a column has 2 primary keys") if isinstance(constraint, Primary): _primary = target_col # remember if primary has auto-increment enabled try: _primary_auto = constraint.auto # type: ignore except TypeError: _primary_auto = False continue # Process fields & constraints for field_def in fields(cls): # Fetch fields using dataclass reflection field_name = field_def.name field_type = field_def.type col = TYPES.get(field_type, text)(field_name) # type: ignore # Check if there's a default value if field_def.default is not MISSING: default_value = field_def.default col = col.default(default_value) # Apply constraints dynamically for constraint in constraints.get(field_name, []): constraint.apply(col) columns.append(col) cls._primary = _primary # Remember whether primary requested auto-increment so create() # can handle instantiation and post-insert assignment accordingly. cls._primary_auto = bool(_primary_auto) try: cls._tbl = db.create_table(cls.__table_name__, columns) except DatabaseExistsError: cls._tbl = db.table(cls.__table_name__.lower())
@classmethod def _execute_hooks(cls, name: str, instance: Self): for hook_fn in cls.__hooks__.get(name, ()): if isinstance(hook_fn, str): getattr(cls, hook_fn)(instance) else: hook_fn(instance) @classmethod def _execute_validators(cls, name: str, instance: Self): for validator_fn in cls.__validators__.get(name, ()): validator_fn.validate(instance) @classmethod def _register(cls, type_: str = "hook", name: str = "", if_fail: str = ""): """Register a hook/validator under a name""" if type_ not in ("hook", "validator"): raise ValueError("Which do you want?") def function(func): if name == "": raise ValueError(f"{type_.title()} name needs to be declared.") if type_ == "hook": if name not in VALID_HOOKS_NAME: raise ValueError("Name of a hook doesn't match with expected value") cls.__hooks__.setdefault(name, []) if cls.__hooks__[name]: cls.__hooks__[name] = [func] else: cls.__hooks__[name].append(func) return func if not is_dataclass(cls): raise TypeError("Dataclass is required for this class") fields_ = tuple((field.name for field in fields(cls))) if name not in fields_: raise ValueError( f"Expected validator to has name as column field. Got {name!r}" ) fail = if_fail or f"{name} fails certain validator" cls.__validators__.setdefault(name, []) validator_entry = Validators(func, fail) if cls.__validators__[name]: cls.__validators__[name] = [validator_entry] else: cls.__validators__[name].append(validator_entry) return func return function
[docs] @classmethod def create(cls, **kwargs): """Create data based on kwargs""" primary: str | None = cls._primary or kwargs.get("id", None) id_present = bool(kwargs.get(cls._primary or "id", None)) # If primary is auto-incremented by DB, do not call model-level # __auto_id__. Instead, instantiate with a temporary None for the # primary so dataclass __init__ accepts it, then perform the insert # and set the real id on the instance from the DB's lastrowid. use_db_autoinc = bool(primary and getattr(cls, "_primary_auto", False) and not id_present) if use_db_autoinc: kwargs[primary] = None # type: ignore # allow dataclass instantiation else: if primary and cls.__auto_id__ and not id_present: # type: ignore kwargs[primary] = cls.__auto_id__() # type: ignore instance = cls(**kwargs) cls._execute_hooks("before_create", instance) for key in kwargs: # Below is a naive changes # When validating input, it skips primary key if it # ... detects SQL is going to handle it internally. # If it's "enough," no changes should be made. # But I don't believe it. if key == primary and use_db_autoinc: continue cls._execute_validators(key, instance) lastrow = cls._tbl.insert(kwargs) if use_db_autoinc: # Assign generated id back to instance attribute try: setattr(instance, primary, lastrow) # type: ignore except (TypeError, AttributeError): pass cls._execute_hooks("after_create", instance) return instance
[docs] def update(self, __primary: str | object = NULL, /, **kwargs): """Update current data""" # pylint: disable=protected-access primary = self._primary or __primary if primary is NULL: raise ValueError( "The table does not have any primary key, cannot update due to undefined selection" ) self._execute_hooks("before_update", self) for key in kwargs: self._execute_validators(key, self) self._tbl.update({primary: getattr(self, primary)}, kwargs) # type: ignore for key, value in kwargs.items(): setattr(self, key, value) self._execute_hooks("after_update", self) return self
[docs] def delete(self, __primary=NULL, /): """Delete current data""" # pylint: disable=protected-access primary = self._primary or __primary if primary is NULL: raise ValueError( "The table does not have any primary key, cannot delete due to undefined selection" ) self._execute_hooks("before_delete", self) self._tbl.delete_one({primary: getattr(self, primary)}) # type: ignore self._execute_hooks("after_delete", self)
[docs] @classmethod def bulk_create(cls, records: list[dict]): """Insert multiple records at once.""" cls._tbl.insert_many(records) return [cls(**record) for record in records] # Return list of instances
[docs] @classmethod def bulk_update(cls, records: list[dict], key: str | object = NULL): """Update multiple records using a primary key or provided key.""" key_ = cls._primary or key if key is NULL: raise ValueError( "The table does not have any primary key, or key parameter is not provided" ) for record in records: if key_ not in record: raise ValueError(f"Missing primary key '{key_}' in record: {record}") cls._tbl.update({key_: record[key_]}, record) # type: ignore
[docs] @classmethod def bulk_delete(cls, keys: list[Any], key: str): """Delete multiple records using a primary key.""" cls._tbl.delete({key: in_(keys)})
[docs] @classmethod def first_or_fail(cls, **kwargs): """Return the first matching record or raise an error if no match is found.""" result = cls.where(**kwargs).limit(1).throw().fetch_one() return result
[docs] @classmethod def first(cls, **kwargs): """Return the first matching record or None if no match is found.""" result = cls.where(**kwargs).limit(1).fetch_one() return result
[docs] @classmethod def one(cls, **kwargs): """Return exactly one record. Raises error if multiple results exist.""" results = cls.where(**kwargs).fetch() if len(results) > 1: raise ValueError(f"Expected one record, but found {len(results)}") return results[0] if results else None
[docs] @classmethod def find(cls, amount: int): """Return models relative to the amount""" results = cls.query().limit(amount).fetch() return results
[docs] @classmethod def find_or_fail(cls, amount: int): """Return models relative to the amount and when returned is equal to 0, throws an error""" results = cls.query().limit(amount).throw().fetch() return results
[docs] @classmethod def all(cls): """Return all values from the table""" return cls.where().fetch()
[docs] @classmethod def count(cls, **kwargs) -> int: """Return count of matching records.""" return cls.where(**kwargs).count()
[docs] @classmethod def exists(cls, **kwargs) -> bool: """Check if any record matches the query.""" return cls.where(**kwargs).limit(1).fetch_one() is not None
[docs] @classmethod @contextmanager def atomic(cls): """Perform operations within a transaction.""" with cls._tbl: # Assuming `transaction()` exists yield
[docs] @classmethod def upsert(cls, key: str, **kwargs): """Insert or update a record based on primary key.""" existing = cls.where(**{key: kwargs[key]}).fetch_one() if existing: return existing.update(**kwargs) return cls.create(**kwargs)
[docs] def to_dict(self): """Convert model instance to dictionary.""" if is_dataclass(self): # always true, though, just in case instance = asdict(self) return {k: v for k, v in instance.items() if k not in self.__hidden__} return {}
[docs] def to_safe_instance(self) -> Self: """Wrap instance that complies with __hidden__.""" if is_dataclass(self): dict_inst = asdict(self).items() instance = { k: (v if not k in self.__hidden__ else None) for k, v in dict_inst } return type(self)(**instance) raise TypeError("This class must be a dataclass")
[docs] def raw(self, query: str, params: list[Any] | tuple[Any, ...] | dict[str, Any]): """Raw SQL query""" return self._tbl._sql.execute(query, params) # pylint: disable=protected-access
[docs] def has_many(self, related: "Type[T]", foreign_key: str | object = NULL): """Ensure related_model has a Foreign key pointing to self""" foreign_key = None if not self._primary: raise ConstraintError( f"The table {self.__table_name__} does not have any primary key " "required for has_many()" ) # Scan __schema__ of related model to find a Foreign key linking back for constraint in related.__schema__: if isinstance(constraint, Foreign): table_ref, _ = constraint.target.split("/") # type: ignore if table_ref == self.__class__.__name__.lower(): foreign_key = constraint.column break if not foreign_key: raise ValueError( f"{related.__name__} does not have a Foreign key pointing " f"to {self.__class__.__name__}" ) # Perform the actual query return related.where(**{foreign_key: getattr(self, self._primary)}).fetch()
[docs] def belongs_to(self, related_model: "Type[T]"): """Retrieve the related model that this instance belongs to.""" # Find the Foreign() constraint that references `related_model` for constraint in self.__schema__: if isinstance(constraint, Foreign) and constraint.target.startswith( # type: ignore related_model.__table_name__ + "/" ): foreign_key = constraint.column referenced_column = constraint.target.split("/")[1] # type: ignore return related_model.where( **{referenced_column: getattr(self, foreign_key)} ).fetch_one() raise ValueError( f"{self.__class__.__name__} does not belong to {related_model.__name__}" )
[docs] def has_one(self, related_model: "Type[T]"): """Retrieve the related model where this instance is referenced.""" # Find the Foreign() constraint in `related_model` that references this model for constraint in related_model.__schema__: if ( isinstance(constraint, Foreign) and constraint.target == f"{self.__table_name__}/{self._primary}" ): foreign_key = constraint.column return related_model.where(**{foreign_key: self._primary}).fetch_one() raise ValueError( f"{related_model.__name__} does not have a one-to-one relationship " f"with {self.__class__.__name__}" )
[docs] @classmethod def get_table(cls): """Return table instance""" return cls._tbl
[docs] @classmethod def where(cls, **kwargs): """Basic select operation""" return QueryBuilder(cls).where(**kwargs)
[docs] @classmethod def query(cls): """Return Query Builder related to this model""" return QueryBuilder(cls)
[docs] def model(db: Database, type_checking: bool = False): """Initiate Model API compatible classes. Requires target to be a dataclass, the app automatically injects dataclass if this isn't a dataclass. Use `type_checking` if you want automatic runtime type checker.""" def outer(cls: Type[T]) -> Type[T]: if not issubclass(cls, BaseModel): raise TypeError(f"Model {cls.__name__} is not subclass of BaseModel.") if not is_dataclass(cls): cls = dataclass(cls) cls.create_table(db) if type_checking: for fn in typecheck(cls): initiate_validators(cls, fn) for member in cls.__dict__.values(): initiate_hook(cls, member) initiate_validators(cls, member) return cls return outer
__all__ = [ "model", "BaseModel", "Unique", "Primary", "Foreign", "QueryBuilder", "CASCADE", "DEFAULT", "NOACT", "SETNULL", "RESTRICT", "validate", "hook", ]