Move get_sqlalchemy_connection to its own file.

This commit is contained in:
Tim Abbott
2016-07-18 23:12:35 -07:00
parent 6a90dc07dc
commit afaac85dc6
4 changed files with 40 additions and 36 deletions

View File

@@ -0,0 +1,36 @@
from django.db import connection
import sqlalchemy
# This is a Pool that doesn't close connections. Therefore it can be used with
# existing Django database connections.
class NonClosingPool(sqlalchemy.pool.NullPool):
def status(self):
return "NonClosingPool"
def _do_return_conn(self, conn):
pass
def recreate(self):
return self.__class__(creator=self._creator, # type: ignore # __class__
recycle=self._recycle,
use_threadlocal=self._use_threadlocal,
reset_on_return=self._reset_on_return,
echo=self.echo,
logging_name=self._orig_logging_name,
_dispatch=self.dispatch)
sqlalchemy_engine = None
def get_sqlalchemy_connection():
global sqlalchemy_engine
if sqlalchemy_engine is None:
def get_dj_conn():
connection.ensure_connection()
return connection.connection
sqlalchemy_engine = sqlalchemy.create_engine('postgresql://',
creator=get_dj_conn,
poolclass=NonClosingPool,
pool_reset_on_return=False)
sa_connection = sqlalchemy_engine.connect()
sa_connection.execution_options(autocommit=False)
return sa_connection

View File

@@ -7,7 +7,7 @@ from django.test.runner import DiscoverRunner
from django.test.signals import template_rendered
from zerver.lib.cache import bounce_key_prefix_for_testing
from zerver.views.messages import get_sqlalchemy_connection
from zerver.lib.sqlalchemy_utils import get_sqlalchemy_connection
from zerver.lib.test_helpers import get_all_templates
import os

View File

@@ -10,12 +10,13 @@ from zerver.models import (
get_display_recipient, get_recipient, get_realm, get_stream, get_user_profile_by_email,
)
from zerver.lib.actions import create_stream_if_needed, do_add_subscription
from zerver.lib.sqlalchemy_utils import get_sqlalchemy_connection
from zerver.lib.test_helpers import (
AuthedTestCase, POSTRequestMock,
get_user_messages, message_ids, queries_captured,
)
from zerver.views.messages import (
exclude_muting_conditions, get_sqlalchemy_connection,
exclude_muting_conditions,
get_old_messages_backend, ok_to_include_history,
NarrowBuilder, BadNarrowOperator
)

View File

@@ -23,6 +23,7 @@ from zerver.lib.actions import recipient_for_emails, do_update_message_flags, \
extract_recipients, truncate_body
from zerver.lib.cache import generic_bulk_cached_fetch
from zerver.lib.response import json_success, json_error
from zerver.lib.sqlalchemy_utils import get_sqlalchemy_connection
from zerver.lib.utils import statsd
from zerver.lib.validator import \
check_list, check_int, check_dict, check_string, check_bool
@@ -34,7 +35,6 @@ from zerver.models import Message, UserProfile, Stream, Subscription, \
resolve_email_to_domain, get_realm, get_active_streams, \
bulk_get_streams
import sqlalchemy
from sqlalchemy import func
from sqlalchemy.sql import select, join, column, literal_column, literal, and_, \
or_, not_, union_all, alias
@@ -46,39 +46,6 @@ import datetime
from six.moves import map
import six
# This is a Pool that doesn't close connections. Therefore it can be used with
# existing Django database connections.
class NonClosingPool(sqlalchemy.pool.NullPool):
def status(self):
return "NonClosingPool"
def _do_return_conn(self, conn):
pass
def recreate(self):
return self.__class__(creator=self._creator, # type: ignore # __class__
recycle=self._recycle,
use_threadlocal=self._use_threadlocal,
reset_on_return=self._reset_on_return,
echo=self.echo,
logging_name=self._orig_logging_name,
_dispatch=self.dispatch)
sqlalchemy_engine = None
def get_sqlalchemy_connection():
global sqlalchemy_engine
if sqlalchemy_engine is None:
def get_dj_conn():
connection.ensure_connection()
return connection.connection
sqlalchemy_engine = sqlalchemy.create_engine('postgresql://',
creator=get_dj_conn,
poolclass=NonClosingPool,
pool_reset_on_return=False)
sa_connection = sqlalchemy_engine.connect()
sa_connection.execution_options(autocommit=False)
return sa_connection
class BadNarrowOperator(JsonableError):
def __init__(self, desc, status_code=400):
self.desc = desc