Coverage for wifa_uq / postprocessing / physics_insights / physics_insights.py: 81%

518 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-19 02:10 +0000

1# wifa_uq/postprocessing/physics_insights/physics_insights.py 

2""" 

3Physics Insights Module for WIFA-UQ. 

4 

5Extracts interpretable physical insights from bias prediction models: 

6 1. Partial Dependence Analysis - How bias varies with each feature 

7 2. Feature Interactions - Which feature combinations drive error 

8 3. Regime Identification - Clustering of high-bias conditions 

9 4. Parameter-Condition Relationships - How optimal params depend on conditions 

10 

11These analyses transform ML results into actionable physics understanding. 

12""" 

13 

14from __future__ import annotations 

15 

16from dataclasses import dataclass, field 

17from pathlib import Path 

18 

19import matplotlib.pyplot as plt 

20import numpy as np 

21import pandas as pd 

22import xarray as xr 

23from sklearn.cluster import KMeans 

24from sklearn.inspection import partial_dependence 

25from sklearn.preprocessing import StandardScaler 

26 

27try: 

28 import shap 

29 

30 HAS_SHAP = True 

31except ImportError: 

32 HAS_SHAP = False 

33 

34 

35# ============================================================================= 

36# Data Classes for Results 

37# ============================================================================= 

38 

39 

40@dataclass 

41class PartialDependenceResult: 

42 """Results from partial dependence analysis.""" 

43 

44 feature: str 

45 grid_values: np.ndarray 

46 pd_values: np.ndarray 

47 bias_direction: str # "increases", "decreases", "non-monotonic" 

48 effect_magnitude: float # Range of PD values 

49 physical_interpretation: str 

50 

51 

52@dataclass 

53class InteractionResult: 

54 """Results from interaction analysis.""" 

55 

56 feature_1: str 

57 feature_2: str 

58 interaction_strength: float 

59 description: str 

60 

61 

62@dataclass 

63class RegimeResult: 

64 """Results from regime identification.""" 

65 

66 regime_id: int 

67 n_cases: int 

68 mean_bias: float 

69 feature_centroids: dict[str, float] 

70 description: str 

71 

72 

73@dataclass 

74class ParameterRelationshipResult: 

75 """Results from parameter-condition analysis.""" 

76 

77 parameter: str 

78 most_influential_feature: str 

79 correlation: float 

80 relationship_type: str # "positive", "negative", "weak" 

81 physical_interpretation: str 

82 

83 

84@dataclass 

85class PhysicsInsightsReport: 

86 """Complete physics insights report.""" 

87 

88 partial_dependence: list[PartialDependenceResult] = field(default_factory=list) 

89 interactions: list[InteractionResult] = field(default_factory=list) 

90 regimes: list[RegimeResult] = field(default_factory=list) 

91 parameter_relationships: list[ParameterRelationshipResult] = field( 

92 default_factory=list 

93 ) 

94 summary: str = "" 

95 

96 def to_dict(self) -> dict: 

97 """Convert to dictionary for serialization.""" 

98 return { 

99 "partial_dependence": [ 

100 { 

101 "feature": r.feature, 

102 "bias_direction": r.bias_direction, 

103 "effect_magnitude": r.effect_magnitude, 

104 "interpretation": r.physical_interpretation, 

105 } 

106 for r in self.partial_dependence 

107 ], 

108 "interactions": [ 

109 { 

110 "features": [r.feature_1, r.feature_2], 

111 "strength": r.interaction_strength, 

112 "description": r.description, 

113 } 

114 for r in self.interactions 

115 ], 

116 "regimes": [ 

117 { 

118 "regime_id": r.regime_id, 

119 "n_cases": r.n_cases, 

120 "mean_bias": r.mean_bias, 

121 "centroids": r.feature_centroids, 

122 "description": r.description, 

123 } 

124 for r in self.regimes 

125 ], 

126 "parameter_relationships": [ 

127 { 

128 "parameter": r.parameter, 

129 "most_influential_feature": r.most_influential_feature, 

130 "correlation": r.correlation, 

131 "interpretation": r.physical_interpretation, 

132 } 

133 for r in self.parameter_relationships 

134 ], 

135 "summary": self.summary, 

136 } 

137 

138 def to_markdown(self) -> str: 

139 """Generate markdown report.""" 

140 lines = ["# Physics Insights Report\n"] 

141 

142 if self.partial_dependence: 

143 lines.append("## 1. How Model Bias Depends on Atmospheric Conditions\n") 

144 for r in self.partial_dependence: 

145 lines.append(f"### {r.feature}") 

146 lines.append( 

147 f"- **Direction**: Bias {r.bias_direction} with {r.feature}" 

148 ) 

149 lines.append(f"- **Effect magnitude**: {r.effect_magnitude:.4f}") 

150 lines.append(f"- **Interpretation**: {r.physical_interpretation}\n") 

151 

152 if self.interactions: 

153 lines.append("## 2. Feature Interactions\n") 

154 for r in self.interactions: 

155 lines.append( 

156 f"- **{r.feature_1} × {r.feature_2}**: " 

157 f"strength = {r.interaction_strength:.4f}" 

158 ) 

159 lines.append(f" - {r.description}\n") 

160 

161 if self.regimes: 

162 lines.append("## 3. Error Regimes (Failure Modes)\n") 

163 for r in self.regimes: 

164 lines.append(f"### Regime {r.regime_id + 1}: {r.description}") 

165 lines.append(f"- Cases: {r.n_cases}") 

166 lines.append(f"- Mean bias: {r.mean_bias:.4f}") 

