Fix GenerationConfig continuous batching serialization#47038
Fix GenerationConfig continuous batching serialization#47038VectorPeak wants to merge 4 commits into
Conversation
|
CI note: the remaining red check appears to be a self-hosted runner/container infrastructure failure rather than a test failure from this PR. The CI recap reports I do not have permission to rerun the failed upstream job from the fork side. |
| def to_dict(self) -> dict[str, Any]: | ||
| """Serializes this instance to a Python dictionary.""" | ||
| output = copy.deepcopy(self.__dict__) | ||
| if self.varlen_compile_config is not None: | ||
| output["varlen_compile_config"] = self.varlen_compile_config.to_dict() | ||
| if self.decode_compile_config is not None: | ||
| output["decode_compile_config"] = self.decode_compile_config.to_dict() | ||
| return output |
There was a problem hiding this comment.
rather than manually doing it per key, lets force all dataclasses to resolve as dicts. I copied this from a few lines above
def convert_dataclass_to_dict(obj):
if isinstance(obj, dict):
return {key: convert_dataclass_to_dict(value) for key, value in obj.items()}
elif is_dataclass(obj):
# Some of our dataclasses have a custom `to_dict()` method, and we prefer it
if hasattr(obj, "to_dict"):
return obj.to_dict()
else:
return objThere was a problem hiding this comment.
Thanks for the suggestion! I updated the patch to use a generic dataclass fallback in convert_dataclass_to_dict() and removed the per-key ContinuousBatchingConfig.to_dict() handling.
My initial intent with the manual keys was to avoid widening the behavioral surface, but this shared fallback is cleaner and matches the direction you suggested while still preferring custom to_dict() implementations when present.
Validation rerun locally:
python -m pytest tests/generation/test_configuration_utils.py::GenerationConfigSerializationTest::test_serialize_generation_continuous_batching_config -q
python -m pytest tests/generation/test_configuration_utils.py::GenerationConfigSerializationTest::test_serialize_generation_watermarking_config -q
python -m ruff format --check src/transformers/generation/configuration_utils.py tests/generation/test_configuration_utils.py
python -m ruff check src/transformers/generation/configuration_utils.py tests/generation/test_configuration_utils.py
git diff --check
There was a problem hiding this comment.
cc @remi-or , to make sure if saving continuous-batch config is intended
CI recapDashboard: View test results in Grafana |
What does this PR do?
Fixes #47039
This PR fixes a
GenerationConfigserialization round-trip loss forContinuousBatchingConfig.What Problem This Solves
GenerationConfig.save_pretrained()persists generation settings by writing the result ofGenerationConfig.to_json_string()intogeneration_config.json. During that JSON conversion, nested dataclass values are passed throughconvert_dataclass_to_dict(), which currently serializes dataclasses by callingto_dict()when that method exists.The problem is that
ContinuousBatchingConfigis a dataclass, but it did not defineto_dict(). Because that helper had no fallback return for dataclasses withoutto_dict(), the continuous batching config could silently fall through asNoneduring JSON conversion. In practice, a configured continuous batching block could therefore be persisted as JSONnull.A user or service can construct a generation config like this:
Before this fix, the saved
generation_config.jsoncould contain:{ "continuous_batching_config": null }That means the saved config no longer carries the actual continuous batching parameters, including values such as
block_size,default_compile_level, and the nestedvarlen_compile_config/decode_compile_configsettings. After a normalsave_pretrained()->from_pretrained()round trip, the loadedGenerationConfighas lost the continuous batching configuration instead of reconstructing it.This matters for serving and inference setups that rely on saved generation configs: a configuration that was valid in memory can become incomplete after being saved and reloaded, so behavior can drift from the original runtime settings without an explicit error.
Change
This PR fixes both directions of the
GenerationConfiground trip: writingContinuousBatchingConfiginto JSON, and restoring it back into typed config objects when the generation config is loaded again.For the save / serialization path:
ContinuousBatchingConfig.to_dict()soconvert_dataclass_to_dict()has an explicit structured representation to use instead of falling through toNone.block_size,default_compile_level,max_cached_graphs, and other continuous batching knobs.varlen_compile_configanddecode_compile_configserialization toCompileConfig.to_dict()when those fields are present, so nested compile settings keep the same serialization behavior as standaloneCompileConfigobjects.CompileConfig.to_dict()filtering behavior, including not leaking internal implementation fields such as_compile_all_devicesinto the saved JSON.For the load / deserialization path:
continuous_batching_configdictionary back into aContinuousBatchingConfiginsideGenerationConfig.__init__, matching the existing pattern used by other nested generation config objects.varlen_compile_configanddecode_compile_configdictionaries back intoCompileConfigduringContinuousBatchingConfig.__post_init__, so callers get typed config objects afterGenerationConfig.from_pretrained()instead of raw dictionaries.For coverage:
GenerationConfigcontainingContinuousBatchingConfig, reloads it withGenerationConfig.from_pretrained(), and verifies that the continuous batching fields survive the round trip.CompileConfigvalues in the test to cover the deeper round-trip path, not just the top-levelContinuousBatchingConfigobject.CompileConfiginstances and that internal compile-only fields are not emitted throughto_dict().This keeps the fix scoped to continuous batching generation config serialization. It does not alter continuous batching scheduling, generation execution, compile defaults, or unrelated generation config fields; it only makes the saved configuration faithfully represent the object that was already present in memory.
Evidence
Local behavior proof after the patch:
The final
Falseverifies that_compile_all_devicesis not leaked through nestedCompileConfig.to_dict()serialization.Possible call chain / impact
This PR only changes serialization/deserialization of
ContinuousBatchingConfig. It does not change continuous batching runtime scheduling, generation behavior, compile defaults, or unrelated generation config fields.Code Agent Policy
AI assistance was used for diagnosis, patch drafting, validation planning, and PR text preparation. The human submitter should review all changed lines and understand the diff before checking this box.
Before submitting
Validation run locally:
Limitations:
was not run because
makeis unavailable in the current Windows PowerShell environment.failed with an
IndexErrorafter selecting files from the previous commit rather than the current uncommitted diff. The defaulttests_fetcher.pyinvocation completed, but reported no changed files before the patch was committed.Who can review?
Continuous batching / generation reviewers from the template are likely relevant after tests pass and the coordination issue has maintainer feedback: @remi-or, @ArthurZucker, @McPatate.