Source code for pilotscope.DBController.BaseDBController

import threading
from abc import ABC, abstractmethod

from sqlalchemy import create_engine, String, Integer, Float, MetaData, Table, inspect, select, func, Column
from sqlalchemy_utils import database_exists, create_database

from pilotscope.Common.Index import Index
from pilotscope.Exception.Exception import DatabaseDeepControlException
from pilotscope.PilotConfig import PilotConfig
from pilotscope.PilotEnum import DatabaseEnum


[docs]class BaseDBController(ABC):
[docs] def __init__(self, config: PilotConfig, echo=True): """ :param config: The config of PilotScope including the config of database. :param echo: if true, the more detailed information will be printed when executing the sql statement. """ self.config = config self.echo = echo self.connection_thread = threading.local() self._db_init()
def _db_init(self): """ Initialize the database connection and engine. """ self.engine = self._create_engine() self.metadata = MetaData() self._connect_if_loss() def _create_engine(self): """ Create the database engine. :return: The created database engine. """ conn_str = self._create_conn_str() if not database_exists(conn_str): create_database(conn_str, encoding="utf8", template="template0") return create_engine(conn_str, echo=self.echo, pool_size=10, pool_recycle=3600, connect_args={"options": "-c statement_timeout={}".format( self.config.sql_execution_timeout * 1000)}, client_encoding='utf8', isolation_level="AUTOCOMMIT") def _get_connection(self): """ Get the connection of DBController. :return: the connection object of sqlalchemy in thread-local data :rtype: Connection of sqlalchemy """ return self.connection_thread.conn def _connect_if_loss(self): """ If the connection is lost, establish a new connection to the database. """ if not self._is_connect(): self.connection_thread.conn = self.engine.connect() def _reset(self): """ Reset the database connection. This function closes the current connection, recreates the pool, and establishes a new connection to the database. """ if self.connection_thread.conn is not None: self.connection_thread.conn.invalidate() self.connection_thread.conn = self.engine.connect() def _disconnect(self): """ Disconnect from the database. This function closes the connection if it is already established. """ if self._is_connect(): self.connection_thread.conn.close() self.connection_thread.conn = None def _is_connect(self): """ If self have connected, return True. Otherwise, return False. Note that if the DBMS is stopped from outside, the return value of this function will not change. :return: if self connected or not :rtype: bool """ return hasattr(self.connection_thread, "conn") and self.connection_thread.conn is not None
[docs] @abstractmethod def explain_physical_plan(self, sql): """ Get a physical plan from database's optimizer for a given SQL query. :param sql: The SQL query to be explained. """ pass
[docs] @abstractmethod def explain_execution_plan(self, sql): """ Get an execution plan from database's optimizer for a given SQL query. :param sql: The SQL query to be explained. """ pass
@abstractmethod def _create_conn_str(self): pass
[docs] @abstractmethod def execute(self, sql, fetch=False): """ Execute a sql statement. :param sql: A SQL statement to be executed. :param fetch: fetch result or not. If True, the function will return a list of tuple representing the result of the sql. """ pass
[docs] @abstractmethod def set_hint(self, key, value): """ Set the value of each hint (i.e., the run-time config) when execute SQL queries. The hints can be used to control the behavior of the database system in a session. For PostgreSQL, you can find all valid hints in https://www.postgresql.org/docs/13/runtime-config.html. For Spark, you can find all valid hints (called conf in Spark) in https://spark.apache.org/docs/latest/configuration.html#runtime-sql-configuration :param key: The key associated with the hint. :param value: The value of the hint to be set. """ raise NotImplementedError
[docs] @abstractmethod def create_index(self, index: Index): """ Create an index on columns `index.columns` of table `index.table` with name `index.index_name`. :param index: a Index object including the information of the index """ pass
[docs] @abstractmethod def drop_index(self, index: Index): """ Drop an index by its index name. :param index: an index that will be dropped """
pass
[docs] @abstractmethod def drop_all_indexes(self): """ Drop all indexes across all tables in the database. This will not delete the system indexes and unique indexes. """ pass
[docs] @abstractmethod def get_all_indexes_byte(self): """ Get the size of all indexes across all tables in the database in bytes. This will include the system indexes and unique indexes. :return: the size of all indexes in bytes """ pass
[docs] @abstractmethod def get_table_indexes_byte(self, table_name): """ Get the size of all indexes on a table in bytes. This will include the system indexes and unique indexes. :param table_name: a table name that the indexes belong to :return: the size of all indexes on the table in bytes """ pass
[docs] @abstractmethod def get_index_byte(self, index: Index): """ Get the size of an index in bytes by its index name. :param index: the index to get size :return: the size of the index in bytes """ pass
[docs] def get_index_number(self, table): """ Get the number of indexes built on the specified table. :param table: name of the table :return: the number of index """ inspector = self._create_inspect() n = len(inspector.get_indexes(table)) return n
[docs] def get_existed_indexes(self, table): """ Retrieves the existing index on the specified table. This will not include the system indexes and unique indexes. :param table: the name of the table :return: a list of pilotscope.common.Index """ inspector = self._create_inspect() db_indexes = inspector.get_indexes(table) indexes = [] for db_index in db_indexes: indexes.append(Index(columns=db_index["column_names"], table=table, index_name=db_index["name"])) return indexes
[docs] def get_all_indexes(self): """ Get all indexes across all tables in the database. :return: a list of pilotscope.common.Index """ inspector = self._create_inspect() indexes = [] for table in inspector.get_table_names(): db_indexes = inspector.get_indexes(table) for db_index in db_indexes: indexes.append(Index(columns=db_index["column_names"], table=table, index_name=db_index["name"])) return indexes
[docs] def get_estimated_cost(self, sql, comment=""): """ Get an estimated cost of a SQL query. :param sql: The SQL query for which to estimate the cost. :param comment: An optional comment to include with the query plan. Useful for debugging. :return: The estimated total cost of executing the SQL query. """ pass
def _create_inspect(self): return inspect(self.engine)
[docs] def create_table_if_absences(self, table_name, column_2_value, primary_key_column=None, enable_autoincrement_id_key=True): """ Create a table according to parameters if absences. This function will not insert any data into the table. The column names and types of the table will be inferred from `column_2_value`. :param table_name: the name of the table you want to create :param column_2_value: a dict, whose keys are the names of columns and values. This data will be used to infer the column names and types of the table. :param primary_key_column: A column name in `column_2_value`. The corresponding column will be set as primary key. Otherwise, there will be no primary key. :param enable_autoincrement_id_key: If it is True, the `primary_key_column` will be autoincrement. It is only meaningful when `primary_key_column` is not None. """ self._connect_if_loss() if primary_key_column is not None and primary_key_column not in column_2_value: raise RuntimeError("the primary key column {} is not in column_2_value".format(primary_key_column)) if not self.exist_table(table_name): column_2_type = self._to_db_data_type(column_2_value) columns = [] for column, column_type in column_2_type.items(): if column == primary_key_column: columns.append( Column(column, column_type, primary_key=True, autoincrement=enable_autoincrement_id_key)) else: columns.append(Column(column, column_type)) table = Table(table_name, self.metadata, *columns, extend_existing=True) table.create(self.engine)
[docs] def drop_table_if_exist(self, table_name): """ Try to drop table named `table_name` :param table_name: the name of the table """ if self.exist_table(table_name): table = Table(table_name, self.metadata, autoload_with=self.engine) table.drop(self.engine) self.metadata.remove(table)
[docs] def exist_table(self, table_name) -> bool: """ If the table named `table_name` exist or not :return: the table named `table_name` exist, it returns True; otherwise, it returns False """ return self.engine.dialect.has_table(self._get_connection(), table_name)
[docs] def get_all_table_names(self): """ Retrieves a list of all table names in the database. :return: A list of table names present in the database. """ self._update_sqla_tables() return list(self.metadata.tables.keys())
[docs] def insert(self, table_name, column_2_value: dict): """ Insert a new row into the table with each column's value set as column_2_value. :param table_name: the name of the table :param column_2_value: a dict where the keys are column names and the values are the values to be inserted """ self._connect_if_loss() table = Table(table_name, self.metadata, autoload_with=self.engine) self.execute(table.insert().values(column_2_value))
[docs] def get_table_columns(self, table_name): """ Get all column names of a table :param table_name: the names of the table :return: the list of the names of column """ return [c.key for c in self._get_sqla_table(table_name).c]
[docs] def get_table_row_count(self, table_name): """ Get the row count of the table :param table_name: the name of the table :return: the row count """ table = self._get_sqla_table(table_name) stmt = select(func.count()).select_from(table) result = self.execute(stmt, fetch=True) return result[0][0]
[docs] def get_column_max(self, table_name, column_name): """ Get the maximum of a column :param table_name: the name of the table that the column belongs to :param column_name: the name of the column :return: the maximum, type of which is same as the data of the column """ table = self._get_sqla_table(table_name) stmt = select(func.max(table.c[column_name])).select_from(table) result = self.execute(stmt, fetch=True) return result[0][0]
[docs] def get_column_min(self, table_name, column_name): """ Get the minimum of a column :param table_name: the name of the table that the column belongs to :param column_name: the name of the column :return: the minimum, type of which is same as the data of the column """ table = self._get_sqla_table(table_name) stmt = select(func.min(table.c[column_name])).select_from(table) result = self.execute(stmt, fetch=True) return result[0][0]
[docs] def shutdown(self): """ shutdown the database """ pass
[docs] def start(self): """ start the database """ pass
[docs] def restart(self): """ restart the database """ if self.config.db_type != DatabaseEnum.SPARK: self._check_enable_deep_control() self.shutdown() self.start()
[docs] def write_knob_to_file(self, key_2_value_knob): """ Write knobs to config file, you should restart database to make it work. :param key_2_value_knob: a dict with keys as the names of the knobs and values as the values to be set. """ pass
[docs] def recover_config(self): """ Recover config file of database to the lasted saved config file by `backup_config()` """ pass
def _to_db_data_type(self, column_2_value): """ Converts Python data types to database data types. :param column_2_value: A dictionary mapping column names to their respective values, e.g. {'col1': 'value1', 'col2': 'value2'} :return: A dictionary mapping column names to SQLAlchemy data types. """ column_2_type = {} for col, data in column_2_value.items(): data_type = String if type(data) == int: data_type = Integer elif type(data) == float: data_type = Float elif type(data) == str: data_type = String elif type(data) == dict: data_type = String elif type(data) == list: data_type = String column_2_type[col] = data_type return column_2_type def _update_sqla_tables(self): """ Retrieves a dictionary of all SQLAlchemy table objects reflected from the database. """ self.metadata.reflect(self.engine) def _get_sqla_table(self, table_name): """ Get SQLAlchemy `Table` object of a table :param table_name: the name of the table :return: the SQLAlchemy `Table` object of the table """ # update info of existed tables return Table(table_name, self.metadata, autoload_with=self.engine) def _check_enable_deep_control(self): if not self.config._enable_deep_control: raise DatabaseDeepControlException()