fix(api): preserve hierarchical estimate rules (#36852)

Co-authored-by: root <kinsonnee@gmail.com>
This commit is contained in:
WOLIKIMCHENG
2026-06-01 11:16:09 +08:00
committed by GitHub
parent 72e040ead3
commit 240912cef5
2 changed files with 68 additions and 1 deletions
+10 -1
View File
@@ -148,6 +148,11 @@ class _EstimateRules(BaseModel):
return list(seen.values())
class _EstimateHierarchicalRules(_EstimateRules):
parent_mode: Literal["full-doc", "paragraph"] | None = None
subchunk_segmentation: _EstimateSegmentation | None = None
class _SummaryIndexSettingDisabled(BaseModel):
enable: Literal[False] = False
@@ -203,7 +208,7 @@ class _HierarchicalProcessRule(BaseModel):
model_config = ConfigDict(extra="allow")
mode: Literal[ProcessRuleMode.HIERARCHICAL]
rules: _EstimateRules
rules: _EstimateHierarchicalRules
summary_index_setting: _SummaryIndexSetting | None = None
@field_validator("summary_index_setting", mode="before")
@@ -2971,6 +2976,10 @@ class DocumentService:
process_rule_dict = validated.process_rule.model_dump(exclude_none=True)
if validated.process_rule.mode == ProcessRuleMode.AUTOMATIC:
process_rule_dict["rules"] = {}
elif validated.process_rule.mode == ProcessRuleMode.HIERARCHICAL:
rules = process_rule_dict.get("rules")
if isinstance(rules, dict) and not rules.get("parent_mode"):
rules["parent_mode"] = "paragraph"
args["process_rule"] = process_rule_dict
@staticmethod
@@ -1344,6 +1344,27 @@ class TestDocumentServiceEstimateValidation:
assert args["process_rule"]["rules"]["pre_processing_rules"] == [{"id": "remove_stopwords", "enabled": False}]
def test_estimate_args_validate_custom_mode_drops_hierarchical_fields(self):
args = {
"info_list": {"data_source_type": "upload_file"},
"process_rule": {
"mode": "custom",
"rules": {
"pre_processing_rules": [{"id": "remove_stopwords", "enabled": True}],
"segmentation": {"separator": "\n", "max_tokens": 128},
"parent_mode": "full-doc",
"subchunk_segmentation": {"separator": "###", "max_tokens": 64},
},
},
}
DocumentService.estimate_args_validate(args)
assert args["process_rule"]["rules"] == {
"pre_processing_rules": [{"id": "remove_stopwords", "enabled": True}],
"segmentation": {"separator": "\n", "max_tokens": 128},
}
def test_estimate_args_validate_requires_summary_index_provider_name(self):
args = {
"info_list": {"data_source_type": "upload_file"},
@@ -1360,6 +1381,43 @@ class TestDocumentServiceEstimateValidation:
with pytest.raises(ValueError, match="Field required"):
DocumentService.estimate_args_validate(args)
def test_estimate_args_validate_preserves_hierarchical_fields(self):
args = {
"info_list": {"data_source_type": "upload_file"},
"process_rule": {
"mode": "hierarchical",
"rules": {
"pre_processing_rules": [{"id": "remove_stopwords", "enabled": True}],
"segmentation": {"separator": "\n", "max_tokens": 512},
"parent_mode": "full-doc",
"subchunk_segmentation": {"separator": "###", "max_tokens": 128},
},
},
}
DocumentService.estimate_args_validate(args)
assert args["process_rule"]["rules"]["parent_mode"] == "full-doc"
assert args["process_rule"]["rules"]["subchunk_segmentation"] == {"separator": "###", "max_tokens": 128}
def test_estimate_args_validate_hierarchical_defaults_parent_mode_to_paragraph(self):
args = {
"info_list": {"data_source_type": "upload_file"},
"process_rule": {
"mode": "hierarchical",
"rules": {
"pre_processing_rules": [{"id": "remove_stopwords", "enabled": True}],
"segmentation": {"separator": "\n", "max_tokens": 512},
"subchunk_segmentation": {"separator": "###", "max_tokens": 128},
},
},
}
DocumentService.estimate_args_validate(args)
assert args["process_rule"]["rules"]["parent_mode"] == "paragraph"
assert args["process_rule"]["rules"]["subchunk_segmentation"] == {"separator": "###", "max_tokens": 128}
class TestDocumentServiceSaveDocumentAdditionalBranches:
"""Additional unit tests for dataset bootstrap and process-rule branches."""