Coverage for wifa_uq / workflow.py: 93%
201 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-19 02:10 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-19 02:10 +0000
1# wifa_uq/workflow.py
2"""
3Main workflow orchestration for WIFA-UQ.
5Supports single-farm and multi-farm configurations with optional physics insights.
6"""
8import yaml
9import xarray as xr
10from pathlib import Path
11from sklearn.pipeline import Pipeline
12from sklearn.preprocessing import StandardScaler
13import xgboost as xgb
14import numpy as np
16from wifa_uq.model_error_database.multi_farm_gen import generate_multi_farm_database
17from wifa_uq.preprocessing.preprocessing import PreprocessingInputs
18from wifa_uq.postprocessing.error_predictor.error_predictor import PCERegressor
19from wifa_uq.model_error_database.database_gen import DatabaseGenerator
20from wifa_uq.postprocessing.error_predictor.error_predictor import (
21 BiasPredictor,
22 MainPipeline,
23 run_cross_validation,
24 run_observation_sensitivity,
25 SIRPolynomialRegressor,
26)
28# from wifa_uq.postprocessing.bayesian_calibration import BayesianCalibrationWrapper
29from wifa_uq.postprocessing.calibration import (
30 MinBiasCalibrator,
31 DefaultParams,
32 LocalParameterPredictor,
33)
35# --- Dynamic Class Loading ---
36CLASS_MAP = {
37 # Calibrators
38 "MinBiasCalibrator": MinBiasCalibrator,
39 "DefaultParams": DefaultParams,
40 "LocalParameterPredictor": LocalParameterPredictor,
41 # Bayesian
42 # "BayesianCalibration": BayesianCalibrationWrapper,
43 # Predictors
44 "BiasPredictor": BiasPredictor,
45 # ML Models
46 "XGBRegressor": xgb.XGBRegressor,
47 "SIRPolynomialRegressor": SIRPolynomialRegressor,
48}
50CALIBRATION_MODES = {
51 "MinBiasCalibrator": "global",
52 "DefaultParams": "global",
53 "LocalParameterPredictor": "local",
54 "BayesianCalibration": "global",
55}
58def get_class_from_map(class_name: str):
59 if class_name not in CLASS_MAP:
60 raise ValueError(
61 f"Unknown class '{class_name}' in config. "
62 f"Available classes are: {list(CLASS_MAP.keys())}"
63 )
64 return CLASS_MAP[class_name]
67def build_predictor_pipeline(model_name: str, model_params: dict | None = None):
68 """
69 Factory function to build the predictor pipeline based on config.
70 Returns the pipeline and a 'model_type' string for SHAP logic.
71 """
72 if model_params is None:
73 model_params = {}
75 if model_name == "Linear":
76 from wifa_uq.postprocessing.error_predictor.error_predictor import (
77 LinearRegressor,
78 )
80 print(
81 f"Building Linear Regressor (method={model_params.get('method', 'ols')})..."
82 )
83 pipeline = LinearRegressor(**model_params)
84 model_type = "linear"
85 return pipeline, model_type
87 if model_name == "XGB":
88 print("Building XGBoost Regressor pipeline...")
89 xgb_params = {
90 "max_depth": model_params.get("max_depth", 3),
91 "n_estimators": model_params.get("n_estimators", 500),
92 "learning_rate": model_params.get("learning_rate", 0.1),
93 "random_state": model_params.get("random_state", 42),
94 }
95 pipeline = Pipeline(
96 [
97 ("scaler", StandardScaler()),
98 ("model", xgb.XGBRegressor(**xgb_params)),
99 ]
100 )
101 model_type = "tree"
103 elif model_name == "SIRPolynomial":
104 print("Building SIR+Polynomial Regressor pipeline...")
105 pipeline = SIRPolynomialRegressor(n_directions=1, degree=2)
106 model_type = "sir"
108 elif model_name == "PCE":
109 print("Building PCE Regressor pipeline...")
110 pipeline = PCERegressor(**model_params)
111 model_type = "pce"
112 else:
113 raise ValueError(
114 f"Unknown model '{model_name}' in config. "
115 f"Available models are: ['XGB', 'SIRPolynomial', 'PCE', 'Linear']"
116 )
117 return pipeline, model_type
120def _is_multi_farm_config(config: dict) -> bool:
121 """Check if config specifies multiple farms."""
122 return "farms" in config.get("paths", {}) or "farms" in config
125def _validate_farm_configs(farms: list[dict]) -> None:
126 """
127 Validate farm configurations.
129 Each farm must have:
130 - name: Unique identifier for cross-validation grouping
131 - system_config: Path to wind energy system YAML
132 """
133 required_keys = {"name", "system_config"}
134 names_seen = set()
136 for i, farm in enumerate(farms):
137 # Check required keys
138 missing = required_keys - set(farm.keys())
139 if missing:
140 raise ValueError(
141 f"Farm #{i + 1} is missing required keys: {missing}. "
142 f"Each farm must have 'name' and 'system_config'."
143 )
145 # Check for duplicate names
146 name = farm["name"]
147 if name in names_seen:
148 raise ValueError(
149 f"Duplicate farm name: '{name}'. Each farm must have a unique name."
150 )
151 names_seen.add(name)
153 print(f"Validated {len(farms)} farm configurations")
156def _resolve_farm_paths(farm_config: dict, base_dir: Path) -> dict:
157 """
158 Resolve relative paths in a farm config to absolute paths.
160 Required keys:
161 - name: Farm identifier (passed through as-is)
162 - system_config: Path to wind energy system YAML
164 Optional keys (for explicit path overrides):
165 - reference_power: Path to reference power NetCDF
166 - reference_resource: Path to reference resource NetCDF
167 - wind_farm_layout: Path to wind farm layout YAML
168 """
169 resolved = {"name": farm_config["name"]}
171 path_keys = [
172 "system_config",
173 "reference_power",
174 "reference_resource",
175 "wind_farm_layout",
176 ]
178 for key in path_keys:
179 if key in farm_config:
180 resolved[key] = base_dir / farm_config[key]
182 return resolved
185def run_workflow(config_path: str | Path):
186 """
187 Runs the full WIFA-UQ workflow from a configuration file.
189 Supports both single-farm and multi-farm configurations.
190 """
191 config_path = Path(config_path).resolve()
192 with open(config_path, "r") as f:
193 config = yaml.safe_load(f)
195 base_dir = config_path.parent
197 # Detect single vs multi-farm mode
198 is_multi_farm = _is_multi_farm_config(config)
200 if is_multi_farm:
201 return _run_multi_farm_workflow(config, base_dir)
202 else:
203 return _run_single_farm_workflow(config, base_dir)
206def _run_single_farm_workflow(config: dict, base_dir: Path):
207 """
208 Original single-farm workflow (existing implementation).
209 """
210 from wifa_uq.model_error_database.path_inference import (
211 infer_paths_from_system_config,
212 validate_required_paths,
213 )
215 # --- 0. Resolve Paths ---
216 paths_config = config["paths"]
217 system_yaml_path = base_dir / paths_config["system_config"]
219 # Build explicit paths dict for any paths that were provided
220 explicit_paths = {}
221 for key in ["reference_power", "reference_resource", "wind_farm_layout"]:
222 if key in paths_config and paths_config[key] is not None:
223 explicit_paths[key] = base_dir / paths_config[key]
225 # Infer missing paths from windIO structure
226 resolved_paths = infer_paths_from_system_config(
227 system_config_path=system_yaml_path,
228 explicit_paths=explicit_paths,
229 )
231 # Validate all required paths exist
232 validate_required_paths(resolved_paths)
234 # Extract resolved paths
235 ref_power_path = resolved_paths["reference_power"]
236 ref_resource_path = resolved_paths["reference_resource"]
237 wf_layout_path = resolved_paths["wind_farm_layout"]
239 output_dir = base_dir / paths_config["output_dir"]
240 output_dir.mkdir(parents=True, exist_ok=True)
241 processed_resource_path = output_dir / paths_config["processed_resource_file"]
242 database_path = output_dir / paths_config["database_file"]
244 print(f"Resolved output directory: {output_dir}")
245 print(f"Resolved reference_power: {ref_power_path}")
246 print(f"Resolved reference_resource: {ref_resource_path}")
247 print(f"Resolved wind_farm_layout: {wf_layout_path}")
248 print("Running in SINGLE-FARM mode")
250 # === 1. PREPROCESSING STEP ===
251 if config["preprocessing"]["run"]:
252 print("--- Running Preprocessing ---")
253 preprocessor = PreprocessingInputs(
254 ref_resource_path=ref_resource_path,
255 output_path=processed_resource_path,
256 steps=config["preprocessing"].get("steps", []),
257 )
258 preprocessor.run_pipeline()
259 print("Preprocessing complete.")
260 else:
261 print("--- Skipping Preprocessing (as per config) ---")
262 processed_resource_path = ref_resource_path
263 if not processed_resource_path.exists():
264 raise FileNotFoundError(
265 f"Input resource file not found: {processed_resource_path}"
266 )
267 print(f"Using raw resource file: {processed_resource_path.name}")
269 # === 2. DATABASE GENERATION STEP ===
270 if config["database_gen"]["run"]:
271 print("--- Running Database Generation ---")
272 param_config = config["database_gen"]["param_config"]
274 db_generator = DatabaseGenerator(
275 nsamples=config["database_gen"]["n_samples"],
276 param_config=param_config,
277 system_yaml_path=system_yaml_path,
278 ref_power_path=ref_power_path,
279 processed_resource_path=processed_resource_path,
280 wf_layout_path=wf_layout_path,
281 output_db_path=database_path,
282 model=config["database_gen"]["flow_model"],
283 )
284 database = db_generator.generate_database()
285 print("Database generation complete.")
286 else:
287 print("--- Loading Existing Database (as per config) ---")
288 if not database_path.exists():
289 raise FileNotFoundError(
290 f"Database file not found at {database_path}. "
291 "Set 'database_gen.run = true' to generate it."
292 )
293 database = xr.load_dataset(database_path)
294 print(f"Database loaded from {database_path}")
296 # Continue with error prediction...
297 return _run_error_prediction(config, database, output_dir)
300def _run_multi_farm_workflow(config: dict, base_dir: Path):
301 """
302 Multi-farm workflow - processes multiple farms and combines results.
303 """
304 paths_config = config.get("paths", {})
306 # Get farm configurations
307 farms_config = config.get("farms") or paths_config.get("farms")
308 if not farms_config:
309 raise ValueError("Multi-farm config requires 'farms' list")
311 # Validate farm configs
312 _validate_farm_configs(farms_config)
314 # Resolve output directory
315 output_dir = base_dir / paths_config.get("output_dir", "multi_farm_results")
316 output_dir.mkdir(parents=True, exist_ok=True)
318 database_filename = paths_config.get("database_file", "results_stacked_hh.nc")
319 database_path = output_dir / database_filename
321 print(f"Resolved output directory: {output_dir}")
322 print(f"Running in MULTI-FARM mode with {len(farms_config)} farms")
324 # Resolve paths for each farm
325 resolved_farms = [_resolve_farm_paths(farm, base_dir) for farm in farms_config]
327 # Print farm summary
328 print("\nFarms to process:")
329 for farm in resolved_farms:
330 print(f" - {farm['name']}: {farm['system_config']}")
332 # === DATABASE GENERATION ===
333 if config["database_gen"]["run"]:
334 print("\n--- Running Multi-Farm Database Generation ---")
336 database = generate_multi_farm_database(
337 farm_configs=resolved_farms,
338 param_config=config["database_gen"]["param_config"],
339 n_samples=config["database_gen"]["n_samples"],
340 output_dir=output_dir,
341 database_file=database_filename,
342 model=config["database_gen"]["flow_model"],
343 preprocessing_steps=config["preprocessing"].get("steps", []),
344 run_preprocessing=config["preprocessing"].get("run", True),
345 )
347 print("Multi-farm database generation complete.")
348 else:
349 print("--- Loading Existing Database ---")
350 if not database_path.exists():
351 raise FileNotFoundError(
352 f"Database file not found at {database_path}. "
353 "Set 'database_gen.run = true' to generate it."
354 )
355 database = xr.load_dataset(database_path)
356 print(f"Database loaded from {database_path}")
357 print(f"Contains {len(np.unique(database.wind_farm.values))} farms")
359 # Continue with error prediction...
360 return _run_error_prediction(config, database, output_dir)
363def _run_error_prediction(config: dict, database: xr.Dataset, output_dir: Path):
364 """
365 Run error prediction, sensitivity analysis, and physics insights.
367 Shared between single-farm and multi-farm workflows.
368 """
369 sa_config = config.get("sensitivity_analysis", {})
370 err_config = config["error_prediction"]
371 physics_config = config.get("physics_insights", {})
372 model_name = err_config.get("model", "XGB")
373 model_params = err_config.get("model_params", {})
375 # --- OBSERVATION SENSITIVITY ---
376 if sa_config.get("run_observation_sensitivity", False):
377 print(f"--- Running Observation Sensitivity for model: {model_name} ---")
378 obs_pipeline, obs_model_type = build_predictor_pipeline(
379 model_name, model_params
380 )
382 run_observation_sensitivity(
383 database=database,
384 features_list=err_config["features"],
385 ml_pipeline=obs_pipeline,
386 model_type=obs_model_type,
387 output_dir=output_dir,
388 method=sa_config.get("method", "auto"),
389 pce_config=sa_config.get("pce_config", {}),
390 )
391 else:
392 print("--- Skipping Observation Sensitivity (as per config) ---")
394 # === ERROR PREDICTION / UQ STEP ===
395 fitted_model = None
396 fitted_calibrator = None
397 y_bias_all = None
399 if err_config["run"]:
400 print("--- Running Error Prediction ---")
402 ml_pipeline, model_type = build_predictor_pipeline(model_name, model_params)
404 calibrator_name = err_config["calibrator"]
405 calibration_mode = CALIBRATION_MODES.get(calibrator_name, "global")
406 Calibrator_cls = get_class_from_map(err_config["calibrator"])
407 Predictor_cls = get_class_from_map(err_config["bias_predictor"])
408 MainPipeline_cls = MainPipeline
410 print(
411 f"Running cross-validation with calibrator: {Calibrator_cls.__name__} "
412 f"and predictor: {model_name}"
413 )
415 cv_df, y_preds, y_tests = run_cross_validation(
416 xr_data=database,
417 ML_pipeline=ml_pipeline,
418 model_type=model_type,
419 Calibrator_cls=Calibrator_cls,
420 BiasPredictor_cls=Predictor_cls,
421 MainPipeline_cls=MainPipeline_cls,
422 cv_config=err_config["cross_validation"],
423 features_list=err_config["features"],
424 output_dir=output_dir,
425 sa_config=sa_config,
426 calibration_mode=calibration_mode,
427 local_regressor=err_config.get("local_regressor"),
428 local_regressor_params=err_config.get("local_regressor_params", {}),
429 )
431 print("--- Cross-Validation Results (mean) ---")
432 print(cv_df.mean())
434 # Save results
435 cv_df.to_csv(output_dir / "cv_results.csv")
436 np.savez(
437 output_dir / "predictions.npz",
438 y_preds=np.array(y_preds, dtype=object),
439 y_tests=np.array(y_tests, dtype=object),
440 )
441 print(f"Results saved to {output_dir}")
443 # --- FIT FINAL MODEL FOR PHYSICS INSIGHTS ---
444 # Re-fit on all data for physics insights analysis
445 if physics_config.get("run", False):
446 print("--- Fitting Final Model for Physics Insights ---")
448 # Build fresh pipeline
449 final_pipeline, _ = build_predictor_pipeline(model_name, model_params)
451 # Prepare full dataset
452 X_df = database.isel(sample=0).to_dataframe().reset_index()
453 features = err_config["features"]
454 X = X_df[features]
456 # Fit calibrator on full data
457 if calibration_mode == "local":
458 fitted_calibrator = Calibrator_cls(
459 database,
460 feature_names=features,
461 regressor_name=err_config.get("local_regressor"),
462 regressor_params=err_config.get("local_regressor_params", {}),
463 )
464 else:
465 fitted_calibrator = Calibrator_cls(database)
466 fitted_calibrator.fit()
468 # Get bias values at calibrated parameters
469 if calibration_mode == "local":
470 optimal_indices = fitted_calibrator.get_optimal_indices()
471 y_bias_all = np.array(
472 [
473 float(
474 database["model_bias_cap"]
475 .isel(case_index=i, sample=idx)
476 .values
477 )
478 for i, idx in enumerate(optimal_indices)
479 ]
480 )
481 else:
482 best_idx = fitted_calibrator.best_idx_
483 y_bias_all = database["model_bias_cap"].sel(sample=best_idx).values
485 # Fit final model
486 final_pipeline.fit(X, y_bias_all)
487 fitted_model = final_pipeline
489 # === PHYSICS INSIGHTS ===
490 if physics_config.get("run", False) and fitted_model is not None:
491 print("--- Running Physics Insights Analysis ---")
492 from wifa_uq.postprocessing.physics_insights import run_physics_insights
494 insights_dir = output_dir / "physics_insights"
495 insights_dir.mkdir(exist_ok=True)
497 run_physics_insights(
498 database=database,
499 fitted_model=fitted_model,
500 calibrator=fitted_calibrator,
501 features_list=err_config["features"],
502 y_bias=y_bias_all,
503 output_dir=insights_dir,
504 config=physics_config,
505 )
507 print("Physics insights analysis complete.")
509 print("--- Workflow complete ---")
510 return cv_df, y_preds, y_tests
512 print("--- Workflow complete (no error prediction) ---")
513 return None, None, None