167 lines.append("- Characteristic conditions:") 

168 for feat, val in r.feature_centroids.items(): 

169 lines.append(f" - {feat}: {val:.3f}") 

170 lines.append("") 

171 

172 if self.parameter_relationships: 

173 lines.append("## 4. Optimal Parameter Dependencies\n") 

174 lines.append( 

175 "*How the 'best' wake model parameters vary with conditions:*\n" 

176 ) 

177 for r in self.parameter_relationships: 

178 lines.append(f"### {r.parameter}") 

179 lines.append( 

180 f"- Most influential feature: **{r.most_influential_feature}**" 

181 ) 

182 lines.append( 

183 f"- Correlation: {r.correlation:.3f} ({r.relationship_type})" 

184 ) 

185 lines.append(f"- **Implication**: {r.physical_interpretation}\n") 

186 

187 if self.summary: 

188 lines.append("## Summary\n") 

189 lines.append(self.summary) 

190 

191 return "\n".join(lines) 

192 

193 

194# ============================================================================= 

195# Physical Interpretation Helpers 

196# ============================================================================= 

197 

198# Domain knowledge for automatic interpretation 

199FEATURE_PHYSICS = { 

200 "ABL_height": { 

201 "high": "stable stratification / nocturnal conditions", 

202 "low": "convective / well-mixed conditions", 

203 "unit": "m", 

204 }, 

205 "wind_veer": { 

206 "high": "strong directional shear / Ekman spiral", 

207 "low": "uniform wind direction with height", 

208 "unit": "deg/m", 

209 }, 

210 "lapse_rate": { 

211 "high": "stable stratification (positive dθ/dz)", 

212 "low": "unstable/neutral conditions", 

213 "unit": "K/m", 

214 }, 

215 "turbulence_intensity": { 

216 "high": "high ambient turbulence / faster wake recovery", 

217 "low": "low turbulence / slower wake recovery", 

218 "unit": "-", 

219 }, 

220 "Blockage_Ratio": { 

221 "high": "significant upstream blockage", 

222 "low": "front-row or minimal blockage", 

223 "unit": "-", 

224 }, 

225 "Blocking_Distance": { 

226 "high": "far from upstream turbines", 

227 "low": "close to upstream turbines", 

228 "unit": "-", 

229 }, 

230 "Farm_Length": { 

231 "high": "deep array (many rows)", 

232 "low": "shallow array", 

233 "unit": "D", 

234 }, 

235 "Farm_Width": { 

236 "high": "wide array", 

237 "low": "narrow array", 

238 "unit": "D", 

239 }, 

240} 

241 

242PARAMETER_PHYSICS = { 

243 "k_b": { 

244 "name": "Wake expansion coefficient", 

245 "increases_with": "faster wake recovery / higher turbulence", 

246 "decreases_with": "slower wake recovery / stable conditions", 

247 }, 

248 "ss_alpha": { 

249 "name": "Self-similarity blockage parameter", 

250 "increases_with": "stronger blockage effects", 

251 "decreases_with": "weaker blockage effects", 

252 }, 

253 "ceps": { 

254 "name": "Added turbulence coefficient", 

255 "increases_with": "more wake-added turbulence", 

256 "decreases_with": "less wake-added turbulence", 

257 }, 

258} 

259 

260 

261def interpret_pd_direction( 

262 feature: str, 

263 direction: str, # "increases", "decreases", "non-monotonic", or "flat" 

264) -> str: 

265 """Generate physical interpretation of partial dependence direction.""" 

266 physics = FEATURE_PHYSICS.get(feature, {}) 

267 high_meaning = physics.get("high", f"high {feature}") 

268 low_meaning = physics.get("low", f"low {feature}") 

269 

270 if direction == "increases": 

271 return ( 

272 f"Model bias increases with {feature}, suggesting the wake model " 

273 f"systematically underestimates wake effects in {high_meaning} " 

274 f"(and/or overestimates in {low_meaning})." 

275 ) 

276 elif direction == "decreases": 

277 return ( 

278 f"Model bias decreases with {feature}, suggesting the wake model " 

279 f"systematically overestimates wake effects in {high_meaning} " 

280 f"(and/or underestimates in {low_meaning})." 

281 ) 

282 else: 

283 return ( 

284 f"Model bias shows a non-monotonic relationship with {feature}, " 

285 f"suggesting different error mechanisms at {high_meaning} vs {low_meaning}." 

286 ) 

287 

288 

289def interpret_parameter_relationship( 

290 parameter: str, 

291 feature: str, 

292 correlation: float, 

293) -> str: 

294 """Generate physical interpretation of parameter-feature relationship.""" 

295 param_physics = PARAMETER_PHYSICS.get(parameter, {"name": parameter}) 

296 feat_physics = FEATURE_PHYSICS.get(feature, {}) 

297 

298 param_name = param_physics.get("name", parameter) 

299 high_meaning = feat_physics.get("high", f"high {feature}") 

300 

301 if abs(correlation) < 0.3: 

302 return f"{param_name} shows weak dependence on {feature}." 

303 

304 if correlation > 0: 

305 return ( 

306 f"{param_name} should increase in {high_meaning}. " 

307 f"This suggests the model's default parameterization underestimates " 

308 f"wake recovery rate in these conditions." 

309 ) 

310 else: 

311 return ( 

312 f"{param_name} should decrease in {high_meaning}. " 

313 f"This suggests the model's default parameterization overestimates " 

314 f"wake recovery rate in these conditions." 

315 ) 

316 

317 

