In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Bipolar to Monopolar LFP Power Estimation Companion
---------------------------------------------------
This notebook provides an interactive interface to help
you estimate Monopolar LFP Powers from Bipolar LFP Power
using the weights and model described in Fleeting et al.,
2025.

Input type: CSV
    >> Must contain ['C0-C3', 'C1-C2', and 'C2-C3'] in dB PSD (labeled at head of each column)
    >> If it contains ['C0', 'C1', 'C2' and 'C3'], it will also run validation.

Output type: CSV
    >> Saved locally to input file with suffix "_estimated"
    >> Conservative format. 
        >> Removed ['C0-C3', 'C1-C2', and 'C2-C3']
        >> Added ['C0', 'C1', 'C2' and 'C3']
        >> Added ['C0_SE', 'C1_SE', 'C2_SE' and 'C3_SE']

Author: Chance Fleeting
"""

import pandas as pd
import numpy as np
import statsmodels.api as sm
from pathlib import Path
from scipy.stats import norm
import matplotlib.pyplot as plt

## Coefficients from Fleeting et al. 2025
coef = pd.DataFrame({
    'C0': [3.931926, 0.737293,  0.076522, 0.101754],
    'C1': [4.738096, 0.410390,  0.393244, 0.125366],
    'C2': [4.953907, 0.284265,  0.157266, 0.488739],
    'C3': [3.844596, 0.564074, -0.043996, 0.409325],
    }, index=['const', 'C0-C3', 'C1-C2', 'C2-C3']) # <-------- Please Label 'C0-C3', 'C1-C2', and 'C2-C3' in your CSV

coef_cov = pd.DataFrame({
    'C0':[pd.DataFrame({
        'const': [ 4.15542477e-04, -5.07329624e-04, -2.91715631e-05,  5.56618849e-04],
        'C0-C3': [-5.07329624e-04,  1.99340800e-03, -1.45554879e-03, -3.50993309e-04],
        'C1-C2': [-2.91715631e-05, -1.45554879e-03,  2.17865725e-03, -9.18408054e-04],
        'C2-C3': [ 5.56618849e-04, -3.50993309e-04, -9.18408054e-04,  1.50124537e-03]
        }, index=['const', 'C0-C3', 'C1-C2', 'C2-C3'])],
    'C1':[pd.DataFrame({
        'const': [ 0.00041828, -0.00067915,  0.00018877,  0.0005084],
        'C0-C3': [-0.00067915,  0.00309587, -0.00265617, -0.00033659],
        'C1-C2': [ 0.00018877, -0.00265617,  0.00380632, -0.0012142],
        'C2-C3': [ 0.0005084,  -0.00033659, -0.0012142,   0.00172875]
        }, index=['const', 'C0-C3', 'C1-C2', 'C2-C3'])],
    'C2':[pd.DataFrame({
        'const': [ 0.00035501, -0.0003559,  -0.0001556,   0.00050323],
        'C0-C3': [-0.0003559,   0.00200934, -0.00118672, -0.00053107],
        'C1-C2': [-0.0001556,  -0.00118672,  0.00212619, -0.00114476],
        'C2-C3': [ 0.00050323, -0.00053107, -0.00114476,  0.00186151]
        }, index=['const', 'C0-C3', 'C1-C2', 'C2-C3'])],
    'C3':[pd.DataFrame({
        'const': [ 0.00049102, -0.00087397,  0.00023486,  0.00062031],
        'C0-C3': [-0.00087397,  0.00352449, -0.00227479, -0.00108152],
        'C1-C2': [ 0.00023486, -0.00227479,  0.00319753, -0.00097603],
        'C2-C3': [ 0.00062031, -0.00108152, -0.00097603,  0.00219852]
        }, index=['const', 'C0-C3', 'C1-C2', 'C2-C3'])]
    }) # I like that the labels propagate with Dataframes. Fit uncertainty

mse = pd.DataFrame({'C0':[3.2663031786532475], 
                    'C1':[3.280076521374565], 
                    'C2':[3.58149570047011], 
                    'C3':[3.7034584336408143]})**2 # system uncertainty

iv = list(coef.index)[1:]    # ['C0-C3', 'C1-C2', 'C2-C3']
dv = list(coef.columns)  # ['C0','C1','C2','C3']
In [2]:
def plot(df_true, df_estimated, df_SE, dv, iv=None, title=None, figsize=(12, 3)):
    """
    Modular function to plot actual vs. predicted regression results,
    showing Adjusted R2, RMSE, and a 95% Prediction Interval patch
    relative to the y=x line. 
    
    THIS WILL ONLY TRIGGER IF 'C0', 'C1', 'C2', and 'C3' ARE IN YOUR CSV.
    THIS IS ONLY USED FOR VALIDATION AGAINST GROUND TRUTH.
    """
    n_targets = len(dv)
    ncols = 4
    nrows = int(np.ceil(n_targets / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False)
    axes_flat = axes.ravel()
    z = norm.ppf(0.975)  # Critical value -- NORMALITY ASSUMPTION for 95% CI

    plot_min, plot_max = -20, 45
    
    k = True
    for ax, target_name in zip(axes_flat, dv):
        actual_vals = df_true[target_name].values
        predicted_vals = df_estimated[target_name].values
        
        # 1. Get the margin-of-error values
        try:
            y_err = z*df_SE[target_name].values[0]
            if len(y_err) != len(actual_vals):
                print(f"Warning: Error bar length mismatch for {target_name}")
                y_err = None
        except Exception as e:
            print(f"Could not extract CI for {target_name}: {e}")
            y_err = None
        
        # 2. Sort all arrays by actual_vals for correct patch plotting
        if y_err is not None:
            sort_idx = np.argsort(actual_vals)
            sorted_actual = actual_vals[sort_idx]
            sorted_predicted = predicted_vals[sort_idx]
            sorted_y_err = y_err[sort_idx]
            
            # 3. Plot the patch (upper and lower bounds)
            ax.fill_between(
                sorted_actual, 
                sorted_actual - sorted_y_err, 
                sorted_actual + sorted_y_err,
                color='gray', 
                alpha=0.3, 
                label='95% PI'
            )
            
            # 4. Plot the scatter points (use sorted values)
            ax.scatter(sorted_actual, sorted_predicted, alpha=0.7, s=20, 
                       edgecolors='none', label='Predictions')
        else:
            # Fallback if CI data is bad
            ax.scatter(actual_vals, predicted_vals, alpha=0.7, s=20, 
                       edgecolors='none', label='Predictions')
                   
        # Plot the y=x line
        ax.plot([plot_min, plot_max], [plot_min, plot_max], 'r--', 
                lw=2, label='Perfect Prediction')

        # Compute adjusted R2
        ss_res = np.sum((actual_vals - predicted_vals)**2)
        ss_tot = np.sum((actual_vals - np.mean(actual_vals))**2)
        n = len(actual_vals)
        pn = 1 if iv is None else (len(iv) if isinstance(iv, list) else 1)
        adj_r2 = 1 - (1 - (1 - ss_res/ss_tot)) * (n - 1)/(n - pn - 1)

        # Compute RMSE
        rmse =  np.sqrt(np.mean((actual_vals-predicted_vals)**2))

        # Draw white background text
        text = f"Adj $R^2$ = {adj_r2:.4f}\nRMSE = {rmse:.4f}"
        
        ax.text(0.05, 0.95, text, 
                transform=ax.transAxes, ha='left', va='top', fontsize=9.5,
                bbox=dict(facecolor='white', edgecolor='none', alpha=0.5, pad=0.5))

        ax.set_title(f"{target_name}", fontweight='bold')
        if k:
            ax.set_ylabel((title or f"Regression Results over {iv}") + "\nPredicted Power (dB)")
            k = False
        ax.set_xlabel("Actual Power (dB)")
        ax.set_xlim(plot_min, plot_max)
        ax.set_ylim(plot_min, plot_max)
        
        # Minimalist axes
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(True)
        ax.spines['bottom'].set_visible(True)
        ax.spines['left'].set_color('k')
        ax.spines['bottom'].set_color('k')
        ax.spines['left'].set_linewidth(1.5)
        ax.spines['bottom'].set_linewidth(1.5)
        ax.tick_params(left=True, bottom=True, right=False, top=False, color='k')
        ax.grid(False)

    # Hide unused subplots
    for ax in axes_flat[len(dv):]:
        ax.axis("off")

    plt.suptitle(title or f"Regression Results over {iv}")
    plt.tight_layout()
    plt.show()

    return plt
In [3]:
## Body
# INPUT Phase

#####_____________________#####
input_file = r".\Example_Data.csv"   # <-- replace with your filename (CSV)
#####_____________________#####

df = pd.read_csv(input_file)
assert all(v in df.columns for v in iv), f"Missing columns: {set(iv) - set(df.columns)}" 

nm = Path(input_file)
df_remainder = df[[c for c in df.columns if c not in iv + dv]]

# COMPUTATION Phase
x = sm.add_constant(df[iv])
df_estimated = x @ coef

df_SE = pd.DataFrame({v:[np.sqrt(np.diag(x @ coef_cov[v][0] @ x.T) + mse[v][0])] for v in dv}) #Standard model error (propagated) ignoring lagged component.

# If Monopolar are present, COMPARE true and estimate (and plot)
if all(v in df.columns for v in dv):
    df_true = df[dv]
    plot(df_true, df_estimated, df_SE, dv, iv, title=f"Verification (N = {len(df_true)})")

# SAVE Phase
for k,v in df_SE.items():
    df_estimated[k+'_SE'] = v[0]
out_file = nm.with_stem(nm.stem+"_estimated")
pd.concat([df_remainder,df_estimated], axis = 1).to_csv(out_file, index=False)
print("Saved estimated values to:", out_file)
Saved estimated values to: Example_Data_estimated.csv