Coverage for wifa_uq / postprocessing / error_predictor / error_predictor.py: 89%
722 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-19 02:10 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-19 02:10 +0000
1import numpy as np
2import pandas as pd
3from pathlib import Path
4import matplotlib.pyplot as plt
5import shap
7from sklearn.base import BaseEstimator, RegressorMixin
8from sklearn.pipeline import Pipeline
9from sklearn.preprocessing import StandardScaler, PolynomialFeatures
10from sklearn.linear_model import LinearRegression
11from sklearn.exceptions import NotFittedError
13import xgboost as xgb
14from sklearn.model_selection import KFold, LeaveOneGroupOut
15from sklearn.metrics import mean_absolute_error, r2_score
16from wifa_uq.postprocessing.PCE_tool.pce_utils import construct_PCE_ot
17from wifa_uq.postprocessing.calibration import (
18 MinBiasCalibrator,
19 # DefaultParams,
20 # LocalParameterPredictor,
21)
22from sliced import SlicedInverseRegression
24"""
25This script contains:
26- SIRPolynomialRegressor class (NEW)
27- Calibrator classes
28- BiasPredictor class
29- MainPipeline class
30- Cross validation routine
31- SHAP/SIR sensitivity analysis functions
32- Multi-farm CV visualization functions (NEW)
33"""
36class PCERegressor(BaseEstimator, RegressorMixin):
37 """
38 A scikit-learn compatible regressor that wraps the OpenTURNS-based PCE
39 from PCE_tool.
41 Safety guard:
42 - By default, refuses to run if the number of input features > max_features
43 unless allow_high_dim=True is explicitly set.
44 """
46 def __init__(
47 self,
48 degree=5,
49 marginals="kernel",
50 copula="independent",
51 q=1.0,
52 max_features=5, # safety limit on input dimension
53 allow_high_dim=False, # must be True to allow > max_features
54 ):
55 self.degree = degree
56 self.marginals = marginals
57 self.copula = copula
58 self.q = q
59 self.max_features = max_features
60 self.allow_high_dim = allow_high_dim
62 def fit(self, X, y):
63 X = np.asarray(X)
64 y = np.asarray(y).ravel()
66 n_features = X.shape[1]
68 # --- Safety check on dimensionality ---
69 if n_features > self.max_features and not self.allow_high_dim:
70 raise ValueError(
71 f"PCERegressor refused to run: number of input variables = {n_features}, "
72 f"which exceeds the default safety limit of {self.max_features}. "
73 f"Set allow_high_dim=True or increase max_features to override."
74 )
76 marginals = [self.marginals] * n_features
78 # Construct PCE using your existing helper
79 self.pce_result_ = construct_PCE_ot(
80 input_array=X,
81 output_array=y,
82 marginals=marginals,
83 copula=self.copula,
84 degree=self.degree,
85 q=self.q,
86 )
87 self.metamodel_ = self.pce_result_.getMetaModel()
88 return self
90 def predict(self, X):
91 if not hasattr(self, "metamodel_"):
92 raise NotFittedError("PCERegressor instance is not fitted yet.")
93 X = np.asarray(X)
94 preds = np.zeros(X.shape[0])
95 for i, xi in enumerate(X):
96 preds[i] = self.metamodel_(xi)[0]
97 return preds
100class LinearRegressor(BaseEstimator, RegressorMixin):
101 """
102 A simple linear regression wrapper with optional regularization.
104 Supports: 'ols' (ordinary least squares), 'ridge', 'lasso', 'elasticnet'
105 """
107 def __init__(self, method="ols", alpha=1.0, l1_ratio=0.5):
108 self.method = method
109 self.alpha = alpha
110 self.l1_ratio = l1_ratio # Only used for elasticnet
112 def fit(self, X, y):
113 from sklearn.linear_model import LinearRegression, Ridge, Lasso, ElasticNet
115 if self.method == "ols":
116 self.model_ = LinearRegression()
117 elif self.method == "ridge":
118 self.model_ = Ridge(alpha=self.alpha)
119 elif self.method == "lasso":
120 self.model_ = Lasso(alpha=self.alpha)
121 elif self.method == "elasticnet":
122 self.model_ = ElasticNet(alpha=self.alpha, l1_ratio=self.l1_ratio)
123 else:
124 raise ValueError(
125 f"Unknown method '{self.method}'. Use 'ols', 'ridge', 'lasso', or 'elasticnet'."
126 )
128 self.scaler_ = StandardScaler()
129 X_scaled = self.scaler_.fit_transform(X)
130 self.model_.fit(X_scaled, y)
131 return self
133 def predict(self, X):
134 if not hasattr(self, "model_"):
135 raise NotFittedError("LinearRegressor is not fitted yet.")
136 X_scaled = self.scaler_.transform(X)
137 return self.model_.predict(X_scaled)
139 def get_feature_importance(self, feature_names):
140 """Return absolute coefficients as feature importance."""
141 if not hasattr(self, "model_"):
142 raise NotFittedError("LinearRegressor is not fitted yet.")
143 return pd.Series(np.abs(self.model_.coef_), index=feature_names)
146## ------------------------------------------------------------------ ##
147## NEW REGRESSOR CLASS
148## ------------------------------------------------------------------ ##
149class SIRPolynomialRegressor(BaseEstimator, RegressorMixin):
150 """
151 A scikit-learn compatible regressor that first applies SIR for
152 dimension reduction and then fits a polynomial regression
153 on the reduced dimension(s).
154 """
156 def __init__(self, n_directions=1, degree=2):
157 self.n_directions = n_directions
158 self.degree = degree
160 def fit(self, X, y):
161 # 1. Standard Scaler
162 self.scaler_ = StandardScaler()
163 X_scaled = self.scaler_.fit_transform(X, y)
165 # 2. Sliced Inverse Regression
166 self.sir_ = SlicedInverseRegression(n_directions=self.n_directions)
167 # Sliced package expects y as a 1D array
168 y_ravel = np.ravel(y)
169 X_sir = self.sir_.fit_transform(X_scaled, y_ravel)
171 # Store the directions (feature importance)
172 # We take the absolute value for importance ranking
173 self.sir_directions_ = np.abs(self.sir_.directions_[0, :])
175 # 3. Polynomial Regression
176 self.poly_reg_ = Pipeline(
177 [
178 ("poly", PolynomialFeatures(degree=self.degree, include_bias=False)),
179 ("lin_reg", LinearRegression()),
180 ]
181 )
182 self.poly_reg_.fit(X_sir, y)
184 return self
186 def predict(self, X):
187 if not hasattr(self, "scaler_"):
188 raise NotFittedError(
189 "This SIRPolynomialRegressor instance is not fitted yet."
190 )
192 X_scaled = self.scaler_.transform(X)
193 X_sir = self.sir_.transform(X_scaled)
194 return self.poly_reg_.predict(X_sir)
196 def get_feature_importance(self, feature_names):
197 if not hasattr(self, "sir_directions_"):
198 raise NotFittedError(
199 "This SIRPolynomialRegressor instance is not fitted yet."
200 )
202 return pd.Series(self.sir_directions_, index=feature_names)
205## ------------------------------------------------------------------ ##
208class BiasPredictor:
209 """
210 Predict bias as a function of features and parameter samples
211 """
213 def __init__(self, regressor_pipeline):
214 self.pipeline = regressor_pipeline
216 def fit(self, X_train, y_train):
217 # print('shape of X and Y training: ',X_train.shape, y_train.shape)
218 self.pipeline.fit(X_train, y_train)
219 return self
221 def predict(self, X_test):
222 y_pred = self.pipeline.predict(X_test)
223 return y_pred
226class MainPipeline:
227 """
228 Main pipeline that combines calibration and bias prediction.
230 Supports two calibration modes:
231 1. Global calibration (MinBiasCalibrator, DefaultParams):
232 - Single parameter set for all cases
233 - calibrator.best_idx_ gives the sample index
235 2. Local calibration (LocalParameterPredictor):
236 - Different optimal parameters per case
237 - calibrator.get_optimal_indices() gives per-case indices
238 """
240 def __init__(
241 self,
242 calibrator,
243 bias_predictor,
244 features_list: list,
245 calibration_mode: str = "global",
246 ):
247 """
248 Args:
249 calibrator: Calibrator instance (already initialized with dataset_train)
250 bias_predictor: BiasPredictor instance
251 features_list: List of feature names to use for bias prediction
252 calibration_mode: "global" or "local"
253 """
254 self.calibrator = calibrator
255 self.bias_predictor = bias_predictor
256 self.features_list = features_list
257 self.calibration_mode = calibration_mode
259 if not self.features_list:
260 raise ValueError("features_list cannot be empty.")
262 def fit(self, dataset_train, dataset_test):
263 """
264 Fit the calibrator and bias predictor.
266 Returns:
267 X_test, y_test, idxs (for compatibility with cross-validation)
268 """
269 # 1. Fit calibrator
270 self.calibrator.fit()
272 if self.calibration_mode == "global":
273 return self._fit_global(dataset_train, dataset_test)
274 elif self.calibration_mode == "local":
275 return self._fit_local(dataset_train, dataset_test)
276 else:
277 raise ValueError(f"Unknown calibration_mode: {self.calibration_mode}")
279 def _fit_global(self, dataset_train, dataset_test):
280 """Fit using global calibration (single parameter set)."""
281 idxs = self.calibrator.best_idx_
283 # Select the calibrated sample for train and test
284 dataset_train_cal = dataset_train.sel(sample=idxs)
285 dataset_test_cal = dataset_test.sel(sample=idxs)
287 # Prepare features
288 X_train_df = dataset_train_cal.to_dataframe().reset_index()
289 X_test_df = dataset_test_cal.to_dataframe().reset_index()
291 X_train = self._extract_features(X_train_df)
292 X_test = self._extract_features(X_test_df)
294 y_train = dataset_train_cal["model_bias_cap"].values
295 y_test = dataset_test_cal["model_bias_cap"].values
297 # Fit bias predictor
298 self.bias_predictor.fit(X_train, y_train)
300 # Store for predict()
301 self.X_test_ = X_test
303 return X_test, y_test, idxs
305 def _fit_local(self, dataset_train, dataset_test):
306 """Fit using local calibration (per-case optimal parameters)."""
307 # 1. Get per-case optimal sample indices from the local calibrator
308 train_optimal_indices = self.calibrator.get_optimal_indices()
309 n_train_cases = len(dataset_train.case_index)
311 # 2. Build training feature matrix from sample=0
312 # (features do not depend on sampled parameters)
313 train_base = dataset_train.isel(sample=0)
314 train_df = train_base.to_dataframe().reset_index()
315 X_train = self._extract_features(train_df)
317 # 3. Build training targets: bias at the optimal sample for each case
318 y_train = np.zeros(n_train_cases)
319 for case_idx, sample_idx in enumerate(train_optimal_indices):
320 y_train[case_idx] = float(
321 dataset_train["model_bias_cap"]
322 .isel(case_index=case_idx, sample=sample_idx)
323 .values
324 )
326 # 4. Test features (also from sample=0)
327 X_test_df = dataset_test.isel(sample=0).to_dataframe().reset_index()
328 X_test_features = self._extract_features(X_test_df)
330 # 5. Predict optimal parameters for test cases,
331 # then find the closest sampled parameter set in the database
332 predicted_params = self.calibrator.predict(X_test_features)
333 test_optimal_indices = self._find_closest_samples(
334 dataset_test, predicted_params
335 )
337 # 6. Build test targets: bias at the chosen sample for each test case
338 n_test_cases = len(dataset_test.case_index)
339 y_test = np.zeros(n_test_cases)
340 for case_idx, sample_idx in enumerate(test_optimal_indices):
341 y_test[case_idx] = float(
342 dataset_test["model_bias_cap"]
343 .isel(case_index=case_idx, sample=sample_idx)
344 .values
345 )
347 # 7. Fit the bias predictor on per-case data
348 self.bias_predictor.fit(X_train, y_train)
350 # Store X_test_ so .predict() can be called without args
351 self.X_test_ = X_test_features
353 # Return in the same shape run_cross_validation expects
354 return X_test_features, y_test, test_optimal_indices
356 def _extract_features(self, df):
357 """Extract and clean features from dataframe."""
358 try:
359 X = df[self.features_list].copy()
360 except KeyError as e:
361 print(f"Error: Feature not found in dataset: {e}")
362 print(f"Available columns: {list(df.columns)}")
363 raise
365 # Clean string-like columns
366 for col in X.columns:
367 if X[col].dtype == "object":
368 if X[col].dropna().empty:
369 continue
370 first_item = X[col].dropna().iloc[0]
371 if isinstance(first_item, str):
372 X[col] = X[col].str.replace(r"[\[\]]", "", regex=True).astype(float)
373 else:
374 X[col] = X[col].astype(float)
376 return X
378 def _find_closest_samples(self, dataset, predicted_params):
379 """
380 Find the sample index closest to predicted parameters for each case.
382 Args:
383 dataset: xarray Dataset with 'sample' dimension
384 predicted_params: DataFrame with predicted optimal parameters
386 Returns:
387 Array of sample indices (one per case)
389 Raises:
390 ValueError: If no valid parameters found for distance calculation
391 """
392 n_cases = len(predicted_params)
393 n_samples = len(dataset.sample)
394 swept_params = self.calibrator.swept_params
396 # Validation 1: Check swept_params is not empty
397 if not swept_params:
398 raise ValueError(
399 "No swept parameters defined. Cannot find closest samples. "
400 "Check that the database has 'swept_params' in attrs or that "
401 "parameters were correctly inferred."
402 )
404 # Validation 2: Check which parameters are actually available
405 available_params = []
406 missing_in_dataset = []
407 missing_in_predictions = []
409 for param_name in swept_params:
410 in_dataset = param_name in dataset.coords
411 in_predictions = param_name in predicted_params.columns
413 if in_dataset and in_predictions:
414 available_params.append(param_name)
415 elif not in_dataset:
416 missing_in_dataset.append(param_name)
417 elif not in_predictions:
418 missing_in_predictions.append(param_name)
420 # Validation 3: Ensure we have at least one parameter to use
421 if not available_params:
422 raise ValueError(
423 f"No valid parameters for distance calculation.\n"
424 f" Swept params: {swept_params}\n"
425 f" Missing in dataset.coords: {missing_in_dataset}\n"
426 f" Missing in predicted_params: {missing_in_predictions}"
427 )
429 # Warn about partial matches (some params missing)
430 if missing_in_dataset or missing_in_predictions:
431 import warnings
433 warnings.warn(
434 f"Some swept parameters unavailable for distance calculation:\n"
435 f" Using: {available_params}\n"
436 f" Missing in dataset: {missing_in_dataset}\n"
437 f" Missing in predictions: {missing_in_predictions}",
438 UserWarning,
439 )
441 # Calculate distances using only available parameters
442 closest_indices = np.zeros(n_cases, dtype=int)
444 for case_idx in range(n_cases):
445 target_params = predicted_params.iloc[case_idx]
447 distances = np.zeros(n_samples)
448 for param_name in available_params:
449 sample_values = dataset.coords[param_name].values
450 target_value = target_params[param_name]
452 # Normalize by parameter range to handle different scales
453 param_range = sample_values.max() - sample_values.min()
454 if param_range > 0:
455 normalized_diff = (sample_values - target_value) / param_range
456 else:
457 # All samples have same value for this param
458 normalized_diff = np.zeros_like(sample_values)
460 distances += normalized_diff**2
462 closest_indices[case_idx] = int(np.argmin(distances))
464 return closest_indices
466 def predict(self, X=None):
467 """Predict bias for test data."""
468 if X is None:
469 X = self.X_test_
470 return self.bias_predictor.predict(X)
473def compute_metrics(y_true, bias_samples, pw, ref, data_driv=None):
474 mse = ((y_true - bias_samples) ** 2).mean()
475 rmse = np.sqrt(mse)
476 mae = mean_absolute_error(y_true, bias_samples)
477 r2 = r2_score(y_true, bias_samples)
479 if pw is not None and ref is not None and data_driv is None:
480 pw_bias = np.mean(pw - ref)
481 pw_bias_corrected = np.mean((pw - bias_samples) - ref)
482 else:
483 pw_bias = None
484 pw_bias_corrected = None
486 if data_driv and ref:
487 data_driv_bias = np.mean(data_driv - ref)
488 else:
489 data_driv_bias = None
491 return {
492 "rmse": rmse,
493 "mse": mse,
494 "mae": mae,
495 "r2": r2,
496 "pw_bias": pw_bias,
497 "pw_bias_corrected": pw_bias_corrected,
498 "data_driv_bias": data_driv_bias,
499 }
502## ------------------------------------------------------------------ ##
503## MULTI-FARM CV VISUALIZATION FUNCTIONS (NEW)
504## ------------------------------------------------------------------ ##
507def plot_multi_farm_cv_metrics(
508 cv_results: pd.DataFrame,
509 fold_labels: list,
510 output_dir: Path,
511 splitting_mode: str = "LeaveOneGroupOut",
512):
513 """
514 Create visualizations for multi-farm cross-validation results.
516 Shows per-fold (per-group) performance metrics to understand
517 how well the model generalizes across different wind farms.
519 Args:
520 cv_results: DataFrame with metrics per fold (rmse, r2, mae, etc.)
521 fold_labels: List of strings identifying each fold (e.g., group names left out)
522 output_dir: Directory to save plots
523 splitting_mode: CV splitting mode for title annotation
524 """
525 output_dir = Path(output_dir)
526 n_folds = len(cv_results)
528 # --- 1. Per-Fold Metrics Bar Chart ---
529 fig, axes = plt.subplots(1, 3, figsize=(15, 5))
530 fig.suptitle(
531 f"Cross-Validation Performance by Fold ({splitting_mode})", fontsize=14
532 )
534 metrics = ["rmse", "r2", "mae"]
535 colors = ["#1f77b4", "#2ca02c", "#ff7f0e"]
537 x = np.arange(n_folds)
539 for ax, metric, color in zip(axes, metrics, colors):
540 values = cv_results[metric].values
541 bars = ax.bar(x, values, color=color, alpha=0.7, edgecolor="black")
543 # Add value labels on bars
544 for bar, val in zip(bars, values):
545 height = bar.get_height()
546 ax.annotate(
547 f"{val:.4f}",
548 xy=(bar.get_x() + bar.get_width() / 2, height),
549 xytext=(0, 3),
550 textcoords="offset points",
551 ha="center",
552 va="bottom",
553 fontsize=8,
554 rotation=45,
555 )
557 # Add mean line
558 mean_val = values.mean()
559 ax.axhline(
560 mean_val,
561 color="red",
562 linestyle="--",
563 linewidth=2,
564 label=f"Mean: {mean_val:.4f}",
565 )
567 ax.set_xlabel("Fold (Left-Out Group)")
568 ax.set_ylabel(metric.upper())
569 ax.set_title(f"{metric.upper()} per Fold")
570 ax.set_xticks(x)
571 ax.set_xticklabels(fold_labels, rotation=45, ha="right", fontsize=8)
572 ax.legend(loc="best")
573 ax.grid(axis="y", alpha=0.3)
575 plt.tight_layout()
576 plt.savefig(output_dir / "cv_fold_metrics.png", dpi=150, bbox_inches="tight")
577 plt.close(fig)
578 print(f" Saved per-fold metrics plot to: {output_dir / 'cv_fold_metrics.png'}")
580 # --- 2. Metrics Comparison Heatmap ---
581 fig, ax = plt.subplots(figsize=(10, 6))
583 # Normalize metrics for heatmap (z-score within each metric)
584 metrics_for_heatmap = ["rmse", "mae", "r2"]
585 heatmap_data = cv_results[metrics_for_heatmap].copy()
587 # For r2, higher is better; for rmse/mae, lower is better
588 # Normalize so that "better" is always higher for visualization
589 heatmap_normalized = heatmap_data.copy()
590 heatmap_normalized["rmse"] = -heatmap_data["rmse"] # Negate so higher = better
591 heatmap_normalized["mae"] = -heatmap_data["mae"] # Negate so higher = better
593 # Create heatmap
594 im = ax.imshow(heatmap_normalized.T, cmap="RdYlGn", aspect="auto")
596 # Add colorbar
597 cbar = plt.colorbar(im, ax=ax)
598 cbar.set_label(
599 "Performance (normalized, higher = better)", rotation=270, labelpad=15
600 )
602 # Set ticks and labels
603 ax.set_xticks(np.arange(n_folds))
604 ax.set_xticklabels(fold_labels, rotation=45, ha="right")
605 ax.set_yticks(np.arange(len(metrics_for_heatmap)))
606 ax.set_yticklabels([m.upper() for m in metrics_for_heatmap])
608 # Add text annotations with actual values
609 for i in range(len(metrics_for_heatmap)):
610 for j in range(n_folds):
611 ax.text(
612 j,
613 i,
614 f"{heatmap_data.iloc[j, i]:.3f}",
615 ha="center",
616 va="center",
617 color="black",
618 fontsize=8,
619 )
621 ax.set_xlabel("Fold (Left-Out Group)")
622 ax.set_title("Performance Heatmap Across CV Folds\n(Green = Better Performance)")
624 plt.tight_layout()
625 plt.savefig(output_dir / "cv_fold_heatmap.png", dpi=150, bbox_inches="tight")
626 plt.close(fig)
627 print(f" Saved heatmap plot to: {output_dir / 'cv_fold_heatmap.png'}")
629 # --- 3. Box Plot Summary ---
630 fig, ax = plt.subplots(figsize=(8, 6))
632 # Prepare data for box plot
633 metrics_data = [cv_results[m].values for m in ["rmse", "mae", "r2"]]
635 bp = ax.boxplot(metrics_data, labels=["RMSE", "MAE", "R²"], patch_artist=True)
637 # Color the boxes
638 colors_box = ["#1f77b4", "#ff7f0e", "#2ca02c"]
639 for patch, color in zip(bp["boxes"], colors_box):
640 patch.set_facecolor(color)
641 patch.set_alpha(0.7)
643 # Add individual points
644 for i, (metric_vals, color) in enumerate(zip(metrics_data, colors_box)):
645 x_jitter = np.random.normal(i + 1, 0.04, size=len(metric_vals))
646 ax.scatter(
647 x_jitter, metric_vals, alpha=0.6, color=color, edgecolor="black", s=50
648 )
650 ax.set_ylabel("Metric Value")
651 ax.set_title(f"Distribution of CV Metrics Across {n_folds} Folds")
652 ax.grid(axis="y", alpha=0.3)
654 plt.tight_layout()
655 plt.savefig(output_dir / "cv_metrics_boxplot.png", dpi=150, bbox_inches="tight")
656 plt.close(fig)
657 print(f" Saved boxplot to: {output_dir / 'cv_metrics_boxplot.png'}")
660def plot_farm_wise_predictions(
661 y_tests: list,
662 y_preds: list,
663 fold_labels: list,
664 fold_farm_names: list,
665 output_dir: Path,
666):
667 """
668 Create scatter plots showing predictions vs true values, colored by farm/group.
670 Args:
671 y_tests: List of test targets per fold
672 y_preds: List of predictions per fold
673 fold_labels: List of fold identifiers (left-out groups)
674 fold_farm_names: List of arrays, each containing farm names for test cases in that fold
675 output_dir: Directory to save plots
676 """
677 output_dir = Path(output_dir)
679 # Combine all folds
680 all_y_test = np.concatenate(y_tests)
681 all_y_pred = np.concatenate(y_preds)
682 all_fold_ids = np.concatenate(
683 [np.full(len(y_tests[i]), fold_labels[i]) for i in range(len(y_tests))]
684 )
686 # Get unique fold labels for coloring
687 unique_folds = np.unique(all_fold_ids)
688 n_folds = len(unique_folds)
690 # Create colormap
691 cmap = plt.cm.get_cmap("tab10" if n_folds <= 10 else "tab20")
692 colors = {fold: cmap(i / n_folds) for i, fold in enumerate(unique_folds)}
694 # --- Main scatter plot colored by fold ---
695 fig, ax = plt.subplots(figsize=(10, 8))
697 for fold_label in unique_folds:
698 mask = all_fold_ids == fold_label
699 ax.scatter(
700 all_y_test[mask],
701 all_y_pred[mask],
702 c=[colors[fold_label]],
703 label=f"Left out: {fold_label}",
704 alpha=0.6,
705 s=30,
706 edgecolor="white",
707 linewidth=0.5,
708 )
710 # Add 1:1 line
711 min_val = min(all_y_test.min(), all_y_pred.min())
712 max_val = max(all_y_test.max(), all_y_pred.max())
713 ax.plot(
714 [min_val, max_val], [min_val, max_val], "k--", linewidth=2, label="1:1 Line"
715 )
717 # Calculate overall metrics
718 overall_rmse = np.sqrt(np.mean((all_y_test - all_y_pred) ** 2))
719 overall_r2 = r2_score(all_y_test, all_y_pred)
721 ax.set_xlabel("True Bias", fontsize=12)
722 ax.set_ylabel("Predicted Bias", fontsize=12)
723 ax.set_title(
724 f"Predictions vs True Values by Left-Out Group\n"
725 f"Overall RMSE: {overall_rmse:.4f}, R²: {overall_r2:.4f}",
726 fontsize=14,
727 )
728 ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=9)
729 ax.grid(alpha=0.3)
730 ax.set_aspect("equal", adjustable="box")
732 plt.tight_layout()
733 plt.savefig(output_dir / "cv_predictions_by_fold.png", dpi=150, bbox_inches="tight")
734 plt.close(fig)
735 print(
736 f" Saved predictions scatter plot to: {output_dir / 'cv_predictions_by_fold.png'}"
737 )
739 # --- Per-fold subplots ---
740 n_cols = min(3, n_folds)
741 n_rows = (n_folds + n_cols - 1) // n_cols
743 fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows))
744 if n_folds == 1:
745 axes = np.array([[axes]])
746 elif n_rows == 1:
747 axes = axes.reshape(1, -1)
749 for idx, fold_label in enumerate(unique_folds):
750 row, col = divmod(idx, n_cols)
751 ax = axes[row, col]
753 mask = all_fold_ids == fold_label
754 y_test_fold = all_y_test[mask]
755 y_pred_fold = all_y_pred[mask]
757 ax.scatter(
758 y_test_fold,
759 y_pred_fold,
760 c=[colors[fold_label]],
761 alpha=0.6,
762 s=30,
763 edgecolor="white",
764 linewidth=0.5,
765 )
767 # 1:1 line
768 min_v = min(y_test_fold.min(), y_pred_fold.min())
769 max_v = max(y_test_fold.max(), y_pred_fold.max())
770 ax.plot([min_v, max_v], [min_v, max_v], "k--", linewidth=1.5)
772 # Per-fold metrics
773 fold_rmse = np.sqrt(np.mean((y_test_fold - y_pred_fold) ** 2))
774 fold_r2 = r2_score(y_test_fold, y_pred_fold) if len(y_test_fold) > 1 else 0
776 ax.set_xlabel("True Bias")
777 ax.set_ylabel("Predicted Bias")
778 ax.set_title(
779 f"Left Out: {fold_label}\nRMSE: {fold_rmse:.4f}, R²: {fold_r2:.4f}"
780 )
781 ax.grid(alpha=0.3)
782 ax.set_aspect("equal", adjustable="box")
784 # Hide unused subplots
785 for idx in range(n_folds, n_rows * n_cols):
786 row, col = divmod(idx, n_cols)
787 axes[row, col].set_visible(False)
789 plt.suptitle("Prediction Quality per CV Fold", fontsize=14, y=1.02)
790 plt.tight_layout()
791 plt.savefig(
792 output_dir / "cv_predictions_per_fold.png", dpi=150, bbox_inches="tight"
793 )
794 plt.close(fig)
795 print(
796 f" Saved per-fold predictions to: {output_dir / 'cv_predictions_per_fold.png'}"
797 )
800def plot_generalization_matrix(
801 cv_results: pd.DataFrame, fold_labels: list, output_dir: Path
802):
803 """
804 Create a generalization analysis visualization showing how training on
805 certain farm groups affects prediction on others.
807 Args:
808 cv_results: DataFrame with metrics per fold
809 fold_labels: List of fold identifiers (left-out groups)
810 output_dir: Directory to save plots
811 """
812 output_dir = Path(output_dir)
813 n_folds = len(fold_labels)
815 # Create summary table
816 fig, ax = plt.subplots(figsize=(12, max(4, n_folds * 0.5)))
818 # Prepare data for table
819 table_data = []
820 for i, label in enumerate(fold_labels):
821 row = [
822 label,
823 f"{cv_results['rmse'].iloc[i]:.4f}",
824 f"{cv_results['r2'].iloc[i]:.4f}",
825 f"{cv_results['mae'].iloc[i]:.4f}",
826 ]
827 table_data.append(row)
829 # Add mean row
830 table_data.append(
831 [
832 "MEAN",
833 f"{cv_results['rmse'].mean():.4f}",
834 f"{cv_results['r2'].mean():.4f}",
835 f"{cv_results['mae'].mean():.4f}",
836 ]
837 )
839 # Add std row
840 table_data.append(
841 [
842 "STD",
843 f"{cv_results['rmse'].std():.4f}",
844 f"{cv_results['r2'].std():.4f}",
845 f"{cv_results['mae'].std():.4f}",
846 ]
847 )
849 columns = ["Left-Out Group", "RMSE", "R²", "MAE"]
851 # Hide axes
852 ax.axis("off")
854 # Create table
855 table = ax.table(
856 cellText=table_data,
857 colLabels=columns,
858 loc="center",
859 cellLoc="center",
860 )
862 # Style the table
863 table.auto_set_font_size(False)
864 table.set_fontsize(10)
865 table.scale(1.2, 1.5)
867 # Color header row
868 for j, col in enumerate(columns):
869 table[(0, j)].set_facecolor("#4472C4")
870 table[(0, j)].set_text_props(color="white", weight="bold")
872 # Color mean/std rows
873 for j in range(len(columns)):
874 table[(n_folds + 1, j)].set_facecolor("#E2EFDA") # Mean row
875 table[(n_folds + 2, j)].set_facecolor("#FCE4D6") # Std row
877 # Color-code RMSE cells based on value
878 rmse_values = cv_results["rmse"].values
879 rmse_min, rmse_max = rmse_values.min(), rmse_values.max()
881 for i in range(n_folds):
882 # Normalize RMSE (lower is better, so invert for color)
883 if rmse_max > rmse_min:
884 norm_val = (rmse_values[i] - rmse_min) / (rmse_max - rmse_min)
885 else:
886 norm_val = 0.5
888 # Color from green (good) to red (bad)
889 color = plt.cm.RdYlGn(1 - norm_val)
890 table[(i + 1, 1)].set_facecolor(color)
892 plt.title(
893 "Cross-Validation Generalization Summary\n"
894 "(Testing on each group after training on all others)",
895 fontsize=14,
896 pad=20,
897 )
899 plt.tight_layout()
900 plt.savefig(
901 output_dir / "cv_generalization_summary.png", dpi=150, bbox_inches="tight"
902 )
903 plt.close(fig)
904 print(
905 f" Saved generalization summary to: {output_dir / 'cv_generalization_summary.png'}"
906 )
909## ------------------------------------------------------------------ ##
912def run_observation_sensitivity(
913 database,
914 features_list,
915 ml_pipeline,
916 model_type,
917 output_dir,
918 method: str = "auto",
919 pce_config: dict = None,
920):
921 """
922 Sensitivity analysis on observations.
924 Args:
925 database: xarray Dataset
926 features_list: List of feature names
927 ml_pipeline: ML pipeline (used for shap/sir methods)
928 model_type: "tree" or "sir" (used for shap/sir methods)
929 output_dir: Where to save plots
930 method: "auto", "shap", "sir", or "pce_sobol"
931 "auto" uses shap for tree models, sir directions for sir models
932 pce_config: Config dict for PCE (only used if method="pce_sobol")
933 """
934 print(f"--- Running Observation Sensitivity (method={method}) ---")
936 # Prepare data (sample 0 = default params)
937 data = database.isel(sample=0).to_dataframe().reset_index()
938 X = data[features_list]
939 y = data["ref_power_cap"].values # observations
941 # Determine method
942 if method == "auto":
943 if model_type == "tree":
944 method = "shap"
945 elif model_type == "sir":
946 method = "sir"
947 elif model_type == "pce":
948 # For PCE models, default "auto" to PCE-based Sobol SA
949 method = "pce_sobol"
950 else:
951 raise ValueError(
952 f"Unknown model_type '{model_type}' for 'auto' sensitivity. "
953 f"Expected one of ['tree', 'sir', 'pce']."
954 )
956 if method == "pce_sobol":
957 # PCE-based Sobol indices
958 from wifa_uq.postprocessing.PCE_tool.pce_utils import run_pce_sensitivity
960 run_pce_sensitivity(
961 database,
962 feature_names=features_list,
963 pce_config=pce_config or {},
964 output_dir=output_dir,
965 )
967 elif method == "shap":
968 # Train model, then SHAP
969 ml_pipeline.fit(X, y)
971 if hasattr(ml_pipeline, "named_steps"):
972 model = ml_pipeline.named_steps["model"]
973 else:
974 model = ml_pipeline
976 # Get scaled data and model for SHAP
977 X_scaled = ml_pipeline.named_steps["scaler"].transform(X)
978 model = ml_pipeline.named_steps["model"]
980 explainer = shap.TreeExplainer(model)
981 shap_values = explainer.shap_values(X_scaled)
983 # Plot
984 shap.summary_plot(shap_values, X, feature_names=features_list, show=False)
985 plt.savefig(Path(output_dir) / "observation_sensitivity_shap.png", dpi=150)
986 plt.close()
987 print(f" Saved SHAP plot to {output_dir}/observation_sensitivity_shap.png")
989 elif method == "sir":
990 # Train SIR model, use direction coefficients
991 scaler = StandardScaler()
992 X_scaled = scaler.fit_transform(X)
993 ml_pipeline.fit(X_scaled, y)
995 # Get SIR direction coefficients as importance
996 directions = ml_pipeline.sir_.directions_.flatten()
997 importance = np.abs(directions)
999 # Identify the feature with the largest influence on the first direction
1000 top_idx = np.argmax(importance)
1001 top_feature_name = features_list[top_idx]
1002 print(f" Dominant feature identified: {top_feature_name}")
1004 # Plot
1005 sorted_idx = np.argsort(importance)
1006 plt.figure(figsize=(8, 6))
1007 plt.barh(range(len(features_list)), importance[sorted_idx])
1008 plt.yticks(range(len(features_list)), [features_list[i] for i in sorted_idx])
1009 plt.xlabel("Absolute SIR Direction Coefficient")
1010 plt.title("Observation Sensitivity (SIR)")
1011 plt.tight_layout()
1012 plt.savefig(Path(output_dir) / "observation_sensitivity_sir.png", dpi=150)
1013 plt.close()
1014 print(f" Saved SIR plot to {output_dir}/observation_sensitivity_sir.png")
1016 # 2. Shadow Plot (Error vs. First Eigenvector, Colored by Top Feature)
1017 # Project data onto the first found direction
1018 X_projected = ml_pipeline.sir_.transform(X_scaled)
1019 first_component = X_projected[:, 0]
1021 # Get values of the top feature for coloring
1022 # X is likely a DataFrame here given the setup code
1023 color_values = X.iloc[:, top_idx].values
1025 plt.figure(figsize=(9, 7))
1026 #
1027 scatter = plt.scatter(
1028 first_component,
1029 y,
1030 c=color_values,
1031 cmap="viridis",
1032 alpha=0.7,
1033 edgecolor="k",
1034 linewidth=0.5,
1035 )
1037 cbar = plt.colorbar(scatter)
1038 cbar.set_label(f"Feature Value: {top_feature_name}", rotation=270, labelpad=15)
1040 plt.xlabel(r"Projected Input $\beta_1^T \mathbf{x}$ (1st SIR Direction)")
1041 plt.ylabel("Observed Error (y)")
1042 plt.title(f"SIR Shadow Plot\nColored by dominant feature '{top_feature_name}'")
1043 plt.grid(True, alpha=0.3)
1044 plt.tight_layout()
1046 shadow_plot_path = Path(output_dir) / "observation_sensitivity_sir_shadow.png"
1047 plt.savefig(shadow_plot_path, dpi=150)
1048 plt.close()
1049 print(f" Saved SIR shadow plot to {shadow_plot_path}")
1051 else:
1052 raise ValueError(
1053 f"Unknown method: {method}. Use 'auto', 'shap', 'sir', or 'pce_sobol'"
1054 )
1057def run_cross_validation(
1058 xr_data,
1059 ML_pipeline,
1060 model_type,
1061 Calibrator_cls,
1062 BiasPredictor_cls,
1063 MainPipeline_cls,
1064 cv_config: dict,
1065 features_list: list,
1066 output_dir: Path,
1067 sa_config: dict,
1068 calibration_mode: str = "global",
1069 local_regressor: str = None,
1070 local_regressor_params: dict = None,
1071):
1072 validation_data = xr_data[
1073 ["turb_rated_power", "pw_power_cap", "ref_power_cap", "case_index"]
1074 ]
1075 groups = xr_data["wind_farm"].values
1077 splitting_mode = cv_config.get("splitting_mode", "kfold_shuffled")
1079 # Track fold labels for multi-farm visualization
1080 fold_labels = []
1081 is_multi_farm = False
1083 if splitting_mode == "LeaveOneGroupOut":
1084 is_multi_farm = True
1085 groups = xr_data["wind_farm"].values
1087 groups_cfg = cv_config.get("groups")
1088 if groups_cfg:
1089 # Flatten config into a name -> label mapping
1090 wf_to_group = {}
1091 for group_label, wf_list in groups_cfg.items():
1092 for wf in wf_list:
1093 wf_to_group[wf] = group_label # can stay as string!
1095 default_group = "__OTHER__"
1096 manual_groups = np.array(
1097 [wf_to_group.get(str(w), default_group) for w in groups]
1098 )
1099 else:
1100 # Fully generic fallback: each wind_farm is its own group
1101 manual_groups = groups
1103 cv = LeaveOneGroupOut()
1104 splits = list(cv.split(xr_data.case_index, groups=manual_groups))
1105 n_splits = cv.get_n_splits(groups=manual_groups)
1107 # Extract fold labels (the left-out group for each fold)
1108 unique_groups = np.unique(manual_groups)
1109 for train_idx, test_idx in splits:
1110 # Find which group is left out (present in test but not in train)
1111 test_groups = np.unique(manual_groups[test_idx])
1112 fold_labels.append(
1113 str(test_groups[0]) if len(test_groups) == 1 else str(test_groups)
1114 )
1116 print(f"Using LeaveOneGroupOut with {n_splits} groups: {list(unique_groups)}")
1118 if splitting_mode == "kfold_shuffled":
1119 n_splits = cv_config.get("n_splits", 5)
1120 cv = KFold(n_splits=n_splits, shuffle=True, random_state=42)
1121 splits = list(cv.split(xr_data.case_index))
1122 fold_labels = [f"Fold {i + 1}" for i in range(n_splits)]
1123 print(f"Using KFold with {n_splits} splits.")
1125 stats_cv, y_preds, y_tests, pw_all, ref_all = [], [], [], [], []
1126 fold_farm_names = [] # Track farm names per fold for visualization
1128 # --- Add lists to store items for SHAP ---
1129 all_models = []
1130 all_xtest_scaled = [] # Only used for tree models
1131 all_features_df = []
1133 # --- Add lists for local calibration parameter prediction tracking ---
1134 all_predicted_params = [] # Predicted parameter values per fold
1135 all_actual_optimal_params = [] # Actual optimal parameter values per fold
1136 swept_params = xr_data.attrs.get("swept_params", [])
1138 for i, (train_idx_locs, test_idx_locs) in enumerate(splits):
1139 # Get the actual case_index *values* at these integer locations
1140 train_indices = xr_data.case_index.values[train_idx_locs]
1141 test_indices = xr_data.case_index.values[test_idx_locs]
1143 dataset_train = xr_data.where(xr_data.case_index.isin(train_indices), drop=True)
1144 dataset_test = xr_data.where(xr_data.case_index.isin(test_indices), drop=True)
1146 # Track farm names for this fold (for visualization)
1147 if "wind_farm" in xr_data.coords:
1148 fold_farm_names.append(xr_data.wind_farm.values[test_idx_locs])
1150 if calibration_mode == "local":
1151 calibrator = Calibrator_cls(
1152 dataset_train,
1153 feature_names=features_list,
1154 regressor_name=local_regressor,
1155 regressor_params=local_regressor_params,
1156 )
1157 else:
1158 calibrator = Calibrator_cls(dataset_train)
1160 bias_pred = BiasPredictor_cls(ML_pipeline)
1162 # --- Pass features_list to MainPipeline ---
1163 main_pipe = MainPipeline_cls(
1164 calibrator,
1165 bias_pred,
1166 features_list=features_list,
1167 calibration_mode=calibration_mode,
1168 )
1170 x_test, y_test, idxs = main_pipe.fit(dataset_train, dataset_test)
1171 y_pred = main_pipe.predict(x_test)
1173 # Get correct validation data for this fold
1174 val_data_fold = validation_data.sel(sample=idxs).where(
1175 validation_data.case_index.isin(test_indices), drop=True
1176 )
1178 if calibration_mode == "global":
1179 # Single sample index for all cases
1180 val_data_fold = validation_data.sel(sample=idxs).where(
1181 validation_data.case_index.isin(test_indices), drop=True
1182 )
1184 pw = val_data_fold["pw_power_cap"].values
1185 ref = val_data_fold["ref_power_cap"].values
1186 else:
1187 # Local calibration: idxs is an array of per-case sample indices
1188 # We must build pw/ref per test case using those indices.
1190 # Local calibration: idxs is an array of per-case sample indices
1191 # We must build pw/ref per test case using those indices.
1192 idxs = np.asarray(idxs)
1193 if idxs.shape[0] != len(test_indices):
1194 raise ValueError(
1195 f"Local calibration returned {idxs.shape[0]} indices, "
1196 f"but there are {len(test_indices)} test cases."
1197 )
1199 # --- Track predicted vs actual optimal parameters for this fold ---
1200 X_test_features = main_pipe._extract_features(
1201 dataset_test.isel(sample=0).to_dataframe().reset_index()
1202 )
1203 predicted_params_fold = main_pipe.calibrator.predict(X_test_features)
1204 all_predicted_params.append(predicted_params_fold)
1206 # Get actual optimal parameters (from the sample indices we chose)
1207 actual_params_fold = {p: [] for p in swept_params}
1208 for sample_idx in idxs:
1209 for param_name in swept_params:
1210 if param_name in dataset_test.coords:
1211 actual_params_fold[param_name].append(
1212 float(
1213 dataset_test.coords[param_name]
1214 .isel(sample=sample_idx)
1215 .values
1216 )
1217 )
1218 actual_params_df = pd.DataFrame(actual_params_fold)
1219 all_actual_optimal_params.append(actual_params_df)
1221 pw_list = []
1222 ref_list = []
1224 # dataset_test.case_index is the subset used inside MainPipeline
1225 local_case_indices = dataset_test.case_index.values
1227 for local_case_idx, sample_idx in enumerate(idxs):
1228 case_index_val = local_case_indices[local_case_idx]
1230 # Pick the appropriate sample & case_index from validation_data
1231 this_point = validation_data.sel(
1232 sample=int(sample_idx), case_index=case_index_val
1233 )
1235 pw_list.append(float(this_point["pw_power_cap"].values))
1236 ref_list.append(float(this_point["ref_power_cap"].values))
1238 pw = np.array(pw_list)
1239 ref = np.array(ref_list)
1241 stats = compute_metrics(y_test, y_pred, pw=pw, ref=ref)
1242 stats_cv.append(stats)
1244 y_preds.append(y_pred)
1245 y_tests.append(y_test)
1246 pw_all.append(pw)
1247 ref_all.append(ref)
1249 # --- Store model and data for SHAP / SIR global importance ---
1250 all_models.append(main_pipe.bias_predictor.pipeline)
1251 all_features_df.append(x_test)
1252 if model_type == "tree":
1253 X_test_scaled = main_pipe.bias_predictor.pipeline.named_steps[
1254 "scaler"
1255 ].transform(x_test)
1256 all_xtest_scaled.append(X_test_scaled)
1258 cv_results = pd.DataFrame(stats_cv)
1260 # --- START PLOTTING BLOCK (Visualization) ---
1262 # Flatten all fold results into single arrays for plotting
1263 y_preds_flat = np.concatenate(y_preds)
1264 y_tests_flat = np.concatenate(y_tests)
1266 have_power = all(p is not None for p in pw_all) and all(
1267 r is not None for r in ref_all
1268 )
1270 if have_power:
1271 pw_flat = np.concatenate(pw_all)
1272 ref_flat = np.concatenate(ref_all)
1273 corrected_power_flat = pw_flat - y_preds_flat
1275 fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
1276 fig.suptitle("Cross-Validation Model Performance", fontsize=16)
1278 # 1. Predicted vs. True Bias
1279 ax1.scatter(y_tests_flat, y_preds_flat, alpha=0.5, s=10)
1280 min_bias = min(y_tests_flat.min(), y_preds_flat.min())
1281 max_bias = max(y_tests_flat.max(), y_preds_flat.max())
1282 ax1.plot([min_bias, max_bias], [min_bias, max_bias], "r--", label="1:1 Line")
1283 ax1.set_xlabel("True Bias (PyWake - Ref)")
1284 ax1.set_ylabel("Predicted Bias (ML)")
1285 ax1.set_title("ML Model Performance")
1286 ax1.grid(True)
1287 ax1.legend()
1288 ax1.axis("equal")
1290 # 2. Uncorrected Power vs. Reference
1291 ax2.scatter(ref_flat, pw_flat, alpha=0.5, s=10, label="Data")
1292 min_power = min(ref_flat.min(), pw_flat.min())
1293 max_power = max(ref_flat.max(), pw_flat.max())
1294 ax2.plot([min_power, max_power], [min_power, max_power], "r--", label="1:1 Line")
1295 ax2.set_xlabel("Reference Power (Truth)")
1296 ax2.set_ylabel("Uncorrected Power (Calibrated)")
1297 ax2.set_title("Uncorrected Model")
1298 ax2.grid(True)
1299 ax2.legend()
1300 ax2.axis("equal")
1302 # 3. Corrected Power vs. Reference
1303 ax3.scatter(ref_flat, corrected_power_flat, alpha=0.5, s=10, label="Data")
1304 ax3.plot([min_power, max_power], [min_power, max_power], "r--", label="1:1 Line")
1305 ax3.set_xlabel("Reference Power (Truth)")
1306 ax3.set_ylabel("Corrected Power (Calibrated + ML)")
1307 ax3.set_title("Corrected Model")
1308 ax3.grid(True)
1309 ax3.legend()
1310 ax3.axis("equal")
1312 plt.tight_layout(rect=[0, 0.03, 1, 0.95])
1313 plot_path = output_dir / "correction_results.png"
1314 plt.savefig(plot_path, dpi=150)
1315 print(f"Saved correction plot to: {plot_path}")
1316 plt.close(fig) # Close the figure
1318 # --- MULTI-FARM CV VISUALIZATION (NEW) ---
1319 if is_multi_farm or splitting_mode == "LeaveOneGroupOut":
1320 print("--- Generating Multi-Farm CV Visualizations ---")
1322 # 1. Per-fold metrics visualization
1323 plot_multi_farm_cv_metrics(
1324 cv_results=cv_results,
1325 fold_labels=fold_labels,
1326 output_dir=output_dir,
1327 splitting_mode=splitting_mode,
1328 )
1330 # 2. Predictions colored by fold
1331 plot_farm_wise_predictions(
1332 y_tests=y_tests,
1333 y_preds=y_preds,
1334 fold_labels=fold_labels,
1335 fold_farm_names=fold_farm_names,
1336 output_dir=output_dir,
1337 )
1339 # 3. Generalization summary table
1340 plot_generalization_matrix(
1341 cv_results=cv_results, fold_labels=fold_labels, output_dir=output_dir
1342 )
1344 # --- PARAMETER PREDICTION PLOT (Local Calibration Only) ---
1345 if calibration_mode == "local" and all_predicted_params and swept_params:
1346 print("--- Generating Parameter Prediction Quality Plot ---")
1348 # Concatenate all folds
1349 predicted_all = pd.concat(all_predicted_params, axis=0, ignore_index=True)
1350 actual_all = pd.concat(all_actual_optimal_params, axis=0, ignore_index=True)
1352 n_params = len(swept_params)
1353 fig, axes = plt.subplots(1, n_params, figsize=(6 * n_params, 5))
1354 if n_params == 1:
1355 axes = [axes]
1357 fig.suptitle("Local Calibration: Parameter Prediction Quality", fontsize=14)
1359 for idx, param_name in enumerate(swept_params):
1360 ax = axes[idx]
1361 pred_vals = predicted_all[param_name].values
1362 actual_vals = actual_all[param_name].values
1364 ax.scatter(actual_vals, pred_vals, alpha=0.5, s=15)
1366 # 1:1 line
1367 min_val = min(actual_vals.min(), pred_vals.min())
1368 max_val = max(actual_vals.max(), pred_vals.max())
1369 ax.plot([min_val, max_val], [min_val, max_val], "r--", label="1:1 Line")
1371 # Calculate R² for parameter prediction
1372 ss_res = np.sum((actual_vals - pred_vals) ** 2)
1373 ss_tot = np.sum((actual_vals - actual_vals.mean()) ** 2)
1374 r2_param = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
1376 ax.set_xlabel(f"Actual Optimal {param_name}")
1377 ax.set_ylabel(f"Predicted {param_name}")
1378 ax.set_title(f"{param_name} (R² = {r2_param:.3f})")
1379 ax.legend()
1380 ax.grid(True)
1381 ax.set_aspect("equal", adjustable="box")
1383 plt.tight_layout(rect=[0, 0.03, 1, 0.95])
1384 param_plot_path = output_dir / "local_parameter_prediction.png"
1385 plt.savefig(param_plot_path, dpi=150)
1386 print(f"Saved parameter prediction plot to: {param_plot_path}")
1387 plt.close(fig)
1389 # --- END PLOTTING BLOCK ---
1391 # --- START SHAP ON BIAS BLOCK ---
1392 if sa_config.get("run_bias_sensitivity", False):
1393 print(f"--- Running Bias Sensitivity (Model Type: {model_type}) ---")
1395 if model_type == "tree":
1396 try:
1397 print("--- Calculating Bias SHAP (TreeExplainer) ---")
1398 all_shap_values = []
1400 for i in range(n_splits):
1401 model = all_models[i].named_steps["model"]
1402 X_test_scaled = all_xtest_scaled[i]
1404 explainer = shap.TreeExplainer(model)
1405 shap_values_fold = explainer.shap_values(X_test_scaled)
1406 all_shap_values.append(shap_values_fold)
1408 final_shap_values = np.concatenate(all_shap_values, axis=0)
1410 final_features_df = pd.concat(all_features_df, axis=0)
1411 # ... (string cleaning logic) ...
1412 for col in final_features_df.columns:
1413 if final_features_df[col].dtype == "object":
1414 if final_features_df[col].dropna().empty:
1415 continue
1416 first_item = final_features_df[col].dropna().iloc[0]
1417 if isinstance(first_item, str):
1418 print(f" Cleaning string column in SHAP: {col}")
1419 final_features_df[col] = (
1420 final_features_df[col]
1421 .str.replace(r"[\[\]]", "", regex=True)
1422 .astype(float)
1423 )
1424 else:
1425 final_features_df[col] = final_features_df[col].astype(
1426 float
1427 )
1428 final_features_df = final_features_df.astype(float) # Final cast
1430 # --- 1. GENERATE AND SAVE BAR PLOT ---
1431 print("--- Calculating Bias SHAP Global Feature Importance ---")
1432 mean_abs_shap = np.mean(np.abs(final_shap_values), axis=0)
1433 shap_scores = pd.Series(
1434 mean_abs_shap, index=final_features_df.columns
1435 ).sort_values(ascending=True)
1437 fig, ax = plt.subplots(figsize=(10, 8))
1438 shap_scores.plot(kind="barh", ax=ax)
1439 ax.set_title(
1440 "Global SHAP Feature Importance (Mean Absolute SHAP Value)"
1441 )
1442 ax.set_xlabel("Mean |SHAP Value| (Impact on Bias Prediction)")
1443 plt.tight_layout()
1445 bar_plot_path = output_dir / "bias_prediction_shap_importance.png"
1446 plt.savefig(bar_plot_path, dpi=150, bbox_inches="tight")
1447 plt.close(fig)
1448 print(f"Saved SHAP importance bar plot to: {bar_plot_path}")
1450 # --- 2. GENERATE AND SAVE BEESWARM PLOT ---
1451 shap.summary_plot(final_shap_values, final_features_df, show=False)
1452 plot_path = output_dir / "bias_prediction_shap.png"
1453 plt.savefig(plot_path, dpi=150, bbox_inches="tight")
1454 plt.close()
1455 print(f"Saved bias SHAP beeswarm plot to: {plot_path}")
1457 except Exception as e:
1458 print(f"Could not run bias SHAP (Tree) analysis: {e}")
1459 raise e
1461 elif model_type == "linear":
1462 try:
1463 print(
1464 "--- Calculating Bias Linear Feature Importance (Coefficients) ---"
1465 )
1466 all_linear_scores = []
1467 feature_names = all_features_df[0].columns
1469 for i in range(n_splits):
1470 model = all_models[i]
1471 fold_scores = model.get_feature_importance(feature_names)
1472 all_linear_scores.append(fold_scores)
1474 # Average importances across folds
1475 importance_scores = pd.concat(all_linear_scores, axis=1).mean(axis=1)
1476 importance_scores = importance_scores.sort_values(ascending=True)
1478 fig, ax = plt.subplots(figsize=(10, 8))
1479 importance_scores.plot(kind="barh", ax=ax)
1480 ax.set_title("Linear Model Feature Importance (Mean |Coefficient|)")
1481 ax.set_xlabel("Mean |Coefficient| (Impact on Bias Prediction)")
1482 plt.tight_layout()
1484 bar_plot_path = output_dir / "bias_prediction_linear_importance.png"
1485 plt.savefig(bar_plot_path, dpi=150, bbox_inches="tight")
1486 plt.close(fig)
1487 print(f"Saved linear importance plot to: {bar_plot_path}")
1489 except Exception as e:
1490 print(f"Could not run bias linear analysis: {e}")
1491 raise e
1492 elif model_type == "sir":
1493 try:
1494 print(
1495 "--- Calculating Bias SIR Feature Importance (Averaged over folds) ---"
1496 )
1497 all_sir_scores = []
1498 feature_names = all_features_df[
1499 0
1500 ].columns # Get feature names from first fold
1502 for i in range(n_splits):
1503 model = all_models[i] # This is the SIRPolynomialRegressor instance
1504 fold_scores = model.get_feature_importance(feature_names)
1505 all_sir_scores.append(fold_scores)
1507 # Average the importances across all folds
1508 shap_scores = pd.concat(all_sir_scores, axis=1).mean(axis=1)
1509 shap_scores = shap_scores.sort_values(ascending=True) # Sort for barh
1511 # Generate and save the bar plot
1512 fig, ax = plt.subplots(figsize=(10, 8))
1513 shap_scores.plot(kind="barh", ax=ax)
1514 ax.set_title(
1515 "Global SIR Feature Importance (Mean Absolute Direction Coefficient)"
1516 )
1517 ax.set_xlabel(
1518 "Mean |SIR Direction Coefficient| (Impact on Bias Prediction)"
1519 )
1520 plt.tight_layout()
1522 bar_plot_path = output_dir / "bias_prediction_sir_importance.png"
1523 plt.savefig(bar_plot_path, dpi=150, bbox_inches="tight")
1524 plt.close(fig)
1525 print(f"Saved SIR importance bar plot to: {bar_plot_path}")
1526 print("--- NOTE: Beeswarm plot is not available for SIR model. ---")
1528 except Exception as e:
1529 print(f"Could not run bias SIR analysis: {e}")
1530 raise e
1532 # --- END SHAP ON BIAS BLOCK ---
1534 return cv_results, y_preds, y_tests
1537# Testing
1538if __name__ == "__main__":
1539 import xarray as xr
1541 xr_data = xr.load_dataset("results_stacked_hh.nc")
1542 pipe_xgb = Pipeline(
1543 [
1544 ("scaler", StandardScaler()),
1545 ("model", xgb.XGBRegressor(max_depth=3, n_estimators=500)),
1546 ]
1547 )
1549 # Create a dummy output_dir for testing
1550 test_output_dir = Path("./test_results")
1551 test_output_dir.mkdir(exist_ok=True)
1553 cv_results, y_preds, y_tests = run_cross_validation(
1554 xr_data,
1555 ML_pipeline=pipe_xgb,
1556 model_type="tree", # <-- Need to provide this for the test
1557 Calibrator_cls=MinBiasCalibrator,
1558 BiasPredictor_cls=BiasPredictor,
1559 MainPipeline_cls=MainPipeline,
1560 cv_config={"splitting_mode": "kfold_shuffled", "n_splits": 5},
1561 features_list=["turbulence_intensity"], # Example features
1562 output_dir=test_output_dir,
1563 sa_config={"run_bias_shap": True}, # Example config
1564 )
1565 print(cv_results.mean())