318def describe_regime( 

319 centroids: dict[str, float], 

320 feature_stats: dict[str, tuple[float, float]], # {feature: (mean, std)} 

321 mean_bias: float, 

322) -> str: 

323 """Generate description of an error regime based on its characteristics.""" 

324 descriptions = [] 

325 

326 for feature, value in centroids.items(): 

327 if feature not in feature_stats: 

328 continue 

329 mean, std = feature_stats[feature] 

330 if std == 0: 

331 continue 

332 

333 z_score = (value - mean) / std 

334 physics = FEATURE_PHYSICS.get(feature, {}) 

335 

336 if z_score > 1.0: 

337 descriptions.append(physics.get("high", f"high {feature}")) 

338 elif z_score < -1.0: 

339 descriptions.append(physics.get("low", f"low {feature}")) 

340 

341 if not descriptions: 

342 if mean_bias > 0: 

343 return "Mixed conditions with positive bias (model overestimates)" 

344 else: 

345 return "Mixed conditions with negative bias (model underestimates)" 

346 

347 bias_desc = "overestimation" if mean_bias > 0 else "underestimation" 

348 return f"{', '.join(descriptions[:2])}{bias_desc}" 

349 

350 

351def describe_regime_relative( 

352 centroids: dict[str, float], 

353 all_cluster_info: list[dict], 

354 feature_stats: dict[str, tuple[float, float]], 

355 mean_bias: float, 

356 cluster_id: int, 

357) -> str: 

358 """ 

359 Generate description of an error regime based on RELATIVE differences to other clusters. 

360 

361 This is more informative when clusters are similar in absolute terms but differ 

362 from each other in specific ways. 

363 """ 

364 # First try absolute description (z-scores > 1) 

365 descriptions = [] 

366 for feature, value in centroids.items(): 

367 if feature not in feature_stats: 

368 continue 

369 mean, std = feature_stats[feature] 

370 if std == 0: 

371 continue 

372 z_score = (value - mean) / std 

373 physics = FEATURE_PHYSICS.get(feature, {}) 

374 if z_score > 1.0: 

375 descriptions.append(physics.get("high", f"high {feature}")) 

376 elif z_score < -1.0: 

377 descriptions.append(physics.get("low", f"low {feature}")) 

378 

379 # If we found distinctive absolute features, use those 

380 if descriptions: 

381 bias_desc = "overestimation" if mean_bias > 0 else "underestimation" 

382 return f"{', '.join(descriptions[:2])}{bias_desc}" 

383 

384 # Otherwise, find what makes THIS cluster different from OTHERS 

385 other_clusters = [c for c in all_cluster_info if c["id"] != cluster_id] 

386 if not other_clusters: 

387 bias_desc = "overestimation" if mean_bias > 0 else "underestimation" 

388 return f"High-error cases → {bias_desc}" 

389 

390 # Compute relative differences for each feature 

391 relative_diffs = [] 

392 for feature, value in centroids.items(): 

393 if feature not in feature_stats: 

394 continue 

395 mean, std = feature_stats[feature] 

396 if std == 0: 

397 continue 

398 

399 # Average value in other clusters 

400 other_avg = np.mean([c["centroids"][feature] for c in other_clusters]) 

401 diff = (value - other_avg) / std if std > 0 else 0 

402 relative_diffs.append((feature, diff, value, other_avg)) 

403 

404 # Sort by absolute difference 

405 relative_diffs.sort(key=lambda x: abs(x[1]), reverse=True) 

406 

407 # Build description from top distinguishing features 

408 distinguishing = [] 

409 for feature, diff, val, other_avg in relative_diffs[:2]: 

410 if abs(diff) < 0.3: # Not meaningfully different 

411 continue 

412 physics = FEATURE_PHYSICS.get(feature, {}) 

413 if diff > 0: 

414 desc = physics.get("high", f"higher {feature}") 

415 distinguishing.append(f"{desc} (vs other regimes)") 

416 else: 

417 desc = physics.get("low", f"lower {feature}") 

418 distinguishing.append(f"{desc} (vs other regimes)") 

419 

420 bias_desc = "overestimation" if mean_bias > 0 else "underestimation" 

421 

422 if distinguishing: 

423 return f"{'; '.join(distinguishing)}{bias_desc}" 

424 

425 # Last resort: describe by bias magnitude 

426 other_biases = [c["mean_bias"] for c in other_clusters] 

427 avg_other_bias = np.mean(other_biases) 

428 

429 if abs(mean_bias) > abs(avg_other_bias) * 1.5: 

430 return f"Highest error magnitude ({mean_bias:.4f}) → {bias_desc}" 

431 elif abs(mean_bias) < abs(avg_other_bias) * 0.7: 

432 return f"Lower error magnitude ({mean_bias:.4f}) → {bias_desc}" 

433 else: 

434 return f"Similar conditions, bias={mean_bias:.4f}{bias_desc}" 

435 

436 

437# ============================================================================= 

438# Analysis Functions 

439# ============================================================================= 

440 

441 

442def _manual_partial_dependence( 

443 model, 

444 X: pd.DataFrame, 

445 features: list[str], 

446 grid_resolution: int = 50, 

447) -> dict: 

448 """ 

449 Manual partial dependence calculation as fallback. 

450 

451 For each feature, creates a grid of values and computes mean prediction 

452 while averaging over all other features. 

453 """ 

454 results = { 

455 "average": [], 

456 "grid_values": [], 

457 } 

458 

459 X_array = X.values 

460 

461 for feature in features: 

462 feat_idx = X.columns.get_loc(feature) 

463 feat_values = X[feature].values 

464 

465 # Create grid 

466 grid = np.linspace(feat_values.min(), feat_values.max(), grid_resolution) 

467 

468 # Compute partial dependence 

469 pd_values = [] 

470 for grid_val in grid: 

471 # Create modified X with feature set to grid value 

472 X_modified = X_array.copy() 

473 X_modified[:, feat_idx] = grid_val 

474 

475 # Predict and average 

476 predictions = model.predict(X_modified) 

477 pd_values.append(predictions.mean()) 

478 

479 results["average"].append(np.array(pd_values)) 

480 results["grid_values"].append(grid) 

481 

482 return results 

483 

484 

485def analyze_partial_dependence( 

486 model, 

487 X: pd.DataFrame, 

488 features: list[str], 

489 output_dir: Path, 

490 grid_resolution: int = 50, 

491) -> list[PartialDependenceResult]: 

492 """ 

493 Compute partial dependence and extract physical interpretations. 

494 

495 Shows how predicted bias changes with each feature, holding others constant. 

496 """ 

497 print("--- Analyzing Partial Dependence ---") 

498 results = [] 

499 

500 # Ensure X is a DataFrame with correct types 

501 X = X.copy() 

502 for col in X.columns: 

503 if X[col].dtype == "object": 

504 X[col] = pd.to_numeric(X[col], errors="coerce") 

505 X = X.astype(float) 

506 

507 # Verify model is fitted by trying a prediction 

508 try: 

509 _ = model.predict(X.iloc[:1]) 

510 except Exception as e: 

511 print(f" WARNING: Model does not appear to be fitted: {e}") 

512 print(" Attempting to re-fit model...") 

513 raise ValueError( 

514 f"Model must be fitted before partial dependence analysis: {e}" 

515 ) 

516 

517 n_features = len(features) 

518 n_cols = min(3, n_features) 

519 n_rows = (n_features + n_cols - 1) // n_cols 

520 

521 fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows)) 

522 if n_features == 1: 

523 axes = np.array([[axes]]) 

524 elif n_rows == 1: 

525 axes = axes.reshape(1, -1) 

526 

527 for idx, feature in enumerate(features): 

528 row, col = divmod(idx, n_cols) 

529 ax = axes[row, col] 

530 

531 # Compute PD for this single feature (avoids sklearn multi-feature output issues) 

532 feature_idx = X.columns.get_loc(feature) 

533 

534 try: 

535 pd_result = partial_dependence( 

536 model, X, features=[feature_idx], grid_resolution=grid_resolution 

537 ) 

538 pd_values = pd_result["average"][0] 

539 grid_values = pd_result["grid_values"][0] 

540 except Exception as e: 

541 print(f" WARNING: partial_dependence failed for {feature}: {e}") 

542 print(" Falling back to manual PD calculation...") 

543 manual_result = _manual_partial_dependence( 

544 model, X, [feature], grid_resolution 

545 ) 

546 pd_values = manual_result["average"][0] 

547 grid_values = manual_result["grid_values"][0] 

548 

549 # Flatten if needed (sklearn may return 2D arrays) 

550 pd_values = np.asarray(pd_values).flatten() 

551 grid_values = np.asarray(grid_values).flatten() 

552 

553 # Determine direction 

554 n_points = len(pd_values) 

