Last active
February 3, 2025 17:38
-
-
Save dicknetherlands/2f6e8619409fa155a05b3a863f10269a to your computer and use it in GitHub Desktop.
Graphene speedup
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This gist shows how to speed up graphene_django by short-cutting the field and type resolution of the returned JSON and | |
# using a bit of caching to avoid having to repeat our discovery/decision process across multiple fields of the same | |
# type. It relies on trusting the developer to always return the correct types and respect non-nullability. | |
# Assumes that: | |
# 1. You're not using async code | |
# 2. You're using graphql_sync_dataloader to solve the N+1 problem | |
# 3. You're not using any graphene middleware other than for authentication | |
# 4. Your resolvers will all respect the schema | |
# 5. See code comments for further limitations | |
# import the view and context defined below then reference them in urls.py...: | |
urlpatterns = [ | |
path( | |
"gql", | |
StreamlinedGraphQLView.as_view( | |
graphiql=True, | |
schema=graphene.Schema(query=MyQueryObject, mutation=MyMutationObject), | |
execution_context_class=StreamlinedExecutionContext, | |
) | |
name="gql", | |
), | |
] | |
# Override the default graphql-core schema validator to stop it firing on every query. This saves a bit of time on | |
# large or complex schemas but it also needs you to trust the schema to be valid. | |
graphql.type.validate.validate_schema = lambda x: list() | |
# the streamlined view | |
class StreamlinedGraphQLView(GraphQLView): | |
def get_context(self, request) -> WSGIRequest: | |
""" | |
StreamlinedExecutionContext bypasses all Graphene middleware for efficiency. That includes authentication | |
code if you're using e.g. graphql_jwt. This method introduces a one-off authentication injection per | |
request instead of injecting it per field via middleware. If your authentication middleware does anything | |
more complex than just injecting the user object then you would need to add that logic to this method. | |
""" | |
context = super().get_context(request) | |
context.user = authenticate(request=context) or AnonymousUser() | |
return context | |
def json_encode(self, request, d, pretty=False) -> str: | |
""" | |
The StreamlinedExecutionContext may include non-serializable SyncFuture objects. We override the original | |
JSON encoding method to account for these. We don't bother checking for is_instance(SyncFuture) because | |
we know that's the only non-standard type that'll be present in the response if we trust our developers | |
not to have introduced anything else. | |
""" | |
def _json_encode(value): | |
return value.result() | |
if not (self.pretty or pretty) and not request.GET.get("pretty"): | |
return json.dumps(d, separators=(",", ":"), default=_json_encode) | |
return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": "), default=_json_encode) | |
# this enum is used to help classify fields and what to do with each of them | |
class PathFieldType(Enum): | |
DICT_KEY = 0 | |
CALLABLE_DICT_KEY = 1 | |
OBJECT_PROPERTY = 2 | |
CALLABLE_OBJECT_PROPERTY = 3 | |
LIST = 4 | |
OBJECT = 5 | |
RAW_SCALAR = 6 | |
SERIALIZE_SCALAR = 7 | |
SYNC_FUTURE_OBJECT = 8 | |
SYNC_FUTURE_LIST = 9 | |
# the streamlined context | |
class StreamlinedExecutionContext(ExecutionContext): | |
""" | |
ExecutionContext that knows how to work with SyncFuture objects and also speeds up field resolution | |
(by reusing ideas from graphql_sync_dataloader.DeferredExecutionContext). | |
Key differences between this context and the default ExecutionContext: | |
1. no null checks on non-null fields | |
2. abstract types are not supported | |
3. middleware is not supported | |
4. async execution is not supported | |
5. no type checking | |
6. returning an Exception as the value for a field is not supported; Exceptions must be raised | |
7. no lists of lists | |
8. serialize() is not called on any int/float/bool/str values | |
9. if the first non-null field on an object of type X is of type Y, assume the same field on all other objects | |
of type X will also be of type Y | |
10. individual scalars cannot be SyncFutures | |
11. all items in a list must be of the same type | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self._deferred_callbacks: list[Callable] = [] | |
self._cached_path_types: dict[str, PathFieldType] = {} | |
self._cached_nullable_types: set[str] = set() | |
def _unindexed_path(self, path: Path) -> str: | |
# Get a path that ignores list indices, so that in the cache we can assume all | |
# objects that appear in the same list share the same type | |
return ".".join(p for p in path.as_list() if not isinstance(p, int)) | |
def default_field_resolver( | |
self, path_key: str, source: Any, info: GraphQLResolveInfo, **args: Any | |
) -> Any: | |
field_name = info.field_name | |
if path_key not in self._cached_path_types: | |
if isinstance(source, Mapping): | |
if callable(source.get(field_name)): | |
self._cached_path_types[path_key] = PathFieldType.CALLABLE_DICT_KEY | |
else: | |
self._cached_path_types[path_key] = PathFieldType.DICT_KEY | |
else: | |
if callable(getattr(source, field_name, None)): | |
self._cached_path_types[path_key] = PathFieldType.CALLABLE_OBJECT_PROPERTY | |
else: | |
self._cached_path_types[path_key] = PathFieldType.OBJECT_PROPERTY | |
match self._cached_path_types[path_key]: | |
case PathFieldType.DICT_KEY: | |
return source.get(field_name) | |
case PathFieldType.OBJECT_PROPERTY: | |
return getattr(source, field_name, None) | |
case PathFieldType.CALLABLE_OBJECT_PROPERTY: | |
return getattr(source, field_name)(info, **args) | |
case PathFieldType.CALLABLE_DICT_KEY: | |
return source.get(field_name)(info, **args) | |
case _: | |
raise NotImplementedError() | |
def execute_operation( | |
self, operation: OperationDefinitionNode, root_value: Any | |
) -> Optional[AwaitableOrValue[Any]]: | |
self._cached_path_types.clear() | |
self._cached_nullable_types.clear() | |
self._deferred_callbacks.clear() | |
result = super().execute_operation(operation, root_value) | |
while self._deferred_callbacks: | |
self._deferred_callbacks.pop(0)() | |
return result | |
def execute_field( | |
self, | |
parent_type: GraphQLObjectType, | |
source: Any, | |
field_nodes: list[FieldNode], | |
path: Path, | |
) -> AwaitableOrValue[Any]: | |
field_def = get_field_def(self.schema, parent_type, field_nodes[0]) | |
if not field_def: | |
return Undefined | |
return_type = field_def.type | |
path_key = self._unindexed_path(path) | |
resolve_fn = field_def.resolve or partial(self.default_field_resolver, path_key) | |
info = self.build_resolve_info(field_def, field_nodes, parent_type, path) | |
try: | |
args = get_argument_values(field_def, field_nodes[0], self.variable_values) | |
result = resolve_fn(source, info, **args) | |
return self.streamlined_complete_value(return_type, field_nodes, path, path_key, result) | |
except Exception as raw_error: | |
error = located_error(raw_error, field_nodes, path.as_list()) | |
return self.handle_field_error(error, return_type) | |
def _get_path_key_type( | |
self, path_key: str, return_type: GraphQLOutputType, result: Any | |
) -> tuple[PathFieldType, GraphQLOutputType]: | |
if path_key not in self._cached_path_types: | |
# Unwrap non-null - we don't care about nullability | |
check_return_type = return_type | |
if is_non_null_type(return_type): | |
check_return_type = return_type.of_type | |
self._cached_nullable_types.add(path_key) | |
# Detect and cache actual path type | |
if is_list_type(check_return_type): | |
if isinstance(result, SyncFuture): | |
self._cached_path_types[path_key] = PathFieldType.SYNC_FUTURE_LIST | |
else: | |
self._cached_path_types[path_key] = PathFieldType.LIST | |
elif is_leaf_type(check_return_type): | |
if isinstance(result, (int, float, bool, str)): | |
self._cached_path_types[path_key] = PathFieldType.RAW_SCALAR | |
else: | |
self._cached_path_types[path_key] = PathFieldType.SERIALIZE_SCALAR | |
elif is_object_type(check_return_type): | |
if isinstance(result, SyncFuture): | |
self._cached_path_types[path_key] = PathFieldType.SYNC_FUTURE_OBJECT | |
else: | |
self._cached_path_types[path_key] = PathFieldType.OBJECT | |
else: | |
raise NotImplementedError() | |
if path_key in self._cached_nullable_types: | |
return_type = return_type.of_type | |
return self._cached_path_types[path_key], return_type | |
def _defer_syncfuture( | |
self, | |
return_type: GraphQLOutputType, | |
field_nodes: list[FieldNode], | |
path: Path, | |
path_type: PathFieldType, | |
result: SyncFuture, | |
): | |
if result.deferred_callback is not None: | |
self._deferred_callbacks.append(result.deferred_callback) | |
# noinspection PyShadowingNames | |
def _process_result(): | |
try: | |
future.set_result(self._resolve_syncfuture(return_type, field_nodes, path, path_type, result)) | |
except Exception as raw_error: | |
error = located_error(raw_error, field_nodes, path.as_list()) | |
future.set_result(self.handle_field_error(error, return_type)) | |
future = SyncFuture() | |
result.add_done_callback(_process_result) | |
return future | |
def _resolve_syncfuture( | |
self, | |
return_type: GraphQLOutputType, | |
field_nodes: list[FieldNode], | |
path: Path, | |
path_type: PathFieldType, | |
result: SyncFuture, | |
) -> Any: | |
if path_type == PathFieldType.SYNC_FUTURE_LIST: | |
return [ | |
self.execute_fields( | |
return_type.of_type, | |
item.result(), | |
path.add_key(index, None), | |
self.collect_subfields(return_type.of_type, field_nodes), | |
) | |
if item.done() | |
else self._defer_syncfuture( | |
return_type.of_type, | |
field_nodes, | |
path.add_key(index, None), | |
PathFieldType.SYNC_FUTURE_OBJECT, | |
item, | |
) | |
for index, item in enumerate(result.result()) | |
] | |
elif path_type == PathFieldType.SYNC_FUTURE_OBJECT: | |
return self.execute_fields( | |
return_type, result.result(), path, self.collect_subfields(return_type, field_nodes) | |
) | |
else: | |
raise NotImplementedError() | |
def _process_result( | |
self, | |
return_type: GraphQLOutputType, | |
field_nodes: list[FieldNode], | |
path: Path, | |
path_type: PathFieldType, | |
result: Any, | |
) -> Any: | |
if result is None or result is Undefined: | |
return None | |
match path_type: | |
case PathFieldType.RAW_SCALAR: | |
return result | |
case PathFieldType.SERIALIZE_SCALAR: | |
return return_type.serialize(result) | |
case PathFieldType.OBJECT: | |
return self.execute_fields( | |
return_type, result, path, self.collect_subfields(return_type, field_nodes) | |
) | |
case PathFieldType.SYNC_FUTURE_OBJECT | PathFieldType.SYNC_FUTURE_LIST: | |
if result.done(): | |
return self._resolve_syncfuture(return_type, field_nodes, path, path_type, result) | |
else: | |
return self._defer_syncfuture(return_type, field_nodes, path, path_type, result) | |
case _: | |
raise NotImplementedError() | |
def streamlined_complete_value( | |
self, | |
return_type: GraphQLOutputType, | |
field_nodes: list[FieldNode], | |
path: Path, | |
path_key: str, | |
result: Any, | |
) -> AwaitableOrValue[Any]: | |
path_type, return_type = self._get_path_key_type(path_key, return_type, result) | |
# Depending on path type, return the result | |
if path_type == PathFieldType.LIST: | |
# Shortcut list generation to prevent excessive rediscovery of types by assuming | |
# all items in list are of the same type, so that we can use one path key for all | |
results = [] | |
if result: | |
item_path_key = f"{path_key}_LIST_ITEM" | |
item_path_type, item_return_type = None, None | |
for index, item in enumerate(result): | |
if item is not None and item_return_type is None: | |
item_path_type, item_return_type = self._get_path_key_type( | |
item_path_key, return_type.of_type, item | |
) | |
results.append( | |
self._process_result( | |
item_return_type, | |
field_nodes, | |
path.add_key(index, None), | |
item_path_type, | |
item, | |
) | |
) | |
return results | |
return self._process_result(return_type, field_nodes, path, path_type, result) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment