first working version

This commit is contained in:
marys
2025-05-27 12:10:20 +02:00
parent a3b80eab6e
commit 61cbc536b6
9 changed files with 180 additions and 0 deletions

0
utils/db/__init__.py Normal file
View File

59
utils/db/db.py Normal file
View File

@@ -0,0 +1,59 @@
import logging
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.pool import NullPool
class DbConnector:
_instances = {}
def __new__(cls, db_id="default"):
# print("==================")
# [print(f"{i} -> {id(cls._instances[i])}") for i in cls._instances]
# print("==================")
if db_id not in cls._instances:
cls._instances[db_id] = super().__new__(cls)
return cls._instances[db_id]
def __init__(self, db_id:str="default"):
self.logger = logging.getLogger("spirent_service_db_connector")
def init(self, db_type:str, host:str, port:int, db_name:str, user:str, password:str):
"""
:param db_type: Example: mysql, postgresql, ...
:param host:
:param port:
:param db_name:
:param user:
:param password:
"""
self.db_type = db_type
self.host = host
self.port = port
self.name = db_name
self.user = user
self.password = password
protocol = {
'mysql': 'mysql+aiomysql',
'postgresql': 'postgresql+asyncpg',
}
sqlalchemy_database_url = f"{protocol[self.db_type]}://{user}:{password}@{host}:{port}/{db_name}"
# for security reason, password not print in logs.
sqlalchemy_database_url_log = f"mysql://{user}:******@{host}:{port}/{db_name}"
self.logger.debug(f"Connecting to db {sqlalchemy_database_url_log}")
self.engine = create_async_engine(sqlalchemy_database_url, poolclass=NullPool)
self.async_session_local = async_sessionmaker(bind=self.engine,
expire_on_commit=False,
class_=AsyncSession,
autoflush=False,
)
def get_db(self):
return self.async_session_local

View File

@@ -0,0 +1,44 @@
from utils.db.db import DbConnector
from models.model_mtr_network_nodes import MtrAdditionalNodesModel
from sqlalchemy import select
class MtrAdditionalNodeQuery:
def __init__(self):
db_con = DbConnector("spirent_mysql")
self.db = db_con.get_db()
async def get_all_nodes(self) -> list:
"""
Retrieve all entries from the mtr_network_nodes database table.
"""
async with self.db() as session:
result = await session.execute(select(MtrAdditionalNodesModel))
all_nodes = result.scalars().all()
#await session.close()
return all_nodes
def get_all_nodes_for_testing(self) -> list:
"""
Retrieve all entries from the mtr_network_nodes database table, exclude rows used for storing parameters for
testing TAS servers.
"""
all_nodes = (self.db.query(MtrAdditionalNodesModel)
.filter(MtrAdditionalNodesModel.test_servers == False)
.filter(MtrAdditionalNodesModel.enabled == True)
.all())
return all_nodes
def get_mtr_config_for_servers(self) -> list:
all_nodes = (self.db.query(MtrAdditionalNodesModel)
.filter(MtrAdditionalNodesModel.test_servers == True)
.filter(MtrAdditionalNodesModel.enabled == True)
.all())
return all_nodes
def get_node_by_id(self, node_id: int) -> MtrAdditionalNodesModel:
return self.db.query(MtrAdditionalNodesModel).filter_by(id=node_id).one()