Source code for pilotscope.DBController.SparkSQLController

import json
import logging
from typing import Union, Dict, Tuple

import numpy as np
import pandas
from pyspark.sql import DataFrame
from pyspark.sql import SparkSession
from pyspark.sql.types import *

from pilotscope.Common.Index import Index
from pilotscope.DBController.BaseDBController import BaseDBController
from pilotscope.Exception.Exception import PilotScopeNotSupportedOperationException
from pilotscope.PilotConfig import SparkConfig
from pilotscope.PilotEnum import PilotEnum, SparkSQLDataSourceEnum

logging.getLogger('pyspark').setLevel(logging.ERROR)
logging.getLogger("py4j").setLevel(logging.ERROR)
logger = logging.getLogger("PilotScope")

SUCCESS = 1
FAILURE = 0


class SparkSQLTypeEnum(PilotEnum):
    String = StringType
    Integer = IntegerType
    Float = FloatType


class SparkIOWriteModeEnum(PilotEnum):
    OVERWRITE = "overwrite"
    APPEND = "append"
    ERROR_IF_EXISTS = "errorifexists"
    IGNORE = "ignore"


def sparkSessionFromConfig(spark_config: SparkConfig):
    session = SparkSession.builder \
        .appName(spark_config.app_name) \
        .master(spark_config.master_url)
    for config_name in spark_config.spark_configs:
        session = session.config(config_name, spark_config.spark_configs[config_name])
    if spark_config.datasource_type == SparkSQLDataSourceEnum.POSTGRESQL:
        session = session.config("spark.jars.packages", spark_config.jdbc)
    return session.getOrCreate()


class SparkColumn(StructField):
    def __init__(self, column_name, column_type: SparkSQLTypeEnum):
        # Spark does not support primary key and auto-increment
        super().__init__(column_name, column_type())


class SparkTable:
    def __init__(self, table_name, metadata, *columns):
        self.table_name = table_name
        self.columns = list(columns)
        self.schema = StructType(list(columns))
        self.df: DataFrame = None

    def load(self, engine, analyze=False):
        self.df = engine.io.read(self.table_name)
        self.df.createOrReplaceTempView(self.table_name)
        self.schema = self.df.schema
        if analyze:
            self.analyzeStats(engine)

    def create(self, engine, analyze=False):
        if engine.has_table(engine.session, self.table_name, where="datasource"):
            # Table exists in the data source, load it directly
            # self.df = engine.io.read(self.table_name)
            self.load(engine, analyze)
        else:
            if engine.has_table(engine.session, self.table_name, where="session"):
                # Table exists in the current session. 
                # To avoid duplicated tables in the same session, an error will be thrown here.
                raise RuntimeError("Duplicated table cannot be created in the current session.")
            else:
                # No such a table in the data source as well as in the session, 
                # so create an empty table and persist it to the data source.
                self.df = engine.session.createDataFrame(data=[], schema=self.schema)
                # engine.io.write(self.df, mode=SparkIOWriteModeEnum.OVERWRITE, target_table_name=self.table_name)
            self.df.createOrReplaceTempView(self.table_name)
            if analyze:
                self.analyzeStats(engine)
        # engine.session.catalog.cacheTable(self.table_name)
        # if analyze:
        #    engine.session.sql("ANALYZE TABLE {} COMPUTE STATISTICS FOR ALL COLUMNS".format(self.table_name))

    # get the SQL string for insertion
    def insert(self, engine, column_2_value, analyze=False, persist=True):
        column_names = list(column_2_value.keys())
        new_row = engine.session.createDataFrame([tuple(column_2_value[col] for col in column_names)], column_names)
        self.df = self.df.union(new_row)
        self.df.createOrReplaceTempView(self.table_name)
        if analyze:
            self.analyzeStats(engine)
        if persist:
            self.persist(
                engine)  # persist self.df, now self.df is the table after inserting, so overwriting the whole table is right

    def nrows(self):
        return self.df.count()

    def cache(self, engine):
        engine.session.catalog.cacheTable(self.table_name)

    def analyzeStats(self, engine):
        self.cache(engine)
        engine.session.sql("ANALYZE TABLE {} COMPUTE STATISTICS FOR ALL COLUMNS".format(self.table_name))

    def persist(self, engine):
        engine.io.write(self, mode=SparkIOWriteModeEnum.OVERWRITE)

    def clear_rows(self, engine, persist=False):
        self.df = engine.session.createDataFrame(data=[], schema=self.schema)
        self.df.createOrReplaceTempView(self.table_name)
        self.analyzeStats(engine)
        if persist:
            self.persist(engine)


