Skip to content

Instantly share code, notes, and snippets.

@hest
Created February 4, 2014 06:08
Show Gist options
  • Select an option

  • Save hest/8798884 to your computer and use it in GitHub Desktop.

Select an option

Save hest/8798884 to your computer and use it in GitHub Desktop.
Fast SQLAlchemy counting (avoid query.count() subquery)
def get_count(q):
count_q = q.statement.with_only_columns([func.count()]).order_by(None)
count = q.session.execute(count_q).scalar()
return count
q = session.query(TestModel).filter(...).order_by(...)
# Slow: SELECT COUNT(*) FROM (SELECT ... FROM TestModel WHERE ...) ...
print q.count()
# Fast: SELECT COUNT(*) FROM TestModel WHERE ...
print get_count(q)
@zt50tz
Copy link
Copy Markdown

zt50tz commented Nov 8, 2020

Hi all!

If you using ORM querys with model as first argument — the comments below are good to use.

But, in my project i use many querys like:
db.session.query(User.id, UserInfo.name).outerjoin(UserInfo, UserInfo.user_id == User.id)

And prev examples are not able to use.

So i write class that process the query and removes unusable joins for count query.
This class checks joins of query and where conditions then removes unnecessary joins.

Usage:

q = db.session.query(
    User.id, User.nname, UserInfo.theme_id
).outerjoin(
    UserInfo, UserInfo.user_id == User.id
).filter(
    User.id > 0,
    UserInfo.name.like('123')
)

# where=True — take a look on "where"; no drop joins if it exists in "filter" statements
# fields=False — not take a look on fields; drop joins if it exists in fields
count_q = DBRemoveJoin(q, fields=False, where=True, debug=False).process()
count_q = count_q.statement.with_only_columns([func.count('*')]).order_by(None)

print(count_q)

Output:

SELECT count(:count_2) AS count_1
FROM "User" LEFT OUTER JOIN "UserInfo" ON "UserInfo".user_id = "User".id
WHERE "User".id > :id_1 AND "UserInfo".name LIKE :name_1

If you comment UserInfo where statement:

q = db.session.query(
    User.id, User.nname, UserInfo.theme_id
).outerjoin(
    UserInfo, UserInfo.user_id == User.id
).filter(
    User.id > 0,
    # UserInfo.name.like('123')
)

Result will be:

SELECT count(:count_2) AS count_1
FROM "User"
WHERE "User".id > :id_1

Class code:


import collections
import six
from sqlalchemy.orm.util import _ORMJoin, outerjoin
from sqlalchemy import Table, Column

from collections import OrderedDict


def iterable(arg):
    return (
        isinstance(arg, collections.Iterable)
        and not isinstance(arg, six.string_types)
    )


class DBRemoveJoin:
    q = None

    fields = True
    where = True
    debug = False

    table_main = None
    table_need = []
    table_joins = {}

    column_items = []
    where_items = []

    def __init__(self, q, fields=True, where=True, debug=False):
        self.fields = fields
        self.where = where
        self.debug = debug
        self.q = q.filter()
        self.table_need = []
        self.column_items = []
        self.where_items = []
        self.table_joins = OrderedDict()

    @staticmethod
    def clause_columns_process(clause):
        ret = []
        column_objs = []
        if hasattr(clause, 'left'):
            column_objs.append(clause.left)
        if hasattr(clause, 'right'):
            column_objs.append(clause.right)
        if column_objs:
            for column_obj in column_objs:
                if isinstance(column_obj, Column):
                    where_item = {
                        'name': column_obj.name,
                        'column': column_obj,
                        'table_name': str(column_obj.table.name)
                    }
                    ret.append(where_item)
        else:
            for tbl in clause._from_objects:
                where_item = {
                    'name': None,
                    'column': None,
                    'table_name': str(tbl.name)
                }
                ret.append(where_item)
        return ret

    @staticmethod
    def clause_columns(clauses):
        ret = []
        if not iterable(clauses):
            clauses = [clauses]
        for clause in clauses:
            ret += DBRemoveJoin.clause_columns_process(clause)
        return ret
            
    def table_joins_process_join(self, table):
        if isinstance(table.left, Table):
            self.table_main = table.left
        else:
            self.table_joins_process_join(table.left)

        right_str = str(table.right)
        if not right_str:
            right_str = str(table.right.selectable.name)

        right_el = {
            'name': right_str,
            'table': table.right,
            'onclause': table.onclause.expression,
            'table_need': [],
            'need': False
        }

        clause_columns = self.clause_columns(table.onclause)
        clause_columns_skip_table_name = [right_el['name'], self.table_main]
        for clause_column in clause_columns:
            if clause_column['table_name'] not in clause_columns_skip_table_name:
                right_el['table_need'].append(clause_column['table_name'])

        self.table_joins[right_str] = right_el

    def table_joins_process(self):
        for table in self.q._from_obj:
            if isinstance(table, _ORMJoin):
                self.table_joins_process_join(table)
            else:
                pass

    def table_need_add(self, table_name):
        if table_name in self.table_need:
            return
        self.table_need.append(table_name)
        table_join = self.table_joins.get(table_name)
        if table_join:
            table_join['need'] = True
            for sub_table_name in table_join['table_need']:
                self.table_need_add(sub_table_name)

    def fields_process(self):
        if not self.fields:
            return
        for column in self.q.statement.columns:
            column_obj = list(column.base_columns)[0]
            column_el = {
                'name': column.name,
                'column': column_obj,
                'table_name': str(column_obj.table.name)
            }
            self.column_items.append(column_el)
            self.table_need_add(column_el['table_name'])

    def where_process_item(self, clause):
        if clause is None:
            return
        if iterable(clause):
            for el in clause:
                self.where_process_item(el)
            return
        items = self.clause_columns(clause)
        for item in items:
            self.table_need_add(item['table_name'])

    def where_process(self):
        if not self.where:
            return
        if not iterable(self.q.whereclause):
            clauses = self.q.whereclause
        else:
            clauses = self.q.whereclause.clauses
        self.where_process_item(clauses)

    def query_process(self):
        if not self.table_joins:
            return

        self.q._from_obj = (self.table_main, )
        for table in self.table_joins.values():
            if not table['need']:
                continue
            self.q._from_obj = (outerjoin(self.q._from_obj[0], table['table'], table['onclause']), )

    def process(self):
        self.table_joins_process()
        self.fields_process()
        self.where_process()
        self.query_process()
        return self.q