555 n_edge = min(5, n_points // 4) if n_points > 4 else 1 

556 

557 start_val = pd_values[:n_edge].mean() 

558 end_val = pd_values[-n_edge:].mean() 

559 mid_val = pd_values[n_points // 2] if n_points > 2 else pd_values.mean() 

560 

561 value_range = pd_values.max() - pd_values.min() 

562 threshold = 0.01 * value_range if value_range > 0 else 1e-10 

563 

564 if end_val > start_val + threshold: 

565 direction = "increases" 

566 elif end_val < start_val - threshold: 

567 direction = "decreases" 

568 else: 

569 # Check for non-monotonicity 

570 if abs(mid_val - start_val) > abs(end_val - start_val): 

571 direction = "non-monotonic" 

572 else: 

573 direction = "flat" 

574 

575 effect_magnitude = float(value_range) if value_range > 0 else 0.0 

576 interpretation = interpret_pd_direction(feature, direction) 

577 

578 results.append( 

579 PartialDependenceResult( 

580 feature=feature, 

581 grid_values=grid_values, 

582 pd_values=pd_values, 

583 bias_direction=direction, 

584 effect_magnitude=float(effect_magnitude), 

585 physical_interpretation=interpretation, 

586 ) 

587 ) 

588 

589 # Plot 

590 ax.plot(grid_values, pd_values, "b-", linewidth=2) 

591 ax.fill_between(grid_values, pd_values, alpha=0.3) 

592 ax.axhline(0, color="k", linestyle="--", alpha=0.5) 

593 

594 # Add direction annotation 

595 direction_symbol = { 

596 "increases": "↗", 

597 "decreases": "↘", 

598 "non-monotonic": "↝", 

599 "flat": "→", 

600 } 

601 ax.set_title(f"{feature} {direction_symbol.get(direction, '')}") 

602 ax.set_xlabel(feature) 

603 ax.set_ylabel("Partial Dependence (Bias)") 

604 ax.grid(True, alpha=0.3) 

605 

606 # Add interpretation as text box 

607 textstr = f"Effect: {effect_magnitude:.4f}\n{direction}" 

608 props = dict(boxstyle="round", facecolor="wheat", alpha=0.5) 

609 ax.text( 

610 0.02, 

611 0.98, 

612 textstr, 

613 transform=ax.transAxes, 

614 fontsize=8, 

615 verticalalignment="top", 

616 bbox=props, 

617 ) 

618 

619 # Hide unused subplots 

620 for idx in range(n_features, n_rows * n_cols): 

621 row, col = divmod(idx, n_cols) 

622 axes[row, col].set_visible(False) 

623 

624 plt.suptitle("Partial Dependence: How Bias Varies with Each Feature", fontsize=14) 

625 plt.tight_layout() 

626 plt.savefig(output_dir / "partial_dependence.png", dpi=150, bbox_inches="tight") 

627 plt.close(fig) 

628 print( 

629 f" Saved partial dependence plot to {output_dir / 'partial_dependence.png'}" 

630 ) 

631 

632 return results 

633 

634 

635def analyze_interactions( 

636 model, 

637 X: pd.DataFrame, 

638 features: list[str], 

639 output_dir: Path, 

640 top_n: int = 5, 

641) -> list[InteractionResult]: 

642 """ 

643 Analyze feature interactions using SHAP interaction values. 

644 

645 Identifies which feature combinations jointly drive bias. 

646 """ 

647 print("--- Analyzing Feature Interactions ---") 

648 

649 if not HAS_SHAP: 

650 print(" SHAP not available, skipping interaction analysis") 

651 return [] 

652 

653 results = [] 

654 

655 try: 

656 # Get SHAP interaction values 

657 explainer = shap.TreeExplainer(model) 

658 shap_interaction = explainer.shap_interaction_values( 

659 X.values[:500] 

660 ) # Limit for speed 

661 

662 # Average absolute interaction strength 

663 n_features = len(features) 

664 interaction_matrix = np.zeros((n_features, n_features)) 

665 

666 for i in range(n_features): 

667 for j in range(n_features): 

668 if i != j: 

669 interaction_matrix[i, j] = np.abs(shap_interaction[:, i, j]).mean() 

670 

671 # Find top interactions 

672 interactions_flat = [] 

673 for i in range(n_features): 

674 for j in range(i + 1, n_features): 

675 interactions_flat.append( 

676 ( 

677 features[i], 

678 features[j], 

679 interaction_matrix[i, j] + interaction_matrix[j, i], 

680 ) 

681 ) 

682 

683 interactions_flat.sort(key=lambda x: x[2], reverse=True) 

684 

685 for feat1, feat2, strength in interactions_flat[:top_n]: 

686 physics1 = FEATURE_PHYSICS.get(feat1, {}).get("high", f"high {feat1}") 

687 physics2 = FEATURE_PHYSICS.get(feat2, {}).get("high", f"high {feat2}") 

688 

689 description = ( 

690 f"Combined effect of {physics1} and {physics2} " 

691 f"creates bias beyond individual effects." 

692 ) 

693 

694 results.append( 

695 InteractionResult( 

696 feature_1=feat1, 

697 feature_2=feat2, 

698 interaction_strength=float(strength), 

699 description=description, 

700 ) 

701 ) 

702 

703 # Plot interaction heatmap 

704 fig, ax = plt.subplots(figsize=(10, 8)) 

705 im = ax.imshow(interaction_matrix, cmap="YlOrRd") 

706 

707 ax.set_xticks(range(n_features)) 

708 ax.set_yticks(range(n_features)) 

709 ax.set_xticklabels(features, rotation=45, ha="right") 

710 ax.set_yticklabels(features) 

711 

712 # Add values 

713 for i in range(n_features): 

714 for j in range(n_features): 

715 val = interaction_matrix[i, j] 

716 color = "white" if val > interaction_matrix.max() / 2 else "black" 

717 ax.text( 

718 j, 

719 i, 

720 f"{val:.3f}", 

721 ha="center", 

722 va="center", 

723 color=color, 

724 fontsize=8, 

725 ) 

726 

727 plt.colorbar(im, ax=ax, label="Interaction Strength") 

728 ax.set_title( 

729 "Feature Interaction Strengths\n(Higher = stronger combined effect on bias)" 

730 ) 

731 plt.tight_layout() 

732 plt.savefig( 

733 output_dir / "feature_interactions.png", dpi=150, bbox_inches="tight" 

734 ) 

735 plt.close(fig) 

736 print( 

737 f" Saved interaction plot to {output_dir / 'feature_interactions.png'}" 

738 ) 

739 

740 except Exception as e: 

741 print(f" Interaction analysis failed: {e}") 

742 

743 return results 

744 

745 

746def analyze_regimes( 

747 X: pd.DataFrame, 

748 y_bias: np.ndarray, 

749 features: list[str], 

750 output_dir: Path, 

751 n_clusters: int = 3, 

752 bias_percentile: float = 75, 

753) -> list[RegimeResult]: 

754 """ 

755 Identify distinct error regimes through clustering. 

756 

757 Clusters high-bias cases to find systematic failure modes. 

758 """ 

759 print("--- Analyzing Error Regimes ---") 

760 results = [] 

761 

762 # Focus on high-bias cases 

763 bias_threshold = np.percentile(np.abs(y_bias), bias_percentile) 

764 high_bias_mask = np.abs(y_bias) >= bias_threshold 

765 

766 X_high = X[high_bias_mask].copy() 

767 y_high = y_bias[high_bias_mask] 

768 

769 # Hard minimum: need at least 10 high-bias cases 

770 if len(X_high) < 10: 

771 print( 

772 f" WARNING: Only {len(X_high)} high-bias cases. Need at least 10 for meaningful regime analysis." 

773 ) 

774 print( 

775 " Skipping regime clustering. Consider lowering bias_percentile or getting more data." 

776 ) 

777 

778 # Still report the high-bias cases as a single "regime" for visibility 

779 if len(X_high) > 0: 

780 centroids = {f: float(X_high[f].mean()) for f in features} 

781 mean_bias = float(y_high.mean()) 

782 feature_stats = {f: (X[f].mean(), X[f].std()) for f in features} 

783 

784 results.append( 

785 RegimeResult( 

786 regime_id=0, 

787 n_cases=len(X_high), 

788 mean_bias=mean_bias, 

789 feature_centroids=centroids, 

790 description=f"All {len(X_high)} high-bias cases (too few for clustering). Mean bias: {mean_bias:.4f}", 

791 ) 

792 ) 

793 

794 # Create simple summary plot 

795 fig, ax = plt.subplots(figsize=(10, 5)) 

796 

797 z_scores = [] 

798 for f in features: 

799 mean, std = feature_stats[f] 

800 z = (centroids[f] - mean) / std if std > 0 else 0 

801 z_scores.append(z) 

802 

803 colors = ["red" if abs(z) > 1 else "steelblue" for z in z_scores] 

804 ax.barh(features, z_scores, color=colors) 

805 ax.axvline(0, color="k", linestyle="-", linewidth=0.5) 

806 ax.axvline(-1, color="gray", linestyle="--", alpha=0.5) 

807 ax.axvline(1, color="gray", linestyle="--", alpha=0.5) 

808 ax.set_xlabel("Z-Score (deviation from dataset mean)") 

809 ax.set_title( 

810 f"High-Bias Cases Characteristics (n={len(X_high)})\nRed = distinctive (|Z| > 1)" 

811 ) 

812 

813 plt.tight_layout() 

814 plt.savefig(output_dir / "error_regimes.png", dpi=150, bbox_inches="tight") 

815 plt.close(fig) 

816 print(f" Saved high-bias summary to {output_dir / 'error_regimes.png'}") 

817 

818 return results 

819 

820 # Adjust n_clusters based on available data 

821 # Rule: at least 5 cases per cluster 

822 max_clusters = len(X_high) // 5 

823 if max_clusters < n_clusters: 

824 print( 

825 f" Reducing clusters from {n_clusters} to {max_clusters} (need 5+ cases per cluster)" 

826 ) 

827 n_clusters = max(2, max_clusters) 

828 

829 # Standardize for clustering 

830 scaler = StandardScaler() 

831 X_scaled = scaler.fit_transform(X_high[features]) 

832 

833 # Cluster 

834 kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) 

835 cluster_labels = kmeans.fit_predict(X_scaled) 

836 

837 # Compute feature statistics for interpretation 

838 feature_stats = {f: (X[f].mean(), X[f].std()) for f in features} 

839 

840 # First pass: collect all cluster info 

841 cluster_info = [] 

842 for cluster_id in range(n_clusters): 

843 mask = cluster_labels == cluster_id 

844 cluster_X = X_high[mask] 

845 cluster_y = y_high[mask] 

846 centroids = {f: float(cluster_X[f].mean()) for f in features} 

847 mean_bias = float(cluster_y.mean()) 

848 cluster_info.append( 

849 { 

850 "id": cluster_id, 

851 "n_cases": int(mask.sum()), 

852 "mean_bias": mean_bias, 

853 "centroids": centroids, 

854 } 

855 ) 

856 

857 # Second pass: describe regimes using RELATIVE differences between clusters 

858 for info in cluster_info: 

859 description = describe_regime_relative( 

860 info["centroids"], 

861 cluster_info, 

862 feature_stats, 

863 info["mean_bias"], 

864 info["id"], 

865 ) 

866 

867 results.append( 

868 RegimeResult( 

869 regime_id=info["id"], 

870 n_cases=info["n_cases"], 

871 mean_bias=info["mean_bias"], 

872 feature_centroids=info["centroids"], 

873 description=description, 

874 ) 

875 ) 

876 

877 # Plot regime analysis 

878 fig, axes = plt.subplots(1, 2, figsize=(14, 5)) 

879 

880 # Left: PCA projection of clusters 

881 from sklearn.decomposition import PCA 

882 

883 pca = PCA(n_components=2) 

884 X_pca = pca.fit_transform(X_scaled) 

885 

886 scatter = axes[0].scatter( 

887 X_pca[:, 0], X_pca[:, 1], c=cluster_labels, cmap="viridis", alpha=0.6, s=30 

888 ) 

889 axes[0].set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]:.1%} variance)") 