class SparkIO:
    def __init__(self, datasource_type: SparkSQLDataSourceEnum, engine, **datasource_conn_info) -> None:
        self.reader = None
        self.conn_info = datasource_conn_info
        self.datasource_type = datasource_type
        if datasource_type != SparkSQLDataSourceEnum.POSTGRESQL:
            raise RuntimeError("SparkIO has not been tested on any other data source types than 'postgresql'.")
        if datasource_type == SparkSQLDataSourceEnum.POSTGRESQL:
            self.reader = engine.session.read \
                .format("jdbc") \
                .option("driver", "org.postgresql.Driver") \
                .option("url", "jdbc:postgresql://{}:{}/{}".format(self.conn_info['host'], self.conn_info["port"],
                                                                   self.conn_info['db'])) \
                .option("user", self.conn_info['user']) \
                .option("password", self.conn_info['pwd'])

    def read(self, table_name=None, query=None) -> DataFrame:
        assert not (table_name is not None and query is not None)
        assert not (table_name is None and query is None)
        if table_name is not None:
            self.reader = self.reader.option("dbtable", table_name)
        elif query is not None:
            self.reader = self.reader.option("query", query)
        return self.reader.load()

    def write(self, table_or_rows: Union[SparkTable, DataFrame], mode: SparkIOWriteModeEnum, target_table_name=None):
        if isinstance(table_or_rows, SparkTable):
            df = table_or_rows.df
        else:
            df = table_or_rows
        if target_table_name is None and not isinstance(table_or_rows, SparkTable):
            raise Exception("Target table name not specified.")
        else:
            if target_table_name is not None:
                table_name = target_table_name
            elif isinstance(table_or_rows, SparkTable):
                table_name = table_or_rows.table_name
        if self.datasource_type == SparkSQLDataSourceEnum.POSTGRESQL:
            df.cache()
            rows = df.count()  # do not delete this read operation of df, which make it possible to overwrite tables.
            write = df.write \
                .mode(mode.value) \
                .format("jdbc") \
                .option("driver", "org.postgresql.Driver") \
                .option("url", "jdbc:postgresql://{}:{}/{}".format(self.conn_info['host'], self.conn_info["port"],
                                                                   self.conn_info['db'])) \
                .option("user", self.conn_info['user']) \
                .option("password", self.conn_info['pwd']) \
                .option("dbtable", table_name)
            write.save()
            assert rows == df.count()

    def has_table(self, table_name):
        return self.read(table_name="information_schema.tables") \
                   .filter("table_name = '{}'".format(table_name)) \
                   .count() > 0

    def get_all_table_names_in_datasource(self) -> np.ndarray:
        return self.read(table_name="information_schema.tables") \
            .filter("table_schema == 'public'") \
            .filter("table_type == 'BASE TABLE'").toPandas()["table_name"].values


class SparkEngine:
    def __init__(self, config: SparkConfig):
        self.config = config
        self.session = None
        self.io = None

    def connect(self):
        self.session = sparkSessionFromConfig(self.config)
        self.io = SparkIO(self.config.datasource_type, self, host=self.config.db_host, port=self.config.db_port,
                          db=self.config.db,
                          user=self.config.db_user, pwd=self.config.db_user_pwd)
        return self.session

    def _has_table_in_datasource(self, table_name):
        return self.io.has_table(table_name)

    def _has_table_in_session(self, session, table_name):
        return session.catalog.tableExists(table_name)

    def has_table(self, connection: SparkSession, table_name: str, where="datasource") -> bool:
        if where == "datasource":
            return self._has_table_in_datasource(table_name)
        elif where == "session":
            return self._has_table_in_session(connection, table_name)
        else:
            raise ValueError("Unsupport 'where' value: {}".format(where))

    def get_all_table_names_in_datasource(self) -> np.ndarray:
        return self.io.get_all_table_names_in_datasource()

    # def clearCachedTables(self):
    #    self.session.catalog.clearCache()


[docs]class SparkSQLController(BaseDBController):
[docs] def __init__(self, config: SparkConfig, echo=False): super().__init__(config, echo)
def _db_init(self): self.name_2_table = {} self.engine: SparkEngine = self._create_engine() self._connect_if_loss() def _create_conn_str(self): return "" def _create_engine(self): return SparkEngine(self.config)
[docs] def load_all_tables_from_datasource(self): all_user_created_table_names = self.engine.get_all_table_names_in_datasource() for table_name in all_user_created_table_names: self.load_table_if_exists_in_datasource(table_name)
def _connect_if_loss(self): if not self._is_connect(): self.connection_thread.conn = self.engine.connect() all_user_created_table_names = self.engine.get_all_table_names_in_datasource() for table_name in all_user_created_table_names: self.load_table_if_exists_in_datasource(table_name) logger.debug("[connect_if_loss] Loaded table '{}'".format(table_name)) pass def _disconnect(self): if self._get_connection() is not None: # try: self.persist_tables() self.engine.clearCachedTables() self.connection_thread.conn.stop() self.connection_thread.conn = None # except: # deal with connection already stopped # pass
[docs] def persist_table(self, table_name): self.name_2_table[table_name].persist(self.engine)
[docs] def persist_tables(self): for table in self.name_2_table.values(): table.persist(self.engine)
[docs] def exist_table(self, table_name, where="session") -> bool: has_table = self.engine.has_table(self._get_connection(), table_name, where) if has_table: return True return False
[docs] def load_table_if_exists_in_datasource(self, table_name): if (not self.exist_table(table_name, where="session")) and self.exist_table(table_name, where="datasource"): # If the table exists in the data source but not in the current session, # then the table will be loaded from the data source. # logger.debug( # "[create_table_if_absences] Table '{}' exists in the data source but not in the current session, ".format( # table_name) + # "so it will be loaded from the data source and your input schema will be ignored.") table = SparkTable(table_name, None) table.load(self.engine) self.name_2_table[table_name] = table
# collect statistics for the given table
[docs] def analyze_table_stats(self, table_name): self.load_table_if_exists_in_datasource(table_name) self.name_2_table[table_name].analyzeStats(self.engine)
[docs] def analyze_all_table_stats(self): for table in self.name_2_table.values(): table.analyzeStats(self.engine)
# clear all SparkTable instances in cache
[docs] def clear_all_tables(self): conn = self._get_connection() conn.catalog.clearCache() for table_name in self.name_2_table: conn.catalog.dropTempView(table_name) self.name_2_table.clear()
# check whether the input key (config name) is modifiable in runtime # and set its value to the given value if it is modifiable
[docs] def set_hint(self, key, value): if self._get_connection().conf.isModifiable(key): # self.connection.conf.set(key, value) sql = "SET {} = {}".format(key, value) self.execute(sql) else: logger.warning( "[get_hint_sql] Configuration '{}' is not modifiable in runtime, nothing changed".format(key))
[docs] def create_table_if_absences(self, table_name, column_2_value, primary_key_column=None, enable_autoincrement_id_key=True): self._connect_if_loss() if primary_key_column is not None: logger.warning( "[create_table_if_absences] Spark SQL does not support specifying primary key while creating table.") primary_key_column = None column_2_type = self._to_db_data_type(column_2_value) # metadata_obj = self.metadata if not self.exist_table(table_name, where="session"): # Only checks whether the table exists in current session. # If the table exists in the data source but not in the session, # here self.exist_table simply returns False, # then table.create will load it from the data source. if self.exist_table(table_name, where="datasource"): logger.warning( "[create_table_if_absences] Table '{}' exists in the data source but not in the current session, ".format( table_name) + "so it will be loaded from the data source and your input schema will be ignored.") columns = [] for column, column_type in column_2_type.items(): columns.append(SparkColumn(column, column_type)) table = SparkTable(table_name, None, *columns) table.create(self.engine) self.name_2_table[table_name] = table else: logger.warning("[create_table_if_absences] Table '{}' exists, nothing changed.".format(table_name))
[docs] def get_table_row_count(self, table_name): self.load_table_if_exists_in_datasource(table_name) if table_name not in self.name_2_table: raise RuntimeError( "The table '{}' not found in both current session and the data source.".format(table_name)) return self.name_2_table[table_name].nrows()
[docs] def insert(self, table_name, column_2_value: dict, persist=True): self.load_table_if_exists_in_datasource(table_name) table = self.name_2_table[table_name] table.insert(self.engine, column_2_value, persist=persist)
[docs] def execute(self, sql, fetch=False, fetch_column_name=False) -> Union[pandas.DataFrame, DataFrame]: row = None try: self._connect_if_loss() df = self._get_connection().sql(sql) row = df.toPandas() if not fetch: row = df except Exception as e: if "PilotScopePullEnd" not in str(e): raise e return row
def _unresolvedLogicalPlan(self, query_execution): return query_execution.logical() def _resolvedLogicalPlan(self, query_execution): return query_execution.spark_analyzed() def _optimizedLogicalPlan(self, query_execution): return query_execution.optimizedPlan() def _logicalPlan(self, query_execution): return self._optimizedLogicalPlan(query_execution) def _physicalPlan(self, query_execution): return query_execution.executedPlan()
[docs] def explain_physical_plan(self, sql, comment="") -> Dict: sql = "{} {}".format(comment, sql) plan = self._physicalPlan(self.execute(sql)._jdf.queryExecution()) return json.loads(plan.toJSON())[0]
[docs] def get_estimated_cost(self, sql, comment="") -> Tuple[int]: raise PilotScopeNotSupportedOperationException( "Spark SQL does not support cost estimation.You can use row count or sizeByte instead.")
[docs] def write_knob_to_file(self, key_2_value_knob): for k, v in key_2_value_knob.items(): self.set_hint(k, v)
[docs] def recover_config(self): # reset all modifiable runtime configurations self._get_connection().sql("RESET")
[docs] def shutdown(self): pass
[docs] def start(self): pass
[docs] def explain_execution_plan(self, sql, comment=""): raise NotImplementedError
[docs] def status(self): raise PilotScopeNotSupportedOperationException
[docs] def get_buffercache(self): raise PilotScopeNotSupportedOperationException
def create_index(self, index_name, table, columns): raise PilotScopeNotSupportedOperationException
[docs] def create_index(self, index): raise PilotScopeNotSupportedOperationException
[docs] def drop_index(self, index): raise PilotScopeNotSupportedOperationException
[docs] def drop_all_indexes(self): raise PilotScopeNotSupportedOperationException
[docs] def get_all_indexes_byte(self): raise PilotScopeNotSupportedOperationException
[docs] def get_table_indexes_byte(self, table_name): raise PilotScopeNotSupportedOperationException
[docs] def get_index_byte(self, index: Index): raise PilotScopeNotSupportedOperationException
def _to_db_data_type(self, column_2_value): column_2_type = {} for col, data in column_2_value.items(): data_type = SparkSQLTypeEnum.String if type(data) == int: data_type = SparkSQLTypeEnum.Integer elif type(data) == float: data_type = SparkSQLTypeEnum.Float elif type(data) == str: data_type = SparkSQLTypeEnum.String elif type(data) == dict: data_type = SparkSQLTypeEnum.String elif type(data) == list: data_type = SparkSQLTypeEnum.String column_2_type[col] = data_type.value return column_2_type