60 lines
2.0 KiB
Python
60 lines
2.0 KiB
Python
|
|
import logging
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
|
from sqlalchemy.pool import NullPool
|
|
|
|
class DbConnector:
|
|
_instances = {}
|
|
|
|
def __new__(cls, db_id="default"):
|
|
# print("==================")
|
|
# [print(f"{i} -> {id(cls._instances[i])}") for i in cls._instances]
|
|
# print("==================")
|
|
if db_id not in cls._instances:
|
|
cls._instances[db_id] = super().__new__(cls)
|
|
return cls._instances[db_id]
|
|
|
|
def __init__(self, db_id:str="default"):
|
|
self.logger = logging.getLogger("spirent_service_db_connector")
|
|
|
|
def init(self, db_type:str, host:str, port:int, db_name:str, user:str, password:str):
|
|
"""
|
|
|
|
:param db_type: Example: mysql, postgresql, ...
|
|
:param host:
|
|
:param port:
|
|
:param db_name:
|
|
:param user:
|
|
:param password:
|
|
"""
|
|
|
|
self.db_type = db_type
|
|
self.host = host
|
|
self.port = port
|
|
self.name = db_name
|
|
self.user = user
|
|
self.password = password
|
|
|
|
protocol = {
|
|
'mysql': 'mysql+aiomysql',
|
|
'postgresql': 'postgresql+asyncpg',
|
|
}
|
|
|
|
sqlalchemy_database_url = f"{protocol[self.db_type]}://{user}:{password}@{host}:{port}/{db_name}"
|
|
# for security reason, password not print in logs.
|
|
sqlalchemy_database_url_log = f"mysql://{user}:******@{host}:{port}/{db_name}"
|
|
self.logger.debug(f"Connecting to db {sqlalchemy_database_url_log}")
|
|
|
|
self.engine = create_async_engine(sqlalchemy_database_url, poolclass=NullPool)
|
|
self.async_session_local = async_sessionmaker(bind=self.engine,
|
|
expire_on_commit=False,
|
|
class_=AsyncSession,
|
|
autoflush=False,
|
|
)
|
|
|
|
def get_db(self):
|
|
return self.async_session_local
|
|
|
|
|
|
|