Skip to content

Instantly share code, notes, and snippets.

@dicknetherlands
Last active February 3, 2025 17:38
Show Gist options
  • Save dicknetherlands/2f6e8619409fa155a05b3a863f10269a to your computer and use it in GitHub Desktop.
Save dicknetherlands/2f6e8619409fa155a05b3a863f10269a to your computer and use it in GitHub Desktop.
Graphene speedup
# This gist shows how to speed up graphene_django by short-cutting the field and type resolution of the returned JSON and
# using a bit of caching to avoid having to repeat our discovery/decision process across multiple fields of the same
# type. It relies on trusting the developer to always return the correct types and respect non-nullability.
# Assumes that:
# 1. You're not using async code
# 2. You're using graphql_sync_dataloader to solve the N+1 problem
# 3. You're not using any graphene middleware other than for authentication
# 4. Your resolvers will all respect the schema
# 5. See code comments for further limitations
# import the view and context defined below then reference them in urls.py...:
urlpatterns = [
path(
"gql",
StreamlinedGraphQLView.as_view(
graphiql=True,
schema=graphene.Schema(query=MyQueryObject, mutation=MyMutationObject),
execution_context_class=StreamlinedExecutionContext,
)
name="gql",
),
]
# Override the default graphql-core schema validator to stop it firing on every query. This saves a bit of time on
# large or complex schemas but it also needs you to trust the schema to be valid.
graphql.type.validate.validate_schema = lambda x: list()
# the streamlined view
class StreamlinedGraphQLView(GraphQLView):
def get_context(self, request) -> WSGIRequest:
"""
StreamlinedExecutionContext bypasses all Graphene middleware for efficiency. That includes authentication
code if you're using e.g. graphql_jwt. This method introduces a one-off authentication injection per
request instead of injecting it per field via middleware. If your authentication middleware does anything
more complex than just injecting the user object then you would need to add that logic to this method.
"""
context = super().get_context(request)
context.user = authenticate(request=context) or AnonymousUser()
return context
def json_encode(self, request, d, pretty=False) -> str:
"""
The StreamlinedExecutionContext may include non-serializable SyncFuture objects. We override the original
JSON encoding method to account for these. We don't bother checking for is_instance(SyncFuture) because
we know that's the only non-standard type that'll be present in the response if we trust our developers
not to have introduced anything else.
"""
def _json_encode(value):
return value.result()
if not (self.pretty or pretty) and not request.GET.get("pretty"):
return json.dumps(d, separators=(",", ":"), default=_json_encode)
return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": "), default=_json_encode)
# this enum is used to help classify fields and what to do with each of them
class PathFieldType(Enum):
DICT_KEY = 0
CALLABLE_DICT_KEY = 1
OBJECT_PROPERTY = 2
CALLABLE_OBJECT_PROPERTY = 3
LIST = 4
OBJECT = 5
RAW_SCALAR = 6
SERIALIZE_SCALAR = 7
SYNC_FUTURE_OBJECT = 8
SYNC_FUTURE_LIST = 9
# the streamlined context
class StreamlinedExecutionContext(ExecutionContext):
"""
ExecutionContext that knows how to work with SyncFuture objects and also speeds up field resolution
(by reusing ideas from graphql_sync_dataloader.DeferredExecutionContext).
Key differences between this context and the default ExecutionContext:
1. no null checks on non-null fields
2. abstract types are not supported
3. middleware is not supported
4. async execution is not supported
5. no type checking
6. returning an Exception as the value for a field is not supported; Exceptions must be raised
7. no lists of lists
8. serialize() is not called on any int/float/bool/str values
9. if the first non-null field on an object of type X is of type Y, assume the same field on all other objects
of type X will also be of type Y
10. individual scalars cannot be SyncFutures
11. all items in a list must be of the same type
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._deferred_callbacks: list[Callable] = []
self._cached_path_types: dict[str, PathFieldType] = {}
self._cached_nullable_types: set[str] = set()
def _unindexed_path(self, path: Path) -> str:
# Get a path that ignores list indices, so that in the cache we can assume all
# objects that appear in the same list share the same type
return ".".join(p for p in path.as_list() if not isinstance(p, int))
def default_field_resolver(
self, path_key: str, source: Any, info: GraphQLResolveInfo, **args: Any
) -> Any:
field_name = info.field_name
if path_key not in self._cached_path_types:
if isinstance(source, Mapping):
if callable(source.get(field_name)):
self._cached_path_types[path_key] = PathFieldType.CALLABLE_DICT_KEY
else:
self._cached_path_types[path_key] = PathFieldType.DICT_KEY
else:
if callable(getattr(source, field_name, None)):
self._cached_path_types[path_key] = PathFieldType.CALLABLE_OBJECT_PROPERTY
else:
self._cached_path_types[path_key] = PathFieldType.OBJECT_PROPERTY
match self._cached_path_types[path_key]:
case PathFieldType.DICT_KEY:
return source.get(field_name)
case PathFieldType.OBJECT_PROPERTY:
return getattr(source, field_name, None)
case PathFieldType.CALLABLE_OBJECT_PROPERTY:
return getattr(source, field_name)(info, **args)
case PathFieldType.CALLABLE_DICT_KEY:
return source.get(field_name)(info, **args)
case _:
raise NotImplementedError()
def execute_operation(
self, operation: OperationDefinitionNode, root_value: Any
) -> Optional[AwaitableOrValue[Any]]:
self._cached_path_types.clear()
self._cached_nullable_types.clear()
self._deferred_callbacks.clear()
result = super().execute_operation(operation, root_value)
while self._deferred_callbacks:
self._deferred_callbacks.pop(0)()
return result
def execute_field(
self,
parent_type: GraphQLObjectType,
source: Any,
field_nodes: list[FieldNode],
path: Path,
) -> AwaitableOrValue[Any]:
field_def = get_field_def(self.schema, parent_type, field_nodes[0])
if not field_def:
return Undefined
return_type = field_def.type
path_key = self._unindexed_path(path)
resolve_fn = field_def.resolve or partial(self.default_field_resolver, path_key)
info = self.build_resolve_info(field_def, field_nodes, parent_type, path)
try:
args = get_argument_values(field_def, field_nodes[0], self.variable_values)
result = resolve_fn(source, info, **args)
return self.streamlined_complete_value(return_type, field_nodes, path, path_key, result)
except Exception as raw_error:
error = located_error(raw_error, field_nodes, path.as_list())
return self.handle_field_error(error, return_type)
def _get_path_key_type(
self, path_key: str, return_type: GraphQLOutputType, result: Any
) -> tuple[PathFieldType, GraphQLOutputType]:
if path_key not in self._cached_path_types:
# Unwrap non-null - we don't care about nullability
check_return_type = return_type
if is_non_null_type(return_type):
check_return_type = return_type.of_type
self._cached_nullable_types.add(path_key)
# Detect and cache actual path type
if is_list_type(check_return_type):
if isinstance(result, SyncFuture):
self._cached_path_types[path_key] = PathFieldType.SYNC_FUTURE_LIST
else:
self._cached_path_types[path_key] = PathFieldType.LIST
elif is_leaf_type(check_return_type):
if isinstance(result, (int, float, bool, str)):
self._cached_path_types[path_key] = PathFieldType.RAW_SCALAR
else:
self._cached_path_types[path_key] = PathFieldType.SERIALIZE_SCALAR
elif is_object_type(check_return_type):
if isinstance(result, SyncFuture):
self._cached_path_types[path_key] = PathFieldType.SYNC_FUTURE_OBJECT
else:
self._cached_path_types[path_key] = PathFieldType.OBJECT
else:
raise NotImplementedError()
if path_key in self._cached_nullable_types:
return_type = return_type.of_type
return self._cached_path_types[path_key], return_type
def _defer_syncfuture(
self,
return_type: GraphQLOutputType,
field_nodes: list[FieldNode],
path: Path,
path_type: PathFieldType,
result: SyncFuture,
):
if result.deferred_callback is not None:
self._deferred_callbacks.append(result.deferred_callback)
# noinspection PyShadowingNames
def _process_result():
try:
future.set_result(self._resolve_syncfuture(return_type, field_nodes, path, path_type, result))
except Exception as raw_error:
error = located_error(raw_error, field_nodes, path.as_list())
future.set_result(self.handle_field_error(error, return_type))
future = SyncFuture()
result.add_done_callback(_process_result)
return future
def _resolve_syncfuture(
self,
return_type: GraphQLOutputType,
field_nodes: list[FieldNode],
path: Path,
path_type: PathFieldType,
result: SyncFuture,
) -> Any:
if path_type == PathFieldType.SYNC_FUTURE_LIST:
return [
self.execute_fields(
return_type.of_type,
item.result(),
path.add_key(index, None),
self.collect_subfields(return_type.of_type, field_nodes),
)
if item.done()
else self._defer_syncfuture(
return_type.of_type,
field_nodes,
path.add_key(index, None),
PathFieldType.SYNC_FUTURE_OBJECT,
item,
)
for index, item in enumerate(result.result())
]
elif path_type == PathFieldType.SYNC_FUTURE_OBJECT:
return self.execute_fields(
return_type, result.result(), path, self.collect_subfields(return_type, field_nodes)
)
else:
raise NotImplementedError()
def _process_result(
self,
return_type: GraphQLOutputType,
field_nodes: list[FieldNode],
path: Path,
path_type: PathFieldType,
result: Any,
) -> Any:
if result is None or result is Undefined:
return None
match path_type:
case PathFieldType.RAW_SCALAR:
return result
case PathFieldType.SERIALIZE_SCALAR:
return return_type.serialize(result)
case PathFieldType.OBJECT:
return self.execute_fields(
return_type, result, path, self.collect_subfields(return_type, field_nodes)
)
case PathFieldType.SYNC_FUTURE_OBJECT | PathFieldType.SYNC_FUTURE_LIST:
if result.done():
return self._resolve_syncfuture(return_type, field_nodes, path, path_type, result)
else:
return self._defer_syncfuture(return_type, field_nodes, path, path_type, result)
case _:
raise NotImplementedError()
def streamlined_complete_value(
self,
return_type: GraphQLOutputType,
field_nodes: list[FieldNode],
path: Path,
path_key: str,
result: Any,
) -> AwaitableOrValue[Any]:
path_type, return_type = self._get_path_key_type(path_key, return_type, result)
# Depending on path type, return the result
if path_type == PathFieldType.LIST:
# Shortcut list generation to prevent excessive rediscovery of types by assuming
# all items in list are of the same type, so that we can use one path key for all
results = []
if result:
item_path_key = f"{path_key}_LIST_ITEM"
item_path_type, item_return_type = None, None
for index, item in enumerate(result):
if item is not None and item_return_type is None:
item_path_type, item_return_type = self._get_path_key_type(
item_path_key, return_type.of_type, item
)
results.append(
self._process_result(
item_return_type,
field_nodes,
path.add_key(index, None),
item_path_type,
item,
)
)
return results
return self._process_result(return_type, field_nodes, path, path_type, result)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment