Last active
March 10, 2025 12:45
-
-
Save blepping/13da8e9b8c2ac27c07509391acf3fc46 to your computer and use it in GitHub Desktop.
Simple ComfyUI nodes to force use of FlashAttention2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# By https://github.com/blepping | |
# License: Apache2 | |
# | |
# Usage: Place in ComfyUI's custom_nodes directory. | |
# It will add BlehFlashAttentionSampler and BlehGlobalFlashAttention nodes. | |
# Requires FlashAttention2 installed into the Python venv: https://github.com/Dao-AILab/flash-attention | |
# | |
from __future__ import annotations | |
import contextlib | |
from typing import TYPE_CHECKING | |
import comfy.ldm.modules.attention as comfyattn | |
import yaml | |
from comfy.samplers import KSAMPLER | |
from flash_attn import flash_attn_func | |
if TYPE_CHECKING: | |
import collections | |
from collections.abc import Callable | |
import torch | |
def attention_flash( | |
q: torch.Tensor, | |
k: torch.Tensor, | |
v: torch.Tensor, | |
heads: int, | |
*, | |
orig_attention: Callable, # noqa: ARG001 | |
flashattn_function: collections.abc.Callable = flash_attn_func, | |
flashattn_verbose: bool = False, | |
**kwargs: dict[str], | |
) -> torch.Tensor: | |
_ = kwargs.pop("mask", None) | |
_ = kwargs.pop("attn_precision", None) | |
skip_reshape = kwargs.pop("skip_reshape", False) | |
skip_output_reshape = kwargs.pop("skip_output_reshape", False) | |
batch = q.shape[0] | |
dim_head = q.shape[-1] // (1 if skip_reshape else heads) | |
if flashattn_verbose: | |
print( | |
f"\n>> FLASH: reshape={not skip_reshape}, output_reshape={not skip_output_reshape}, dim_head={q.shape[-1]}, heads={heads}, adj_heads={dim_head}, q={q.shape}, k={k.shape}, v={v.shape}, args: {kwargs}\n", | |
) | |
if skip_reshape: | |
q, k, v = (t.transpose(1, 2) for t in (q, k, v)) | |
else: | |
q, k, v = (t.view(batch, -1, heads, dim_head) for t in (q, k, v)) | |
softmax_scale_hd = kwargs.pop(f"softmax_scale_{dim_head}", None) | |
if softmax_scale_hd is not None: | |
kwargs["softmax_scale"] = softmax_scale_hd | |
result = flashattn_function( | |
q, | |
k, | |
v, | |
causal=False, | |
dropout_p=0.0, | |
**kwargs, | |
) | |
if skip_output_reshape: | |
return result.transpose(1, 2) | |
return result.reshape(batch, -1, heads * dim_head) | |
def copy_funattrs(fun, dest=None): | |
if dest is None: | |
dest = fun.__class__(fun.__code__, fun.__globals__) | |
for k in ( | |
"__code__", | |
"__defaults__", | |
"__kwdefaults__", | |
"__module__", | |
): | |
setattr(dest, k, getattr(fun, k)) | |
return dest | |
def make_flashattn_wrapper( | |
*, | |
orig_attn, | |
**kwargs: dict, | |
): | |
outer_kwargs = kwargs | |
def attn( | |
*args: list, | |
_flash_outer_kwargs=outer_kwargs, | |
_flash_orig_attention=orig_attn, | |
_flash_attn=attention_flash, | |
**kwargs: dict, | |
) -> torch.Tensor: | |
return _flash_attn( | |
*args, | |
orig_attention=_flash_orig_attention, | |
**_flash_outer_kwargs, | |
**kwargs, | |
) | |
return attn | |
@contextlib.contextmanager | |
def flashattn_context( | |
enabled: bool, | |
**kwargs: dict, | |
): | |
if not enabled: | |
yield None | |
return | |
orig_attn = copy_funattrs(comfyattn.optimized_attention) | |
attn = make_flashattn_wrapper(orig_attn=orig_attn, **kwargs) | |
try: | |
copy_funattrs(attn, comfyattn.optimized_attention) | |
yield None | |
finally: | |
copy_funattrs(orig_attn, comfyattn.optimized_attention) | |
def get_yaml_parameters(yaml_parameters: str | None = None) -> dict: | |
if not yaml_parameters: | |
return {} | |
extra_params = yaml.safe_load(yaml_parameters) | |
if extra_params is None: | |
return {} | |
if not isinstance(extra_params, dict): | |
raise ValueError( # noqa: TRY004 | |
"BlehFlashAttention: yaml_parameters must either be null or an object", | |
) | |
return extra_params | |
class BlehGlobalFlashAttention: | |
DESCRIPTION = "Deprecated: Prefer using BlehFlashAttentionSampler if possible. This node allows globally replacing ComfyUI's attention with SageAtteniton (performance enhancement). Requires SageAttention to be installed into the ComfyUI Python environment. IMPORTANT: This is not a normal model patch. For settings to apply (including toggling on or off) the node must actually be run. If you toggle it on, run your workflow and then bypass or mute the node this will not actually disable SageAttention." | |
RETURN_TYPES = ("MODEL",) | |
FUNCTION = "go" | |
CATEGORY = "hacks" | |
@classmethod | |
def INPUT_TYPES(cls) -> dict: | |
return { | |
"required": { | |
"model": ("MODEL",), | |
"enabled": ( | |
"BOOLEAN", | |
{"default": True}, | |
), | |
}, | |
"optional": { | |
"yaml_parameters": ( | |
"STRING", | |
{ | |
"tooltip": "Allows specifying custom parameters via YAML. These are mostly passed directly to the SageAttention function with no error checking. Must be empty or a YAML object.", | |
"dynamicPrompts": False, | |
"multiline": True, | |
"defaultInput": True, | |
}, | |
), | |
}, | |
} | |
orig_attn = None | |
@classmethod | |
def go( | |
cls, | |
*, | |
model: object, | |
enabled: bool, | |
yaml_parameters: str | None = None, | |
) -> tuple: | |
if not enabled: | |
if cls.orig_attn is not None: | |
copy_funattrs(cls.orig_attn, comfyattn.optimized_attention) | |
cls.orig_attn = None | |
return (model,) | |
if not cls.orig_attn: | |
cls.orig_attn = copy_funattrs(comfyattn.optimized_attention) | |
attn = make_flashattn_wrapper( | |
orig_attn=cls.orig_attn, | |
**get_yaml_parameters(yaml_parameters), | |
) | |
copy_funattrs(attn, comfyattn.optimized_attention) | |
return (model,) | |
def flashattn_sampler( | |
model: object, | |
x: torch.Tensor, | |
sigmas: torch.Tensor, | |
*, | |
flashattn_sampler_options: tuple, | |
**kwargs: dict, | |
) -> torch.Tensor: | |
sampler, start_percent, end_percent, flashattn_kwargs = flashattn_sampler_options | |
ms = model.inner_model.inner_model.model_sampling | |
start_sigma, end_sigma = ( | |
round(ms.percent_to_sigma(start_percent), 4), | |
round(ms.percent_to_sigma(end_percent), 4), | |
) | |
del ms | |
def model_wrapper( | |
x: torch.Tensor, | |
sigma: torch.Tensor, | |
**extra_args: dict[str], | |
) -> torch.Tensor: | |
sigma_float = float(sigma.max().detach().cpu()) | |
enabled = end_sigma <= sigma_float <= start_sigma | |
with flashattn_context( | |
enabled=enabled, | |
**flashattn_kwargs, | |
): | |
return model(x, sigma, **extra_args) | |
for k in ( | |
"inner_model", | |
"sigmas", | |
): | |
if hasattr(model, k): | |
setattr(model_wrapper, k, getattr(model, k)) | |
return sampler.sampler_function( | |
model_wrapper, | |
x, | |
sigmas, | |
**kwargs, | |
**sampler.extra_options, | |
) | |
class BlehFlashAttentionSampler: | |
DESCRIPTION = "Sampler wrapper that enables using SageAttention (performance enhancement) while sampling is in progress. Requires SageAttention to be installed into the ComfyUI Python environment." | |
CATEGORY = "sampling/custom_sampling/samplers" | |
RETURN_TYPES = ("SAMPLER",) | |
FUNCTION = "go" | |
@classmethod | |
def INPUT_TYPES(cls) -> dict: | |
return { | |
"required": { | |
"sampler": ("SAMPLER",), | |
}, | |
"optional": { | |
"start_percent": ( | |
"FLOAT", | |
{ | |
"default": 0.0, | |
"min": 0.0, | |
"max": 1.0, | |
"step": 0.001, | |
"tooltip": "Time the effect becomes active as a percentage of sampling, not steps.", | |
}, | |
), | |
"end_percent": ( | |
"FLOAT", | |
{ | |
"default": 1.0, | |
"min": 0.0, | |
"max": 1.0, | |
"step": 0.001, | |
"tooltip": "Time the effect ends (inclusive) as a percentage of sampling, not steps.", | |
}, | |
), | |
"yaml_parameters": ( | |
"STRING", | |
{ | |
"tooltip": "Allows specifying custom parameters via YAML. These are mostly passed directly to the SageAttention function with no error checking. Must be empty or a YAML object.", | |
"dynamicPrompts": False, | |
"multiline": True, | |
"defaultInput": True, | |
}, | |
), | |
}, | |
} | |
@classmethod | |
def go( | |
cls, | |
sampler: object, | |
*, | |
start_percent: float = 0.0, | |
end_percent: float = 1.0, | |
yaml_parameters: str | None = None, | |
) -> tuple: | |
return ( | |
KSAMPLER( | |
flashattn_sampler, | |
extra_options={ | |
"flashattn_sampler_options": ( | |
sampler, | |
start_percent, | |
end_percent, | |
get_yaml_parameters(yaml_parameters), | |
), | |
}, | |
), | |
) | |
NODE_CLASS_MAPPINGS = { | |
"BlehGlobalFlashAttention": BlehGlobalFlashAttention, | |
"BlehFlashAttentionSampler": BlehFlashAttentionSampler, | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment