Coverage for wifa_uq / workflow_schema.py: 94%
127 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-19 02:10 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-19 02:10 +0000
1from pydantic import BaseModel, Field, model_validator
2from typing import Literal, Optional, Union, Any
3from pathlib import Path
6class ParamConfigDict(BaseModel):
7 """Full parameter configuration with explicit fields."""
9 range: tuple[float, float]
10 default: Optional[float] = None
11 short_name: Optional[str] = None
14class PreprocessingConfig(BaseModel):
15 """Configuration for the preprocessing step."""
17 run: bool = False
18 steps: list[Literal["recalculate_params"]] = Field(default_factory=list)
21class DatabaseGenConfig(BaseModel):
22 """Configuration for database generation."""
24 run: bool = False
25 flow_model: Literal["pywake", "foxes"] = "pywake"
26 n_samples: int = 100
27 # Accepts either [min, max] list or full dict with range/default/short_name
28 param_config: dict[str, Union[ParamConfigDict, list[float]]] = Field(
29 default_factory=dict
30 )
33class CrossValidationConfig(BaseModel):
34 """Configuration for cross-validation."""
36 run: bool = False
37 splitting_mode: Literal["kfold_shuffled", "LeaveOneGroupOut"] = "kfold_shuffled"
38 n_splits: int = 5
39 metrics: list[Literal["rmse", "r2", "mae"]] = Field(
40 default_factory=lambda: ["rmse", "r2", "mae"]
41 )
42 # Groups for LeaveOneGroupOut CV: maps group name -> list of farm names
43 groups: Optional[dict[str, list[str]]] = None
46class PCEModelParams(BaseModel):
47 """Parameters specific to PCE model."""
49 degree: int = 5
50 marginals: Literal["kernel", "uniform", "normal"] = "kernel"
51 copula: Literal["independent", "normal"] = "independent"
52 q: float = 1.0
53 max_features: int = 5
54 allow_high_dim: bool = False
57class LinearModelParams(BaseModel):
58 """Parameters specific to Linear model."""
60 method: Literal["ols", "ridge", "lasso", "elasticnet"] = "ols"
61 alpha: float = 1.0
62 l1_ratio: float = 0.5 # Only used for elasticnet
65class XGBModelParams(BaseModel):
66 """Parameters specific to XGBoost model."""
68 max_depth: int = 3
69 n_estimators: int = 500
70 learning_rate: float = 0.1
71 random_state: Optional[int] = None
74class SIRModelParams(BaseModel):
75 """Parameters specific to SIR+Polynomial model."""
77 n_directions: int = 1
78 degree: int = 2
81class ErrorPredictionConfig(BaseModel):
82 """Configuration for error prediction and cross-validation."""
84 run: bool = False
85 features: list[str]
86 model: Literal["XGB", "PCE", "SIRPolynomial", "Linear"] = "XGB"
87 # Model params - validated based on model type, or pass as generic dict
88 model_params: dict[str, Any] = Field(default_factory=dict)
89 calibrator: Literal[
90 "MinBiasCalibrator",
91 "LocalParameterPredictor",
92 "DefaultParams",
93 "BayesianCalibration",
94 ] = "MinBiasCalibrator"
95 # Only used when calibrator is LocalParameterPredictor
96 local_regressor: Optional[
97 Literal["Ridge", "Linear", "Lasso", "ElasticNet", "RandomForest", "XGB"]
98 ] = None
99 local_regressor_params: dict[str, Any] = Field(default_factory=dict)
100 bias_predictor: Literal["BiasPredictor"] = "BiasPredictor"
101 cross_validation: CrossValidationConfig = Field(
102 default_factory=CrossValidationConfig
103 )
105 @model_validator(mode="after")
106 def validate_local_calibrator(self):
107 """Warn if local_regressor is set but calibrator is not LocalParameterPredictor."""
108 if (
109 self.local_regressor is not None
110 and self.calibrator != "LocalParameterPredictor"
111 ):
112 import warnings
114 warnings.warn(
115 f"local_regressor is set to '{self.local_regressor}' but calibrator is "
116 f"'{self.calibrator}'. local_regressor is only used with LocalParameterPredictor."
117 )
118 return self
121class PCESensitivityConfig(BaseModel):
122 """PCE-specific sensitivity analysis configuration."""
124 degree: int = 5
125 marginals: Literal["kernel", "uniform", "normal"] = "kernel"
126 copula: Literal["independent", "normal"] = "independent"
127 q: float = 0.5
130class SensitivityConfig(BaseModel):
131 """Configuration for sensitivity analysis."""
133 run_observation_sensitivity: bool = False
134 run_bias_sensitivity: bool = False
135 method: Literal["auto", "shap", "pce_sobol", "sir"] = "auto"
136 pce_config: Optional[PCESensitivityConfig] = None
139class PartialDependenceConfig(BaseModel):
140 """Configuration for partial dependence analysis."""
142 enabled: bool = True
143 grid_resolution: int = 50
146class InteractionsConfig(BaseModel):
147 """Configuration for feature interaction analysis."""
149 enabled: bool = True
150 top_n: int = 5 # Number of top interactions to report
153class RegimeAnalysisConfig(BaseModel):
154 """Configuration for error regime identification."""
156 enabled: bool = True
157 n_clusters: int = 3 # Number of regimes to identify
158 bias_percentile: float = 75.0 # Focus on cases above this percentile
161class ParameterRelationshipsConfig(BaseModel):
162 """Configuration for parameter-condition relationship analysis."""
164 enabled: bool = True # Only works with LocalParameterPredictor
167class PhysicsInsightsConfig(BaseModel):
168 """
169 Configuration for physics insights analysis.
171 Extracts interpretable physical understanding from ML bias correction models,
172 answering questions like "why does the wake model fail in certain conditions?"
174 Outputs:
175 - partial_dependence.png: How bias varies with each atmospheric feature
176 - feature_interactions.png: Which feature combinations jointly drive error
177 - error_regimes.png: Distinct failure mode clusters
178 - parameter_relationships.png: How optimal params vary with conditions
179 - physics_insights_report.md: Human-readable summary
180 - physics_insights.json: Machine-readable results
181 """
183 run: bool = False
184 partial_dependence: PartialDependenceConfig = Field(
185 default_factory=PartialDependenceConfig
186 )
187 interactions: InteractionsConfig = Field(default_factory=InteractionsConfig)
188 regime_analysis: RegimeAnalysisConfig = Field(default_factory=RegimeAnalysisConfig)
189 parameter_relationships: ParameterRelationshipsConfig = Field(
190 default_factory=ParameterRelationshipsConfig
191 )
194class FarmConfig(BaseModel):
195 """Configuration for a single farm in multi-farm mode."""
197 name: str
198 system_config: Path
199 # These are optional - will be inferred from windIO structure if not provided
200 reference_power: Optional[Path] = None
201 reference_resource: Optional[Path] = None
202 wind_farm_layout: Optional[Path] = None
205class PathsConfig(BaseModel):
206 """Path configuration for single-farm or output paths for multi-farm."""
208 # Required for single-farm mode, not needed for multi-farm
209 system_config: Optional[Path] = None
210 # These are optional - will be inferred from windIO structure if not provided
211 reference_power: Optional[Path] = None
212 reference_resource: Optional[Path] = None
213 wind_farm_layout: Optional[Path] = None
214 # Output configuration
215 output_dir: Path
216 processed_resource_file: str = "processed_physical_inputs.nc"
217 database_file: str = "results_stacked_hh.nc"
220class WifaUQConfig(BaseModel):
221 """
222 Main WIFA-UQ workflow configuration.
224 Supports two modes:
225 1. Single-farm: Specify paths.system_config (other paths auto-inferred)
226 2. Multi-farm: Specify farms list with each farm's system_config
228 Example single-farm config:
229 paths:
230 system_config: wind_energy_system.yaml
231 output_dir: results/
232 preprocessing:
233 run: true
234 steps: [recalculate_params]
235 ...
237 Example multi-farm config:
238 paths:
239 output_dir: results/multi_farm/
240 farms:
241 - name: Farm1
242 system_config: farm1/system.yaml
243 - name: Farm2
244 system_config: farm2/system.yaml
245 ...
246 """
248 description: Optional[str] = None
249 paths: PathsConfig
250 # For multi-farm mode
251 farms: Optional[list[FarmConfig]] = None
252 preprocessing: PreprocessingConfig = Field(default_factory=PreprocessingConfig)
253 database_gen: DatabaseGenConfig = Field(default_factory=DatabaseGenConfig)
254 error_prediction: ErrorPredictionConfig
255 sensitivity_analysis: SensitivityConfig = Field(default_factory=SensitivityConfig)
256 physics_insights: PhysicsInsightsConfig = Field(
257 default_factory=PhysicsInsightsConfig
258 )
260 @model_validator(mode="after")
261 def check_paths_or_farms(self):
262 """Validate that either single-farm paths or multi-farm config is provided."""
263 is_multi_farm = self.farms is not None and len(self.farms) > 0
264 is_single_farm = self.paths.system_config is not None
266 if not is_multi_farm and not is_single_farm:
267 raise ValueError(
268 "Configuration must specify either:\n"
269 " - paths.system_config (for single-farm mode), or\n"
270 " - farms list (for multi-farm mode)"
271 )
273 if is_multi_farm and is_single_farm:
274 raise ValueError(
275 "Cannot specify both paths.system_config and farms list. "
276 "Choose single-farm or multi-farm mode."
277 )
279 return self
281 @model_validator(mode="after")
282 def check_logo_groups(self):
283 """Validate LeaveOneGroupOut has groups defined in multi-farm mode."""
284 cv_config = self.error_prediction.cross_validation
285 if cv_config.splitting_mode == "LeaveOneGroupOut":
286 if self.farms is None:
287 raise ValueError(
288 "LeaveOneGroupOut splitting requires multi-farm mode (farms list must be specified)"
289 )
290 if cv_config.groups is None:
291 # Groups are optional - if not specified, each farm becomes its own group
292 pass
293 return self
295 def is_multi_farm(self) -> bool:
296 """Check if this is a multi-farm configuration."""
297 return self.farms is not None and len(self.farms) > 0