Coverage for wifa_uq / postprocessing / physics_insights / physics_insights.py: 81%
518 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-19 02:10 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-19 02:10 +0000
1# wifa_uq/postprocessing/physics_insights/physics_insights.py
2"""
3Physics Insights Module for WIFA-UQ.
5Extracts interpretable physical insights from bias prediction models:
6 1. Partial Dependence Analysis - How bias varies with each feature
7 2. Feature Interactions - Which feature combinations drive error
8 3. Regime Identification - Clustering of high-bias conditions
9 4. Parameter-Condition Relationships - How optimal params depend on conditions
11These analyses transform ML results into actionable physics understanding.
12"""
14from __future__ import annotations
16from dataclasses import dataclass, field
17from pathlib import Path
19import matplotlib.pyplot as plt
20import numpy as np
21import pandas as pd
22import xarray as xr
23from sklearn.cluster import KMeans
24from sklearn.inspection import partial_dependence
25from sklearn.preprocessing import StandardScaler
27try:
28 import shap
30 HAS_SHAP = True
31except ImportError:
32 HAS_SHAP = False
35# =============================================================================
36# Data Classes for Results
37# =============================================================================
40@dataclass
41class PartialDependenceResult:
42 """Results from partial dependence analysis."""
44 feature: str
45 grid_values: np.ndarray
46 pd_values: np.ndarray
47 bias_direction: str # "increases", "decreases", "non-monotonic"
48 effect_magnitude: float # Range of PD values
49 physical_interpretation: str
52@dataclass
53class InteractionResult:
54 """Results from interaction analysis."""
56 feature_1: str
57 feature_2: str
58 interaction_strength: float
59 description: str
62@dataclass
63class RegimeResult:
64 """Results from regime identification."""
66 regime_id: int
67 n_cases: int
68 mean_bias: float
69 feature_centroids: dict[str, float]
70 description: str
73@dataclass
74class ParameterRelationshipResult:
75 """Results from parameter-condition analysis."""
77 parameter: str
78 most_influential_feature: str
79 correlation: float
80 relationship_type: str # "positive", "negative", "weak"
81 physical_interpretation: str
84@dataclass
85class PhysicsInsightsReport:
86 """Complete physics insights report."""
88 partial_dependence: list[PartialDependenceResult] = field(default_factory=list)
89 interactions: list[InteractionResult] = field(default_factory=list)
90 regimes: list[RegimeResult] = field(default_factory=list)
91 parameter_relationships: list[ParameterRelationshipResult] = field(
92 default_factory=list
93 )
94 summary: str = ""
96 def to_dict(self) -> dict:
97 """Convert to dictionary for serialization."""
98 return {
99 "partial_dependence": [
100 {
101 "feature": r.feature,
102 "bias_direction": r.bias_direction,
103 "effect_magnitude": r.effect_magnitude,
104 "interpretation": r.physical_interpretation,
105 }
106 for r in self.partial_dependence
107 ],
108 "interactions": [
109 {
110 "features": [r.feature_1, r.feature_2],
111 "strength": r.interaction_strength,
112 "description": r.description,
113 }
114 for r in self.interactions
115 ],
116 "regimes": [
117 {
118 "regime_id": r.regime_id,
119 "n_cases": r.n_cases,
120 "mean_bias": r.mean_bias,
121 "centroids": r.feature_centroids,
122 "description": r.description,
123 }
124 for r in self.regimes
125 ],
126 "parameter_relationships": [
127 {
128 "parameter": r.parameter,
129 "most_influential_feature": r.most_influential_feature,
130 "correlation": r.correlation,
131 "interpretation": r.physical_interpretation,
132 }
133 for r in self.parameter_relationships
134 ],
135 "summary": self.summary,
136 }
138 def to_markdown(self) -> str:
139 """Generate markdown report."""
140 lines = ["# Physics Insights Report\n"]
142 if self.partial_dependence:
143 lines.append("## 1. How Model Bias Depends on Atmospheric Conditions\n")
144 for r in self.partial_dependence:
145 lines.append(f"### {r.feature}")
146 lines.append(
147 f"- **Direction**: Bias {r.bias_direction} with {r.feature}"
148 )
149 lines.append(f"- **Effect magnitude**: {r.effect_magnitude:.4f}")
150 lines.append(f"- **Interpretation**: {r.physical_interpretation}\n")
152 if self.interactions:
153 lines.append("## 2. Feature Interactions\n")
154 for r in self.interactions:
155 lines.append(
156 f"- **{r.feature_1} × {r.feature_2}**: "
157 f"strength = {r.interaction_strength:.4f}"
158 )
159 lines.append(f" - {r.description}\n")
161 if self.regimes:
162 lines.append("## 3. Error Regimes (Failure Modes)\n")
163 for r in self.regimes:
164 lines.append(f"### Regime {r.regime_id + 1}: {r.description}")
165 lines.append(f"- Cases: {r.n_cases}")
166 lines.append(f"- Mean bias: {r.mean_bias:.4f}")
167 lines.append("- Characteristic conditions:")
168 for feat, val in r.feature_centroids.items():
169 lines.append(f" - {feat}: {val:.3f}")
170 lines.append("")
172 if self.parameter_relationships:
173 lines.append("## 4. Optimal Parameter Dependencies\n")
174 lines.append(
175 "*How the 'best' wake model parameters vary with conditions:*\n"
176 )
177 for r in self.parameter_relationships:
178 lines.append(f"### {r.parameter}")
179 lines.append(
180 f"- Most influential feature: **{r.most_influential_feature}**"
181 )
182 lines.append(
183 f"- Correlation: {r.correlation:.3f} ({r.relationship_type})"
184 )
185 lines.append(f"- **Implication**: {r.physical_interpretation}\n")
187 if self.summary:
188 lines.append("## Summary\n")
189 lines.append(self.summary)
191 return "\n".join(lines)
194# =============================================================================
195# Physical Interpretation Helpers
196# =============================================================================
198# Domain knowledge for automatic interpretation
199FEATURE_PHYSICS = {
200 "ABL_height": {
201 "high": "stable stratification / nocturnal conditions",
202 "low": "convective / well-mixed conditions",
203 "unit": "m",
204 },
205 "wind_veer": {
206 "high": "strong directional shear / Ekman spiral",
207 "low": "uniform wind direction with height",
208 "unit": "deg/m",
209 },
210 "lapse_rate": {
211 "high": "stable stratification (positive dθ/dz)",
212 "low": "unstable/neutral conditions",
213 "unit": "K/m",
214 },
215 "turbulence_intensity": {
216 "high": "high ambient turbulence / faster wake recovery",
217 "low": "low turbulence / slower wake recovery",
218 "unit": "-",
219 },
220 "Blockage_Ratio": {
221 "high": "significant upstream blockage",
222 "low": "front-row or minimal blockage",
223 "unit": "-",
224 },
225 "Blocking_Distance": {
226 "high": "far from upstream turbines",
227 "low": "close to upstream turbines",
228 "unit": "-",
229 },
230 "Farm_Length": {
231 "high": "deep array (many rows)",
232 "low": "shallow array",
233 "unit": "D",
234 },
235 "Farm_Width": {
236 "high": "wide array",
237 "low": "narrow array",
238 "unit": "D",
239 },
240}
242PARAMETER_PHYSICS = {
243 "k_b": {
244 "name": "Wake expansion coefficient",
245 "increases_with": "faster wake recovery / higher turbulence",
246 "decreases_with": "slower wake recovery / stable conditions",
247 },
248 "ss_alpha": {
249 "name": "Self-similarity blockage parameter",
250 "increases_with": "stronger blockage effects",
251 "decreases_with": "weaker blockage effects",
252 },
253 "ceps": {
254 "name": "Added turbulence coefficient",
255 "increases_with": "more wake-added turbulence",
256 "decreases_with": "less wake-added turbulence",
257 },
258}
261def interpret_pd_direction(
262 feature: str,
263 direction: str, # "increases", "decreases", "non-monotonic", or "flat"
264) -> str:
265 """Generate physical interpretation of partial dependence direction."""
266 physics = FEATURE_PHYSICS.get(feature, {})
267 high_meaning = physics.get("high", f"high {feature}")
268 low_meaning = physics.get("low", f"low {feature}")
270 if direction == "increases":
271 return (
272 f"Model bias increases with {feature}, suggesting the wake model "
273 f"systematically underestimates wake effects in {high_meaning} "
274 f"(and/or overestimates in {low_meaning})."
275 )
276 elif direction == "decreases":
277 return (
278 f"Model bias decreases with {feature}, suggesting the wake model "
279 f"systematically overestimates wake effects in {high_meaning} "
280 f"(and/or underestimates in {low_meaning})."
281 )
282 else:
283 return (
284 f"Model bias shows a non-monotonic relationship with {feature}, "
285 f"suggesting different error mechanisms at {high_meaning} vs {low_meaning}."
286 )
289def interpret_parameter_relationship(
290 parameter: str,
291 feature: str,
292 correlation: float,
293) -> str:
294 """Generate physical interpretation of parameter-feature relationship."""
295 param_physics = PARAMETER_PHYSICS.get(parameter, {"name": parameter})
296 feat_physics = FEATURE_PHYSICS.get(feature, {})
298 param_name = param_physics.get("name", parameter)
299 high_meaning = feat_physics.get("high", f"high {feature}")
301 if abs(correlation) < 0.3:
302 return f"{param_name} shows weak dependence on {feature}."
304 if correlation > 0:
305 return (
306 f"{param_name} should increase in {high_meaning}. "
307 f"This suggests the model's default parameterization underestimates "
308 f"wake recovery rate in these conditions."
309 )
310 else:
311 return (
312 f"{param_name} should decrease in {high_meaning}. "
313 f"This suggests the model's default parameterization overestimates "
314 f"wake recovery rate in these conditions."
315 )
318def describe_regime(
319 centroids: dict[str, float],
320 feature_stats: dict[str, tuple[float, float]], # {feature: (mean, std)}
321 mean_bias: float,
322) -> str:
323 """Generate description of an error regime based on its characteristics."""
324 descriptions = []
326 for feature, value in centroids.items():
327 if feature not in feature_stats:
328 continue
329 mean, std = feature_stats[feature]
330 if std == 0:
331 continue
333 z_score = (value - mean) / std
334 physics = FEATURE_PHYSICS.get(feature, {})
336 if z_score > 1.0:
337 descriptions.append(physics.get("high", f"high {feature}"))
338 elif z_score < -1.0:
339 descriptions.append(physics.get("low", f"low {feature}"))
341 if not descriptions:
342 if mean_bias > 0:
343 return "Mixed conditions with positive bias (model overestimates)"
344 else:
345 return "Mixed conditions with negative bias (model underestimates)"
347 bias_desc = "overestimation" if mean_bias > 0 else "underestimation"
348 return f"{', '.join(descriptions[:2])} → {bias_desc}"
351def describe_regime_relative(
352 centroids: dict[str, float],
353 all_cluster_info: list[dict],
354 feature_stats: dict[str, tuple[float, float]],
355 mean_bias: float,
356 cluster_id: int,
357) -> str:
358 """
359 Generate description of an error regime based on RELATIVE differences to other clusters.
361 This is more informative when clusters are similar in absolute terms but differ
362 from each other in specific ways.
363 """
364 # First try absolute description (z-scores > 1)
365 descriptions = []
366 for feature, value in centroids.items():
367 if feature not in feature_stats:
368 continue
369 mean, std = feature_stats[feature]
370 if std == 0:
371 continue
372 z_score = (value - mean) / std
373 physics = FEATURE_PHYSICS.get(feature, {})
374 if z_score > 1.0:
375 descriptions.append(physics.get("high", f"high {feature}"))
376 elif z_score < -1.0:
377 descriptions.append(physics.get("low", f"low {feature}"))
379 # If we found distinctive absolute features, use those
380 if descriptions:
381 bias_desc = "overestimation" if mean_bias > 0 else "underestimation"
382 return f"{', '.join(descriptions[:2])} → {bias_desc}"
384 # Otherwise, find what makes THIS cluster different from OTHERS
385 other_clusters = [c for c in all_cluster_info if c["id"] != cluster_id]
386 if not other_clusters:
387 bias_desc = "overestimation" if mean_bias > 0 else "underestimation"
388 return f"High-error cases → {bias_desc}"
390 # Compute relative differences for each feature
391 relative_diffs = []
392 for feature, value in centroids.items():
393 if feature not in feature_stats:
394 continue
395 mean, std = feature_stats[feature]
396 if std == 0:
397 continue
399 # Average value in other clusters
400 other_avg = np.mean([c["centroids"][feature] for c in other_clusters])
401 diff = (value - other_avg) / std if std > 0 else 0
402 relative_diffs.append((feature, diff, value, other_avg))
404 # Sort by absolute difference
405 relative_diffs.sort(key=lambda x: abs(x[1]), reverse=True)
407 # Build description from top distinguishing features
408 distinguishing = []
409 for feature, diff, val, other_avg in relative_diffs[:2]:
410 if abs(diff) < 0.3: # Not meaningfully different
411 continue
412 physics = FEATURE_PHYSICS.get(feature, {})
413 if diff > 0:
414 desc = physics.get("high", f"higher {feature}")
415 distinguishing.append(f"{desc} (vs other regimes)")
416 else:
417 desc = physics.get("low", f"lower {feature}")
418 distinguishing.append(f"{desc} (vs other regimes)")
420 bias_desc = "overestimation" if mean_bias > 0 else "underestimation"
422 if distinguishing:
423 return f"{'; '.join(distinguishing)} → {bias_desc}"
425 # Last resort: describe by bias magnitude
426 other_biases = [c["mean_bias"] for c in other_clusters]
427 avg_other_bias = np.mean(other_biases)
429 if abs(mean_bias) > abs(avg_other_bias) * 1.5:
430 return f"Highest error magnitude ({mean_bias:.4f}) → {bias_desc}"
431 elif abs(mean_bias) < abs(avg_other_bias) * 0.7:
432 return f"Lower error magnitude ({mean_bias:.4f}) → {bias_desc}"
433 else:
434 return f"Similar conditions, bias={mean_bias:.4f} → {bias_desc}"
437# =============================================================================
438# Analysis Functions
439# =============================================================================
442def _manual_partial_dependence(
443 model,
444 X: pd.DataFrame,
445 features: list[str],
446 grid_resolution: int = 50,
447) -> dict:
448 """
449 Manual partial dependence calculation as fallback.
451 For each feature, creates a grid of values and computes mean prediction
452 while averaging over all other features.
453 """
454 results = {
455 "average": [],
456 "grid_values": [],
457 }
459 X_array = X.values
461 for feature in features:
462 feat_idx = X.columns.get_loc(feature)
463 feat_values = X[feature].values
465 # Create grid
466 grid = np.linspace(feat_values.min(), feat_values.max(), grid_resolution)
468 # Compute partial dependence
469 pd_values = []
470 for grid_val in grid:
471 # Create modified X with feature set to grid value
472 X_modified = X_array.copy()
473 X_modified[:, feat_idx] = grid_val
475 # Predict and average
476 predictions = model.predict(X_modified)
477 pd_values.append(predictions.mean())
479 results["average"].append(np.array(pd_values))
480 results["grid_values"].append(grid)
482 return results
485def analyze_partial_dependence(
486 model,
487 X: pd.DataFrame,
488 features: list[str],
489 output_dir: Path,
490 grid_resolution: int = 50,
491) -> list[PartialDependenceResult]:
492 """
493 Compute partial dependence and extract physical interpretations.
495 Shows how predicted bias changes with each feature, holding others constant.
496 """
497 print("--- Analyzing Partial Dependence ---")
498 results = []
500 # Ensure X is a DataFrame with correct types
501 X = X.copy()
502 for col in X.columns:
503 if X[col].dtype == "object":
504 X[col] = pd.to_numeric(X[col], errors="coerce")
505 X = X.astype(float)
507 # Verify model is fitted by trying a prediction
508 try:
509 _ = model.predict(X.iloc[:1])
510 except Exception as e:
511 print(f" WARNING: Model does not appear to be fitted: {e}")
512 print(" Attempting to re-fit model...")
513 raise ValueError(
514 f"Model must be fitted before partial dependence analysis: {e}"
515 )
517 n_features = len(features)
518 n_cols = min(3, n_features)
519 n_rows = (n_features + n_cols - 1) // n_cols
521 fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows))
522 if n_features == 1:
523 axes = np.array([[axes]])
524 elif n_rows == 1:
525 axes = axes.reshape(1, -1)
527 for idx, feature in enumerate(features):
528 row, col = divmod(idx, n_cols)
529 ax = axes[row, col]
531 # Compute PD for this single feature (avoids sklearn multi-feature output issues)
532 feature_idx = X.columns.get_loc(feature)
534 try:
535 pd_result = partial_dependence(
536 model, X, features=[feature_idx], grid_resolution=grid_resolution
537 )
538 pd_values = pd_result["average"][0]
539 grid_values = pd_result["grid_values"][0]
540 except Exception as e:
541 print(f" WARNING: partial_dependence failed for {feature}: {e}")
542 print(" Falling back to manual PD calculation...")
543 manual_result = _manual_partial_dependence(
544 model, X, [feature], grid_resolution
545 )
546 pd_values = manual_result["average"][0]
547 grid_values = manual_result["grid_values"][0]
549 # Flatten if needed (sklearn may return 2D arrays)
550 pd_values = np.asarray(pd_values).flatten()
551 grid_values = np.asarray(grid_values).flatten()
553 # Determine direction
554 n_points = len(pd_values)
555 n_edge = min(5, n_points // 4) if n_points > 4 else 1
557 start_val = pd_values[:n_edge].mean()
558 end_val = pd_values[-n_edge:].mean()
559 mid_val = pd_values[n_points // 2] if n_points > 2 else pd_values.mean()
561 value_range = pd_values.max() - pd_values.min()
562 threshold = 0.01 * value_range if value_range > 0 else 1e-10
564 if end_val > start_val + threshold:
565 direction = "increases"
566 elif end_val < start_val - threshold:
567 direction = "decreases"
568 else:
569 # Check for non-monotonicity
570 if abs(mid_val - start_val) > abs(end_val - start_val):
571 direction = "non-monotonic"
572 else:
573 direction = "flat"
575 effect_magnitude = float(value_range) if value_range > 0 else 0.0
576 interpretation = interpret_pd_direction(feature, direction)
578 results.append(
579 PartialDependenceResult(
580 feature=feature,
581 grid_values=grid_values,
582 pd_values=pd_values,
583 bias_direction=direction,
584 effect_magnitude=float(effect_magnitude),
585 physical_interpretation=interpretation,
586 )
587 )
589 # Plot
590 ax.plot(grid_values, pd_values, "b-", linewidth=2)
591 ax.fill_between(grid_values, pd_values, alpha=0.3)
592 ax.axhline(0, color="k", linestyle="--", alpha=0.5)
594 # Add direction annotation
595 direction_symbol = {
596 "increases": "↗",
597 "decreases": "↘",
598 "non-monotonic": "↝",
599 "flat": "→",
600 }
601 ax.set_title(f"{feature} {direction_symbol.get(direction, '')}")
602 ax.set_xlabel(feature)
603 ax.set_ylabel("Partial Dependence (Bias)")
604 ax.grid(True, alpha=0.3)
606 # Add interpretation as text box
607 textstr = f"Effect: {effect_magnitude:.4f}\n{direction}"
608 props = dict(boxstyle="round", facecolor="wheat", alpha=0.5)
609 ax.text(
610 0.02,
611 0.98,
612 textstr,
613 transform=ax.transAxes,
614 fontsize=8,
615 verticalalignment="top",
616 bbox=props,
617 )
619 # Hide unused subplots
620 for idx in range(n_features, n_rows * n_cols):
621 row, col = divmod(idx, n_cols)
622 axes[row, col].set_visible(False)
624 plt.suptitle("Partial Dependence: How Bias Varies with Each Feature", fontsize=14)
625 plt.tight_layout()
626 plt.savefig(output_dir / "partial_dependence.png", dpi=150, bbox_inches="tight")
627 plt.close(fig)
628 print(
629 f" Saved partial dependence plot to {output_dir / 'partial_dependence.png'}"
630 )
632 return results
635def analyze_interactions(
636 model,
637 X: pd.DataFrame,
638 features: list[str],
639 output_dir: Path,
640 top_n: int = 5,
641) -> list[InteractionResult]:
642 """
643 Analyze feature interactions using SHAP interaction values.
645 Identifies which feature combinations jointly drive bias.
646 """
647 print("--- Analyzing Feature Interactions ---")
649 if not HAS_SHAP:
650 print(" SHAP not available, skipping interaction analysis")
651 return []
653 results = []
655 try:
656 # Get SHAP interaction values
657 explainer = shap.TreeExplainer(model)
658 shap_interaction = explainer.shap_interaction_values(
659 X.values[:500]
660 ) # Limit for speed
662 # Average absolute interaction strength
663 n_features = len(features)
664 interaction_matrix = np.zeros((n_features, n_features))
666 for i in range(n_features):
667 for j in range(n_features):
668 if i != j:
669 interaction_matrix[i, j] = np.abs(shap_interaction[:, i, j]).mean()
671 # Find top interactions
672 interactions_flat = []
673 for i in range(n_features):
674 for j in range(i + 1, n_features):
675 interactions_flat.append(
676 (
677 features[i],
678 features[j],
679 interaction_matrix[i, j] + interaction_matrix[j, i],
680 )
681 )
683 interactions_flat.sort(key=lambda x: x[2], reverse=True)
685 for feat1, feat2, strength in interactions_flat[:top_n]:
686 physics1 = FEATURE_PHYSICS.get(feat1, {}).get("high", f"high {feat1}")
687 physics2 = FEATURE_PHYSICS.get(feat2, {}).get("high", f"high {feat2}")
689 description = (
690 f"Combined effect of {physics1} and {physics2} "
691 f"creates bias beyond individual effects."
692 )
694 results.append(
695 InteractionResult(
696 feature_1=feat1,
697 feature_2=feat2,
698 interaction_strength=float(strength),
699 description=description,
700 )
701 )
703 # Plot interaction heatmap
704 fig, ax = plt.subplots(figsize=(10, 8))
705 im = ax.imshow(interaction_matrix, cmap="YlOrRd")
707 ax.set_xticks(range(n_features))
708 ax.set_yticks(range(n_features))
709 ax.set_xticklabels(features, rotation=45, ha="right")
710 ax.set_yticklabels(features)
712 # Add values
713 for i in range(n_features):
714 for j in range(n_features):
715 val = interaction_matrix[i, j]
716 color = "white" if val > interaction_matrix.max() / 2 else "black"
717 ax.text(
718 j,
719 i,
720 f"{val:.3f}",
721 ha="center",
722 va="center",
723 color=color,
724 fontsize=8,
725 )
727 plt.colorbar(im, ax=ax, label="Interaction Strength")
728 ax.set_title(
729 "Feature Interaction Strengths\n(Higher = stronger combined effect on bias)"
730 )
731 plt.tight_layout()
732 plt.savefig(
733 output_dir / "feature_interactions.png", dpi=150, bbox_inches="tight"
734 )
735 plt.close(fig)
736 print(
737 f" Saved interaction plot to {output_dir / 'feature_interactions.png'}"
738 )
740 except Exception as e:
741 print(f" Interaction analysis failed: {e}")
743 return results
746def analyze_regimes(
747 X: pd.DataFrame,
748 y_bias: np.ndarray,
749 features: list[str],
750 output_dir: Path,
751 n_clusters: int = 3,
752 bias_percentile: float = 75,
753) -> list[RegimeResult]:
754 """
755 Identify distinct error regimes through clustering.
757 Clusters high-bias cases to find systematic failure modes.
758 """
759 print("--- Analyzing Error Regimes ---")
760 results = []
762 # Focus on high-bias cases
763 bias_threshold = np.percentile(np.abs(y_bias), bias_percentile)
764 high_bias_mask = np.abs(y_bias) >= bias_threshold
766 X_high = X[high_bias_mask].copy()
767 y_high = y_bias[high_bias_mask]
769 # Hard minimum: need at least 10 high-bias cases
770 if len(X_high) < 10:
771 print(
772 f" WARNING: Only {len(X_high)} high-bias cases. Need at least 10 for meaningful regime analysis."
773 )
774 print(
775 " Skipping regime clustering. Consider lowering bias_percentile or getting more data."
776 )
778 # Still report the high-bias cases as a single "regime" for visibility
779 if len(X_high) > 0:
780 centroids = {f: float(X_high[f].mean()) for f in features}
781 mean_bias = float(y_high.mean())
782 feature_stats = {f: (X[f].mean(), X[f].std()) for f in features}
784 results.append(
785 RegimeResult(
786 regime_id=0,
787 n_cases=len(X_high),
788 mean_bias=mean_bias,
789 feature_centroids=centroids,
790 description=f"All {len(X_high)} high-bias cases (too few for clustering). Mean bias: {mean_bias:.4f}",
791 )
792 )
794 # Create simple summary plot
795 fig, ax = plt.subplots(figsize=(10, 5))
797 z_scores = []
798 for f in features:
799 mean, std = feature_stats[f]
800 z = (centroids[f] - mean) / std if std > 0 else 0
801 z_scores.append(z)
803 colors = ["red" if abs(z) > 1 else "steelblue" for z in z_scores]
804 ax.barh(features, z_scores, color=colors)
805 ax.axvline(0, color="k", linestyle="-", linewidth=0.5)
806 ax.axvline(-1, color="gray", linestyle="--", alpha=0.5)
807 ax.axvline(1, color="gray", linestyle="--", alpha=0.5)
808 ax.set_xlabel("Z-Score (deviation from dataset mean)")
809 ax.set_title(
810 f"High-Bias Cases Characteristics (n={len(X_high)})\nRed = distinctive (|Z| > 1)"
811 )
813 plt.tight_layout()
814 plt.savefig(output_dir / "error_regimes.png", dpi=150, bbox_inches="tight")
815 plt.close(fig)
816 print(f" Saved high-bias summary to {output_dir / 'error_regimes.png'}")
818 return results
820 # Adjust n_clusters based on available data
821 # Rule: at least 5 cases per cluster
822 max_clusters = len(X_high) // 5
823 if max_clusters < n_clusters:
824 print(
825 f" Reducing clusters from {n_clusters} to {max_clusters} (need 5+ cases per cluster)"
826 )
827 n_clusters = max(2, max_clusters)
829 # Standardize for clustering
830 scaler = StandardScaler()
831 X_scaled = scaler.fit_transform(X_high[features])
833 # Cluster
834 kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
835 cluster_labels = kmeans.fit_predict(X_scaled)
837 # Compute feature statistics for interpretation
838 feature_stats = {f: (X[f].mean(), X[f].std()) for f in features}
840 # First pass: collect all cluster info
841 cluster_info = []
842 for cluster_id in range(n_clusters):
843 mask = cluster_labels == cluster_id
844 cluster_X = X_high[mask]
845 cluster_y = y_high[mask]
846 centroids = {f: float(cluster_X[f].mean()) for f in features}
847 mean_bias = float(cluster_y.mean())
848 cluster_info.append(
849 {
850 "id": cluster_id,
851 "n_cases": int(mask.sum()),
852 "mean_bias": mean_bias,
853 "centroids": centroids,
854 }
855 )
857 # Second pass: describe regimes using RELATIVE differences between clusters
858 for info in cluster_info:
859 description = describe_regime_relative(
860 info["centroids"],
861 cluster_info,
862 feature_stats,
863 info["mean_bias"],
864 info["id"],
865 )
867 results.append(
868 RegimeResult(
869 regime_id=info["id"],
870 n_cases=info["n_cases"],
871 mean_bias=info["mean_bias"],
872 feature_centroids=info["centroids"],
873 description=description,
874 )
875 )
877 # Plot regime analysis
878 fig, axes = plt.subplots(1, 2, figsize=(14, 5))
880 # Left: PCA projection of clusters
881 from sklearn.decomposition import PCA
883 pca = PCA(n_components=2)
884 X_pca = pca.fit_transform(X_scaled)
886 scatter = axes[0].scatter(
887 X_pca[:, 0], X_pca[:, 1], c=cluster_labels, cmap="viridis", alpha=0.6, s=30
888 )
889 axes[0].set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]:.1%} variance)")
890 axes[0].set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]:.1%} variance)")
891 axes[0].set_title(f"High-Bias Cases Clustered into {n_clusters} Regimes")
892 plt.colorbar(scatter, ax=axes[0], label="Regime")
894 # Right: Regime characteristics bar chart
895 regime_data = []
896 for r in results:
897 for feat, val in r.feature_centroids.items():
898 mean, std = feature_stats[feat]
899 z_score = (val - mean) / std if std > 0 else 0
900 regime_data.append(
901 {
902 "Regime": f"Regime {r.regime_id + 1}\n({r.n_cases} cases)",
903 "Feature": feat,
904 "Z-Score": z_score,
905 }
906 )
908 regime_df = pd.DataFrame(regime_data)
909 regime_pivot = regime_df.pivot(index="Feature", columns="Regime", values="Z-Score")
911 regime_pivot.plot(kind="barh", ax=axes[1], width=0.8)
912 axes[1].axvline(0, color="k", linestyle="-", linewidth=0.5)
913 axes[1].axvline(-1, color="r", linestyle="--", alpha=0.5, linewidth=0.5)
914 axes[1].axvline(1, color="r", linestyle="--", alpha=0.5, linewidth=0.5)
915 axes[1].set_xlabel("Z-Score (deviation from mean)")
916 axes[1].set_title("Regime Characteristics\n(|Z| > 1 = distinctive)")
917 axes[1].legend(loc="best", fontsize=8)
919 plt.tight_layout()
920 plt.savefig(output_dir / "error_regimes.png", dpi=150, bbox_inches="tight")
921 plt.close(fig)
922 print(f" Saved regime analysis to {output_dir / 'error_regimes.png'}")
924 return results
927def analyze_parameter_relationships(
928 database: xr.Dataset,
929 calibrator,
930 features: list[str],
931 output_dir: Path,
932) -> list[ParameterRelationshipResult]:
933 """
934 Analyze how optimal parameters depend on atmospheric conditions.
936 Only applicable for local calibration where optimal params vary per case.
937 """
938 print("--- Analyzing Parameter-Condition Relationships ---")
939 results = []
941 # Check if we have local calibration results
942 if not hasattr(calibrator, "optimal_params_") or calibrator.optimal_params_ is None:
943 print(" No local calibration results available")
944 return results
946 swept_params = calibrator.swept_params
947 optimal_params = calibrator.optimal_params_
949 # Get feature values (from sample=0, features don't depend on sample)
950 X_df = database.isel(sample=0).to_dataframe().reset_index()
951 X = X_df[features]
953 n_params = len(swept_params)
954 n_features = len(features)
956 fig, axes = plt.subplots(
957 n_params, n_features, figsize=(4 * n_features, 4 * n_params)
958 )
959 if n_params == 1:
960 axes = axes.reshape(1, -1)
961 if n_features == 1:
962 axes = axes.reshape(-1, 1)
964 for p_idx, param in enumerate(swept_params):
965 param_values = optimal_params[param]
967 best_feature = None
968 best_corr = 0
970 for f_idx, feature in enumerate(features):
971 ax = axes[p_idx, f_idx]
972 feature_values = X[feature].values
974 # Compute correlation
975 corr = np.corrcoef(feature_values, param_values)[0, 1]
976 if np.isnan(corr):
977 corr = 0
979 if abs(corr) > abs(best_corr):
980 best_corr = corr
981 best_feature = feature
983 # Scatter plot with regression line
984 ax.scatter(feature_values, param_values, alpha=0.4, s=10)
986 # Add regression line
987 if abs(corr) > 0.1:
988 z = np.polyfit(feature_values, param_values, 1)
989 p = np.poly1d(z)
990 x_line = np.linspace(feature_values.min(), feature_values.max(), 100)
991 ax.plot(x_line, p(x_line), "r-", linewidth=2, label=f"r={corr:.2f}")
992 ax.legend(fontsize=8)
994 ax.set_xlabel(feature)
995 if f_idx == 0:
996 ax.set_ylabel(f"Optimal {param}")
997 ax.set_title(f"r = {corr:.3f}")
998 ax.grid(True, alpha=0.3)
1000 # Store result for best correlated feature
1001 if best_feature:
1002 if abs(best_corr) > 0.5:
1003 rel_type = "strong positive" if best_corr > 0 else "strong negative"
1004 elif abs(best_corr) > 0.3:
1005 rel_type = "moderate positive" if best_corr > 0 else "moderate negative"
1006 else:
1007 rel_type = "weak"
1009 interpretation = interpret_parameter_relationship(
1010 param, best_feature, best_corr
1011 )
1013 results.append(
1014 ParameterRelationshipResult(
1015 parameter=param,
1016 most_influential_feature=best_feature,
1017 correlation=float(best_corr),
1018 relationship_type=rel_type,
1019 physical_interpretation=interpretation,
1020 )
1021 )
1023 plt.suptitle(
1024 "Optimal Parameter vs. Atmospheric Conditions\n"
1025 "(Shows how model params should vary with conditions)",
1026 fontsize=12,
1027 )
1028 plt.tight_layout()
1029 plt.savefig(
1030 output_dir / "parameter_relationships.png", dpi=150, bbox_inches="tight"
1031 )
1032 plt.close(fig)
1033 print(
1034 f" Saved parameter relationships to {output_dir / 'parameter_relationships.png'}"
1035 )
1037 return results
1040def generate_summary(report: PhysicsInsightsReport) -> str:
1041 """Generate executive summary of physics insights."""
1042 lines = []
1044 # Summarize PD findings
1045 if report.partial_dependence:
1046 increasing = [
1047 r for r in report.partial_dependence if r.bias_direction == "increases"
1048 ]
1049 decreasing = [
1050 r for r in report.partial_dependence if r.bias_direction == "decreases"
1051 ]
1053 if increasing:
1054 feats = ", ".join(r.feature for r in increasing[:2])
1055 lines.append(f"• Model bias increases with {feats}")
1056 if decreasing:
1057 feats = ", ".join(r.feature for r in decreasing[:2])
1058 lines.append(f"• Model bias decreases with {feats}")
1060 # Summarize interactions
1061 if report.interactions:
1062 top = report.interactions[0]
1063 lines.append(f"• Strongest interaction: {top.feature_1} × {top.feature_2}")
1065 # Summarize regimes
1066 if report.regimes:
1067 n_regimes = len(report.regimes)
1068 lines.append(f"• Identified {n_regimes} distinct error regimes")
1069 for r in report.regimes:
1070 lines.append(f" - Regime {r.regime_id + 1}: {r.description}")
1072 # Summarize parameter insights
1073 if report.parameter_relationships:
1074 for r in report.parameter_relationships:
1075 if "strong" in r.relationship_type:
1076 lines.append(
1077 f"• {r.parameter} should be condition-dependent "
1078 f"(varies with {r.most_influential_feature})"
1079 )
1081 return "\n".join(lines)
1084# =============================================================================
1085# Main Entry Point
1086# =============================================================================
1089def run_physics_insights(
1090 database: xr.Dataset,
1091 fitted_model,
1092 calibrator,
1093 features_list: list[str],
1094 y_bias: np.ndarray,
1095 output_dir: Path,
1096 config: dict | None = None,
1097) -> PhysicsInsightsReport:
1098 """
1099 Run complete physics insights analysis.
1101 Args:
1102 database: xarray Dataset with model error database
1103 fitted_model: Trained ML model (or pipeline with 'model' step)
1104 calibrator: Fitted calibrator (for parameter relationships)
1105 features_list: List of feature names used
1106 y_bias: Bias values (predictions or actuals)
1107 output_dir: Directory for output plots
1108 config: Optional configuration dict
1110 Returns:
1111 PhysicsInsightsReport with all analysis results
1112 """
1113 config = config or {}
1114 output_dir = Path(output_dir)
1115 output_dir.mkdir(parents=True, exist_ok=True)
1117 print("=" * 60)
1118 print("PHYSICS INSIGHTS ANALYSIS")
1119 print("=" * 60)
1121 # Prepare feature matrix
1122 X_df = database.isel(sample=0).to_dataframe().reset_index()
1123 X = X_df[features_list]
1125 # Extract model from pipeline if needed
1126 if hasattr(fitted_model, "named_steps") and "model" in fitted_model.named_steps:
1127 model = fitted_model.named_steps["model"]
1128 else:
1129 model = fitted_model
1131 report = PhysicsInsightsReport()
1133 # 1. Partial Dependence
1134 if config.get("partial_dependence", {}).get("enabled", True):
1135 pd_features = config.get("partial_dependence", {}).get(
1136 "features", features_list
1137 )
1138 report.partial_dependence = analyze_partial_dependence(
1139 model=fitted_model, # Use full pipeline for PD
1140 X=X,
1141 features=pd_features,
1142 output_dir=output_dir,
1143 )
1145 # 2. Interactions (requires tree model)
1146 if config.get("interactions", {}).get("enabled", True):
1147 if hasattr(model, "feature_importances_"): # Tree-based model
1148 report.interactions = analyze_interactions(
1149 model=model,
1150 X=X,
1151 features=features_list,
1152 output_dir=output_dir,
1153 top_n=config.get("interactions", {}).get("top_n", 5),
1154 )
1155 else:
1156 print(" Skipping interactions (requires tree-based model)")
1158 # 3. Regime Analysis
1159 if config.get("regime_analysis", {}).get("enabled", True):
1160 report.regimes = analyze_regimes(
1161 X=X,
1162 y_bias=y_bias,
1163 features=features_list,
1164 output_dir=output_dir,
1165 n_clusters=config.get("regime_analysis", {}).get("n_clusters", 3),
1166 bias_percentile=config.get("regime_analysis", {}).get(
1167 "bias_percentile", 75
1168 ),
1169 )
1171 # 4. Parameter Relationships (local calibration only)
1172 if config.get("parameter_relationships", {}).get("enabled", True):
1173 if hasattr(calibrator, "optimal_params_"):
1174 report.parameter_relationships = analyze_parameter_relationships(
1175 database=database,
1176 calibrator=calibrator,
1177 features=features_list,
1178 output_dir=output_dir,
1179 )
1181 # Generate summary
1182 report.summary = generate_summary(report)
1184 # Save report
1185 report_md = report.to_markdown()
1186 with open(output_dir / "physics_insights_report.md", "w") as f:
1187 f.write(report_md)
1188 print(f"\nSaved report to {output_dir / 'physics_insights_report.md'}")
1190 # Save as JSON for programmatic access
1191 import json
1193 with open(output_dir / "physics_insights.json", "w") as f:
1194 json.dump(report.to_dict(), f, indent=2)
1195 print(f"Saved JSON to {output_dir / 'physics_insights.json'}")
1197 print("\n" + "=" * 60)
1198 print("SUMMARY")
1199 print("=" * 60)
1200 print(report.summary)
1202 return report