890 axes[0].set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]:.1%} variance)") 

891 axes[0].set_title(f"High-Bias Cases Clustered into {n_clusters} Regimes") 

892 plt.colorbar(scatter, ax=axes[0], label="Regime") 

893 

894 # Right: Regime characteristics bar chart 

895 regime_data = [] 

896 for r in results: 

897 for feat, val in r.feature_centroids.items(): 

898 mean, std = feature_stats[feat] 

899 z_score = (val - mean) / std if std > 0 else 0 

900 regime_data.append( 

901 { 

902 "Regime": f"Regime {r.regime_id + 1}\n({r.n_cases} cases)", 

903 "Feature": feat, 

904 "Z-Score": z_score, 

905 } 

906 ) 

907 

908 regime_df = pd.DataFrame(regime_data) 

909 regime_pivot = regime_df.pivot(index="Feature", columns="Regime", values="Z-Score") 

910 

911 regime_pivot.plot(kind="barh", ax=axes[1], width=0.8) 

912 axes[1].axvline(0, color="k", linestyle="-", linewidth=0.5) 

913 axes[1].axvline(-1, color="r", linestyle="--", alpha=0.5, linewidth=0.5) 

914 axes[1].axvline(1, color="r", linestyle="--", alpha=0.5, linewidth=0.5) 

