-
-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Note bundled flash-linear-attention kernels for gated-deltanet models #6850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -203,6 +203,45 @@ def _get_user_task_config_attrs(user_config): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "gpt_oss", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Architectures with gated-deltanet (linear attention) layers. Unsloth bundles the | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # flash-linear-attention Triton kernels (unsloth_zoo/_vendored/fla), so no install is | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # needed; transformers uses the much slower pure PyTorch path only when they can't be enabled. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| FLA_MODEL_TYPE_PREFIXES = ("qwen3_next", "qwen3_5", "kimi_linear") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _fla_advised = False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _maybe_advise_fla_install(model_types): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """One-time note when a gated-deltanet model loads without the fast kernels. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| The kernels ship with Unsloth (no install needed); this fires only when they | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| could not be enabled on this platform (e.g. no CUDA, torch < 2.7 or | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| triton < 3.3), i.e. exactly when transformers uses the slow pure PyTorch path. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| global _fla_advised | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if _fla_advised: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if model_types is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(model_types, str): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_types = [model_types] # a lone string would otherwise iterate chars | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not any( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| isinstance(t, str) and t.startswith(FLA_MODEL_TYPE_PREFIXES) for t in model_types | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformers.utils.import_utils import is_flash_linear_attention_available | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if is_flash_linear_attention_available(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return # bundled (or user-installed) fast kernels are active | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+227
to
+236
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If To make this more robust, we should isolate the import and check of
Suggested change
References
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _fla_advised = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| print( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "Unsloth: This model uses gated-deltanet linear attention layers. Unsloth\n" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "bundles the flash-linear-attention kernels, but they could not be enabled\n" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "on this setup (they need CUDA with torch >= 2.7 and triton >= 3.3), so\n" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "transformers will use a slower pure PyTorch path." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _fix_rope_inv_freq(model): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Fix inv_freq corruption caused by transformers v5 meta-device loading. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1263,6 +1302,7 @@ def _dispatch_diffusion(): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| trust_remote_code = trust_remote_code, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model_types_all = ",".join(model_types) + "," | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _maybe_advise_fla_install(model_types) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # ---- Text-diffusion models (e.g. DiffusionGemma) take a transformers-only slow path. ---- | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # These use a custom block-diffusion `generate` and a novel backbone, so we skip Unsloth's | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make
_maybe_advise_fla_installmore robust and defensive, we should handle cases wheremodel_typesisNone, empty, or a single string (which would otherwise cause character-by-character iteration and fail to match).Additionally, to ensure this advisory is triggered regardless of the loading path, consider calling
_maybe_advise_fla_install(model_types)inFastLanguageModel.from_pretrained(around line 620) right aftermodel_typesis resolved.