Coverage for wifa_uq / workflow.py: 93%

201 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-19 02:10 +0000

1# wifa_uq/workflow.py 

2""" 

3Main workflow orchestration for WIFA-UQ. 

4 

5Supports single-farm and multi-farm configurations with optional physics insights. 

6""" 

7 

8import yaml 

9import xarray as xr 

10from pathlib import Path 

11from sklearn.pipeline import Pipeline 

12from sklearn.preprocessing import StandardScaler 

13import xgboost as xgb 

14import numpy as np 

15 

16from wifa_uq.model_error_database.multi_farm_gen import generate_multi_farm_database 

17from wifa_uq.preprocessing.preprocessing import PreprocessingInputs 

18from wifa_uq.postprocessing.error_predictor.error_predictor import PCERegressor 

19from wifa_uq.model_error_database.database_gen import DatabaseGenerator 

20from wifa_uq.postprocessing.error_predictor.error_predictor import ( 

21 BiasPredictor, 

22 MainPipeline, 

23 run_cross_validation, 

24 run_observation_sensitivity, 

25 SIRPolynomialRegressor, 

26) 

27 

28# from wifa_uq.postprocessing.bayesian_calibration import BayesianCalibrationWrapper 

29from wifa_uq.postprocessing.calibration import ( 

30 MinBiasCalibrator, 

31 DefaultParams, 

32 LocalParameterPredictor, 

33) 

34 

35# --- Dynamic Class Loading --- 

36CLASS_MAP = { 

37 # Calibrators 

38 "MinBiasCalibrator": MinBiasCalibrator, 

39 "DefaultParams": DefaultParams, 

40 "LocalParameterPredictor": LocalParameterPredictor, 

41 # Bayesian 

42 # "BayesianCalibration": BayesianCalibrationWrapper, 

43 # Predictors 

44 "BiasPredictor": BiasPredictor, 

45 # ML Models 

46 "XGBRegressor": xgb.XGBRegressor, 

47 "SIRPolynomialRegressor": SIRPolynomialRegressor, 

48} 

49 

50CALIBRATION_MODES = { 

51 "MinBiasCalibrator": "global", 

52 "DefaultParams": "global", 

53 "LocalParameterPredictor": "local", 

54 "BayesianCalibration": "global", 

55} 

56 

57 

58def get_class_from_map(class_name: str): 

59 if class_name not in CLASS_MAP: 

60 raise ValueError( 

61 f"Unknown class '{class_name}' in config. " 

62 f"Available classes are: {list(CLASS_MAP.keys())}" 

63 ) 

64 return CLASS_MAP[class_name] 

65 

66 

67def build_predictor_pipeline(model_name: str, model_params: dict | None = None): 

68 """ 

69 Factory function to build the predictor pipeline based on config. 

70 Returns the pipeline and a 'model_type' string for SHAP logic. 

71 """ 

72 if model_params is None: 

73 model_params = {} 

74 

75 if model_name == "Linear": 

76 from wifa_uq.postprocessing.error_predictor.error_predictor import ( 

77 LinearRegressor, 

78 ) 

79 

80 print( 

81 f"Building Linear Regressor (method={model_params.get('method', 'ols')})..." 

82 ) 

83 pipeline = LinearRegressor(**model_params) 

84 model_type = "linear" 

85 return pipeline, model_type 

86 

87 if model_name == "XGB": 

88 print("Building XGBoost Regressor pipeline...") 

89 xgb_params = { 

90 "max_depth": model_params.get("max_depth", 3), 

91 "n_estimators": model_params.get("n_estimators", 500), 

92 "learning_rate": model_params.get("learning_rate", 0.1), 

93 "random_state": model_params.get("random_state", 42), 

94 } 

95 pipeline = Pipeline( 

96 [ 

97 ("scaler", StandardScaler()), 

98 ("model", xgb.XGBRegressor(**xgb_params)), 

99 ] 

100 ) 

101 model_type = "tree" 

102 

103 elif model_name == "SIRPolynomial": 

104 print("Building SIR+Polynomial Regressor pipeline...") 

105 pipeline = SIRPolynomialRegressor(n_directions=1, degree=2) 

106 model_type = "sir" 

107 

108 elif model_name == "PCE": 

109 print("Building PCE Regressor pipeline...") 

110 pipeline = PCERegressor(**model_params) 

111 model_type = "pce" 

112 else: 

113 raise ValueError( 

114 f"Unknown model '{model_name}' in config. " 

115 f"Available models are: ['XGB', 'SIRPolynomial', 'PCE', 'Linear']" 

116 ) 

117 return pipeline, model_type 

118 

119 

120def _is_multi_farm_config(config: dict) -> bool: 

121 """Check if config specifies multiple farms.""" 

122 return "farms" in config.get("paths", {}) or "farms" in config 

123 

124 

125def _validate_farm_configs(farms: list[dict]) -> None: 

126 """ 

127 Validate farm configurations. 

128 

129 Each farm must have: 

130 - name: Unique identifier for cross-validation grouping 

131 - system_config: Path to wind energy system YAML 

132 """ 

133 required_keys = {"name", "system_config"} 

134 names_seen = set() 

135 

136 for i, farm in enumerate(farms): 

137 # Check required keys 

138 missing = required_keys - set(farm.keys()) 

139 if missing: 

140 raise ValueError( 

141 f"Farm #{i + 1} is missing required keys: {missing}. " 

142 f"Each farm must have 'name' and 'system_config'." 

143 ) 

144 

145 # Check for duplicate names 

146 name = farm["name"] 

147 if name in names_seen: 

148 raise ValueError( 

149 f"Duplicate farm name: '{name}'. Each farm must have a unique name." 

150 ) 

151 names_seen.add(name) 

152 

153 print(f"Validated {len(farms)} farm configurations") 

154 

155 

156def _resolve_farm_paths(farm_config: dict, base_dir: Path) -> dict: 

157 """ 

158 Resolve relative paths in a farm config to absolute paths. 

159 

160 Required keys: 

161 - name: Farm identifier (passed through as-is) 

162 - system_config: Path to wind energy system YAML 

163 

164 Optional keys (for explicit path overrides): 

165 - reference_power: Path to reference power NetCDF 

166 - reference_resource: Path to reference resource NetCDF 

167 - wind_farm_layout: Path to wind farm layout YAML 

168 """ 

169 resolved = {"name": farm_config["name"]} 

170 

171 path_keys = [ 

172 "system_config", 

173 "reference_power", 

174 "reference_resource", 

175 "wind_farm_layout", 

176 ] 

177 

178 for key in path_keys: 

179 if key in farm_config: 

180 resolved[key] = base_dir / farm_config[key] 

181 

182 return resolved 

183 

184 

185def run_workflow(config_path: str | Path): 

186 """ 

187 Runs the full WIFA-UQ workflow from a configuration file. 

188 

189 Supports both single-farm and multi-farm configurations. 

190 """ 

191 config_path = Path(config_path).resolve() 

192 with open(config_path, "r") as f: 

193 config = yaml.safe_load(f) 

194 

195 base_dir = config_path.parent 

196 

197 # Detect single vs multi-farm mode 

198 is_multi_farm = _is_multi_farm_config(config) 

199 

200 if is_multi_farm: 

201 return _run_multi_farm_workflow(config, base_dir) 

202 else: 

203 return _run_single_farm_workflow(config, base_dir) 

204 

205 

206def _run_single_farm_workflow(config: dict, base_dir: Path): 

207 """ 

208 Original single-farm workflow (existing implementation). 

209 """ 

210 from wifa_uq.model_error_database.path_inference import ( 

211 infer_paths_from_system_config, 

212 validate_required_paths, 

213 ) 

214 

215 # --- 0. Resolve Paths --- 

216 paths_config = config["paths"] 

217 system_yaml_path = base_dir / paths_config["system_config"] 

218 

219 # Build explicit paths dict for any paths that were provided 

220 explicit_paths = {} 

221 for key in ["reference_power", "reference_resource", "wind_farm_layout"]: 

222 if key in paths_config and paths_config[key] is not None: 

223 explicit_paths[key] = base_dir / paths_config[key] 

224 

225 # Infer missing paths from windIO structure 

226 resolved_paths = infer_paths_from_system_config( 

227 system_config_path=system_yaml_path, 

228 explicit_paths=explicit_paths, 

229 ) 

230 

231 # Validate all required paths exist 

232 validate_required_paths(resolved_paths) 

233 

234 # Extract resolved paths 

235 ref_power_path = resolved_paths["reference_power"] 

236 ref_resource_path = resolved_paths["reference_resource"] 

237 wf_layout_path = resolved_paths["wind_farm_layout"] 

238 

239 output_dir = base_dir / paths_config["output_dir"] 

240 output_dir.mkdir(parents=True, exist_ok=True) 

241 processed_resource_path = output_dir / paths_config["processed_resource_file"] 

242 database_path = output_dir / paths_config["database_file"] 

243 

244 print(f"Resolved output directory: {output_dir}") 

245 print(f"Resolved reference_power: {ref_power_path}") 

246 print(f"Resolved reference_resource: {ref_resource_path}") 

247 print(f"Resolved wind_farm_layout: {wf_layout_path}") 

248 print("Running in SINGLE-FARM mode") 

249 

250 # === 1. PREPROCESSING STEP === 

251 if config["preprocessing"]["run"]: 

252 print("--- Running Preprocessing ---") 

253 preprocessor = PreprocessingInputs( 

254 ref_resource_path=ref_resource_path, 

255 output_path=processed_resource_path, 

256 steps=config["preprocessing"].get("steps", []), 

257 ) 

258 preprocessor.run_pipeline() 

259 print("Preprocessing complete.") 

260 else: 

261 print("--- Skipping Preprocessing (as per config) ---") 

262 processed_resource_path = ref_resource_path 

263 if not processed_resource_path.exists(): 

264 raise FileNotFoundError( 

265 f"Input resource file not found: {processed_resource_path}" 

266 ) 

267 print(f"Using raw resource file: {processed_resource_path.name}") 

268 

269 # === 2. DATABASE GENERATION STEP === 

270 if config["database_gen"]["run"]: 

271 print("--- Running Database Generation ---") 

272 param_config = config["database_gen"]["param_config"] 

273 

274 db_generator = DatabaseGenerator( 

275 nsamples=config["database_gen"]["n_samples"], 

276 param_config=param_config, 

277 system_yaml_path=system_yaml_path, 

278 ref_power_path=ref_power_path, 

279 processed_resource_path=processed_resource_path, 

280 wf_layout_path=wf_layout_path, 

281 output_db_path=database_path, 

282 model=config["database_gen"]["flow_model"], 

283 ) 

284 database = db_generator.generate_database() 

285 print("Database generation complete.") 

286 else: 

287 print("--- Loading Existing Database (as per config) ---") 

288 if not database_path.exists(): 

289 raise FileNotFoundError( 

290 f"Database file not found at {database_path}. " 

291 "Set 'database_gen.run = true' to generate it." 

292 ) 

293 database = xr.load_dataset(database_path) 

294 print(f"Database loaded from {database_path}") 

295 

296 # Continue with error prediction... 

297 return _run_error_prediction(config, database, output_dir) 

298 

299 

300def _run_multi_farm_workflow(config: dict, base_dir: Path): 

301 """ 

302 Multi-farm workflow - processes multiple farms and combines results. 

303 """ 

304 paths_config = config.get("paths", {}) 

305 

306 # Get farm configurations 

307 farms_config = config.get("farms") or paths_config.get("farms") 

308 if not farms_config: 

309 raise ValueError("Multi-farm config requires 'farms' list") 

310 

311 # Validate farm configs 

312 _validate_farm_configs(farms_config) 

313 

314 # Resolve output directory 