Keep in mind that is home project code.

Sorry for my english (=

@ivcuello
Copy link
Copy Markdown

ivcuello commented Jan 5, 2021

If anyone is having issues of the ORM "dropping" the FROM statement all together. In my case it was solved by referencing a column in the func.count() call. Instead of using a literal like func.count(1) just reference a column (func.count(Table.id)) & the ORM won't delete the from clause in the resulting query.

@kigawas
Copy link
Copy Markdown

kigawas commented Apr 9, 2021

Nice function! FYI it does not consider the query limit.

So i extended the function for my usage.

count = q.session.execute(count_q).scalar()
return min(q._limit, count) if q._limit else count

THIS METHOD DOES NOT APPLY TO QUERIES WITH LIMIT AND OFFSET

What if I want to know the total count? Normally in an API's pagination, (page, page_size, total_count) is needed.

Say

SELECT  count(*) AS total_count
FROM user
WHERE user.is_deleted = false
 LIMIT 20 OFFSET 5;

If you execute this in a DB, it'll return nothing, which is None in Python. In fact, the count_q cannot include limit and offset:

image

Although you can work around like this, it still isn't ideal due to overheads.

Solution

To avoid using limit and offset in the query.

@transfluxus
Copy link
Copy Markdown

Hi,
somebody an idea how that looks like in sqlalchemy 1.4?
I am getting:

sqlalchemy.exc.ArgumentError: Query has only expression-based entities - can't find property named "template". where template is something from my db I guess

@kigawas
Copy link
Copy Markdown

kigawas commented Oct 25, 2021

⚠️ YOU DO NOT NEED THIS ⚠️

https://dba.stackexchange.com/questions/168022/performance-of-count-in-subquery

⚠️ IF YOUR QUERY IS SLOW, YOU SHOULD REMOVE ORDER BY FIRST ⚠️

count_stmt = select(func.count()).select_from(statement.order_by(None).subquery())
await session.scalar(count_stmt)

@davidjb99
Copy link
Copy Markdown

davidjb99 commented Nov 8, 2021

This worked perfectly for me after I upgraded from sa 1.3 to 1.4 and a count query went from 80ms to 800ms.

by switching form the built in .count() to the suggested first gist the query went back to 80ms. I believe the problem was .count() loading all columns into python which is not required and very slow for thousands or rows, using with_only_columns and removing the sub query took it back to 80ms. No idea what broke it in 1.4.

    count_q = q.statement.with_only_columns([func.count()]).order_by(None)
    count = q.session.execute(count_q).scalar()

I'm not sure why people are putting warnings on this gist or writing very long replies expecting help. If you are stuck ask on Stackoverflow!

My thanks to @hest

@kigawas
Copy link
Copy Markdown

kigawas commented Nov 9, 2021

This worked perfectly for me after I upgraded from sa 1.3 to 1.4 and a count query went from 80ms to 800ms.

by switching form the built in .count() to the suggested first gist the query went back to 80ms. I believe the problem was .count() loading all columns into python which is not required and very slow for thousands or rows, using with_only_columns and removing the sub query took it back to 80ms. No idea what broke it in 1.4.

    count_q = q.statement.with_only_columns([func.count()]).order_by(None)
    count = q.session.execute(count_q).scalar()

I'm not sure why people are putting warnings on this gist or writing very long replies expecting help. If you are stuck ask on Stackoverflow!

My thanks to @hest

This thread is generally a kind of misinformation. If you check the stackoverflow link above, these sqls are exactly the same:

-- on postgres
EXPLAIN ANALYZE SELECT COUNT(*) FROM some_big_table WHERE some_col = 'some_val'
EXPLAIN ANALYZE SELECT COUNT(*) FROM ( SELECT col1, col2, col3, col4 FROM some_big_table WHERE some_col = 'some_val' )

If you find your query executing slow, the first is to try removing order by.

@davidjb99
Copy link
Copy Markdown

If you find your query executing slow, the first is to try removing order by.

I tried this, and the methods outlined in the stackoverflow link, and it did not work. Please don't say I'm providing misinformation by outlining what worked for me it is rude.

@kigawas
Copy link
Copy Markdown

kigawas commented Dec 20, 2021

@davidjb99

Sorry for letting you misunderstand. I meant the gist (which was posted in 2015) is likely misinformation now, not your comment.

@kannasuresh99
Copy link
Copy Markdown

This is the method, I'm using

def get_count(self, model_fields, filter_clause):
        """ Note: filter_clause should not be 'None' or 'Null' for this method to work """
        query = self.session.query().with_entities(*model_fields)
        query = query.filter(filter_clause)
        count_query = query.statement \
                .with_only_columns([func.count()]) \
                .order_by(None)
        result = query.session.execute(count_query).scalar()
        return result

@saulin18
Copy link
Copy Markdown

saulin18 commented Apr 3, 2026

use over()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment