from typing import List
from pilotscope.Anchor.BaseAnchor.BaseAnchorHandler import BaseAnchorHandler
from pilotscope.Anchor.BaseAnchor.BasePullHandler import RecordPullHandler, BasePullHandler
from pilotscope.Anchor.BaseAnchor.BasePushHandler import BasePushHandler
from pilotscope.Common.Util import extract_handlers
from pilotscope.DBInteractor.PilotDataInteractor import PilotDataInteractor
from pilotscope.PilotEnum import *
from pilotscope.PilotEvent import *
from pilotscope.PilotTransData import PilotTransData
# noinspection PyProtectedMember
[docs]class PilotScheduler:
[docs] def __init__(self, config: PilotConfig) -> None:
"""
:param config: The configuration of PilotScope.
"""
self.config = config
self.table_name_for_store_data = None
self.data_manager: DataManager = DataManager(self.config)
self.db_controller = DBControllerFactory.get_db_controller(self.config)
self.events = []
self.user_tasks: List[BasePushHandler] = []
self.data_interactor = PilotDataInteractor(self.config)
[docs] def init(self):
"""
Initialize the scheduler for enabling the AI4DB algorithms, triggering the registered events and others.
This function should be called before executing any sql and after registering all the required data and events.
"""
self._deal_initial_events()
[docs] def execute(self, sql):
"""
The function will finish the following tasks:
1. execute a sql using the registered AI4DB algorithms.
2. save the collected data into the specific table
3. try to trigger the registered events
:param sql: a sql to be executed
:return: the related records of the sql
"""
data_interactor = self.data_interactor
# add recordPullAnchor
record_handler = RecordPullHandler(self.config)
data_interactor._add_anchor(record_handler.anchor_name, record_handler)
# add all replace anchors from user
data_interactor._add_anchors(self.user_tasks)
# replace value based on user's method
for replace_handle in self.user_tasks:
replace_handle._update_injected_data(sql)
result = data_interactor.execute(sql, is_reset=False)
if result is not None:
self._post_process(result)
return result.records
return None
[docs] def register_custom_handlers(self, handlers: List[BaseAnchorHandler]):
"""
Register custom AI4DB handlers
:param handlers: a list of custom handlers
"""
if not self._is_valid_custom_handlers(handlers):
raise RuntimeError("pilotscope is not allowed to register identical class type for custom handler")
if not isinstance(handlers, List):
handlers = [handlers]
self.user_tasks += handlers
[docs] def register_required_data(self, table_name_for_store_data, pull_execution_time=False, pull_physical_plan=False,
pull_subquery_2_cards=False, pull_buffer_cache=False, pull_estimated_cost=False):
"""
Register data need to collect when executing a sql
:param table_name_for_store_data: the table name for storing the collected data
:param pull_execution_time: whether to get the execution time of a sql
:param pull_physical_plan: whether to get the physical plan of a sql
:param pull_subquery_2_cards: whether to get the sub-plan queries and their cardinality of a sql
:param pull_buffer_cache: whether to get the buffer cache of table after executing a sql
:param pull_estimated_cost: whether to get the estimated cost of a sql
:return:
"""
if pull_execution_time:
self.data_interactor.pull_execution_time()
if pull_physical_plan:
self.data_interactor.pull_physical_plan()
if pull_subquery_2_cards:
self.data_interactor.pull_subquery_card()
if pull_buffer_cache:
self.data_interactor.pull_buffercache()
if pull_estimated_cost:
self.data_interactor.pull_estimated_cost()
self.table_name_for_store_data = table_name_for_store_data
[docs] def register_events(self, events: List[Event]):
"""
Register events into scheduler.
:param events: the events to be registered
"""
if not isinstance(events, List):
events = [events]
self.events += events
def _post_process(self, data: PilotTransData):
self._store_collected_data_into_table(data)
self._deal_execution_end_events()
def _store_collected_data_into_table(self, data: PilotTransData):
pull_anchors = extract_handlers(self.data_interactor._get_all_handlers(), True)
column_2_value = {}
for anchor in pull_anchors:
if isinstance(anchor, BasePullHandler):
anchor.prepare_data_for_writing(column_2_value, data)
else:
raise RuntimeError
self.data_manager.save_data(self.table_name_for_store_data, column_2_value)
def _deal_initial_events(self):
pretraining_thread = None
for event in self.events:
if isinstance(event, PretrainingModelEvent):
event: PretrainingModelEvent = event
pretraining_thread = event._async_start()
elif isinstance(event, PeriodicModelUpdateEvent):
event: PeriodicModelUpdateEvent = event
if event.execute_before_first_query:
event.process(self.db_controller, self.data_manager)
elif isinstance(event, WorkloadBeforeEvent):
event: WorkloadBeforeEvent = event
event._update(self.db_controller, self.data_manager)
# wait until finishing pretraining
if pretraining_thread is not None and self.config.pretraining_model == TrainSwitchMode.WAIT:
pretraining_thread.join()
pass
def _deal_execution_end_events(self):
for event in self.events:
if isinstance(event, QueryFinishEvent):
event: QueryFinishEvent = event
event._update(self.db_controller, self.data_manager)
def _is_valid_custom_handlers(self, handlers):
# return false, if there is identical class typy for the elements in handlers
deduplicated_size = len(set([type(handler) for handler in handlers]))
return deduplicated_size == len(handlers)