315 output_dir = base_dir / paths_config.get("output_dir", "multi_farm_results") 

316 output_dir.mkdir(parents=True, exist_ok=True) 

317 

318 database_filename = paths_config.get("database_file", "results_stacked_hh.nc") 

319 database_path = output_dir / database_filename 

320 

321 print(f"Resolved output directory: {output_dir}") 

322 print(f"Running in MULTI-FARM mode with {len(farms_config)} farms") 

323 

324 # Resolve paths for each farm 

325 resolved_farms = [_resolve_farm_paths(farm, base_dir) for farm in farms_config] 

326 

327 # Print farm summary 

328 print("\nFarms to process:") 

329 for farm in resolved_farms: 

330 print(f" - {farm['name']}: {farm['system_config']}") 

331 

332 # === DATABASE GENERATION === 

333 if config["database_gen"]["run"]: 

334 print("\n--- Running Multi-Farm Database Generation ---") 

335 

336 database = generate_multi_farm_database( 

337 farm_configs=resolved_farms, 

338 param_config=config["database_gen"]["param_config"], 

339 n_samples=config["database_gen"]["n_samples"], 

340 output_dir=output_dir, 

341 database_file=database_filename, 

342 model=config["database_gen"]["flow_model"], 

343 preprocessing_steps=config["preprocessing"].get("steps", []), 

344 run_preprocessing=config["preprocessing"].get("run", True), 

345 ) 

346 

347 print("Multi-farm database generation complete.") 

348 else: 

349 print("--- Loading Existing Database ---") 

350 if not database_path.exists(): 

351 raise FileNotFoundError( 

352 f"Database file not found at {database_path}. " 

353 "Set 'database_gen.run = true' to generate it." 

354 ) 

355 database = xr.load_dataset(database_path) 

356 print(f"Database loaded from {database_path}") 

357 print(f"Contains {len(np.unique(database.wind_farm.values))} farms") 

358 

359 # Continue with error prediction... 

360 return _run_error_prediction(config, database, output_dir) 

361 

362 

363def _run_error_prediction(config: dict, database: xr.Dataset, output_dir: Path): 

364 """ 

365 Run error prediction, sensitivity analysis, and physics insights. 

366 

367 Shared between single-farm and multi-farm workflows. 

368 """ 

369 sa_config = config.get("sensitivity_analysis", {}) 

370 err_config = config["error_prediction"] 

371 physics_config = config.get("physics_insights", {}) 

372 model_name = err_config.get("model", "XGB") 

373 model_params = err_config.get("model_params", {}) 

374 

375 # --- OBSERVATION SENSITIVITY --- 

376 if sa_config.get("run_observation_sensitivity", False): 

377 print(f"--- Running Observation Sensitivity for model: {model_name} ---") 

378 obs_pipeline, obs_model_type = build_predictor_pipeline( 

379 model_name, model_params 

380 ) 

381 

382 run_observation_sensitivity( 

383 database=database, 

384 features_list=err_config["features"], 

385 ml_pipeline=obs_pipeline, 

386 model_type=obs_model_type, 

387 output_dir=output_dir, 

388 method=sa_config.get("method", "auto"), 

389 pce_config=sa_config.get("pce_config", {}), 

390 ) 

391 else: 

392 print("--- Skipping Observation Sensitivity (as per config) ---") 

393 

394 # === ERROR PREDICTION / UQ STEP === 

395 fitted_model = None 

396 fitted_calibrator = None 

397 y_bias_all = None 

398 

399 if err_config["run"]: 

400 print("--- Running Error Prediction ---") 

401 

402 ml_pipeline, model_type = build_predictor_pipeline(model_name, model_params) 

403 

404 calibrator_name = err_config["calibrator"] 

405 calibration_mode = CALIBRATION_MODES.get(calibrator_name, "global") 

406 Calibrator_cls = get_class_from_map(err_config["calibrator"]) 

407 Predictor_cls = get_class_from_map(err_config["bias_predictor"]) 

408 MainPipeline_cls = MainPipeline 

409 

410 print( 

411 f"Running cross-validation with calibrator: {Calibrator_cls.__name__} " 

412 f"and predictor: {model_name}" 

413 ) 

414 

415 cv_df, y_preds, y_tests = run_cross_validation( 

416 xr_data=database, 

417 ML_pipeline=ml_pipeline, 

418 model_type=model_type, 

419 Calibrator_cls=Calibrator_cls, 

420 BiasPredictor_cls=Predictor_cls, 

421 MainPipeline_cls=MainPipeline_cls, 

422 cv_config=err_config["cross_validation"], 

423 features_list=err_config["features"], 

424 output_dir=output_dir, 

425 sa_config=sa_config, 

426 calibration_mode=calibration_mode, 

427 local_regressor=err_config.get("local_regressor"), 

428 local_regressor_params=err_config.get("local_regressor_params", {}), 

429 ) 

430 

431 print("--- Cross-Validation Results (mean) ---") 

432 print(cv_df.mean()) 

433 

434 # Save results 

435 cv_df.to_csv(output_dir / "cv_results.csv") 

436 np.savez( 

437 output_dir / "predictions.npz", 

438 y_preds=np.array(y_preds, dtype=object), 

439 y_tests=np.array(y_tests, dtype=object), 

440 ) 

441 print(f"Results saved to {output_dir}") 

442 

443 # --- FIT FINAL MODEL FOR PHYSICS INSIGHTS --- 

444 # Re-fit on all data for physics insights analysis 

445 if physics_config.get("run", False): 

446 print("--- Fitting Final Model for Physics Insights ---") 

447 

448 # Build fresh pipeline 

449 final_pipeline, _ = build_predictor_pipeline(model_name, model_params) 

450 

451 # Prepare full dataset 

452 X_df = database.isel(sample=0).to_dataframe().reset_index() 

453 features = err_config["features"] 

454 X = X_df[features] 

455 

456 # Fit calibrator on full data 

457 if calibration_mode == "local": 

458 fitted_calibrator = Calibrator_cls( 

459 database, 

460 feature_names=features, 

461 regressor_name=err_config.get("local_regressor"), 

462 regressor_params=err_config.get("local_regressor_params", {}), 

463 ) 

464 else: 

465 fitted_calibrator = Calibrator_cls(database) 

466 fitted_calibrator.fit() 

467 

468 # Get bias values at calibrated parameters 

469 if calibration_mode == "local": 

470 optimal_indices = fitted_calibrator.get_optimal_indices() 

471 y_bias_all = np.array( 

472 [ 

473 float( 

474 database["model_bias_cap"] 

475 .isel(case_index=i, sample=idx) 

476 .values 

477 ) 

478 for i, idx in enumerate(optimal_indices) 

479 ] 

480 ) 

481 else: 

482 best_idx = fitted_calibrator.best_idx_ 

483 y_bias_all = database["model_bias_cap"].sel(sample=best_idx).values 

484 

485 # Fit final model 

486 final_pipeline.fit(X, y_bias_all) 

487 fitted_model = final_pipeline 

488 

489 # === PHYSICS INSIGHTS === 

490 if physics_config.get("run", False) and fitted_model is not None: 

491 print("--- Running Physics Insights Analysis ---") 

492 from wifa_uq.postprocessing.physics_insights import run_physics_insights 

493 

494 insights_dir = output_dir / "physics_insights" 

495 insights_dir.mkdir(exist_ok=True) 

496 

497 run_physics_insights( 

498 database=database, 

499 fitted_model=fitted_model, 

500 calibrator=fitted_calibrator, 

501 features_list=err_config["features"], 

502 y_bias=y_bias_all, 

503 output_dir=insights_dir, 

504 config=physics_config, 

505 ) 

506 

507 print("Physics insights analysis complete.") 

508 

509 print("--- Workflow complete ---") 

510 return cv_df, y_preds, y_tests 

511 

512 print("--- Workflow complete (no error prediction) ---") 

513 return None, None, None