Skip to content

Instantly share code, notes, and snippets.

@blepping
Last active March 10, 2025 12:45
Show Gist options
  • Save blepping/13da8e9b8c2ac27c07509391acf3fc46 to your computer and use it in GitHub Desktop.
Save blepping/13da8e9b8c2ac27c07509391acf3fc46 to your computer and use it in GitHub Desktop.
Simple ComfyUI nodes to force use of FlashAttention2
# By https://github.com/blepping
# License: Apache2
#
# Usage: Place in ComfyUI's custom_nodes directory.
# It will add BlehFlashAttentionSampler and BlehGlobalFlashAttention nodes.
# Requires FlashAttention2 installed into the Python venv: https://github.com/Dao-AILab/flash-attention
#
from __future__ import annotations
import contextlib
from typing import TYPE_CHECKING
import comfy.ldm.modules.attention as comfyattn
import yaml
from comfy.samplers import KSAMPLER
from flash_attn import flash_attn_func
if TYPE_CHECKING:
import collections
from collections.abc import Callable
import torch
def attention_flash(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
heads: int,
*,
orig_attention: Callable, # noqa: ARG001
flashattn_function: collections.abc.Callable = flash_attn_func,
flashattn_verbose: bool = False,
**kwargs: dict[str],
) -> torch.Tensor:
_ = kwargs.pop("mask", None)
_ = kwargs.pop("attn_precision", None)
skip_reshape = kwargs.pop("skip_reshape", False)
skip_output_reshape = kwargs.pop("skip_output_reshape", False)
batch = q.shape[0]
dim_head = q.shape[-1] // (1 if skip_reshape else heads)
if flashattn_verbose:
print(
f"\n>> FLASH: reshape={not skip_reshape}, output_reshape={not skip_output_reshape}, dim_head={q.shape[-1]}, heads={heads}, adj_heads={dim_head}, q={q.shape}, k={k.shape}, v={v.shape}, args: {kwargs}\n",
)
if skip_reshape:
q, k, v = (t.transpose(1, 2) for t in (q, k, v))
else:
q, k, v = (t.view(batch, -1, heads, dim_head) for t in (q, k, v))
softmax_scale_hd = kwargs.pop(f"softmax_scale_{dim_head}", None)
if softmax_scale_hd is not None:
kwargs["softmax_scale"] = softmax_scale_hd
result = flashattn_function(
q,
k,
v,
causal=False,
dropout_p=0.0,
**kwargs,
)
if skip_output_reshape:
return result.transpose(1, 2)
return result.reshape(batch, -1, heads * dim_head)
def copy_funattrs(fun, dest=None):
if dest is None:
dest = fun.__class__(fun.__code__, fun.__globals__)
for k in (
"__code__",
"__defaults__",
"__kwdefaults__",
"__module__",
):
setattr(dest, k, getattr(fun, k))
return dest
def make_flashattn_wrapper(
*,
orig_attn,
**kwargs: dict,
):
outer_kwargs = kwargs
def attn(
*args: list,
_flash_outer_kwargs=outer_kwargs,
_flash_orig_attention=orig_attn,
_flash_attn=attention_flash,
**kwargs: dict,
) -> torch.Tensor:
return _flash_attn(
*args,
orig_attention=_flash_orig_attention,
**_flash_outer_kwargs,
**kwargs,
)
return attn
@contextlib.contextmanager
def flashattn_context(
enabled: bool,
**kwargs: dict,
):
if not enabled:
yield None
return
orig_attn = copy_funattrs(comfyattn.optimized_attention)
attn = make_flashattn_wrapper(orig_attn=orig_attn, **kwargs)
try:
copy_funattrs(attn, comfyattn.optimized_attention)
yield None
finally:
copy_funattrs(orig_attn, comfyattn.optimized_attention)
def get_yaml_parameters(yaml_parameters: str | None = None) -> dict:
if not yaml_parameters:
return {}
extra_params = yaml.safe_load(yaml_parameters)
if extra_params is None:
return {}
if not isinstance(extra_params, dict):
raise ValueError( # noqa: TRY004
"BlehFlashAttention: yaml_parameters must either be null or an object",
)
return extra_params
class BlehGlobalFlashAttention:
DESCRIPTION = "Deprecated: Prefer using BlehFlashAttentionSampler if possible. This node allows globally replacing ComfyUI's attention with SageAtteniton (performance enhancement). Requires SageAttention to be installed into the ComfyUI Python environment. IMPORTANT: This is not a normal model patch. For settings to apply (including toggling on or off) the node must actually be run. If you toggle it on, run your workflow and then bypass or mute the node this will not actually disable SageAttention."
RETURN_TYPES = ("MODEL",)
FUNCTION = "go"
CATEGORY = "hacks"
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"model": ("MODEL",),
"enabled": (
"BOOLEAN",
{"default": True},
),
},
"optional": {
"yaml_parameters": (
"STRING",
{
"tooltip": "Allows specifying custom parameters via YAML. These are mostly passed directly to the SageAttention function with no error checking. Must be empty or a YAML object.",
"dynamicPrompts": False,
"multiline": True,
"defaultInput": True,
},
),
},
}
orig_attn = None
@classmethod
def go(
cls,
*,
model: object,
enabled: bool,
yaml_parameters: str | None = None,
) -> tuple:
if not enabled:
if cls.orig_attn is not None:
copy_funattrs(cls.orig_attn, comfyattn.optimized_attention)
cls.orig_attn = None
return (model,)
if not cls.orig_attn:
cls.orig_attn = copy_funattrs(comfyattn.optimized_attention)
attn = make_flashattn_wrapper(
orig_attn=cls.orig_attn,
**get_yaml_parameters(yaml_parameters),
)
copy_funattrs(attn, comfyattn.optimized_attention)
return (model,)
def flashattn_sampler(
model: object,
x: torch.Tensor,
sigmas: torch.Tensor,
*,
flashattn_sampler_options: tuple,
**kwargs: dict,
) -> torch.Tensor:
sampler, start_percent, end_percent, flashattn_kwargs = flashattn_sampler_options
ms = model.inner_model.inner_model.model_sampling
start_sigma, end_sigma = (
round(ms.percent_to_sigma(start_percent), 4),
round(ms.percent_to_sigma(end_percent), 4),
)
del ms
def model_wrapper(
x: torch.Tensor,
sigma: torch.Tensor,
**extra_args: dict[str],
) -> torch.Tensor:
sigma_float = float(sigma.max().detach().cpu())
enabled = end_sigma <= sigma_float <= start_sigma
with flashattn_context(
enabled=enabled,
**flashattn_kwargs,
):
return model(x, sigma, **extra_args)
for k in (
"inner_model",
"sigmas",
):
if hasattr(model, k):
setattr(model_wrapper, k, getattr(model, k))
return sampler.sampler_function(
model_wrapper,
x,
sigmas,
**kwargs,
**sampler.extra_options,
)
class BlehFlashAttentionSampler:
DESCRIPTION = "Sampler wrapper that enables using SageAttention (performance enhancement) while sampling is in progress. Requires SageAttention to be installed into the ComfyUI Python environment."
CATEGORY = "sampling/custom_sampling/samplers"
RETURN_TYPES = ("SAMPLER",)
FUNCTION = "go"
@classmethod
def INPUT_TYPES(cls) -> dict:
return {
"required": {
"sampler": ("SAMPLER",),
},
"optional": {
"start_percent": (
"FLOAT",
{
"default": 0.0,
"min": 0.0,
"max": 1.0,
"step": 0.001,
"tooltip": "Time the effect becomes active as a percentage of sampling, not steps.",
},
),
"end_percent": (
"FLOAT",
{
"default": 1.0,
"min": 0.0,
"max": 1.0,
"step": 0.001,
"tooltip": "Time the effect ends (inclusive) as a percentage of sampling, not steps.",
},
),
"yaml_parameters": (
"STRING",
{
"tooltip": "Allows specifying custom parameters via YAML. These are mostly passed directly to the SageAttention function with no error checking. Must be empty or a YAML object.",
"dynamicPrompts": False,
"multiline": True,
"defaultInput": True,
},
),
},
}
@classmethod
def go(
cls,
sampler: object,
*,
start_percent: float = 0.0,
end_percent: float = 1.0,
yaml_parameters: str | None = None,
) -> tuple:
return (
KSAMPLER(
flashattn_sampler,
extra_options={
"flashattn_sampler_options": (
sampler,
start_percent,
end_percent,
get_yaml_parameters(yaml_parameters),
),
},
),
)
NODE_CLASS_MAPPINGS = {
"BlehGlobalFlashAttention": BlehGlobalFlashAttention,
"BlehFlashAttentionSampler": BlehFlashAttentionSampler,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment