Source code for finds.database.sql

"""SQL class wrapper, with convenience methods for pandas DataFrames

Copyright 2022-2024, Terence Lim

MIT License
"""
from typing import List, Dict, Mapping, Any, Tuple
import random
import numpy as np
import pandas as pd
from pandas import DataFrame, Series
import sqlalchemy
from sqlalchemy import text, Integer, SmallInteger, Boolean, Float, String
from sqlalchemy.orm import sessionmaker      
from finds.database import Database

[docs]def as_dtypes(df: DataFrame, columns: Dict, drop_duplicates: List[str] = [], sort_values: List[str] = [], keep: str ='first', replace : Dict[str, Tuple[Any, Any]] = {}) -> DataFrame: """Convert DataFrame dtypes to the given sqlalchemy Column types Args: df: Input DataFrame to apply new data types from target columns columns: Target sqlalchemy column types as dict of {column: type} sort_values: List of column names to sort by drop_duplicates: list of fields if all duplicated to drop rows keep : 'first' or 'last' row to keep if drop duplicates replace : dict of {column label: tuple(old, replacement) values} Returns: DataFrame with columns and rows transformed Notes: - Columns of DataFrame are dropped if not specified in columns input - If input is None, then return empty DataFrame with given column types - Blank values in boolean and int fields are set to False/0. - Invalid/blank values in double field are coerced to NaN. - Invalid values in int field are coerced to 0 """ if df is None: df = DataFrame(columns=list(columns)) df.columns = df.columns.map(str.lower).map(str.rstrip) # clean column names df = df.reindex(columns=list(columns)) # reorder and only keep columns if len(sort_values): df.sort_values(sort_values) if len(drop_duplicates): df.drop_duplicates(subset=drop_duplicates, keep=keep, inplace=True) for col, v in columns.items(): try: if col in replace: df[col] = df[col].replace(*replace[col]) if isinstance(v, Integer) or isinstance(v, SmallInteger): df[col] = df[col].replace("(?<=\d)-","", regex=True) # crsp dates df[col] = df[col].replace('', 0).astype(int) elif isinstance(v, Boolean): df[col] = df[col].replace('', False).astype(bool) elif isinstance(v, Float): df[col] = pd.to_numeric(df[col], errors='coerce').astype(float) elif isinstance(v, String): df[col] = df[col].astype(str).str.encode('ascii', 'ignore')\ .str.decode('ascii') else: raise Exception('(as_dtypes) Unknown type for column: ' + col) except: raise Exception('(as_dtypes) bad data in column: ' + col) return df
[docs]class SQL(Database): """Interface to sqlalchemy, with convenience functions for dataframes""" def __init__(self, user: str, password: str, host: str = 'localhost', port: str = '3306', database: str = '', autocommit: str = 'true', charset: str = 'utf8', temp: str = f"temp{random.randint(0, 8192)}", **kwargs): super().__init__(**kwargs) self.url = f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}"\ + f"?charset={charset}&local_infile=1&autocommit={autocommit}" self._t = temp # name of temp table for this process self.create_engine()
[docs] @staticmethod def create_database(user: str, password: str, host: str = 'localhost', port: str = '3306', database: str = '', **kwargs): """Create new database using this user's credentials""" url = f"mysql+pymysql://{user}:{password}@{host}:{port}" engine = sqlalchemy.create_engine(url) with engine.begin() as conn: conn.execute(text("COMMIT")) conn.execute(text(f"CREATE DATABASE {database}"))
[docs] def create_engine(self): """Call and store sqlalchemy.create_engine() and MetaData()""" self.engine = sqlalchemy.create_engine(self.url, echo=self._verbose > 0) self.metadata = sqlalchemy.MetaData()
[docs] def rollback(self): """Call sessionmaker() to rollback current transaction in progress""" Session = sessionmaker(self.engine) with Session() as session: session.rollback()
[docs] def Table(self, key: str, *args, **kwargs) -> sqlalchemy.Table: """Wraps sqlalchemy.Table() after removing key from metadata""" if key in self.metadata.tables: # remove from metadata if existed self.metadata.remove(self.metadata.tables[key]) table = sqlalchemy.Table(key, self.metadata, *args, **kwargs) #self.metadata.create_all(self.engine) return table
[docs] def create_all(self): """Create all tables in metadata""" self.metadata.create_all(self.engine)
[docs] @classmethod def Index(cls, *args) -> sqlalchemy.Index: """Wraps sqlalchemy.Index() with auto-generated index name from args""" return sqlalchemy.Index("_".join(args), *args)
[docs] def remove(self, key: str): """Remove a table by key name from metadata instance""" if key in self.metadata.tables: self.metadata.remove(self.metadata.tables[key])
[docs] def run(self, q) -> Dict | None: """Execute sql command Args: q: query string Returns: The result set {'data', 'columns'}, or None. Raises: RuntimeError: failed to run query Examples: >>> sql.run("show databases") >>> sql.run("show tables") >>> sql.run('select * from testing') >>> sql.run('select distinct permno from benchmarks') >>> sql.run("show create table _") >>> sql.run("describe _") >>> sql.run("truncate table _", fetch=False) """ if isinstance(q, str): q = text(q) for _ in range(2): try: with self.engine.begin() as conn: try: r = conn.execute(q) return {'data': r.fetchall(), 'columns': r.keys()} except Exception: return None break except Exception as e: self._print(e) self.create_engine() raise RuntimeError('(sql.run) ' + q)
[docs] def summary(self, table: str, val: str, key: str = '') -> DataFrame: """Return summary statistics for a field, optionally grouped-by key Args: table: Physical name of table val: Field name to summarise key: Field to group by Returns: DataFrame with columns (count, average, max, min) Examples: >>> sql.summary('annual', 'revt', 'sic') """ if key: q = (f"SELECT {key}, COUNT(*) as count, AVG({val}) as avg, " f" STD({val}) as std, MAX({val}) as max, MIN({val}) as min " f" FROM {table} GROUP BY {key}") return self.read_dataframe(q).set_index(key).sort_index() else: q = (f"SELECT COUNT(*) as count, AVG({val}) as avg, " f" MAX({val}) as max, MIN({val}) as min FROM {table}") return DataFrame(index=[val], **self.run(q))
[docs] def load_infile(self, table: str, csvfile: str, options: str =''): """Load table from csv file, using mysql's load data local infile Args: table: Physical name of table to load into csvfile: CSV filename options: String appended to SQL load infile query """ q = (f"LOAD DATA LOCAL INFILE '{csvfile}' INTO TABLE {table} " f" FIELDS TERMINATED BY ',' ENCLOSED BY '\"'" f" LINES TERMINATED BY '\\n' IGNORE 1 ROWS {options};") try: self._print("(load_infile)", q) self.run(q) except Exception as e: print("(load_infile) Got exception = ", e, " Query = ", q) raise e
[docs] def load_dataframe(self, table: str, df: DataFrame, index_label: str = '', to_sql: bool = True, replace: bool = False): """Load dataframe into sql table, ignoring duplicate primary keys Args: table: Physical name of table to insert into df: Source dataframe index_label: Column name to load index as, None (default) to ignore to_sql: first attempt pandas.to_sql(), which may fail if duplicate keys; then/else insert ignore from temp table instead. replace: set True to overwrite table, else append (default) """ df.columns = df.columns.map(str.lower).map(str.rstrip) chunksize = int(1024*1024*32 // len(df.columns)) try: # to_sql raises exception if exist duplicate keys assert(to_sql) df.to_sql(table, self.engine, if_exists=('replace' if replace else 'append'), chunksize=chunksize, index=bool(index_label), index_label=index_label) except Exception as e: # duplicates exist self._print("(load_dataframe) Retrying insert ignore", table) self.run('drop table if exists ' + self._t) df.to_sql(self._t, self.engine, if_exists='replace', chunksize=chunksize, index=bool(index_label), index_label=index_label) # warnings.filterwarnings("ignore", category=pymysql.Warning) columns = ", ".join(df.columns) q = (f"INSERT IGNORE INTO {table} ({columns})" f" SELECT {columns} FROM {self._t}") self.run(q) # warnings.filterwarnings("default", category=pymysql.Warning) self.run('drop table if exists ' + self._t)
[docs] def read_dataframe(self, q: str): """Return sql query result as data frame Args: q: query string or SQLAlchemy Selectable Returns: DataFrame of results Raises: RuntimeError: Failed to run query """ result = self.run(q) if result is None: raise RuntimeError('read_dataframe error in database: ', str(q)) return DataFrame(**result)
[docs] def pivot(self, table: str, index: str, columns: str, values: str, where: str = '', limit: int | None = None, chunksize: int | None = None) -> DataFrame: """Return sql query result as pivoted data frame Args: table: Physical name of table to retrieve from index: Field name to select as dataframe index columns: Field name to select as column labels values: Field name to select as values where: Where clause, optional limit: Maximum optional number of rows or chunks to return chunksize: To optionally buildup results in chunks of this size Returns: Query result as a pivoted (wide) DataFrame """ if where and not where.strip().upper().startswith('WHERE'): where = 'WHERE ' + where if isinstance(chunksize, int): # execute in chunks rows = self.read_dataframe( f"SELECT DISTINCT {index} FROM {table} {where}" ) rows = np.array(rows[index].astype(str)) out = DataFrame() n_features = len(rows) n_splits = n_features // chunksize if n_splits * chunksize < n_features: n_splits += 1 for i in range(n_splits): row = slice(chunksize * i, min(n_features, chunksize * (i+1))) self._print('slice #', i, 'of', n_splits) if isinstance(limit, int) and i >= limit: break where += " AND " if where else " WHERE " indexes = "','".join(rows[row]) q = (f"SELECT {index}, {columns}, {values} FROM {table} " f" {where} {index} in ('{indexes}')") df = self.read_dataframe(q) out = out.append(df.pivot(index=index, columns=columns, values=values), sort=True) return out else: # execute as single chunk where += " LIMIT " + str(limit) if limit else '' q = f"SELECT {index}, {columns}, {values} FROM {table} {where}" self._print('(pivot)', q) return self.read_dataframe(q).pivot(index=index, columns=columns, values=values)
if __name__ == "__main__": from secret import credentials VERBOSE = 1 # Create new databases SQL.create_database(**credentials['sql']) SQL.create_database(**credentials['user']) # Show data tables sql = SQL(**credentials['sql'], verbose=VERBOSE) print(sql.run('show tables')) # Show user tables user = SQL(**credentials['user'], verbose=VERBOSE) print(user.run('show tables')) # test a transaction df = DataFrame(data=[[1, 1.5, 'a'], [2, '2.5', None]], columns=['a', 'b', 'c'], index=['d', 'e']) user.run('drop table if exists test') user.load_dataframe('test', df) s = user.run('select * from test') print('test:') print(DataFrame(**s))