Reconnect to postgresql & mysql external databases when getting disconnected
This commit is contained in:
		
							parent
							
								
									9cd150a048
								
							
						
					
					
						commit
						dfbc125947
					
				|  | @ -7,6 +7,12 @@ from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR | |||
| import os | ||||
| import logging | ||||
| 
 | ||||
| from peewee_migrate import Router | ||||
| from playhouse.db_url import connect | ||||
| 
 | ||||
| from apps.webui.internal.wrappers import PeeweeConnectionState, register_peewee_databases | ||||
| from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL | ||||
| 
 | ||||
| log = logging.getLogger(__name__) | ||||
| log.setLevel(SRC_LOG_LEVELS["DB"]) | ||||
| 
 | ||||
|  | @ -20,6 +26,8 @@ class JSONField(TextField): | |||
|             return json.loads(value) | ||||
| 
 | ||||
| 
 | ||||
| register_peewee_databases() | ||||
| 
 | ||||
| # Check if the file exists | ||||
| if os.path.exists(f"{DATA_DIR}/ollama.db"): | ||||
|     # Rename the file | ||||
|  | @ -29,6 +37,7 @@ else: | |||
|     pass | ||||
| 
 | ||||
| DB = connect(DATABASE_URL) | ||||
| DB._state = PeeweeConnectionState() | ||||
| log.info(f"Connected to a {DB.__class__.__name__} database.") | ||||
| router = Router( | ||||
|     DB, | ||||
|  |  | |||
|  | @ -0,0 +1,59 @@ | |||
| from contextvars import ContextVar | ||||
| 
 | ||||
| from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError, MySQLDatabase, _ConnectionState | ||||
| from playhouse.db_url import register_database | ||||
| from playhouse.pool import PooledPostgresqlDatabase, PooledMySQLDatabase | ||||
| from playhouse.shortcuts import ReconnectMixin | ||||
| from psycopg2 import OperationalError | ||||
| from psycopg2.errors import InterfaceError | ||||
| 
 | ||||
| 
 | ||||
| db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None} | ||||
| db_state = ContextVar("db_state", default=db_state_default.copy()) | ||||
| 
 | ||||
| 
 | ||||
| class PeeweeConnectionState(_ConnectionState): | ||||
|     def __init__(self, **kwargs): | ||||
|         super().__setattr__("_state", db_state) | ||||
|         super().__init__(**kwargs) | ||||
| 
 | ||||
|     def __setattr__(self, name, value): | ||||
|         self._state.get()[name] = value | ||||
| 
 | ||||
|     def __getattr__(self, name): | ||||
|         return self._state.get()[name] | ||||
| 
 | ||||
| 
 | ||||
| class CustomReconnectMixin(ReconnectMixin): | ||||
|     reconnect_errors = ( | ||||
|         # default ReconnectMixin exceptions (MySQL specific) | ||||
|         *ReconnectMixin.reconnect_errors, | ||||
|         # psycopg2 | ||||
|         (OperationalError, 'termin'), | ||||
|         (InterfaceError, 'closed'), | ||||
|         # peewee | ||||
|         (PeeWeeInterfaceError, 'closed'), | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| class ReconnectingPooledPostgresqlDatabase(CustomReconnectMixin, PooledPostgresqlDatabase): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| class ReconnectingMySQLDatabase(CustomReconnectMixin, MySQLDatabase): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| class ReconnectingPooledMySQLDatabase(CustomReconnectMixin, PooledMySQLDatabase): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| def register_peewee_databases(): | ||||
|     register_database(MySQLDatabase, 'mysql') | ||||
|     register_database(PooledMySQLDatabase, 'mysql+pool') | ||||
|     register_database(ReconnectingPostgresqlDatabase, 'postgres', 'postgresql') | ||||
|     register_database(ReconnectingPooledPostgresqlDatabase, 'postgres+pool', 'postgresql+pool') | ||||
		Loading…
	
		Reference in New Issue