Coverage for wifa_uq / workflow_schema.py: 94%

127 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-19 02:10 +0000

1from pydantic import BaseModel, Field, model_validator 

2from typing import Literal, Optional, Union, Any 

3from pathlib import Path 

4 

5 

6class ParamConfigDict(BaseModel): 

7 """Full parameter configuration with explicit fields.""" 

8 

9 range: tuple[float, float] 

10 default: Optional[float] = None 

11 short_name: Optional[str] = None 

12 

13 

14class PreprocessingConfig(BaseModel): 

15 """Configuration for the preprocessing step.""" 

16 

17 run: bool = False 

18 steps: list[Literal["recalculate_params"]] = Field(default_factory=list) 

19 

20 

21class DatabaseGenConfig(BaseModel): 

22 """Configuration for database generation.""" 

23 

24 run: bool = False 

25 flow_model: Literal["pywake", "foxes"] = "pywake" 

26 n_samples: int = 100 

27 # Accepts either [min, max] list or full dict with range/default/short_name 

28 param_config: dict[str, Union[ParamConfigDict, list[float]]] = Field( 

29 default_factory=dict 

30 ) 

31 

32 

33class CrossValidationConfig(BaseModel): 

34 """Configuration for cross-validation.""" 

35 

36 run: bool = False 

37 splitting_mode: Literal["kfold_shuffled", "LeaveOneGroupOut"] = "kfold_shuffled" 

38 n_splits: int = 5 

39 metrics: list[Literal["rmse", "r2", "mae"]] = Field( 

40 default_factory=lambda: ["rmse", "r2", "mae"] 

41 ) 

42 # Groups for LeaveOneGroupOut CV: maps group name -> list of farm names 

43 groups: Optional[dict[str, list[str]]] = None 

44 

45 

46class PCEModelParams(BaseModel): 

47 """Parameters specific to PCE model.""" 

48 

49 degree: int = 5 

50 marginals: Literal["kernel", "uniform", "normal"] = "kernel" 

51 copula: Literal["independent", "normal"] = "independent" 

52 q: float = 1.0 

53 max_features: int = 5 

54 allow_high_dim: bool = False 

55 

56 

57class LinearModelParams(BaseModel): 

58 """Parameters specific to Linear model.""" 

59 

60 method: Literal["ols", "ridge", "lasso", "elasticnet"] = "ols" 

61 alpha: float = 1.0 

62 l1_ratio: float = 0.5 # Only used for elasticnet 

63 

64 

65class XGBModelParams(BaseModel): 

66 """Parameters specific to XGBoost model.""" 

67 

68 max_depth: int = 3 

69 n_estimators: int = 500 

70 learning_rate: float = 0.1 

71 random_state: Optional[int] = None 

72 

73 

74class SIRModelParams(BaseModel): 

75 """Parameters specific to SIR+Polynomial model.""" 

76 

77 n_directions: int = 1 

78 degree: int = 2 

79 

80 

81class ErrorPredictionConfig(BaseModel): 

82 """Configuration for error prediction and cross-validation.""" 

83 

84 run: bool = False 

85 features: list[str] 

86 model: Literal["XGB", "PCE", "SIRPolynomial", "Linear"] = "XGB" 

87 # Model params - validated based on model type, or pass as generic dict 

88 model_params: dict[str, Any] = Field(default_factory=dict) 

89 calibrator: Literal[ 

90 "MinBiasCalibrator", 

91 "LocalParameterPredictor", 

92 "DefaultParams", 

93 "BayesianCalibration", 

94 ] = "MinBiasCalibrator" 

95 # Only used when calibrator is LocalParameterPredictor 

96 local_regressor: Optional[ 

97 Literal["Ridge", "Linear", "Lasso", "ElasticNet", "RandomForest", "XGB"] 

98 ] = None 

99 local_regressor_params: dict[str, Any] = Field(default_factory=dict) 

100 bias_predictor: Literal["BiasPredictor"] = "BiasPredictor" 

101 cross_validation: CrossValidationConfig = Field( 

102 default_factory=CrossValidationConfig 

103 ) 

104 

105 @model_validator(mode="after") 

106 def validate_local_calibrator(self): 

107 """Warn if local_regressor is set but calibrator is not LocalParameterPredictor.""" 

108 if ( 

109 self.local_regressor is not None 

110 and self.calibrator != "LocalParameterPredictor" 

111 ): 

112 import warnings 

113 

114 warnings.warn( 

115 f"local_regressor is set to '{self.local_regressor}' but calibrator is " 

116 f"'{self.calibrator}'. local_regressor is only used with LocalParameterPredictor." 

117 ) 

118 return self 

119 

120 

121class PCESensitivityConfig(BaseModel): 

122 """PCE-specific sensitivity analysis configuration.""" 

123 

124 degree: int = 5 

125 marginals: Literal["kernel", "uniform", "normal"] = "kernel" 

126 copula: Literal["independent", "normal"] = "independent" 

127 q: float = 0.5 

128 

129 

130class SensitivityConfig(BaseModel): 

131 """Configuration for sensitivity analysis.""" 

132 

133 run_observation_sensitivity: bool = False 

134 run_bias_sensitivity: bool = False 

135 method: Literal["auto", "shap", "pce_sobol", "sir"] = "auto" 

136 pce_config: Optional[PCESensitivityConfig] = None 

137 

138 

139class PartialDependenceConfig(BaseModel): 

140 """Configuration for partial dependence analysis.""" 

141 

142 enabled: bool = True 

143 grid_resolution: int = 50 

144 

145 

146class InteractionsConfig(BaseModel): 

147 """Configuration for feature interaction analysis.""" 

148 

149 enabled: bool = True 

150 top_n: int = 5 # Number of top interactions to report 

151 

152 

153class RegimeAnalysisConfig(BaseModel): 

154 """Configuration for error regime identification.""" 

155 

156 enabled: bool = True 

157 n_clusters: int = 3 # Number of regimes to identify 

158 bias_percentile: float = 75.0 # Focus on cases above this percentile 

159 

160 

161class ParameterRelationshipsConfig(BaseModel): 

162 """Configuration for parameter-condition relationship analysis.""" 

163 

164 enabled: bool = True # Only works with LocalParameterPredictor 

165 

166 

167class PhysicsInsightsConfig(BaseModel): 

168 """ 

169 Configuration for physics insights analysis. 

170 

171 Extracts interpretable physical understanding from ML bias correction models, 

172 answering questions like "why does the wake model fail in certain conditions?" 

173 

174 Outputs: 

175 - partial_dependence.png: How bias varies with each atmospheric feature 

176 - feature_interactions.png: Which feature combinations jointly drive error 

177 - error_regimes.png: Distinct failure mode clusters 

178 - parameter_relationships.png: How optimal params vary with conditions 

179 - physics_insights_report.md: Human-readable summary 

180 - physics_insights.json: Machine-readable results 

181 """ 

182 

183 run: bool = False 

184 partial_dependence: PartialDependenceConfig = Field( 

185 default_factory=PartialDependenceConfig 

186 ) 

187 interactions: InteractionsConfig = Field(default_factory=InteractionsConfig) 

188 regime_analysis: RegimeAnalysisConfig = Field(default_factory=RegimeAnalysisConfig) 

189 parameter_relationships: ParameterRelationshipsConfig = Field( 

190 default_factory=ParameterRelationshipsConfig 

191 ) 

192 

193 

194class FarmConfig(BaseModel): 

195 """Configuration for a single farm in multi-farm mode.""" 

196 

197 name: str 

198 system_config: Path 

199 # These are optional - will be inferred from windIO structure if not provided 

200 reference_power: Optional[Path] = None 

201 reference_resource: Optional[Path] = None 

202 wind_farm_layout: Optional[Path] = None 

203 

204 

205class PathsConfig(BaseModel): 

206 """Path configuration for single-farm or output paths for multi-farm.""" 

207 

208 # Required for single-farm mode, not needed for multi-farm 

209 system_config: Optional[Path] = None 

210 # These are optional - will be inferred from windIO structure if not provided 

211 reference_power: Optional[Path] = None 

212 reference_resource: Optional[Path] = None 

213 wind_farm_layout: Optional[Path] = None 

214 # Output configuration 

215 output_dir: Path 

216 processed_resource_file: str = "processed_physical_inputs.nc" 

217 database_file: str = "results_stacked_hh.nc" 

218 

219 

220class WifaUQConfig(BaseModel): 

221 """ 

222 Main WIFA-UQ workflow configuration. 

223 

224 Supports two modes: 

225 1. Single-farm: Specify paths.system_config (other paths auto-inferred) 

226 2. Multi-farm: Specify farms list with each farm's system_config 

227 

228 Example single-farm config: 

229 paths: 

230 system_config: wind_energy_system.yaml 

231 output_dir: results/ 

232 preprocessing: 

233 run: true 

234 steps: [recalculate_params] 

235 ... 

236 

237 Example multi-farm config: 

238 paths: 

239 output_dir: results/multi_farm/ 

240 farms: 

241 - name: Farm1 

242 system_config: farm1/system.yaml 

243 - name: Farm2 

244 system_config: farm2/system.yaml 

245 ... 

246 """ 

247 

248 description: Optional[str] = None 

249 paths: PathsConfig 

250 # For multi-farm mode 

251 farms: Optional[list[FarmConfig]] = None 

252 preprocessing: PreprocessingConfig = Field(default_factory=PreprocessingConfig) 

253 database_gen: DatabaseGenConfig = Field(default_factory=DatabaseGenConfig) 

254 error_prediction: ErrorPredictionConfig 

255 sensitivity_analysis: SensitivityConfig = Field(default_factory=SensitivityConfig) 

256 physics_insights: PhysicsInsightsConfig = Field( 

257 default_factory=PhysicsInsightsConfig 

258 ) 

259 

260 @model_validator(mode="after") 

261 def check_paths_or_farms(self): 

262 """Validate that either single-farm paths or multi-farm config is provided.""" 

263 is_multi_farm = self.farms is not None and len(self.farms) > 0 

264 is_single_farm = self.paths.system_config is not None 

265 

266 if not is_multi_farm and not is_single_farm: 

267 raise ValueError( 

268 "Configuration must specify either:\n" 

269 " - paths.system_config (for single-farm mode), or\n" 

270 " - farms list (for multi-farm mode)" 

271 ) 

272 

273 if is_multi_farm and is_single_farm: 

274 raise ValueError( 

275 "Cannot specify both paths.system_config and farms list. " 

276 "Choose single-farm or multi-farm mode." 

277 ) 

278 

279 return self 

280 

281 @model_validator(mode="after") 

282 def check_logo_groups(self): 

283 """Validate LeaveOneGroupOut has groups defined in multi-farm mode.""" 

284 cv_config = self.error_prediction.cross_validation 

285 if cv_config.splitting_mode == "LeaveOneGroupOut": 

286 if self.farms is None: 

287 raise ValueError( 

288 "LeaveOneGroupOut splitting requires multi-farm mode (farms list must be specified)" 

289 ) 

290 if cv_config.groups is None: 

291 # Groups are optional - if not specified, each farm becomes its own group 

292 pass 

293 return self 

294 

295 def is_multi_farm(self) -> bool: 

296 """Check if this is a multi-farm configuration.""" 

297 return self.farms is not None and len(self.farms) > 0