915 axes[1].set_xlabel("Z-Score (deviation from mean)") 

916 axes[1].set_title("Regime Characteristics\n(|Z| > 1 = distinctive)") 

917 axes[1].legend(loc="best", fontsize=8) 

918 

919 plt.tight_layout() 

920 plt.savefig(output_dir / "error_regimes.png", dpi=150, bbox_inches="tight") 

921 plt.close(fig) 

922 print(f" Saved regime analysis to {output_dir / 'error_regimes.png'}") 

923 

924 return results 

925 

926 

927def analyze_parameter_relationships( 

928 database: xr.Dataset, 

929 calibrator, 

930 features: list[str], 

931 output_dir: Path, 

932) -> list[ParameterRelationshipResult]: 

933 """ 

934 Analyze how optimal parameters depend on atmospheric conditions. 

935 

936 Only applicable for local calibration where optimal params vary per case. 

937 """ 

938 print("--- Analyzing Parameter-Condition Relationships ---") 

939 results = [] 

940 

941 # Check if we have local calibration results 

942 if not hasattr(calibrator, "optimal_params_") or calibrator.optimal_params_ is None: 

943 print(" No local calibration results available") 

944 return results 

945 

946 swept_params = calibrator.swept_params 

947 optimal_params = calibrator.optimal_params_ 

948 

949 # Get feature values (from sample=0, features don't depend on sample) 

950 X_df = database.isel(sample=0).to_dataframe().reset_index() 

951 X = X_df[features] 

952 

953 n_params = len(swept_params) 

954 n_features = len(features) 

955 

956 fig, axes = plt.subplots( 

957 n_params, n_features, figsize=(4 * n_features, 4 * n_params) 

958 ) 

959 if n_params == 1: 

960 axes = axes.reshape(1, -1) 

961 if n_features == 1: 

962 axes = axes.reshape(-1, 1) 

963 

964 for p_idx, param in enumerate(swept_params): 

965 param_values = optimal_params[param] 

966 

967 best_feature = None 

968 best_corr = 0 

969 

970 for f_idx, feature in enumerate(features): 

971 ax = axes[p_idx, f_idx] 

972 feature_values = X[feature].values 

973 

974 # Compute correlation 

975 corr = np.corrcoef(feature_values, param_values)[0, 1] 

976 if np.isnan(corr): 

977 corr = 0 

978 

979 if abs(corr) > abs(best_corr): 

980 best_corr = corr 

981 best_feature = feature 

982 

983 # Scatter plot with regression line 

984 ax.scatter(feature_values, param_values, alpha=0.4, s=10) 

985 

986 # Add regression line 

987 if abs(corr) > 0.1: 

988 z = np.polyfit(feature_values, param_values, 1) 

989 p = np.poly1d(z) 

990 x_line = np.linspace(feature_values.min(), feature_values.max(), 100) 

991 ax.plot(x_line, p(x_line), "r-", linewidth=2, label=f"r={corr:.2f}") 

992 ax.legend(fontsize=8) 

993 

994 ax.set_xlabel(feature) 

995 if f_idx == 0: 

996 ax.set_ylabel(f"Optimal {param}") 

997 ax.set_title(f"r = {corr:.3f}") 

998 ax.grid(True, alpha=0.3) 

999 

1000 # Store result for best correlated feature 

1001 if best_feature: 

1002 if abs(best_corr) > 0.5: 

1003 rel_type = "strong positive" if best_corr > 0 else "strong negative" 

1004 elif abs(best_corr) > 0.3: 

1005 rel_type = "moderate positive" if best_corr > 0 else "moderate negative" 

1006 else: 

1007 rel_type = "weak" 

1008 

1009 interpretation = interpret_parameter_relationship( 

1010 param, best_feature, best_corr 

1011 ) 

1012 

1013 results.append( 

1014 ParameterRelationshipResult( 

1015 parameter=param, 

1016 most_influential_feature=best_feature, 

1017 correlation=float(best_corr), 

1018 relationship_type=rel_type, 

1019 physical_interpretation=interpretation, 

1020 ) 

1021 ) 

1022 

1023 plt.suptitle( 

1024 "Optimal Parameter vs. Atmospheric Conditions\n" 

1025 "(Shows how model params should vary with conditions)", 

1026 fontsize=12, 

1027 ) 

1028 plt.tight_layout() 

1029 plt.savefig( 

1030 output_dir / "parameter_relationships.png", dpi=150, bbox_inches="tight" 

1031 ) 

1032 plt.close(fig) 

1033 print( 

1034 f" Saved parameter relationships to {output_dir / 'parameter_relationships.png'}" 

1035 ) 

1036 

1037 return results 

1038 

1039 

1040def generate_summary(report: PhysicsInsightsReport) -> str: 

1041 """Generate executive summary of physics insights.""" 

1042 lines = [] 

1043 

1044 # Summarize PD findings 

1045 if report.partial_dependence: 

1046 increasing = [ 

1047 r for r in report.partial_dependence if r.bias_direction == "increases" 

1048 ] 

1049 decreasing = [ 

1050 r for r in report.partial_dependence if r.bias_direction == "decreases" 

1051 ] 

1052 

1053 if increasing: 

1054 feats = ", ".join(r.feature for r in increasing[:2]) 

1055 lines.append(f"• Model bias increases with {feats}") 

1056 if decreasing: 

