"""Table"""
# pylint: disable=too-many-arguments,too-many-public-methods,R0801
from contextvars import ContextVar
from sqlite3 import Connection, Error, OperationalError
from typing import (
Any,
Generator,
Iterable,
Literal,
Optional,
overload,
TYPE_CHECKING
)
from luminadb.functions import ParsedFn, Function, count
from luminadb.subquery import SubQuery
from .utils import check_iter, check_one, Row, crunch
from ._debug import if_debug_print
from .column import BuilderColumn, Column
from .errors import TableRemovedError
from .query_builder import (
# extract_single_column,
# fetch_columns,
build_select,
build_insert,
build_delete,
build_update,
)
from .query_builder.typings import Condition
from .query_builder.table_creation import extract_single_column
# from .signature import op
from .typings import (
Data,
Orders,
Query,
# _MasterQuery,
OnlyColumn,
SquashedQueries,
JustAColumn,
)
if TYPE_CHECKING:
from .database import Database
# Let's add a little bit of 'black' magic here.
_null = Function("__NULL__")()
_tx_stack = ContextVar("_tx_stack", default=[])
[docs]
class Table: # pylint: disable=too-many-instance-attributes
"""Table. Make sure you remember how the table goes."""
def __init__(
self,
parent: "Database", # type: ignore
table: str,
columns: Optional[Iterable[Column]] = None, # type: ignore
) -> None:
if parent.closed:
raise ConnectionError("Connection to database is already closed.")
self._parent_repr = repr(parent)
self._db = parent
self._sql: Connection = parent.sql
# pylint: disable-next=protected-access
self._sql_path = parent._path
self._deleted = False
self._force_dirty = False
self._dirty = False
self._auto = True
self._table = check_one(table)
self._prev_autocommit = None
self._prev_auto = True
self._columns: Optional[list[Column]] = list(columns) if columns else None
def __enter__(self):
self._prev_auto = self._auto
self._prev_autocommit = self._sql.isolation_level
self._auto = False
self._sql.isolation_level = None
self._begin_transaction()
return self
def __exit__(self, exc_type, _, __):
if exc_type is None:
self._commit_transaction()
self._dirty = False
else:
self._rollback_transaction()
self._dirty = False
self._sql.isolation_level = self._prev_autocommit
self._auto = self._prev_auto
@property
def deleted(self):
"""Is table deleted"""
return self._deleted
@property
def name(self):
"""Table name"""
return self._table
@property
def force_dirty(self):
"""Force dirty state, whether .selecting() on dirty/uncommitted data is allowed or not"""
return self._force_dirty
@force_dirty.setter
def force_dirty(self, value: bool):
"""Force dirty state, whether .selecting() on dirty/uncommitted data is allowed or not"""
if not isinstance(value, bool):
return
self._force_dirty = value
@property
def auto_commit(self):
"""Auto commit state of this instance"""
return self._auto
@auto_commit.setter
def auto_commit(self, value: bool):
if not isinstance(value, bool):
return
self._auto = value
@property
def in_transaction(self):
"""Returns True if the table is in an active transaction."""
return not self._auto or self._sql.isolation_level is None
def _finalize(self):
pass
def _delete_hook(self):
try:
self.select()
except OperationalError:
self._deleted = True
def _exec(
self,
query: str,
data: dict[str, Any] | list[dict[str, Any]],
which: Literal["execute", "executemany"] = "execute",
):
"""Execute a sql query"""
if_debug_print(query, '\n', data)
cursor = self._sql.cursor()
fn = cursor.execute if which == "execute" else cursor.executemany
try:
fn(query, data)
except Error as exc:
if str(exc).startswith("no such table:"):
raise TableRemovedError(f"Table {self._table} doesn't exists anymore") from None
exc.add_note(f"SQL query: {query}")
exc.add_note(f"Arguments: {data}")
exc.add_note(
f"There's about {1 if isinstance(data, dict) else len(data)} value(s) inserted"
)
raise exc
return cursor
def _control(self):
if self._deleted:
raise TableRemovedError(f"{self._table} is already removed")
def _query_control(self):
if self._dirty and self._force_dirty is False:
self._sql.commit()
self._dirty = False
[docs]
def force_nodelete(self):
"""Force "undelete" table. Used if table was mistakenly assigned as
deleted."""
self._deleted = True
[docs]
def delete(
self,
where: Condition = None,
limit: int = 0,
order: Optional[Orders] = None,
):
"""Delete row or rows
Args:
where (Condition, optional): Condition to determine deletion
See `Signature` class about conditional stuff. Defaults to None.
limit (int, optional): Limit deletion by integer. Defaults to 0.
order (Optional[Orders], optional): Order of deletion. Defaults to None.
Returns:
int: Rows affected
"""
query, data = build_delete(self._table, where, limit, order) # type: ignore
self._control()
cursor = self._exec(query, data)
rcount = cursor.rowcount
if not self.in_transaction:
self._sql.commit()
else:
self._dirty = True
return rcount
[docs]
def delete_one(self, where: Condition = None, order: Optional[Orders] = None):
"""Delete a row
Args:
where (Condition, optional): Conditional to determine deletion.
Defaults to None.
order (Optional[Orders], optional): Order of deletion. Defaults to None.
"""
return self.delete(where, 1, order)
[docs]
def insert(self, data: Data):
"""Insert data to current table
Args:
data (Data): Data to insert. Make sure it's compatible with the table.
Returns:
int: Last rowid
"""
query, _ = build_insert(self._table, data) # type: ignore
self._control()
cursor = self._exec(query, data)
rlastrowid = cursor.lastrowid
if not self.in_transaction:
self._sql.commit()
else:
self._dirty = True
return rlastrowid
[docs]
def insert_multiple(self, datas: list[Data]):
"""Insert multiple values
Args:
datas (Iterable[Data]): Data to be inserted.
"""
self._control()
query, _ = build_insert(self._table, datas[0]) # type: ignore
self._exec(query, datas, "executemany")
if not self.in_transaction:
self._sql.commit()
else:
self._dirty = True
[docs]
def insert_many(self, datas: list[Data]):
"""Alias to `insert_multiple`"""
return self.insert_multiple(datas)
[docs]
def update(
self,
where: Condition | None = None,
data: Data | None = None,
limit: int = 0,
order: Optional[Orders] = None,
):
"""Update rows of current table
Args:
data (Data): New data to update
where (Condition, optional): Condition dictionary.
See `Signature` about how condition works. Defaults to None.
limit (int, optional): Limit updates. Defaults to 0.
order (Optional[Orders], optional): Order of change. Defaults to None.
Returns:
int: Rows affected
"""
if data is None:
raise ValueError("data parameter must not be None")
query, data = build_update(
self._table, data, where, limit, order
) # type: ignore
self._control()
cursor = self._exec(query, data)
rcount = cursor.rowcount
if not self.in_transaction:
self._sql.commit()
else:
self._dirty = True
return rcount
[docs]
def update_one(
self,
where: Condition | None = None,
data: Data | None = None,
order: Orders | None = None,
) -> int:
"""Update 1 data only"""
return self.update(where, data, 1, order)
@overload
def select(
self,
where: Condition = None,
what: OnlyColumn = "*",
limit: int = 0,
offset: int = 0,
order: Optional[Orders] = None,
flatten: Literal[False] = False,
) -> list[Query]:
pass
@overload
def select(
self,
where: Condition = None,
what: OnlyColumn = "*",
limit: int = 0,
offset: int = 0,
order: Optional[Orders] = None,
flatten: Literal[True] = True,
) -> SquashedQueries:
pass
@overload
def select(
self,
where: Condition = None,
what: ParsedFn = _null,
limit: int = 0,
offset: int = 0,
order: Optional[Orders] = None,
flatten: Literal[False] = False,
) -> Any:
pass
@overload
def select(
self,
where: Condition = None,
what: JustAColumn = "_COLUMN",
limit: int = 0,
offset: int = 0,
order: Optional[Orders] = None,
flatten: Literal[False] = False,
) -> list[Any]:
pass
[docs]
def select(
self, # pylint: disable=too-many-arguments
where: Condition = None,
what: OnlyColumn | ParsedFn | JustAColumn = "*",
limit: int = 0,
offset: int = 0,
order: Optional[Orders] = None,
flatten: bool = False,
):
"""Select data in current table. Bare .select() returns all data.
Args:
where (Condition, optional): Conditions to used. Defaults to None.
what: (OnlyColumn, ParsedFn, optional): Select what you want. Default to None.
limit (int, optional): Limit of select. Defaults to 0.
offset (int, optional): Offset. Defaults to 0
order (Optional[Orders], optional): Selection order. Defaults to None.
flatten (bool): Flatten returned data into dict of lists. Defaults to False.
Returns:
Queries: Selected data
"""
self._control()
self._query_control()
query, data = build_select(
self._table, where, what, limit, offset, order
) # type: ignore
just_a_column = (isinstance(what, tuple) and len(what) == 1) or (
isinstance(what, str) and what != "*"
)
with self._sql:
cursor = self._exec(query, data)
data = cursor.fetchall()
if just_a_column:
return [d[what] for d in data]
if flatten:
return crunch(data)
if isinstance(what, ParsedFn):
return data[0][what.parse_sql()[0]]
return data
@overload
def paginate_select(
self,
where: Condition = None,
what: OnlyColumn = "*",
page: int = 0,
length: int = 10,
order: Optional[Orders] = None,
flatten: Literal[False] = False,
) -> Generator[list[Query], None, None]:
pass
@overload
def paginate_select(
self,
where: Condition = None,
what: JustAColumn = "_COLUMN",
page: int = 0,
length: int = 10,
order: Optional[Orders] = None,
flatten: Literal[False] = False,
) -> Generator[list[Any], None, None]: # type: ignore
pass
@overload
def paginate_select(
self,
where: Condition = None,
what: OnlyColumn = "*",
page: int = 0,
length: int = 10,
order: Optional[Orders] = None,
flatten: Literal[True] = True,
) -> Generator[SquashedQueries, None, None]:
pass
[docs]
def paginate_select(
self,
where: Condition = None,
what: OnlyColumn | JustAColumn = "*",
page: int = 0,
length: int = 10,
order: Optional[Orders] = None,
flatten: bool = False,
):
"""Paginate select
Args:
where (Condition, optional): Confitions to use. Defaults to None.
what (OnlyColumn, optional): Select what you want. Default to None.
page (int): Which page number be returned first
length (int, optional): Pagination length. Defaults to 10.
order (Optional[Orders], optional): Order. Defaults to None.
flatten (bool): Flatten returned data into dict of lists. Defaults to False.
Yields:
Generator[Queries, None, None]: Step-by-step paginated result.
"""
if page < 0:
page = 0
order = "desc" if order in ("asc", None) else "asc" # type: ignore
self._control()
self._query_control()
start = page * length
# ! A `only` keyword as a string or tuple of 1 element will
# ! actually be a problem if they left alone because the end result is a list
just_a_column = (isinstance(what, str) and what != "*") or (
isinstance(what, tuple) and len(what) == 1
)
while True:
query, data = build_select(
self._table, where, what, length, start, order
) # type: ignore
with self._sql:
cursor = self._exec(query, data)
fetched = cursor.fetchmany(length)
if len(fetched) == 0:
return
if flatten and not just_a_column:
fetched = crunch(fetched)
if len(fetched) != length:
yield fetched
return
yield fetched
start += length
@overload
def select_one(
self,
where: Condition = None,
what: ParsedFn = _null,
order: Optional[Orders] = None,
) -> Any:
pass
@overload
def select_one(
self,
where: Condition = None,
what: OnlyColumn = "*",
order: Optional[Orders] = None,
) -> Query:
pass
@overload
def select_one(
self,
where: Condition = None,
what: JustAColumn = "_COLUMN",
order: Optional[Orders] = None,
) -> Any:
pass
[docs]
def select_one(
self,
where: Condition = None,
what: OnlyColumn | JustAColumn | ParsedFn = "*",
order: Optional[Orders] = None,
):
"""Select one data
Args:
where (Condition, optional): Condition to use. Defaults to None.
what: (OnlyColumn, optional): Select what you want. Default to None.
order (Optional[Orders], optional): Order of selection. Defaults to None.
Returns:
Any: Selected data
"""
self._control()
self._query_control()
query, data = build_select(
self._table, where, what, 1, 0, order
) # type: ignore
with self._sql:
cursor = self._exec(query, data)
returned = cursor.fetchone()
if isinstance(what, ParsedFn):
print(returned)
return returned[what.parse_sql()[0]]
if not returned:
return Row()
if isinstance(what, tuple) and len(what) == 1:
return returned[what]
if isinstance(what, str) and what != "*":
return returned[what]
return returned
[docs]
def columns(self):
"""Table columns"""
if self._columns is None:
raise AttributeError("columns are undefined.")
return tuple(self._columns)
[docs]
def add_column(self, column: Column | BuilderColumn):
"""Add column to table"""
sql = self._sql
column = column.to_column() if isinstance(column, BuilderColumn) else column
if column.primary or column.unique:
raise OperationalError(
"New column cannot have primary or unique constraint"
)
if column.nullable is False and column.default is None:
raise OperationalError(
"New column cannot be not null while default value is \
set to null"
)
if column.default is not None and column.foreign:
raise OperationalError(
"New column must accept null default value if foreign \
constraint is enabled."
)
query = f"alter table {self._table} add column {extract_single_column(column)}"
if self._columns is not None:
self._columns.append(column)
sql.execute(query)
[docs]
def subquery(self, where: Condition, columns: OnlyColumn | str, limit: int = 0) -> SubQuery:
"""Push subquery to current .select() of other table"""
return SubQuery(self, columns, where, limit)
[docs]
def rename_column(self, old_column: str, new_column: str):
"""Rename existing column to new column"""
check_iter((old_column, new_column))
query = f"alter table {self._table} rename column {old_column} to {new_column}"
self._sql.execute(query)
[docs]
def commit(self):
"""Commit changes"""
self._sql.commit()
self._dirty = False
[docs]
def rollback(self):
"""Rollback"""
self._sql.rollback()
self._dirty = False
def _begin_transaction(self):
"""Start a transaction or savepoint depending on depth."""
stack = list(_tx_stack.get()) # copy since ContextVar values are immutable
depth = len(stack)
if depth == 0:
self._sql.execute("BEGIN TRANSACTION")
else:
savepoint_name = f"sp_{depth}"
self._sql.execute(f"SAVEPOINT {savepoint_name}")
stack.append(True)
_tx_stack.set(stack)
def _commit_transaction(self):
"""Commit or release savepoint depending on depth."""
stack = list(_tx_stack.get())
depth = len(stack)
if depth == 1:
self._sql.commit()
elif depth > 1:
savepoint_name = f"sp_{depth-1}"
self._sql.execute(f"RELEASE SAVEPOINT {savepoint_name}")
stack.pop()
_tx_stack.set(stack)
def _rollback_transaction(self):
"""Rollback or rollback to savepoint depending on depth."""
stack = list(_tx_stack.get())
depth = len(stack)
if depth == 1:
self._sql.rollback()
elif depth > 1:
savepoint_name = f"sp_{depth-1}"
self._sql.execute(f"ROLLBACK TO SAVEPOINT {savepoint_name}"
)
stack.pop()
_tx_stack.set(stack)
[docs]
def count(self):
"""Count how much objects/rows stored in this table"""
# ? Might as well uses __len__? But it's quite expensive.
return self.select(what=count("*"))
def __repr__(self) -> str:
return f"<{type(self).__name__}({self._table}) -> {self._db!r}>"
__all__ = ["Table"]