1057 feats = ", ".join(r.feature for r in decreasing[:2]) 

1058 lines.append(f"• Model bias decreases with {feats}") 

1059 

1060 # Summarize interactions 

1061 if report.interactions: 

1062 top = report.interactions[0] 

1063 lines.append(f"• Strongest interaction: {top.feature_1} × {top.feature_2}") 

1064 

1065 # Summarize regimes 

1066 if report.regimes: 

1067 n_regimes = len(report.regimes) 

1068 lines.append(f"• Identified {n_regimes} distinct error regimes") 

1069 for r in report.regimes: 

1070 lines.append(f" - Regime {r.regime_id + 1}: {r.description}") 

1071 

1072 # Summarize parameter insights 

1073 if report.parameter_relationships: 

1074 for r in report.parameter_relationships: 

1075 if "strong" in r.relationship_type: 

1076 lines.append( 

1077 f"{r.parameter} should be condition-dependent " 

1078 f"(varies with {r.most_influential_feature})" 

1079 ) 

1080 

1081 return "\n".join(lines) 

1082 

1083 

1084# ============================================================================= 

1085# Main Entry Point 

1086# ============================================================================= 

1087 

1088 

1089def run_physics_insights( 

1090 database: xr.Dataset, 

1091 fitted_model, 

1092 calibrator, 

1093 features_list: list[str], 

1094 y_bias: np.ndarray, 

1095 output_dir: Path, 

1096 config: dict | None = None, 

1097) -> PhysicsInsightsReport: 

1098 """ 

1099 Run complete physics insights analysis. 

1100 

1101 Args: 

1102 database: xarray Dataset with model error database 

1103 fitted_model: Trained ML model (or pipeline with 'model' step) 

1104 calibrator: Fitted calibrator (for parameter relationships) 

1105 features_list: List of feature names used 

1106 y_bias: Bias values (predictions or actuals) 

1107 output_dir: Directory for output plots 

1108 config: Optional configuration dict 

1109 

1110 Returns: 

1111 PhysicsInsightsReport with all analysis results 

1112 """ 

1113 config = config or {} 

1114 output_dir = Path(output_dir) 

1115 output_dir.mkdir(parents=True, exist_ok=True) 

1116 

1117 print("=" * 60) 

1118 print("PHYSICS INSIGHTS ANALYSIS") 

1119 print("=" * 60) 

1120 

1121 # Prepare feature matrix 

1122 X_df = database.isel(sample=0).to_dataframe().reset_index() 

1123 X = X_df[features_list] 

1124 

1125 # Extract model from pipeline if needed 

1126 if hasattr(fitted_model, "named_steps") and "model" in fitted_model.named_steps: 

1127 model = fitted_model.named_steps["model"] 

1128 else: 

1129 model = fitted_model 

1130 

1131 report = PhysicsInsightsReport() 

1132 

1133 # 1. Partial Dependence 

1134 if config.get("partial_dependence", {}).get("enabled", True): 

1135 pd_features = config.get("partial_dependence", {}).get( 

1136 "features", features_list 

1137 ) 

1138 report.partial_dependence = analyze_partial_dependence( 

1139 model=fitted_model, # Use full pipeline for PD 

1140 X=X, 

1141 features=pd_features, 

1142 output_dir=output_dir, 

1143 ) 

1144 

1145 # 2. Interactions (requires tree model) 

1146 if config.get("interactions", {}).get("enabled", True): 

1147 if hasattr(model, "feature_importances_"): # Tree-based model 

1148 report.interactions = analyze_interactions( 

1149 model=model, 

1150 X=X, 

1151 features=features_list, 

1152 output_dir=output_dir, 

1153 top_n=config.get("interactions", {}).get("top_n", 5), 

1154 ) 

1155 else: 

1156 print(" Skipping interactions (requires tree-based model)") 

1157 

1158 # 3. Regime Analysis 

1159 if config.get("regime_analysis", {}).get("enabled", True): 

1160 report.regimes = analyze_regimes( 

1161 X=X, 

1162 y_bias=y_bias, 

1163 features=features_list, 

1164 output_dir=output_dir, 

1165 n_clusters=config.get("regime_analysis", {}).get("n_clusters", 3), 

1166 bias_percentile=config.get("regime_analysis", {}).get( 

1167 "bias_percentile", 75 

1168 ), 

1169 ) 

1170 

1171 # 4. Parameter Relationships (local calibration only) 

1172 if config.get("parameter_relationships", {}).get("enabled", True): 

1173 if hasattr(calibrator, "optimal_params_"): 

1174 report.parameter_relationships = analyze_parameter_relationships( 

1175 database=database, 

1176 calibrator=calibrator, 

1177 features=features_list, 

1178 output_dir=output_dir, 

1179 ) 

1180 

1181 # Generate summary 

1182 report.summary = generate_summary(report) 

1183 

1184 # Save report 

1185 report_md = report.to_markdown() 

1186 with open(output_dir / "physics_insights_report.md", "w") as f: 

1187 f.write(report_md) 

1188 print(f"\nSaved report to {output_dir / 'physics_insights_report.md'}") 

1189 

1190 # Save as JSON for programmatic access 

1191 import json 

1192 

1193 with open(output_dir / "physics_insights.json", "w") as f: 

1194 json.dump(report.to_dict(), f, indent=2) 

1195 print(f"Saved JSON to {output_dir / 'physics_insights.json'}") 

1196 

1197 print("\n" + "=" * 60) 

1198 print("SUMMARY") 

1199 print("=" * 60) 

1200 print(report.summary) 

1201 

1202 return report